feat: migrate store layer to GORM with PostgreSQL support

- Migrate all store packages from raw database/sql to GORM ORM
- Add PostgreSQL support alongside SQLite
- Move EncryptedString type to crypto package for cleaner architecture
- Add automatic encryption/decryption for sensitive fields (API keys, secrets)
- Fix PostgreSQL AutoMigrate conflicts by skipping existing tables
- Fix duplicate /klines route registration
- Update tests to use GORM database connections
- Add database configuration support in config package
This commit is contained in:
tinkle-community
2026-01-01 19:32:49 +08:00
parent d547863ebb
commit 2d272bb7b8
32 changed files with 2573 additions and 3771 deletions
+13
View File
@@ -55,3 +55,16 @@ TRANSPORT_ENCRYPTION=false
# Telegram notifications (optional) # Telegram notifications (optional)
# TELEGRAM_BOT_TOKEN=your-bot-token # TELEGRAM_BOT_TOKEN=your-bot-token
# TELEGRAM_CHAT_ID=your-chat-id # TELEGRAM_CHAT_ID=your-chat-id
DB_TYPE=postgres
DB_HOST=10.
DB_PORT=5432
DB_USER=nofx_user
DB_PASSWORD=
DB_NAME=nofx
DB_SSLMODE=disable
# 数据库配置 - SQLite(默认)
DB_TYPE=sqlite
DB_PATH=data/data.db
+1 -1
View File
@@ -824,7 +824,7 @@ func (s *Server) hydrateBacktestAIConfig(cfg *backtest.BacktestConfig) error {
return fmt.Errorf("AI model %s is not enabled yet", model.Name) return fmt.Errorf("AI model %s is not enabled yet", model.Name)
} }
apiKey := strings.TrimSpace(model.APIKey) apiKey := strings.TrimSpace(string(model.APIKey))
if apiKey == "" { if apiKey == "" {
return fmt.Errorf("AI model %s is missing API Key, please configure it in the system first", model.Name) return fmt.Errorf("AI model %s is missing API Key, please configure it in the system first", model.Name)
} }
+43 -41
View File
@@ -54,7 +54,7 @@ func NewServer(traderManager *manager.TraderManager, st *store.Store, cryptoServ
cryptoHandler := NewCryptoHandler(cryptoService) cryptoHandler := NewCryptoHandler(cryptoService)
// Create debate store and handler // Create debate store and handler
debateStore := store.NewDebateStore(st.DB()) debateStore := store.NewDebateStore(st.GormDB())
if err := debateStore.InitSchema(); err != nil { if err := debateStore.InitSchema(); err != nil {
logger.Errorf("Failed to initialize debate schema: %v", err) logger.Errorf("Failed to initialize debate schema: %v", err)
} }
@@ -125,7 +125,6 @@ func (s *Server) setupRoutes() {
// Market data (no authentication required) // Market data (no authentication required)
api.GET("/klines", s.handleKlines) api.GET("/klines", s.handleKlines)
api.GET("/klines", s.handleKlines)
api.GET("/symbols", s.handleSymbols) api.GET("/symbols", s.handleSymbols)
// Authentication related routes (no authentication required) // Authentication related routes (no authentication required)
@@ -576,12 +575,13 @@ func (s *Server) handleCreateTrader(c *gin.Context) {
var createErr error var createErr error
// Use ExchangeType (e.g., "binance") instead of ID (UUID) // Use ExchangeType (e.g., "binance") instead of ID (UUID)
// Convert EncryptedString fields to string
switch exchangeCfg.ExchangeType { switch exchangeCfg.ExchangeType {
case "binance": case "binance":
tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey, userID) tempTrader = trader.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID)
case "hyperliquid": case "hyperliquid":
tempTrader, createErr = trader.NewHyperliquidTrader( tempTrader, createErr = trader.NewHyperliquidTrader(
exchangeCfg.APIKey, // private key string(exchangeCfg.APIKey), // private key
exchangeCfg.HyperliquidWalletAddr, exchangeCfg.HyperliquidWalletAddr,
exchangeCfg.Testnet, exchangeCfg.Testnet,
) )
@@ -589,31 +589,31 @@ func (s *Server) handleCreateTrader(c *gin.Context) {
tempTrader, createErr = trader.NewAsterTrader( tempTrader, createErr = trader.NewAsterTrader(
exchangeCfg.AsterUser, exchangeCfg.AsterUser,
exchangeCfg.AsterSigner, exchangeCfg.AsterSigner,
exchangeCfg.AsterPrivateKey, string(exchangeCfg.AsterPrivateKey),
) )
case "bybit": case "bybit":
tempTrader = trader.NewBybitTrader( tempTrader = trader.NewBybitTrader(
exchangeCfg.APIKey, string(exchangeCfg.APIKey),
exchangeCfg.SecretKey, string(exchangeCfg.SecretKey),
) )
case "okx": case "okx":
tempTrader = trader.NewOKXTrader( tempTrader = trader.NewOKXTrader(
exchangeCfg.APIKey, string(exchangeCfg.APIKey),
exchangeCfg.SecretKey, string(exchangeCfg.SecretKey),
exchangeCfg.Passphrase, string(exchangeCfg.Passphrase),
) )
case "bitget": case "bitget":
tempTrader = trader.NewBitgetTrader( tempTrader = trader.NewBitgetTrader(
exchangeCfg.APIKey, string(exchangeCfg.APIKey),
exchangeCfg.SecretKey, string(exchangeCfg.SecretKey),
exchangeCfg.Passphrase, string(exchangeCfg.Passphrase),
) )
case "lighter": case "lighter":
if exchangeCfg.LighterWalletAddr != "" && exchangeCfg.LighterAPIKeyPrivateKey != "" { if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" {
// Lighter only supports mainnet // Lighter only supports mainnet
tempTrader, createErr = trader.NewLighterTraderV2( tempTrader, createErr = trader.NewLighterTraderV2(
exchangeCfg.LighterWalletAddr, exchangeCfg.LighterWalletAddr,
exchangeCfg.LighterAPIKeyPrivateKey, string(exchangeCfg.LighterAPIKeyPrivateKey),
exchangeCfg.LighterAPIKeyIndex, exchangeCfg.LighterAPIKeyIndex,
false, // Always use mainnet for Lighter false, // Always use mainnet for Lighter
) )
@@ -1095,12 +1095,13 @@ func (s *Server) handleSyncBalance(c *gin.Context) {
var createErr error var createErr error
// Use ExchangeType (e.g., "binance") instead of ExchangeID (which is now UUID) // Use ExchangeType (e.g., "binance") instead of ExchangeID (which is now UUID)
// Convert EncryptedString fields to string
switch exchangeCfg.ExchangeType { switch exchangeCfg.ExchangeType {
case "binance": case "binance":
tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey, userID) tempTrader = trader.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID)
case "hyperliquid": case "hyperliquid":
tempTrader, createErr = trader.NewHyperliquidTrader( tempTrader, createErr = trader.NewHyperliquidTrader(
exchangeCfg.APIKey, string(exchangeCfg.APIKey),
exchangeCfg.HyperliquidWalletAddr, exchangeCfg.HyperliquidWalletAddr,
exchangeCfg.Testnet, exchangeCfg.Testnet,
) )
@@ -1108,31 +1109,31 @@ func (s *Server) handleSyncBalance(c *gin.Context) {
tempTrader, createErr = trader.NewAsterTrader( tempTrader, createErr = trader.NewAsterTrader(
exchangeCfg.AsterUser, exchangeCfg.AsterUser,
exchangeCfg.AsterSigner, exchangeCfg.AsterSigner,
exchangeCfg.AsterPrivateKey, string(exchangeCfg.AsterPrivateKey),
) )
case "bybit": case "bybit":
tempTrader = trader.NewBybitTrader( tempTrader = trader.NewBybitTrader(
exchangeCfg.APIKey, string(exchangeCfg.APIKey),
exchangeCfg.SecretKey, string(exchangeCfg.SecretKey),
) )
case "okx": case "okx":
tempTrader = trader.NewOKXTrader( tempTrader = trader.NewOKXTrader(
exchangeCfg.APIKey, string(exchangeCfg.APIKey),
exchangeCfg.SecretKey, string(exchangeCfg.SecretKey),
exchangeCfg.Passphrase, string(exchangeCfg.Passphrase),
) )
case "bitget": case "bitget":
tempTrader = trader.NewBitgetTrader( tempTrader = trader.NewBitgetTrader(
exchangeCfg.APIKey, string(exchangeCfg.APIKey),
exchangeCfg.SecretKey, string(exchangeCfg.SecretKey),
exchangeCfg.Passphrase, string(exchangeCfg.Passphrase),
) )
case "lighter": case "lighter":
if exchangeCfg.LighterWalletAddr != "" && exchangeCfg.LighterAPIKeyPrivateKey != "" { if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" {
// Lighter only supports mainnet // Lighter only supports mainnet
tempTrader, createErr = trader.NewLighterTraderV2( tempTrader, createErr = trader.NewLighterTraderV2(
exchangeCfg.LighterWalletAddr, exchangeCfg.LighterWalletAddr,
exchangeCfg.LighterAPIKeyPrivateKey, string(exchangeCfg.LighterAPIKeyPrivateKey),
exchangeCfg.LighterAPIKeyIndex, exchangeCfg.LighterAPIKeyIndex,
false, // Always use mainnet for Lighter false, // Always use mainnet for Lighter
) )
@@ -1246,12 +1247,13 @@ func (s *Server) handleClosePosition(c *gin.Context) {
var createErr error var createErr error
// Use ExchangeType (e.g., "binance") instead of ExchangeID (which is now UUID) // Use ExchangeType (e.g., "binance") instead of ExchangeID (which is now UUID)
// Convert EncryptedString fields to string
switch exchangeCfg.ExchangeType { switch exchangeCfg.ExchangeType {
case "binance": case "binance":
tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey, userID) tempTrader = trader.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID)
case "hyperliquid": case "hyperliquid":
tempTrader, createErr = trader.NewHyperliquidTrader( tempTrader, createErr = trader.NewHyperliquidTrader(
exchangeCfg.APIKey, string(exchangeCfg.APIKey),
exchangeCfg.HyperliquidWalletAddr, exchangeCfg.HyperliquidWalletAddr,
exchangeCfg.Testnet, exchangeCfg.Testnet,
) )
@@ -1259,31 +1261,31 @@ func (s *Server) handleClosePosition(c *gin.Context) {
tempTrader, createErr = trader.NewAsterTrader( tempTrader, createErr = trader.NewAsterTrader(
exchangeCfg.AsterUser, exchangeCfg.AsterUser,
exchangeCfg.AsterSigner, exchangeCfg.AsterSigner,
exchangeCfg.AsterPrivateKey, string(exchangeCfg.AsterPrivateKey),
) )
case "bybit": case "bybit":
tempTrader = trader.NewBybitTrader( tempTrader = trader.NewBybitTrader(
exchangeCfg.APIKey, string(exchangeCfg.APIKey),
exchangeCfg.SecretKey, string(exchangeCfg.SecretKey),
) )
case "okx": case "okx":
tempTrader = trader.NewOKXTrader( tempTrader = trader.NewOKXTrader(
exchangeCfg.APIKey, string(exchangeCfg.APIKey),
exchangeCfg.SecretKey, string(exchangeCfg.SecretKey),
exchangeCfg.Passphrase, string(exchangeCfg.Passphrase),
) )
case "bitget": case "bitget":
tempTrader = trader.NewBitgetTrader( tempTrader = trader.NewBitgetTrader(
exchangeCfg.APIKey, string(exchangeCfg.APIKey),
exchangeCfg.SecretKey, string(exchangeCfg.SecretKey),
exchangeCfg.Passphrase, string(exchangeCfg.Passphrase),
) )
case "lighter": case "lighter":
if exchangeCfg.LighterWalletAddr != "" && exchangeCfg.LighterAPIKeyPrivateKey != "" { if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" {
// Lighter only supports mainnet // Lighter only supports mainnet
tempTrader, createErr = trader.NewLighterTraderV2( tempTrader, createErr = trader.NewLighterTraderV2(
exchangeCfg.LighterWalletAddr, exchangeCfg.LighterWalletAddr,
exchangeCfg.LighterAPIKeyPrivateKey, string(exchangeCfg.LighterAPIKeyPrivateKey),
exchangeCfg.LighterAPIKeyIndex, exchangeCfg.LighterAPIKeyIndex,
false, // Always use mainnet for Lighter false, // Always use mainnet for Lighter
) )
+10 -8
View File
@@ -549,32 +549,34 @@ func (s *Server) runRealAITest(userID, modelID, systemPrompt, userPrompt string)
var aiClient mcp.AIClient var aiClient mcp.AIClient
provider := model.Provider provider := model.Provider
// Convert EncryptedString to string for API key
apiKey := string(model.APIKey)
switch provider { switch provider {
case "qwen": case "qwen":
aiClient = mcp.NewQwenClient() aiClient = mcp.NewQwenClient()
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName) aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
case "deepseek": case "deepseek":
aiClient = mcp.NewDeepSeekClient() aiClient = mcp.NewDeepSeekClient()
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName) aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
case "claude": case "claude":
aiClient = mcp.NewClaudeClient() aiClient = mcp.NewClaudeClient()
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName) aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
case "kimi": case "kimi":
aiClient = mcp.NewKimiClient() aiClient = mcp.NewKimiClient()
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName) aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
case "gemini": case "gemini":
aiClient = mcp.NewGeminiClient() aiClient = mcp.NewGeminiClient()
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName) aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
case "grok": case "grok":
aiClient = mcp.NewGrokClient() aiClient = mcp.NewGrokClient()
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName) aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
case "openai": case "openai":
aiClient = mcp.NewOpenAIClient() aiClient = mcp.NewOpenAIClient()
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName) aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
default: default:
// Use generic client // Use generic client
aiClient = mcp.NewClient() aiClient = mcp.NewClient()
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName) aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
} }
// Call AI API // Call AI API
+2 -2
View File
@@ -76,8 +76,8 @@ func enforceRetentionDB(maxRuns int) {
query := ` query := `
SELECT run_id FROM backtest_runs SELECT run_id FROM backtest_runs
WHERE state IN (?, ?, ?, ?) WHERE state IN (?, ?, ?, ?)
ORDER BY datetime(updated_at) DESC ORDER BY updated_at DESC
LIMIT -1 OFFSET ? OFFSET ?
` `
rows, err := persistenceDB.Query(query, rows, err := persistenceDB.Query(query,
finalStates[0], finalStates[1], finalStates[2], finalStates[3], maxRuns) finalStates[0], finalStates[1], finalStates[2], finalStates[3], maxRuns)
+4 -4
View File
@@ -166,7 +166,7 @@ func loadRunMetadataDB(runID string) (*RunMetadata, error) {
} }
func loadRunIDsDB() ([]string, error) { func loadRunIDsDB() ([]string, error) {
rows, err := persistenceDB.Query(`SELECT run_id FROM backtest_runs ORDER BY datetime(updated_at) DESC`) rows, err := persistenceDB.Query(`SELECT run_id FROM backtest_runs ORDER BY updated_at DESC`)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -278,9 +278,9 @@ func loadDecisionTraceDB(runID string, cycle int) (*store.DecisionRecord, error)
var rows *sql.Rows var rows *sql.Rows
var err error var err error
if cycle > 0 { if cycle > 0 {
rows, err = persistenceDB.Query(query+` AND cycle = ? ORDER BY datetime(created_at) DESC LIMIT 1`, runID, cycle) rows, err = persistenceDB.Query(query+` AND cycle = ? ORDER BY created_at DESC LIMIT 1`, runID, cycle)
} else { } else {
rows, err = persistenceDB.Query(query+` ORDER BY datetime(created_at) DESC LIMIT 1`, runID) rows, err = persistenceDB.Query(query+` ORDER BY created_at DESC LIMIT 1`, runID)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@@ -461,7 +461,7 @@ func listIndexEntriesDB() ([]RunIndexEntry, error) {
rows, err := persistenceDB.Query(` rows, err := persistenceDB.Query(`
SELECT run_id, state, symbol_count, decision_tf, equity_last, max_drawdown_pct, created_at, updated_at, config_json SELECT run_id, state, symbol_count, decision_tf, equity_last, max_drawdown_pct, created_at, updated_at, config_json
FROM backtest_runs FROM backtest_runs
ORDER BY datetime(updated_at) DESC ORDER BY updated_at DESC
`) `)
if err != nil { if err != nil {
return nil, err return nil, err
+46
View File
@@ -20,6 +20,16 @@ type Config struct {
RegistrationEnabled bool RegistrationEnabled bool
MaxUsers int // Maximum number of users allowed (0 = unlimited, default = 10) MaxUsers int // Maximum number of users allowed (0 = unlimited, default = 10)
// Database configuration
DBType string // sqlite or postgres
DBPath string // SQLite database file path
DBHost string // PostgreSQL host
DBPort int // PostgreSQL port
DBUser string // PostgreSQL user
DBPassword string // PostgreSQL password
DBName string // PostgreSQL database name
DBSSLMode string // PostgreSQL SSL mode
// Security configuration // Security configuration
// TransportEncryption enables browser-side encryption for API keys // TransportEncryption enables browser-side encryption for API keys
// Requires HTTPS or localhost. Set to false for HTTP access via IP. // Requires HTTPS or localhost. Set to false for HTTP access via IP.
@@ -43,6 +53,14 @@ func Init() {
RegistrationEnabled: true, RegistrationEnabled: true,
MaxUsers: 10, // Default: 10 users allowed MaxUsers: 10, // Default: 10 users allowed
ExperienceImprovement: true, // Default: enabled to help improve the product ExperienceImprovement: true, // Default: enabled to help improve the product
// Database defaults
DBType: "sqlite",
DBPath: "data/data.db",
DBHost: "localhost",
DBPort: 5432,
DBUser: "postgres",
DBName: "nofx",
DBSSLMode: "disable",
} }
// Load from environment variables // Load from environment variables
@@ -86,6 +104,34 @@ func Init() {
cfg.AlpacaSecretKey = os.Getenv("ALPACA_SECRET_KEY") cfg.AlpacaSecretKey = os.Getenv("ALPACA_SECRET_KEY")
cfg.TwelveDataKey = os.Getenv("TWELVEDATA_API_KEY") cfg.TwelveDataKey = os.Getenv("TWELVEDATA_API_KEY")
// Database configuration
if v := os.Getenv("DB_TYPE"); v != "" {
cfg.DBType = strings.ToLower(v)
}
if v := os.Getenv("DB_PATH"); v != "" {
cfg.DBPath = v
}
if v := os.Getenv("DB_HOST"); v != "" {
cfg.DBHost = v
}
if v := os.Getenv("DB_PORT"); v != "" {
if port, err := strconv.Atoi(v); err == nil && port > 0 {
cfg.DBPort = port
}
}
if v := os.Getenv("DB_USER"); v != "" {
cfg.DBUser = v
}
if v := os.Getenv("DB_PASSWORD"); v != "" {
cfg.DBPassword = v
}
if v := os.Getenv("DB_NAME"); v != "" {
cfg.DBName = v
}
if v := os.Getenv("DB_SSLMODE"); v != "" {
cfg.DBSSLMode = v
}
global = cfg global = cfg
// Initialize experience improvement (installation ID will be set after database init) // Initialize experience improvement (installation ID will be set after database init)
+75
View File
@@ -7,6 +7,7 @@ import (
"crypto/rsa" "crypto/rsa"
"crypto/sha256" "crypto/sha256"
"crypto/x509" "crypto/x509"
"database/sql/driver"
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
@@ -392,3 +393,77 @@ func GenerateDataKey() (string, error) {
} }
return base64.StdEncoding.EncodeToString(key), nil return base64.StdEncoding.EncodeToString(key), nil
} }
// ============================================================================
// EncryptedString - GORM custom type for automatic encryption/decryption
// ============================================================================
// Global crypto service for EncryptedString
var globalCryptoService *CryptoService
// SetGlobalCryptoService sets the global crypto service for EncryptedString
func SetGlobalCryptoService(cs *CryptoService) {
globalCryptoService = cs
}
// EncryptedString is a custom type that automatically encrypts on save and decrypts on load
// Usage: Use EncryptedString instead of string for sensitive fields in GORM models
type EncryptedString string
// Scan implements sql.Scanner - called when reading from database
// Automatically decrypts the value
func (es *EncryptedString) Scan(value interface{}) error {
if value == nil {
*es = ""
return nil
}
var str string
switch v := value.(type) {
case string:
str = v
case []byte:
str = string(v)
default:
*es = ""
return nil
}
// Decrypt if crypto service is set
if globalCryptoService != nil && str != "" && globalCryptoService.IsEncryptedStorageValue(str) {
decrypted, err := globalCryptoService.DecryptFromStorage(str)
if err != nil {
// If decryption fails, return the original value
*es = EncryptedString(str)
} else {
*es = EncryptedString(decrypted)
}
} else {
*es = EncryptedString(str)
}
return nil
}
// Value implements driver.Valuer - called when writing to database
// Automatically encrypts the value
func (es EncryptedString) Value() (driver.Value, error) {
if es == "" {
return "", nil
}
// Encrypt if crypto service is set
if globalCryptoService != nil {
encrypted, err := globalCryptoService.EncryptForStorage(string(es))
if err != nil {
// If encryption fails, return the original value
return string(es), nil
}
return encrypted, nil
}
return string(es), nil
}
// String returns the plaintext string value
func (es EncryptedString) String() string {
return string(es)
}
+2 -2
View File
@@ -101,8 +101,8 @@ func (e *DebateEngine) InitializeClients(participants []*store.DebateParticipant
client = mcp.New() client = mcp.New()
} }
// Configure client // Configure client (convert EncryptedString to string)
client.SetAPIKey(aiModel.APIKey, aiModel.CustomAPIURL, aiModel.CustomModelName) client.SetAPIKey(string(aiModel.APIKey), aiModel.CustomAPIURL, aiModel.CustomModelName)
e.clients[p.AIModelID] = client e.clients[p.AIModelID] = client
} }
+9
View File
@@ -51,6 +51,12 @@ require (
github.com/goccy/go-json v0.10.4 // indirect github.com/goccy/go-json v0.10.4 // indirect
github.com/goccy/go-yaml v1.18.0 // indirect github.com/goccy/go-yaml v1.18.0 // indirect
github.com/holiman/uint256 v1.3.2 // indirect github.com/holiman/uint256 v1.3.2 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.6.0 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/josharian/intern v1.0.0 // indirect github.com/josharian/intern v1.0.0 // indirect
github.com/jpillora/backoff v1.0.0 // indirect github.com/jpillora/backoff v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
@@ -94,6 +100,9 @@ require (
golang.org/x/tools v0.36.0 // indirect golang.org/x/tools v0.36.0 // indirect
google.golang.org/protobuf v1.36.9 // indirect google.golang.org/protobuf v1.36.9 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
gorm.io/driver/postgres v1.6.0 // indirect
gorm.io/driver/sqlite v1.6.0 // indirect
gorm.io/gorm v1.31.1 // indirect
howett.net/plist v1.0.1 // indirect howett.net/plist v1.0.1 // indirect
modernc.org/libc v1.66.10 // indirect modernc.org/libc v1.66.10 // indirect
modernc.org/mathutil v1.7.1 // indirect modernc.org/mathutil v1.7.1 // indirect
+18
View File
@@ -111,7 +111,19 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/holiman/uint256 v1.3.2 h1:a9EgMPSC1AAaj1SZL5zIQD3WbwTuHrMGOerLjGmM/TA= github.com/holiman/uint256 v1.3.2 h1:a9EgMPSC1AAaj1SZL5zIQD3WbwTuHrMGOerLjGmM/TA=
github.com/holiman/uint256 v1.3.2/go.mod h1:EOMSn4q6Nyt9P6efbI3bueV4e1b3dGlUCXeiRV4ng7E= github.com/holiman/uint256 v1.3.2/go.mod h1:EOMSn4q6Nyt9P6efbI3bueV4e1b3dGlUCXeiRV4ng7E=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
@@ -284,6 +296,12 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4=
gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo=
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
howett.net/plist v1.0.1 h1:37GdZ8tP09Q35o9ych3ehygcsL+HqKSwzctveSlarvM= howett.net/plist v1.0.1 h1:37GdZ8tP09Q35o9ych3ehygcsL+HqKSwzctveSlarvM=
howett.net/plist v1.0.1/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g= howett.net/plist v1.0.1/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
modernc.org/cc/v4 v4.26.5 h1:xM3bX7Mve6G8K8b+T11ReenJOT+BmVqQj0FY5T4+5Y4= modernc.org/cc/v4 v4.26.5 h1:xM3bX7Mve6G8K8b+T11ReenJOT+BmVqQj0FY5T4+5Y4=
+28 -48
View File
@@ -36,30 +36,44 @@ func main() {
cfg := config.Get() cfg := config.Get()
logger.Info("✅ Configuration loaded") logger.Info("✅ Configuration loaded")
// Initialize database from environment variables // Initialize encryption service BEFORE database (so EncryptedString can decrypt on read)
// DB_TYPE: sqlite (default) or postgres logger.Info("🔐 Initializing encryption service...")
// For SQLite: DB_PATH (default: data/data.db) cryptoService, err := crypto.NewCryptoService()
// For PostgreSQL: DB_HOST, DB_PORT, DB_USER, DB_PASSWORD, DB_NAME, DB_SSLMODE if err != nil {
dbPath := os.Getenv("DB_PATH") logger.Fatalf("❌ Failed to initialize encryption service: %v", err)
if dbPath == "" {
dbPath = "data/data.db"
} }
// For backward compatibility: command line arg overrides env var (SQLite only) crypto.SetGlobalCryptoService(cryptoService)
logger.Info("✅ Encryption service initialized successfully")
// Initialize database from configuration
// For backward compatibility: command line arg overrides config (SQLite only)
if len(os.Args) > 1 { if len(os.Args) > 1 {
dbPath = os.Args[1] cfg.DBPath = os.Args[1]
os.Setenv("DB_PATH", dbPath)
} }
// Ensure data directory exists (for SQLite) // Ensure data directory exists (for SQLite)
if os.Getenv("DB_TYPE") == "" || os.Getenv("DB_TYPE") == "sqlite" { if cfg.DBType == "sqlite" {
if dir := filepath.Dir(dbPath); dir != "." { if dir := filepath.Dir(cfg.DBPath); dir != "." {
if err := os.MkdirAll(dir, 0755); err != nil { if err := os.MkdirAll(dir, 0755); err != nil {
logger.Errorf("Failed to create data directory: %v", err) logger.Errorf("Failed to create data directory: %v", err)
} }
} }
} }
logger.Info("📋 Initializing database...") logger.Infof("📋 Initializing database (%s)...", cfg.DBType)
st, err := store.NewFromEnv() dbType := store.DBTypeSQLite
if cfg.DBType == "postgres" {
dbType = store.DBTypePostgres
}
st, err := store.NewWithConfig(store.DBConfig{
Type: dbType,
Path: cfg.DBPath,
Host: cfg.DBHost,
Port: cfg.DBPort,
User: cfg.DBUser,
Password: cfg.DBPassword,
DBName: cfg.DBName,
SSLMode: cfg.DBSSLMode,
})
if err != nil { if err != nil {
logger.Fatalf("❌ Failed to initialize database: %v", err) logger.Fatalf("❌ Failed to initialize database: %v", err)
} }
@@ -69,40 +83,6 @@ func main() {
// Initialize installation ID for experience improvement (anonymous statistics) // Initialize installation ID for experience improvement (anonymous statistics)
initInstallationID(st) initInstallationID(st)
// Initialize encryption service
logger.Info("🔐 Initializing encryption service...")
cryptoService, err := crypto.NewCryptoService()
if err != nil {
logger.Fatalf("❌ Failed to initialize encryption service: %v", err)
}
encryptFunc := func(plaintext string) string {
if plaintext == "" {
return plaintext
}
encrypted, err := cryptoService.EncryptForStorage(plaintext)
if err != nil {
logger.Warnf("⚠️ Encryption failed: %v", err)
return plaintext
}
return encrypted
}
decryptFunc := func(encrypted string) string {
if encrypted == "" {
return encrypted
}
if !cryptoService.IsEncryptedStorageValue(encrypted) {
return encrypted
}
decrypted, err := cryptoService.DecryptFromStorage(encrypted)
if err != nil {
logger.Warnf("⚠️ Decryption failed: %v", err)
return encrypted
}
return decrypted
}
st.SetCryptoFuncs(encryptFunc, decryptFunc)
logger.Info("✅ Encryption service initialized successfully")
// Set JWT secret // Set JWT secret
auth.SetJWTSecret(cfg.JWTSecret) auth.SetJWTSecret(cfg.JWTSecret)
logger.Info("🔑 JWT secret configured") logger.Info("🔑 JWT secret configured")
+19 -19
View File
@@ -664,46 +664,46 @@ func (tm *TraderManager) addTraderFromStore(traderCfg *store.Trader, aiModelCfg
StrategyConfig: strategyConfig, StrategyConfig: strategyConfig,
} }
// Set API keys based on exchange type // Set API keys based on exchange type (convert EncryptedString to string)
switch exchangeCfg.ExchangeType { switch exchangeCfg.ExchangeType {
case "binance": case "binance":
traderConfig.BinanceAPIKey = exchangeCfg.APIKey traderConfig.BinanceAPIKey = string(exchangeCfg.APIKey)
traderConfig.BinanceSecretKey = exchangeCfg.SecretKey traderConfig.BinanceSecretKey = string(exchangeCfg.SecretKey)
case "bybit": case "bybit":
traderConfig.BybitAPIKey = exchangeCfg.APIKey traderConfig.BybitAPIKey = string(exchangeCfg.APIKey)
traderConfig.BybitSecretKey = exchangeCfg.SecretKey traderConfig.BybitSecretKey = string(exchangeCfg.SecretKey)
case "okx": case "okx":
traderConfig.OKXAPIKey = exchangeCfg.APIKey traderConfig.OKXAPIKey = string(exchangeCfg.APIKey)
traderConfig.OKXSecretKey = exchangeCfg.SecretKey traderConfig.OKXSecretKey = string(exchangeCfg.SecretKey)
traderConfig.OKXPassphrase = exchangeCfg.Passphrase traderConfig.OKXPassphrase = string(exchangeCfg.Passphrase)
case "bitget": case "bitget":
traderConfig.BitgetAPIKey = exchangeCfg.APIKey traderConfig.BitgetAPIKey = string(exchangeCfg.APIKey)
traderConfig.BitgetSecretKey = exchangeCfg.SecretKey traderConfig.BitgetSecretKey = string(exchangeCfg.SecretKey)
traderConfig.BitgetPassphrase = exchangeCfg.Passphrase traderConfig.BitgetPassphrase = string(exchangeCfg.Passphrase)
case "hyperliquid": case "hyperliquid":
traderConfig.HyperliquidPrivateKey = exchangeCfg.APIKey traderConfig.HyperliquidPrivateKey = string(exchangeCfg.APIKey)
traderConfig.HyperliquidWalletAddr = exchangeCfg.HyperliquidWalletAddr traderConfig.HyperliquidWalletAddr = exchangeCfg.HyperliquidWalletAddr
case "aster": case "aster":
traderConfig.AsterUser = exchangeCfg.AsterUser traderConfig.AsterUser = exchangeCfg.AsterUser
traderConfig.AsterSigner = exchangeCfg.AsterSigner traderConfig.AsterSigner = exchangeCfg.AsterSigner
traderConfig.AsterPrivateKey = exchangeCfg.AsterPrivateKey traderConfig.AsterPrivateKey = string(exchangeCfg.AsterPrivateKey)
case "lighter": case "lighter":
traderConfig.LighterPrivateKey = exchangeCfg.LighterPrivateKey traderConfig.LighterPrivateKey = string(exchangeCfg.LighterPrivateKey)
traderConfig.LighterWalletAddr = exchangeCfg.LighterWalletAddr traderConfig.LighterWalletAddr = exchangeCfg.LighterWalletAddr
traderConfig.LighterAPIKeyPrivateKey = exchangeCfg.LighterAPIKeyPrivateKey traderConfig.LighterAPIKeyPrivateKey = string(exchangeCfg.LighterAPIKeyPrivateKey)
traderConfig.LighterAPIKeyIndex = exchangeCfg.LighterAPIKeyIndex traderConfig.LighterAPIKeyIndex = exchangeCfg.LighterAPIKeyIndex
traderConfig.LighterTestnet = exchangeCfg.Testnet traderConfig.LighterTestnet = exchangeCfg.Testnet
} }
// Set API keys based on AI model // Set API keys based on AI model (convert EncryptedString to string)
switch aiModelCfg.Provider { switch aiModelCfg.Provider {
case "qwen": case "qwen":
traderConfig.QwenKey = aiModelCfg.APIKey traderConfig.QwenKey = string(aiModelCfg.APIKey)
case "deepseek": case "deepseek":
traderConfig.DeepSeekKey = aiModelCfg.APIKey traderConfig.DeepSeekKey = string(aiModelCfg.APIKey)
default: default:
// For other providers (grok, openai, claude, gemini, kimi, etc.), use CustomAPIKey // For other providers (grok, openai, claude, gemini, kimi, etc.), use CustomAPIKey
traderConfig.CustomAPIKey = aiModelCfg.APIKey traderConfig.CustomAPIKey = string(aiModelCfg.APIKey)
} }
// Create trader instance // Create trader instance
+90 -174
View File
@@ -1,71 +1,52 @@
package store package store
import ( import (
"database/sql"
"errors" "errors"
"fmt" "fmt"
"nofx/crypto"
"nofx/logger" "nofx/logger"
"strings" "strings"
"time" "time"
"gorm.io/gorm"
) )
// AIModelStore AI model storage // AIModelStore AI model storage
type AIModelStore struct { type AIModelStore struct {
db *sql.DB db *gorm.DB
encryptFunc func(string) string
decryptFunc func(string) string
} }
// AIModel AI model configuration // AIModel AI model configuration
type AIModel struct { type AIModel struct {
ID string `json:"id"` ID string `gorm:"primaryKey" json:"id"`
UserID string `json:"user_id"` UserID string `gorm:"column:user_id;not null;default:default;index" json:"user_id"`
Name string `json:"name"` Name string `gorm:"not null" json:"name"`
Provider string `json:"provider"` Provider string `gorm:"not null" json:"provider"`
Enabled bool `json:"enabled"` Enabled bool `gorm:"default:false" json:"enabled"`
APIKey string `json:"apiKey"` APIKey crypto.EncryptedString `gorm:"column:api_key;default:''" json:"apiKey"`
CustomAPIURL string `json:"customApiUrl"` CustomAPIURL string `gorm:"column:custom_api_url;default:''" json:"customApiUrl"`
CustomModelName string `json:"customModelName"` CustomModelName string `gorm:"column:custom_model_name;default:''" json:"customModelName"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
}
func (AIModel) TableName() string { return "ai_models" }
// NewAIModelStore creates a new AIModelStore
func NewAIModelStore(db *gorm.DB) *AIModelStore {
return &AIModelStore{db: db}
} }
func (s *AIModelStore) initTables() error { func (s *AIModelStore) initTables() error {
_, err := s.db.Exec(` // For PostgreSQL with existing table, skip AutoMigrate
CREATE TABLE IF NOT EXISTS ai_models ( if s.db.Dialector.Name() == "postgres" {
id TEXT PRIMARY KEY, var tableExists int64
user_id TEXT NOT NULL DEFAULT 'default', s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'ai_models'`).Scan(&tableExists)
name TEXT NOT NULL, if tableExists > 0 {
provider TEXT NOT NULL, return nil
enabled BOOLEAN DEFAULT 0, }
api_key TEXT DEFAULT '',
custom_api_url TEXT DEFAULT '',
custom_model_name TEXT DEFAULT '',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return err
} }
return s.db.AutoMigrate(&AIModel{})
// Trigger
_, err = s.db.Exec(`
CREATE TRIGGER IF NOT EXISTS update_ai_models_updated_at
AFTER UPDATE ON ai_models
BEGIN
UPDATE ai_models SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
END
`)
if err != nil {
return err
}
// Backward compatibility: add potentially missing columns
s.db.Exec(`ALTER TABLE ai_models ADD COLUMN custom_api_url TEXT DEFAULT ''`)
s.db.Exec(`ALTER TABLE ai_models ADD COLUMN custom_model_name TEXT DEFAULT ''`)
return nil
} }
func (s *AIModelStore) initDefaultData() error { func (s *AIModelStore) initDefaultData() error {
@@ -73,51 +54,13 @@ func (s *AIModelStore) initDefaultData() error {
return nil return nil
} }
func (s *AIModelStore) encrypt(plaintext string) string {
if s.encryptFunc != nil {
return s.encryptFunc(plaintext)
}
return plaintext
}
func (s *AIModelStore) decrypt(encrypted string) string {
if s.decryptFunc != nil {
return s.decryptFunc(encrypted)
}
return encrypted
}
// List retrieves user's AI model list // List retrieves user's AI model list
func (s *AIModelStore) List(userID string) ([]*AIModel, error) { func (s *AIModelStore) List(userID string) ([]*AIModel, error) {
rows, err := s.db.Query(` var models []*AIModel
SELECT id, user_id, name, provider, enabled, api_key, err := s.db.Where("user_id = ?", userID).Order("id").Find(&models).Error
COALESCE(custom_api_url, '') as custom_api_url,
COALESCE(custom_model_name, '') as custom_model_name,
created_at, updated_at
FROM ai_models WHERE user_id = ? ORDER BY id
`, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
models := make([]*AIModel, 0)
for rows.Next() {
var model AIModel
var createdAt, updatedAt string
err := rows.Scan(
&model.ID, &model.UserID, &model.Name, &model.Provider,
&model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName,
&createdAt, &updatedAt,
)
if err != nil {
return nil, err
}
model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
model.APIKey = s.decrypt(model.APIKey)
models = append(models, &model)
}
return models, nil return models, nil
} }
@@ -140,27 +83,15 @@ func (s *AIModelStore) Get(userID, modelID string) (*AIModel, error) {
for _, uid := range candidates { for _, uid := range candidates {
var model AIModel var model AIModel
var createdAt, updatedAt string err := s.db.Where("user_id = ? AND id = ?", uid, modelID).First(&model).Error
err := s.db.QueryRow(`
SELECT id, user_id, name, provider, enabled, api_key,
COALESCE(custom_api_url, ''), COALESCE(custom_model_name, ''), created_at, updated_at
FROM ai_models WHERE user_id = ? AND id = ? LIMIT 1
`, uid, modelID).Scan(
&model.ID, &model.UserID, &model.Name, &model.Provider,
&model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName,
&createdAt, &updatedAt,
)
if err == nil { if err == nil {
model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
model.APIKey = s.decrypt(model.APIKey)
return &model, nil return &model, nil
} }
if !errors.Is(err, sql.ErrNoRows) { if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err return nil, err
} }
} }
return nil, sql.ErrNoRows return nil, gorm.ErrRecordNotFound
} }
// GetByID retrieves an AI model by ID only (for debate engine) // GetByID retrieves an AI model by ID only (for debate engine)
@@ -170,22 +101,10 @@ func (s *AIModelStore) GetByID(modelID string) (*AIModel, error) {
} }
var model AIModel var model AIModel
var createdAt, updatedAt string err := s.db.Where("id = ?", modelID).First(&model).Error
err := s.db.QueryRow(`
SELECT id, user_id, name, provider, enabled, api_key,
COALESCE(custom_api_url, ''), COALESCE(custom_model_name, ''), created_at, updated_at
FROM ai_models WHERE id = ? LIMIT 1
`, modelID).Scan(
&model.ID, &model.UserID, &model.Name, &model.Provider,
&model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName,
&createdAt, &updatedAt,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
model.APIKey = s.decrypt(model.APIKey)
return &model, nil return &model, nil
} }
@@ -198,7 +117,7 @@ func (s *AIModelStore) GetDefault(userID string) (*AIModel, error) {
if err == nil { if err == nil {
return model, nil return model, nil
} }
if !errors.Is(err, sql.ErrNoRows) { if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err return nil, err
} }
if userID != "default" { if userID != "default" {
@@ -209,23 +128,12 @@ func (s *AIModelStore) GetDefault(userID string) (*AIModel, error) {
func (s *AIModelStore) firstEnabled(userID string) (*AIModel, error) { func (s *AIModelStore) firstEnabled(userID string) (*AIModel, error) {
var model AIModel var model AIModel
var createdAt, updatedAt string err := s.db.Where("user_id = ? AND enabled = ?", userID, true).
err := s.db.QueryRow(` Order("updated_at DESC, id ASC").
SELECT id, user_id, name, provider, enabled, api_key, First(&model).Error
COALESCE(custom_api_url, ''), COALESCE(custom_model_name, ''), created_at, updated_at
FROM ai_models WHERE user_id = ? AND enabled = 1
ORDER BY datetime(updated_at) DESC, id ASC LIMIT 1
`, userID).Scan(
&model.ID, &model.UserID, &model.Name, &model.Provider,
&model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName,
&createdAt, &updatedAt,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
model.APIKey = s.decrypt(model.APIKey)
return &model, nil return &model, nil
} }
@@ -233,44 +141,38 @@ func (s *AIModelStore) firstEnabled(userID string) (*AIModel, error) {
// IMPORTANT: If apiKey is empty string, the existing API key will be preserved (not overwritten) // IMPORTANT: If apiKey is empty string, the existing API key will be preserved (not overwritten)
func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPIURL, customModelName string) error { func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPIURL, customModelName string) error {
// Try exact ID match first // Try exact ID match first
var existingID string var existingModel AIModel
err := s.db.QueryRow(`SELECT id FROM ai_models WHERE user_id = ? AND id = ? LIMIT 1`, userID, id).Scan(&existingID) err := s.db.Where("user_id = ? AND id = ?", userID, id).First(&existingModel).Error
if err == nil { if err == nil {
// If apiKey is empty, preserve the existing API key // Update existing model
if apiKey == "" { updates := map[string]interface{}{
_, err = s.db.Exec(` "enabled": enabled,
UPDATE ai_models SET enabled = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now') "custom_api_url": customAPIURL,
WHERE id = ? AND user_id = ? "custom_model_name": customModelName,
`, enabled, customAPIURL, customModelName, existingID, userID) "updated_at": time.Now(),
} else {
encryptedAPIKey := s.encrypt(apiKey)
_, err = s.db.Exec(`
UPDATE ai_models SET enabled = ?, api_key = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now')
WHERE id = ? AND user_id = ?
`, enabled, encryptedAPIKey, customAPIURL, customModelName, existingID, userID)
} }
return err // If apiKey is not empty, update it (encryption handled by crypto.EncryptedString)
if apiKey != "" {
updates["api_key"] = crypto.EncryptedString(apiKey)
}
return s.db.Model(&existingModel).Updates(updates).Error
} }
// Try legacy logic compatibility: use id as provider to search // Try legacy logic compatibility: use id as provider to search
provider := id provider := id
err = s.db.QueryRow(`SELECT id FROM ai_models WHERE user_id = ? AND provider = ? LIMIT 1`, userID, provider).Scan(&existingID) err = s.db.Where("user_id = ? AND provider = ?", userID, provider).First(&existingModel).Error
if err == nil { if err == nil {
logger.Warnf("⚠️ Using legacy provider matching to update model: %s -> %s", provider, existingID) logger.Warnf("⚠️ Using legacy provider matching to update model: %s -> %s", provider, existingModel.ID)
// If apiKey is empty, preserve the existing API key updates := map[string]interface{}{
if apiKey == "" { "enabled": enabled,
_, err = s.db.Exec(` "custom_api_url": customAPIURL,
UPDATE ai_models SET enabled = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now') "custom_model_name": customModelName,
WHERE id = ? AND user_id = ? "updated_at": time.Now(),
`, enabled, customAPIURL, customModelName, existingID, userID)
} else {
encryptedAPIKey := s.encrypt(apiKey)
_, err = s.db.Exec(`
UPDATE ai_models SET enabled = ?, api_key = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now')
WHERE id = ? AND user_id = ?
`, enabled, encryptedAPIKey, customAPIURL, customModelName, existingID, userID)
} }
return err if apiKey != "" {
updates["api_key"] = crypto.EncryptedString(apiKey)
}
return s.db.Model(&existingModel).Updates(updates).Error
} }
// Create new record // Create new record
@@ -285,9 +187,12 @@ func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPI
} }
} }
// Try to get name from existing model with same provider
var refModel AIModel
var name string var name string
err = s.db.QueryRow(`SELECT name FROM ai_models WHERE provider = ? LIMIT 1`, provider).Scan(&name) if err := s.db.Where("provider = ?", provider).First(&refModel).Error; err == nil {
if err != nil { name = refModel.Name
} else {
if provider == "deepseek" { if provider == "deepseek" {
name = "DeepSeek AI" name = "DeepSeek AI"
} else if provider == "qwen" { } else if provider == "qwen" {
@@ -303,19 +208,30 @@ func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPI
} }
logger.Infof("✓ Creating new AI model configuration: ID=%s, Provider=%s, Name=%s", newModelID, provider, name) logger.Infof("✓ Creating new AI model configuration: ID=%s, Provider=%s, Name=%s", newModelID, provider, name)
encryptedAPIKey := s.encrypt(apiKey) newModel := &AIModel{
_, err = s.db.Exec(` ID: newModelID,
INSERT INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url, custom_model_name, created_at, updated_at) UserID: userID,
VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now')) Name: name,
`, newModelID, userID, name, provider, enabled, encryptedAPIKey, customAPIURL, customModelName) Provider: provider,
return err Enabled: enabled,
APIKey: crypto.EncryptedString(apiKey),
CustomAPIURL: customAPIURL,
CustomModelName: customModelName,
}
return s.db.Create(newModel).Error
} }
// Create creates an AI model // Create creates an AI model
func (s *AIModelStore) Create(userID, id, name, provider string, enabled bool, apiKey, customAPIURL string) error { func (s *AIModelStore) Create(userID, id, name, provider string, enabled bool, apiKey, customAPIURL string) error {
_, err := s.db.Exec(` model := &AIModel{
INSERT OR IGNORE INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url) ID: id,
VALUES (?, ?, ?, ?, ?, ?, ?) UserID: userID,
`, id, userID, name, provider, enabled, apiKey, customAPIURL) Name: name,
return err Provider: provider,
Enabled: enabled,
APIKey: crypto.EncryptedString(apiKey),
CustomAPIURL: customAPIURL,
}
// Use FirstOrCreate to ignore if already exists
return s.db.Where("id = ?", id).FirstOrCreate(model).Error
} }
+339 -351
View File
@@ -1,15 +1,26 @@
package store package store
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"time" "time"
"gorm.io/gorm"
) )
// BacktestStore backtest data storage // BacktestStore backtest data storage
type BacktestStore struct { type BacktestStore struct {
db *sql.DB db *gorm.DB
}
// NewBacktestStore creates a new backtest store
func NewBacktestStore(db *gorm.DB) *BacktestStore {
return &BacktestStore{db: db}
}
// isPostgres checks if the database is PostgreSQL
func (s *BacktestStore) isPostgres() bool {
return s.db.Dialector.Name() == "postgres"
} }
// RunState backtest state // RunState backtest state
@@ -92,492 +103,469 @@ type RunIndexEntry struct {
UpdatedAtISO string `json:"updated_at"` UpdatedAtISO string `json:"updated_at"`
} }
// BacktestRun GORM model for backtest_runs table
type BacktestRun struct {
RunID string `gorm:"column:run_id;primaryKey"`
UserID string `gorm:"column:user_id;not null;default:''"`
ConfigJSON []byte `gorm:"column:config_json"`
State string `gorm:"column:state;not null;default:created"`
Label string `gorm:"column:label;default:''"`
SymbolCount int `gorm:"column:symbol_count;default:0"`
DecisionTF string `gorm:"column:decision_tf;default:''"`
ProcessedBars int `gorm:"column:processed_bars;default:0"`
ProgressPct float64 `gorm:"column:progress_pct;default:0"`
EquityLast float64 `gorm:"column:equity_last;default:0"`
MaxDrawdownPct float64 `gorm:"column:max_drawdown_pct;default:0"`
Liquidated bool `gorm:"column:liquidated;default:false"`
LiquidationNote string `gorm:"column:liquidation_note;default:''"`
PromptTemplate string `gorm:"column:prompt_template;default:''"`
CustomPrompt string `gorm:"column:custom_prompt;default:''"`
OverridePrompt bool `gorm:"column:override_prompt;default:false"`
AIProvider string `gorm:"column:ai_provider;default:''"`
AIModel string `gorm:"column:ai_model;default:''"`
LastError string `gorm:"column:last_error;default:''"`
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime"`
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime"`
}
func (BacktestRun) TableName() string {
return "backtest_runs"
}
// BacktestCheckpoint GORM model
type BacktestCheckpoint struct {
RunID string `gorm:"column:run_id;primaryKey"`
Payload []byte `gorm:"column:payload;not null"`
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime"`
}
func (BacktestCheckpoint) TableName() string {
return "backtest_checkpoints"
}
// BacktestEquity GORM model
type BacktestEquity struct {
ID int64 `gorm:"primaryKey;autoIncrement"`
RunID string `gorm:"column:run_id;not null;index:idx_backtest_equity_run_ts"`
TS int64 `gorm:"column:ts;not null;index:idx_backtest_equity_run_ts"`
Equity float64 `gorm:"column:equity;not null"`
Available float64 `gorm:"column:available;not null"`
PnL float64 `gorm:"column:pnl;not null"`
PnLPct float64 `gorm:"column:pnl_pct;not null"`
DDPct float64 `gorm:"column:dd_pct;not null"`
Cycle int `gorm:"column:cycle;not null"`
}
func (BacktestEquity) TableName() string {
return "backtest_equity"
}
// BacktestTrade GORM model
type BacktestTrade struct {
ID int64 `gorm:"primaryKey;autoIncrement"`
RunID string `gorm:"column:run_id;not null;index:idx_backtest_trades_run_ts"`
TS int64 `gorm:"column:ts;not null;index:idx_backtest_trades_run_ts"`
Symbol string `gorm:"column:symbol;not null"`
Action string `gorm:"column:action;not null"`
Side string `gorm:"column:side;default:''"`
Qty float64 `gorm:"column:qty;default:0"`
Price float64 `gorm:"column:price;default:0"`
Fee float64 `gorm:"column:fee;default:0"`
Slippage float64 `gorm:"column:slippage;default:0"`
OrderValue float64 `gorm:"column:order_value;default:0"`
RealizedPnL float64 `gorm:"column:realized_pnl;default:0"`
Leverage int `gorm:"column:leverage;default:0"`
Cycle int `gorm:"column:cycle;default:0"`
PositionAfter float64 `gorm:"column:position_after;default:0"`
Liquidation bool `gorm:"column:liquidation;default:false"`
Note string `gorm:"column:note;default:''"`
}
func (BacktestTrade) TableName() string {
return "backtest_trades"
}
// BacktestMetrics GORM model
type BacktestMetrics struct {
RunID string `gorm:"column:run_id;primaryKey"`
Payload []byte `gorm:"column:payload;not null"`
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime"`
}
func (BacktestMetrics) TableName() string {
return "backtest_metrics"
}
// BacktestDecision GORM model
type BacktestDecision struct {
ID int64 `gorm:"primaryKey;autoIncrement"`
RunID string `gorm:"column:run_id;not null;index:idx_backtest_decisions_run_cycle"`
Cycle int `gorm:"column:cycle;not null;index:idx_backtest_decisions_run_cycle"`
Payload []byte `gorm:"column:payload;not null"`
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime"`
}
func (BacktestDecision) TableName() string {
return "backtest_decisions"
}
// initTables initializes backtest related tables // initTables initializes backtest related tables
func (s *BacktestStore) initTables() error { func (s *BacktestStore) initTables() error {
queries := []string{ // For PostgreSQL with existing tables, skip AutoMigrate to avoid type conflicts
// Backtest runs main table if s.db.Dialector.Name() == "postgres" {
`CREATE TABLE IF NOT EXISTS backtest_runs ( var tableExists int64
run_id TEXT PRIMARY KEY, s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'backtest_runs'`).Scan(&tableExists)
user_id TEXT NOT NULL DEFAULT '',
config_json TEXT NOT NULL DEFAULT '',
state TEXT NOT NULL DEFAULT 'created',
label TEXT DEFAULT '',
symbol_count INTEGER DEFAULT 0,
decision_tf TEXT DEFAULT '',
processed_bars INTEGER DEFAULT 0,
progress_pct REAL DEFAULT 0,
equity_last REAL DEFAULT 0,
max_drawdown_pct REAL DEFAULT 0,
liquidated BOOLEAN DEFAULT 0,
liquidation_note TEXT DEFAULT '',
prompt_template TEXT DEFAULT '',
custom_prompt TEXT DEFAULT '',
override_prompt BOOLEAN DEFAULT 0,
ai_provider TEXT DEFAULT '',
ai_model TEXT DEFAULT '',
last_error TEXT DEFAULT '',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`,
// Backtest checkpoints if tableExists > 0 {
`CREATE TABLE IF NOT EXISTS backtest_checkpoints ( // Tables exist - just ensure indexes exist
run_id TEXT PRIMARY KEY, s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_backtest_equity_run_ts ON backtest_equity(run_id, ts)`)
payload BLOB NOT NULL, s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_backtest_trades_run_ts ON backtest_trades(run_id, ts)`)
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_backtest_decisions_run_cycle ON backtest_decisions(run_id, cycle)`)
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE return nil
)`,
// Backtest equity curve
`CREATE TABLE IF NOT EXISTS backtest_equity (
id INTEGER PRIMARY KEY AUTOINCREMENT,
run_id TEXT NOT NULL,
ts INTEGER NOT NULL,
equity REAL NOT NULL,
available REAL NOT NULL,
pnl REAL NOT NULL,
pnl_pct REAL NOT NULL,
dd_pct REAL NOT NULL,
cycle INTEGER NOT NULL,
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
)`,
// Backtest trade records
`CREATE TABLE IF NOT EXISTS backtest_trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
run_id TEXT NOT NULL,
ts INTEGER NOT NULL,
symbol TEXT NOT NULL,
action TEXT NOT NULL,
side TEXT DEFAULT '',
qty REAL DEFAULT 0,
price REAL DEFAULT 0,
fee REAL DEFAULT 0,
slippage REAL DEFAULT 0,
order_value REAL DEFAULT 0,
realized_pnl REAL DEFAULT 0,
leverage INTEGER DEFAULT 0,
cycle INTEGER DEFAULT 0,
position_after REAL DEFAULT 0,
liquidation BOOLEAN DEFAULT 0,
note TEXT DEFAULT '',
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
)`,
// Backtest metrics
`CREATE TABLE IF NOT EXISTS backtest_metrics (
run_id TEXT PRIMARY KEY,
payload BLOB NOT NULL,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
)`,
// Backtest decision logs
`CREATE TABLE IF NOT EXISTS backtest_decisions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
run_id TEXT NOT NULL,
cycle INTEGER NOT NULL,
payload BLOB NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
)`,
// Indexes
`CREATE INDEX IF NOT EXISTS idx_backtest_runs_state ON backtest_runs(state, updated_at)`,
`CREATE INDEX IF NOT EXISTS idx_backtest_equity_run_ts ON backtest_equity(run_id, ts)`,
`CREATE INDEX IF NOT EXISTS idx_backtest_trades_run_ts ON backtest_trades(run_id, ts)`,
`CREATE INDEX IF NOT EXISTS idx_backtest_decisions_run_cycle ON backtest_decisions(run_id, cycle)`,
}
for _, query := range queries {
if _, err := s.db.Exec(query); err != nil {
return fmt.Errorf("failed to execute SQL: %w", err)
} }
} }
// Add potentially missing columns (backward compatibility) // AutoMigrate all backtest tables
s.addColumnIfNotExists("backtest_runs", "label", "TEXT DEFAULT ''") if err := s.db.AutoMigrate(
s.addColumnIfNotExists("backtest_runs", "last_error", "TEXT DEFAULT ''") &BacktestRun{},
s.addColumnIfNotExists("backtest_trades", "leverage", "INTEGER DEFAULT 0") &BacktestCheckpoint{},
&BacktestEquity{},
&BacktestTrade{},
&BacktestMetrics{},
&BacktestDecision{},
); err != nil {
return fmt.Errorf("failed to migrate backtest tables: %w", err)
}
return nil return nil
} }
func (s *BacktestStore) addColumnIfNotExists(table, column, definition string) {
rows, err := s.db.Query(fmt.Sprintf("PRAGMA table_info(%s)", table))
if err != nil {
return
}
defer rows.Close()
for rows.Next() {
var cid int
var name, ctype string
var notnull, pk int
var dflt interface{}
if err := rows.Scan(&cid, &name, &ctype, &notnull, &dflt, &pk); err != nil {
continue
}
if name == column {
return // Column already exists
}
}
s.db.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", table, column, definition))
}
// SaveCheckpoint saves checkpoint // SaveCheckpoint saves checkpoint
func (s *BacktestStore) SaveCheckpoint(runID string, payload []byte) error { func (s *BacktestStore) SaveCheckpoint(runID string, payload []byte) error {
_, err := s.db.Exec(` checkpoint := BacktestCheckpoint{
INSERT INTO backtest_checkpoints (run_id, payload, updated_at) RunID: runID,
VALUES (?, ?, CURRENT_TIMESTAMP) Payload: payload,
ON CONFLICT(run_id) DO UPDATE SET payload=excluded.payload, updated_at=CURRENT_TIMESTAMP }
`, runID, payload) return s.db.Save(&checkpoint).Error
return err
} }
// LoadCheckpoint loads checkpoint // LoadCheckpoint loads checkpoint
func (s *BacktestStore) LoadCheckpoint(runID string) ([]byte, error) { func (s *BacktestStore) LoadCheckpoint(runID string) ([]byte, error) {
var payload []byte var checkpoint BacktestCheckpoint
err := s.db.QueryRow(`SELECT payload FROM backtest_checkpoints WHERE run_id = ?`, runID).Scan(&payload) err := s.db.Where("run_id = ?", runID).First(&checkpoint).Error
return payload, err if err != nil {
return nil, err
}
return checkpoint.Payload, nil
} }
// SaveRunMetadata saves run metadata // SaveRunMetadata saves run metadata
func (s *BacktestStore) SaveRunMetadata(meta *RunMetadata) error { func (s *BacktestStore) SaveRunMetadata(meta *RunMetadata) error {
created := meta.CreatedAt.UTC().Format(time.RFC3339) run := BacktestRun{
updated := meta.UpdatedAt.UTC().Format(time.RFC3339) RunID: meta.RunID,
userID := meta.UserID UserID: meta.UserID,
State: string(meta.State),
if _, err := s.db.Exec(` Label: meta.Label,
INSERT INTO backtest_runs (run_id, user_id, label, last_error, created_at, updated_at) LastError: meta.LastError,
VALUES (?, ?, ?, ?, ?, ?) SymbolCount: meta.Summary.SymbolCount,
ON CONFLICT(run_id) DO NOTHING DecisionTF: meta.Summary.DecisionTF,
`, meta.RunID, userID, meta.Label, meta.LastError, created, updated); err != nil { ProcessedBars: meta.Summary.ProcessedBars,
return err ProgressPct: meta.Summary.ProgressPct,
EquityLast: meta.Summary.EquityLast,
MaxDrawdownPct: meta.Summary.MaxDrawdownPct,
Liquidated: meta.Summary.Liquidated,
LiquidationNote: meta.Summary.LiquidationNote,
CreatedAt: meta.CreatedAt,
UpdatedAt: meta.UpdatedAt,
} }
return s.db.Save(&run).Error
_, err := s.db.Exec(`
UPDATE backtest_runs
SET user_id = ?, state = ?, symbol_count = ?, decision_tf = ?, processed_bars = ?,
progress_pct = ?, equity_last = ?, max_drawdown_pct = ?, liquidated = ?,
liquidation_note = ?, label = ?, last_error = ?, updated_at = ?
WHERE run_id = ?
`, userID, string(meta.State), meta.Summary.SymbolCount, meta.Summary.DecisionTF,
meta.Summary.ProcessedBars, meta.Summary.ProgressPct, meta.Summary.EquityLast,
meta.Summary.MaxDrawdownPct, meta.Summary.Liquidated, meta.Summary.LiquidationNote,
meta.Label, meta.LastError, updated, meta.RunID)
return err
} }
// LoadRunMetadata loads run metadata // LoadRunMetadata loads run metadata
func (s *BacktestStore) LoadRunMetadata(runID string) (*RunMetadata, error) { func (s *BacktestStore) LoadRunMetadata(runID string) (*RunMetadata, error) {
var ( var run BacktestRun
userID string err := s.db.Where("run_id = ?", runID).First(&run).Error
state string
label string
lastErr string
symbolCount int
decisionTF string
processedBars int
progressPct float64
equityLast float64
maxDD float64
liquidated bool
liquidationNote string
createdISO string
updatedISO string
)
err := s.db.QueryRow(`
SELECT user_id, state, label, last_error, symbol_count, decision_tf, processed_bars,
progress_pct, equity_last, max_drawdown_pct, liquidated, liquidation_note,
created_at, updated_at
FROM backtest_runs WHERE run_id = ?
`, runID).Scan(&userID, &state, &label, &lastErr, &symbolCount, &decisionTF,
&processedBars, &progressPct, &equityLast, &maxDD, &liquidated, &liquidationNote,
&createdISO, &updatedISO)
if err != nil { if err != nil {
return nil, err return nil, err
} }
meta := &RunMetadata{ return &RunMetadata{
RunID: runID, RunID: run.RunID,
UserID: userID, UserID: run.UserID,
Version: 1, Version: 1,
State: RunState(state), State: RunState(run.State),
Label: label, Label: run.Label,
LastError: lastErr, LastError: run.LastError,
Summary: RunSummary{ Summary: RunSummary{
SymbolCount: symbolCount, SymbolCount: run.SymbolCount,
DecisionTF: decisionTF, DecisionTF: run.DecisionTF,
ProcessedBars: processedBars, ProcessedBars: run.ProcessedBars,
ProgressPct: progressPct, ProgressPct: run.ProgressPct,
EquityLast: equityLast, EquityLast: run.EquityLast,
MaxDrawdownPct: maxDD, MaxDrawdownPct: run.MaxDrawdownPct,
Liquidated: liquidated, Liquidated: run.Liquidated,
LiquidationNote: liquidationNote, LiquidationNote: run.LiquidationNote,
}, },
} CreatedAt: run.CreatedAt,
UpdatedAt: run.UpdatedAt,
meta.CreatedAt, _ = time.Parse(time.RFC3339, createdISO) }, nil
meta.UpdatedAt, _ = time.Parse(time.RFC3339, updatedISO)
return meta, nil
} }
// ListRunIDs lists all run IDs // ListRunIDs lists all run IDs
func (s *BacktestStore) ListRunIDs() ([]string, error) { func (s *BacktestStore) ListRunIDs() ([]string, error) {
rows, err := s.db.Query(`SELECT run_id FROM backtest_runs ORDER BY datetime(updated_at) DESC`) var runs []BacktestRun
err := s.db.Order("updated_at DESC").Find(&runs).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
var ids []string ids := make([]string, len(runs))
for rows.Next() { for i, run := range runs {
var runID string ids[i] = run.RunID
if err := rows.Scan(&runID); err != nil {
return nil, err
}
ids = append(ids, runID)
} }
return ids, rows.Err() return ids, nil
} }
// AppendEquityPoint appends equity point // AppendEquityPoint appends equity point
func (s *BacktestStore) AppendEquityPoint(runID string, point EquityPoint) error { func (s *BacktestStore) AppendEquityPoint(runID string, point EquityPoint) error {
_, err := s.db.Exec(` eq := BacktestEquity{
INSERT INTO backtest_equity (run_id, ts, equity, available, pnl, pnl_pct, dd_pct, cycle) RunID: runID,
VALUES (?, ?, ?, ?, ?, ?, ?, ?) TS: point.Timestamp,
`, runID, point.Timestamp, point.Equity, point.Available, point.PnL, Equity: point.Equity,
point.PnLPct, point.DrawdownPct, point.Cycle) Available: point.Available,
return err PnL: point.PnL,
PnLPct: point.PnLPct,
DDPct: point.DrawdownPct,
Cycle: point.Cycle,
}
return s.db.Create(&eq).Error
} }
// LoadEquityPoints loads equity points // LoadEquityPoints loads equity points
func (s *BacktestStore) LoadEquityPoints(runID string) ([]EquityPoint, error) { func (s *BacktestStore) LoadEquityPoints(runID string) ([]EquityPoint, error) {
rows, err := s.db.Query(` var eqs []BacktestEquity
SELECT ts, equity, available, pnl, pnl_pct, dd_pct, cycle err := s.db.Where("run_id = ?", runID).Order("ts ASC").Find(&eqs).Error
FROM backtest_equity WHERE run_id = ? ORDER BY ts ASC
`, runID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
points := make([]EquityPoint, 0) points := make([]EquityPoint, len(eqs))
for rows.Next() { for i, eq := range eqs {
var point EquityPoint points[i] = EquityPoint{
if err := rows.Scan(&point.Timestamp, &point.Equity, &point.Available, Timestamp: eq.TS,
&point.PnL, &point.PnLPct, &point.DrawdownPct, &point.Cycle); err != nil { Equity: eq.Equity,
return nil, err Available: eq.Available,
PnL: eq.PnL,
PnLPct: eq.PnLPct,
DrawdownPct: eq.DDPct,
Cycle: eq.Cycle,
} }
points = append(points, point)
} }
return points, rows.Err() return points, nil
} }
// AppendTradeEvent appends trade event // AppendTradeEvent appends trade event
func (s *BacktestStore) AppendTradeEvent(runID string, event TradeEvent) error { func (s *BacktestStore) AppendTradeEvent(runID string, event TradeEvent) error {
_, err := s.db.Exec(` trade := BacktestTrade{
INSERT INTO backtest_trades (run_id, ts, symbol, action, side, qty, price, fee, RunID: runID,
slippage, order_value, realized_pnl, leverage, cycle, TS: event.Timestamp,
position_after, liquidation, note) Symbol: event.Symbol,
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) Action: event.Action,
`, runID, event.Timestamp, event.Symbol, event.Action, event.Side, event.Quantity, Side: event.Side,
event.Price, event.Fee, event.Slippage, event.OrderValue, event.RealizedPnL, Qty: event.Quantity,
event.Leverage, event.Cycle, event.PositionAfter, event.LiquidationFlag, event.Note) Price: event.Price,
return err Fee: event.Fee,
Slippage: event.Slippage,
OrderValue: event.OrderValue,
RealizedPnL: event.RealizedPnL,
Leverage: event.Leverage,
Cycle: event.Cycle,
PositionAfter: event.PositionAfter,
Liquidation: event.LiquidationFlag,
Note: event.Note,
}
return s.db.Create(&trade).Error
} }
// LoadTradeEvents loads trade events // LoadTradeEvents loads trade events
func (s *BacktestStore) LoadTradeEvents(runID string) ([]TradeEvent, error) { func (s *BacktestStore) LoadTradeEvents(runID string) ([]TradeEvent, error) {
rows, err := s.db.Query(` var trades []BacktestTrade
SELECT ts, symbol, action, side, qty, price, fee, slippage, order_value, err := s.db.Where("run_id = ?", runID).Order("ts ASC").Find(&trades).Error
realized_pnl, leverage, cycle, position_after, liquidation, note
FROM backtest_trades WHERE run_id = ? ORDER BY ts ASC
`, runID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
events := make([]TradeEvent, 0) events := make([]TradeEvent, len(trades))
for rows.Next() { for i, trade := range trades {
var event TradeEvent events[i] = TradeEvent{
if err := rows.Scan(&event.Timestamp, &event.Symbol, &event.Action, &event.Side, Timestamp: trade.TS,
&event.Quantity, &event.Price, &event.Fee, &event.Slippage, &event.OrderValue, Symbol: trade.Symbol,
&event.RealizedPnL, &event.Leverage, &event.Cycle, &event.PositionAfter, Action: trade.Action,
&event.LiquidationFlag, &event.Note); err != nil { Side: trade.Side,
return nil, err Quantity: trade.Qty,
Price: trade.Price,
Fee: trade.Fee,
Slippage: trade.Slippage,
OrderValue: trade.OrderValue,
RealizedPnL: trade.RealizedPnL,
Leverage: trade.Leverage,
Cycle: trade.Cycle,
PositionAfter: trade.PositionAfter,
LiquidationFlag: trade.Liquidation,
Note: trade.Note,
} }
events = append(events, event)
} }
return events, rows.Err() return events, nil
} }
// SaveMetrics saves metrics // SaveMetrics saves metrics
func (s *BacktestStore) SaveMetrics(runID string, payload []byte) error { func (s *BacktestStore) SaveMetrics(runID string, payload []byte) error {
_, err := s.db.Exec(` metrics := BacktestMetrics{
INSERT INTO backtest_metrics (run_id, payload, updated_at) RunID: runID,
VALUES (?, ?, CURRENT_TIMESTAMP) Payload: payload,
ON CONFLICT(run_id) DO UPDATE SET payload=excluded.payload, updated_at=CURRENT_TIMESTAMP }
`, runID, payload) return s.db.Save(&metrics).Error
return err
} }
// LoadMetrics loads metrics // LoadMetrics loads metrics
func (s *BacktestStore) LoadMetrics(runID string) ([]byte, error) { func (s *BacktestStore) LoadMetrics(runID string) ([]byte, error) {
var payload []byte var metrics BacktestMetrics
err := s.db.QueryRow(`SELECT payload FROM backtest_metrics WHERE run_id = ?`, runID).Scan(&payload) err := s.db.Where("run_id = ?", runID).First(&metrics).Error
return payload, err if err != nil {
return nil, err
}
return metrics.Payload, nil
} }
// SaveDecisionRecord saves decision record // SaveDecisionRecord saves decision record
func (s *BacktestStore) SaveDecisionRecord(runID string, cycle int, payload []byte) error { func (s *BacktestStore) SaveDecisionRecord(runID string, cycle int, payload []byte) error {
_, err := s.db.Exec(` decision := BacktestDecision{
INSERT INTO backtest_decisions (run_id, cycle, payload) RunID: runID,
VALUES (?, ?, ?) Cycle: cycle,
`, runID, cycle, payload) Payload: payload,
return err }
return s.db.Create(&decision).Error
} }
// LoadDecisionRecords loads decision records // LoadDecisionRecords loads decision records
func (s *BacktestStore) LoadDecisionRecords(runID string, limit, offset int) ([]json.RawMessage, error) { func (s *BacktestStore) LoadDecisionRecords(runID string, limit, offset int) ([]json.RawMessage, error) {
rows, err := s.db.Query(` var decisions []BacktestDecision
SELECT payload FROM backtest_decisions err := s.db.Where("run_id = ?", runID).
WHERE run_id = ? Order("id DESC").
ORDER BY id DESC Limit(limit).
LIMIT ? OFFSET ? Offset(offset).
`, runID, limit, offset) Find(&decisions).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
records := make([]json.RawMessage, 0, limit) records := make([]json.RawMessage, len(decisions))
for rows.Next() { for i, d := range decisions {
var payload []byte records[i] = json.RawMessage(d.Payload)
if err := rows.Scan(&payload); err != nil {
return nil, err
}
records = append(records, json.RawMessage(payload))
} }
return records, rows.Err() return records, nil
} }
// LoadLatestDecision loads latest decision // LoadLatestDecision loads latest decision
func (s *BacktestStore) LoadLatestDecision(runID string, cycle int) ([]byte, error) { func (s *BacktestStore) LoadLatestDecision(runID string, cycle int) ([]byte, error) {
var query string var decision BacktestDecision
var args []interface{} query := s.db.Where("run_id = ?", runID)
if cycle > 0 { if cycle > 0 {
query = `SELECT payload FROM backtest_decisions WHERE run_id = ? AND cycle = ? ORDER BY datetime(created_at) DESC LIMIT 1` query = query.Where("cycle = ?", cycle)
args = []interface{}{runID, cycle}
} else {
query = `SELECT payload FROM backtest_decisions WHERE run_id = ? ORDER BY datetime(created_at) DESC LIMIT 1`
args = []interface{}{runID}
} }
err := query.Order("created_at DESC").First(&decision).Error
var payload []byte if err != nil {
err := s.db.QueryRow(query, args...).Scan(&payload) return nil, err
return payload, err }
return decision.Payload, nil
} }
// UpdateProgress updates progress // UpdateProgress updates progress
func (s *BacktestStore) UpdateProgress(runID string, progressPct, equity float64, barIndex int, liquidated bool) error { func (s *BacktestStore) UpdateProgress(runID string, progressPct, equity float64, barIndex int, liquidated bool) error {
_, err := s.db.Exec(` return s.db.Model(&BacktestRun{}).Where("run_id = ?", runID).Updates(map[string]interface{}{
UPDATE backtest_runs "progress_pct": progressPct,
SET progress_pct = ?, equity_last = ?, processed_bars = ?, liquidated = ?, updated_at = CURRENT_TIMESTAMP "equity_last": equity,
WHERE run_id = ? "processed_bars": barIndex,
`, progressPct, equity, barIndex, liquidated, runID) "liquidated": liquidated,
return err }).Error
} }
// ListIndexEntries lists index entries // ListIndexEntries lists index entries
func (s *BacktestStore) ListIndexEntries() ([]RunIndexEntry, error) { func (s *BacktestStore) ListIndexEntries() ([]RunIndexEntry, error) {
rows, err := s.db.Query(` var runs []BacktestRun
SELECT run_id, state, symbol_count, decision_tf, equity_last, max_drawdown_pct, err := s.db.Order("updated_at DESC").Find(&runs).Error
created_at, updated_at, config_json
FROM backtest_runs
ORDER BY datetime(updated_at) DESC
`)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
var entries []RunIndexEntry entries := make([]RunIndexEntry, len(runs))
for rows.Next() { for i, run := range runs {
var entry RunIndexEntry entry := RunIndexEntry{
var symbolCnt int RunID: run.RunID,
var cfgJSON []byte State: run.State,
var createdISO, updatedISO string DecisionTF: run.DecisionTF,
EquityLast: run.EquityLast,
if err := rows.Scan(&entry.RunID, &entry.State, &symbolCnt, &entry.DecisionTF, MaxDrawdownPct: run.MaxDrawdownPct,
&entry.EquityLast, &entry.MaxDrawdownPct, &createdISO, &updatedISO, &cfgJSON); err != nil { CreatedAtISO: run.CreatedAt.Format(time.RFC3339),
return nil, err UpdatedAtISO: run.UpdatedAt.Format(time.RFC3339),
Symbols: make([]string, 0, run.SymbolCount),
} }
entry.CreatedAtISO = createdISO if len(run.ConfigJSON) > 0 {
entry.UpdatedAtISO = updatedISO
entry.Symbols = make([]string, 0, symbolCnt)
// Try to extract more information from config
if len(cfgJSON) > 0 {
var cfg struct { var cfg struct {
Symbols []string `json:"symbols"` Symbols []string `json:"symbols"`
StartTS int64 `json:"start_ts"` StartTS int64 `json:"start_ts"`
EndTS int64 `json:"end_ts"` EndTS int64 `json:"end_ts"`
} }
if json.Unmarshal(cfgJSON, &cfg) == nil { if json.Unmarshal(run.ConfigJSON, &cfg) == nil {
entry.Symbols = cfg.Symbols entry.Symbols = cfg.Symbols
entry.StartTS = cfg.StartTS entry.StartTS = cfg.StartTS
entry.EndTS = cfg.EndTS entry.EndTS = cfg.EndTS
} }
} }
entries = append(entries, entry) entries[i] = entry
} }
return entries, rows.Err() return entries, nil
} }
// DeleteRun deletes run // DeleteRun deletes run
func (s *BacktestStore) DeleteRun(runID string) error { func (s *BacktestStore) DeleteRun(runID string) error {
_, err := s.db.Exec(`DELETE FROM backtest_runs WHERE run_id = ?`, runID) // Delete related records first (cascade may not work in all cases)
return err s.db.Where("run_id = ?", runID).Delete(&BacktestCheckpoint{})
s.db.Where("run_id = ?", runID).Delete(&BacktestEquity{})
s.db.Where("run_id = ?", runID).Delete(&BacktestTrade{})
s.db.Where("run_id = ?", runID).Delete(&BacktestMetrics{})
s.db.Where("run_id = ?", runID).Delete(&BacktestDecision{})
return s.db.Where("run_id = ?", runID).Delete(&BacktestRun{}).Error
} }
// SaveConfig saves config // SaveConfig saves config
func (s *BacktestStore) SaveConfig(runID, userID, template, customPrompt, provider, model string, override bool, configJSON []byte) error { func (s *BacktestStore) SaveConfig(runID, userID, template, customPrompt, provider, model string, override bool, configJSON []byte) error {
now := time.Now().UTC().Format(time.RFC3339)
if userID == "" { if userID == "" {
userID = "default" userID = "default"
} }
_, err := s.db.Exec(` run := BacktestRun{
INSERT INTO backtest_runs (run_id, user_id, config_json, prompt_template, custom_prompt, RunID: runID,
override_prompt, ai_provider, ai_model, created_at, updated_at) UserID: userID,
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ConfigJSON: configJSON,
ON CONFLICT(run_id) DO NOTHING PromptTemplate: template,
`, runID, userID, configJSON, template, customPrompt, override, provider, model, now, now) CustomPrompt: customPrompt,
if err != nil { OverridePrompt: override,
return err AIProvider: provider,
AIModel: model,
} }
return s.db.Save(&run).Error
_, err = s.db.Exec(`
UPDATE backtest_runs
SET user_id = ?, config_json = ?, prompt_template = ?, custom_prompt = ?,
override_prompt = ?, ai_provider = ?, ai_model = ?, updated_at = CURRENT_TIMESTAMP
WHERE run_id = ?
`, userID, configJSON, template, customPrompt, override, provider, model, runID)
return err
} }
// LoadConfig loads config // LoadConfig loads config
func (s *BacktestStore) LoadConfig(runID string) ([]byte, error) { func (s *BacktestStore) LoadConfig(runID string) ([]byte, error) {
var payload []byte var run BacktestRun
err := s.db.QueryRow(`SELECT config_json FROM backtest_runs WHERE run_id = ?`, runID).Scan(&payload) err := s.db.Where("run_id = ?", runID).First(&run).Error
return payload, err if err != nil {
return nil, err
}
return run.ConfigJSON, nil
} }
+225 -491
View File
@@ -1,12 +1,11 @@
package store package store
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"fmt"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"gorm.io/gorm"
) )
// DebateStatus represents the status of a debate session // DebateStatus represents the status of a debate session
@@ -49,30 +48,6 @@ var PersonalityEmojis = map[DebatePersonality]string{
PersonalityRiskManager: "🛡️", PersonalityRiskManager: "🛡️",
} }
// DebateSession represents a debate session
type DebateSession struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Name string `json:"name"`
StrategyID string `json:"strategy_id"`
Status DebateStatus `json:"status"`
Symbol string `json:"symbol"` // Primary symbol (for backward compat, may be empty for multi-coin)
MaxRounds int `json:"max_rounds"`
CurrentRound int `json:"current_round"`
IntervalMinutes int `json:"interval_minutes"` // Debate interval (5, 15, 30, 60 minutes)
PromptVariant string `json:"prompt_variant"` // balanced/aggressive/conservative/scalping
FinalDecision *DebateDecision `json:"final_decision,omitempty"` // Single decision (backward compat)
FinalDecisions []*DebateDecision `json:"final_decisions,omitempty"` // Multi-coin decisions
AutoExecute bool `json:"auto_execute"`
TraderID string `json:"trader_id,omitempty"` // Trader to use for auto-execute
// OI Ranking data options
EnableOIRanking bool `json:"enable_oi_ranking"` // Whether to include OI ranking data
OIRankingLimit int `json:"oi_ranking_limit"` // Number of OI ranking entries (default 10)
OIDuration string `json:"oi_duration"` // Duration for OI data (1h, 4h, 24h, etc.)
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// DebateDecision represents a trading decision from the debate // DebateDecision represents a trading decision from the debate
type DebateDecision struct { type DebateDecision struct {
Action string `json:"action"` // open_long/open_short/close_long/close_short/hold/wait Action string `json:"action"` // open_long/open_short/close_long/close_short/hold/wait
@@ -86,178 +61,187 @@ type DebateDecision struct {
Reasoning string `json:"reasoning"` // Brief reasoning Reasoning string `json:"reasoning"` // Brief reasoning
// Execution tracking // Execution tracking
Executed bool `json:"executed"` // Whether this decision was executed Executed bool `json:"executed"` // Whether this decision was executed
ExecutedAt time.Time `json:"executed_at,omitempty"` // When it was executed ExecutedAt time.Time `json:"executed_at,omitempty"` // When it was executed
OrderID string `json:"order_id,omitempty"` // Exchange order ID OrderID string `json:"order_id,omitempty"` // Exchange order ID
Error string `json:"error,omitempty"` // Execution error if any Error string `json:"error,omitempty"` // Execution error if any
} }
// DebateSession represents a debate session (API struct)
type DebateSession struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Name string `json:"name"`
StrategyID string `json:"strategy_id"`
Status DebateStatus `json:"status"`
Symbol string `json:"symbol"` // Primary symbol (for backward compat, may be empty for multi-coin)
MaxRounds int `json:"max_rounds"`
CurrentRound int `json:"current_round"`
IntervalMinutes int `json:"interval_minutes"` // Debate interval (5, 15, 30, 60 minutes)
PromptVariant string `json:"prompt_variant"` // balanced/aggressive/conservative/scalping
FinalDecision *DebateDecision `json:"final_decision,omitempty"` // Single decision (backward compat)
FinalDecisions []*DebateDecision `json:"final_decisions,omitempty"` // Multi-coin decisions
AutoExecute bool `json:"auto_execute"`
TraderID string `json:"trader_id,omitempty"` // Trader to use for auto-execute
// OI Ranking data options
EnableOIRanking bool `json:"enable_oi_ranking"` // Whether to include OI ranking data
OIRankingLimit int `json:"oi_ranking_limit"` // Number of OI ranking entries (default 10)
OIDuration string `json:"oi_duration"` // Duration for OI data (1h, 4h, 24h, etc.)
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// DebateSessionDB is the GORM model for debate_sessions
type DebateSessionDB struct {
ID string `gorm:"column:id;primaryKey"`
UserID string `gorm:"column:user_id;not null;index"`
Name string `gorm:"column:name;not null"`
StrategyID string `gorm:"column:strategy_id;not null"`
Status DebateStatus `gorm:"column:status;not null;default:pending;index"`
Symbol string `gorm:"column:symbol;not null"`
MaxRounds int `gorm:"column:max_rounds;default:3"`
CurrentRound int `gorm:"column:current_round;default:0"`
IntervalMinutes int `gorm:"column:interval_minutes;default:5"`
PromptVariant string `gorm:"column:prompt_variant;default:balanced"`
FinalDecision string `gorm:"column:final_decision"` // JSON string
AutoExecute bool `gorm:"column:auto_execute;default:false"`
TraderID string `gorm:"column:trader_id"`
EnableOIRanking bool `gorm:"column:enable_oi_ranking;default:false"`
OIRankingLimit int `gorm:"column:oi_ranking_limit;default:10"`
OIDuration string `gorm:"column:oi_duration;default:1h"`
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime"`
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime"`
}
func (DebateSessionDB) TableName() string {
return "debate_sessions"
}
func (db *DebateSessionDB) toSession() *DebateSession {
s := &DebateSession{
ID: db.ID,
UserID: db.UserID,
Name: db.Name,
StrategyID: db.StrategyID,
Status: db.Status,
Symbol: db.Symbol,
MaxRounds: db.MaxRounds,
CurrentRound: db.CurrentRound,
IntervalMinutes: db.IntervalMinutes,
PromptVariant: db.PromptVariant,
AutoExecute: db.AutoExecute,
TraderID: db.TraderID,
EnableOIRanking: db.EnableOIRanking,
OIRankingLimit: db.OIRankingLimit,
OIDuration: db.OIDuration,
CreatedAt: db.CreatedAt,
UpdatedAt: db.UpdatedAt,
}
// Set defaults
if s.IntervalMinutes == 0 {
s.IntervalMinutes = 5
}
if s.PromptVariant == "" {
s.PromptVariant = "balanced"
}
if s.OIRankingLimit == 0 {
s.OIRankingLimit = 10
}
if s.OIDuration == "" {
s.OIDuration = "1h"
}
// Parse final decision
if db.FinalDecision != "" {
var decision DebateDecision
if json.Unmarshal([]byte(db.FinalDecision), &decision) == nil {
s.FinalDecision = &decision
}
}
return s
}
// DebateParticipant represents an AI participant in a debate // DebateParticipant represents an AI participant in a debate
type DebateParticipant struct { type DebateParticipant struct {
ID string `json:"id"` ID string `gorm:"column:id;primaryKey" json:"id"`
SessionID string `json:"session_id"` SessionID string `gorm:"column:session_id;not null;index" json:"session_id"`
AIModelID string `json:"ai_model_id"` AIModelID string `gorm:"column:ai_model_id;not null" json:"ai_model_id"`
AIModelName string `json:"ai_model_name"` AIModelName string `gorm:"column:ai_model_name;not null" json:"ai_model_name"`
Provider string `json:"provider"` Provider string `gorm:"column:provider;not null" json:"provider"`
Personality DebatePersonality `json:"personality"` Personality DebatePersonality `gorm:"column:personality;not null" json:"personality"`
Color string `json:"color"` Color string `gorm:"column:color;not null" json:"color"`
SpeakOrder int `json:"speak_order"` SpeakOrder int `gorm:"column:speak_order;default:0" json:"speak_order"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
}
func (DebateParticipant) TableName() string {
return "debate_participants"
} }
// DebateMessage represents a message in the debate // DebateMessage represents a message in the debate
type DebateMessage struct { type DebateMessage struct {
ID string `json:"id"` ID string `gorm:"column:id;primaryKey" json:"id"`
SessionID string `json:"session_id"` SessionID string `gorm:"column:session_id;not null;index" json:"session_id"`
Round int `json:"round"` Round int `gorm:"column:round;not null" json:"round"`
AIModelID string `json:"ai_model_id"` AIModelID string `gorm:"column:ai_model_id;not null" json:"ai_model_id"`
AIModelName string `json:"ai_model_name"` AIModelName string `gorm:"column:ai_model_name;not null" json:"ai_model_name"`
Provider string `json:"provider"` Provider string `gorm:"column:provider;not null" json:"provider"`
Personality DebatePersonality `json:"personality"` Personality DebatePersonality `gorm:"column:personality;not null" json:"personality"`
MessageType string `json:"message_type"` // analysis/rebuttal/final/vote MessageType string `gorm:"column:message_type;not null" json:"message_type"` // analysis/rebuttal/final/vote
Content string `json:"content"` Content string `gorm:"column:content;not null" json:"content"`
Decision *DebateDecision `json:"decision,omitempty"` // Single decision (backward compat) DecisionRaw string `gorm:"column:decision" json:"-"` // JSON string in DB
Decisions []*DebateDecision `json:"decisions,omitempty"` // Multi-coin decisions Decision *DebateDecision `gorm:"-" json:"decision,omitempty"` // Parsed for API
Confidence int `json:"confidence"` Decisions []*DebateDecision `gorm:"-" json:"decisions,omitempty"` // Multi-coin decisions
CreatedAt time.Time `json:"created_at"` Confidence int `gorm:"column:confidence;default:0" json:"confidence"`
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
}
func (DebateMessage) TableName() string {
return "debate_messages"
} }
// DebateVote represents a final vote from an AI (can contain multiple coin decisions) // DebateVote represents a final vote from an AI (can contain multiple coin decisions)
type DebateVote struct { type DebateVote struct {
ID string `json:"id"` ID string `gorm:"column:id;primaryKey" json:"id"`
SessionID string `json:"session_id"` SessionID string `gorm:"column:session_id;not null;index" json:"session_id"`
AIModelID string `json:"ai_model_id"` AIModelID string `gorm:"column:ai_model_id;not null" json:"ai_model_id"`
AIModelName string `json:"ai_model_name"` AIModelName string `gorm:"column:ai_model_name;not null" json:"ai_model_name"`
Action string `json:"action"` // Primary action (backward compat) Action string `gorm:"column:action;not null" json:"action"` // Primary action (backward compat)
Symbol string `json:"symbol"` // Primary symbol (backward compat) Symbol string `gorm:"column:symbol;not null" json:"symbol"` // Primary symbol (backward compat)
Confidence int `json:"confidence"` Confidence int `gorm:"column:confidence;default:0" json:"confidence"`
Leverage int `json:"leverage"` Leverage int `gorm:"column:leverage;default:5" json:"leverage"`
PositionPct float64 `json:"position_pct"` PositionPct float64 `gorm:"column:position_pct;default:0.2" json:"position_pct"`
StopLossPct float64 `json:"stop_loss_pct"` StopLossPct float64 `gorm:"column:stop_loss_pct;default:0.03" json:"stop_loss_pct"`
TakeProfitPct float64 `json:"take_profit_pct"` TakeProfitPct float64 `gorm:"column:take_profit_pct;default:0.06" json:"take_profit_pct"`
Reasoning string `json:"reasoning"` Reasoning string `gorm:"column:reasoning" json:"reasoning"`
Decisions []*DebateDecision `json:"decisions,omitempty"` // Multi-coin decisions Decisions []*DebateDecision `gorm:"-" json:"decisions,omitempty"` // Multi-coin decisions
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
}
func (DebateVote) TableName() string {
return "debate_votes"
} }
// DebateStore handles database operations for debates // DebateStore handles database operations for debates
type DebateStore struct { type DebateStore struct {
db *sql.DB db *gorm.DB
} }
// NewDebateStore creates a new DebateStore // NewDebateStore creates a new DebateStore
func NewDebateStore(db *sql.DB) *DebateStore { func NewDebateStore(db *gorm.DB) *DebateStore {
return &DebateStore{db: db} return &DebateStore{db: db}
} }
// InitSchema creates the debate tables // InitSchema creates the debate tables using GORM AutoMigrate
func (s *DebateStore) InitSchema() error { func (s *DebateStore) InitSchema() error {
schemas := []string{ return s.db.AutoMigrate(
`CREATE TABLE IF NOT EXISTS debate_sessions ( &DebateSessionDB{},
id TEXT PRIMARY KEY, &DebateParticipant{},
user_id TEXT NOT NULL, &DebateMessage{},
name TEXT NOT NULL, &DebateVote{},
strategy_id TEXT NOT NULL, )
status TEXT NOT NULL DEFAULT 'pending',
symbol TEXT NOT NULL,
max_rounds INTEGER DEFAULT 3,
current_round INTEGER DEFAULT 0,
interval_minutes INTEGER DEFAULT 5,
prompt_variant TEXT DEFAULT 'balanced',
final_decision TEXT,
auto_execute BOOLEAN DEFAULT 0,
trader_id TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`,
`CREATE INDEX IF NOT EXISTS idx_debate_sessions_user_id ON debate_sessions(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_debate_sessions_status ON debate_sessions(status)`,
`CREATE TABLE IF NOT EXISTS debate_participants (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
ai_model_id TEXT NOT NULL,
ai_model_name TEXT NOT NULL,
provider TEXT NOT NULL,
personality TEXT NOT NULL,
color TEXT NOT NULL,
speak_order INTEGER DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (session_id) REFERENCES debate_sessions(id) ON DELETE CASCADE
)`,
`CREATE INDEX IF NOT EXISTS idx_debate_participants_session ON debate_participants(session_id)`,
`CREATE TABLE IF NOT EXISTS debate_messages (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
round INTEGER NOT NULL,
ai_model_id TEXT NOT NULL,
ai_model_name TEXT NOT NULL,
provider TEXT NOT NULL,
personality TEXT NOT NULL,
message_type TEXT NOT NULL,
content TEXT NOT NULL,
decision TEXT,
confidence INTEGER DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (session_id) REFERENCES debate_sessions(id) ON DELETE CASCADE
)`,
`CREATE INDEX IF NOT EXISTS idx_debate_messages_session ON debate_messages(session_id)`,
`CREATE INDEX IF NOT EXISTS idx_debate_messages_round ON debate_messages(session_id, round)`,
`CREATE TABLE IF NOT EXISTS debate_votes (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
ai_model_id TEXT NOT NULL,
ai_model_name TEXT NOT NULL,
action TEXT NOT NULL,
symbol TEXT NOT NULL,
confidence INTEGER DEFAULT 0,
leverage INTEGER DEFAULT 5,
position_pct REAL DEFAULT 0.2,
stop_loss_pct REAL DEFAULT 0.03,
take_profit_pct REAL DEFAULT 0.06,
reasoning TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (session_id) REFERENCES debate_sessions(id) ON DELETE CASCADE
)`,
`CREATE INDEX IF NOT EXISTS idx_debate_votes_session ON debate_votes(session_id)`,
// Trigger to update updated_at
`CREATE TRIGGER IF NOT EXISTS update_debate_sessions_timestamp
AFTER UPDATE ON debate_sessions
FOR EACH ROW
BEGIN
UPDATE debate_sessions SET updated_at = CURRENT_TIMESTAMP WHERE id = OLD.id;
END`,
}
for _, schema := range schemas {
if _, err := s.db.Exec(schema); err != nil {
return fmt.Errorf("failed to create debate schema: %w", err)
}
}
// Migrate: Add new columns to existing tables (ignore errors if columns already exist)
migrations := []string{
`ALTER TABLE debate_sessions ADD COLUMN interval_minutes INTEGER DEFAULT 5`,
`ALTER TABLE debate_sessions ADD COLUMN prompt_variant TEXT DEFAULT 'balanced'`,
`ALTER TABLE debate_sessions ADD COLUMN trader_id TEXT`,
`ALTER TABLE debate_sessions ADD COLUMN enable_oi_ranking BOOLEAN DEFAULT 0`,
`ALTER TABLE debate_sessions ADD COLUMN oi_ranking_limit INTEGER DEFAULT 10`,
`ALTER TABLE debate_sessions ADD COLUMN oi_duration TEXT DEFAULT '1h'`,
`ALTER TABLE debate_votes ADD COLUMN leverage INTEGER DEFAULT 5`,
`ALTER TABLE debate_votes ADD COLUMN position_pct REAL DEFAULT 0.2`,
`ALTER TABLE debate_votes ADD COLUMN stop_loss_pct REAL DEFAULT 0.03`,
`ALTER TABLE debate_votes ADD COLUMN take_profit_pct REAL DEFAULT 0.06`,
}
for _, migration := range migrations {
// Ignore errors - column may already exist
s.db.Exec(migration)
}
return nil
} }
// CreateSession creates a new debate session // CreateSession creates a new debate session
@@ -279,227 +263,73 @@ func (s *DebateStore) CreateSession(session *DebateSession) error {
if session.OIDuration == "" { if session.OIDuration == "" {
session.OIDuration = "1h" session.OIDuration = "1h"
} }
session.CreatedAt = time.Now()
session.UpdatedAt = time.Now()
_, err := s.db.Exec(` db := &DebateSessionDB{
INSERT INTO debate_sessions (id, user_id, name, strategy_id, status, symbol, max_rounds, current_round, interval_minutes, prompt_variant, auto_execute, trader_id, enable_oi_ranking, oi_ranking_limit, oi_duration, created_at, updated_at) ID: session.ID,
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, UserID: session.UserID,
session.ID, session.UserID, session.Name, session.StrategyID, session.Status, Name: session.Name,
session.Symbol, session.MaxRounds, session.CurrentRound, session.IntervalMinutes, session.PromptVariant, StrategyID: session.StrategyID,
session.AutoExecute, session.TraderID, session.EnableOIRanking, session.OIRankingLimit, session.OIDuration, Status: session.Status,
session.CreatedAt, session.UpdatedAt, Symbol: session.Symbol,
) MaxRounds: session.MaxRounds,
return err CurrentRound: session.CurrentRound,
IntervalMinutes: session.IntervalMinutes,
PromptVariant: session.PromptVariant,
AutoExecute: session.AutoExecute,
TraderID: session.TraderID,
EnableOIRanking: session.EnableOIRanking,
OIRankingLimit: session.OIRankingLimit,
OIDuration: session.OIDuration,
}
return s.db.Create(db).Error
} }
// GetSession gets a debate session by ID // GetSession gets a debate session by ID
func (s *DebateStore) GetSession(id string) (*DebateSession, error) { func (s *DebateStore) GetSession(id string) (*DebateSession, error) {
var session DebateSession var db DebateSessionDB
var finalDecisionJSON sql.NullString if err := s.db.Where("id = ?", id).First(&db).Error; err != nil {
var traderID sql.NullString return nil, err
var intervalMinutes sql.NullInt64
var promptVariant sql.NullString
var enableOIRanking sql.NullBool
var oiRankingLimit sql.NullInt64
var oiDuration sql.NullString
// Try new schema first
err := s.db.QueryRow(`
SELECT id, user_id, name, strategy_id, status, symbol, max_rounds, current_round,
interval_minutes, prompt_variant, final_decision, auto_execute, trader_id,
enable_oi_ranking, oi_ranking_limit, oi_duration, created_at, updated_at
FROM debate_sessions WHERE id = ?`, id,
).Scan(
&session.ID, &session.UserID, &session.Name, &session.StrategyID,
&session.Status, &session.Symbol, &session.MaxRounds, &session.CurrentRound,
&intervalMinutes, &promptVariant,
&finalDecisionJSON, &session.AutoExecute, &traderID,
&enableOIRanking, &oiRankingLimit, &oiDuration, &session.CreatedAt, &session.UpdatedAt,
)
// Fallback to basic schema if new columns don't exist
if err != nil {
err = s.db.QueryRow(`
SELECT id, user_id, name, strategy_id, status, symbol, max_rounds, current_round,
final_decision, auto_execute, created_at, updated_at
FROM debate_sessions WHERE id = ?`, id,
).Scan(
&session.ID, &session.UserID, &session.Name, &session.StrategyID,
&session.Status, &session.Symbol, &session.MaxRounds, &session.CurrentRound,
&finalDecisionJSON, &session.AutoExecute, &session.CreatedAt, &session.UpdatedAt,
)
if err != nil {
return nil, err
}
// Set defaults for new fields
session.IntervalMinutes = 5
session.PromptVariant = "balanced"
session.OIRankingLimit = 10
session.OIDuration = "1h"
} else {
// Set defaults for nullable fields
session.IntervalMinutes = 5
if intervalMinutes.Valid {
session.IntervalMinutes = int(intervalMinutes.Int64)
}
session.PromptVariant = "balanced"
if promptVariant.Valid {
session.PromptVariant = promptVariant.String
}
if traderID.Valid {
session.TraderID = traderID.String
}
if enableOIRanking.Valid {
session.EnableOIRanking = enableOIRanking.Bool
}
session.OIRankingLimit = 10
if oiRankingLimit.Valid {
session.OIRankingLimit = int(oiRankingLimit.Int64)
}
session.OIDuration = "1h"
if oiDuration.Valid {
session.OIDuration = oiDuration.String
}
} }
return db.toSession(), nil
if finalDecisionJSON.Valid && finalDecisionJSON.String != "" {
var decision DebateDecision
if err := json.Unmarshal([]byte(finalDecisionJSON.String), &decision); err == nil {
session.FinalDecision = &decision
}
}
return &session, nil
} }
// GetSessionsByUser gets all debate sessions for a user // GetSessionsByUser gets all debate sessions for a user
func (s *DebateStore) GetSessionsByUser(userID string) ([]*DebateSession, error) { func (s *DebateStore) GetSessionsByUser(userID string) ([]*DebateSession, error) {
// First try the new schema with all columns var dbs []DebateSessionDB
rows, err := s.db.Query(` if err := s.db.Where("user_id = ?", userID).Order("created_at DESC").Find(&dbs).Error; err != nil {
SELECT id, user_id, name, strategy_id, status, symbol, max_rounds, current_round, return nil, err
interval_minutes, prompt_variant, final_decision, auto_execute, trader_id, created_at, updated_at
FROM debate_sessions WHERE user_id = ? ORDER BY created_at DESC`, userID,
)
// If query fails (likely due to missing columns), try basic query
if err != nil {
return s.getSessionsByUserBasic(userID)
} }
defer rows.Close()
var sessions []*DebateSession sessions := make([]*DebateSession, len(dbs))
for rows.Next() { for i, db := range dbs {
var session DebateSession sessions[i] = db.toSession()
var finalDecisionJSON sql.NullString
var traderID sql.NullString
var intervalMinutes sql.NullInt64
var promptVariant sql.NullString
if err := rows.Scan(
&session.ID, &session.UserID, &session.Name, &session.StrategyID,
&session.Status, &session.Symbol, &session.MaxRounds, &session.CurrentRound,
&intervalMinutes, &promptVariant,
&finalDecisionJSON, &session.AutoExecute, &traderID, &session.CreatedAt, &session.UpdatedAt,
); err != nil {
return nil, err
}
// Set defaults for nullable fields
session.IntervalMinutes = 5
if intervalMinutes.Valid {
session.IntervalMinutes = int(intervalMinutes.Int64)
}
session.PromptVariant = "balanced"
if promptVariant.Valid {
session.PromptVariant = promptVariant.String
}
if finalDecisionJSON.Valid && finalDecisionJSON.String != "" {
var decision DebateDecision
if err := json.Unmarshal([]byte(finalDecisionJSON.String), &decision); err == nil {
session.FinalDecision = &decision
}
}
if traderID.Valid {
session.TraderID = traderID.String
}
sessions = append(sessions, &session)
} }
return sessions, nil return sessions, nil
} }
// ListAllSessions returns all debate sessions (for cleanup on startup) // ListAllSessions returns all debate sessions (for cleanup on startup)
func (s *DebateStore) ListAllSessions() ([]*DebateSession, error) { func (s *DebateStore) ListAllSessions() ([]*DebateSession, error) {
rows, err := s.db.Query(`SELECT id, status FROM debate_sessions`) var dbs []DebateSessionDB
if err != nil { if err := s.db.Select("id, status").Find(&dbs).Error; err != nil {
return nil, err return nil, err
} }
defer rows.Close()
var sessions []*DebateSession sessions := make([]*DebateSession, len(dbs))
for rows.Next() { for i, db := range dbs {
var session DebateSession sessions[i] = &DebateSession{ID: db.ID, Status: db.Status}
if err := rows.Scan(&session.ID, &session.Status); err != nil {
return nil, err
}
sessions = append(sessions, &session)
}
return sessions, nil
}
// getSessionsByUserBasic is a fallback for old schema without new columns
func (s *DebateStore) getSessionsByUserBasic(userID string) ([]*DebateSession, error) {
rows, err := s.db.Query(`
SELECT id, user_id, name, strategy_id, status, symbol, max_rounds, current_round,
final_decision, auto_execute, created_at, updated_at
FROM debate_sessions WHERE user_id = ? ORDER BY created_at DESC`, userID,
)
if err != nil {
return nil, err
}
defer rows.Close()
var sessions []*DebateSession
for rows.Next() {
var session DebateSession
var finalDecisionJSON sql.NullString
if err := rows.Scan(
&session.ID, &session.UserID, &session.Name, &session.StrategyID,
&session.Status, &session.Symbol, &session.MaxRounds, &session.CurrentRound,
&finalDecisionJSON, &session.AutoExecute, &session.CreatedAt, &session.UpdatedAt,
); err != nil {
return nil, err
}
// Set defaults for new fields
session.IntervalMinutes = 5
session.PromptVariant = "balanced"
if finalDecisionJSON.Valid && finalDecisionJSON.String != "" {
var decision DebateDecision
if err := json.Unmarshal([]byte(finalDecisionJSON.String), &decision); err == nil {
session.FinalDecision = &decision
}
}
sessions = append(sessions, &session)
} }
return sessions, nil return sessions, nil
} }
// UpdateSessionStatus updates the status of a debate session // UpdateSessionStatus updates the status of a debate session
func (s *DebateStore) UpdateSessionStatus(id string, status DebateStatus) error { func (s *DebateStore) UpdateSessionStatus(id string, status DebateStatus) error {
_, err := s.db.Exec(`UPDATE debate_sessions SET status = ? WHERE id = ?`, status, id) return s.db.Model(&DebateSessionDB{}).Where("id = ?", id).Update("status", status).Error
return err
} }
// UpdateSessionRound updates the current round of a debate session // UpdateSessionRound updates the current round of a debate session
func (s *DebateStore) UpdateSessionRound(id string, round int) error { func (s *DebateStore) UpdateSessionRound(id string, round int) error {
_, err := s.db.Exec(`UPDATE debate_sessions SET current_round = ? WHERE id = ?`, round, id) return s.db.Model(&DebateSessionDB{}).Where("id = ?", id).Update("current_round", round).Error
return err
} }
// UpdateSessionFinalDecision updates the final decision of a debate session (single decision) // UpdateSessionFinalDecision updates the final decision of a debate session (single decision)
@@ -508,30 +338,31 @@ func (s *DebateStore) UpdateSessionFinalDecision(id string, decision *DebateDeci
if err != nil { if err != nil {
return err return err
} }
_, err = s.db.Exec(`UPDATE debate_sessions SET final_decision = ?, status = ? WHERE id = ?`, return s.db.Model(&DebateSessionDB{}).Where("id = ?", id).Updates(map[string]interface{}{
string(decisionJSON), DebateStatusCompleted, id) "final_decision": string(decisionJSON),
return err "status": DebateStatusCompleted,
}).Error
} }
// UpdateSessionFinalDecisions updates both single and multi-coin final decisions // UpdateSessionFinalDecisions updates both single and multi-coin final decisions
func (s *DebateStore) UpdateSessionFinalDecisions(id string, primaryDecision *DebateDecision, allDecisions []*DebateDecision) error { func (s *DebateStore) UpdateSessionFinalDecisions(id string, primaryDecision *DebateDecision, allDecisions []*DebateDecision) error {
// Always store primary decision as a single object (for backward compat)
// This ensures GetSession can deserialize it correctly
primaryJSON, err := json.Marshal(primaryDecision) primaryJSON, err := json.Marshal(primaryDecision)
if err != nil { if err != nil {
return err return err
} }
return s.db.Model(&DebateSessionDB{}).Where("id = ?", id).Updates(map[string]interface{}{
// Update final_decision with primary decision and set status to completed "final_decision": string(primaryJSON),
_, err = s.db.Exec(`UPDATE debate_sessions SET final_decision = ?, status = ? WHERE id = ?`, "status": DebateStatusCompleted,
string(primaryJSON), DebateStatusCompleted, id) }).Error
return err
} }
// DeleteSession deletes a debate session and all related data // DeleteSession deletes a debate session and all related data
func (s *DebateStore) DeleteSession(id string) error { func (s *DebateStore) DeleteSession(id string) error {
_, err := s.db.Exec(`DELETE FROM debate_sessions WHERE id = ?`, id) // Delete related data first
return err s.db.Where("session_id = ?", id).Delete(&DebateParticipant{})
s.db.Where("session_id = ?", id).Delete(&DebateMessage{})
s.db.Where("session_id = ?", id).Delete(&DebateVote{})
return s.db.Where("id = ?", id).Delete(&DebateSessionDB{}).Error
} }
// AddParticipant adds a participant to a debate session // AddParticipant adds a participant to a debate session
@@ -539,9 +370,6 @@ func (s *DebateStore) AddParticipant(participant *DebateParticipant) error {
if participant.ID == "" { if participant.ID == "" {
participant.ID = uuid.New().String() participant.ID = uuid.New().String()
} }
participant.CreatedAt = time.Now()
// Set color based on personality if not provided
if participant.Color == "" { if participant.Color == "" {
if color, ok := PersonalityColors[participant.Personality]; ok { if color, ok := PersonalityColors[participant.Personality]; ok {
participant.Color = color participant.Color = color
@@ -549,39 +377,14 @@ func (s *DebateStore) AddParticipant(participant *DebateParticipant) error {
participant.Color = "#6B7280" // Default gray participant.Color = "#6B7280" // Default gray
} }
} }
return s.db.Create(participant).Error
_, err := s.db.Exec(`
INSERT INTO debate_participants (id, session_id, ai_model_id, ai_model_name, provider, personality, color, speak_order, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
participant.ID, participant.SessionID, participant.AIModelID, participant.AIModelName,
participant.Provider, participant.Personality, participant.Color, participant.SpeakOrder, participant.CreatedAt,
)
return err
} }
// GetParticipants gets all participants for a debate session // GetParticipants gets all participants for a debate session
func (s *DebateStore) GetParticipants(sessionID string) ([]*DebateParticipant, error) { func (s *DebateStore) GetParticipants(sessionID string) ([]*DebateParticipant, error) {
rows, err := s.db.Query(`
SELECT id, session_id, ai_model_id, ai_model_name, provider, personality, color, speak_order, created_at
FROM debate_participants WHERE session_id = ? ORDER BY speak_order`, sessionID,
)
if err != nil {
return nil, err
}
defer rows.Close()
var participants []*DebateParticipant var participants []*DebateParticipant
for rows.Next() { err := s.db.Where("session_id = ?", sessionID).Order("speak_order").Find(&participants).Error
var p DebateParticipant return participants, err
if err := rows.Scan(
&p.ID, &p.SessionID, &p.AIModelID, &p.AIModelName,
&p.Provider, &p.Personality, &p.Color, &p.SpeakOrder, &p.CreatedAt,
); err != nil {
return nil, err
}
participants = append(participants, &p)
}
return participants, nil
} }
// AddMessage adds a message to a debate session // AddMessage adds a message to a debate session
@@ -589,95 +392,52 @@ func (s *DebateStore) AddMessage(msg *DebateMessage) error {
if msg.ID == "" { if msg.ID == "" {
msg.ID = uuid.New().String() msg.ID = uuid.New().String()
} }
msg.CreatedAt = time.Now()
var decisionJSON sql.NullString
if msg.Decision != nil { if msg.Decision != nil {
data, err := json.Marshal(msg.Decision) data, err := json.Marshal(msg.Decision)
if err != nil { if err != nil {
return err return err
} }
decisionJSON = sql.NullString{String: string(data), Valid: true} msg.DecisionRaw = string(data)
} }
return s.db.Create(msg).Error
_, err := s.db.Exec(`
INSERT INTO debate_messages (id, session_id, round, ai_model_id, ai_model_name, provider, personality, message_type, content, decision, confidence, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
msg.ID, msg.SessionID, msg.Round, msg.AIModelID, msg.AIModelName,
msg.Provider, msg.Personality, msg.MessageType, msg.Content,
decisionJSON, msg.Confidence, msg.CreatedAt,
)
return err
} }
// GetMessages gets all messages for a debate session // GetMessages gets all messages for a debate session
func (s *DebateStore) GetMessages(sessionID string) ([]*DebateMessage, error) { func (s *DebateStore) GetMessages(sessionID string) ([]*DebateMessage, error) {
rows, err := s.db.Query(` var messages []*DebateMessage
SELECT id, session_id, round, ai_model_id, ai_model_name, provider, personality, message_type, content, decision, confidence, created_at err := s.db.Where("session_id = ?", sessionID).Order("round, created_at").Find(&messages).Error
FROM debate_messages WHERE session_id = ? ORDER BY round, created_at`, sessionID,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
var messages []*DebateMessage // Parse decision JSON
for rows.Next() { for _, msg := range messages {
var msg DebateMessage if msg.DecisionRaw != "" {
var decisionJSON sql.NullString
if err := rows.Scan(
&msg.ID, &msg.SessionID, &msg.Round, &msg.AIModelID, &msg.AIModelName,
&msg.Provider, &msg.Personality, &msg.MessageType, &msg.Content,
&decisionJSON, &msg.Confidence, &msg.CreatedAt,
); err != nil {
return nil, err
}
if decisionJSON.Valid && decisionJSON.String != "" {
var decision DebateDecision var decision DebateDecision
if err := json.Unmarshal([]byte(decisionJSON.String), &decision); err == nil { if json.Unmarshal([]byte(msg.DecisionRaw), &decision) == nil {
msg.Decision = &decision msg.Decision = &decision
} }
} }
messages = append(messages, &msg)
} }
return messages, nil return messages, nil
} }
// GetMessagesByRound gets messages for a specific round // GetMessagesByRound gets messages for a specific round
func (s *DebateStore) GetMessagesByRound(sessionID string, round int) ([]*DebateMessage, error) { func (s *DebateStore) GetMessagesByRound(sessionID string, round int) ([]*DebateMessage, error) {
rows, err := s.db.Query(` var messages []*DebateMessage
SELECT id, session_id, round, ai_model_id, ai_model_name, provider, personality, message_type, content, decision, confidence, created_at err := s.db.Where("session_id = ? AND round = ?", sessionID, round).Order("created_at").Find(&messages).Error
FROM debate_messages WHERE session_id = ? AND round = ? ORDER BY created_at`, sessionID, round,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
var messages []*DebateMessage // Parse decision JSON
for rows.Next() { for _, msg := range messages {
var msg DebateMessage if msg.DecisionRaw != "" {
var decisionJSON sql.NullString
if err := rows.Scan(
&msg.ID, &msg.SessionID, &msg.Round, &msg.AIModelID, &msg.AIModelName,
&msg.Provider, &msg.Personality, &msg.MessageType, &msg.Content,
&decisionJSON, &msg.Confidence, &msg.CreatedAt,
); err != nil {
return nil, err
}
if decisionJSON.Valid && decisionJSON.String != "" {
var decision DebateDecision var decision DebateDecision
if err := json.Unmarshal([]byte(decisionJSON.String), &decision); err == nil { if json.Unmarshal([]byte(msg.DecisionRaw), &decision) == nil {
msg.Decision = &decision msg.Decision = &decision
} }
} }
messages = append(messages, &msg)
} }
return messages, nil return messages, nil
} }
@@ -687,40 +447,14 @@ func (s *DebateStore) AddVote(vote *DebateVote) error {
if vote.ID == "" { if vote.ID == "" {
vote.ID = uuid.New().String() vote.ID = uuid.New().String()
} }
vote.CreatedAt = time.Now() return s.db.Create(vote).Error
_, err := s.db.Exec(`
INSERT INTO debate_votes (id, session_id, ai_model_id, ai_model_name, action, symbol, confidence, leverage, position_pct, stop_loss_pct, take_profit_pct, reasoning, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
vote.ID, vote.SessionID, vote.AIModelID, vote.AIModelName,
vote.Action, vote.Symbol, vote.Confidence, vote.Leverage, vote.PositionPct, vote.StopLossPct, vote.TakeProfitPct, vote.Reasoning, vote.CreatedAt,
)
return err
} }
// GetVotes gets all votes for a debate session // GetVotes gets all votes for a debate session
func (s *DebateStore) GetVotes(sessionID string) ([]*DebateVote, error) { func (s *DebateStore) GetVotes(sessionID string) ([]*DebateVote, error) {
rows, err := s.db.Query(`
SELECT id, session_id, ai_model_id, ai_model_name, action, symbol, confidence, leverage, position_pct, stop_loss_pct, take_profit_pct, reasoning, created_at
FROM debate_votes WHERE session_id = ? ORDER BY created_at`, sessionID,
)
if err != nil {
return nil, err
}
defer rows.Close()
var votes []*DebateVote var votes []*DebateVote
for rows.Next() { err := s.db.Where("session_id = ?", sessionID).Order("created_at").Find(&votes).Error
var vote DebateVote return votes, err
if err := rows.Scan(
&vote.ID, &vote.SessionID, &vote.AIModelID, &vote.AIModelName,
&vote.Action, &vote.Symbol, &vote.Confidence, &vote.Leverage, &vote.PositionPct, &vote.StopLossPct, &vote.TakeProfitPct, &vote.Reasoning, &vote.CreatedAt,
); err != nil {
return nil, err
}
votes = append(votes, &vote)
}
return votes, nil
} }
// DebateSessionWithDetails combines session with participants and messages // DebateSessionWithDetails combines session with participants and messages
+135 -198
View File
@@ -1,18 +1,41 @@
package store package store
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"time" "time"
"gorm.io/gorm"
) )
// DecisionStore decision log storage // DecisionStore decision log storage
type DecisionStore struct { type DecisionStore struct {
db *sql.DB db *gorm.DB
} }
// DecisionRecord decision record // DecisionRecordDB internal GORM model for decision_records table
type DecisionRecordDB struct {
ID int64 `gorm:"primaryKey;autoIncrement"`
TraderID string `gorm:"column:trader_id;not null;index:idx_decision_records_trader_time"`
CycleNumber int `gorm:"column:cycle_number;not null"`
Timestamp time.Time `gorm:"not null;index:idx_decision_records_trader_time,sort:desc;index:idx_decision_records_timestamp,sort:desc"`
SystemPrompt string `gorm:"column:system_prompt;default:''"`
InputPrompt string `gorm:"column:input_prompt;default:''"`
CoTTrace string `gorm:"column:cot_trace;default:''"`
DecisionJSON string `gorm:"column:decision_json;default:''"`
RawResponse string `gorm:"column:raw_response;default:''"`
CandidateCoins string `gorm:"column:candidate_coins;default:''"`
ExecutionLog string `gorm:"column:execution_log;default:''"`
Decisions string `gorm:"column:decisions;default:'[]'"`
Success bool `gorm:"default:false"`
ErrorMessage string `gorm:"column:error_message;default:''"`
AIRequestDurationMs int64 `gorm:"column:ai_request_duration_ms;default:0"`
CreatedAt time.Time `json:"created_at"`
}
func (DecisionRecordDB) TableName() string { return "decision_records" }
// DecisionRecord decision record (external API struct)
type DecisionRecord struct { type DecisionRecord struct {
ID int64 `json:"id"` ID int64 `json:"id"`
TraderID string `json:"trader_id"` TraderID string `json:"trader_id"`
@@ -81,49 +104,47 @@ type Statistics struct {
TotalClosePositions int `json:"total_close_positions"` TotalClosePositions int `json:"total_close_positions"`
} }
// initTables initializes AI decision log tables // NewDecisionStore creates a new DecisionStore
// Note: Account equity curve data has been migrated to trader_equity_snapshots table (managed by EquityStore) func NewDecisionStore(db *gorm.DB) *DecisionStore {
func (s *DecisionStore) initTables() error { return &DecisionStore{db: db}
queries := []string{
// AI decision log table (records AI input/output, chain of thought, etc.)
`CREATE TABLE IF NOT EXISTS decision_records (
id INTEGER PRIMARY KEY AUTOINCREMENT,
trader_id TEXT NOT NULL,
cycle_number INTEGER NOT NULL,
timestamp DATETIME NOT NULL,
system_prompt TEXT DEFAULT '',
input_prompt TEXT DEFAULT '',
cot_trace TEXT DEFAULT '',
decision_json TEXT DEFAULT '',
raw_response TEXT DEFAULT '',
candidate_coins TEXT DEFAULT '',
execution_log TEXT DEFAULT '',
success BOOLEAN DEFAULT 0,
error_message TEXT DEFAULT '',
ai_request_duration_ms INTEGER DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`,
// Indexes
`CREATE INDEX IF NOT EXISTS idx_decision_records_trader_time ON decision_records(trader_id, timestamp DESC)`,
`CREATE INDEX IF NOT EXISTS idx_decision_records_timestamp ON decision_records(timestamp DESC)`,
}
for _, query := range queries {
if _, err := s.db.Exec(query); err != nil {
return fmt.Errorf("failed to execute SQL: %w", err)
}
}
// Migration: add raw_response column if not exists
s.db.Exec(`ALTER TABLE decision_records ADD COLUMN raw_response TEXT DEFAULT ''`)
// Migration: add decisions column if not exists
s.db.Exec(`ALTER TABLE decision_records ADD COLUMN decisions TEXT DEFAULT '[]'`)
return nil
} }
// LogDecision logs decision (only saves AI decision log, equity curve has been migrated to equity table) // initTables initializes AI decision log tables
func (s *DecisionStore) initTables() error {
// For PostgreSQL with existing table, skip AutoMigrate
if s.db.Dialector.Name() == "postgres" {
var tableExists int64
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'decision_records'`).Scan(&tableExists)
if tableExists > 0 {
return nil
}
}
return s.db.AutoMigrate(&DecisionRecordDB{})
}
// toRecord converts DB model to API struct
func (db *DecisionRecordDB) toRecord() *DecisionRecord {
record := &DecisionRecord{
ID: db.ID,
TraderID: db.TraderID,
CycleNumber: db.CycleNumber,
Timestamp: db.Timestamp,
SystemPrompt: db.SystemPrompt,
InputPrompt: db.InputPrompt,
CoTTrace: db.CoTTrace,
DecisionJSON: db.DecisionJSON,
RawResponse: db.RawResponse,
Success: db.Success,
ErrorMessage: db.ErrorMessage,
AIRequestDurationMs: db.AIRequestDurationMs,
}
json.Unmarshal([]byte(db.CandidateCoins), &record.CandidateCoins)
json.Unmarshal([]byte(db.ExecutionLog), &record.ExecutionLog)
json.Unmarshal([]byte(db.Decisions), &record.Decisions)
return record
}
// LogDecision logs decision
func (s *DecisionStore) LogDecision(record *DecisionRecord) error { func (s *DecisionStore) LogDecision(record *DecisionRecord) error {
if record.Timestamp.IsZero() { if record.Timestamp.IsZero() {
record.Timestamp = time.Now().UTC() record.Timestamp = time.Now().UTC()
@@ -131,65 +152,49 @@ func (s *DecisionStore) LogDecision(record *DecisionRecord) error {
record.Timestamp = record.Timestamp.UTC() record.Timestamp = record.Timestamp.UTC()
} }
// Serialize candidate coins, execution log and decisions to JSON // Serialize arrays to JSON
candidateCoinsJSON, _ := json.Marshal(record.CandidateCoins) candidateCoinsJSON, _ := json.Marshal(record.CandidateCoins)
executionLogJSON, _ := json.Marshal(record.ExecutionLog) executionLogJSON, _ := json.Marshal(record.ExecutionLog)
decisionsJSON, _ := json.Marshal(record.Decisions) decisionsJSON, _ := json.Marshal(record.Decisions)
// Insert decision record main table (only save AI decision related content) dbRecord := &DecisionRecordDB{
result, err := s.db.Exec(` TraderID: record.TraderID,
INSERT INTO decision_records ( CycleNumber: record.CycleNumber,
trader_id, cycle_number, timestamp, system_prompt, input_prompt, Timestamp: record.Timestamp,
cot_trace, decision_json, raw_response, candidate_coins, execution_log, SystemPrompt: record.SystemPrompt,
decisions, success, error_message, ai_request_duration_ms InputPrompt: record.InputPrompt,
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) CoTTrace: record.CoTTrace,
`, DecisionJSON: record.DecisionJSON,
record.TraderID, record.CycleNumber, record.Timestamp.Format(time.RFC3339), RawResponse: record.RawResponse,
record.SystemPrompt, record.InputPrompt, record.CoTTrace, record.DecisionJSON, CandidateCoins: string(candidateCoinsJSON),
record.RawResponse, string(candidateCoinsJSON), string(executionLogJSON), ExecutionLog: string(executionLogJSON),
string(decisionsJSON), record.Success, record.ErrorMessage, record.AIRequestDurationMs, Decisions: string(decisionsJSON),
) Success: record.Success,
if err != nil { ErrorMessage: record.ErrorMessage,
AIRequestDurationMs: record.AIRequestDurationMs,
}
if err := s.db.Create(dbRecord).Error; err != nil {
return fmt.Errorf("failed to insert decision record: %w", err) return fmt.Errorf("failed to insert decision record: %w", err)
} }
record.ID = dbRecord.ID
decisionID, err := result.LastInsertId()
if err != nil {
return fmt.Errorf("failed to get decision ID: %w", err)
}
record.ID = decisionID
return nil return nil
} }
// GetLatestRecords gets the latest N records for specified trader (sorted by time in ascending order: old to new) // GetLatestRecords gets the latest N records for specified trader (sorted by time in ascending order: old to new)
func (s *DecisionStore) GetLatestRecords(traderID string, n int) ([]*DecisionRecord, error) { func (s *DecisionStore) GetLatestRecords(traderID string, n int) ([]*DecisionRecord, error) {
rows, err := s.db.Query(` var dbRecords []*DecisionRecordDB
SELECT id, trader_id, cycle_number, timestamp, system_prompt, input_prompt, err := s.db.Where("trader_id = ?", traderID).
cot_trace, decision_json, candidate_coins, execution_log, Order("timestamp DESC").
COALESCE(decisions, '[]'), success, error_message, ai_request_duration_ms Limit(n).
FROM decision_records Find(&dbRecords).Error
WHERE trader_id = ?
ORDER BY timestamp DESC
LIMIT ?
`, traderID, n)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to query decision records: %w", err) return nil, fmt.Errorf("failed to query decision records: %w", err)
} }
defer rows.Close()
var records []*DecisionRecord records := make([]*DecisionRecord, len(dbRecords))
for rows.Next() { for i, db := range dbRecords {
record, err := s.scanDecisionRecord(rows) records[i] = db.toRecord()
if err != nil {
continue
}
records = append(records, record)
}
// Fill associated data
for _, record := range records {
s.fillRecordDetails(record)
} }
// Reverse array to sort time from old to new // Reverse array to sort time from old to new
@@ -202,26 +207,15 @@ func (s *DecisionStore) GetLatestRecords(traderID string, n int) ([]*DecisionRec
// GetAllLatestRecords gets the latest N records for all traders // GetAllLatestRecords gets the latest N records for all traders
func (s *DecisionStore) GetAllLatestRecords(n int) ([]*DecisionRecord, error) { func (s *DecisionStore) GetAllLatestRecords(n int) ([]*DecisionRecord, error) {
rows, err := s.db.Query(` var dbRecords []*DecisionRecordDB
SELECT id, trader_id, cycle_number, timestamp, system_prompt, input_prompt, err := s.db.Order("timestamp DESC").Limit(n).Find(&dbRecords).Error
cot_trace, decision_json, candidate_coins, execution_log,
COALESCE(decisions, '[]'), success, error_message, ai_request_duration_ms
FROM decision_records
ORDER BY timestamp DESC
LIMIT ?
`, n)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to query decision records: %w", err) return nil, fmt.Errorf("failed to query decision records: %w", err)
} }
defer rows.Close()
var records []*DecisionRecord records := make([]*DecisionRecord, len(dbRecords))
for rows.Next() { for i, db := range dbRecords {
record, err := s.scanDecisionRecord(rows) records[i] = db.toRecord()
if err != nil {
continue
}
records = append(records, record)
} }
// Reverse array // Reverse array
@@ -236,26 +230,17 @@ func (s *DecisionStore) GetAllLatestRecords(n int) ([]*DecisionRecord, error) {
func (s *DecisionStore) GetRecordsByDate(traderID string, date time.Time) ([]*DecisionRecord, error) { func (s *DecisionStore) GetRecordsByDate(traderID string, date time.Time) ([]*DecisionRecord, error) {
dateStr := date.Format("2006-01-02") dateStr := date.Format("2006-01-02")
rows, err := s.db.Query(` var dbRecords []*DecisionRecordDB
SELECT id, trader_id, cycle_number, timestamp, system_prompt, input_prompt, err := s.db.Where("trader_id = ? AND DATE(timestamp) = ?", traderID, dateStr).
cot_trace, decision_json, candidate_coins, execution_log, Order("timestamp ASC").
COALESCE(decisions, '[]'), success, error_message, ai_request_duration_ms Find(&dbRecords).Error
FROM decision_records
WHERE trader_id = ? AND DATE(timestamp) = ?
ORDER BY timestamp ASC
`, traderID, dateStr)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to query decision records: %w", err) return nil, fmt.Errorf("failed to query decision records: %w", err)
} }
defer rows.Close()
var records []*DecisionRecord records := make([]*DecisionRecord, len(dbRecords))
for rows.Next() { for i, db := range dbRecords {
record, err := s.scanDecisionRecord(rows) records[i] = db.toRecord()
if err != nil {
continue
}
records = append(records, record)
} }
return records, nil return records, nil
@@ -263,48 +248,31 @@ func (s *DecisionStore) GetRecordsByDate(traderID string, date time.Time) ([]*De
// CleanOldRecords cleans old records from N days ago // CleanOldRecords cleans old records from N days ago
func (s *DecisionStore) CleanOldRecords(traderID string, days int) (int64, error) { func (s *DecisionStore) CleanOldRecords(traderID string, days int) (int64, error) {
cutoffTime := time.Now().AddDate(0, 0, -days).Format(time.RFC3339) cutoffTime := time.Now().AddDate(0, 0, -days)
result, err := s.db.Exec(` result := s.db.Where("trader_id = ? AND timestamp < ?", traderID, cutoffTime).
DELETE FROM decision_records Delete(&DecisionRecordDB{})
WHERE trader_id = ? AND timestamp < ? if result.Error != nil {
`, traderID, cutoffTime) return 0, fmt.Errorf("failed to clean old records: %w", result.Error)
if err != nil {
return 0, fmt.Errorf("failed to clean old records: %w", err)
} }
return result.RowsAffected, nil
return result.RowsAffected()
} }
// GetStatistics gets statistics information for specified trader // GetStatistics gets statistics information for specified trader
func (s *DecisionStore) GetStatistics(traderID string) (*Statistics, error) { func (s *DecisionStore) GetStatistics(traderID string) (*Statistics, error) {
stats := &Statistics{} stats := &Statistics{}
err := s.db.QueryRow(` var totalCount, successCount int64
SELECT COUNT(*) FROM decision_records WHERE trader_id = ? s.db.Model(&DecisionRecordDB{}).Where("trader_id = ?", traderID).Count(&totalCount)
`, traderID).Scan(&stats.TotalCycles) s.db.Model(&DecisionRecordDB{}).Where("trader_id = ? AND success = ?", traderID, true).Count(&successCount)
if err != nil {
return nil, fmt.Errorf("failed to query total cycles: %w", err)
}
err = s.db.QueryRow(` stats.TotalCycles = int(totalCount)
SELECT COUNT(*) FROM decision_records WHERE trader_id = ? AND success = 1 stats.SuccessfulCycles = int(successCount)
`, traderID).Scan(&stats.SuccessfulCycles)
if err != nil {
return nil, fmt.Errorf("failed to query successful cycles: %w", err)
}
stats.FailedCycles = stats.TotalCycles - stats.SuccessfulCycles stats.FailedCycles = stats.TotalCycles - stats.SuccessfulCycles
// Count from trader_positions table // Count from trader_positions table using raw query for cross-table
s.db.QueryRow(` s.db.Raw("SELECT COUNT(*) FROM trader_positions WHERE trader_id = ?", traderID).Scan(&stats.TotalOpenPositions)
SELECT COUNT(*) FROM trader_positions s.db.Raw("SELECT COUNT(*) FROM trader_positions WHERE trader_id = ? AND status = 'CLOSED'", traderID).Scan(&stats.TotalClosePositions)
WHERE trader_id = ?
`, traderID).Scan(&stats.TotalOpenPositions)
s.db.QueryRow(`
SELECT COUNT(*) FROM trader_positions
WHERE trader_id = ? AND status = 'CLOSED'
`, traderID).Scan(&stats.TotalClosePositions)
return stats, nil return stats, nil
} }
@@ -313,64 +281,33 @@ func (s *DecisionStore) GetStatistics(traderID string) (*Statistics, error) {
func (s *DecisionStore) GetAllStatistics() (*Statistics, error) { func (s *DecisionStore) GetAllStatistics() (*Statistics, error) {
stats := &Statistics{} stats := &Statistics{}
s.db.QueryRow(`SELECT COUNT(*) FROM decision_records`).Scan(&stats.TotalCycles) var totalCount, successCount int64
s.db.QueryRow(`SELECT COUNT(*) FROM decision_records WHERE success = 1`).Scan(&stats.SuccessfulCycles) s.db.Model(&DecisionRecordDB{}).Count(&totalCount)
s.db.Model(&DecisionRecordDB{}).Where("success = ?", true).Count(&successCount)
stats.TotalCycles = int(totalCount)
stats.SuccessfulCycles = int(successCount)
stats.FailedCycles = stats.TotalCycles - stats.SuccessfulCycles stats.FailedCycles = stats.TotalCycles - stats.SuccessfulCycles
// Count from trader_positions table // Count from trader_positions table
s.db.QueryRow(` s.db.Raw("SELECT COUNT(*) FROM trader_positions").Scan(&stats.TotalOpenPositions)
SELECT COUNT(*) FROM trader_positions s.db.Raw("SELECT COUNT(*) FROM trader_positions WHERE status = 'CLOSED'").Scan(&stats.TotalClosePositions)
`).Scan(&stats.TotalOpenPositions)
s.db.QueryRow(`
SELECT COUNT(*) FROM trader_positions
WHERE status = 'CLOSED'
`).Scan(&stats.TotalClosePositions)
return stats, nil return stats, nil
} }
// GetLastCycleNumber gets the last cycle number for specified trader // GetLastCycleNumber gets the last cycle number for specified trader
func (s *DecisionStore) GetLastCycleNumber(traderID string) (int, error) { func (s *DecisionStore) GetLastCycleNumber(traderID string) (int, error) {
var cycleNumber int var cycleNumber *int
err := s.db.QueryRow(` err := s.db.Model(&DecisionRecordDB{}).
SELECT COALESCE(MAX(cycle_number), 0) FROM decision_records WHERE trader_id = ? Where("trader_id = ?", traderID).
`, traderID).Scan(&cycleNumber) Select("MAX(cycle_number)").
Scan(&cycleNumber).Error
if err != nil { if err != nil {
return 0, err return 0, err
} }
return cycleNumber, nil if cycleNumber == nil {
} return 0, nil
// scanDecisionRecord scans decision record from row
func (s *DecisionStore) scanDecisionRecord(rows *sql.Rows) (*DecisionRecord, error) {
var record DecisionRecord
var timestampStr string
var candidateCoinsJSON, executionLogJSON, decisionsJSON string
err := rows.Scan(
&record.ID, &record.TraderID, &record.CycleNumber, &timestampStr,
&record.SystemPrompt, &record.InputPrompt, &record.CoTTrace,
&record.DecisionJSON, &candidateCoinsJSON, &executionLogJSON,
&decisionsJSON, &record.Success, &record.ErrorMessage, &record.AIRequestDurationMs,
)
if err != nil {
return nil, err
} }
return *cycleNumber, nil
record.Timestamp, _ = time.Parse(time.RFC3339, timestampStr)
json.Unmarshal([]byte(candidateCoinsJSON), &record.CandidateCoins)
json.Unmarshal([]byte(executionLogJSON), &record.ExecutionLog)
json.Unmarshal([]byte(decisionsJSON), &record.Decisions)
return &record, nil
}
// fillRecordDetails fills associated data for decision record (old associated tables removed, this function kept for compatibility)
// Note: Account snapshot, position snapshot, decision action data are no longer stored in decision related tables
// - For equity data use EquityStore.GetLatest()
// - For order data use OrderStore
func (s *DecisionStore) fillRecordDetails(record *DecisionRecord) {
// Old associated tables removed, no longer need to fill
// AccountState, Positions, Decisions fields will remain at zero values
} }
+41
View File
@@ -238,3 +238,44 @@ func getEnv(key, defaultValue string) string {
} }
return defaultValue return defaultValue
} }
// convertQuery converts ? placeholders to $1, $2 for PostgreSQL
// and handles other database-specific syntax differences
func convertQuery(query string, dbType DBType) string {
if dbType != DBTypePostgres {
return query
}
result := query
// Convert ? to $1, $2, etc. for PostgreSQL
index := 1
for strings.Contains(result, "?") {
result = strings.Replace(result, "?", fmt.Sprintf("$%d", index), 1)
index++
}
// Convert datetime('now') to CURRENT_TIMESTAMP
result = strings.ReplaceAll(result, "datetime('now')", "CURRENT_TIMESTAMP")
// Remove datetime() wrapper for ORDER BY (PostgreSQL timestamps sort correctly)
// This handles patterns like "ORDER BY datetime(column) DESC"
result = strings.ReplaceAll(result, "datetime(updated_at)", "updated_at")
result = strings.ReplaceAll(result, "datetime(created_at)", "created_at")
return result
}
// boolDefault returns database-appropriate boolean default for COALESCE
// Use in queries like: COALESCE(column, %s)
func boolDefault(dbType DBType, value bool) string {
if dbType == DBTypePostgres {
if value {
return "TRUE"
}
return "FALSE"
}
if value {
return "1"
}
return "0"
}
+61 -139
View File
@@ -1,55 +1,48 @@
package store package store
import ( import (
"database/sql"
"fmt" "fmt"
"time" "time"
"gorm.io/gorm"
) )
// EquityStore account equity storage (for plotting return curves) // EquityStore account equity storage (for plotting return curves)
type EquityStore struct { type EquityStore struct {
db *sql.DB db *gorm.DB
} }
// EquitySnapshot equity snapshot // EquitySnapshot equity snapshot
type EquitySnapshot struct { type EquitySnapshot struct {
ID int64 `json:"id"` ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
TraderID string `json:"trader_id"` TraderID string `gorm:"column:trader_id;not null;index:idx_equity_trader_time" json:"trader_id"`
Timestamp time.Time `json:"timestamp"` Timestamp time.Time `gorm:"not null;index:idx_equity_trader_time,sort:desc;index:idx_equity_timestamp,sort:desc" json:"timestamp"`
TotalEquity float64 `json:"total_equity"` // Account equity (balance + unrealized PnL) TotalEquity float64 `gorm:"column:total_equity;not null;default:0" json:"total_equity"`
Balance float64 `json:"balance"` // Account balance Balance float64 `gorm:"not null;default:0" json:"balance"`
UnrealizedPnL float64 `json:"unrealized_pnl"` // Unrealized profit and loss UnrealizedPnL float64 `gorm:"column:unrealized_pnl;not null;default:0" json:"unrealized_pnl"`
PositionCount int `json:"position_count"` // Position count PositionCount int `gorm:"column:position_count;default:0" json:"position_count"`
MarginUsedPct float64 `json:"margin_used_pct"` // Margin usage percentage MarginUsedPct float64 `gorm:"column:margin_used_pct;default:0" json:"margin_used_pct"`
CreatedAt time.Time `json:"created_at"`
}
func (EquitySnapshot) TableName() string { return "trader_equity_snapshots" }
// NewEquityStore creates a new EquityStore
func NewEquityStore(db *gorm.DB) *EquityStore {
return &EquityStore{db: db}
} }
// initTables initializes equity tables // initTables initializes equity tables
func (s *EquityStore) initTables() error { func (s *EquityStore) initTables() error {
queries := []string{ // For PostgreSQL with existing table, skip AutoMigrate
// Equity snapshot table - specifically for return curves if s.db.Dialector.Name() == "postgres" {
`CREATE TABLE IF NOT EXISTS trader_equity_snapshots ( var tableExists int64
id INTEGER PRIMARY KEY AUTOINCREMENT, s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'trader_equity_snapshots'`).Scan(&tableExists)
trader_id TEXT NOT NULL, if tableExists > 0 {
timestamp DATETIME NOT NULL, return nil
total_equity REAL NOT NULL DEFAULT 0,
balance REAL NOT NULL DEFAULT 0,
unrealized_pnl REAL NOT NULL DEFAULT 0,
position_count INTEGER DEFAULT 0,
margin_used_pct REAL DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`,
// Indexes
`CREATE INDEX IF NOT EXISTS idx_equity_trader_time ON trader_equity_snapshots(trader_id, timestamp DESC)`,
`CREATE INDEX IF NOT EXISTS idx_equity_timestamp ON trader_equity_snapshots(timestamp DESC)`,
}
for _, query := range queries {
if _, err := s.db.Exec(query); err != nil {
return fmt.Errorf("failed to execute SQL: %w", err)
} }
} }
return s.db.AutoMigrate(&EquitySnapshot{})
return nil
} }
// Save saves equity snapshot // Save saves equity snapshot
@@ -60,58 +53,22 @@ func (s *EquityStore) Save(snapshot *EquitySnapshot) error {
snapshot.Timestamp = snapshot.Timestamp.UTC() snapshot.Timestamp = snapshot.Timestamp.UTC()
} }
result, err := s.db.Exec(` if err := s.db.Create(snapshot).Error; err != nil {
INSERT INTO trader_equity_snapshots (
trader_id, timestamp, total_equity, balance,
unrealized_pnl, position_count, margin_used_pct
) VALUES (?, ?, ?, ?, ?, ?, ?)
`,
snapshot.TraderID,
snapshot.Timestamp.Format(time.RFC3339),
snapshot.TotalEquity,
snapshot.Balance,
snapshot.UnrealizedPnL,
snapshot.PositionCount,
snapshot.MarginUsedPct,
)
if err != nil {
return fmt.Errorf("failed to save equity snapshot: %w", err) return fmt.Errorf("failed to save equity snapshot: %w", err)
} }
id, _ := result.LastInsertId()
snapshot.ID = id
return nil return nil
} }
// GetLatest gets the latest N equity records for specified trader (sorted in ascending chronological order: old to new) // GetLatest gets the latest N equity records for specified trader (sorted in ascending chronological order: old to new)
func (s *EquityStore) GetLatest(traderID string, limit int) ([]*EquitySnapshot, error) { func (s *EquityStore) GetLatest(traderID string, limit int) ([]*EquitySnapshot, error) {
rows, err := s.db.Query(` var snapshots []*EquitySnapshot
SELECT id, trader_id, timestamp, total_equity, balance, err := s.db.Where("trader_id = ?", traderID).
unrealized_pnl, position_count, margin_used_pct Order("timestamp DESC").
FROM trader_equity_snapshots Limit(limit).
WHERE trader_id = ? Find(&snapshots).Error
ORDER BY timestamp DESC
LIMIT ?
`, traderID, limit)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to query equity records: %w", err) return nil, fmt.Errorf("failed to query equity records: %w", err)
} }
defer rows.Close()
var snapshots []*EquitySnapshot
for rows.Next() {
snap := &EquitySnapshot{}
var timestampStr string
err := rows.Scan(
&snap.ID, &snap.TraderID, &timestampStr, &snap.TotalEquity,
&snap.Balance, &snap.UnrealizedPnL, &snap.PositionCount, &snap.MarginUsedPct,
)
if err != nil {
continue
}
snap.Timestamp, _ = time.Parse(time.RFC3339, timestampStr)
snapshots = append(snapshots, snap)
}
// Reverse the array to sort time from old to new (suitable for plotting curves) // Reverse the array to sort time from old to new (suitable for plotting curves)
for i, j := 0, len(snapshots)-1; i < j; i, j = i+1, j-1 { for i, j := 0, len(snapshots)-1; i < j; i, j = i+1, j-1 {
@@ -123,116 +80,81 @@ func (s *EquityStore) GetLatest(traderID string, limit int) ([]*EquitySnapshot,
// GetByTimeRange gets equity records within specified time range // GetByTimeRange gets equity records within specified time range
func (s *EquityStore) GetByTimeRange(traderID string, start, end time.Time) ([]*EquitySnapshot, error) { func (s *EquityStore) GetByTimeRange(traderID string, start, end time.Time) ([]*EquitySnapshot, error) {
rows, err := s.db.Query(` var snapshots []*EquitySnapshot
SELECT id, trader_id, timestamp, total_equity, balance, err := s.db.Where("trader_id = ? AND timestamp >= ? AND timestamp <= ?", traderID, start, end).
unrealized_pnl, position_count, margin_used_pct Order("timestamp ASC").
FROM trader_equity_snapshots Find(&snapshots).Error
WHERE trader_id = ? AND timestamp >= ? AND timestamp <= ?
ORDER BY timestamp ASC
`, traderID, start.Format(time.RFC3339), end.Format(time.RFC3339))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to query equity records: %w", err) return nil, fmt.Errorf("failed to query equity records: %w", err)
} }
defer rows.Close()
var snapshots []*EquitySnapshot
for rows.Next() {
snap := &EquitySnapshot{}
var timestampStr string
err := rows.Scan(
&snap.ID, &snap.TraderID, &timestampStr, &snap.TotalEquity,
&snap.Balance, &snap.UnrealizedPnL, &snap.PositionCount, &snap.MarginUsedPct,
)
if err != nil {
continue
}
snap.Timestamp, _ = time.Parse(time.RFC3339, timestampStr)
snapshots = append(snapshots, snap)
}
return snapshots, nil return snapshots, nil
} }
// GetAllTradersLatest gets latest equity for all traders (for leaderboards) // GetAllTradersLatest gets latest equity for all traders (for leaderboards)
func (s *EquityStore) GetAllTradersLatest() (map[string]*EquitySnapshot, error) { func (s *EquityStore) GetAllTradersLatest() (map[string]*EquitySnapshot, error) {
rows, err := s.db.Query(` // Use raw SQL for this complex query with subquery
var snapshots []*EquitySnapshot
err := s.db.Raw(`
SELECT e.id, e.trader_id, e.timestamp, e.total_equity, e.balance, SELECT e.id, e.trader_id, e.timestamp, e.total_equity, e.balance,
e.unrealized_pnl, e.position_count, e.margin_used_pct e.unrealized_pnl, e.position_count, e.margin_used_pct, e.created_at
FROM trader_equity_snapshots e FROM trader_equity_snapshots e
INNER JOIN ( INNER JOIN (
SELECT trader_id, MAX(timestamp) as max_ts SELECT trader_id, MAX(timestamp) as max_ts
FROM trader_equity_snapshots FROM trader_equity_snapshots
GROUP BY trader_id GROUP BY trader_id
) latest ON e.trader_id = latest.trader_id AND e.timestamp = latest.max_ts ) latest ON e.trader_id = latest.trader_id AND e.timestamp = latest.max_ts
`) `).Scan(&snapshots).Error
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to query latest equity: %w", err) return nil, fmt.Errorf("failed to query latest equity: %w", err)
} }
defer rows.Close()
result := make(map[string]*EquitySnapshot) result := make(map[string]*EquitySnapshot)
for rows.Next() { for _, snap := range snapshots {
snap := &EquitySnapshot{}
var timestampStr string
err := rows.Scan(
&snap.ID, &snap.TraderID, &timestampStr, &snap.TotalEquity,
&snap.Balance, &snap.UnrealizedPnL, &snap.PositionCount, &snap.MarginUsedPct,
)
if err != nil {
continue
}
snap.Timestamp, _ = time.Parse(time.RFC3339, timestampStr)
result[snap.TraderID] = snap result[snap.TraderID] = snap
} }
return result, nil return result, nil
} }
// CleanOldRecords cleans old records from N days ago // CleanOldRecords cleans old records from N days ago
func (s *EquityStore) CleanOldRecords(traderID string, days int) (int64, error) { func (s *EquityStore) CleanOldRecords(traderID string, days int) (int64, error) {
cutoffTime := time.Now().AddDate(0, 0, -days).Format(time.RFC3339) cutoffTime := time.Now().AddDate(0, 0, -days)
result, err := s.db.Exec(` result := s.db.Where("trader_id = ? AND timestamp < ?", traderID, cutoffTime).
DELETE FROM trader_equity_snapshots Delete(&EquitySnapshot{})
WHERE trader_id = ? AND timestamp < ? if result.Error != nil {
`, traderID, cutoffTime) return 0, fmt.Errorf("failed to clean old records: %w", result.Error)
if err != nil {
return 0, fmt.Errorf("failed to clean old records: %w", err)
} }
return result.RowsAffected, nil
return result.RowsAffected()
} }
// GetCount gets record count for specified trader // GetCount gets record count for specified trader
func (s *EquityStore) GetCount(traderID string) (int, error) { func (s *EquityStore) GetCount(traderID string) (int, error) {
var count int var count int64
err := s.db.QueryRow(` err := s.db.Model(&EquitySnapshot{}).Where("trader_id = ?", traderID).Count(&count).Error
SELECT COUNT(*) FROM trader_equity_snapshots WHERE trader_id = ? return int(count), err
`, traderID).Scan(&count)
return count, err
} }
// MigrateFromDecision migrates data from old decision_account_snapshots table // MigrateFromDecision migrates data from old decision_account_snapshots table
func (s *EquityStore) MigrateFromDecision() (int64, error) { func (s *EquityStore) MigrateFromDecision() (int64, error) {
// Check if migration is needed (whether new table is empty) // Check if migration is needed (whether new table is empty)
var count int var count int64
s.db.QueryRow(`SELECT COUNT(*) FROM trader_equity_snapshots`).Scan(&count) s.db.Model(&EquitySnapshot{}).Count(&count)
if count > 0 { if count > 0 {
return 0, nil // Already has data, skip migration return 0, nil // Already has data, skip migration
} }
// Check if old table exists // Check if old table exists (SQLite specific check, but works for migration)
var tableName string var tableName string
err := s.db.QueryRow(` err := s.db.Raw(`
SELECT name FROM sqlite_master SELECT name FROM sqlite_master
WHERE type='table' AND name='decision_account_snapshots' WHERE type='table' AND name='decision_account_snapshots'
`).Scan(&tableName) `).Scan(&tableName).Error
if err != nil { if err != nil || tableName == "" {
return 0, nil // Old table doesn't exist, skip return 0, nil // Old table doesn't exist, skip
} }
// Migrate data: join query from decision_records + decision_account_snapshots // Migrate data: join query from decision_records + decision_account_snapshots
result, err := s.db.Exec(` result := s.db.Exec(`
INSERT INTO trader_equity_snapshots ( INSERT INTO trader_equity_snapshots (
trader_id, timestamp, total_equity, balance, trader_id, timestamp, total_equity, balance,
unrealized_pnl, position_count, margin_used_pct unrealized_pnl, position_count, margin_used_pct
@@ -249,9 +171,9 @@ func (s *EquityStore) MigrateFromDecision() (int64, error) {
JOIN decision_account_snapshots das ON dr.id = das.decision_id JOIN decision_account_snapshots das ON dr.id = das.decision_id
ORDER BY dr.timestamp ASC ORDER BY dr.timestamp ASC
`) `)
if err != nil { if result.Error != nil {
return 0, fmt.Errorf("failed to migrate data: %w", err) return 0, fmt.Errorf("failed to migrate data: %w", result.Error)
} }
return result.RowsAffected() return result.RowsAffected, nil
} }
+153 -302
View File
@@ -1,83 +1,68 @@
package store package store
import ( import (
"database/sql"
"fmt" "fmt"
"nofx/crypto"
"nofx/logger" "nofx/logger"
"strings"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"gorm.io/gorm"
) )
// ExchangeStore exchange storage // ExchangeStore exchange storage
type ExchangeStore struct { type ExchangeStore struct {
db *sql.DB db *gorm.DB
encryptFunc func(string) string
decryptFunc func(string) string
} }
// Exchange exchange configuration // Exchange exchange configuration
type Exchange struct { type Exchange struct {
ID string `json:"id"` // UUID ID string `gorm:"primaryKey" json:"id"`
ExchangeType string `json:"exchange_type"` // "binance", "bybit", "okx", "hyperliquid", "aster", "lighter" ExchangeType string `gorm:"column:exchange_type;not null;default:''" json:"exchange_type"`
AccountName string `json:"account_name"` // User-defined account name AccountName string `gorm:"column:account_name;not null;default:''" json:"account_name"`
UserID string `json:"user_id"` UserID string `gorm:"column:user_id;not null;default:default;index" json:"user_id"`
Name string `json:"name"` // Display name (auto-generated or user-defined) Name string `gorm:"not null" json:"name"`
Type string `json:"type"` // "cex" or "dex" Type string `gorm:"not null" json:"type"` // "cex" or "dex"
Enabled bool `json:"enabled"` Enabled bool `gorm:"default:false" json:"enabled"`
APIKey string `json:"apiKey"` APIKey crypto.EncryptedString `gorm:"column:api_key;default:''" json:"apiKey"`
SecretKey string `json:"secretKey"` SecretKey crypto.EncryptedString `gorm:"column:secret_key;default:''" json:"secretKey"`
Passphrase string `json:"passphrase"` // OKX-specific Passphrase crypto.EncryptedString `gorm:"column:passphrase;default:''" json:"passphrase"`
Testnet bool `json:"testnet"` Testnet bool `gorm:"default:false" json:"testnet"`
HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"` HyperliquidWalletAddr string `gorm:"column:hyperliquid_wallet_addr;default:''" json:"hyperliquidWalletAddr"`
AsterUser string `json:"asterUser"` AsterUser string `gorm:"column:aster_user;default:''" json:"asterUser"`
AsterSigner string `json:"asterSigner"` AsterSigner string `gorm:"column:aster_signer;default:''" json:"asterSigner"`
AsterPrivateKey string `json:"asterPrivateKey"` AsterPrivateKey crypto.EncryptedString `gorm:"column:aster_private_key;default:''" json:"asterPrivateKey"`
LighterWalletAddr string `json:"lighterWalletAddr"` LighterWalletAddr string `gorm:"column:lighter_wallet_addr;default:''" json:"lighterWalletAddr"`
LighterPrivateKey string `json:"lighterPrivateKey"` LighterPrivateKey crypto.EncryptedString `gorm:"column:lighter_private_key;default:''" json:"lighterPrivateKey"`
LighterAPIKeyPrivateKey string `json:"lighterAPIKeyPrivateKey"` LighterAPIKeyPrivateKey crypto.EncryptedString `gorm:"column:lighter_api_key_private_key;default:''" json:"lighterAPIKeyPrivateKey"`
LighterAPIKeyIndex int `json:"lighterAPIKeyIndex"` LighterAPIKeyIndex int `gorm:"column:lighter_api_key_index;default:0" json:"lighterAPIKeyIndex"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
}
func (Exchange) TableName() string { return "exchanges" }
// NewExchangeStore creates a new ExchangeStore
func NewExchangeStore(db *gorm.DB) *ExchangeStore {
return &ExchangeStore{db: db}
} }
func (s *ExchangeStore) initTables() error { func (s *ExchangeStore) initTables() error {
// Create new table structure with UUID as primary key // For PostgreSQL with existing table, skip AutoMigrate
_, err := s.db.Exec(` if s.db.Dialector.Name() == "postgres" {
CREATE TABLE IF NOT EXISTS exchanges ( var tableExists int64
id TEXT PRIMARY KEY, s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'exchanges'`).Scan(&tableExists)
exchange_type TEXT NOT NULL DEFAULT '', if tableExists > 0 {
account_name TEXT NOT NULL DEFAULT '', // Still run data migrations
user_id TEXT NOT NULL DEFAULT 'default', s.migrateToMultiAccount()
name TEXT NOT NULL, s.db.Model(&Exchange{}).Where("account_name = '' OR account_name IS NULL").Update("account_name", "Default")
type TEXT NOT NULL, return nil
enabled BOOLEAN DEFAULT 0, }
api_key TEXT DEFAULT '',
secret_key TEXT DEFAULT '',
passphrase TEXT DEFAULT '',
testnet BOOLEAN DEFAULT 0,
hyperliquid_wallet_addr TEXT DEFAULT '',
aster_user TEXT DEFAULT '',
aster_signer TEXT DEFAULT '',
aster_private_key TEXT DEFAULT '',
lighter_wallet_addr TEXT DEFAULT '',
lighter_private_key TEXT DEFAULT '',
lighter_api_key_private_key TEXT DEFAULT '',
lighter_api_key_index INTEGER DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return err
} }
// Migration: add new columns if not exists if err := s.db.AutoMigrate(&Exchange{}); err != nil {
s.db.Exec(`ALTER TABLE exchanges ADD COLUMN passphrase TEXT DEFAULT ''`) return err
s.db.Exec(`ALTER TABLE exchanges ADD COLUMN exchange_type TEXT NOT NULL DEFAULT ''`) }
s.db.Exec(`ALTER TABLE exchanges ADD COLUMN account_name TEXT NOT NULL DEFAULT ''`)
s.db.Exec(`ALTER TABLE exchanges ADD COLUMN lighter_api_key_index INTEGER DEFAULT 0`)
// Run migration to multi-account if needed // Run migration to multi-account if needed
if err := s.migrateToMultiAccount(); err != nil { if err := s.migrateToMultiAccount(); err != nil {
@@ -85,120 +70,65 @@ func (s *ExchangeStore) initTables() error {
} }
// Fix empty account_name for existing records // Fix empty account_name for existing records
s.db.Exec(`UPDATE exchanges SET account_name = 'Default' WHERE account_name = '' OR account_name IS NULL`) s.db.Model(&Exchange{}).Where("account_name = '' OR account_name IS NULL").Update("account_name", "Default")
// Update trigger for new schema return nil
s.db.Exec(`DROP TRIGGER IF EXISTS update_exchanges_updated_at`)
_, err = s.db.Exec(`
CREATE TRIGGER IF NOT EXISTS update_exchanges_updated_at
AFTER UPDATE ON exchanges
BEGIN
UPDATE exchanges SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
END
`)
return err
} }
// migrateToMultiAccount migrates old schema (id=exchange_type) to new schema (id=UUID) // migrateToMultiAccount migrates old schema (id=exchange_type) to new schema (id=UUID)
func (s *ExchangeStore) migrateToMultiAccount() error { func (s *ExchangeStore) migrateToMultiAccount() error {
// Check if migration is needed by looking for old-style IDs (non-UUID) // Check if migration is needed by looking for old-style IDs (non-UUID)
var count int var count int64
err := s.db.QueryRow(` err := s.db.Model(&Exchange{}).
SELECT COUNT(*) FROM exchanges Where("exchange_type = '' AND id IN ?", []string{"binance", "bybit", "okx", "bitget", "hyperliquid", "aster", "lighter"}).
WHERE exchange_type = '' AND id IN ('binance', 'bybit', 'okx', 'bitget', 'hyperliquid', 'aster', 'lighter') Count(&count).Error
`).Scan(&count)
if err != nil { if err != nil {
return err return err
} }
if count == 0 { if count == 0 {
// No migration needed
return nil return nil
} }
logger.Infof("🔄 Migrating %d exchange records to multi-account schema...", count) logger.Infof("🔄 Migrating %d exchange records to multi-account schema...", count)
// Get all old records // Get all old records
rows, err := s.db.Query(` var records []Exchange
SELECT id, user_id, name, type, enabled, api_key, secret_key, err = s.db.Where("exchange_type = '' AND id IN ?", []string{"binance", "bybit", "okx", "bitget", "hyperliquid", "aster", "lighter"}).
COALESCE(passphrase, '') as passphrase, testnet, Find(&records).Error
COALESCE(hyperliquid_wallet_addr, '') as hyperliquid_wallet_addr,
COALESCE(aster_user, '') as aster_user,
COALESCE(aster_signer, '') as aster_signer,
COALESCE(aster_private_key, '') as aster_private_key,
COALESCE(lighter_wallet_addr, '') as lighter_wallet_addr,
COALESCE(lighter_private_key, '') as lighter_private_key,
COALESCE(lighter_api_key_private_key, '') as lighter_api_key_private_key
FROM exchanges
WHERE exchange_type = '' AND id IN ('binance', 'bybit', 'okx', 'bitget', 'hyperliquid', 'aster', 'lighter')
`)
if err != nil { if err != nil {
return err return err
} }
defer rows.Close()
type oldRecord struct {
id, userID, name, typ string
enabled, testnet bool
apiKey, secretKey, passphrase string
hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string
lighterWalletAddr, lighterPrivateKey, lighterApiKeyPrivateKey string
}
var records []oldRecord
for rows.Next() {
var r oldRecord
if err := rows.Scan(&r.id, &r.userID, &r.name, &r.typ, &r.enabled,
&r.apiKey, &r.secretKey, &r.passphrase, &r.testnet,
&r.hyperliquidWalletAddr, &r.asterUser, &r.asterSigner, &r.asterPrivateKey,
&r.lighterWalletAddr, &r.lighterPrivateKey, &r.lighterApiKeyPrivateKey); err != nil {
return err
}
records = append(records, r)
}
// Begin transaction // Begin transaction
tx, err := s.db.Begin() return s.db.Transaction(func(tx *gorm.DB) error {
if err != nil { for _, r := range records {
return err newID := uuid.New().String()
} oldID := r.ID // This is the exchange type (e.g., "binance")
defer tx.Rollback()
// Migrate each record // Update traders table to use new UUID
for _, r := range records { if err := tx.Exec("UPDATE traders SET exchange_id = ? WHERE exchange_id = ? AND user_id = ?",
newID := uuid.New().String() newID, oldID, r.UserID).Error; err != nil {
oldID := r.id // This is the exchange type (e.g., "binance") logger.Errorf("Failed to update traders for exchange %s: %v", oldID, err)
return err
}
// Update traders table to use new UUID // Update the exchange record
_, err = tx.Exec(`UPDATE traders SET exchange_id = ? WHERE exchange_id = ? AND user_id = ?`, if err := tx.Model(&Exchange{}).
newID, oldID, r.userID) Where("id = ? AND user_id = ?", oldID, r.UserID).
if err != nil { Updates(map[string]interface{}{
logger.Errorf("Failed to update traders for exchange %s: %v", oldID, err) "id": newID,
return err "exchange_type": oldID,
"account_name": "Default",
}).Error; err != nil {
logger.Errorf("Failed to migrate exchange %s: %v", oldID, err)
return err
}
logger.Infof("✅ Migrated exchange %s -> UUID %s for user %s", oldID, newID, r.UserID)
} }
return nil
// Update the exchange record })
_, err = tx.Exec(`
UPDATE exchanges SET
id = ?,
exchange_type = ?,
account_name = ?
WHERE id = ? AND user_id = ?
`, newID, oldID, "Default", oldID, r.userID)
if err != nil {
logger.Errorf("Failed to migrate exchange %s: %v", oldID, err)
return err
}
logger.Infof("✅ Migrated exchange %s -> UUID %s for user %s", oldID, newID, r.userID)
}
if err := tx.Commit(); err != nil {
return err
}
logger.Infof("✅ Multi-account migration completed successfully")
return nil
} }
func (s *ExchangeStore) initDefaultData() error { func (s *ExchangeStore) initDefaultData() error {
@@ -206,108 +136,24 @@ func (s *ExchangeStore) initDefaultData() error {
return nil return nil
} }
func (s *ExchangeStore) encrypt(plaintext string) string {
if s.encryptFunc != nil {
return s.encryptFunc(plaintext)
}
return plaintext
}
func (s *ExchangeStore) decrypt(encrypted string) string {
if s.decryptFunc != nil {
return s.decryptFunc(encrypted)
}
return encrypted
}
// List gets user's exchange list // List gets user's exchange list
func (s *ExchangeStore) List(userID string) ([]*Exchange, error) { func (s *ExchangeStore) List(userID string) ([]*Exchange, error) {
rows, err := s.db.Query(` var exchanges []*Exchange
SELECT id, COALESCE(exchange_type, '') as exchange_type, COALESCE(account_name, '') as account_name, err := s.db.Where("user_id = ?", userID).Order("exchange_type, account_name").Find(&exchanges).Error
user_id, name, type, enabled, api_key, secret_key,
COALESCE(passphrase, '') as passphrase, testnet,
COALESCE(hyperliquid_wallet_addr, '') as hyperliquid_wallet_addr,
COALESCE(aster_user, '') as aster_user,
COALESCE(aster_signer, '') as aster_signer,
COALESCE(aster_private_key, '') as aster_private_key,
COALESCE(lighter_wallet_addr, '') as lighter_wallet_addr,
COALESCE(lighter_private_key, '') as lighter_private_key,
COALESCE(lighter_api_key_private_key, '') as lighter_api_key_private_key,
COALESCE(lighter_api_key_index, 0) as lighter_api_key_index,
created_at, updated_at
FROM exchanges WHERE user_id = ? ORDER BY exchange_type, account_name
`, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
exchanges := make([]*Exchange, 0)
for rows.Next() {
var e Exchange
var createdAt, updatedAt string
err := rows.Scan(
&e.ID, &e.ExchangeType, &e.AccountName,
&e.UserID, &e.Name, &e.Type,
&e.Enabled, &e.APIKey, &e.SecretKey, &e.Passphrase, &e.Testnet,
&e.HyperliquidWalletAddr, &e.AsterUser, &e.AsterSigner, &e.AsterPrivateKey,
&e.LighterWalletAddr, &e.LighterPrivateKey, &e.LighterAPIKeyPrivateKey, &e.LighterAPIKeyIndex,
&createdAt, &updatedAt,
)
if err != nil {
return nil, err
}
e.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
e.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
e.APIKey = s.decrypt(e.APIKey)
e.SecretKey = s.decrypt(e.SecretKey)
e.Passphrase = s.decrypt(e.Passphrase)
e.AsterPrivateKey = s.decrypt(e.AsterPrivateKey)
e.LighterPrivateKey = s.decrypt(e.LighterPrivateKey)
e.LighterAPIKeyPrivateKey = s.decrypt(e.LighterAPIKeyPrivateKey)
exchanges = append(exchanges, &e)
}
return exchanges, nil return exchanges, nil
} }
// GetByID gets a specific exchange by UUID // GetByID gets a specific exchange by UUID
func (s *ExchangeStore) GetByID(userID, id string) (*Exchange, error) { func (s *ExchangeStore) GetByID(userID, id string) (*Exchange, error) {
var e Exchange var exchange Exchange
var createdAt, updatedAt string err := s.db.Where("id = ? AND user_id = ?", id, userID).First(&exchange).Error
err := s.db.QueryRow(`
SELECT id, COALESCE(exchange_type, '') as exchange_type, COALESCE(account_name, '') as account_name,
user_id, name, type, enabled, api_key, secret_key,
COALESCE(passphrase, '') as passphrase, testnet,
COALESCE(hyperliquid_wallet_addr, '') as hyperliquid_wallet_addr,
COALESCE(aster_user, '') as aster_user,
COALESCE(aster_signer, '') as aster_signer,
COALESCE(aster_private_key, '') as aster_private_key,
COALESCE(lighter_wallet_addr, '') as lighter_wallet_addr,
COALESCE(lighter_private_key, '') as lighter_private_key,
COALESCE(lighter_api_key_private_key, '') as lighter_api_key_private_key,
COALESCE(lighter_api_key_index, 0) as lighter_api_key_index,
created_at, updated_at
FROM exchanges WHERE id = ? AND user_id = ?
`, id, userID).Scan(
&e.ID, &e.ExchangeType, &e.AccountName,
&e.UserID, &e.Name, &e.Type,
&e.Enabled, &e.APIKey, &e.SecretKey, &e.Passphrase, &e.Testnet,
&e.HyperliquidWalletAddr, &e.AsterUser, &e.AsterSigner, &e.AsterPrivateKey,
&e.LighterWalletAddr, &e.LighterPrivateKey, &e.LighterAPIKeyPrivateKey, &e.LighterAPIKeyIndex,
&createdAt, &updatedAt,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
e.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) return &exchange, nil
e.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
e.APIKey = s.decrypt(e.APIKey)
e.SecretKey = s.decrypt(e.SecretKey)
e.Passphrase = s.decrypt(e.Passphrase)
e.AsterPrivateKey = s.decrypt(e.AsterPrivateKey)
e.LighterPrivateKey = s.decrypt(e.LighterPrivateKey)
e.LighterAPIKeyPrivateKey = s.decrypt(e.LighterAPIKeyPrivateKey)
return &e, nil
} }
// getExchangeNameAndType returns the display name and type for an exchange type // getExchangeNameAndType returns the display name and type for an exchange type
@@ -341,7 +187,6 @@ func (s *ExchangeStore) Create(userID, exchangeType, accountName string, enabled
id := uuid.New().String() id := uuid.New().String()
name, typ := getExchangeNameAndType(exchangeType) name, typ := getExchangeNameAndType(exchangeType)
// If account name is empty, use "Default"
if accountName == "" { if accountName == "" {
accountName = "Default" accountName = "Default"
} }
@@ -349,19 +194,29 @@ func (s *ExchangeStore) Create(userID, exchangeType, accountName string, enabled
logger.Debugf("🔧 ExchangeStore.Create: userID=%s, exchangeType=%s, accountName=%s, id=%s", logger.Debugf("🔧 ExchangeStore.Create: userID=%s, exchangeType=%s, accountName=%s, id=%s",
userID, exchangeType, accountName, id) userID, exchangeType, accountName, id)
_, err := s.db.Exec(` exchange := &Exchange{
INSERT INTO exchanges (id, exchange_type, account_name, user_id, name, type, enabled, ID: id,
api_key, secret_key, passphrase, testnet, ExchangeType: exchangeType,
hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key, AccountName: accountName,
lighter_wallet_addr, lighter_private_key, lighter_api_key_private_key, lighter_api_key_index, UserID: userID,
created_at, updated_at) Name: name,
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now')) Type: typ,
`, id, exchangeType, accountName, userID, name, typ, enabled, Enabled: enabled,
s.encrypt(apiKey), s.encrypt(secretKey), s.encrypt(passphrase), testnet, APIKey: crypto.EncryptedString(apiKey),
hyperliquidWalletAddr, asterUser, asterSigner, s.encrypt(asterPrivateKey), SecretKey: crypto.EncryptedString(secretKey),
lighterWalletAddr, s.encrypt(lighterPrivateKey), s.encrypt(lighterApiKeyPrivateKey), lighterApiKeyIndex) Passphrase: crypto.EncryptedString(passphrase),
Testnet: testnet,
HyperliquidWalletAddr: hyperliquidWalletAddr,
AsterUser: asterUser,
AsterSigner: asterSigner,
AsterPrivateKey: crypto.EncryptedString(asterPrivateKey),
LighterWalletAddr: lighterWalletAddr,
LighterPrivateKey: crypto.EncryptedString(lighterPrivateKey),
LighterAPIKeyPrivateKey: crypto.EncryptedString(lighterApiKeyPrivateKey),
LighterAPIKeyIndex: lighterApiKeyIndex,
}
if err != nil { if err := s.db.Create(exchange).Error; err != nil {
return "", err return "", err
} }
return id, nil return id, nil
@@ -373,53 +228,42 @@ func (s *ExchangeStore) Update(userID, id string, enabled bool, apiKey, secretKe
logger.Debugf("🔧 ExchangeStore.Update: userID=%s, id=%s, enabled=%v", userID, id, enabled) logger.Debugf("🔧 ExchangeStore.Update: userID=%s, id=%s, enabled=%v", userID, id, enabled)
setClauses := []string{ updates := map[string]interface{}{
"enabled = ?", "enabled": enabled,
"testnet = ?", "testnet": testnet,
"hyperliquid_wallet_addr = ?", "hyperliquid_wallet_addr": hyperliquidWalletAddr,
"aster_user = ?", "aster_user": asterUser,
"aster_signer = ?", "aster_signer": asterSigner,
"lighter_wallet_addr = ?", "lighter_wallet_addr": lighterWalletAddr,
"lighter_api_key_index = ?", "lighter_api_key_index": lighterApiKeyIndex,
"updated_at = datetime('now')", "updated_at": time.Now(),
} }
args := []interface{}{enabled, testnet, hyperliquidWalletAddr, asterUser, asterSigner, lighterWalletAddr, lighterApiKeyIndex}
// Only update encrypted fields if not empty
if apiKey != "" { if apiKey != "" {
setClauses = append(setClauses, "api_key = ?") updates["api_key"] = crypto.EncryptedString(apiKey)
args = append(args, s.encrypt(apiKey))
} }
if secretKey != "" { if secretKey != "" {
setClauses = append(setClauses, "secret_key = ?") updates["secret_key"] = crypto.EncryptedString(secretKey)
args = append(args, s.encrypt(secretKey))
} }
if passphrase != "" { if passphrase != "" {
setClauses = append(setClauses, "passphrase = ?") updates["passphrase"] = crypto.EncryptedString(passphrase)
args = append(args, s.encrypt(passphrase))
} }
if asterPrivateKey != "" { if asterPrivateKey != "" {
setClauses = append(setClauses, "aster_private_key = ?") updates["aster_private_key"] = crypto.EncryptedString(asterPrivateKey)
args = append(args, s.encrypt(asterPrivateKey))
} }
if lighterPrivateKey != "" { if lighterPrivateKey != "" {
setClauses = append(setClauses, "lighter_private_key = ?") updates["lighter_private_key"] = crypto.EncryptedString(lighterPrivateKey)
args = append(args, s.encrypt(lighterPrivateKey))
} }
if lighterApiKeyPrivateKey != "" { if lighterApiKeyPrivateKey != "" {
setClauses = append(setClauses, "lighter_api_key_private_key = ?") updates["lighter_api_key_private_key"] = crypto.EncryptedString(lighterApiKeyPrivateKey)
args = append(args, s.encrypt(lighterApiKeyPrivateKey))
} }
args = append(args, id, userID) result := s.db.Model(&Exchange{}).Where("id = ? AND user_id = ?", id, userID).Updates(updates)
query := fmt.Sprintf(`UPDATE exchanges SET %s WHERE id = ? AND user_id = ?`, strings.Join(setClauses, ", ")) if result.Error != nil {
return result.Error
result, err := s.db.Exec(query, args...)
if err != nil {
return err
} }
if result.RowsAffected == 0 {
rowsAffected, _ := result.RowsAffected()
if rowsAffected == 0 {
return fmt.Errorf("exchange not found: id=%s, userID=%s", id, userID) return fmt.Errorf("exchange not found: id=%s, userID=%s", id, userID)
} }
return nil return nil
@@ -427,13 +271,16 @@ func (s *ExchangeStore) Update(userID, id string, enabled bool, apiKey, secretKe
// UpdateAccountName updates the account name for an exchange // UpdateAccountName updates the account name for an exchange
func (s *ExchangeStore) UpdateAccountName(userID, id, accountName string) error { func (s *ExchangeStore) UpdateAccountName(userID, id, accountName string) error {
result, err := s.db.Exec(`UPDATE exchanges SET account_name = ?, updated_at = datetime('now') WHERE id = ? AND user_id = ?`, result := s.db.Model(&Exchange{}).
accountName, id, userID) Where("id = ? AND user_id = ?", id, userID).
if err != nil { Updates(map[string]interface{}{
return err "account_name": accountName,
"updated_at": time.Now(),
})
if result.Error != nil {
return result.Error
} }
rowsAffected, _ := result.RowsAffected() if result.RowsAffected == 0 {
if rowsAffected == 0 {
return fmt.Errorf("exchange not found: id=%s, userID=%s", id, userID) return fmt.Errorf("exchange not found: id=%s, userID=%s", id, userID)
} }
return nil return nil
@@ -441,12 +288,11 @@ func (s *ExchangeStore) UpdateAccountName(userID, id, accountName string) error
// Delete deletes an exchange account // Delete deletes an exchange account
func (s *ExchangeStore) Delete(userID, id string) error { func (s *ExchangeStore) Delete(userID, id string) error {
result, err := s.db.Exec(`DELETE FROM exchanges WHERE id = ? AND user_id = ?`, id, userID) result := s.db.Where("id = ? AND user_id = ?", id, userID).Delete(&Exchange{})
if err != nil { if result.Error != nil {
return err return result.Error
} }
rowsAffected, _ := result.RowsAffected() if result.RowsAffected == 0 {
if rowsAffected == 0 {
return fmt.Errorf("exchange not found: id=%s, userID=%s", id, userID) return fmt.Errorf("exchange not found: id=%s, userID=%s", id, userID)
} }
logger.Infof("🗑️ Deleted exchange: id=%s, userID=%s", id, userID) logger.Infof("🗑️ Deleted exchange: id=%s, userID=%s", id, userID)
@@ -460,20 +306,25 @@ func (s *ExchangeStore) CreateLegacy(userID, id, name, typ string, enabled bool,
// Check if this is an old-style ID (exchange type as ID) // Check if this is an old-style ID (exchange type as ID)
if id == "binance" || id == "bybit" || id == "okx" || id == "bitget" || id == "hyperliquid" || id == "aster" || id == "lighter" { if id == "binance" || id == "bybit" || id == "okx" || id == "bitget" || id == "hyperliquid" || id == "aster" || id == "lighter" {
// Use new Create method with exchange type
_, err := s.Create(userID, id, "Default", enabled, apiKey, secretKey, "", testnet, _, err := s.Create(userID, id, "Default", enabled, apiKey, secretKey, "", testnet,
hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey, "", "", "", 0) hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey, "", "", "", 0)
return err return err
} }
// Otherwise assume it's already a UUID // Otherwise assume it's already a UUID
_, err := s.db.Exec(` exchange := &Exchange{
INSERT OR IGNORE INTO exchanges (id, exchange_type, account_name, user_id, name, type, enabled, ID: id,
api_key, secret_key, testnet, UserID: userID,
hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key, Name: name,
lighter_wallet_addr, lighter_private_key) Type: typ,
VALUES (?, '', '', ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, '', '') Enabled: enabled,
`, id, userID, name, typ, enabled, s.encrypt(apiKey), s.encrypt(secretKey), testnet, APIKey: crypto.EncryptedString(apiKey),
hyperliquidWalletAddr, asterUser, asterSigner, s.encrypt(asterPrivateKey)) SecretKey: crypto.EncryptedString(secretKey),
return err Testnet: testnet,
HyperliquidWalletAddr: hyperliquidWalletAddr,
AsterUser: asterUser,
AsterSigner: asterSigner,
AsterPrivateKey: crypto.EncryptedString(asterPrivateKey),
}
return s.db.Where("id = ?", id).FirstOrCreate(exchange).Error
} }
+146
View File
@@ -0,0 +1,146 @@
package store
import (
"fmt"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// GormDB is the global GORM database connection
var gormDB *gorm.DB
// DB returns the GORM database connection
func DB() *gorm.DB {
return gormDB
}
// InitGorm initializes GORM with SQLite
func InitGorm(dbPath string) (*gorm.DB, error) {
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
return nil, fmt.Errorf("failed to open SQLite database: %w", err)
}
// Set connection pool for SQLite
sqlDB, err := db.DB()
if err != nil {
return nil, err
}
sqlDB.SetMaxOpenConns(1)
sqlDB.SetMaxIdleConns(1)
// Enable foreign keys for SQLite
db.Exec("PRAGMA foreign_keys = ON")
db.Exec("PRAGMA journal_mode = DELETE")
db.Exec("PRAGMA synchronous = FULL")
db.Exec("PRAGMA busy_timeout = 5000")
gormDB = db
return db, nil
}
// InitGormPostgres initializes GORM with PostgreSQL
func InitGormPostgres(host string, port int, user, password, dbname, sslmode string) (*gorm.DB, error) {
dsn := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
host, port, user, password, dbname, sslmode,
)
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err)
}
// Set connection pool for PostgreSQL
sqlDB, err := db.DB()
if err != nil {
return nil, err
}
sqlDB.SetMaxOpenConns(25)
sqlDB.SetMaxIdleConns(5)
gormDB = db
return db, nil
}
// InitGormWithConfig initializes GORM with provided configuration
// Uses DBConfig from driver.go
func InitGormWithConfig(cfg DBConfig) (*gorm.DB, error) {
switch cfg.Type {
case DBTypeSQLite:
return InitGorm(cfg.Path)
case DBTypePostgres:
return InitGormPostgres(
cfg.Host,
cfg.Port,
cfg.User,
cfg.Password,
cfg.DBName,
cfg.SSLMode,
)
default:
return nil, fmt.Errorf("unsupported DB_TYPE: %s (use 'sqlite' or 'postgres')", cfg.Type)
}
}
// ============================================================================
// Query Scopes - Reusable query helpers
// ============================================================================
// ForUser returns a scope that filters by user_id
func ForUser(userID string) func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.Where("user_id = ?", userID)
}
}
// ForTrader returns a scope that filters by trader_id
func ForTrader(traderID string) func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.Where("trader_id = ?", traderID)
}
}
// OpenPositions returns a scope for open positions
func OpenPositions() func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.Where("status = ?", "OPEN")
}
}
// ClosedPositions returns a scope for closed positions
func ClosedPositions() func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.Where("status = ?", "CLOSED")
}
}
// OrderByCreatedDesc returns a scope that orders by created_at DESC
func OrderByCreatedDesc() func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.Order("created_at DESC")
}
}
// OrderByUpdatedDesc returns a scope that orders by updated_at DESC
func OrderByUpdatedDesc() func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.Order("updated_at DESC")
}
}
// Paginate returns a scope for pagination
func Paginate(limit, offset int) func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.Limit(limit).Offset(offset)
}
}
+199 -448
View File
@@ -1,495 +1,255 @@
package store package store
import ( import (
"database/sql"
"fmt" "fmt"
"strings"
"time" "time"
"gorm.io/gorm"
) )
// TraderOrder 订单记录(完整的订单生命周期追踪) // TraderOrder order record
type TraderOrder struct { type TraderOrder struct {
ID int64 `json:"id"` ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
TraderID string `json:"trader_id"` TraderID string `gorm:"column:trader_id;not null;index:idx_orders_trader_id" json:"trader_id"`
ExchangeID string `json:"exchange_id"` // Exchange account UUID ExchangeID string `gorm:"column:exchange_id;not null;default:''" json:"exchange_id"`
ExchangeType string `json:"exchange_type"` // Exchange type (hyperliquid/lighter/binance/etc) ExchangeType string `gorm:"column:exchange_type;not null;default:''" json:"exchange_type"`
ExchangeOrderID string `json:"exchange_order_id"` // Exchange order ID ExchangeOrderID string `gorm:"column:exchange_order_id;not null;uniqueIndex:idx_orders_exchange_unique,priority:2" json:"exchange_order_id"`
ClientOrderID string `json:"client_order_id"` // Client order ID ClientOrderID string `gorm:"column:client_order_id;default:''" json:"client_order_id"`
Symbol string `json:"symbol"` // Trading pair Symbol string `gorm:"column:symbol;not null;index:idx_orders_symbol" json:"symbol"`
Side string `json:"side"` // BUY/SELL Side string `gorm:"column:side;not null" json:"side"`
PositionSide string `json:"position_side"` // LONG/SHORT (hedge mode) PositionSide string `gorm:"column:position_side;default:''" json:"position_side"`
Type string `json:"type"` // MARKET/LIMIT/STOP/STOP_MARKET/TAKE_PROFIT/TAKE_PROFIT_MARKET Type string `gorm:"column:type;not null" json:"type"`
TimeInForce string `json:"time_in_force"` // GTC/IOC/FOK TimeInForce string `gorm:"column:time_in_force;default:GTC" json:"time_in_force"`
Quantity float64 `json:"quantity"` // 订单数量 Quantity float64 `gorm:"column:quantity;not null" json:"quantity"`
Price float64 `json:"price"` // 限价单价格 Price float64 `gorm:"column:price;default:0" json:"price"`
StopPrice float64 `json:"stop_price"` // 止损/止盈触发价格 StopPrice float64 `gorm:"column:stop_price;default:0" json:"stop_price"`
Status string `json:"status"` // NEW/PARTIALLY_FILLED/FILLED/CANCELED/REJECTED/EXPIRED Status string `gorm:"column:status;not null;default:NEW;index:idx_orders_status" json:"status"`
FilledQuantity float64 `json:"filled_quantity"` // 已成交数量 FilledQuantity float64 `gorm:"column:filled_quantity;default:0" json:"filled_quantity"`
AvgFillPrice float64 `json:"avg_fill_price"` // 平均成交价格 AvgFillPrice float64 `gorm:"column:avg_fill_price;default:0" json:"avg_fill_price"`
Commission float64 `json:"commission"` // 手续费总额 Commission float64 `gorm:"column:commission;default:0" json:"commission"`
CommissionAsset string `json:"commission_asset"` // 手续费资产(USDT等) CommissionAsset string `gorm:"column:commission_asset;default:USDT" json:"commission_asset"`
Leverage int `json:"leverage"` // 杠杆倍数 Leverage int `gorm:"column:leverage;default:1" json:"leverage"`
ReduceOnly bool `json:"reduce_only"` // 是否只减仓 ReduceOnly bool `gorm:"column:reduce_only;default:false" json:"reduce_only"`
ClosePosition bool `json:"close_position"` // 是否平仓单 ClosePosition bool `gorm:"column:close_position;default:false" json:"close_position"`
WorkingType string `json:"working_type"` // CONTRACT_PRICE/MARK_PRICE WorkingType string `gorm:"column:working_type;default:CONTRACT_PRICE" json:"working_type"`
PriceProtect bool `json:"price_protect"` // 价格保护 PriceProtect bool `gorm:"column:price_protect;default:false" json:"price_protect"`
OrderAction string `json:"order_action"` // OPEN_LONG/OPEN_SHORT/CLOSE_LONG/CLOSE_SHORT/ADD_LONG/ADD_SHORT/STOP_LOSS/TAKE_PROFIT OrderAction string `gorm:"column:order_action;default:''" json:"order_action"`
RelatedPositionID int64 `json:"related_position_id"` // 关联的仓位ID RelatedPositionID int64 `gorm:"column:related_position_id;default:0" json:"related_position_id"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime" json:"updated_at"`
FilledAt time.Time `json:"filled_at"` // 完全成交时间 FilledAt time.Time `gorm:"column:filled_at" json:"filled_at"`
} }
// TraderFill trade record (one order may have multiple fills) // TableName returns the table name for TraderOrder
func (TraderOrder) TableName() string {
return "trader_orders"
}
// TraderFill trade record
type TraderFill struct { type TraderFill struct {
ID int64 `json:"id"` ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
TraderID string `json:"trader_id"` TraderID string `gorm:"column:trader_id;not null;index:idx_fills_trader_id" json:"trader_id"`
ExchangeID string `json:"exchange_id"` // Exchange account UUID ExchangeID string `gorm:"column:exchange_id;not null;default:''" json:"exchange_id"`
ExchangeType string `json:"exchange_type"` // Exchange type (hyperliquid/lighter/binance/etc) ExchangeType string `gorm:"column:exchange_type;not null;default:''" json:"exchange_type"`
OrderID int64 `json:"order_id"` // Related order ID OrderID int64 `gorm:"column:order_id;not null;index:idx_fills_order_id" json:"order_id"`
ExchangeOrderID string `json:"exchange_order_id"` // Exchange order ID ExchangeOrderID string `gorm:"column:exchange_order_id;not null" json:"exchange_order_id"`
ExchangeTradeID string `json:"exchange_trade_id"` // Exchange trade ID ExchangeTradeID string `gorm:"column:exchange_trade_id;not null;uniqueIndex:idx_fills_exchange_unique,priority:2" json:"exchange_trade_id"`
Symbol string `json:"symbol"` Symbol string `gorm:"column:symbol;not null" json:"symbol"`
Side string `json:"side"` // BUY/SELL Side string `gorm:"column:side;not null" json:"side"`
Price float64 `json:"price"` // 成交价格 Price float64 `gorm:"column:price;not null" json:"price"`
Quantity float64 `json:"quantity"` // 成交数量 Quantity float64 `gorm:"column:quantity;not null" json:"quantity"`
QuoteQuantity float64 `json:"quote_quantity"` // 成交金额(USDT QuoteQuantity float64 `gorm:"column:quote_quantity;not null" json:"quote_quantity"`
Commission float64 `json:"commission"` // 手续费 Commission float64 `gorm:"column:commission;not null" json:"commission"`
CommissionAsset string `json:"commission_asset"` CommissionAsset string `gorm:"column:commission_asset;not null" json:"commission_asset"`
RealizedPnL float64 `json:"realized_pnl"` // 实现盈亏(平仓时) RealizedPnL float64 `gorm:"column:realized_pnl;default:0" json:"realized_pnl"`
IsMaker bool `json:"is_maker"` // 是否为maker IsMaker bool `gorm:"column:is_maker;default:false" json:"is_maker"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
} }
// OrderStore 订单存储 // TableName returns the table name for TraderFill
func (TraderFill) TableName() string {
return "trader_fills"
}
// OrderStore order storage
type OrderStore struct { type OrderStore struct {
db *sql.DB db *gorm.DB
} }
// NewOrderStore 创建订单存储实例 // NewOrderStore creates order storage instance
func NewOrderStore(db *sql.DB) *OrderStore { func NewOrderStore(db *gorm.DB) *OrderStore {
return &OrderStore{db: db} return &OrderStore{db: db}
} }
// InitTables 初始化订单表 // InitTables initializes order tables
func (s *OrderStore) InitTables() error { func (s *OrderStore) InitTables() error {
// 创建订单表 // For PostgreSQL, check if tables exist to avoid AutoMigrate index conflicts
_, err := s.db.Exec(` if s.db.Dialector.Name() == "postgres" {
CREATE TABLE IF NOT EXISTS trader_orders ( var ordersExist, fillsExist int64
id INTEGER PRIMARY KEY AUTOINCREMENT, s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'trader_orders'`).Scan(&ordersExist)
trader_id TEXT NOT NULL, s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'trader_fills'`).Scan(&fillsExist)
exchange_id TEXT NOT NULL DEFAULT '',
exchange_type TEXT NOT NULL DEFAULT '', if ordersExist > 0 && fillsExist > 0 {
exchange_order_id TEXT NOT NULL, // Tables exist - just ensure indexes exist, skip AutoMigrate
client_order_id TEXT DEFAULT '', s.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_orders_exchange_unique ON trader_orders(exchange_id, exchange_order_id)`)
symbol TEXT NOT NULL, s.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_fills_exchange_unique ON trader_fills(exchange_id, exchange_trade_id)`)
side TEXT NOT NULL, s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_trader_id ON trader_orders(trader_id)`)
position_side TEXT DEFAULT '', s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_symbol ON trader_orders(symbol)`)
type TEXT NOT NULL, s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_status ON trader_orders(status)`)
time_in_force TEXT DEFAULT 'GTC', s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_fills_trader_id ON trader_fills(trader_id)`)
quantity REAL NOT NULL, s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_fills_order_id ON trader_fills(order_id)`)
price REAL DEFAULT 0, return nil
stop_price REAL DEFAULT 0, }
status TEXT NOT NULL DEFAULT 'NEW',
filled_quantity REAL DEFAULT 0,
avg_fill_price REAL DEFAULT 0,
commission REAL DEFAULT 0,
commission_asset TEXT DEFAULT 'USDT',
leverage INTEGER DEFAULT 1,
reduce_only INTEGER DEFAULT 0,
close_position INTEGER DEFAULT 0,
working_type TEXT DEFAULT 'CONTRACT_PRICE',
price_protect INTEGER DEFAULT 0,
order_action TEXT DEFAULT '',
related_position_id INTEGER DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
filled_at DATETIME,
UNIQUE(exchange_id, exchange_order_id)
)
`)
if err != nil {
return fmt.Errorf("failed to create trader_orders table: %w", err)
} }
// 创建成交记录表 if err := s.db.AutoMigrate(&TraderOrder{}, &TraderFill{}); err != nil {
_, err = s.db.Exec(` return fmt.Errorf("failed to migrate order tables: %w", err)
CREATE TABLE IF NOT EXISTS trader_fills (
id INTEGER PRIMARY KEY AUTOINCREMENT,
trader_id TEXT NOT NULL,
exchange_id TEXT NOT NULL DEFAULT '',
exchange_type TEXT NOT NULL DEFAULT '',
order_id INTEGER NOT NULL,
exchange_order_id TEXT NOT NULL,
exchange_trade_id TEXT NOT NULL,
symbol TEXT NOT NULL,
side TEXT NOT NULL,
price REAL NOT NULL,
quantity REAL NOT NULL,
quote_quantity REAL NOT NULL,
commission REAL NOT NULL,
commission_asset TEXT NOT NULL,
realized_pnl REAL DEFAULT 0,
is_maker INTEGER DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(exchange_id, exchange_trade_id),
FOREIGN KEY (order_id) REFERENCES trader_orders(id)
)
`)
if err != nil {
return fmt.Errorf("failed to create trader_fills table: %w", err)
} }
// 创建索引 // Create unique composite index for exchange_id + exchange_order_id
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_trader_id ON trader_orders(trader_id)`) s.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_orders_exchange_unique ON trader_orders(exchange_id, exchange_order_id)`)
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_symbol ON trader_orders(symbol)`) // Create unique composite index for exchange_id + exchange_trade_id
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_status ON trader_orders(status)`) s.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_fills_exchange_unique ON trader_fills(exchange_id, exchange_trade_id)`)
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_exchange_order_id ON trader_orders(exchange_id, exchange_order_id)`)
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_fills_order_id ON trader_fills(order_id)`)
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_fills_trader_id ON trader_fills(trader_id)`)
return nil return nil
} }
// CreateOrder 创建订单记录(去重:如果订单已存在则返回已有记录) // CreateOrder creates order record
func (s *OrderStore) CreateOrder(order *TraderOrder) error { func (s *OrderStore) CreateOrder(order *TraderOrder) error {
// 1. 先检查订单是否已存在(去重) // Check if order already exists
existing, err := s.GetOrderByExchangeID(order.ExchangeID, order.ExchangeOrderID) existing, err := s.GetOrderByExchangeID(order.ExchangeID, order.ExchangeOrderID)
if err != nil { if err != nil {
return fmt.Errorf("failed to check existing order: %w", err) return fmt.Errorf("failed to check existing order: %w", err)
} }
if existing != nil { if existing != nil {
// 订单已存在,返回已有记录的ID
order.ID = existing.ID order.ID = existing.ID
order.CreatedAt = existing.CreatedAt order.CreatedAt = existing.CreatedAt
order.UpdatedAt = existing.UpdatedAt order.UpdatedAt = existing.UpdatedAt
return nil // 不是错误,只是跳过插入 return nil
} }
// 2. 订单不存在,插入新记录 return s.db.Create(order).Error
now := time.Now()
order.CreatedAt = now
order.UpdatedAt = now
result, err := s.db.Exec(`
INSERT INTO trader_orders (
trader_id, exchange_id, exchange_type, exchange_order_id, client_order_id,
symbol, side, position_side, type, time_in_force,
quantity, price, stop_price, status,
filled_quantity, avg_fill_price, commission, commission_asset,
leverage, reduce_only, close_position, working_type, price_protect,
order_action, related_position_id,
created_at, updated_at, filled_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`,
order.TraderID, order.ExchangeID, order.ExchangeType, order.ExchangeOrderID, order.ClientOrderID,
order.Symbol, order.Side, order.PositionSide, order.Type, order.TimeInForce,
order.Quantity, order.Price, order.StopPrice, order.Status,
order.FilledQuantity, order.AvgFillPrice, order.Commission, order.CommissionAsset,
order.Leverage, order.ReduceOnly, order.ClosePosition, order.WorkingType, order.PriceProtect,
order.OrderAction, order.RelatedPositionID,
formatTimeOrNow(order.CreatedAt, now), formatTimeOrNow(order.UpdatedAt, now),
formatTimePtr(order.FilledAt),
)
if err != nil {
return fmt.Errorf("failed to create order: %w", err)
}
id, _ := result.LastInsertId()
order.ID = id
return nil
} }
// UpdateOrderStatus 更新订单状态 // UpdateOrderStatus updates order status
func (s *OrderStore) UpdateOrderStatus(id int64, status string, filledQty, avgPrice, commission float64) error { func (s *OrderStore) UpdateOrderStatus(id int64, status string, filledQty, avgPrice, commission float64) error {
now := time.Now() updates := map[string]interface{}{
updateSQL := ` "status": status,
UPDATE trader_orders SET "filled_quantity": filledQty,
status = ?, "avg_fill_price": avgPrice,
filled_quantity = ?, "commission": commission,
avg_fill_price = ?, }
commission = ?,
updated_at = ?
`
args := []interface{}{status, filledQty, avgPrice, commission, now.Format(time.RFC3339)}
// 如果完全成交,记录成交时间
if status == "FILLED" { if status == "FILLED" {
updateSQL += `, filled_at = ?` updates["filled_at"] = time.Now()
args = append(args, now.Format(time.RFC3339))
} }
updateSQL += ` WHERE id = ?` return s.db.Model(&TraderOrder{}).Where("id = ?", id).Updates(updates).Error
args = append(args, id)
_, err := s.db.Exec(updateSQL, args...)
if err != nil {
return fmt.Errorf("failed to update order status: %w", err)
}
return nil
} }
// CreateFill 创建成交记录(去重:如果成交已存在则跳过) // CreateFill creates fill record
func (s *OrderStore) CreateFill(fill *TraderFill) error { func (s *OrderStore) CreateFill(fill *TraderFill) error {
// 1. 先检查成交是否已存在(去重) // Check if fill already exists
existing, err := s.GetFillByExchangeTradeID(fill.ExchangeID, fill.ExchangeTradeID) existing, err := s.GetFillByExchangeTradeID(fill.ExchangeID, fill.ExchangeTradeID)
if err != nil { if err != nil {
return fmt.Errorf("failed to check existing fill: %w", err) return fmt.Errorf("failed to check existing fill: %w", err)
} }
if existing != nil { if existing != nil {
// 成交已存在,返回已有记录的ID
fill.ID = existing.ID fill.ID = existing.ID
fill.CreatedAt = existing.CreatedAt fill.CreatedAt = existing.CreatedAt
return nil // 不是错误,只是跳过插入 return nil
} }
// 2. 成交不存在,插入新记录 return s.db.Create(fill).Error
now := time.Now()
fill.CreatedAt = now
result, err := s.db.Exec(`
INSERT INTO trader_fills (
trader_id, exchange_id, exchange_type, order_id, exchange_order_id, exchange_trade_id,
symbol, side, price, quantity, quote_quantity,
commission, commission_asset, realized_pnl, is_maker,
created_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`,
fill.TraderID, fill.ExchangeID, fill.ExchangeType, fill.OrderID, fill.ExchangeOrderID, fill.ExchangeTradeID,
fill.Symbol, fill.Side, fill.Price, fill.Quantity, fill.QuoteQuantity,
fill.Commission, fill.CommissionAsset, fill.RealizedPnL, fill.IsMaker,
now.Format(time.RFC3339),
)
if err != nil {
return fmt.Errorf("failed to create fill: %w", err)
}
id, _ := result.LastInsertId()
fill.ID = id
return nil
} }
// GetFillByExchangeTradeID 根据交易所成交ID获取成交记录 // GetFillByExchangeTradeID gets fill by exchange trade ID
func (s *OrderStore) GetFillByExchangeTradeID(exchangeID, exchangeTradeID string) (*TraderFill, error) { func (s *OrderStore) GetFillByExchangeTradeID(exchangeID, exchangeTradeID string) (*TraderFill, error) {
row := s.db.QueryRow(`
SELECT id, trader_id, exchange_id, exchange_type, order_id, exchange_order_id, exchange_trade_id,
symbol, side, price, quantity, quote_quantity,
commission, commission_asset, realized_pnl, is_maker,
created_at
FROM trader_fills
WHERE exchange_id = ? AND exchange_trade_id = ?
`, exchangeID, exchangeTradeID)
var fill TraderFill var fill TraderFill
var createdAt sql.NullString err := s.db.Where("exchange_id = ? AND exchange_trade_id = ?", exchangeID, exchangeTradeID).First(&fill).Error
err := row.Scan(
&fill.ID, &fill.TraderID, &fill.ExchangeID, &fill.ExchangeType, &fill.OrderID, &fill.ExchangeOrderID, &fill.ExchangeTradeID,
&fill.Symbol, &fill.Side, &fill.Price, &fill.Quantity, &fill.QuoteQuantity,
&fill.Commission, &fill.CommissionAsset, &fill.RealizedPnL, &fill.IsMaker,
&createdAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to get fill: %w", err) return nil, fmt.Errorf("failed to get fill: %w", err)
} }
// Parse time
if createdAt.Valid {
if t, err := time.Parse(time.RFC3339, createdAt.String); err == nil {
fill.CreatedAt = t
}
}
return &fill, nil return &fill, nil
} }
// GetOrderByExchangeID 根据交易所订单ID获取订单 // GetOrderByExchangeID gets order by exchange order ID
func (s *OrderStore) GetOrderByExchangeID(exchangeID, exchangeOrderID string) (*TraderOrder, error) { func (s *OrderStore) GetOrderByExchangeID(exchangeID, exchangeOrderID string) (*TraderOrder, error) {
row := s.db.QueryRow(`
SELECT id, trader_id, exchange_id, exchange_type, exchange_order_id, client_order_id,
symbol, side, position_side, type, time_in_force,
quantity, price, stop_price, status,
filled_quantity, avg_fill_price, commission, commission_asset,
leverage, reduce_only, close_position, working_type, price_protect,
order_action, related_position_id,
created_at, updated_at, filled_at
FROM trader_orders
WHERE exchange_id = ? AND exchange_order_id = ?
`, exchangeID, exchangeOrderID)
var order TraderOrder var order TraderOrder
var createdAt, updatedAt, filledAt sql.NullString err := s.db.Where("exchange_id = ? AND exchange_order_id = ?", exchangeID, exchangeOrderID).First(&order).Error
err := row.Scan(
&order.ID, &order.TraderID, &order.ExchangeID, &order.ExchangeType, &order.ExchangeOrderID, &order.ClientOrderID,
&order.Symbol, &order.Side, &order.PositionSide, &order.Type, &order.TimeInForce,
&order.Quantity, &order.Price, &order.StopPrice, &order.Status,
&order.FilledQuantity, &order.AvgFillPrice, &order.Commission, &order.CommissionAsset,
&order.Leverage, &order.ReduceOnly, &order.ClosePosition, &order.WorkingType, &order.PriceProtect,
&order.OrderAction, &order.RelatedPositionID,
&createdAt, &updatedAt, &filledAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to get order: %w", err) return nil, fmt.Errorf("failed to get order: %w", err)
} }
// Parse times
if createdAt.Valid {
if t, err := time.Parse(time.RFC3339, createdAt.String); err == nil {
order.CreatedAt = t
}
}
if updatedAt.Valid {
if t, err := time.Parse(time.RFC3339, updatedAt.String); err == nil {
order.UpdatedAt = t
}
}
if filledAt.Valid {
if t, err := time.Parse(time.RFC3339, filledAt.String); err == nil {
order.FilledAt = t
}
}
return &order, nil return &order, nil
} }
// GetTraderOrders 获取trader的订单列表 // GetTraderOrders gets trader's order list
func (s *OrderStore) GetTraderOrders(traderID string, limit int) ([]*TraderOrder, error) { func (s *OrderStore) GetTraderOrders(traderID string, limit int) ([]*TraderOrder, error) {
rows, err := s.db.Query(` var orders []*TraderOrder
SELECT id, trader_id, exchange_id, exchange_type, exchange_order_id, client_order_id, err := s.db.Where("trader_id = ?", traderID).
symbol, side, position_side, type, time_in_force, Order("created_at DESC").
quantity, price, stop_price, status, Limit(limit).
filled_quantity, avg_fill_price, commission, commission_asset, Find(&orders).Error
leverage, reduce_only, close_position, working_type, price_protect,
order_action, related_position_id,
created_at, updated_at, filled_at
FROM trader_orders
WHERE trader_id = ?
ORDER BY created_at DESC
LIMIT ?
`, traderID, limit)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to query orders: %w", err) return nil, fmt.Errorf("failed to query orders: %w", err)
} }
defer rows.Close()
var orders []*TraderOrder
for rows.Next() {
var order TraderOrder
var createdAt, updatedAt, filledAt sql.NullString
err := rows.Scan(
&order.ID, &order.TraderID, &order.ExchangeID, &order.ExchangeType, &order.ExchangeOrderID, &order.ClientOrderID,
&order.Symbol, &order.Side, &order.PositionSide, &order.Type, &order.TimeInForce,
&order.Quantity, &order.Price, &order.StopPrice, &order.Status,
&order.FilledQuantity, &order.AvgFillPrice, &order.Commission, &order.CommissionAsset,
&order.Leverage, &order.ReduceOnly, &order.ClosePosition, &order.WorkingType, &order.PriceProtect,
&order.OrderAction, &order.RelatedPositionID,
&createdAt, &updatedAt, &filledAt,
)
if err != nil {
continue
}
// Parse times
if createdAt.Valid {
if t, err := time.Parse(time.RFC3339, createdAt.String); err == nil {
order.CreatedAt = t
}
}
if updatedAt.Valid {
if t, err := time.Parse(time.RFC3339, updatedAt.String); err == nil {
order.UpdatedAt = t
}
}
if filledAt.Valid {
if t, err := time.Parse(time.RFC3339, filledAt.String); err == nil {
order.FilledAt = t
}
}
orders = append(orders, &order)
}
return orders, nil return orders, nil
} }
// GetOrderFills 获取订单的成交记录 // GetOrderFills gets order's fill records
func (s *OrderStore) GetOrderFills(orderID int64) ([]*TraderFill, error) { func (s *OrderStore) GetOrderFills(orderID int64) ([]*TraderFill, error) {
rows, err := s.db.Query(` var fills []*TraderFill
SELECT id, trader_id, exchange_id, exchange_type, order_id, exchange_order_id, exchange_trade_id, err := s.db.Where("order_id = ?", orderID).
symbol, side, price, quantity, quote_quantity, Order("created_at ASC").
commission, commission_asset, realized_pnl, is_maker, Find(&fills).Error
created_at
FROM trader_fills
WHERE order_id = ?
ORDER BY created_at ASC
`, orderID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to query fills: %w", err) return nil, fmt.Errorf("failed to query fills: %w", err)
} }
defer rows.Close()
var fills []*TraderFill
for rows.Next() {
var fill TraderFill
var createdAt sql.NullString
err := rows.Scan(
&fill.ID, &fill.TraderID, &fill.ExchangeID, &fill.ExchangeType, &fill.OrderID, &fill.ExchangeOrderID, &fill.ExchangeTradeID,
&fill.Symbol, &fill.Side, &fill.Price, &fill.Quantity, &fill.QuoteQuantity,
&fill.Commission, &fill.CommissionAsset, &fill.RealizedPnL, &fill.IsMaker,
&createdAt,
)
if err != nil {
continue
}
if createdAt.Valid {
if t, err := time.Parse(time.RFC3339, createdAt.String); err == nil {
fill.CreatedAt = t
}
}
fills = append(fills, &fill)
}
return fills, nil return fills, nil
} }
// GetTraderOrderStats 获取trader的订单统计 // GetTraderOrderStats gets trader's order statistics
func (s *OrderStore) GetTraderOrderStats(traderID string) (map[string]interface{}, error) { func (s *OrderStore) GetTraderOrderStats(traderID string) (map[string]interface{}, error) {
var totalOrders, filledOrders, canceledOrders int type result struct {
var totalCommission, totalVolume float64 TotalOrders int
FilledOrders int
err := s.db.QueryRow(` CanceledOrders int
SELECT TotalCommission float64
COUNT(*) as total_orders, TotalVolume float64
SUM(CASE WHEN status = 'FILLED' THEN 1 ELSE 0 END) as filled_orders, }
SUM(CASE WHEN status = 'CANCELED' THEN 1 ELSE 0 END) as canceled_orders, var r result
SUM(commission) as total_commission,
SUM(filled_quantity * avg_fill_price) as total_volume
FROM trader_orders
WHERE trader_id = ?
`, traderID).Scan(&totalOrders, &filledOrders, &canceledOrders, &totalCommission, &totalVolume)
err := s.db.Model(&TraderOrder{}).
Select(`COUNT(*) as total_orders,
SUM(CASE WHEN status = 'FILLED' THEN 1 ELSE 0 END) as filled_orders,
SUM(CASE WHEN status = 'CANCELED' THEN 1 ELSE 0 END) as canceled_orders,
SUM(commission) as total_commission,
SUM(filled_quantity * avg_fill_price) as total_volume`).
Where("trader_id = ?", traderID).
Scan(&r).Error
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get order stats: %w", err) return nil, fmt.Errorf("failed to get order stats: %w", err)
} }
return map[string]interface{}{ return map[string]interface{}{
"total_orders": totalOrders, "total_orders": r.TotalOrders,
"filled_orders": filledOrders, "filled_orders": r.FilledOrders,
"canceled_orders": canceledOrders, "canceled_orders": r.CanceledOrders,
"total_commission": totalCommission, "total_commission": r.TotalCommission,
"total_volume": totalVolume, "total_volume": r.TotalVolume,
}, nil }, nil
} }
// CleanupDuplicateOrders 清理重复的订单记录(保留最早创建的记录) // CleanupDuplicateOrders cleans up duplicate order records
func (s *OrderStore) CleanupDuplicateOrders() (int, error) { func (s *OrderStore) CleanupDuplicateOrders() (int, error) {
result, err := s.db.Exec(` result := s.db.Exec(`
DELETE FROM trader_orders DELETE FROM trader_orders
WHERE id NOT IN ( WHERE id NOT IN (
SELECT MIN(id) SELECT MIN(id)
@@ -497,17 +257,15 @@ func (s *OrderStore) CleanupDuplicateOrders() (int, error) {
GROUP BY exchange_id, exchange_order_id GROUP BY exchange_id, exchange_order_id
) )
`) `)
if err != nil { if result.Error != nil {
return 0, fmt.Errorf("failed to cleanup duplicate orders: %w", err) return 0, fmt.Errorf("failed to cleanup duplicate orders: %w", result.Error)
} }
return int(result.RowsAffected), nil
rowsAffected, _ := result.RowsAffected()
return int(rowsAffected), nil
} }
// CleanupDuplicateFills 清理重复的成交记录(保留最早创建的记录) // CleanupDuplicateFills cleans up duplicate fill records
func (s *OrderStore) CleanupDuplicateFills() (int, error) { func (s *OrderStore) CleanupDuplicateFills() (int, error) {
result, err := s.db.Exec(` result := s.db.Exec(`
DELETE FROM trader_fills DELETE FROM trader_fills
WHERE id NOT IN ( WHERE id NOT IN (
SELECT MIN(id) SELECT MIN(id)
@@ -515,73 +273,66 @@ func (s *OrderStore) CleanupDuplicateFills() (int, error) {
GROUP BY exchange_id, exchange_trade_id GROUP BY exchange_id, exchange_trade_id
) )
`) `)
if err != nil { if result.Error != nil {
return 0, fmt.Errorf("failed to cleanup duplicate fills: %w", err) return 0, fmt.Errorf("failed to cleanup duplicate fills: %w", result.Error)
} }
return int(result.RowsAffected), nil
rowsAffected, _ := result.RowsAffected()
return int(rowsAffected), nil
} }
// GetDuplicateOrdersCount 获取重复订单的数量(用于诊断) // GetDuplicateOrdersCount gets duplicate orders count
func (s *OrderStore) GetDuplicateOrdersCount() (int, error) { func (s *OrderStore) GetDuplicateOrdersCount() (int, error) {
var count int var total, distinct int64
err := s.db.QueryRow(` s.db.Model(&TraderOrder{}).Count(&total)
SELECT COUNT(*) - COUNT(DISTINCT exchange_id || ',' || exchange_order_id)
FROM trader_orders // Count distinct combinations
`).Scan(&count) var distinctResult struct{ Count int64 }
return count, err s.db.Model(&TraderOrder{}).
Select("COUNT(DISTINCT exchange_id || ',' || exchange_order_id) as count").
Scan(&distinctResult)
distinct = distinctResult.Count
return int(total - distinct), nil
} }
// GetDuplicateFillsCount 获取重复成交的数量(用于诊断) // GetDuplicateFillsCount gets duplicate fills count
func (s *OrderStore) GetDuplicateFillsCount() (int, error) { func (s *OrderStore) GetDuplicateFillsCount() (int, error) {
var count int var total, distinct int64
err := s.db.QueryRow(` s.db.Model(&TraderFill{}).Count(&total)
SELECT COUNT(*) - COUNT(DISTINCT exchange_id || ',' || exchange_trade_id)
FROM trader_fills var distinctResult struct{ Count int64 }
`).Scan(&count) s.db.Model(&TraderFill{}).
return count, err Select("COUNT(DISTINCT exchange_id || ',' || exchange_trade_id) as count").
Scan(&distinctResult)
distinct = distinctResult.Count
return int(total - distinct), nil
} }
// GetMaxTradeIDsByExchange returns max trade ID for each symbol for a given exchange // GetMaxTradeIDsByExchange returns max trade ID for each symbol for a given exchange
// Used for incremental sync - only fetch trades with ID > maxTradeID
func (s *OrderStore) GetMaxTradeIDsByExchange(exchangeID string) (map[string]int64, error) { func (s *OrderStore) GetMaxTradeIDsByExchange(exchangeID string) (map[string]int64, error) {
rows, err := s.db.Query(` type symbolMaxID struct {
SELECT symbol, MAX(CAST(exchange_trade_id AS INTEGER)) as max_trade_id Symbol string
FROM trader_fills MaxTradeID int64
WHERE exchange_id = ? AND exchange_trade_id != '' }
GROUP BY symbol var results []symbolMaxID
`, exchangeID)
err := s.db.Model(&TraderFill{}).
Select("symbol, MAX(CAST(exchange_trade_id AS INTEGER)) as max_trade_id").
Where("exchange_id = ? AND exchange_trade_id != ''", exchangeID).
Group("symbol").
Find(&results).Error
if err != nil { if err != nil {
// If CAST fails (non-numeric trade IDs), fallback to string comparison
if strings.Contains(err.Error(), "CAST") || strings.Contains(err.Error(), "invalid") {
return make(map[string]int64), nil
}
return nil, fmt.Errorf("failed to query max trade IDs: %w", err) return nil, fmt.Errorf("failed to query max trade IDs: %w", err)
} }
defer rows.Close()
result := make(map[string]int64) result := make(map[string]int64)
for rows.Next() { for _, r := range results {
var symbol string result[r.Symbol] = r.MaxTradeID
var maxID int64
if err := rows.Scan(&symbol, &maxID); err != nil {
continue
}
result[symbol] = maxID
} }
return result, nil return result, nil
} }
// formatTimePtr formats time.Time to RFC3339 string, returns NULL for zero time
func formatTimePtr(t time.Time) interface{} {
if t.IsZero() {
return nil
}
return t.Format(time.RFC3339)
}
// formatTimeOrNow returns the formatted time if not zero, otherwise returns now
func formatTimeOrNow(t time.Time, now time.Time) string {
if t.IsZero() {
return now.Format(time.RFC3339)
}
return t.Format(time.RFC3339)
}
+534 -826
View File
File diff suppressed because it is too large Load Diff
+105 -81
View File
@@ -7,12 +7,15 @@ import (
"fmt" "fmt"
"nofx/logger" "nofx/logger"
"sync" "sync"
"gorm.io/gorm"
) )
// Store unified data storage interface // Store unified data storage interface
type Store struct { type Store struct {
db *sql.DB gdb *gorm.DB // GORM database connection
driver *DBDriver // Database driver for abstraction db *sql.DB // Legacy sql.DB for backward compatibility
driver *DBDriver // Database driver for abstraction (legacy)
// Sub-stores (lazy initialization) // Sub-stores (lazy initialization)
user *UserStore user *UserStore
@@ -26,105 +29,103 @@ type Store struct {
equity *EquityStore equity *EquityStore
order *OrderStore order *OrderStore
// Encryption functions
encryptFunc func(string) string
decryptFunc func(string) string
mu sync.RWMutex mu sync.RWMutex
} }
// New creates new Store instance (SQLite mode for backward compatibility) // New creates new Store instance (SQLite mode for backward compatibility)
func New(dbPath string) (*Store, error) { func New(dbPath string) (*Store, error) {
driver, err := NewDBDriver(DBConfig{Type: DBTypeSQLite, Path: dbPath}) gdb, err := InitGorm(dbPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err) return nil, fmt.Errorf("failed to open database: %w", err)
} }
s := &Store{db: driver.DB(), driver: driver} // Get underlying sql.DB for legacy compatibility
sqlDB, err := gdb.DB()
if err != nil {
return nil, fmt.Errorf("failed to get sql.DB: %w", err)
}
s := &Store{gdb: gdb, db: sqlDB}
// Initialize all table structures // Initialize all table structures
if err := s.initTables(); err != nil { if err := s.initTables(); err != nil {
driver.Close() sqlDB.Close()
return nil, fmt.Errorf("failed to initialize table structure: %w", err) return nil, fmt.Errorf("failed to initialize table structure: %w", err)
} }
// Initialize default data // Initialize default data
if err := s.initDefaultData(); err != nil { if err := s.initDefaultData(); err != nil {
driver.Close() sqlDB.Close()
return nil, fmt.Errorf("failed to initialize default data: %w", err) return nil, fmt.Errorf("failed to initialize default data: %w", err)
} }
logger.Infof("✅ Database initialized (type: %s)", driver.Type) logger.Infof("✅ Database initialized (GORM, SQLite)")
return s, nil return s, nil
} }
// NewFromEnv creates new Store instance from environment variables // NewWithConfig creates new Store instance with provided database configuration
// DB_TYPE: sqlite (default) or postgres func NewWithConfig(cfg DBConfig) (*Store, error) {
// For SQLite: DB_PATH (default: data/data.db) gdb, err := InitGormWithConfig(cfg)
// For PostgreSQL: DB_HOST, DB_PORT, DB_USER, DB_PASSWORD, DB_NAME, DB_SSLMODE
func NewFromEnv() (*Store, error) {
driver, err := NewDBDriverFromEnv()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err) return nil, fmt.Errorf("failed to open database: %w", err)
} }
s := &Store{db: driver.DB(), driver: driver} // Get underlying sql.DB for legacy compatibility
sqlDB, err := gdb.DB()
if err != nil {
return nil, fmt.Errorf("failed to get sql.DB: %w", err)
}
s := &Store{gdb: gdb, db: sqlDB}
// Initialize all table structures // Initialize all table structures
if err := s.initTables(); err != nil { if err := s.initTables(); err != nil {
driver.Close() sqlDB.Close()
return nil, fmt.Errorf("failed to initialize table structure: %w", err) return nil, fmt.Errorf("failed to initialize table structure: %w", err)
} }
// Initialize default data // Initialize default data
if err := s.initDefaultData(); err != nil { if err := s.initDefaultData(); err != nil {
driver.Close() sqlDB.Close()
return nil, fmt.Errorf("failed to initialize default data: %w", err) return nil, fmt.Errorf("failed to initialize default data: %w", err)
} }
logger.Infof("✅ Database initialized (type: %s)", driver.Type) dbTypeStr := "SQLite"
if cfg.Type == DBTypePostgres {
dbTypeStr = "PostgreSQL"
}
logger.Infof("✅ Database initialized (GORM, %s)", dbTypeStr)
return s, nil return s, nil
} }
// NewFromDB creates Store from existing database connection // NewFromGorm creates Store from existing GORM connection
func NewFromGorm(gdb *gorm.DB) (*Store, error) {
sqlDB, err := gdb.DB()
if err != nil {
return nil, err
}
return &Store{gdb: gdb, db: sqlDB}, nil
}
// NewFromDB creates Store from existing database connection (legacy)
// Deprecated: Use NewFromGorm instead
func NewFromDB(db *sql.DB) *Store { func NewFromDB(db *sql.DB) *Store {
return &Store{db: db} return &Store{db: db}
} }
// SetCryptoFuncs sets encryption/decryption functions // initTables initializes all database tables using GORM AutoMigrate
func (s *Store) SetCryptoFuncs(encrypt, decrypt func(string) string) {
s.mu.Lock()
defer s.mu.Unlock()
s.encryptFunc = encrypt
s.decryptFunc = decrypt
// Update already initialized sub-stores
if s.aiModel != nil {
s.aiModel.encryptFunc = encrypt
s.aiModel.decryptFunc = decrypt
}
if s.exchange != nil {
s.exchange.encryptFunc = encrypt
s.exchange.decryptFunc = decrypt
}
if s.trader != nil {
s.trader.decryptFunc = decrypt
}
}
// initTables initializes all database tables
func (s *Store) initTables() error { func (s *Store) initTables() error {
// Initialize system config table first // Create system_config table (GORM handles this via raw SQL for simplicity)
if _, err := s.db.Exec(` if err := s.gdb.Exec(`
CREATE TABLE IF NOT EXISTS system_config ( CREATE TABLE IF NOT EXISTS system_config (
key TEXT PRIMARY KEY, key TEXT PRIMARY KEY,
value TEXT NOT NULL value TEXT NOT NULL
) )
`); err != nil { `).Error; err != nil {
return fmt.Errorf("failed to create system_config table: %w", err) return fmt.Errorf("failed to create system_config table: %w", err)
} }
// Initialize in dependency order // Initialize sub-store tables
if err := s.User().initTables(); err != nil { if err := s.User().initTables(); err != nil {
return fmt.Errorf("failed to initialize user tables: %w", err) return fmt.Errorf("failed to initialize user tables: %w", err)
} }
@@ -183,7 +184,7 @@ func (s *Store) User() *UserStore {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.user == nil { if s.user == nil {
s.user = &UserStore{db: s.db} s.user = NewUserStore(s.gdb)
} }
return s.user return s.user
} }
@@ -193,11 +194,7 @@ func (s *Store) AIModel() *AIModelStore {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.aiModel == nil { if s.aiModel == nil {
s.aiModel = &AIModelStore{ s.aiModel = NewAIModelStore(s.gdb)
db: s.db,
encryptFunc: s.encryptFunc,
decryptFunc: s.decryptFunc,
}
} }
return s.aiModel return s.aiModel
} }
@@ -207,11 +204,7 @@ func (s *Store) Exchange() *ExchangeStore {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.exchange == nil { if s.exchange == nil {
s.exchange = &ExchangeStore{ s.exchange = NewExchangeStore(s.gdb)
db: s.db,
encryptFunc: s.encryptFunc,
decryptFunc: s.decryptFunc,
}
} }
return s.exchange return s.exchange
} }
@@ -221,10 +214,7 @@ func (s *Store) Trader() *TraderStore {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.trader == nil { if s.trader == nil {
s.trader = &TraderStore{ s.trader = NewTraderStore(s.gdb)
db: s.db,
decryptFunc: s.decryptFunc,
}
} }
return s.trader return s.trader
} }
@@ -234,7 +224,7 @@ func (s *Store) Decision() *DecisionStore {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.decision == nil { if s.decision == nil {
s.decision = &DecisionStore{db: s.db} s.decision = NewDecisionStore(s.gdb)
} }
return s.decision return s.decision
} }
@@ -244,7 +234,7 @@ func (s *Store) Backtest() *BacktestStore {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.backtest == nil { if s.backtest == nil {
s.backtest = &BacktestStore{db: s.db} s.backtest = NewBacktestStore(s.gdb)
} }
return s.backtest return s.backtest
} }
@@ -254,7 +244,7 @@ func (s *Store) Position() *PositionStore {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.position == nil { if s.position == nil {
s.position = NewPositionStore(s.db) s.position = NewPositionStore(s.gdb)
} }
return s.position return s.position
} }
@@ -264,7 +254,7 @@ func (s *Store) Strategy() *StrategyStore {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.strategy == nil { if s.strategy == nil {
s.strategy = &StrategyStore{db: s.db} s.strategy = NewStrategyStore(s.gdb)
} }
return s.strategy return s.strategy
} }
@@ -274,7 +264,7 @@ func (s *Store) Equity() *EquityStore {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.equity == nil { if s.equity == nil {
s.equity = &EquityStore{db: s.db} s.equity = NewEquityStore(s.gdb)
} }
return s.equity return s.equity
} }
@@ -284,7 +274,7 @@ func (s *Store) Order() *OrderStore {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.order == nil { if s.order == nil {
s.order = NewOrderStore(s.db) s.order = NewOrderStore(s.gdb)
} }
return s.order return s.order
} }
@@ -294,10 +284,18 @@ func (s *Store) Close() error {
if s.driver != nil { if s.driver != nil {
return s.driver.Close() return s.driver.Close()
} }
return s.db.Close() if s.db != nil {
return s.db.Close()
}
return nil
} }
// Driver returns database driver for abstraction // GormDB returns the GORM database connection
func (s *Store) GormDB() *gorm.DB {
return s.gdb
}
// Driver returns database driver for abstraction (legacy)
func (s *Store) Driver() *DBDriver { func (s *Store) Driver() *DBDriver {
return s.driver return s.driver
} }
@@ -307,11 +305,25 @@ func (s *Store) DBType() DBType {
if s.driver != nil { if s.driver != nil {
return s.driver.Type return s.driver.Type
} }
// Detect from GORM dialector
if s.gdb != nil {
switch s.gdb.Dialector.Name() {
case "postgres":
return DBTypePostgres
default:
return DBTypeSQLite
}
}
return DBTypeSQLite return DBTypeSQLite
} }
// DB gets underlying database connection (for legacy code compatibility, gradually deprecated) // q converts query placeholders for current database type (legacy helper)
// Deprecated: use Store methods instead func (s *Store) q(query string) string {
return convertQuery(query, s.DBType())
}
// DB gets underlying database connection (for legacy code compatibility)
// Deprecated: use GormDB() instead
func (s *Store) DB() *sql.DB { func (s *Store) DB() *sql.DB {
return s.db return s.db
} }
@@ -319,24 +331,36 @@ func (s *Store) DB() *sql.DB {
// GetSystemConfig gets a system configuration value by key // GetSystemConfig gets a system configuration value by key
func (s *Store) GetSystemConfig(key string) (string, error) { func (s *Store) GetSystemConfig(key string) (string, error) {
var value string var value string
err := s.db.QueryRow(`SELECT value FROM system_config WHERE key = ?`, key).Scan(&value) result := s.gdb.Raw("SELECT value FROM system_config WHERE key = ?", key).Scan(&value)
if err == sql.ErrNoRows { if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return "", nil
}
return "", result.Error
}
if result.RowsAffected == 0 {
return "", nil return "", nil
} }
return value, err return value, nil
} }
// SetSystemConfig sets a system configuration value // SetSystemConfig sets a system configuration value
func (s *Store) SetSystemConfig(key, value string) error { func (s *Store) SetSystemConfig(key, value string) error {
_, err := s.db.Exec(` // Use GORM-compatible upsert
return s.gdb.Exec(`
INSERT INTO system_config (key, value) VALUES (?, ?) INSERT INTO system_config (key, value) VALUES (?, ?)
ON CONFLICT(key) DO UPDATE SET value = excluded.value ON CONFLICT(key) DO UPDATE SET value = excluded.value
`, key, value) `, key, value).Error
return err
} }
// Transaction executes transaction // Transaction executes transaction with GORM
func (s *Store) Transaction(fn func(tx *sql.Tx) error) error { func (s *Store) Transaction(fn func(tx *gorm.DB) error) error {
return s.gdb.Transaction(fn)
}
// TransactionSQL executes transaction with sql.Tx (legacy)
// Deprecated: Use Transaction() instead
func (s *Store) TransactionSQL(fn func(tx *sql.Tx) error) error {
tx, err := s.db.Begin() tx, err := s.db.Begin()
if err != nil { if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err) return fmt.Errorf("failed to begin transaction: %w", err)
+57 -155
View File
@@ -1,30 +1,33 @@
package store package store
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"time" "time"
"gorm.io/gorm"
) )
// StrategyStore strategy storage // StrategyStore strategy storage
type StrategyStore struct { type StrategyStore struct {
db *sql.DB db *gorm.DB
} }
// Strategy strategy configuration // Strategy strategy configuration
type Strategy struct { type Strategy struct {
ID string `json:"id"` ID string `gorm:"primaryKey" json:"id"`
UserID string `json:"user_id"` UserID string `gorm:"column:user_id;not null;default:'';index" json:"user_id"`
Name string `json:"name"` Name string `gorm:"not null" json:"name"`
Description string `json:"description"` Description string `gorm:"default:''" json:"description"`
IsActive bool `json:"is_active"` // whether it is active (a user can only have one active strategy) IsActive bool `gorm:"column:is_active;default:false;index" json:"is_active"`
IsDefault bool `json:"is_default"` // whether it is a system default strategy IsDefault bool `gorm:"column:is_default;default:false" json:"is_default"`
Config string `json:"config"` // strategy configuration in JSON format Config string `gorm:"not null;default:'{}'" json:"config"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
} }
func (Strategy) TableName() string { return "strategies" }
// StrategyConfig strategy configuration details (JSON structure) // StrategyConfig strategy configuration details (JSON structure)
type StrategyConfig struct { type StrategyConfig struct {
// coin source configuration // coin source configuration
@@ -136,24 +139,6 @@ type ExternalDataSource struct {
} }
// RiskControlConfig risk control configuration // RiskControlConfig risk control configuration
// All parameters are clearly defined without ambiguity:
//
// Position Limits:
// - MaxPositions: max number of coins held simultaneously (CODE ENFORCED)
//
// Trading Leverage (exchange leverage for opening positions):
// - BTCETHMaxLeverage: BTC/ETH max exchange leverage (AI guided)
// - AltcoinMaxLeverage: Altcoin max exchange leverage (AI guided)
//
// Position Value Limits (single position notional value / account equity):
// - BTCETHMaxPositionValueRatio: BTC/ETH max = equity × ratio (CODE ENFORCED)
// - AltcoinMaxPositionValueRatio: Altcoin max = equity × ratio (CODE ENFORCED)
//
// Risk Controls:
// - MaxMarginUsage: max margin utilization percentage (CODE ENFORCED)
// - MinPositionSize: minimum position size in USDT (CODE ENFORCED)
// - MinRiskRewardRatio: min take_profit / stop_loss ratio (AI guided)
// - MinConfidence: min AI confidence to open position (AI guided)
type RiskControlConfig struct { type RiskControlConfig struct {
// Max number of coins held simultaneously (CODE ENFORCED) // Max number of coins held simultaneously (CODE ENFORCED)
MaxPositions int `json:"max_positions"` MaxPositions int `json:"max_positions"`
@@ -179,38 +164,21 @@ type RiskControlConfig struct {
MinConfidence int `json:"min_confidence"` MinConfidence int `json:"min_confidence"`
} }
// NewStrategyStore creates a new StrategyStore
func NewStrategyStore(db *gorm.DB) *StrategyStore {
return &StrategyStore{db: db}
}
func (s *StrategyStore) initTables() error { func (s *StrategyStore) initTables() error {
_, err := s.db.Exec(` // For PostgreSQL with existing table, skip AutoMigrate
CREATE TABLE IF NOT EXISTS strategies ( if s.db.Dialector.Name() == "postgres" {
id TEXT PRIMARY KEY, var tableExists int64
user_id TEXT NOT NULL DEFAULT '', s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'strategies'`).Scan(&tableExists)
name TEXT NOT NULL, if tableExists > 0 {
description TEXT DEFAULT '', return nil
is_active BOOLEAN DEFAULT 0, }
is_default BOOLEAN DEFAULT 0,
config TEXT NOT NULL DEFAULT '{}',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return err
} }
return s.db.AutoMigrate(&Strategy{})
// create indexes
_, _ = s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_strategies_user_id ON strategies(user_id)`)
_, _ = s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_strategies_is_active ON strategies(is_active)`)
// trigger: automatically update updated_at on update
_, err = s.db.Exec(`
CREATE TRIGGER IF NOT EXISTS update_strategies_updated_at
AFTER UPDATE ON strategies
BEGIN
UPDATE strategies SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
END
`)
return err
} }
func (s *StrategyStore) initDefaultData() error { func (s *StrategyStore) initDefaultData() error {
@@ -322,159 +290,93 @@ Only enter positions when multiple signals resonate. Freely use any effective an
// Create create a strategy // Create create a strategy
func (s *StrategyStore) Create(strategy *Strategy) error { func (s *StrategyStore) Create(strategy *Strategy) error {
_, err := s.db.Exec(` return s.db.Create(strategy).Error
INSERT INTO strategies (id, user_id, name, description, is_active, is_default, config)
VALUES (?, ?, ?, ?, ?, ?, ?)
`, strategy.ID, strategy.UserID, strategy.Name, strategy.Description, strategy.IsActive, strategy.IsDefault, strategy.Config)
return err
} }
// Update update a strategy // Update update a strategy
func (s *StrategyStore) Update(strategy *Strategy) error { func (s *StrategyStore) Update(strategy *Strategy) error {
_, err := s.db.Exec(` return s.db.Model(&Strategy{}).
UPDATE strategies SET Where("id = ? AND user_id = ?", strategy.ID, strategy.UserID).
name = ?, description = ?, config = ?, updated_at = CURRENT_TIMESTAMP Updates(map[string]interface{}{
WHERE id = ? AND user_id = ? "name": strategy.Name,
`, strategy.Name, strategy.Description, strategy.Config, strategy.ID, strategy.UserID) "description": strategy.Description,
return err "config": strategy.Config,
"updated_at": time.Now(),
}).Error
} }
// Delete delete a strategy // Delete delete a strategy
func (s *StrategyStore) Delete(userID, id string) error { func (s *StrategyStore) Delete(userID, id string) error {
// do not allow deleting system default strategy // do not allow deleting system default strategy
var isDefault bool var st Strategy
s.db.QueryRow(`SELECT is_default FROM strategies WHERE id = ?`, id).Scan(&isDefault) if err := s.db.Where("id = ?", id).First(&st).Error; err == nil && st.IsDefault {
if isDefault {
return fmt.Errorf("cannot delete system default strategy") return fmt.Errorf("cannot delete system default strategy")
} }
_, err := s.db.Exec(`DELETE FROM strategies WHERE id = ? AND user_id = ?`, id, userID) return s.db.Where("id = ? AND user_id = ?", id, userID).Delete(&Strategy{}).Error
return err
} }
// List get user's strategy list // List get user's strategy list
func (s *StrategyStore) List(userID string) ([]*Strategy, error) { func (s *StrategyStore) List(userID string) ([]*Strategy, error) {
// get user's own strategies + system default strategy var strategies []*Strategy
rows, err := s.db.Query(` err := s.db.Where("user_id = ? OR is_default = ?", userID, true).
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at Order("is_default DESC, created_at DESC").
FROM strategies Find(&strategies).Error
WHERE user_id = ? OR is_default = 1
ORDER BY is_default DESC, created_at DESC
`, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
var strategies []*Strategy
for rows.Next() {
var st Strategy
var createdAt, updatedAt string
err := rows.Scan(
&st.ID, &st.UserID, &st.Name, &st.Description,
&st.IsActive, &st.IsDefault, &st.Config,
&createdAt, &updatedAt,
)
if err != nil {
return nil, err
}
st.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
st.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
strategies = append(strategies, &st)
}
return strategies, nil return strategies, nil
} }
// Get get a single strategy // Get get a single strategy
func (s *StrategyStore) Get(userID, id string) (*Strategy, error) { func (s *StrategyStore) Get(userID, id string) (*Strategy, error) {
var st Strategy var st Strategy
var createdAt, updatedAt string err := s.db.Where("id = ? AND (user_id = ? OR is_default = ?)", id, userID, true).
err := s.db.QueryRow(` First(&st).Error
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
FROM strategies
WHERE id = ? AND (user_id = ? OR is_default = 1)
`, id, userID).Scan(
&st.ID, &st.UserID, &st.Name, &st.Description,
&st.IsActive, &st.IsDefault, &st.Config,
&createdAt, &updatedAt,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
st.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
st.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
return &st, nil return &st, nil
} }
// GetActive get user's currently active strategy // GetActive get user's currently active strategy
func (s *StrategyStore) GetActive(userID string) (*Strategy, error) { func (s *StrategyStore) GetActive(userID string) (*Strategy, error) {
var st Strategy var st Strategy
var createdAt, updatedAt string err := s.db.Where("user_id = ? AND is_active = ?", userID, true).First(&st).Error
err := s.db.QueryRow(` if err == gorm.ErrRecordNotFound {
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
FROM strategies
WHERE user_id = ? AND is_active = 1
`, userID).Scan(
&st.ID, &st.UserID, &st.Name, &st.Description,
&st.IsActive, &st.IsDefault, &st.Config,
&createdAt, &updatedAt,
)
if err == sql.ErrNoRows {
// no active strategy, return system default strategy // no active strategy, return system default strategy
return s.GetDefault() return s.GetDefault()
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
st.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
st.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
return &st, nil return &st, nil
} }
// GetDefault get system default strategy // GetDefault get system default strategy
func (s *StrategyStore) GetDefault() (*Strategy, error) { func (s *StrategyStore) GetDefault() (*Strategy, error) {
var st Strategy var st Strategy
var createdAt, updatedAt string err := s.db.Where("is_default = ?", true).First(&st).Error
err := s.db.QueryRow(`
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
FROM strategies
WHERE is_default = 1
LIMIT 1
`).Scan(
&st.ID, &st.UserID, &st.Name, &st.Description,
&st.IsActive, &st.IsDefault, &st.Config,
&createdAt, &updatedAt,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
st.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
st.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
return &st, nil return &st, nil
} }
// SetActive set active strategy (will first deactivate other strategies) // SetActive set active strategy (will first deactivate other strategies)
func (s *StrategyStore) SetActive(userID, strategyID string) error { func (s *StrategyStore) SetActive(userID, strategyID string) error {
// begin transaction return s.db.Transaction(func(tx *gorm.DB) error {
tx, err := s.db.Begin() // first deactivate all strategies for the user
if err != nil { if err := tx.Model(&Strategy{}).Where("user_id = ?", userID).
return err Update("is_active", false).Error; err != nil {
} return err
defer tx.Rollback() }
// first deactivate all strategies for the user // activate specified strategy
_, err = tx.Exec(`UPDATE strategies SET is_active = 0 WHERE user_id = ?`, userID) return tx.Model(&Strategy{}).
if err != nil { Where("id = ? AND (user_id = ? OR is_default = ?)", strategyID, userID, true).
return err Update("is_active", true).Error
} })
// activate specified strategy
_, err = tx.Exec(`UPDATE strategies SET is_active = 1 WHERE id = ? AND (user_id = ? OR is_default = 1)`, strategyID, userID)
if err != nil {
return err
}
return tx.Commit()
} }
// Duplicate duplicate a strategy (used to create custom strategy based on default strategy) // Duplicate duplicate a strategy (used to create custom strategy based on default strategy)
+111 -374
View File
@@ -1,43 +1,52 @@
package store package store
import ( import (
"database/sql"
"fmt" "fmt"
"strings"
"time" "time"
"gorm.io/gorm"
) )
// TraderStore trader storage // TraderStore trader storage
type TraderStore struct { type TraderStore struct {
db *sql.DB db *gorm.DB
decryptFunc func(string) string }
// NewTraderStore creates a new trader store
func NewTraderStore(db *gorm.DB) *TraderStore {
return &TraderStore{db: db}
} }
// Trader trader configuration // Trader trader configuration
type Trader struct { type Trader struct {
ID string `json:"id"` ID string `gorm:"primaryKey" json:"id"`
UserID string `json:"user_id"` UserID string `gorm:"column:user_id;not null;default:default;index" json:"user_id"`
Name string `json:"name"` Name string `gorm:"column:name;not null" json:"name"`
AIModelID string `json:"ai_model_id"` AIModelID string `gorm:"column:ai_model_id;not null" json:"ai_model_id"`
ExchangeID string `json:"exchange_id"` ExchangeID string `gorm:"column:exchange_id;not null" json:"exchange_id"`
StrategyID string `json:"strategy_id"` // Associated strategy ID StrategyID string `gorm:"column:strategy_id;default:''" json:"strategy_id"`
InitialBalance float64 `json:"initial_balance"` InitialBalance float64 `gorm:"column:initial_balance;not null" json:"initial_balance"`
ScanIntervalMinutes int `json:"scan_interval_minutes"` ScanIntervalMinutes int `gorm:"column:scan_interval_minutes;default:3" json:"scan_interval_minutes"`
IsRunning bool `json:"is_running"` IsRunning bool `gorm:"column:is_running;default:false" json:"is_running"`
IsCrossMargin bool `json:"is_cross_margin"` IsCrossMargin bool `gorm:"column:is_cross_margin;default:true" json:"is_cross_margin"`
ShowInCompetition bool `json:"show_in_competition"` // Whether to show in competition page ShowInCompetition bool `gorm:"column:show_in_competition;default:true" json:"show_in_competition"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime" json:"updated_at"`
// Following fields are deprecated, kept for backward compatibility, new traders should use StrategyID // Following fields are deprecated, kept for backward compatibility, new traders should use StrategyID
BTCETHLeverage int `json:"btc_eth_leverage,omitempty"` BTCETHLeverage int `gorm:"column:btc_eth_leverage;default:5" json:"btc_eth_leverage,omitempty"`
AltcoinLeverage int `json:"altcoin_leverage,omitempty"` AltcoinLeverage int `gorm:"column:altcoin_leverage;default:5" json:"altcoin_leverage,omitempty"`
TradingSymbols string `json:"trading_symbols,omitempty"` TradingSymbols string `gorm:"column:trading_symbols;default:''" json:"trading_symbols,omitempty"`
UseCoinPool bool `json:"use_coin_pool,omitempty"` UseCoinPool bool `gorm:"column:use_coin_pool;default:false" json:"use_coin_pool,omitempty"`
UseOITop bool `json:"use_oi_top,omitempty"` UseOITop bool `gorm:"column:use_oi_top;default:false" json:"use_oi_top,omitempty"`
CustomPrompt string `json:"custom_prompt,omitempty"` CustomPrompt string `gorm:"column:custom_prompt;default:''" json:"custom_prompt,omitempty"`
OverrideBasePrompt bool `json:"override_base_prompt,omitempty"` OverrideBasePrompt bool `gorm:"column:override_base_prompt;default:false" json:"override_base_prompt,omitempty"`
SystemPromptTemplate string `json:"system_prompt_template,omitempty"` SystemPromptTemplate string `gorm:"column:system_prompt_template;default:default" json:"system_prompt_template,omitempty"`
}
// TableName returns the table name for Trader
func (Trader) TableName() string {
return "traders"
} }
// TraderFullConfig trader full configuration (includes AI model, exchange and strategy) // TraderFullConfig trader full configuration (includes AI model, exchange and strategy)
@@ -45,331 +54,130 @@ type TraderFullConfig struct {
Trader *Trader Trader *Trader
AIModel *AIModel AIModel *AIModel
Exchange *Exchange Exchange *Exchange
Strategy *Strategy // Associated strategy configuration Strategy *Strategy
} }
func (s *TraderStore) initTables() error { func (s *TraderStore) initTables() error {
_, err := s.db.Exec(` // For PostgreSQL with existing table, skip AutoMigrate
CREATE TABLE IF NOT EXISTS traders ( if s.db.Dialector.Name() == "postgres" {
id TEXT PRIMARY KEY, var tableExists int64
user_id TEXT NOT NULL DEFAULT 'default', s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'traders'`).Scan(&tableExists)
name TEXT NOT NULL, if tableExists > 0 {
ai_model_id TEXT NOT NULL, return nil
exchange_id TEXT NOT NULL, }
initial_balance REAL NOT NULL,
scan_interval_minutes INTEGER DEFAULT 3,
is_running BOOLEAN DEFAULT 0,
btc_eth_leverage INTEGER DEFAULT 5,
altcoin_leverage INTEGER DEFAULT 5,
trading_symbols TEXT DEFAULT '',
use_coin_pool BOOLEAN DEFAULT 0,
use_oi_top BOOLEAN DEFAULT 0,
custom_prompt TEXT DEFAULT '',
override_base_prompt BOOLEAN DEFAULT 0,
system_prompt_template TEXT DEFAULT 'default',
is_cross_margin BOOLEAN DEFAULT 1,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return err
} }
// Use GORM AutoMigrate
// Trigger if err := s.db.AutoMigrate(&Trader{}); err != nil {
_, err = s.db.Exec(` return fmt.Errorf("failed to migrate traders table: %w", err)
CREATE TRIGGER IF NOT EXISTS update_traders_updated_at
AFTER UPDATE ON traders
BEGIN
UPDATE traders SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
END
`)
if err != nil {
return err
} }
// Backward compatibility
alterQueries := []string{
`ALTER TABLE traders ADD COLUMN custom_prompt TEXT DEFAULT ''`,
`ALTER TABLE traders ADD COLUMN override_base_prompt BOOLEAN DEFAULT 0`,
`ALTER TABLE traders ADD COLUMN is_cross_margin BOOLEAN DEFAULT 1`,
`ALTER TABLE traders ADD COLUMN btc_eth_leverage INTEGER DEFAULT 5`,
`ALTER TABLE traders ADD COLUMN altcoin_leverage INTEGER DEFAULT 5`,
`ALTER TABLE traders ADD COLUMN trading_symbols TEXT DEFAULT ''`,
`ALTER TABLE traders ADD COLUMN use_coin_pool BOOLEAN DEFAULT 0`,
`ALTER TABLE traders ADD COLUMN use_oi_top BOOLEAN DEFAULT 0`,
`ALTER TABLE traders ADD COLUMN system_prompt_template TEXT DEFAULT 'default'`,
`ALTER TABLE traders ADD COLUMN strategy_id TEXT DEFAULT ''`,
`ALTER TABLE traders ADD COLUMN show_in_competition BOOLEAN DEFAULT 1`,
}
for _, q := range alterQueries {
s.db.Exec(q)
}
// Migration: Remove FOREIGN KEY constraint from existing traders table
// SQLite doesn't support ALTER TABLE DROP CONSTRAINT, so we need to recreate the table
if err := s.migrateTradersRemoveFK(); err != nil {
// Log but don't fail - this is a best-effort migration
// The constraint may not exist in older databases
}
return nil return nil
} }
// migrateTradersRemoveFK removes FOREIGN KEY constraint from traders table if it exists
func (s *TraderStore) migrateTradersRemoveFK() error {
// Check if the table has a foreign key constraint by examining the schema
var sql string
err := s.db.QueryRow(`SELECT sql FROM sqlite_master WHERE type='table' AND name='traders'`).Scan(&sql)
if err != nil {
return err
}
// If no FOREIGN KEY in schema, no migration needed
if !strings.Contains(sql, "FOREIGN KEY") {
return nil
}
// Recreate table without FOREIGN KEY constraint
_, err = s.db.Exec(`
-- Create new table without FOREIGN KEY
CREATE TABLE IF NOT EXISTS traders_new (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL DEFAULT 'default',
name TEXT NOT NULL,
ai_model_id TEXT NOT NULL,
exchange_id TEXT NOT NULL,
initial_balance REAL NOT NULL,
scan_interval_minutes INTEGER DEFAULT 3,
is_running BOOLEAN DEFAULT 0,
btc_eth_leverage INTEGER DEFAULT 5,
altcoin_leverage INTEGER DEFAULT 5,
trading_symbols TEXT DEFAULT '',
use_coin_pool BOOLEAN DEFAULT 0,
use_oi_top BOOLEAN DEFAULT 0,
custom_prompt TEXT DEFAULT '',
override_base_prompt BOOLEAN DEFAULT 0,
system_prompt_template TEXT DEFAULT 'default',
is_cross_margin BOOLEAN DEFAULT 1,
strategy_id TEXT DEFAULT '',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
-- Copy data from old table
INSERT OR IGNORE INTO traders_new
SELECT id, user_id, name, ai_model_id, exchange_id, initial_balance,
scan_interval_minutes, is_running, btc_eth_leverage, altcoin_leverage,
trading_symbols, use_coin_pool, use_oi_top, custom_prompt,
override_base_prompt, system_prompt_template, is_cross_margin,
COALESCE(strategy_id, ''), created_at, updated_at
FROM traders;
-- Drop old table
DROP TABLE traders;
-- Rename new table
ALTER TABLE traders_new RENAME TO traders;
`)
if err != nil {
return err
}
// Recreate trigger
_, err = s.db.Exec(`
CREATE TRIGGER IF NOT EXISTS update_traders_updated_at
AFTER UPDATE ON traders
BEGIN
UPDATE traders SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
END
`)
return err
}
func (s *TraderStore) decrypt(encrypted string) string {
if s.decryptFunc != nil {
return s.decryptFunc(encrypted)
}
return encrypted
}
// Create creates trader // Create creates trader
func (s *TraderStore) Create(trader *Trader) error { func (s *TraderStore) Create(trader *Trader) error {
_, err := s.db.Exec(` return s.db.Create(trader).Error
INSERT INTO traders (id, user_id, name, ai_model_id, exchange_id, strategy_id, initial_balance,
scan_interval_minutes, is_running, is_cross_margin, show_in_competition,
btc_eth_leverage, altcoin_leverage, trading_symbols, use_coin_pool,
use_oi_top, custom_prompt, override_base_prompt, system_prompt_template)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`, trader.ID, trader.UserID, trader.Name, trader.AIModelID, trader.ExchangeID, trader.StrategyID,
trader.InitialBalance, trader.ScanIntervalMinutes, trader.IsRunning, trader.IsCrossMargin, trader.ShowInCompetition,
trader.BTCETHLeverage, trader.AltcoinLeverage, trader.TradingSymbols, trader.UseCoinPool,
trader.UseOITop, trader.CustomPrompt, trader.OverrideBasePrompt, trader.SystemPromptTemplate)
return err
} }
// List gets user's trader list // List gets user's trader list
func (s *TraderStore) List(userID string) ([]*Trader, error) { func (s *TraderStore) List(userID string) ([]*Trader, error) {
rows, err := s.db.Query(` var traders []*Trader
SELECT id, user_id, name, ai_model_id, exchange_id, COALESCE(strategy_id, ''), err := s.db.Where("user_id = ?", userID).
initial_balance, scan_interval_minutes, is_running, COALESCE(is_cross_margin, 1), Order("created_at DESC").
COALESCE(show_in_competition, 1), Find(&traders).Error
COALESCE(btc_eth_leverage, 5), COALESCE(altcoin_leverage, 5), COALESCE(trading_symbols, ''),
COALESCE(use_coin_pool, 0), COALESCE(use_oi_top, 0), COALESCE(custom_prompt, ''),
COALESCE(override_base_prompt, 0), COALESCE(system_prompt_template, 'default'),
created_at, updated_at
FROM traders WHERE user_id = ? ORDER BY created_at DESC
`, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
var traders []*Trader
for rows.Next() {
var t Trader
var createdAt, updatedAt string
err := rows.Scan(
&t.ID, &t.UserID, &t.Name, &t.AIModelID, &t.ExchangeID, &t.StrategyID,
&t.InitialBalance, &t.ScanIntervalMinutes, &t.IsRunning, &t.IsCrossMargin,
&t.ShowInCompetition,
&t.BTCETHLeverage, &t.AltcoinLeverage, &t.TradingSymbols,
&t.UseCoinPool, &t.UseOITop, &t.CustomPrompt, &t.OverrideBasePrompt,
&t.SystemPromptTemplate, &createdAt, &updatedAt,
)
if err != nil {
return nil, err
}
t.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
t.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
traders = append(traders, &t)
}
return traders, nil return traders, nil
} }
// UpdateStatus updates trader running status // UpdateStatus updates trader running status
func (s *TraderStore) UpdateStatus(userID, id string, isRunning bool) error { func (s *TraderStore) UpdateStatus(userID, id string, isRunning bool) error {
_, err := s.db.Exec(`UPDATE traders SET is_running = ? WHERE id = ? AND user_id = ?`, isRunning, id, userID) return s.db.Model(&Trader{}).
return err Where("id = ? AND user_id = ?", id, userID).
Update("is_running", isRunning).Error
} }
// UpdateShowInCompetition updates trader competition visibility // UpdateShowInCompetition updates trader competition visibility
func (s *TraderStore) UpdateShowInCompetition(userID, id string, showInCompetition bool) error { func (s *TraderStore) UpdateShowInCompetition(userID, id string, showInCompetition bool) error {
_, err := s.db.Exec(`UPDATE traders SET show_in_competition = ? WHERE id = ? AND user_id = ?`, showInCompetition, id, userID) return s.db.Model(&Trader{}).
return err Where("id = ? AND user_id = ?", id, userID).
Update("show_in_competition", showInCompetition).Error
} }
// Update updates trader configuration // Update updates trader configuration
func (s *TraderStore) Update(trader *Trader) error { func (s *TraderStore) Update(trader *Trader) error {
fmt.Printf("📝 TraderStore.Update: ID=%s, Name=%s, AIModelID=%s, StrategyID=%s\n", fmt.Printf("📝 TraderStore.Update: ID=%s, Name=%s, AIModelID=%s, StrategyID=%s\n",
trader.ID, trader.Name, trader.AIModelID, trader.StrategyID) trader.ID, trader.Name, trader.AIModelID, trader.StrategyID)
_, err := s.db.Exec(`
UPDATE traders SET updates := map[string]interface{}{
name = ?, "name": trader.Name,
ai_model_id = ?, "ai_model_id": trader.AIModelID,
exchange_id = ?, "exchange_id": trader.ExchangeID,
strategy_id = ?, "strategy_id": trader.StrategyID,
initial_balance = CASE WHEN ? > 0 THEN ? ELSE initial_balance END, "is_cross_margin": trader.IsCrossMargin,
scan_interval_minutes = CASE WHEN ? > 0 THEN ? ELSE scan_interval_minutes END, "show_in_competition": trader.ShowInCompetition,
is_cross_margin = ?, }
show_in_competition = ?,
updated_at = CURRENT_TIMESTAMP // Only update these if > 0
WHERE id = ? AND user_id = ? if trader.InitialBalance > 0 {
`, trader.Name, trader.AIModelID, trader.ExchangeID, trader.StrategyID, updates["initial_balance"] = trader.InitialBalance
trader.InitialBalance, trader.InitialBalance, }
trader.ScanIntervalMinutes, trader.ScanIntervalMinutes, if trader.ScanIntervalMinutes > 0 {
trader.IsCrossMargin, trader.ShowInCompetition, updates["scan_interval_minutes"] = trader.ScanIntervalMinutes
trader.ID, trader.UserID) }
return err
return s.db.Model(&Trader{}).
Where("id = ? AND user_id = ?", trader.ID, trader.UserID).
Updates(updates).Error
} }
// UpdateInitialBalance updates initial balance // UpdateInitialBalance updates initial balance
func (s *TraderStore) UpdateInitialBalance(userID, id string, newBalance float64) error { func (s *TraderStore) UpdateInitialBalance(userID, id string, newBalance float64) error {
_, err := s.db.Exec(`UPDATE traders SET initial_balance = ? WHERE id = ? AND user_id = ?`, newBalance, id, userID) return s.db.Model(&Trader{}).
return err Where("id = ? AND user_id = ?", id, userID).
Update("initial_balance", newBalance).Error
} }
// UpdateCustomPrompt updates custom prompt // UpdateCustomPrompt updates custom prompt
func (s *TraderStore) UpdateCustomPrompt(userID, id string, customPrompt string, overrideBase bool) error { func (s *TraderStore) UpdateCustomPrompt(userID, id string, customPrompt string, overrideBase bool) error {
_, err := s.db.Exec(`UPDATE traders SET custom_prompt = ?, override_base_prompt = ? WHERE id = ? AND user_id = ?`, return s.db.Model(&Trader{}).
customPrompt, overrideBase, id, userID) Where("id = ? AND user_id = ?", id, userID).
return err Updates(map[string]interface{}{
"custom_prompt": customPrompt,
"override_base_prompt": overrideBase,
}).Error
} }
// Delete deletes trader and associated data // Delete deletes trader and associated data
func (s *TraderStore) Delete(userID, id string) error { func (s *TraderStore) Delete(userID, id string) error {
// Delete associated equity snapshots first // Delete associated equity snapshots first
_, _ = s.db.Exec(`DELETE FROM trader_equity_snapshots WHERE trader_id = ?`, id) s.db.Where("trader_id = ?", id).Delete(&EquitySnapshot{})
// Delete the trader // Delete the trader
_, err := s.db.Exec(`DELETE FROM traders WHERE id = ? AND user_id = ?`, id, userID) return s.db.Where("id = ? AND user_id = ?", id, userID).Delete(&Trader{}).Error
return err
} }
// GetFullConfig gets trader full configuration // GetFullConfig gets trader full configuration
func (s *TraderStore) GetFullConfig(userID, traderID string) (*TraderFullConfig, error) { func (s *TraderStore) GetFullConfig(userID, traderID string) (*TraderFullConfig, error) {
var trader Trader var trader Trader
var aiModel AIModel err := s.db.Where("id = ? AND user_id = ?", traderID, userID).First(&trader).Error
var exchange Exchange
var traderCreatedAt, traderUpdatedAt string
var aiModelCreatedAt, aiModelUpdatedAt string
var exchangeCreatedAt, exchangeUpdatedAt string
err := s.db.QueryRow(`
SELECT
t.id, t.user_id, t.name, t.ai_model_id, t.exchange_id, COALESCE(t.strategy_id, ''),
t.initial_balance, t.scan_interval_minutes, t.is_running, COALESCE(t.is_cross_margin, 1),
COALESCE(t.btc_eth_leverage, 5), COALESCE(t.altcoin_leverage, 5), COALESCE(t.trading_symbols, ''),
COALESCE(t.use_coin_pool, 0), COALESCE(t.use_oi_top, 0), COALESCE(t.custom_prompt, ''),
COALESCE(t.override_base_prompt, 0), COALESCE(t.system_prompt_template, 'default'),
t.created_at, t.updated_at,
a.id, a.user_id, a.name, a.provider, a.enabled, a.api_key,
COALESCE(a.custom_api_url, ''), COALESCE(a.custom_model_name, ''), a.created_at, a.updated_at,
e.id, COALESCE(e.exchange_type, '') as exchange_type, COALESCE(e.account_name, '') as account_name,
e.user_id, e.name, e.type, e.enabled, e.api_key, e.secret_key, COALESCE(e.passphrase, ''), e.testnet,
COALESCE(e.hyperliquid_wallet_addr, ''), COALESCE(e.aster_user, ''), COALESCE(e.aster_signer, ''),
COALESCE(e.aster_private_key, ''), COALESCE(e.lighter_wallet_addr, ''), COALESCE(e.lighter_private_key, ''),
COALESCE(e.lighter_api_key_private_key, ''), COALESCE(e.lighter_api_key_index, 0), e.created_at, e.updated_at
FROM traders t
JOIN ai_models a ON t.ai_model_id = a.id AND t.user_id = a.user_id
JOIN exchanges e ON t.exchange_id = e.id AND t.user_id = e.user_id
WHERE t.id = ? AND t.user_id = ?
`, traderID, userID).Scan(
&trader.ID, &trader.UserID, &trader.Name, &trader.AIModelID, &trader.ExchangeID, &trader.StrategyID,
&trader.InitialBalance, &trader.ScanIntervalMinutes, &trader.IsRunning, &trader.IsCrossMargin,
&trader.BTCETHLeverage, &trader.AltcoinLeverage, &trader.TradingSymbols,
&trader.UseCoinPool, &trader.UseOITop, &trader.CustomPrompt, &trader.OverrideBasePrompt,
&trader.SystemPromptTemplate, &traderCreatedAt, &traderUpdatedAt,
&aiModel.ID, &aiModel.UserID, &aiModel.Name, &aiModel.Provider, &aiModel.Enabled, &aiModel.APIKey,
&aiModel.CustomAPIURL, &aiModel.CustomModelName, &aiModelCreatedAt, &aiModelUpdatedAt,
&exchange.ID, &exchange.ExchangeType, &exchange.AccountName,
&exchange.UserID, &exchange.Name, &exchange.Type, &exchange.Enabled,
&exchange.APIKey, &exchange.SecretKey, &exchange.Passphrase, &exchange.Testnet, &exchange.HyperliquidWalletAddr,
&exchange.AsterUser, &exchange.AsterSigner, &exchange.AsterPrivateKey,
&exchange.LighterWalletAddr, &exchange.LighterPrivateKey, &exchange.LighterAPIKeyPrivateKey, &exchange.LighterAPIKeyIndex,
&exchangeCreatedAt, &exchangeUpdatedAt,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
trader.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", traderCreatedAt) // Get AI model
trader.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", traderUpdatedAt) var aiModel AIModel
aiModel.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", aiModelCreatedAt) err = s.db.Where("id = ? AND user_id = ?", trader.AIModelID, userID).First(&aiModel).Error
aiModel.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", aiModelUpdatedAt) if err != nil {
exchange.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", exchangeCreatedAt) return nil, fmt.Errorf("failed to get AI model: %w", err)
exchange.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", exchangeUpdatedAt) }
// Decrypt // Get exchange
aiModel.APIKey = s.decrypt(aiModel.APIKey) var exchange Exchange
exchange.APIKey = s.decrypt(exchange.APIKey) err = s.db.Where("id = ? AND user_id = ?", trader.ExchangeID, userID).First(&exchange).Error
exchange.SecretKey = s.decrypt(exchange.SecretKey) if err != nil {
exchange.Passphrase = s.decrypt(exchange.Passphrase) return nil, fmt.Errorf("failed to get exchange: %w", err)
exchange.AsterPrivateKey = s.decrypt(exchange.AsterPrivateKey) }
exchange.LighterPrivateKey = s.decrypt(exchange.LighterPrivateKey)
exchange.LighterAPIKeyPrivateKey = s.decrypt(exchange.LighterAPIKeyPrivateKey)
// Load associated strategy // Load associated strategy
var strategy *Strategy var strategy *Strategy
@@ -392,119 +200,48 @@ func (s *TraderStore) GetFullConfig(userID, traderID string) (*TraderFullConfig,
// getStrategyByID internal method: gets strategy by ID // getStrategyByID internal method: gets strategy by ID
func (s *TraderStore) getStrategyByID(userID, strategyID string) (*Strategy, error) { func (s *TraderStore) getStrategyByID(userID, strategyID string) (*Strategy, error) {
var strategy Strategy var strategy Strategy
var createdAt, updatedAt string err := s.db.Where("id = ? AND (user_id = ? OR is_default = ?)", strategyID, userID, true).
err := s.db.QueryRow(` First(&strategy).Error
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
FROM strategies WHERE id = ? AND (user_id = ? OR is_default = 1)
`, strategyID, userID).Scan(
&strategy.ID, &strategy.UserID, &strategy.Name, &strategy.Description,
&strategy.IsActive, &strategy.IsDefault, &strategy.Config, &createdAt, &updatedAt,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
strategy.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
strategy.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
return &strategy, nil return &strategy, nil
} }
// getActiveOrDefaultStrategy internal method: gets user's active strategy or system default strategy // getActiveOrDefaultStrategy internal method: gets user's active strategy or system default strategy
func (s *TraderStore) getActiveOrDefaultStrategy(userID string) (*Strategy, error) { func (s *TraderStore) getActiveOrDefaultStrategy(userID string) (*Strategy, error) {
var strategy Strategy var strategy Strategy
var createdAt, updatedAt string
// First try to get user's active strategy // First try to get user's active strategy
err := s.db.QueryRow(` err := s.db.Where("user_id = ? AND is_active = ?", userID, true).First(&strategy).Error
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
FROM strategies WHERE user_id = ? AND is_active = 1
`, userID).Scan(
&strategy.ID, &strategy.UserID, &strategy.Name, &strategy.Description,
&strategy.IsActive, &strategy.IsDefault, &strategy.Config, &createdAt, &updatedAt,
)
if err == nil { if err == nil {
strategy.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
strategy.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
return &strategy, nil return &strategy, nil
} }
// Fallback to system default strategy // Fallback to system default strategy
err = s.db.QueryRow(` err = s.db.Where("is_default = ?", true).First(&strategy).Error
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
FROM strategies WHERE is_default = 1 LIMIT 1
`).Scan(
&strategy.ID, &strategy.UserID, &strategy.Name, &strategy.Description,
&strategy.IsActive, &strategy.IsDefault, &strategy.Config, &createdAt, &updatedAt,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
strategy.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
strategy.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
return &strategy, nil return &strategy, nil
} }
// ListAll gets all users' trader list
// GetByID gets a trader by ID without requiring userID (for public APIs) // GetByID gets a trader by ID without requiring userID (for public APIs)
func (s *TraderStore) GetByID(traderID string) (*Trader, error) { func (s *TraderStore) GetByID(traderID string) (*Trader, error) {
var t Trader var trader Trader
var createdAt, updatedAt string err := s.db.Where("id = ?", traderID).First(&trader).Error
err := s.db.QueryRow(`
SELECT id, user_id, name, ai_model_id, exchange_id, COALESCE(strategy_id, ''),
initial_balance, scan_interval_minutes, is_running, COALESCE(is_cross_margin, 1),
COALESCE(btc_eth_leverage, 5), COALESCE(altcoin_leverage, 5), COALESCE(trading_symbols, ''),
COALESCE(use_coin_pool, 0), COALESCE(use_oi_top, 0), COALESCE(custom_prompt, ''),
COALESCE(override_base_prompt, 0), COALESCE(system_prompt_template, 'default'),
created_at, updated_at
FROM traders WHERE id = ?
`, traderID).Scan(
&t.ID, &t.UserID, &t.Name, &t.AIModelID, &t.ExchangeID, &t.StrategyID,
&t.InitialBalance, &t.ScanIntervalMinutes, &t.IsRunning, &t.IsCrossMargin,
&t.BTCETHLeverage, &t.AltcoinLeverage, &t.TradingSymbols,
&t.UseCoinPool, &t.UseOITop, &t.CustomPrompt, &t.OverrideBasePrompt,
&t.SystemPromptTemplate, &createdAt, &updatedAt,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) return &trader, nil
t.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
return &t, nil
} }
// ListAll gets all traders
func (s *TraderStore) ListAll() ([]*Trader, error) { func (s *TraderStore) ListAll() ([]*Trader, error) {
rows, err := s.db.Query(` var traders []*Trader
SELECT id, user_id, name, ai_model_id, exchange_id, COALESCE(strategy_id, ''), err := s.db.Order("created_at DESC").Find(&traders).Error
initial_balance, scan_interval_minutes, is_running, COALESCE(is_cross_margin, 1),
COALESCE(show_in_competition, 1),
COALESCE(btc_eth_leverage, 5), COALESCE(altcoin_leverage, 5), COALESCE(trading_symbols, ''),
COALESCE(use_coin_pool, 0), COALESCE(use_oi_top, 0), COALESCE(custom_prompt, ''),
COALESCE(override_base_prompt, 0), COALESCE(system_prompt_template, 'default'),
created_at, updated_at
FROM traders ORDER BY created_at DESC
`)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
var traders []*Trader
for rows.Next() {
var t Trader
var createdAt, updatedAt string
err := rows.Scan(
&t.ID, &t.UserID, &t.Name, &t.AIModelID, &t.ExchangeID, &t.StrategyID,
&t.InitialBalance, &t.ScanIntervalMinutes, &t.IsRunning, &t.IsCrossMargin,
&t.ShowInCompetition,
&t.BTCETHLeverage, &t.AltcoinLeverage, &t.TradingSymbols,
&t.UseCoinPool, &t.UseOITop, &t.CustomPrompt, &t.OverrideBasePrompt,
&t.SystemPromptTemplate, &createdAt, &updatedAt,
)
if err != nil {
return nil, err
}
t.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
t.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
traders = append(traders, &t)
}
return traders, nil return traders, nil
} }
+60 -87
View File
@@ -2,27 +2,30 @@ package store
import ( import (
"crypto/rand" "crypto/rand"
"database/sql"
"encoding/base32" "encoding/base32"
"time" "time"
"gorm.io/gorm"
) )
// UserStore user storage // UserStore user storage
type UserStore struct { type UserStore struct {
db *sql.DB db *gorm.DB
} }
// User user // User user model
type User struct { type User struct {
ID string `json:"id"` ID string `gorm:"primaryKey" json:"id"`
Email string `json:"email"` Email string `gorm:"uniqueIndex:idx_users_email;not null" json:"email"`
PasswordHash string `json:"-"` PasswordHash string `gorm:"column:password_hash;not null" json:"-"`
OTPSecret string `json:"-"` OTPSecret string `gorm:"column:otp_secret" json:"-"`
OTPVerified bool `json:"otp_verified"` OTPVerified bool `gorm:"column:otp_verified;default:false" json:"otp_verified"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
} }
func (User) TableName() string { return "users" }
// GenerateOTPSecret generates OTP secret // GenerateOTPSecret generates OTP secret
func GenerateOTPSecret() (string, error) { func GenerateOTPSecret() (string, error) {
secret := make([]byte, 20) secret := make([]byte, 20)
@@ -33,131 +36,101 @@ func GenerateOTPSecret() (string, error) {
return base32.StdEncoding.EncodeToString(secret), nil return base32.StdEncoding.EncodeToString(secret), nil
} }
// NewUserStore creates a new UserStore
func NewUserStore(db *gorm.DB) *UserStore {
return &UserStore{db: db}
}
func (s *UserStore) initTables() error { func (s *UserStore) initTables() error {
_, err := s.db.Exec(` // For PostgreSQL with existing table, skip AutoMigrate to avoid index conflicts
CREATE TABLE IF NOT EXISTS users ( if s.db.Dialector.Name() == "postgres" {
id TEXT PRIMARY KEY, var tableExists int64
email TEXT UNIQUE NOT NULL, s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'users'`).Scan(&tableExists)
password_hash TEXT NOT NULL,
otp_secret TEXT,
otp_verified BOOLEAN DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return err
}
// Trigger if tableExists > 0 {
_, err = s.db.Exec(` // Table exists - manually ensure all columns exist
CREATE TRIGGER IF NOT EXISTS update_users_updated_at // Core columns (should already exist)
AFTER UPDATE ON users s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS email TEXT NOT NULL DEFAULT ''`)
BEGIN s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS password_hash TEXT NOT NULL DEFAULT ''`)
UPDATE users SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id; s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP`)
END s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP`)
`) // OTP columns (added later)
if err != nil { s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS otp_secret TEXT DEFAULT ''`)
return err s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS otp_verified BOOLEAN DEFAULT FALSE`)
}
return nil // Ensure unique index exists on email (don't care about the name)
var indexExists int64
s.db.Raw(`
SELECT COUNT(*) FROM pg_indexes
WHERE tablename = 'users' AND indexdef LIKE '%email%' AND indexdef LIKE '%UNIQUE%'
`).Scan(&indexExists)
if indexExists == 0 {
s.db.Exec("CREATE UNIQUE INDEX idx_users_email ON users(email)")
}
return nil
}
}
return s.db.AutoMigrate(&User{})
} }
// Create creates user // Create creates user
func (s *UserStore) Create(user *User) error { func (s *UserStore) Create(user *User) error {
_, err := s.db.Exec(` return s.db.Create(user).Error
INSERT INTO users (id, email, password_hash, otp_secret, otp_verified)
VALUES (?, ?, ?, ?, ?)
`, user.ID, user.Email, user.PasswordHash, user.OTPSecret, user.OTPVerified)
return err
} }
// GetByEmail gets user by email // GetByEmail gets user by email
func (s *UserStore) GetByEmail(email string) (*User, error) { func (s *UserStore) GetByEmail(email string) (*User, error) {
var user User var user User
var createdAt, updatedAt string err := s.db.Where("email = ?", email).First(&user).Error
err := s.db.QueryRow(`
SELECT id, email, password_hash, otp_secret, otp_verified, created_at, updated_at
FROM users WHERE email = ?
`, email).Scan(
&user.ID, &user.Email, &user.PasswordHash, &user.OTPSecret,
&user.OTPVerified, &createdAt, &updatedAt,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
user.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
return &user, nil return &user, nil
} }
// GetByID gets user by ID // GetByID gets user by ID
func (s *UserStore) GetByID(userID string) (*User, error) { func (s *UserStore) GetByID(userID string) (*User, error) {
var user User var user User
var createdAt, updatedAt string err := s.db.Where("id = ?", userID).First(&user).Error
err := s.db.QueryRow(`
SELECT id, email, password_hash, otp_secret, otp_verified, created_at, updated_at
FROM users WHERE id = ?
`, userID).Scan(
&user.ID, &user.Email, &user.PasswordHash, &user.OTPSecret,
&user.OTPVerified, &createdAt, &updatedAt,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
user.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
return &user, nil return &user, nil
} }
// Count returns the total number of users // Count returns the total number of users
func (s *UserStore) Count() (int, error) { func (s *UserStore) Count() (int, error) {
var count int var count int64
err := s.db.QueryRow(`SELECT COUNT(*) FROM users`).Scan(&count) err := s.db.Model(&User{}).Count(&count).Error
return count, err return int(count), err
} }
// GetAllIDs gets all user IDs // GetAllIDs gets all user IDs
func (s *UserStore) GetAllIDs() ([]string, error) { func (s *UserStore) GetAllIDs() ([]string, error) {
rows, err := s.db.Query(`SELECT id FROM users ORDER BY id`)
if err != nil {
return nil, err
}
defer rows.Close()
var userIDs []string var userIDs []string
for rows.Next() { err := s.db.Model(&User{}).Order("id").Pluck("id", &userIDs).Error
var userID string return userIDs, err
if err := rows.Scan(&userID); err != nil {
return nil, err
}
userIDs = append(userIDs, userID)
}
return userIDs, nil
} }
// UpdateOTPVerified updates OTP verification status // UpdateOTPVerified updates OTP verification status
func (s *UserStore) UpdateOTPVerified(userID string, verified bool) error { func (s *UserStore) UpdateOTPVerified(userID string, verified bool) error {
_, err := s.db.Exec(`UPDATE users SET otp_verified = ? WHERE id = ?`, verified, userID) return s.db.Model(&User{}).Where("id = ?", userID).Update("otp_verified", verified).Error
return err
} }
// UpdatePassword updates password // UpdatePassword updates password
func (s *UserStore) UpdatePassword(userID, passwordHash string) error { func (s *UserStore) UpdatePassword(userID, passwordHash string) error {
_, err := s.db.Exec(` return s.db.Model(&User{}).Where("id = ?", userID).Updates(map[string]interface{}{
UPDATE users SET password_hash = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ? "password_hash": passwordHash,
`, passwordHash, userID) "updated_at": time.Now(),
return err }).Error
} }
// EnsureAdmin ensures admin user exists // EnsureAdmin ensures admin user exists
func (s *UserStore) EnsureAdmin() error { func (s *UserStore) EnsureAdmin() error {
var count int var count int64
err := s.db.QueryRow(`SELECT COUNT(*) FROM users WHERE id = 'admin'`).Scan(&count) s.db.Model(&User{}).Where("id = ?", "admin").Count(&count)
if err != nil {
return err
}
if count > 0 { if count > 0 {
return nil return nil
} }
+12 -8
View File
@@ -1,12 +1,13 @@
package trader package trader
import ( import (
"database/sql"
"nofx/store" "nofx/store"
"testing" "testing"
"time" "time"
_ "github.com/mattn/go-sqlite3" "gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
) )
// TestScenario represents a trading scenario to test // TestScenario represents a trading scenario to test
@@ -116,11 +117,12 @@ func runStandardTests(t *testing.T, exchangeName string) {
for _, scenario := range scenarios { for _, scenario := range scenarios {
t.Run(scenario.Name, func(t *testing.T) { t.Run(scenario.Name, func(t *testing.T) {
// Setup database // Setup database
db, err := sql.Open("sqlite3", ":memory:") db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil { if err != nil {
t.Fatalf("Failed to create test database: %v", err) t.Fatalf("Failed to create test database: %v", err)
} }
defer db.Close()
positionStore := store.NewPositionStore(db) positionStore := store.NewPositionStore(db)
if err := positionStore.InitTables(); err != nil { if err := positionStore.InitTables(); err != nil {
@@ -199,11 +201,12 @@ func TestAllExchangesStandardScenarios(t *testing.T) {
// TestPositionAccumulationBug tests that positions don't accumulate incorrectly // TestPositionAccumulationBug tests that positions don't accumulate incorrectly
func TestPositionAccumulationBug(t *testing.T) { func TestPositionAccumulationBug(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:") db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil { if err != nil {
t.Fatalf("Failed to create test database: %v", err) t.Fatalf("Failed to create test database: %v", err)
} }
defer db.Close()
positionStore := store.NewPositionStore(db) positionStore := store.NewPositionStore(db)
if err := positionStore.InitTables(); err != nil { if err := positionStore.InitTables(); err != nil {
@@ -283,11 +286,12 @@ func TestPositionAccumulationBug(t *testing.T) {
// TestQuantityPrecision tests handling of quantity precision issues // TestQuantityPrecision tests handling of quantity precision issues
func TestQuantityPrecision(t *testing.T) { func TestQuantityPrecision(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:") db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil { if err != nil {
t.Fatalf("Failed to create test database: %v", err) t.Fatalf("Failed to create test database: %v", err)
} }
defer db.Close()
positionStore := store.NewPositionStore(db) positionStore := store.NewPositionStore(db)
if err := positionStore.InitTables(); err != nil { if err := positionStore.InitTables(); err != nil {
+9 -6
View File
@@ -1,13 +1,14 @@
package trader package trader
import ( import (
"database/sql"
"math" "math"
"nofx/store" "nofx/store"
"testing" "testing"
"time" "time"
_ "github.com/mattn/go-sqlite3" "gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
) )
// TestHyperliquidOrderDirectionParsing tests Dir field parsing // TestHyperliquidOrderDirectionParsing tests Dir field parsing
@@ -75,11 +76,12 @@ func TestHyperliquidOrderDirectionParsing(t *testing.T) {
// TestHyperliquidPositionBuilding tests the complete flow of position building // TestHyperliquidPositionBuilding tests the complete flow of position building
func TestHyperliquidPositionBuilding(t *testing.T) { func TestHyperliquidPositionBuilding(t *testing.T) {
// Setup in-memory database // Setup in-memory database
db, err := sql.Open("sqlite3", ":memory:") db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil { if err != nil {
t.Fatalf("Failed to create test database: %v", err) t.Fatalf("Failed to create test database: %v", err)
} }
defer db.Close()
// Initialize stores // Initialize stores
positionStore := store.NewPositionStore(db) positionStore := store.NewPositionStore(db)
@@ -304,11 +306,12 @@ func TestHyperliquidPositionBuilding(t *testing.T) {
// TestHyperliquidBugScenario tests the exact bug we fixed // TestHyperliquidBugScenario tests the exact bug we fixed
func TestHyperliquidBugScenario(t *testing.T) { func TestHyperliquidBugScenario(t *testing.T) {
// Setup database // Setup database
db, err := sql.Open("sqlite3", ":memory:") db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil { if err != nil {
t.Fatalf("Failed to create test database: %v", err) t.Fatalf("Failed to create test database: %v", err)
} }
defer db.Close()
positionStore := store.NewPositionStore(db) positionStore := store.NewPositionStore(db)
if err := positionStore.InitTables(); err != nil { if err := positionStore.InitTables(); err != nil {
@@ -1,5 +1,5 @@
import { motion } from 'framer-motion' import { motion } from 'framer-motion'
import { Bot, TrendingUp, Layers, Zap, Hexagon, Crosshair } from 'lucide-react' import { TrendingUp, Layers, Zap, Hexagon, Crosshair } from 'lucide-react'
const agents = [ const agents = [
{ {
+10 -3
View File
@@ -1,8 +1,15 @@
import { motion } from 'framer-motion' import { motion } from 'framer-motion'
import { Activity, BarChart3, Globe, Wifi, Server, Database, Lock } from 'lucide-react'
import { useState, useEffect } from 'react' import { useState, useEffect } from 'react'
const generateLog = (id) => { interface LogEntry {
id: number
time: string
type: string
msg: string
color: string
}
const generateLog = (id: number): LogEntry => {
const types = ['EXE', 'ARB', 'LIQ', 'NET', 'SYS'] const types = ['EXE', 'ARB', 'LIQ', 'NET', 'SYS']
const pairs = ['BTC-USDT', 'ETH-PERP', 'SOL-USDT', 'BNB-BUSD'] const pairs = ['BTC-USDT', 'ETH-PERP', 'SOL-USDT', 'BNB-BUSD']
const actions = ['BUY', 'SELL', 'SHORT', 'LONG'] const actions = ['BUY', 'SELL', 'SHORT', 'LONG']
@@ -37,7 +44,7 @@ const generateLog = (id) => {
} }
export default function LiveFeed() { export default function LiveFeed() {
const [logs, setLogs] = useState([]) const [logs, setLogs] = useState<LogEntry[]>([])
useEffect(() => { useEffect(() => {
// Initial population // Initial population
@@ -159,11 +159,24 @@ export default function TerminalHero() {
<span className="text-stroke-1 text-transparent bg-clip-text bg-gradient-to-r from-nofx-gold via-white to-nofx-gold animate-shimmer bg-[length:200%_auto]">TRADING</span> <span className="text-stroke-1 text-transparent bg-clip-text bg-gradient-to-r from-nofx-gold via-white to-nofx-gold animate-shimmer bg-[length:200%_auto]">TRADING</span>
</h1> </h1>
<p className="max-w-xl text-zinc-400 text-lg mb-10 font-light leading-relaxed"> <p className="max-w-xl text-zinc-400 text-lg mb-6 font-light leading-relaxed">
The World's First Open-Source Agentic Trading OS. The World's First Open-Source Agentic Trading OS.
Deploy autonomous high-frequency trading agents powered by advanced LLMs. Deploy autonomous high-frequency trading agents powered by advanced LLMs.
</p> </p>
{/* Market Access Strip - Prominent Display */}
<div className="flex flex-wrap gap-4 mb-12 font-mono">
{['CRYPTO', 'US STOCKS', 'FOREX', 'METALS'].map((market) => (
<div key={market} className="flex items-center gap-3 px-4 py-2 rounded bg-zinc-900 border border-zinc-700 text-white font-bold tracking-wider hover:border-nofx-gold hover:shadow-[0_0_15px_rgba(255,215,0,0.3)] transition-all duration-300">
<span className="relative flex h-2 w-2">
<span className="animate-ping absolute inline-flex h-full w-full rounded-full bg-nofx-success opacity-75"></span>
<span className="relative inline-flex rounded-full h-2 w-2 bg-nofx-success"></span>
</span>
{market}
</div>
))}
</div>
{/* Command Line Input Simulation */} {/* Command Line Input Simulation */}
<div className="w-full max-w-lg h-12 bg-black/50 border border-zinc-800 rounded flex items-center px-4 mb-10 font-mono text-sm shadow-2xl backdrop-blur-sm group hover:border-nofx-gold/50 transition-colors cursor-text" onClick={() => document.getElementById('market-scanner')?.scrollIntoView({ behavior: 'smooth' })}> <div className="w-full max-w-lg h-12 bg-black/50 border border-zinc-800 rounded flex items-center px-4 mb-10 font-mono text-sm shadow-2xl backdrop-blur-sm group hover:border-nofx-gold/50 transition-colors cursor-text" onClick={() => document.getElementById('market-scanner')?.scrollIntoView({ behavior: 'smooth' })}>
<span className="text-nofx-success mr-2"></span> <span className="text-nofx-success mr-2"></span>
@@ -269,7 +282,7 @@ export default function TerminalHero() {
import { OFFICIAL_LINKS } from '../../../constants/branding' import { OFFICIAL_LINKS } from '../../../constants/branding'
function CommunityStats() { function CommunityStats() {
const { stars, forks, contributors, isLoading, error } = useGitHubStats('tinkle-community', 'nofx') const { stars, forks, contributors, isLoading, error } = useGitHubStats('NoFxAiOS', 'nofx')
const stats = [ const stats = [
{ {