From 2d272bb7b8f39c5a114df91a606bd65e713e377e Mon Sep 17 00:00:00 2001 From: tinkle-community Date: Thu, 1 Jan 2026 19:32:49 +0800 Subject: [PATCH] feat: migrate store layer to GORM with PostgreSQL support - Migrate all store packages from raw database/sql to GORM ORM - Add PostgreSQL support alongside SQLite - Move EncryptedString type to crypto package for cleaner architecture - Add automatic encryption/decryption for sensitive fields (API keys, secrets) - Fix PostgreSQL AutoMigrate conflicts by skipping existing tables - Fix duplicate /klines route registration - Update tests to use GORM database connections - Add database configuration support in config package --- .env.example | 13 + api/backtest.go | 2 +- api/server.go | 84 +- api/strategy.go | 18 +- backtest/retention.go | 4 +- backtest/storage_db_impl.go | 8 +- config/config.go | 46 + crypto/crypto.go | 75 + debate/engine.go | 4 +- go.mod | 9 + go.sum | 18 + main.go | 76 +- manager/trader_manager.go | 38 +- store/ai_model.go | 264 ++-- store/backtest.go | 690 ++++----- store/debate.go | 716 +++------ store/decision.go | 333 ++-- store/driver.go | 41 + store/equity.go | 200 +-- store/exchange.go | 455 ++---- store/gorm.go | 146 ++ store/order.go | 647 +++----- store/position.go | 1360 +++++++---------- store/store.go | 186 ++- store/strategy.go | 212 +-- store/trader.go | 485 ++---- store/user.go | 147 +- trader/exchange_sync_test.go | 20 +- trader/hyperliquid_sync_test.go | 15 +- web/src/components/landing/core/AgentGrid.tsx | 2 +- web/src/components/landing/core/LiveFeed.tsx | 13 +- .../components/landing/core/TerminalHero.tsx | 17 +- 32 files changed, 2573 insertions(+), 3771 deletions(-) create mode 100644 store/gorm.go diff --git a/.env.example b/.env.example index 1b31c050..5eafa687 100644 --- a/.env.example +++ b/.env.example @@ -55,3 +55,16 @@ TRANSPORT_ENCRYPTION=false # Telegram notifications (optional) # TELEGRAM_BOT_TOKEN=your-bot-token # TELEGRAM_CHAT_ID=your-chat-id + +DB_TYPE=postgres +DB_HOST=10. +DB_PORT=5432 +DB_USER=nofx_user +DB_PASSWORD= +DB_NAME=nofx +DB_SSLMODE=disable + + +# 数据库配置 - SQLite(默认) +DB_TYPE=sqlite +DB_PATH=data/data.db \ No newline at end of file diff --git a/api/backtest.go b/api/backtest.go index b8bc840a..8e0014b9 100644 --- a/api/backtest.go +++ b/api/backtest.go @@ -824,7 +824,7 @@ func (s *Server) hydrateBacktestAIConfig(cfg *backtest.BacktestConfig) error { return fmt.Errorf("AI model %s is not enabled yet", model.Name) } - apiKey := strings.TrimSpace(model.APIKey) + apiKey := strings.TrimSpace(string(model.APIKey)) if apiKey == "" { return fmt.Errorf("AI model %s is missing API Key, please configure it in the system first", model.Name) } diff --git a/api/server.go b/api/server.go index ccc16c8f..9914b7af 100644 --- a/api/server.go +++ b/api/server.go @@ -54,7 +54,7 @@ func NewServer(traderManager *manager.TraderManager, st *store.Store, cryptoServ cryptoHandler := NewCryptoHandler(cryptoService) // Create debate store and handler - debateStore := store.NewDebateStore(st.DB()) + debateStore := store.NewDebateStore(st.GormDB()) if err := debateStore.InitSchema(); err != nil { logger.Errorf("Failed to initialize debate schema: %v", err) } @@ -125,7 +125,6 @@ func (s *Server) setupRoutes() { // Market data (no authentication required) api.GET("/klines", s.handleKlines) - api.GET("/klines", s.handleKlines) api.GET("/symbols", s.handleSymbols) // Authentication related routes (no authentication required) @@ -576,12 +575,13 @@ func (s *Server) handleCreateTrader(c *gin.Context) { var createErr error // Use ExchangeType (e.g., "binance") instead of ID (UUID) + // Convert EncryptedString fields to string switch exchangeCfg.ExchangeType { case "binance": - tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey, userID) + tempTrader = trader.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID) case "hyperliquid": tempTrader, createErr = trader.NewHyperliquidTrader( - exchangeCfg.APIKey, // private key + string(exchangeCfg.APIKey), // private key exchangeCfg.HyperliquidWalletAddr, exchangeCfg.Testnet, ) @@ -589,31 +589,31 @@ func (s *Server) handleCreateTrader(c *gin.Context) { tempTrader, createErr = trader.NewAsterTrader( exchangeCfg.AsterUser, exchangeCfg.AsterSigner, - exchangeCfg.AsterPrivateKey, + string(exchangeCfg.AsterPrivateKey), ) case "bybit": tempTrader = trader.NewBybitTrader( - exchangeCfg.APIKey, - exchangeCfg.SecretKey, + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), ) case "okx": tempTrader = trader.NewOKXTrader( - exchangeCfg.APIKey, - exchangeCfg.SecretKey, - exchangeCfg.Passphrase, + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + string(exchangeCfg.Passphrase), ) case "bitget": tempTrader = trader.NewBitgetTrader( - exchangeCfg.APIKey, - exchangeCfg.SecretKey, - exchangeCfg.Passphrase, + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + string(exchangeCfg.Passphrase), ) case "lighter": - if exchangeCfg.LighterWalletAddr != "" && exchangeCfg.LighterAPIKeyPrivateKey != "" { + if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" { // Lighter only supports mainnet tempTrader, createErr = trader.NewLighterTraderV2( exchangeCfg.LighterWalletAddr, - exchangeCfg.LighterAPIKeyPrivateKey, + string(exchangeCfg.LighterAPIKeyPrivateKey), exchangeCfg.LighterAPIKeyIndex, false, // Always use mainnet for Lighter ) @@ -1095,12 +1095,13 @@ func (s *Server) handleSyncBalance(c *gin.Context) { var createErr error // Use ExchangeType (e.g., "binance") instead of ExchangeID (which is now UUID) + // Convert EncryptedString fields to string switch exchangeCfg.ExchangeType { case "binance": - tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey, userID) + tempTrader = trader.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID) case "hyperliquid": tempTrader, createErr = trader.NewHyperliquidTrader( - exchangeCfg.APIKey, + string(exchangeCfg.APIKey), exchangeCfg.HyperliquidWalletAddr, exchangeCfg.Testnet, ) @@ -1108,31 +1109,31 @@ func (s *Server) handleSyncBalance(c *gin.Context) { tempTrader, createErr = trader.NewAsterTrader( exchangeCfg.AsterUser, exchangeCfg.AsterSigner, - exchangeCfg.AsterPrivateKey, + string(exchangeCfg.AsterPrivateKey), ) case "bybit": tempTrader = trader.NewBybitTrader( - exchangeCfg.APIKey, - exchangeCfg.SecretKey, + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), ) case "okx": tempTrader = trader.NewOKXTrader( - exchangeCfg.APIKey, - exchangeCfg.SecretKey, - exchangeCfg.Passphrase, + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + string(exchangeCfg.Passphrase), ) case "bitget": tempTrader = trader.NewBitgetTrader( - exchangeCfg.APIKey, - exchangeCfg.SecretKey, - exchangeCfg.Passphrase, + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + string(exchangeCfg.Passphrase), ) case "lighter": - if exchangeCfg.LighterWalletAddr != "" && exchangeCfg.LighterAPIKeyPrivateKey != "" { + if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" { // Lighter only supports mainnet tempTrader, createErr = trader.NewLighterTraderV2( exchangeCfg.LighterWalletAddr, - exchangeCfg.LighterAPIKeyPrivateKey, + string(exchangeCfg.LighterAPIKeyPrivateKey), exchangeCfg.LighterAPIKeyIndex, false, // Always use mainnet for Lighter ) @@ -1246,12 +1247,13 @@ func (s *Server) handleClosePosition(c *gin.Context) { var createErr error // Use ExchangeType (e.g., "binance") instead of ExchangeID (which is now UUID) + // Convert EncryptedString fields to string switch exchangeCfg.ExchangeType { case "binance": - tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey, userID) + tempTrader = trader.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID) case "hyperliquid": tempTrader, createErr = trader.NewHyperliquidTrader( - exchangeCfg.APIKey, + string(exchangeCfg.APIKey), exchangeCfg.HyperliquidWalletAddr, exchangeCfg.Testnet, ) @@ -1259,31 +1261,31 @@ func (s *Server) handleClosePosition(c *gin.Context) { tempTrader, createErr = trader.NewAsterTrader( exchangeCfg.AsterUser, exchangeCfg.AsterSigner, - exchangeCfg.AsterPrivateKey, + string(exchangeCfg.AsterPrivateKey), ) case "bybit": tempTrader = trader.NewBybitTrader( - exchangeCfg.APIKey, - exchangeCfg.SecretKey, + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), ) case "okx": tempTrader = trader.NewOKXTrader( - exchangeCfg.APIKey, - exchangeCfg.SecretKey, - exchangeCfg.Passphrase, + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + string(exchangeCfg.Passphrase), ) case "bitget": tempTrader = trader.NewBitgetTrader( - exchangeCfg.APIKey, - exchangeCfg.SecretKey, - exchangeCfg.Passphrase, + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + string(exchangeCfg.Passphrase), ) case "lighter": - if exchangeCfg.LighterWalletAddr != "" && exchangeCfg.LighterAPIKeyPrivateKey != "" { + if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" { // Lighter only supports mainnet tempTrader, createErr = trader.NewLighterTraderV2( exchangeCfg.LighterWalletAddr, - exchangeCfg.LighterAPIKeyPrivateKey, + string(exchangeCfg.LighterAPIKeyPrivateKey), exchangeCfg.LighterAPIKeyIndex, false, // Always use mainnet for Lighter ) diff --git a/api/strategy.go b/api/strategy.go index e9a2e2de..6c581703 100644 --- a/api/strategy.go +++ b/api/strategy.go @@ -549,32 +549,34 @@ func (s *Server) runRealAITest(userID, modelID, systemPrompt, userPrompt string) var aiClient mcp.AIClient provider := model.Provider + // Convert EncryptedString to string for API key + apiKey := string(model.APIKey) switch provider { case "qwen": aiClient = mcp.NewQwenClient() - aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName) + aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName) case "deepseek": aiClient = mcp.NewDeepSeekClient() - aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName) + aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName) case "claude": aiClient = mcp.NewClaudeClient() - aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName) + aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName) case "kimi": aiClient = mcp.NewKimiClient() - aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName) + aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName) case "gemini": aiClient = mcp.NewGeminiClient() - aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName) + aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName) case "grok": aiClient = mcp.NewGrokClient() - aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName) + aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName) case "openai": aiClient = mcp.NewOpenAIClient() - aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName) + aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName) default: // Use generic client aiClient = mcp.NewClient() - aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName) + aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName) } // Call AI API diff --git a/backtest/retention.go b/backtest/retention.go index 55395c97..a9d34d74 100644 --- a/backtest/retention.go +++ b/backtest/retention.go @@ -76,8 +76,8 @@ func enforceRetentionDB(maxRuns int) { query := ` SELECT run_id FROM backtest_runs WHERE state IN (?, ?, ?, ?) - ORDER BY datetime(updated_at) DESC - LIMIT -1 OFFSET ? + ORDER BY updated_at DESC + OFFSET ? ` rows, err := persistenceDB.Query(query, finalStates[0], finalStates[1], finalStates[2], finalStates[3], maxRuns) diff --git a/backtest/storage_db_impl.go b/backtest/storage_db_impl.go index 67cc0831..f8899aa0 100644 --- a/backtest/storage_db_impl.go +++ b/backtest/storage_db_impl.go @@ -166,7 +166,7 @@ func loadRunMetadataDB(runID string) (*RunMetadata, error) { } func loadRunIDsDB() ([]string, error) { - rows, err := persistenceDB.Query(`SELECT run_id FROM backtest_runs ORDER BY datetime(updated_at) DESC`) + rows, err := persistenceDB.Query(`SELECT run_id FROM backtest_runs ORDER BY updated_at DESC`) if err != nil { return nil, err } @@ -278,9 +278,9 @@ func loadDecisionTraceDB(runID string, cycle int) (*store.DecisionRecord, error) var rows *sql.Rows var err error if cycle > 0 { - rows, err = persistenceDB.Query(query+` AND cycle = ? ORDER BY datetime(created_at) DESC LIMIT 1`, runID, cycle) + rows, err = persistenceDB.Query(query+` AND cycle = ? ORDER BY created_at DESC LIMIT 1`, runID, cycle) } else { - rows, err = persistenceDB.Query(query+` ORDER BY datetime(created_at) DESC LIMIT 1`, runID) + rows, err = persistenceDB.Query(query+` ORDER BY created_at DESC LIMIT 1`, runID) } if err != nil { return nil, err @@ -461,7 +461,7 @@ func listIndexEntriesDB() ([]RunIndexEntry, error) { rows, err := persistenceDB.Query(` SELECT run_id, state, symbol_count, decision_tf, equity_last, max_drawdown_pct, created_at, updated_at, config_json FROM backtest_runs - ORDER BY datetime(updated_at) DESC + ORDER BY updated_at DESC `) if err != nil { return nil, err diff --git a/config/config.go b/config/config.go index 33c58f19..1a4a0d96 100644 --- a/config/config.go +++ b/config/config.go @@ -20,6 +20,16 @@ type Config struct { RegistrationEnabled bool MaxUsers int // Maximum number of users allowed (0 = unlimited, default = 10) + // Database configuration + DBType string // sqlite or postgres + DBPath string // SQLite database file path + DBHost string // PostgreSQL host + DBPort int // PostgreSQL port + DBUser string // PostgreSQL user + DBPassword string // PostgreSQL password + DBName string // PostgreSQL database name + DBSSLMode string // PostgreSQL SSL mode + // Security configuration // TransportEncryption enables browser-side encryption for API keys // Requires HTTPS or localhost. Set to false for HTTP access via IP. @@ -43,6 +53,14 @@ func Init() { RegistrationEnabled: true, MaxUsers: 10, // Default: 10 users allowed ExperienceImprovement: true, // Default: enabled to help improve the product + // Database defaults + DBType: "sqlite", + DBPath: "data/data.db", + DBHost: "localhost", + DBPort: 5432, + DBUser: "postgres", + DBName: "nofx", + DBSSLMode: "disable", } // Load from environment variables @@ -86,6 +104,34 @@ func Init() { cfg.AlpacaSecretKey = os.Getenv("ALPACA_SECRET_KEY") cfg.TwelveDataKey = os.Getenv("TWELVEDATA_API_KEY") + // Database configuration + if v := os.Getenv("DB_TYPE"); v != "" { + cfg.DBType = strings.ToLower(v) + } + if v := os.Getenv("DB_PATH"); v != "" { + cfg.DBPath = v + } + if v := os.Getenv("DB_HOST"); v != "" { + cfg.DBHost = v + } + if v := os.Getenv("DB_PORT"); v != "" { + if port, err := strconv.Atoi(v); err == nil && port > 0 { + cfg.DBPort = port + } + } + if v := os.Getenv("DB_USER"); v != "" { + cfg.DBUser = v + } + if v := os.Getenv("DB_PASSWORD"); v != "" { + cfg.DBPassword = v + } + if v := os.Getenv("DB_NAME"); v != "" { + cfg.DBName = v + } + if v := os.Getenv("DB_SSLMODE"); v != "" { + cfg.DBSSLMode = v + } + global = cfg // Initialize experience improvement (installation ID will be set after database init) diff --git a/crypto/crypto.go b/crypto/crypto.go index 76927f6c..f2142155 100644 --- a/crypto/crypto.go +++ b/crypto/crypto.go @@ -7,6 +7,7 @@ import ( "crypto/rsa" "crypto/sha256" "crypto/x509" + "database/sql/driver" "encoding/base64" "encoding/hex" "encoding/json" @@ -392,3 +393,77 @@ func GenerateDataKey() (string, error) { } return base64.StdEncoding.EncodeToString(key), nil } + +// ============================================================================ +// EncryptedString - GORM custom type for automatic encryption/decryption +// ============================================================================ + +// Global crypto service for EncryptedString +var globalCryptoService *CryptoService + +// SetGlobalCryptoService sets the global crypto service for EncryptedString +func SetGlobalCryptoService(cs *CryptoService) { + globalCryptoService = cs +} + +// EncryptedString is a custom type that automatically encrypts on save and decrypts on load +// Usage: Use EncryptedString instead of string for sensitive fields in GORM models +type EncryptedString string + +// Scan implements sql.Scanner - called when reading from database +// Automatically decrypts the value +func (es *EncryptedString) Scan(value interface{}) error { + if value == nil { + *es = "" + return nil + } + + var str string + switch v := value.(type) { + case string: + str = v + case []byte: + str = string(v) + default: + *es = "" + return nil + } + + // Decrypt if crypto service is set + if globalCryptoService != nil && str != "" && globalCryptoService.IsEncryptedStorageValue(str) { + decrypted, err := globalCryptoService.DecryptFromStorage(str) + if err != nil { + // If decryption fails, return the original value + *es = EncryptedString(str) + } else { + *es = EncryptedString(decrypted) + } + } else { + *es = EncryptedString(str) + } + return nil +} + +// Value implements driver.Valuer - called when writing to database +// Automatically encrypts the value +func (es EncryptedString) Value() (driver.Value, error) { + if es == "" { + return "", nil + } + + // Encrypt if crypto service is set + if globalCryptoService != nil { + encrypted, err := globalCryptoService.EncryptForStorage(string(es)) + if err != nil { + // If encryption fails, return the original value + return string(es), nil + } + return encrypted, nil + } + return string(es), nil +} + +// String returns the plaintext string value +func (es EncryptedString) String() string { + return string(es) +} diff --git a/debate/engine.go b/debate/engine.go index 48d50122..ca4913f9 100644 --- a/debate/engine.go +++ b/debate/engine.go @@ -101,8 +101,8 @@ func (e *DebateEngine) InitializeClients(participants []*store.DebateParticipant client = mcp.New() } - // Configure client - client.SetAPIKey(aiModel.APIKey, aiModel.CustomAPIURL, aiModel.CustomModelName) + // Configure client (convert EncryptedString to string) + client.SetAPIKey(string(aiModel.APIKey), aiModel.CustomAPIURL, aiModel.CustomModelName) e.clients[p.AIModelID] = client } diff --git a/go.mod b/go.mod index 3c4d11fa..eb33cad4 100644 --- a/go.mod +++ b/go.mod @@ -51,6 +51,12 @@ require ( github.com/goccy/go-json v0.10.4 // indirect github.com/goccy/go-yaml v1.18.0 // indirect github.com/holiman/uint256 v1.3.2 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/pgx/v5 v5.6.0 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/jpillora/backoff v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect @@ -94,6 +100,9 @@ require ( golang.org/x/tools v0.36.0 // indirect google.golang.org/protobuf v1.36.9 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + gorm.io/driver/postgres v1.6.0 // indirect + gorm.io/driver/sqlite v1.6.0 // indirect + gorm.io/gorm v1.31.1 // indirect howett.net/plist v1.0.1 // indirect modernc.org/libc v1.66.10 // indirect modernc.org/mathutil v1.7.1 // indirect diff --git a/go.sum b/go.sum index d4d1591b..48302780 100644 --- a/go.sum +++ b/go.sum @@ -111,7 +111,19 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/holiman/uint256 v1.3.2 h1:a9EgMPSC1AAaj1SZL5zIQD3WbwTuHrMGOerLjGmM/TA= github.com/holiman/uint256 v1.3.2/go.mod h1:EOMSn4q6Nyt9P6efbI3bueV4e1b3dGlUCXeiRV4ng7E= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= +github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= @@ -284,6 +296,12 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4= +gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo= +gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= +gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= howett.net/plist v1.0.1 h1:37GdZ8tP09Q35o9ych3ehygcsL+HqKSwzctveSlarvM= howett.net/plist v1.0.1/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g= modernc.org/cc/v4 v4.26.5 h1:xM3bX7Mve6G8K8b+T11ReenJOT+BmVqQj0FY5T4+5Y4= diff --git a/main.go b/main.go index 3350d423..9905e7b1 100644 --- a/main.go +++ b/main.go @@ -36,30 +36,44 @@ func main() { cfg := config.Get() logger.Info("✅ Configuration loaded") - // Initialize database from environment variables - // DB_TYPE: sqlite (default) or postgres - // For SQLite: DB_PATH (default: data/data.db) - // For PostgreSQL: DB_HOST, DB_PORT, DB_USER, DB_PASSWORD, DB_NAME, DB_SSLMODE - dbPath := os.Getenv("DB_PATH") - if dbPath == "" { - dbPath = "data/data.db" + // Initialize encryption service BEFORE database (so EncryptedString can decrypt on read) + logger.Info("🔐 Initializing encryption service...") + cryptoService, err := crypto.NewCryptoService() + if err != nil { + logger.Fatalf("❌ Failed to initialize encryption service: %v", err) } - // For backward compatibility: command line arg overrides env var (SQLite only) + crypto.SetGlobalCryptoService(cryptoService) + logger.Info("✅ Encryption service initialized successfully") + + // Initialize database from configuration + // For backward compatibility: command line arg overrides config (SQLite only) if len(os.Args) > 1 { - dbPath = os.Args[1] - os.Setenv("DB_PATH", dbPath) + cfg.DBPath = os.Args[1] } // Ensure data directory exists (for SQLite) - if os.Getenv("DB_TYPE") == "" || os.Getenv("DB_TYPE") == "sqlite" { - if dir := filepath.Dir(dbPath); dir != "." { + if cfg.DBType == "sqlite" { + if dir := filepath.Dir(cfg.DBPath); dir != "." { if err := os.MkdirAll(dir, 0755); err != nil { logger.Errorf("Failed to create data directory: %v", err) } } } - logger.Info("📋 Initializing database...") - st, err := store.NewFromEnv() + logger.Infof("📋 Initializing database (%s)...", cfg.DBType) + dbType := store.DBTypeSQLite + if cfg.DBType == "postgres" { + dbType = store.DBTypePostgres + } + st, err := store.NewWithConfig(store.DBConfig{ + Type: dbType, + Path: cfg.DBPath, + Host: cfg.DBHost, + Port: cfg.DBPort, + User: cfg.DBUser, + Password: cfg.DBPassword, + DBName: cfg.DBName, + SSLMode: cfg.DBSSLMode, + }) if err != nil { logger.Fatalf("❌ Failed to initialize database: %v", err) } @@ -69,40 +83,6 @@ func main() { // Initialize installation ID for experience improvement (anonymous statistics) initInstallationID(st) - // Initialize encryption service - logger.Info("🔐 Initializing encryption service...") - cryptoService, err := crypto.NewCryptoService() - if err != nil { - logger.Fatalf("❌ Failed to initialize encryption service: %v", err) - } - encryptFunc := func(plaintext string) string { - if plaintext == "" { - return plaintext - } - encrypted, err := cryptoService.EncryptForStorage(plaintext) - if err != nil { - logger.Warnf("⚠️ Encryption failed: %v", err) - return plaintext - } - return encrypted - } - decryptFunc := func(encrypted string) string { - if encrypted == "" { - return encrypted - } - if !cryptoService.IsEncryptedStorageValue(encrypted) { - return encrypted - } - decrypted, err := cryptoService.DecryptFromStorage(encrypted) - if err != nil { - logger.Warnf("⚠️ Decryption failed: %v", err) - return encrypted - } - return decrypted - } - st.SetCryptoFuncs(encryptFunc, decryptFunc) - logger.Info("✅ Encryption service initialized successfully") - // Set JWT secret auth.SetJWTSecret(cfg.JWTSecret) logger.Info("🔑 JWT secret configured") diff --git a/manager/trader_manager.go b/manager/trader_manager.go index d3a4b3c6..c18c9a9d 100644 --- a/manager/trader_manager.go +++ b/manager/trader_manager.go @@ -664,46 +664,46 @@ func (tm *TraderManager) addTraderFromStore(traderCfg *store.Trader, aiModelCfg StrategyConfig: strategyConfig, } - // Set API keys based on exchange type + // Set API keys based on exchange type (convert EncryptedString to string) switch exchangeCfg.ExchangeType { case "binance": - traderConfig.BinanceAPIKey = exchangeCfg.APIKey - traderConfig.BinanceSecretKey = exchangeCfg.SecretKey + traderConfig.BinanceAPIKey = string(exchangeCfg.APIKey) + traderConfig.BinanceSecretKey = string(exchangeCfg.SecretKey) case "bybit": - traderConfig.BybitAPIKey = exchangeCfg.APIKey - traderConfig.BybitSecretKey = exchangeCfg.SecretKey + traderConfig.BybitAPIKey = string(exchangeCfg.APIKey) + traderConfig.BybitSecretKey = string(exchangeCfg.SecretKey) case "okx": - traderConfig.OKXAPIKey = exchangeCfg.APIKey - traderConfig.OKXSecretKey = exchangeCfg.SecretKey - traderConfig.OKXPassphrase = exchangeCfg.Passphrase + traderConfig.OKXAPIKey = string(exchangeCfg.APIKey) + traderConfig.OKXSecretKey = string(exchangeCfg.SecretKey) + traderConfig.OKXPassphrase = string(exchangeCfg.Passphrase) case "bitget": - traderConfig.BitgetAPIKey = exchangeCfg.APIKey - traderConfig.BitgetSecretKey = exchangeCfg.SecretKey - traderConfig.BitgetPassphrase = exchangeCfg.Passphrase + traderConfig.BitgetAPIKey = string(exchangeCfg.APIKey) + traderConfig.BitgetSecretKey = string(exchangeCfg.SecretKey) + traderConfig.BitgetPassphrase = string(exchangeCfg.Passphrase) case "hyperliquid": - traderConfig.HyperliquidPrivateKey = exchangeCfg.APIKey + traderConfig.HyperliquidPrivateKey = string(exchangeCfg.APIKey) traderConfig.HyperliquidWalletAddr = exchangeCfg.HyperliquidWalletAddr case "aster": traderConfig.AsterUser = exchangeCfg.AsterUser traderConfig.AsterSigner = exchangeCfg.AsterSigner - traderConfig.AsterPrivateKey = exchangeCfg.AsterPrivateKey + traderConfig.AsterPrivateKey = string(exchangeCfg.AsterPrivateKey) case "lighter": - traderConfig.LighterPrivateKey = exchangeCfg.LighterPrivateKey + traderConfig.LighterPrivateKey = string(exchangeCfg.LighterPrivateKey) traderConfig.LighterWalletAddr = exchangeCfg.LighterWalletAddr - traderConfig.LighterAPIKeyPrivateKey = exchangeCfg.LighterAPIKeyPrivateKey + traderConfig.LighterAPIKeyPrivateKey = string(exchangeCfg.LighterAPIKeyPrivateKey) traderConfig.LighterAPIKeyIndex = exchangeCfg.LighterAPIKeyIndex traderConfig.LighterTestnet = exchangeCfg.Testnet } - // Set API keys based on AI model + // Set API keys based on AI model (convert EncryptedString to string) switch aiModelCfg.Provider { case "qwen": - traderConfig.QwenKey = aiModelCfg.APIKey + traderConfig.QwenKey = string(aiModelCfg.APIKey) case "deepseek": - traderConfig.DeepSeekKey = aiModelCfg.APIKey + traderConfig.DeepSeekKey = string(aiModelCfg.APIKey) default: // For other providers (grok, openai, claude, gemini, kimi, etc.), use CustomAPIKey - traderConfig.CustomAPIKey = aiModelCfg.APIKey + traderConfig.CustomAPIKey = string(aiModelCfg.APIKey) } // Create trader instance diff --git a/store/ai_model.go b/store/ai_model.go index 932f7501..9950acba 100644 --- a/store/ai_model.go +++ b/store/ai_model.go @@ -1,71 +1,52 @@ package store import ( - "database/sql" "errors" "fmt" + "nofx/crypto" "nofx/logger" "strings" "time" + + "gorm.io/gorm" ) // AIModelStore AI model storage type AIModelStore struct { - db *sql.DB - encryptFunc func(string) string - decryptFunc func(string) string + db *gorm.DB } // AIModel AI model configuration type AIModel struct { - ID string `json:"id"` - UserID string `json:"user_id"` - Name string `json:"name"` - Provider string `json:"provider"` - Enabled bool `json:"enabled"` - APIKey string `json:"apiKey"` - CustomAPIURL string `json:"customApiUrl"` - CustomModelName string `json:"customModelName"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID string `gorm:"primaryKey" json:"id"` + UserID string `gorm:"column:user_id;not null;default:default;index" json:"user_id"` + Name string `gorm:"not null" json:"name"` + Provider string `gorm:"not null" json:"provider"` + Enabled bool `gorm:"default:false" json:"enabled"` + APIKey crypto.EncryptedString `gorm:"column:api_key;default:''" json:"apiKey"` + CustomAPIURL string `gorm:"column:custom_api_url;default:''" json:"customApiUrl"` + CustomModelName string `gorm:"column:custom_model_name;default:''" json:"customModelName"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func (AIModel) TableName() string { return "ai_models" } + +// NewAIModelStore creates a new AIModelStore +func NewAIModelStore(db *gorm.DB) *AIModelStore { + return &AIModelStore{db: db} } func (s *AIModelStore) initTables() error { - _, err := s.db.Exec(` - CREATE TABLE IF NOT EXISTS ai_models ( - id TEXT PRIMARY KEY, - user_id TEXT NOT NULL DEFAULT 'default', - name TEXT NOT NULL, - provider TEXT NOT NULL, - enabled BOOLEAN DEFAULT 0, - api_key TEXT DEFAULT '', - custom_api_url TEXT DEFAULT '', - custom_model_name TEXT DEFAULT '', - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - `) - if err != nil { - return err + // For PostgreSQL with existing table, skip AutoMigrate + if s.db.Dialector.Name() == "postgres" { + var tableExists int64 + s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'ai_models'`).Scan(&tableExists) + if tableExists > 0 { + return nil + } } - - // Trigger - _, err = s.db.Exec(` - CREATE TRIGGER IF NOT EXISTS update_ai_models_updated_at - AFTER UPDATE ON ai_models - BEGIN - UPDATE ai_models SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id; - END - `) - if err != nil { - return err - } - - // Backward compatibility: add potentially missing columns - s.db.Exec(`ALTER TABLE ai_models ADD COLUMN custom_api_url TEXT DEFAULT ''`) - s.db.Exec(`ALTER TABLE ai_models ADD COLUMN custom_model_name TEXT DEFAULT ''`) - - return nil + return s.db.AutoMigrate(&AIModel{}) } func (s *AIModelStore) initDefaultData() error { @@ -73,51 +54,13 @@ func (s *AIModelStore) initDefaultData() error { return nil } -func (s *AIModelStore) encrypt(plaintext string) string { - if s.encryptFunc != nil { - return s.encryptFunc(plaintext) - } - return plaintext -} - -func (s *AIModelStore) decrypt(encrypted string) string { - if s.decryptFunc != nil { - return s.decryptFunc(encrypted) - } - return encrypted -} - // List retrieves user's AI model list func (s *AIModelStore) List(userID string) ([]*AIModel, error) { - rows, err := s.db.Query(` - SELECT id, user_id, name, provider, enabled, api_key, - COALESCE(custom_api_url, '') as custom_api_url, - COALESCE(custom_model_name, '') as custom_model_name, - created_at, updated_at - FROM ai_models WHERE user_id = ? ORDER BY id - `, userID) + var models []*AIModel + err := s.db.Where("user_id = ?", userID).Order("id").Find(&models).Error if err != nil { return nil, err } - defer rows.Close() - - models := make([]*AIModel, 0) - for rows.Next() { - var model AIModel - var createdAt, updatedAt string - err := rows.Scan( - &model.ID, &model.UserID, &model.Name, &model.Provider, - &model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName, - &createdAt, &updatedAt, - ) - if err != nil { - return nil, err - } - model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) - model.APIKey = s.decrypt(model.APIKey) - models = append(models, &model) - } return models, nil } @@ -140,27 +83,15 @@ func (s *AIModelStore) Get(userID, modelID string) (*AIModel, error) { for _, uid := range candidates { var model AIModel - var createdAt, updatedAt string - err := s.db.QueryRow(` - SELECT id, user_id, name, provider, enabled, api_key, - COALESCE(custom_api_url, ''), COALESCE(custom_model_name, ''), created_at, updated_at - FROM ai_models WHERE user_id = ? AND id = ? LIMIT 1 - `, uid, modelID).Scan( - &model.ID, &model.UserID, &model.Name, &model.Provider, - &model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName, - &createdAt, &updatedAt, - ) + err := s.db.Where("user_id = ? AND id = ?", uid, modelID).First(&model).Error if err == nil { - model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) - model.APIKey = s.decrypt(model.APIKey) return &model, nil } - if !errors.Is(err, sql.ErrNoRows) { + if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } } - return nil, sql.ErrNoRows + return nil, gorm.ErrRecordNotFound } // GetByID retrieves an AI model by ID only (for debate engine) @@ -170,22 +101,10 @@ func (s *AIModelStore) GetByID(modelID string) (*AIModel, error) { } var model AIModel - var createdAt, updatedAt string - err := s.db.QueryRow(` - SELECT id, user_id, name, provider, enabled, api_key, - COALESCE(custom_api_url, ''), COALESCE(custom_model_name, ''), created_at, updated_at - FROM ai_models WHERE id = ? LIMIT 1 - `, modelID).Scan( - &model.ID, &model.UserID, &model.Name, &model.Provider, - &model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName, - &createdAt, &updatedAt, - ) + err := s.db.Where("id = ?", modelID).First(&model).Error if err != nil { return nil, err } - model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) - model.APIKey = s.decrypt(model.APIKey) return &model, nil } @@ -198,7 +117,7 @@ func (s *AIModelStore) GetDefault(userID string) (*AIModel, error) { if err == nil { return model, nil } - if !errors.Is(err, sql.ErrNoRows) { + if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } if userID != "default" { @@ -209,23 +128,12 @@ func (s *AIModelStore) GetDefault(userID string) (*AIModel, error) { func (s *AIModelStore) firstEnabled(userID string) (*AIModel, error) { var model AIModel - var createdAt, updatedAt string - err := s.db.QueryRow(` - SELECT id, user_id, name, provider, enabled, api_key, - COALESCE(custom_api_url, ''), COALESCE(custom_model_name, ''), created_at, updated_at - FROM ai_models WHERE user_id = ? AND enabled = 1 - ORDER BY datetime(updated_at) DESC, id ASC LIMIT 1 - `, userID).Scan( - &model.ID, &model.UserID, &model.Name, &model.Provider, - &model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName, - &createdAt, &updatedAt, - ) + err := s.db.Where("user_id = ? AND enabled = ?", userID, true). + Order("updated_at DESC, id ASC"). + First(&model).Error if err != nil { return nil, err } - model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) - model.APIKey = s.decrypt(model.APIKey) return &model, nil } @@ -233,44 +141,38 @@ func (s *AIModelStore) firstEnabled(userID string) (*AIModel, error) { // IMPORTANT: If apiKey is empty string, the existing API key will be preserved (not overwritten) func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPIURL, customModelName string) error { // Try exact ID match first - var existingID string - err := s.db.QueryRow(`SELECT id FROM ai_models WHERE user_id = ? AND id = ? LIMIT 1`, userID, id).Scan(&existingID) + var existingModel AIModel + err := s.db.Where("user_id = ? AND id = ?", userID, id).First(&existingModel).Error if err == nil { - // If apiKey is empty, preserve the existing API key - if apiKey == "" { - _, err = s.db.Exec(` - UPDATE ai_models SET enabled = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now') - WHERE id = ? AND user_id = ? - `, enabled, customAPIURL, customModelName, existingID, userID) - } else { - encryptedAPIKey := s.encrypt(apiKey) - _, err = s.db.Exec(` - UPDATE ai_models SET enabled = ?, api_key = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now') - WHERE id = ? AND user_id = ? - `, enabled, encryptedAPIKey, customAPIURL, customModelName, existingID, userID) + // Update existing model + updates := map[string]interface{}{ + "enabled": enabled, + "custom_api_url": customAPIURL, + "custom_model_name": customModelName, + "updated_at": time.Now(), } - return err + // If apiKey is not empty, update it (encryption handled by crypto.EncryptedString) + if apiKey != "" { + updates["api_key"] = crypto.EncryptedString(apiKey) + } + return s.db.Model(&existingModel).Updates(updates).Error } // Try legacy logic compatibility: use id as provider to search provider := id - err = s.db.QueryRow(`SELECT id FROM ai_models WHERE user_id = ? AND provider = ? LIMIT 1`, userID, provider).Scan(&existingID) + err = s.db.Where("user_id = ? AND provider = ?", userID, provider).First(&existingModel).Error if err == nil { - logger.Warnf("⚠️ Using legacy provider matching to update model: %s -> %s", provider, existingID) - // If apiKey is empty, preserve the existing API key - if apiKey == "" { - _, err = s.db.Exec(` - UPDATE ai_models SET enabled = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now') - WHERE id = ? AND user_id = ? - `, enabled, customAPIURL, customModelName, existingID, userID) - } else { - encryptedAPIKey := s.encrypt(apiKey) - _, err = s.db.Exec(` - UPDATE ai_models SET enabled = ?, api_key = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now') - WHERE id = ? AND user_id = ? - `, enabled, encryptedAPIKey, customAPIURL, customModelName, existingID, userID) + logger.Warnf("⚠️ Using legacy provider matching to update model: %s -> %s", provider, existingModel.ID) + updates := map[string]interface{}{ + "enabled": enabled, + "custom_api_url": customAPIURL, + "custom_model_name": customModelName, + "updated_at": time.Now(), } - return err + if apiKey != "" { + updates["api_key"] = crypto.EncryptedString(apiKey) + } + return s.db.Model(&existingModel).Updates(updates).Error } // Create new record @@ -285,9 +187,12 @@ func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPI } } + // Try to get name from existing model with same provider + var refModel AIModel var name string - err = s.db.QueryRow(`SELECT name FROM ai_models WHERE provider = ? LIMIT 1`, provider).Scan(&name) - if err != nil { + if err := s.db.Where("provider = ?", provider).First(&refModel).Error; err == nil { + name = refModel.Name + } else { if provider == "deepseek" { name = "DeepSeek AI" } else if provider == "qwen" { @@ -303,19 +208,30 @@ func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPI } logger.Infof("✓ Creating new AI model configuration: ID=%s, Provider=%s, Name=%s", newModelID, provider, name) - encryptedAPIKey := s.encrypt(apiKey) - _, err = s.db.Exec(` - INSERT INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url, custom_model_name, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now')) - `, newModelID, userID, name, provider, enabled, encryptedAPIKey, customAPIURL, customModelName) - return err + newModel := &AIModel{ + ID: newModelID, + UserID: userID, + Name: name, + Provider: provider, + Enabled: enabled, + APIKey: crypto.EncryptedString(apiKey), + CustomAPIURL: customAPIURL, + CustomModelName: customModelName, + } + return s.db.Create(newModel).Error } // Create creates an AI model func (s *AIModelStore) Create(userID, id, name, provider string, enabled bool, apiKey, customAPIURL string) error { - _, err := s.db.Exec(` - INSERT OR IGNORE INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url) - VALUES (?, ?, ?, ?, ?, ?, ?) - `, id, userID, name, provider, enabled, apiKey, customAPIURL) - return err + model := &AIModel{ + ID: id, + UserID: userID, + Name: name, + Provider: provider, + Enabled: enabled, + APIKey: crypto.EncryptedString(apiKey), + CustomAPIURL: customAPIURL, + } + // Use FirstOrCreate to ignore if already exists + return s.db.Where("id = ?", id).FirstOrCreate(model).Error } diff --git a/store/backtest.go b/store/backtest.go index 2ab4c846..ecb59f0e 100644 --- a/store/backtest.go +++ b/store/backtest.go @@ -1,15 +1,26 @@ package store import ( - "database/sql" "encoding/json" "fmt" "time" + + "gorm.io/gorm" ) // BacktestStore backtest data storage type BacktestStore struct { - db *sql.DB + db *gorm.DB +} + +// NewBacktestStore creates a new backtest store +func NewBacktestStore(db *gorm.DB) *BacktestStore { + return &BacktestStore{db: db} +} + +// isPostgres checks if the database is PostgreSQL +func (s *BacktestStore) isPostgres() bool { + return s.db.Dialector.Name() == "postgres" } // RunState backtest state @@ -92,492 +103,469 @@ type RunIndexEntry struct { UpdatedAtISO string `json:"updated_at"` } +// BacktestRun GORM model for backtest_runs table +type BacktestRun struct { + RunID string `gorm:"column:run_id;primaryKey"` + UserID string `gorm:"column:user_id;not null;default:''"` + ConfigJSON []byte `gorm:"column:config_json"` + State string `gorm:"column:state;not null;default:created"` + Label string `gorm:"column:label;default:''"` + SymbolCount int `gorm:"column:symbol_count;default:0"` + DecisionTF string `gorm:"column:decision_tf;default:''"` + ProcessedBars int `gorm:"column:processed_bars;default:0"` + ProgressPct float64 `gorm:"column:progress_pct;default:0"` + EquityLast float64 `gorm:"column:equity_last;default:0"` + MaxDrawdownPct float64 `gorm:"column:max_drawdown_pct;default:0"` + Liquidated bool `gorm:"column:liquidated;default:false"` + LiquidationNote string `gorm:"column:liquidation_note;default:''"` + PromptTemplate string `gorm:"column:prompt_template;default:''"` + CustomPrompt string `gorm:"column:custom_prompt;default:''"` + OverridePrompt bool `gorm:"column:override_prompt;default:false"` + AIProvider string `gorm:"column:ai_provider;default:''"` + AIModel string `gorm:"column:ai_model;default:''"` + LastError string `gorm:"column:last_error;default:''"` + CreatedAt time.Time `gorm:"column:created_at;autoCreateTime"` + UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime"` +} + +func (BacktestRun) TableName() string { + return "backtest_runs" +} + +// BacktestCheckpoint GORM model +type BacktestCheckpoint struct { + RunID string `gorm:"column:run_id;primaryKey"` + Payload []byte `gorm:"column:payload;not null"` + UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime"` +} + +func (BacktestCheckpoint) TableName() string { + return "backtest_checkpoints" +} + +// BacktestEquity GORM model +type BacktestEquity struct { + ID int64 `gorm:"primaryKey;autoIncrement"` + RunID string `gorm:"column:run_id;not null;index:idx_backtest_equity_run_ts"` + TS int64 `gorm:"column:ts;not null;index:idx_backtest_equity_run_ts"` + Equity float64 `gorm:"column:equity;not null"` + Available float64 `gorm:"column:available;not null"` + PnL float64 `gorm:"column:pnl;not null"` + PnLPct float64 `gorm:"column:pnl_pct;not null"` + DDPct float64 `gorm:"column:dd_pct;not null"` + Cycle int `gorm:"column:cycle;not null"` +} + +func (BacktestEquity) TableName() string { + return "backtest_equity" +} + +// BacktestTrade GORM model +type BacktestTrade struct { + ID int64 `gorm:"primaryKey;autoIncrement"` + RunID string `gorm:"column:run_id;not null;index:idx_backtest_trades_run_ts"` + TS int64 `gorm:"column:ts;not null;index:idx_backtest_trades_run_ts"` + Symbol string `gorm:"column:symbol;not null"` + Action string `gorm:"column:action;not null"` + Side string `gorm:"column:side;default:''"` + Qty float64 `gorm:"column:qty;default:0"` + Price float64 `gorm:"column:price;default:0"` + Fee float64 `gorm:"column:fee;default:0"` + Slippage float64 `gorm:"column:slippage;default:0"` + OrderValue float64 `gorm:"column:order_value;default:0"` + RealizedPnL float64 `gorm:"column:realized_pnl;default:0"` + Leverage int `gorm:"column:leverage;default:0"` + Cycle int `gorm:"column:cycle;default:0"` + PositionAfter float64 `gorm:"column:position_after;default:0"` + Liquidation bool `gorm:"column:liquidation;default:false"` + Note string `gorm:"column:note;default:''"` +} + +func (BacktestTrade) TableName() string { + return "backtest_trades" +} + +// BacktestMetrics GORM model +type BacktestMetrics struct { + RunID string `gorm:"column:run_id;primaryKey"` + Payload []byte `gorm:"column:payload;not null"` + UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime"` +} + +func (BacktestMetrics) TableName() string { + return "backtest_metrics" +} + +// BacktestDecision GORM model +type BacktestDecision struct { + ID int64 `gorm:"primaryKey;autoIncrement"` + RunID string `gorm:"column:run_id;not null;index:idx_backtest_decisions_run_cycle"` + Cycle int `gorm:"column:cycle;not null;index:idx_backtest_decisions_run_cycle"` + Payload []byte `gorm:"column:payload;not null"` + CreatedAt time.Time `gorm:"column:created_at;autoCreateTime"` +} + +func (BacktestDecision) TableName() string { + return "backtest_decisions" +} + // initTables initializes backtest related tables func (s *BacktestStore) initTables() error { - queries := []string{ - // Backtest runs main table - `CREATE TABLE IF NOT EXISTS backtest_runs ( - run_id TEXT PRIMARY KEY, - user_id TEXT NOT NULL DEFAULT '', - config_json TEXT NOT NULL DEFAULT '', - state TEXT NOT NULL DEFAULT 'created', - label TEXT DEFAULT '', - symbol_count INTEGER DEFAULT 0, - decision_tf TEXT DEFAULT '', - processed_bars INTEGER DEFAULT 0, - progress_pct REAL DEFAULT 0, - equity_last REAL DEFAULT 0, - max_drawdown_pct REAL DEFAULT 0, - liquidated BOOLEAN DEFAULT 0, - liquidation_note TEXT DEFAULT '', - prompt_template TEXT DEFAULT '', - custom_prompt TEXT DEFAULT '', - override_prompt BOOLEAN DEFAULT 0, - ai_provider TEXT DEFAULT '', - ai_model TEXT DEFAULT '', - last_error TEXT DEFAULT '', - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP - )`, + // For PostgreSQL with existing tables, skip AutoMigrate to avoid type conflicts + if s.db.Dialector.Name() == "postgres" { + var tableExists int64 + s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'backtest_runs'`).Scan(&tableExists) - // Backtest checkpoints - `CREATE TABLE IF NOT EXISTS backtest_checkpoints ( - run_id TEXT PRIMARY KEY, - payload BLOB NOT NULL, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE - )`, - - // Backtest equity curve - `CREATE TABLE IF NOT EXISTS backtest_equity ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - run_id TEXT NOT NULL, - ts INTEGER NOT NULL, - equity REAL NOT NULL, - available REAL NOT NULL, - pnl REAL NOT NULL, - pnl_pct REAL NOT NULL, - dd_pct REAL NOT NULL, - cycle INTEGER NOT NULL, - FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE - )`, - - // Backtest trade records - `CREATE TABLE IF NOT EXISTS backtest_trades ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - run_id TEXT NOT NULL, - ts INTEGER NOT NULL, - symbol TEXT NOT NULL, - action TEXT NOT NULL, - side TEXT DEFAULT '', - qty REAL DEFAULT 0, - price REAL DEFAULT 0, - fee REAL DEFAULT 0, - slippage REAL DEFAULT 0, - order_value REAL DEFAULT 0, - realized_pnl REAL DEFAULT 0, - leverage INTEGER DEFAULT 0, - cycle INTEGER DEFAULT 0, - position_after REAL DEFAULT 0, - liquidation BOOLEAN DEFAULT 0, - note TEXT DEFAULT '', - FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE - )`, - - // Backtest metrics - `CREATE TABLE IF NOT EXISTS backtest_metrics ( - run_id TEXT PRIMARY KEY, - payload BLOB NOT NULL, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE - )`, - - // Backtest decision logs - `CREATE TABLE IF NOT EXISTS backtest_decisions ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - run_id TEXT NOT NULL, - cycle INTEGER NOT NULL, - payload BLOB NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE - )`, - - // Indexes - `CREATE INDEX IF NOT EXISTS idx_backtest_runs_state ON backtest_runs(state, updated_at)`, - `CREATE INDEX IF NOT EXISTS idx_backtest_equity_run_ts ON backtest_equity(run_id, ts)`, - `CREATE INDEX IF NOT EXISTS idx_backtest_trades_run_ts ON backtest_trades(run_id, ts)`, - `CREATE INDEX IF NOT EXISTS idx_backtest_decisions_run_cycle ON backtest_decisions(run_id, cycle)`, - } - - for _, query := range queries { - if _, err := s.db.Exec(query); err != nil { - return fmt.Errorf("failed to execute SQL: %w", err) + if tableExists > 0 { + // Tables exist - just ensure indexes exist + s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_backtest_equity_run_ts ON backtest_equity(run_id, ts)`) + s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_backtest_trades_run_ts ON backtest_trades(run_id, ts)`) + s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_backtest_decisions_run_cycle ON backtest_decisions(run_id, cycle)`) + return nil } } - // Add potentially missing columns (backward compatibility) - s.addColumnIfNotExists("backtest_runs", "label", "TEXT DEFAULT ''") - s.addColumnIfNotExists("backtest_runs", "last_error", "TEXT DEFAULT ''") - s.addColumnIfNotExists("backtest_trades", "leverage", "INTEGER DEFAULT 0") + // AutoMigrate all backtest tables + if err := s.db.AutoMigrate( + &BacktestRun{}, + &BacktestCheckpoint{}, + &BacktestEquity{}, + &BacktestTrade{}, + &BacktestMetrics{}, + &BacktestDecision{}, + ); err != nil { + return fmt.Errorf("failed to migrate backtest tables: %w", err) + } return nil } -func (s *BacktestStore) addColumnIfNotExists(table, column, definition string) { - rows, err := s.db.Query(fmt.Sprintf("PRAGMA table_info(%s)", table)) - if err != nil { - return - } - defer rows.Close() - - for rows.Next() { - var cid int - var name, ctype string - var notnull, pk int - var dflt interface{} - if err := rows.Scan(&cid, &name, &ctype, ¬null, &dflt, &pk); err != nil { - continue - } - if name == column { - return // Column already exists - } - } - - s.db.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", table, column, definition)) -} - // SaveCheckpoint saves checkpoint func (s *BacktestStore) SaveCheckpoint(runID string, payload []byte) error { - _, err := s.db.Exec(` - INSERT INTO backtest_checkpoints (run_id, payload, updated_at) - VALUES (?, ?, CURRENT_TIMESTAMP) - ON CONFLICT(run_id) DO UPDATE SET payload=excluded.payload, updated_at=CURRENT_TIMESTAMP - `, runID, payload) - return err + checkpoint := BacktestCheckpoint{ + RunID: runID, + Payload: payload, + } + return s.db.Save(&checkpoint).Error } // LoadCheckpoint loads checkpoint func (s *BacktestStore) LoadCheckpoint(runID string) ([]byte, error) { - var payload []byte - err := s.db.QueryRow(`SELECT payload FROM backtest_checkpoints WHERE run_id = ?`, runID).Scan(&payload) - return payload, err + var checkpoint BacktestCheckpoint + err := s.db.Where("run_id = ?", runID).First(&checkpoint).Error + if err != nil { + return nil, err + } + return checkpoint.Payload, nil } // SaveRunMetadata saves run metadata func (s *BacktestStore) SaveRunMetadata(meta *RunMetadata) error { - created := meta.CreatedAt.UTC().Format(time.RFC3339) - updated := meta.UpdatedAt.UTC().Format(time.RFC3339) - userID := meta.UserID - - if _, err := s.db.Exec(` - INSERT INTO backtest_runs (run_id, user_id, label, last_error, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?) - ON CONFLICT(run_id) DO NOTHING - `, meta.RunID, userID, meta.Label, meta.LastError, created, updated); err != nil { - return err + run := BacktestRun{ + RunID: meta.RunID, + UserID: meta.UserID, + State: string(meta.State), + Label: meta.Label, + LastError: meta.LastError, + SymbolCount: meta.Summary.SymbolCount, + DecisionTF: meta.Summary.DecisionTF, + ProcessedBars: meta.Summary.ProcessedBars, + ProgressPct: meta.Summary.ProgressPct, + EquityLast: meta.Summary.EquityLast, + MaxDrawdownPct: meta.Summary.MaxDrawdownPct, + Liquidated: meta.Summary.Liquidated, + LiquidationNote: meta.Summary.LiquidationNote, + CreatedAt: meta.CreatedAt, + UpdatedAt: meta.UpdatedAt, } - - _, err := s.db.Exec(` - UPDATE backtest_runs - SET user_id = ?, state = ?, symbol_count = ?, decision_tf = ?, processed_bars = ?, - progress_pct = ?, equity_last = ?, max_drawdown_pct = ?, liquidated = ?, - liquidation_note = ?, label = ?, last_error = ?, updated_at = ? - WHERE run_id = ? - `, userID, string(meta.State), meta.Summary.SymbolCount, meta.Summary.DecisionTF, - meta.Summary.ProcessedBars, meta.Summary.ProgressPct, meta.Summary.EquityLast, - meta.Summary.MaxDrawdownPct, meta.Summary.Liquidated, meta.Summary.LiquidationNote, - meta.Label, meta.LastError, updated, meta.RunID) - return err + return s.db.Save(&run).Error } // LoadRunMetadata loads run metadata func (s *BacktestStore) LoadRunMetadata(runID string) (*RunMetadata, error) { - var ( - userID string - state string - label string - lastErr string - symbolCount int - decisionTF string - processedBars int - progressPct float64 - equityLast float64 - maxDD float64 - liquidated bool - liquidationNote string - createdISO string - updatedISO string - ) - - err := s.db.QueryRow(` - SELECT user_id, state, label, last_error, symbol_count, decision_tf, processed_bars, - progress_pct, equity_last, max_drawdown_pct, liquidated, liquidation_note, - created_at, updated_at - FROM backtest_runs WHERE run_id = ? - `, runID).Scan(&userID, &state, &label, &lastErr, &symbolCount, &decisionTF, - &processedBars, &progressPct, &equityLast, &maxDD, &liquidated, &liquidationNote, - &createdISO, &updatedISO) + var run BacktestRun + err := s.db.Where("run_id = ?", runID).First(&run).Error if err != nil { return nil, err } - meta := &RunMetadata{ - RunID: runID, - UserID: userID, + return &RunMetadata{ + RunID: run.RunID, + UserID: run.UserID, Version: 1, - State: RunState(state), - Label: label, - LastError: lastErr, + State: RunState(run.State), + Label: run.Label, + LastError: run.LastError, Summary: RunSummary{ - SymbolCount: symbolCount, - DecisionTF: decisionTF, - ProcessedBars: processedBars, - ProgressPct: progressPct, - EquityLast: equityLast, - MaxDrawdownPct: maxDD, - Liquidated: liquidated, - LiquidationNote: liquidationNote, + SymbolCount: run.SymbolCount, + DecisionTF: run.DecisionTF, + ProcessedBars: run.ProcessedBars, + ProgressPct: run.ProgressPct, + EquityLast: run.EquityLast, + MaxDrawdownPct: run.MaxDrawdownPct, + Liquidated: run.Liquidated, + LiquidationNote: run.LiquidationNote, }, - } - - meta.CreatedAt, _ = time.Parse(time.RFC3339, createdISO) - meta.UpdatedAt, _ = time.Parse(time.RFC3339, updatedISO) - - return meta, nil + CreatedAt: run.CreatedAt, + UpdatedAt: run.UpdatedAt, + }, nil } // ListRunIDs lists all run IDs func (s *BacktestStore) ListRunIDs() ([]string, error) { - rows, err := s.db.Query(`SELECT run_id FROM backtest_runs ORDER BY datetime(updated_at) DESC`) + var runs []BacktestRun + err := s.db.Order("updated_at DESC").Find(&runs).Error if err != nil { return nil, err } - defer rows.Close() - var ids []string - for rows.Next() { - var runID string - if err := rows.Scan(&runID); err != nil { - return nil, err - } - ids = append(ids, runID) + ids := make([]string, len(runs)) + for i, run := range runs { + ids[i] = run.RunID } - return ids, rows.Err() + return ids, nil } // AppendEquityPoint appends equity point func (s *BacktestStore) AppendEquityPoint(runID string, point EquityPoint) error { - _, err := s.db.Exec(` - INSERT INTO backtest_equity (run_id, ts, equity, available, pnl, pnl_pct, dd_pct, cycle) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - `, runID, point.Timestamp, point.Equity, point.Available, point.PnL, - point.PnLPct, point.DrawdownPct, point.Cycle) - return err + eq := BacktestEquity{ + RunID: runID, + TS: point.Timestamp, + Equity: point.Equity, + Available: point.Available, + PnL: point.PnL, + PnLPct: point.PnLPct, + DDPct: point.DrawdownPct, + Cycle: point.Cycle, + } + return s.db.Create(&eq).Error } // LoadEquityPoints loads equity points func (s *BacktestStore) LoadEquityPoints(runID string) ([]EquityPoint, error) { - rows, err := s.db.Query(` - SELECT ts, equity, available, pnl, pnl_pct, dd_pct, cycle - FROM backtest_equity WHERE run_id = ? ORDER BY ts ASC - `, runID) + var eqs []BacktestEquity + err := s.db.Where("run_id = ?", runID).Order("ts ASC").Find(&eqs).Error if err != nil { return nil, err } - defer rows.Close() - points := make([]EquityPoint, 0) - for rows.Next() { - var point EquityPoint - if err := rows.Scan(&point.Timestamp, &point.Equity, &point.Available, - &point.PnL, &point.PnLPct, &point.DrawdownPct, &point.Cycle); err != nil { - return nil, err + points := make([]EquityPoint, len(eqs)) + for i, eq := range eqs { + points[i] = EquityPoint{ + Timestamp: eq.TS, + Equity: eq.Equity, + Available: eq.Available, + PnL: eq.PnL, + PnLPct: eq.PnLPct, + DrawdownPct: eq.DDPct, + Cycle: eq.Cycle, } - points = append(points, point) } - return points, rows.Err() + return points, nil } // AppendTradeEvent appends trade event func (s *BacktestStore) AppendTradeEvent(runID string, event TradeEvent) error { - _, err := s.db.Exec(` - INSERT INTO backtest_trades (run_id, ts, symbol, action, side, qty, price, fee, - slippage, order_value, realized_pnl, leverage, cycle, - position_after, liquidation, note) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - `, runID, event.Timestamp, event.Symbol, event.Action, event.Side, event.Quantity, - event.Price, event.Fee, event.Slippage, event.OrderValue, event.RealizedPnL, - event.Leverage, event.Cycle, event.PositionAfter, event.LiquidationFlag, event.Note) - return err + trade := BacktestTrade{ + RunID: runID, + TS: event.Timestamp, + Symbol: event.Symbol, + Action: event.Action, + Side: event.Side, + Qty: event.Quantity, + Price: event.Price, + Fee: event.Fee, + Slippage: event.Slippage, + OrderValue: event.OrderValue, + RealizedPnL: event.RealizedPnL, + Leverage: event.Leverage, + Cycle: event.Cycle, + PositionAfter: event.PositionAfter, + Liquidation: event.LiquidationFlag, + Note: event.Note, + } + return s.db.Create(&trade).Error } // LoadTradeEvents loads trade events func (s *BacktestStore) LoadTradeEvents(runID string) ([]TradeEvent, error) { - rows, err := s.db.Query(` - SELECT ts, symbol, action, side, qty, price, fee, slippage, order_value, - realized_pnl, leverage, cycle, position_after, liquidation, note - FROM backtest_trades WHERE run_id = ? ORDER BY ts ASC - `, runID) + var trades []BacktestTrade + err := s.db.Where("run_id = ?", runID).Order("ts ASC").Find(&trades).Error if err != nil { return nil, err } - defer rows.Close() - events := make([]TradeEvent, 0) - for rows.Next() { - var event TradeEvent - if err := rows.Scan(&event.Timestamp, &event.Symbol, &event.Action, &event.Side, - &event.Quantity, &event.Price, &event.Fee, &event.Slippage, &event.OrderValue, - &event.RealizedPnL, &event.Leverage, &event.Cycle, &event.PositionAfter, - &event.LiquidationFlag, &event.Note); err != nil { - return nil, err + events := make([]TradeEvent, len(trades)) + for i, trade := range trades { + events[i] = TradeEvent{ + Timestamp: trade.TS, + Symbol: trade.Symbol, + Action: trade.Action, + Side: trade.Side, + Quantity: trade.Qty, + Price: trade.Price, + Fee: trade.Fee, + Slippage: trade.Slippage, + OrderValue: trade.OrderValue, + RealizedPnL: trade.RealizedPnL, + Leverage: trade.Leverage, + Cycle: trade.Cycle, + PositionAfter: trade.PositionAfter, + LiquidationFlag: trade.Liquidation, + Note: trade.Note, } - events = append(events, event) } - return events, rows.Err() + return events, nil } // SaveMetrics saves metrics func (s *BacktestStore) SaveMetrics(runID string, payload []byte) error { - _, err := s.db.Exec(` - INSERT INTO backtest_metrics (run_id, payload, updated_at) - VALUES (?, ?, CURRENT_TIMESTAMP) - ON CONFLICT(run_id) DO UPDATE SET payload=excluded.payload, updated_at=CURRENT_TIMESTAMP - `, runID, payload) - return err + metrics := BacktestMetrics{ + RunID: runID, + Payload: payload, + } + return s.db.Save(&metrics).Error } // LoadMetrics loads metrics func (s *BacktestStore) LoadMetrics(runID string) ([]byte, error) { - var payload []byte - err := s.db.QueryRow(`SELECT payload FROM backtest_metrics WHERE run_id = ?`, runID).Scan(&payload) - return payload, err + var metrics BacktestMetrics + err := s.db.Where("run_id = ?", runID).First(&metrics).Error + if err != nil { + return nil, err + } + return metrics.Payload, nil } // SaveDecisionRecord saves decision record func (s *BacktestStore) SaveDecisionRecord(runID string, cycle int, payload []byte) error { - _, err := s.db.Exec(` - INSERT INTO backtest_decisions (run_id, cycle, payload) - VALUES (?, ?, ?) - `, runID, cycle, payload) - return err + decision := BacktestDecision{ + RunID: runID, + Cycle: cycle, + Payload: payload, + } + return s.db.Create(&decision).Error } // LoadDecisionRecords loads decision records func (s *BacktestStore) LoadDecisionRecords(runID string, limit, offset int) ([]json.RawMessage, error) { - rows, err := s.db.Query(` - SELECT payload FROM backtest_decisions - WHERE run_id = ? - ORDER BY id DESC - LIMIT ? OFFSET ? - `, runID, limit, offset) + var decisions []BacktestDecision + err := s.db.Where("run_id = ?", runID). + Order("id DESC"). + Limit(limit). + Offset(offset). + Find(&decisions).Error if err != nil { return nil, err } - defer rows.Close() - records := make([]json.RawMessage, 0, limit) - for rows.Next() { - var payload []byte - if err := rows.Scan(&payload); err != nil { - return nil, err - } - records = append(records, json.RawMessage(payload)) + records := make([]json.RawMessage, len(decisions)) + for i, d := range decisions { + records[i] = json.RawMessage(d.Payload) } - return records, rows.Err() + return records, nil } // LoadLatestDecision loads latest decision func (s *BacktestStore) LoadLatestDecision(runID string, cycle int) ([]byte, error) { - var query string - var args []interface{} - + var decision BacktestDecision + query := s.db.Where("run_id = ?", runID) if cycle > 0 { - query = `SELECT payload FROM backtest_decisions WHERE run_id = ? AND cycle = ? ORDER BY datetime(created_at) DESC LIMIT 1` - args = []interface{}{runID, cycle} - } else { - query = `SELECT payload FROM backtest_decisions WHERE run_id = ? ORDER BY datetime(created_at) DESC LIMIT 1` - args = []interface{}{runID} + query = query.Where("cycle = ?", cycle) } - - var payload []byte - err := s.db.QueryRow(query, args...).Scan(&payload) - return payload, err + err := query.Order("created_at DESC").First(&decision).Error + if err != nil { + return nil, err + } + return decision.Payload, nil } // UpdateProgress updates progress func (s *BacktestStore) UpdateProgress(runID string, progressPct, equity float64, barIndex int, liquidated bool) error { - _, err := s.db.Exec(` - UPDATE backtest_runs - SET progress_pct = ?, equity_last = ?, processed_bars = ?, liquidated = ?, updated_at = CURRENT_TIMESTAMP - WHERE run_id = ? - `, progressPct, equity, barIndex, liquidated, runID) - return err + return s.db.Model(&BacktestRun{}).Where("run_id = ?", runID).Updates(map[string]interface{}{ + "progress_pct": progressPct, + "equity_last": equity, + "processed_bars": barIndex, + "liquidated": liquidated, + }).Error } // ListIndexEntries lists index entries func (s *BacktestStore) ListIndexEntries() ([]RunIndexEntry, error) { - rows, err := s.db.Query(` - SELECT run_id, state, symbol_count, decision_tf, equity_last, max_drawdown_pct, - created_at, updated_at, config_json - FROM backtest_runs - ORDER BY datetime(updated_at) DESC - `) + var runs []BacktestRun + err := s.db.Order("updated_at DESC").Find(&runs).Error if err != nil { return nil, err } - defer rows.Close() - var entries []RunIndexEntry - for rows.Next() { - var entry RunIndexEntry - var symbolCnt int - var cfgJSON []byte - var createdISO, updatedISO string - - if err := rows.Scan(&entry.RunID, &entry.State, &symbolCnt, &entry.DecisionTF, - &entry.EquityLast, &entry.MaxDrawdownPct, &createdISO, &updatedISO, &cfgJSON); err != nil { - return nil, err + entries := make([]RunIndexEntry, len(runs)) + for i, run := range runs { + entry := RunIndexEntry{ + RunID: run.RunID, + State: run.State, + DecisionTF: run.DecisionTF, + EquityLast: run.EquityLast, + MaxDrawdownPct: run.MaxDrawdownPct, + CreatedAtISO: run.CreatedAt.Format(time.RFC3339), + UpdatedAtISO: run.UpdatedAt.Format(time.RFC3339), + Symbols: make([]string, 0, run.SymbolCount), } - entry.CreatedAtISO = createdISO - entry.UpdatedAtISO = updatedISO - entry.Symbols = make([]string, 0, symbolCnt) - - // Try to extract more information from config - if len(cfgJSON) > 0 { + if len(run.ConfigJSON) > 0 { var cfg struct { Symbols []string `json:"symbols"` StartTS int64 `json:"start_ts"` EndTS int64 `json:"end_ts"` } - if json.Unmarshal(cfgJSON, &cfg) == nil { + if json.Unmarshal(run.ConfigJSON, &cfg) == nil { entry.Symbols = cfg.Symbols entry.StartTS = cfg.StartTS entry.EndTS = cfg.EndTS } } - entries = append(entries, entry) + entries[i] = entry } - return entries, rows.Err() + return entries, nil } // DeleteRun deletes run func (s *BacktestStore) DeleteRun(runID string) error { - _, err := s.db.Exec(`DELETE FROM backtest_runs WHERE run_id = ?`, runID) - return err + // Delete related records first (cascade may not work in all cases) + s.db.Where("run_id = ?", runID).Delete(&BacktestCheckpoint{}) + s.db.Where("run_id = ?", runID).Delete(&BacktestEquity{}) + s.db.Where("run_id = ?", runID).Delete(&BacktestTrade{}) + s.db.Where("run_id = ?", runID).Delete(&BacktestMetrics{}) + s.db.Where("run_id = ?", runID).Delete(&BacktestDecision{}) + + return s.db.Where("run_id = ?", runID).Delete(&BacktestRun{}).Error } // SaveConfig saves config func (s *BacktestStore) SaveConfig(runID, userID, template, customPrompt, provider, model string, override bool, configJSON []byte) error { - now := time.Now().UTC().Format(time.RFC3339) if userID == "" { userID = "default" } - _, err := s.db.Exec(` - INSERT INTO backtest_runs (run_id, user_id, config_json, prompt_template, custom_prompt, - override_prompt, ai_provider, ai_model, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ON CONFLICT(run_id) DO NOTHING - `, runID, userID, configJSON, template, customPrompt, override, provider, model, now, now) - if err != nil { - return err + run := BacktestRun{ + RunID: runID, + UserID: userID, + ConfigJSON: configJSON, + PromptTemplate: template, + CustomPrompt: customPrompt, + OverridePrompt: override, + AIProvider: provider, + AIModel: model, } - - _, err = s.db.Exec(` - UPDATE backtest_runs - SET user_id = ?, config_json = ?, prompt_template = ?, custom_prompt = ?, - override_prompt = ?, ai_provider = ?, ai_model = ?, updated_at = CURRENT_TIMESTAMP - WHERE run_id = ? - `, userID, configJSON, template, customPrompt, override, provider, model, runID) - return err + return s.db.Save(&run).Error } // LoadConfig loads config func (s *BacktestStore) LoadConfig(runID string) ([]byte, error) { - var payload []byte - err := s.db.QueryRow(`SELECT config_json FROM backtest_runs WHERE run_id = ?`, runID).Scan(&payload) - return payload, err + var run BacktestRun + err := s.db.Where("run_id = ?", runID).First(&run).Error + if err != nil { + return nil, err + } + return run.ConfigJSON, nil } diff --git a/store/debate.go b/store/debate.go index 3f57d2fb..03d7600f 100644 --- a/store/debate.go +++ b/store/debate.go @@ -1,12 +1,11 @@ package store import ( - "database/sql" "encoding/json" - "fmt" "time" "github.com/google/uuid" + "gorm.io/gorm" ) // DebateStatus represents the status of a debate session @@ -49,30 +48,6 @@ var PersonalityEmojis = map[DebatePersonality]string{ PersonalityRiskManager: "🛡️", } -// DebateSession represents a debate session -type DebateSession struct { - ID string `json:"id"` - UserID string `json:"user_id"` - Name string `json:"name"` - StrategyID string `json:"strategy_id"` - Status DebateStatus `json:"status"` - Symbol string `json:"symbol"` // Primary symbol (for backward compat, may be empty for multi-coin) - MaxRounds int `json:"max_rounds"` - CurrentRound int `json:"current_round"` - IntervalMinutes int `json:"interval_minutes"` // Debate interval (5, 15, 30, 60 minutes) - PromptVariant string `json:"prompt_variant"` // balanced/aggressive/conservative/scalping - FinalDecision *DebateDecision `json:"final_decision,omitempty"` // Single decision (backward compat) - FinalDecisions []*DebateDecision `json:"final_decisions,omitempty"` // Multi-coin decisions - AutoExecute bool `json:"auto_execute"` - TraderID string `json:"trader_id,omitempty"` // Trader to use for auto-execute - // OI Ranking data options - EnableOIRanking bool `json:"enable_oi_ranking"` // Whether to include OI ranking data - OIRankingLimit int `json:"oi_ranking_limit"` // Number of OI ranking entries (default 10) - OIDuration string `json:"oi_duration"` // Duration for OI data (1h, 4h, 24h, etc.) - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` -} - // DebateDecision represents a trading decision from the debate type DebateDecision struct { Action string `json:"action"` // open_long/open_short/close_long/close_short/hold/wait @@ -86,178 +61,187 @@ type DebateDecision struct { Reasoning string `json:"reasoning"` // Brief reasoning // Execution tracking - Executed bool `json:"executed"` // Whether this decision was executed + Executed bool `json:"executed"` // Whether this decision was executed ExecutedAt time.Time `json:"executed_at,omitempty"` // When it was executed OrderID string `json:"order_id,omitempty"` // Exchange order ID Error string `json:"error,omitempty"` // Execution error if any } +// DebateSession represents a debate session (API struct) +type DebateSession struct { + ID string `json:"id"` + UserID string `json:"user_id"` + Name string `json:"name"` + StrategyID string `json:"strategy_id"` + Status DebateStatus `json:"status"` + Symbol string `json:"symbol"` // Primary symbol (for backward compat, may be empty for multi-coin) + MaxRounds int `json:"max_rounds"` + CurrentRound int `json:"current_round"` + IntervalMinutes int `json:"interval_minutes"` // Debate interval (5, 15, 30, 60 minutes) + PromptVariant string `json:"prompt_variant"` // balanced/aggressive/conservative/scalping + FinalDecision *DebateDecision `json:"final_decision,omitempty"` // Single decision (backward compat) + FinalDecisions []*DebateDecision `json:"final_decisions,omitempty"` // Multi-coin decisions + AutoExecute bool `json:"auto_execute"` + TraderID string `json:"trader_id,omitempty"` // Trader to use for auto-execute + // OI Ranking data options + EnableOIRanking bool `json:"enable_oi_ranking"` // Whether to include OI ranking data + OIRankingLimit int `json:"oi_ranking_limit"` // Number of OI ranking entries (default 10) + OIDuration string `json:"oi_duration"` // Duration for OI data (1h, 4h, 24h, etc.) + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// DebateSessionDB is the GORM model for debate_sessions +type DebateSessionDB struct { + ID string `gorm:"column:id;primaryKey"` + UserID string `gorm:"column:user_id;not null;index"` + Name string `gorm:"column:name;not null"` + StrategyID string `gorm:"column:strategy_id;not null"` + Status DebateStatus `gorm:"column:status;not null;default:pending;index"` + Symbol string `gorm:"column:symbol;not null"` + MaxRounds int `gorm:"column:max_rounds;default:3"` + CurrentRound int `gorm:"column:current_round;default:0"` + IntervalMinutes int `gorm:"column:interval_minutes;default:5"` + PromptVariant string `gorm:"column:prompt_variant;default:balanced"` + FinalDecision string `gorm:"column:final_decision"` // JSON string + AutoExecute bool `gorm:"column:auto_execute;default:false"` + TraderID string `gorm:"column:trader_id"` + EnableOIRanking bool `gorm:"column:enable_oi_ranking;default:false"` + OIRankingLimit int `gorm:"column:oi_ranking_limit;default:10"` + OIDuration string `gorm:"column:oi_duration;default:1h"` + CreatedAt time.Time `gorm:"column:created_at;autoCreateTime"` + UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime"` +} + +func (DebateSessionDB) TableName() string { + return "debate_sessions" +} + +func (db *DebateSessionDB) toSession() *DebateSession { + s := &DebateSession{ + ID: db.ID, + UserID: db.UserID, + Name: db.Name, + StrategyID: db.StrategyID, + Status: db.Status, + Symbol: db.Symbol, + MaxRounds: db.MaxRounds, + CurrentRound: db.CurrentRound, + IntervalMinutes: db.IntervalMinutes, + PromptVariant: db.PromptVariant, + AutoExecute: db.AutoExecute, + TraderID: db.TraderID, + EnableOIRanking: db.EnableOIRanking, + OIRankingLimit: db.OIRankingLimit, + OIDuration: db.OIDuration, + CreatedAt: db.CreatedAt, + UpdatedAt: db.UpdatedAt, + } + + // Set defaults + if s.IntervalMinutes == 0 { + s.IntervalMinutes = 5 + } + if s.PromptVariant == "" { + s.PromptVariant = "balanced" + } + if s.OIRankingLimit == 0 { + s.OIRankingLimit = 10 + } + if s.OIDuration == "" { + s.OIDuration = "1h" + } + + // Parse final decision + if db.FinalDecision != "" { + var decision DebateDecision + if json.Unmarshal([]byte(db.FinalDecision), &decision) == nil { + s.FinalDecision = &decision + } + } + + return s +} + // DebateParticipant represents an AI participant in a debate type DebateParticipant struct { - ID string `json:"id"` - SessionID string `json:"session_id"` - AIModelID string `json:"ai_model_id"` - AIModelName string `json:"ai_model_name"` - Provider string `json:"provider"` - Personality DebatePersonality `json:"personality"` - Color string `json:"color"` - SpeakOrder int `json:"speak_order"` - CreatedAt time.Time `json:"created_at"` + ID string `gorm:"column:id;primaryKey" json:"id"` + SessionID string `gorm:"column:session_id;not null;index" json:"session_id"` + AIModelID string `gorm:"column:ai_model_id;not null" json:"ai_model_id"` + AIModelName string `gorm:"column:ai_model_name;not null" json:"ai_model_name"` + Provider string `gorm:"column:provider;not null" json:"provider"` + Personality DebatePersonality `gorm:"column:personality;not null" json:"personality"` + Color string `gorm:"column:color;not null" json:"color"` + SpeakOrder int `gorm:"column:speak_order;default:0" json:"speak_order"` + CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"` +} + +func (DebateParticipant) TableName() string { + return "debate_participants" } // DebateMessage represents a message in the debate type DebateMessage struct { - ID string `json:"id"` - SessionID string `json:"session_id"` - Round int `json:"round"` - AIModelID string `json:"ai_model_id"` - AIModelName string `json:"ai_model_name"` - Provider string `json:"provider"` - Personality DebatePersonality `json:"personality"` - MessageType string `json:"message_type"` // analysis/rebuttal/final/vote - Content string `json:"content"` - Decision *DebateDecision `json:"decision,omitempty"` // Single decision (backward compat) - Decisions []*DebateDecision `json:"decisions,omitempty"` // Multi-coin decisions - Confidence int `json:"confidence"` - CreatedAt time.Time `json:"created_at"` + ID string `gorm:"column:id;primaryKey" json:"id"` + SessionID string `gorm:"column:session_id;not null;index" json:"session_id"` + Round int `gorm:"column:round;not null" json:"round"` + AIModelID string `gorm:"column:ai_model_id;not null" json:"ai_model_id"` + AIModelName string `gorm:"column:ai_model_name;not null" json:"ai_model_name"` + Provider string `gorm:"column:provider;not null" json:"provider"` + Personality DebatePersonality `gorm:"column:personality;not null" json:"personality"` + MessageType string `gorm:"column:message_type;not null" json:"message_type"` // analysis/rebuttal/final/vote + Content string `gorm:"column:content;not null" json:"content"` + DecisionRaw string `gorm:"column:decision" json:"-"` // JSON string in DB + Decision *DebateDecision `gorm:"-" json:"decision,omitempty"` // Parsed for API + Decisions []*DebateDecision `gorm:"-" json:"decisions,omitempty"` // Multi-coin decisions + Confidence int `gorm:"column:confidence;default:0" json:"confidence"` + CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"` +} + +func (DebateMessage) TableName() string { + return "debate_messages" } // DebateVote represents a final vote from an AI (can contain multiple coin decisions) type DebateVote struct { - ID string `json:"id"` - SessionID string `json:"session_id"` - AIModelID string `json:"ai_model_id"` - AIModelName string `json:"ai_model_name"` - Action string `json:"action"` // Primary action (backward compat) - Symbol string `json:"symbol"` // Primary symbol (backward compat) - Confidence int `json:"confidence"` - Leverage int `json:"leverage"` - PositionPct float64 `json:"position_pct"` - StopLossPct float64 `json:"stop_loss_pct"` - TakeProfitPct float64 `json:"take_profit_pct"` - Reasoning string `json:"reasoning"` - Decisions []*DebateDecision `json:"decisions,omitempty"` // Multi-coin decisions - CreatedAt time.Time `json:"created_at"` + ID string `gorm:"column:id;primaryKey" json:"id"` + SessionID string `gorm:"column:session_id;not null;index" json:"session_id"` + AIModelID string `gorm:"column:ai_model_id;not null" json:"ai_model_id"` + AIModelName string `gorm:"column:ai_model_name;not null" json:"ai_model_name"` + Action string `gorm:"column:action;not null" json:"action"` // Primary action (backward compat) + Symbol string `gorm:"column:symbol;not null" json:"symbol"` // Primary symbol (backward compat) + Confidence int `gorm:"column:confidence;default:0" json:"confidence"` + Leverage int `gorm:"column:leverage;default:5" json:"leverage"` + PositionPct float64 `gorm:"column:position_pct;default:0.2" json:"position_pct"` + StopLossPct float64 `gorm:"column:stop_loss_pct;default:0.03" json:"stop_loss_pct"` + TakeProfitPct float64 `gorm:"column:take_profit_pct;default:0.06" json:"take_profit_pct"` + Reasoning string `gorm:"column:reasoning" json:"reasoning"` + Decisions []*DebateDecision `gorm:"-" json:"decisions,omitempty"` // Multi-coin decisions + CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"` +} + +func (DebateVote) TableName() string { + return "debate_votes" } // DebateStore handles database operations for debates type DebateStore struct { - db *sql.DB + db *gorm.DB } // NewDebateStore creates a new DebateStore -func NewDebateStore(db *sql.DB) *DebateStore { +func NewDebateStore(db *gorm.DB) *DebateStore { return &DebateStore{db: db} } -// InitSchema creates the debate tables +// InitSchema creates the debate tables using GORM AutoMigrate func (s *DebateStore) InitSchema() error { - schemas := []string{ - `CREATE TABLE IF NOT EXISTS debate_sessions ( - id TEXT PRIMARY KEY, - user_id TEXT NOT NULL, - name TEXT NOT NULL, - strategy_id TEXT NOT NULL, - status TEXT NOT NULL DEFAULT 'pending', - symbol TEXT NOT NULL, - max_rounds INTEGER DEFAULT 3, - current_round INTEGER DEFAULT 0, - interval_minutes INTEGER DEFAULT 5, - prompt_variant TEXT DEFAULT 'balanced', - final_decision TEXT, - auto_execute BOOLEAN DEFAULT 0, - trader_id TEXT, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP - )`, - `CREATE INDEX IF NOT EXISTS idx_debate_sessions_user_id ON debate_sessions(user_id)`, - `CREATE INDEX IF NOT EXISTS idx_debate_sessions_status ON debate_sessions(status)`, - - `CREATE TABLE IF NOT EXISTS debate_participants ( - id TEXT PRIMARY KEY, - session_id TEXT NOT NULL, - ai_model_id TEXT NOT NULL, - ai_model_name TEXT NOT NULL, - provider TEXT NOT NULL, - personality TEXT NOT NULL, - color TEXT NOT NULL, - speak_order INTEGER DEFAULT 0, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (session_id) REFERENCES debate_sessions(id) ON DELETE CASCADE - )`, - `CREATE INDEX IF NOT EXISTS idx_debate_participants_session ON debate_participants(session_id)`, - - `CREATE TABLE IF NOT EXISTS debate_messages ( - id TEXT PRIMARY KEY, - session_id TEXT NOT NULL, - round INTEGER NOT NULL, - ai_model_id TEXT NOT NULL, - ai_model_name TEXT NOT NULL, - provider TEXT NOT NULL, - personality TEXT NOT NULL, - message_type TEXT NOT NULL, - content TEXT NOT NULL, - decision TEXT, - confidence INTEGER DEFAULT 0, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (session_id) REFERENCES debate_sessions(id) ON DELETE CASCADE - )`, - `CREATE INDEX IF NOT EXISTS idx_debate_messages_session ON debate_messages(session_id)`, - `CREATE INDEX IF NOT EXISTS idx_debate_messages_round ON debate_messages(session_id, round)`, - - `CREATE TABLE IF NOT EXISTS debate_votes ( - id TEXT PRIMARY KEY, - session_id TEXT NOT NULL, - ai_model_id TEXT NOT NULL, - ai_model_name TEXT NOT NULL, - action TEXT NOT NULL, - symbol TEXT NOT NULL, - confidence INTEGER DEFAULT 0, - leverage INTEGER DEFAULT 5, - position_pct REAL DEFAULT 0.2, - stop_loss_pct REAL DEFAULT 0.03, - take_profit_pct REAL DEFAULT 0.06, - reasoning TEXT, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (session_id) REFERENCES debate_sessions(id) ON DELETE CASCADE - )`, - `CREATE INDEX IF NOT EXISTS idx_debate_votes_session ON debate_votes(session_id)`, - - // Trigger to update updated_at - `CREATE TRIGGER IF NOT EXISTS update_debate_sessions_timestamp - AFTER UPDATE ON debate_sessions - FOR EACH ROW - BEGIN - UPDATE debate_sessions SET updated_at = CURRENT_TIMESTAMP WHERE id = OLD.id; - END`, - } - - for _, schema := range schemas { - if _, err := s.db.Exec(schema); err != nil { - return fmt.Errorf("failed to create debate schema: %w", err) - } - } - - // Migrate: Add new columns to existing tables (ignore errors if columns already exist) - migrations := []string{ - `ALTER TABLE debate_sessions ADD COLUMN interval_minutes INTEGER DEFAULT 5`, - `ALTER TABLE debate_sessions ADD COLUMN prompt_variant TEXT DEFAULT 'balanced'`, - `ALTER TABLE debate_sessions ADD COLUMN trader_id TEXT`, - `ALTER TABLE debate_sessions ADD COLUMN enable_oi_ranking BOOLEAN DEFAULT 0`, - `ALTER TABLE debate_sessions ADD COLUMN oi_ranking_limit INTEGER DEFAULT 10`, - `ALTER TABLE debate_sessions ADD COLUMN oi_duration TEXT DEFAULT '1h'`, - `ALTER TABLE debate_votes ADD COLUMN leverage INTEGER DEFAULT 5`, - `ALTER TABLE debate_votes ADD COLUMN position_pct REAL DEFAULT 0.2`, - `ALTER TABLE debate_votes ADD COLUMN stop_loss_pct REAL DEFAULT 0.03`, - `ALTER TABLE debate_votes ADD COLUMN take_profit_pct REAL DEFAULT 0.06`, - } - - for _, migration := range migrations { - // Ignore errors - column may already exist - s.db.Exec(migration) - } - - return nil + return s.db.AutoMigrate( + &DebateSessionDB{}, + &DebateParticipant{}, + &DebateMessage{}, + &DebateVote{}, + ) } // CreateSession creates a new debate session @@ -279,227 +263,73 @@ func (s *DebateStore) CreateSession(session *DebateSession) error { if session.OIDuration == "" { session.OIDuration = "1h" } - session.CreatedAt = time.Now() - session.UpdatedAt = time.Now() - _, err := s.db.Exec(` - INSERT INTO debate_sessions (id, user_id, name, strategy_id, status, symbol, max_rounds, current_round, interval_minutes, prompt_variant, auto_execute, trader_id, enable_oi_ranking, oi_ranking_limit, oi_duration, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, - session.ID, session.UserID, session.Name, session.StrategyID, session.Status, - session.Symbol, session.MaxRounds, session.CurrentRound, session.IntervalMinutes, session.PromptVariant, - session.AutoExecute, session.TraderID, session.EnableOIRanking, session.OIRankingLimit, session.OIDuration, - session.CreatedAt, session.UpdatedAt, - ) - return err + db := &DebateSessionDB{ + ID: session.ID, + UserID: session.UserID, + Name: session.Name, + StrategyID: session.StrategyID, + Status: session.Status, + Symbol: session.Symbol, + MaxRounds: session.MaxRounds, + CurrentRound: session.CurrentRound, + IntervalMinutes: session.IntervalMinutes, + PromptVariant: session.PromptVariant, + AutoExecute: session.AutoExecute, + TraderID: session.TraderID, + EnableOIRanking: session.EnableOIRanking, + OIRankingLimit: session.OIRankingLimit, + OIDuration: session.OIDuration, + } + + return s.db.Create(db).Error } // GetSession gets a debate session by ID func (s *DebateStore) GetSession(id string) (*DebateSession, error) { - var session DebateSession - var finalDecisionJSON sql.NullString - var traderID sql.NullString - var intervalMinutes sql.NullInt64 - var promptVariant sql.NullString - var enableOIRanking sql.NullBool - var oiRankingLimit sql.NullInt64 - var oiDuration sql.NullString - - // Try new schema first - err := s.db.QueryRow(` - SELECT id, user_id, name, strategy_id, status, symbol, max_rounds, current_round, - interval_minutes, prompt_variant, final_decision, auto_execute, trader_id, - enable_oi_ranking, oi_ranking_limit, oi_duration, created_at, updated_at - FROM debate_sessions WHERE id = ?`, id, - ).Scan( - &session.ID, &session.UserID, &session.Name, &session.StrategyID, - &session.Status, &session.Symbol, &session.MaxRounds, &session.CurrentRound, - &intervalMinutes, &promptVariant, - &finalDecisionJSON, &session.AutoExecute, &traderID, - &enableOIRanking, &oiRankingLimit, &oiDuration, &session.CreatedAt, &session.UpdatedAt, - ) - - // Fallback to basic schema if new columns don't exist - if err != nil { - err = s.db.QueryRow(` - SELECT id, user_id, name, strategy_id, status, symbol, max_rounds, current_round, - final_decision, auto_execute, created_at, updated_at - FROM debate_sessions WHERE id = ?`, id, - ).Scan( - &session.ID, &session.UserID, &session.Name, &session.StrategyID, - &session.Status, &session.Symbol, &session.MaxRounds, &session.CurrentRound, - &finalDecisionJSON, &session.AutoExecute, &session.CreatedAt, &session.UpdatedAt, - ) - if err != nil { - return nil, err - } - // Set defaults for new fields - session.IntervalMinutes = 5 - session.PromptVariant = "balanced" - session.OIRankingLimit = 10 - session.OIDuration = "1h" - } else { - // Set defaults for nullable fields - session.IntervalMinutes = 5 - if intervalMinutes.Valid { - session.IntervalMinutes = int(intervalMinutes.Int64) - } - session.PromptVariant = "balanced" - if promptVariant.Valid { - session.PromptVariant = promptVariant.String - } - if traderID.Valid { - session.TraderID = traderID.String - } - if enableOIRanking.Valid { - session.EnableOIRanking = enableOIRanking.Bool - } - session.OIRankingLimit = 10 - if oiRankingLimit.Valid { - session.OIRankingLimit = int(oiRankingLimit.Int64) - } - session.OIDuration = "1h" - if oiDuration.Valid { - session.OIDuration = oiDuration.String - } + var db DebateSessionDB + if err := s.db.Where("id = ?", id).First(&db).Error; err != nil { + return nil, err } - - if finalDecisionJSON.Valid && finalDecisionJSON.String != "" { - var decision DebateDecision - if err := json.Unmarshal([]byte(finalDecisionJSON.String), &decision); err == nil { - session.FinalDecision = &decision - } - } - - return &session, nil + return db.toSession(), nil } // GetSessionsByUser gets all debate sessions for a user func (s *DebateStore) GetSessionsByUser(userID string) ([]*DebateSession, error) { - // First try the new schema with all columns - rows, err := s.db.Query(` - SELECT id, user_id, name, strategy_id, status, symbol, max_rounds, current_round, - interval_minutes, prompt_variant, final_decision, auto_execute, trader_id, created_at, updated_at - FROM debate_sessions WHERE user_id = ? ORDER BY created_at DESC`, userID, - ) - - // If query fails (likely due to missing columns), try basic query - if err != nil { - return s.getSessionsByUserBasic(userID) + var dbs []DebateSessionDB + if err := s.db.Where("user_id = ?", userID).Order("created_at DESC").Find(&dbs).Error; err != nil { + return nil, err } - defer rows.Close() - var sessions []*DebateSession - for rows.Next() { - var session DebateSession - var finalDecisionJSON sql.NullString - var traderID sql.NullString - var intervalMinutes sql.NullInt64 - var promptVariant sql.NullString - - if err := rows.Scan( - &session.ID, &session.UserID, &session.Name, &session.StrategyID, - &session.Status, &session.Symbol, &session.MaxRounds, &session.CurrentRound, - &intervalMinutes, &promptVariant, - &finalDecisionJSON, &session.AutoExecute, &traderID, &session.CreatedAt, &session.UpdatedAt, - ); err != nil { - return nil, err - } - - // Set defaults for nullable fields - session.IntervalMinutes = 5 - if intervalMinutes.Valid { - session.IntervalMinutes = int(intervalMinutes.Int64) - } - session.PromptVariant = "balanced" - if promptVariant.Valid { - session.PromptVariant = promptVariant.String - } - - if finalDecisionJSON.Valid && finalDecisionJSON.String != "" { - var decision DebateDecision - if err := json.Unmarshal([]byte(finalDecisionJSON.String), &decision); err == nil { - session.FinalDecision = &decision - } - } - if traderID.Valid { - session.TraderID = traderID.String - } - - sessions = append(sessions, &session) + sessions := make([]*DebateSession, len(dbs)) + for i, db := range dbs { + sessions[i] = db.toSession() } return sessions, nil } // ListAllSessions returns all debate sessions (for cleanup on startup) func (s *DebateStore) ListAllSessions() ([]*DebateSession, error) { - rows, err := s.db.Query(`SELECT id, status FROM debate_sessions`) - if err != nil { + var dbs []DebateSessionDB + if err := s.db.Select("id, status").Find(&dbs).Error; err != nil { return nil, err } - defer rows.Close() - var sessions []*DebateSession - for rows.Next() { - var session DebateSession - if err := rows.Scan(&session.ID, &session.Status); err != nil { - return nil, err - } - sessions = append(sessions, &session) - } - return sessions, nil -} - -// getSessionsByUserBasic is a fallback for old schema without new columns -func (s *DebateStore) getSessionsByUserBasic(userID string) ([]*DebateSession, error) { - rows, err := s.db.Query(` - SELECT id, user_id, name, strategy_id, status, symbol, max_rounds, current_round, - final_decision, auto_execute, created_at, updated_at - FROM debate_sessions WHERE user_id = ? ORDER BY created_at DESC`, userID, - ) - if err != nil { - return nil, err - } - defer rows.Close() - - var sessions []*DebateSession - for rows.Next() { - var session DebateSession - var finalDecisionJSON sql.NullString - - if err := rows.Scan( - &session.ID, &session.UserID, &session.Name, &session.StrategyID, - &session.Status, &session.Symbol, &session.MaxRounds, &session.CurrentRound, - &finalDecisionJSON, &session.AutoExecute, &session.CreatedAt, &session.UpdatedAt, - ); err != nil { - return nil, err - } - - // Set defaults for new fields - session.IntervalMinutes = 5 - session.PromptVariant = "balanced" - - if finalDecisionJSON.Valid && finalDecisionJSON.String != "" { - var decision DebateDecision - if err := json.Unmarshal([]byte(finalDecisionJSON.String), &decision); err == nil { - session.FinalDecision = &decision - } - } - - sessions = append(sessions, &session) + sessions := make([]*DebateSession, len(dbs)) + for i, db := range dbs { + sessions[i] = &DebateSession{ID: db.ID, Status: db.Status} } return sessions, nil } // UpdateSessionStatus updates the status of a debate session func (s *DebateStore) UpdateSessionStatus(id string, status DebateStatus) error { - _, err := s.db.Exec(`UPDATE debate_sessions SET status = ? WHERE id = ?`, status, id) - return err + return s.db.Model(&DebateSessionDB{}).Where("id = ?", id).Update("status", status).Error } // UpdateSessionRound updates the current round of a debate session func (s *DebateStore) UpdateSessionRound(id string, round int) error { - _, err := s.db.Exec(`UPDATE debate_sessions SET current_round = ? WHERE id = ?`, round, id) - return err + return s.db.Model(&DebateSessionDB{}).Where("id = ?", id).Update("current_round", round).Error } // UpdateSessionFinalDecision updates the final decision of a debate session (single decision) @@ -508,30 +338,31 @@ func (s *DebateStore) UpdateSessionFinalDecision(id string, decision *DebateDeci if err != nil { return err } - _, err = s.db.Exec(`UPDATE debate_sessions SET final_decision = ?, status = ? WHERE id = ?`, - string(decisionJSON), DebateStatusCompleted, id) - return err + return s.db.Model(&DebateSessionDB{}).Where("id = ?", id).Updates(map[string]interface{}{ + "final_decision": string(decisionJSON), + "status": DebateStatusCompleted, + }).Error } // UpdateSessionFinalDecisions updates both single and multi-coin final decisions func (s *DebateStore) UpdateSessionFinalDecisions(id string, primaryDecision *DebateDecision, allDecisions []*DebateDecision) error { - // Always store primary decision as a single object (for backward compat) - // This ensures GetSession can deserialize it correctly primaryJSON, err := json.Marshal(primaryDecision) if err != nil { return err } - - // Update final_decision with primary decision and set status to completed - _, err = s.db.Exec(`UPDATE debate_sessions SET final_decision = ?, status = ? WHERE id = ?`, - string(primaryJSON), DebateStatusCompleted, id) - return err + return s.db.Model(&DebateSessionDB{}).Where("id = ?", id).Updates(map[string]interface{}{ + "final_decision": string(primaryJSON), + "status": DebateStatusCompleted, + }).Error } // DeleteSession deletes a debate session and all related data func (s *DebateStore) DeleteSession(id string) error { - _, err := s.db.Exec(`DELETE FROM debate_sessions WHERE id = ?`, id) - return err + // Delete related data first + s.db.Where("session_id = ?", id).Delete(&DebateParticipant{}) + s.db.Where("session_id = ?", id).Delete(&DebateMessage{}) + s.db.Where("session_id = ?", id).Delete(&DebateVote{}) + return s.db.Where("id = ?", id).Delete(&DebateSessionDB{}).Error } // AddParticipant adds a participant to a debate session @@ -539,9 +370,6 @@ func (s *DebateStore) AddParticipant(participant *DebateParticipant) error { if participant.ID == "" { participant.ID = uuid.New().String() } - participant.CreatedAt = time.Now() - - // Set color based on personality if not provided if participant.Color == "" { if color, ok := PersonalityColors[participant.Personality]; ok { participant.Color = color @@ -549,39 +377,14 @@ func (s *DebateStore) AddParticipant(participant *DebateParticipant) error { participant.Color = "#6B7280" // Default gray } } - - _, err := s.db.Exec(` - INSERT INTO debate_participants (id, session_id, ai_model_id, ai_model_name, provider, personality, color, speak_order, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, - participant.ID, participant.SessionID, participant.AIModelID, participant.AIModelName, - participant.Provider, participant.Personality, participant.Color, participant.SpeakOrder, participant.CreatedAt, - ) - return err + return s.db.Create(participant).Error } // GetParticipants gets all participants for a debate session func (s *DebateStore) GetParticipants(sessionID string) ([]*DebateParticipant, error) { - rows, err := s.db.Query(` - SELECT id, session_id, ai_model_id, ai_model_name, provider, personality, color, speak_order, created_at - FROM debate_participants WHERE session_id = ? ORDER BY speak_order`, sessionID, - ) - if err != nil { - return nil, err - } - defer rows.Close() - var participants []*DebateParticipant - for rows.Next() { - var p DebateParticipant - if err := rows.Scan( - &p.ID, &p.SessionID, &p.AIModelID, &p.AIModelName, - &p.Provider, &p.Personality, &p.Color, &p.SpeakOrder, &p.CreatedAt, - ); err != nil { - return nil, err - } - participants = append(participants, &p) - } - return participants, nil + err := s.db.Where("session_id = ?", sessionID).Order("speak_order").Find(&participants).Error + return participants, err } // AddMessage adds a message to a debate session @@ -589,95 +392,52 @@ func (s *DebateStore) AddMessage(msg *DebateMessage) error { if msg.ID == "" { msg.ID = uuid.New().String() } - msg.CreatedAt = time.Now() - - var decisionJSON sql.NullString if msg.Decision != nil { data, err := json.Marshal(msg.Decision) if err != nil { return err } - decisionJSON = sql.NullString{String: string(data), Valid: true} + msg.DecisionRaw = string(data) } - - _, err := s.db.Exec(` - INSERT INTO debate_messages (id, session_id, round, ai_model_id, ai_model_name, provider, personality, message_type, content, decision, confidence, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, - msg.ID, msg.SessionID, msg.Round, msg.AIModelID, msg.AIModelName, - msg.Provider, msg.Personality, msg.MessageType, msg.Content, - decisionJSON, msg.Confidence, msg.CreatedAt, - ) - return err + return s.db.Create(msg).Error } // GetMessages gets all messages for a debate session func (s *DebateStore) GetMessages(sessionID string) ([]*DebateMessage, error) { - rows, err := s.db.Query(` - SELECT id, session_id, round, ai_model_id, ai_model_name, provider, personality, message_type, content, decision, confidence, created_at - FROM debate_messages WHERE session_id = ? ORDER BY round, created_at`, sessionID, - ) + var messages []*DebateMessage + err := s.db.Where("session_id = ?", sessionID).Order("round, created_at").Find(&messages).Error if err != nil { return nil, err } - defer rows.Close() - var messages []*DebateMessage - for rows.Next() { - var msg DebateMessage - var decisionJSON sql.NullString - - if err := rows.Scan( - &msg.ID, &msg.SessionID, &msg.Round, &msg.AIModelID, &msg.AIModelName, - &msg.Provider, &msg.Personality, &msg.MessageType, &msg.Content, - &decisionJSON, &msg.Confidence, &msg.CreatedAt, - ); err != nil { - return nil, err - } - - if decisionJSON.Valid && decisionJSON.String != "" { + // Parse decision JSON + for _, msg := range messages { + if msg.DecisionRaw != "" { var decision DebateDecision - if err := json.Unmarshal([]byte(decisionJSON.String), &decision); err == nil { + if json.Unmarshal([]byte(msg.DecisionRaw), &decision) == nil { msg.Decision = &decision } } - - messages = append(messages, &msg) } return messages, nil } // GetMessagesByRound gets messages for a specific round func (s *DebateStore) GetMessagesByRound(sessionID string, round int) ([]*DebateMessage, error) { - rows, err := s.db.Query(` - SELECT id, session_id, round, ai_model_id, ai_model_name, provider, personality, message_type, content, decision, confidence, created_at - FROM debate_messages WHERE session_id = ? AND round = ? ORDER BY created_at`, sessionID, round, - ) + var messages []*DebateMessage + err := s.db.Where("session_id = ? AND round = ?", sessionID, round).Order("created_at").Find(&messages).Error if err != nil { return nil, err } - defer rows.Close() - var messages []*DebateMessage - for rows.Next() { - var msg DebateMessage - var decisionJSON sql.NullString - - if err := rows.Scan( - &msg.ID, &msg.SessionID, &msg.Round, &msg.AIModelID, &msg.AIModelName, - &msg.Provider, &msg.Personality, &msg.MessageType, &msg.Content, - &decisionJSON, &msg.Confidence, &msg.CreatedAt, - ); err != nil { - return nil, err - } - - if decisionJSON.Valid && decisionJSON.String != "" { + // Parse decision JSON + for _, msg := range messages { + if msg.DecisionRaw != "" { var decision DebateDecision - if err := json.Unmarshal([]byte(decisionJSON.String), &decision); err == nil { + if json.Unmarshal([]byte(msg.DecisionRaw), &decision) == nil { msg.Decision = &decision } } - - messages = append(messages, &msg) } return messages, nil } @@ -687,40 +447,14 @@ func (s *DebateStore) AddVote(vote *DebateVote) error { if vote.ID == "" { vote.ID = uuid.New().String() } - vote.CreatedAt = time.Now() - - _, err := s.db.Exec(` - INSERT INTO debate_votes (id, session_id, ai_model_id, ai_model_name, action, symbol, confidence, leverage, position_pct, stop_loss_pct, take_profit_pct, reasoning, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, - vote.ID, vote.SessionID, vote.AIModelID, vote.AIModelName, - vote.Action, vote.Symbol, vote.Confidence, vote.Leverage, vote.PositionPct, vote.StopLossPct, vote.TakeProfitPct, vote.Reasoning, vote.CreatedAt, - ) - return err + return s.db.Create(vote).Error } // GetVotes gets all votes for a debate session func (s *DebateStore) GetVotes(sessionID string) ([]*DebateVote, error) { - rows, err := s.db.Query(` - SELECT id, session_id, ai_model_id, ai_model_name, action, symbol, confidence, leverage, position_pct, stop_loss_pct, take_profit_pct, reasoning, created_at - FROM debate_votes WHERE session_id = ? ORDER BY created_at`, sessionID, - ) - if err != nil { - return nil, err - } - defer rows.Close() - var votes []*DebateVote - for rows.Next() { - var vote DebateVote - if err := rows.Scan( - &vote.ID, &vote.SessionID, &vote.AIModelID, &vote.AIModelName, - &vote.Action, &vote.Symbol, &vote.Confidence, &vote.Leverage, &vote.PositionPct, &vote.StopLossPct, &vote.TakeProfitPct, &vote.Reasoning, &vote.CreatedAt, - ); err != nil { - return nil, err - } - votes = append(votes, &vote) - } - return votes, nil + err := s.db.Where("session_id = ?", sessionID).Order("created_at").Find(&votes).Error + return votes, err } // DebateSessionWithDetails combines session with participants and messages diff --git a/store/decision.go b/store/decision.go index 7c96681f..6926f326 100644 --- a/store/decision.go +++ b/store/decision.go @@ -1,18 +1,41 @@ package store import ( - "database/sql" "encoding/json" "fmt" "time" + + "gorm.io/gorm" ) // DecisionStore decision log storage type DecisionStore struct { - db *sql.DB + db *gorm.DB } -// DecisionRecord decision record +// DecisionRecordDB internal GORM model for decision_records table +type DecisionRecordDB struct { + ID int64 `gorm:"primaryKey;autoIncrement"` + TraderID string `gorm:"column:trader_id;not null;index:idx_decision_records_trader_time"` + CycleNumber int `gorm:"column:cycle_number;not null"` + Timestamp time.Time `gorm:"not null;index:idx_decision_records_trader_time,sort:desc;index:idx_decision_records_timestamp,sort:desc"` + SystemPrompt string `gorm:"column:system_prompt;default:''"` + InputPrompt string `gorm:"column:input_prompt;default:''"` + CoTTrace string `gorm:"column:cot_trace;default:''"` + DecisionJSON string `gorm:"column:decision_json;default:''"` + RawResponse string `gorm:"column:raw_response;default:''"` + CandidateCoins string `gorm:"column:candidate_coins;default:''"` + ExecutionLog string `gorm:"column:execution_log;default:''"` + Decisions string `gorm:"column:decisions;default:'[]'"` + Success bool `gorm:"default:false"` + ErrorMessage string `gorm:"column:error_message;default:''"` + AIRequestDurationMs int64 `gorm:"column:ai_request_duration_ms;default:0"` + CreatedAt time.Time `json:"created_at"` +} + +func (DecisionRecordDB) TableName() string { return "decision_records" } + +// DecisionRecord decision record (external API struct) type DecisionRecord struct { ID int64 `json:"id"` TraderID string `json:"trader_id"` @@ -81,49 +104,47 @@ type Statistics struct { TotalClosePositions int `json:"total_close_positions"` } -// initTables initializes AI decision log tables -// Note: Account equity curve data has been migrated to trader_equity_snapshots table (managed by EquityStore) -func (s *DecisionStore) initTables() error { - queries := []string{ - // AI decision log table (records AI input/output, chain of thought, etc.) - `CREATE TABLE IF NOT EXISTS decision_records ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - trader_id TEXT NOT NULL, - cycle_number INTEGER NOT NULL, - timestamp DATETIME NOT NULL, - system_prompt TEXT DEFAULT '', - input_prompt TEXT DEFAULT '', - cot_trace TEXT DEFAULT '', - decision_json TEXT DEFAULT '', - raw_response TEXT DEFAULT '', - candidate_coins TEXT DEFAULT '', - execution_log TEXT DEFAULT '', - success BOOLEAN DEFAULT 0, - error_message TEXT DEFAULT '', - ai_request_duration_ms INTEGER DEFAULT 0, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - )`, - // Indexes - `CREATE INDEX IF NOT EXISTS idx_decision_records_trader_time ON decision_records(trader_id, timestamp DESC)`, - `CREATE INDEX IF NOT EXISTS idx_decision_records_timestamp ON decision_records(timestamp DESC)`, - } - - for _, query := range queries { - if _, err := s.db.Exec(query); err != nil { - return fmt.Errorf("failed to execute SQL: %w", err) - } - } - - // Migration: add raw_response column if not exists - s.db.Exec(`ALTER TABLE decision_records ADD COLUMN raw_response TEXT DEFAULT ''`) - - // Migration: add decisions column if not exists - s.db.Exec(`ALTER TABLE decision_records ADD COLUMN decisions TEXT DEFAULT '[]'`) - - return nil +// NewDecisionStore creates a new DecisionStore +func NewDecisionStore(db *gorm.DB) *DecisionStore { + return &DecisionStore{db: db} } -// LogDecision logs decision (only saves AI decision log, equity curve has been migrated to equity table) +// initTables initializes AI decision log tables +func (s *DecisionStore) initTables() error { + // For PostgreSQL with existing table, skip AutoMigrate + if s.db.Dialector.Name() == "postgres" { + var tableExists int64 + s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'decision_records'`).Scan(&tableExists) + if tableExists > 0 { + return nil + } + } + return s.db.AutoMigrate(&DecisionRecordDB{}) +} + +// toRecord converts DB model to API struct +func (db *DecisionRecordDB) toRecord() *DecisionRecord { + record := &DecisionRecord{ + ID: db.ID, + TraderID: db.TraderID, + CycleNumber: db.CycleNumber, + Timestamp: db.Timestamp, + SystemPrompt: db.SystemPrompt, + InputPrompt: db.InputPrompt, + CoTTrace: db.CoTTrace, + DecisionJSON: db.DecisionJSON, + RawResponse: db.RawResponse, + Success: db.Success, + ErrorMessage: db.ErrorMessage, + AIRequestDurationMs: db.AIRequestDurationMs, + } + json.Unmarshal([]byte(db.CandidateCoins), &record.CandidateCoins) + json.Unmarshal([]byte(db.ExecutionLog), &record.ExecutionLog) + json.Unmarshal([]byte(db.Decisions), &record.Decisions) + return record +} + +// LogDecision logs decision func (s *DecisionStore) LogDecision(record *DecisionRecord) error { if record.Timestamp.IsZero() { record.Timestamp = time.Now().UTC() @@ -131,65 +152,49 @@ func (s *DecisionStore) LogDecision(record *DecisionRecord) error { record.Timestamp = record.Timestamp.UTC() } - // Serialize candidate coins, execution log and decisions to JSON + // Serialize arrays to JSON candidateCoinsJSON, _ := json.Marshal(record.CandidateCoins) executionLogJSON, _ := json.Marshal(record.ExecutionLog) decisionsJSON, _ := json.Marshal(record.Decisions) - // Insert decision record main table (only save AI decision related content) - result, err := s.db.Exec(` - INSERT INTO decision_records ( - trader_id, cycle_number, timestamp, system_prompt, input_prompt, - cot_trace, decision_json, raw_response, candidate_coins, execution_log, - decisions, success, error_message, ai_request_duration_ms - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - `, - record.TraderID, record.CycleNumber, record.Timestamp.Format(time.RFC3339), - record.SystemPrompt, record.InputPrompt, record.CoTTrace, record.DecisionJSON, - record.RawResponse, string(candidateCoinsJSON), string(executionLogJSON), - string(decisionsJSON), record.Success, record.ErrorMessage, record.AIRequestDurationMs, - ) - if err != nil { + dbRecord := &DecisionRecordDB{ + TraderID: record.TraderID, + CycleNumber: record.CycleNumber, + Timestamp: record.Timestamp, + SystemPrompt: record.SystemPrompt, + InputPrompt: record.InputPrompt, + CoTTrace: record.CoTTrace, + DecisionJSON: record.DecisionJSON, + RawResponse: record.RawResponse, + CandidateCoins: string(candidateCoinsJSON), + ExecutionLog: string(executionLogJSON), + Decisions: string(decisionsJSON), + Success: record.Success, + ErrorMessage: record.ErrorMessage, + AIRequestDurationMs: record.AIRequestDurationMs, + } + + if err := s.db.Create(dbRecord).Error; err != nil { return fmt.Errorf("failed to insert decision record: %w", err) } - - decisionID, err := result.LastInsertId() - if err != nil { - return fmt.Errorf("failed to get decision ID: %w", err) - } - record.ID = decisionID - + record.ID = dbRecord.ID return nil } // GetLatestRecords gets the latest N records for specified trader (sorted by time in ascending order: old to new) func (s *DecisionStore) GetLatestRecords(traderID string, n int) ([]*DecisionRecord, error) { - rows, err := s.db.Query(` - SELECT id, trader_id, cycle_number, timestamp, system_prompt, input_prompt, - cot_trace, decision_json, candidate_coins, execution_log, - COALESCE(decisions, '[]'), success, error_message, ai_request_duration_ms - FROM decision_records - WHERE trader_id = ? - ORDER BY timestamp DESC - LIMIT ? - `, traderID, n) + var dbRecords []*DecisionRecordDB + err := s.db.Where("trader_id = ?", traderID). + Order("timestamp DESC"). + Limit(n). + Find(&dbRecords).Error if err != nil { return nil, fmt.Errorf("failed to query decision records: %w", err) } - defer rows.Close() - var records []*DecisionRecord - for rows.Next() { - record, err := s.scanDecisionRecord(rows) - if err != nil { - continue - } - records = append(records, record) - } - - // Fill associated data - for _, record := range records { - s.fillRecordDetails(record) + records := make([]*DecisionRecord, len(dbRecords)) + for i, db := range dbRecords { + records[i] = db.toRecord() } // Reverse array to sort time from old to new @@ -202,26 +207,15 @@ func (s *DecisionStore) GetLatestRecords(traderID string, n int) ([]*DecisionRec // GetAllLatestRecords gets the latest N records for all traders func (s *DecisionStore) GetAllLatestRecords(n int) ([]*DecisionRecord, error) { - rows, err := s.db.Query(` - SELECT id, trader_id, cycle_number, timestamp, system_prompt, input_prompt, - cot_trace, decision_json, candidate_coins, execution_log, - COALESCE(decisions, '[]'), success, error_message, ai_request_duration_ms - FROM decision_records - ORDER BY timestamp DESC - LIMIT ? - `, n) + var dbRecords []*DecisionRecordDB + err := s.db.Order("timestamp DESC").Limit(n).Find(&dbRecords).Error if err != nil { return nil, fmt.Errorf("failed to query decision records: %w", err) } - defer rows.Close() - var records []*DecisionRecord - for rows.Next() { - record, err := s.scanDecisionRecord(rows) - if err != nil { - continue - } - records = append(records, record) + records := make([]*DecisionRecord, len(dbRecords)) + for i, db := range dbRecords { + records[i] = db.toRecord() } // Reverse array @@ -236,26 +230,17 @@ func (s *DecisionStore) GetAllLatestRecords(n int) ([]*DecisionRecord, error) { func (s *DecisionStore) GetRecordsByDate(traderID string, date time.Time) ([]*DecisionRecord, error) { dateStr := date.Format("2006-01-02") - rows, err := s.db.Query(` - SELECT id, trader_id, cycle_number, timestamp, system_prompt, input_prompt, - cot_trace, decision_json, candidate_coins, execution_log, - COALESCE(decisions, '[]'), success, error_message, ai_request_duration_ms - FROM decision_records - WHERE trader_id = ? AND DATE(timestamp) = ? - ORDER BY timestamp ASC - `, traderID, dateStr) + var dbRecords []*DecisionRecordDB + err := s.db.Where("trader_id = ? AND DATE(timestamp) = ?", traderID, dateStr). + Order("timestamp ASC"). + Find(&dbRecords).Error if err != nil { return nil, fmt.Errorf("failed to query decision records: %w", err) } - defer rows.Close() - var records []*DecisionRecord - for rows.Next() { - record, err := s.scanDecisionRecord(rows) - if err != nil { - continue - } - records = append(records, record) + records := make([]*DecisionRecord, len(dbRecords)) + for i, db := range dbRecords { + records[i] = db.toRecord() } return records, nil @@ -263,48 +248,31 @@ func (s *DecisionStore) GetRecordsByDate(traderID string, date time.Time) ([]*De // CleanOldRecords cleans old records from N days ago func (s *DecisionStore) CleanOldRecords(traderID string, days int) (int64, error) { - cutoffTime := time.Now().AddDate(0, 0, -days).Format(time.RFC3339) + cutoffTime := time.Now().AddDate(0, 0, -days) - result, err := s.db.Exec(` - DELETE FROM decision_records - WHERE trader_id = ? AND timestamp < ? - `, traderID, cutoffTime) - if err != nil { - return 0, fmt.Errorf("failed to clean old records: %w", err) + result := s.db.Where("trader_id = ? AND timestamp < ?", traderID, cutoffTime). + Delete(&DecisionRecordDB{}) + if result.Error != nil { + return 0, fmt.Errorf("failed to clean old records: %w", result.Error) } - - return result.RowsAffected() + return result.RowsAffected, nil } // GetStatistics gets statistics information for specified trader func (s *DecisionStore) GetStatistics(traderID string) (*Statistics, error) { stats := &Statistics{} - err := s.db.QueryRow(` - SELECT COUNT(*) FROM decision_records WHERE trader_id = ? - `, traderID).Scan(&stats.TotalCycles) - if err != nil { - return nil, fmt.Errorf("failed to query total cycles: %w", err) - } + var totalCount, successCount int64 + s.db.Model(&DecisionRecordDB{}).Where("trader_id = ?", traderID).Count(&totalCount) + s.db.Model(&DecisionRecordDB{}).Where("trader_id = ? AND success = ?", traderID, true).Count(&successCount) - err = s.db.QueryRow(` - SELECT COUNT(*) FROM decision_records WHERE trader_id = ? AND success = 1 - `, traderID).Scan(&stats.SuccessfulCycles) - if err != nil { - return nil, fmt.Errorf("failed to query successful cycles: %w", err) - } + stats.TotalCycles = int(totalCount) + stats.SuccessfulCycles = int(successCount) stats.FailedCycles = stats.TotalCycles - stats.SuccessfulCycles - // Count from trader_positions table - s.db.QueryRow(` - SELECT COUNT(*) FROM trader_positions - WHERE trader_id = ? - `, traderID).Scan(&stats.TotalOpenPositions) - - s.db.QueryRow(` - SELECT COUNT(*) FROM trader_positions - WHERE trader_id = ? AND status = 'CLOSED' - `, traderID).Scan(&stats.TotalClosePositions) + // Count from trader_positions table using raw query for cross-table + s.db.Raw("SELECT COUNT(*) FROM trader_positions WHERE trader_id = ?", traderID).Scan(&stats.TotalOpenPositions) + s.db.Raw("SELECT COUNT(*) FROM trader_positions WHERE trader_id = ? AND status = 'CLOSED'", traderID).Scan(&stats.TotalClosePositions) return stats, nil } @@ -313,64 +281,33 @@ func (s *DecisionStore) GetStatistics(traderID string) (*Statistics, error) { func (s *DecisionStore) GetAllStatistics() (*Statistics, error) { stats := &Statistics{} - s.db.QueryRow(`SELECT COUNT(*) FROM decision_records`).Scan(&stats.TotalCycles) - s.db.QueryRow(`SELECT COUNT(*) FROM decision_records WHERE success = 1`).Scan(&stats.SuccessfulCycles) + var totalCount, successCount int64 + s.db.Model(&DecisionRecordDB{}).Count(&totalCount) + s.db.Model(&DecisionRecordDB{}).Where("success = ?", true).Count(&successCount) + + stats.TotalCycles = int(totalCount) + stats.SuccessfulCycles = int(successCount) stats.FailedCycles = stats.TotalCycles - stats.SuccessfulCycles // Count from trader_positions table - s.db.QueryRow(` - SELECT COUNT(*) FROM trader_positions - `).Scan(&stats.TotalOpenPositions) - - s.db.QueryRow(` - SELECT COUNT(*) FROM trader_positions - WHERE status = 'CLOSED' - `).Scan(&stats.TotalClosePositions) + s.db.Raw("SELECT COUNT(*) FROM trader_positions").Scan(&stats.TotalOpenPositions) + s.db.Raw("SELECT COUNT(*) FROM trader_positions WHERE status = 'CLOSED'").Scan(&stats.TotalClosePositions) return stats, nil } // GetLastCycleNumber gets the last cycle number for specified trader func (s *DecisionStore) GetLastCycleNumber(traderID string) (int, error) { - var cycleNumber int - err := s.db.QueryRow(` - SELECT COALESCE(MAX(cycle_number), 0) FROM decision_records WHERE trader_id = ? - `, traderID).Scan(&cycleNumber) + var cycleNumber *int + err := s.db.Model(&DecisionRecordDB{}). + Where("trader_id = ?", traderID). + Select("MAX(cycle_number)"). + Scan(&cycleNumber).Error if err != nil { return 0, err } - return cycleNumber, nil -} - -// scanDecisionRecord scans decision record from row -func (s *DecisionStore) scanDecisionRecord(rows *sql.Rows) (*DecisionRecord, error) { - var record DecisionRecord - var timestampStr string - var candidateCoinsJSON, executionLogJSON, decisionsJSON string - - err := rows.Scan( - &record.ID, &record.TraderID, &record.CycleNumber, ×tampStr, - &record.SystemPrompt, &record.InputPrompt, &record.CoTTrace, - &record.DecisionJSON, &candidateCoinsJSON, &executionLogJSON, - &decisionsJSON, &record.Success, &record.ErrorMessage, &record.AIRequestDurationMs, - ) - if err != nil { - return nil, err + if cycleNumber == nil { + return 0, nil } - - record.Timestamp, _ = time.Parse(time.RFC3339, timestampStr) - json.Unmarshal([]byte(candidateCoinsJSON), &record.CandidateCoins) - json.Unmarshal([]byte(executionLogJSON), &record.ExecutionLog) - json.Unmarshal([]byte(decisionsJSON), &record.Decisions) - - return &record, nil -} - -// fillRecordDetails fills associated data for decision record (old associated tables removed, this function kept for compatibility) -// Note: Account snapshot, position snapshot, decision action data are no longer stored in decision related tables -// - For equity data use EquityStore.GetLatest() -// - For order data use OrderStore -func (s *DecisionStore) fillRecordDetails(record *DecisionRecord) { - // Old associated tables removed, no longer need to fill - // AccountState, Positions, Decisions fields will remain at zero values + return *cycleNumber, nil } diff --git a/store/driver.go b/store/driver.go index 244b18a3..003534ba 100644 --- a/store/driver.go +++ b/store/driver.go @@ -238,3 +238,44 @@ func getEnv(key, defaultValue string) string { } return defaultValue } + +// convertQuery converts ? placeholders to $1, $2 for PostgreSQL +// and handles other database-specific syntax differences +func convertQuery(query string, dbType DBType) string { + if dbType != DBTypePostgres { + return query + } + result := query + + // Convert ? to $1, $2, etc. for PostgreSQL + index := 1 + for strings.Contains(result, "?") { + result = strings.Replace(result, "?", fmt.Sprintf("$%d", index), 1) + index++ + } + + // Convert datetime('now') to CURRENT_TIMESTAMP + result = strings.ReplaceAll(result, "datetime('now')", "CURRENT_TIMESTAMP") + + // Remove datetime() wrapper for ORDER BY (PostgreSQL timestamps sort correctly) + // This handles patterns like "ORDER BY datetime(column) DESC" + result = strings.ReplaceAll(result, "datetime(updated_at)", "updated_at") + result = strings.ReplaceAll(result, "datetime(created_at)", "created_at") + + return result +} + +// boolDefault returns database-appropriate boolean default for COALESCE +// Use in queries like: COALESCE(column, %s) +func boolDefault(dbType DBType, value bool) string { + if dbType == DBTypePostgres { + if value { + return "TRUE" + } + return "FALSE" + } + if value { + return "1" + } + return "0" +} diff --git a/store/equity.go b/store/equity.go index 34ec2c77..9c337019 100644 --- a/store/equity.go +++ b/store/equity.go @@ -1,55 +1,48 @@ package store import ( - "database/sql" "fmt" "time" + + "gorm.io/gorm" ) // EquityStore account equity storage (for plotting return curves) type EquityStore struct { - db *sql.DB + db *gorm.DB } // EquitySnapshot equity snapshot type EquitySnapshot struct { - ID int64 `json:"id"` - TraderID string `json:"trader_id"` - Timestamp time.Time `json:"timestamp"` - TotalEquity float64 `json:"total_equity"` // Account equity (balance + unrealized PnL) - Balance float64 `json:"balance"` // Account balance - UnrealizedPnL float64 `json:"unrealized_pnl"` // Unrealized profit and loss - PositionCount int `json:"position_count"` // Position count - MarginUsedPct float64 `json:"margin_used_pct"` // Margin usage percentage + ID int64 `gorm:"primaryKey;autoIncrement" json:"id"` + TraderID string `gorm:"column:trader_id;not null;index:idx_equity_trader_time" json:"trader_id"` + Timestamp time.Time `gorm:"not null;index:idx_equity_trader_time,sort:desc;index:idx_equity_timestamp,sort:desc" json:"timestamp"` + TotalEquity float64 `gorm:"column:total_equity;not null;default:0" json:"total_equity"` + Balance float64 `gorm:"not null;default:0" json:"balance"` + UnrealizedPnL float64 `gorm:"column:unrealized_pnl;not null;default:0" json:"unrealized_pnl"` + PositionCount int `gorm:"column:position_count;default:0" json:"position_count"` + MarginUsedPct float64 `gorm:"column:margin_used_pct;default:0" json:"margin_used_pct"` + CreatedAt time.Time `json:"created_at"` +} + +func (EquitySnapshot) TableName() string { return "trader_equity_snapshots" } + +// NewEquityStore creates a new EquityStore +func NewEquityStore(db *gorm.DB) *EquityStore { + return &EquityStore{db: db} } // initTables initializes equity tables func (s *EquityStore) initTables() error { - queries := []string{ - // Equity snapshot table - specifically for return curves - `CREATE TABLE IF NOT EXISTS trader_equity_snapshots ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - trader_id TEXT NOT NULL, - timestamp DATETIME NOT NULL, - total_equity REAL NOT NULL DEFAULT 0, - balance REAL NOT NULL DEFAULT 0, - unrealized_pnl REAL NOT NULL DEFAULT 0, - position_count INTEGER DEFAULT 0, - margin_used_pct REAL DEFAULT 0, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - )`, - // Indexes - `CREATE INDEX IF NOT EXISTS idx_equity_trader_time ON trader_equity_snapshots(trader_id, timestamp DESC)`, - `CREATE INDEX IF NOT EXISTS idx_equity_timestamp ON trader_equity_snapshots(timestamp DESC)`, - } - - for _, query := range queries { - if _, err := s.db.Exec(query); err != nil { - return fmt.Errorf("failed to execute SQL: %w", err) + // For PostgreSQL with existing table, skip AutoMigrate + if s.db.Dialector.Name() == "postgres" { + var tableExists int64 + s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'trader_equity_snapshots'`).Scan(&tableExists) + if tableExists > 0 { + return nil } } - - return nil + return s.db.AutoMigrate(&EquitySnapshot{}) } // Save saves equity snapshot @@ -60,58 +53,22 @@ func (s *EquityStore) Save(snapshot *EquitySnapshot) error { snapshot.Timestamp = snapshot.Timestamp.UTC() } - result, err := s.db.Exec(` - INSERT INTO trader_equity_snapshots ( - trader_id, timestamp, total_equity, balance, - unrealized_pnl, position_count, margin_used_pct - ) VALUES (?, ?, ?, ?, ?, ?, ?) - `, - snapshot.TraderID, - snapshot.Timestamp.Format(time.RFC3339), - snapshot.TotalEquity, - snapshot.Balance, - snapshot.UnrealizedPnL, - snapshot.PositionCount, - snapshot.MarginUsedPct, - ) - if err != nil { + if err := s.db.Create(snapshot).Error; err != nil { return fmt.Errorf("failed to save equity snapshot: %w", err) } - - id, _ := result.LastInsertId() - snapshot.ID = id return nil } // GetLatest gets the latest N equity records for specified trader (sorted in ascending chronological order: old to new) func (s *EquityStore) GetLatest(traderID string, limit int) ([]*EquitySnapshot, error) { - rows, err := s.db.Query(` - SELECT id, trader_id, timestamp, total_equity, balance, - unrealized_pnl, position_count, margin_used_pct - FROM trader_equity_snapshots - WHERE trader_id = ? - ORDER BY timestamp DESC - LIMIT ? - `, traderID, limit) + var snapshots []*EquitySnapshot + err := s.db.Where("trader_id = ?", traderID). + Order("timestamp DESC"). + Limit(limit). + Find(&snapshots).Error if err != nil { return nil, fmt.Errorf("failed to query equity records: %w", err) } - defer rows.Close() - - var snapshots []*EquitySnapshot - for rows.Next() { - snap := &EquitySnapshot{} - var timestampStr string - err := rows.Scan( - &snap.ID, &snap.TraderID, ×tampStr, &snap.TotalEquity, - &snap.Balance, &snap.UnrealizedPnL, &snap.PositionCount, &snap.MarginUsedPct, - ) - if err != nil { - continue - } - snap.Timestamp, _ = time.Parse(time.RFC3339, timestampStr) - snapshots = append(snapshots, snap) - } // Reverse the array to sort time from old to new (suitable for plotting curves) for i, j := 0, len(snapshots)-1; i < j; i, j = i+1, j-1 { @@ -123,116 +80,81 @@ func (s *EquityStore) GetLatest(traderID string, limit int) ([]*EquitySnapshot, // GetByTimeRange gets equity records within specified time range func (s *EquityStore) GetByTimeRange(traderID string, start, end time.Time) ([]*EquitySnapshot, error) { - rows, err := s.db.Query(` - SELECT id, trader_id, timestamp, total_equity, balance, - unrealized_pnl, position_count, margin_used_pct - FROM trader_equity_snapshots - WHERE trader_id = ? AND timestamp >= ? AND timestamp <= ? - ORDER BY timestamp ASC - `, traderID, start.Format(time.RFC3339), end.Format(time.RFC3339)) + var snapshots []*EquitySnapshot + err := s.db.Where("trader_id = ? AND timestamp >= ? AND timestamp <= ?", traderID, start, end). + Order("timestamp ASC"). + Find(&snapshots).Error if err != nil { return nil, fmt.Errorf("failed to query equity records: %w", err) } - defer rows.Close() - - var snapshots []*EquitySnapshot - for rows.Next() { - snap := &EquitySnapshot{} - var timestampStr string - err := rows.Scan( - &snap.ID, &snap.TraderID, ×tampStr, &snap.TotalEquity, - &snap.Balance, &snap.UnrealizedPnL, &snap.PositionCount, &snap.MarginUsedPct, - ) - if err != nil { - continue - } - snap.Timestamp, _ = time.Parse(time.RFC3339, timestampStr) - snapshots = append(snapshots, snap) - } - return snapshots, nil } // GetAllTradersLatest gets latest equity for all traders (for leaderboards) func (s *EquityStore) GetAllTradersLatest() (map[string]*EquitySnapshot, error) { - rows, err := s.db.Query(` + // Use raw SQL for this complex query with subquery + var snapshots []*EquitySnapshot + err := s.db.Raw(` SELECT e.id, e.trader_id, e.timestamp, e.total_equity, e.balance, - e.unrealized_pnl, e.position_count, e.margin_used_pct + e.unrealized_pnl, e.position_count, e.margin_used_pct, e.created_at FROM trader_equity_snapshots e INNER JOIN ( SELECT trader_id, MAX(timestamp) as max_ts FROM trader_equity_snapshots GROUP BY trader_id ) latest ON e.trader_id = latest.trader_id AND e.timestamp = latest.max_ts - `) + `).Scan(&snapshots).Error if err != nil { return nil, fmt.Errorf("failed to query latest equity: %w", err) } - defer rows.Close() result := make(map[string]*EquitySnapshot) - for rows.Next() { - snap := &EquitySnapshot{} - var timestampStr string - err := rows.Scan( - &snap.ID, &snap.TraderID, ×tampStr, &snap.TotalEquity, - &snap.Balance, &snap.UnrealizedPnL, &snap.PositionCount, &snap.MarginUsedPct, - ) - if err != nil { - continue - } - snap.Timestamp, _ = time.Parse(time.RFC3339, timestampStr) + for _, snap := range snapshots { result[snap.TraderID] = snap } - return result, nil } // CleanOldRecords cleans old records from N days ago func (s *EquityStore) CleanOldRecords(traderID string, days int) (int64, error) { - cutoffTime := time.Now().AddDate(0, 0, -days).Format(time.RFC3339) + cutoffTime := time.Now().AddDate(0, 0, -days) - result, err := s.db.Exec(` - DELETE FROM trader_equity_snapshots - WHERE trader_id = ? AND timestamp < ? - `, traderID, cutoffTime) - if err != nil { - return 0, fmt.Errorf("failed to clean old records: %w", err) + result := s.db.Where("trader_id = ? AND timestamp < ?", traderID, cutoffTime). + Delete(&EquitySnapshot{}) + if result.Error != nil { + return 0, fmt.Errorf("failed to clean old records: %w", result.Error) } - - return result.RowsAffected() + return result.RowsAffected, nil } // GetCount gets record count for specified trader func (s *EquityStore) GetCount(traderID string) (int, error) { - var count int - err := s.db.QueryRow(` - SELECT COUNT(*) FROM trader_equity_snapshots WHERE trader_id = ? - `, traderID).Scan(&count) - return count, err + var count int64 + err := s.db.Model(&EquitySnapshot{}).Where("trader_id = ?", traderID).Count(&count).Error + return int(count), err } // MigrateFromDecision migrates data from old decision_account_snapshots table func (s *EquityStore) MigrateFromDecision() (int64, error) { // Check if migration is needed (whether new table is empty) - var count int - s.db.QueryRow(`SELECT COUNT(*) FROM trader_equity_snapshots`).Scan(&count) + var count int64 + s.db.Model(&EquitySnapshot{}).Count(&count) if count > 0 { return 0, nil // Already has data, skip migration } - // Check if old table exists + // Check if old table exists (SQLite specific check, but works for migration) var tableName string - err := s.db.QueryRow(` + err := s.db.Raw(` SELECT name FROM sqlite_master WHERE type='table' AND name='decision_account_snapshots' - `).Scan(&tableName) - if err != nil { + `).Scan(&tableName).Error + if err != nil || tableName == "" { return 0, nil // Old table doesn't exist, skip } // Migrate data: join query from decision_records + decision_account_snapshots - result, err := s.db.Exec(` + result := s.db.Exec(` INSERT INTO trader_equity_snapshots ( trader_id, timestamp, total_equity, balance, unrealized_pnl, position_count, margin_used_pct @@ -249,9 +171,9 @@ func (s *EquityStore) MigrateFromDecision() (int64, error) { JOIN decision_account_snapshots das ON dr.id = das.decision_id ORDER BY dr.timestamp ASC `) - if err != nil { - return 0, fmt.Errorf("failed to migrate data: %w", err) + if result.Error != nil { + return 0, fmt.Errorf("failed to migrate data: %w", result.Error) } - return result.RowsAffected() + return result.RowsAffected, nil } diff --git a/store/exchange.go b/store/exchange.go index ce2e4399..cb210355 100644 --- a/store/exchange.go +++ b/store/exchange.go @@ -1,83 +1,68 @@ package store import ( - "database/sql" "fmt" + "nofx/crypto" "nofx/logger" - "strings" "time" "github.com/google/uuid" + "gorm.io/gorm" ) // ExchangeStore exchange storage type ExchangeStore struct { - db *sql.DB - encryptFunc func(string) string - decryptFunc func(string) string + db *gorm.DB } // Exchange exchange configuration type Exchange struct { - ID string `json:"id"` // UUID - ExchangeType string `json:"exchange_type"` // "binance", "bybit", "okx", "hyperliquid", "aster", "lighter" - AccountName string `json:"account_name"` // User-defined account name - UserID string `json:"user_id"` - Name string `json:"name"` // Display name (auto-generated or user-defined) - Type string `json:"type"` // "cex" or "dex" - Enabled bool `json:"enabled"` - APIKey string `json:"apiKey"` - SecretKey string `json:"secretKey"` - Passphrase string `json:"passphrase"` // OKX-specific - Testnet bool `json:"testnet"` - HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"` - AsterUser string `json:"asterUser"` - AsterSigner string `json:"asterSigner"` - AsterPrivateKey string `json:"asterPrivateKey"` - LighterWalletAddr string `json:"lighterWalletAddr"` - LighterPrivateKey string `json:"lighterPrivateKey"` - LighterAPIKeyPrivateKey string `json:"lighterAPIKeyPrivateKey"` - LighterAPIKeyIndex int `json:"lighterAPIKeyIndex"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID string `gorm:"primaryKey" json:"id"` + ExchangeType string `gorm:"column:exchange_type;not null;default:''" json:"exchange_type"` + AccountName string `gorm:"column:account_name;not null;default:''" json:"account_name"` + UserID string `gorm:"column:user_id;not null;default:default;index" json:"user_id"` + Name string `gorm:"not null" json:"name"` + Type string `gorm:"not null" json:"type"` // "cex" or "dex" + Enabled bool `gorm:"default:false" json:"enabled"` + APIKey crypto.EncryptedString `gorm:"column:api_key;default:''" json:"apiKey"` + SecretKey crypto.EncryptedString `gorm:"column:secret_key;default:''" json:"secretKey"` + Passphrase crypto.EncryptedString `gorm:"column:passphrase;default:''" json:"passphrase"` + Testnet bool `gorm:"default:false" json:"testnet"` + HyperliquidWalletAddr string `gorm:"column:hyperliquid_wallet_addr;default:''" json:"hyperliquidWalletAddr"` + AsterUser string `gorm:"column:aster_user;default:''" json:"asterUser"` + AsterSigner string `gorm:"column:aster_signer;default:''" json:"asterSigner"` + AsterPrivateKey crypto.EncryptedString `gorm:"column:aster_private_key;default:''" json:"asterPrivateKey"` + LighterWalletAddr string `gorm:"column:lighter_wallet_addr;default:''" json:"lighterWalletAddr"` + LighterPrivateKey crypto.EncryptedString `gorm:"column:lighter_private_key;default:''" json:"lighterPrivateKey"` + LighterAPIKeyPrivateKey crypto.EncryptedString `gorm:"column:lighter_api_key_private_key;default:''" json:"lighterAPIKeyPrivateKey"` + LighterAPIKeyIndex int `gorm:"column:lighter_api_key_index;default:0" json:"lighterAPIKeyIndex"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func (Exchange) TableName() string { return "exchanges" } + +// NewExchangeStore creates a new ExchangeStore +func NewExchangeStore(db *gorm.DB) *ExchangeStore { + return &ExchangeStore{db: db} } func (s *ExchangeStore) initTables() error { - // Create new table structure with UUID as primary key - _, err := s.db.Exec(` - CREATE TABLE IF NOT EXISTS exchanges ( - id TEXT PRIMARY KEY, - exchange_type TEXT NOT NULL DEFAULT '', - account_name TEXT NOT NULL DEFAULT '', - user_id TEXT NOT NULL DEFAULT 'default', - name TEXT NOT NULL, - type TEXT NOT NULL, - enabled BOOLEAN DEFAULT 0, - api_key TEXT DEFAULT '', - secret_key TEXT DEFAULT '', - passphrase TEXT DEFAULT '', - testnet BOOLEAN DEFAULT 0, - hyperliquid_wallet_addr TEXT DEFAULT '', - aster_user TEXT DEFAULT '', - aster_signer TEXT DEFAULT '', - aster_private_key TEXT DEFAULT '', - lighter_wallet_addr TEXT DEFAULT '', - lighter_private_key TEXT DEFAULT '', - lighter_api_key_private_key TEXT DEFAULT '', - lighter_api_key_index INTEGER DEFAULT 0, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - `) - if err != nil { - return err + // For PostgreSQL with existing table, skip AutoMigrate + if s.db.Dialector.Name() == "postgres" { + var tableExists int64 + s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'exchanges'`).Scan(&tableExists) + if tableExists > 0 { + // Still run data migrations + s.migrateToMultiAccount() + s.db.Model(&Exchange{}).Where("account_name = '' OR account_name IS NULL").Update("account_name", "Default") + return nil + } } - // Migration: add new columns if not exists - s.db.Exec(`ALTER TABLE exchanges ADD COLUMN passphrase TEXT DEFAULT ''`) - s.db.Exec(`ALTER TABLE exchanges ADD COLUMN exchange_type TEXT NOT NULL DEFAULT ''`) - s.db.Exec(`ALTER TABLE exchanges ADD COLUMN account_name TEXT NOT NULL DEFAULT ''`) - s.db.Exec(`ALTER TABLE exchanges ADD COLUMN lighter_api_key_index INTEGER DEFAULT 0`) + if err := s.db.AutoMigrate(&Exchange{}); err != nil { + return err + } // Run migration to multi-account if needed if err := s.migrateToMultiAccount(); err != nil { @@ -85,120 +70,65 @@ func (s *ExchangeStore) initTables() error { } // Fix empty account_name for existing records - s.db.Exec(`UPDATE exchanges SET account_name = 'Default' WHERE account_name = '' OR account_name IS NULL`) + s.db.Model(&Exchange{}).Where("account_name = '' OR account_name IS NULL").Update("account_name", "Default") - // Update trigger for new schema - s.db.Exec(`DROP TRIGGER IF EXISTS update_exchanges_updated_at`) - _, err = s.db.Exec(` - CREATE TRIGGER IF NOT EXISTS update_exchanges_updated_at - AFTER UPDATE ON exchanges - BEGIN - UPDATE exchanges SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id; - END - `) - return err + return nil } // migrateToMultiAccount migrates old schema (id=exchange_type) to new schema (id=UUID) func (s *ExchangeStore) migrateToMultiAccount() error { // Check if migration is needed by looking for old-style IDs (non-UUID) - var count int - err := s.db.QueryRow(` - SELECT COUNT(*) FROM exchanges - WHERE exchange_type = '' AND id IN ('binance', 'bybit', 'okx', 'bitget', 'hyperliquid', 'aster', 'lighter') - `).Scan(&count) + var count int64 + err := s.db.Model(&Exchange{}). + Where("exchange_type = '' AND id IN ?", []string{"binance", "bybit", "okx", "bitget", "hyperliquid", "aster", "lighter"}). + Count(&count).Error if err != nil { return err } if count == 0 { - // No migration needed return nil } logger.Infof("🔄 Migrating %d exchange records to multi-account schema...", count) // Get all old records - rows, err := s.db.Query(` - SELECT id, user_id, name, type, enabled, api_key, secret_key, - COALESCE(passphrase, '') as passphrase, testnet, - COALESCE(hyperliquid_wallet_addr, '') as hyperliquid_wallet_addr, - COALESCE(aster_user, '') as aster_user, - COALESCE(aster_signer, '') as aster_signer, - COALESCE(aster_private_key, '') as aster_private_key, - COALESCE(lighter_wallet_addr, '') as lighter_wallet_addr, - COALESCE(lighter_private_key, '') as lighter_private_key, - COALESCE(lighter_api_key_private_key, '') as lighter_api_key_private_key - FROM exchanges - WHERE exchange_type = '' AND id IN ('binance', 'bybit', 'okx', 'bitget', 'hyperliquid', 'aster', 'lighter') - `) + var records []Exchange + err = s.db.Where("exchange_type = '' AND id IN ?", []string{"binance", "bybit", "okx", "bitget", "hyperliquid", "aster", "lighter"}). + Find(&records).Error if err != nil { return err } - defer rows.Close() - - type oldRecord struct { - id, userID, name, typ string - enabled, testnet bool - apiKey, secretKey, passphrase string - hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string - lighterWalletAddr, lighterPrivateKey, lighterApiKeyPrivateKey string - } - - var records []oldRecord - for rows.Next() { - var r oldRecord - if err := rows.Scan(&r.id, &r.userID, &r.name, &r.typ, &r.enabled, - &r.apiKey, &r.secretKey, &r.passphrase, &r.testnet, - &r.hyperliquidWalletAddr, &r.asterUser, &r.asterSigner, &r.asterPrivateKey, - &r.lighterWalletAddr, &r.lighterPrivateKey, &r.lighterApiKeyPrivateKey); err != nil { - return err - } - records = append(records, r) - } // Begin transaction - tx, err := s.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() + return s.db.Transaction(func(tx *gorm.DB) error { + for _, r := range records { + newID := uuid.New().String() + oldID := r.ID // This is the exchange type (e.g., "binance") - // Migrate each record - for _, r := range records { - newID := uuid.New().String() - oldID := r.id // This is the exchange type (e.g., "binance") + // Update traders table to use new UUID + if err := tx.Exec("UPDATE traders SET exchange_id = ? WHERE exchange_id = ? AND user_id = ?", + newID, oldID, r.UserID).Error; err != nil { + logger.Errorf("Failed to update traders for exchange %s: %v", oldID, err) + return err + } - // Update traders table to use new UUID - _, err = tx.Exec(`UPDATE traders SET exchange_id = ? WHERE exchange_id = ? AND user_id = ?`, - newID, oldID, r.userID) - if err != nil { - logger.Errorf("Failed to update traders for exchange %s: %v", oldID, err) - return err + // Update the exchange record + if err := tx.Model(&Exchange{}). + Where("id = ? AND user_id = ?", oldID, r.UserID). + Updates(map[string]interface{}{ + "id": newID, + "exchange_type": oldID, + "account_name": "Default", + }).Error; err != nil { + logger.Errorf("Failed to migrate exchange %s: %v", oldID, err) + return err + } + + logger.Infof("✅ Migrated exchange %s -> UUID %s for user %s", oldID, newID, r.UserID) } - - // Update the exchange record - _, err = tx.Exec(` - UPDATE exchanges SET - id = ?, - exchange_type = ?, - account_name = ? - WHERE id = ? AND user_id = ? - `, newID, oldID, "Default", oldID, r.userID) - if err != nil { - logger.Errorf("Failed to migrate exchange %s: %v", oldID, err) - return err - } - - logger.Infof("✅ Migrated exchange %s -> UUID %s for user %s", oldID, newID, r.userID) - } - - if err := tx.Commit(); err != nil { - return err - } - - logger.Infof("✅ Multi-account migration completed successfully") - return nil + return nil + }) } func (s *ExchangeStore) initDefaultData() error { @@ -206,108 +136,24 @@ func (s *ExchangeStore) initDefaultData() error { return nil } -func (s *ExchangeStore) encrypt(plaintext string) string { - if s.encryptFunc != nil { - return s.encryptFunc(plaintext) - } - return plaintext -} - -func (s *ExchangeStore) decrypt(encrypted string) string { - if s.decryptFunc != nil { - return s.decryptFunc(encrypted) - } - return encrypted -} - // List gets user's exchange list func (s *ExchangeStore) List(userID string) ([]*Exchange, error) { - rows, err := s.db.Query(` - SELECT id, COALESCE(exchange_type, '') as exchange_type, COALESCE(account_name, '') as account_name, - user_id, name, type, enabled, api_key, secret_key, - COALESCE(passphrase, '') as passphrase, testnet, - COALESCE(hyperliquid_wallet_addr, '') as hyperliquid_wallet_addr, - COALESCE(aster_user, '') as aster_user, - COALESCE(aster_signer, '') as aster_signer, - COALESCE(aster_private_key, '') as aster_private_key, - COALESCE(lighter_wallet_addr, '') as lighter_wallet_addr, - COALESCE(lighter_private_key, '') as lighter_private_key, - COALESCE(lighter_api_key_private_key, '') as lighter_api_key_private_key, - COALESCE(lighter_api_key_index, 0) as lighter_api_key_index, - created_at, updated_at - FROM exchanges WHERE user_id = ? ORDER BY exchange_type, account_name - `, userID) + var exchanges []*Exchange + err := s.db.Where("user_id = ?", userID).Order("exchange_type, account_name").Find(&exchanges).Error if err != nil { return nil, err } - defer rows.Close() - - exchanges := make([]*Exchange, 0) - for rows.Next() { - var e Exchange - var createdAt, updatedAt string - err := rows.Scan( - &e.ID, &e.ExchangeType, &e.AccountName, - &e.UserID, &e.Name, &e.Type, - &e.Enabled, &e.APIKey, &e.SecretKey, &e.Passphrase, &e.Testnet, - &e.HyperliquidWalletAddr, &e.AsterUser, &e.AsterSigner, &e.AsterPrivateKey, - &e.LighterWalletAddr, &e.LighterPrivateKey, &e.LighterAPIKeyPrivateKey, &e.LighterAPIKeyIndex, - &createdAt, &updatedAt, - ) - if err != nil { - return nil, err - } - e.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - e.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) - e.APIKey = s.decrypt(e.APIKey) - e.SecretKey = s.decrypt(e.SecretKey) - e.Passphrase = s.decrypt(e.Passphrase) - e.AsterPrivateKey = s.decrypt(e.AsterPrivateKey) - e.LighterPrivateKey = s.decrypt(e.LighterPrivateKey) - e.LighterAPIKeyPrivateKey = s.decrypt(e.LighterAPIKeyPrivateKey) - exchanges = append(exchanges, &e) - } return exchanges, nil } // GetByID gets a specific exchange by UUID func (s *ExchangeStore) GetByID(userID, id string) (*Exchange, error) { - var e Exchange - var createdAt, updatedAt string - err := s.db.QueryRow(` - SELECT id, COALESCE(exchange_type, '') as exchange_type, COALESCE(account_name, '') as account_name, - user_id, name, type, enabled, api_key, secret_key, - COALESCE(passphrase, '') as passphrase, testnet, - COALESCE(hyperliquid_wallet_addr, '') as hyperliquid_wallet_addr, - COALESCE(aster_user, '') as aster_user, - COALESCE(aster_signer, '') as aster_signer, - COALESCE(aster_private_key, '') as aster_private_key, - COALESCE(lighter_wallet_addr, '') as lighter_wallet_addr, - COALESCE(lighter_private_key, '') as lighter_private_key, - COALESCE(lighter_api_key_private_key, '') as lighter_api_key_private_key, - COALESCE(lighter_api_key_index, 0) as lighter_api_key_index, - created_at, updated_at - FROM exchanges WHERE id = ? AND user_id = ? - `, id, userID).Scan( - &e.ID, &e.ExchangeType, &e.AccountName, - &e.UserID, &e.Name, &e.Type, - &e.Enabled, &e.APIKey, &e.SecretKey, &e.Passphrase, &e.Testnet, - &e.HyperliquidWalletAddr, &e.AsterUser, &e.AsterSigner, &e.AsterPrivateKey, - &e.LighterWalletAddr, &e.LighterPrivateKey, &e.LighterAPIKeyPrivateKey, &e.LighterAPIKeyIndex, - &createdAt, &updatedAt, - ) + var exchange Exchange + err := s.db.Where("id = ? AND user_id = ?", id, userID).First(&exchange).Error if err != nil { return nil, err } - e.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - e.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) - e.APIKey = s.decrypt(e.APIKey) - e.SecretKey = s.decrypt(e.SecretKey) - e.Passphrase = s.decrypt(e.Passphrase) - e.AsterPrivateKey = s.decrypt(e.AsterPrivateKey) - e.LighterPrivateKey = s.decrypt(e.LighterPrivateKey) - e.LighterAPIKeyPrivateKey = s.decrypt(e.LighterAPIKeyPrivateKey) - return &e, nil + return &exchange, nil } // getExchangeNameAndType returns the display name and type for an exchange type @@ -341,7 +187,6 @@ func (s *ExchangeStore) Create(userID, exchangeType, accountName string, enabled id := uuid.New().String() name, typ := getExchangeNameAndType(exchangeType) - // If account name is empty, use "Default" if accountName == "" { accountName = "Default" } @@ -349,19 +194,29 @@ func (s *ExchangeStore) Create(userID, exchangeType, accountName string, enabled logger.Debugf("🔧 ExchangeStore.Create: userID=%s, exchangeType=%s, accountName=%s, id=%s", userID, exchangeType, accountName, id) - _, err := s.db.Exec(` - INSERT INTO exchanges (id, exchange_type, account_name, user_id, name, type, enabled, - api_key, secret_key, passphrase, testnet, - hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key, - lighter_wallet_addr, lighter_private_key, lighter_api_key_private_key, lighter_api_key_index, - created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now')) - `, id, exchangeType, accountName, userID, name, typ, enabled, - s.encrypt(apiKey), s.encrypt(secretKey), s.encrypt(passphrase), testnet, - hyperliquidWalletAddr, asterUser, asterSigner, s.encrypt(asterPrivateKey), - lighterWalletAddr, s.encrypt(lighterPrivateKey), s.encrypt(lighterApiKeyPrivateKey), lighterApiKeyIndex) + exchange := &Exchange{ + ID: id, + ExchangeType: exchangeType, + AccountName: accountName, + UserID: userID, + Name: name, + Type: typ, + Enabled: enabled, + APIKey: crypto.EncryptedString(apiKey), + SecretKey: crypto.EncryptedString(secretKey), + Passphrase: crypto.EncryptedString(passphrase), + Testnet: testnet, + HyperliquidWalletAddr: hyperliquidWalletAddr, + AsterUser: asterUser, + AsterSigner: asterSigner, + AsterPrivateKey: crypto.EncryptedString(asterPrivateKey), + LighterWalletAddr: lighterWalletAddr, + LighterPrivateKey: crypto.EncryptedString(lighterPrivateKey), + LighterAPIKeyPrivateKey: crypto.EncryptedString(lighterApiKeyPrivateKey), + LighterAPIKeyIndex: lighterApiKeyIndex, + } - if err != nil { + if err := s.db.Create(exchange).Error; err != nil { return "", err } return id, nil @@ -373,53 +228,42 @@ func (s *ExchangeStore) Update(userID, id string, enabled bool, apiKey, secretKe logger.Debugf("🔧 ExchangeStore.Update: userID=%s, id=%s, enabled=%v", userID, id, enabled) - setClauses := []string{ - "enabled = ?", - "testnet = ?", - "hyperliquid_wallet_addr = ?", - "aster_user = ?", - "aster_signer = ?", - "lighter_wallet_addr = ?", - "lighter_api_key_index = ?", - "updated_at = datetime('now')", + updates := map[string]interface{}{ + "enabled": enabled, + "testnet": testnet, + "hyperliquid_wallet_addr": hyperliquidWalletAddr, + "aster_user": asterUser, + "aster_signer": asterSigner, + "lighter_wallet_addr": lighterWalletAddr, + "lighter_api_key_index": lighterApiKeyIndex, + "updated_at": time.Now(), } - args := []interface{}{enabled, testnet, hyperliquidWalletAddr, asterUser, asterSigner, lighterWalletAddr, lighterApiKeyIndex} + // Only update encrypted fields if not empty if apiKey != "" { - setClauses = append(setClauses, "api_key = ?") - args = append(args, s.encrypt(apiKey)) + updates["api_key"] = crypto.EncryptedString(apiKey) } if secretKey != "" { - setClauses = append(setClauses, "secret_key = ?") - args = append(args, s.encrypt(secretKey)) + updates["secret_key"] = crypto.EncryptedString(secretKey) } if passphrase != "" { - setClauses = append(setClauses, "passphrase = ?") - args = append(args, s.encrypt(passphrase)) + updates["passphrase"] = crypto.EncryptedString(passphrase) } if asterPrivateKey != "" { - setClauses = append(setClauses, "aster_private_key = ?") - args = append(args, s.encrypt(asterPrivateKey)) + updates["aster_private_key"] = crypto.EncryptedString(asterPrivateKey) } if lighterPrivateKey != "" { - setClauses = append(setClauses, "lighter_private_key = ?") - args = append(args, s.encrypt(lighterPrivateKey)) + updates["lighter_private_key"] = crypto.EncryptedString(lighterPrivateKey) } if lighterApiKeyPrivateKey != "" { - setClauses = append(setClauses, "lighter_api_key_private_key = ?") - args = append(args, s.encrypt(lighterApiKeyPrivateKey)) + updates["lighter_api_key_private_key"] = crypto.EncryptedString(lighterApiKeyPrivateKey) } - args = append(args, id, userID) - query := fmt.Sprintf(`UPDATE exchanges SET %s WHERE id = ? AND user_id = ?`, strings.Join(setClauses, ", ")) - - result, err := s.db.Exec(query, args...) - if err != nil { - return err + result := s.db.Model(&Exchange{}).Where("id = ? AND user_id = ?", id, userID).Updates(updates) + if result.Error != nil { + return result.Error } - - rowsAffected, _ := result.RowsAffected() - if rowsAffected == 0 { + if result.RowsAffected == 0 { return fmt.Errorf("exchange not found: id=%s, userID=%s", id, userID) } return nil @@ -427,13 +271,16 @@ func (s *ExchangeStore) Update(userID, id string, enabled bool, apiKey, secretKe // UpdateAccountName updates the account name for an exchange func (s *ExchangeStore) UpdateAccountName(userID, id, accountName string) error { - result, err := s.db.Exec(`UPDATE exchanges SET account_name = ?, updated_at = datetime('now') WHERE id = ? AND user_id = ?`, - accountName, id, userID) - if err != nil { - return err + result := s.db.Model(&Exchange{}). + Where("id = ? AND user_id = ?", id, userID). + Updates(map[string]interface{}{ + "account_name": accountName, + "updated_at": time.Now(), + }) + if result.Error != nil { + return result.Error } - rowsAffected, _ := result.RowsAffected() - if rowsAffected == 0 { + if result.RowsAffected == 0 { return fmt.Errorf("exchange not found: id=%s, userID=%s", id, userID) } return nil @@ -441,12 +288,11 @@ func (s *ExchangeStore) UpdateAccountName(userID, id, accountName string) error // Delete deletes an exchange account func (s *ExchangeStore) Delete(userID, id string) error { - result, err := s.db.Exec(`DELETE FROM exchanges WHERE id = ? AND user_id = ?`, id, userID) - if err != nil { - return err + result := s.db.Where("id = ? AND user_id = ?", id, userID).Delete(&Exchange{}) + if result.Error != nil { + return result.Error } - rowsAffected, _ := result.RowsAffected() - if rowsAffected == 0 { + if result.RowsAffected == 0 { return fmt.Errorf("exchange not found: id=%s, userID=%s", id, userID) } logger.Infof("🗑️ Deleted exchange: id=%s, userID=%s", id, userID) @@ -460,20 +306,25 @@ func (s *ExchangeStore) CreateLegacy(userID, id, name, typ string, enabled bool, // Check if this is an old-style ID (exchange type as ID) if id == "binance" || id == "bybit" || id == "okx" || id == "bitget" || id == "hyperliquid" || id == "aster" || id == "lighter" { - // Use new Create method with exchange type _, err := s.Create(userID, id, "Default", enabled, apiKey, secretKey, "", testnet, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey, "", "", "", 0) return err } // Otherwise assume it's already a UUID - _, err := s.db.Exec(` - INSERT OR IGNORE INTO exchanges (id, exchange_type, account_name, user_id, name, type, enabled, - api_key, secret_key, testnet, - hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key, - lighter_wallet_addr, lighter_private_key) - VALUES (?, '', '', ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, '', '') - `, id, userID, name, typ, enabled, s.encrypt(apiKey), s.encrypt(secretKey), testnet, - hyperliquidWalletAddr, asterUser, asterSigner, s.encrypt(asterPrivateKey)) - return err + exchange := &Exchange{ + ID: id, + UserID: userID, + Name: name, + Type: typ, + Enabled: enabled, + APIKey: crypto.EncryptedString(apiKey), + SecretKey: crypto.EncryptedString(secretKey), + Testnet: testnet, + HyperliquidWalletAddr: hyperliquidWalletAddr, + AsterUser: asterUser, + AsterSigner: asterSigner, + AsterPrivateKey: crypto.EncryptedString(asterPrivateKey), + } + return s.db.Where("id = ?", id).FirstOrCreate(exchange).Error } diff --git a/store/gorm.go b/store/gorm.go new file mode 100644 index 00000000..f3340275 --- /dev/null +++ b/store/gorm.go @@ -0,0 +1,146 @@ +package store + +import ( + "fmt" + + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +// GormDB is the global GORM database connection +var gormDB *gorm.DB + +// DB returns the GORM database connection +func DB() *gorm.DB { + return gormDB +} + +// InitGorm initializes GORM with SQLite +func InitGorm(dbPath string) (*gorm.DB, error) { + db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + if err != nil { + return nil, fmt.Errorf("failed to open SQLite database: %w", err) + } + + // Set connection pool for SQLite + sqlDB, err := db.DB() + if err != nil { + return nil, err + } + sqlDB.SetMaxOpenConns(1) + sqlDB.SetMaxIdleConns(1) + + // Enable foreign keys for SQLite + db.Exec("PRAGMA foreign_keys = ON") + db.Exec("PRAGMA journal_mode = DELETE") + db.Exec("PRAGMA synchronous = FULL") + db.Exec("PRAGMA busy_timeout = 5000") + + gormDB = db + return db, nil +} + +// InitGormPostgres initializes GORM with PostgreSQL +func InitGormPostgres(host string, port int, user, password, dbname, sslmode string) (*gorm.DB, error) { + dsn := fmt.Sprintf( + "host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", + host, port, user, password, dbname, sslmode, + ) + + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + if err != nil { + return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err) + } + + // Set connection pool for PostgreSQL + sqlDB, err := db.DB() + if err != nil { + return nil, err + } + sqlDB.SetMaxOpenConns(25) + sqlDB.SetMaxIdleConns(5) + + gormDB = db + return db, nil +} + +// InitGormWithConfig initializes GORM with provided configuration +// Uses DBConfig from driver.go +func InitGormWithConfig(cfg DBConfig) (*gorm.DB, error) { + switch cfg.Type { + case DBTypeSQLite: + return InitGorm(cfg.Path) + + case DBTypePostgres: + return InitGormPostgres( + cfg.Host, + cfg.Port, + cfg.User, + cfg.Password, + cfg.DBName, + cfg.SSLMode, + ) + + default: + return nil, fmt.Errorf("unsupported DB_TYPE: %s (use 'sqlite' or 'postgres')", cfg.Type) + } +} + +// ============================================================================ +// Query Scopes - Reusable query helpers +// ============================================================================ + +// ForUser returns a scope that filters by user_id +func ForUser(userID string) func(*gorm.DB) *gorm.DB { + return func(db *gorm.DB) *gorm.DB { + return db.Where("user_id = ?", userID) + } +} + +// ForTrader returns a scope that filters by trader_id +func ForTrader(traderID string) func(*gorm.DB) *gorm.DB { + return func(db *gorm.DB) *gorm.DB { + return db.Where("trader_id = ?", traderID) + } +} + +// OpenPositions returns a scope for open positions +func OpenPositions() func(*gorm.DB) *gorm.DB { + return func(db *gorm.DB) *gorm.DB { + return db.Where("status = ?", "OPEN") + } +} + +// ClosedPositions returns a scope for closed positions +func ClosedPositions() func(*gorm.DB) *gorm.DB { + return func(db *gorm.DB) *gorm.DB { + return db.Where("status = ?", "CLOSED") + } +} + +// OrderByCreatedDesc returns a scope that orders by created_at DESC +func OrderByCreatedDesc() func(*gorm.DB) *gorm.DB { + return func(db *gorm.DB) *gorm.DB { + return db.Order("created_at DESC") + } +} + +// OrderByUpdatedDesc returns a scope that orders by updated_at DESC +func OrderByUpdatedDesc() func(*gorm.DB) *gorm.DB { + return func(db *gorm.DB) *gorm.DB { + return db.Order("updated_at DESC") + } +} + +// Paginate returns a scope for pagination +func Paginate(limit, offset int) func(*gorm.DB) *gorm.DB { + return func(db *gorm.DB) *gorm.DB { + return db.Limit(limit).Offset(offset) + } +} diff --git a/store/order.go b/store/order.go index a5c6f4fc..4ae6ef7a 100644 --- a/store/order.go +++ b/store/order.go @@ -1,495 +1,255 @@ package store import ( - "database/sql" "fmt" + "strings" "time" + + "gorm.io/gorm" ) -// TraderOrder 订单记录(完整的订单生命周期追踪) +// TraderOrder order record type TraderOrder struct { - ID int64 `json:"id"` - TraderID string `json:"trader_id"` - ExchangeID string `json:"exchange_id"` // Exchange account UUID - ExchangeType string `json:"exchange_type"` // Exchange type (hyperliquid/lighter/binance/etc) - ExchangeOrderID string `json:"exchange_order_id"` // Exchange order ID - ClientOrderID string `json:"client_order_id"` // Client order ID - Symbol string `json:"symbol"` // Trading pair - Side string `json:"side"` // BUY/SELL - PositionSide string `json:"position_side"` // LONG/SHORT (hedge mode) - Type string `json:"type"` // MARKET/LIMIT/STOP/STOP_MARKET/TAKE_PROFIT/TAKE_PROFIT_MARKET - TimeInForce string `json:"time_in_force"` // GTC/IOC/FOK - Quantity float64 `json:"quantity"` // 订单数量 - Price float64 `json:"price"` // 限价单价格 - StopPrice float64 `json:"stop_price"` // 止损/止盈触发价格 - Status string `json:"status"` // NEW/PARTIALLY_FILLED/FILLED/CANCELED/REJECTED/EXPIRED - FilledQuantity float64 `json:"filled_quantity"` // 已成交数量 - AvgFillPrice float64 `json:"avg_fill_price"` // 平均成交价格 - Commission float64 `json:"commission"` // 手续费总额 - CommissionAsset string `json:"commission_asset"` // 手续费资产(USDT等) - Leverage int `json:"leverage"` // 杠杆倍数 - ReduceOnly bool `json:"reduce_only"` // 是否只减仓 - ClosePosition bool `json:"close_position"` // 是否平仓单 - WorkingType string `json:"working_type"` // CONTRACT_PRICE/MARK_PRICE - PriceProtect bool `json:"price_protect"` // 价格保护 - OrderAction string `json:"order_action"` // OPEN_LONG/OPEN_SHORT/CLOSE_LONG/CLOSE_SHORT/ADD_LONG/ADD_SHORT/STOP_LOSS/TAKE_PROFIT - RelatedPositionID int64 `json:"related_position_id"` // 关联的仓位ID - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - FilledAt time.Time `json:"filled_at"` // 完全成交时间 + ID int64 `gorm:"primaryKey;autoIncrement" json:"id"` + TraderID string `gorm:"column:trader_id;not null;index:idx_orders_trader_id" json:"trader_id"` + ExchangeID string `gorm:"column:exchange_id;not null;default:''" json:"exchange_id"` + ExchangeType string `gorm:"column:exchange_type;not null;default:''" json:"exchange_type"` + ExchangeOrderID string `gorm:"column:exchange_order_id;not null;uniqueIndex:idx_orders_exchange_unique,priority:2" json:"exchange_order_id"` + ClientOrderID string `gorm:"column:client_order_id;default:''" json:"client_order_id"` + Symbol string `gorm:"column:symbol;not null;index:idx_orders_symbol" json:"symbol"` + Side string `gorm:"column:side;not null" json:"side"` + PositionSide string `gorm:"column:position_side;default:''" json:"position_side"` + Type string `gorm:"column:type;not null" json:"type"` + TimeInForce string `gorm:"column:time_in_force;default:GTC" json:"time_in_force"` + Quantity float64 `gorm:"column:quantity;not null" json:"quantity"` + Price float64 `gorm:"column:price;default:0" json:"price"` + StopPrice float64 `gorm:"column:stop_price;default:0" json:"stop_price"` + Status string `gorm:"column:status;not null;default:NEW;index:idx_orders_status" json:"status"` + FilledQuantity float64 `gorm:"column:filled_quantity;default:0" json:"filled_quantity"` + AvgFillPrice float64 `gorm:"column:avg_fill_price;default:0" json:"avg_fill_price"` + Commission float64 `gorm:"column:commission;default:0" json:"commission"` + CommissionAsset string `gorm:"column:commission_asset;default:USDT" json:"commission_asset"` + Leverage int `gorm:"column:leverage;default:1" json:"leverage"` + ReduceOnly bool `gorm:"column:reduce_only;default:false" json:"reduce_only"` + ClosePosition bool `gorm:"column:close_position;default:false" json:"close_position"` + WorkingType string `gorm:"column:working_type;default:CONTRACT_PRICE" json:"working_type"` + PriceProtect bool `gorm:"column:price_protect;default:false" json:"price_protect"` + OrderAction string `gorm:"column:order_action;default:''" json:"order_action"` + RelatedPositionID int64 `gorm:"column:related_position_id;default:0" json:"related_position_id"` + CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"` + UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime" json:"updated_at"` + FilledAt time.Time `gorm:"column:filled_at" json:"filled_at"` } -// TraderFill trade record (one order may have multiple fills) +// TableName returns the table name for TraderOrder +func (TraderOrder) TableName() string { + return "trader_orders" +} + +// TraderFill trade record type TraderFill struct { - ID int64 `json:"id"` - TraderID string `json:"trader_id"` - ExchangeID string `json:"exchange_id"` // Exchange account UUID - ExchangeType string `json:"exchange_type"` // Exchange type (hyperliquid/lighter/binance/etc) - OrderID int64 `json:"order_id"` // Related order ID - ExchangeOrderID string `json:"exchange_order_id"` // Exchange order ID - ExchangeTradeID string `json:"exchange_trade_id"` // Exchange trade ID - Symbol string `json:"symbol"` - Side string `json:"side"` // BUY/SELL - Price float64 `json:"price"` // 成交价格 - Quantity float64 `json:"quantity"` // 成交数量 - QuoteQuantity float64 `json:"quote_quantity"` // 成交金额(USDT) - Commission float64 `json:"commission"` // 手续费 - CommissionAsset string `json:"commission_asset"` - RealizedPnL float64 `json:"realized_pnl"` // 实现盈亏(平仓时) - IsMaker bool `json:"is_maker"` // 是否为maker - CreatedAt time.Time `json:"created_at"` + ID int64 `gorm:"primaryKey;autoIncrement" json:"id"` + TraderID string `gorm:"column:trader_id;not null;index:idx_fills_trader_id" json:"trader_id"` + ExchangeID string `gorm:"column:exchange_id;not null;default:''" json:"exchange_id"` + ExchangeType string `gorm:"column:exchange_type;not null;default:''" json:"exchange_type"` + OrderID int64 `gorm:"column:order_id;not null;index:idx_fills_order_id" json:"order_id"` + ExchangeOrderID string `gorm:"column:exchange_order_id;not null" json:"exchange_order_id"` + ExchangeTradeID string `gorm:"column:exchange_trade_id;not null;uniqueIndex:idx_fills_exchange_unique,priority:2" json:"exchange_trade_id"` + Symbol string `gorm:"column:symbol;not null" json:"symbol"` + Side string `gorm:"column:side;not null" json:"side"` + Price float64 `gorm:"column:price;not null" json:"price"` + Quantity float64 `gorm:"column:quantity;not null" json:"quantity"` + QuoteQuantity float64 `gorm:"column:quote_quantity;not null" json:"quote_quantity"` + Commission float64 `gorm:"column:commission;not null" json:"commission"` + CommissionAsset string `gorm:"column:commission_asset;not null" json:"commission_asset"` + RealizedPnL float64 `gorm:"column:realized_pnl;default:0" json:"realized_pnl"` + IsMaker bool `gorm:"column:is_maker;default:false" json:"is_maker"` + CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"` } -// OrderStore 订单存储 +// TableName returns the table name for TraderFill +func (TraderFill) TableName() string { + return "trader_fills" +} + +// OrderStore order storage type OrderStore struct { - db *sql.DB + db *gorm.DB } -// NewOrderStore 创建订单存储实例 -func NewOrderStore(db *sql.DB) *OrderStore { +// NewOrderStore creates order storage instance +func NewOrderStore(db *gorm.DB) *OrderStore { return &OrderStore{db: db} } -// InitTables 初始化订单表 +// InitTables initializes order tables func (s *OrderStore) InitTables() error { - // 创建订单表 - _, err := s.db.Exec(` - CREATE TABLE IF NOT EXISTS trader_orders ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - trader_id TEXT NOT NULL, - exchange_id TEXT NOT NULL DEFAULT '', - exchange_type TEXT NOT NULL DEFAULT '', - exchange_order_id TEXT NOT NULL, - client_order_id TEXT DEFAULT '', - symbol TEXT NOT NULL, - side TEXT NOT NULL, - position_side TEXT DEFAULT '', - type TEXT NOT NULL, - time_in_force TEXT DEFAULT 'GTC', - quantity REAL NOT NULL, - price REAL DEFAULT 0, - stop_price REAL DEFAULT 0, - status TEXT NOT NULL DEFAULT 'NEW', - filled_quantity REAL DEFAULT 0, - avg_fill_price REAL DEFAULT 0, - commission REAL DEFAULT 0, - commission_asset TEXT DEFAULT 'USDT', - leverage INTEGER DEFAULT 1, - reduce_only INTEGER DEFAULT 0, - close_position INTEGER DEFAULT 0, - working_type TEXT DEFAULT 'CONTRACT_PRICE', - price_protect INTEGER DEFAULT 0, - order_action TEXT DEFAULT '', - related_position_id INTEGER DEFAULT 0, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, - filled_at DATETIME, - UNIQUE(exchange_id, exchange_order_id) - ) - `) - if err != nil { - return fmt.Errorf("failed to create trader_orders table: %w", err) + // For PostgreSQL, check if tables exist to avoid AutoMigrate index conflicts + if s.db.Dialector.Name() == "postgres" { + var ordersExist, fillsExist int64 + s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'trader_orders'`).Scan(&ordersExist) + s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'trader_fills'`).Scan(&fillsExist) + + if ordersExist > 0 && fillsExist > 0 { + // Tables exist - just ensure indexes exist, skip AutoMigrate + s.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_orders_exchange_unique ON trader_orders(exchange_id, exchange_order_id)`) + s.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_fills_exchange_unique ON trader_fills(exchange_id, exchange_trade_id)`) + s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_trader_id ON trader_orders(trader_id)`) + s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_symbol ON trader_orders(symbol)`) + s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_status ON trader_orders(status)`) + s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_fills_trader_id ON trader_fills(trader_id)`) + s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_fills_order_id ON trader_fills(order_id)`) + return nil + } } - // 创建成交记录表 - _, err = s.db.Exec(` - CREATE TABLE IF NOT EXISTS trader_fills ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - trader_id TEXT NOT NULL, - exchange_id TEXT NOT NULL DEFAULT '', - exchange_type TEXT NOT NULL DEFAULT '', - order_id INTEGER NOT NULL, - exchange_order_id TEXT NOT NULL, - exchange_trade_id TEXT NOT NULL, - symbol TEXT NOT NULL, - side TEXT NOT NULL, - price REAL NOT NULL, - quantity REAL NOT NULL, - quote_quantity REAL NOT NULL, - commission REAL NOT NULL, - commission_asset TEXT NOT NULL, - realized_pnl REAL DEFAULT 0, - is_maker INTEGER DEFAULT 0, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - UNIQUE(exchange_id, exchange_trade_id), - FOREIGN KEY (order_id) REFERENCES trader_orders(id) - ) - `) - if err != nil { - return fmt.Errorf("failed to create trader_fills table: %w", err) + if err := s.db.AutoMigrate(&TraderOrder{}, &TraderFill{}); err != nil { + return fmt.Errorf("failed to migrate order tables: %w", err) } - // 创建索引 - s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_trader_id ON trader_orders(trader_id)`) - s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_symbol ON trader_orders(symbol)`) - s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_status ON trader_orders(status)`) - s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_exchange_order_id ON trader_orders(exchange_id, exchange_order_id)`) - s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_fills_order_id ON trader_fills(order_id)`) - s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_fills_trader_id ON trader_fills(trader_id)`) + // Create unique composite index for exchange_id + exchange_order_id + s.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_orders_exchange_unique ON trader_orders(exchange_id, exchange_order_id)`) + // Create unique composite index for exchange_id + exchange_trade_id + s.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_fills_exchange_unique ON trader_fills(exchange_id, exchange_trade_id)`) return nil } -// CreateOrder 创建订单记录(去重:如果订单已存在则返回已有记录) +// CreateOrder creates order record func (s *OrderStore) CreateOrder(order *TraderOrder) error { - // 1. 先检查订单是否已存在(去重) + // Check if order already exists existing, err := s.GetOrderByExchangeID(order.ExchangeID, order.ExchangeOrderID) if err != nil { return fmt.Errorf("failed to check existing order: %w", err) } if existing != nil { - // 订单已存在,返回已有记录的ID order.ID = existing.ID order.CreatedAt = existing.CreatedAt order.UpdatedAt = existing.UpdatedAt - return nil // 不是错误,只是跳过插入 + return nil } - // 2. 订单不存在,插入新记录 - now := time.Now() - order.CreatedAt = now - order.UpdatedAt = now - - result, err := s.db.Exec(` - INSERT INTO trader_orders ( - trader_id, exchange_id, exchange_type, exchange_order_id, client_order_id, - symbol, side, position_side, type, time_in_force, - quantity, price, stop_price, status, - filled_quantity, avg_fill_price, commission, commission_asset, - leverage, reduce_only, close_position, working_type, price_protect, - order_action, related_position_id, - created_at, updated_at, filled_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - `, - order.TraderID, order.ExchangeID, order.ExchangeType, order.ExchangeOrderID, order.ClientOrderID, - order.Symbol, order.Side, order.PositionSide, order.Type, order.TimeInForce, - order.Quantity, order.Price, order.StopPrice, order.Status, - order.FilledQuantity, order.AvgFillPrice, order.Commission, order.CommissionAsset, - order.Leverage, order.ReduceOnly, order.ClosePosition, order.WorkingType, order.PriceProtect, - order.OrderAction, order.RelatedPositionID, - formatTimeOrNow(order.CreatedAt, now), formatTimeOrNow(order.UpdatedAt, now), - formatTimePtr(order.FilledAt), - ) - if err != nil { - return fmt.Errorf("failed to create order: %w", err) - } - - id, _ := result.LastInsertId() - order.ID = id - return nil + return s.db.Create(order).Error } -// UpdateOrderStatus 更新订单状态 +// UpdateOrderStatus updates order status func (s *OrderStore) UpdateOrderStatus(id int64, status string, filledQty, avgPrice, commission float64) error { - now := time.Now() - updateSQL := ` - UPDATE trader_orders SET - status = ?, - filled_quantity = ?, - avg_fill_price = ?, - commission = ?, - updated_at = ? - ` - args := []interface{}{status, filledQty, avgPrice, commission, now.Format(time.RFC3339)} + updates := map[string]interface{}{ + "status": status, + "filled_quantity": filledQty, + "avg_fill_price": avgPrice, + "commission": commission, + } - // 如果完全成交,记录成交时间 if status == "FILLED" { - updateSQL += `, filled_at = ?` - args = append(args, now.Format(time.RFC3339)) + updates["filled_at"] = time.Now() } - updateSQL += ` WHERE id = ?` - args = append(args, id) - - _, err := s.db.Exec(updateSQL, args...) - if err != nil { - return fmt.Errorf("failed to update order status: %w", err) - } - return nil + return s.db.Model(&TraderOrder{}).Where("id = ?", id).Updates(updates).Error } -// CreateFill 创建成交记录(去重:如果成交已存在则跳过) +// CreateFill creates fill record func (s *OrderStore) CreateFill(fill *TraderFill) error { - // 1. 先检查成交是否已存在(去重) + // Check if fill already exists existing, err := s.GetFillByExchangeTradeID(fill.ExchangeID, fill.ExchangeTradeID) if err != nil { return fmt.Errorf("failed to check existing fill: %w", err) } if existing != nil { - // 成交已存在,返回已有记录的ID fill.ID = existing.ID fill.CreatedAt = existing.CreatedAt - return nil // 不是错误,只是跳过插入 + return nil } - // 2. 成交不存在,插入新记录 - now := time.Now() - fill.CreatedAt = now - - result, err := s.db.Exec(` - INSERT INTO trader_fills ( - trader_id, exchange_id, exchange_type, order_id, exchange_order_id, exchange_trade_id, - symbol, side, price, quantity, quote_quantity, - commission, commission_asset, realized_pnl, is_maker, - created_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - `, - fill.TraderID, fill.ExchangeID, fill.ExchangeType, fill.OrderID, fill.ExchangeOrderID, fill.ExchangeTradeID, - fill.Symbol, fill.Side, fill.Price, fill.Quantity, fill.QuoteQuantity, - fill.Commission, fill.CommissionAsset, fill.RealizedPnL, fill.IsMaker, - now.Format(time.RFC3339), - ) - if err != nil { - return fmt.Errorf("failed to create fill: %w", err) - } - - id, _ := result.LastInsertId() - fill.ID = id - return nil + return s.db.Create(fill).Error } -// GetFillByExchangeTradeID 根据交易所成交ID获取成交记录 +// GetFillByExchangeTradeID gets fill by exchange trade ID func (s *OrderStore) GetFillByExchangeTradeID(exchangeID, exchangeTradeID string) (*TraderFill, error) { - row := s.db.QueryRow(` - SELECT id, trader_id, exchange_id, exchange_type, order_id, exchange_order_id, exchange_trade_id, - symbol, side, price, quantity, quote_quantity, - commission, commission_asset, realized_pnl, is_maker, - created_at - FROM trader_fills - WHERE exchange_id = ? AND exchange_trade_id = ? - `, exchangeID, exchangeTradeID) - var fill TraderFill - var createdAt sql.NullString - err := row.Scan( - &fill.ID, &fill.TraderID, &fill.ExchangeID, &fill.ExchangeType, &fill.OrderID, &fill.ExchangeOrderID, &fill.ExchangeTradeID, - &fill.Symbol, &fill.Side, &fill.Price, &fill.Quantity, &fill.QuoteQuantity, - &fill.Commission, &fill.CommissionAsset, &fill.RealizedPnL, &fill.IsMaker, - &createdAt, - ) - if err == sql.ErrNoRows { - return nil, nil - } + err := s.db.Where("exchange_id = ? AND exchange_trade_id = ?", exchangeID, exchangeTradeID).First(&fill).Error if err != nil { + if err == gorm.ErrRecordNotFound { + return nil, nil + } return nil, fmt.Errorf("failed to get fill: %w", err) } - - // Parse time - if createdAt.Valid { - if t, err := time.Parse(time.RFC3339, createdAt.String); err == nil { - fill.CreatedAt = t - } - } - return &fill, nil } -// GetOrderByExchangeID 根据交易所订单ID获取订单 +// GetOrderByExchangeID gets order by exchange order ID func (s *OrderStore) GetOrderByExchangeID(exchangeID, exchangeOrderID string) (*TraderOrder, error) { - row := s.db.QueryRow(` - SELECT id, trader_id, exchange_id, exchange_type, exchange_order_id, client_order_id, - symbol, side, position_side, type, time_in_force, - quantity, price, stop_price, status, - filled_quantity, avg_fill_price, commission, commission_asset, - leverage, reduce_only, close_position, working_type, price_protect, - order_action, related_position_id, - created_at, updated_at, filled_at - FROM trader_orders - WHERE exchange_id = ? AND exchange_order_id = ? - `, exchangeID, exchangeOrderID) - var order TraderOrder - var createdAt, updatedAt, filledAt sql.NullString - err := row.Scan( - &order.ID, &order.TraderID, &order.ExchangeID, &order.ExchangeType, &order.ExchangeOrderID, &order.ClientOrderID, - &order.Symbol, &order.Side, &order.PositionSide, &order.Type, &order.TimeInForce, - &order.Quantity, &order.Price, &order.StopPrice, &order.Status, - &order.FilledQuantity, &order.AvgFillPrice, &order.Commission, &order.CommissionAsset, - &order.Leverage, &order.ReduceOnly, &order.ClosePosition, &order.WorkingType, &order.PriceProtect, - &order.OrderAction, &order.RelatedPositionID, - &createdAt, &updatedAt, &filledAt, - ) - if err == sql.ErrNoRows { - return nil, nil - } + err := s.db.Where("exchange_id = ? AND exchange_order_id = ?", exchangeID, exchangeOrderID).First(&order).Error if err != nil { + if err == gorm.ErrRecordNotFound { + return nil, nil + } return nil, fmt.Errorf("failed to get order: %w", err) } - - // Parse times - if createdAt.Valid { - if t, err := time.Parse(time.RFC3339, createdAt.String); err == nil { - order.CreatedAt = t - } - } - if updatedAt.Valid { - if t, err := time.Parse(time.RFC3339, updatedAt.String); err == nil { - order.UpdatedAt = t - } - } - if filledAt.Valid { - if t, err := time.Parse(time.RFC3339, filledAt.String); err == nil { - order.FilledAt = t - } - } - return &order, nil } -// GetTraderOrders 获取trader的订单列表 +// GetTraderOrders gets trader's order list func (s *OrderStore) GetTraderOrders(traderID string, limit int) ([]*TraderOrder, error) { - rows, err := s.db.Query(` - SELECT id, trader_id, exchange_id, exchange_type, exchange_order_id, client_order_id, - symbol, side, position_side, type, time_in_force, - quantity, price, stop_price, status, - filled_quantity, avg_fill_price, commission, commission_asset, - leverage, reduce_only, close_position, working_type, price_protect, - order_action, related_position_id, - created_at, updated_at, filled_at - FROM trader_orders - WHERE trader_id = ? - ORDER BY created_at DESC - LIMIT ? - `, traderID, limit) + var orders []*TraderOrder + err := s.db.Where("trader_id = ?", traderID). + Order("created_at DESC"). + Limit(limit). + Find(&orders).Error if err != nil { return nil, fmt.Errorf("failed to query orders: %w", err) } - defer rows.Close() - - var orders []*TraderOrder - for rows.Next() { - var order TraderOrder - var createdAt, updatedAt, filledAt sql.NullString - err := rows.Scan( - &order.ID, &order.TraderID, &order.ExchangeID, &order.ExchangeType, &order.ExchangeOrderID, &order.ClientOrderID, - &order.Symbol, &order.Side, &order.PositionSide, &order.Type, &order.TimeInForce, - &order.Quantity, &order.Price, &order.StopPrice, &order.Status, - &order.FilledQuantity, &order.AvgFillPrice, &order.Commission, &order.CommissionAsset, - &order.Leverage, &order.ReduceOnly, &order.ClosePosition, &order.WorkingType, &order.PriceProtect, - &order.OrderAction, &order.RelatedPositionID, - &createdAt, &updatedAt, &filledAt, - ) - if err != nil { - continue - } - - // Parse times - if createdAt.Valid { - if t, err := time.Parse(time.RFC3339, createdAt.String); err == nil { - order.CreatedAt = t - } - } - if updatedAt.Valid { - if t, err := time.Parse(time.RFC3339, updatedAt.String); err == nil { - order.UpdatedAt = t - } - } - if filledAt.Valid { - if t, err := time.Parse(time.RFC3339, filledAt.String); err == nil { - order.FilledAt = t - } - } - - orders = append(orders, &order) - } - return orders, nil } -// GetOrderFills 获取订单的成交记录 +// GetOrderFills gets order's fill records func (s *OrderStore) GetOrderFills(orderID int64) ([]*TraderFill, error) { - rows, err := s.db.Query(` - SELECT id, trader_id, exchange_id, exchange_type, order_id, exchange_order_id, exchange_trade_id, - symbol, side, price, quantity, quote_quantity, - commission, commission_asset, realized_pnl, is_maker, - created_at - FROM trader_fills - WHERE order_id = ? - ORDER BY created_at ASC - `, orderID) + var fills []*TraderFill + err := s.db.Where("order_id = ?", orderID). + Order("created_at ASC"). + Find(&fills).Error if err != nil { return nil, fmt.Errorf("failed to query fills: %w", err) } - defer rows.Close() - - var fills []*TraderFill - for rows.Next() { - var fill TraderFill - var createdAt sql.NullString - err := rows.Scan( - &fill.ID, &fill.TraderID, &fill.ExchangeID, &fill.ExchangeType, &fill.OrderID, &fill.ExchangeOrderID, &fill.ExchangeTradeID, - &fill.Symbol, &fill.Side, &fill.Price, &fill.Quantity, &fill.QuoteQuantity, - &fill.Commission, &fill.CommissionAsset, &fill.RealizedPnL, &fill.IsMaker, - &createdAt, - ) - if err != nil { - continue - } - - if createdAt.Valid { - if t, err := time.Parse(time.RFC3339, createdAt.String); err == nil { - fill.CreatedAt = t - } - } - - fills = append(fills, &fill) - } - return fills, nil } -// GetTraderOrderStats 获取trader的订单统计 +// GetTraderOrderStats gets trader's order statistics func (s *OrderStore) GetTraderOrderStats(traderID string) (map[string]interface{}, error) { - var totalOrders, filledOrders, canceledOrders int - var totalCommission, totalVolume float64 - - err := s.db.QueryRow(` - SELECT - COUNT(*) as total_orders, - SUM(CASE WHEN status = 'FILLED' THEN 1 ELSE 0 END) as filled_orders, - SUM(CASE WHEN status = 'CANCELED' THEN 1 ELSE 0 END) as canceled_orders, - SUM(commission) as total_commission, - SUM(filled_quantity * avg_fill_price) as total_volume - FROM trader_orders - WHERE trader_id = ? - `, traderID).Scan(&totalOrders, &filledOrders, &canceledOrders, &totalCommission, &totalVolume) + type result struct { + TotalOrders int + FilledOrders int + CanceledOrders int + TotalCommission float64 + TotalVolume float64 + } + var r result + err := s.db.Model(&TraderOrder{}). + Select(`COUNT(*) as total_orders, + SUM(CASE WHEN status = 'FILLED' THEN 1 ELSE 0 END) as filled_orders, + SUM(CASE WHEN status = 'CANCELED' THEN 1 ELSE 0 END) as canceled_orders, + SUM(commission) as total_commission, + SUM(filled_quantity * avg_fill_price) as total_volume`). + Where("trader_id = ?", traderID). + Scan(&r).Error if err != nil { return nil, fmt.Errorf("failed to get order stats: %w", err) } return map[string]interface{}{ - "total_orders": totalOrders, - "filled_orders": filledOrders, - "canceled_orders": canceledOrders, - "total_commission": totalCommission, - "total_volume": totalVolume, + "total_orders": r.TotalOrders, + "filled_orders": r.FilledOrders, + "canceled_orders": r.CanceledOrders, + "total_commission": r.TotalCommission, + "total_volume": r.TotalVolume, }, nil } -// CleanupDuplicateOrders 清理重复的订单记录(保留最早创建的记录) +// CleanupDuplicateOrders cleans up duplicate order records func (s *OrderStore) CleanupDuplicateOrders() (int, error) { - result, err := s.db.Exec(` + result := s.db.Exec(` DELETE FROM trader_orders WHERE id NOT IN ( SELECT MIN(id) @@ -497,17 +257,15 @@ func (s *OrderStore) CleanupDuplicateOrders() (int, error) { GROUP BY exchange_id, exchange_order_id ) `) - if err != nil { - return 0, fmt.Errorf("failed to cleanup duplicate orders: %w", err) + if result.Error != nil { + return 0, fmt.Errorf("failed to cleanup duplicate orders: %w", result.Error) } - - rowsAffected, _ := result.RowsAffected() - return int(rowsAffected), nil + return int(result.RowsAffected), nil } -// CleanupDuplicateFills 清理重复的成交记录(保留最早创建的记录) +// CleanupDuplicateFills cleans up duplicate fill records func (s *OrderStore) CleanupDuplicateFills() (int, error) { - result, err := s.db.Exec(` + result := s.db.Exec(` DELETE FROM trader_fills WHERE id NOT IN ( SELECT MIN(id) @@ -515,73 +273,66 @@ func (s *OrderStore) CleanupDuplicateFills() (int, error) { GROUP BY exchange_id, exchange_trade_id ) `) - if err != nil { - return 0, fmt.Errorf("failed to cleanup duplicate fills: %w", err) + if result.Error != nil { + return 0, fmt.Errorf("failed to cleanup duplicate fills: %w", result.Error) } - - rowsAffected, _ := result.RowsAffected() - return int(rowsAffected), nil + return int(result.RowsAffected), nil } -// GetDuplicateOrdersCount 获取重复订单的数量(用于诊断) +// GetDuplicateOrdersCount gets duplicate orders count func (s *OrderStore) GetDuplicateOrdersCount() (int, error) { - var count int - err := s.db.QueryRow(` - SELECT COUNT(*) - COUNT(DISTINCT exchange_id || ',' || exchange_order_id) - FROM trader_orders - `).Scan(&count) - return count, err + var total, distinct int64 + s.db.Model(&TraderOrder{}).Count(&total) + + // Count distinct combinations + var distinctResult struct{ Count int64 } + s.db.Model(&TraderOrder{}). + Select("COUNT(DISTINCT exchange_id || ',' || exchange_order_id) as count"). + Scan(&distinctResult) + distinct = distinctResult.Count + + return int(total - distinct), nil } -// GetDuplicateFillsCount 获取重复成交的数量(用于诊断) +// GetDuplicateFillsCount gets duplicate fills count func (s *OrderStore) GetDuplicateFillsCount() (int, error) { - var count int - err := s.db.QueryRow(` - SELECT COUNT(*) - COUNT(DISTINCT exchange_id || ',' || exchange_trade_id) - FROM trader_fills - `).Scan(&count) - return count, err + var total, distinct int64 + s.db.Model(&TraderFill{}).Count(&total) + + var distinctResult struct{ Count int64 } + s.db.Model(&TraderFill{}). + Select("COUNT(DISTINCT exchange_id || ',' || exchange_trade_id) as count"). + Scan(&distinctResult) + distinct = distinctResult.Count + + return int(total - distinct), nil } // GetMaxTradeIDsByExchange returns max trade ID for each symbol for a given exchange -// Used for incremental sync - only fetch trades with ID > maxTradeID func (s *OrderStore) GetMaxTradeIDsByExchange(exchangeID string) (map[string]int64, error) { - rows, err := s.db.Query(` - SELECT symbol, MAX(CAST(exchange_trade_id AS INTEGER)) as max_trade_id - FROM trader_fills - WHERE exchange_id = ? AND exchange_trade_id != '' - GROUP BY symbol - `, exchangeID) + type symbolMaxID struct { + Symbol string + MaxTradeID int64 + } + var results []symbolMaxID + + err := s.db.Model(&TraderFill{}). + Select("symbol, MAX(CAST(exchange_trade_id AS INTEGER)) as max_trade_id"). + Where("exchange_id = ? AND exchange_trade_id != ''", exchangeID). + Group("symbol"). + Find(&results).Error if err != nil { + // If CAST fails (non-numeric trade IDs), fallback to string comparison + if strings.Contains(err.Error(), "CAST") || strings.Contains(err.Error(), "invalid") { + return make(map[string]int64), nil + } return nil, fmt.Errorf("failed to query max trade IDs: %w", err) } - defer rows.Close() result := make(map[string]int64) - for rows.Next() { - var symbol string - var maxID int64 - if err := rows.Scan(&symbol, &maxID); err != nil { - continue - } - result[symbol] = maxID + for _, r := range results { + result[r.Symbol] = r.MaxTradeID } return result, nil } - -// formatTimePtr formats time.Time to RFC3339 string, returns NULL for zero time -func formatTimePtr(t time.Time) interface{} { - if t.IsZero() { - return nil - } - return t.Format(time.RFC3339) -} - -// formatTimeOrNow returns the formatted time if not zero, otherwise returns now -func formatTimeOrNow(t time.Time, now time.Time) string { - if t.IsZero() { - return now.Format(time.RFC3339) - } - return t.Format(time.RFC3339) -} diff --git a/store/position.go b/store/position.go index cf86c623..3fc42798 100644 --- a/store/position.go +++ b/store/position.go @@ -1,463 +1,338 @@ package store import ( - "database/sql" "fmt" "math" "strings" "time" + + "gorm.io/gorm" ) // TraderStats trading statistics metrics type TraderStats struct { - TotalTrades int `json:"total_trades"` // Total trades (closed) - WinTrades int `json:"win_trades"` // Winning trades - LossTrades int `json:"loss_trades"` // Losing trades - WinRate float64 `json:"win_rate"` // Win rate (%) - ProfitFactor float64 `json:"profit_factor"` // Profit factor - SharpeRatio float64 `json:"sharpe_ratio"` // Sharpe ratio - TotalPnL float64 `json:"total_pnl"` // Total PnL - TotalFee float64 `json:"total_fee"` // Total fees - AvgWin float64 `json:"avg_win"` // Average win - AvgLoss float64 `json:"avg_loss"` // Average loss - MaxDrawdownPct float64 `json:"max_drawdown_pct"` // Max drawdown (%) + TotalTrades int `json:"total_trades"` + WinTrades int `json:"win_trades"` + LossTrades int `json:"loss_trades"` + WinRate float64 `json:"win_rate"` + ProfitFactor float64 `json:"profit_factor"` + SharpeRatio float64 `json:"sharpe_ratio"` + TotalPnL float64 `json:"total_pnl"` + TotalFee float64 `json:"total_fee"` + AvgWin float64 `json:"avg_win"` + AvgLoss float64 `json:"avg_loss"` + MaxDrawdownPct float64 `json:"max_drawdown_pct"` } -// TraderPosition position record (complete open/close position tracking) +// TraderPosition position record type TraderPosition struct { - ID int64 `json:"id"` - TraderID string `json:"trader_id"` - ExchangeID string `json:"exchange_id"` // Exchange account UUID (for multi-account support) - ExchangeType string `json:"exchange_type"` // Exchange type: binance/bybit/okx/hyperliquid/aster/lighter - ExchangePositionID string `json:"exchange_position_id"` // Exchange-specific unique position ID for deduplication - Symbol string `json:"symbol"` - Side string `json:"side"` // LONG/SHORT - EntryQuantity float64 `json:"entry_quantity"` // Original entry quantity (never modified) - Quantity float64 `json:"quantity"` // Remaining quantity (reduced on partial close) - EntryPrice float64 `json:"entry_price"` // Entry price - EntryOrderID string `json:"entry_order_id"` // Entry order ID - EntryTime time.Time `json:"entry_time"` // Entry time - ExitPrice float64 `json:"exit_price"` // Exit price - ExitOrderID string `json:"exit_order_id"` // Exit order ID - ExitTime *time.Time `json:"exit_time"` // Exit time - RealizedPnL float64 `json:"realized_pnl"` // Realized profit and loss - Fee float64 `json:"fee"` // Fee - Leverage int `json:"leverage"` // Leverage multiplier - Status string `json:"status"` // OPEN/CLOSED - CloseReason string `json:"close_reason"` // Close reason: ai_decision/manual/stop_loss/take_profit - Source string `json:"source"` // Source: system/manual/sync - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID int64 `gorm:"primaryKey;autoIncrement" json:"id"` + TraderID string `gorm:"column:trader_id;not null;index:idx_positions_trader" json:"trader_id"` + ExchangeID string `gorm:"column:exchange_id;not null;default:'';index:idx_positions_exchange" json:"exchange_id"` + ExchangeType string `gorm:"column:exchange_type;not null;default:''" json:"exchange_type"` + ExchangePositionID string `gorm:"column:exchange_position_id;not null;default:''" json:"exchange_position_id"` + Symbol string `gorm:"column:symbol;not null" json:"symbol"` + Side string `gorm:"column:side;not null" json:"side"` + EntryQuantity float64 `gorm:"column:entry_quantity;default:0" json:"entry_quantity"` + Quantity float64 `gorm:"column:quantity;not null" json:"quantity"` + EntryPrice float64 `gorm:"column:entry_price;not null" json:"entry_price"` + EntryOrderID string `gorm:"column:entry_order_id;default:''" json:"entry_order_id"` + EntryTime time.Time `gorm:"column:entry_time;not null;index:idx_positions_entry" json:"entry_time"` + ExitPrice float64 `gorm:"column:exit_price;default:0" json:"exit_price"` + ExitOrderID string `gorm:"column:exit_order_id;default:''" json:"exit_order_id"` + ExitTime *time.Time `gorm:"column:exit_time;index:idx_positions_exit" json:"exit_time"` + RealizedPnL float64 `gorm:"column:realized_pnl;default:0" json:"realized_pnl"` + Fee float64 `gorm:"column:fee;default:0" json:"fee"` + Leverage int `gorm:"column:leverage;default:1" json:"leverage"` + Status string `gorm:"column:status;default:OPEN;index:idx_positions_status" json:"status"` + CloseReason string `gorm:"column:close_reason;default:''" json:"close_reason"` + Source string `gorm:"column:source;default:system" json:"source"` + CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"` + UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime" json:"updated_at"` +} + +// TableName returns the table name +func (TraderPosition) TableName() string { + return "trader_positions" } // PositionStore position storage type PositionStore struct { - db *sql.DB + db *gorm.DB } // NewPositionStore creates position storage instance -func NewPositionStore(db *sql.DB) *PositionStore { +func NewPositionStore(db *gorm.DB) *PositionStore { return &PositionStore{db: db} } +// isPostgres checks if the database is PostgreSQL +func (s *PositionStore) isPostgres() bool { + return s.db.Dialector.Name() == "postgres" +} + // InitTables initializes position tables func (s *PositionStore) InitTables() error { - _, err := s.db.Exec(` - CREATE TABLE IF NOT EXISTS trader_positions ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - trader_id TEXT NOT NULL, - exchange_id TEXT NOT NULL DEFAULT '', - exchange_type TEXT NOT NULL DEFAULT '', - exchange_position_id TEXT NOT NULL DEFAULT '', - symbol TEXT NOT NULL, - side TEXT NOT NULL, - quantity REAL NOT NULL, - entry_price REAL NOT NULL, - entry_order_id TEXT DEFAULT '', - entry_time DATETIME NOT NULL, - exit_price REAL DEFAULT 0, - exit_order_id TEXT DEFAULT '', - exit_time DATETIME, - realized_pnl REAL DEFAULT 0, - fee REAL DEFAULT 0, - leverage INTEGER DEFAULT 1, - status TEXT DEFAULT 'OPEN', - close_reason TEXT DEFAULT '', - source TEXT DEFAULT 'system', - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - `) - if err != nil { - return fmt.Errorf("failed to create trader_positions table: %w", err) + // For PostgreSQL with existing table, skip AutoMigrate + if s.isPostgres() { + var tableExists int64 + s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'trader_positions'`).Scan(&tableExists) + if tableExists > 0 { + // Just ensure index exists + s.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_positions_exchange_pos_unique ON trader_positions(exchange_id, exchange_position_id) WHERE exchange_position_id != ''`) + return nil + } } - // Migration: add exchange_id column to existing table (if not exists) - // Must be executed before creating indexes! - s.db.Exec(`ALTER TABLE trader_positions ADD COLUMN exchange_id TEXT NOT NULL DEFAULT ''`) - // Migration: add exchange_type column (binance/bybit/okx/etc) - s.db.Exec(`ALTER TABLE trader_positions ADD COLUMN exchange_type TEXT NOT NULL DEFAULT ''`) - // Migration: add exchange_position_id for deduplication - s.db.Exec(`ALTER TABLE trader_positions ADD COLUMN exchange_position_id TEXT NOT NULL DEFAULT ''`) - // Migration: add source field (system/manual/sync) - s.db.Exec(`ALTER TABLE trader_positions ADD COLUMN source TEXT DEFAULT 'system'`) - // Migration: add entry_quantity field (original quantity, never modified on partial close) - s.db.Exec(`ALTER TABLE trader_positions ADD COLUMN entry_quantity REAL DEFAULT 0`) - // Backfill: set entry_quantity = quantity for existing records where entry_quantity is 0 - s.db.Exec(`UPDATE trader_positions SET entry_quantity = quantity WHERE entry_quantity = 0 OR entry_quantity IS NULL`) - - // Create indexes (after migration) - indices := []string{ - `CREATE INDEX IF NOT EXISTS idx_positions_trader ON trader_positions(trader_id)`, - `CREATE INDEX IF NOT EXISTS idx_positions_exchange ON trader_positions(exchange_id)`, - `CREATE INDEX IF NOT EXISTS idx_positions_status ON trader_positions(trader_id, status)`, - `CREATE INDEX IF NOT EXISTS idx_positions_symbol ON trader_positions(trader_id, symbol, side, status)`, - `CREATE INDEX IF NOT EXISTS idx_positions_entry ON trader_positions(trader_id, entry_time DESC)`, - `CREATE INDEX IF NOT EXISTS idx_positions_exit ON trader_positions(trader_id, exit_time DESC)`, - // Unique index based on exchange_id (account UUID), not trader_id - // This ensures the same position from an exchange account is not duplicated across different traders - `CREATE UNIQUE INDEX IF NOT EXISTS idx_positions_exchange_pos_unique ON trader_positions(exchange_id, exchange_position_id) WHERE exchange_position_id != ''`, + if err := s.db.AutoMigrate(&TraderPosition{}); err != nil { + return fmt.Errorf("failed to migrate trader_positions table: %w", err) } - for _, idx := range indices { - if _, err := s.db.Exec(idx); err != nil { - // Ignore unique index creation errors for existing data - if !strings.Contains(err.Error(), "UNIQUE constraint failed") { - return fmt.Errorf("failed to create index: %w", err) - } + + // Create unique partial index for exchange position deduplication + var indexSQL string + if s.isPostgres() { + indexSQL = `CREATE UNIQUE INDEX IF NOT EXISTS idx_positions_exchange_pos_unique ON trader_positions(exchange_id, exchange_position_id) WHERE exchange_position_id != ''` + } else { + indexSQL = `CREATE UNIQUE INDEX IF NOT EXISTS idx_positions_exchange_pos_unique ON trader_positions(exchange_id, exchange_position_id) WHERE exchange_position_id != ''` + } + if err := s.db.Exec(indexSQL).Error; err != nil { + if !strings.Contains(err.Error(), "already exists") && !strings.Contains(err.Error(), "UNIQUE constraint failed") { + return fmt.Errorf("failed to create unique index: %w", err) } } return nil } -// Create creates position record (called when opening position) +// Create creates position record func (s *PositionStore) Create(pos *TraderPosition) error { - now := time.Now() - pos.CreatedAt = now - pos.UpdatedAt = now pos.Status = "OPEN" - // Set EntryQuantity to same as Quantity if not already set if pos.EntryQuantity == 0 { pos.EntryQuantity = pos.Quantity } - - result, err := s.db.Exec(` - INSERT INTO trader_positions ( - trader_id, exchange_id, exchange_type, symbol, side, quantity, entry_quantity, entry_price, entry_order_id, - entry_time, leverage, status, created_at, updated_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - `, - pos.TraderID, pos.ExchangeID, pos.ExchangeType, pos.Symbol, pos.Side, pos.Quantity, pos.EntryQuantity, pos.EntryPrice, - pos.EntryOrderID, pos.EntryTime.Format(time.RFC3339), pos.Leverage, - pos.Status, now.Format(time.RFC3339), now.Format(time.RFC3339), - ) - if err != nil { - return fmt.Errorf("failed to create position record: %w", err) - } - - id, _ := result.LastInsertId() - pos.ID = id - return nil + return s.db.Create(pos).Error } -// ClosePosition closes position (updates position record) +// ClosePosition closes position func (s *PositionStore) ClosePosition(id int64, exitPrice float64, exitOrderID string, realizedPnL float64, fee float64, closeReason string) error { now := time.Now() - _, err := s.db.Exec(` - UPDATE trader_positions SET - exit_price = ?, exit_order_id = ?, exit_time = ?, - realized_pnl = ?, fee = ?, status = 'CLOSED', - close_reason = ?, updated_at = ? - WHERE id = ? - `, - exitPrice, exitOrderID, now.Format(time.RFC3339), - realizedPnL, fee, closeReason, now.Format(time.RFC3339), id, - ) - if err != nil { - return fmt.Errorf("failed to update position record: %w", err) - } - return nil + return s.db.Model(&TraderPosition{}).Where("id = ?", id).Updates(map[string]interface{}{ + "exit_price": exitPrice, + "exit_order_id": exitOrderID, + "exit_time": now, + "realized_pnl": realizedPnL, + "fee": fee, + "status": "CLOSED", + "close_reason": closeReason, + }).Error } -// UpdatePositionQuantityAndPrice updates position quantity and recalculates entry price (weighted average) when adding to position -// Both quantity and entry_quantity are updated to reflect the new total position size +// UpdatePositionQuantityAndPrice updates position quantity and recalculates entry price func (s *PositionStore) UpdatePositionQuantityAndPrice(id int64, addQty float64, addPrice float64, addFee float64) error { - // First, get current position data - var currentQty, currentEntryQty, currentEntryPrice, currentFee float64 - err := s.db.QueryRow(` - SELECT quantity, COALESCE(entry_quantity, quantity), entry_price, fee FROM trader_positions WHERE id = ? - `, id).Scan(¤tQty, ¤tEntryQty, ¤tEntryPrice, ¤tFee) - if err != nil { + var pos TraderPosition + if err := s.db.First(&pos, id).Error; err != nil { return fmt.Errorf("failed to get current position: %w", err) } - // Calculate weighted average entry price - newQty := currentQty + addQty - newEntryQty := currentEntryQty + addQty - // Round quantity to 4 decimal places to avoid floating point precision issues - newQty = math.Round(newQty*10000) / 10000 - newEntryQty = math.Round(newEntryQty*10000) / 10000 - - newEntryPrice := (currentEntryPrice*currentQty + addPrice*addQty) / newQty - // Round to 2 decimal places to avoid floating point precision issues - newEntryPrice = math.Round(newEntryPrice*100) / 100 - - // Accumulate fees - newFee := currentFee + addFee - - // Update position (both quantity and entry_quantity) - now := time.Now() - _, err = s.db.Exec(` - UPDATE trader_positions SET - quantity = ?, entry_quantity = ?, entry_price = ?, fee = ?, updated_at = ? - WHERE id = ? - `, newQty, newEntryQty, newEntryPrice, newFee, now.Format(time.RFC3339), id) - if err != nil { - return fmt.Errorf("failed to update position quantity and price: %w", err) + currentEntryQty := pos.EntryQuantity + if currentEntryQty == 0 { + currentEntryQty = pos.Quantity } - return nil + + newQty := math.Round((pos.Quantity+addQty)*10000) / 10000 + newEntryQty := math.Round((currentEntryQty+addQty)*10000) / 10000 + newEntryPrice := (pos.EntryPrice*pos.Quantity + addPrice*addQty) / newQty + newEntryPrice = math.Round(newEntryPrice*100) / 100 + newFee := pos.Fee + addFee + + return s.db.Model(&TraderPosition{}).Where("id = ?", id).Updates(map[string]interface{}{ + "quantity": newQty, + "entry_quantity": newEntryQty, + "entry_price": newEntryPrice, + "fee": newFee, + }).Error } -// ReducePositionQuantity reduces position quantity for partial close (keeps status as OPEN) -// Also updates exit_price with weighted average of all partial closes +// ReducePositionQuantity reduces position quantity for partial close func (s *PositionStore) ReducePositionQuantity(id int64, reduceQty float64, exitPrice float64, addFee float64, addPnL float64) error { - // First get current position data - var currentQty, currentFee, currentExitPrice, entryQty, currentPnL float64 - err := s.db.QueryRow(`SELECT quantity, fee, exit_price, entry_quantity, realized_pnl FROM trader_positions WHERE id = ?`, id).Scan(¤tQty, ¤tFee, ¤tExitPrice, &entryQty, ¤tPnL) - if err != nil { + var pos TraderPosition + if err := s.db.First(&pos, id).Error; err != nil { return fmt.Errorf("failed to get current position: %w", err) } - // Calculate new quantity and fee - newQty := math.Round((currentQty-reduceQty)*10000) / 10000 - newFee := currentFee + addFee - newPnL := currentPnL + addPnL + newQty := math.Round((pos.Quantity-reduceQty)*10000) / 10000 + newFee := pos.Fee + addFee + newPnL := pos.RealizedPnL + addPnL - // Calculate weighted average exit price - // closedQty = entryQty - currentQty (already closed before this trade) - // newClosedQty = closedQty + reduceQty (total closed after this trade) - closedQty := entryQty - currentQty + closedQty := pos.EntryQuantity - pos.Quantity newClosedQty := closedQty + reduceQty var newExitPrice float64 if newClosedQty > 0 { - // Weighted average: (old_exit * old_closed + new_price * new_close) / total_closed - newExitPrice = (currentExitPrice*closedQty + exitPrice*reduceQty) / newClosedQty - newExitPrice = math.Round(newExitPrice*100) / 100 // Round to 2 decimal places + newExitPrice = (pos.ExitPrice*closedQty + exitPrice*reduceQty) / newClosedQty + newExitPrice = math.Round(newExitPrice*100) / 100 } - now := time.Now() - _, err = s.db.Exec(` - UPDATE trader_positions SET - quantity = ?, - fee = ?, - exit_price = ?, - realized_pnl = ?, - updated_at = ? - WHERE id = ? - `, newQty, newFee, newExitPrice, newPnL, now.Format(time.RFC3339), id) - if err != nil { - return fmt.Errorf("failed to reduce position quantity: %w", err) - } - return nil + return s.db.Model(&TraderPosition{}).Where("id = ?", id).Updates(map[string]interface{}{ + "quantity": newQty, + "fee": newFee, + "exit_price": newExitPrice, + "realized_pnl": newPnL, + }).Error } -// UpdatePositionExchangeInfo updates exchange_id and exchange_type for a position +// UpdatePositionExchangeInfo updates exchange_id and exchange_type func (s *PositionStore) UpdatePositionExchangeInfo(id int64, exchangeID, exchangeType string) error { - now := time.Now() - _, err := s.db.Exec(` - UPDATE trader_positions SET - exchange_id = ?, - exchange_type = ?, - updated_at = ? - WHERE id = ? - `, exchangeID, exchangeType, now.Format(time.RFC3339), id) - if err != nil { - return fmt.Errorf("failed to update position exchange info: %w", err) - } - return nil + return s.db.Model(&TraderPosition{}).Where("id = ?", id).Updates(map[string]interface{}{ + "exchange_id": exchangeID, + "exchange_type": exchangeType, + }).Error } -// ClosePositionFully marks position as fully closed with exit time and accumulated PnL -func (s *PositionStore) ClosePositionFully( - id int64, - exitPrice float64, - exitOrderID string, - exitTime time.Time, - totalRealizedPnL float64, - totalFee float64, - closeReason string, -) error { - now := time.Now() - // When closing, restore quantity to entry_quantity so closed position shows original size - _, err := s.db.Exec(` - UPDATE trader_positions SET - quantity = CASE WHEN entry_quantity > 0 THEN entry_quantity ELSE quantity END, - exit_price = ?, - exit_order_id = ?, - exit_time = ?, - realized_pnl = ?, - fee = ?, - status = 'CLOSED', - close_reason = ?, - updated_at = ? - WHERE id = ? - `, - exitPrice, exitOrderID, exitTime.Format(time.RFC3339), - totalRealizedPnL, totalFee, closeReason, now.Format(time.RFC3339), id, - ) - if err != nil { - return fmt.Errorf("failed to close position: %w", err) +// ClosePositionFully marks position as fully closed +func (s *PositionStore) ClosePositionFully(id int64, exitPrice float64, exitOrderID string, exitTime time.Time, totalRealizedPnL float64, totalFee float64, closeReason string) error { + var pos TraderPosition + if err := s.db.First(&pos, id).Error; err != nil { + return fmt.Errorf("failed to get position: %w", err) } - return nil + + quantity := pos.Quantity + if pos.EntryQuantity > 0 { + quantity = pos.EntryQuantity + } + + return s.db.Model(&TraderPosition{}).Where("id = ?", id).Updates(map[string]interface{}{ + "quantity": quantity, + "exit_price": exitPrice, + "exit_order_id": exitOrderID, + "exit_time": exitTime, + "realized_pnl": totalRealizedPnL, + "fee": totalFee, + "status": "CLOSED", + "close_reason": closeReason, + }).Error } -// DeleteAllOpenPositions deletes all OPEN positions for a trader (used for snapshot reset) +// DeleteAllOpenPositions deletes all OPEN positions for a trader func (s *PositionStore) DeleteAllOpenPositions(traderID string) error { - _, err := s.db.Exec(` - DELETE FROM trader_positions WHERE trader_id = ? AND status = 'OPEN' - `, traderID) - if err != nil { - return fmt.Errorf("failed to delete open positions: %w", err) - } - return nil + return s.db.Where("trader_id = ? AND status = ?", traderID, "OPEN").Delete(&TraderPosition{}).Error } // GetOpenPositions gets all open positions func (s *PositionStore) GetOpenPositions(traderID string) ([]*TraderPosition, error) { - rows, err := s.db.Query(` - SELECT id, trader_id, exchange_id, COALESCE(exchange_type, '') as exchange_type, symbol, side, quantity, COALESCE(entry_quantity, quantity) as entry_quantity, entry_price, entry_order_id, - entry_time, exit_price, exit_order_id, exit_time, realized_pnl, fee, - leverage, status, close_reason, created_at, updated_at - FROM trader_positions - WHERE trader_id = ? AND status = 'OPEN' - ORDER BY entry_time DESC - `, traderID) + var positions []*TraderPosition + err := s.db.Where("trader_id = ? AND status = ?", traderID, "OPEN"). + Order("entry_time DESC"). + Find(&positions).Error if err != nil { return nil, fmt.Errorf("failed to query open positions: %w", err) } - defer rows.Close() - return s.scanPositions(rows) + // Fix EntryQuantity if it's 0 + for _, pos := range positions { + if pos.EntryQuantity == 0 { + pos.EntryQuantity = pos.Quantity + } + } + return positions, nil } // GetOpenPositionBySymbol gets open position for specified symbol and direction -// It tries both the normalized symbol (ETHUSDT) and base symbol (ETH) for compatibility func (s *PositionStore) GetOpenPositionBySymbol(traderID, symbol, side string) (*TraderPosition, error) { var pos TraderPosition - var entryTime, exitTime, createdAt, updatedAt sql.NullString + err := s.db.Where("trader_id = ? AND symbol = ? AND side = ? AND status = ?", traderID, symbol, side, "OPEN"). + Order("entry_time DESC"). + First(&pos).Error - // Try with the exact symbol first - err := s.db.QueryRow(` - SELECT id, trader_id, exchange_id, COALESCE(exchange_type, '') as exchange_type, symbol, side, quantity, COALESCE(entry_quantity, quantity) as entry_quantity, entry_price, entry_order_id, - entry_time, exit_price, exit_order_id, exit_time, realized_pnl, fee, - leverage, status, close_reason, created_at, updated_at - FROM trader_positions - WHERE trader_id = ? AND symbol = ? AND side = ? AND status = 'OPEN' - ORDER BY entry_time DESC LIMIT 1 - `, traderID, symbol, side).Scan( - &pos.ID, &pos.TraderID, &pos.ExchangeID, &pos.ExchangeType, &pos.Symbol, &pos.Side, &pos.Quantity, &pos.EntryQuantity, - &pos.EntryPrice, &pos.EntryOrderID, &entryTime, &pos.ExitPrice, - &pos.ExitOrderID, &exitTime, &pos.RealizedPnL, &pos.Fee, - &pos.Leverage, &pos.Status, &pos.CloseReason, &createdAt, &updatedAt, - ) if err == nil { - s.parsePositionTimes(&pos, entryTime, exitTime, createdAt, updatedAt) + if pos.EntryQuantity == 0 { + pos.EntryQuantity = pos.Quantity + } return &pos, nil } - // If not found and symbol ends with USDT, try without USDT suffix (for backward compatibility) - if err == sql.ErrNoRows && strings.HasSuffix(symbol, "USDT") { - baseSymbol := strings.TrimSuffix(symbol, "USDT") - err = s.db.QueryRow(` - SELECT id, trader_id, exchange_id, COALESCE(exchange_type, '') as exchange_type, symbol, side, quantity, COALESCE(entry_quantity, quantity) as entry_quantity, entry_price, entry_order_id, - entry_time, exit_price, exit_order_id, exit_time, realized_pnl, fee, - leverage, status, close_reason, created_at, updated_at - FROM trader_positions - WHERE trader_id = ? AND symbol = ? AND side = ? AND status = 'OPEN' - ORDER BY entry_time DESC LIMIT 1 - `, traderID, baseSymbol, side).Scan( - &pos.ID, &pos.TraderID, &pos.ExchangeID, &pos.ExchangeType, &pos.Symbol, &pos.Side, &pos.Quantity, &pos.EntryQuantity, - &pos.EntryPrice, &pos.EntryOrderID, &entryTime, &pos.ExitPrice, - &pos.ExitOrderID, &exitTime, &pos.RealizedPnL, &pos.Fee, - &pos.Leverage, &pos.Status, &pos.CloseReason, &createdAt, &updatedAt, - ) - if err == nil { - s.parsePositionTimes(&pos, entryTime, exitTime, createdAt, updatedAt) - return &pos, nil + if err == gorm.ErrRecordNotFound { + // Try without USDT suffix for backward compatibility + if strings.HasSuffix(symbol, "USDT") { + baseSymbol := strings.TrimSuffix(symbol, "USDT") + err = s.db.Where("trader_id = ? AND symbol = ? AND side = ? AND status = ?", traderID, baseSymbol, side, "OPEN"). + Order("entry_time DESC"). + First(&pos).Error + if err == nil { + if pos.EntryQuantity == 0 { + pos.EntryQuantity = pos.Quantity + } + return &pos, nil + } } - } - - if err == sql.ErrNoRows { return nil, nil } return nil, err } -// GetClosedPositions gets closed positions (historical records) +// GetClosedPositions gets closed positions func (s *PositionStore) GetClosedPositions(traderID string, limit int) ([]*TraderPosition, error) { - rows, err := s.db.Query(` - SELECT id, trader_id, exchange_id, COALESCE(exchange_type, '') as exchange_type, symbol, side, quantity, COALESCE(entry_quantity, quantity) as entry_quantity, entry_price, entry_order_id, - entry_time, exit_price, exit_order_id, exit_time, realized_pnl, fee, - leverage, status, close_reason, created_at, updated_at - FROM trader_positions - WHERE trader_id = ? AND status = 'CLOSED' - ORDER BY exit_time DESC - LIMIT ? - `, traderID, limit) + var positions []*TraderPosition + err := s.db.Where("trader_id = ? AND status = ?", traderID, "CLOSED"). + Order("exit_time DESC"). + Limit(limit). + Find(&positions).Error if err != nil { return nil, fmt.Errorf("failed to query closed positions: %w", err) } - defer rows.Close() - return s.scanPositions(rows) + for _, pos := range positions { + if pos.EntryQuantity == 0 { + pos.EntryQuantity = pos.Quantity + } + } + return positions, nil } -// GetAllOpenPositions gets all traders' open positions (for global sync) +// GetAllOpenPositions gets all traders' open positions func (s *PositionStore) GetAllOpenPositions() ([]*TraderPosition, error) { - rows, err := s.db.Query(` - SELECT id, trader_id, exchange_id, COALESCE(exchange_type, '') as exchange_type, symbol, side, quantity, COALESCE(entry_quantity, quantity) as entry_quantity, entry_price, entry_order_id, - entry_time, exit_price, exit_order_id, exit_time, realized_pnl, fee, - leverage, status, close_reason, created_at, updated_at - FROM trader_positions - WHERE status = 'OPEN' - ORDER BY trader_id, entry_time DESC - `) + var positions []*TraderPosition + err := s.db.Where("status = ?", "OPEN"). + Order("trader_id, entry_time DESC"). + Find(&positions).Error if err != nil { return nil, fmt.Errorf("failed to query all open positions: %w", err) } - defer rows.Close() - return s.scanPositions(rows) + for _, pos := range positions { + if pos.EntryQuantity == 0 { + pos.EntryQuantity = pos.Quantity + } + } + return positions, nil } -// GetPositionStats gets position statistics (simplified version) +// GetPositionStats gets position statistics func (s *PositionStore) GetPositionStats(traderID string) (map[string]interface{}, error) { stats := make(map[string]interface{}) - // Total trades - var totalTrades, winTrades int - var totalPnL, totalFee float64 + type result struct { + Total int + Wins int + TotalPnL float64 + TotalFee float64 + } + var r result - err := s.db.QueryRow(` - SELECT - COUNT(*) as total, - SUM(CASE WHEN realized_pnl > 0 THEN 1 ELSE 0 END) as wins, - COALESCE(SUM(realized_pnl), 0) as total_pnl, - COALESCE(SUM(fee), 0) as total_fee - FROM trader_positions - WHERE trader_id = ? AND status = 'CLOSED' - `, traderID).Scan(&totalTrades, &winTrades, &totalPnL, &totalFee) + err := s.db.Model(&TraderPosition{}). + Select("COUNT(*) as total, SUM(CASE WHEN realized_pnl > 0 THEN 1 ELSE 0 END) as wins, COALESCE(SUM(realized_pnl), 0) as total_pnl, COALESCE(SUM(fee), 0) as total_fee"). + Where("trader_id = ? AND status = ?", traderID, "CLOSED"). + Scan(&r).Error if err != nil { return nil, err } - stats["total_trades"] = totalTrades - stats["win_trades"] = winTrades - stats["total_pnl"] = totalPnL - stats["total_fee"] = totalFee - if totalTrades > 0 { - stats["win_rate"] = float64(winTrades) / float64(totalTrades) * 100 + stats["total_trades"] = r.Total + stats["win_trades"] = r.Wins + stats["total_pnl"] = r.TotalPnL + stats["total_fee"] = r.TotalFee + if r.Total > 0 { + stats["win_rate"] = float64(r.Wins) / float64(r.Total) * 100 } else { stats["win_rate"] = 0.0 } @@ -465,79 +340,59 @@ func (s *PositionStore) GetPositionStats(traderID string) (map[string]interface{ return stats, nil } -// GetFullStats gets complete trading statistics (compatible with TraderStats) +// GetFullStats gets complete trading statistics func (s *PositionStore) GetFullStats(traderID string) (*TraderStats, error) { stats := &TraderStats{} - // First check how many rows exist - var count int - if err := s.db.QueryRow(`SELECT COUNT(*) FROM trader_positions WHERE trader_id = ? AND status = 'CLOSED'`, traderID).Scan(&count); err == nil { - if count == 0 { - // No closed positions, return empty stats - return stats, nil - } + var count int64 + if err := s.db.Model(&TraderPosition{}).Where("trader_id = ? AND status = ?", traderID, "CLOSED").Count(&count).Error; err != nil { + return nil, err + } + if count == 0 { + return stats, nil } - // Query all closed positions - rows, err := s.db.Query(` - SELECT realized_pnl, fee, exit_time - FROM trader_positions - WHERE trader_id = ? AND status = 'CLOSED' - ORDER BY exit_time ASC - `, traderID) + var positions []TraderPosition + err := s.db.Where("trader_id = ? AND status = ?", traderID, "CLOSED"). + Order("exit_time ASC"). + Find(&positions).Error if err != nil { return nil, fmt.Errorf("failed to query position statistics: %w", err) } - defer rows.Close() var pnls []float64 var totalWin, totalLoss float64 - for rows.Next() { - var pnl, fee float64 - var exitTime sql.NullString - if err := rows.Scan(&pnl, &fee, &exitTime); err != nil { - continue - } - + for _, pos := range positions { stats.TotalTrades++ - stats.TotalPnL += pnl - stats.TotalFee += fee - pnls = append(pnls, pnl) + stats.TotalPnL += pos.RealizedPnL + stats.TotalFee += pos.Fee + pnls = append(pnls, pos.RealizedPnL) - if pnl > 0 { + if pos.RealizedPnL > 0 { stats.WinTrades++ - totalWin += pnl - } else if pnl < 0 { + totalWin += pos.RealizedPnL + } else if pos.RealizedPnL < 0 { stats.LossTrades++ - totalLoss += -pnl // Convert to positive + totalLoss += -pos.RealizedPnL } } - // Calculate win rate if stats.TotalTrades > 0 { stats.WinRate = float64(stats.WinTrades) / float64(stats.TotalTrades) * 100 } - - // Calculate profit factor if totalLoss > 0 { stats.ProfitFactor = totalWin / totalLoss } - - // Calculate average profit/loss if stats.WinTrades > 0 { stats.AvgWin = totalWin / float64(stats.WinTrades) } if stats.LossTrades > 0 { stats.AvgLoss = totalLoss / float64(stats.LossTrades) } - - // Calculate Sharpe ratio if len(pnls) > 1 { stats.SharpeRatio = calculateSharpeRatioFromPnls(pnls) } - - // Calculate maximum drawdown if len(pnls) > 0 { stats.MaxDrawdownPct = calculateMaxDrawdownFromPnls(pnls) } @@ -545,89 +400,62 @@ func (s *PositionStore) GetFullStats(traderID string) (*TraderStats, error) { return stats, nil } -// RecentTrade recent trade record (for AI input) +// RecentTrade recent trade record type RecentTrade struct { Symbol string `json:"symbol"` - Side string `json:"side"` // long/short + Side string `json:"side"` EntryPrice float64 `json:"entry_price"` ExitPrice float64 `json:"exit_price"` RealizedPnL float64 `json:"realized_pnl"` PnLPct float64 `json:"pnl_pct"` - EntryTime int64 `json:"entry_time"` // Entry time Unix timestamp (seconds) - ExitTime int64 `json:"exit_time"` // Exit time Unix timestamp (seconds) - HoldDuration string `json:"hold_duration"` // Hold duration (持仓时长), e.g. "2h30m" + EntryTime int64 `json:"entry_time"` + ExitTime int64 `json:"exit_time"` + HoldDuration string `json:"hold_duration"` } // GetRecentTrades gets recent closed trades func (s *PositionStore) GetRecentTrades(traderID string, limit int) ([]RecentTrade, error) { - rows, err := s.db.Query(` - SELECT symbol, side, entry_price, exit_price, realized_pnl, leverage, entry_time, exit_time - FROM trader_positions - WHERE trader_id = ? AND status = 'CLOSED' - ORDER BY exit_time DESC - LIMIT ? - `, traderID, limit) + var positions []TraderPosition + err := s.db.Where("trader_id = ? AND status = ?", traderID, "CLOSED"). + Order("exit_time DESC"). + Limit(limit). + Find(&positions).Error if err != nil { return nil, fmt.Errorf("failed to query recent trades: %w", err) } - defer rows.Close() var trades []RecentTrade - for rows.Next() { - var t RecentTrade - var leverage int - var entryTime, exitTime sql.NullString - - err := rows.Scan(&t.Symbol, &t.Side, &t.EntryPrice, &t.ExitPrice, &t.RealizedPnL, &leverage, &entryTime, &exitTime) - if err != nil { - continue + for _, pos := range positions { + t := RecentTrade{ + Symbol: pos.Symbol, + Side: strings.ToLower(pos.Side), + EntryPrice: pos.EntryPrice, + ExitPrice: pos.ExitPrice, + RealizedPnL: pos.RealizedPnL, + EntryTime: pos.EntryTime.Unix(), } - // Convert side format - if t.Side == "LONG" { - t.Side = "long" - } else if t.Side == "SHORT" { - t.Side = "short" - } - - // Calculate profit/loss percentage - if t.EntryPrice > 0 { - if t.Side == "long" { - t.PnLPct = (t.ExitPrice - t.EntryPrice) / t.EntryPrice * 100 * float64(leverage) - } else { - t.PnLPct = (t.EntryPrice - t.ExitPrice) / t.EntryPrice * 100 * float64(leverage) - } - } - - // Parse entry time and exit time, return as Unix timestamps (seconds) - var parsedEntryTime, parsedExitTime time.Time - if entryTime.Valid { - if parsed, err := time.Parse(time.RFC3339, entryTime.String); err == nil { - parsedEntryTime = parsed.UTC() - t.EntryTime = parsedEntryTime.Unix() // Unix timestamp in seconds - } - } - if exitTime.Valid { - if parsed, err := time.Parse(time.RFC3339, exitTime.String); err == nil { - parsedExitTime = parsed.UTC() - t.ExitTime = parsedExitTime.Unix() // Unix timestamp in seconds - } - } - - // Calculate hold duration - if !parsedEntryTime.IsZero() && !parsedExitTime.IsZero() { - duration := parsedExitTime.Sub(parsedEntryTime) + if pos.ExitTime != nil { + t.ExitTime = pos.ExitTime.Unix() + duration := pos.ExitTime.Sub(pos.EntryTime) t.HoldDuration = formatDuration(duration) } + if pos.EntryPrice > 0 { + if t.Side == "long" { + t.PnLPct = (pos.ExitPrice - pos.EntryPrice) / pos.EntryPrice * 100 * float64(pos.Leverage) + } else { + t.PnLPct = (pos.EntryPrice - pos.ExitPrice) / pos.EntryPrice * 100 * float64(pos.Leverage) + } + } + trades = append(trades, t) } return trades, nil } -// formatDuration formats a duration into a human-readable string -// e.g. "2d3h", "5h30m", "45m", "30s" +// formatDuration formats a duration func formatDuration(d time.Duration) string { if d < time.Minute { return fmt.Sprintf("%ds", int(d.Seconds())) @@ -677,13 +505,11 @@ func calculateSharpeRatioFromPnls(pnls []float64) float64 { } // calculateMaxDrawdownFromPnls calculates maximum drawdown -// Uses a virtual starting equity of 10000 to calculate percentage drawdown func calculateMaxDrawdownFromPnls(pnls []float64) float64 { if len(pnls) == 0 { return 0 } - // Use virtual starting equity for percentage calculation const startingEquity = 10000.0 equity := startingEquity peak := startingEquity @@ -705,47 +531,6 @@ func calculateMaxDrawdownFromPnls(pnls []float64) float64 { return maxDD } -// scanPositions scans position rows into structs -func (s *PositionStore) scanPositions(rows *sql.Rows) ([]*TraderPosition, error) { - var positions []*TraderPosition - for rows.Next() { - var pos TraderPosition - var entryTime, exitTime, createdAt, updatedAt sql.NullString - - err := rows.Scan( - &pos.ID, &pos.TraderID, &pos.ExchangeID, &pos.ExchangeType, &pos.Symbol, &pos.Side, &pos.Quantity, &pos.EntryQuantity, - &pos.EntryPrice, &pos.EntryOrderID, &entryTime, &pos.ExitPrice, - &pos.ExitOrderID, &exitTime, &pos.RealizedPnL, &pos.Fee, - &pos.Leverage, &pos.Status, &pos.CloseReason, &createdAt, &updatedAt, - ) - if err != nil { - continue - } - - s.parsePositionTimes(&pos, entryTime, exitTime, createdAt, updatedAt) - positions = append(positions, &pos) - } - - return positions, nil -} - -// parsePositionTimes parses time fields -func (s *PositionStore) parsePositionTimes(pos *TraderPosition, entryTime, exitTime, createdAt, updatedAt sql.NullString) { - if entryTime.Valid { - pos.EntryTime, _ = time.Parse(time.RFC3339, entryTime.String) - } - if exitTime.Valid { - t, _ := time.Parse(time.RFC3339, exitTime.String) - pos.ExitTime = &t - } - if createdAt.Valid { - pos.CreatedAt, _ = time.Parse(time.RFC3339, createdAt.String) - } - if updatedAt.Valid { - pos.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt.String) - } -} - // SymbolStats per-symbol trading statistics type SymbolStats struct { Symbol string `json:"symbol"` @@ -754,97 +539,137 @@ type SymbolStats struct { WinRate float64 `json:"win_rate"` TotalPnL float64 `json:"total_pnl"` AvgPnL float64 `json:"avg_pnl"` - AvgHoldMins float64 `json:"avg_hold_mins"` // Average holding time in minutes + AvgHoldMins float64 `json:"avg_hold_mins"` } // GetSymbolStats gets per-symbol trading statistics func (s *PositionStore) GetSymbolStats(traderID string, limit int) ([]SymbolStats, error) { - rows, err := s.db.Query(` - SELECT - symbol, - COUNT(*) as total_trades, - SUM(CASE WHEN realized_pnl > 0 THEN 1 ELSE 0 END) as win_trades, - COALESCE(SUM(realized_pnl), 0) as total_pnl, - COALESCE(AVG(realized_pnl), 0) as avg_pnl, - COALESCE(AVG((julianday(exit_time) - julianday(entry_time)) * 24 * 60), 0) as avg_hold_mins - FROM trader_positions - WHERE trader_id = ? AND status = 'CLOSED' - GROUP BY symbol - ORDER BY total_pnl DESC - LIMIT ? - `, traderID, limit) + var positions []TraderPosition + err := s.db.Where("trader_id = ? AND status = ?", traderID, "CLOSED").Find(&positions).Error if err != nil { return nil, fmt.Errorf("failed to query symbol stats: %w", err) } - defer rows.Close() + + // Group by symbol + symbolMap := make(map[string]*SymbolStats) + symbolHoldMins := make(map[string][]float64) + + for _, pos := range positions { + if _, ok := symbolMap[pos.Symbol]; !ok { + symbolMap[pos.Symbol] = &SymbolStats{Symbol: pos.Symbol} + symbolHoldMins[pos.Symbol] = []float64{} + } + s := symbolMap[pos.Symbol] + s.TotalTrades++ + s.TotalPnL += pos.RealizedPnL + if pos.RealizedPnL > 0 { + s.WinTrades++ + } + + if pos.ExitTime != nil { + holdMins := pos.ExitTime.Sub(pos.EntryTime).Minutes() + symbolHoldMins[pos.Symbol] = append(symbolHoldMins[pos.Symbol], holdMins) + } + } var stats []SymbolStats - for rows.Next() { - var s SymbolStats - err := rows.Scan(&s.Symbol, &s.TotalTrades, &s.WinTrades, &s.TotalPnL, &s.AvgPnL, &s.AvgHoldMins) - if err != nil { - continue - } + for symbol, s := range symbolMap { if s.TotalTrades > 0 { s.WinRate = float64(s.WinTrades) / float64(s.TotalTrades) * 100 + s.AvgPnL = s.TotalPnL / float64(s.TotalTrades) } - stats = append(stats, s) + if len(symbolHoldMins[symbol]) > 0 { + var totalMins float64 + for _, m := range symbolHoldMins[symbol] { + totalMins += m + } + s.AvgHoldMins = totalMins / float64(len(symbolHoldMins[symbol])) + } + stats = append(stats, *s) } + + // Sort by TotalPnL descending and limit + for i := 0; i < len(stats)-1; i++ { + for j := i + 1; j < len(stats); j++ { + if stats[j].TotalPnL > stats[i].TotalPnL { + stats[i], stats[j] = stats[j], stats[i] + } + } + } + + if limit > 0 && len(stats) > limit { + stats = stats[:limit] + } + return stats, nil } // HoldingTimeStats holding duration analysis type HoldingTimeStats struct { - Range string `json:"range"` // e.g., "<1h", "1-4h", "4-24h", ">24h" - TradeCount int `json:"trade_count"` - WinRate float64 `json:"win_rate"` - AvgPnL float64 `json:"avg_pnl"` + Range string `json:"range"` + TradeCount int `json:"trade_count"` + WinRate float64 `json:"win_rate"` + AvgPnL float64 `json:"avg_pnl"` } // GetHoldingTimeStats analyzes performance by holding duration func (s *PositionStore) GetHoldingTimeStats(traderID string) ([]HoldingTimeStats, error) { - rows, err := s.db.Query(` - WITH holding AS ( - SELECT - realized_pnl, - (julianday(exit_time) - julianday(entry_time)) * 24 as hold_hours - FROM trader_positions - WHERE trader_id = ? AND status = 'CLOSED' AND exit_time IS NOT NULL - ) - SELECT - CASE - WHEN hold_hours < 1 THEN '<1h' - WHEN hold_hours < 4 THEN '1-4h' - WHEN hold_hours < 24 THEN '4-24h' - ELSE '>24h' - END as time_range, - COUNT(*) as trade_count, - SUM(CASE WHEN realized_pnl > 0 THEN 1.0 ELSE 0.0 END) / COUNT(*) * 100 as win_rate, - AVG(realized_pnl) as avg_pnl - FROM holding - GROUP BY time_range - ORDER BY - CASE time_range - WHEN '<1h' THEN 1 - WHEN '1-4h' THEN 2 - WHEN '4-24h' THEN 3 - ELSE 4 - END - `, traderID) + var positions []TraderPosition + err := s.db.Where("trader_id = ? AND status = ? AND exit_time IS NOT NULL", traderID, "CLOSED").Find(&positions).Error if err != nil { return nil, fmt.Errorf("failed to query holding time stats: %w", err) } - defer rows.Close() - var stats []HoldingTimeStats - for rows.Next() { - var s HoldingTimeStats - err := rows.Scan(&s.Range, &s.TradeCount, &s.WinRate, &s.AvgPnL) - if err != nil { + rangeStats := map[string]*struct { + count int + wins int + totalPnL float64 + }{ + "<1h": {}, + "1-4h": {}, + "4-24h": {}, + ">24h": {}, + } + + for _, pos := range positions { + if pos.ExitTime == nil { continue } - stats = append(stats, s) + holdHours := pos.ExitTime.Sub(pos.EntryTime).Hours() + + var rangeKey string + switch { + case holdHours < 1: + rangeKey = "<1h" + case holdHours < 4: + rangeKey = "1-4h" + case holdHours < 24: + rangeKey = "4-24h" + default: + rangeKey = ">24h" + } + + r := rangeStats[rangeKey] + r.count++ + r.totalPnL += pos.RealizedPnL + if pos.RealizedPnL > 0 { + r.wins++ + } } + + var stats []HoldingTimeStats + for _, rangeKey := range []string{"<1h", "1-4h", "4-24h", ">24h"} { + r := rangeStats[rangeKey] + if r.count > 0 { + stats = append(stats, HoldingTimeStats{ + Range: rangeKey, + TradeCount: r.count, + WinRate: float64(r.wins) / float64(r.count) * 100, + AvgPnL: r.totalPnL / float64(r.count), + }) + } + } + return stats, nil } @@ -859,71 +684,67 @@ type DirectionStats struct { // GetDirectionStats analyzes long vs short performance func (s *PositionStore) GetDirectionStats(traderID string) ([]DirectionStats, error) { - rows, err := s.db.Query(` - SELECT - side, - COUNT(*) as trade_count, - SUM(CASE WHEN realized_pnl > 0 THEN 1.0 ELSE 0.0 END) / COUNT(*) * 100 as win_rate, - COALESCE(SUM(realized_pnl), 0) as total_pnl, - COALESCE(AVG(realized_pnl), 0) as avg_pnl - FROM trader_positions - WHERE trader_id = ? AND status = 'CLOSED' - GROUP BY side - `, traderID) + var positions []TraderPosition + err := s.db.Where("trader_id = ? AND status = ?", traderID, "CLOSED").Find(&positions).Error if err != nil { return nil, fmt.Errorf("failed to query direction stats: %w", err) } - defer rows.Close() + + sideStats := make(map[string]*DirectionStats) + for _, pos := range positions { + if _, ok := sideStats[pos.Side]; !ok { + sideStats[pos.Side] = &DirectionStats{Side: pos.Side} + } + s := sideStats[pos.Side] + s.TradeCount++ + s.TotalPnL += pos.RealizedPnL + if pos.RealizedPnL > 0 { + s.WinRate++ + } + } var stats []DirectionStats - for rows.Next() { - var s DirectionStats - err := rows.Scan(&s.Side, &s.TradeCount, &s.WinRate, &s.TotalPnL, &s.AvgPnL) - if err != nil { - continue + for _, s := range sideStats { + if s.TradeCount > 0 { + s.AvgPnL = s.TotalPnL / float64(s.TradeCount) + s.WinRate = s.WinRate / float64(s.TradeCount) * 100 } - stats = append(stats, s) + stats = append(stats, *s) } + return stats, nil } // HistorySummary comprehensive trading history for AI context type HistorySummary struct { - // Overall stats TotalTrades int `json:"total_trades"` WinRate float64 `json:"win_rate"` TotalPnL float64 `json:"total_pnl"` - AvgTradeReturn float64 `json:"avg_trade_return"` // Percentage + AvgTradeReturn float64 `json:"avg_trade_return"` - // Best/Worst performers - BestSymbols []SymbolStats `json:"best_symbols"` // Top 3 profitable - WorstSymbols []SymbolStats `json:"worst_symbols"` // Top 3 losing + BestSymbols []SymbolStats `json:"best_symbols"` + WorstSymbols []SymbolStats `json:"worst_symbols"` - // Direction analysis LongWinRate float64 `json:"long_win_rate"` ShortWinRate float64 `json:"short_win_rate"` LongPnL float64 `json:"long_pnl"` ShortPnL float64 `json:"short_pnl"` - // Time analysis AvgHoldingMins float64 `json:"avg_holding_mins"` - BestHoldRange string `json:"best_hold_range"` // e.g., "1-4h" + BestHoldRange string `json:"best_hold_range"` - // Recent performance (last 20 trades) RecentWinRate float64 `json:"recent_win_rate"` RecentPnL float64 `json:"recent_pnl"` - // Streak info - CurrentStreak int `json:"current_streak"` // Positive = wins, negative = losses - MaxWinStreak int `json:"max_win_streak"` - MaxLoseStreak int `json:"max_lose_streak"` + CurrentStreak int `json:"current_streak"` + MaxWinStreak int `json:"max_win_streak"` + MaxLoseStreak int `json:"max_lose_streak"` } // GetHistorySummary generates comprehensive AI context summary func (s *PositionStore) GetHistorySummary(traderID string) (*HistorySummary, error) { summary := &HistorySummary{} - // Get overall stats fullStats, err := s.GetFullStats(traderID) if err != nil { return nil, err @@ -935,16 +756,13 @@ func (s *PositionStore) GetHistorySummary(traderID string) (*HistorySummary, err summary.AvgTradeReturn = fullStats.TotalPnL / float64(fullStats.TotalTrades) } - // Get symbol stats - best performers symbolStats, _ := s.GetSymbolStats(traderID, 20) if len(symbolStats) > 0 { - // Best 3 for i := 0; i < len(symbolStats) && i < 3; i++ { if symbolStats[i].TotalPnL > 0 { summary.BestSymbols = append(summary.BestSymbols, symbolStats[i]) } } - // Worst 3 (from the end) for i := len(symbolStats) - 1; i >= 0 && len(summary.WorstSymbols) < 3; i-- { if symbolStats[i].TotalPnL < 0 { summary.WorstSymbols = append(summary.WorstSymbols, symbolStats[i]) @@ -952,7 +770,6 @@ func (s *PositionStore) GetHistorySummary(traderID string) (*HistorySummary, err } } - // Get direction stats dirStats, _ := s.GetDirectionStats(traderID) for _, d := range dirStats { if d.Side == "LONG" { @@ -964,7 +781,6 @@ func (s *PositionStore) GetHistorySummary(traderID string) (*HistorySummary, err } } - // Get holding time stats holdStats, _ := s.GetHoldingTimeStats(traderID) var bestHoldWinRate float64 for _, h := range holdStats { @@ -975,40 +791,30 @@ func (s *PositionStore) GetHistorySummary(traderID string) (*HistorySummary, err } // Calculate average holding time - var avgHold sql.NullFloat64 - s.db.QueryRow(` - SELECT AVG((julianday(exit_time) - julianday(entry_time)) * 24 * 60) - FROM trader_positions - WHERE trader_id = ? AND status = 'CLOSED' AND exit_time IS NOT NULL - `, traderID).Scan(&avgHold) - if avgHold.Valid { - summary.AvgHoldingMins = avgHold.Float64 - } - - // Get recent 20 trades performance - var recentWins int - var recentTotal int - var recentPnL float64 - rows, err := s.db.Query(` - SELECT realized_pnl FROM trader_positions - WHERE trader_id = ? AND status = 'CLOSED' - ORDER BY exit_time DESC LIMIT 20 - `, traderID) - if err == nil { - defer rows.Close() - for rows.Next() { - var pnl float64 - rows.Scan(&pnl) - recentTotal++ - recentPnL += pnl - if pnl > 0 { - recentWins++ + var positions []TraderPosition + s.db.Where("trader_id = ? AND status = ? AND exit_time IS NOT NULL", traderID, "CLOSED").Find(&positions) + if len(positions) > 0 { + var totalMins float64 + for _, pos := range positions { + if pos.ExitTime != nil { + totalMins += pos.ExitTime.Sub(pos.EntryTime).Minutes() } } + summary.AvgHoldingMins = totalMins / float64(len(positions)) } - if recentTotal > 0 { - summary.RecentWinRate = float64(recentWins) / float64(recentTotal) * 100 - summary.RecentPnL = recentPnL + + // Recent 20 trades + var recent []TraderPosition + s.db.Where("trader_id = ? AND status = ?", traderID, "CLOSED"). + Order("exit_time DESC").Limit(20).Find(&recent) + for _, pos := range recent { + summary.RecentPnL += pos.RealizedPnL + if pos.RealizedPnL > 0 { + summary.RecentWinRate++ + } + } + if len(recent) > 0 { + summary.RecentWinRate = summary.RecentWinRate / float64(len(recent)) * 100 } // Calculate streaks @@ -1019,24 +825,20 @@ func (s *PositionStore) GetHistorySummary(traderID string) (*HistorySummary, err // calculateStreaks calculates win/loss streaks func (s *PositionStore) calculateStreaks(traderID string, summary *HistorySummary) { - rows, err := s.db.Query(` - SELECT realized_pnl FROM trader_positions - WHERE trader_id = ? AND status = 'CLOSED' - ORDER BY exit_time DESC - `, traderID) - if err != nil { + var positions []TraderPosition + err := s.db.Where("trader_id = ? AND status = ?", traderID, "CLOSED"). + Order("exit_time DESC"). + Find(&positions).Error + if err != nil || len(positions) == 0 { return } - defer rows.Close() var currentStreak, maxWin, maxLose int var prevWin *bool isFirst := true - for rows.Next() { - var pnl float64 - rows.Scan(&pnl) - isWin := pnl > 0 + for _, pos := range positions { + isWin := pos.RealizedPnL > 0 if isFirst { if isWin { @@ -1076,23 +878,16 @@ func (s *PositionStore) calculateStreaks(traderID string, summary *HistorySummar summary.MaxLoseStreak = maxLose } -// ============================================================================= -// Deduplication and Sync Methods -// ============================================================================= - -// ExistsWithExchangePositionID checks if a position with the given exchange position ID already exists -// Note: Uses exchange_id (account UUID) for deduplication, not trader_id -// This ensures that the same position from an exchange account is not duplicated across different traders +// ExistsWithExchangePositionID checks if a position exists func (s *PositionStore) ExistsWithExchangePositionID(exchangeID, exchangePositionID string) (bool, error) { if exchangePositionID == "" { return false, nil } - var count int - err := s.db.QueryRow(` - SELECT COUNT(*) FROM trader_positions - WHERE exchange_id = ? AND exchange_position_id = ? - `, exchangeID, exchangePositionID).Scan(&count) + var count int64 + err := s.db.Model(&TraderPosition{}). + Where("exchange_id = ? AND exchange_position_id = ?", exchangeID, exchangePositionID). + Count(&count).Error if err != nil { return false, fmt.Errorf("failed to check position existence: %w", err) } @@ -1100,146 +895,28 @@ func (s *PositionStore) ExistsWithExchangePositionID(exchangeID, exchangePositio } // GetOpenPositionByExchangePositionID gets an OPEN position by exchange_position_id -// Used for accumulating into existing position when duplicate exchange_position_id is detected func (s *PositionStore) GetOpenPositionByExchangePositionID(exchangeID, exchangePositionID string) (*TraderPosition, error) { if exchangePositionID == "" { return nil, nil } var pos TraderPosition - var entryTime, exitTime, createdAt, updatedAt sql.NullString - - err := s.db.QueryRow(` - SELECT id, trader_id, exchange_id, COALESCE(exchange_type, '') as exchange_type, symbol, side, quantity, COALESCE(entry_quantity, quantity) as entry_quantity, entry_price, entry_order_id, - entry_time, exit_price, exit_order_id, exit_time, realized_pnl, fee, - leverage, status, close_reason, created_at, updated_at - FROM trader_positions - WHERE exchange_id = ? AND exchange_position_id = ? AND status = 'OPEN' - LIMIT 1 - `, exchangeID, exchangePositionID).Scan( - &pos.ID, &pos.TraderID, &pos.ExchangeID, &pos.ExchangeType, &pos.Symbol, &pos.Side, &pos.Quantity, &pos.EntryQuantity, - &pos.EntryPrice, &pos.EntryOrderID, &entryTime, &pos.ExitPrice, - &pos.ExitOrderID, &exitTime, &pos.RealizedPnL, &pos.Fee, - &pos.Leverage, &pos.Status, &pos.CloseReason, &createdAt, &updatedAt, - ) + err := s.db.Where("exchange_id = ? AND exchange_position_id = ? AND status = ?", exchangeID, exchangePositionID, "OPEN"). + First(&pos).Error if err != nil { - if err == sql.ErrNoRows { + if err == gorm.ErrRecordNotFound { return nil, nil } return nil, err } - s.parsePositionTimes(&pos, entryTime, exitTime, createdAt, updatedAt) + if pos.EntryQuantity == 0 { + pos.EntryQuantity = pos.Quantity + } return &pos, nil } -// CreateFromClosedPnL creates a closed position record from exchange closed PnL data -// This is used for syncing historical positions from exchange -// Returns true if created, false if already exists (deduped) or invalid data -func (s *PositionStore) CreateFromClosedPnL(traderID, exchangeID, exchangeType string, record *ClosedPnLRecord) (bool, error) { - // ========================================================================== - // Step 1: Validate required fields - // ========================================================================== - if record.Symbol == "" { - return false, nil // Skip: no symbol - } - - // Normalize and validate side - side := strings.ToUpper(record.Side) - if side == "LONG" || side == "BUY" { - side = "LONG" - } else if side == "SHORT" || side == "SELL" { - side = "SHORT" - } else { - return false, nil // Skip: invalid side - } - - // Validate quantity - if record.Quantity <= 0 { - return false, nil // Skip: invalid quantity - } - - // Validate prices (entry price can be calculated, but should be positive) - if record.ExitPrice <= 0 { - return false, nil // Skip: invalid exit price - } - if record.EntryPrice <= 0 { - return false, nil // Skip: invalid entry price - } - - // ========================================================================== - // Step 2: Generate unique exchange position ID for deduplication - // ========================================================================== - exchangePositionID := record.ExchangeID - if exchangePositionID == "" { - // Fallback: generate from symbol + side + exit time + pnl (to ensure uniqueness) - exchangePositionID = fmt.Sprintf("%s_%s_%d_%.8f", - record.Symbol, side, record.ExitTime.UnixMilli(), record.RealizedPnL) - } - - // ========================================================================== - // Step 3: Check for duplicates based on (exchange_id, exchange_position_id) - // ========================================================================== - exists, err := s.ExistsWithExchangePositionID(exchangeID, exchangePositionID) - if err != nil { - return false, err - } - if exists { - return false, nil // Already exists, skip - } - - // ========================================================================== - // Step 4: Handle timestamps - // ========================================================================== - now := time.Now() - exitTime := record.ExitTime - entryTime := record.EntryTime - - // Validate exit time - if exitTime.IsZero() || exitTime.Year() < 2000 { - return false, nil // Skip: invalid exit time - } - - // Handle zero entry time - use exit time as approximation - if entryTime.IsZero() || entryTime.Year() < 2000 { - entryTime = exitTime - } - - // Entry time should not be after exit time - if entryTime.After(exitTime) { - entryTime = exitTime - } - - // ========================================================================== - // Step 5: Insert into database - // ========================================================================== - _, err = s.db.Exec(` - INSERT INTO trader_positions ( - trader_id, exchange_id, exchange_type, exchange_position_id, symbol, side, quantity, - entry_price, entry_order_id, entry_time, - exit_price, exit_order_id, exit_time, - realized_pnl, fee, leverage, status, close_reason, source, - created_at, updated_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 'CLOSED', ?, 'sync', ?, ?) - `, - traderID, exchangeID, exchangeType, exchangePositionID, record.Symbol, side, record.Quantity, - record.EntryPrice, "", entryTime.Format(time.RFC3339), - record.ExitPrice, record.OrderID, exitTime.Format(time.RFC3339), - record.RealizedPnL, record.Fee, record.Leverage, record.CloseType, - now.Format(time.RFC3339), now.Format(time.RFC3339), - ) - if err != nil { - // Duplicate key error, treat as already exists - if strings.Contains(err.Error(), "UNIQUE constraint failed") { - return false, nil - } - return false, fmt.Errorf("failed to create position from closed PnL: %w", err) - } - - return true, nil -} - -// ClosedPnLRecord represents a closed position record from exchange (duplicated here for store package) +// ClosedPnLRecord represents a closed position record from exchange type ClosedPnLRecord struct { Symbol string Side string @@ -1256,92 +933,133 @@ type ClosedPnLRecord struct { ExchangeID string } -// GetLastClosedPositionTime gets the most recent exit time from closed positions -// This is used to determine the start time for syncing new closed positions -func (s *PositionStore) GetLastClosedPositionTime(traderID string) (time.Time, error) { - var exitTime sql.NullString - err := s.db.QueryRow(` - SELECT exit_time FROM trader_positions - WHERE trader_id = ? AND status = 'CLOSED' AND exit_time IS NOT NULL - ORDER BY exit_time DESC LIMIT 1 - `, traderID).Scan(&exitTime) +// CreateFromClosedPnL creates a closed position record from exchange data +func (s *PositionStore) CreateFromClosedPnL(traderID, exchangeID, exchangeType string, record *ClosedPnLRecord) (bool, error) { + if record.Symbol == "" { + return false, nil + } - if err == sql.ErrNoRows || !exitTime.Valid { - // No closed positions, return 30 days ago as default + side := strings.ToUpper(record.Side) + if side == "LONG" || side == "BUY" { + side = "LONG" + } else if side == "SHORT" || side == "SELL" { + side = "SHORT" + } else { + return false, nil + } + + if record.Quantity <= 0 || record.ExitPrice <= 0 || record.EntryPrice <= 0 { + return false, nil + } + + exchangePositionID := record.ExchangeID + if exchangePositionID == "" { + exchangePositionID = fmt.Sprintf("%s_%s_%d_%.8f", record.Symbol, side, record.ExitTime.UnixMilli(), record.RealizedPnL) + } + + exists, err := s.ExistsWithExchangePositionID(exchangeID, exchangePositionID) + if err != nil { + return false, err + } + if exists { + return false, nil + } + + exitTime := record.ExitTime + entryTime := record.EntryTime + + if exitTime.IsZero() || exitTime.Year() < 2000 { + return false, nil + } + if entryTime.IsZero() || entryTime.Year() < 2000 { + entryTime = exitTime + } + if entryTime.After(exitTime) { + entryTime = exitTime + } + + pos := &TraderPosition{ + TraderID: traderID, + ExchangeID: exchangeID, + ExchangeType: exchangeType, + ExchangePositionID: exchangePositionID, + Symbol: record.Symbol, + Side: side, + Quantity: record.Quantity, + EntryQuantity: record.Quantity, + EntryPrice: record.EntryPrice, + EntryTime: entryTime, + ExitPrice: record.ExitPrice, + ExitOrderID: record.OrderID, + ExitTime: &exitTime, + RealizedPnL: record.RealizedPnL, + Fee: record.Fee, + Leverage: record.Leverage, + Status: "CLOSED", + CloseReason: record.CloseType, + Source: "sync", + } + + err = s.db.Create(pos).Error + if err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint failed") { + return false, nil + } + return false, fmt.Errorf("failed to create position from closed PnL: %w", err) + } + + return true, nil +} + +// GetLastClosedPositionTime gets the most recent exit time +func (s *PositionStore) GetLastClosedPositionTime(traderID string) (time.Time, error) { + var pos TraderPosition + err := s.db.Where("trader_id = ? AND status = ? AND exit_time IS NOT NULL", traderID, "CLOSED"). + Order("exit_time DESC"). + First(&pos).Error + + if err == gorm.ErrRecordNotFound || pos.ExitTime == nil { return time.Now().Add(-30 * 24 * time.Hour), nil } if err != nil { return time.Time{}, fmt.Errorf("failed to get last closed position time: %w", err) } - t, _ := time.Parse(time.RFC3339, exitTime.String) - return t, nil + return *pos.ExitTime, nil } -// CreateOpenPosition creates an open position record with exchange position ID -// NOTE: This function should only be called when GetOpenPositionBySymbol returns nil. -// If a position with the same exchange_position_id already exists (e.g., due to same millisecond trades), -// this function will accumulate into the existing position instead of silently skipping. +// CreateOpenPosition creates an open position func (s *PositionStore) CreateOpenPosition(pos *TraderPosition) error { - // Check if already exists by exchange position ID - // If exists, accumulate into that position instead of skipping if pos.ExchangePositionID != "" && pos.ExchangeID != "" { existingPos, err := s.GetOpenPositionByExchangePositionID(pos.ExchangeID, pos.ExchangePositionID) if err != nil { return err } if existingPos != nil { - // Position with same exchange_position_id exists and is OPEN, accumulate into it return s.UpdatePositionQuantityAndPrice(existingPos.ID, pos.Quantity, pos.EntryPrice, pos.Fee) } - // Check if position exists but is CLOSED exists, err := s.ExistsWithExchangePositionID(pos.ExchangeID, pos.ExchangePositionID) if err != nil { return err } if exists { - // Position exists but is CLOSED, skip (this is a valid case for historical sync) return nil } } - now := time.Now() - pos.CreatedAt = now - pos.UpdatedAt = now - // Only set status to OPEN if not already set (allows creating CLOSED positions) if pos.Status == "" { pos.Status = "OPEN" } if pos.Source == "" { pos.Source = "system" } - // Set EntryQuantity to same as Quantity if not already set if pos.EntryQuantity == 0 { pos.EntryQuantity = pos.Quantity } - // Format exit time if present - var exitTimeStr *string - if pos.ExitTime != nil { - s := pos.ExitTime.Format(time.RFC3339) - exitTimeStr = &s - } - - result, err := s.db.Exec(` - INSERT INTO trader_positions ( - trader_id, exchange_id, exchange_type, exchange_position_id, symbol, side, quantity, entry_quantity, - entry_price, entry_order_id, entry_time, exit_price, exit_order_id, exit_time, - realized_pnl, leverage, status, source, fee, - created_at, updated_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - `, - pos.TraderID, pos.ExchangeID, pos.ExchangeType, pos.ExchangePositionID, pos.Symbol, pos.Side, pos.Quantity, pos.EntryQuantity, - pos.EntryPrice, pos.EntryOrderID, pos.EntryTime.Format(time.RFC3339), pos.ExitPrice, pos.ExitOrderID, exitTimeStr, - pos.RealizedPnL, pos.Leverage, pos.Status, pos.Source, pos.Fee, now.Format(time.RFC3339), now.Format(time.RFC3339), - ) + err := s.db.Create(pos).Error if err != nil { if strings.Contains(err.Error(), "UNIQUE constraint failed") { - // UNIQUE constraint failed, try to accumulate into existing position existingPos, findErr := s.GetOpenPositionByExchangePositionID(pos.ExchangeID, pos.ExchangePositionID) if findErr != nil { return findErr @@ -1349,42 +1067,32 @@ func (s *PositionStore) CreateOpenPosition(pos *TraderPosition) error { if existingPos != nil { return s.UpdatePositionQuantityAndPrice(existingPos.ID, pos.Quantity, pos.EntryPrice, pos.Fee) } - // Position is CLOSED, skip return nil } return fmt.Errorf("failed to create open position: %w", err) } - id, _ := result.LastInsertId() - pos.ID = id return nil } // ClosePositionWithAccurateData closes a position with accurate data from exchange func (s *PositionStore) ClosePositionWithAccurateData(id int64, exitPrice float64, exitOrderID string, exitTime time.Time, realizedPnL float64, fee float64, closeReason string) error { - now := time.Now() - _, err := s.db.Exec(` - UPDATE trader_positions SET - exit_price = ?, exit_order_id = ?, exit_time = ?, - realized_pnl = ?, fee = ?, status = 'CLOSED', - close_reason = ?, updated_at = ? - WHERE id = ? - `, - exitPrice, exitOrderID, exitTime.Format(time.RFC3339), - realizedPnL, fee, closeReason, now.Format(time.RFC3339), id, - ) - if err != nil { - return fmt.Errorf("failed to close position with accurate data: %w", err) - } - return nil + return s.db.Model(&TraderPosition{}).Where("id = ?", id).Updates(map[string]interface{}{ + "exit_price": exitPrice, + "exit_order_id": exitOrderID, + "exit_time": exitTime, + "realized_pnl": realizedPnL, + "fee": fee, + "status": "CLOSED", + "close_reason": closeReason, + }).Error } -// SyncClosedPositions syncs closed positions from exchange to local database -// Returns (created count, skipped count, error) +// SyncClosedPositions syncs closed positions from exchange func (s *PositionStore) SyncClosedPositions(traderID, exchangeID, exchangeType string, records []ClosedPnLRecord) (int, int, error) { created, skipped := 0, 0 for _, record := range records { - rec := record // Create local copy to avoid closure issues + rec := record wasCreated, err := s.CreateFromClosedPnL(traderID, exchangeID, exchangeType, &rec) if err != nil { return created, skipped, fmt.Errorf("failed to sync position: %w", err) diff --git a/store/store.go b/store/store.go index 4a3947b5..21c15813 100644 --- a/store/store.go +++ b/store/store.go @@ -7,12 +7,15 @@ import ( "fmt" "nofx/logger" "sync" + + "gorm.io/gorm" ) // Store unified data storage interface type Store struct { - db *sql.DB - driver *DBDriver // Database driver for abstraction + gdb *gorm.DB // GORM database connection + db *sql.DB // Legacy sql.DB for backward compatibility + driver *DBDriver // Database driver for abstraction (legacy) // Sub-stores (lazy initialization) user *UserStore @@ -26,105 +29,103 @@ type Store struct { equity *EquityStore order *OrderStore - // Encryption functions - encryptFunc func(string) string - decryptFunc func(string) string - mu sync.RWMutex } // New creates new Store instance (SQLite mode for backward compatibility) func New(dbPath string) (*Store, error) { - driver, err := NewDBDriver(DBConfig{Type: DBTypeSQLite, Path: dbPath}) + gdb, err := InitGorm(dbPath) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } - s := &Store{db: driver.DB(), driver: driver} + // Get underlying sql.DB for legacy compatibility + sqlDB, err := gdb.DB() + if err != nil { + return nil, fmt.Errorf("failed to get sql.DB: %w", err) + } + + s := &Store{gdb: gdb, db: sqlDB} // Initialize all table structures if err := s.initTables(); err != nil { - driver.Close() + sqlDB.Close() return nil, fmt.Errorf("failed to initialize table structure: %w", err) } // Initialize default data if err := s.initDefaultData(); err != nil { - driver.Close() + sqlDB.Close() return nil, fmt.Errorf("failed to initialize default data: %w", err) } - logger.Infof("✅ Database initialized (type: %s)", driver.Type) + logger.Infof("✅ Database initialized (GORM, SQLite)") return s, nil } -// NewFromEnv creates new Store instance from environment variables -// DB_TYPE: sqlite (default) or postgres -// For SQLite: DB_PATH (default: data/data.db) -// For PostgreSQL: DB_HOST, DB_PORT, DB_USER, DB_PASSWORD, DB_NAME, DB_SSLMODE -func NewFromEnv() (*Store, error) { - driver, err := NewDBDriverFromEnv() +// NewWithConfig creates new Store instance with provided database configuration +func NewWithConfig(cfg DBConfig) (*Store, error) { + gdb, err := InitGormWithConfig(cfg) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } - s := &Store{db: driver.DB(), driver: driver} + // Get underlying sql.DB for legacy compatibility + sqlDB, err := gdb.DB() + if err != nil { + return nil, fmt.Errorf("failed to get sql.DB: %w", err) + } + + s := &Store{gdb: gdb, db: sqlDB} // Initialize all table structures if err := s.initTables(); err != nil { - driver.Close() + sqlDB.Close() return nil, fmt.Errorf("failed to initialize table structure: %w", err) } // Initialize default data if err := s.initDefaultData(); err != nil { - driver.Close() + sqlDB.Close() return nil, fmt.Errorf("failed to initialize default data: %w", err) } - logger.Infof("✅ Database initialized (type: %s)", driver.Type) + dbTypeStr := "SQLite" + if cfg.Type == DBTypePostgres { + dbTypeStr = "PostgreSQL" + } + logger.Infof("✅ Database initialized (GORM, %s)", dbTypeStr) return s, nil } -// NewFromDB creates Store from existing database connection +// NewFromGorm creates Store from existing GORM connection +func NewFromGorm(gdb *gorm.DB) (*Store, error) { + sqlDB, err := gdb.DB() + if err != nil { + return nil, err + } + return &Store{gdb: gdb, db: sqlDB}, nil +} + +// NewFromDB creates Store from existing database connection (legacy) +// Deprecated: Use NewFromGorm instead func NewFromDB(db *sql.DB) *Store { return &Store{db: db} } -// SetCryptoFuncs sets encryption/decryption functions -func (s *Store) SetCryptoFuncs(encrypt, decrypt func(string) string) { - s.mu.Lock() - defer s.mu.Unlock() - s.encryptFunc = encrypt - s.decryptFunc = decrypt - - // Update already initialized sub-stores - if s.aiModel != nil { - s.aiModel.encryptFunc = encrypt - s.aiModel.decryptFunc = decrypt - } - if s.exchange != nil { - s.exchange.encryptFunc = encrypt - s.exchange.decryptFunc = decrypt - } - if s.trader != nil { - s.trader.decryptFunc = decrypt - } -} - -// initTables initializes all database tables +// initTables initializes all database tables using GORM AutoMigrate func (s *Store) initTables() error { - // Initialize system config table first - if _, err := s.db.Exec(` + // Create system_config table (GORM handles this via raw SQL for simplicity) + if err := s.gdb.Exec(` CREATE TABLE IF NOT EXISTS system_config ( key TEXT PRIMARY KEY, value TEXT NOT NULL ) - `); err != nil { + `).Error; err != nil { return fmt.Errorf("failed to create system_config table: %w", err) } - // Initialize in dependency order + // Initialize sub-store tables if err := s.User().initTables(); err != nil { return fmt.Errorf("failed to initialize user tables: %w", err) } @@ -183,7 +184,7 @@ func (s *Store) User() *UserStore { s.mu.Lock() defer s.mu.Unlock() if s.user == nil { - s.user = &UserStore{db: s.db} + s.user = NewUserStore(s.gdb) } return s.user } @@ -193,11 +194,7 @@ func (s *Store) AIModel() *AIModelStore { s.mu.Lock() defer s.mu.Unlock() if s.aiModel == nil { - s.aiModel = &AIModelStore{ - db: s.db, - encryptFunc: s.encryptFunc, - decryptFunc: s.decryptFunc, - } + s.aiModel = NewAIModelStore(s.gdb) } return s.aiModel } @@ -207,11 +204,7 @@ func (s *Store) Exchange() *ExchangeStore { s.mu.Lock() defer s.mu.Unlock() if s.exchange == nil { - s.exchange = &ExchangeStore{ - db: s.db, - encryptFunc: s.encryptFunc, - decryptFunc: s.decryptFunc, - } + s.exchange = NewExchangeStore(s.gdb) } return s.exchange } @@ -221,10 +214,7 @@ func (s *Store) Trader() *TraderStore { s.mu.Lock() defer s.mu.Unlock() if s.trader == nil { - s.trader = &TraderStore{ - db: s.db, - decryptFunc: s.decryptFunc, - } + s.trader = NewTraderStore(s.gdb) } return s.trader } @@ -234,7 +224,7 @@ func (s *Store) Decision() *DecisionStore { s.mu.Lock() defer s.mu.Unlock() if s.decision == nil { - s.decision = &DecisionStore{db: s.db} + s.decision = NewDecisionStore(s.gdb) } return s.decision } @@ -244,7 +234,7 @@ func (s *Store) Backtest() *BacktestStore { s.mu.Lock() defer s.mu.Unlock() if s.backtest == nil { - s.backtest = &BacktestStore{db: s.db} + s.backtest = NewBacktestStore(s.gdb) } return s.backtest } @@ -254,7 +244,7 @@ func (s *Store) Position() *PositionStore { s.mu.Lock() defer s.mu.Unlock() if s.position == nil { - s.position = NewPositionStore(s.db) + s.position = NewPositionStore(s.gdb) } return s.position } @@ -264,7 +254,7 @@ func (s *Store) Strategy() *StrategyStore { s.mu.Lock() defer s.mu.Unlock() if s.strategy == nil { - s.strategy = &StrategyStore{db: s.db} + s.strategy = NewStrategyStore(s.gdb) } return s.strategy } @@ -274,7 +264,7 @@ func (s *Store) Equity() *EquityStore { s.mu.Lock() defer s.mu.Unlock() if s.equity == nil { - s.equity = &EquityStore{db: s.db} + s.equity = NewEquityStore(s.gdb) } return s.equity } @@ -284,7 +274,7 @@ func (s *Store) Order() *OrderStore { s.mu.Lock() defer s.mu.Unlock() if s.order == nil { - s.order = NewOrderStore(s.db) + s.order = NewOrderStore(s.gdb) } return s.order } @@ -294,10 +284,18 @@ func (s *Store) Close() error { if s.driver != nil { return s.driver.Close() } - return s.db.Close() + if s.db != nil { + return s.db.Close() + } + return nil } -// Driver returns database driver for abstraction +// GormDB returns the GORM database connection +func (s *Store) GormDB() *gorm.DB { + return s.gdb +} + +// Driver returns database driver for abstraction (legacy) func (s *Store) Driver() *DBDriver { return s.driver } @@ -307,11 +305,25 @@ func (s *Store) DBType() DBType { if s.driver != nil { return s.driver.Type } + // Detect from GORM dialector + if s.gdb != nil { + switch s.gdb.Dialector.Name() { + case "postgres": + return DBTypePostgres + default: + return DBTypeSQLite + } + } return DBTypeSQLite } -// DB gets underlying database connection (for legacy code compatibility, gradually deprecated) -// Deprecated: use Store methods instead +// q converts query placeholders for current database type (legacy helper) +func (s *Store) q(query string) string { + return convertQuery(query, s.DBType()) +} + +// DB gets underlying database connection (for legacy code compatibility) +// Deprecated: use GormDB() instead func (s *Store) DB() *sql.DB { return s.db } @@ -319,24 +331,36 @@ func (s *Store) DB() *sql.DB { // GetSystemConfig gets a system configuration value by key func (s *Store) GetSystemConfig(key string) (string, error) { var value string - err := s.db.QueryRow(`SELECT value FROM system_config WHERE key = ?`, key).Scan(&value) - if err == sql.ErrNoRows { + result := s.gdb.Raw("SELECT value FROM system_config WHERE key = ?", key).Scan(&value) + if result.Error != nil { + if result.Error == gorm.ErrRecordNotFound { + return "", nil + } + return "", result.Error + } + if result.RowsAffected == 0 { return "", nil } - return value, err + return value, nil } // SetSystemConfig sets a system configuration value func (s *Store) SetSystemConfig(key, value string) error { - _, err := s.db.Exec(` + // Use GORM-compatible upsert + return s.gdb.Exec(` INSERT INTO system_config (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value - `, key, value) - return err + `, key, value).Error } -// Transaction executes transaction -func (s *Store) Transaction(fn func(tx *sql.Tx) error) error { +// Transaction executes transaction with GORM +func (s *Store) Transaction(fn func(tx *gorm.DB) error) error { + return s.gdb.Transaction(fn) +} + +// TransactionSQL executes transaction with sql.Tx (legacy) +// Deprecated: Use Transaction() instead +func (s *Store) TransactionSQL(fn func(tx *sql.Tx) error) error { tx, err := s.db.Begin() if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) diff --git a/store/strategy.go b/store/strategy.go index e5f78bf9..1e4af1bd 100644 --- a/store/strategy.go +++ b/store/strategy.go @@ -1,30 +1,33 @@ package store import ( - "database/sql" "encoding/json" "fmt" "time" + + "gorm.io/gorm" ) // StrategyStore strategy storage type StrategyStore struct { - db *sql.DB + db *gorm.DB } // Strategy strategy configuration type Strategy struct { - ID string `json:"id"` - UserID string `json:"user_id"` - Name string `json:"name"` - Description string `json:"description"` - IsActive bool `json:"is_active"` // whether it is active (a user can only have one active strategy) - IsDefault bool `json:"is_default"` // whether it is a system default strategy - Config string `json:"config"` // strategy configuration in JSON format + ID string `gorm:"primaryKey" json:"id"` + UserID string `gorm:"column:user_id;not null;default:'';index" json:"user_id"` + Name string `gorm:"not null" json:"name"` + Description string `gorm:"default:''" json:"description"` + IsActive bool `gorm:"column:is_active;default:false;index" json:"is_active"` + IsDefault bool `gorm:"column:is_default;default:false" json:"is_default"` + Config string `gorm:"not null;default:'{}'" json:"config"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } +func (Strategy) TableName() string { return "strategies" } + // StrategyConfig strategy configuration details (JSON structure) type StrategyConfig struct { // coin source configuration @@ -136,24 +139,6 @@ type ExternalDataSource struct { } // RiskControlConfig risk control configuration -// All parameters are clearly defined without ambiguity: -// -// Position Limits: -// - MaxPositions: max number of coins held simultaneously (CODE ENFORCED) -// -// Trading Leverage (exchange leverage for opening positions): -// - BTCETHMaxLeverage: BTC/ETH max exchange leverage (AI guided) -// - AltcoinMaxLeverage: Altcoin max exchange leverage (AI guided) -// -// Position Value Limits (single position notional value / account equity): -// - BTCETHMaxPositionValueRatio: BTC/ETH max = equity × ratio (CODE ENFORCED) -// - AltcoinMaxPositionValueRatio: Altcoin max = equity × ratio (CODE ENFORCED) -// -// Risk Controls: -// - MaxMarginUsage: max margin utilization percentage (CODE ENFORCED) -// - MinPositionSize: minimum position size in USDT (CODE ENFORCED) -// - MinRiskRewardRatio: min take_profit / stop_loss ratio (AI guided) -// - MinConfidence: min AI confidence to open position (AI guided) type RiskControlConfig struct { // Max number of coins held simultaneously (CODE ENFORCED) MaxPositions int `json:"max_positions"` @@ -179,38 +164,21 @@ type RiskControlConfig struct { MinConfidence int `json:"min_confidence"` } +// NewStrategyStore creates a new StrategyStore +func NewStrategyStore(db *gorm.DB) *StrategyStore { + return &StrategyStore{db: db} +} + func (s *StrategyStore) initTables() error { - _, err := s.db.Exec(` - CREATE TABLE IF NOT EXISTS strategies ( - id TEXT PRIMARY KEY, - user_id TEXT NOT NULL DEFAULT '', - name TEXT NOT NULL, - description TEXT DEFAULT '', - is_active BOOLEAN DEFAULT 0, - is_default BOOLEAN DEFAULT 0, - config TEXT NOT NULL DEFAULT '{}', - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - `) - if err != nil { - return err + // For PostgreSQL with existing table, skip AutoMigrate + if s.db.Dialector.Name() == "postgres" { + var tableExists int64 + s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'strategies'`).Scan(&tableExists) + if tableExists > 0 { + return nil + } } - - // create indexes - _, _ = s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_strategies_user_id ON strategies(user_id)`) - _, _ = s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_strategies_is_active ON strategies(is_active)`) - - // trigger: automatically update updated_at on update - _, err = s.db.Exec(` - CREATE TRIGGER IF NOT EXISTS update_strategies_updated_at - AFTER UPDATE ON strategies - BEGIN - UPDATE strategies SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id; - END - `) - - return err + return s.db.AutoMigrate(&Strategy{}) } func (s *StrategyStore) initDefaultData() error { @@ -322,159 +290,93 @@ Only enter positions when multiple signals resonate. Freely use any effective an // Create create a strategy func (s *StrategyStore) Create(strategy *Strategy) error { - _, err := s.db.Exec(` - INSERT INTO strategies (id, user_id, name, description, is_active, is_default, config) - VALUES (?, ?, ?, ?, ?, ?, ?) - `, strategy.ID, strategy.UserID, strategy.Name, strategy.Description, strategy.IsActive, strategy.IsDefault, strategy.Config) - return err + return s.db.Create(strategy).Error } // Update update a strategy func (s *StrategyStore) Update(strategy *Strategy) error { - _, err := s.db.Exec(` - UPDATE strategies SET - name = ?, description = ?, config = ?, updated_at = CURRENT_TIMESTAMP - WHERE id = ? AND user_id = ? - `, strategy.Name, strategy.Description, strategy.Config, strategy.ID, strategy.UserID) - return err + return s.db.Model(&Strategy{}). + Where("id = ? AND user_id = ?", strategy.ID, strategy.UserID). + Updates(map[string]interface{}{ + "name": strategy.Name, + "description": strategy.Description, + "config": strategy.Config, + "updated_at": time.Now(), + }).Error } // Delete delete a strategy func (s *StrategyStore) Delete(userID, id string) error { // do not allow deleting system default strategy - var isDefault bool - s.db.QueryRow(`SELECT is_default FROM strategies WHERE id = ?`, id).Scan(&isDefault) - if isDefault { + var st Strategy + if err := s.db.Where("id = ?", id).First(&st).Error; err == nil && st.IsDefault { return fmt.Errorf("cannot delete system default strategy") } - _, err := s.db.Exec(`DELETE FROM strategies WHERE id = ? AND user_id = ?`, id, userID) - return err + return s.db.Where("id = ? AND user_id = ?", id, userID).Delete(&Strategy{}).Error } // List get user's strategy list func (s *StrategyStore) List(userID string) ([]*Strategy, error) { - // get user's own strategies + system default strategy - rows, err := s.db.Query(` - SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at - FROM strategies - WHERE user_id = ? OR is_default = 1 - ORDER BY is_default DESC, created_at DESC - `, userID) + var strategies []*Strategy + err := s.db.Where("user_id = ? OR is_default = ?", userID, true). + Order("is_default DESC, created_at DESC"). + Find(&strategies).Error if err != nil { return nil, err } - defer rows.Close() - - var strategies []*Strategy - for rows.Next() { - var st Strategy - var createdAt, updatedAt string - err := rows.Scan( - &st.ID, &st.UserID, &st.Name, &st.Description, - &st.IsActive, &st.IsDefault, &st.Config, - &createdAt, &updatedAt, - ) - if err != nil { - return nil, err - } - st.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - st.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) - strategies = append(strategies, &st) - } return strategies, nil } // Get get a single strategy func (s *StrategyStore) Get(userID, id string) (*Strategy, error) { var st Strategy - var createdAt, updatedAt string - err := s.db.QueryRow(` - SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at - FROM strategies - WHERE id = ? AND (user_id = ? OR is_default = 1) - `, id, userID).Scan( - &st.ID, &st.UserID, &st.Name, &st.Description, - &st.IsActive, &st.IsDefault, &st.Config, - &createdAt, &updatedAt, - ) + err := s.db.Where("id = ? AND (user_id = ? OR is_default = ?)", id, userID, true). + First(&st).Error if err != nil { return nil, err } - st.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - st.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) return &st, nil } // GetActive get user's currently active strategy func (s *StrategyStore) GetActive(userID string) (*Strategy, error) { var st Strategy - var createdAt, updatedAt string - err := s.db.QueryRow(` - SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at - FROM strategies - WHERE user_id = ? AND is_active = 1 - `, userID).Scan( - &st.ID, &st.UserID, &st.Name, &st.Description, - &st.IsActive, &st.IsDefault, &st.Config, - &createdAt, &updatedAt, - ) - if err == sql.ErrNoRows { + err := s.db.Where("user_id = ? AND is_active = ?", userID, true).First(&st).Error + if err == gorm.ErrRecordNotFound { // no active strategy, return system default strategy return s.GetDefault() } if err != nil { return nil, err } - st.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - st.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) return &st, nil } // GetDefault get system default strategy func (s *StrategyStore) GetDefault() (*Strategy, error) { var st Strategy - var createdAt, updatedAt string - err := s.db.QueryRow(` - SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at - FROM strategies - WHERE is_default = 1 - LIMIT 1 - `).Scan( - &st.ID, &st.UserID, &st.Name, &st.Description, - &st.IsActive, &st.IsDefault, &st.Config, - &createdAt, &updatedAt, - ) + err := s.db.Where("is_default = ?", true).First(&st).Error if err != nil { return nil, err } - st.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - st.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) return &st, nil } // SetActive set active strategy (will first deactivate other strategies) func (s *StrategyStore) SetActive(userID, strategyID string) error { - // begin transaction - tx, err := s.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() + return s.db.Transaction(func(tx *gorm.DB) error { + // first deactivate all strategies for the user + if err := tx.Model(&Strategy{}).Where("user_id = ?", userID). + Update("is_active", false).Error; err != nil { + return err + } - // first deactivate all strategies for the user - _, err = tx.Exec(`UPDATE strategies SET is_active = 0 WHERE user_id = ?`, userID) - if err != nil { - return err - } - - // activate specified strategy - _, err = tx.Exec(`UPDATE strategies SET is_active = 1 WHERE id = ? AND (user_id = ? OR is_default = 1)`, strategyID, userID) - if err != nil { - return err - } - - return tx.Commit() + // activate specified strategy + return tx.Model(&Strategy{}). + Where("id = ? AND (user_id = ? OR is_default = ?)", strategyID, userID, true). + Update("is_active", true).Error + }) } // Duplicate duplicate a strategy (used to create custom strategy based on default strategy) diff --git a/store/trader.go b/store/trader.go index 240a2ae0..b7f364eb 100644 --- a/store/trader.go +++ b/store/trader.go @@ -1,43 +1,52 @@ package store import ( - "database/sql" "fmt" - "strings" "time" + + "gorm.io/gorm" ) // TraderStore trader storage type TraderStore struct { - db *sql.DB - decryptFunc func(string) string + db *gorm.DB +} + +// NewTraderStore creates a new trader store +func NewTraderStore(db *gorm.DB) *TraderStore { + return &TraderStore{db: db} } // Trader trader configuration type Trader struct { - ID string `json:"id"` - UserID string `json:"user_id"` - Name string `json:"name"` - AIModelID string `json:"ai_model_id"` - ExchangeID string `json:"exchange_id"` - StrategyID string `json:"strategy_id"` // Associated strategy ID - InitialBalance float64 `json:"initial_balance"` - ScanIntervalMinutes int `json:"scan_interval_minutes"` - IsRunning bool `json:"is_running"` - IsCrossMargin bool `json:"is_cross_margin"` - ShowInCompetition bool `json:"show_in_competition"` // Whether to show in competition page - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID string `gorm:"primaryKey" json:"id"` + UserID string `gorm:"column:user_id;not null;default:default;index" json:"user_id"` + Name string `gorm:"column:name;not null" json:"name"` + AIModelID string `gorm:"column:ai_model_id;not null" json:"ai_model_id"` + ExchangeID string `gorm:"column:exchange_id;not null" json:"exchange_id"` + StrategyID string `gorm:"column:strategy_id;default:''" json:"strategy_id"` + InitialBalance float64 `gorm:"column:initial_balance;not null" json:"initial_balance"` + ScanIntervalMinutes int `gorm:"column:scan_interval_minutes;default:3" json:"scan_interval_minutes"` + IsRunning bool `gorm:"column:is_running;default:false" json:"is_running"` + IsCrossMargin bool `gorm:"column:is_cross_margin;default:true" json:"is_cross_margin"` + ShowInCompetition bool `gorm:"column:show_in_competition;default:true" json:"show_in_competition"` + CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"` + UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime" json:"updated_at"` // Following fields are deprecated, kept for backward compatibility, new traders should use StrategyID - BTCETHLeverage int `json:"btc_eth_leverage,omitempty"` - AltcoinLeverage int `json:"altcoin_leverage,omitempty"` - TradingSymbols string `json:"trading_symbols,omitempty"` - UseCoinPool bool `json:"use_coin_pool,omitempty"` - UseOITop bool `json:"use_oi_top,omitempty"` - CustomPrompt string `json:"custom_prompt,omitempty"` - OverrideBasePrompt bool `json:"override_base_prompt,omitempty"` - SystemPromptTemplate string `json:"system_prompt_template,omitempty"` + BTCETHLeverage int `gorm:"column:btc_eth_leverage;default:5" json:"btc_eth_leverage,omitempty"` + AltcoinLeverage int `gorm:"column:altcoin_leverage;default:5" json:"altcoin_leverage,omitempty"` + TradingSymbols string `gorm:"column:trading_symbols;default:''" json:"trading_symbols,omitempty"` + UseCoinPool bool `gorm:"column:use_coin_pool;default:false" json:"use_coin_pool,omitempty"` + UseOITop bool `gorm:"column:use_oi_top;default:false" json:"use_oi_top,omitempty"` + CustomPrompt string `gorm:"column:custom_prompt;default:''" json:"custom_prompt,omitempty"` + OverrideBasePrompt bool `gorm:"column:override_base_prompt;default:false" json:"override_base_prompt,omitempty"` + SystemPromptTemplate string `gorm:"column:system_prompt_template;default:default" json:"system_prompt_template,omitempty"` +} + +// TableName returns the table name for Trader +func (Trader) TableName() string { + return "traders" } // TraderFullConfig trader full configuration (includes AI model, exchange and strategy) @@ -45,331 +54,130 @@ type TraderFullConfig struct { Trader *Trader AIModel *AIModel Exchange *Exchange - Strategy *Strategy // Associated strategy configuration + Strategy *Strategy } func (s *TraderStore) initTables() error { - _, err := s.db.Exec(` - CREATE TABLE IF NOT EXISTS traders ( - id TEXT PRIMARY KEY, - user_id TEXT NOT NULL DEFAULT 'default', - name TEXT NOT NULL, - ai_model_id TEXT NOT NULL, - exchange_id TEXT NOT NULL, - initial_balance REAL NOT NULL, - scan_interval_minutes INTEGER DEFAULT 3, - is_running BOOLEAN DEFAULT 0, - btc_eth_leverage INTEGER DEFAULT 5, - altcoin_leverage INTEGER DEFAULT 5, - trading_symbols TEXT DEFAULT '', - use_coin_pool BOOLEAN DEFAULT 0, - use_oi_top BOOLEAN DEFAULT 0, - custom_prompt TEXT DEFAULT '', - override_base_prompt BOOLEAN DEFAULT 0, - system_prompt_template TEXT DEFAULT 'default', - is_cross_margin BOOLEAN DEFAULT 1, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - `) - if err != nil { - return err + // For PostgreSQL with existing table, skip AutoMigrate + if s.db.Dialector.Name() == "postgres" { + var tableExists int64 + s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'traders'`).Scan(&tableExists) + if tableExists > 0 { + return nil + } } - - // Trigger - _, err = s.db.Exec(` - CREATE TRIGGER IF NOT EXISTS update_traders_updated_at - AFTER UPDATE ON traders - BEGIN - UPDATE traders SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id; - END - `) - if err != nil { - return err + // Use GORM AutoMigrate + if err := s.db.AutoMigrate(&Trader{}); err != nil { + return fmt.Errorf("failed to migrate traders table: %w", err) } - - // Backward compatibility - alterQueries := []string{ - `ALTER TABLE traders ADD COLUMN custom_prompt TEXT DEFAULT ''`, - `ALTER TABLE traders ADD COLUMN override_base_prompt BOOLEAN DEFAULT 0`, - `ALTER TABLE traders ADD COLUMN is_cross_margin BOOLEAN DEFAULT 1`, - `ALTER TABLE traders ADD COLUMN btc_eth_leverage INTEGER DEFAULT 5`, - `ALTER TABLE traders ADD COLUMN altcoin_leverage INTEGER DEFAULT 5`, - `ALTER TABLE traders ADD COLUMN trading_symbols TEXT DEFAULT ''`, - `ALTER TABLE traders ADD COLUMN use_coin_pool BOOLEAN DEFAULT 0`, - `ALTER TABLE traders ADD COLUMN use_oi_top BOOLEAN DEFAULT 0`, - `ALTER TABLE traders ADD COLUMN system_prompt_template TEXT DEFAULT 'default'`, - `ALTER TABLE traders ADD COLUMN strategy_id TEXT DEFAULT ''`, - `ALTER TABLE traders ADD COLUMN show_in_competition BOOLEAN DEFAULT 1`, - } - for _, q := range alterQueries { - s.db.Exec(q) - } - - // Migration: Remove FOREIGN KEY constraint from existing traders table - // SQLite doesn't support ALTER TABLE DROP CONSTRAINT, so we need to recreate the table - if err := s.migrateTradersRemoveFK(); err != nil { - // Log but don't fail - this is a best-effort migration - // The constraint may not exist in older databases - } - return nil } -// migrateTradersRemoveFK removes FOREIGN KEY constraint from traders table if it exists -func (s *TraderStore) migrateTradersRemoveFK() error { - // Check if the table has a foreign key constraint by examining the schema - var sql string - err := s.db.QueryRow(`SELECT sql FROM sqlite_master WHERE type='table' AND name='traders'`).Scan(&sql) - if err != nil { - return err - } - - // If no FOREIGN KEY in schema, no migration needed - if !strings.Contains(sql, "FOREIGN KEY") { - return nil - } - - // Recreate table without FOREIGN KEY constraint - _, err = s.db.Exec(` - -- Create new table without FOREIGN KEY - CREATE TABLE IF NOT EXISTS traders_new ( - id TEXT PRIMARY KEY, - user_id TEXT NOT NULL DEFAULT 'default', - name TEXT NOT NULL, - ai_model_id TEXT NOT NULL, - exchange_id TEXT NOT NULL, - initial_balance REAL NOT NULL, - scan_interval_minutes INTEGER DEFAULT 3, - is_running BOOLEAN DEFAULT 0, - btc_eth_leverage INTEGER DEFAULT 5, - altcoin_leverage INTEGER DEFAULT 5, - trading_symbols TEXT DEFAULT '', - use_coin_pool BOOLEAN DEFAULT 0, - use_oi_top BOOLEAN DEFAULT 0, - custom_prompt TEXT DEFAULT '', - override_base_prompt BOOLEAN DEFAULT 0, - system_prompt_template TEXT DEFAULT 'default', - is_cross_margin BOOLEAN DEFAULT 1, - strategy_id TEXT DEFAULT '', - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP - ); - - -- Copy data from old table - INSERT OR IGNORE INTO traders_new - SELECT id, user_id, name, ai_model_id, exchange_id, initial_balance, - scan_interval_minutes, is_running, btc_eth_leverage, altcoin_leverage, - trading_symbols, use_coin_pool, use_oi_top, custom_prompt, - override_base_prompt, system_prompt_template, is_cross_margin, - COALESCE(strategy_id, ''), created_at, updated_at - FROM traders; - - -- Drop old table - DROP TABLE traders; - - -- Rename new table - ALTER TABLE traders_new RENAME TO traders; - `) - - if err != nil { - return err - } - - // Recreate trigger - _, err = s.db.Exec(` - CREATE TRIGGER IF NOT EXISTS update_traders_updated_at - AFTER UPDATE ON traders - BEGIN - UPDATE traders SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id; - END - `) - - return err -} - -func (s *TraderStore) decrypt(encrypted string) string { - if s.decryptFunc != nil { - return s.decryptFunc(encrypted) - } - return encrypted -} - // Create creates trader func (s *TraderStore) Create(trader *Trader) error { - _, err := s.db.Exec(` - INSERT INTO traders (id, user_id, name, ai_model_id, exchange_id, strategy_id, initial_balance, - scan_interval_minutes, is_running, is_cross_margin, show_in_competition, - btc_eth_leverage, altcoin_leverage, trading_symbols, use_coin_pool, - use_oi_top, custom_prompt, override_base_prompt, system_prompt_template) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - `, trader.ID, trader.UserID, trader.Name, trader.AIModelID, trader.ExchangeID, trader.StrategyID, - trader.InitialBalance, trader.ScanIntervalMinutes, trader.IsRunning, trader.IsCrossMargin, trader.ShowInCompetition, - trader.BTCETHLeverage, trader.AltcoinLeverage, trader.TradingSymbols, trader.UseCoinPool, - trader.UseOITop, trader.CustomPrompt, trader.OverrideBasePrompt, trader.SystemPromptTemplate) - return err + return s.db.Create(trader).Error } // List gets user's trader list func (s *TraderStore) List(userID string) ([]*Trader, error) { - rows, err := s.db.Query(` - SELECT id, user_id, name, ai_model_id, exchange_id, COALESCE(strategy_id, ''), - initial_balance, scan_interval_minutes, is_running, COALESCE(is_cross_margin, 1), - COALESCE(show_in_competition, 1), - COALESCE(btc_eth_leverage, 5), COALESCE(altcoin_leverage, 5), COALESCE(trading_symbols, ''), - COALESCE(use_coin_pool, 0), COALESCE(use_oi_top, 0), COALESCE(custom_prompt, ''), - COALESCE(override_base_prompt, 0), COALESCE(system_prompt_template, 'default'), - created_at, updated_at - FROM traders WHERE user_id = ? ORDER BY created_at DESC - `, userID) + var traders []*Trader + err := s.db.Where("user_id = ?", userID). + Order("created_at DESC"). + Find(&traders).Error if err != nil { return nil, err } - defer rows.Close() - - var traders []*Trader - for rows.Next() { - var t Trader - var createdAt, updatedAt string - err := rows.Scan( - &t.ID, &t.UserID, &t.Name, &t.AIModelID, &t.ExchangeID, &t.StrategyID, - &t.InitialBalance, &t.ScanIntervalMinutes, &t.IsRunning, &t.IsCrossMargin, - &t.ShowInCompetition, - &t.BTCETHLeverage, &t.AltcoinLeverage, &t.TradingSymbols, - &t.UseCoinPool, &t.UseOITop, &t.CustomPrompt, &t.OverrideBasePrompt, - &t.SystemPromptTemplate, &createdAt, &updatedAt, - ) - if err != nil { - return nil, err - } - t.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - t.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) - traders = append(traders, &t) - } return traders, nil } // UpdateStatus updates trader running status func (s *TraderStore) UpdateStatus(userID, id string, isRunning bool) error { - _, err := s.db.Exec(`UPDATE traders SET is_running = ? WHERE id = ? AND user_id = ?`, isRunning, id, userID) - return err + return s.db.Model(&Trader{}). + Where("id = ? AND user_id = ?", id, userID). + Update("is_running", isRunning).Error } // UpdateShowInCompetition updates trader competition visibility func (s *TraderStore) UpdateShowInCompetition(userID, id string, showInCompetition bool) error { - _, err := s.db.Exec(`UPDATE traders SET show_in_competition = ? WHERE id = ? AND user_id = ?`, showInCompetition, id, userID) - return err + return s.db.Model(&Trader{}). + Where("id = ? AND user_id = ?", id, userID). + Update("show_in_competition", showInCompetition).Error } // Update updates trader configuration func (s *TraderStore) Update(trader *Trader) error { fmt.Printf("📝 TraderStore.Update: ID=%s, Name=%s, AIModelID=%s, StrategyID=%s\n", trader.ID, trader.Name, trader.AIModelID, trader.StrategyID) - _, err := s.db.Exec(` - UPDATE traders SET - name = ?, - ai_model_id = ?, - exchange_id = ?, - strategy_id = ?, - initial_balance = CASE WHEN ? > 0 THEN ? ELSE initial_balance END, - scan_interval_minutes = CASE WHEN ? > 0 THEN ? ELSE scan_interval_minutes END, - is_cross_margin = ?, - show_in_competition = ?, - updated_at = CURRENT_TIMESTAMP - WHERE id = ? AND user_id = ? - `, trader.Name, trader.AIModelID, trader.ExchangeID, trader.StrategyID, - trader.InitialBalance, trader.InitialBalance, - trader.ScanIntervalMinutes, trader.ScanIntervalMinutes, - trader.IsCrossMargin, trader.ShowInCompetition, - trader.ID, trader.UserID) - return err + + updates := map[string]interface{}{ + "name": trader.Name, + "ai_model_id": trader.AIModelID, + "exchange_id": trader.ExchangeID, + "strategy_id": trader.StrategyID, + "is_cross_margin": trader.IsCrossMargin, + "show_in_competition": trader.ShowInCompetition, + } + + // Only update these if > 0 + if trader.InitialBalance > 0 { + updates["initial_balance"] = trader.InitialBalance + } + if trader.ScanIntervalMinutes > 0 { + updates["scan_interval_minutes"] = trader.ScanIntervalMinutes + } + + return s.db.Model(&Trader{}). + Where("id = ? AND user_id = ?", trader.ID, trader.UserID). + Updates(updates).Error } // UpdateInitialBalance updates initial balance func (s *TraderStore) UpdateInitialBalance(userID, id string, newBalance float64) error { - _, err := s.db.Exec(`UPDATE traders SET initial_balance = ? WHERE id = ? AND user_id = ?`, newBalance, id, userID) - return err + return s.db.Model(&Trader{}). + Where("id = ? AND user_id = ?", id, userID). + Update("initial_balance", newBalance).Error } // UpdateCustomPrompt updates custom prompt func (s *TraderStore) UpdateCustomPrompt(userID, id string, customPrompt string, overrideBase bool) error { - _, err := s.db.Exec(`UPDATE traders SET custom_prompt = ?, override_base_prompt = ? WHERE id = ? AND user_id = ?`, - customPrompt, overrideBase, id, userID) - return err + return s.db.Model(&Trader{}). + Where("id = ? AND user_id = ?", id, userID). + Updates(map[string]interface{}{ + "custom_prompt": customPrompt, + "override_base_prompt": overrideBase, + }).Error } // Delete deletes trader and associated data func (s *TraderStore) Delete(userID, id string) error { // Delete associated equity snapshots first - _, _ = s.db.Exec(`DELETE FROM trader_equity_snapshots WHERE trader_id = ?`, id) + s.db.Where("trader_id = ?", id).Delete(&EquitySnapshot{}) // Delete the trader - _, err := s.db.Exec(`DELETE FROM traders WHERE id = ? AND user_id = ?`, id, userID) - return err + return s.db.Where("id = ? AND user_id = ?", id, userID).Delete(&Trader{}).Error } // GetFullConfig gets trader full configuration func (s *TraderStore) GetFullConfig(userID, traderID string) (*TraderFullConfig, error) { var trader Trader - var aiModel AIModel - var exchange Exchange - var traderCreatedAt, traderUpdatedAt string - var aiModelCreatedAt, aiModelUpdatedAt string - var exchangeCreatedAt, exchangeUpdatedAt string - - err := s.db.QueryRow(` - SELECT - t.id, t.user_id, t.name, t.ai_model_id, t.exchange_id, COALESCE(t.strategy_id, ''), - t.initial_balance, t.scan_interval_minutes, t.is_running, COALESCE(t.is_cross_margin, 1), - COALESCE(t.btc_eth_leverage, 5), COALESCE(t.altcoin_leverage, 5), COALESCE(t.trading_symbols, ''), - COALESCE(t.use_coin_pool, 0), COALESCE(t.use_oi_top, 0), COALESCE(t.custom_prompt, ''), - COALESCE(t.override_base_prompt, 0), COALESCE(t.system_prompt_template, 'default'), - t.created_at, t.updated_at, - a.id, a.user_id, a.name, a.provider, a.enabled, a.api_key, - COALESCE(a.custom_api_url, ''), COALESCE(a.custom_model_name, ''), a.created_at, a.updated_at, - e.id, COALESCE(e.exchange_type, '') as exchange_type, COALESCE(e.account_name, '') as account_name, - e.user_id, e.name, e.type, e.enabled, e.api_key, e.secret_key, COALESCE(e.passphrase, ''), e.testnet, - COALESCE(e.hyperliquid_wallet_addr, ''), COALESCE(e.aster_user, ''), COALESCE(e.aster_signer, ''), - COALESCE(e.aster_private_key, ''), COALESCE(e.lighter_wallet_addr, ''), COALESCE(e.lighter_private_key, ''), - COALESCE(e.lighter_api_key_private_key, ''), COALESCE(e.lighter_api_key_index, 0), e.created_at, e.updated_at - FROM traders t - JOIN ai_models a ON t.ai_model_id = a.id AND t.user_id = a.user_id - JOIN exchanges e ON t.exchange_id = e.id AND t.user_id = e.user_id - WHERE t.id = ? AND t.user_id = ? - `, traderID, userID).Scan( - &trader.ID, &trader.UserID, &trader.Name, &trader.AIModelID, &trader.ExchangeID, &trader.StrategyID, - &trader.InitialBalance, &trader.ScanIntervalMinutes, &trader.IsRunning, &trader.IsCrossMargin, - &trader.BTCETHLeverage, &trader.AltcoinLeverage, &trader.TradingSymbols, - &trader.UseCoinPool, &trader.UseOITop, &trader.CustomPrompt, &trader.OverrideBasePrompt, - &trader.SystemPromptTemplate, &traderCreatedAt, &traderUpdatedAt, - &aiModel.ID, &aiModel.UserID, &aiModel.Name, &aiModel.Provider, &aiModel.Enabled, &aiModel.APIKey, - &aiModel.CustomAPIURL, &aiModel.CustomModelName, &aiModelCreatedAt, &aiModelUpdatedAt, - &exchange.ID, &exchange.ExchangeType, &exchange.AccountName, - &exchange.UserID, &exchange.Name, &exchange.Type, &exchange.Enabled, - &exchange.APIKey, &exchange.SecretKey, &exchange.Passphrase, &exchange.Testnet, &exchange.HyperliquidWalletAddr, - &exchange.AsterUser, &exchange.AsterSigner, &exchange.AsterPrivateKey, - &exchange.LighterWalletAddr, &exchange.LighterPrivateKey, &exchange.LighterAPIKeyPrivateKey, &exchange.LighterAPIKeyIndex, - &exchangeCreatedAt, &exchangeUpdatedAt, - ) + err := s.db.Where("id = ? AND user_id = ?", traderID, userID).First(&trader).Error if err != nil { return nil, err } - trader.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", traderCreatedAt) - trader.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", traderUpdatedAt) - aiModel.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", aiModelCreatedAt) - aiModel.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", aiModelUpdatedAt) - exchange.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", exchangeCreatedAt) - exchange.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", exchangeUpdatedAt) + // Get AI model + var aiModel AIModel + err = s.db.Where("id = ? AND user_id = ?", trader.AIModelID, userID).First(&aiModel).Error + if err != nil { + return nil, fmt.Errorf("failed to get AI model: %w", err) + } - // Decrypt - aiModel.APIKey = s.decrypt(aiModel.APIKey) - exchange.APIKey = s.decrypt(exchange.APIKey) - exchange.SecretKey = s.decrypt(exchange.SecretKey) - exchange.Passphrase = s.decrypt(exchange.Passphrase) - exchange.AsterPrivateKey = s.decrypt(exchange.AsterPrivateKey) - exchange.LighterPrivateKey = s.decrypt(exchange.LighterPrivateKey) - exchange.LighterAPIKeyPrivateKey = s.decrypt(exchange.LighterAPIKeyPrivateKey) + // Get exchange + var exchange Exchange + err = s.db.Where("id = ? AND user_id = ?", trader.ExchangeID, userID).First(&exchange).Error + if err != nil { + return nil, fmt.Errorf("failed to get exchange: %w", err) + } // Load associated strategy var strategy *Strategy @@ -392,119 +200,48 @@ func (s *TraderStore) GetFullConfig(userID, traderID string) (*TraderFullConfig, // getStrategyByID internal method: gets strategy by ID func (s *TraderStore) getStrategyByID(userID, strategyID string) (*Strategy, error) { var strategy Strategy - var createdAt, updatedAt string - err := s.db.QueryRow(` - SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at - FROM strategies WHERE id = ? AND (user_id = ? OR is_default = 1) - `, strategyID, userID).Scan( - &strategy.ID, &strategy.UserID, &strategy.Name, &strategy.Description, - &strategy.IsActive, &strategy.IsDefault, &strategy.Config, &createdAt, &updatedAt, - ) + err := s.db.Where("id = ? AND (user_id = ? OR is_default = ?)", strategyID, userID, true). + First(&strategy).Error if err != nil { return nil, err } - strategy.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - strategy.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) return &strategy, nil } // getActiveOrDefaultStrategy internal method: gets user's active strategy or system default strategy func (s *TraderStore) getActiveOrDefaultStrategy(userID string) (*Strategy, error) { var strategy Strategy - var createdAt, updatedAt string // First try to get user's active strategy - err := s.db.QueryRow(` - SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at - FROM strategies WHERE user_id = ? AND is_active = 1 - `, userID).Scan( - &strategy.ID, &strategy.UserID, &strategy.Name, &strategy.Description, - &strategy.IsActive, &strategy.IsDefault, &strategy.Config, &createdAt, &updatedAt, - ) + err := s.db.Where("user_id = ? AND is_active = ?", userID, true).First(&strategy).Error if err == nil { - strategy.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - strategy.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) return &strategy, nil } // Fallback to system default strategy - err = s.db.QueryRow(` - SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at - FROM strategies WHERE is_default = 1 LIMIT 1 - `).Scan( - &strategy.ID, &strategy.UserID, &strategy.Name, &strategy.Description, - &strategy.IsActive, &strategy.IsDefault, &strategy.Config, &createdAt, &updatedAt, - ) + err = s.db.Where("is_default = ?", true).First(&strategy).Error if err != nil { return nil, err } - strategy.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - strategy.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) return &strategy, nil } -// ListAll gets all users' trader list // GetByID gets a trader by ID without requiring userID (for public APIs) func (s *TraderStore) GetByID(traderID string) (*Trader, error) { - var t Trader - var createdAt, updatedAt string - err := s.db.QueryRow(` - SELECT id, user_id, name, ai_model_id, exchange_id, COALESCE(strategy_id, ''), - initial_balance, scan_interval_minutes, is_running, COALESCE(is_cross_margin, 1), - COALESCE(btc_eth_leverage, 5), COALESCE(altcoin_leverage, 5), COALESCE(trading_symbols, ''), - COALESCE(use_coin_pool, 0), COALESCE(use_oi_top, 0), COALESCE(custom_prompt, ''), - COALESCE(override_base_prompt, 0), COALESCE(system_prompt_template, 'default'), - created_at, updated_at - FROM traders WHERE id = ? - `, traderID).Scan( - &t.ID, &t.UserID, &t.Name, &t.AIModelID, &t.ExchangeID, &t.StrategyID, - &t.InitialBalance, &t.ScanIntervalMinutes, &t.IsRunning, &t.IsCrossMargin, - &t.BTCETHLeverage, &t.AltcoinLeverage, &t.TradingSymbols, - &t.UseCoinPool, &t.UseOITop, &t.CustomPrompt, &t.OverrideBasePrompt, - &t.SystemPromptTemplate, &createdAt, &updatedAt, - ) + var trader Trader + err := s.db.Where("id = ?", traderID).First(&trader).Error if err != nil { return nil, err } - t.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - t.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) - return &t, nil + return &trader, nil } +// ListAll gets all traders func (s *TraderStore) ListAll() ([]*Trader, error) { - rows, err := s.db.Query(` - SELECT id, user_id, name, ai_model_id, exchange_id, COALESCE(strategy_id, ''), - initial_balance, scan_interval_minutes, is_running, COALESCE(is_cross_margin, 1), - COALESCE(show_in_competition, 1), - COALESCE(btc_eth_leverage, 5), COALESCE(altcoin_leverage, 5), COALESCE(trading_symbols, ''), - COALESCE(use_coin_pool, 0), COALESCE(use_oi_top, 0), COALESCE(custom_prompt, ''), - COALESCE(override_base_prompt, 0), COALESCE(system_prompt_template, 'default'), - created_at, updated_at - FROM traders ORDER BY created_at DESC - `) + var traders []*Trader + err := s.db.Order("created_at DESC").Find(&traders).Error if err != nil { return nil, err } - defer rows.Close() - - var traders []*Trader - for rows.Next() { - var t Trader - var createdAt, updatedAt string - err := rows.Scan( - &t.ID, &t.UserID, &t.Name, &t.AIModelID, &t.ExchangeID, &t.StrategyID, - &t.InitialBalance, &t.ScanIntervalMinutes, &t.IsRunning, &t.IsCrossMargin, - &t.ShowInCompetition, - &t.BTCETHLeverage, &t.AltcoinLeverage, &t.TradingSymbols, - &t.UseCoinPool, &t.UseOITop, &t.CustomPrompt, &t.OverrideBasePrompt, - &t.SystemPromptTemplate, &createdAt, &updatedAt, - ) - if err != nil { - return nil, err - } - t.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - t.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) - traders = append(traders, &t) - } return traders, nil } diff --git a/store/user.go b/store/user.go index 8f53d236..b840f22c 100644 --- a/store/user.go +++ b/store/user.go @@ -2,27 +2,30 @@ package store import ( "crypto/rand" - "database/sql" "encoding/base32" "time" + + "gorm.io/gorm" ) // UserStore user storage type UserStore struct { - db *sql.DB + db *gorm.DB } -// User user +// User user model type User struct { - ID string `json:"id"` - Email string `json:"email"` - PasswordHash string `json:"-"` - OTPSecret string `json:"-"` - OTPVerified bool `json:"otp_verified"` + ID string `gorm:"primaryKey" json:"id"` + Email string `gorm:"uniqueIndex:idx_users_email;not null" json:"email"` + PasswordHash string `gorm:"column:password_hash;not null" json:"-"` + OTPSecret string `gorm:"column:otp_secret" json:"-"` + OTPVerified bool `gorm:"column:otp_verified;default:false" json:"otp_verified"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } +func (User) TableName() string { return "users" } + // GenerateOTPSecret generates OTP secret func GenerateOTPSecret() (string, error) { secret := make([]byte, 20) @@ -33,131 +36,101 @@ func GenerateOTPSecret() (string, error) { return base32.StdEncoding.EncodeToString(secret), nil } +// NewUserStore creates a new UserStore +func NewUserStore(db *gorm.DB) *UserStore { + return &UserStore{db: db} +} + func (s *UserStore) initTables() error { - _, err := s.db.Exec(` - CREATE TABLE IF NOT EXISTS users ( - id TEXT PRIMARY KEY, - email TEXT UNIQUE NOT NULL, - password_hash TEXT NOT NULL, - otp_secret TEXT, - otp_verified BOOLEAN DEFAULT 0, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - `) - if err != nil { - return err - } + // For PostgreSQL with existing table, skip AutoMigrate to avoid index conflicts + if s.db.Dialector.Name() == "postgres" { + var tableExists int64 + s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'users'`).Scan(&tableExists) - // Trigger - _, err = s.db.Exec(` - CREATE TRIGGER IF NOT EXISTS update_users_updated_at - AFTER UPDATE ON users - BEGIN - UPDATE users SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id; - END - `) - if err != nil { - return err - } + if tableExists > 0 { + // Table exists - manually ensure all columns exist + // Core columns (should already exist) + s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS email TEXT NOT NULL DEFAULT ''`) + s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS password_hash TEXT NOT NULL DEFAULT ''`) + s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP`) + s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP`) + // OTP columns (added later) + s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS otp_secret TEXT DEFAULT ''`) + s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS otp_verified BOOLEAN DEFAULT FALSE`) - return nil + // Ensure unique index exists on email (don't care about the name) + var indexExists int64 + s.db.Raw(` + SELECT COUNT(*) FROM pg_indexes + WHERE tablename = 'users' AND indexdef LIKE '%email%' AND indexdef LIKE '%UNIQUE%' + `).Scan(&indexExists) + + if indexExists == 0 { + s.db.Exec("CREATE UNIQUE INDEX idx_users_email ON users(email)") + } + + return nil + } + } + return s.db.AutoMigrate(&User{}) } // Create creates user func (s *UserStore) Create(user *User) error { - _, err := s.db.Exec(` - INSERT INTO users (id, email, password_hash, otp_secret, otp_verified) - VALUES (?, ?, ?, ?, ?) - `, user.ID, user.Email, user.PasswordHash, user.OTPSecret, user.OTPVerified) - return err + return s.db.Create(user).Error } // GetByEmail gets user by email func (s *UserStore) GetByEmail(email string) (*User, error) { var user User - var createdAt, updatedAt string - err := s.db.QueryRow(` - SELECT id, email, password_hash, otp_secret, otp_verified, created_at, updated_at - FROM users WHERE email = ? - `, email).Scan( - &user.ID, &user.Email, &user.PasswordHash, &user.OTPSecret, - &user.OTPVerified, &createdAt, &updatedAt, - ) + err := s.db.Where("email = ?", email).First(&user).Error if err != nil { return nil, err } - user.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - user.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) return &user, nil } // GetByID gets user by ID func (s *UserStore) GetByID(userID string) (*User, error) { var user User - var createdAt, updatedAt string - err := s.db.QueryRow(` - SELECT id, email, password_hash, otp_secret, otp_verified, created_at, updated_at - FROM users WHERE id = ? - `, userID).Scan( - &user.ID, &user.Email, &user.PasswordHash, &user.OTPSecret, - &user.OTPVerified, &createdAt, &updatedAt, - ) + err := s.db.Where("id = ?", userID).First(&user).Error if err != nil { return nil, err } - user.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - user.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) return &user, nil } // Count returns the total number of users func (s *UserStore) Count() (int, error) { - var count int - err := s.db.QueryRow(`SELECT COUNT(*) FROM users`).Scan(&count) - return count, err + var count int64 + err := s.db.Model(&User{}).Count(&count).Error + return int(count), err } // GetAllIDs gets all user IDs func (s *UserStore) GetAllIDs() ([]string, error) { - rows, err := s.db.Query(`SELECT id FROM users ORDER BY id`) - if err != nil { - return nil, err - } - defer rows.Close() - var userIDs []string - for rows.Next() { - var userID string - if err := rows.Scan(&userID); err != nil { - return nil, err - } - userIDs = append(userIDs, userID) - } - return userIDs, nil + err := s.db.Model(&User{}).Order("id").Pluck("id", &userIDs).Error + return userIDs, err } // UpdateOTPVerified updates OTP verification status func (s *UserStore) UpdateOTPVerified(userID string, verified bool) error { - _, err := s.db.Exec(`UPDATE users SET otp_verified = ? WHERE id = ?`, verified, userID) - return err + return s.db.Model(&User{}).Where("id = ?", userID).Update("otp_verified", verified).Error } // UpdatePassword updates password func (s *UserStore) UpdatePassword(userID, passwordHash string) error { - _, err := s.db.Exec(` - UPDATE users SET password_hash = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ? - `, passwordHash, userID) - return err + return s.db.Model(&User{}).Where("id = ?", userID).Updates(map[string]interface{}{ + "password_hash": passwordHash, + "updated_at": time.Now(), + }).Error } // EnsureAdmin ensures admin user exists func (s *UserStore) EnsureAdmin() error { - var count int - err := s.db.QueryRow(`SELECT COUNT(*) FROM users WHERE id = 'admin'`).Scan(&count) - if err != nil { - return err - } + var count int64 + s.db.Model(&User{}).Where("id = ?", "admin").Count(&count) if count > 0 { return nil } diff --git a/trader/exchange_sync_test.go b/trader/exchange_sync_test.go index 3a996f50..7811d4cd 100644 --- a/trader/exchange_sync_test.go +++ b/trader/exchange_sync_test.go @@ -1,12 +1,13 @@ package trader import ( - "database/sql" "nofx/store" "testing" "time" - _ "github.com/mattn/go-sqlite3" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" ) // TestScenario represents a trading scenario to test @@ -116,11 +117,12 @@ func runStandardTests(t *testing.T, exchangeName string) { for _, scenario := range scenarios { t.Run(scenario.Name, func(t *testing.T) { // Setup database - db, err := sql.Open("sqlite3", ":memory:") + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) if err != nil { t.Fatalf("Failed to create test database: %v", err) } - defer db.Close() positionStore := store.NewPositionStore(db) if err := positionStore.InitTables(); err != nil { @@ -199,11 +201,12 @@ func TestAllExchangesStandardScenarios(t *testing.T) { // TestPositionAccumulationBug tests that positions don't accumulate incorrectly func TestPositionAccumulationBug(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) if err != nil { t.Fatalf("Failed to create test database: %v", err) } - defer db.Close() positionStore := store.NewPositionStore(db) if err := positionStore.InitTables(); err != nil { @@ -283,11 +286,12 @@ func TestPositionAccumulationBug(t *testing.T) { // TestQuantityPrecision tests handling of quantity precision issues func TestQuantityPrecision(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) if err != nil { t.Fatalf("Failed to create test database: %v", err) } - defer db.Close() positionStore := store.NewPositionStore(db) if err := positionStore.InitTables(); err != nil { diff --git a/trader/hyperliquid_sync_test.go b/trader/hyperliquid_sync_test.go index 92683e7f..fce8fdc8 100644 --- a/trader/hyperliquid_sync_test.go +++ b/trader/hyperliquid_sync_test.go @@ -1,13 +1,14 @@ package trader import ( - "database/sql" "math" "nofx/store" "testing" "time" - _ "github.com/mattn/go-sqlite3" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" ) // TestHyperliquidOrderDirectionParsing tests Dir field parsing @@ -75,11 +76,12 @@ func TestHyperliquidOrderDirectionParsing(t *testing.T) { // TestHyperliquidPositionBuilding tests the complete flow of position building func TestHyperliquidPositionBuilding(t *testing.T) { // Setup in-memory database - db, err := sql.Open("sqlite3", ":memory:") + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) if err != nil { t.Fatalf("Failed to create test database: %v", err) } - defer db.Close() // Initialize stores positionStore := store.NewPositionStore(db) @@ -304,11 +306,12 @@ func TestHyperliquidPositionBuilding(t *testing.T) { // TestHyperliquidBugScenario tests the exact bug we fixed func TestHyperliquidBugScenario(t *testing.T) { // Setup database - db, err := sql.Open("sqlite3", ":memory:") + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) if err != nil { t.Fatalf("Failed to create test database: %v", err) } - defer db.Close() positionStore := store.NewPositionStore(db) if err := positionStore.InitTables(); err != nil { diff --git a/web/src/components/landing/core/AgentGrid.tsx b/web/src/components/landing/core/AgentGrid.tsx index 691e7a6e..09152ab7 100644 --- a/web/src/components/landing/core/AgentGrid.tsx +++ b/web/src/components/landing/core/AgentGrid.tsx @@ -1,5 +1,5 @@ import { motion } from 'framer-motion' -import { Bot, TrendingUp, Layers, Zap, Hexagon, Crosshair } from 'lucide-react' +import { TrendingUp, Layers, Zap, Hexagon, Crosshair } from 'lucide-react' const agents = [ { diff --git a/web/src/components/landing/core/LiveFeed.tsx b/web/src/components/landing/core/LiveFeed.tsx index 3d4e2e29..84d0b662 100644 --- a/web/src/components/landing/core/LiveFeed.tsx +++ b/web/src/components/landing/core/LiveFeed.tsx @@ -1,8 +1,15 @@ import { motion } from 'framer-motion' -import { Activity, BarChart3, Globe, Wifi, Server, Database, Lock } from 'lucide-react' import { useState, useEffect } from 'react' -const generateLog = (id) => { +interface LogEntry { + id: number + time: string + type: string + msg: string + color: string +} + +const generateLog = (id: number): LogEntry => { const types = ['EXE', 'ARB', 'LIQ', 'NET', 'SYS'] const pairs = ['BTC-USDT', 'ETH-PERP', 'SOL-USDT', 'BNB-BUSD'] const actions = ['BUY', 'SELL', 'SHORT', 'LONG'] @@ -37,7 +44,7 @@ const generateLog = (id) => { } export default function LiveFeed() { - const [logs, setLogs] = useState([]) + const [logs, setLogs] = useState([]) useEffect(() => { // Initial population diff --git a/web/src/components/landing/core/TerminalHero.tsx b/web/src/components/landing/core/TerminalHero.tsx index df375857..7e56e103 100644 --- a/web/src/components/landing/core/TerminalHero.tsx +++ b/web/src/components/landing/core/TerminalHero.tsx @@ -159,11 +159,24 @@ export default function TerminalHero() { TRADING -

+

The World's First Open-Source Agentic Trading OS. Deploy autonomous high-frequency trading agents powered by advanced LLMs.

+ {/* Market Access Strip - Prominent Display */} +
+ {['CRYPTO', 'US STOCKS', 'FOREX', 'METALS'].map((market) => ( +
+ + + + + {market} +
+ ))} +
+ {/* Command Line Input Simulation */}
document.getElementById('market-scanner')?.scrollIntoView({ behavior: 'smooth' })}> @@ -269,7 +282,7 @@ export default function TerminalHero() { import { OFFICIAL_LINKS } from '../../../constants/branding' function CommunityStats() { - const { stars, forks, contributors, isLoading, error } = useGitHubStats('tinkle-community', 'nofx') + const { stars, forks, contributors, isLoading, error } = useGitHubStats('NoFxAiOS', 'nofx') const stats = [ {