mirror of
https://github.com/laoxong/nofx.git
synced 2026-06-04 01:48:22 +08:00
feat: migrate store layer to GORM with PostgreSQL support
- Migrate all store packages from raw database/sql to GORM ORM - Add PostgreSQL support alongside SQLite - Move EncryptedString type to crypto package for cleaner architecture - Add automatic encryption/decryption for sensitive fields (API keys, secrets) - Fix PostgreSQL AutoMigrate conflicts by skipping existing tables - Fix duplicate /klines route registration - Update tests to use GORM database connections - Add database configuration support in config package
This commit is contained in:
@@ -55,3 +55,16 @@ TRANSPORT_ENCRYPTION=false
|
||||
# Telegram notifications (optional)
|
||||
# TELEGRAM_BOT_TOKEN=your-bot-token
|
||||
# TELEGRAM_CHAT_ID=your-chat-id
|
||||
|
||||
DB_TYPE=postgres
|
||||
DB_HOST=10.
|
||||
DB_PORT=5432
|
||||
DB_USER=nofx_user
|
||||
DB_PASSWORD=
|
||||
DB_NAME=nofx
|
||||
DB_SSLMODE=disable
|
||||
|
||||
|
||||
# 数据库配置 - SQLite(默认)
|
||||
DB_TYPE=sqlite
|
||||
DB_PATH=data/data.db
|
||||
+1
-1
@@ -824,7 +824,7 @@ func (s *Server) hydrateBacktestAIConfig(cfg *backtest.BacktestConfig) error {
|
||||
return fmt.Errorf("AI model %s is not enabled yet", model.Name)
|
||||
}
|
||||
|
||||
apiKey := strings.TrimSpace(model.APIKey)
|
||||
apiKey := strings.TrimSpace(string(model.APIKey))
|
||||
if apiKey == "" {
|
||||
return fmt.Errorf("AI model %s is missing API Key, please configure it in the system first", model.Name)
|
||||
}
|
||||
|
||||
+43
-41
@@ -54,7 +54,7 @@ func NewServer(traderManager *manager.TraderManager, st *store.Store, cryptoServ
|
||||
cryptoHandler := NewCryptoHandler(cryptoService)
|
||||
|
||||
// Create debate store and handler
|
||||
debateStore := store.NewDebateStore(st.DB())
|
||||
debateStore := store.NewDebateStore(st.GormDB())
|
||||
if err := debateStore.InitSchema(); err != nil {
|
||||
logger.Errorf("Failed to initialize debate schema: %v", err)
|
||||
}
|
||||
@@ -125,7 +125,6 @@ func (s *Server) setupRoutes() {
|
||||
|
||||
// Market data (no authentication required)
|
||||
api.GET("/klines", s.handleKlines)
|
||||
api.GET("/klines", s.handleKlines)
|
||||
api.GET("/symbols", s.handleSymbols)
|
||||
|
||||
// Authentication related routes (no authentication required)
|
||||
@@ -576,12 +575,13 @@ func (s *Server) handleCreateTrader(c *gin.Context) {
|
||||
var createErr error
|
||||
|
||||
// Use ExchangeType (e.g., "binance") instead of ID (UUID)
|
||||
// Convert EncryptedString fields to string
|
||||
switch exchangeCfg.ExchangeType {
|
||||
case "binance":
|
||||
tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey, userID)
|
||||
tempTrader = trader.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID)
|
||||
case "hyperliquid":
|
||||
tempTrader, createErr = trader.NewHyperliquidTrader(
|
||||
exchangeCfg.APIKey, // private key
|
||||
string(exchangeCfg.APIKey), // private key
|
||||
exchangeCfg.HyperliquidWalletAddr,
|
||||
exchangeCfg.Testnet,
|
||||
)
|
||||
@@ -589,31 +589,31 @@ func (s *Server) handleCreateTrader(c *gin.Context) {
|
||||
tempTrader, createErr = trader.NewAsterTrader(
|
||||
exchangeCfg.AsterUser,
|
||||
exchangeCfg.AsterSigner,
|
||||
exchangeCfg.AsterPrivateKey,
|
||||
string(exchangeCfg.AsterPrivateKey),
|
||||
)
|
||||
case "bybit":
|
||||
tempTrader = trader.NewBybitTrader(
|
||||
exchangeCfg.APIKey,
|
||||
exchangeCfg.SecretKey,
|
||||
string(exchangeCfg.APIKey),
|
||||
string(exchangeCfg.SecretKey),
|
||||
)
|
||||
case "okx":
|
||||
tempTrader = trader.NewOKXTrader(
|
||||
exchangeCfg.APIKey,
|
||||
exchangeCfg.SecretKey,
|
||||
exchangeCfg.Passphrase,
|
||||
string(exchangeCfg.APIKey),
|
||||
string(exchangeCfg.SecretKey),
|
||||
string(exchangeCfg.Passphrase),
|
||||
)
|
||||
case "bitget":
|
||||
tempTrader = trader.NewBitgetTrader(
|
||||
exchangeCfg.APIKey,
|
||||
exchangeCfg.SecretKey,
|
||||
exchangeCfg.Passphrase,
|
||||
string(exchangeCfg.APIKey),
|
||||
string(exchangeCfg.SecretKey),
|
||||
string(exchangeCfg.Passphrase),
|
||||
)
|
||||
case "lighter":
|
||||
if exchangeCfg.LighterWalletAddr != "" && exchangeCfg.LighterAPIKeyPrivateKey != "" {
|
||||
if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" {
|
||||
// Lighter only supports mainnet
|
||||
tempTrader, createErr = trader.NewLighterTraderV2(
|
||||
exchangeCfg.LighterWalletAddr,
|
||||
exchangeCfg.LighterAPIKeyPrivateKey,
|
||||
string(exchangeCfg.LighterAPIKeyPrivateKey),
|
||||
exchangeCfg.LighterAPIKeyIndex,
|
||||
false, // Always use mainnet for Lighter
|
||||
)
|
||||
@@ -1095,12 +1095,13 @@ func (s *Server) handleSyncBalance(c *gin.Context) {
|
||||
var createErr error
|
||||
|
||||
// Use ExchangeType (e.g., "binance") instead of ExchangeID (which is now UUID)
|
||||
// Convert EncryptedString fields to string
|
||||
switch exchangeCfg.ExchangeType {
|
||||
case "binance":
|
||||
tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey, userID)
|
||||
tempTrader = trader.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID)
|
||||
case "hyperliquid":
|
||||
tempTrader, createErr = trader.NewHyperliquidTrader(
|
||||
exchangeCfg.APIKey,
|
||||
string(exchangeCfg.APIKey),
|
||||
exchangeCfg.HyperliquidWalletAddr,
|
||||
exchangeCfg.Testnet,
|
||||
)
|
||||
@@ -1108,31 +1109,31 @@ func (s *Server) handleSyncBalance(c *gin.Context) {
|
||||
tempTrader, createErr = trader.NewAsterTrader(
|
||||
exchangeCfg.AsterUser,
|
||||
exchangeCfg.AsterSigner,
|
||||
exchangeCfg.AsterPrivateKey,
|
||||
string(exchangeCfg.AsterPrivateKey),
|
||||
)
|
||||
case "bybit":
|
||||
tempTrader = trader.NewBybitTrader(
|
||||
exchangeCfg.APIKey,
|
||||
exchangeCfg.SecretKey,
|
||||
string(exchangeCfg.APIKey),
|
||||
string(exchangeCfg.SecretKey),
|
||||
)
|
||||
case "okx":
|
||||
tempTrader = trader.NewOKXTrader(
|
||||
exchangeCfg.APIKey,
|
||||
exchangeCfg.SecretKey,
|
||||
exchangeCfg.Passphrase,
|
||||
string(exchangeCfg.APIKey),
|
||||
string(exchangeCfg.SecretKey),
|
||||
string(exchangeCfg.Passphrase),
|
||||
)
|
||||
case "bitget":
|
||||
tempTrader = trader.NewBitgetTrader(
|
||||
exchangeCfg.APIKey,
|
||||
exchangeCfg.SecretKey,
|
||||
exchangeCfg.Passphrase,
|
||||
string(exchangeCfg.APIKey),
|
||||
string(exchangeCfg.SecretKey),
|
||||
string(exchangeCfg.Passphrase),
|
||||
)
|
||||
case "lighter":
|
||||
if exchangeCfg.LighterWalletAddr != "" && exchangeCfg.LighterAPIKeyPrivateKey != "" {
|
||||
if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" {
|
||||
// Lighter only supports mainnet
|
||||
tempTrader, createErr = trader.NewLighterTraderV2(
|
||||
exchangeCfg.LighterWalletAddr,
|
||||
exchangeCfg.LighterAPIKeyPrivateKey,
|
||||
string(exchangeCfg.LighterAPIKeyPrivateKey),
|
||||
exchangeCfg.LighterAPIKeyIndex,
|
||||
false, // Always use mainnet for Lighter
|
||||
)
|
||||
@@ -1246,12 +1247,13 @@ func (s *Server) handleClosePosition(c *gin.Context) {
|
||||
var createErr error
|
||||
|
||||
// Use ExchangeType (e.g., "binance") instead of ExchangeID (which is now UUID)
|
||||
// Convert EncryptedString fields to string
|
||||
switch exchangeCfg.ExchangeType {
|
||||
case "binance":
|
||||
tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey, userID)
|
||||
tempTrader = trader.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID)
|
||||
case "hyperliquid":
|
||||
tempTrader, createErr = trader.NewHyperliquidTrader(
|
||||
exchangeCfg.APIKey,
|
||||
string(exchangeCfg.APIKey),
|
||||
exchangeCfg.HyperliquidWalletAddr,
|
||||
exchangeCfg.Testnet,
|
||||
)
|
||||
@@ -1259,31 +1261,31 @@ func (s *Server) handleClosePosition(c *gin.Context) {
|
||||
tempTrader, createErr = trader.NewAsterTrader(
|
||||
exchangeCfg.AsterUser,
|
||||
exchangeCfg.AsterSigner,
|
||||
exchangeCfg.AsterPrivateKey,
|
||||
string(exchangeCfg.AsterPrivateKey),
|
||||
)
|
||||
case "bybit":
|
||||
tempTrader = trader.NewBybitTrader(
|
||||
exchangeCfg.APIKey,
|
||||
exchangeCfg.SecretKey,
|
||||
string(exchangeCfg.APIKey),
|
||||
string(exchangeCfg.SecretKey),
|
||||
)
|
||||
case "okx":
|
||||
tempTrader = trader.NewOKXTrader(
|
||||
exchangeCfg.APIKey,
|
||||
exchangeCfg.SecretKey,
|
||||
exchangeCfg.Passphrase,
|
||||
string(exchangeCfg.APIKey),
|
||||
string(exchangeCfg.SecretKey),
|
||||
string(exchangeCfg.Passphrase),
|
||||
)
|
||||
case "bitget":
|
||||
tempTrader = trader.NewBitgetTrader(
|
||||
exchangeCfg.APIKey,
|
||||
exchangeCfg.SecretKey,
|
||||
exchangeCfg.Passphrase,
|
||||
string(exchangeCfg.APIKey),
|
||||
string(exchangeCfg.SecretKey),
|
||||
string(exchangeCfg.Passphrase),
|
||||
)
|
||||
case "lighter":
|
||||
if exchangeCfg.LighterWalletAddr != "" && exchangeCfg.LighterAPIKeyPrivateKey != "" {
|
||||
if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" {
|
||||
// Lighter only supports mainnet
|
||||
tempTrader, createErr = trader.NewLighterTraderV2(
|
||||
exchangeCfg.LighterWalletAddr,
|
||||
exchangeCfg.LighterAPIKeyPrivateKey,
|
||||
string(exchangeCfg.LighterAPIKeyPrivateKey),
|
||||
exchangeCfg.LighterAPIKeyIndex,
|
||||
false, // Always use mainnet for Lighter
|
||||
)
|
||||
|
||||
+10
-8
@@ -549,32 +549,34 @@ func (s *Server) runRealAITest(userID, modelID, systemPrompt, userPrompt string)
|
||||
var aiClient mcp.AIClient
|
||||
provider := model.Provider
|
||||
|
||||
// Convert EncryptedString to string for API key
|
||||
apiKey := string(model.APIKey)
|
||||
switch provider {
|
||||
case "qwen":
|
||||
aiClient = mcp.NewQwenClient()
|
||||
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
|
||||
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||
case "deepseek":
|
||||
aiClient = mcp.NewDeepSeekClient()
|
||||
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
|
||||
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||
case "claude":
|
||||
aiClient = mcp.NewClaudeClient()
|
||||
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
|
||||
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||
case "kimi":
|
||||
aiClient = mcp.NewKimiClient()
|
||||
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
|
||||
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||
case "gemini":
|
||||
aiClient = mcp.NewGeminiClient()
|
||||
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
|
||||
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||
case "grok":
|
||||
aiClient = mcp.NewGrokClient()
|
||||
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
|
||||
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||
case "openai":
|
||||
aiClient = mcp.NewOpenAIClient()
|
||||
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
|
||||
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||
default:
|
||||
// Use generic client
|
||||
aiClient = mcp.NewClient()
|
||||
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
|
||||
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||
}
|
||||
|
||||
// Call AI API
|
||||
|
||||
@@ -76,8 +76,8 @@ func enforceRetentionDB(maxRuns int) {
|
||||
query := `
|
||||
SELECT run_id FROM backtest_runs
|
||||
WHERE state IN (?, ?, ?, ?)
|
||||
ORDER BY datetime(updated_at) DESC
|
||||
LIMIT -1 OFFSET ?
|
||||
ORDER BY updated_at DESC
|
||||
OFFSET ?
|
||||
`
|
||||
rows, err := persistenceDB.Query(query,
|
||||
finalStates[0], finalStates[1], finalStates[2], finalStates[3], maxRuns)
|
||||
|
||||
@@ -166,7 +166,7 @@ func loadRunMetadataDB(runID string) (*RunMetadata, error) {
|
||||
}
|
||||
|
||||
func loadRunIDsDB() ([]string, error) {
|
||||
rows, err := persistenceDB.Query(`SELECT run_id FROM backtest_runs ORDER BY datetime(updated_at) DESC`)
|
||||
rows, err := persistenceDB.Query(`SELECT run_id FROM backtest_runs ORDER BY updated_at DESC`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -278,9 +278,9 @@ func loadDecisionTraceDB(runID string, cycle int) (*store.DecisionRecord, error)
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if cycle > 0 {
|
||||
rows, err = persistenceDB.Query(query+` AND cycle = ? ORDER BY datetime(created_at) DESC LIMIT 1`, runID, cycle)
|
||||
rows, err = persistenceDB.Query(query+` AND cycle = ? ORDER BY created_at DESC LIMIT 1`, runID, cycle)
|
||||
} else {
|
||||
rows, err = persistenceDB.Query(query+` ORDER BY datetime(created_at) DESC LIMIT 1`, runID)
|
||||
rows, err = persistenceDB.Query(query+` ORDER BY created_at DESC LIMIT 1`, runID)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -461,7 +461,7 @@ func listIndexEntriesDB() ([]RunIndexEntry, error) {
|
||||
rows, err := persistenceDB.Query(`
|
||||
SELECT run_id, state, symbol_count, decision_tf, equity_last, max_drawdown_pct, created_at, updated_at, config_json
|
||||
FROM backtest_runs
|
||||
ORDER BY datetime(updated_at) DESC
|
||||
ORDER BY updated_at DESC
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -20,6 +20,16 @@ type Config struct {
|
||||
RegistrationEnabled bool
|
||||
MaxUsers int // Maximum number of users allowed (0 = unlimited, default = 10)
|
||||
|
||||
// Database configuration
|
||||
DBType string // sqlite or postgres
|
||||
DBPath string // SQLite database file path
|
||||
DBHost string // PostgreSQL host
|
||||
DBPort int // PostgreSQL port
|
||||
DBUser string // PostgreSQL user
|
||||
DBPassword string // PostgreSQL password
|
||||
DBName string // PostgreSQL database name
|
||||
DBSSLMode string // PostgreSQL SSL mode
|
||||
|
||||
// Security configuration
|
||||
// TransportEncryption enables browser-side encryption for API keys
|
||||
// Requires HTTPS or localhost. Set to false for HTTP access via IP.
|
||||
@@ -43,6 +53,14 @@ func Init() {
|
||||
RegistrationEnabled: true,
|
||||
MaxUsers: 10, // Default: 10 users allowed
|
||||
ExperienceImprovement: true, // Default: enabled to help improve the product
|
||||
// Database defaults
|
||||
DBType: "sqlite",
|
||||
DBPath: "data/data.db",
|
||||
DBHost: "localhost",
|
||||
DBPort: 5432,
|
||||
DBUser: "postgres",
|
||||
DBName: "nofx",
|
||||
DBSSLMode: "disable",
|
||||
}
|
||||
|
||||
// Load from environment variables
|
||||
@@ -86,6 +104,34 @@ func Init() {
|
||||
cfg.AlpacaSecretKey = os.Getenv("ALPACA_SECRET_KEY")
|
||||
cfg.TwelveDataKey = os.Getenv("TWELVEDATA_API_KEY")
|
||||
|
||||
// Database configuration
|
||||
if v := os.Getenv("DB_TYPE"); v != "" {
|
||||
cfg.DBType = strings.ToLower(v)
|
||||
}
|
||||
if v := os.Getenv("DB_PATH"); v != "" {
|
||||
cfg.DBPath = v
|
||||
}
|
||||
if v := os.Getenv("DB_HOST"); v != "" {
|
||||
cfg.DBHost = v
|
||||
}
|
||||
if v := os.Getenv("DB_PORT"); v != "" {
|
||||
if port, err := strconv.Atoi(v); err == nil && port > 0 {
|
||||
cfg.DBPort = port
|
||||
}
|
||||
}
|
||||
if v := os.Getenv("DB_USER"); v != "" {
|
||||
cfg.DBUser = v
|
||||
}
|
||||
if v := os.Getenv("DB_PASSWORD"); v != "" {
|
||||
cfg.DBPassword = v
|
||||
}
|
||||
if v := os.Getenv("DB_NAME"); v != "" {
|
||||
cfg.DBName = v
|
||||
}
|
||||
if v := os.Getenv("DB_SSLMODE"); v != "" {
|
||||
cfg.DBSSLMode = v
|
||||
}
|
||||
|
||||
global = cfg
|
||||
|
||||
// Initialize experience improvement (installation ID will be set after database init)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"database/sql/driver"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
@@ -392,3 +393,77 @@ func GenerateDataKey() (string, error) {
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(key), nil
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// EncryptedString - GORM custom type for automatic encryption/decryption
|
||||
// ============================================================================
|
||||
|
||||
// Global crypto service for EncryptedString
|
||||
var globalCryptoService *CryptoService
|
||||
|
||||
// SetGlobalCryptoService sets the global crypto service for EncryptedString
|
||||
func SetGlobalCryptoService(cs *CryptoService) {
|
||||
globalCryptoService = cs
|
||||
}
|
||||
|
||||
// EncryptedString is a custom type that automatically encrypts on save and decrypts on load
|
||||
// Usage: Use EncryptedString instead of string for sensitive fields in GORM models
|
||||
type EncryptedString string
|
||||
|
||||
// Scan implements sql.Scanner - called when reading from database
|
||||
// Automatically decrypts the value
|
||||
func (es *EncryptedString) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
*es = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
var str string
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
str = v
|
||||
case []byte:
|
||||
str = string(v)
|
||||
default:
|
||||
*es = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decrypt if crypto service is set
|
||||
if globalCryptoService != nil && str != "" && globalCryptoService.IsEncryptedStorageValue(str) {
|
||||
decrypted, err := globalCryptoService.DecryptFromStorage(str)
|
||||
if err != nil {
|
||||
// If decryption fails, return the original value
|
||||
*es = EncryptedString(str)
|
||||
} else {
|
||||
*es = EncryptedString(decrypted)
|
||||
}
|
||||
} else {
|
||||
*es = EncryptedString(str)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements driver.Valuer - called when writing to database
|
||||
// Automatically encrypts the value
|
||||
func (es EncryptedString) Value() (driver.Value, error) {
|
||||
if es == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Encrypt if crypto service is set
|
||||
if globalCryptoService != nil {
|
||||
encrypted, err := globalCryptoService.EncryptForStorage(string(es))
|
||||
if err != nil {
|
||||
// If encryption fails, return the original value
|
||||
return string(es), nil
|
||||
}
|
||||
return encrypted, nil
|
||||
}
|
||||
return string(es), nil
|
||||
}
|
||||
|
||||
// String returns the plaintext string value
|
||||
func (es EncryptedString) String() string {
|
||||
return string(es)
|
||||
}
|
||||
|
||||
+2
-2
@@ -101,8 +101,8 @@ func (e *DebateEngine) InitializeClients(participants []*store.DebateParticipant
|
||||
client = mcp.New()
|
||||
}
|
||||
|
||||
// Configure client
|
||||
client.SetAPIKey(aiModel.APIKey, aiModel.CustomAPIURL, aiModel.CustomModelName)
|
||||
// Configure client (convert EncryptedString to string)
|
||||
client.SetAPIKey(string(aiModel.APIKey), aiModel.CustomAPIURL, aiModel.CustomModelName)
|
||||
|
||||
e.clients[p.AIModelID] = client
|
||||
}
|
||||
|
||||
@@ -51,6 +51,12 @@ require (
|
||||
github.com/goccy/go-json v0.10.4 // indirect
|
||||
github.com/goccy/go-yaml v1.18.0 // indirect
|
||||
github.com/holiman/uint256 v1.3.2 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/pgx/v5 v5.6.0 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
github.com/jpillora/backoff v1.0.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
@@ -94,6 +100,9 @@ require (
|
||||
golang.org/x/tools v0.36.0 // indirect
|
||||
google.golang.org/protobuf v1.36.9 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
gorm.io/driver/postgres v1.6.0 // indirect
|
||||
gorm.io/driver/sqlite v1.6.0 // indirect
|
||||
gorm.io/gorm v1.31.1 // indirect
|
||||
howett.net/plist v1.0.1 // indirect
|
||||
modernc.org/libc v1.66.10 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
|
||||
@@ -111,7 +111,19 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/holiman/uint256 v1.3.2 h1:a9EgMPSC1AAaj1SZL5zIQD3WbwTuHrMGOerLjGmM/TA=
|
||||
github.com/holiman/uint256 v1.3.2/go.mod h1:EOMSn4q6Nyt9P6efbI3bueV4e1b3dGlUCXeiRV4ng7E=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
|
||||
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
||||
@@ -284,6 +296,12 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4=
|
||||
gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo=
|
||||
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
|
||||
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
|
||||
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
|
||||
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
|
||||
howett.net/plist v1.0.1 h1:37GdZ8tP09Q35o9ych3ehygcsL+HqKSwzctveSlarvM=
|
||||
howett.net/plist v1.0.1/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
|
||||
modernc.org/cc/v4 v4.26.5 h1:xM3bX7Mve6G8K8b+T11ReenJOT+BmVqQj0FY5T4+5Y4=
|
||||
|
||||
@@ -36,30 +36,44 @@ func main() {
|
||||
cfg := config.Get()
|
||||
logger.Info("✅ Configuration loaded")
|
||||
|
||||
// Initialize database from environment variables
|
||||
// DB_TYPE: sqlite (default) or postgres
|
||||
// For SQLite: DB_PATH (default: data/data.db)
|
||||
// For PostgreSQL: DB_HOST, DB_PORT, DB_USER, DB_PASSWORD, DB_NAME, DB_SSLMODE
|
||||
dbPath := os.Getenv("DB_PATH")
|
||||
if dbPath == "" {
|
||||
dbPath = "data/data.db"
|
||||
// Initialize encryption service BEFORE database (so EncryptedString can decrypt on read)
|
||||
logger.Info("🔐 Initializing encryption service...")
|
||||
cryptoService, err := crypto.NewCryptoService()
|
||||
if err != nil {
|
||||
logger.Fatalf("❌ Failed to initialize encryption service: %v", err)
|
||||
}
|
||||
// For backward compatibility: command line arg overrides env var (SQLite only)
|
||||
crypto.SetGlobalCryptoService(cryptoService)
|
||||
logger.Info("✅ Encryption service initialized successfully")
|
||||
|
||||
// Initialize database from configuration
|
||||
// For backward compatibility: command line arg overrides config (SQLite only)
|
||||
if len(os.Args) > 1 {
|
||||
dbPath = os.Args[1]
|
||||
os.Setenv("DB_PATH", dbPath)
|
||||
cfg.DBPath = os.Args[1]
|
||||
}
|
||||
// Ensure data directory exists (for SQLite)
|
||||
if os.Getenv("DB_TYPE") == "" || os.Getenv("DB_TYPE") == "sqlite" {
|
||||
if dir := filepath.Dir(dbPath); dir != "." {
|
||||
if cfg.DBType == "sqlite" {
|
||||
if dir := filepath.Dir(cfg.DBPath); dir != "." {
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
logger.Errorf("Failed to create data directory: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("📋 Initializing database...")
|
||||
st, err := store.NewFromEnv()
|
||||
logger.Infof("📋 Initializing database (%s)...", cfg.DBType)
|
||||
dbType := store.DBTypeSQLite
|
||||
if cfg.DBType == "postgres" {
|
||||
dbType = store.DBTypePostgres
|
||||
}
|
||||
st, err := store.NewWithConfig(store.DBConfig{
|
||||
Type: dbType,
|
||||
Path: cfg.DBPath,
|
||||
Host: cfg.DBHost,
|
||||
Port: cfg.DBPort,
|
||||
User: cfg.DBUser,
|
||||
Password: cfg.DBPassword,
|
||||
DBName: cfg.DBName,
|
||||
SSLMode: cfg.DBSSLMode,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Fatalf("❌ Failed to initialize database: %v", err)
|
||||
}
|
||||
@@ -69,40 +83,6 @@ func main() {
|
||||
// Initialize installation ID for experience improvement (anonymous statistics)
|
||||
initInstallationID(st)
|
||||
|
||||
// Initialize encryption service
|
||||
logger.Info("🔐 Initializing encryption service...")
|
||||
cryptoService, err := crypto.NewCryptoService()
|
||||
if err != nil {
|
||||
logger.Fatalf("❌ Failed to initialize encryption service: %v", err)
|
||||
}
|
||||
encryptFunc := func(plaintext string) string {
|
||||
if plaintext == "" {
|
||||
return plaintext
|
||||
}
|
||||
encrypted, err := cryptoService.EncryptForStorage(plaintext)
|
||||
if err != nil {
|
||||
logger.Warnf("⚠️ Encryption failed: %v", err)
|
||||
return plaintext
|
||||
}
|
||||
return encrypted
|
||||
}
|
||||
decryptFunc := func(encrypted string) string {
|
||||
if encrypted == "" {
|
||||
return encrypted
|
||||
}
|
||||
if !cryptoService.IsEncryptedStorageValue(encrypted) {
|
||||
return encrypted
|
||||
}
|
||||
decrypted, err := cryptoService.DecryptFromStorage(encrypted)
|
||||
if err != nil {
|
||||
logger.Warnf("⚠️ Decryption failed: %v", err)
|
||||
return encrypted
|
||||
}
|
||||
return decrypted
|
||||
}
|
||||
st.SetCryptoFuncs(encryptFunc, decryptFunc)
|
||||
logger.Info("✅ Encryption service initialized successfully")
|
||||
|
||||
// Set JWT secret
|
||||
auth.SetJWTSecret(cfg.JWTSecret)
|
||||
logger.Info("🔑 JWT secret configured")
|
||||
|
||||
+19
-19
@@ -664,46 +664,46 @@ func (tm *TraderManager) addTraderFromStore(traderCfg *store.Trader, aiModelCfg
|
||||
StrategyConfig: strategyConfig,
|
||||
}
|
||||
|
||||
// Set API keys based on exchange type
|
||||
// Set API keys based on exchange type (convert EncryptedString to string)
|
||||
switch exchangeCfg.ExchangeType {
|
||||
case "binance":
|
||||
traderConfig.BinanceAPIKey = exchangeCfg.APIKey
|
||||
traderConfig.BinanceSecretKey = exchangeCfg.SecretKey
|
||||
traderConfig.BinanceAPIKey = string(exchangeCfg.APIKey)
|
||||
traderConfig.BinanceSecretKey = string(exchangeCfg.SecretKey)
|
||||
case "bybit":
|
||||
traderConfig.BybitAPIKey = exchangeCfg.APIKey
|
||||
traderConfig.BybitSecretKey = exchangeCfg.SecretKey
|
||||
traderConfig.BybitAPIKey = string(exchangeCfg.APIKey)
|
||||
traderConfig.BybitSecretKey = string(exchangeCfg.SecretKey)
|
||||
case "okx":
|
||||
traderConfig.OKXAPIKey = exchangeCfg.APIKey
|
||||
traderConfig.OKXSecretKey = exchangeCfg.SecretKey
|
||||
traderConfig.OKXPassphrase = exchangeCfg.Passphrase
|
||||
traderConfig.OKXAPIKey = string(exchangeCfg.APIKey)
|
||||
traderConfig.OKXSecretKey = string(exchangeCfg.SecretKey)
|
||||
traderConfig.OKXPassphrase = string(exchangeCfg.Passphrase)
|
||||
case "bitget":
|
||||
traderConfig.BitgetAPIKey = exchangeCfg.APIKey
|
||||
traderConfig.BitgetSecretKey = exchangeCfg.SecretKey
|
||||
traderConfig.BitgetPassphrase = exchangeCfg.Passphrase
|
||||
traderConfig.BitgetAPIKey = string(exchangeCfg.APIKey)
|
||||
traderConfig.BitgetSecretKey = string(exchangeCfg.SecretKey)
|
||||
traderConfig.BitgetPassphrase = string(exchangeCfg.Passphrase)
|
||||
case "hyperliquid":
|
||||
traderConfig.HyperliquidPrivateKey = exchangeCfg.APIKey
|
||||
traderConfig.HyperliquidPrivateKey = string(exchangeCfg.APIKey)
|
||||
traderConfig.HyperliquidWalletAddr = exchangeCfg.HyperliquidWalletAddr
|
||||
case "aster":
|
||||
traderConfig.AsterUser = exchangeCfg.AsterUser
|
||||
traderConfig.AsterSigner = exchangeCfg.AsterSigner
|
||||
traderConfig.AsterPrivateKey = exchangeCfg.AsterPrivateKey
|
||||
traderConfig.AsterPrivateKey = string(exchangeCfg.AsterPrivateKey)
|
||||
case "lighter":
|
||||
traderConfig.LighterPrivateKey = exchangeCfg.LighterPrivateKey
|
||||
traderConfig.LighterPrivateKey = string(exchangeCfg.LighterPrivateKey)
|
||||
traderConfig.LighterWalletAddr = exchangeCfg.LighterWalletAddr
|
||||
traderConfig.LighterAPIKeyPrivateKey = exchangeCfg.LighterAPIKeyPrivateKey
|
||||
traderConfig.LighterAPIKeyPrivateKey = string(exchangeCfg.LighterAPIKeyPrivateKey)
|
||||
traderConfig.LighterAPIKeyIndex = exchangeCfg.LighterAPIKeyIndex
|
||||
traderConfig.LighterTestnet = exchangeCfg.Testnet
|
||||
}
|
||||
|
||||
// Set API keys based on AI model
|
||||
// Set API keys based on AI model (convert EncryptedString to string)
|
||||
switch aiModelCfg.Provider {
|
||||
case "qwen":
|
||||
traderConfig.QwenKey = aiModelCfg.APIKey
|
||||
traderConfig.QwenKey = string(aiModelCfg.APIKey)
|
||||
case "deepseek":
|
||||
traderConfig.DeepSeekKey = aiModelCfg.APIKey
|
||||
traderConfig.DeepSeekKey = string(aiModelCfg.APIKey)
|
||||
default:
|
||||
// For other providers (grok, openai, claude, gemini, kimi, etc.), use CustomAPIKey
|
||||
traderConfig.CustomAPIKey = aiModelCfg.APIKey
|
||||
traderConfig.CustomAPIKey = string(aiModelCfg.APIKey)
|
||||
}
|
||||
|
||||
// Create trader instance
|
||||
|
||||
+88
-172
@@ -1,123 +1,66 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"nofx/crypto"
|
||||
"nofx/logger"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AIModelStore AI model storage
|
||||
type AIModelStore struct {
|
||||
db *sql.DB
|
||||
encryptFunc func(string) string
|
||||
decryptFunc func(string) string
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// AIModel AI model configuration
|
||||
type AIModel struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"apiKey"`
|
||||
CustomAPIURL string `json:"customApiUrl"`
|
||||
CustomModelName string `json:"customModelName"`
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
UserID string `gorm:"column:user_id;not null;default:default;index" json:"user_id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
Provider string `gorm:"not null" json:"provider"`
|
||||
Enabled bool `gorm:"default:false" json:"enabled"`
|
||||
APIKey crypto.EncryptedString `gorm:"column:api_key;default:''" json:"apiKey"`
|
||||
CustomAPIURL string `gorm:"column:custom_api_url;default:''" json:"customApiUrl"`
|
||||
CustomModelName string `gorm:"column:custom_model_name;default:''" json:"customModelName"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
func (AIModel) TableName() string { return "ai_models" }
|
||||
|
||||
// NewAIModelStore creates a new AIModelStore
|
||||
func NewAIModelStore(db *gorm.DB) *AIModelStore {
|
||||
return &AIModelStore{db: db}
|
||||
}
|
||||
|
||||
func (s *AIModelStore) initTables() error {
|
||||
_, err := s.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS ai_models (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL DEFAULT 'default',
|
||||
name TEXT NOT NULL,
|
||||
provider TEXT NOT NULL,
|
||||
enabled BOOLEAN DEFAULT 0,
|
||||
api_key TEXT DEFAULT '',
|
||||
custom_api_url TEXT DEFAULT '',
|
||||
custom_model_name TEXT DEFAULT '',
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Trigger
|
||||
_, err = s.db.Exec(`
|
||||
CREATE TRIGGER IF NOT EXISTS update_ai_models_updated_at
|
||||
AFTER UPDATE ON ai_models
|
||||
BEGIN
|
||||
UPDATE ai_models SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
|
||||
END
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Backward compatibility: add potentially missing columns
|
||||
s.db.Exec(`ALTER TABLE ai_models ADD COLUMN custom_api_url TEXT DEFAULT ''`)
|
||||
s.db.Exec(`ALTER TABLE ai_models ADD COLUMN custom_model_name TEXT DEFAULT ''`)
|
||||
|
||||
// For PostgreSQL with existing table, skip AutoMigrate
|
||||
if s.db.Dialector.Name() == "postgres" {
|
||||
var tableExists int64
|
||||
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'ai_models'`).Scan(&tableExists)
|
||||
if tableExists > 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return s.db.AutoMigrate(&AIModel{})
|
||||
}
|
||||
|
||||
func (s *AIModelStore) initDefaultData() error {
|
||||
// No longer pre-populate AI models - create on demand when user configures
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AIModelStore) encrypt(plaintext string) string {
|
||||
if s.encryptFunc != nil {
|
||||
return s.encryptFunc(plaintext)
|
||||
}
|
||||
return plaintext
|
||||
}
|
||||
|
||||
func (s *AIModelStore) decrypt(encrypted string) string {
|
||||
if s.decryptFunc != nil {
|
||||
return s.decryptFunc(encrypted)
|
||||
}
|
||||
return encrypted
|
||||
}
|
||||
|
||||
// List retrieves user's AI model list
|
||||
func (s *AIModelStore) List(userID string) ([]*AIModel, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, user_id, name, provider, enabled, api_key,
|
||||
COALESCE(custom_api_url, '') as custom_api_url,
|
||||
COALESCE(custom_model_name, '') as custom_model_name,
|
||||
created_at, updated_at
|
||||
FROM ai_models WHERE user_id = ? ORDER BY id
|
||||
`, userID)
|
||||
var models []*AIModel
|
||||
err := s.db.Where("user_id = ?", userID).Order("id").Find(&models).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
models := make([]*AIModel, 0)
|
||||
for rows.Next() {
|
||||
var model AIModel
|
||||
var createdAt, updatedAt string
|
||||
err := rows.Scan(
|
||||
&model.ID, &model.UserID, &model.Name, &model.Provider,
|
||||
&model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName,
|
||||
&createdAt, &updatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
model.APIKey = s.decrypt(model.APIKey)
|
||||
models = append(models, &model)
|
||||
}
|
||||
return models, nil
|
||||
}
|
||||
|
||||
@@ -140,27 +83,15 @@ func (s *AIModelStore) Get(userID, modelID string) (*AIModel, error) {
|
||||
|
||||
for _, uid := range candidates {
|
||||
var model AIModel
|
||||
var createdAt, updatedAt string
|
||||
err := s.db.QueryRow(`
|
||||
SELECT id, user_id, name, provider, enabled, api_key,
|
||||
COALESCE(custom_api_url, ''), COALESCE(custom_model_name, ''), created_at, updated_at
|
||||
FROM ai_models WHERE user_id = ? AND id = ? LIMIT 1
|
||||
`, uid, modelID).Scan(
|
||||
&model.ID, &model.UserID, &model.Name, &model.Provider,
|
||||
&model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName,
|
||||
&createdAt, &updatedAt,
|
||||
)
|
||||
err := s.db.Where("user_id = ? AND id = ?", uid, modelID).First(&model).Error
|
||||
if err == nil {
|
||||
model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
model.APIKey = s.decrypt(model.APIKey)
|
||||
return &model, nil
|
||||
}
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nil, sql.ErrNoRows
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
// GetByID retrieves an AI model by ID only (for debate engine)
|
||||
@@ -170,22 +101,10 @@ func (s *AIModelStore) GetByID(modelID string) (*AIModel, error) {
|
||||
}
|
||||
|
||||
var model AIModel
|
||||
var createdAt, updatedAt string
|
||||
err := s.db.QueryRow(`
|
||||
SELECT id, user_id, name, provider, enabled, api_key,
|
||||
COALESCE(custom_api_url, ''), COALESCE(custom_model_name, ''), created_at, updated_at
|
||||
FROM ai_models WHERE id = ? LIMIT 1
|
||||
`, modelID).Scan(
|
||||
&model.ID, &model.UserID, &model.Name, &model.Provider,
|
||||
&model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName,
|
||||
&createdAt, &updatedAt,
|
||||
)
|
||||
err := s.db.Where("id = ?", modelID).First(&model).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
model.APIKey = s.decrypt(model.APIKey)
|
||||
return &model, nil
|
||||
}
|
||||
|
||||
@@ -198,7 +117,7 @@ func (s *AIModelStore) GetDefault(userID string) (*AIModel, error) {
|
||||
if err == nil {
|
||||
return model, nil
|
||||
}
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
if userID != "default" {
|
||||
@@ -209,23 +128,12 @@ func (s *AIModelStore) GetDefault(userID string) (*AIModel, error) {
|
||||
|
||||
func (s *AIModelStore) firstEnabled(userID string) (*AIModel, error) {
|
||||
var model AIModel
|
||||
var createdAt, updatedAt string
|
||||
err := s.db.QueryRow(`
|
||||
SELECT id, user_id, name, provider, enabled, api_key,
|
||||
COALESCE(custom_api_url, ''), COALESCE(custom_model_name, ''), created_at, updated_at
|
||||
FROM ai_models WHERE user_id = ? AND enabled = 1
|
||||
ORDER BY datetime(updated_at) DESC, id ASC LIMIT 1
|
||||
`, userID).Scan(
|
||||
&model.ID, &model.UserID, &model.Name, &model.Provider,
|
||||
&model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName,
|
||||
&createdAt, &updatedAt,
|
||||
)
|
||||
err := s.db.Where("user_id = ? AND enabled = ?", userID, true).
|
||||
Order("updated_at DESC, id ASC").
|
||||
First(&model).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
model.APIKey = s.decrypt(model.APIKey)
|
||||
return &model, nil
|
||||
}
|
||||
|
||||
@@ -233,44 +141,38 @@ func (s *AIModelStore) firstEnabled(userID string) (*AIModel, error) {
|
||||
// IMPORTANT: If apiKey is empty string, the existing API key will be preserved (not overwritten)
|
||||
func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPIURL, customModelName string) error {
|
||||
// Try exact ID match first
|
||||
var existingID string
|
||||
err := s.db.QueryRow(`SELECT id FROM ai_models WHERE user_id = ? AND id = ? LIMIT 1`, userID, id).Scan(&existingID)
|
||||
var existingModel AIModel
|
||||
err := s.db.Where("user_id = ? AND id = ?", userID, id).First(&existingModel).Error
|
||||
if err == nil {
|
||||
// If apiKey is empty, preserve the existing API key
|
||||
if apiKey == "" {
|
||||
_, err = s.db.Exec(`
|
||||
UPDATE ai_models SET enabled = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now')
|
||||
WHERE id = ? AND user_id = ?
|
||||
`, enabled, customAPIURL, customModelName, existingID, userID)
|
||||
} else {
|
||||
encryptedAPIKey := s.encrypt(apiKey)
|
||||
_, err = s.db.Exec(`
|
||||
UPDATE ai_models SET enabled = ?, api_key = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now')
|
||||
WHERE id = ? AND user_id = ?
|
||||
`, enabled, encryptedAPIKey, customAPIURL, customModelName, existingID, userID)
|
||||
// Update existing model
|
||||
updates := map[string]interface{}{
|
||||
"enabled": enabled,
|
||||
"custom_api_url": customAPIURL,
|
||||
"custom_model_name": customModelName,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
return err
|
||||
// If apiKey is not empty, update it (encryption handled by crypto.EncryptedString)
|
||||
if apiKey != "" {
|
||||
updates["api_key"] = crypto.EncryptedString(apiKey)
|
||||
}
|
||||
return s.db.Model(&existingModel).Updates(updates).Error
|
||||
}
|
||||
|
||||
// Try legacy logic compatibility: use id as provider to search
|
||||
provider := id
|
||||
err = s.db.QueryRow(`SELECT id FROM ai_models WHERE user_id = ? AND provider = ? LIMIT 1`, userID, provider).Scan(&existingID)
|
||||
err = s.db.Where("user_id = ? AND provider = ?", userID, provider).First(&existingModel).Error
|
||||
if err == nil {
|
||||
logger.Warnf("⚠️ Using legacy provider matching to update model: %s -> %s", provider, existingID)
|
||||
// If apiKey is empty, preserve the existing API key
|
||||
if apiKey == "" {
|
||||
_, err = s.db.Exec(`
|
||||
UPDATE ai_models SET enabled = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now')
|
||||
WHERE id = ? AND user_id = ?
|
||||
`, enabled, customAPIURL, customModelName, existingID, userID)
|
||||
} else {
|
||||
encryptedAPIKey := s.encrypt(apiKey)
|
||||
_, err = s.db.Exec(`
|
||||
UPDATE ai_models SET enabled = ?, api_key = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now')
|
||||
WHERE id = ? AND user_id = ?
|
||||
`, enabled, encryptedAPIKey, customAPIURL, customModelName, existingID, userID)
|
||||
logger.Warnf("⚠️ Using legacy provider matching to update model: %s -> %s", provider, existingModel.ID)
|
||||
updates := map[string]interface{}{
|
||||
"enabled": enabled,
|
||||
"custom_api_url": customAPIURL,
|
||||
"custom_model_name": customModelName,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
return err
|
||||
if apiKey != "" {
|
||||
updates["api_key"] = crypto.EncryptedString(apiKey)
|
||||
}
|
||||
return s.db.Model(&existingModel).Updates(updates).Error
|
||||
}
|
||||
|
||||
// Create new record
|
||||
@@ -285,9 +187,12 @@ func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPI
|
||||
}
|
||||
}
|
||||
|
||||
// Try to get name from existing model with same provider
|
||||
var refModel AIModel
|
||||
var name string
|
||||
err = s.db.QueryRow(`SELECT name FROM ai_models WHERE provider = ? LIMIT 1`, provider).Scan(&name)
|
||||
if err != nil {
|
||||
if err := s.db.Where("provider = ?", provider).First(&refModel).Error; err == nil {
|
||||
name = refModel.Name
|
||||
} else {
|
||||
if provider == "deepseek" {
|
||||
name = "DeepSeek AI"
|
||||
} else if provider == "qwen" {
|
||||
@@ -303,19 +208,30 @@ func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPI
|
||||
}
|
||||
|
||||
logger.Infof("✓ Creating new AI model configuration: ID=%s, Provider=%s, Name=%s", newModelID, provider, name)
|
||||
encryptedAPIKey := s.encrypt(apiKey)
|
||||
_, err = s.db.Exec(`
|
||||
INSERT INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url, custom_model_name, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now'))
|
||||
`, newModelID, userID, name, provider, enabled, encryptedAPIKey, customAPIURL, customModelName)
|
||||
return err
|
||||
newModel := &AIModel{
|
||||
ID: newModelID,
|
||||
UserID: userID,
|
||||
Name: name,
|
||||
Provider: provider,
|
||||
Enabled: enabled,
|
||||
APIKey: crypto.EncryptedString(apiKey),
|
||||
CustomAPIURL: customAPIURL,
|
||||
CustomModelName: customModelName,
|
||||
}
|
||||
return s.db.Create(newModel).Error
|
||||
}
|
||||
|
||||
// Create creates an AI model
|
||||
func (s *AIModelStore) Create(userID, id, name, provider string, enabled bool, apiKey, customAPIURL string) error {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT OR IGNORE INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
`, id, userID, name, provider, enabled, apiKey, customAPIURL)
|
||||
return err
|
||||
model := &AIModel{
|
||||
ID: id,
|
||||
UserID: userID,
|
||||
Name: name,
|
||||
Provider: provider,
|
||||
Enabled: enabled,
|
||||
APIKey: crypto.EncryptedString(apiKey),
|
||||
CustomAPIURL: customAPIURL,
|
||||
}
|
||||
// Use FirstOrCreate to ignore if already exists
|
||||
return s.db.Where("id = ?", id).FirstOrCreate(model).Error
|
||||
}
|
||||
|
||||
+339
-351
@@ -1,15 +1,26 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// BacktestStore backtest data storage
|
||||
type BacktestStore struct {
|
||||
db *sql.DB
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewBacktestStore creates a new backtest store
|
||||
func NewBacktestStore(db *gorm.DB) *BacktestStore {
|
||||
return &BacktestStore{db: db}
|
||||
}
|
||||
|
||||
// isPostgres checks if the database is PostgreSQL
|
||||
func (s *BacktestStore) isPostgres() bool {
|
||||
return s.db.Dialector.Name() == "postgres"
|
||||
}
|
||||
|
||||
// RunState backtest state
|
||||
@@ -92,492 +103,469 @@ type RunIndexEntry struct {
|
||||
UpdatedAtISO string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// BacktestRun GORM model for backtest_runs table
|
||||
type BacktestRun struct {
|
||||
RunID string `gorm:"column:run_id;primaryKey"`
|
||||
UserID string `gorm:"column:user_id;not null;default:''"`
|
||||
ConfigJSON []byte `gorm:"column:config_json"`
|
||||
State string `gorm:"column:state;not null;default:created"`
|
||||
Label string `gorm:"column:label;default:''"`
|
||||
SymbolCount int `gorm:"column:symbol_count;default:0"`
|
||||
DecisionTF string `gorm:"column:decision_tf;default:''"`
|
||||
ProcessedBars int `gorm:"column:processed_bars;default:0"`
|
||||
ProgressPct float64 `gorm:"column:progress_pct;default:0"`
|
||||
EquityLast float64 `gorm:"column:equity_last;default:0"`
|
||||
MaxDrawdownPct float64 `gorm:"column:max_drawdown_pct;default:0"`
|
||||
Liquidated bool `gorm:"column:liquidated;default:false"`
|
||||
LiquidationNote string `gorm:"column:liquidation_note;default:''"`
|
||||
PromptTemplate string `gorm:"column:prompt_template;default:''"`
|
||||
CustomPrompt string `gorm:"column:custom_prompt;default:''"`
|
||||
OverridePrompt bool `gorm:"column:override_prompt;default:false"`
|
||||
AIProvider string `gorm:"column:ai_provider;default:''"`
|
||||
AIModel string `gorm:"column:ai_model;default:''"`
|
||||
LastError string `gorm:"column:last_error;default:''"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime"`
|
||||
}
|
||||
|
||||
func (BacktestRun) TableName() string {
|
||||
return "backtest_runs"
|
||||
}
|
||||
|
||||
// BacktestCheckpoint GORM model
|
||||
type BacktestCheckpoint struct {
|
||||
RunID string `gorm:"column:run_id;primaryKey"`
|
||||
Payload []byte `gorm:"column:payload;not null"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime"`
|
||||
}
|
||||
|
||||
func (BacktestCheckpoint) TableName() string {
|
||||
return "backtest_checkpoints"
|
||||
}
|
||||
|
||||
// BacktestEquity GORM model
|
||||
type BacktestEquity struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
RunID string `gorm:"column:run_id;not null;index:idx_backtest_equity_run_ts"`
|
||||
TS int64 `gorm:"column:ts;not null;index:idx_backtest_equity_run_ts"`
|
||||
Equity float64 `gorm:"column:equity;not null"`
|
||||
Available float64 `gorm:"column:available;not null"`
|
||||
PnL float64 `gorm:"column:pnl;not null"`
|
||||
PnLPct float64 `gorm:"column:pnl_pct;not null"`
|
||||
DDPct float64 `gorm:"column:dd_pct;not null"`
|
||||
Cycle int `gorm:"column:cycle;not null"`
|
||||
}
|
||||
|
||||
func (BacktestEquity) TableName() string {
|
||||
return "backtest_equity"
|
||||
}
|
||||
|
||||
// BacktestTrade GORM model
|
||||
type BacktestTrade struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
RunID string `gorm:"column:run_id;not null;index:idx_backtest_trades_run_ts"`
|
||||
TS int64 `gorm:"column:ts;not null;index:idx_backtest_trades_run_ts"`
|
||||
Symbol string `gorm:"column:symbol;not null"`
|
||||
Action string `gorm:"column:action;not null"`
|
||||
Side string `gorm:"column:side;default:''"`
|
||||
Qty float64 `gorm:"column:qty;default:0"`
|
||||
Price float64 `gorm:"column:price;default:0"`
|
||||
Fee float64 `gorm:"column:fee;default:0"`
|
||||
Slippage float64 `gorm:"column:slippage;default:0"`
|
||||
OrderValue float64 `gorm:"column:order_value;default:0"`
|
||||
RealizedPnL float64 `gorm:"column:realized_pnl;default:0"`
|
||||
Leverage int `gorm:"column:leverage;default:0"`
|
||||
Cycle int `gorm:"column:cycle;default:0"`
|
||||
PositionAfter float64 `gorm:"column:position_after;default:0"`
|
||||
Liquidation bool `gorm:"column:liquidation;default:false"`
|
||||
Note string `gorm:"column:note;default:''"`
|
||||
}
|
||||
|
||||
func (BacktestTrade) TableName() string {
|
||||
return "backtest_trades"
|
||||
}
|
||||
|
||||
// BacktestMetrics GORM model
|
||||
type BacktestMetrics struct {
|
||||
RunID string `gorm:"column:run_id;primaryKey"`
|
||||
Payload []byte `gorm:"column:payload;not null"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime"`
|
||||
}
|
||||
|
||||
func (BacktestMetrics) TableName() string {
|
||||
return "backtest_metrics"
|
||||
}
|
||||
|
||||
// BacktestDecision GORM model
|
||||
type BacktestDecision struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
RunID string `gorm:"column:run_id;not null;index:idx_backtest_decisions_run_cycle"`
|
||||
Cycle int `gorm:"column:cycle;not null;index:idx_backtest_decisions_run_cycle"`
|
||||
Payload []byte `gorm:"column:payload;not null"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime"`
|
||||
}
|
||||
|
||||
func (BacktestDecision) TableName() string {
|
||||
return "backtest_decisions"
|
||||
}
|
||||
|
||||
// initTables initializes backtest related tables
|
||||
func (s *BacktestStore) initTables() error {
|
||||
queries := []string{
|
||||
// Backtest runs main table
|
||||
`CREATE TABLE IF NOT EXISTS backtest_runs (
|
||||
run_id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL DEFAULT '',
|
||||
config_json TEXT NOT NULL DEFAULT '',
|
||||
state TEXT NOT NULL DEFAULT 'created',
|
||||
label TEXT DEFAULT '',
|
||||
symbol_count INTEGER DEFAULT 0,
|
||||
decision_tf TEXT DEFAULT '',
|
||||
processed_bars INTEGER DEFAULT 0,
|
||||
progress_pct REAL DEFAULT 0,
|
||||
equity_last REAL DEFAULT 0,
|
||||
max_drawdown_pct REAL DEFAULT 0,
|
||||
liquidated BOOLEAN DEFAULT 0,
|
||||
liquidation_note TEXT DEFAULT '',
|
||||
prompt_template TEXT DEFAULT '',
|
||||
custom_prompt TEXT DEFAULT '',
|
||||
override_prompt BOOLEAN DEFAULT 0,
|
||||
ai_provider TEXT DEFAULT '',
|
||||
ai_model TEXT DEFAULT '',
|
||||
last_error TEXT DEFAULT '',
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)`,
|
||||
// For PostgreSQL with existing tables, skip AutoMigrate to avoid type conflicts
|
||||
if s.db.Dialector.Name() == "postgres" {
|
||||
var tableExists int64
|
||||
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'backtest_runs'`).Scan(&tableExists)
|
||||
|
||||
// Backtest checkpoints
|
||||
`CREATE TABLE IF NOT EXISTS backtest_checkpoints (
|
||||
run_id TEXT PRIMARY KEY,
|
||||
payload BLOB NOT NULL,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
|
||||
)`,
|
||||
|
||||
// Backtest equity curve
|
||||
`CREATE TABLE IF NOT EXISTS backtest_equity (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
run_id TEXT NOT NULL,
|
||||
ts INTEGER NOT NULL,
|
||||
equity REAL NOT NULL,
|
||||
available REAL NOT NULL,
|
||||
pnl REAL NOT NULL,
|
||||
pnl_pct REAL NOT NULL,
|
||||
dd_pct REAL NOT NULL,
|
||||
cycle INTEGER NOT NULL,
|
||||
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
|
||||
)`,
|
||||
|
||||
// Backtest trade records
|
||||
`CREATE TABLE IF NOT EXISTS backtest_trades (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
run_id TEXT NOT NULL,
|
||||
ts INTEGER NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
side TEXT DEFAULT '',
|
||||
qty REAL DEFAULT 0,
|
||||
price REAL DEFAULT 0,
|
||||
fee REAL DEFAULT 0,
|
||||
slippage REAL DEFAULT 0,
|
||||
order_value REAL DEFAULT 0,
|
||||
realized_pnl REAL DEFAULT 0,
|
||||
leverage INTEGER DEFAULT 0,
|
||||
cycle INTEGER DEFAULT 0,
|
||||
position_after REAL DEFAULT 0,
|
||||
liquidation BOOLEAN DEFAULT 0,
|
||||
note TEXT DEFAULT '',
|
||||
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
|
||||
)`,
|
||||
|
||||
// Backtest metrics
|
||||
`CREATE TABLE IF NOT EXISTS backtest_metrics (
|
||||
run_id TEXT PRIMARY KEY,
|
||||
payload BLOB NOT NULL,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
|
||||
)`,
|
||||
|
||||
// Backtest decision logs
|
||||
`CREATE TABLE IF NOT EXISTS backtest_decisions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
run_id TEXT NOT NULL,
|
||||
cycle INTEGER NOT NULL,
|
||||
payload BLOB NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
|
||||
)`,
|
||||
|
||||
// Indexes
|
||||
`CREATE INDEX IF NOT EXISTS idx_backtest_runs_state ON backtest_runs(state, updated_at)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_backtest_equity_run_ts ON backtest_equity(run_id, ts)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_backtest_trades_run_ts ON backtest_trades(run_id, ts)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_backtest_decisions_run_cycle ON backtest_decisions(run_id, cycle)`,
|
||||
}
|
||||
|
||||
for _, query := range queries {
|
||||
if _, err := s.db.Exec(query); err != nil {
|
||||
return fmt.Errorf("failed to execute SQL: %w", err)
|
||||
if tableExists > 0 {
|
||||
// Tables exist - just ensure indexes exist
|
||||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_backtest_equity_run_ts ON backtest_equity(run_id, ts)`)
|
||||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_backtest_trades_run_ts ON backtest_trades(run_id, ts)`)
|
||||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_backtest_decisions_run_cycle ON backtest_decisions(run_id, cycle)`)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Add potentially missing columns (backward compatibility)
|
||||
s.addColumnIfNotExists("backtest_runs", "label", "TEXT DEFAULT ''")
|
||||
s.addColumnIfNotExists("backtest_runs", "last_error", "TEXT DEFAULT ''")
|
||||
s.addColumnIfNotExists("backtest_trades", "leverage", "INTEGER DEFAULT 0")
|
||||
// AutoMigrate all backtest tables
|
||||
if err := s.db.AutoMigrate(
|
||||
&BacktestRun{},
|
||||
&BacktestCheckpoint{},
|
||||
&BacktestEquity{},
|
||||
&BacktestTrade{},
|
||||
&BacktestMetrics{},
|
||||
&BacktestDecision{},
|
||||
); err != nil {
|
||||
return fmt.Errorf("failed to migrate backtest tables: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BacktestStore) addColumnIfNotExists(table, column, definition string) {
|
||||
rows, err := s.db.Query(fmt.Sprintf("PRAGMA table_info(%s)", table))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var cid int
|
||||
var name, ctype string
|
||||
var notnull, pk int
|
||||
var dflt interface{}
|
||||
if err := rows.Scan(&cid, &name, &ctype, ¬null, &dflt, &pk); err != nil {
|
||||
continue
|
||||
}
|
||||
if name == column {
|
||||
return // Column already exists
|
||||
}
|
||||
}
|
||||
|
||||
s.db.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", table, column, definition))
|
||||
}
|
||||
|
||||
// SaveCheckpoint saves checkpoint
|
||||
func (s *BacktestStore) SaveCheckpoint(runID string, payload []byte) error {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO backtest_checkpoints (run_id, payload, updated_at)
|
||||
VALUES (?, ?, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT(run_id) DO UPDATE SET payload=excluded.payload, updated_at=CURRENT_TIMESTAMP
|
||||
`, runID, payload)
|
||||
return err
|
||||
checkpoint := BacktestCheckpoint{
|
||||
RunID: runID,
|
||||
Payload: payload,
|
||||
}
|
||||
return s.db.Save(&checkpoint).Error
|
||||
}
|
||||
|
||||
// LoadCheckpoint loads checkpoint
|
||||
func (s *BacktestStore) LoadCheckpoint(runID string) ([]byte, error) {
|
||||
var payload []byte
|
||||
err := s.db.QueryRow(`SELECT payload FROM backtest_checkpoints WHERE run_id = ?`, runID).Scan(&payload)
|
||||
return payload, err
|
||||
var checkpoint BacktestCheckpoint
|
||||
err := s.db.Where("run_id = ?", runID).First(&checkpoint).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return checkpoint.Payload, nil
|
||||
}
|
||||
|
||||
// SaveRunMetadata saves run metadata
|
||||
func (s *BacktestStore) SaveRunMetadata(meta *RunMetadata) error {
|
||||
created := meta.CreatedAt.UTC().Format(time.RFC3339)
|
||||
updated := meta.UpdatedAt.UTC().Format(time.RFC3339)
|
||||
userID := meta.UserID
|
||||
|
||||
if _, err := s.db.Exec(`
|
||||
INSERT INTO backtest_runs (run_id, user_id, label, last_error, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(run_id) DO NOTHING
|
||||
`, meta.RunID, userID, meta.Label, meta.LastError, created, updated); err != nil {
|
||||
return err
|
||||
run := BacktestRun{
|
||||
RunID: meta.RunID,
|
||||
UserID: meta.UserID,
|
||||
State: string(meta.State),
|
||||
Label: meta.Label,
|
||||
LastError: meta.LastError,
|
||||
SymbolCount: meta.Summary.SymbolCount,
|
||||
DecisionTF: meta.Summary.DecisionTF,
|
||||
ProcessedBars: meta.Summary.ProcessedBars,
|
||||
ProgressPct: meta.Summary.ProgressPct,
|
||||
EquityLast: meta.Summary.EquityLast,
|
||||
MaxDrawdownPct: meta.Summary.MaxDrawdownPct,
|
||||
Liquidated: meta.Summary.Liquidated,
|
||||
LiquidationNote: meta.Summary.LiquidationNote,
|
||||
CreatedAt: meta.CreatedAt,
|
||||
UpdatedAt: meta.UpdatedAt,
|
||||
}
|
||||
|
||||
_, err := s.db.Exec(`
|
||||
UPDATE backtest_runs
|
||||
SET user_id = ?, state = ?, symbol_count = ?, decision_tf = ?, processed_bars = ?,
|
||||
progress_pct = ?, equity_last = ?, max_drawdown_pct = ?, liquidated = ?,
|
||||
liquidation_note = ?, label = ?, last_error = ?, updated_at = ?
|
||||
WHERE run_id = ?
|
||||
`, userID, string(meta.State), meta.Summary.SymbolCount, meta.Summary.DecisionTF,
|
||||
meta.Summary.ProcessedBars, meta.Summary.ProgressPct, meta.Summary.EquityLast,
|
||||
meta.Summary.MaxDrawdownPct, meta.Summary.Liquidated, meta.Summary.LiquidationNote,
|
||||
meta.Label, meta.LastError, updated, meta.RunID)
|
||||
return err
|
||||
return s.db.Save(&run).Error
|
||||
}
|
||||
|
||||
// LoadRunMetadata loads run metadata
|
||||
func (s *BacktestStore) LoadRunMetadata(runID string) (*RunMetadata, error) {
|
||||
var (
|
||||
userID string
|
||||
state string
|
||||
label string
|
||||
lastErr string
|
||||
symbolCount int
|
||||
decisionTF string
|
||||
processedBars int
|
||||
progressPct float64
|
||||
equityLast float64
|
||||
maxDD float64
|
||||
liquidated bool
|
||||
liquidationNote string
|
||||
createdISO string
|
||||
updatedISO string
|
||||
)
|
||||
|
||||
err := s.db.QueryRow(`
|
||||
SELECT user_id, state, label, last_error, symbol_count, decision_tf, processed_bars,
|
||||
progress_pct, equity_last, max_drawdown_pct, liquidated, liquidation_note,
|
||||
created_at, updated_at
|
||||
FROM backtest_runs WHERE run_id = ?
|
||||
`, runID).Scan(&userID, &state, &label, &lastErr, &symbolCount, &decisionTF,
|
||||
&processedBars, &progressPct, &equityLast, &maxDD, &liquidated, &liquidationNote,
|
||||
&createdISO, &updatedISO)
|
||||
var run BacktestRun
|
||||
err := s.db.Where("run_id = ?", runID).First(&run).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
meta := &RunMetadata{
|
||||
RunID: runID,
|
||||
UserID: userID,
|
||||
return &RunMetadata{
|
||||
RunID: run.RunID,
|
||||
UserID: run.UserID,
|
||||
Version: 1,
|
||||
State: RunState(state),
|
||||
Label: label,
|
||||
LastError: lastErr,
|
||||
State: RunState(run.State),
|
||||
Label: run.Label,
|
||||
LastError: run.LastError,
|
||||
Summary: RunSummary{
|
||||
SymbolCount: symbolCount,
|
||||
DecisionTF: decisionTF,
|
||||
ProcessedBars: processedBars,
|
||||
ProgressPct: progressPct,
|
||||
EquityLast: equityLast,
|
||||
MaxDrawdownPct: maxDD,
|
||||
Liquidated: liquidated,
|
||||
LiquidationNote: liquidationNote,
|
||||
SymbolCount: run.SymbolCount,
|
||||
DecisionTF: run.DecisionTF,
|
||||
ProcessedBars: run.ProcessedBars,
|
||||
ProgressPct: run.ProgressPct,
|
||||
EquityLast: run.EquityLast,
|
||||
MaxDrawdownPct: run.MaxDrawdownPct,
|
||||
Liquidated: run.Liquidated,
|
||||
LiquidationNote: run.LiquidationNote,
|
||||
},
|
||||
}
|
||||
|
||||
meta.CreatedAt, _ = time.Parse(time.RFC3339, createdISO)
|
||||
meta.UpdatedAt, _ = time.Parse(time.RFC3339, updatedISO)
|
||||
|
||||
return meta, nil
|
||||
CreatedAt: run.CreatedAt,
|
||||
UpdatedAt: run.UpdatedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListRunIDs lists all run IDs
|
||||
func (s *BacktestStore) ListRunIDs() ([]string, error) {
|
||||
rows, err := s.db.Query(`SELECT run_id FROM backtest_runs ORDER BY datetime(updated_at) DESC`)
|
||||
var runs []BacktestRun
|
||||
err := s.db.Order("updated_at DESC").Find(&runs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var ids []string
|
||||
for rows.Next() {
|
||||
var runID string
|
||||
if err := rows.Scan(&runID); err != nil {
|
||||
return nil, err
|
||||
ids := make([]string, len(runs))
|
||||
for i, run := range runs {
|
||||
ids[i] = run.RunID
|
||||
}
|
||||
ids = append(ids, runID)
|
||||
}
|
||||
return ids, rows.Err()
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// AppendEquityPoint appends equity point
|
||||
func (s *BacktestStore) AppendEquityPoint(runID string, point EquityPoint) error {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO backtest_equity (run_id, ts, equity, available, pnl, pnl_pct, dd_pct, cycle)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`, runID, point.Timestamp, point.Equity, point.Available, point.PnL,
|
||||
point.PnLPct, point.DrawdownPct, point.Cycle)
|
||||
return err
|
||||
eq := BacktestEquity{
|
||||
RunID: runID,
|
||||
TS: point.Timestamp,
|
||||
Equity: point.Equity,
|
||||
Available: point.Available,
|
||||
PnL: point.PnL,
|
||||
PnLPct: point.PnLPct,
|
||||
DDPct: point.DrawdownPct,
|
||||
Cycle: point.Cycle,
|
||||
}
|
||||
return s.db.Create(&eq).Error
|
||||
}
|
||||
|
||||
// LoadEquityPoints loads equity points
|
||||
func (s *BacktestStore) LoadEquityPoints(runID string) ([]EquityPoint, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT ts, equity, available, pnl, pnl_pct, dd_pct, cycle
|
||||
FROM backtest_equity WHERE run_id = ? ORDER BY ts ASC
|
||||
`, runID)
|
||||
var eqs []BacktestEquity
|
||||
err := s.db.Where("run_id = ?", runID).Order("ts ASC").Find(&eqs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
points := make([]EquityPoint, 0)
|
||||
for rows.Next() {
|
||||
var point EquityPoint
|
||||
if err := rows.Scan(&point.Timestamp, &point.Equity, &point.Available,
|
||||
&point.PnL, &point.PnLPct, &point.DrawdownPct, &point.Cycle); err != nil {
|
||||
return nil, err
|
||||
points := make([]EquityPoint, len(eqs))
|
||||
for i, eq := range eqs {
|
||||
points[i] = EquityPoint{
|
||||
Timestamp: eq.TS,
|
||||
Equity: eq.Equity,
|
||||
Available: eq.Available,
|
||||
PnL: eq.PnL,
|
||||
PnLPct: eq.PnLPct,
|
||||
DrawdownPct: eq.DDPct,
|
||||
Cycle: eq.Cycle,
|
||||
}
|
||||
points = append(points, point)
|
||||
}
|
||||
return points, rows.Err()
|
||||
return points, nil
|
||||
}
|
||||
|
||||
// AppendTradeEvent appends trade event
|
||||
func (s *BacktestStore) AppendTradeEvent(runID string, event TradeEvent) error {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO backtest_trades (run_id, ts, symbol, action, side, qty, price, fee,
|
||||
slippage, order_value, realized_pnl, leverage, cycle,
|
||||
position_after, liquidation, note)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`, runID, event.Timestamp, event.Symbol, event.Action, event.Side, event.Quantity,
|
||||
event.Price, event.Fee, event.Slippage, event.OrderValue, event.RealizedPnL,
|
||||
event.Leverage, event.Cycle, event.PositionAfter, event.LiquidationFlag, event.Note)
|
||||
return err
|
||||
trade := BacktestTrade{
|
||||
RunID: runID,
|
||||
TS: event.Timestamp,
|
||||
Symbol: event.Symbol,
|
||||
Action: event.Action,
|
||||
Side: event.Side,
|
||||
Qty: event.Quantity,
|
||||
Price: event.Price,
|
||||
Fee: event.Fee,
|
||||
Slippage: event.Slippage,
|
||||
OrderValue: event.OrderValue,
|
||||
RealizedPnL: event.RealizedPnL,
|
||||
Leverage: event.Leverage,
|
||||
Cycle: event.Cycle,
|
||||
PositionAfter: event.PositionAfter,
|
||||
Liquidation: event.LiquidationFlag,
|
||||
Note: event.Note,
|
||||
}
|
||||
return s.db.Create(&trade).Error
|
||||
}
|
||||
|
||||
// LoadTradeEvents loads trade events
|
||||
func (s *BacktestStore) LoadTradeEvents(runID string) ([]TradeEvent, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT ts, symbol, action, side, qty, price, fee, slippage, order_value,
|
||||
realized_pnl, leverage, cycle, position_after, liquidation, note
|
||||
FROM backtest_trades WHERE run_id = ? ORDER BY ts ASC
|
||||
`, runID)
|
||||
var trades []BacktestTrade
|
||||
err := s.db.Where("run_id = ?", runID).Order("ts ASC").Find(&trades).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
events := make([]TradeEvent, 0)
|
||||
for rows.Next() {
|
||||
var event TradeEvent
|
||||
if err := rows.Scan(&event.Timestamp, &event.Symbol, &event.Action, &event.Side,
|
||||
&event.Quantity, &event.Price, &event.Fee, &event.Slippage, &event.OrderValue,
|
||||
&event.RealizedPnL, &event.Leverage, &event.Cycle, &event.PositionAfter,
|
||||
&event.LiquidationFlag, &event.Note); err != nil {
|
||||
return nil, err
|
||||
events := make([]TradeEvent, len(trades))
|
||||
for i, trade := range trades {
|
||||
events[i] = TradeEvent{
|
||||
Timestamp: trade.TS,
|
||||
Symbol: trade.Symbol,
|
||||
Action: trade.Action,
|
||||
Side: trade.Side,
|
||||
Quantity: trade.Qty,
|
||||
Price: trade.Price,
|
||||
Fee: trade.Fee,
|
||||
Slippage: trade.Slippage,
|
||||
OrderValue: trade.OrderValue,
|
||||
RealizedPnL: trade.RealizedPnL,
|
||||
Leverage: trade.Leverage,
|
||||
Cycle: trade.Cycle,
|
||||
PositionAfter: trade.PositionAfter,
|
||||
LiquidationFlag: trade.Liquidation,
|
||||
Note: trade.Note,
|
||||
}
|
||||
events = append(events, event)
|
||||
}
|
||||
return events, rows.Err()
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// SaveMetrics saves metrics
|
||||
func (s *BacktestStore) SaveMetrics(runID string, payload []byte) error {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO backtest_metrics (run_id, payload, updated_at)
|
||||
VALUES (?, ?, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT(run_id) DO UPDATE SET payload=excluded.payload, updated_at=CURRENT_TIMESTAMP
|
||||
`, runID, payload)
|
||||
return err
|
||||
metrics := BacktestMetrics{
|
||||
RunID: runID,
|
||||
Payload: payload,
|
||||
}
|
||||
return s.db.Save(&metrics).Error
|
||||
}
|
||||
|
||||
// LoadMetrics loads metrics
|
||||
func (s *BacktestStore) LoadMetrics(runID string) ([]byte, error) {
|
||||
var payload []byte
|
||||
err := s.db.QueryRow(`SELECT payload FROM backtest_metrics WHERE run_id = ?`, runID).Scan(&payload)
|
||||
return payload, err
|
||||
var metrics BacktestMetrics
|
||||
err := s.db.Where("run_id = ?", runID).First(&metrics).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return metrics.Payload, nil
|
||||
}
|
||||
|
||||
// SaveDecisionRecord saves decision record
|
||||
func (s *BacktestStore) SaveDecisionRecord(runID string, cycle int, payload []byte) error {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO backtest_decisions (run_id, cycle, payload)
|
||||
VALUES (?, ?, ?)
|
||||
`, runID, cycle, payload)
|
||||
return err
|
||||
decision := BacktestDecision{
|
||||
RunID: runID,
|
||||
Cycle: cycle,
|
||||
Payload: payload,
|
||||
}
|
||||
return s.db.Create(&decision).Error
|
||||
}
|
||||
|
||||
// LoadDecisionRecords loads decision records
|
||||
func (s *BacktestStore) LoadDecisionRecords(runID string, limit, offset int) ([]json.RawMessage, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT payload FROM backtest_decisions
|
||||
WHERE run_id = ?
|
||||
ORDER BY id DESC
|
||||
LIMIT ? OFFSET ?
|
||||
`, runID, limit, offset)
|
||||
var decisions []BacktestDecision
|
||||
err := s.db.Where("run_id = ?", runID).
|
||||
Order("id DESC").
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&decisions).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
records := make([]json.RawMessage, 0, limit)
|
||||
for rows.Next() {
|
||||
var payload []byte
|
||||
if err := rows.Scan(&payload); err != nil {
|
||||
return nil, err
|
||||
records := make([]json.RawMessage, len(decisions))
|
||||
for i, d := range decisions {
|
||||
records[i] = json.RawMessage(d.Payload)
|
||||
}
|
||||
records = append(records, json.RawMessage(payload))
|
||||
}
|
||||
return records, rows.Err()
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// LoadLatestDecision loads latest decision
|
||||
func (s *BacktestStore) LoadLatestDecision(runID string, cycle int) ([]byte, error) {
|
||||
var query string
|
||||
var args []interface{}
|
||||
|
||||
var decision BacktestDecision
|
||||
query := s.db.Where("run_id = ?", runID)
|
||||
if cycle > 0 {
|
||||
query = `SELECT payload FROM backtest_decisions WHERE run_id = ? AND cycle = ? ORDER BY datetime(created_at) DESC LIMIT 1`
|
||||
args = []interface{}{runID, cycle}
|
||||
} else {
|
||||
query = `SELECT payload FROM backtest_decisions WHERE run_id = ? ORDER BY datetime(created_at) DESC LIMIT 1`
|
||||
args = []interface{}{runID}
|
||||
query = query.Where("cycle = ?", cycle)
|
||||
}
|
||||
|
||||
var payload []byte
|
||||
err := s.db.QueryRow(query, args...).Scan(&payload)
|
||||
return payload, err
|
||||
err := query.Order("created_at DESC").First(&decision).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decision.Payload, nil
|
||||
}
|
||||
|
||||
// UpdateProgress updates progress
|
||||
func (s *BacktestStore) UpdateProgress(runID string, progressPct, equity float64, barIndex int, liquidated bool) error {
|
||||
_, err := s.db.Exec(`
|
||||
UPDATE backtest_runs
|
||||
SET progress_pct = ?, equity_last = ?, processed_bars = ?, liquidated = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE run_id = ?
|
||||
`, progressPct, equity, barIndex, liquidated, runID)
|
||||
return err
|
||||
return s.db.Model(&BacktestRun{}).Where("run_id = ?", runID).Updates(map[string]interface{}{
|
||||
"progress_pct": progressPct,
|
||||
"equity_last": equity,
|
||||
"processed_bars": barIndex,
|
||||
"liquidated": liquidated,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// ListIndexEntries lists index entries
|
||||
func (s *BacktestStore) ListIndexEntries() ([]RunIndexEntry, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT run_id, state, symbol_count, decision_tf, equity_last, max_drawdown_pct,
|
||||
created_at, updated_at, config_json
|
||||
FROM backtest_runs
|
||||
ORDER BY datetime(updated_at) DESC
|
||||
`)
|
||||
var runs []BacktestRun
|
||||
err := s.db.Order("updated_at DESC").Find(&runs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var entries []RunIndexEntry
|
||||
for rows.Next() {
|
||||
var entry RunIndexEntry
|
||||
var symbolCnt int
|
||||
var cfgJSON []byte
|
||||
var createdISO, updatedISO string
|
||||
|
||||
if err := rows.Scan(&entry.RunID, &entry.State, &symbolCnt, &entry.DecisionTF,
|
||||
&entry.EquityLast, &entry.MaxDrawdownPct, &createdISO, &updatedISO, &cfgJSON); err != nil {
|
||||
return nil, err
|
||||
entries := make([]RunIndexEntry, len(runs))
|
||||
for i, run := range runs {
|
||||
entry := RunIndexEntry{
|
||||
RunID: run.RunID,
|
||||
State: run.State,
|
||||
DecisionTF: run.DecisionTF,
|
||||
EquityLast: run.EquityLast,
|
||||
MaxDrawdownPct: run.MaxDrawdownPct,
|
||||
CreatedAtISO: run.CreatedAt.Format(time.RFC3339),
|
||||
UpdatedAtISO: run.UpdatedAt.Format(time.RFC3339),
|
||||
Symbols: make([]string, 0, run.SymbolCount),
|
||||
}
|
||||
|
||||
entry.CreatedAtISO = createdISO
|
||||
entry.UpdatedAtISO = updatedISO
|
||||
entry.Symbols = make([]string, 0, symbolCnt)
|
||||
|
||||
// Try to extract more information from config
|
||||
if len(cfgJSON) > 0 {
|
||||
if len(run.ConfigJSON) > 0 {
|
||||
var cfg struct {
|
||||
Symbols []string `json:"symbols"`
|
||||
StartTS int64 `json:"start_ts"`
|
||||
EndTS int64 `json:"end_ts"`
|
||||
}
|
||||
if json.Unmarshal(cfgJSON, &cfg) == nil {
|
||||
if json.Unmarshal(run.ConfigJSON, &cfg) == nil {
|
||||
entry.Symbols = cfg.Symbols
|
||||
entry.StartTS = cfg.StartTS
|
||||
entry.EndTS = cfg.EndTS
|
||||
}
|
||||
}
|
||||
|
||||
entries = append(entries, entry)
|
||||
entries[i] = entry
|
||||
}
|
||||
return entries, rows.Err()
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// DeleteRun deletes run
|
||||
func (s *BacktestStore) DeleteRun(runID string) error {
|
||||
_, err := s.db.Exec(`DELETE FROM backtest_runs WHERE run_id = ?`, runID)
|
||||
return err
|
||||
// Delete related records first (cascade may not work in all cases)
|
||||
s.db.Where("run_id = ?", runID).Delete(&BacktestCheckpoint{})
|
||||
s.db.Where("run_id = ?", runID).Delete(&BacktestEquity{})
|
||||
s.db.Where("run_id = ?", runID).Delete(&BacktestTrade{})
|
||||
s.db.Where("run_id = ?", runID).Delete(&BacktestMetrics{})
|
||||
s.db.Where("run_id = ?", runID).Delete(&BacktestDecision{})
|
||||
|
||||
return s.db.Where("run_id = ?", runID).Delete(&BacktestRun{}).Error
|
||||
}
|
||||
|
||||
// SaveConfig saves config
|
||||
func (s *BacktestStore) SaveConfig(runID, userID, template, customPrompt, provider, model string, override bool, configJSON []byte) error {
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
if userID == "" {
|
||||
userID = "default"
|
||||
}
|
||||
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO backtest_runs (run_id, user_id, config_json, prompt_template, custom_prompt,
|
||||
override_prompt, ai_provider, ai_model, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(run_id) DO NOTHING
|
||||
`, runID, userID, configJSON, template, customPrompt, override, provider, model, now, now)
|
||||
if err != nil {
|
||||
return err
|
||||
run := BacktestRun{
|
||||
RunID: runID,
|
||||
UserID: userID,
|
||||
ConfigJSON: configJSON,
|
||||
PromptTemplate: template,
|
||||
CustomPrompt: customPrompt,
|
||||
OverridePrompt: override,
|
||||
AIProvider: provider,
|
||||
AIModel: model,
|
||||
}
|
||||
|
||||
_, err = s.db.Exec(`
|
||||
UPDATE backtest_runs
|
||||
SET user_id = ?, config_json = ?, prompt_template = ?, custom_prompt = ?,
|
||||
override_prompt = ?, ai_provider = ?, ai_model = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE run_id = ?
|
||||
`, userID, configJSON, template, customPrompt, override, provider, model, runID)
|
||||
return err
|
||||
return s.db.Save(&run).Error
|
||||
}
|
||||
|
||||
// LoadConfig loads config
|
||||
func (s *BacktestStore) LoadConfig(runID string) ([]byte, error) {
|
||||
var payload []byte
|
||||
err := s.db.QueryRow(`SELECT config_json FROM backtest_runs WHERE run_id = ?`, runID).Scan(&payload)
|
||||
return payload, err
|
||||
var run BacktestRun
|
||||
err := s.db.Where("run_id = ?", runID).First(&run).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return run.ConfigJSON, nil
|
||||
}
|
||||
|
||||
+215
-481
@@ -1,12 +1,11 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// DebateStatus represents the status of a debate session
|
||||
@@ -49,7 +48,26 @@ var PersonalityEmojis = map[DebatePersonality]string{
|
||||
PersonalityRiskManager: "🛡️",
|
||||
}
|
||||
|
||||
// DebateSession represents a debate session
|
||||
// DebateDecision represents a trading decision from the debate
|
||||
type DebateDecision struct {
|
||||
Action string `json:"action"` // open_long/open_short/close_long/close_short/hold/wait
|
||||
Symbol string `json:"symbol"` // Trading pair
|
||||
Confidence int `json:"confidence"` // 0-100
|
||||
Leverage int `json:"leverage"` // Recommended leverage
|
||||
PositionPct float64 `json:"position_pct"` // Position size as percentage of equity (0.0-1.0)
|
||||
PositionSizeUSD float64 `json:"position_size_usd"` // Position size in USD (calculated from pct)
|
||||
StopLoss float64 `json:"stop_loss"` // Stop loss price
|
||||
TakeProfit float64 `json:"take_profit"` // Take profit price
|
||||
Reasoning string `json:"reasoning"` // Brief reasoning
|
||||
|
||||
// Execution tracking
|
||||
Executed bool `json:"executed"` // Whether this decision was executed
|
||||
ExecutedAt time.Time `json:"executed_at,omitempty"` // When it was executed
|
||||
OrderID string `json:"order_id,omitempty"` // Exchange order ID
|
||||
Error string `json:"error,omitempty"` // Execution error if any
|
||||
}
|
||||
|
||||
// DebateSession represents a debate session (API struct)
|
||||
type DebateSession struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
@@ -73,191 +91,157 @@ type DebateSession struct {
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// DebateDecision represents a trading decision from the debate
|
||||
type DebateDecision struct {
|
||||
Action string `json:"action"` // open_long/open_short/close_long/close_short/hold/wait
|
||||
Symbol string `json:"symbol"` // Trading pair
|
||||
Confidence int `json:"confidence"` // 0-100
|
||||
Leverage int `json:"leverage"` // Recommended leverage
|
||||
PositionPct float64 `json:"position_pct"` // Position size as percentage of equity (0.0-1.0)
|
||||
PositionSizeUSD float64 `json:"position_size_usd"` // Position size in USD (calculated from pct)
|
||||
StopLoss float64 `json:"stop_loss"` // Stop loss price
|
||||
TakeProfit float64 `json:"take_profit"` // Take profit price
|
||||
Reasoning string `json:"reasoning"` // Brief reasoning
|
||||
// DebateSessionDB is the GORM model for debate_sessions
|
||||
type DebateSessionDB struct {
|
||||
ID string `gorm:"column:id;primaryKey"`
|
||||
UserID string `gorm:"column:user_id;not null;index"`
|
||||
Name string `gorm:"column:name;not null"`
|
||||
StrategyID string `gorm:"column:strategy_id;not null"`
|
||||
Status DebateStatus `gorm:"column:status;not null;default:pending;index"`
|
||||
Symbol string `gorm:"column:symbol;not null"`
|
||||
MaxRounds int `gorm:"column:max_rounds;default:3"`
|
||||
CurrentRound int `gorm:"column:current_round;default:0"`
|
||||
IntervalMinutes int `gorm:"column:interval_minutes;default:5"`
|
||||
PromptVariant string `gorm:"column:prompt_variant;default:balanced"`
|
||||
FinalDecision string `gorm:"column:final_decision"` // JSON string
|
||||
AutoExecute bool `gorm:"column:auto_execute;default:false"`
|
||||
TraderID string `gorm:"column:trader_id"`
|
||||
EnableOIRanking bool `gorm:"column:enable_oi_ranking;default:false"`
|
||||
OIRankingLimit int `gorm:"column:oi_ranking_limit;default:10"`
|
||||
OIDuration string `gorm:"column:oi_duration;default:1h"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime"`
|
||||
}
|
||||
|
||||
// Execution tracking
|
||||
Executed bool `json:"executed"` // Whether this decision was executed
|
||||
ExecutedAt time.Time `json:"executed_at,omitempty"` // When it was executed
|
||||
OrderID string `json:"order_id,omitempty"` // Exchange order ID
|
||||
Error string `json:"error,omitempty"` // Execution error if any
|
||||
func (DebateSessionDB) TableName() string {
|
||||
return "debate_sessions"
|
||||
}
|
||||
|
||||
func (db *DebateSessionDB) toSession() *DebateSession {
|
||||
s := &DebateSession{
|
||||
ID: db.ID,
|
||||
UserID: db.UserID,
|
||||
Name: db.Name,
|
||||
StrategyID: db.StrategyID,
|
||||
Status: db.Status,
|
||||
Symbol: db.Symbol,
|
||||
MaxRounds: db.MaxRounds,
|
||||
CurrentRound: db.CurrentRound,
|
||||
IntervalMinutes: db.IntervalMinutes,
|
||||
PromptVariant: db.PromptVariant,
|
||||
AutoExecute: db.AutoExecute,
|
||||
TraderID: db.TraderID,
|
||||
EnableOIRanking: db.EnableOIRanking,
|
||||
OIRankingLimit: db.OIRankingLimit,
|
||||
OIDuration: db.OIDuration,
|
||||
CreatedAt: db.CreatedAt,
|
||||
UpdatedAt: db.UpdatedAt,
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
if s.IntervalMinutes == 0 {
|
||||
s.IntervalMinutes = 5
|
||||
}
|
||||
if s.PromptVariant == "" {
|
||||
s.PromptVariant = "balanced"
|
||||
}
|
||||
if s.OIRankingLimit == 0 {
|
||||
s.OIRankingLimit = 10
|
||||
}
|
||||
if s.OIDuration == "" {
|
||||
s.OIDuration = "1h"
|
||||
}
|
||||
|
||||
// Parse final decision
|
||||
if db.FinalDecision != "" {
|
||||
var decision DebateDecision
|
||||
if json.Unmarshal([]byte(db.FinalDecision), &decision) == nil {
|
||||
s.FinalDecision = &decision
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// DebateParticipant represents an AI participant in a debate
|
||||
type DebateParticipant struct {
|
||||
ID string `json:"id"`
|
||||
SessionID string `json:"session_id"`
|
||||
AIModelID string `json:"ai_model_id"`
|
||||
AIModelName string `json:"ai_model_name"`
|
||||
Provider string `json:"provider"`
|
||||
Personality DebatePersonality `json:"personality"`
|
||||
Color string `json:"color"`
|
||||
SpeakOrder int `json:"speak_order"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ID string `gorm:"column:id;primaryKey" json:"id"`
|
||||
SessionID string `gorm:"column:session_id;not null;index" json:"session_id"`
|
||||
AIModelID string `gorm:"column:ai_model_id;not null" json:"ai_model_id"`
|
||||
AIModelName string `gorm:"column:ai_model_name;not null" json:"ai_model_name"`
|
||||
Provider string `gorm:"column:provider;not null" json:"provider"`
|
||||
Personality DebatePersonality `gorm:"column:personality;not null" json:"personality"`
|
||||
Color string `gorm:"column:color;not null" json:"color"`
|
||||
SpeakOrder int `gorm:"column:speak_order;default:0" json:"speak_order"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
|
||||
}
|
||||
|
||||
func (DebateParticipant) TableName() string {
|
||||
return "debate_participants"
|
||||
}
|
||||
|
||||
// DebateMessage represents a message in the debate
|
||||
type DebateMessage struct {
|
||||
ID string `json:"id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Round int `json:"round"`
|
||||
AIModelID string `json:"ai_model_id"`
|
||||
AIModelName string `json:"ai_model_name"`
|
||||
Provider string `json:"provider"`
|
||||
Personality DebatePersonality `json:"personality"`
|
||||
MessageType string `json:"message_type"` // analysis/rebuttal/final/vote
|
||||
Content string `json:"content"`
|
||||
Decision *DebateDecision `json:"decision,omitempty"` // Single decision (backward compat)
|
||||
Decisions []*DebateDecision `json:"decisions,omitempty"` // Multi-coin decisions
|
||||
Confidence int `json:"confidence"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ID string `gorm:"column:id;primaryKey" json:"id"`
|
||||
SessionID string `gorm:"column:session_id;not null;index" json:"session_id"`
|
||||
Round int `gorm:"column:round;not null" json:"round"`
|
||||
AIModelID string `gorm:"column:ai_model_id;not null" json:"ai_model_id"`
|
||||
AIModelName string `gorm:"column:ai_model_name;not null" json:"ai_model_name"`
|
||||
Provider string `gorm:"column:provider;not null" json:"provider"`
|
||||
Personality DebatePersonality `gorm:"column:personality;not null" json:"personality"`
|
||||
MessageType string `gorm:"column:message_type;not null" json:"message_type"` // analysis/rebuttal/final/vote
|
||||
Content string `gorm:"column:content;not null" json:"content"`
|
||||
DecisionRaw string `gorm:"column:decision" json:"-"` // JSON string in DB
|
||||
Decision *DebateDecision `gorm:"-" json:"decision,omitempty"` // Parsed for API
|
||||
Decisions []*DebateDecision `gorm:"-" json:"decisions,omitempty"` // Multi-coin decisions
|
||||
Confidence int `gorm:"column:confidence;default:0" json:"confidence"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
|
||||
}
|
||||
|
||||
func (DebateMessage) TableName() string {
|
||||
return "debate_messages"
|
||||
}
|
||||
|
||||
// DebateVote represents a final vote from an AI (can contain multiple coin decisions)
|
||||
type DebateVote struct {
|
||||
ID string `json:"id"`
|
||||
SessionID string `json:"session_id"`
|
||||
AIModelID string `json:"ai_model_id"`
|
||||
AIModelName string `json:"ai_model_name"`
|
||||
Action string `json:"action"` // Primary action (backward compat)
|
||||
Symbol string `json:"symbol"` // Primary symbol (backward compat)
|
||||
Confidence int `json:"confidence"`
|
||||
Leverage int `json:"leverage"`
|
||||
PositionPct float64 `json:"position_pct"`
|
||||
StopLossPct float64 `json:"stop_loss_pct"`
|
||||
TakeProfitPct float64 `json:"take_profit_pct"`
|
||||
Reasoning string `json:"reasoning"`
|
||||
Decisions []*DebateDecision `json:"decisions,omitempty"` // Multi-coin decisions
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ID string `gorm:"column:id;primaryKey" json:"id"`
|
||||
SessionID string `gorm:"column:session_id;not null;index" json:"session_id"`
|
||||
AIModelID string `gorm:"column:ai_model_id;not null" json:"ai_model_id"`
|
||||
AIModelName string `gorm:"column:ai_model_name;not null" json:"ai_model_name"`
|
||||
Action string `gorm:"column:action;not null" json:"action"` // Primary action (backward compat)
|
||||
Symbol string `gorm:"column:symbol;not null" json:"symbol"` // Primary symbol (backward compat)
|
||||
Confidence int `gorm:"column:confidence;default:0" json:"confidence"`
|
||||
Leverage int `gorm:"column:leverage;default:5" json:"leverage"`
|
||||
PositionPct float64 `gorm:"column:position_pct;default:0.2" json:"position_pct"`
|
||||
StopLossPct float64 `gorm:"column:stop_loss_pct;default:0.03" json:"stop_loss_pct"`
|
||||
TakeProfitPct float64 `gorm:"column:take_profit_pct;default:0.06" json:"take_profit_pct"`
|
||||
Reasoning string `gorm:"column:reasoning" json:"reasoning"`
|
||||
Decisions []*DebateDecision `gorm:"-" json:"decisions,omitempty"` // Multi-coin decisions
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
|
||||
}
|
||||
|
||||
func (DebateVote) TableName() string {
|
||||
return "debate_votes"
|
||||
}
|
||||
|
||||
// DebateStore handles database operations for debates
|
||||
type DebateStore struct {
|
||||
db *sql.DB
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewDebateStore creates a new DebateStore
|
||||
func NewDebateStore(db *sql.DB) *DebateStore {
|
||||
func NewDebateStore(db *gorm.DB) *DebateStore {
|
||||
return &DebateStore{db: db}
|
||||
}
|
||||
|
||||
// InitSchema creates the debate tables
|
||||
// InitSchema creates the debate tables using GORM AutoMigrate
|
||||
func (s *DebateStore) InitSchema() error {
|
||||
schemas := []string{
|
||||
`CREATE TABLE IF NOT EXISTS debate_sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
strategy_id TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
symbol TEXT NOT NULL,
|
||||
max_rounds INTEGER DEFAULT 3,
|
||||
current_round INTEGER DEFAULT 0,
|
||||
interval_minutes INTEGER DEFAULT 5,
|
||||
prompt_variant TEXT DEFAULT 'balanced',
|
||||
final_decision TEXT,
|
||||
auto_execute BOOLEAN DEFAULT 0,
|
||||
trader_id TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_debate_sessions_user_id ON debate_sessions(user_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_debate_sessions_status ON debate_sessions(status)`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS debate_participants (
|
||||
id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
ai_model_id TEXT NOT NULL,
|
||||
ai_model_name TEXT NOT NULL,
|
||||
provider TEXT NOT NULL,
|
||||
personality TEXT NOT NULL,
|
||||
color TEXT NOT NULL,
|
||||
speak_order INTEGER DEFAULT 0,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (session_id) REFERENCES debate_sessions(id) ON DELETE CASCADE
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_debate_participants_session ON debate_participants(session_id)`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS debate_messages (
|
||||
id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
round INTEGER NOT NULL,
|
||||
ai_model_id TEXT NOT NULL,
|
||||
ai_model_name TEXT NOT NULL,
|
||||
provider TEXT NOT NULL,
|
||||
personality TEXT NOT NULL,
|
||||
message_type TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
decision TEXT,
|
||||
confidence INTEGER DEFAULT 0,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (session_id) REFERENCES debate_sessions(id) ON DELETE CASCADE
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_debate_messages_session ON debate_messages(session_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_debate_messages_round ON debate_messages(session_id, round)`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS debate_votes (
|
||||
id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
ai_model_id TEXT NOT NULL,
|
||||
ai_model_name TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
confidence INTEGER DEFAULT 0,
|
||||
leverage INTEGER DEFAULT 5,
|
||||
position_pct REAL DEFAULT 0.2,
|
||||
stop_loss_pct REAL DEFAULT 0.03,
|
||||
take_profit_pct REAL DEFAULT 0.06,
|
||||
reasoning TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (session_id) REFERENCES debate_sessions(id) ON DELETE CASCADE
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_debate_votes_session ON debate_votes(session_id)`,
|
||||
|
||||
// Trigger to update updated_at
|
||||
`CREATE TRIGGER IF NOT EXISTS update_debate_sessions_timestamp
|
||||
AFTER UPDATE ON debate_sessions
|
||||
FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE debate_sessions SET updated_at = CURRENT_TIMESTAMP WHERE id = OLD.id;
|
||||
END`,
|
||||
}
|
||||
|
||||
for _, schema := range schemas {
|
||||
if _, err := s.db.Exec(schema); err != nil {
|
||||
return fmt.Errorf("failed to create debate schema: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Migrate: Add new columns to existing tables (ignore errors if columns already exist)
|
||||
migrations := []string{
|
||||
`ALTER TABLE debate_sessions ADD COLUMN interval_minutes INTEGER DEFAULT 5`,
|
||||
`ALTER TABLE debate_sessions ADD COLUMN prompt_variant TEXT DEFAULT 'balanced'`,
|
||||
`ALTER TABLE debate_sessions ADD COLUMN trader_id TEXT`,
|
||||
`ALTER TABLE debate_sessions ADD COLUMN enable_oi_ranking BOOLEAN DEFAULT 0`,
|
||||
`ALTER TABLE debate_sessions ADD COLUMN oi_ranking_limit INTEGER DEFAULT 10`,
|
||||
`ALTER TABLE debate_sessions ADD COLUMN oi_duration TEXT DEFAULT '1h'`,
|
||||
`ALTER TABLE debate_votes ADD COLUMN leverage INTEGER DEFAULT 5`,
|
||||
`ALTER TABLE debate_votes ADD COLUMN position_pct REAL DEFAULT 0.2`,
|
||||
`ALTER TABLE debate_votes ADD COLUMN stop_loss_pct REAL DEFAULT 0.03`,
|
||||
`ALTER TABLE debate_votes ADD COLUMN take_profit_pct REAL DEFAULT 0.06`,
|
||||
}
|
||||
|
||||
for _, migration := range migrations {
|
||||
// Ignore errors - column may already exist
|
||||
s.db.Exec(migration)
|
||||
}
|
||||
|
||||
return nil
|
||||
return s.db.AutoMigrate(
|
||||
&DebateSessionDB{},
|
||||
&DebateParticipant{},
|
||||
&DebateMessage{},
|
||||
&DebateVote{},
|
||||
)
|
||||
}
|
||||
|
||||
// CreateSession creates a new debate session
|
||||
@@ -279,227 +263,73 @@ func (s *DebateStore) CreateSession(session *DebateSession) error {
|
||||
if session.OIDuration == "" {
|
||||
session.OIDuration = "1h"
|
||||
}
|
||||
session.CreatedAt = time.Now()
|
||||
session.UpdatedAt = time.Now()
|
||||
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO debate_sessions (id, user_id, name, strategy_id, status, symbol, max_rounds, current_round, interval_minutes, prompt_variant, auto_execute, trader_id, enable_oi_ranking, oi_ranking_limit, oi_duration, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
session.ID, session.UserID, session.Name, session.StrategyID, session.Status,
|
||||
session.Symbol, session.MaxRounds, session.CurrentRound, session.IntervalMinutes, session.PromptVariant,
|
||||
session.AutoExecute, session.TraderID, session.EnableOIRanking, session.OIRankingLimit, session.OIDuration,
|
||||
session.CreatedAt, session.UpdatedAt,
|
||||
)
|
||||
return err
|
||||
db := &DebateSessionDB{
|
||||
ID: session.ID,
|
||||
UserID: session.UserID,
|
||||
Name: session.Name,
|
||||
StrategyID: session.StrategyID,
|
||||
Status: session.Status,
|
||||
Symbol: session.Symbol,
|
||||
MaxRounds: session.MaxRounds,
|
||||
CurrentRound: session.CurrentRound,
|
||||
IntervalMinutes: session.IntervalMinutes,
|
||||
PromptVariant: session.PromptVariant,
|
||||
AutoExecute: session.AutoExecute,
|
||||
TraderID: session.TraderID,
|
||||
EnableOIRanking: session.EnableOIRanking,
|
||||
OIRankingLimit: session.OIRankingLimit,
|
||||
OIDuration: session.OIDuration,
|
||||
}
|
||||
|
||||
return s.db.Create(db).Error
|
||||
}
|
||||
|
||||
// GetSession gets a debate session by ID
|
||||
func (s *DebateStore) GetSession(id string) (*DebateSession, error) {
|
||||
var session DebateSession
|
||||
var finalDecisionJSON sql.NullString
|
||||
var traderID sql.NullString
|
||||
var intervalMinutes sql.NullInt64
|
||||
var promptVariant sql.NullString
|
||||
var enableOIRanking sql.NullBool
|
||||
var oiRankingLimit sql.NullInt64
|
||||
var oiDuration sql.NullString
|
||||
|
||||
// Try new schema first
|
||||
err := s.db.QueryRow(`
|
||||
SELECT id, user_id, name, strategy_id, status, symbol, max_rounds, current_round,
|
||||
interval_minutes, prompt_variant, final_decision, auto_execute, trader_id,
|
||||
enable_oi_ranking, oi_ranking_limit, oi_duration, created_at, updated_at
|
||||
FROM debate_sessions WHERE id = ?`, id,
|
||||
).Scan(
|
||||
&session.ID, &session.UserID, &session.Name, &session.StrategyID,
|
||||
&session.Status, &session.Symbol, &session.MaxRounds, &session.CurrentRound,
|
||||
&intervalMinutes, &promptVariant,
|
||||
&finalDecisionJSON, &session.AutoExecute, &traderID,
|
||||
&enableOIRanking, &oiRankingLimit, &oiDuration, &session.CreatedAt, &session.UpdatedAt,
|
||||
)
|
||||
|
||||
// Fallback to basic schema if new columns don't exist
|
||||
if err != nil {
|
||||
err = s.db.QueryRow(`
|
||||
SELECT id, user_id, name, strategy_id, status, symbol, max_rounds, current_round,
|
||||
final_decision, auto_execute, created_at, updated_at
|
||||
FROM debate_sessions WHERE id = ?`, id,
|
||||
).Scan(
|
||||
&session.ID, &session.UserID, &session.Name, &session.StrategyID,
|
||||
&session.Status, &session.Symbol, &session.MaxRounds, &session.CurrentRound,
|
||||
&finalDecisionJSON, &session.AutoExecute, &session.CreatedAt, &session.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
var db DebateSessionDB
|
||||
if err := s.db.Where("id = ?", id).First(&db).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Set defaults for new fields
|
||||
session.IntervalMinutes = 5
|
||||
session.PromptVariant = "balanced"
|
||||
session.OIRankingLimit = 10
|
||||
session.OIDuration = "1h"
|
||||
} else {
|
||||
// Set defaults for nullable fields
|
||||
session.IntervalMinutes = 5
|
||||
if intervalMinutes.Valid {
|
||||
session.IntervalMinutes = int(intervalMinutes.Int64)
|
||||
}
|
||||
session.PromptVariant = "balanced"
|
||||
if promptVariant.Valid {
|
||||
session.PromptVariant = promptVariant.String
|
||||
}
|
||||
if traderID.Valid {
|
||||
session.TraderID = traderID.String
|
||||
}
|
||||
if enableOIRanking.Valid {
|
||||
session.EnableOIRanking = enableOIRanking.Bool
|
||||
}
|
||||
session.OIRankingLimit = 10
|
||||
if oiRankingLimit.Valid {
|
||||
session.OIRankingLimit = int(oiRankingLimit.Int64)
|
||||
}
|
||||
session.OIDuration = "1h"
|
||||
if oiDuration.Valid {
|
||||
session.OIDuration = oiDuration.String
|
||||
}
|
||||
}
|
||||
|
||||
if finalDecisionJSON.Valid && finalDecisionJSON.String != "" {
|
||||
var decision DebateDecision
|
||||
if err := json.Unmarshal([]byte(finalDecisionJSON.String), &decision); err == nil {
|
||||
session.FinalDecision = &decision
|
||||
}
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
return db.toSession(), nil
|
||||
}
|
||||
|
||||
// GetSessionsByUser gets all debate sessions for a user
|
||||
func (s *DebateStore) GetSessionsByUser(userID string) ([]*DebateSession, error) {
|
||||
// First try the new schema with all columns
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, user_id, name, strategy_id, status, symbol, max_rounds, current_round,
|
||||
interval_minutes, prompt_variant, final_decision, auto_execute, trader_id, created_at, updated_at
|
||||
FROM debate_sessions WHERE user_id = ? ORDER BY created_at DESC`, userID,
|
||||
)
|
||||
|
||||
// If query fails (likely due to missing columns), try basic query
|
||||
if err != nil {
|
||||
return s.getSessionsByUserBasic(userID)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var sessions []*DebateSession
|
||||
for rows.Next() {
|
||||
var session DebateSession
|
||||
var finalDecisionJSON sql.NullString
|
||||
var traderID sql.NullString
|
||||
var intervalMinutes sql.NullInt64
|
||||
var promptVariant sql.NullString
|
||||
|
||||
if err := rows.Scan(
|
||||
&session.ID, &session.UserID, &session.Name, &session.StrategyID,
|
||||
&session.Status, &session.Symbol, &session.MaxRounds, &session.CurrentRound,
|
||||
&intervalMinutes, &promptVariant,
|
||||
&finalDecisionJSON, &session.AutoExecute, &traderID, &session.CreatedAt, &session.UpdatedAt,
|
||||
); err != nil {
|
||||
var dbs []DebateSessionDB
|
||||
if err := s.db.Where("user_id = ?", userID).Order("created_at DESC").Find(&dbs).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set defaults for nullable fields
|
||||
session.IntervalMinutes = 5
|
||||
if intervalMinutes.Valid {
|
||||
session.IntervalMinutes = int(intervalMinutes.Int64)
|
||||
}
|
||||
session.PromptVariant = "balanced"
|
||||
if promptVariant.Valid {
|
||||
session.PromptVariant = promptVariant.String
|
||||
}
|
||||
|
||||
if finalDecisionJSON.Valid && finalDecisionJSON.String != "" {
|
||||
var decision DebateDecision
|
||||
if err := json.Unmarshal([]byte(finalDecisionJSON.String), &decision); err == nil {
|
||||
session.FinalDecision = &decision
|
||||
}
|
||||
}
|
||||
if traderID.Valid {
|
||||
session.TraderID = traderID.String
|
||||
}
|
||||
|
||||
sessions = append(sessions, &session)
|
||||
sessions := make([]*DebateSession, len(dbs))
|
||||
for i, db := range dbs {
|
||||
sessions[i] = db.toSession()
|
||||
}
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
// ListAllSessions returns all debate sessions (for cleanup on startup)
|
||||
func (s *DebateStore) ListAllSessions() ([]*DebateSession, error) {
|
||||
rows, err := s.db.Query(`SELECT id, status FROM debate_sessions`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var sessions []*DebateSession
|
||||
for rows.Next() {
|
||||
var session DebateSession
|
||||
if err := rows.Scan(&session.ID, &session.Status); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessions = append(sessions, &session)
|
||||
}
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
// getSessionsByUserBasic is a fallback for old schema without new columns
|
||||
func (s *DebateStore) getSessionsByUserBasic(userID string) ([]*DebateSession, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, user_id, name, strategy_id, status, symbol, max_rounds, current_round,
|
||||
final_decision, auto_execute, created_at, updated_at
|
||||
FROM debate_sessions WHERE user_id = ? ORDER BY created_at DESC`, userID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var sessions []*DebateSession
|
||||
for rows.Next() {
|
||||
var session DebateSession
|
||||
var finalDecisionJSON sql.NullString
|
||||
|
||||
if err := rows.Scan(
|
||||
&session.ID, &session.UserID, &session.Name, &session.StrategyID,
|
||||
&session.Status, &session.Symbol, &session.MaxRounds, &session.CurrentRound,
|
||||
&finalDecisionJSON, &session.AutoExecute, &session.CreatedAt, &session.UpdatedAt,
|
||||
); err != nil {
|
||||
var dbs []DebateSessionDB
|
||||
if err := s.db.Select("id, status").Find(&dbs).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set defaults for new fields
|
||||
session.IntervalMinutes = 5
|
||||
session.PromptVariant = "balanced"
|
||||
|
||||
if finalDecisionJSON.Valid && finalDecisionJSON.String != "" {
|
||||
var decision DebateDecision
|
||||
if err := json.Unmarshal([]byte(finalDecisionJSON.String), &decision); err == nil {
|
||||
session.FinalDecision = &decision
|
||||
}
|
||||
}
|
||||
|
||||
sessions = append(sessions, &session)
|
||||
sessions := make([]*DebateSession, len(dbs))
|
||||
for i, db := range dbs {
|
||||
sessions[i] = &DebateSession{ID: db.ID, Status: db.Status}
|
||||
}
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
// UpdateSessionStatus updates the status of a debate session
|
||||
func (s *DebateStore) UpdateSessionStatus(id string, status DebateStatus) error {
|
||||
_, err := s.db.Exec(`UPDATE debate_sessions SET status = ? WHERE id = ?`, status, id)
|
||||
return err
|
||||
return s.db.Model(&DebateSessionDB{}).Where("id = ?", id).Update("status", status).Error
|
||||
}
|
||||
|
||||
// UpdateSessionRound updates the current round of a debate session
|
||||
func (s *DebateStore) UpdateSessionRound(id string, round int) error {
|
||||
_, err := s.db.Exec(`UPDATE debate_sessions SET current_round = ? WHERE id = ?`, round, id)
|
||||
return err
|
||||
return s.db.Model(&DebateSessionDB{}).Where("id = ?", id).Update("current_round", round).Error
|
||||
}
|
||||
|
||||
// UpdateSessionFinalDecision updates the final decision of a debate session (single decision)
|
||||
@@ -508,30 +338,31 @@ func (s *DebateStore) UpdateSessionFinalDecision(id string, decision *DebateDeci
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.db.Exec(`UPDATE debate_sessions SET final_decision = ?, status = ? WHERE id = ?`,
|
||||
string(decisionJSON), DebateStatusCompleted, id)
|
||||
return err
|
||||
return s.db.Model(&DebateSessionDB{}).Where("id = ?", id).Updates(map[string]interface{}{
|
||||
"final_decision": string(decisionJSON),
|
||||
"status": DebateStatusCompleted,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdateSessionFinalDecisions updates both single and multi-coin final decisions
|
||||
func (s *DebateStore) UpdateSessionFinalDecisions(id string, primaryDecision *DebateDecision, allDecisions []*DebateDecision) error {
|
||||
// Always store primary decision as a single object (for backward compat)
|
||||
// This ensures GetSession can deserialize it correctly
|
||||
primaryJSON, err := json.Marshal(primaryDecision)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update final_decision with primary decision and set status to completed
|
||||
_, err = s.db.Exec(`UPDATE debate_sessions SET final_decision = ?, status = ? WHERE id = ?`,
|
||||
string(primaryJSON), DebateStatusCompleted, id)
|
||||
return err
|
||||
return s.db.Model(&DebateSessionDB{}).Where("id = ?", id).Updates(map[string]interface{}{
|
||||
"final_decision": string(primaryJSON),
|
||||
"status": DebateStatusCompleted,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// DeleteSession deletes a debate session and all related data
|
||||
func (s *DebateStore) DeleteSession(id string) error {
|
||||
_, err := s.db.Exec(`DELETE FROM debate_sessions WHERE id = ?`, id)
|
||||
return err
|
||||
// Delete related data first
|
||||
s.db.Where("session_id = ?", id).Delete(&DebateParticipant{})
|
||||
s.db.Where("session_id = ?", id).Delete(&DebateMessage{})
|
||||
s.db.Where("session_id = ?", id).Delete(&DebateVote{})
|
||||
return s.db.Where("id = ?", id).Delete(&DebateSessionDB{}).Error
|
||||
}
|
||||
|
||||
// AddParticipant adds a participant to a debate session
|
||||
@@ -539,9 +370,6 @@ func (s *DebateStore) AddParticipant(participant *DebateParticipant) error {
|
||||
if participant.ID == "" {
|
||||
participant.ID = uuid.New().String()
|
||||
}
|
||||
participant.CreatedAt = time.Now()
|
||||
|
||||
// Set color based on personality if not provided
|
||||
if participant.Color == "" {
|
||||
if color, ok := PersonalityColors[participant.Personality]; ok {
|
||||
participant.Color = color
|
||||
@@ -549,39 +377,14 @@ func (s *DebateStore) AddParticipant(participant *DebateParticipant) error {
|
||||
participant.Color = "#6B7280" // Default gray
|
||||
}
|
||||
}
|
||||
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO debate_participants (id, session_id, ai_model_id, ai_model_name, provider, personality, color, speak_order, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
participant.ID, participant.SessionID, participant.AIModelID, participant.AIModelName,
|
||||
participant.Provider, participant.Personality, participant.Color, participant.SpeakOrder, participant.CreatedAt,
|
||||
)
|
||||
return err
|
||||
return s.db.Create(participant).Error
|
||||
}
|
||||
|
||||
// GetParticipants gets all participants for a debate session
|
||||
func (s *DebateStore) GetParticipants(sessionID string) ([]*DebateParticipant, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, session_id, ai_model_id, ai_model_name, provider, personality, color, speak_order, created_at
|
||||
FROM debate_participants WHERE session_id = ? ORDER BY speak_order`, sessionID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var participants []*DebateParticipant
|
||||
for rows.Next() {
|
||||
var p DebateParticipant
|
||||
if err := rows.Scan(
|
||||
&p.ID, &p.SessionID, &p.AIModelID, &p.AIModelName,
|
||||
&p.Provider, &p.Personality, &p.Color, &p.SpeakOrder, &p.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
participants = append(participants, &p)
|
||||
}
|
||||
return participants, nil
|
||||
err := s.db.Where("session_id = ?", sessionID).Order("speak_order").Find(&participants).Error
|
||||
return participants, err
|
||||
}
|
||||
|
||||
// AddMessage adds a message to a debate session
|
||||
@@ -589,95 +392,52 @@ func (s *DebateStore) AddMessage(msg *DebateMessage) error {
|
||||
if msg.ID == "" {
|
||||
msg.ID = uuid.New().String()
|
||||
}
|
||||
msg.CreatedAt = time.Now()
|
||||
|
||||
var decisionJSON sql.NullString
|
||||
if msg.Decision != nil {
|
||||
data, err := json.Marshal(msg.Decision)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decisionJSON = sql.NullString{String: string(data), Valid: true}
|
||||
msg.DecisionRaw = string(data)
|
||||
}
|
||||
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO debate_messages (id, session_id, round, ai_model_id, ai_model_name, provider, personality, message_type, content, decision, confidence, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
msg.ID, msg.SessionID, msg.Round, msg.AIModelID, msg.AIModelName,
|
||||
msg.Provider, msg.Personality, msg.MessageType, msg.Content,
|
||||
decisionJSON, msg.Confidence, msg.CreatedAt,
|
||||
)
|
||||
return err
|
||||
return s.db.Create(msg).Error
|
||||
}
|
||||
|
||||
// GetMessages gets all messages for a debate session
|
||||
func (s *DebateStore) GetMessages(sessionID string) ([]*DebateMessage, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, session_id, round, ai_model_id, ai_model_name, provider, personality, message_type, content, decision, confidence, created_at
|
||||
FROM debate_messages WHERE session_id = ? ORDER BY round, created_at`, sessionID,
|
||||
)
|
||||
var messages []*DebateMessage
|
||||
err := s.db.Where("session_id = ?", sessionID).Order("round, created_at").Find(&messages).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var messages []*DebateMessage
|
||||
for rows.Next() {
|
||||
var msg DebateMessage
|
||||
var decisionJSON sql.NullString
|
||||
|
||||
if err := rows.Scan(
|
||||
&msg.ID, &msg.SessionID, &msg.Round, &msg.AIModelID, &msg.AIModelName,
|
||||
&msg.Provider, &msg.Personality, &msg.MessageType, &msg.Content,
|
||||
&decisionJSON, &msg.Confidence, &msg.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if decisionJSON.Valid && decisionJSON.String != "" {
|
||||
// Parse decision JSON
|
||||
for _, msg := range messages {
|
||||
if msg.DecisionRaw != "" {
|
||||
var decision DebateDecision
|
||||
if err := json.Unmarshal([]byte(decisionJSON.String), &decision); err == nil {
|
||||
if json.Unmarshal([]byte(msg.DecisionRaw), &decision) == nil {
|
||||
msg.Decision = &decision
|
||||
}
|
||||
}
|
||||
|
||||
messages = append(messages, &msg)
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// GetMessagesByRound gets messages for a specific round
|
||||
func (s *DebateStore) GetMessagesByRound(sessionID string, round int) ([]*DebateMessage, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, session_id, round, ai_model_id, ai_model_name, provider, personality, message_type, content, decision, confidence, created_at
|
||||
FROM debate_messages WHERE session_id = ? AND round = ? ORDER BY created_at`, sessionID, round,
|
||||
)
|
||||
var messages []*DebateMessage
|
||||
err := s.db.Where("session_id = ? AND round = ?", sessionID, round).Order("created_at").Find(&messages).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var messages []*DebateMessage
|
||||
for rows.Next() {
|
||||
var msg DebateMessage
|
||||
var decisionJSON sql.NullString
|
||||
|
||||
if err := rows.Scan(
|
||||
&msg.ID, &msg.SessionID, &msg.Round, &msg.AIModelID, &msg.AIModelName,
|
||||
&msg.Provider, &msg.Personality, &msg.MessageType, &msg.Content,
|
||||
&decisionJSON, &msg.Confidence, &msg.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if decisionJSON.Valid && decisionJSON.String != "" {
|
||||
// Parse decision JSON
|
||||
for _, msg := range messages {
|
||||
if msg.DecisionRaw != "" {
|
||||
var decision DebateDecision
|
||||
if err := json.Unmarshal([]byte(decisionJSON.String), &decision); err == nil {
|
||||
if json.Unmarshal([]byte(msg.DecisionRaw), &decision) == nil {
|
||||
msg.Decision = &decision
|
||||
}
|
||||
}
|
||||
|
||||
messages = append(messages, &msg)
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
@@ -687,40 +447,14 @@ func (s *DebateStore) AddVote(vote *DebateVote) error {
|
||||
if vote.ID == "" {
|
||||
vote.ID = uuid.New().String()
|
||||
}
|
||||
vote.CreatedAt = time.Now()
|
||||
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO debate_votes (id, session_id, ai_model_id, ai_model_name, action, symbol, confidence, leverage, position_pct, stop_loss_pct, take_profit_pct, reasoning, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
vote.ID, vote.SessionID, vote.AIModelID, vote.AIModelName,
|
||||
vote.Action, vote.Symbol, vote.Confidence, vote.Leverage, vote.PositionPct, vote.StopLossPct, vote.TakeProfitPct, vote.Reasoning, vote.CreatedAt,
|
||||
)
|
||||
return err
|
||||
return s.db.Create(vote).Error
|
||||
}
|
||||
|
||||
// GetVotes gets all votes for a debate session
|
||||
func (s *DebateStore) GetVotes(sessionID string) ([]*DebateVote, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, session_id, ai_model_id, ai_model_name, action, symbol, confidence, leverage, position_pct, stop_loss_pct, take_profit_pct, reasoning, created_at
|
||||
FROM debate_votes WHERE session_id = ? ORDER BY created_at`, sessionID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var votes []*DebateVote
|
||||
for rows.Next() {
|
||||
var vote DebateVote
|
||||
if err := rows.Scan(
|
||||
&vote.ID, &vote.SessionID, &vote.AIModelID, &vote.AIModelName,
|
||||
&vote.Action, &vote.Symbol, &vote.Confidence, &vote.Leverage, &vote.PositionPct, &vote.StopLossPct, &vote.TakeProfitPct, &vote.Reasoning, &vote.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
votes = append(votes, &vote)
|
||||
}
|
||||
return votes, nil
|
||||
err := s.db.Where("session_id = ?", sessionID).Order("created_at").Find(&votes).Error
|
||||
return votes, err
|
||||
}
|
||||
|
||||
// DebateSessionWithDetails combines session with participants and messages
|
||||
|
||||
+132
-195
@@ -1,18 +1,41 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// DecisionStore decision log storage
|
||||
type DecisionStore struct {
|
||||
db *sql.DB
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// DecisionRecord decision record
|
||||
// DecisionRecordDB internal GORM model for decision_records table
|
||||
type DecisionRecordDB struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
TraderID string `gorm:"column:trader_id;not null;index:idx_decision_records_trader_time"`
|
||||
CycleNumber int `gorm:"column:cycle_number;not null"`
|
||||
Timestamp time.Time `gorm:"not null;index:idx_decision_records_trader_time,sort:desc;index:idx_decision_records_timestamp,sort:desc"`
|
||||
SystemPrompt string `gorm:"column:system_prompt;default:''"`
|
||||
InputPrompt string `gorm:"column:input_prompt;default:''"`
|
||||
CoTTrace string `gorm:"column:cot_trace;default:''"`
|
||||
DecisionJSON string `gorm:"column:decision_json;default:''"`
|
||||
RawResponse string `gorm:"column:raw_response;default:''"`
|
||||
CandidateCoins string `gorm:"column:candidate_coins;default:''"`
|
||||
ExecutionLog string `gorm:"column:execution_log;default:''"`
|
||||
Decisions string `gorm:"column:decisions;default:'[]'"`
|
||||
Success bool `gorm:"default:false"`
|
||||
ErrorMessage string `gorm:"column:error_message;default:''"`
|
||||
AIRequestDurationMs int64 `gorm:"column:ai_request_duration_ms;default:0"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
func (DecisionRecordDB) TableName() string { return "decision_records" }
|
||||
|
||||
// DecisionRecord decision record (external API struct)
|
||||
type DecisionRecord struct {
|
||||
ID int64 `json:"id"`
|
||||
TraderID string `json:"trader_id"`
|
||||
@@ -81,49 +104,47 @@ type Statistics struct {
|
||||
TotalClosePositions int `json:"total_close_positions"`
|
||||
}
|
||||
|
||||
// NewDecisionStore creates a new DecisionStore
|
||||
func NewDecisionStore(db *gorm.DB) *DecisionStore {
|
||||
return &DecisionStore{db: db}
|
||||
}
|
||||
|
||||
// initTables initializes AI decision log tables
|
||||
// Note: Account equity curve data has been migrated to trader_equity_snapshots table (managed by EquityStore)
|
||||
func (s *DecisionStore) initTables() error {
|
||||
queries := []string{
|
||||
// AI decision log table (records AI input/output, chain of thought, etc.)
|
||||
`CREATE TABLE IF NOT EXISTS decision_records (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
trader_id TEXT NOT NULL,
|
||||
cycle_number INTEGER NOT NULL,
|
||||
timestamp DATETIME NOT NULL,
|
||||
system_prompt TEXT DEFAULT '',
|
||||
input_prompt TEXT DEFAULT '',
|
||||
cot_trace TEXT DEFAULT '',
|
||||
decision_json TEXT DEFAULT '',
|
||||
raw_response TEXT DEFAULT '',
|
||||
candidate_coins TEXT DEFAULT '',
|
||||
execution_log TEXT DEFAULT '',
|
||||
success BOOLEAN DEFAULT 0,
|
||||
error_message TEXT DEFAULT '',
|
||||
ai_request_duration_ms INTEGER DEFAULT 0,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)`,
|
||||
// Indexes
|
||||
`CREATE INDEX IF NOT EXISTS idx_decision_records_trader_time ON decision_records(trader_id, timestamp DESC)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_decision_records_timestamp ON decision_records(timestamp DESC)`,
|
||||
}
|
||||
|
||||
for _, query := range queries {
|
||||
if _, err := s.db.Exec(query); err != nil {
|
||||
return fmt.Errorf("failed to execute SQL: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Migration: add raw_response column if not exists
|
||||
s.db.Exec(`ALTER TABLE decision_records ADD COLUMN raw_response TEXT DEFAULT ''`)
|
||||
|
||||
// Migration: add decisions column if not exists
|
||||
s.db.Exec(`ALTER TABLE decision_records ADD COLUMN decisions TEXT DEFAULT '[]'`)
|
||||
|
||||
// For PostgreSQL with existing table, skip AutoMigrate
|
||||
if s.db.Dialector.Name() == "postgres" {
|
||||
var tableExists int64
|
||||
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'decision_records'`).Scan(&tableExists)
|
||||
if tableExists > 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return s.db.AutoMigrate(&DecisionRecordDB{})
|
||||
}
|
||||
|
||||
// LogDecision logs decision (only saves AI decision log, equity curve has been migrated to equity table)
|
||||
// toRecord converts DB model to API struct
|
||||
func (db *DecisionRecordDB) toRecord() *DecisionRecord {
|
||||
record := &DecisionRecord{
|
||||
ID: db.ID,
|
||||
TraderID: db.TraderID,
|
||||
CycleNumber: db.CycleNumber,
|
||||
Timestamp: db.Timestamp,
|
||||
SystemPrompt: db.SystemPrompt,
|
||||
InputPrompt: db.InputPrompt,
|
||||
CoTTrace: db.CoTTrace,
|
||||
DecisionJSON: db.DecisionJSON,
|
||||
RawResponse: db.RawResponse,
|
||||
Success: db.Success,
|
||||
ErrorMessage: db.ErrorMessage,
|
||||
AIRequestDurationMs: db.AIRequestDurationMs,
|
||||
}
|
||||
json.Unmarshal([]byte(db.CandidateCoins), &record.CandidateCoins)
|
||||
json.Unmarshal([]byte(db.ExecutionLog), &record.ExecutionLog)
|
||||
json.Unmarshal([]byte(db.Decisions), &record.Decisions)
|
||||
return record
|
||||
}
|
||||
|
||||
// LogDecision logs decision
|
||||
func (s *DecisionStore) LogDecision(record *DecisionRecord) error {
|
||||
if record.Timestamp.IsZero() {
|
||||
record.Timestamp = time.Now().UTC()
|
||||
@@ -131,65 +152,49 @@ func (s *DecisionStore) LogDecision(record *DecisionRecord) error {
|
||||
record.Timestamp = record.Timestamp.UTC()
|
||||
}
|
||||
|
||||
// Serialize candidate coins, execution log and decisions to JSON
|
||||
// Serialize arrays to JSON
|
||||
candidateCoinsJSON, _ := json.Marshal(record.CandidateCoins)
|
||||
executionLogJSON, _ := json.Marshal(record.ExecutionLog)
|
||||
decisionsJSON, _ := json.Marshal(record.Decisions)
|
||||
|
||||
// Insert decision record main table (only save AI decision related content)
|
||||
result, err := s.db.Exec(`
|
||||
INSERT INTO decision_records (
|
||||
trader_id, cycle_number, timestamp, system_prompt, input_prompt,
|
||||
cot_trace, decision_json, raw_response, candidate_coins, execution_log,
|
||||
decisions, success, error_message, ai_request_duration_ms
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`,
|
||||
record.TraderID, record.CycleNumber, record.Timestamp.Format(time.RFC3339),
|
||||
record.SystemPrompt, record.InputPrompt, record.CoTTrace, record.DecisionJSON,
|
||||
record.RawResponse, string(candidateCoinsJSON), string(executionLogJSON),
|
||||
string(decisionsJSON), record.Success, record.ErrorMessage, record.AIRequestDurationMs,
|
||||
)
|
||||
if err != nil {
|
||||
dbRecord := &DecisionRecordDB{
|
||||
TraderID: record.TraderID,
|
||||
CycleNumber: record.CycleNumber,
|
||||
Timestamp: record.Timestamp,
|
||||
SystemPrompt: record.SystemPrompt,
|
||||
InputPrompt: record.InputPrompt,
|
||||
CoTTrace: record.CoTTrace,
|
||||
DecisionJSON: record.DecisionJSON,
|
||||
RawResponse: record.RawResponse,
|
||||
CandidateCoins: string(candidateCoinsJSON),
|
||||
ExecutionLog: string(executionLogJSON),
|
||||
Decisions: string(decisionsJSON),
|
||||
Success: record.Success,
|
||||
ErrorMessage: record.ErrorMessage,
|
||||
AIRequestDurationMs: record.AIRequestDurationMs,
|
||||
}
|
||||
|
||||
if err := s.db.Create(dbRecord).Error; err != nil {
|
||||
return fmt.Errorf("failed to insert decision record: %w", err)
|
||||
}
|
||||
|
||||
decisionID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get decision ID: %w", err)
|
||||
}
|
||||
record.ID = decisionID
|
||||
|
||||
record.ID = dbRecord.ID
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLatestRecords gets the latest N records for specified trader (sorted by time in ascending order: old to new)
|
||||
func (s *DecisionStore) GetLatestRecords(traderID string, n int) ([]*DecisionRecord, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, trader_id, cycle_number, timestamp, system_prompt, input_prompt,
|
||||
cot_trace, decision_json, candidate_coins, execution_log,
|
||||
COALESCE(decisions, '[]'), success, error_message, ai_request_duration_ms
|
||||
FROM decision_records
|
||||
WHERE trader_id = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
`, traderID, n)
|
||||
var dbRecords []*DecisionRecordDB
|
||||
err := s.db.Where("trader_id = ?", traderID).
|
||||
Order("timestamp DESC").
|
||||
Limit(n).
|
||||
Find(&dbRecords).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query decision records: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var records []*DecisionRecord
|
||||
for rows.Next() {
|
||||
record, err := s.scanDecisionRecord(rows)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
records = append(records, record)
|
||||
}
|
||||
|
||||
// Fill associated data
|
||||
for _, record := range records {
|
||||
s.fillRecordDetails(record)
|
||||
records := make([]*DecisionRecord, len(dbRecords))
|
||||
for i, db := range dbRecords {
|
||||
records[i] = db.toRecord()
|
||||
}
|
||||
|
||||
// Reverse array to sort time from old to new
|
||||
@@ -202,26 +207,15 @@ func (s *DecisionStore) GetLatestRecords(traderID string, n int) ([]*DecisionRec
|
||||
|
||||
// GetAllLatestRecords gets the latest N records for all traders
|
||||
func (s *DecisionStore) GetAllLatestRecords(n int) ([]*DecisionRecord, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, trader_id, cycle_number, timestamp, system_prompt, input_prompt,
|
||||
cot_trace, decision_json, candidate_coins, execution_log,
|
||||
COALESCE(decisions, '[]'), success, error_message, ai_request_duration_ms
|
||||
FROM decision_records
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
`, n)
|
||||
var dbRecords []*DecisionRecordDB
|
||||
err := s.db.Order("timestamp DESC").Limit(n).Find(&dbRecords).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query decision records: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var records []*DecisionRecord
|
||||
for rows.Next() {
|
||||
record, err := s.scanDecisionRecord(rows)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
records = append(records, record)
|
||||
records := make([]*DecisionRecord, len(dbRecords))
|
||||
for i, db := range dbRecords {
|
||||
records[i] = db.toRecord()
|
||||
}
|
||||
|
||||
// Reverse array
|
||||
@@ -236,26 +230,17 @@ func (s *DecisionStore) GetAllLatestRecords(n int) ([]*DecisionRecord, error) {
|
||||
func (s *DecisionStore) GetRecordsByDate(traderID string, date time.Time) ([]*DecisionRecord, error) {
|
||||
dateStr := date.Format("2006-01-02")
|
||||
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, trader_id, cycle_number, timestamp, system_prompt, input_prompt,
|
||||
cot_trace, decision_json, candidate_coins, execution_log,
|
||||
COALESCE(decisions, '[]'), success, error_message, ai_request_duration_ms
|
||||
FROM decision_records
|
||||
WHERE trader_id = ? AND DATE(timestamp) = ?
|
||||
ORDER BY timestamp ASC
|
||||
`, traderID, dateStr)
|
||||
var dbRecords []*DecisionRecordDB
|
||||
err := s.db.Where("trader_id = ? AND DATE(timestamp) = ?", traderID, dateStr).
|
||||
Order("timestamp ASC").
|
||||
Find(&dbRecords).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query decision records: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var records []*DecisionRecord
|
||||
for rows.Next() {
|
||||
record, err := s.scanDecisionRecord(rows)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
records = append(records, record)
|
||||
records := make([]*DecisionRecord, len(dbRecords))
|
||||
for i, db := range dbRecords {
|
||||
records[i] = db.toRecord()
|
||||
}
|
||||
|
||||
return records, nil
|
||||
@@ -263,48 +248,31 @@ func (s *DecisionStore) GetRecordsByDate(traderID string, date time.Time) ([]*De
|
||||
|
||||
// CleanOldRecords cleans old records from N days ago
|
||||
func (s *DecisionStore) CleanOldRecords(traderID string, days int) (int64, error) {
|
||||
cutoffTime := time.Now().AddDate(0, 0, -days).Format(time.RFC3339)
|
||||
cutoffTime := time.Now().AddDate(0, 0, -days)
|
||||
|
||||
result, err := s.db.Exec(`
|
||||
DELETE FROM decision_records
|
||||
WHERE trader_id = ? AND timestamp < ?
|
||||
`, traderID, cutoffTime)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to clean old records: %w", err)
|
||||
result := s.db.Where("trader_id = ? AND timestamp < ?", traderID, cutoffTime).
|
||||
Delete(&DecisionRecordDB{})
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("failed to clean old records: %w", result.Error)
|
||||
}
|
||||
|
||||
return result.RowsAffected()
|
||||
return result.RowsAffected, nil
|
||||
}
|
||||
|
||||
// GetStatistics gets statistics information for specified trader
|
||||
func (s *DecisionStore) GetStatistics(traderID string) (*Statistics, error) {
|
||||
stats := &Statistics{}
|
||||
|
||||
err := s.db.QueryRow(`
|
||||
SELECT COUNT(*) FROM decision_records WHERE trader_id = ?
|
||||
`, traderID).Scan(&stats.TotalCycles)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query total cycles: %w", err)
|
||||
}
|
||||
var totalCount, successCount int64
|
||||
s.db.Model(&DecisionRecordDB{}).Where("trader_id = ?", traderID).Count(&totalCount)
|
||||
s.db.Model(&DecisionRecordDB{}).Where("trader_id = ? AND success = ?", traderID, true).Count(&successCount)
|
||||
|
||||
err = s.db.QueryRow(`
|
||||
SELECT COUNT(*) FROM decision_records WHERE trader_id = ? AND success = 1
|
||||
`, traderID).Scan(&stats.SuccessfulCycles)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query successful cycles: %w", err)
|
||||
}
|
||||
stats.TotalCycles = int(totalCount)
|
||||
stats.SuccessfulCycles = int(successCount)
|
||||
stats.FailedCycles = stats.TotalCycles - stats.SuccessfulCycles
|
||||
|
||||
// Count from trader_positions table
|
||||
s.db.QueryRow(`
|
||||
SELECT COUNT(*) FROM trader_positions
|
||||
WHERE trader_id = ?
|
||||
`, traderID).Scan(&stats.TotalOpenPositions)
|
||||
|
||||
s.db.QueryRow(`
|
||||
SELECT COUNT(*) FROM trader_positions
|
||||
WHERE trader_id = ? AND status = 'CLOSED'
|
||||
`, traderID).Scan(&stats.TotalClosePositions)
|
||||
// Count from trader_positions table using raw query for cross-table
|
||||
s.db.Raw("SELECT COUNT(*) FROM trader_positions WHERE trader_id = ?", traderID).Scan(&stats.TotalOpenPositions)
|
||||
s.db.Raw("SELECT COUNT(*) FROM trader_positions WHERE trader_id = ? AND status = 'CLOSED'", traderID).Scan(&stats.TotalClosePositions)
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
@@ -313,64 +281,33 @@ func (s *DecisionStore) GetStatistics(traderID string) (*Statistics, error) {
|
||||
func (s *DecisionStore) GetAllStatistics() (*Statistics, error) {
|
||||
stats := &Statistics{}
|
||||
|
||||
s.db.QueryRow(`SELECT COUNT(*) FROM decision_records`).Scan(&stats.TotalCycles)
|
||||
s.db.QueryRow(`SELECT COUNT(*) FROM decision_records WHERE success = 1`).Scan(&stats.SuccessfulCycles)
|
||||
var totalCount, successCount int64
|
||||
s.db.Model(&DecisionRecordDB{}).Count(&totalCount)
|
||||
s.db.Model(&DecisionRecordDB{}).Where("success = ?", true).Count(&successCount)
|
||||
|
||||
stats.TotalCycles = int(totalCount)
|
||||
stats.SuccessfulCycles = int(successCount)
|
||||
stats.FailedCycles = stats.TotalCycles - stats.SuccessfulCycles
|
||||
|
||||
// Count from trader_positions table
|
||||
s.db.QueryRow(`
|
||||
SELECT COUNT(*) FROM trader_positions
|
||||
`).Scan(&stats.TotalOpenPositions)
|
||||
|
||||
s.db.QueryRow(`
|
||||
SELECT COUNT(*) FROM trader_positions
|
||||
WHERE status = 'CLOSED'
|
||||
`).Scan(&stats.TotalClosePositions)
|
||||
s.db.Raw("SELECT COUNT(*) FROM trader_positions").Scan(&stats.TotalOpenPositions)
|
||||
s.db.Raw("SELECT COUNT(*) FROM trader_positions WHERE status = 'CLOSED'").Scan(&stats.TotalClosePositions)
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetLastCycleNumber gets the last cycle number for specified trader
|
||||
func (s *DecisionStore) GetLastCycleNumber(traderID string) (int, error) {
|
||||
var cycleNumber int
|
||||
err := s.db.QueryRow(`
|
||||
SELECT COALESCE(MAX(cycle_number), 0) FROM decision_records WHERE trader_id = ?
|
||||
`, traderID).Scan(&cycleNumber)
|
||||
var cycleNumber *int
|
||||
err := s.db.Model(&DecisionRecordDB{}).
|
||||
Where("trader_id = ?", traderID).
|
||||
Select("MAX(cycle_number)").
|
||||
Scan(&cycleNumber).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return cycleNumber, nil
|
||||
if cycleNumber == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// scanDecisionRecord scans decision record from row
|
||||
func (s *DecisionStore) scanDecisionRecord(rows *sql.Rows) (*DecisionRecord, error) {
|
||||
var record DecisionRecord
|
||||
var timestampStr string
|
||||
var candidateCoinsJSON, executionLogJSON, decisionsJSON string
|
||||
|
||||
err := rows.Scan(
|
||||
&record.ID, &record.TraderID, &record.CycleNumber, ×tampStr,
|
||||
&record.SystemPrompt, &record.InputPrompt, &record.CoTTrace,
|
||||
&record.DecisionJSON, &candidateCoinsJSON, &executionLogJSON,
|
||||
&decisionsJSON, &record.Success, &record.ErrorMessage, &record.AIRequestDurationMs,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
record.Timestamp, _ = time.Parse(time.RFC3339, timestampStr)
|
||||
json.Unmarshal([]byte(candidateCoinsJSON), &record.CandidateCoins)
|
||||
json.Unmarshal([]byte(executionLogJSON), &record.ExecutionLog)
|
||||
json.Unmarshal([]byte(decisionsJSON), &record.Decisions)
|
||||
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
// fillRecordDetails fills associated data for decision record (old associated tables removed, this function kept for compatibility)
|
||||
// Note: Account snapshot, position snapshot, decision action data are no longer stored in decision related tables
|
||||
// - For equity data use EquityStore.GetLatest()
|
||||
// - For order data use OrderStore
|
||||
func (s *DecisionStore) fillRecordDetails(record *DecisionRecord) {
|
||||
// Old associated tables removed, no longer need to fill
|
||||
// AccountState, Positions, Decisions fields will remain at zero values
|
||||
return *cycleNumber, nil
|
||||
}
|
||||
|
||||
@@ -238,3 +238,44 @@ func getEnv(key, defaultValue string) string {
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// convertQuery converts ? placeholders to $1, $2 for PostgreSQL
|
||||
// and handles other database-specific syntax differences
|
||||
func convertQuery(query string, dbType DBType) string {
|
||||
if dbType != DBTypePostgres {
|
||||
return query
|
||||
}
|
||||
result := query
|
||||
|
||||
// Convert ? to $1, $2, etc. for PostgreSQL
|
||||
index := 1
|
||||
for strings.Contains(result, "?") {
|
||||
result = strings.Replace(result, "?", fmt.Sprintf("$%d", index), 1)
|
||||
index++
|
||||
}
|
||||
|
||||
// Convert datetime('now') to CURRENT_TIMESTAMP
|
||||
result = strings.ReplaceAll(result, "datetime('now')", "CURRENT_TIMESTAMP")
|
||||
|
||||
// Remove datetime() wrapper for ORDER BY (PostgreSQL timestamps sort correctly)
|
||||
// This handles patterns like "ORDER BY datetime(column) DESC"
|
||||
result = strings.ReplaceAll(result, "datetime(updated_at)", "updated_at")
|
||||
result = strings.ReplaceAll(result, "datetime(created_at)", "created_at")
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// boolDefault returns database-appropriate boolean default for COALESCE
|
||||
// Use in queries like: COALESCE(column, %s)
|
||||
func boolDefault(dbType DBType, value bool) string {
|
||||
if dbType == DBTypePostgres {
|
||||
if value {
|
||||
return "TRUE"
|
||||
}
|
||||
return "FALSE"
|
||||
}
|
||||
if value {
|
||||
return "1"
|
||||
}
|
||||
return "0"
|
||||
}
|
||||
|
||||
+62
-140
@@ -1,56 +1,49 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// EquityStore account equity storage (for plotting return curves)
|
||||
type EquityStore struct {
|
||||
db *sql.DB
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// EquitySnapshot equity snapshot
|
||||
type EquitySnapshot struct {
|
||||
ID int64 `json:"id"`
|
||||
TraderID string `json:"trader_id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
TotalEquity float64 `json:"total_equity"` // Account equity (balance + unrealized PnL)
|
||||
Balance float64 `json:"balance"` // Account balance
|
||||
UnrealizedPnL float64 `json:"unrealized_pnl"` // Unrealized profit and loss
|
||||
PositionCount int `json:"position_count"` // Position count
|
||||
MarginUsedPct float64 `json:"margin_used_pct"` // Margin usage percentage
|
||||
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
TraderID string `gorm:"column:trader_id;not null;index:idx_equity_trader_time" json:"trader_id"`
|
||||
Timestamp time.Time `gorm:"not null;index:idx_equity_trader_time,sort:desc;index:idx_equity_timestamp,sort:desc" json:"timestamp"`
|
||||
TotalEquity float64 `gorm:"column:total_equity;not null;default:0" json:"total_equity"`
|
||||
Balance float64 `gorm:"not null;default:0" json:"balance"`
|
||||
UnrealizedPnL float64 `gorm:"column:unrealized_pnl;not null;default:0" json:"unrealized_pnl"`
|
||||
PositionCount int `gorm:"column:position_count;default:0" json:"position_count"`
|
||||
MarginUsedPct float64 `gorm:"column:margin_used_pct;default:0" json:"margin_used_pct"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
func (EquitySnapshot) TableName() string { return "trader_equity_snapshots" }
|
||||
|
||||
// NewEquityStore creates a new EquityStore
|
||||
func NewEquityStore(db *gorm.DB) *EquityStore {
|
||||
return &EquityStore{db: db}
|
||||
}
|
||||
|
||||
// initTables initializes equity tables
|
||||
func (s *EquityStore) initTables() error {
|
||||
queries := []string{
|
||||
// Equity snapshot table - specifically for return curves
|
||||
`CREATE TABLE IF NOT EXISTS trader_equity_snapshots (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
trader_id TEXT NOT NULL,
|
||||
timestamp DATETIME NOT NULL,
|
||||
total_equity REAL NOT NULL DEFAULT 0,
|
||||
balance REAL NOT NULL DEFAULT 0,
|
||||
unrealized_pnl REAL NOT NULL DEFAULT 0,
|
||||
position_count INTEGER DEFAULT 0,
|
||||
margin_used_pct REAL DEFAULT 0,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)`,
|
||||
// Indexes
|
||||
`CREATE INDEX IF NOT EXISTS idx_equity_trader_time ON trader_equity_snapshots(trader_id, timestamp DESC)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_equity_timestamp ON trader_equity_snapshots(timestamp DESC)`,
|
||||
}
|
||||
|
||||
for _, query := range queries {
|
||||
if _, err := s.db.Exec(query); err != nil {
|
||||
return fmt.Errorf("failed to execute SQL: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// For PostgreSQL with existing table, skip AutoMigrate
|
||||
if s.db.Dialector.Name() == "postgres" {
|
||||
var tableExists int64
|
||||
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'trader_equity_snapshots'`).Scan(&tableExists)
|
||||
if tableExists > 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return s.db.AutoMigrate(&EquitySnapshot{})
|
||||
}
|
||||
|
||||
// Save saves equity snapshot
|
||||
func (s *EquityStore) Save(snapshot *EquitySnapshot) error {
|
||||
@@ -60,58 +53,22 @@ func (s *EquityStore) Save(snapshot *EquitySnapshot) error {
|
||||
snapshot.Timestamp = snapshot.Timestamp.UTC()
|
||||
}
|
||||
|
||||
result, err := s.db.Exec(`
|
||||
INSERT INTO trader_equity_snapshots (
|
||||
trader_id, timestamp, total_equity, balance,
|
||||
unrealized_pnl, position_count, margin_used_pct
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
`,
|
||||
snapshot.TraderID,
|
||||
snapshot.Timestamp.Format(time.RFC3339),
|
||||
snapshot.TotalEquity,
|
||||
snapshot.Balance,
|
||||
snapshot.UnrealizedPnL,
|
||||
snapshot.PositionCount,
|
||||
snapshot.MarginUsedPct,
|
||||
)
|
||||
if err != nil {
|
||||
if err := s.db.Create(snapshot).Error; err != nil {
|
||||
return fmt.Errorf("failed to save equity snapshot: %w", err)
|
||||
}
|
||||
|
||||
id, _ := result.LastInsertId()
|
||||
snapshot.ID = id
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLatest gets the latest N equity records for specified trader (sorted in ascending chronological order: old to new)
|
||||
func (s *EquityStore) GetLatest(traderID string, limit int) ([]*EquitySnapshot, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, trader_id, timestamp, total_equity, balance,
|
||||
unrealized_pnl, position_count, margin_used_pct
|
||||
FROM trader_equity_snapshots
|
||||
WHERE trader_id = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
`, traderID, limit)
|
||||
var snapshots []*EquitySnapshot
|
||||
err := s.db.Where("trader_id = ?", traderID).
|
||||
Order("timestamp DESC").
|
||||
Limit(limit).
|
||||
Find(&snapshots).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query equity records: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var snapshots []*EquitySnapshot
|
||||
for rows.Next() {
|
||||
snap := &EquitySnapshot{}
|
||||
var timestampStr string
|
||||
err := rows.Scan(
|
||||
&snap.ID, &snap.TraderID, ×tampStr, &snap.TotalEquity,
|
||||
&snap.Balance, &snap.UnrealizedPnL, &snap.PositionCount, &snap.MarginUsedPct,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
snap.Timestamp, _ = time.Parse(time.RFC3339, timestampStr)
|
||||
snapshots = append(snapshots, snap)
|
||||
}
|
||||
|
||||
// Reverse the array to sort time from old to new (suitable for plotting curves)
|
||||
for i, j := 0, len(snapshots)-1; i < j; i, j = i+1, j-1 {
|
||||
@@ -123,116 +80,81 @@ func (s *EquityStore) GetLatest(traderID string, limit int) ([]*EquitySnapshot,
|
||||
|
||||
// GetByTimeRange gets equity records within specified time range
|
||||
func (s *EquityStore) GetByTimeRange(traderID string, start, end time.Time) ([]*EquitySnapshot, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, trader_id, timestamp, total_equity, balance,
|
||||
unrealized_pnl, position_count, margin_used_pct
|
||||
FROM trader_equity_snapshots
|
||||
WHERE trader_id = ? AND timestamp >= ? AND timestamp <= ?
|
||||
ORDER BY timestamp ASC
|
||||
`, traderID, start.Format(time.RFC3339), end.Format(time.RFC3339))
|
||||
var snapshots []*EquitySnapshot
|
||||
err := s.db.Where("trader_id = ? AND timestamp >= ? AND timestamp <= ?", traderID, start, end).
|
||||
Order("timestamp ASC").
|
||||
Find(&snapshots).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query equity records: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var snapshots []*EquitySnapshot
|
||||
for rows.Next() {
|
||||
snap := &EquitySnapshot{}
|
||||
var timestampStr string
|
||||
err := rows.Scan(
|
||||
&snap.ID, &snap.TraderID, ×tampStr, &snap.TotalEquity,
|
||||
&snap.Balance, &snap.UnrealizedPnL, &snap.PositionCount, &snap.MarginUsedPct,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
snap.Timestamp, _ = time.Parse(time.RFC3339, timestampStr)
|
||||
snapshots = append(snapshots, snap)
|
||||
}
|
||||
|
||||
return snapshots, nil
|
||||
}
|
||||
|
||||
// GetAllTradersLatest gets latest equity for all traders (for leaderboards)
|
||||
func (s *EquityStore) GetAllTradersLatest() (map[string]*EquitySnapshot, error) {
|
||||
rows, err := s.db.Query(`
|
||||
// Use raw SQL for this complex query with subquery
|
||||
var snapshots []*EquitySnapshot
|
||||
err := s.db.Raw(`
|
||||
SELECT e.id, e.trader_id, e.timestamp, e.total_equity, e.balance,
|
||||
e.unrealized_pnl, e.position_count, e.margin_used_pct
|
||||
e.unrealized_pnl, e.position_count, e.margin_used_pct, e.created_at
|
||||
FROM trader_equity_snapshots e
|
||||
INNER JOIN (
|
||||
SELECT trader_id, MAX(timestamp) as max_ts
|
||||
FROM trader_equity_snapshots
|
||||
GROUP BY trader_id
|
||||
) latest ON e.trader_id = latest.trader_id AND e.timestamp = latest.max_ts
|
||||
`)
|
||||
`).Scan(&snapshots).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query latest equity: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
result := make(map[string]*EquitySnapshot)
|
||||
for rows.Next() {
|
||||
snap := &EquitySnapshot{}
|
||||
var timestampStr string
|
||||
err := rows.Scan(
|
||||
&snap.ID, &snap.TraderID, ×tampStr, &snap.TotalEquity,
|
||||
&snap.Balance, &snap.UnrealizedPnL, &snap.PositionCount, &snap.MarginUsedPct,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
snap.Timestamp, _ = time.Parse(time.RFC3339, timestampStr)
|
||||
for _, snap := range snapshots {
|
||||
result[snap.TraderID] = snap
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// CleanOldRecords cleans old records from N days ago
|
||||
func (s *EquityStore) CleanOldRecords(traderID string, days int) (int64, error) {
|
||||
cutoffTime := time.Now().AddDate(0, 0, -days).Format(time.RFC3339)
|
||||
cutoffTime := time.Now().AddDate(0, 0, -days)
|
||||
|
||||
result, err := s.db.Exec(`
|
||||
DELETE FROM trader_equity_snapshots
|
||||
WHERE trader_id = ? AND timestamp < ?
|
||||
`, traderID, cutoffTime)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to clean old records: %w", err)
|
||||
result := s.db.Where("trader_id = ? AND timestamp < ?", traderID, cutoffTime).
|
||||
Delete(&EquitySnapshot{})
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("failed to clean old records: %w", result.Error)
|
||||
}
|
||||
|
||||
return result.RowsAffected()
|
||||
return result.RowsAffected, nil
|
||||
}
|
||||
|
||||
// GetCount gets record count for specified trader
|
||||
func (s *EquityStore) GetCount(traderID string) (int, error) {
|
||||
var count int
|
||||
err := s.db.QueryRow(`
|
||||
SELECT COUNT(*) FROM trader_equity_snapshots WHERE trader_id = ?
|
||||
`, traderID).Scan(&count)
|
||||
return count, err
|
||||
var count int64
|
||||
err := s.db.Model(&EquitySnapshot{}).Where("trader_id = ?", traderID).Count(&count).Error
|
||||
return int(count), err
|
||||
}
|
||||
|
||||
// MigrateFromDecision migrates data from old decision_account_snapshots table
|
||||
func (s *EquityStore) MigrateFromDecision() (int64, error) {
|
||||
// Check if migration is needed (whether new table is empty)
|
||||
var count int
|
||||
s.db.QueryRow(`SELECT COUNT(*) FROM trader_equity_snapshots`).Scan(&count)
|
||||
var count int64
|
||||
s.db.Model(&EquitySnapshot{}).Count(&count)
|
||||
if count > 0 {
|
||||
return 0, nil // Already has data, skip migration
|
||||
}
|
||||
|
||||
// Check if old table exists
|
||||
// Check if old table exists (SQLite specific check, but works for migration)
|
||||
var tableName string
|
||||
err := s.db.QueryRow(`
|
||||
err := s.db.Raw(`
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name='decision_account_snapshots'
|
||||
`).Scan(&tableName)
|
||||
if err != nil {
|
||||
`).Scan(&tableName).Error
|
||||
if err != nil || tableName == "" {
|
||||
return 0, nil // Old table doesn't exist, skip
|
||||
}
|
||||
|
||||
// Migrate data: join query from decision_records + decision_account_snapshots
|
||||
result, err := s.db.Exec(`
|
||||
result := s.db.Exec(`
|
||||
INSERT INTO trader_equity_snapshots (
|
||||
trader_id, timestamp, total_equity, balance,
|
||||
unrealized_pnl, position_count, margin_used_pct
|
||||
@@ -249,9 +171,9 @@ func (s *EquityStore) MigrateFromDecision() (int64, error) {
|
||||
JOIN decision_account_snapshots das ON dr.id = das.decision_id
|
||||
ORDER BY dr.timestamp ASC
|
||||
`)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to migrate data: %w", err)
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("failed to migrate data: %w", result.Error)
|
||||
}
|
||||
|
||||
return result.RowsAffected()
|
||||
return result.RowsAffected, nil
|
||||
}
|
||||
|
||||
+140
-289
@@ -1,83 +1,68 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"nofx/crypto"
|
||||
"nofx/logger"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ExchangeStore exchange storage
|
||||
type ExchangeStore struct {
|
||||
db *sql.DB
|
||||
encryptFunc func(string) string
|
||||
decryptFunc func(string) string
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// Exchange exchange configuration
|
||||
type Exchange struct {
|
||||
ID string `json:"id"` // UUID
|
||||
ExchangeType string `json:"exchange_type"` // "binance", "bybit", "okx", "hyperliquid", "aster", "lighter"
|
||||
AccountName string `json:"account_name"` // User-defined account name
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"` // Display name (auto-generated or user-defined)
|
||||
Type string `json:"type"` // "cex" or "dex"
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"apiKey"`
|
||||
SecretKey string `json:"secretKey"`
|
||||
Passphrase string `json:"passphrase"` // OKX-specific
|
||||
Testnet bool `json:"testnet"`
|
||||
HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"`
|
||||
AsterUser string `json:"asterUser"`
|
||||
AsterSigner string `json:"asterSigner"`
|
||||
AsterPrivateKey string `json:"asterPrivateKey"`
|
||||
LighterWalletAddr string `json:"lighterWalletAddr"`
|
||||
LighterPrivateKey string `json:"lighterPrivateKey"`
|
||||
LighterAPIKeyPrivateKey string `json:"lighterAPIKeyPrivateKey"`
|
||||
LighterAPIKeyIndex int `json:"lighterAPIKeyIndex"`
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
ExchangeType string `gorm:"column:exchange_type;not null;default:''" json:"exchange_type"`
|
||||
AccountName string `gorm:"column:account_name;not null;default:''" json:"account_name"`
|
||||
UserID string `gorm:"column:user_id;not null;default:default;index" json:"user_id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
Type string `gorm:"not null" json:"type"` // "cex" or "dex"
|
||||
Enabled bool `gorm:"default:false" json:"enabled"`
|
||||
APIKey crypto.EncryptedString `gorm:"column:api_key;default:''" json:"apiKey"`
|
||||
SecretKey crypto.EncryptedString `gorm:"column:secret_key;default:''" json:"secretKey"`
|
||||
Passphrase crypto.EncryptedString `gorm:"column:passphrase;default:''" json:"passphrase"`
|
||||
Testnet bool `gorm:"default:false" json:"testnet"`
|
||||
HyperliquidWalletAddr string `gorm:"column:hyperliquid_wallet_addr;default:''" json:"hyperliquidWalletAddr"`
|
||||
AsterUser string `gorm:"column:aster_user;default:''" json:"asterUser"`
|
||||
AsterSigner string `gorm:"column:aster_signer;default:''" json:"asterSigner"`
|
||||
AsterPrivateKey crypto.EncryptedString `gorm:"column:aster_private_key;default:''" json:"asterPrivateKey"`
|
||||
LighterWalletAddr string `gorm:"column:lighter_wallet_addr;default:''" json:"lighterWalletAddr"`
|
||||
LighterPrivateKey crypto.EncryptedString `gorm:"column:lighter_private_key;default:''" json:"lighterPrivateKey"`
|
||||
LighterAPIKeyPrivateKey crypto.EncryptedString `gorm:"column:lighter_api_key_private_key;default:''" json:"lighterAPIKeyPrivateKey"`
|
||||
LighterAPIKeyIndex int `gorm:"column:lighter_api_key_index;default:0" json:"lighterAPIKeyIndex"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
func (s *ExchangeStore) initTables() error {
|
||||
// Create new table structure with UUID as primary key
|
||||
_, err := s.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS exchanges (
|
||||
id TEXT PRIMARY KEY,
|
||||
exchange_type TEXT NOT NULL DEFAULT '',
|
||||
account_name TEXT NOT NULL DEFAULT '',
|
||||
user_id TEXT NOT NULL DEFAULT 'default',
|
||||
name TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
enabled BOOLEAN DEFAULT 0,
|
||||
api_key TEXT DEFAULT '',
|
||||
secret_key TEXT DEFAULT '',
|
||||
passphrase TEXT DEFAULT '',
|
||||
testnet BOOLEAN DEFAULT 0,
|
||||
hyperliquid_wallet_addr TEXT DEFAULT '',
|
||||
aster_user TEXT DEFAULT '',
|
||||
aster_signer TEXT DEFAULT '',
|
||||
aster_private_key TEXT DEFAULT '',
|
||||
lighter_wallet_addr TEXT DEFAULT '',
|
||||
lighter_private_key TEXT DEFAULT '',
|
||||
lighter_api_key_private_key TEXT DEFAULT '',
|
||||
lighter_api_key_index INTEGER DEFAULT 0,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
func (Exchange) TableName() string { return "exchanges" }
|
||||
|
||||
// NewExchangeStore creates a new ExchangeStore
|
||||
func NewExchangeStore(db *gorm.DB) *ExchangeStore {
|
||||
return &ExchangeStore{db: db}
|
||||
}
|
||||
|
||||
// Migration: add new columns if not exists
|
||||
s.db.Exec(`ALTER TABLE exchanges ADD COLUMN passphrase TEXT DEFAULT ''`)
|
||||
s.db.Exec(`ALTER TABLE exchanges ADD COLUMN exchange_type TEXT NOT NULL DEFAULT ''`)
|
||||
s.db.Exec(`ALTER TABLE exchanges ADD COLUMN account_name TEXT NOT NULL DEFAULT ''`)
|
||||
s.db.Exec(`ALTER TABLE exchanges ADD COLUMN lighter_api_key_index INTEGER DEFAULT 0`)
|
||||
func (s *ExchangeStore) initTables() error {
|
||||
// For PostgreSQL with existing table, skip AutoMigrate
|
||||
if s.db.Dialector.Name() == "postgres" {
|
||||
var tableExists int64
|
||||
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'exchanges'`).Scan(&tableExists)
|
||||
if tableExists > 0 {
|
||||
// Still run data migrations
|
||||
s.migrateToMultiAccount()
|
||||
s.db.Model(&Exchange{}).Where("account_name = '' OR account_name IS NULL").Update("account_name", "Default")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.db.AutoMigrate(&Exchange{}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Run migration to multi-account if needed
|
||||
if err := s.migrateToMultiAccount(); err != nil {
|
||||
@@ -85,120 +70,65 @@ func (s *ExchangeStore) initTables() error {
|
||||
}
|
||||
|
||||
// Fix empty account_name for existing records
|
||||
s.db.Exec(`UPDATE exchanges SET account_name = 'Default' WHERE account_name = '' OR account_name IS NULL`)
|
||||
s.db.Model(&Exchange{}).Where("account_name = '' OR account_name IS NULL").Update("account_name", "Default")
|
||||
|
||||
// Update trigger for new schema
|
||||
s.db.Exec(`DROP TRIGGER IF EXISTS update_exchanges_updated_at`)
|
||||
_, err = s.db.Exec(`
|
||||
CREATE TRIGGER IF NOT EXISTS update_exchanges_updated_at
|
||||
AFTER UPDATE ON exchanges
|
||||
BEGIN
|
||||
UPDATE exchanges SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
|
||||
END
|
||||
`)
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateToMultiAccount migrates old schema (id=exchange_type) to new schema (id=UUID)
|
||||
func (s *ExchangeStore) migrateToMultiAccount() error {
|
||||
// Check if migration is needed by looking for old-style IDs (non-UUID)
|
||||
var count int
|
||||
err := s.db.QueryRow(`
|
||||
SELECT COUNT(*) FROM exchanges
|
||||
WHERE exchange_type = '' AND id IN ('binance', 'bybit', 'okx', 'bitget', 'hyperliquid', 'aster', 'lighter')
|
||||
`).Scan(&count)
|
||||
var count int64
|
||||
err := s.db.Model(&Exchange{}).
|
||||
Where("exchange_type = '' AND id IN ?", []string{"binance", "bybit", "okx", "bitget", "hyperliquid", "aster", "lighter"}).
|
||||
Count(&count).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
// No migration needed
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Infof("🔄 Migrating %d exchange records to multi-account schema...", count)
|
||||
|
||||
// Get all old records
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, user_id, name, type, enabled, api_key, secret_key,
|
||||
COALESCE(passphrase, '') as passphrase, testnet,
|
||||
COALESCE(hyperliquid_wallet_addr, '') as hyperliquid_wallet_addr,
|
||||
COALESCE(aster_user, '') as aster_user,
|
||||
COALESCE(aster_signer, '') as aster_signer,
|
||||
COALESCE(aster_private_key, '') as aster_private_key,
|
||||
COALESCE(lighter_wallet_addr, '') as lighter_wallet_addr,
|
||||
COALESCE(lighter_private_key, '') as lighter_private_key,
|
||||
COALESCE(lighter_api_key_private_key, '') as lighter_api_key_private_key
|
||||
FROM exchanges
|
||||
WHERE exchange_type = '' AND id IN ('binance', 'bybit', 'okx', 'bitget', 'hyperliquid', 'aster', 'lighter')
|
||||
`)
|
||||
var records []Exchange
|
||||
err = s.db.Where("exchange_type = '' AND id IN ?", []string{"binance", "bybit", "okx", "bitget", "hyperliquid", "aster", "lighter"}).
|
||||
Find(&records).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
type oldRecord struct {
|
||||
id, userID, name, typ string
|
||||
enabled, testnet bool
|
||||
apiKey, secretKey, passphrase string
|
||||
hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string
|
||||
lighterWalletAddr, lighterPrivateKey, lighterApiKeyPrivateKey string
|
||||
}
|
||||
|
||||
var records []oldRecord
|
||||
for rows.Next() {
|
||||
var r oldRecord
|
||||
if err := rows.Scan(&r.id, &r.userID, &r.name, &r.typ, &r.enabled,
|
||||
&r.apiKey, &r.secretKey, &r.passphrase, &r.testnet,
|
||||
&r.hyperliquidWalletAddr, &r.asterUser, &r.asterSigner, &r.asterPrivateKey,
|
||||
&r.lighterWalletAddr, &r.lighterPrivateKey, &r.lighterApiKeyPrivateKey); err != nil {
|
||||
return err
|
||||
}
|
||||
records = append(records, r)
|
||||
}
|
||||
|
||||
// Begin transaction
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Migrate each record
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
for _, r := range records {
|
||||
newID := uuid.New().String()
|
||||
oldID := r.id // This is the exchange type (e.g., "binance")
|
||||
oldID := r.ID // This is the exchange type (e.g., "binance")
|
||||
|
||||
// Update traders table to use new UUID
|
||||
_, err = tx.Exec(`UPDATE traders SET exchange_id = ? WHERE exchange_id = ? AND user_id = ?`,
|
||||
newID, oldID, r.userID)
|
||||
if err != nil {
|
||||
if err := tx.Exec("UPDATE traders SET exchange_id = ? WHERE exchange_id = ? AND user_id = ?",
|
||||
newID, oldID, r.UserID).Error; err != nil {
|
||||
logger.Errorf("Failed to update traders for exchange %s: %v", oldID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the exchange record
|
||||
_, err = tx.Exec(`
|
||||
UPDATE exchanges SET
|
||||
id = ?,
|
||||
exchange_type = ?,
|
||||
account_name = ?
|
||||
WHERE id = ? AND user_id = ?
|
||||
`, newID, oldID, "Default", oldID, r.userID)
|
||||
if err != nil {
|
||||
if err := tx.Model(&Exchange{}).
|
||||
Where("id = ? AND user_id = ?", oldID, r.UserID).
|
||||
Updates(map[string]interface{}{
|
||||
"id": newID,
|
||||
"exchange_type": oldID,
|
||||
"account_name": "Default",
|
||||
}).Error; err != nil {
|
||||
logger.Errorf("Failed to migrate exchange %s: %v", oldID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Infof("✅ Migrated exchange %s -> UUID %s for user %s", oldID, newID, r.userID)
|
||||
logger.Infof("✅ Migrated exchange %s -> UUID %s for user %s", oldID, newID, r.UserID)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Infof("✅ Multi-account migration completed successfully")
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ExchangeStore) initDefaultData() error {
|
||||
@@ -206,108 +136,24 @@ func (s *ExchangeStore) initDefaultData() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ExchangeStore) encrypt(plaintext string) string {
|
||||
if s.encryptFunc != nil {
|
||||
return s.encryptFunc(plaintext)
|
||||
}
|
||||
return plaintext
|
||||
}
|
||||
|
||||
func (s *ExchangeStore) decrypt(encrypted string) string {
|
||||
if s.decryptFunc != nil {
|
||||
return s.decryptFunc(encrypted)
|
||||
}
|
||||
return encrypted
|
||||
}
|
||||
|
||||
// List gets user's exchange list
|
||||
func (s *ExchangeStore) List(userID string) ([]*Exchange, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, COALESCE(exchange_type, '') as exchange_type, COALESCE(account_name, '') as account_name,
|
||||
user_id, name, type, enabled, api_key, secret_key,
|
||||
COALESCE(passphrase, '') as passphrase, testnet,
|
||||
COALESCE(hyperliquid_wallet_addr, '') as hyperliquid_wallet_addr,
|
||||
COALESCE(aster_user, '') as aster_user,
|
||||
COALESCE(aster_signer, '') as aster_signer,
|
||||
COALESCE(aster_private_key, '') as aster_private_key,
|
||||
COALESCE(lighter_wallet_addr, '') as lighter_wallet_addr,
|
||||
COALESCE(lighter_private_key, '') as lighter_private_key,
|
||||
COALESCE(lighter_api_key_private_key, '') as lighter_api_key_private_key,
|
||||
COALESCE(lighter_api_key_index, 0) as lighter_api_key_index,
|
||||
created_at, updated_at
|
||||
FROM exchanges WHERE user_id = ? ORDER BY exchange_type, account_name
|
||||
`, userID)
|
||||
var exchanges []*Exchange
|
||||
err := s.db.Where("user_id = ?", userID).Order("exchange_type, account_name").Find(&exchanges).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
exchanges := make([]*Exchange, 0)
|
||||
for rows.Next() {
|
||||
var e Exchange
|
||||
var createdAt, updatedAt string
|
||||
err := rows.Scan(
|
||||
&e.ID, &e.ExchangeType, &e.AccountName,
|
||||
&e.UserID, &e.Name, &e.Type,
|
||||
&e.Enabled, &e.APIKey, &e.SecretKey, &e.Passphrase, &e.Testnet,
|
||||
&e.HyperliquidWalletAddr, &e.AsterUser, &e.AsterSigner, &e.AsterPrivateKey,
|
||||
&e.LighterWalletAddr, &e.LighterPrivateKey, &e.LighterAPIKeyPrivateKey, &e.LighterAPIKeyIndex,
|
||||
&createdAt, &updatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
e.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
e.APIKey = s.decrypt(e.APIKey)
|
||||
e.SecretKey = s.decrypt(e.SecretKey)
|
||||
e.Passphrase = s.decrypt(e.Passphrase)
|
||||
e.AsterPrivateKey = s.decrypt(e.AsterPrivateKey)
|
||||
e.LighterPrivateKey = s.decrypt(e.LighterPrivateKey)
|
||||
e.LighterAPIKeyPrivateKey = s.decrypt(e.LighterAPIKeyPrivateKey)
|
||||
exchanges = append(exchanges, &e)
|
||||
}
|
||||
return exchanges, nil
|
||||
}
|
||||
|
||||
// GetByID gets a specific exchange by UUID
|
||||
func (s *ExchangeStore) GetByID(userID, id string) (*Exchange, error) {
|
||||
var e Exchange
|
||||
var createdAt, updatedAt string
|
||||
err := s.db.QueryRow(`
|
||||
SELECT id, COALESCE(exchange_type, '') as exchange_type, COALESCE(account_name, '') as account_name,
|
||||
user_id, name, type, enabled, api_key, secret_key,
|
||||
COALESCE(passphrase, '') as passphrase, testnet,
|
||||
COALESCE(hyperliquid_wallet_addr, '') as hyperliquid_wallet_addr,
|
||||
COALESCE(aster_user, '') as aster_user,
|
||||
COALESCE(aster_signer, '') as aster_signer,
|
||||
COALESCE(aster_private_key, '') as aster_private_key,
|
||||
COALESCE(lighter_wallet_addr, '') as lighter_wallet_addr,
|
||||
COALESCE(lighter_private_key, '') as lighter_private_key,
|
||||
COALESCE(lighter_api_key_private_key, '') as lighter_api_key_private_key,
|
||||
COALESCE(lighter_api_key_index, 0) as lighter_api_key_index,
|
||||
created_at, updated_at
|
||||
FROM exchanges WHERE id = ? AND user_id = ?
|
||||
`, id, userID).Scan(
|
||||
&e.ID, &e.ExchangeType, &e.AccountName,
|
||||
&e.UserID, &e.Name, &e.Type,
|
||||
&e.Enabled, &e.APIKey, &e.SecretKey, &e.Passphrase, &e.Testnet,
|
||||
&e.HyperliquidWalletAddr, &e.AsterUser, &e.AsterSigner, &e.AsterPrivateKey,
|
||||
&e.LighterWalletAddr, &e.LighterPrivateKey, &e.LighterAPIKeyPrivateKey, &e.LighterAPIKeyIndex,
|
||||
&createdAt, &updatedAt,
|
||||
)
|
||||
var exchange Exchange
|
||||
err := s.db.Where("id = ? AND user_id = ?", id, userID).First(&exchange).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
e.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
e.APIKey = s.decrypt(e.APIKey)
|
||||
e.SecretKey = s.decrypt(e.SecretKey)
|
||||
e.Passphrase = s.decrypt(e.Passphrase)
|
||||
e.AsterPrivateKey = s.decrypt(e.AsterPrivateKey)
|
||||
e.LighterPrivateKey = s.decrypt(e.LighterPrivateKey)
|
||||
e.LighterAPIKeyPrivateKey = s.decrypt(e.LighterAPIKeyPrivateKey)
|
||||
return &e, nil
|
||||
return &exchange, nil
|
||||
}
|
||||
|
||||
// getExchangeNameAndType returns the display name and type for an exchange type
|
||||
@@ -341,7 +187,6 @@ func (s *ExchangeStore) Create(userID, exchangeType, accountName string, enabled
|
||||
id := uuid.New().String()
|
||||
name, typ := getExchangeNameAndType(exchangeType)
|
||||
|
||||
// If account name is empty, use "Default"
|
||||
if accountName == "" {
|
||||
accountName = "Default"
|
||||
}
|
||||
@@ -349,19 +194,29 @@ func (s *ExchangeStore) Create(userID, exchangeType, accountName string, enabled
|
||||
logger.Debugf("🔧 ExchangeStore.Create: userID=%s, exchangeType=%s, accountName=%s, id=%s",
|
||||
userID, exchangeType, accountName, id)
|
||||
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO exchanges (id, exchange_type, account_name, user_id, name, type, enabled,
|
||||
api_key, secret_key, passphrase, testnet,
|
||||
hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key,
|
||||
lighter_wallet_addr, lighter_private_key, lighter_api_key_private_key, lighter_api_key_index,
|
||||
created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now'))
|
||||
`, id, exchangeType, accountName, userID, name, typ, enabled,
|
||||
s.encrypt(apiKey), s.encrypt(secretKey), s.encrypt(passphrase), testnet,
|
||||
hyperliquidWalletAddr, asterUser, asterSigner, s.encrypt(asterPrivateKey),
|
||||
lighterWalletAddr, s.encrypt(lighterPrivateKey), s.encrypt(lighterApiKeyPrivateKey), lighterApiKeyIndex)
|
||||
exchange := &Exchange{
|
||||
ID: id,
|
||||
ExchangeType: exchangeType,
|
||||
AccountName: accountName,
|
||||
UserID: userID,
|
||||
Name: name,
|
||||
Type: typ,
|
||||
Enabled: enabled,
|
||||
APIKey: crypto.EncryptedString(apiKey),
|
||||
SecretKey: crypto.EncryptedString(secretKey),
|
||||
Passphrase: crypto.EncryptedString(passphrase),
|
||||
Testnet: testnet,
|
||||
HyperliquidWalletAddr: hyperliquidWalletAddr,
|
||||
AsterUser: asterUser,
|
||||
AsterSigner: asterSigner,
|
||||
AsterPrivateKey: crypto.EncryptedString(asterPrivateKey),
|
||||
LighterWalletAddr: lighterWalletAddr,
|
||||
LighterPrivateKey: crypto.EncryptedString(lighterPrivateKey),
|
||||
LighterAPIKeyPrivateKey: crypto.EncryptedString(lighterApiKeyPrivateKey),
|
||||
LighterAPIKeyIndex: lighterApiKeyIndex,
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err := s.db.Create(exchange).Error; err != nil {
|
||||
return "", err
|
||||
}
|
||||
return id, nil
|
||||
@@ -373,53 +228,42 @@ func (s *ExchangeStore) Update(userID, id string, enabled bool, apiKey, secretKe
|
||||
|
||||
logger.Debugf("🔧 ExchangeStore.Update: userID=%s, id=%s, enabled=%v", userID, id, enabled)
|
||||
|
||||
setClauses := []string{
|
||||
"enabled = ?",
|
||||
"testnet = ?",
|
||||
"hyperliquid_wallet_addr = ?",
|
||||
"aster_user = ?",
|
||||
"aster_signer = ?",
|
||||
"lighter_wallet_addr = ?",
|
||||
"lighter_api_key_index = ?",
|
||||
"updated_at = datetime('now')",
|
||||
updates := map[string]interface{}{
|
||||
"enabled": enabled,
|
||||
"testnet": testnet,
|
||||
"hyperliquid_wallet_addr": hyperliquidWalletAddr,
|
||||
"aster_user": asterUser,
|
||||
"aster_signer": asterSigner,
|
||||
"lighter_wallet_addr": lighterWalletAddr,
|
||||
"lighter_api_key_index": lighterApiKeyIndex,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
args := []interface{}{enabled, testnet, hyperliquidWalletAddr, asterUser, asterSigner, lighterWalletAddr, lighterApiKeyIndex}
|
||||
|
||||
// Only update encrypted fields if not empty
|
||||
if apiKey != "" {
|
||||
setClauses = append(setClauses, "api_key = ?")
|
||||
args = append(args, s.encrypt(apiKey))
|
||||
updates["api_key"] = crypto.EncryptedString(apiKey)
|
||||
}
|
||||
if secretKey != "" {
|
||||
setClauses = append(setClauses, "secret_key = ?")
|
||||
args = append(args, s.encrypt(secretKey))
|
||||
updates["secret_key"] = crypto.EncryptedString(secretKey)
|
||||
}
|
||||
if passphrase != "" {
|
||||
setClauses = append(setClauses, "passphrase = ?")
|
||||
args = append(args, s.encrypt(passphrase))
|
||||
updates["passphrase"] = crypto.EncryptedString(passphrase)
|
||||
}
|
||||
if asterPrivateKey != "" {
|
||||
setClauses = append(setClauses, "aster_private_key = ?")
|
||||
args = append(args, s.encrypt(asterPrivateKey))
|
||||
updates["aster_private_key"] = crypto.EncryptedString(asterPrivateKey)
|
||||
}
|
||||
if lighterPrivateKey != "" {
|
||||
setClauses = append(setClauses, "lighter_private_key = ?")
|
||||
args = append(args, s.encrypt(lighterPrivateKey))
|
||||
updates["lighter_private_key"] = crypto.EncryptedString(lighterPrivateKey)
|
||||
}
|
||||
if lighterApiKeyPrivateKey != "" {
|
||||
setClauses = append(setClauses, "lighter_api_key_private_key = ?")
|
||||
args = append(args, s.encrypt(lighterApiKeyPrivateKey))
|
||||
updates["lighter_api_key_private_key"] = crypto.EncryptedString(lighterApiKeyPrivateKey)
|
||||
}
|
||||
|
||||
args = append(args, id, userID)
|
||||
query := fmt.Sprintf(`UPDATE exchanges SET %s WHERE id = ? AND user_id = ?`, strings.Join(setClauses, ", "))
|
||||
|
||||
result, err := s.db.Exec(query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
result := s.db.Model(&Exchange{}).Where("id = ? AND user_id = ?", id, userID).Updates(updates)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
rowsAffected, _ := result.RowsAffected()
|
||||
if rowsAffected == 0 {
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("exchange not found: id=%s, userID=%s", id, userID)
|
||||
}
|
||||
return nil
|
||||
@@ -427,13 +271,16 @@ func (s *ExchangeStore) Update(userID, id string, enabled bool, apiKey, secretKe
|
||||
|
||||
// UpdateAccountName updates the account name for an exchange
|
||||
func (s *ExchangeStore) UpdateAccountName(userID, id, accountName string) error {
|
||||
result, err := s.db.Exec(`UPDATE exchanges SET account_name = ?, updated_at = datetime('now') WHERE id = ? AND user_id = ?`,
|
||||
accountName, id, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
result := s.db.Model(&Exchange{}).
|
||||
Where("id = ? AND user_id = ?", id, userID).
|
||||
Updates(map[string]interface{}{
|
||||
"account_name": accountName,
|
||||
"updated_at": time.Now(),
|
||||
})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
rowsAffected, _ := result.RowsAffected()
|
||||
if rowsAffected == 0 {
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("exchange not found: id=%s, userID=%s", id, userID)
|
||||
}
|
||||
return nil
|
||||
@@ -441,12 +288,11 @@ func (s *ExchangeStore) UpdateAccountName(userID, id, accountName string) error
|
||||
|
||||
// Delete deletes an exchange account
|
||||
func (s *ExchangeStore) Delete(userID, id string) error {
|
||||
result, err := s.db.Exec(`DELETE FROM exchanges WHERE id = ? AND user_id = ?`, id, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
result := s.db.Where("id = ? AND user_id = ?", id, userID).Delete(&Exchange{})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
rowsAffected, _ := result.RowsAffected()
|
||||
if rowsAffected == 0 {
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("exchange not found: id=%s, userID=%s", id, userID)
|
||||
}
|
||||
logger.Infof("🗑️ Deleted exchange: id=%s, userID=%s", id, userID)
|
||||
@@ -460,20 +306,25 @@ func (s *ExchangeStore) CreateLegacy(userID, id, name, typ string, enabled bool,
|
||||
|
||||
// Check if this is an old-style ID (exchange type as ID)
|
||||
if id == "binance" || id == "bybit" || id == "okx" || id == "bitget" || id == "hyperliquid" || id == "aster" || id == "lighter" {
|
||||
// Use new Create method with exchange type
|
||||
_, err := s.Create(userID, id, "Default", enabled, apiKey, secretKey, "", testnet,
|
||||
hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey, "", "", "", 0)
|
||||
return err
|
||||
}
|
||||
|
||||
// Otherwise assume it's already a UUID
|
||||
_, err := s.db.Exec(`
|
||||
INSERT OR IGNORE INTO exchanges (id, exchange_type, account_name, user_id, name, type, enabled,
|
||||
api_key, secret_key, testnet,
|
||||
hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key,
|
||||
lighter_wallet_addr, lighter_private_key)
|
||||
VALUES (?, '', '', ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, '', '')
|
||||
`, id, userID, name, typ, enabled, s.encrypt(apiKey), s.encrypt(secretKey), testnet,
|
||||
hyperliquidWalletAddr, asterUser, asterSigner, s.encrypt(asterPrivateKey))
|
||||
return err
|
||||
exchange := &Exchange{
|
||||
ID: id,
|
||||
UserID: userID,
|
||||
Name: name,
|
||||
Type: typ,
|
||||
Enabled: enabled,
|
||||
APIKey: crypto.EncryptedString(apiKey),
|
||||
SecretKey: crypto.EncryptedString(secretKey),
|
||||
Testnet: testnet,
|
||||
HyperliquidWalletAddr: hyperliquidWalletAddr,
|
||||
AsterUser: asterUser,
|
||||
AsterSigner: asterSigner,
|
||||
AsterPrivateKey: crypto.EncryptedString(asterPrivateKey),
|
||||
}
|
||||
return s.db.Where("id = ?", id).FirstOrCreate(exchange).Error
|
||||
}
|
||||
|
||||
+146
@@ -0,0 +1,146 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// GormDB is the global GORM database connection
|
||||
var gormDB *gorm.DB
|
||||
|
||||
// DB returns the GORM database connection
|
||||
func DB() *gorm.DB {
|
||||
return gormDB
|
||||
}
|
||||
|
||||
// InitGorm initializes GORM with SQLite
|
||||
func InitGorm(dbPath string) (*gorm.DB, error) {
|
||||
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open SQLite database: %w", err)
|
||||
}
|
||||
|
||||
// Set connection pool for SQLite
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
sqlDB.SetMaxIdleConns(1)
|
||||
|
||||
// Enable foreign keys for SQLite
|
||||
db.Exec("PRAGMA foreign_keys = ON")
|
||||
db.Exec("PRAGMA journal_mode = DELETE")
|
||||
db.Exec("PRAGMA synchronous = FULL")
|
||||
db.Exec("PRAGMA busy_timeout = 5000")
|
||||
|
||||
gormDB = db
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// InitGormPostgres initializes GORM with PostgreSQL
|
||||
func InitGormPostgres(host string, port int, user, password, dbname, sslmode string) (*gorm.DB, error) {
|
||||
dsn := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
||||
host, port, user, password, dbname, sslmode,
|
||||
)
|
||||
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err)
|
||||
}
|
||||
|
||||
// Set connection pool for PostgreSQL
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlDB.SetMaxOpenConns(25)
|
||||
sqlDB.SetMaxIdleConns(5)
|
||||
|
||||
gormDB = db
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// InitGormWithConfig initializes GORM with provided configuration
|
||||
// Uses DBConfig from driver.go
|
||||
func InitGormWithConfig(cfg DBConfig) (*gorm.DB, error) {
|
||||
switch cfg.Type {
|
||||
case DBTypeSQLite:
|
||||
return InitGorm(cfg.Path)
|
||||
|
||||
case DBTypePostgres:
|
||||
return InitGormPostgres(
|
||||
cfg.Host,
|
||||
cfg.Port,
|
||||
cfg.User,
|
||||
cfg.Password,
|
||||
cfg.DBName,
|
||||
cfg.SSLMode,
|
||||
)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported DB_TYPE: %s (use 'sqlite' or 'postgres')", cfg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Query Scopes - Reusable query helpers
|
||||
// ============================================================================
|
||||
|
||||
// ForUser returns a scope that filters by user_id
|
||||
func ForUser(userID string) func(*gorm.DB) *gorm.DB {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Where("user_id = ?", userID)
|
||||
}
|
||||
}
|
||||
|
||||
// ForTrader returns a scope that filters by trader_id
|
||||
func ForTrader(traderID string) func(*gorm.DB) *gorm.DB {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Where("trader_id = ?", traderID)
|
||||
}
|
||||
}
|
||||
|
||||
// OpenPositions returns a scope for open positions
|
||||
func OpenPositions() func(*gorm.DB) *gorm.DB {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Where("status = ?", "OPEN")
|
||||
}
|
||||
}
|
||||
|
||||
// ClosedPositions returns a scope for closed positions
|
||||
func ClosedPositions() func(*gorm.DB) *gorm.DB {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Where("status = ?", "CLOSED")
|
||||
}
|
||||
}
|
||||
|
||||
// OrderByCreatedDesc returns a scope that orders by created_at DESC
|
||||
func OrderByCreatedDesc() func(*gorm.DB) *gorm.DB {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Order("created_at DESC")
|
||||
}
|
||||
}
|
||||
|
||||
// OrderByUpdatedDesc returns a scope that orders by updated_at DESC
|
||||
func OrderByUpdatedDesc() func(*gorm.DB) *gorm.DB {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Order("updated_at DESC")
|
||||
}
|
||||
}
|
||||
|
||||
// Paginate returns a scope for pagination
|
||||
func Paginate(limit, offset int) func(*gorm.DB) *gorm.DB {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Limit(limit).Offset(offset)
|
||||
}
|
||||
}
|
||||
+197
-446
@@ -1,495 +1,255 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TraderOrder 订单记录(完整的订单生命周期追踪)
|
||||
// TraderOrder order record
|
||||
type TraderOrder struct {
|
||||
ID int64 `json:"id"`
|
||||
TraderID string `json:"trader_id"`
|
||||
ExchangeID string `json:"exchange_id"` // Exchange account UUID
|
||||
ExchangeType string `json:"exchange_type"` // Exchange type (hyperliquid/lighter/binance/etc)
|
||||
ExchangeOrderID string `json:"exchange_order_id"` // Exchange order ID
|
||||
ClientOrderID string `json:"client_order_id"` // Client order ID
|
||||
Symbol string `json:"symbol"` // Trading pair
|
||||
Side string `json:"side"` // BUY/SELL
|
||||
PositionSide string `json:"position_side"` // LONG/SHORT (hedge mode)
|
||||
Type string `json:"type"` // MARKET/LIMIT/STOP/STOP_MARKET/TAKE_PROFIT/TAKE_PROFIT_MARKET
|
||||
TimeInForce string `json:"time_in_force"` // GTC/IOC/FOK
|
||||
Quantity float64 `json:"quantity"` // 订单数量
|
||||
Price float64 `json:"price"` // 限价单价格
|
||||
StopPrice float64 `json:"stop_price"` // 止损/止盈触发价格
|
||||
Status string `json:"status"` // NEW/PARTIALLY_FILLED/FILLED/CANCELED/REJECTED/EXPIRED
|
||||
FilledQuantity float64 `json:"filled_quantity"` // 已成交数量
|
||||
AvgFillPrice float64 `json:"avg_fill_price"` // 平均成交价格
|
||||
Commission float64 `json:"commission"` // 手续费总额
|
||||
CommissionAsset string `json:"commission_asset"` // 手续费资产(USDT等)
|
||||
Leverage int `json:"leverage"` // 杠杆倍数
|
||||
ReduceOnly bool `json:"reduce_only"` // 是否只减仓
|
||||
ClosePosition bool `json:"close_position"` // 是否平仓单
|
||||
WorkingType string `json:"working_type"` // CONTRACT_PRICE/MARK_PRICE
|
||||
PriceProtect bool `json:"price_protect"` // 价格保护
|
||||
OrderAction string `json:"order_action"` // OPEN_LONG/OPEN_SHORT/CLOSE_LONG/CLOSE_SHORT/ADD_LONG/ADD_SHORT/STOP_LOSS/TAKE_PROFIT
|
||||
RelatedPositionID int64 `json:"related_position_id"` // 关联的仓位ID
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
FilledAt time.Time `json:"filled_at"` // 完全成交时间
|
||||
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
TraderID string `gorm:"column:trader_id;not null;index:idx_orders_trader_id" json:"trader_id"`
|
||||
ExchangeID string `gorm:"column:exchange_id;not null;default:''" json:"exchange_id"`
|
||||
ExchangeType string `gorm:"column:exchange_type;not null;default:''" json:"exchange_type"`
|
||||
ExchangeOrderID string `gorm:"column:exchange_order_id;not null;uniqueIndex:idx_orders_exchange_unique,priority:2" json:"exchange_order_id"`
|
||||
ClientOrderID string `gorm:"column:client_order_id;default:''" json:"client_order_id"`
|
||||
Symbol string `gorm:"column:symbol;not null;index:idx_orders_symbol" json:"symbol"`
|
||||
Side string `gorm:"column:side;not null" json:"side"`
|
||||
PositionSide string `gorm:"column:position_side;default:''" json:"position_side"`
|
||||
Type string `gorm:"column:type;not null" json:"type"`
|
||||
TimeInForce string `gorm:"column:time_in_force;default:GTC" json:"time_in_force"`
|
||||
Quantity float64 `gorm:"column:quantity;not null" json:"quantity"`
|
||||
Price float64 `gorm:"column:price;default:0" json:"price"`
|
||||
StopPrice float64 `gorm:"column:stop_price;default:0" json:"stop_price"`
|
||||
Status string `gorm:"column:status;not null;default:NEW;index:idx_orders_status" json:"status"`
|
||||
FilledQuantity float64 `gorm:"column:filled_quantity;default:0" json:"filled_quantity"`
|
||||
AvgFillPrice float64 `gorm:"column:avg_fill_price;default:0" json:"avg_fill_price"`
|
||||
Commission float64 `gorm:"column:commission;default:0" json:"commission"`
|
||||
CommissionAsset string `gorm:"column:commission_asset;default:USDT" json:"commission_asset"`
|
||||
Leverage int `gorm:"column:leverage;default:1" json:"leverage"`
|
||||
ReduceOnly bool `gorm:"column:reduce_only;default:false" json:"reduce_only"`
|
||||
ClosePosition bool `gorm:"column:close_position;default:false" json:"close_position"`
|
||||
WorkingType string `gorm:"column:working_type;default:CONTRACT_PRICE" json:"working_type"`
|
||||
PriceProtect bool `gorm:"column:price_protect;default:false" json:"price_protect"`
|
||||
OrderAction string `gorm:"column:order_action;default:''" json:"order_action"`
|
||||
RelatedPositionID int64 `gorm:"column:related_position_id;default:0" json:"related_position_id"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime" json:"updated_at"`
|
||||
FilledAt time.Time `gorm:"column:filled_at" json:"filled_at"`
|
||||
}
|
||||
|
||||
// TraderFill trade record (one order may have multiple fills)
|
||||
// TableName returns the table name for TraderOrder
|
||||
func (TraderOrder) TableName() string {
|
||||
return "trader_orders"
|
||||
}
|
||||
|
||||
// TraderFill trade record
|
||||
type TraderFill struct {
|
||||
ID int64 `json:"id"`
|
||||
TraderID string `json:"trader_id"`
|
||||
ExchangeID string `json:"exchange_id"` // Exchange account UUID
|
||||
ExchangeType string `json:"exchange_type"` // Exchange type (hyperliquid/lighter/binance/etc)
|
||||
OrderID int64 `json:"order_id"` // Related order ID
|
||||
ExchangeOrderID string `json:"exchange_order_id"` // Exchange order ID
|
||||
ExchangeTradeID string `json:"exchange_trade_id"` // Exchange trade ID
|
||||
Symbol string `json:"symbol"`
|
||||
Side string `json:"side"` // BUY/SELL
|
||||
Price float64 `json:"price"` // 成交价格
|
||||
Quantity float64 `json:"quantity"` // 成交数量
|
||||
QuoteQuantity float64 `json:"quote_quantity"` // 成交金额(USDT)
|
||||
Commission float64 `json:"commission"` // 手续费
|
||||
CommissionAsset string `json:"commission_asset"`
|
||||
RealizedPnL float64 `json:"realized_pnl"` // 实现盈亏(平仓时)
|
||||
IsMaker bool `json:"is_maker"` // 是否为maker
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
TraderID string `gorm:"column:trader_id;not null;index:idx_fills_trader_id" json:"trader_id"`
|
||||
ExchangeID string `gorm:"column:exchange_id;not null;default:''" json:"exchange_id"`
|
||||
ExchangeType string `gorm:"column:exchange_type;not null;default:''" json:"exchange_type"`
|
||||
OrderID int64 `gorm:"column:order_id;not null;index:idx_fills_order_id" json:"order_id"`
|
||||
ExchangeOrderID string `gorm:"column:exchange_order_id;not null" json:"exchange_order_id"`
|
||||
ExchangeTradeID string `gorm:"column:exchange_trade_id;not null;uniqueIndex:idx_fills_exchange_unique,priority:2" json:"exchange_trade_id"`
|
||||
Symbol string `gorm:"column:symbol;not null" json:"symbol"`
|
||||
Side string `gorm:"column:side;not null" json:"side"`
|
||||
Price float64 `gorm:"column:price;not null" json:"price"`
|
||||
Quantity float64 `gorm:"column:quantity;not null" json:"quantity"`
|
||||
QuoteQuantity float64 `gorm:"column:quote_quantity;not null" json:"quote_quantity"`
|
||||
Commission float64 `gorm:"column:commission;not null" json:"commission"`
|
||||
CommissionAsset string `gorm:"column:commission_asset;not null" json:"commission_asset"`
|
||||
RealizedPnL float64 `gorm:"column:realized_pnl;default:0" json:"realized_pnl"`
|
||||
IsMaker bool `gorm:"column:is_maker;default:false" json:"is_maker"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
|
||||
}
|
||||
|
||||
// OrderStore 订单存储
|
||||
// TableName returns the table name for TraderFill
|
||||
func (TraderFill) TableName() string {
|
||||
return "trader_fills"
|
||||
}
|
||||
|
||||
// OrderStore order storage
|
||||
type OrderStore struct {
|
||||
db *sql.DB
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewOrderStore 创建订单存储实例
|
||||
func NewOrderStore(db *sql.DB) *OrderStore {
|
||||
// NewOrderStore creates order storage instance
|
||||
func NewOrderStore(db *gorm.DB) *OrderStore {
|
||||
return &OrderStore{db: db}
|
||||
}
|
||||
|
||||
// InitTables 初始化订单表
|
||||
// InitTables initializes order tables
|
||||
func (s *OrderStore) InitTables() error {
|
||||
// 创建订单表
|
||||
_, err := s.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS trader_orders (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
trader_id TEXT NOT NULL,
|
||||
exchange_id TEXT NOT NULL DEFAULT '',
|
||||
exchange_type TEXT NOT NULL DEFAULT '',
|
||||
exchange_order_id TEXT NOT NULL,
|
||||
client_order_id TEXT DEFAULT '',
|
||||
symbol TEXT NOT NULL,
|
||||
side TEXT NOT NULL,
|
||||
position_side TEXT DEFAULT '',
|
||||
type TEXT NOT NULL,
|
||||
time_in_force TEXT DEFAULT 'GTC',
|
||||
quantity REAL NOT NULL,
|
||||
price REAL DEFAULT 0,
|
||||
stop_price REAL DEFAULT 0,
|
||||
status TEXT NOT NULL DEFAULT 'NEW',
|
||||
filled_quantity REAL DEFAULT 0,
|
||||
avg_fill_price REAL DEFAULT 0,
|
||||
commission REAL DEFAULT 0,
|
||||
commission_asset TEXT DEFAULT 'USDT',
|
||||
leverage INTEGER DEFAULT 1,
|
||||
reduce_only INTEGER DEFAULT 0,
|
||||
close_position INTEGER DEFAULT 0,
|
||||
working_type TEXT DEFAULT 'CONTRACT_PRICE',
|
||||
price_protect INTEGER DEFAULT 0,
|
||||
order_action TEXT DEFAULT '',
|
||||
related_position_id INTEGER DEFAULT 0,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
filled_at DATETIME,
|
||||
UNIQUE(exchange_id, exchange_order_id)
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create trader_orders table: %w", err)
|
||||
}
|
||||
// For PostgreSQL, check if tables exist to avoid AutoMigrate index conflicts
|
||||
if s.db.Dialector.Name() == "postgres" {
|
||||
var ordersExist, fillsExist int64
|
||||
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'trader_orders'`).Scan(&ordersExist)
|
||||
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'trader_fills'`).Scan(&fillsExist)
|
||||
|
||||
// 创建成交记录表
|
||||
_, err = s.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS trader_fills (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
trader_id TEXT NOT NULL,
|
||||
exchange_id TEXT NOT NULL DEFAULT '',
|
||||
exchange_type TEXT NOT NULL DEFAULT '',
|
||||
order_id INTEGER NOT NULL,
|
||||
exchange_order_id TEXT NOT NULL,
|
||||
exchange_trade_id TEXT NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
side TEXT NOT NULL,
|
||||
price REAL NOT NULL,
|
||||
quantity REAL NOT NULL,
|
||||
quote_quantity REAL NOT NULL,
|
||||
commission REAL NOT NULL,
|
||||
commission_asset TEXT NOT NULL,
|
||||
realized_pnl REAL DEFAULT 0,
|
||||
is_maker INTEGER DEFAULT 0,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(exchange_id, exchange_trade_id),
|
||||
FOREIGN KEY (order_id) REFERENCES trader_orders(id)
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create trader_fills table: %w", err)
|
||||
}
|
||||
|
||||
// 创建索引
|
||||
if ordersExist > 0 && fillsExist > 0 {
|
||||
// Tables exist - just ensure indexes exist, skip AutoMigrate
|
||||
s.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_orders_exchange_unique ON trader_orders(exchange_id, exchange_order_id)`)
|
||||
s.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_fills_exchange_unique ON trader_fills(exchange_id, exchange_trade_id)`)
|
||||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_trader_id ON trader_orders(trader_id)`)
|
||||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_symbol ON trader_orders(symbol)`)
|
||||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_status ON trader_orders(status)`)
|
||||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_exchange_order_id ON trader_orders(exchange_id, exchange_order_id)`)
|
||||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_fills_order_id ON trader_fills(order_id)`)
|
||||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_fills_trader_id ON trader_fills(trader_id)`)
|
||||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_fills_order_id ON trader_fills(order_id)`)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.db.AutoMigrate(&TraderOrder{}, &TraderFill{}); err != nil {
|
||||
return fmt.Errorf("failed to migrate order tables: %w", err)
|
||||
}
|
||||
|
||||
// Create unique composite index for exchange_id + exchange_order_id
|
||||
s.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_orders_exchange_unique ON trader_orders(exchange_id, exchange_order_id)`)
|
||||
// Create unique composite index for exchange_id + exchange_trade_id
|
||||
s.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_fills_exchange_unique ON trader_fills(exchange_id, exchange_trade_id)`)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateOrder 创建订单记录(去重:如果订单已存在则返回已有记录)
|
||||
// CreateOrder creates order record
|
||||
func (s *OrderStore) CreateOrder(order *TraderOrder) error {
|
||||
// 1. 先检查订单是否已存在(去重)
|
||||
// Check if order already exists
|
||||
existing, err := s.GetOrderByExchangeID(order.ExchangeID, order.ExchangeOrderID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check existing order: %w", err)
|
||||
}
|
||||
if existing != nil {
|
||||
// 订单已存在,返回已有记录的ID
|
||||
order.ID = existing.ID
|
||||
order.CreatedAt = existing.CreatedAt
|
||||
order.UpdatedAt = existing.UpdatedAt
|
||||
return nil // 不是错误,只是跳过插入
|
||||
}
|
||||
|
||||
// 2. 订单不存在,插入新记录
|
||||
now := time.Now()
|
||||
order.CreatedAt = now
|
||||
order.UpdatedAt = now
|
||||
|
||||
result, err := s.db.Exec(`
|
||||
INSERT INTO trader_orders (
|
||||
trader_id, exchange_id, exchange_type, exchange_order_id, client_order_id,
|
||||
symbol, side, position_side, type, time_in_force,
|
||||
quantity, price, stop_price, status,
|
||||
filled_quantity, avg_fill_price, commission, commission_asset,
|
||||
leverage, reduce_only, close_position, working_type, price_protect,
|
||||
order_action, related_position_id,
|
||||
created_at, updated_at, filled_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`,
|
||||
order.TraderID, order.ExchangeID, order.ExchangeType, order.ExchangeOrderID, order.ClientOrderID,
|
||||
order.Symbol, order.Side, order.PositionSide, order.Type, order.TimeInForce,
|
||||
order.Quantity, order.Price, order.StopPrice, order.Status,
|
||||
order.FilledQuantity, order.AvgFillPrice, order.Commission, order.CommissionAsset,
|
||||
order.Leverage, order.ReduceOnly, order.ClosePosition, order.WorkingType, order.PriceProtect,
|
||||
order.OrderAction, order.RelatedPositionID,
|
||||
formatTimeOrNow(order.CreatedAt, now), formatTimeOrNow(order.UpdatedAt, now),
|
||||
formatTimePtr(order.FilledAt),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create order: %w", err)
|
||||
}
|
||||
|
||||
id, _ := result.LastInsertId()
|
||||
order.ID = id
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateOrderStatus 更新订单状态
|
||||
return s.db.Create(order).Error
|
||||
}
|
||||
|
||||
// UpdateOrderStatus updates order status
|
||||
func (s *OrderStore) UpdateOrderStatus(id int64, status string, filledQty, avgPrice, commission float64) error {
|
||||
now := time.Now()
|
||||
updateSQL := `
|
||||
UPDATE trader_orders SET
|
||||
status = ?,
|
||||
filled_quantity = ?,
|
||||
avg_fill_price = ?,
|
||||
commission = ?,
|
||||
updated_at = ?
|
||||
`
|
||||
args := []interface{}{status, filledQty, avgPrice, commission, now.Format(time.RFC3339)}
|
||||
updates := map[string]interface{}{
|
||||
"status": status,
|
||||
"filled_quantity": filledQty,
|
||||
"avg_fill_price": avgPrice,
|
||||
"commission": commission,
|
||||
}
|
||||
|
||||
// 如果完全成交,记录成交时间
|
||||
if status == "FILLED" {
|
||||
updateSQL += `, filled_at = ?`
|
||||
args = append(args, now.Format(time.RFC3339))
|
||||
updates["filled_at"] = time.Now()
|
||||
}
|
||||
|
||||
updateSQL += ` WHERE id = ?`
|
||||
args = append(args, id)
|
||||
|
||||
_, err := s.db.Exec(updateSQL, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update order status: %w", err)
|
||||
}
|
||||
return nil
|
||||
return s.db.Model(&TraderOrder{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// CreateFill 创建成交记录(去重:如果成交已存在则跳过)
|
||||
// CreateFill creates fill record
|
||||
func (s *OrderStore) CreateFill(fill *TraderFill) error {
|
||||
// 1. 先检查成交是否已存在(去重)
|
||||
// Check if fill already exists
|
||||
existing, err := s.GetFillByExchangeTradeID(fill.ExchangeID, fill.ExchangeTradeID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check existing fill: %w", err)
|
||||
}
|
||||
if existing != nil {
|
||||
// 成交已存在,返回已有记录的ID
|
||||
fill.ID = existing.ID
|
||||
fill.CreatedAt = existing.CreatedAt
|
||||
return nil // 不是错误,只是跳过插入
|
||||
}
|
||||
|
||||
// 2. 成交不存在,插入新记录
|
||||
now := time.Now()
|
||||
fill.CreatedAt = now
|
||||
|
||||
result, err := s.db.Exec(`
|
||||
INSERT INTO trader_fills (
|
||||
trader_id, exchange_id, exchange_type, order_id, exchange_order_id, exchange_trade_id,
|
||||
symbol, side, price, quantity, quote_quantity,
|
||||
commission, commission_asset, realized_pnl, is_maker,
|
||||
created_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`,
|
||||
fill.TraderID, fill.ExchangeID, fill.ExchangeType, fill.OrderID, fill.ExchangeOrderID, fill.ExchangeTradeID,
|
||||
fill.Symbol, fill.Side, fill.Price, fill.Quantity, fill.QuoteQuantity,
|
||||
fill.Commission, fill.CommissionAsset, fill.RealizedPnL, fill.IsMaker,
|
||||
now.Format(time.RFC3339),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create fill: %w", err)
|
||||
}
|
||||
|
||||
id, _ := result.LastInsertId()
|
||||
fill.ID = id
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFillByExchangeTradeID 根据交易所成交ID获取成交记录
|
||||
func (s *OrderStore) GetFillByExchangeTradeID(exchangeID, exchangeTradeID string) (*TraderFill, error) {
|
||||
row := s.db.QueryRow(`
|
||||
SELECT id, trader_id, exchange_id, exchange_type, order_id, exchange_order_id, exchange_trade_id,
|
||||
symbol, side, price, quantity, quote_quantity,
|
||||
commission, commission_asset, realized_pnl, is_maker,
|
||||
created_at
|
||||
FROM trader_fills
|
||||
WHERE exchange_id = ? AND exchange_trade_id = ?
|
||||
`, exchangeID, exchangeTradeID)
|
||||
return s.db.Create(fill).Error
|
||||
}
|
||||
|
||||
// GetFillByExchangeTradeID gets fill by exchange trade ID
|
||||
func (s *OrderStore) GetFillByExchangeTradeID(exchangeID, exchangeTradeID string) (*TraderFill, error) {
|
||||
var fill TraderFill
|
||||
var createdAt sql.NullString
|
||||
err := row.Scan(
|
||||
&fill.ID, &fill.TraderID, &fill.ExchangeID, &fill.ExchangeType, &fill.OrderID, &fill.ExchangeOrderID, &fill.ExchangeTradeID,
|
||||
&fill.Symbol, &fill.Side, &fill.Price, &fill.Quantity, &fill.QuoteQuantity,
|
||||
&fill.Commission, &fill.CommissionAsset, &fill.RealizedPnL, &fill.IsMaker,
|
||||
&createdAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
err := s.db.Where("exchange_id = ? AND exchange_trade_id = ?", exchangeID, exchangeTradeID).First(&fill).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get fill: %w", err)
|
||||
}
|
||||
|
||||
// Parse time
|
||||
if createdAt.Valid {
|
||||
if t, err := time.Parse(time.RFC3339, createdAt.String); err == nil {
|
||||
fill.CreatedAt = t
|
||||
}
|
||||
}
|
||||
|
||||
return &fill, nil
|
||||
}
|
||||
|
||||
// GetOrderByExchangeID 根据交易所订单ID获取订单
|
||||
// GetOrderByExchangeID gets order by exchange order ID
|
||||
func (s *OrderStore) GetOrderByExchangeID(exchangeID, exchangeOrderID string) (*TraderOrder, error) {
|
||||
row := s.db.QueryRow(`
|
||||
SELECT id, trader_id, exchange_id, exchange_type, exchange_order_id, client_order_id,
|
||||
symbol, side, position_side, type, time_in_force,
|
||||
quantity, price, stop_price, status,
|
||||
filled_quantity, avg_fill_price, commission, commission_asset,
|
||||
leverage, reduce_only, close_position, working_type, price_protect,
|
||||
order_action, related_position_id,
|
||||
created_at, updated_at, filled_at
|
||||
FROM trader_orders
|
||||
WHERE exchange_id = ? AND exchange_order_id = ?
|
||||
`, exchangeID, exchangeOrderID)
|
||||
|
||||
var order TraderOrder
|
||||
var createdAt, updatedAt, filledAt sql.NullString
|
||||
err := row.Scan(
|
||||
&order.ID, &order.TraderID, &order.ExchangeID, &order.ExchangeType, &order.ExchangeOrderID, &order.ClientOrderID,
|
||||
&order.Symbol, &order.Side, &order.PositionSide, &order.Type, &order.TimeInForce,
|
||||
&order.Quantity, &order.Price, &order.StopPrice, &order.Status,
|
||||
&order.FilledQuantity, &order.AvgFillPrice, &order.Commission, &order.CommissionAsset,
|
||||
&order.Leverage, &order.ReduceOnly, &order.ClosePosition, &order.WorkingType, &order.PriceProtect,
|
||||
&order.OrderAction, &order.RelatedPositionID,
|
||||
&createdAt, &updatedAt, &filledAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
err := s.db.Where("exchange_id = ? AND exchange_order_id = ?", exchangeID, exchangeOrderID).First(&order).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get order: %w", err)
|
||||
}
|
||||
|
||||
// Parse times
|
||||
if createdAt.Valid {
|
||||
if t, err := time.Parse(time.RFC3339, createdAt.String); err == nil {
|
||||
order.CreatedAt = t
|
||||
}
|
||||
}
|
||||
if updatedAt.Valid {
|
||||
if t, err := time.Parse(time.RFC3339, updatedAt.String); err == nil {
|
||||
order.UpdatedAt = t
|
||||
}
|
||||
}
|
||||
if filledAt.Valid {
|
||||
if t, err := time.Parse(time.RFC3339, filledAt.String); err == nil {
|
||||
order.FilledAt = t
|
||||
}
|
||||
}
|
||||
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
// GetTraderOrders 获取trader的订单列表
|
||||
// GetTraderOrders gets trader's order list
|
||||
func (s *OrderStore) GetTraderOrders(traderID string, limit int) ([]*TraderOrder, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, trader_id, exchange_id, exchange_type, exchange_order_id, client_order_id,
|
||||
symbol, side, position_side, type, time_in_force,
|
||||
quantity, price, stop_price, status,
|
||||
filled_quantity, avg_fill_price, commission, commission_asset,
|
||||
leverage, reduce_only, close_position, working_type, price_protect,
|
||||
order_action, related_position_id,
|
||||
created_at, updated_at, filled_at
|
||||
FROM trader_orders
|
||||
WHERE trader_id = ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
`, traderID, limit)
|
||||
var orders []*TraderOrder
|
||||
err := s.db.Where("trader_id = ?", traderID).
|
||||
Order("created_at DESC").
|
||||
Limit(limit).
|
||||
Find(&orders).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query orders: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var orders []*TraderOrder
|
||||
for rows.Next() {
|
||||
var order TraderOrder
|
||||
var createdAt, updatedAt, filledAt sql.NullString
|
||||
err := rows.Scan(
|
||||
&order.ID, &order.TraderID, &order.ExchangeID, &order.ExchangeType, &order.ExchangeOrderID, &order.ClientOrderID,
|
||||
&order.Symbol, &order.Side, &order.PositionSide, &order.Type, &order.TimeInForce,
|
||||
&order.Quantity, &order.Price, &order.StopPrice, &order.Status,
|
||||
&order.FilledQuantity, &order.AvgFillPrice, &order.Commission, &order.CommissionAsset,
|
||||
&order.Leverage, &order.ReduceOnly, &order.ClosePosition, &order.WorkingType, &order.PriceProtect,
|
||||
&order.OrderAction, &order.RelatedPositionID,
|
||||
&createdAt, &updatedAt, &filledAt,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse times
|
||||
if createdAt.Valid {
|
||||
if t, err := time.Parse(time.RFC3339, createdAt.String); err == nil {
|
||||
order.CreatedAt = t
|
||||
}
|
||||
}
|
||||
if updatedAt.Valid {
|
||||
if t, err := time.Parse(time.RFC3339, updatedAt.String); err == nil {
|
||||
order.UpdatedAt = t
|
||||
}
|
||||
}
|
||||
if filledAt.Valid {
|
||||
if t, err := time.Parse(time.RFC3339, filledAt.String); err == nil {
|
||||
order.FilledAt = t
|
||||
}
|
||||
}
|
||||
|
||||
orders = append(orders, &order)
|
||||
}
|
||||
|
||||
return orders, nil
|
||||
}
|
||||
|
||||
// GetOrderFills 获取订单的成交记录
|
||||
// GetOrderFills gets order's fill records
|
||||
func (s *OrderStore) GetOrderFills(orderID int64) ([]*TraderFill, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, trader_id, exchange_id, exchange_type, order_id, exchange_order_id, exchange_trade_id,
|
||||
symbol, side, price, quantity, quote_quantity,
|
||||
commission, commission_asset, realized_pnl, is_maker,
|
||||
created_at
|
||||
FROM trader_fills
|
||||
WHERE order_id = ?
|
||||
ORDER BY created_at ASC
|
||||
`, orderID)
|
||||
var fills []*TraderFill
|
||||
err := s.db.Where("order_id = ?", orderID).
|
||||
Order("created_at ASC").
|
||||
Find(&fills).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query fills: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var fills []*TraderFill
|
||||
for rows.Next() {
|
||||
var fill TraderFill
|
||||
var createdAt sql.NullString
|
||||
err := rows.Scan(
|
||||
&fill.ID, &fill.TraderID, &fill.ExchangeID, &fill.ExchangeType, &fill.OrderID, &fill.ExchangeOrderID, &fill.ExchangeTradeID,
|
||||
&fill.Symbol, &fill.Side, &fill.Price, &fill.Quantity, &fill.QuoteQuantity,
|
||||
&fill.Commission, &fill.CommissionAsset, &fill.RealizedPnL, &fill.IsMaker,
|
||||
&createdAt,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if createdAt.Valid {
|
||||
if t, err := time.Parse(time.RFC3339, createdAt.String); err == nil {
|
||||
fill.CreatedAt = t
|
||||
}
|
||||
}
|
||||
|
||||
fills = append(fills, &fill)
|
||||
}
|
||||
|
||||
return fills, nil
|
||||
}
|
||||
|
||||
// GetTraderOrderStats 获取trader的订单统计
|
||||
// GetTraderOrderStats gets trader's order statistics
|
||||
func (s *OrderStore) GetTraderOrderStats(traderID string) (map[string]interface{}, error) {
|
||||
var totalOrders, filledOrders, canceledOrders int
|
||||
var totalCommission, totalVolume float64
|
||||
type result struct {
|
||||
TotalOrders int
|
||||
FilledOrders int
|
||||
CanceledOrders int
|
||||
TotalCommission float64
|
||||
TotalVolume float64
|
||||
}
|
||||
var r result
|
||||
|
||||
err := s.db.QueryRow(`
|
||||
SELECT
|
||||
COUNT(*) as total_orders,
|
||||
err := s.db.Model(&TraderOrder{}).
|
||||
Select(`COUNT(*) as total_orders,
|
||||
SUM(CASE WHEN status = 'FILLED' THEN 1 ELSE 0 END) as filled_orders,
|
||||
SUM(CASE WHEN status = 'CANCELED' THEN 1 ELSE 0 END) as canceled_orders,
|
||||
SUM(commission) as total_commission,
|
||||
SUM(filled_quantity * avg_fill_price) as total_volume
|
||||
FROM trader_orders
|
||||
WHERE trader_id = ?
|
||||
`, traderID).Scan(&totalOrders, &filledOrders, &canceledOrders, &totalCommission, &totalVolume)
|
||||
|
||||
SUM(filled_quantity * avg_fill_price) as total_volume`).
|
||||
Where("trader_id = ?", traderID).
|
||||
Scan(&r).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get order stats: %w", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_orders": totalOrders,
|
||||
"filled_orders": filledOrders,
|
||||
"canceled_orders": canceledOrders,
|
||||
"total_commission": totalCommission,
|
||||
"total_volume": totalVolume,
|
||||
"total_orders": r.TotalOrders,
|
||||
"filled_orders": r.FilledOrders,
|
||||
"canceled_orders": r.CanceledOrders,
|
||||
"total_commission": r.TotalCommission,
|
||||
"total_volume": r.TotalVolume,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CleanupDuplicateOrders 清理重复的订单记录(保留最早创建的记录)
|
||||
// CleanupDuplicateOrders cleans up duplicate order records
|
||||
func (s *OrderStore) CleanupDuplicateOrders() (int, error) {
|
||||
result, err := s.db.Exec(`
|
||||
result := s.db.Exec(`
|
||||
DELETE FROM trader_orders
|
||||
WHERE id NOT IN (
|
||||
SELECT MIN(id)
|
||||
@@ -497,17 +257,15 @@ func (s *OrderStore) CleanupDuplicateOrders() (int, error) {
|
||||
GROUP BY exchange_id, exchange_order_id
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to cleanup duplicate orders: %w", err)
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("failed to cleanup duplicate orders: %w", result.Error)
|
||||
}
|
||||
return int(result.RowsAffected), nil
|
||||
}
|
||||
|
||||
rowsAffected, _ := result.RowsAffected()
|
||||
return int(rowsAffected), nil
|
||||
}
|
||||
|
||||
// CleanupDuplicateFills 清理重复的成交记录(保留最早创建的记录)
|
||||
// CleanupDuplicateFills cleans up duplicate fill records
|
||||
func (s *OrderStore) CleanupDuplicateFills() (int, error) {
|
||||
result, err := s.db.Exec(`
|
||||
result := s.db.Exec(`
|
||||
DELETE FROM trader_fills
|
||||
WHERE id NOT IN (
|
||||
SELECT MIN(id)
|
||||
@@ -515,73 +273,66 @@ func (s *OrderStore) CleanupDuplicateFills() (int, error) {
|
||||
GROUP BY exchange_id, exchange_trade_id
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to cleanup duplicate fills: %w", err)
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("failed to cleanup duplicate fills: %w", result.Error)
|
||||
}
|
||||
return int(result.RowsAffected), nil
|
||||
}
|
||||
|
||||
rowsAffected, _ := result.RowsAffected()
|
||||
return int(rowsAffected), nil
|
||||
}
|
||||
|
||||
// GetDuplicateOrdersCount 获取重复订单的数量(用于诊断)
|
||||
// GetDuplicateOrdersCount gets duplicate orders count
|
||||
func (s *OrderStore) GetDuplicateOrdersCount() (int, error) {
|
||||
var count int
|
||||
err := s.db.QueryRow(`
|
||||
SELECT COUNT(*) - COUNT(DISTINCT exchange_id || ',' || exchange_order_id)
|
||||
FROM trader_orders
|
||||
`).Scan(&count)
|
||||
return count, err
|
||||
var total, distinct int64
|
||||
s.db.Model(&TraderOrder{}).Count(&total)
|
||||
|
||||
// Count distinct combinations
|
||||
var distinctResult struct{ Count int64 }
|
||||
s.db.Model(&TraderOrder{}).
|
||||
Select("COUNT(DISTINCT exchange_id || ',' || exchange_order_id) as count").
|
||||
Scan(&distinctResult)
|
||||
distinct = distinctResult.Count
|
||||
|
||||
return int(total - distinct), nil
|
||||
}
|
||||
|
||||
// GetDuplicateFillsCount 获取重复成交的数量(用于诊断)
|
||||
// GetDuplicateFillsCount gets duplicate fills count
|
||||
func (s *OrderStore) GetDuplicateFillsCount() (int, error) {
|
||||
var count int
|
||||
err := s.db.QueryRow(`
|
||||
SELECT COUNT(*) - COUNT(DISTINCT exchange_id || ',' || exchange_trade_id)
|
||||
FROM trader_fills
|
||||
`).Scan(&count)
|
||||
return count, err
|
||||
var total, distinct int64
|
||||
s.db.Model(&TraderFill{}).Count(&total)
|
||||
|
||||
var distinctResult struct{ Count int64 }
|
||||
s.db.Model(&TraderFill{}).
|
||||
Select("COUNT(DISTINCT exchange_id || ',' || exchange_trade_id) as count").
|
||||
Scan(&distinctResult)
|
||||
distinct = distinctResult.Count
|
||||
|
||||
return int(total - distinct), nil
|
||||
}
|
||||
|
||||
// GetMaxTradeIDsByExchange returns max trade ID for each symbol for a given exchange
|
||||
// Used for incremental sync - only fetch trades with ID > maxTradeID
|
||||
func (s *OrderStore) GetMaxTradeIDsByExchange(exchangeID string) (map[string]int64, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT symbol, MAX(CAST(exchange_trade_id AS INTEGER)) as max_trade_id
|
||||
FROM trader_fills
|
||||
WHERE exchange_id = ? AND exchange_trade_id != ''
|
||||
GROUP BY symbol
|
||||
`, exchangeID)
|
||||
type symbolMaxID struct {
|
||||
Symbol string
|
||||
MaxTradeID int64
|
||||
}
|
||||
var results []symbolMaxID
|
||||
|
||||
err := s.db.Model(&TraderFill{}).
|
||||
Select("symbol, MAX(CAST(exchange_trade_id AS INTEGER)) as max_trade_id").
|
||||
Where("exchange_id = ? AND exchange_trade_id != ''", exchangeID).
|
||||
Group("symbol").
|
||||
Find(&results).Error
|
||||
if err != nil {
|
||||
// If CAST fails (non-numeric trade IDs), fallback to string comparison
|
||||
if strings.Contains(err.Error(), "CAST") || strings.Contains(err.Error(), "invalid") {
|
||||
return make(map[string]int64), nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to query max trade IDs: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
result := make(map[string]int64)
|
||||
for rows.Next() {
|
||||
var symbol string
|
||||
var maxID int64
|
||||
if err := rows.Scan(&symbol, &maxID); err != nil {
|
||||
continue
|
||||
}
|
||||
result[symbol] = maxID
|
||||
for _, r := range results {
|
||||
result[r.Symbol] = r.MaxTradeID
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// formatTimePtr formats time.Time to RFC3339 string, returns NULL for zero time
|
||||
func formatTimePtr(t time.Time) interface{} {
|
||||
if t.IsZero() {
|
||||
return nil
|
||||
}
|
||||
return t.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// formatTimeOrNow returns the formatted time if not zero, otherwise returns now
|
||||
func formatTimeOrNow(t time.Time, now time.Time) string {
|
||||
if t.IsZero() {
|
||||
return now.Format(time.RFC3339)
|
||||
}
|
||||
return t.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
+520
-812
File diff suppressed because it is too large
Load Diff
+104
-80
@@ -7,12 +7,15 @@ import (
|
||||
"fmt"
|
||||
"nofx/logger"
|
||||
"sync"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Store unified data storage interface
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
driver *DBDriver // Database driver for abstraction
|
||||
gdb *gorm.DB // GORM database connection
|
||||
db *sql.DB // Legacy sql.DB for backward compatibility
|
||||
driver *DBDriver // Database driver for abstraction (legacy)
|
||||
|
||||
// Sub-stores (lazy initialization)
|
||||
user *UserStore
|
||||
@@ -26,105 +29,103 @@ type Store struct {
|
||||
equity *EquityStore
|
||||
order *OrderStore
|
||||
|
||||
// Encryption functions
|
||||
encryptFunc func(string) string
|
||||
decryptFunc func(string) string
|
||||
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// New creates new Store instance (SQLite mode for backward compatibility)
|
||||
func New(dbPath string) (*Store, error) {
|
||||
driver, err := NewDBDriver(DBConfig{Type: DBTypeSQLite, Path: dbPath})
|
||||
gdb, err := InitGorm(dbPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
s := &Store{db: driver.DB(), driver: driver}
|
||||
// Get underlying sql.DB for legacy compatibility
|
||||
sqlDB, err := gdb.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get sql.DB: %w", err)
|
||||
}
|
||||
|
||||
s := &Store{gdb: gdb, db: sqlDB}
|
||||
|
||||
// Initialize all table structures
|
||||
if err := s.initTables(); err != nil {
|
||||
driver.Close()
|
||||
sqlDB.Close()
|
||||
return nil, fmt.Errorf("failed to initialize table structure: %w", err)
|
||||
}
|
||||
|
||||
// Initialize default data
|
||||
if err := s.initDefaultData(); err != nil {
|
||||
driver.Close()
|
||||
sqlDB.Close()
|
||||
return nil, fmt.Errorf("failed to initialize default data: %w", err)
|
||||
}
|
||||
|
||||
logger.Infof("✅ Database initialized (type: %s)", driver.Type)
|
||||
logger.Infof("✅ Database initialized (GORM, SQLite)")
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// NewFromEnv creates new Store instance from environment variables
|
||||
// DB_TYPE: sqlite (default) or postgres
|
||||
// For SQLite: DB_PATH (default: data/data.db)
|
||||
// For PostgreSQL: DB_HOST, DB_PORT, DB_USER, DB_PASSWORD, DB_NAME, DB_SSLMODE
|
||||
func NewFromEnv() (*Store, error) {
|
||||
driver, err := NewDBDriverFromEnv()
|
||||
// NewWithConfig creates new Store instance with provided database configuration
|
||||
func NewWithConfig(cfg DBConfig) (*Store, error) {
|
||||
gdb, err := InitGormWithConfig(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
s := &Store{db: driver.DB(), driver: driver}
|
||||
// Get underlying sql.DB for legacy compatibility
|
||||
sqlDB, err := gdb.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get sql.DB: %w", err)
|
||||
}
|
||||
|
||||
s := &Store{gdb: gdb, db: sqlDB}
|
||||
|
||||
// Initialize all table structures
|
||||
if err := s.initTables(); err != nil {
|
||||
driver.Close()
|
||||
sqlDB.Close()
|
||||
return nil, fmt.Errorf("failed to initialize table structure: %w", err)
|
||||
}
|
||||
|
||||
// Initialize default data
|
||||
if err := s.initDefaultData(); err != nil {
|
||||
driver.Close()
|
||||
sqlDB.Close()
|
||||
return nil, fmt.Errorf("failed to initialize default data: %w", err)
|
||||
}
|
||||
|
||||
logger.Infof("✅ Database initialized (type: %s)", driver.Type)
|
||||
dbTypeStr := "SQLite"
|
||||
if cfg.Type == DBTypePostgres {
|
||||
dbTypeStr = "PostgreSQL"
|
||||
}
|
||||
logger.Infof("✅ Database initialized (GORM, %s)", dbTypeStr)
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// NewFromDB creates Store from existing database connection
|
||||
// NewFromGorm creates Store from existing GORM connection
|
||||
func NewFromGorm(gdb *gorm.DB) (*Store, error) {
|
||||
sqlDB, err := gdb.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Store{gdb: gdb, db: sqlDB}, nil
|
||||
}
|
||||
|
||||
// NewFromDB creates Store from existing database connection (legacy)
|
||||
// Deprecated: Use NewFromGorm instead
|
||||
func NewFromDB(db *sql.DB) *Store {
|
||||
return &Store{db: db}
|
||||
}
|
||||
|
||||
// SetCryptoFuncs sets encryption/decryption functions
|
||||
func (s *Store) SetCryptoFuncs(encrypt, decrypt func(string) string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.encryptFunc = encrypt
|
||||
s.decryptFunc = decrypt
|
||||
|
||||
// Update already initialized sub-stores
|
||||
if s.aiModel != nil {
|
||||
s.aiModel.encryptFunc = encrypt
|
||||
s.aiModel.decryptFunc = decrypt
|
||||
}
|
||||
if s.exchange != nil {
|
||||
s.exchange.encryptFunc = encrypt
|
||||
s.exchange.decryptFunc = decrypt
|
||||
}
|
||||
if s.trader != nil {
|
||||
s.trader.decryptFunc = decrypt
|
||||
}
|
||||
}
|
||||
|
||||
// initTables initializes all database tables
|
||||
// initTables initializes all database tables using GORM AutoMigrate
|
||||
func (s *Store) initTables() error {
|
||||
// Initialize system config table first
|
||||
if _, err := s.db.Exec(`
|
||||
// Create system_config table (GORM handles this via raw SQL for simplicity)
|
||||
if err := s.gdb.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS system_config (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL
|
||||
)
|
||||
`); err != nil {
|
||||
`).Error; err != nil {
|
||||
return fmt.Errorf("failed to create system_config table: %w", err)
|
||||
}
|
||||
|
||||
// Initialize in dependency order
|
||||
// Initialize sub-store tables
|
||||
if err := s.User().initTables(); err != nil {
|
||||
return fmt.Errorf("failed to initialize user tables: %w", err)
|
||||
}
|
||||
@@ -183,7 +184,7 @@ func (s *Store) User() *UserStore {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.user == nil {
|
||||
s.user = &UserStore{db: s.db}
|
||||
s.user = NewUserStore(s.gdb)
|
||||
}
|
||||
return s.user
|
||||
}
|
||||
@@ -193,11 +194,7 @@ func (s *Store) AIModel() *AIModelStore {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.aiModel == nil {
|
||||
s.aiModel = &AIModelStore{
|
||||
db: s.db,
|
||||
encryptFunc: s.encryptFunc,
|
||||
decryptFunc: s.decryptFunc,
|
||||
}
|
||||
s.aiModel = NewAIModelStore(s.gdb)
|
||||
}
|
||||
return s.aiModel
|
||||
}
|
||||
@@ -207,11 +204,7 @@ func (s *Store) Exchange() *ExchangeStore {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.exchange == nil {
|
||||
s.exchange = &ExchangeStore{
|
||||
db: s.db,
|
||||
encryptFunc: s.encryptFunc,
|
||||
decryptFunc: s.decryptFunc,
|
||||
}
|
||||
s.exchange = NewExchangeStore(s.gdb)
|
||||
}
|
||||
return s.exchange
|
||||
}
|
||||
@@ -221,10 +214,7 @@ func (s *Store) Trader() *TraderStore {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.trader == nil {
|
||||
s.trader = &TraderStore{
|
||||
db: s.db,
|
||||
decryptFunc: s.decryptFunc,
|
||||
}
|
||||
s.trader = NewTraderStore(s.gdb)
|
||||
}
|
||||
return s.trader
|
||||
}
|
||||
@@ -234,7 +224,7 @@ func (s *Store) Decision() *DecisionStore {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.decision == nil {
|
||||
s.decision = &DecisionStore{db: s.db}
|
||||
s.decision = NewDecisionStore(s.gdb)
|
||||
}
|
||||
return s.decision
|
||||
}
|
||||
@@ -244,7 +234,7 @@ func (s *Store) Backtest() *BacktestStore {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.backtest == nil {
|
||||
s.backtest = &BacktestStore{db: s.db}
|
||||
s.backtest = NewBacktestStore(s.gdb)
|
||||
}
|
||||
return s.backtest
|
||||
}
|
||||
@@ -254,7 +244,7 @@ func (s *Store) Position() *PositionStore {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.position == nil {
|
||||
s.position = NewPositionStore(s.db)
|
||||
s.position = NewPositionStore(s.gdb)
|
||||
}
|
||||
return s.position
|
||||
}
|
||||
@@ -264,7 +254,7 @@ func (s *Store) Strategy() *StrategyStore {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.strategy == nil {
|
||||
s.strategy = &StrategyStore{db: s.db}
|
||||
s.strategy = NewStrategyStore(s.gdb)
|
||||
}
|
||||
return s.strategy
|
||||
}
|
||||
@@ -274,7 +264,7 @@ func (s *Store) Equity() *EquityStore {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.equity == nil {
|
||||
s.equity = &EquityStore{db: s.db}
|
||||
s.equity = NewEquityStore(s.gdb)
|
||||
}
|
||||
return s.equity
|
||||
}
|
||||
@@ -284,7 +274,7 @@ func (s *Store) Order() *OrderStore {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.order == nil {
|
||||
s.order = NewOrderStore(s.db)
|
||||
s.order = NewOrderStore(s.gdb)
|
||||
}
|
||||
return s.order
|
||||
}
|
||||
@@ -294,10 +284,18 @@ func (s *Store) Close() error {
|
||||
if s.driver != nil {
|
||||
return s.driver.Close()
|
||||
}
|
||||
if s.db != nil {
|
||||
return s.db.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Driver returns database driver for abstraction
|
||||
// GormDB returns the GORM database connection
|
||||
func (s *Store) GormDB() *gorm.DB {
|
||||
return s.gdb
|
||||
}
|
||||
|
||||
// Driver returns database driver for abstraction (legacy)
|
||||
func (s *Store) Driver() *DBDriver {
|
||||
return s.driver
|
||||
}
|
||||
@@ -307,11 +305,25 @@ func (s *Store) DBType() DBType {
|
||||
if s.driver != nil {
|
||||
return s.driver.Type
|
||||
}
|
||||
// Detect from GORM dialector
|
||||
if s.gdb != nil {
|
||||
switch s.gdb.Dialector.Name() {
|
||||
case "postgres":
|
||||
return DBTypePostgres
|
||||
default:
|
||||
return DBTypeSQLite
|
||||
}
|
||||
}
|
||||
return DBTypeSQLite
|
||||
}
|
||||
|
||||
// DB gets underlying database connection (for legacy code compatibility, gradually deprecated)
|
||||
// Deprecated: use Store methods instead
|
||||
// q converts query placeholders for current database type (legacy helper)
|
||||
func (s *Store) q(query string) string {
|
||||
return convertQuery(query, s.DBType())
|
||||
}
|
||||
|
||||
// DB gets underlying database connection (for legacy code compatibility)
|
||||
// Deprecated: use GormDB() instead
|
||||
func (s *Store) DB() *sql.DB {
|
||||
return s.db
|
||||
}
|
||||
@@ -319,24 +331,36 @@ func (s *Store) DB() *sql.DB {
|
||||
// GetSystemConfig gets a system configuration value by key
|
||||
func (s *Store) GetSystemConfig(key string) (string, error) {
|
||||
var value string
|
||||
err := s.db.QueryRow(`SELECT value FROM system_config WHERE key = ?`, key).Scan(&value)
|
||||
if err == sql.ErrNoRows {
|
||||
result := s.gdb.Raw("SELECT value FROM system_config WHERE key = ?", key).Scan(&value)
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return "", nil
|
||||
}
|
||||
return value, err
|
||||
return "", result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return "", nil
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// SetSystemConfig sets a system configuration value
|
||||
func (s *Store) SetSystemConfig(key, value string) error {
|
||||
_, err := s.db.Exec(`
|
||||
// Use GORM-compatible upsert
|
||||
return s.gdb.Exec(`
|
||||
INSERT INTO system_config (key, value) VALUES (?, ?)
|
||||
ON CONFLICT(key) DO UPDATE SET value = excluded.value
|
||||
`, key, value)
|
||||
return err
|
||||
`, key, value).Error
|
||||
}
|
||||
|
||||
// Transaction executes transaction
|
||||
func (s *Store) Transaction(fn func(tx *sql.Tx) error) error {
|
||||
// Transaction executes transaction with GORM
|
||||
func (s *Store) Transaction(fn func(tx *gorm.DB) error) error {
|
||||
return s.gdb.Transaction(fn)
|
||||
}
|
||||
|
||||
// TransactionSQL executes transaction with sql.Tx (legacy)
|
||||
// Deprecated: Use Transaction() instead
|
||||
func (s *Store) TransactionSQL(fn func(tx *sql.Tx) error) error {
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
|
||||
+52
-150
@@ -1,30 +1,33 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// StrategyStore strategy storage
|
||||
type StrategyStore struct {
|
||||
db *sql.DB
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// Strategy strategy configuration
|
||||
type Strategy struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
IsActive bool `json:"is_active"` // whether it is active (a user can only have one active strategy)
|
||||
IsDefault bool `json:"is_default"` // whether it is a system default strategy
|
||||
Config string `json:"config"` // strategy configuration in JSON format
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
UserID string `gorm:"column:user_id;not null;default:'';index" json:"user_id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
Description string `gorm:"default:''" json:"description"`
|
||||
IsActive bool `gorm:"column:is_active;default:false;index" json:"is_active"`
|
||||
IsDefault bool `gorm:"column:is_default;default:false" json:"is_default"`
|
||||
Config string `gorm:"not null;default:'{}'" json:"config"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
func (Strategy) TableName() string { return "strategies" }
|
||||
|
||||
// StrategyConfig strategy configuration details (JSON structure)
|
||||
type StrategyConfig struct {
|
||||
// coin source configuration
|
||||
@@ -136,24 +139,6 @@ type ExternalDataSource struct {
|
||||
}
|
||||
|
||||
// RiskControlConfig risk control configuration
|
||||
// All parameters are clearly defined without ambiguity:
|
||||
//
|
||||
// Position Limits:
|
||||
// - MaxPositions: max number of coins held simultaneously (CODE ENFORCED)
|
||||
//
|
||||
// Trading Leverage (exchange leverage for opening positions):
|
||||
// - BTCETHMaxLeverage: BTC/ETH max exchange leverage (AI guided)
|
||||
// - AltcoinMaxLeverage: Altcoin max exchange leverage (AI guided)
|
||||
//
|
||||
// Position Value Limits (single position notional value / account equity):
|
||||
// - BTCETHMaxPositionValueRatio: BTC/ETH max = equity × ratio (CODE ENFORCED)
|
||||
// - AltcoinMaxPositionValueRatio: Altcoin max = equity × ratio (CODE ENFORCED)
|
||||
//
|
||||
// Risk Controls:
|
||||
// - MaxMarginUsage: max margin utilization percentage (CODE ENFORCED)
|
||||
// - MinPositionSize: minimum position size in USDT (CODE ENFORCED)
|
||||
// - MinRiskRewardRatio: min take_profit / stop_loss ratio (AI guided)
|
||||
// - MinConfidence: min AI confidence to open position (AI guided)
|
||||
type RiskControlConfig struct {
|
||||
// Max number of coins held simultaneously (CODE ENFORCED)
|
||||
MaxPositions int `json:"max_positions"`
|
||||
@@ -179,38 +164,21 @@ type RiskControlConfig struct {
|
||||
MinConfidence int `json:"min_confidence"`
|
||||
}
|
||||
|
||||
func (s *StrategyStore) initTables() error {
|
||||
_, err := s.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS strategies (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL DEFAULT '',
|
||||
name TEXT NOT NULL,
|
||||
description TEXT DEFAULT '',
|
||||
is_active BOOLEAN DEFAULT 0,
|
||||
is_default BOOLEAN DEFAULT 0,
|
||||
config TEXT NOT NULL DEFAULT '{}',
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
// NewStrategyStore creates a new StrategyStore
|
||||
func NewStrategyStore(db *gorm.DB) *StrategyStore {
|
||||
return &StrategyStore{db: db}
|
||||
}
|
||||
|
||||
// create indexes
|
||||
_, _ = s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_strategies_user_id ON strategies(user_id)`)
|
||||
_, _ = s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_strategies_is_active ON strategies(is_active)`)
|
||||
|
||||
// trigger: automatically update updated_at on update
|
||||
_, err = s.db.Exec(`
|
||||
CREATE TRIGGER IF NOT EXISTS update_strategies_updated_at
|
||||
AFTER UPDATE ON strategies
|
||||
BEGIN
|
||||
UPDATE strategies SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
|
||||
END
|
||||
`)
|
||||
|
||||
return err
|
||||
func (s *StrategyStore) initTables() error {
|
||||
// For PostgreSQL with existing table, skip AutoMigrate
|
||||
if s.db.Dialector.Name() == "postgres" {
|
||||
var tableExists int64
|
||||
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'strategies'`).Scan(&tableExists)
|
||||
if tableExists > 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return s.db.AutoMigrate(&Strategy{})
|
||||
}
|
||||
|
||||
func (s *StrategyStore) initDefaultData() error {
|
||||
@@ -322,159 +290,93 @@ Only enter positions when multiple signals resonate. Freely use any effective an
|
||||
|
||||
// Create create a strategy
|
||||
func (s *StrategyStore) Create(strategy *Strategy) error {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO strategies (id, user_id, name, description, is_active, is_default, config)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
`, strategy.ID, strategy.UserID, strategy.Name, strategy.Description, strategy.IsActive, strategy.IsDefault, strategy.Config)
|
||||
return err
|
||||
return s.db.Create(strategy).Error
|
||||
}
|
||||
|
||||
// Update update a strategy
|
||||
func (s *StrategyStore) Update(strategy *Strategy) error {
|
||||
_, err := s.db.Exec(`
|
||||
UPDATE strategies SET
|
||||
name = ?, description = ?, config = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = ? AND user_id = ?
|
||||
`, strategy.Name, strategy.Description, strategy.Config, strategy.ID, strategy.UserID)
|
||||
return err
|
||||
return s.db.Model(&Strategy{}).
|
||||
Where("id = ? AND user_id = ?", strategy.ID, strategy.UserID).
|
||||
Updates(map[string]interface{}{
|
||||
"name": strategy.Name,
|
||||
"description": strategy.Description,
|
||||
"config": strategy.Config,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// Delete delete a strategy
|
||||
func (s *StrategyStore) Delete(userID, id string) error {
|
||||
// do not allow deleting system default strategy
|
||||
var isDefault bool
|
||||
s.db.QueryRow(`SELECT is_default FROM strategies WHERE id = ?`, id).Scan(&isDefault)
|
||||
if isDefault {
|
||||
var st Strategy
|
||||
if err := s.db.Where("id = ?", id).First(&st).Error; err == nil && st.IsDefault {
|
||||
return fmt.Errorf("cannot delete system default strategy")
|
||||
}
|
||||
|
||||
_, err := s.db.Exec(`DELETE FROM strategies WHERE id = ? AND user_id = ?`, id, userID)
|
||||
return err
|
||||
return s.db.Where("id = ? AND user_id = ?", id, userID).Delete(&Strategy{}).Error
|
||||
}
|
||||
|
||||
// List get user's strategy list
|
||||
func (s *StrategyStore) List(userID string) ([]*Strategy, error) {
|
||||
// get user's own strategies + system default strategy
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
|
||||
FROM strategies
|
||||
WHERE user_id = ? OR is_default = 1
|
||||
ORDER BY is_default DESC, created_at DESC
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var strategies []*Strategy
|
||||
for rows.Next() {
|
||||
var st Strategy
|
||||
var createdAt, updatedAt string
|
||||
err := rows.Scan(
|
||||
&st.ID, &st.UserID, &st.Name, &st.Description,
|
||||
&st.IsActive, &st.IsDefault, &st.Config,
|
||||
&createdAt, &updatedAt,
|
||||
)
|
||||
err := s.db.Where("user_id = ? OR is_default = ?", userID, true).
|
||||
Order("is_default DESC, created_at DESC").
|
||||
Find(&strategies).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
st.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
st.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
strategies = append(strategies, &st)
|
||||
}
|
||||
return strategies, nil
|
||||
}
|
||||
|
||||
// Get get a single strategy
|
||||
func (s *StrategyStore) Get(userID, id string) (*Strategy, error) {
|
||||
var st Strategy
|
||||
var createdAt, updatedAt string
|
||||
err := s.db.QueryRow(`
|
||||
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
|
||||
FROM strategies
|
||||
WHERE id = ? AND (user_id = ? OR is_default = 1)
|
||||
`, id, userID).Scan(
|
||||
&st.ID, &st.UserID, &st.Name, &st.Description,
|
||||
&st.IsActive, &st.IsDefault, &st.Config,
|
||||
&createdAt, &updatedAt,
|
||||
)
|
||||
err := s.db.Where("id = ? AND (user_id = ? OR is_default = ?)", id, userID, true).
|
||||
First(&st).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
st.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
st.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
return &st, nil
|
||||
}
|
||||
|
||||
// GetActive get user's currently active strategy
|
||||
func (s *StrategyStore) GetActive(userID string) (*Strategy, error) {
|
||||
var st Strategy
|
||||
var createdAt, updatedAt string
|
||||
err := s.db.QueryRow(`
|
||||
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
|
||||
FROM strategies
|
||||
WHERE user_id = ? AND is_active = 1
|
||||
`, userID).Scan(
|
||||
&st.ID, &st.UserID, &st.Name, &st.Description,
|
||||
&st.IsActive, &st.IsDefault, &st.Config,
|
||||
&createdAt, &updatedAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
err := s.db.Where("user_id = ? AND is_active = ?", userID, true).First(&st).Error
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
// no active strategy, return system default strategy
|
||||
return s.GetDefault()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
st.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
st.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
return &st, nil
|
||||
}
|
||||
|
||||
// GetDefault get system default strategy
|
||||
func (s *StrategyStore) GetDefault() (*Strategy, error) {
|
||||
var st Strategy
|
||||
var createdAt, updatedAt string
|
||||
err := s.db.QueryRow(`
|
||||
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
|
||||
FROM strategies
|
||||
WHERE is_default = 1
|
||||
LIMIT 1
|
||||
`).Scan(
|
||||
&st.ID, &st.UserID, &st.Name, &st.Description,
|
||||
&st.IsActive, &st.IsDefault, &st.Config,
|
||||
&createdAt, &updatedAt,
|
||||
)
|
||||
err := s.db.Where("is_default = ?", true).First(&st).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
st.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
st.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
return &st, nil
|
||||
}
|
||||
|
||||
// SetActive set active strategy (will first deactivate other strategies)
|
||||
func (s *StrategyStore) SetActive(userID, strategyID string) error {
|
||||
// begin transaction
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
// first deactivate all strategies for the user
|
||||
_, err = tx.Exec(`UPDATE strategies SET is_active = 0 WHERE user_id = ?`, userID)
|
||||
if err != nil {
|
||||
if err := tx.Model(&Strategy{}).Where("user_id = ?", userID).
|
||||
Update("is_active", false).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// activate specified strategy
|
||||
_, err = tx.Exec(`UPDATE strategies SET is_active = 1 WHERE id = ? AND (user_id = ? OR is_default = 1)`, strategyID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
return tx.Model(&Strategy{}).
|
||||
Where("id = ? AND (user_id = ? OR is_default = ?)", strategyID, userID, true).
|
||||
Update("is_active", true).Error
|
||||
})
|
||||
}
|
||||
|
||||
// Duplicate duplicate a strategy (used to create custom strategy based on default strategy)
|
||||
|
||||
+108
-371
@@ -1,43 +1,52 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TraderStore trader storage
|
||||
type TraderStore struct {
|
||||
db *sql.DB
|
||||
decryptFunc func(string) string
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewTraderStore creates a new trader store
|
||||
func NewTraderStore(db *gorm.DB) *TraderStore {
|
||||
return &TraderStore{db: db}
|
||||
}
|
||||
|
||||
// Trader trader configuration
|
||||
type Trader struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
AIModelID string `json:"ai_model_id"`
|
||||
ExchangeID string `json:"exchange_id"`
|
||||
StrategyID string `json:"strategy_id"` // Associated strategy ID
|
||||
InitialBalance float64 `json:"initial_balance"`
|
||||
ScanIntervalMinutes int `json:"scan_interval_minutes"`
|
||||
IsRunning bool `json:"is_running"`
|
||||
IsCrossMargin bool `json:"is_cross_margin"`
|
||||
ShowInCompetition bool `json:"show_in_competition"` // Whether to show in competition page
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
UserID string `gorm:"column:user_id;not null;default:default;index" json:"user_id"`
|
||||
Name string `gorm:"column:name;not null" json:"name"`
|
||||
AIModelID string `gorm:"column:ai_model_id;not null" json:"ai_model_id"`
|
||||
ExchangeID string `gorm:"column:exchange_id;not null" json:"exchange_id"`
|
||||
StrategyID string `gorm:"column:strategy_id;default:''" json:"strategy_id"`
|
||||
InitialBalance float64 `gorm:"column:initial_balance;not null" json:"initial_balance"`
|
||||
ScanIntervalMinutes int `gorm:"column:scan_interval_minutes;default:3" json:"scan_interval_minutes"`
|
||||
IsRunning bool `gorm:"column:is_running;default:false" json:"is_running"`
|
||||
IsCrossMargin bool `gorm:"column:is_cross_margin;default:true" json:"is_cross_margin"`
|
||||
ShowInCompetition bool `gorm:"column:show_in_competition;default:true" json:"show_in_competition"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime" json:"updated_at"`
|
||||
|
||||
// Following fields are deprecated, kept for backward compatibility, new traders should use StrategyID
|
||||
BTCETHLeverage int `json:"btc_eth_leverage,omitempty"`
|
||||
AltcoinLeverage int `json:"altcoin_leverage,omitempty"`
|
||||
TradingSymbols string `json:"trading_symbols,omitempty"`
|
||||
UseCoinPool bool `json:"use_coin_pool,omitempty"`
|
||||
UseOITop bool `json:"use_oi_top,omitempty"`
|
||||
CustomPrompt string `json:"custom_prompt,omitempty"`
|
||||
OverrideBasePrompt bool `json:"override_base_prompt,omitempty"`
|
||||
SystemPromptTemplate string `json:"system_prompt_template,omitempty"`
|
||||
BTCETHLeverage int `gorm:"column:btc_eth_leverage;default:5" json:"btc_eth_leverage,omitempty"`
|
||||
AltcoinLeverage int `gorm:"column:altcoin_leverage;default:5" json:"altcoin_leverage,omitempty"`
|
||||
TradingSymbols string `gorm:"column:trading_symbols;default:''" json:"trading_symbols,omitempty"`
|
||||
UseCoinPool bool `gorm:"column:use_coin_pool;default:false" json:"use_coin_pool,omitempty"`
|
||||
UseOITop bool `gorm:"column:use_oi_top;default:false" json:"use_oi_top,omitempty"`
|
||||
CustomPrompt string `gorm:"column:custom_prompt;default:''" json:"custom_prompt,omitempty"`
|
||||
OverrideBasePrompt bool `gorm:"column:override_base_prompt;default:false" json:"override_base_prompt,omitempty"`
|
||||
SystemPromptTemplate string `gorm:"column:system_prompt_template;default:default" json:"system_prompt_template,omitempty"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for Trader
|
||||
func (Trader) TableName() string {
|
||||
return "traders"
|
||||
}
|
||||
|
||||
// TraderFullConfig trader full configuration (includes AI model, exchange and strategy)
|
||||
@@ -45,331 +54,130 @@ type TraderFullConfig struct {
|
||||
Trader *Trader
|
||||
AIModel *AIModel
|
||||
Exchange *Exchange
|
||||
Strategy *Strategy // Associated strategy configuration
|
||||
Strategy *Strategy
|
||||
}
|
||||
|
||||
func (s *TraderStore) initTables() error {
|
||||
_, err := s.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS traders (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL DEFAULT 'default',
|
||||
name TEXT NOT NULL,
|
||||
ai_model_id TEXT NOT NULL,
|
||||
exchange_id TEXT NOT NULL,
|
||||
initial_balance REAL NOT NULL,
|
||||
scan_interval_minutes INTEGER DEFAULT 3,
|
||||
is_running BOOLEAN DEFAULT 0,
|
||||
btc_eth_leverage INTEGER DEFAULT 5,
|
||||
altcoin_leverage INTEGER DEFAULT 5,
|
||||
trading_symbols TEXT DEFAULT '',
|
||||
use_coin_pool BOOLEAN DEFAULT 0,
|
||||
use_oi_top BOOLEAN DEFAULT 0,
|
||||
custom_prompt TEXT DEFAULT '',
|
||||
override_base_prompt BOOLEAN DEFAULT 0,
|
||||
system_prompt_template TEXT DEFAULT 'default',
|
||||
is_cross_margin BOOLEAN DEFAULT 1,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Trigger
|
||||
_, err = s.db.Exec(`
|
||||
CREATE TRIGGER IF NOT EXISTS update_traders_updated_at
|
||||
AFTER UPDATE ON traders
|
||||
BEGIN
|
||||
UPDATE traders SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
|
||||
END
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Backward compatibility
|
||||
alterQueries := []string{
|
||||
`ALTER TABLE traders ADD COLUMN custom_prompt TEXT DEFAULT ''`,
|
||||
`ALTER TABLE traders ADD COLUMN override_base_prompt BOOLEAN DEFAULT 0`,
|
||||
`ALTER TABLE traders ADD COLUMN is_cross_margin BOOLEAN DEFAULT 1`,
|
||||
`ALTER TABLE traders ADD COLUMN btc_eth_leverage INTEGER DEFAULT 5`,
|
||||
`ALTER TABLE traders ADD COLUMN altcoin_leverage INTEGER DEFAULT 5`,
|
||||
`ALTER TABLE traders ADD COLUMN trading_symbols TEXT DEFAULT ''`,
|
||||
`ALTER TABLE traders ADD COLUMN use_coin_pool BOOLEAN DEFAULT 0`,
|
||||
`ALTER TABLE traders ADD COLUMN use_oi_top BOOLEAN DEFAULT 0`,
|
||||
`ALTER TABLE traders ADD COLUMN system_prompt_template TEXT DEFAULT 'default'`,
|
||||
`ALTER TABLE traders ADD COLUMN strategy_id TEXT DEFAULT ''`,
|
||||
`ALTER TABLE traders ADD COLUMN show_in_competition BOOLEAN DEFAULT 1`,
|
||||
}
|
||||
for _, q := range alterQueries {
|
||||
s.db.Exec(q)
|
||||
}
|
||||
|
||||
// Migration: Remove FOREIGN KEY constraint from existing traders table
|
||||
// SQLite doesn't support ALTER TABLE DROP CONSTRAINT, so we need to recreate the table
|
||||
if err := s.migrateTradersRemoveFK(); err != nil {
|
||||
// Log but don't fail - this is a best-effort migration
|
||||
// The constraint may not exist in older databases
|
||||
}
|
||||
|
||||
// For PostgreSQL with existing table, skip AutoMigrate
|
||||
if s.db.Dialector.Name() == "postgres" {
|
||||
var tableExists int64
|
||||
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'traders'`).Scan(&tableExists)
|
||||
if tableExists > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateTradersRemoveFK removes FOREIGN KEY constraint from traders table if it exists
|
||||
func (s *TraderStore) migrateTradersRemoveFK() error {
|
||||
// Check if the table has a foreign key constraint by examining the schema
|
||||
var sql string
|
||||
err := s.db.QueryRow(`SELECT sql FROM sqlite_master WHERE type='table' AND name='traders'`).Scan(&sql)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If no FOREIGN KEY in schema, no migration needed
|
||||
if !strings.Contains(sql, "FOREIGN KEY") {
|
||||
// Use GORM AutoMigrate
|
||||
if err := s.db.AutoMigrate(&Trader{}); err != nil {
|
||||
return fmt.Errorf("failed to migrate traders table: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Recreate table without FOREIGN KEY constraint
|
||||
_, err = s.db.Exec(`
|
||||
-- Create new table without FOREIGN KEY
|
||||
CREATE TABLE IF NOT EXISTS traders_new (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL DEFAULT 'default',
|
||||
name TEXT NOT NULL,
|
||||
ai_model_id TEXT NOT NULL,
|
||||
exchange_id TEXT NOT NULL,
|
||||
initial_balance REAL NOT NULL,
|
||||
scan_interval_minutes INTEGER DEFAULT 3,
|
||||
is_running BOOLEAN DEFAULT 0,
|
||||
btc_eth_leverage INTEGER DEFAULT 5,
|
||||
altcoin_leverage INTEGER DEFAULT 5,
|
||||
trading_symbols TEXT DEFAULT '',
|
||||
use_coin_pool BOOLEAN DEFAULT 0,
|
||||
use_oi_top BOOLEAN DEFAULT 0,
|
||||
custom_prompt TEXT DEFAULT '',
|
||||
override_base_prompt BOOLEAN DEFAULT 0,
|
||||
system_prompt_template TEXT DEFAULT 'default',
|
||||
is_cross_margin BOOLEAN DEFAULT 1,
|
||||
strategy_id TEXT DEFAULT '',
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Copy data from old table
|
||||
INSERT OR IGNORE INTO traders_new
|
||||
SELECT id, user_id, name, ai_model_id, exchange_id, initial_balance,
|
||||
scan_interval_minutes, is_running, btc_eth_leverage, altcoin_leverage,
|
||||
trading_symbols, use_coin_pool, use_oi_top, custom_prompt,
|
||||
override_base_prompt, system_prompt_template, is_cross_margin,
|
||||
COALESCE(strategy_id, ''), created_at, updated_at
|
||||
FROM traders;
|
||||
|
||||
-- Drop old table
|
||||
DROP TABLE traders;
|
||||
|
||||
-- Rename new table
|
||||
ALTER TABLE traders_new RENAME TO traders;
|
||||
`)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Recreate trigger
|
||||
_, err = s.db.Exec(`
|
||||
CREATE TRIGGER IF NOT EXISTS update_traders_updated_at
|
||||
AFTER UPDATE ON traders
|
||||
BEGIN
|
||||
UPDATE traders SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
|
||||
END
|
||||
`)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *TraderStore) decrypt(encrypted string) string {
|
||||
if s.decryptFunc != nil {
|
||||
return s.decryptFunc(encrypted)
|
||||
}
|
||||
return encrypted
|
||||
}
|
||||
|
||||
// Create creates trader
|
||||
func (s *TraderStore) Create(trader *Trader) error {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO traders (id, user_id, name, ai_model_id, exchange_id, strategy_id, initial_balance,
|
||||
scan_interval_minutes, is_running, is_cross_margin, show_in_competition,
|
||||
btc_eth_leverage, altcoin_leverage, trading_symbols, use_coin_pool,
|
||||
use_oi_top, custom_prompt, override_base_prompt, system_prompt_template)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`, trader.ID, trader.UserID, trader.Name, trader.AIModelID, trader.ExchangeID, trader.StrategyID,
|
||||
trader.InitialBalance, trader.ScanIntervalMinutes, trader.IsRunning, trader.IsCrossMargin, trader.ShowInCompetition,
|
||||
trader.BTCETHLeverage, trader.AltcoinLeverage, trader.TradingSymbols, trader.UseCoinPool,
|
||||
trader.UseOITop, trader.CustomPrompt, trader.OverrideBasePrompt, trader.SystemPromptTemplate)
|
||||
return err
|
||||
return s.db.Create(trader).Error
|
||||
}
|
||||
|
||||
// List gets user's trader list
|
||||
func (s *TraderStore) List(userID string) ([]*Trader, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, user_id, name, ai_model_id, exchange_id, COALESCE(strategy_id, ''),
|
||||
initial_balance, scan_interval_minutes, is_running, COALESCE(is_cross_margin, 1),
|
||||
COALESCE(show_in_competition, 1),
|
||||
COALESCE(btc_eth_leverage, 5), COALESCE(altcoin_leverage, 5), COALESCE(trading_symbols, ''),
|
||||
COALESCE(use_coin_pool, 0), COALESCE(use_oi_top, 0), COALESCE(custom_prompt, ''),
|
||||
COALESCE(override_base_prompt, 0), COALESCE(system_prompt_template, 'default'),
|
||||
created_at, updated_at
|
||||
FROM traders WHERE user_id = ? ORDER BY created_at DESC
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var traders []*Trader
|
||||
for rows.Next() {
|
||||
var t Trader
|
||||
var createdAt, updatedAt string
|
||||
err := rows.Scan(
|
||||
&t.ID, &t.UserID, &t.Name, &t.AIModelID, &t.ExchangeID, &t.StrategyID,
|
||||
&t.InitialBalance, &t.ScanIntervalMinutes, &t.IsRunning, &t.IsCrossMargin,
|
||||
&t.ShowInCompetition,
|
||||
&t.BTCETHLeverage, &t.AltcoinLeverage, &t.TradingSymbols,
|
||||
&t.UseCoinPool, &t.UseOITop, &t.CustomPrompt, &t.OverrideBasePrompt,
|
||||
&t.SystemPromptTemplate, &createdAt, &updatedAt,
|
||||
)
|
||||
err := s.db.Where("user_id = ?", userID).
|
||||
Order("created_at DESC").
|
||||
Find(&traders).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
t.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
traders = append(traders, &t)
|
||||
}
|
||||
return traders, nil
|
||||
}
|
||||
|
||||
// UpdateStatus updates trader running status
|
||||
func (s *TraderStore) UpdateStatus(userID, id string, isRunning bool) error {
|
||||
_, err := s.db.Exec(`UPDATE traders SET is_running = ? WHERE id = ? AND user_id = ?`, isRunning, id, userID)
|
||||
return err
|
||||
return s.db.Model(&Trader{}).
|
||||
Where("id = ? AND user_id = ?", id, userID).
|
||||
Update("is_running", isRunning).Error
|
||||
}
|
||||
|
||||
// UpdateShowInCompetition updates trader competition visibility
|
||||
func (s *TraderStore) UpdateShowInCompetition(userID, id string, showInCompetition bool) error {
|
||||
_, err := s.db.Exec(`UPDATE traders SET show_in_competition = ? WHERE id = ? AND user_id = ?`, showInCompetition, id, userID)
|
||||
return err
|
||||
return s.db.Model(&Trader{}).
|
||||
Where("id = ? AND user_id = ?", id, userID).
|
||||
Update("show_in_competition", showInCompetition).Error
|
||||
}
|
||||
|
||||
// Update updates trader configuration
|
||||
func (s *TraderStore) Update(trader *Trader) error {
|
||||
fmt.Printf("📝 TraderStore.Update: ID=%s, Name=%s, AIModelID=%s, StrategyID=%s\n",
|
||||
trader.ID, trader.Name, trader.AIModelID, trader.StrategyID)
|
||||
_, err := s.db.Exec(`
|
||||
UPDATE traders SET
|
||||
name = ?,
|
||||
ai_model_id = ?,
|
||||
exchange_id = ?,
|
||||
strategy_id = ?,
|
||||
initial_balance = CASE WHEN ? > 0 THEN ? ELSE initial_balance END,
|
||||
scan_interval_minutes = CASE WHEN ? > 0 THEN ? ELSE scan_interval_minutes END,
|
||||
is_cross_margin = ?,
|
||||
show_in_competition = ?,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = ? AND user_id = ?
|
||||
`, trader.Name, trader.AIModelID, trader.ExchangeID, trader.StrategyID,
|
||||
trader.InitialBalance, trader.InitialBalance,
|
||||
trader.ScanIntervalMinutes, trader.ScanIntervalMinutes,
|
||||
trader.IsCrossMargin, trader.ShowInCompetition,
|
||||
trader.ID, trader.UserID)
|
||||
return err
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"name": trader.Name,
|
||||
"ai_model_id": trader.AIModelID,
|
||||
"exchange_id": trader.ExchangeID,
|
||||
"strategy_id": trader.StrategyID,
|
||||
"is_cross_margin": trader.IsCrossMargin,
|
||||
"show_in_competition": trader.ShowInCompetition,
|
||||
}
|
||||
|
||||
// Only update these if > 0
|
||||
if trader.InitialBalance > 0 {
|
||||
updates["initial_balance"] = trader.InitialBalance
|
||||
}
|
||||
if trader.ScanIntervalMinutes > 0 {
|
||||
updates["scan_interval_minutes"] = trader.ScanIntervalMinutes
|
||||
}
|
||||
|
||||
return s.db.Model(&Trader{}).
|
||||
Where("id = ? AND user_id = ?", trader.ID, trader.UserID).
|
||||
Updates(updates).Error
|
||||
}
|
||||
|
||||
// UpdateInitialBalance updates initial balance
|
||||
func (s *TraderStore) UpdateInitialBalance(userID, id string, newBalance float64) error {
|
||||
_, err := s.db.Exec(`UPDATE traders SET initial_balance = ? WHERE id = ? AND user_id = ?`, newBalance, id, userID)
|
||||
return err
|
||||
return s.db.Model(&Trader{}).
|
||||
Where("id = ? AND user_id = ?", id, userID).
|
||||
Update("initial_balance", newBalance).Error
|
||||
}
|
||||
|
||||
// UpdateCustomPrompt updates custom prompt
|
||||
func (s *TraderStore) UpdateCustomPrompt(userID, id string, customPrompt string, overrideBase bool) error {
|
||||
_, err := s.db.Exec(`UPDATE traders SET custom_prompt = ?, override_base_prompt = ? WHERE id = ? AND user_id = ?`,
|
||||
customPrompt, overrideBase, id, userID)
|
||||
return err
|
||||
return s.db.Model(&Trader{}).
|
||||
Where("id = ? AND user_id = ?", id, userID).
|
||||
Updates(map[string]interface{}{
|
||||
"custom_prompt": customPrompt,
|
||||
"override_base_prompt": overrideBase,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// Delete deletes trader and associated data
|
||||
func (s *TraderStore) Delete(userID, id string) error {
|
||||
// Delete associated equity snapshots first
|
||||
_, _ = s.db.Exec(`DELETE FROM trader_equity_snapshots WHERE trader_id = ?`, id)
|
||||
s.db.Where("trader_id = ?", id).Delete(&EquitySnapshot{})
|
||||
|
||||
// Delete the trader
|
||||
_, err := s.db.Exec(`DELETE FROM traders WHERE id = ? AND user_id = ?`, id, userID)
|
||||
return err
|
||||
return s.db.Where("id = ? AND user_id = ?", id, userID).Delete(&Trader{}).Error
|
||||
}
|
||||
|
||||
// GetFullConfig gets trader full configuration
|
||||
func (s *TraderStore) GetFullConfig(userID, traderID string) (*TraderFullConfig, error) {
|
||||
var trader Trader
|
||||
var aiModel AIModel
|
||||
var exchange Exchange
|
||||
var traderCreatedAt, traderUpdatedAt string
|
||||
var aiModelCreatedAt, aiModelUpdatedAt string
|
||||
var exchangeCreatedAt, exchangeUpdatedAt string
|
||||
|
||||
err := s.db.QueryRow(`
|
||||
SELECT
|
||||
t.id, t.user_id, t.name, t.ai_model_id, t.exchange_id, COALESCE(t.strategy_id, ''),
|
||||
t.initial_balance, t.scan_interval_minutes, t.is_running, COALESCE(t.is_cross_margin, 1),
|
||||
COALESCE(t.btc_eth_leverage, 5), COALESCE(t.altcoin_leverage, 5), COALESCE(t.trading_symbols, ''),
|
||||
COALESCE(t.use_coin_pool, 0), COALESCE(t.use_oi_top, 0), COALESCE(t.custom_prompt, ''),
|
||||
COALESCE(t.override_base_prompt, 0), COALESCE(t.system_prompt_template, 'default'),
|
||||
t.created_at, t.updated_at,
|
||||
a.id, a.user_id, a.name, a.provider, a.enabled, a.api_key,
|
||||
COALESCE(a.custom_api_url, ''), COALESCE(a.custom_model_name, ''), a.created_at, a.updated_at,
|
||||
e.id, COALESCE(e.exchange_type, '') as exchange_type, COALESCE(e.account_name, '') as account_name,
|
||||
e.user_id, e.name, e.type, e.enabled, e.api_key, e.secret_key, COALESCE(e.passphrase, ''), e.testnet,
|
||||
COALESCE(e.hyperliquid_wallet_addr, ''), COALESCE(e.aster_user, ''), COALESCE(e.aster_signer, ''),
|
||||
COALESCE(e.aster_private_key, ''), COALESCE(e.lighter_wallet_addr, ''), COALESCE(e.lighter_private_key, ''),
|
||||
COALESCE(e.lighter_api_key_private_key, ''), COALESCE(e.lighter_api_key_index, 0), e.created_at, e.updated_at
|
||||
FROM traders t
|
||||
JOIN ai_models a ON t.ai_model_id = a.id AND t.user_id = a.user_id
|
||||
JOIN exchanges e ON t.exchange_id = e.id AND t.user_id = e.user_id
|
||||
WHERE t.id = ? AND t.user_id = ?
|
||||
`, traderID, userID).Scan(
|
||||
&trader.ID, &trader.UserID, &trader.Name, &trader.AIModelID, &trader.ExchangeID, &trader.StrategyID,
|
||||
&trader.InitialBalance, &trader.ScanIntervalMinutes, &trader.IsRunning, &trader.IsCrossMargin,
|
||||
&trader.BTCETHLeverage, &trader.AltcoinLeverage, &trader.TradingSymbols,
|
||||
&trader.UseCoinPool, &trader.UseOITop, &trader.CustomPrompt, &trader.OverrideBasePrompt,
|
||||
&trader.SystemPromptTemplate, &traderCreatedAt, &traderUpdatedAt,
|
||||
&aiModel.ID, &aiModel.UserID, &aiModel.Name, &aiModel.Provider, &aiModel.Enabled, &aiModel.APIKey,
|
||||
&aiModel.CustomAPIURL, &aiModel.CustomModelName, &aiModelCreatedAt, &aiModelUpdatedAt,
|
||||
&exchange.ID, &exchange.ExchangeType, &exchange.AccountName,
|
||||
&exchange.UserID, &exchange.Name, &exchange.Type, &exchange.Enabled,
|
||||
&exchange.APIKey, &exchange.SecretKey, &exchange.Passphrase, &exchange.Testnet, &exchange.HyperliquidWalletAddr,
|
||||
&exchange.AsterUser, &exchange.AsterSigner, &exchange.AsterPrivateKey,
|
||||
&exchange.LighterWalletAddr, &exchange.LighterPrivateKey, &exchange.LighterAPIKeyPrivateKey, &exchange.LighterAPIKeyIndex,
|
||||
&exchangeCreatedAt, &exchangeUpdatedAt,
|
||||
)
|
||||
err := s.db.Where("id = ? AND user_id = ?", traderID, userID).First(&trader).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
trader.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", traderCreatedAt)
|
||||
trader.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", traderUpdatedAt)
|
||||
aiModel.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", aiModelCreatedAt)
|
||||
aiModel.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", aiModelUpdatedAt)
|
||||
exchange.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", exchangeCreatedAt)
|
||||
exchange.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", exchangeUpdatedAt)
|
||||
// Get AI model
|
||||
var aiModel AIModel
|
||||
err = s.db.Where("id = ? AND user_id = ?", trader.AIModelID, userID).First(&aiModel).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get AI model: %w", err)
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
aiModel.APIKey = s.decrypt(aiModel.APIKey)
|
||||
exchange.APIKey = s.decrypt(exchange.APIKey)
|
||||
exchange.SecretKey = s.decrypt(exchange.SecretKey)
|
||||
exchange.Passphrase = s.decrypt(exchange.Passphrase)
|
||||
exchange.AsterPrivateKey = s.decrypt(exchange.AsterPrivateKey)
|
||||
exchange.LighterPrivateKey = s.decrypt(exchange.LighterPrivateKey)
|
||||
exchange.LighterAPIKeyPrivateKey = s.decrypt(exchange.LighterAPIKeyPrivateKey)
|
||||
// Get exchange
|
||||
var exchange Exchange
|
||||
err = s.db.Where("id = ? AND user_id = ?", trader.ExchangeID, userID).First(&exchange).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get exchange: %w", err)
|
||||
}
|
||||
|
||||
// Load associated strategy
|
||||
var strategy *Strategy
|
||||
@@ -392,119 +200,48 @@ func (s *TraderStore) GetFullConfig(userID, traderID string) (*TraderFullConfig,
|
||||
// getStrategyByID internal method: gets strategy by ID
|
||||
func (s *TraderStore) getStrategyByID(userID, strategyID string) (*Strategy, error) {
|
||||
var strategy Strategy
|
||||
var createdAt, updatedAt string
|
||||
err := s.db.QueryRow(`
|
||||
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
|
||||
FROM strategies WHERE id = ? AND (user_id = ? OR is_default = 1)
|
||||
`, strategyID, userID).Scan(
|
||||
&strategy.ID, &strategy.UserID, &strategy.Name, &strategy.Description,
|
||||
&strategy.IsActive, &strategy.IsDefault, &strategy.Config, &createdAt, &updatedAt,
|
||||
)
|
||||
err := s.db.Where("id = ? AND (user_id = ? OR is_default = ?)", strategyID, userID, true).
|
||||
First(&strategy).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
strategy.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
strategy.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
return &strategy, nil
|
||||
}
|
||||
|
||||
// getActiveOrDefaultStrategy internal method: gets user's active strategy or system default strategy
|
||||
func (s *TraderStore) getActiveOrDefaultStrategy(userID string) (*Strategy, error) {
|
||||
var strategy Strategy
|
||||
var createdAt, updatedAt string
|
||||
|
||||
// First try to get user's active strategy
|
||||
err := s.db.QueryRow(`
|
||||
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
|
||||
FROM strategies WHERE user_id = ? AND is_active = 1
|
||||
`, userID).Scan(
|
||||
&strategy.ID, &strategy.UserID, &strategy.Name, &strategy.Description,
|
||||
&strategy.IsActive, &strategy.IsDefault, &strategy.Config, &createdAt, &updatedAt,
|
||||
)
|
||||
err := s.db.Where("user_id = ? AND is_active = ?", userID, true).First(&strategy).Error
|
||||
if err == nil {
|
||||
strategy.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
strategy.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
return &strategy, nil
|
||||
}
|
||||
|
||||
// Fallback to system default strategy
|
||||
err = s.db.QueryRow(`
|
||||
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
|
||||
FROM strategies WHERE is_default = 1 LIMIT 1
|
||||
`).Scan(
|
||||
&strategy.ID, &strategy.UserID, &strategy.Name, &strategy.Description,
|
||||
&strategy.IsActive, &strategy.IsDefault, &strategy.Config, &createdAt, &updatedAt,
|
||||
)
|
||||
err = s.db.Where("is_default = ?", true).First(&strategy).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
strategy.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
strategy.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
return &strategy, nil
|
||||
}
|
||||
|
||||
// ListAll gets all users' trader list
|
||||
// GetByID gets a trader by ID without requiring userID (for public APIs)
|
||||
func (s *TraderStore) GetByID(traderID string) (*Trader, error) {
|
||||
var t Trader
|
||||
var createdAt, updatedAt string
|
||||
err := s.db.QueryRow(`
|
||||
SELECT id, user_id, name, ai_model_id, exchange_id, COALESCE(strategy_id, ''),
|
||||
initial_balance, scan_interval_minutes, is_running, COALESCE(is_cross_margin, 1),
|
||||
COALESCE(btc_eth_leverage, 5), COALESCE(altcoin_leverage, 5), COALESCE(trading_symbols, ''),
|
||||
COALESCE(use_coin_pool, 0), COALESCE(use_oi_top, 0), COALESCE(custom_prompt, ''),
|
||||
COALESCE(override_base_prompt, 0), COALESCE(system_prompt_template, 'default'),
|
||||
created_at, updated_at
|
||||
FROM traders WHERE id = ?
|
||||
`, traderID).Scan(
|
||||
&t.ID, &t.UserID, &t.Name, &t.AIModelID, &t.ExchangeID, &t.StrategyID,
|
||||
&t.InitialBalance, &t.ScanIntervalMinutes, &t.IsRunning, &t.IsCrossMargin,
|
||||
&t.BTCETHLeverage, &t.AltcoinLeverage, &t.TradingSymbols,
|
||||
&t.UseCoinPool, &t.UseOITop, &t.CustomPrompt, &t.OverrideBasePrompt,
|
||||
&t.SystemPromptTemplate, &createdAt, &updatedAt,
|
||||
)
|
||||
var trader Trader
|
||||
err := s.db.Where("id = ?", traderID).First(&trader).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
t.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
return &t, nil
|
||||
return &trader, nil
|
||||
}
|
||||
|
||||
// ListAll gets all traders
|
||||
func (s *TraderStore) ListAll() ([]*Trader, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, user_id, name, ai_model_id, exchange_id, COALESCE(strategy_id, ''),
|
||||
initial_balance, scan_interval_minutes, is_running, COALESCE(is_cross_margin, 1),
|
||||
COALESCE(show_in_competition, 1),
|
||||
COALESCE(btc_eth_leverage, 5), COALESCE(altcoin_leverage, 5), COALESCE(trading_symbols, ''),
|
||||
COALESCE(use_coin_pool, 0), COALESCE(use_oi_top, 0), COALESCE(custom_prompt, ''),
|
||||
COALESCE(override_base_prompt, 0), COALESCE(system_prompt_template, 'default'),
|
||||
created_at, updated_at
|
||||
FROM traders ORDER BY created_at DESC
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var traders []*Trader
|
||||
for rows.Next() {
|
||||
var t Trader
|
||||
var createdAt, updatedAt string
|
||||
err := rows.Scan(
|
||||
&t.ID, &t.UserID, &t.Name, &t.AIModelID, &t.ExchangeID, &t.StrategyID,
|
||||
&t.InitialBalance, &t.ScanIntervalMinutes, &t.IsRunning, &t.IsCrossMargin,
|
||||
&t.ShowInCompetition,
|
||||
&t.BTCETHLeverage, &t.AltcoinLeverage, &t.TradingSymbols,
|
||||
&t.UseCoinPool, &t.UseOITop, &t.CustomPrompt, &t.OverrideBasePrompt,
|
||||
&t.SystemPromptTemplate, &createdAt, &updatedAt,
|
||||
)
|
||||
err := s.db.Order("created_at DESC").Find(&traders).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
t.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
traders = append(traders, &t)
|
||||
}
|
||||
return traders, nil
|
||||
}
|
||||
|
||||
+58
-85
@@ -2,27 +2,30 @@ package store
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base32"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// UserStore user storage
|
||||
type UserStore struct {
|
||||
db *sql.DB
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// User user
|
||||
// User user model
|
||||
type User struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
PasswordHash string `json:"-"`
|
||||
OTPSecret string `json:"-"`
|
||||
OTPVerified bool `json:"otp_verified"`
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
Email string `gorm:"uniqueIndex:idx_users_email;not null" json:"email"`
|
||||
PasswordHash string `gorm:"column:password_hash;not null" json:"-"`
|
||||
OTPSecret string `gorm:"column:otp_secret" json:"-"`
|
||||
OTPVerified bool `gorm:"column:otp_verified;default:false" json:"otp_verified"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
func (User) TableName() string { return "users" }
|
||||
|
||||
// GenerateOTPSecret generates OTP secret
|
||||
func GenerateOTPSecret() (string, error) {
|
||||
secret := make([]byte, 20)
|
||||
@@ -33,131 +36,101 @@ func GenerateOTPSecret() (string, error) {
|
||||
return base32.StdEncoding.EncodeToString(secret), nil
|
||||
}
|
||||
|
||||
func (s *UserStore) initTables() error {
|
||||
_, err := s.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id TEXT PRIMARY KEY,
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
password_hash TEXT NOT NULL,
|
||||
otp_secret TEXT,
|
||||
otp_verified BOOLEAN DEFAULT 0,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
// NewUserStore creates a new UserStore
|
||||
func NewUserStore(db *gorm.DB) *UserStore {
|
||||
return &UserStore{db: db}
|
||||
}
|
||||
|
||||
// Trigger
|
||||
_, err = s.db.Exec(`
|
||||
CREATE TRIGGER IF NOT EXISTS update_users_updated_at
|
||||
AFTER UPDATE ON users
|
||||
BEGIN
|
||||
UPDATE users SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
|
||||
END
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
func (s *UserStore) initTables() error {
|
||||
// For PostgreSQL with existing table, skip AutoMigrate to avoid index conflicts
|
||||
if s.db.Dialector.Name() == "postgres" {
|
||||
var tableExists int64
|
||||
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'users'`).Scan(&tableExists)
|
||||
|
||||
if tableExists > 0 {
|
||||
// Table exists - manually ensure all columns exist
|
||||
// Core columns (should already exist)
|
||||
s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS email TEXT NOT NULL DEFAULT ''`)
|
||||
s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS password_hash TEXT NOT NULL DEFAULT ''`)
|
||||
s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP`)
|
||||
s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP`)
|
||||
// OTP columns (added later)
|
||||
s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS otp_secret TEXT DEFAULT ''`)
|
||||
s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS otp_verified BOOLEAN DEFAULT FALSE`)
|
||||
|
||||
// Ensure unique index exists on email (don't care about the name)
|
||||
var indexExists int64
|
||||
s.db.Raw(`
|
||||
SELECT COUNT(*) FROM pg_indexes
|
||||
WHERE tablename = 'users' AND indexdef LIKE '%email%' AND indexdef LIKE '%UNIQUE%'
|
||||
`).Scan(&indexExists)
|
||||
|
||||
if indexExists == 0 {
|
||||
s.db.Exec("CREATE UNIQUE INDEX idx_users_email ON users(email)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return s.db.AutoMigrate(&User{})
|
||||
}
|
||||
|
||||
// Create creates user
|
||||
func (s *UserStore) Create(user *User) error {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO users (id, email, password_hash, otp_secret, otp_verified)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
`, user.ID, user.Email, user.PasswordHash, user.OTPSecret, user.OTPVerified)
|
||||
return err
|
||||
return s.db.Create(user).Error
|
||||
}
|
||||
|
||||
// GetByEmail gets user by email
|
||||
func (s *UserStore) GetByEmail(email string) (*User, error) {
|
||||
var user User
|
||||
var createdAt, updatedAt string
|
||||
err := s.db.QueryRow(`
|
||||
SELECT id, email, password_hash, otp_secret, otp_verified, created_at, updated_at
|
||||
FROM users WHERE email = ?
|
||||
`, email).Scan(
|
||||
&user.ID, &user.Email, &user.PasswordHash, &user.OTPSecret,
|
||||
&user.OTPVerified, &createdAt, &updatedAt,
|
||||
)
|
||||
err := s.db.Where("email = ?", email).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
user.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetByID gets user by ID
|
||||
func (s *UserStore) GetByID(userID string) (*User, error) {
|
||||
var user User
|
||||
var createdAt, updatedAt string
|
||||
err := s.db.QueryRow(`
|
||||
SELECT id, email, password_hash, otp_secret, otp_verified, created_at, updated_at
|
||||
FROM users WHERE id = ?
|
||||
`, userID).Scan(
|
||||
&user.ID, &user.Email, &user.PasswordHash, &user.OTPSecret,
|
||||
&user.OTPVerified, &createdAt, &updatedAt,
|
||||
)
|
||||
err := s.db.Where("id = ?", userID).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
user.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// Count returns the total number of users
|
||||
func (s *UserStore) Count() (int, error) {
|
||||
var count int
|
||||
err := s.db.QueryRow(`SELECT COUNT(*) FROM users`).Scan(&count)
|
||||
return count, err
|
||||
var count int64
|
||||
err := s.db.Model(&User{}).Count(&count).Error
|
||||
return int(count), err
|
||||
}
|
||||
|
||||
// GetAllIDs gets all user IDs
|
||||
func (s *UserStore) GetAllIDs() ([]string, error) {
|
||||
rows, err := s.db.Query(`SELECT id FROM users ORDER BY id`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var userIDs []string
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
if err := rows.Scan(&userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userIDs = append(userIDs, userID)
|
||||
}
|
||||
return userIDs, nil
|
||||
err := s.db.Model(&User{}).Order("id").Pluck("id", &userIDs).Error
|
||||
return userIDs, err
|
||||
}
|
||||
|
||||
// UpdateOTPVerified updates OTP verification status
|
||||
func (s *UserStore) UpdateOTPVerified(userID string, verified bool) error {
|
||||
_, err := s.db.Exec(`UPDATE users SET otp_verified = ? WHERE id = ?`, verified, userID)
|
||||
return err
|
||||
return s.db.Model(&User{}).Where("id = ?", userID).Update("otp_verified", verified).Error
|
||||
}
|
||||
|
||||
// UpdatePassword updates password
|
||||
func (s *UserStore) UpdatePassword(userID, passwordHash string) error {
|
||||
_, err := s.db.Exec(`
|
||||
UPDATE users SET password_hash = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?
|
||||
`, passwordHash, userID)
|
||||
return err
|
||||
return s.db.Model(&User{}).Where("id = ?", userID).Updates(map[string]interface{}{
|
||||
"password_hash": passwordHash,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// EnsureAdmin ensures admin user exists
|
||||
func (s *UserStore) EnsureAdmin() error {
|
||||
var count int
|
||||
err := s.db.QueryRow(`SELECT COUNT(*) FROM users WHERE id = 'admin'`).Scan(&count)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var count int64
|
||||
s.db.Model(&User{}).Where("id = ?", "admin").Count(&count)
|
||||
if count > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package trader
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"nofx/store"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// TestScenario represents a trading scenario to test
|
||||
@@ -116,11 +117,12 @@ func runStandardTests(t *testing.T, exchangeName string) {
|
||||
for _, scenario := range scenarios {
|
||||
t.Run(scenario.Name, func(t *testing.T) {
|
||||
// Setup database
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
positionStore := store.NewPositionStore(db)
|
||||
if err := positionStore.InitTables(); err != nil {
|
||||
@@ -199,11 +201,12 @@ func TestAllExchangesStandardScenarios(t *testing.T) {
|
||||
|
||||
// TestPositionAccumulationBug tests that positions don't accumulate incorrectly
|
||||
func TestPositionAccumulationBug(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
positionStore := store.NewPositionStore(db)
|
||||
if err := positionStore.InitTables(); err != nil {
|
||||
@@ -283,11 +286,12 @@ func TestPositionAccumulationBug(t *testing.T) {
|
||||
|
||||
// TestQuantityPrecision tests handling of quantity precision issues
|
||||
func TestQuantityPrecision(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
positionStore := store.NewPositionStore(db)
|
||||
if err := positionStore.InitTables(); err != nil {
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package trader
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"math"
|
||||
"nofx/store"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// TestHyperliquidOrderDirectionParsing tests Dir field parsing
|
||||
@@ -75,11 +76,12 @@ func TestHyperliquidOrderDirectionParsing(t *testing.T) {
|
||||
// TestHyperliquidPositionBuilding tests the complete flow of position building
|
||||
func TestHyperliquidPositionBuilding(t *testing.T) {
|
||||
// Setup in-memory database
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Initialize stores
|
||||
positionStore := store.NewPositionStore(db)
|
||||
@@ -304,11 +306,12 @@ func TestHyperliquidPositionBuilding(t *testing.T) {
|
||||
// TestHyperliquidBugScenario tests the exact bug we fixed
|
||||
func TestHyperliquidBugScenario(t *testing.T) {
|
||||
// Setup database
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
positionStore := store.NewPositionStore(db)
|
||||
if err := positionStore.InitTables(); err != nil {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { motion } from 'framer-motion'
|
||||
import { Bot, TrendingUp, Layers, Zap, Hexagon, Crosshair } from 'lucide-react'
|
||||
import { TrendingUp, Layers, Zap, Hexagon, Crosshair } from 'lucide-react'
|
||||
|
||||
const agents = [
|
||||
{
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
import { motion } from 'framer-motion'
|
||||
import { Activity, BarChart3, Globe, Wifi, Server, Database, Lock } from 'lucide-react'
|
||||
import { useState, useEffect } from 'react'
|
||||
|
||||
const generateLog = (id) => {
|
||||
interface LogEntry {
|
||||
id: number
|
||||
time: string
|
||||
type: string
|
||||
msg: string
|
||||
color: string
|
||||
}
|
||||
|
||||
const generateLog = (id: number): LogEntry => {
|
||||
const types = ['EXE', 'ARB', 'LIQ', 'NET', 'SYS']
|
||||
const pairs = ['BTC-USDT', 'ETH-PERP', 'SOL-USDT', 'BNB-BUSD']
|
||||
const actions = ['BUY', 'SELL', 'SHORT', 'LONG']
|
||||
@@ -37,7 +44,7 @@ const generateLog = (id) => {
|
||||
}
|
||||
|
||||
export default function LiveFeed() {
|
||||
const [logs, setLogs] = useState([])
|
||||
const [logs, setLogs] = useState<LogEntry[]>([])
|
||||
|
||||
useEffect(() => {
|
||||
// Initial population
|
||||
|
||||
@@ -159,11 +159,24 @@ export default function TerminalHero() {
|
||||
<span className="text-stroke-1 text-transparent bg-clip-text bg-gradient-to-r from-nofx-gold via-white to-nofx-gold animate-shimmer bg-[length:200%_auto]">TRADING</span>
|
||||
</h1>
|
||||
|
||||
<p className="max-w-xl text-zinc-400 text-lg mb-10 font-light leading-relaxed">
|
||||
<p className="max-w-xl text-zinc-400 text-lg mb-6 font-light leading-relaxed">
|
||||
The World's First Open-Source Agentic Trading OS.
|
||||
Deploy autonomous high-frequency trading agents powered by advanced LLMs.
|
||||
</p>
|
||||
|
||||
{/* Market Access Strip - Prominent Display */}
|
||||
<div className="flex flex-wrap gap-4 mb-12 font-mono">
|
||||
{['CRYPTO', 'US STOCKS', 'FOREX', 'METALS'].map((market) => (
|
||||
<div key={market} className="flex items-center gap-3 px-4 py-2 rounded bg-zinc-900 border border-zinc-700 text-white font-bold tracking-wider hover:border-nofx-gold hover:shadow-[0_0_15px_rgba(255,215,0,0.3)] transition-all duration-300">
|
||||
<span className="relative flex h-2 w-2">
|
||||
<span className="animate-ping absolute inline-flex h-full w-full rounded-full bg-nofx-success opacity-75"></span>
|
||||
<span className="relative inline-flex rounded-full h-2 w-2 bg-nofx-success"></span>
|
||||
</span>
|
||||
{market}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Command Line Input Simulation */}
|
||||
<div className="w-full max-w-lg h-12 bg-black/50 border border-zinc-800 rounded flex items-center px-4 mb-10 font-mono text-sm shadow-2xl backdrop-blur-sm group hover:border-nofx-gold/50 transition-colors cursor-text" onClick={() => document.getElementById('market-scanner')?.scrollIntoView({ behavior: 'smooth' })}>
|
||||
<span className="text-nofx-success mr-2">➜</span>
|
||||
@@ -269,7 +282,7 @@ export default function TerminalHero() {
|
||||
import { OFFICIAL_LINKS } from '../../../constants/branding'
|
||||
|
||||
function CommunityStats() {
|
||||
const { stars, forks, contributors, isLoading, error } = useGitHubStats('tinkle-community', 'nofx')
|
||||
const { stars, forks, contributors, isLoading, error } = useGitHubStats('NoFxAiOS', 'nofx')
|
||||
|
||||
const stats = [
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user