feat: migrate store layer to GORM with PostgreSQL support

- Migrate all store packages from raw database/sql to GORM ORM
- Add PostgreSQL support alongside SQLite
- Move EncryptedString type to crypto package for cleaner architecture
- Add automatic encryption/decryption for sensitive fields (API keys, secrets)
- Fix PostgreSQL AutoMigrate conflicts by skipping existing tables
- Fix duplicate /klines route registration
- Update tests to use GORM database connections
- Add database configuration support in config package
This commit is contained in:
tinkle-community
2026-01-01 19:32:49 +08:00
parent d547863ebb
commit 2d272bb7b8
32 changed files with 2573 additions and 3771 deletions
+13
View File
@@ -55,3 +55,16 @@ TRANSPORT_ENCRYPTION=false
# Telegram notifications (optional)
# TELEGRAM_BOT_TOKEN=your-bot-token
# TELEGRAM_CHAT_ID=your-chat-id
DB_TYPE=postgres
DB_HOST=10.
DB_PORT=5432
DB_USER=nofx_user
DB_PASSWORD=
DB_NAME=nofx
DB_SSLMODE=disable
# 数据库配置 - SQLite(默认)
DB_TYPE=sqlite
DB_PATH=data/data.db
+1 -1
View File
@@ -824,7 +824,7 @@ func (s *Server) hydrateBacktestAIConfig(cfg *backtest.BacktestConfig) error {
return fmt.Errorf("AI model %s is not enabled yet", model.Name)
}
apiKey := strings.TrimSpace(model.APIKey)
apiKey := strings.TrimSpace(string(model.APIKey))
if apiKey == "" {
return fmt.Errorf("AI model %s is missing API Key, please configure it in the system first", model.Name)
}
+43 -41
View File
@@ -54,7 +54,7 @@ func NewServer(traderManager *manager.TraderManager, st *store.Store, cryptoServ
cryptoHandler := NewCryptoHandler(cryptoService)
// Create debate store and handler
debateStore := store.NewDebateStore(st.DB())
debateStore := store.NewDebateStore(st.GormDB())
if err := debateStore.InitSchema(); err != nil {
logger.Errorf("Failed to initialize debate schema: %v", err)
}
@@ -125,7 +125,6 @@ func (s *Server) setupRoutes() {
// Market data (no authentication required)
api.GET("/klines", s.handleKlines)
api.GET("/klines", s.handleKlines)
api.GET("/symbols", s.handleSymbols)
// Authentication related routes (no authentication required)
@@ -576,12 +575,13 @@ func (s *Server) handleCreateTrader(c *gin.Context) {
var createErr error
// Use ExchangeType (e.g., "binance") instead of ID (UUID)
// Convert EncryptedString fields to string
switch exchangeCfg.ExchangeType {
case "binance":
tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey, userID)
tempTrader = trader.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID)
case "hyperliquid":
tempTrader, createErr = trader.NewHyperliquidTrader(
exchangeCfg.APIKey, // private key
string(exchangeCfg.APIKey), // private key
exchangeCfg.HyperliquidWalletAddr,
exchangeCfg.Testnet,
)
@@ -589,31 +589,31 @@ func (s *Server) handleCreateTrader(c *gin.Context) {
tempTrader, createErr = trader.NewAsterTrader(
exchangeCfg.AsterUser,
exchangeCfg.AsterSigner,
exchangeCfg.AsterPrivateKey,
string(exchangeCfg.AsterPrivateKey),
)
case "bybit":
tempTrader = trader.NewBybitTrader(
exchangeCfg.APIKey,
exchangeCfg.SecretKey,
string(exchangeCfg.APIKey),
string(exchangeCfg.SecretKey),
)
case "okx":
tempTrader = trader.NewOKXTrader(
exchangeCfg.APIKey,
exchangeCfg.SecretKey,
exchangeCfg.Passphrase,
string(exchangeCfg.APIKey),
string(exchangeCfg.SecretKey),
string(exchangeCfg.Passphrase),
)
case "bitget":
tempTrader = trader.NewBitgetTrader(
exchangeCfg.APIKey,
exchangeCfg.SecretKey,
exchangeCfg.Passphrase,
string(exchangeCfg.APIKey),
string(exchangeCfg.SecretKey),
string(exchangeCfg.Passphrase),
)
case "lighter":
if exchangeCfg.LighterWalletAddr != "" && exchangeCfg.LighterAPIKeyPrivateKey != "" {
if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" {
// Lighter only supports mainnet
tempTrader, createErr = trader.NewLighterTraderV2(
exchangeCfg.LighterWalletAddr,
exchangeCfg.LighterAPIKeyPrivateKey,
string(exchangeCfg.LighterAPIKeyPrivateKey),
exchangeCfg.LighterAPIKeyIndex,
false, // Always use mainnet for Lighter
)
@@ -1095,12 +1095,13 @@ func (s *Server) handleSyncBalance(c *gin.Context) {
var createErr error
// Use ExchangeType (e.g., "binance") instead of ExchangeID (which is now UUID)
// Convert EncryptedString fields to string
switch exchangeCfg.ExchangeType {
case "binance":
tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey, userID)
tempTrader = trader.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID)
case "hyperliquid":
tempTrader, createErr = trader.NewHyperliquidTrader(
exchangeCfg.APIKey,
string(exchangeCfg.APIKey),
exchangeCfg.HyperliquidWalletAddr,
exchangeCfg.Testnet,
)
@@ -1108,31 +1109,31 @@ func (s *Server) handleSyncBalance(c *gin.Context) {
tempTrader, createErr = trader.NewAsterTrader(
exchangeCfg.AsterUser,
exchangeCfg.AsterSigner,
exchangeCfg.AsterPrivateKey,
string(exchangeCfg.AsterPrivateKey),
)
case "bybit":
tempTrader = trader.NewBybitTrader(
exchangeCfg.APIKey,
exchangeCfg.SecretKey,
string(exchangeCfg.APIKey),
string(exchangeCfg.SecretKey),
)
case "okx":
tempTrader = trader.NewOKXTrader(
exchangeCfg.APIKey,
exchangeCfg.SecretKey,
exchangeCfg.Passphrase,
string(exchangeCfg.APIKey),
string(exchangeCfg.SecretKey),
string(exchangeCfg.Passphrase),
)
case "bitget":
tempTrader = trader.NewBitgetTrader(
exchangeCfg.APIKey,
exchangeCfg.SecretKey,
exchangeCfg.Passphrase,
string(exchangeCfg.APIKey),
string(exchangeCfg.SecretKey),
string(exchangeCfg.Passphrase),
)
case "lighter":
if exchangeCfg.LighterWalletAddr != "" && exchangeCfg.LighterAPIKeyPrivateKey != "" {
if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" {
// Lighter only supports mainnet
tempTrader, createErr = trader.NewLighterTraderV2(
exchangeCfg.LighterWalletAddr,
exchangeCfg.LighterAPIKeyPrivateKey,
string(exchangeCfg.LighterAPIKeyPrivateKey),
exchangeCfg.LighterAPIKeyIndex,
false, // Always use mainnet for Lighter
)
@@ -1246,12 +1247,13 @@ func (s *Server) handleClosePosition(c *gin.Context) {
var createErr error
// Use ExchangeType (e.g., "binance") instead of ExchangeID (which is now UUID)
// Convert EncryptedString fields to string
switch exchangeCfg.ExchangeType {
case "binance":
tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey, userID)
tempTrader = trader.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID)
case "hyperliquid":
tempTrader, createErr = trader.NewHyperliquidTrader(
exchangeCfg.APIKey,
string(exchangeCfg.APIKey),
exchangeCfg.HyperliquidWalletAddr,
exchangeCfg.Testnet,
)
@@ -1259,31 +1261,31 @@ func (s *Server) handleClosePosition(c *gin.Context) {
tempTrader, createErr = trader.NewAsterTrader(
exchangeCfg.AsterUser,
exchangeCfg.AsterSigner,
exchangeCfg.AsterPrivateKey,
string(exchangeCfg.AsterPrivateKey),
)
case "bybit":
tempTrader = trader.NewBybitTrader(
exchangeCfg.APIKey,
exchangeCfg.SecretKey,
string(exchangeCfg.APIKey),
string(exchangeCfg.SecretKey),
)
case "okx":
tempTrader = trader.NewOKXTrader(
exchangeCfg.APIKey,
exchangeCfg.SecretKey,
exchangeCfg.Passphrase,
string(exchangeCfg.APIKey),
string(exchangeCfg.SecretKey),
string(exchangeCfg.Passphrase),
)
case "bitget":
tempTrader = trader.NewBitgetTrader(
exchangeCfg.APIKey,
exchangeCfg.SecretKey,
exchangeCfg.Passphrase,
string(exchangeCfg.APIKey),
string(exchangeCfg.SecretKey),
string(exchangeCfg.Passphrase),
)
case "lighter":
if exchangeCfg.LighterWalletAddr != "" && exchangeCfg.LighterAPIKeyPrivateKey != "" {
if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" {
// Lighter only supports mainnet
tempTrader, createErr = trader.NewLighterTraderV2(
exchangeCfg.LighterWalletAddr,
exchangeCfg.LighterAPIKeyPrivateKey,
string(exchangeCfg.LighterAPIKeyPrivateKey),
exchangeCfg.LighterAPIKeyIndex,
false, // Always use mainnet for Lighter
)
+10 -8
View File
@@ -549,32 +549,34 @@ func (s *Server) runRealAITest(userID, modelID, systemPrompt, userPrompt string)
var aiClient mcp.AIClient
provider := model.Provider
// Convert EncryptedString to string for API key
apiKey := string(model.APIKey)
switch provider {
case "qwen":
aiClient = mcp.NewQwenClient()
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
case "deepseek":
aiClient = mcp.NewDeepSeekClient()
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
case "claude":
aiClient = mcp.NewClaudeClient()
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
case "kimi":
aiClient = mcp.NewKimiClient()
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
case "gemini":
aiClient = mcp.NewGeminiClient()
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
case "grok":
aiClient = mcp.NewGrokClient()
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
case "openai":
aiClient = mcp.NewOpenAIClient()
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
default:
// Use generic client
aiClient = mcp.NewClient()
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
}
// Call AI API
+2 -2
View File
@@ -76,8 +76,8 @@ func enforceRetentionDB(maxRuns int) {
query := `
SELECT run_id FROM backtest_runs
WHERE state IN (?, ?, ?, ?)
ORDER BY datetime(updated_at) DESC
LIMIT -1 OFFSET ?
ORDER BY updated_at DESC
OFFSET ?
`
rows, err := persistenceDB.Query(query,
finalStates[0], finalStates[1], finalStates[2], finalStates[3], maxRuns)
+4 -4
View File
@@ -166,7 +166,7 @@ func loadRunMetadataDB(runID string) (*RunMetadata, error) {
}
func loadRunIDsDB() ([]string, error) {
rows, err := persistenceDB.Query(`SELECT run_id FROM backtest_runs ORDER BY datetime(updated_at) DESC`)
rows, err := persistenceDB.Query(`SELECT run_id FROM backtest_runs ORDER BY updated_at DESC`)
if err != nil {
return nil, err
}
@@ -278,9 +278,9 @@ func loadDecisionTraceDB(runID string, cycle int) (*store.DecisionRecord, error)
var rows *sql.Rows
var err error
if cycle > 0 {
rows, err = persistenceDB.Query(query+` AND cycle = ? ORDER BY datetime(created_at) DESC LIMIT 1`, runID, cycle)
rows, err = persistenceDB.Query(query+` AND cycle = ? ORDER BY created_at DESC LIMIT 1`, runID, cycle)
} else {
rows, err = persistenceDB.Query(query+` ORDER BY datetime(created_at) DESC LIMIT 1`, runID)
rows, err = persistenceDB.Query(query+` ORDER BY created_at DESC LIMIT 1`, runID)
}
if err != nil {
return nil, err
@@ -461,7 +461,7 @@ func listIndexEntriesDB() ([]RunIndexEntry, error) {
rows, err := persistenceDB.Query(`
SELECT run_id, state, symbol_count, decision_tf, equity_last, max_drawdown_pct, created_at, updated_at, config_json
FROM backtest_runs
ORDER BY datetime(updated_at) DESC
ORDER BY updated_at DESC
`)
if err != nil {
return nil, err
+46
View File
@@ -20,6 +20,16 @@ type Config struct {
RegistrationEnabled bool
MaxUsers int // Maximum number of users allowed (0 = unlimited, default = 10)
// Database configuration
DBType string // sqlite or postgres
DBPath string // SQLite database file path
DBHost string // PostgreSQL host
DBPort int // PostgreSQL port
DBUser string // PostgreSQL user
DBPassword string // PostgreSQL password
DBName string // PostgreSQL database name
DBSSLMode string // PostgreSQL SSL mode
// Security configuration
// TransportEncryption enables browser-side encryption for API keys
// Requires HTTPS or localhost. Set to false for HTTP access via IP.
@@ -43,6 +53,14 @@ func Init() {
RegistrationEnabled: true,
MaxUsers: 10, // Default: 10 users allowed
ExperienceImprovement: true, // Default: enabled to help improve the product
// Database defaults
DBType: "sqlite",
DBPath: "data/data.db",
DBHost: "localhost",
DBPort: 5432,
DBUser: "postgres",
DBName: "nofx",
DBSSLMode: "disable",
}
// Load from environment variables
@@ -86,6 +104,34 @@ func Init() {
cfg.AlpacaSecretKey = os.Getenv("ALPACA_SECRET_KEY")
cfg.TwelveDataKey = os.Getenv("TWELVEDATA_API_KEY")
// Database configuration
if v := os.Getenv("DB_TYPE"); v != "" {
cfg.DBType = strings.ToLower(v)
}
if v := os.Getenv("DB_PATH"); v != "" {
cfg.DBPath = v
}
if v := os.Getenv("DB_HOST"); v != "" {
cfg.DBHost = v
}
if v := os.Getenv("DB_PORT"); v != "" {
if port, err := strconv.Atoi(v); err == nil && port > 0 {
cfg.DBPort = port
}
}
if v := os.Getenv("DB_USER"); v != "" {
cfg.DBUser = v
}
if v := os.Getenv("DB_PASSWORD"); v != "" {
cfg.DBPassword = v
}
if v := os.Getenv("DB_NAME"); v != "" {
cfg.DBName = v
}
if v := os.Getenv("DB_SSLMODE"); v != "" {
cfg.DBSSLMode = v
}
global = cfg
// Initialize experience improvement (installation ID will be set after database init)
+75
View File
@@ -7,6 +7,7 @@ import (
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"database/sql/driver"
"encoding/base64"
"encoding/hex"
"encoding/json"
@@ -392,3 +393,77 @@ func GenerateDataKey() (string, error) {
}
return base64.StdEncoding.EncodeToString(key), nil
}
// ============================================================================
// EncryptedString - GORM custom type for automatic encryption/decryption
// ============================================================================
// Global crypto service for EncryptedString
var globalCryptoService *CryptoService
// SetGlobalCryptoService sets the global crypto service for EncryptedString
func SetGlobalCryptoService(cs *CryptoService) {
globalCryptoService = cs
}
// EncryptedString is a custom type that automatically encrypts on save and decrypts on load
// Usage: Use EncryptedString instead of string for sensitive fields in GORM models
type EncryptedString string
// Scan implements sql.Scanner - called when reading from database
// Automatically decrypts the value
func (es *EncryptedString) Scan(value interface{}) error {
if value == nil {
*es = ""
return nil
}
var str string
switch v := value.(type) {
case string:
str = v
case []byte:
str = string(v)
default:
*es = ""
return nil
}
// Decrypt if crypto service is set
if globalCryptoService != nil && str != "" && globalCryptoService.IsEncryptedStorageValue(str) {
decrypted, err := globalCryptoService.DecryptFromStorage(str)
if err != nil {
// If decryption fails, return the original value
*es = EncryptedString(str)
} else {
*es = EncryptedString(decrypted)
}
} else {
*es = EncryptedString(str)
}
return nil
}
// Value implements driver.Valuer - called when writing to database
// Automatically encrypts the value
func (es EncryptedString) Value() (driver.Value, error) {
if es == "" {
return "", nil
}
// Encrypt if crypto service is set
if globalCryptoService != nil {
encrypted, err := globalCryptoService.EncryptForStorage(string(es))
if err != nil {
// If encryption fails, return the original value
return string(es), nil
}
return encrypted, nil
}
return string(es), nil
}
// String returns the plaintext string value
func (es EncryptedString) String() string {
return string(es)
}
+2 -2
View File
@@ -101,8 +101,8 @@ func (e *DebateEngine) InitializeClients(participants []*store.DebateParticipant
client = mcp.New()
}
// Configure client
client.SetAPIKey(aiModel.APIKey, aiModel.CustomAPIURL, aiModel.CustomModelName)
// Configure client (convert EncryptedString to string)
client.SetAPIKey(string(aiModel.APIKey), aiModel.CustomAPIURL, aiModel.CustomModelName)
e.clients[p.AIModelID] = client
}
+9
View File
@@ -51,6 +51,12 @@ require (
github.com/goccy/go-json v0.10.4 // indirect
github.com/goccy/go-yaml v1.18.0 // indirect
github.com/holiman/uint256 v1.3.2 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.6.0 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/jpillora/backoff v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
@@ -94,6 +100,9 @@ require (
golang.org/x/tools v0.36.0 // indirect
google.golang.org/protobuf v1.36.9 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gorm.io/driver/postgres v1.6.0 // indirect
gorm.io/driver/sqlite v1.6.0 // indirect
gorm.io/gorm v1.31.1 // indirect
howett.net/plist v1.0.1 // indirect
modernc.org/libc v1.66.10 // indirect
modernc.org/mathutil v1.7.1 // indirect
+18
View File
@@ -111,7 +111,19 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/holiman/uint256 v1.3.2 h1:a9EgMPSC1AAaj1SZL5zIQD3WbwTuHrMGOerLjGmM/TA=
github.com/holiman/uint256 v1.3.2/go.mod h1:EOMSn4q6Nyt9P6efbI3bueV4e1b3dGlUCXeiRV4ng7E=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
@@ -284,6 +296,12 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4=
gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo=
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
howett.net/plist v1.0.1 h1:37GdZ8tP09Q35o9ych3ehygcsL+HqKSwzctveSlarvM=
howett.net/plist v1.0.1/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
modernc.org/cc/v4 v4.26.5 h1:xM3bX7Mve6G8K8b+T11ReenJOT+BmVqQj0FY5T4+5Y4=
+28 -48
View File
@@ -36,30 +36,44 @@ func main() {
cfg := config.Get()
logger.Info("✅ Configuration loaded")
// Initialize database from environment variables
// DB_TYPE: sqlite (default) or postgres
// For SQLite: DB_PATH (default: data/data.db)
// For PostgreSQL: DB_HOST, DB_PORT, DB_USER, DB_PASSWORD, DB_NAME, DB_SSLMODE
dbPath := os.Getenv("DB_PATH")
if dbPath == "" {
dbPath = "data/data.db"
// Initialize encryption service BEFORE database (so EncryptedString can decrypt on read)
logger.Info("🔐 Initializing encryption service...")
cryptoService, err := crypto.NewCryptoService()
if err != nil {
logger.Fatalf("❌ Failed to initialize encryption service: %v", err)
}
// For backward compatibility: command line arg overrides env var (SQLite only)
crypto.SetGlobalCryptoService(cryptoService)
logger.Info("✅ Encryption service initialized successfully")
// Initialize database from configuration
// For backward compatibility: command line arg overrides config (SQLite only)
if len(os.Args) > 1 {
dbPath = os.Args[1]
os.Setenv("DB_PATH", dbPath)
cfg.DBPath = os.Args[1]
}
// Ensure data directory exists (for SQLite)
if os.Getenv("DB_TYPE") == "" || os.Getenv("DB_TYPE") == "sqlite" {
if dir := filepath.Dir(dbPath); dir != "." {
if cfg.DBType == "sqlite" {
if dir := filepath.Dir(cfg.DBPath); dir != "." {
if err := os.MkdirAll(dir, 0755); err != nil {
logger.Errorf("Failed to create data directory: %v", err)
}
}
}
logger.Info("📋 Initializing database...")
st, err := store.NewFromEnv()
logger.Infof("📋 Initializing database (%s)...", cfg.DBType)
dbType := store.DBTypeSQLite
if cfg.DBType == "postgres" {
dbType = store.DBTypePostgres
}
st, err := store.NewWithConfig(store.DBConfig{
Type: dbType,
Path: cfg.DBPath,
Host: cfg.DBHost,
Port: cfg.DBPort,
User: cfg.DBUser,
Password: cfg.DBPassword,
DBName: cfg.DBName,
SSLMode: cfg.DBSSLMode,
})
if err != nil {
logger.Fatalf("❌ Failed to initialize database: %v", err)
}
@@ -69,40 +83,6 @@ func main() {
// Initialize installation ID for experience improvement (anonymous statistics)
initInstallationID(st)
// Initialize encryption service
logger.Info("🔐 Initializing encryption service...")
cryptoService, err := crypto.NewCryptoService()
if err != nil {
logger.Fatalf("❌ Failed to initialize encryption service: %v", err)
}
encryptFunc := func(plaintext string) string {
if plaintext == "" {
return plaintext
}
encrypted, err := cryptoService.EncryptForStorage(plaintext)
if err != nil {
logger.Warnf("⚠️ Encryption failed: %v", err)
return plaintext
}
return encrypted
}
decryptFunc := func(encrypted string) string {
if encrypted == "" {
return encrypted
}
if !cryptoService.IsEncryptedStorageValue(encrypted) {
return encrypted
}
decrypted, err := cryptoService.DecryptFromStorage(encrypted)
if err != nil {
logger.Warnf("⚠️ Decryption failed: %v", err)
return encrypted
}
return decrypted
}
st.SetCryptoFuncs(encryptFunc, decryptFunc)
logger.Info("✅ Encryption service initialized successfully")
// Set JWT secret
auth.SetJWTSecret(cfg.JWTSecret)
logger.Info("🔑 JWT secret configured")
+19 -19
View File
@@ -664,46 +664,46 @@ func (tm *TraderManager) addTraderFromStore(traderCfg *store.Trader, aiModelCfg
StrategyConfig: strategyConfig,
}
// Set API keys based on exchange type
// Set API keys based on exchange type (convert EncryptedString to string)
switch exchangeCfg.ExchangeType {
case "binance":
traderConfig.BinanceAPIKey = exchangeCfg.APIKey
traderConfig.BinanceSecretKey = exchangeCfg.SecretKey
traderConfig.BinanceAPIKey = string(exchangeCfg.APIKey)
traderConfig.BinanceSecretKey = string(exchangeCfg.SecretKey)
case "bybit":
traderConfig.BybitAPIKey = exchangeCfg.APIKey
traderConfig.BybitSecretKey = exchangeCfg.SecretKey
traderConfig.BybitAPIKey = string(exchangeCfg.APIKey)
traderConfig.BybitSecretKey = string(exchangeCfg.SecretKey)
case "okx":
traderConfig.OKXAPIKey = exchangeCfg.APIKey
traderConfig.OKXSecretKey = exchangeCfg.SecretKey
traderConfig.OKXPassphrase = exchangeCfg.Passphrase
traderConfig.OKXAPIKey = string(exchangeCfg.APIKey)
traderConfig.OKXSecretKey = string(exchangeCfg.SecretKey)
traderConfig.OKXPassphrase = string(exchangeCfg.Passphrase)
case "bitget":
traderConfig.BitgetAPIKey = exchangeCfg.APIKey
traderConfig.BitgetSecretKey = exchangeCfg.SecretKey
traderConfig.BitgetPassphrase = exchangeCfg.Passphrase
traderConfig.BitgetAPIKey = string(exchangeCfg.APIKey)
traderConfig.BitgetSecretKey = string(exchangeCfg.SecretKey)
traderConfig.BitgetPassphrase = string(exchangeCfg.Passphrase)
case "hyperliquid":
traderConfig.HyperliquidPrivateKey = exchangeCfg.APIKey
traderConfig.HyperliquidPrivateKey = string(exchangeCfg.APIKey)
traderConfig.HyperliquidWalletAddr = exchangeCfg.HyperliquidWalletAddr
case "aster":
traderConfig.AsterUser = exchangeCfg.AsterUser
traderConfig.AsterSigner = exchangeCfg.AsterSigner
traderConfig.AsterPrivateKey = exchangeCfg.AsterPrivateKey
traderConfig.AsterPrivateKey = string(exchangeCfg.AsterPrivateKey)
case "lighter":
traderConfig.LighterPrivateKey = exchangeCfg.LighterPrivateKey
traderConfig.LighterPrivateKey = string(exchangeCfg.LighterPrivateKey)
traderConfig.LighterWalletAddr = exchangeCfg.LighterWalletAddr
traderConfig.LighterAPIKeyPrivateKey = exchangeCfg.LighterAPIKeyPrivateKey
traderConfig.LighterAPIKeyPrivateKey = string(exchangeCfg.LighterAPIKeyPrivateKey)
traderConfig.LighterAPIKeyIndex = exchangeCfg.LighterAPIKeyIndex
traderConfig.LighterTestnet = exchangeCfg.Testnet
}
// Set API keys based on AI model
// Set API keys based on AI model (convert EncryptedString to string)
switch aiModelCfg.Provider {
case "qwen":
traderConfig.QwenKey = aiModelCfg.APIKey
traderConfig.QwenKey = string(aiModelCfg.APIKey)
case "deepseek":
traderConfig.DeepSeekKey = aiModelCfg.APIKey
traderConfig.DeepSeekKey = string(aiModelCfg.APIKey)
default:
// For other providers (grok, openai, claude, gemini, kimi, etc.), use CustomAPIKey
traderConfig.CustomAPIKey = aiModelCfg.APIKey
traderConfig.CustomAPIKey = string(aiModelCfg.APIKey)
}
// Create trader instance
+88 -172
View File
@@ -1,71 +1,52 @@
package store
import (
"database/sql"
"errors"
"fmt"
"nofx/crypto"
"nofx/logger"
"strings"
"time"
"gorm.io/gorm"
)
// AIModelStore AI model storage
type AIModelStore struct {
db *sql.DB
encryptFunc func(string) string
decryptFunc func(string) string
db *gorm.DB
}
// AIModel AI model configuration
type AIModel struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Name string `json:"name"`
Provider string `json:"provider"`
Enabled bool `json:"enabled"`
APIKey string `json:"apiKey"`
CustomAPIURL string `json:"customApiUrl"`
CustomModelName string `json:"customModelName"`
ID string `gorm:"primaryKey" json:"id"`
UserID string `gorm:"column:user_id;not null;default:default;index" json:"user_id"`
Name string `gorm:"not null" json:"name"`
Provider string `gorm:"not null" json:"provider"`
Enabled bool `gorm:"default:false" json:"enabled"`
APIKey crypto.EncryptedString `gorm:"column:api_key;default:''" json:"apiKey"`
CustomAPIURL string `gorm:"column:custom_api_url;default:''" json:"customApiUrl"`
CustomModelName string `gorm:"column:custom_model_name;default:''" json:"customModelName"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (AIModel) TableName() string { return "ai_models" }
// NewAIModelStore creates a new AIModelStore
func NewAIModelStore(db *gorm.DB) *AIModelStore {
return &AIModelStore{db: db}
}
func (s *AIModelStore) initTables() error {
_, err := s.db.Exec(`
CREATE TABLE IF NOT EXISTS ai_models (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL DEFAULT 'default',
name TEXT NOT NULL,
provider TEXT NOT NULL,
enabled BOOLEAN DEFAULT 0,
api_key TEXT DEFAULT '',
custom_api_url TEXT DEFAULT '',
custom_model_name TEXT DEFAULT '',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return err
}
// Trigger
_, err = s.db.Exec(`
CREATE TRIGGER IF NOT EXISTS update_ai_models_updated_at
AFTER UPDATE ON ai_models
BEGIN
UPDATE ai_models SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
END
`)
if err != nil {
return err
}
// Backward compatibility: add potentially missing columns
s.db.Exec(`ALTER TABLE ai_models ADD COLUMN custom_api_url TEXT DEFAULT ''`)
s.db.Exec(`ALTER TABLE ai_models ADD COLUMN custom_model_name TEXT DEFAULT ''`)
// For PostgreSQL with existing table, skip AutoMigrate
if s.db.Dialector.Name() == "postgres" {
var tableExists int64
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'ai_models'`).Scan(&tableExists)
if tableExists > 0 {
return nil
}
}
return s.db.AutoMigrate(&AIModel{})
}
func (s *AIModelStore) initDefaultData() error {
@@ -73,51 +54,13 @@ func (s *AIModelStore) initDefaultData() error {
return nil
}
func (s *AIModelStore) encrypt(plaintext string) string {
if s.encryptFunc != nil {
return s.encryptFunc(plaintext)
}
return plaintext
}
func (s *AIModelStore) decrypt(encrypted string) string {
if s.decryptFunc != nil {
return s.decryptFunc(encrypted)
}
return encrypted
}
// List retrieves user's AI model list
func (s *AIModelStore) List(userID string) ([]*AIModel, error) {
rows, err := s.db.Query(`
SELECT id, user_id, name, provider, enabled, api_key,
COALESCE(custom_api_url, '') as custom_api_url,
COALESCE(custom_model_name, '') as custom_model_name,
created_at, updated_at
FROM ai_models WHERE user_id = ? ORDER BY id
`, userID)
var models []*AIModel
err := s.db.Where("user_id = ?", userID).Order("id").Find(&models).Error
if err != nil {
return nil, err
}
defer rows.Close()
models := make([]*AIModel, 0)
for rows.Next() {
var model AIModel
var createdAt, updatedAt string
err := rows.Scan(
&model.ID, &model.UserID, &model.Name, &model.Provider,
&model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName,
&createdAt, &updatedAt,
)
if err != nil {
return nil, err
}
model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
model.APIKey = s.decrypt(model.APIKey)
models = append(models, &model)
}
return models, nil
}
@@ -140,27 +83,15 @@ func (s *AIModelStore) Get(userID, modelID string) (*AIModel, error) {
for _, uid := range candidates {
var model AIModel
var createdAt, updatedAt string
err := s.db.QueryRow(`
SELECT id, user_id, name, provider, enabled, api_key,
COALESCE(custom_api_url, ''), COALESCE(custom_model_name, ''), created_at, updated_at
FROM ai_models WHERE user_id = ? AND id = ? LIMIT 1
`, uid, modelID).Scan(
&model.ID, &model.UserID, &model.Name, &model.Provider,
&model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName,
&createdAt, &updatedAt,
)
err := s.db.Where("user_id = ? AND id = ?", uid, modelID).First(&model).Error
if err == nil {
model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
model.APIKey = s.decrypt(model.APIKey)
return &model, nil
}
if !errors.Is(err, sql.ErrNoRows) {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
}
return nil, sql.ErrNoRows
return nil, gorm.ErrRecordNotFound
}
// GetByID retrieves an AI model by ID only (for debate engine)
@@ -170,22 +101,10 @@ func (s *AIModelStore) GetByID(modelID string) (*AIModel, error) {
}
var model AIModel
var createdAt, updatedAt string
err := s.db.QueryRow(`
SELECT id, user_id, name, provider, enabled, api_key,
COALESCE(custom_api_url, ''), COALESCE(custom_model_name, ''), created_at, updated_at
FROM ai_models WHERE id = ? LIMIT 1
`, modelID).Scan(
&model.ID, &model.UserID, &model.Name, &model.Provider,
&model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName,
&createdAt, &updatedAt,
)
err := s.db.Where("id = ?", modelID).First(&model).Error
if err != nil {
return nil, err
}
model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
model.APIKey = s.decrypt(model.APIKey)
return &model, nil
}
@@ -198,7 +117,7 @@ func (s *AIModelStore) GetDefault(userID string) (*AIModel, error) {
if err == nil {
return model, nil
}
if !errors.Is(err, sql.ErrNoRows) {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
if userID != "default" {
@@ -209,23 +128,12 @@ func (s *AIModelStore) GetDefault(userID string) (*AIModel, error) {
func (s *AIModelStore) firstEnabled(userID string) (*AIModel, error) {
var model AIModel
var createdAt, updatedAt string
err := s.db.QueryRow(`
SELECT id, user_id, name, provider, enabled, api_key,
COALESCE(custom_api_url, ''), COALESCE(custom_model_name, ''), created_at, updated_at
FROM ai_models WHERE user_id = ? AND enabled = 1
ORDER BY datetime(updated_at) DESC, id ASC LIMIT 1
`, userID).Scan(
&model.ID, &model.UserID, &model.Name, &model.Provider,
&model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName,
&createdAt, &updatedAt,
)
err := s.db.Where("user_id = ? AND enabled = ?", userID, true).
Order("updated_at DESC, id ASC").
First(&model).Error
if err != nil {
return nil, err
}
model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
model.APIKey = s.decrypt(model.APIKey)
return &model, nil
}
@@ -233,44 +141,38 @@ func (s *AIModelStore) firstEnabled(userID string) (*AIModel, error) {
// IMPORTANT: If apiKey is empty string, the existing API key will be preserved (not overwritten)
func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPIURL, customModelName string) error {
// Try exact ID match first
var existingID string
err := s.db.QueryRow(`SELECT id FROM ai_models WHERE user_id = ? AND id = ? LIMIT 1`, userID, id).Scan(&existingID)
var existingModel AIModel
err := s.db.Where("user_id = ? AND id = ?", userID, id).First(&existingModel).Error
if err == nil {
// If apiKey is empty, preserve the existing API key
if apiKey == "" {
_, err = s.db.Exec(`
UPDATE ai_models SET enabled = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now')
WHERE id = ? AND user_id = ?
`, enabled, customAPIURL, customModelName, existingID, userID)
} else {
encryptedAPIKey := s.encrypt(apiKey)
_, err = s.db.Exec(`
UPDATE ai_models SET enabled = ?, api_key = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now')
WHERE id = ? AND user_id = ?
`, enabled, encryptedAPIKey, customAPIURL, customModelName, existingID, userID)
// Update existing model
updates := map[string]interface{}{
"enabled": enabled,
"custom_api_url": customAPIURL,
"custom_model_name": customModelName,
"updated_at": time.Now(),
}
return err
// If apiKey is not empty, update it (encryption handled by crypto.EncryptedString)
if apiKey != "" {
updates["api_key"] = crypto.EncryptedString(apiKey)
}
return s.db.Model(&existingModel).Updates(updates).Error
}
// Try legacy logic compatibility: use id as provider to search
provider := id
err = s.db.QueryRow(`SELECT id FROM ai_models WHERE user_id = ? AND provider = ? LIMIT 1`, userID, provider).Scan(&existingID)
err = s.db.Where("user_id = ? AND provider = ?", userID, provider).First(&existingModel).Error
if err == nil {
logger.Warnf("⚠️ Using legacy provider matching to update model: %s -> %s", provider, existingID)
// If apiKey is empty, preserve the existing API key
if apiKey == "" {
_, err = s.db.Exec(`
UPDATE ai_models SET enabled = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now')
WHERE id = ? AND user_id = ?
`, enabled, customAPIURL, customModelName, existingID, userID)
} else {
encryptedAPIKey := s.encrypt(apiKey)
_, err = s.db.Exec(`
UPDATE ai_models SET enabled = ?, api_key = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now')
WHERE id = ? AND user_id = ?
`, enabled, encryptedAPIKey, customAPIURL, customModelName, existingID, userID)
logger.Warnf("⚠️ Using legacy provider matching to update model: %s -> %s", provider, existingModel.ID)
updates := map[string]interface{}{
"enabled": enabled,
"custom_api_url": customAPIURL,
"custom_model_name": customModelName,
"updated_at": time.Now(),
}
return err
if apiKey != "" {
updates["api_key"] = crypto.EncryptedString(apiKey)
}
return s.db.Model(&existingModel).Updates(updates).Error
}
// Create new record
@@ -285,9 +187,12 @@ func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPI
}
}
// Try to get name from existing model with same provider
var refModel AIModel
var name string
err = s.db.QueryRow(`SELECT name FROM ai_models WHERE provider = ? LIMIT 1`, provider).Scan(&name)
if err != nil {
if err := s.db.Where("provider = ?", provider).First(&refModel).Error; err == nil {
name = refModel.Name
} else {
if provider == "deepseek" {
name = "DeepSeek AI"
} else if provider == "qwen" {
@@ -303,19 +208,30 @@ func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPI
}
logger.Infof("✓ Creating new AI model configuration: ID=%s, Provider=%s, Name=%s", newModelID, provider, name)
encryptedAPIKey := s.encrypt(apiKey)
_, err = s.db.Exec(`
INSERT INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url, custom_model_name, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now'))
`, newModelID, userID, name, provider, enabled, encryptedAPIKey, customAPIURL, customModelName)
return err
newModel := &AIModel{
ID: newModelID,
UserID: userID,
Name: name,
Provider: provider,
Enabled: enabled,
APIKey: crypto.EncryptedString(apiKey),
CustomAPIURL: customAPIURL,
CustomModelName: customModelName,
}
return s.db.Create(newModel).Error
}
// Create creates an AI model
func (s *AIModelStore) Create(userID, id, name, provider string, enabled bool, apiKey, customAPIURL string) error {
_, err := s.db.Exec(`
INSERT OR IGNORE INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url)
VALUES (?, ?, ?, ?, ?, ?, ?)
`, id, userID, name, provider, enabled, apiKey, customAPIURL)
return err
model := &AIModel{
ID: id,
UserID: userID,
Name: name,
Provider: provider,
Enabled: enabled,
APIKey: crypto.EncryptedString(apiKey),
CustomAPIURL: customAPIURL,
}
// Use FirstOrCreate to ignore if already exists
return s.db.Where("id = ?", id).FirstOrCreate(model).Error
}
+339 -351
View File
@@ -1,15 +1,26 @@
package store
import (
"database/sql"
"encoding/json"
"fmt"
"time"
"gorm.io/gorm"
)
// BacktestStore backtest data storage
type BacktestStore struct {
db *sql.DB
db *gorm.DB
}
// NewBacktestStore creates a new backtest store
func NewBacktestStore(db *gorm.DB) *BacktestStore {
return &BacktestStore{db: db}
}
// isPostgres checks if the database is PostgreSQL
func (s *BacktestStore) isPostgres() bool {
return s.db.Dialector.Name() == "postgres"
}
// RunState backtest state
@@ -92,492 +103,469 @@ type RunIndexEntry struct {
UpdatedAtISO string `json:"updated_at"`
}
// BacktestRun GORM model for backtest_runs table
type BacktestRun struct {
RunID string `gorm:"column:run_id;primaryKey"`
UserID string `gorm:"column:user_id;not null;default:''"`
ConfigJSON []byte `gorm:"column:config_json"`
State string `gorm:"column:state;not null;default:created"`
Label string `gorm:"column:label;default:''"`
SymbolCount int `gorm:"column:symbol_count;default:0"`
DecisionTF string `gorm:"column:decision_tf;default:''"`
ProcessedBars int `gorm:"column:processed_bars;default:0"`
ProgressPct float64 `gorm:"column:progress_pct;default:0"`
EquityLast float64 `gorm:"column:equity_last;default:0"`
MaxDrawdownPct float64 `gorm:"column:max_drawdown_pct;default:0"`
Liquidated bool `gorm:"column:liquidated;default:false"`
LiquidationNote string `gorm:"column:liquidation_note;default:''"`
PromptTemplate string `gorm:"column:prompt_template;default:''"`
CustomPrompt string `gorm:"column:custom_prompt;default:''"`
OverridePrompt bool `gorm:"column:override_prompt;default:false"`
AIProvider string `gorm:"column:ai_provider;default:''"`
AIModel string `gorm:"column:ai_model;default:''"`
LastError string `gorm:"column:last_error;default:''"`
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime"`
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime"`
}
func (BacktestRun) TableName() string {
return "backtest_runs"
}
// BacktestCheckpoint GORM model
type BacktestCheckpoint struct {
RunID string `gorm:"column:run_id;primaryKey"`
Payload []byte `gorm:"column:payload;not null"`
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime"`
}
func (BacktestCheckpoint) TableName() string {
return "backtest_checkpoints"
}
// BacktestEquity GORM model
type BacktestEquity struct {
ID int64 `gorm:"primaryKey;autoIncrement"`
RunID string `gorm:"column:run_id;not null;index:idx_backtest_equity_run_ts"`
TS int64 `gorm:"column:ts;not null;index:idx_backtest_equity_run_ts"`
Equity float64 `gorm:"column:equity;not null"`
Available float64 `gorm:"column:available;not null"`
PnL float64 `gorm:"column:pnl;not null"`
PnLPct float64 `gorm:"column:pnl_pct;not null"`
DDPct float64 `gorm:"column:dd_pct;not null"`
Cycle int `gorm:"column:cycle;not null"`
}
func (BacktestEquity) TableName() string {
return "backtest_equity"
}
// BacktestTrade GORM model
type BacktestTrade struct {
ID int64 `gorm:"primaryKey;autoIncrement"`
RunID string `gorm:"column:run_id;not null;index:idx_backtest_trades_run_ts"`
TS int64 `gorm:"column:ts;not null;index:idx_backtest_trades_run_ts"`
Symbol string `gorm:"column:symbol;not null"`
Action string `gorm:"column:action;not null"`
Side string `gorm:"column:side;default:''"`
Qty float64 `gorm:"column:qty;default:0"`
Price float64 `gorm:"column:price;default:0"`
Fee float64 `gorm:"column:fee;default:0"`
Slippage float64 `gorm:"column:slippage;default:0"`
OrderValue float64 `gorm:"column:order_value;default:0"`
RealizedPnL float64 `gorm:"column:realized_pnl;default:0"`
Leverage int `gorm:"column:leverage;default:0"`
Cycle int `gorm:"column:cycle;default:0"`
PositionAfter float64 `gorm:"column:position_after;default:0"`
Liquidation bool `gorm:"column:liquidation;default:false"`
Note string `gorm:"column:note;default:''"`
}
func (BacktestTrade) TableName() string {
return "backtest_trades"
}
// BacktestMetrics GORM model
type BacktestMetrics struct {
RunID string `gorm:"column:run_id;primaryKey"`
Payload []byte `gorm:"column:payload;not null"`
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime"`
}
func (BacktestMetrics) TableName() string {
return "backtest_metrics"
}
// BacktestDecision GORM model
type BacktestDecision struct {
ID int64 `gorm:"primaryKey;autoIncrement"`
RunID string `gorm:"column:run_id;not null;index:idx_backtest_decisions_run_cycle"`
Cycle int `gorm:"column:cycle;not null;index:idx_backtest_decisions_run_cycle"`
Payload []byte `gorm:"column:payload;not null"`
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime"`
}
func (BacktestDecision) TableName() string {
return "backtest_decisions"
}
// initTables initializes backtest related tables
func (s *BacktestStore) initTables() error {
queries := []string{
// Backtest runs main table
`CREATE TABLE IF NOT EXISTS backtest_runs (
run_id TEXT PRIMARY KEY,
user_id TEXT NOT NULL DEFAULT '',
config_json TEXT NOT NULL DEFAULT '',
state TEXT NOT NULL DEFAULT 'created',
label TEXT DEFAULT '',
symbol_count INTEGER DEFAULT 0,
decision_tf TEXT DEFAULT '',
processed_bars INTEGER DEFAULT 0,
progress_pct REAL DEFAULT 0,
equity_last REAL DEFAULT 0,
max_drawdown_pct REAL DEFAULT 0,
liquidated BOOLEAN DEFAULT 0,
liquidation_note TEXT DEFAULT '',
prompt_template TEXT DEFAULT '',
custom_prompt TEXT DEFAULT '',
override_prompt BOOLEAN DEFAULT 0,
ai_provider TEXT DEFAULT '',
ai_model TEXT DEFAULT '',
last_error TEXT DEFAULT '',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`,
// For PostgreSQL with existing tables, skip AutoMigrate to avoid type conflicts
if s.db.Dialector.Name() == "postgres" {
var tableExists int64
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'backtest_runs'`).Scan(&tableExists)
// Backtest checkpoints
`CREATE TABLE IF NOT EXISTS backtest_checkpoints (
run_id TEXT PRIMARY KEY,
payload BLOB NOT NULL,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
)`,
// Backtest equity curve
`CREATE TABLE IF NOT EXISTS backtest_equity (
id INTEGER PRIMARY KEY AUTOINCREMENT,
run_id TEXT NOT NULL,
ts INTEGER NOT NULL,
equity REAL NOT NULL,
available REAL NOT NULL,
pnl REAL NOT NULL,
pnl_pct REAL NOT NULL,
dd_pct REAL NOT NULL,
cycle INTEGER NOT NULL,
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
)`,
// Backtest trade records
`CREATE TABLE IF NOT EXISTS backtest_trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
run_id TEXT NOT NULL,
ts INTEGER NOT NULL,
symbol TEXT NOT NULL,
action TEXT NOT NULL,
side TEXT DEFAULT '',
qty REAL DEFAULT 0,
price REAL DEFAULT 0,
fee REAL DEFAULT 0,
slippage REAL DEFAULT 0,
order_value REAL DEFAULT 0,
realized_pnl REAL DEFAULT 0,
leverage INTEGER DEFAULT 0,
cycle INTEGER DEFAULT 0,
position_after REAL DEFAULT 0,
liquidation BOOLEAN DEFAULT 0,
note TEXT DEFAULT '',
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
)`,
// Backtest metrics
`CREATE TABLE IF NOT EXISTS backtest_metrics (
run_id TEXT PRIMARY KEY,
payload BLOB NOT NULL,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
)`,
// Backtest decision logs
`CREATE TABLE IF NOT EXISTS backtest_decisions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
run_id TEXT NOT NULL,
cycle INTEGER NOT NULL,
payload BLOB NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
)`,
// Indexes
`CREATE INDEX IF NOT EXISTS idx_backtest_runs_state ON backtest_runs(state, updated_at)`,
`CREATE INDEX IF NOT EXISTS idx_backtest_equity_run_ts ON backtest_equity(run_id, ts)`,
`CREATE INDEX IF NOT EXISTS idx_backtest_trades_run_ts ON backtest_trades(run_id, ts)`,
`CREATE INDEX IF NOT EXISTS idx_backtest_decisions_run_cycle ON backtest_decisions(run_id, cycle)`,
}
for _, query := range queries {
if _, err := s.db.Exec(query); err != nil {
return fmt.Errorf("failed to execute SQL: %w", err)
if tableExists > 0 {
// Tables exist - just ensure indexes exist
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_backtest_equity_run_ts ON backtest_equity(run_id, ts)`)
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_backtest_trades_run_ts ON backtest_trades(run_id, ts)`)
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_backtest_decisions_run_cycle ON backtest_decisions(run_id, cycle)`)
return nil
}
}
// Add potentially missing columns (backward compatibility)
s.addColumnIfNotExists("backtest_runs", "label", "TEXT DEFAULT ''")
s.addColumnIfNotExists("backtest_runs", "last_error", "TEXT DEFAULT ''")
s.addColumnIfNotExists("backtest_trades", "leverage", "INTEGER DEFAULT 0")
// AutoMigrate all backtest tables
if err := s.db.AutoMigrate(
&BacktestRun{},
&BacktestCheckpoint{},
&BacktestEquity{},
&BacktestTrade{},
&BacktestMetrics{},
&BacktestDecision{},
); err != nil {
return fmt.Errorf("failed to migrate backtest tables: %w", err)
}
return nil
}
func (s *BacktestStore) addColumnIfNotExists(table, column, definition string) {
rows, err := s.db.Query(fmt.Sprintf("PRAGMA table_info(%s)", table))
if err != nil {
return
}
defer rows.Close()
for rows.Next() {
var cid int
var name, ctype string
var notnull, pk int
var dflt interface{}
if err := rows.Scan(&cid, &name, &ctype, &notnull, &dflt, &pk); err != nil {
continue
}
if name == column {
return // Column already exists
}
}
s.db.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", table, column, definition))
}
// SaveCheckpoint saves checkpoint
func (s *BacktestStore) SaveCheckpoint(runID string, payload []byte) error {
_, err := s.db.Exec(`
INSERT INTO backtest_checkpoints (run_id, payload, updated_at)
VALUES (?, ?, CURRENT_TIMESTAMP)
ON CONFLICT(run_id) DO UPDATE SET payload=excluded.payload, updated_at=CURRENT_TIMESTAMP
`, runID, payload)
return err
checkpoint := BacktestCheckpoint{
RunID: runID,
Payload: payload,
}
return s.db.Save(&checkpoint).Error
}
// LoadCheckpoint loads checkpoint
func (s *BacktestStore) LoadCheckpoint(runID string) ([]byte, error) {
var payload []byte
err := s.db.QueryRow(`SELECT payload FROM backtest_checkpoints WHERE run_id = ?`, runID).Scan(&payload)
return payload, err
var checkpoint BacktestCheckpoint
err := s.db.Where("run_id = ?", runID).First(&checkpoint).Error
if err != nil {
return nil, err
}
return checkpoint.Payload, nil
}
// SaveRunMetadata saves run metadata
func (s *BacktestStore) SaveRunMetadata(meta *RunMetadata) error {
created := meta.CreatedAt.UTC().Format(time.RFC3339)
updated := meta.UpdatedAt.UTC().Format(time.RFC3339)
userID := meta.UserID
if _, err := s.db.Exec(`
INSERT INTO backtest_runs (run_id, user_id, label, last_error, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(run_id) DO NOTHING
`, meta.RunID, userID, meta.Label, meta.LastError, created, updated); err != nil {
return err
run := BacktestRun{
RunID: meta.RunID,
UserID: meta.UserID,
State: string(meta.State),
Label: meta.Label,
LastError: meta.LastError,
SymbolCount: meta.Summary.SymbolCount,
DecisionTF: meta.Summary.DecisionTF,
ProcessedBars: meta.Summary.ProcessedBars,
ProgressPct: meta.Summary.ProgressPct,
EquityLast: meta.Summary.EquityLast,
MaxDrawdownPct: meta.Summary.MaxDrawdownPct,
Liquidated: meta.Summary.Liquidated,
LiquidationNote: meta.Summary.LiquidationNote,
CreatedAt: meta.CreatedAt,
UpdatedAt: meta.UpdatedAt,
}
_, err := s.db.Exec(`
UPDATE backtest_runs
SET user_id = ?, state = ?, symbol_count = ?, decision_tf = ?, processed_bars = ?,
progress_pct = ?, equity_last = ?, max_drawdown_pct = ?, liquidated = ?,
liquidation_note = ?, label = ?, last_error = ?, updated_at = ?
WHERE run_id = ?
`, userID, string(meta.State), meta.Summary.SymbolCount, meta.Summary.DecisionTF,
meta.Summary.ProcessedBars, meta.Summary.ProgressPct, meta.Summary.EquityLast,
meta.Summary.MaxDrawdownPct, meta.Summary.Liquidated, meta.Summary.LiquidationNote,
meta.Label, meta.LastError, updated, meta.RunID)
return err
return s.db.Save(&run).Error
}
// LoadRunMetadata loads run metadata
func (s *BacktestStore) LoadRunMetadata(runID string) (*RunMetadata, error) {
var (
userID string
state string
label string
lastErr string
symbolCount int
decisionTF string
processedBars int
progressPct float64
equityLast float64
maxDD float64
liquidated bool
liquidationNote string
createdISO string
updatedISO string
)
err := s.db.QueryRow(`
SELECT user_id, state, label, last_error, symbol_count, decision_tf, processed_bars,
progress_pct, equity_last, max_drawdown_pct, liquidated, liquidation_note,
created_at, updated_at
FROM backtest_runs WHERE run_id = ?
`, runID).Scan(&userID, &state, &label, &lastErr, &symbolCount, &decisionTF,
&processedBars, &progressPct, &equityLast, &maxDD, &liquidated, &liquidationNote,
&createdISO, &updatedISO)
var run BacktestRun
err := s.db.Where("run_id = ?", runID).First(&run).Error
if err != nil {
return nil, err
}
meta := &RunMetadata{
RunID: runID,
UserID: userID,
return &RunMetadata{
RunID: run.RunID,
UserID: run.UserID,
Version: 1,
State: RunState(state),
Label: label,
LastError: lastErr,
State: RunState(run.State),
Label: run.Label,
LastError: run.LastError,
Summary: RunSummary{
SymbolCount: symbolCount,
DecisionTF: decisionTF,
ProcessedBars: processedBars,
ProgressPct: progressPct,
EquityLast: equityLast,
MaxDrawdownPct: maxDD,
Liquidated: liquidated,
LiquidationNote: liquidationNote,
SymbolCount: run.SymbolCount,
DecisionTF: run.DecisionTF,
ProcessedBars: run.ProcessedBars,
ProgressPct: run.ProgressPct,
EquityLast: run.EquityLast,
MaxDrawdownPct: run.MaxDrawdownPct,
Liquidated: run.Liquidated,
LiquidationNote: run.LiquidationNote,
},
}
meta.CreatedAt, _ = time.Parse(time.RFC3339, createdISO)
meta.UpdatedAt, _ = time.Parse(time.RFC3339, updatedISO)
return meta, nil
CreatedAt: run.CreatedAt,
UpdatedAt: run.UpdatedAt,
}, nil
}
// ListRunIDs lists all run IDs
func (s *BacktestStore) ListRunIDs() ([]string, error) {
rows, err := s.db.Query(`SELECT run_id FROM backtest_runs ORDER BY datetime(updated_at) DESC`)
var runs []BacktestRun
err := s.db.Order("updated_at DESC").Find(&runs).Error
if err != nil {
return nil, err
}
defer rows.Close()
var ids []string
for rows.Next() {
var runID string
if err := rows.Scan(&runID); err != nil {
return nil, err
ids := make([]string, len(runs))
for i, run := range runs {
ids[i] = run.RunID
}
ids = append(ids, runID)
}
return ids, rows.Err()
return ids, nil
}
// AppendEquityPoint appends equity point
func (s *BacktestStore) AppendEquityPoint(runID string, point EquityPoint) error {
_, err := s.db.Exec(`
INSERT INTO backtest_equity (run_id, ts, equity, available, pnl, pnl_pct, dd_pct, cycle)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
`, runID, point.Timestamp, point.Equity, point.Available, point.PnL,
point.PnLPct, point.DrawdownPct, point.Cycle)
return err
eq := BacktestEquity{
RunID: runID,
TS: point.Timestamp,
Equity: point.Equity,
Available: point.Available,
PnL: point.PnL,
PnLPct: point.PnLPct,
DDPct: point.DrawdownPct,
Cycle: point.Cycle,
}
return s.db.Create(&eq).Error
}
// LoadEquityPoints loads equity points
func (s *BacktestStore) LoadEquityPoints(runID string) ([]EquityPoint, error) {
rows, err := s.db.Query(`
SELECT ts, equity, available, pnl, pnl_pct, dd_pct, cycle
FROM backtest_equity WHERE run_id = ? ORDER BY ts ASC
`, runID)
var eqs []BacktestEquity
err := s.db.Where("run_id = ?", runID).Order("ts ASC").Find(&eqs).Error
if err != nil {
return nil, err
}
defer rows.Close()
points := make([]EquityPoint, 0)
for rows.Next() {
var point EquityPoint
if err := rows.Scan(&point.Timestamp, &point.Equity, &point.Available,
&point.PnL, &point.PnLPct, &point.DrawdownPct, &point.Cycle); err != nil {
return nil, err
points := make([]EquityPoint, len(eqs))
for i, eq := range eqs {
points[i] = EquityPoint{
Timestamp: eq.TS,
Equity: eq.Equity,
Available: eq.Available,
PnL: eq.PnL,
PnLPct: eq.PnLPct,
DrawdownPct: eq.DDPct,
Cycle: eq.Cycle,
}
points = append(points, point)
}
return points, rows.Err()
return points, nil
}
// AppendTradeEvent appends trade event
func (s *BacktestStore) AppendTradeEvent(runID string, event TradeEvent) error {
_, err := s.db.Exec(`
INSERT INTO backtest_trades (run_id, ts, symbol, action, side, qty, price, fee,
slippage, order_value, realized_pnl, leverage, cycle,
position_after, liquidation, note)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`, runID, event.Timestamp, event.Symbol, event.Action, event.Side, event.Quantity,
event.Price, event.Fee, event.Slippage, event.OrderValue, event.RealizedPnL,
event.Leverage, event.Cycle, event.PositionAfter, event.LiquidationFlag, event.Note)
return err
trade := BacktestTrade{
RunID: runID,
TS: event.Timestamp,
Symbol: event.Symbol,
Action: event.Action,
Side: event.Side,
Qty: event.Quantity,
Price: event.Price,
Fee: event.Fee,
Slippage: event.Slippage,
OrderValue: event.OrderValue,
RealizedPnL: event.RealizedPnL,
Leverage: event.Leverage,
Cycle: event.Cycle,
PositionAfter: event.PositionAfter,
Liquidation: event.LiquidationFlag,
Note: event.Note,
}
return s.db.Create(&trade).Error
}
// LoadTradeEvents loads trade events
func (s *BacktestStore) LoadTradeEvents(runID string) ([]TradeEvent, error) {
rows, err := s.db.Query(`
SELECT ts, symbol, action, side, qty, price, fee, slippage, order_value,
realized_pnl, leverage, cycle, position_after, liquidation, note
FROM backtest_trades WHERE run_id = ? ORDER BY ts ASC
`, runID)
var trades []BacktestTrade
err := s.db.Where("run_id = ?", runID).Order("ts ASC").Find(&trades).Error
if err != nil {
return nil, err
}
defer rows.Close()
events := make([]TradeEvent, 0)
for rows.Next() {
var event TradeEvent
if err := rows.Scan(&event.Timestamp, &event.Symbol, &event.Action, &event.Side,
&event.Quantity, &event.Price, &event.Fee, &event.Slippage, &event.OrderValue,
&event.RealizedPnL, &event.Leverage, &event.Cycle, &event.PositionAfter,
&event.LiquidationFlag, &event.Note); err != nil {
return nil, err
events := make([]TradeEvent, len(trades))
for i, trade := range trades {
events[i] = TradeEvent{
Timestamp: trade.TS,
Symbol: trade.Symbol,
Action: trade.Action,
Side: trade.Side,
Quantity: trade.Qty,
Price: trade.Price,
Fee: trade.Fee,
Slippage: trade.Slippage,
OrderValue: trade.OrderValue,
RealizedPnL: trade.RealizedPnL,
Leverage: trade.Leverage,
Cycle: trade.Cycle,
PositionAfter: trade.PositionAfter,
LiquidationFlag: trade.Liquidation,
Note: trade.Note,
}
events = append(events, event)
}
return events, rows.Err()
return events, nil
}
// SaveMetrics saves metrics
func (s *BacktestStore) SaveMetrics(runID string, payload []byte) error {
_, err := s.db.Exec(`
INSERT INTO backtest_metrics (run_id, payload, updated_at)
VALUES (?, ?, CURRENT_TIMESTAMP)
ON CONFLICT(run_id) DO UPDATE SET payload=excluded.payload, updated_at=CURRENT_TIMESTAMP
`, runID, payload)
return err
metrics := BacktestMetrics{
RunID: runID,
Payload: payload,
}
return s.db.Save(&metrics).Error
}
// LoadMetrics loads metrics
func (s *BacktestStore) LoadMetrics(runID string) ([]byte, error) {
var payload []byte
err := s.db.QueryRow(`SELECT payload FROM backtest_metrics WHERE run_id = ?`, runID).Scan(&payload)
return payload, err
var metrics BacktestMetrics
err := s.db.Where("run_id = ?", runID).First(&metrics).Error
if err != nil {
return nil, err
}
return metrics.Payload, nil
}
// SaveDecisionRecord saves decision record
func (s *BacktestStore) SaveDecisionRecord(runID string, cycle int, payload []byte) error {
_, err := s.db.Exec(`
INSERT INTO backtest_decisions (run_id, cycle, payload)
VALUES (?, ?, ?)
`, runID, cycle, payload)
return err
decision := BacktestDecision{
RunID: runID,
Cycle: cycle,
Payload: payload,
}
return s.db.Create(&decision).Error
}
// LoadDecisionRecords loads decision records
func (s *BacktestStore) LoadDecisionRecords(runID string, limit, offset int) ([]json.RawMessage, error) {
rows, err := s.db.Query(`
SELECT payload FROM backtest_decisions
WHERE run_id = ?
ORDER BY id DESC
LIMIT ? OFFSET ?
`, runID, limit, offset)
var decisions []BacktestDecision
err := s.db.Where("run_id = ?", runID).
Order("id DESC").
Limit(limit).
Offset(offset).
Find(&decisions).Error
if err != nil {
return nil, err
}
defer rows.Close()
records := make([]json.RawMessage, 0, limit)
for rows.Next() {
var payload []byte
if err := rows.Scan(&payload); err != nil {
return nil, err
records := make([]json.RawMessage, len(decisions))
for i, d := range decisions {
records[i] = json.RawMessage(d.Payload)
}
records = append(records, json.RawMessage(payload))
}
return records, rows.Err()
return records, nil
}
// LoadLatestDecision loads latest decision
func (s *BacktestStore) LoadLatestDecision(runID string, cycle int) ([]byte, error) {
var query string
var args []interface{}
var decision BacktestDecision
query := s.db.Where("run_id = ?", runID)
if cycle > 0 {
query = `SELECT payload FROM backtest_decisions WHERE run_id = ? AND cycle = ? ORDER BY datetime(created_at) DESC LIMIT 1`
args = []interface{}{runID, cycle}
} else {
query = `SELECT payload FROM backtest_decisions WHERE run_id = ? ORDER BY datetime(created_at) DESC LIMIT 1`
args = []interface{}{runID}
query = query.Where("cycle = ?", cycle)
}
var payload []byte
err := s.db.QueryRow(query, args...).Scan(&payload)
return payload, err
err := query.Order("created_at DESC").First(&decision).Error
if err != nil {
return nil, err
}
return decision.Payload, nil
}
// UpdateProgress updates progress
func (s *BacktestStore) UpdateProgress(runID string, progressPct, equity float64, barIndex int, liquidated bool) error {
_, err := s.db.Exec(`
UPDATE backtest_runs
SET progress_pct = ?, equity_last = ?, processed_bars = ?, liquidated = ?, updated_at = CURRENT_TIMESTAMP
WHERE run_id = ?
`, progressPct, equity, barIndex, liquidated, runID)
return err
return s.db.Model(&BacktestRun{}).Where("run_id = ?", runID).Updates(map[string]interface{}{
"progress_pct": progressPct,
"equity_last": equity,
"processed_bars": barIndex,
"liquidated": liquidated,
}).Error
}
// ListIndexEntries lists index entries
func (s *BacktestStore) ListIndexEntries() ([]RunIndexEntry, error) {
rows, err := s.db.Query(`
SELECT run_id, state, symbol_count, decision_tf, equity_last, max_drawdown_pct,
created_at, updated_at, config_json
FROM backtest_runs
ORDER BY datetime(updated_at) DESC
`)
var runs []BacktestRun
err := s.db.Order("updated_at DESC").Find(&runs).Error
if err != nil {
return nil, err
}
defer rows.Close()
var entries []RunIndexEntry
for rows.Next() {
var entry RunIndexEntry
var symbolCnt int
var cfgJSON []byte
var createdISO, updatedISO string
if err := rows.Scan(&entry.RunID, &entry.State, &symbolCnt, &entry.DecisionTF,
&entry.EquityLast, &entry.MaxDrawdownPct, &createdISO, &updatedISO, &cfgJSON); err != nil {
return nil, err
entries := make([]RunIndexEntry, len(runs))
for i, run := range runs {
entry := RunIndexEntry{
RunID: run.RunID,
State: run.State,
DecisionTF: run.DecisionTF,
EquityLast: run.EquityLast,
MaxDrawdownPct: run.MaxDrawdownPct,
CreatedAtISO: run.CreatedAt.Format(time.RFC3339),
UpdatedAtISO: run.UpdatedAt.Format(time.RFC3339),
Symbols: make([]string, 0, run.SymbolCount),
}
entry.CreatedAtISO = createdISO
entry.UpdatedAtISO = updatedISO
entry.Symbols = make([]string, 0, symbolCnt)
// Try to extract more information from config
if len(cfgJSON) > 0 {
if len(run.ConfigJSON) > 0 {
var cfg struct {
Symbols []string `json:"symbols"`
StartTS int64 `json:"start_ts"`
EndTS int64 `json:"end_ts"`
}
if json.Unmarshal(cfgJSON, &cfg) == nil {
if json.Unmarshal(run.ConfigJSON, &cfg) == nil {
entry.Symbols = cfg.Symbols
entry.StartTS = cfg.StartTS
entry.EndTS = cfg.EndTS
}
}
entries = append(entries, entry)
entries[i] = entry
}
return entries, rows.Err()
return entries, nil
}
// DeleteRun deletes run
func (s *BacktestStore) DeleteRun(runID string) error {
_, err := s.db.Exec(`DELETE FROM backtest_runs WHERE run_id = ?`, runID)
return err
// Delete related records first (cascade may not work in all cases)
s.db.Where("run_id = ?", runID).Delete(&BacktestCheckpoint{})
s.db.Where("run_id = ?", runID).Delete(&BacktestEquity{})
s.db.Where("run_id = ?", runID).Delete(&BacktestTrade{})
s.db.Where("run_id = ?", runID).Delete(&BacktestMetrics{})
s.db.Where("run_id = ?", runID).Delete(&BacktestDecision{})
return s.db.Where("run_id = ?", runID).Delete(&BacktestRun{}).Error
}
// SaveConfig saves config
func (s *BacktestStore) SaveConfig(runID, userID, template, customPrompt, provider, model string, override bool, configJSON []byte) error {
now := time.Now().UTC().Format(time.RFC3339)
if userID == "" {
userID = "default"
}
_, err := s.db.Exec(`
INSERT INTO backtest_runs (run_id, user_id, config_json, prompt_template, custom_prompt,
override_prompt, ai_provider, ai_model, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(run_id) DO NOTHING
`, runID, userID, configJSON, template, customPrompt, override, provider, model, now, now)
if err != nil {
return err
run := BacktestRun{
RunID: runID,
UserID: userID,
ConfigJSON: configJSON,
PromptTemplate: template,
CustomPrompt: customPrompt,
OverridePrompt: override,
AIProvider: provider,
AIModel: model,
}
_, err = s.db.Exec(`
UPDATE backtest_runs
SET user_id = ?, config_json = ?, prompt_template = ?, custom_prompt = ?,
override_prompt = ?, ai_provider = ?, ai_model = ?, updated_at = CURRENT_TIMESTAMP
WHERE run_id = ?
`, userID, configJSON, template, customPrompt, override, provider, model, runID)
return err
return s.db.Save(&run).Error
}
// LoadConfig loads config
func (s *BacktestStore) LoadConfig(runID string) ([]byte, error) {
var payload []byte
err := s.db.QueryRow(`SELECT config_json FROM backtest_runs WHERE run_id = ?`, runID).Scan(&payload)
return payload, err
var run BacktestRun
err := s.db.Where("run_id = ?", runID).First(&run).Error
if err != nil {
return nil, err
}
return run.ConfigJSON, nil
}
+215 -481
View File
@@ -1,12 +1,11 @@
package store
import (
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// DebateStatus represents the status of a debate session
@@ -49,7 +48,26 @@ var PersonalityEmojis = map[DebatePersonality]string{
PersonalityRiskManager: "🛡️",
}
// DebateSession represents a debate session
// DebateDecision represents a trading decision from the debate
type DebateDecision struct {
Action string `json:"action"` // open_long/open_short/close_long/close_short/hold/wait
Symbol string `json:"symbol"` // Trading pair
Confidence int `json:"confidence"` // 0-100
Leverage int `json:"leverage"` // Recommended leverage
PositionPct float64 `json:"position_pct"` // Position size as percentage of equity (0.0-1.0)
PositionSizeUSD float64 `json:"position_size_usd"` // Position size in USD (calculated from pct)
StopLoss float64 `json:"stop_loss"` // Stop loss price
TakeProfit float64 `json:"take_profit"` // Take profit price
Reasoning string `json:"reasoning"` // Brief reasoning
// Execution tracking
Executed bool `json:"executed"` // Whether this decision was executed
ExecutedAt time.Time `json:"executed_at,omitempty"` // When it was executed
OrderID string `json:"order_id,omitempty"` // Exchange order ID
Error string `json:"error,omitempty"` // Execution error if any
}
// DebateSession represents a debate session (API struct)
type DebateSession struct {
ID string `json:"id"`
UserID string `json:"user_id"`
@@ -73,191 +91,157 @@ type DebateSession struct {
UpdatedAt time.Time `json:"updated_at"`
}
// DebateDecision represents a trading decision from the debate
type DebateDecision struct {
Action string `json:"action"` // open_long/open_short/close_long/close_short/hold/wait
Symbol string `json:"symbol"` // Trading pair
Confidence int `json:"confidence"` // 0-100
Leverage int `json:"leverage"` // Recommended leverage
PositionPct float64 `json:"position_pct"` // Position size as percentage of equity (0.0-1.0)
PositionSizeUSD float64 `json:"position_size_usd"` // Position size in USD (calculated from pct)
StopLoss float64 `json:"stop_loss"` // Stop loss price
TakeProfit float64 `json:"take_profit"` // Take profit price
Reasoning string `json:"reasoning"` // Brief reasoning
// DebateSessionDB is the GORM model for debate_sessions
type DebateSessionDB struct {
ID string `gorm:"column:id;primaryKey"`
UserID string `gorm:"column:user_id;not null;index"`
Name string `gorm:"column:name;not null"`
StrategyID string `gorm:"column:strategy_id;not null"`
Status DebateStatus `gorm:"column:status;not null;default:pending;index"`
Symbol string `gorm:"column:symbol;not null"`
MaxRounds int `gorm:"column:max_rounds;default:3"`
CurrentRound int `gorm:"column:current_round;default:0"`
IntervalMinutes int `gorm:"column:interval_minutes;default:5"`
PromptVariant string `gorm:"column:prompt_variant;default:balanced"`
FinalDecision string `gorm:"column:final_decision"` // JSON string
AutoExecute bool `gorm:"column:auto_execute;default:false"`
TraderID string `gorm:"column:trader_id"`
EnableOIRanking bool `gorm:"column:enable_oi_ranking;default:false"`
OIRankingLimit int `gorm:"column:oi_ranking_limit;default:10"`
OIDuration string `gorm:"column:oi_duration;default:1h"`
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime"`
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime"`
}
// Execution tracking
Executed bool `json:"executed"` // Whether this decision was executed
ExecutedAt time.Time `json:"executed_at,omitempty"` // When it was executed
OrderID string `json:"order_id,omitempty"` // Exchange order ID
Error string `json:"error,omitempty"` // Execution error if any
func (DebateSessionDB) TableName() string {
return "debate_sessions"
}
func (db *DebateSessionDB) toSession() *DebateSession {
s := &DebateSession{
ID: db.ID,
UserID: db.UserID,
Name: db.Name,
StrategyID: db.StrategyID,
Status: db.Status,
Symbol: db.Symbol,
MaxRounds: db.MaxRounds,
CurrentRound: db.CurrentRound,
IntervalMinutes: db.IntervalMinutes,
PromptVariant: db.PromptVariant,
AutoExecute: db.AutoExecute,
TraderID: db.TraderID,
EnableOIRanking: db.EnableOIRanking,
OIRankingLimit: db.OIRankingLimit,
OIDuration: db.OIDuration,
CreatedAt: db.CreatedAt,
UpdatedAt: db.UpdatedAt,
}
// Set defaults
if s.IntervalMinutes == 0 {
s.IntervalMinutes = 5
}
if s.PromptVariant == "" {
s.PromptVariant = "balanced"
}
if s.OIRankingLimit == 0 {
s.OIRankingLimit = 10
}
if s.OIDuration == "" {
s.OIDuration = "1h"
}
// Parse final decision
if db.FinalDecision != "" {
var decision DebateDecision
if json.Unmarshal([]byte(db.FinalDecision), &decision) == nil {
s.FinalDecision = &decision
}
}
return s
}
// DebateParticipant represents an AI participant in a debate
type DebateParticipant struct {
ID string `json:"id"`
SessionID string `json:"session_id"`
AIModelID string `json:"ai_model_id"`
AIModelName string `json:"ai_model_name"`
Provider string `json:"provider"`
Personality DebatePersonality `json:"personality"`
Color string `json:"color"`
SpeakOrder int `json:"speak_order"`
CreatedAt time.Time `json:"created_at"`
ID string `gorm:"column:id;primaryKey" json:"id"`
SessionID string `gorm:"column:session_id;not null;index" json:"session_id"`
AIModelID string `gorm:"column:ai_model_id;not null" json:"ai_model_id"`
AIModelName string `gorm:"column:ai_model_name;not null" json:"ai_model_name"`
Provider string `gorm:"column:provider;not null" json:"provider"`
Personality DebatePersonality `gorm:"column:personality;not null" json:"personality"`
Color string `gorm:"column:color;not null" json:"color"`
SpeakOrder int `gorm:"column:speak_order;default:0" json:"speak_order"`
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
}
func (DebateParticipant) TableName() string {
return "debate_participants"
}
// DebateMessage represents a message in the debate
type DebateMessage struct {
ID string `json:"id"`
SessionID string `json:"session_id"`
Round int `json:"round"`
AIModelID string `json:"ai_model_id"`
AIModelName string `json:"ai_model_name"`
Provider string `json:"provider"`
Personality DebatePersonality `json:"personality"`
MessageType string `json:"message_type"` // analysis/rebuttal/final/vote
Content string `json:"content"`
Decision *DebateDecision `json:"decision,omitempty"` // Single decision (backward compat)
Decisions []*DebateDecision `json:"decisions,omitempty"` // Multi-coin decisions
Confidence int `json:"confidence"`
CreatedAt time.Time `json:"created_at"`
ID string `gorm:"column:id;primaryKey" json:"id"`
SessionID string `gorm:"column:session_id;not null;index" json:"session_id"`
Round int `gorm:"column:round;not null" json:"round"`
AIModelID string `gorm:"column:ai_model_id;not null" json:"ai_model_id"`
AIModelName string `gorm:"column:ai_model_name;not null" json:"ai_model_name"`
Provider string `gorm:"column:provider;not null" json:"provider"`
Personality DebatePersonality `gorm:"column:personality;not null" json:"personality"`
MessageType string `gorm:"column:message_type;not null" json:"message_type"` // analysis/rebuttal/final/vote
Content string `gorm:"column:content;not null" json:"content"`
DecisionRaw string `gorm:"column:decision" json:"-"` // JSON string in DB
Decision *DebateDecision `gorm:"-" json:"decision,omitempty"` // Parsed for API
Decisions []*DebateDecision `gorm:"-" json:"decisions,omitempty"` // Multi-coin decisions
Confidence int `gorm:"column:confidence;default:0" json:"confidence"`
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
}
func (DebateMessage) TableName() string {
return "debate_messages"
}
// DebateVote represents a final vote from an AI (can contain multiple coin decisions)
type DebateVote struct {
ID string `json:"id"`
SessionID string `json:"session_id"`
AIModelID string `json:"ai_model_id"`
AIModelName string `json:"ai_model_name"`
Action string `json:"action"` // Primary action (backward compat)
Symbol string `json:"symbol"` // Primary symbol (backward compat)
Confidence int `json:"confidence"`
Leverage int `json:"leverage"`
PositionPct float64 `json:"position_pct"`
StopLossPct float64 `json:"stop_loss_pct"`
TakeProfitPct float64 `json:"take_profit_pct"`
Reasoning string `json:"reasoning"`
Decisions []*DebateDecision `json:"decisions,omitempty"` // Multi-coin decisions
CreatedAt time.Time `json:"created_at"`
ID string `gorm:"column:id;primaryKey" json:"id"`
SessionID string `gorm:"column:session_id;not null;index" json:"session_id"`
AIModelID string `gorm:"column:ai_model_id;not null" json:"ai_model_id"`
AIModelName string `gorm:"column:ai_model_name;not null" json:"ai_model_name"`
Action string `gorm:"column:action;not null" json:"action"` // Primary action (backward compat)
Symbol string `gorm:"column:symbol;not null" json:"symbol"` // Primary symbol (backward compat)
Confidence int `gorm:"column:confidence;default:0" json:"confidence"`
Leverage int `gorm:"column:leverage;default:5" json:"leverage"`
PositionPct float64 `gorm:"column:position_pct;default:0.2" json:"position_pct"`
StopLossPct float64 `gorm:"column:stop_loss_pct;default:0.03" json:"stop_loss_pct"`
TakeProfitPct float64 `gorm:"column:take_profit_pct;default:0.06" json:"take_profit_pct"`
Reasoning string `gorm:"column:reasoning" json:"reasoning"`
Decisions []*DebateDecision `gorm:"-" json:"decisions,omitempty"` // Multi-coin decisions
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
}
func (DebateVote) TableName() string {
return "debate_votes"
}
// DebateStore handles database operations for debates
type DebateStore struct {
db *sql.DB
db *gorm.DB
}
// NewDebateStore creates a new DebateStore
func NewDebateStore(db *sql.DB) *DebateStore {
func NewDebateStore(db *gorm.DB) *DebateStore {
return &DebateStore{db: db}
}
// InitSchema creates the debate tables
// InitSchema creates the debate tables using GORM AutoMigrate
func (s *DebateStore) InitSchema() error {
schemas := []string{
`CREATE TABLE IF NOT EXISTS debate_sessions (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
name TEXT NOT NULL,
strategy_id TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
symbol TEXT NOT NULL,
max_rounds INTEGER DEFAULT 3,
current_round INTEGER DEFAULT 0,
interval_minutes INTEGER DEFAULT 5,
prompt_variant TEXT DEFAULT 'balanced',
final_decision TEXT,
auto_execute BOOLEAN DEFAULT 0,
trader_id TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`,
`CREATE INDEX IF NOT EXISTS idx_debate_sessions_user_id ON debate_sessions(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_debate_sessions_status ON debate_sessions(status)`,
`CREATE TABLE IF NOT EXISTS debate_participants (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
ai_model_id TEXT NOT NULL,
ai_model_name TEXT NOT NULL,
provider TEXT NOT NULL,
personality TEXT NOT NULL,
color TEXT NOT NULL,
speak_order INTEGER DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (session_id) REFERENCES debate_sessions(id) ON DELETE CASCADE
)`,
`CREATE INDEX IF NOT EXISTS idx_debate_participants_session ON debate_participants(session_id)`,
`CREATE TABLE IF NOT EXISTS debate_messages (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
round INTEGER NOT NULL,
ai_model_id TEXT NOT NULL,
ai_model_name TEXT NOT NULL,
provider TEXT NOT NULL,
personality TEXT NOT NULL,
message_type TEXT NOT NULL,
content TEXT NOT NULL,
decision TEXT,
confidence INTEGER DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (session_id) REFERENCES debate_sessions(id) ON DELETE CASCADE
)`,
`CREATE INDEX IF NOT EXISTS idx_debate_messages_session ON debate_messages(session_id)`,
`CREATE INDEX IF NOT EXISTS idx_debate_messages_round ON debate_messages(session_id, round)`,
`CREATE TABLE IF NOT EXISTS debate_votes (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
ai_model_id TEXT NOT NULL,
ai_model_name TEXT NOT NULL,
action TEXT NOT NULL,
symbol TEXT NOT NULL,
confidence INTEGER DEFAULT 0,
leverage INTEGER DEFAULT 5,
position_pct REAL DEFAULT 0.2,
stop_loss_pct REAL DEFAULT 0.03,
take_profit_pct REAL DEFAULT 0.06,
reasoning TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (session_id) REFERENCES debate_sessions(id) ON DELETE CASCADE
)`,
`CREATE INDEX IF NOT EXISTS idx_debate_votes_session ON debate_votes(session_id)`,
// Trigger to update updated_at
`CREATE TRIGGER IF NOT EXISTS update_debate_sessions_timestamp
AFTER UPDATE ON debate_sessions
FOR EACH ROW
BEGIN
UPDATE debate_sessions SET updated_at = CURRENT_TIMESTAMP WHERE id = OLD.id;
END`,
}
for _, schema := range schemas {
if _, err := s.db.Exec(schema); err != nil {
return fmt.Errorf("failed to create debate schema: %w", err)
}
}
// Migrate: Add new columns to existing tables (ignore errors if columns already exist)
migrations := []string{
`ALTER TABLE debate_sessions ADD COLUMN interval_minutes INTEGER DEFAULT 5`,
`ALTER TABLE debate_sessions ADD COLUMN prompt_variant TEXT DEFAULT 'balanced'`,
`ALTER TABLE debate_sessions ADD COLUMN trader_id TEXT`,
`ALTER TABLE debate_sessions ADD COLUMN enable_oi_ranking BOOLEAN DEFAULT 0`,
`ALTER TABLE debate_sessions ADD COLUMN oi_ranking_limit INTEGER DEFAULT 10`,
`ALTER TABLE debate_sessions ADD COLUMN oi_duration TEXT DEFAULT '1h'`,
`ALTER TABLE debate_votes ADD COLUMN leverage INTEGER DEFAULT 5`,
`ALTER TABLE debate_votes ADD COLUMN position_pct REAL DEFAULT 0.2`,
`ALTER TABLE debate_votes ADD COLUMN stop_loss_pct REAL DEFAULT 0.03`,
`ALTER TABLE debate_votes ADD COLUMN take_profit_pct REAL DEFAULT 0.06`,
}
for _, migration := range migrations {
// Ignore errors - column may already exist
s.db.Exec(migration)
}
return nil
return s.db.AutoMigrate(
&DebateSessionDB{},
&DebateParticipant{},
&DebateMessage{},
&DebateVote{},
)
}
// CreateSession creates a new debate session
@@ -279,227 +263,73 @@ func (s *DebateStore) CreateSession(session *DebateSession) error {
if session.OIDuration == "" {
session.OIDuration = "1h"
}
session.CreatedAt = time.Now()
session.UpdatedAt = time.Now()
_, err := s.db.Exec(`
INSERT INTO debate_sessions (id, user_id, name, strategy_id, status, symbol, max_rounds, current_round, interval_minutes, prompt_variant, auto_execute, trader_id, enable_oi_ranking, oi_ranking_limit, oi_duration, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
session.ID, session.UserID, session.Name, session.StrategyID, session.Status,
session.Symbol, session.MaxRounds, session.CurrentRound, session.IntervalMinutes, session.PromptVariant,
session.AutoExecute, session.TraderID, session.EnableOIRanking, session.OIRankingLimit, session.OIDuration,
session.CreatedAt, session.UpdatedAt,
)
return err
db := &DebateSessionDB{
ID: session.ID,
UserID: session.UserID,
Name: session.Name,
StrategyID: session.StrategyID,
Status: session.Status,
Symbol: session.Symbol,
MaxRounds: session.MaxRounds,
CurrentRound: session.CurrentRound,
IntervalMinutes: session.IntervalMinutes,
PromptVariant: session.PromptVariant,
AutoExecute: session.AutoExecute,
TraderID: session.TraderID,
EnableOIRanking: session.EnableOIRanking,
OIRankingLimit: session.OIRankingLimit,
OIDuration: session.OIDuration,
}
return s.db.Create(db).Error
}
// GetSession gets a debate session by ID
func (s *DebateStore) GetSession(id string) (*DebateSession, error) {
var session DebateSession
var finalDecisionJSON sql.NullString
var traderID sql.NullString
var intervalMinutes sql.NullInt64
var promptVariant sql.NullString
var enableOIRanking sql.NullBool
var oiRankingLimit sql.NullInt64
var oiDuration sql.NullString
// Try new schema first
err := s.db.QueryRow(`
SELECT id, user_id, name, strategy_id, status, symbol, max_rounds, current_round,
interval_minutes, prompt_variant, final_decision, auto_execute, trader_id,
enable_oi_ranking, oi_ranking_limit, oi_duration, created_at, updated_at
FROM debate_sessions WHERE id = ?`, id,
).Scan(
&session.ID, &session.UserID, &session.Name, &session.StrategyID,
&session.Status, &session.Symbol, &session.MaxRounds, &session.CurrentRound,
&intervalMinutes, &promptVariant,
&finalDecisionJSON, &session.AutoExecute, &traderID,
&enableOIRanking, &oiRankingLimit, &oiDuration, &session.CreatedAt, &session.UpdatedAt,
)
// Fallback to basic schema if new columns don't exist
if err != nil {
err = s.db.QueryRow(`
SELECT id, user_id, name, strategy_id, status, symbol, max_rounds, current_round,
final_decision, auto_execute, created_at, updated_at
FROM debate_sessions WHERE id = ?`, id,
).Scan(
&session.ID, &session.UserID, &session.Name, &session.StrategyID,
&session.Status, &session.Symbol, &session.MaxRounds, &session.CurrentRound,
&finalDecisionJSON, &session.AutoExecute, &session.CreatedAt, &session.UpdatedAt,
)
if err != nil {
var db DebateSessionDB
if err := s.db.Where("id = ?", id).First(&db).Error; err != nil {
return nil, err
}
// Set defaults for new fields
session.IntervalMinutes = 5
session.PromptVariant = "balanced"
session.OIRankingLimit = 10
session.OIDuration = "1h"
} else {
// Set defaults for nullable fields
session.IntervalMinutes = 5
if intervalMinutes.Valid {
session.IntervalMinutes = int(intervalMinutes.Int64)
}
session.PromptVariant = "balanced"
if promptVariant.Valid {
session.PromptVariant = promptVariant.String
}
if traderID.Valid {
session.TraderID = traderID.String
}
if enableOIRanking.Valid {
session.EnableOIRanking = enableOIRanking.Bool
}
session.OIRankingLimit = 10
if oiRankingLimit.Valid {
session.OIRankingLimit = int(oiRankingLimit.Int64)
}
session.OIDuration = "1h"
if oiDuration.Valid {
session.OIDuration = oiDuration.String
}
}
if finalDecisionJSON.Valid && finalDecisionJSON.String != "" {
var decision DebateDecision
if err := json.Unmarshal([]byte(finalDecisionJSON.String), &decision); err == nil {
session.FinalDecision = &decision
}
}
return &session, nil
return db.toSession(), nil
}
// GetSessionsByUser gets all debate sessions for a user
func (s *DebateStore) GetSessionsByUser(userID string) ([]*DebateSession, error) {
// First try the new schema with all columns
rows, err := s.db.Query(`
SELECT id, user_id, name, strategy_id, status, symbol, max_rounds, current_round,
interval_minutes, prompt_variant, final_decision, auto_execute, trader_id, created_at, updated_at
FROM debate_sessions WHERE user_id = ? ORDER BY created_at DESC`, userID,
)
// If query fails (likely due to missing columns), try basic query
if err != nil {
return s.getSessionsByUserBasic(userID)
}
defer rows.Close()
var sessions []*DebateSession
for rows.Next() {
var session DebateSession
var finalDecisionJSON sql.NullString
var traderID sql.NullString
var intervalMinutes sql.NullInt64
var promptVariant sql.NullString
if err := rows.Scan(
&session.ID, &session.UserID, &session.Name, &session.StrategyID,
&session.Status, &session.Symbol, &session.MaxRounds, &session.CurrentRound,
&intervalMinutes, &promptVariant,
&finalDecisionJSON, &session.AutoExecute, &traderID, &session.CreatedAt, &session.UpdatedAt,
); err != nil {
var dbs []DebateSessionDB
if err := s.db.Where("user_id = ?", userID).Order("created_at DESC").Find(&dbs).Error; err != nil {
return nil, err
}
// Set defaults for nullable fields
session.IntervalMinutes = 5
if intervalMinutes.Valid {
session.IntervalMinutes = int(intervalMinutes.Int64)
}
session.PromptVariant = "balanced"
if promptVariant.Valid {
session.PromptVariant = promptVariant.String
}
if finalDecisionJSON.Valid && finalDecisionJSON.String != "" {
var decision DebateDecision
if err := json.Unmarshal([]byte(finalDecisionJSON.String), &decision); err == nil {
session.FinalDecision = &decision
}
}
if traderID.Valid {
session.TraderID = traderID.String
}
sessions = append(sessions, &session)
sessions := make([]*DebateSession, len(dbs))
for i, db := range dbs {
sessions[i] = db.toSession()
}
return sessions, nil
}
// ListAllSessions returns all debate sessions (for cleanup on startup)
func (s *DebateStore) ListAllSessions() ([]*DebateSession, error) {
rows, err := s.db.Query(`SELECT id, status FROM debate_sessions`)
if err != nil {
return nil, err
}
defer rows.Close()
var sessions []*DebateSession
for rows.Next() {
var session DebateSession
if err := rows.Scan(&session.ID, &session.Status); err != nil {
return nil, err
}
sessions = append(sessions, &session)
}
return sessions, nil
}
// getSessionsByUserBasic is a fallback for old schema without new columns
func (s *DebateStore) getSessionsByUserBasic(userID string) ([]*DebateSession, error) {
rows, err := s.db.Query(`
SELECT id, user_id, name, strategy_id, status, symbol, max_rounds, current_round,
final_decision, auto_execute, created_at, updated_at
FROM debate_sessions WHERE user_id = ? ORDER BY created_at DESC`, userID,
)
if err != nil {
return nil, err
}
defer rows.Close()
var sessions []*DebateSession
for rows.Next() {
var session DebateSession
var finalDecisionJSON sql.NullString
if err := rows.Scan(
&session.ID, &session.UserID, &session.Name, &session.StrategyID,
&session.Status, &session.Symbol, &session.MaxRounds, &session.CurrentRound,
&finalDecisionJSON, &session.AutoExecute, &session.CreatedAt, &session.UpdatedAt,
); err != nil {
var dbs []DebateSessionDB
if err := s.db.Select("id, status").Find(&dbs).Error; err != nil {
return nil, err
}
// Set defaults for new fields
session.IntervalMinutes = 5
session.PromptVariant = "balanced"
if finalDecisionJSON.Valid && finalDecisionJSON.String != "" {
var decision DebateDecision
if err := json.Unmarshal([]byte(finalDecisionJSON.String), &decision); err == nil {
session.FinalDecision = &decision
}
}
sessions = append(sessions, &session)
sessions := make([]*DebateSession, len(dbs))
for i, db := range dbs {
sessions[i] = &DebateSession{ID: db.ID, Status: db.Status}
}
return sessions, nil
}
// UpdateSessionStatus updates the status of a debate session
func (s *DebateStore) UpdateSessionStatus(id string, status DebateStatus) error {
_, err := s.db.Exec(`UPDATE debate_sessions SET status = ? WHERE id = ?`, status, id)
return err
return s.db.Model(&DebateSessionDB{}).Where("id = ?", id).Update("status", status).Error
}
// UpdateSessionRound updates the current round of a debate session
func (s *DebateStore) UpdateSessionRound(id string, round int) error {
_, err := s.db.Exec(`UPDATE debate_sessions SET current_round = ? WHERE id = ?`, round, id)
return err
return s.db.Model(&DebateSessionDB{}).Where("id = ?", id).Update("current_round", round).Error
}
// UpdateSessionFinalDecision updates the final decision of a debate session (single decision)
@@ -508,30 +338,31 @@ func (s *DebateStore) UpdateSessionFinalDecision(id string, decision *DebateDeci
if err != nil {
return err
}
_, err = s.db.Exec(`UPDATE debate_sessions SET final_decision = ?, status = ? WHERE id = ?`,
string(decisionJSON), DebateStatusCompleted, id)
return err
return s.db.Model(&DebateSessionDB{}).Where("id = ?", id).Updates(map[string]interface{}{
"final_decision": string(decisionJSON),
"status": DebateStatusCompleted,
}).Error
}
// UpdateSessionFinalDecisions updates both single and multi-coin final decisions
func (s *DebateStore) UpdateSessionFinalDecisions(id string, primaryDecision *DebateDecision, allDecisions []*DebateDecision) error {
// Always store primary decision as a single object (for backward compat)
// This ensures GetSession can deserialize it correctly
primaryJSON, err := json.Marshal(primaryDecision)
if err != nil {
return err
}
// Update final_decision with primary decision and set status to completed
_, err = s.db.Exec(`UPDATE debate_sessions SET final_decision = ?, status = ? WHERE id = ?`,
string(primaryJSON), DebateStatusCompleted, id)
return err
return s.db.Model(&DebateSessionDB{}).Where("id = ?", id).Updates(map[string]interface{}{
"final_decision": string(primaryJSON),
"status": DebateStatusCompleted,
}).Error
}
// DeleteSession deletes a debate session and all related data
func (s *DebateStore) DeleteSession(id string) error {
_, err := s.db.Exec(`DELETE FROM debate_sessions WHERE id = ?`, id)
return err
// Delete related data first
s.db.Where("session_id = ?", id).Delete(&DebateParticipant{})
s.db.Where("session_id = ?", id).Delete(&DebateMessage{})
s.db.Where("session_id = ?", id).Delete(&DebateVote{})
return s.db.Where("id = ?", id).Delete(&DebateSessionDB{}).Error
}
// AddParticipant adds a participant to a debate session
@@ -539,9 +370,6 @@ func (s *DebateStore) AddParticipant(participant *DebateParticipant) error {
if participant.ID == "" {
participant.ID = uuid.New().String()
}
participant.CreatedAt = time.Now()
// Set color based on personality if not provided
if participant.Color == "" {
if color, ok := PersonalityColors[participant.Personality]; ok {
participant.Color = color
@@ -549,39 +377,14 @@ func (s *DebateStore) AddParticipant(participant *DebateParticipant) error {
participant.Color = "#6B7280" // Default gray
}
}
_, err := s.db.Exec(`
INSERT INTO debate_participants (id, session_id, ai_model_id, ai_model_name, provider, personality, color, speak_order, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
participant.ID, participant.SessionID, participant.AIModelID, participant.AIModelName,
participant.Provider, participant.Personality, participant.Color, participant.SpeakOrder, participant.CreatedAt,
)
return err
return s.db.Create(participant).Error
}
// GetParticipants gets all participants for a debate session
func (s *DebateStore) GetParticipants(sessionID string) ([]*DebateParticipant, error) {
rows, err := s.db.Query(`
SELECT id, session_id, ai_model_id, ai_model_name, provider, personality, color, speak_order, created_at
FROM debate_participants WHERE session_id = ? ORDER BY speak_order`, sessionID,
)
if err != nil {
return nil, err
}
defer rows.Close()
var participants []*DebateParticipant
for rows.Next() {
var p DebateParticipant
if err := rows.Scan(
&p.ID, &p.SessionID, &p.AIModelID, &p.AIModelName,
&p.Provider, &p.Personality, &p.Color, &p.SpeakOrder, &p.CreatedAt,
); err != nil {
return nil, err
}
participants = append(participants, &p)
}
return participants, nil
err := s.db.Where("session_id = ?", sessionID).Order("speak_order").Find(&participants).Error
return participants, err
}
// AddMessage adds a message to a debate session
@@ -589,95 +392,52 @@ func (s *DebateStore) AddMessage(msg *DebateMessage) error {
if msg.ID == "" {
msg.ID = uuid.New().String()
}
msg.CreatedAt = time.Now()
var decisionJSON sql.NullString
if msg.Decision != nil {
data, err := json.Marshal(msg.Decision)
if err != nil {
return err
}
decisionJSON = sql.NullString{String: string(data), Valid: true}
msg.DecisionRaw = string(data)
}
_, err := s.db.Exec(`
INSERT INTO debate_messages (id, session_id, round, ai_model_id, ai_model_name, provider, personality, message_type, content, decision, confidence, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
msg.ID, msg.SessionID, msg.Round, msg.AIModelID, msg.AIModelName,
msg.Provider, msg.Personality, msg.MessageType, msg.Content,
decisionJSON, msg.Confidence, msg.CreatedAt,
)
return err
return s.db.Create(msg).Error
}
// GetMessages gets all messages for a debate session
func (s *DebateStore) GetMessages(sessionID string) ([]*DebateMessage, error) {
rows, err := s.db.Query(`
SELECT id, session_id, round, ai_model_id, ai_model_name, provider, personality, message_type, content, decision, confidence, created_at
FROM debate_messages WHERE session_id = ? ORDER BY round, created_at`, sessionID,
)
var messages []*DebateMessage
err := s.db.Where("session_id = ?", sessionID).Order("round, created_at").Find(&messages).Error
if err != nil {
return nil, err
}
defer rows.Close()
var messages []*DebateMessage
for rows.Next() {
var msg DebateMessage
var decisionJSON sql.NullString
if err := rows.Scan(
&msg.ID, &msg.SessionID, &msg.Round, &msg.AIModelID, &msg.AIModelName,
&msg.Provider, &msg.Personality, &msg.MessageType, &msg.Content,
&decisionJSON, &msg.Confidence, &msg.CreatedAt,
); err != nil {
return nil, err
}
if decisionJSON.Valid && decisionJSON.String != "" {
// Parse decision JSON
for _, msg := range messages {
if msg.DecisionRaw != "" {
var decision DebateDecision
if err := json.Unmarshal([]byte(decisionJSON.String), &decision); err == nil {
if json.Unmarshal([]byte(msg.DecisionRaw), &decision) == nil {
msg.Decision = &decision
}
}
messages = append(messages, &msg)
}
return messages, nil
}
// GetMessagesByRound gets messages for a specific round
func (s *DebateStore) GetMessagesByRound(sessionID string, round int) ([]*DebateMessage, error) {
rows, err := s.db.Query(`
SELECT id, session_id, round, ai_model_id, ai_model_name, provider, personality, message_type, content, decision, confidence, created_at
FROM debate_messages WHERE session_id = ? AND round = ? ORDER BY created_at`, sessionID, round,
)
var messages []*DebateMessage
err := s.db.Where("session_id = ? AND round = ?", sessionID, round).Order("created_at").Find(&messages).Error
if err != nil {
return nil, err
}
defer rows.Close()
var messages []*DebateMessage
for rows.Next() {
var msg DebateMessage
var decisionJSON sql.NullString
if err := rows.Scan(
&msg.ID, &msg.SessionID, &msg.Round, &msg.AIModelID, &msg.AIModelName,
&msg.Provider, &msg.Personality, &msg.MessageType, &msg.Content,
&decisionJSON, &msg.Confidence, &msg.CreatedAt,
); err != nil {
return nil, err
}
if decisionJSON.Valid && decisionJSON.String != "" {
// Parse decision JSON
for _, msg := range messages {
if msg.DecisionRaw != "" {
var decision DebateDecision
if err := json.Unmarshal([]byte(decisionJSON.String), &decision); err == nil {
if json.Unmarshal([]byte(msg.DecisionRaw), &decision) == nil {
msg.Decision = &decision
}
}
messages = append(messages, &msg)
}
return messages, nil
}
@@ -687,40 +447,14 @@ func (s *DebateStore) AddVote(vote *DebateVote) error {
if vote.ID == "" {
vote.ID = uuid.New().String()
}
vote.CreatedAt = time.Now()
_, err := s.db.Exec(`
INSERT INTO debate_votes (id, session_id, ai_model_id, ai_model_name, action, symbol, confidence, leverage, position_pct, stop_loss_pct, take_profit_pct, reasoning, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
vote.ID, vote.SessionID, vote.AIModelID, vote.AIModelName,
vote.Action, vote.Symbol, vote.Confidence, vote.Leverage, vote.PositionPct, vote.StopLossPct, vote.TakeProfitPct, vote.Reasoning, vote.CreatedAt,
)
return err
return s.db.Create(vote).Error
}
// GetVotes gets all votes for a debate session
func (s *DebateStore) GetVotes(sessionID string) ([]*DebateVote, error) {
rows, err := s.db.Query(`
SELECT id, session_id, ai_model_id, ai_model_name, action, symbol, confidence, leverage, position_pct, stop_loss_pct, take_profit_pct, reasoning, created_at
FROM debate_votes WHERE session_id = ? ORDER BY created_at`, sessionID,
)
if err != nil {
return nil, err
}
defer rows.Close()
var votes []*DebateVote
for rows.Next() {
var vote DebateVote
if err := rows.Scan(
&vote.ID, &vote.SessionID, &vote.AIModelID, &vote.AIModelName,
&vote.Action, &vote.Symbol, &vote.Confidence, &vote.Leverage, &vote.PositionPct, &vote.StopLossPct, &vote.TakeProfitPct, &vote.Reasoning, &vote.CreatedAt,
); err != nil {
return nil, err
}
votes = append(votes, &vote)
}
return votes, nil
err := s.db.Where("session_id = ?", sessionID).Order("created_at").Find(&votes).Error
return votes, err
}
// DebateSessionWithDetails combines session with participants and messages
+135 -198
View File
@@ -1,18 +1,41 @@
package store
import (
"database/sql"
"encoding/json"
"fmt"
"time"
"gorm.io/gorm"
)
// DecisionStore decision log storage
type DecisionStore struct {
db *sql.DB
db *gorm.DB
}
// DecisionRecord decision record
// DecisionRecordDB internal GORM model for decision_records table
type DecisionRecordDB struct {
ID int64 `gorm:"primaryKey;autoIncrement"`
TraderID string `gorm:"column:trader_id;not null;index:idx_decision_records_trader_time"`
CycleNumber int `gorm:"column:cycle_number;not null"`
Timestamp time.Time `gorm:"not null;index:idx_decision_records_trader_time,sort:desc;index:idx_decision_records_timestamp,sort:desc"`
SystemPrompt string `gorm:"column:system_prompt;default:''"`
InputPrompt string `gorm:"column:input_prompt;default:''"`
CoTTrace string `gorm:"column:cot_trace;default:''"`
DecisionJSON string `gorm:"column:decision_json;default:''"`
RawResponse string `gorm:"column:raw_response;default:''"`
CandidateCoins string `gorm:"column:candidate_coins;default:''"`
ExecutionLog string `gorm:"column:execution_log;default:''"`
Decisions string `gorm:"column:decisions;default:'[]'"`
Success bool `gorm:"default:false"`
ErrorMessage string `gorm:"column:error_message;default:''"`
AIRequestDurationMs int64 `gorm:"column:ai_request_duration_ms;default:0"`
CreatedAt time.Time `json:"created_at"`
}
func (DecisionRecordDB) TableName() string { return "decision_records" }
// DecisionRecord decision record (external API struct)
type DecisionRecord struct {
ID int64 `json:"id"`
TraderID string `json:"trader_id"`
@@ -81,49 +104,47 @@ type Statistics struct {
TotalClosePositions int `json:"total_close_positions"`
}
// initTables initializes AI decision log tables
// Note: Account equity curve data has been migrated to trader_equity_snapshots table (managed by EquityStore)
func (s *DecisionStore) initTables() error {
queries := []string{
// AI decision log table (records AI input/output, chain of thought, etc.)
`CREATE TABLE IF NOT EXISTS decision_records (
id INTEGER PRIMARY KEY AUTOINCREMENT,
trader_id TEXT NOT NULL,
cycle_number INTEGER NOT NULL,
timestamp DATETIME NOT NULL,
system_prompt TEXT DEFAULT '',
input_prompt TEXT DEFAULT '',
cot_trace TEXT DEFAULT '',
decision_json TEXT DEFAULT '',
raw_response TEXT DEFAULT '',
candidate_coins TEXT DEFAULT '',
execution_log TEXT DEFAULT '',
success BOOLEAN DEFAULT 0,
error_message TEXT DEFAULT '',
ai_request_duration_ms INTEGER DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`,
// Indexes
`CREATE INDEX IF NOT EXISTS idx_decision_records_trader_time ON decision_records(trader_id, timestamp DESC)`,
`CREATE INDEX IF NOT EXISTS idx_decision_records_timestamp ON decision_records(timestamp DESC)`,
}
for _, query := range queries {
if _, err := s.db.Exec(query); err != nil {
return fmt.Errorf("failed to execute SQL: %w", err)
}
}
// Migration: add raw_response column if not exists
s.db.Exec(`ALTER TABLE decision_records ADD COLUMN raw_response TEXT DEFAULT ''`)
// Migration: add decisions column if not exists
s.db.Exec(`ALTER TABLE decision_records ADD COLUMN decisions TEXT DEFAULT '[]'`)
return nil
// NewDecisionStore creates a new DecisionStore
func NewDecisionStore(db *gorm.DB) *DecisionStore {
return &DecisionStore{db: db}
}
// LogDecision logs decision (only saves AI decision log, equity curve has been migrated to equity table)
// initTables initializes AI decision log tables
func (s *DecisionStore) initTables() error {
// For PostgreSQL with existing table, skip AutoMigrate
if s.db.Dialector.Name() == "postgres" {
var tableExists int64
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'decision_records'`).Scan(&tableExists)
if tableExists > 0 {
return nil
}
}
return s.db.AutoMigrate(&DecisionRecordDB{})
}
// toRecord converts DB model to API struct
func (db *DecisionRecordDB) toRecord() *DecisionRecord {
record := &DecisionRecord{
ID: db.ID,
TraderID: db.TraderID,
CycleNumber: db.CycleNumber,
Timestamp: db.Timestamp,
SystemPrompt: db.SystemPrompt,
InputPrompt: db.InputPrompt,
CoTTrace: db.CoTTrace,
DecisionJSON: db.DecisionJSON,
RawResponse: db.RawResponse,
Success: db.Success,
ErrorMessage: db.ErrorMessage,
AIRequestDurationMs: db.AIRequestDurationMs,
}
json.Unmarshal([]byte(db.CandidateCoins), &record.CandidateCoins)
json.Unmarshal([]byte(db.ExecutionLog), &record.ExecutionLog)
json.Unmarshal([]byte(db.Decisions), &record.Decisions)
return record
}
// LogDecision logs decision
func (s *DecisionStore) LogDecision(record *DecisionRecord) error {
if record.Timestamp.IsZero() {
record.Timestamp = time.Now().UTC()
@@ -131,65 +152,49 @@ func (s *DecisionStore) LogDecision(record *DecisionRecord) error {
record.Timestamp = record.Timestamp.UTC()
}
// Serialize candidate coins, execution log and decisions to JSON
// Serialize arrays to JSON
candidateCoinsJSON, _ := json.Marshal(record.CandidateCoins)
executionLogJSON, _ := json.Marshal(record.ExecutionLog)
decisionsJSON, _ := json.Marshal(record.Decisions)
// Insert decision record main table (only save AI decision related content)
result, err := s.db.Exec(`
INSERT INTO decision_records (
trader_id, cycle_number, timestamp, system_prompt, input_prompt,
cot_trace, decision_json, raw_response, candidate_coins, execution_log,
decisions, success, error_message, ai_request_duration_ms
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`,
record.TraderID, record.CycleNumber, record.Timestamp.Format(time.RFC3339),
record.SystemPrompt, record.InputPrompt, record.CoTTrace, record.DecisionJSON,
record.RawResponse, string(candidateCoinsJSON), string(executionLogJSON),
string(decisionsJSON), record.Success, record.ErrorMessage, record.AIRequestDurationMs,
)
if err != nil {
dbRecord := &DecisionRecordDB{
TraderID: record.TraderID,
CycleNumber: record.CycleNumber,
Timestamp: record.Timestamp,
SystemPrompt: record.SystemPrompt,
InputPrompt: record.InputPrompt,
CoTTrace: record.CoTTrace,
DecisionJSON: record.DecisionJSON,
RawResponse: record.RawResponse,
CandidateCoins: string(candidateCoinsJSON),
ExecutionLog: string(executionLogJSON),
Decisions: string(decisionsJSON),
Success: record.Success,
ErrorMessage: record.ErrorMessage,
AIRequestDurationMs: record.AIRequestDurationMs,
}
if err := s.db.Create(dbRecord).Error; err != nil {
return fmt.Errorf("failed to insert decision record: %w", err)
}
decisionID, err := result.LastInsertId()
if err != nil {
return fmt.Errorf("failed to get decision ID: %w", err)
}
record.ID = decisionID
record.ID = dbRecord.ID
return nil
}
// GetLatestRecords gets the latest N records for specified trader (sorted by time in ascending order: old to new)
func (s *DecisionStore) GetLatestRecords(traderID string, n int) ([]*DecisionRecord, error) {
rows, err := s.db.Query(`
SELECT id, trader_id, cycle_number, timestamp, system_prompt, input_prompt,
cot_trace, decision_json, candidate_coins, execution_log,
COALESCE(decisions, '[]'), success, error_message, ai_request_duration_ms
FROM decision_records
WHERE trader_id = ?
ORDER BY timestamp DESC
LIMIT ?
`, traderID, n)
var dbRecords []*DecisionRecordDB
err := s.db.Where("trader_id = ?", traderID).
Order("timestamp DESC").
Limit(n).
Find(&dbRecords).Error
if err != nil {
return nil, fmt.Errorf("failed to query decision records: %w", err)
}
defer rows.Close()
var records []*DecisionRecord
for rows.Next() {
record, err := s.scanDecisionRecord(rows)
if err != nil {
continue
}
records = append(records, record)
}
// Fill associated data
for _, record := range records {
s.fillRecordDetails(record)
records := make([]*DecisionRecord, len(dbRecords))
for i, db := range dbRecords {
records[i] = db.toRecord()
}
// Reverse array to sort time from old to new
@@ -202,26 +207,15 @@ func (s *DecisionStore) GetLatestRecords(traderID string, n int) ([]*DecisionRec
// GetAllLatestRecords gets the latest N records for all traders
func (s *DecisionStore) GetAllLatestRecords(n int) ([]*DecisionRecord, error) {
rows, err := s.db.Query(`
SELECT id, trader_id, cycle_number, timestamp, system_prompt, input_prompt,
cot_trace, decision_json, candidate_coins, execution_log,
COALESCE(decisions, '[]'), success, error_message, ai_request_duration_ms
FROM decision_records
ORDER BY timestamp DESC
LIMIT ?
`, n)
var dbRecords []*DecisionRecordDB
err := s.db.Order("timestamp DESC").Limit(n).Find(&dbRecords).Error
if err != nil {
return nil, fmt.Errorf("failed to query decision records: %w", err)
}
defer rows.Close()
var records []*DecisionRecord
for rows.Next() {
record, err := s.scanDecisionRecord(rows)
if err != nil {
continue
}
records = append(records, record)
records := make([]*DecisionRecord, len(dbRecords))
for i, db := range dbRecords {
records[i] = db.toRecord()
}
// Reverse array
@@ -236,26 +230,17 @@ func (s *DecisionStore) GetAllLatestRecords(n int) ([]*DecisionRecord, error) {
func (s *DecisionStore) GetRecordsByDate(traderID string, date time.Time) ([]*DecisionRecord, error) {
dateStr := date.Format("2006-01-02")
rows, err := s.db.Query(`
SELECT id, trader_id, cycle_number, timestamp, system_prompt, input_prompt,
cot_trace, decision_json, candidate_coins, execution_log,
COALESCE(decisions, '[]'), success, error_message, ai_request_duration_ms
FROM decision_records
WHERE trader_id = ? AND DATE(timestamp) = ?
ORDER BY timestamp ASC
`, traderID, dateStr)
var dbRecords []*DecisionRecordDB
err := s.db.Where("trader_id = ? AND DATE(timestamp) = ?", traderID, dateStr).
Order("timestamp ASC").
Find(&dbRecords).Error
if err != nil {
return nil, fmt.Errorf("failed to query decision records: %w", err)
}
defer rows.Close()
var records []*DecisionRecord
for rows.Next() {
record, err := s.scanDecisionRecord(rows)
if err != nil {
continue
}
records = append(records, record)
records := make([]*DecisionRecord, len(dbRecords))
for i, db := range dbRecords {
records[i] = db.toRecord()
}
return records, nil
@@ -263,48 +248,31 @@ func (s *DecisionStore) GetRecordsByDate(traderID string, date time.Time) ([]*De
// CleanOldRecords cleans old records from N days ago
func (s *DecisionStore) CleanOldRecords(traderID string, days int) (int64, error) {
cutoffTime := time.Now().AddDate(0, 0, -days).Format(time.RFC3339)
cutoffTime := time.Now().AddDate(0, 0, -days)
result, err := s.db.Exec(`
DELETE FROM decision_records
WHERE trader_id = ? AND timestamp < ?
`, traderID, cutoffTime)
if err != nil {
return 0, fmt.Errorf("failed to clean old records: %w", err)
result := s.db.Where("trader_id = ? AND timestamp < ?", traderID, cutoffTime).
Delete(&DecisionRecordDB{})
if result.Error != nil {
return 0, fmt.Errorf("failed to clean old records: %w", result.Error)
}
return result.RowsAffected()
return result.RowsAffected, nil
}
// GetStatistics gets statistics information for specified trader
func (s *DecisionStore) GetStatistics(traderID string) (*Statistics, error) {
stats := &Statistics{}
err := s.db.QueryRow(`
SELECT COUNT(*) FROM decision_records WHERE trader_id = ?
`, traderID).Scan(&stats.TotalCycles)
if err != nil {
return nil, fmt.Errorf("failed to query total cycles: %w", err)
}
var totalCount, successCount int64
s.db.Model(&DecisionRecordDB{}).Where("trader_id = ?", traderID).Count(&totalCount)
s.db.Model(&DecisionRecordDB{}).Where("trader_id = ? AND success = ?", traderID, true).Count(&successCount)
err = s.db.QueryRow(`
SELECT COUNT(*) FROM decision_records WHERE trader_id = ? AND success = 1
`, traderID).Scan(&stats.SuccessfulCycles)
if err != nil {
return nil, fmt.Errorf("failed to query successful cycles: %w", err)
}
stats.TotalCycles = int(totalCount)
stats.SuccessfulCycles = int(successCount)
stats.FailedCycles = stats.TotalCycles - stats.SuccessfulCycles
// Count from trader_positions table
s.db.QueryRow(`
SELECT COUNT(*) FROM trader_positions
WHERE trader_id = ?
`, traderID).Scan(&stats.TotalOpenPositions)
s.db.QueryRow(`
SELECT COUNT(*) FROM trader_positions
WHERE trader_id = ? AND status = 'CLOSED'
`, traderID).Scan(&stats.TotalClosePositions)
// Count from trader_positions table using raw query for cross-table
s.db.Raw("SELECT COUNT(*) FROM trader_positions WHERE trader_id = ?", traderID).Scan(&stats.TotalOpenPositions)
s.db.Raw("SELECT COUNT(*) FROM trader_positions WHERE trader_id = ? AND status = 'CLOSED'", traderID).Scan(&stats.TotalClosePositions)
return stats, nil
}
@@ -313,64 +281,33 @@ func (s *DecisionStore) GetStatistics(traderID string) (*Statistics, error) {
func (s *DecisionStore) GetAllStatistics() (*Statistics, error) {
stats := &Statistics{}
s.db.QueryRow(`SELECT COUNT(*) FROM decision_records`).Scan(&stats.TotalCycles)
s.db.QueryRow(`SELECT COUNT(*) FROM decision_records WHERE success = 1`).Scan(&stats.SuccessfulCycles)
var totalCount, successCount int64
s.db.Model(&DecisionRecordDB{}).Count(&totalCount)
s.db.Model(&DecisionRecordDB{}).Where("success = ?", true).Count(&successCount)
stats.TotalCycles = int(totalCount)
stats.SuccessfulCycles = int(successCount)
stats.FailedCycles = stats.TotalCycles - stats.SuccessfulCycles
// Count from trader_positions table
s.db.QueryRow(`
SELECT COUNT(*) FROM trader_positions
`).Scan(&stats.TotalOpenPositions)
s.db.QueryRow(`
SELECT COUNT(*) FROM trader_positions
WHERE status = 'CLOSED'
`).Scan(&stats.TotalClosePositions)
s.db.Raw("SELECT COUNT(*) FROM trader_positions").Scan(&stats.TotalOpenPositions)
s.db.Raw("SELECT COUNT(*) FROM trader_positions WHERE status = 'CLOSED'").Scan(&stats.TotalClosePositions)
return stats, nil
}
// GetLastCycleNumber gets the last cycle number for specified trader
func (s *DecisionStore) GetLastCycleNumber(traderID string) (int, error) {
var cycleNumber int
err := s.db.QueryRow(`
SELECT COALESCE(MAX(cycle_number), 0) FROM decision_records WHERE trader_id = ?
`, traderID).Scan(&cycleNumber)
var cycleNumber *int
err := s.db.Model(&DecisionRecordDB{}).
Where("trader_id = ?", traderID).
Select("MAX(cycle_number)").
Scan(&cycleNumber).Error
if err != nil {
return 0, err
}
return cycleNumber, nil
}
// scanDecisionRecord scans decision record from row
func (s *DecisionStore) scanDecisionRecord(rows *sql.Rows) (*DecisionRecord, error) {
var record DecisionRecord
var timestampStr string
var candidateCoinsJSON, executionLogJSON, decisionsJSON string
err := rows.Scan(
&record.ID, &record.TraderID, &record.CycleNumber, &timestampStr,
&record.SystemPrompt, &record.InputPrompt, &record.CoTTrace,
&record.DecisionJSON, &candidateCoinsJSON, &executionLogJSON,
&decisionsJSON, &record.Success, &record.ErrorMessage, &record.AIRequestDurationMs,
)
if err != nil {
return nil, err
if cycleNumber == nil {
return 0, nil
}
record.Timestamp, _ = time.Parse(time.RFC3339, timestampStr)
json.Unmarshal([]byte(candidateCoinsJSON), &record.CandidateCoins)
json.Unmarshal([]byte(executionLogJSON), &record.ExecutionLog)
json.Unmarshal([]byte(decisionsJSON), &record.Decisions)
return &record, nil
}
// fillRecordDetails fills associated data for decision record (old associated tables removed, this function kept for compatibility)
// Note: Account snapshot, position snapshot, decision action data are no longer stored in decision related tables
// - For equity data use EquityStore.GetLatest()
// - For order data use OrderStore
func (s *DecisionStore) fillRecordDetails(record *DecisionRecord) {
// Old associated tables removed, no longer need to fill
// AccountState, Positions, Decisions fields will remain at zero values
return *cycleNumber, nil
}
+41
View File
@@ -238,3 +238,44 @@ func getEnv(key, defaultValue string) string {
}
return defaultValue
}
// convertQuery converts ? placeholders to $1, $2 for PostgreSQL
// and handles other database-specific syntax differences
func convertQuery(query string, dbType DBType) string {
if dbType != DBTypePostgres {
return query
}
result := query
// Convert ? to $1, $2, etc. for PostgreSQL
index := 1
for strings.Contains(result, "?") {
result = strings.Replace(result, "?", fmt.Sprintf("$%d", index), 1)
index++
}
// Convert datetime('now') to CURRENT_TIMESTAMP
result = strings.ReplaceAll(result, "datetime('now')", "CURRENT_TIMESTAMP")
// Remove datetime() wrapper for ORDER BY (PostgreSQL timestamps sort correctly)
// This handles patterns like "ORDER BY datetime(column) DESC"
result = strings.ReplaceAll(result, "datetime(updated_at)", "updated_at")
result = strings.ReplaceAll(result, "datetime(created_at)", "created_at")
return result
}
// boolDefault returns database-appropriate boolean default for COALESCE
// Use in queries like: COALESCE(column, %s)
func boolDefault(dbType DBType, value bool) string {
if dbType == DBTypePostgres {
if value {
return "TRUE"
}
return "FALSE"
}
if value {
return "1"
}
return "0"
}
+62 -140
View File
@@ -1,55 +1,48 @@
package store
import (
"database/sql"
"fmt"
"time"
"gorm.io/gorm"
)
// EquityStore account equity storage (for plotting return curves)
type EquityStore struct {
db *sql.DB
db *gorm.DB
}
// EquitySnapshot equity snapshot
type EquitySnapshot struct {
ID int64 `json:"id"`
TraderID string `json:"trader_id"`
Timestamp time.Time `json:"timestamp"`
TotalEquity float64 `json:"total_equity"` // Account equity (balance + unrealized PnL)
Balance float64 `json:"balance"` // Account balance
UnrealizedPnL float64 `json:"unrealized_pnl"` // Unrealized profit and loss
PositionCount int `json:"position_count"` // Position count
MarginUsedPct float64 `json:"margin_used_pct"` // Margin usage percentage
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
TraderID string `gorm:"column:trader_id;not null;index:idx_equity_trader_time" json:"trader_id"`
Timestamp time.Time `gorm:"not null;index:idx_equity_trader_time,sort:desc;index:idx_equity_timestamp,sort:desc" json:"timestamp"`
TotalEquity float64 `gorm:"column:total_equity;not null;default:0" json:"total_equity"`
Balance float64 `gorm:"not null;default:0" json:"balance"`
UnrealizedPnL float64 `gorm:"column:unrealized_pnl;not null;default:0" json:"unrealized_pnl"`
PositionCount int `gorm:"column:position_count;default:0" json:"position_count"`
MarginUsedPct float64 `gorm:"column:margin_used_pct;default:0" json:"margin_used_pct"`
CreatedAt time.Time `json:"created_at"`
}
func (EquitySnapshot) TableName() string { return "trader_equity_snapshots" }
// NewEquityStore creates a new EquityStore
func NewEquityStore(db *gorm.DB) *EquityStore {
return &EquityStore{db: db}
}
// initTables initializes equity tables
func (s *EquityStore) initTables() error {
queries := []string{
// Equity snapshot table - specifically for return curves
`CREATE TABLE IF NOT EXISTS trader_equity_snapshots (
id INTEGER PRIMARY KEY AUTOINCREMENT,
trader_id TEXT NOT NULL,
timestamp DATETIME NOT NULL,
total_equity REAL NOT NULL DEFAULT 0,
balance REAL NOT NULL DEFAULT 0,
unrealized_pnl REAL NOT NULL DEFAULT 0,
position_count INTEGER DEFAULT 0,
margin_used_pct REAL DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`,
// Indexes
`CREATE INDEX IF NOT EXISTS idx_equity_trader_time ON trader_equity_snapshots(trader_id, timestamp DESC)`,
`CREATE INDEX IF NOT EXISTS idx_equity_timestamp ON trader_equity_snapshots(timestamp DESC)`,
}
for _, query := range queries {
if _, err := s.db.Exec(query); err != nil {
return fmt.Errorf("failed to execute SQL: %w", err)
}
}
// For PostgreSQL with existing table, skip AutoMigrate
if s.db.Dialector.Name() == "postgres" {
var tableExists int64
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'trader_equity_snapshots'`).Scan(&tableExists)
if tableExists > 0 {
return nil
}
}
return s.db.AutoMigrate(&EquitySnapshot{})
}
// Save saves equity snapshot
@@ -60,58 +53,22 @@ func (s *EquityStore) Save(snapshot *EquitySnapshot) error {
snapshot.Timestamp = snapshot.Timestamp.UTC()
}
result, err := s.db.Exec(`
INSERT INTO trader_equity_snapshots (
trader_id, timestamp, total_equity, balance,
unrealized_pnl, position_count, margin_used_pct
) VALUES (?, ?, ?, ?, ?, ?, ?)
`,
snapshot.TraderID,
snapshot.Timestamp.Format(time.RFC3339),
snapshot.TotalEquity,
snapshot.Balance,
snapshot.UnrealizedPnL,
snapshot.PositionCount,
snapshot.MarginUsedPct,
)
if err != nil {
if err := s.db.Create(snapshot).Error; err != nil {
return fmt.Errorf("failed to save equity snapshot: %w", err)
}
id, _ := result.LastInsertId()
snapshot.ID = id
return nil
}
// GetLatest gets the latest N equity records for specified trader (sorted in ascending chronological order: old to new)
func (s *EquityStore) GetLatest(traderID string, limit int) ([]*EquitySnapshot, error) {
rows, err := s.db.Query(`
SELECT id, trader_id, timestamp, total_equity, balance,
unrealized_pnl, position_count, margin_used_pct
FROM trader_equity_snapshots
WHERE trader_id = ?
ORDER BY timestamp DESC
LIMIT ?
`, traderID, limit)
var snapshots []*EquitySnapshot
err := s.db.Where("trader_id = ?", traderID).
Order("timestamp DESC").
Limit(limit).
Find(&snapshots).Error
if err != nil {
return nil, fmt.Errorf("failed to query equity records: %w", err)
}
defer rows.Close()
var snapshots []*EquitySnapshot
for rows.Next() {
snap := &EquitySnapshot{}
var timestampStr string
err := rows.Scan(
&snap.ID, &snap.TraderID, &timestampStr, &snap.TotalEquity,
&snap.Balance, &snap.UnrealizedPnL, &snap.PositionCount, &snap.MarginUsedPct,
)
if err != nil {
continue
}
snap.Timestamp, _ = time.Parse(time.RFC3339, timestampStr)
snapshots = append(snapshots, snap)
}
// Reverse the array to sort time from old to new (suitable for plotting curves)
for i, j := 0, len(snapshots)-1; i < j; i, j = i+1, j-1 {
@@ -123,116 +80,81 @@ func (s *EquityStore) GetLatest(traderID string, limit int) ([]*EquitySnapshot,
// GetByTimeRange gets equity records within specified time range
func (s *EquityStore) GetByTimeRange(traderID string, start, end time.Time) ([]*EquitySnapshot, error) {
rows, err := s.db.Query(`
SELECT id, trader_id, timestamp, total_equity, balance,
unrealized_pnl, position_count, margin_used_pct
FROM trader_equity_snapshots
WHERE trader_id = ? AND timestamp >= ? AND timestamp <= ?
ORDER BY timestamp ASC
`, traderID, start.Format(time.RFC3339), end.Format(time.RFC3339))
var snapshots []*EquitySnapshot
err := s.db.Where("trader_id = ? AND timestamp >= ? AND timestamp <= ?", traderID, start, end).
Order("timestamp ASC").
Find(&snapshots).Error
if err != nil {
return nil, fmt.Errorf("failed to query equity records: %w", err)
}
defer rows.Close()
var snapshots []*EquitySnapshot
for rows.Next() {
snap := &EquitySnapshot{}
var timestampStr string
err := rows.Scan(
&snap.ID, &snap.TraderID, &timestampStr, &snap.TotalEquity,
&snap.Balance, &snap.UnrealizedPnL, &snap.PositionCount, &snap.MarginUsedPct,
)
if err != nil {
continue
}
snap.Timestamp, _ = time.Parse(time.RFC3339, timestampStr)
snapshots = append(snapshots, snap)
}
return snapshots, nil
}
// GetAllTradersLatest gets latest equity for all traders (for leaderboards)
func (s *EquityStore) GetAllTradersLatest() (map[string]*EquitySnapshot, error) {
rows, err := s.db.Query(`
// Use raw SQL for this complex query with subquery
var snapshots []*EquitySnapshot
err := s.db.Raw(`
SELECT e.id, e.trader_id, e.timestamp, e.total_equity, e.balance,
e.unrealized_pnl, e.position_count, e.margin_used_pct
e.unrealized_pnl, e.position_count, e.margin_used_pct, e.created_at
FROM trader_equity_snapshots e
INNER JOIN (
SELECT trader_id, MAX(timestamp) as max_ts
FROM trader_equity_snapshots
GROUP BY trader_id
) latest ON e.trader_id = latest.trader_id AND e.timestamp = latest.max_ts
`)
`).Scan(&snapshots).Error
if err != nil {
return nil, fmt.Errorf("failed to query latest equity: %w", err)
}
defer rows.Close()
result := make(map[string]*EquitySnapshot)
for rows.Next() {
snap := &EquitySnapshot{}
var timestampStr string
err := rows.Scan(
&snap.ID, &snap.TraderID, &timestampStr, &snap.TotalEquity,
&snap.Balance, &snap.UnrealizedPnL, &snap.PositionCount, &snap.MarginUsedPct,
)
if err != nil {
continue
}
snap.Timestamp, _ = time.Parse(time.RFC3339, timestampStr)
for _, snap := range snapshots {
result[snap.TraderID] = snap
}
return result, nil
}
// CleanOldRecords cleans old records from N days ago
func (s *EquityStore) CleanOldRecords(traderID string, days int) (int64, error) {
cutoffTime := time.Now().AddDate(0, 0, -days).Format(time.RFC3339)
cutoffTime := time.Now().AddDate(0, 0, -days)
result, err := s.db.Exec(`
DELETE FROM trader_equity_snapshots
WHERE trader_id = ? AND timestamp < ?
`, traderID, cutoffTime)
if err != nil {
return 0, fmt.Errorf("failed to clean old records: %w", err)
result := s.db.Where("trader_id = ? AND timestamp < ?", traderID, cutoffTime).
Delete(&EquitySnapshot{})
if result.Error != nil {
return 0, fmt.Errorf("failed to clean old records: %w", result.Error)
}
return result.RowsAffected()
return result.RowsAffected, nil
}
// GetCount gets record count for specified trader
func (s *EquityStore) GetCount(traderID string) (int, error) {
var count int
err := s.db.QueryRow(`
SELECT COUNT(*) FROM trader_equity_snapshots WHERE trader_id = ?
`, traderID).Scan(&count)
return count, err
var count int64
err := s.db.Model(&EquitySnapshot{}).Where("trader_id = ?", traderID).Count(&count).Error
return int(count), err
}
// MigrateFromDecision migrates data from old decision_account_snapshots table
func (s *EquityStore) MigrateFromDecision() (int64, error) {
// Check if migration is needed (whether new table is empty)
var count int
s.db.QueryRow(`SELECT COUNT(*) FROM trader_equity_snapshots`).Scan(&count)
var count int64
s.db.Model(&EquitySnapshot{}).Count(&count)
if count > 0 {
return 0, nil // Already has data, skip migration
}
// Check if old table exists
// Check if old table exists (SQLite specific check, but works for migration)
var tableName string
err := s.db.QueryRow(`
err := s.db.Raw(`
SELECT name FROM sqlite_master
WHERE type='table' AND name='decision_account_snapshots'
`).Scan(&tableName)
if err != nil {
`).Scan(&tableName).Error
if err != nil || tableName == "" {
return 0, nil // Old table doesn't exist, skip
}
// Migrate data: join query from decision_records + decision_account_snapshots
result, err := s.db.Exec(`
result := s.db.Exec(`
INSERT INTO trader_equity_snapshots (
trader_id, timestamp, total_equity, balance,
unrealized_pnl, position_count, margin_used_pct
@@ -249,9 +171,9 @@ func (s *EquityStore) MigrateFromDecision() (int64, error) {
JOIN decision_account_snapshots das ON dr.id = das.decision_id
ORDER BY dr.timestamp ASC
`)
if err != nil {
return 0, fmt.Errorf("failed to migrate data: %w", err)
if result.Error != nil {
return 0, fmt.Errorf("failed to migrate data: %w", result.Error)
}
return result.RowsAffected()
return result.RowsAffected, nil
}
+139 -288
View File
@@ -1,83 +1,68 @@
package store
import (
"database/sql"
"fmt"
"nofx/crypto"
"nofx/logger"
"strings"
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// ExchangeStore exchange storage
type ExchangeStore struct {
db *sql.DB
encryptFunc func(string) string
decryptFunc func(string) string
db *gorm.DB
}
// Exchange exchange configuration
type Exchange struct {
ID string `json:"id"` // UUID
ExchangeType string `json:"exchange_type"` // "binance", "bybit", "okx", "hyperliquid", "aster", "lighter"
AccountName string `json:"account_name"` // User-defined account name
UserID string `json:"user_id"`
Name string `json:"name"` // Display name (auto-generated or user-defined)
Type string `json:"type"` // "cex" or "dex"
Enabled bool `json:"enabled"`
APIKey string `json:"apiKey"`
SecretKey string `json:"secretKey"`
Passphrase string `json:"passphrase"` // OKX-specific
Testnet bool `json:"testnet"`
HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"`
AsterUser string `json:"asterUser"`
AsterSigner string `json:"asterSigner"`
AsterPrivateKey string `json:"asterPrivateKey"`
LighterWalletAddr string `json:"lighterWalletAddr"`
LighterPrivateKey string `json:"lighterPrivateKey"`
LighterAPIKeyPrivateKey string `json:"lighterAPIKeyPrivateKey"`
LighterAPIKeyIndex int `json:"lighterAPIKeyIndex"`
ID string `gorm:"primaryKey" json:"id"`
ExchangeType string `gorm:"column:exchange_type;not null;default:''" json:"exchange_type"`
AccountName string `gorm:"column:account_name;not null;default:''" json:"account_name"`
UserID string `gorm:"column:user_id;not null;default:default;index" json:"user_id"`
Name string `gorm:"not null" json:"name"`
Type string `gorm:"not null" json:"type"` // "cex" or "dex"
Enabled bool `gorm:"default:false" json:"enabled"`
APIKey crypto.EncryptedString `gorm:"column:api_key;default:''" json:"apiKey"`
SecretKey crypto.EncryptedString `gorm:"column:secret_key;default:''" json:"secretKey"`
Passphrase crypto.EncryptedString `gorm:"column:passphrase;default:''" json:"passphrase"`
Testnet bool `gorm:"default:false" json:"testnet"`
HyperliquidWalletAddr string `gorm:"column:hyperliquid_wallet_addr;default:''" json:"hyperliquidWalletAddr"`
AsterUser string `gorm:"column:aster_user;default:''" json:"asterUser"`
AsterSigner string `gorm:"column:aster_signer;default:''" json:"asterSigner"`
AsterPrivateKey crypto.EncryptedString `gorm:"column:aster_private_key;default:''" json:"asterPrivateKey"`
LighterWalletAddr string `gorm:"column:lighter_wallet_addr;default:''" json:"lighterWalletAddr"`
LighterPrivateKey crypto.EncryptedString `gorm:"column:lighter_private_key;default:''" json:"lighterPrivateKey"`
LighterAPIKeyPrivateKey crypto.EncryptedString `gorm:"column:lighter_api_key_private_key;default:''" json:"lighterAPIKeyPrivateKey"`
LighterAPIKeyIndex int `gorm:"column:lighter_api_key_index;default:0" json:"lighterAPIKeyIndex"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (Exchange) TableName() string { return "exchanges" }
// NewExchangeStore creates a new ExchangeStore
func NewExchangeStore(db *gorm.DB) *ExchangeStore {
return &ExchangeStore{db: db}
}
func (s *ExchangeStore) initTables() error {
// Create new table structure with UUID as primary key
_, err := s.db.Exec(`
CREATE TABLE IF NOT EXISTS exchanges (
id TEXT PRIMARY KEY,
exchange_type TEXT NOT NULL DEFAULT '',
account_name TEXT NOT NULL DEFAULT '',
user_id TEXT NOT NULL DEFAULT 'default',
name TEXT NOT NULL,
type TEXT NOT NULL,
enabled BOOLEAN DEFAULT 0,
api_key TEXT DEFAULT '',
secret_key TEXT DEFAULT '',
passphrase TEXT DEFAULT '',
testnet BOOLEAN DEFAULT 0,
hyperliquid_wallet_addr TEXT DEFAULT '',
aster_user TEXT DEFAULT '',
aster_signer TEXT DEFAULT '',
aster_private_key TEXT DEFAULT '',
lighter_wallet_addr TEXT DEFAULT '',
lighter_private_key TEXT DEFAULT '',
lighter_api_key_private_key TEXT DEFAULT '',
lighter_api_key_index INTEGER DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return err
// For PostgreSQL with existing table, skip AutoMigrate
if s.db.Dialector.Name() == "postgres" {
var tableExists int64
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'exchanges'`).Scan(&tableExists)
if tableExists > 0 {
// Still run data migrations
s.migrateToMultiAccount()
s.db.Model(&Exchange{}).Where("account_name = '' OR account_name IS NULL").Update("account_name", "Default")
return nil
}
}
// Migration: add new columns if not exists
s.db.Exec(`ALTER TABLE exchanges ADD COLUMN passphrase TEXT DEFAULT ''`)
s.db.Exec(`ALTER TABLE exchanges ADD COLUMN exchange_type TEXT NOT NULL DEFAULT ''`)
s.db.Exec(`ALTER TABLE exchanges ADD COLUMN account_name TEXT NOT NULL DEFAULT ''`)
s.db.Exec(`ALTER TABLE exchanges ADD COLUMN lighter_api_key_index INTEGER DEFAULT 0`)
if err := s.db.AutoMigrate(&Exchange{}); err != nil {
return err
}
// Run migration to multi-account if needed
if err := s.migrateToMultiAccount(); err != nil {
@@ -85,120 +70,65 @@ func (s *ExchangeStore) initTables() error {
}
// Fix empty account_name for existing records
s.db.Exec(`UPDATE exchanges SET account_name = 'Default' WHERE account_name = '' OR account_name IS NULL`)
s.db.Model(&Exchange{}).Where("account_name = '' OR account_name IS NULL").Update("account_name", "Default")
// Update trigger for new schema
s.db.Exec(`DROP TRIGGER IF EXISTS update_exchanges_updated_at`)
_, err = s.db.Exec(`
CREATE TRIGGER IF NOT EXISTS update_exchanges_updated_at
AFTER UPDATE ON exchanges
BEGIN
UPDATE exchanges SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
END
`)
return err
return nil
}
// migrateToMultiAccount migrates old schema (id=exchange_type) to new schema (id=UUID)
func (s *ExchangeStore) migrateToMultiAccount() error {
// Check if migration is needed by looking for old-style IDs (non-UUID)
var count int
err := s.db.QueryRow(`
SELECT COUNT(*) FROM exchanges
WHERE exchange_type = '' AND id IN ('binance', 'bybit', 'okx', 'bitget', 'hyperliquid', 'aster', 'lighter')
`).Scan(&count)
var count int64
err := s.db.Model(&Exchange{}).
Where("exchange_type = '' AND id IN ?", []string{"binance", "bybit", "okx", "bitget", "hyperliquid", "aster", "lighter"}).
Count(&count).Error
if err != nil {
return err
}
if count == 0 {
// No migration needed
return nil
}
logger.Infof("🔄 Migrating %d exchange records to multi-account schema...", count)
// Get all old records
rows, err := s.db.Query(`
SELECT id, user_id, name, type, enabled, api_key, secret_key,
COALESCE(passphrase, '') as passphrase, testnet,
COALESCE(hyperliquid_wallet_addr, '') as hyperliquid_wallet_addr,
COALESCE(aster_user, '') as aster_user,
COALESCE(aster_signer, '') as aster_signer,
COALESCE(aster_private_key, '') as aster_private_key,
COALESCE(lighter_wallet_addr, '') as lighter_wallet_addr,
COALESCE(lighter_private_key, '') as lighter_private_key,
COALESCE(lighter_api_key_private_key, '') as lighter_api_key_private_key
FROM exchanges
WHERE exchange_type = '' AND id IN ('binance', 'bybit', 'okx', 'bitget', 'hyperliquid', 'aster', 'lighter')
`)
var records []Exchange
err = s.db.Where("exchange_type = '' AND id IN ?", []string{"binance", "bybit", "okx", "bitget", "hyperliquid", "aster", "lighter"}).
Find(&records).Error
if err != nil {
return err
}
defer rows.Close()
type oldRecord struct {
id, userID, name, typ string
enabled, testnet bool
apiKey, secretKey, passphrase string
hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string
lighterWalletAddr, lighterPrivateKey, lighterApiKeyPrivateKey string
}
var records []oldRecord
for rows.Next() {
var r oldRecord
if err := rows.Scan(&r.id, &r.userID, &r.name, &r.typ, &r.enabled,
&r.apiKey, &r.secretKey, &r.passphrase, &r.testnet,
&r.hyperliquidWalletAddr, &r.asterUser, &r.asterSigner, &r.asterPrivateKey,
&r.lighterWalletAddr, &r.lighterPrivateKey, &r.lighterApiKeyPrivateKey); err != nil {
return err
}
records = append(records, r)
}
// Begin transaction
tx, err := s.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// Migrate each record
return s.db.Transaction(func(tx *gorm.DB) error {
for _, r := range records {
newID := uuid.New().String()
oldID := r.id // This is the exchange type (e.g., "binance")
oldID := r.ID // This is the exchange type (e.g., "binance")
// Update traders table to use new UUID
_, err = tx.Exec(`UPDATE traders SET exchange_id = ? WHERE exchange_id = ? AND user_id = ?`,
newID, oldID, r.userID)
if err != nil {
if err := tx.Exec("UPDATE traders SET exchange_id = ? WHERE exchange_id = ? AND user_id = ?",
newID, oldID, r.UserID).Error; err != nil {
logger.Errorf("Failed to update traders for exchange %s: %v", oldID, err)
return err
}
// Update the exchange record
_, err = tx.Exec(`
UPDATE exchanges SET
id = ?,
exchange_type = ?,
account_name = ?
WHERE id = ? AND user_id = ?
`, newID, oldID, "Default", oldID, r.userID)
if err != nil {
if err := tx.Model(&Exchange{}).
Where("id = ? AND user_id = ?", oldID, r.UserID).
Updates(map[string]interface{}{
"id": newID,
"exchange_type": oldID,
"account_name": "Default",
}).Error; err != nil {
logger.Errorf("Failed to migrate exchange %s: %v", oldID, err)
return err
}
logger.Infof("✅ Migrated exchange %s -> UUID %s for user %s", oldID, newID, r.userID)
logger.Infof("✅ Migrated exchange %s -> UUID %s for user %s", oldID, newID, r.UserID)
}
if err := tx.Commit(); err != nil {
return err
}
logger.Infof("✅ Multi-account migration completed successfully")
return nil
})
}
func (s *ExchangeStore) initDefaultData() error {
@@ -206,108 +136,24 @@ func (s *ExchangeStore) initDefaultData() error {
return nil
}
func (s *ExchangeStore) encrypt(plaintext string) string {
if s.encryptFunc != nil {
return s.encryptFunc(plaintext)
}
return plaintext
}
func (s *ExchangeStore) decrypt(encrypted string) string {
if s.decryptFunc != nil {
return s.decryptFunc(encrypted)
}
return encrypted
}
// List gets user's exchange list
func (s *ExchangeStore) List(userID string) ([]*Exchange, error) {
rows, err := s.db.Query(`
SELECT id, COALESCE(exchange_type, '') as exchange_type, COALESCE(account_name, '') as account_name,
user_id, name, type, enabled, api_key, secret_key,
COALESCE(passphrase, '') as passphrase, testnet,
COALESCE(hyperliquid_wallet_addr, '') as hyperliquid_wallet_addr,
COALESCE(aster_user, '') as aster_user,
COALESCE(aster_signer, '') as aster_signer,
COALESCE(aster_private_key, '') as aster_private_key,
COALESCE(lighter_wallet_addr, '') as lighter_wallet_addr,
COALESCE(lighter_private_key, '') as lighter_private_key,
COALESCE(lighter_api_key_private_key, '') as lighter_api_key_private_key,
COALESCE(lighter_api_key_index, 0) as lighter_api_key_index,
created_at, updated_at
FROM exchanges WHERE user_id = ? ORDER BY exchange_type, account_name
`, userID)
var exchanges []*Exchange
err := s.db.Where("user_id = ?", userID).Order("exchange_type, account_name").Find(&exchanges).Error
if err != nil {
return nil, err
}
defer rows.Close()
exchanges := make([]*Exchange, 0)
for rows.Next() {
var e Exchange
var createdAt, updatedAt string
err := rows.Scan(
&e.ID, &e.ExchangeType, &e.AccountName,
&e.UserID, &e.Name, &e.Type,
&e.Enabled, &e.APIKey, &e.SecretKey, &e.Passphrase, &e.Testnet,
&e.HyperliquidWalletAddr, &e.AsterUser, &e.AsterSigner, &e.AsterPrivateKey,
&e.LighterWalletAddr, &e.LighterPrivateKey, &e.LighterAPIKeyPrivateKey, &e.LighterAPIKeyIndex,
&createdAt, &updatedAt,
)
if err != nil {
return nil, err
}
e.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
e.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
e.APIKey = s.decrypt(e.APIKey)
e.SecretKey = s.decrypt(e.SecretKey)
e.Passphrase = s.decrypt(e.Passphrase)
e.AsterPrivateKey = s.decrypt(e.AsterPrivateKey)
e.LighterPrivateKey = s.decrypt(e.LighterPrivateKey)
e.LighterAPIKeyPrivateKey = s.decrypt(e.LighterAPIKeyPrivateKey)
exchanges = append(exchanges, &e)
}
return exchanges, nil
}
// GetByID gets a specific exchange by UUID
func (s *ExchangeStore) GetByID(userID, id string) (*Exchange, error) {
var e Exchange
var createdAt, updatedAt string
err := s.db.QueryRow(`
SELECT id, COALESCE(exchange_type, '') as exchange_type, COALESCE(account_name, '') as account_name,
user_id, name, type, enabled, api_key, secret_key,
COALESCE(passphrase, '') as passphrase, testnet,
COALESCE(hyperliquid_wallet_addr, '') as hyperliquid_wallet_addr,
COALESCE(aster_user, '') as aster_user,
COALESCE(aster_signer, '') as aster_signer,
COALESCE(aster_private_key, '') as aster_private_key,
COALESCE(lighter_wallet_addr, '') as lighter_wallet_addr,
COALESCE(lighter_private_key, '') as lighter_private_key,
COALESCE(lighter_api_key_private_key, '') as lighter_api_key_private_key,
COALESCE(lighter_api_key_index, 0) as lighter_api_key_index,
created_at, updated_at
FROM exchanges WHERE id = ? AND user_id = ?
`, id, userID).Scan(
&e.ID, &e.ExchangeType, &e.AccountName,
&e.UserID, &e.Name, &e.Type,
&e.Enabled, &e.APIKey, &e.SecretKey, &e.Passphrase, &e.Testnet,
&e.HyperliquidWalletAddr, &e.AsterUser, &e.AsterSigner, &e.AsterPrivateKey,
&e.LighterWalletAddr, &e.LighterPrivateKey, &e.LighterAPIKeyPrivateKey, &e.LighterAPIKeyIndex,
&createdAt, &updatedAt,
)
var exchange Exchange
err := s.db.Where("id = ? AND user_id = ?", id, userID).First(&exchange).Error
if err != nil {
return nil, err
}
e.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
e.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
e.APIKey = s.decrypt(e.APIKey)
e.SecretKey = s.decrypt(e.SecretKey)
e.Passphrase = s.decrypt(e.Passphrase)
e.AsterPrivateKey = s.decrypt(e.AsterPrivateKey)
e.LighterPrivateKey = s.decrypt(e.LighterPrivateKey)
e.LighterAPIKeyPrivateKey = s.decrypt(e.LighterAPIKeyPrivateKey)
return &e, nil
return &exchange, nil
}
// getExchangeNameAndType returns the display name and type for an exchange type
@@ -341,7 +187,6 @@ func (s *ExchangeStore) Create(userID, exchangeType, accountName string, enabled
id := uuid.New().String()
name, typ := getExchangeNameAndType(exchangeType)
// If account name is empty, use "Default"
if accountName == "" {
accountName = "Default"
}
@@ -349,19 +194,29 @@ func (s *ExchangeStore) Create(userID, exchangeType, accountName string, enabled
logger.Debugf("🔧 ExchangeStore.Create: userID=%s, exchangeType=%s, accountName=%s, id=%s",
userID, exchangeType, accountName, id)
_, err := s.db.Exec(`
INSERT INTO exchanges (id, exchange_type, account_name, user_id, name, type, enabled,
api_key, secret_key, passphrase, testnet,
hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key,
lighter_wallet_addr, lighter_private_key, lighter_api_key_private_key, lighter_api_key_index,
created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now'))
`, id, exchangeType, accountName, userID, name, typ, enabled,
s.encrypt(apiKey), s.encrypt(secretKey), s.encrypt(passphrase), testnet,
hyperliquidWalletAddr, asterUser, asterSigner, s.encrypt(asterPrivateKey),
lighterWalletAddr, s.encrypt(lighterPrivateKey), s.encrypt(lighterApiKeyPrivateKey), lighterApiKeyIndex)
exchange := &Exchange{
ID: id,
ExchangeType: exchangeType,
AccountName: accountName,
UserID: userID,
Name: name,
Type: typ,
Enabled: enabled,
APIKey: crypto.EncryptedString(apiKey),
SecretKey: crypto.EncryptedString(secretKey),
Passphrase: crypto.EncryptedString(passphrase),
Testnet: testnet,
HyperliquidWalletAddr: hyperliquidWalletAddr,
AsterUser: asterUser,
AsterSigner: asterSigner,
AsterPrivateKey: crypto.EncryptedString(asterPrivateKey),
LighterWalletAddr: lighterWalletAddr,
LighterPrivateKey: crypto.EncryptedString(lighterPrivateKey),
LighterAPIKeyPrivateKey: crypto.EncryptedString(lighterApiKeyPrivateKey),
LighterAPIKeyIndex: lighterApiKeyIndex,
}
if err != nil {
if err := s.db.Create(exchange).Error; err != nil {
return "", err
}
return id, nil
@@ -373,53 +228,42 @@ func (s *ExchangeStore) Update(userID, id string, enabled bool, apiKey, secretKe
logger.Debugf("🔧 ExchangeStore.Update: userID=%s, id=%s, enabled=%v", userID, id, enabled)
setClauses := []string{
"enabled = ?",
"testnet = ?",
"hyperliquid_wallet_addr = ?",
"aster_user = ?",
"aster_signer = ?",
"lighter_wallet_addr = ?",
"lighter_api_key_index = ?",
"updated_at = datetime('now')",
updates := map[string]interface{}{
"enabled": enabled,
"testnet": testnet,
"hyperliquid_wallet_addr": hyperliquidWalletAddr,
"aster_user": asterUser,
"aster_signer": asterSigner,
"lighter_wallet_addr": lighterWalletAddr,
"lighter_api_key_index": lighterApiKeyIndex,
"updated_at": time.Now(),
}
args := []interface{}{enabled, testnet, hyperliquidWalletAddr, asterUser, asterSigner, lighterWalletAddr, lighterApiKeyIndex}
// Only update encrypted fields if not empty
if apiKey != "" {
setClauses = append(setClauses, "api_key = ?")
args = append(args, s.encrypt(apiKey))
updates["api_key"] = crypto.EncryptedString(apiKey)
}
if secretKey != "" {
setClauses = append(setClauses, "secret_key = ?")
args = append(args, s.encrypt(secretKey))
updates["secret_key"] = crypto.EncryptedString(secretKey)
}
if passphrase != "" {
setClauses = append(setClauses, "passphrase = ?")
args = append(args, s.encrypt(passphrase))
updates["passphrase"] = crypto.EncryptedString(passphrase)
}
if asterPrivateKey != "" {
setClauses = append(setClauses, "aster_private_key = ?")
args = append(args, s.encrypt(asterPrivateKey))
updates["aster_private_key"] = crypto.EncryptedString(asterPrivateKey)
}
if lighterPrivateKey != "" {
setClauses = append(setClauses, "lighter_private_key = ?")
args = append(args, s.encrypt(lighterPrivateKey))
updates["lighter_private_key"] = crypto.EncryptedString(lighterPrivateKey)
}
if lighterApiKeyPrivateKey != "" {
setClauses = append(setClauses, "lighter_api_key_private_key = ?")
args = append(args, s.encrypt(lighterApiKeyPrivateKey))
updates["lighter_api_key_private_key"] = crypto.EncryptedString(lighterApiKeyPrivateKey)
}
args = append(args, id, userID)
query := fmt.Sprintf(`UPDATE exchanges SET %s WHERE id = ? AND user_id = ?`, strings.Join(setClauses, ", "))
result, err := s.db.Exec(query, args...)
if err != nil {
return err
result := s.db.Model(&Exchange{}).Where("id = ? AND user_id = ?", id, userID).Updates(updates)
if result.Error != nil {
return result.Error
}
rowsAffected, _ := result.RowsAffected()
if rowsAffected == 0 {
if result.RowsAffected == 0 {
return fmt.Errorf("exchange not found: id=%s, userID=%s", id, userID)
}
return nil
@@ -427,13 +271,16 @@ func (s *ExchangeStore) Update(userID, id string, enabled bool, apiKey, secretKe
// UpdateAccountName updates the account name for an exchange
func (s *ExchangeStore) UpdateAccountName(userID, id, accountName string) error {
result, err := s.db.Exec(`UPDATE exchanges SET account_name = ?, updated_at = datetime('now') WHERE id = ? AND user_id = ?`,
accountName, id, userID)
if err != nil {
return err
result := s.db.Model(&Exchange{}).
Where("id = ? AND user_id = ?", id, userID).
Updates(map[string]interface{}{
"account_name": accountName,
"updated_at": time.Now(),
})
if result.Error != nil {
return result.Error
}
rowsAffected, _ := result.RowsAffected()
if rowsAffected == 0 {
if result.RowsAffected == 0 {
return fmt.Errorf("exchange not found: id=%s, userID=%s", id, userID)
}
return nil
@@ -441,12 +288,11 @@ func (s *ExchangeStore) UpdateAccountName(userID, id, accountName string) error
// Delete deletes an exchange account
func (s *ExchangeStore) Delete(userID, id string) error {
result, err := s.db.Exec(`DELETE FROM exchanges WHERE id = ? AND user_id = ?`, id, userID)
if err != nil {
return err
result := s.db.Where("id = ? AND user_id = ?", id, userID).Delete(&Exchange{})
if result.Error != nil {
return result.Error
}
rowsAffected, _ := result.RowsAffected()
if rowsAffected == 0 {
if result.RowsAffected == 0 {
return fmt.Errorf("exchange not found: id=%s, userID=%s", id, userID)
}
logger.Infof("🗑️ Deleted exchange: id=%s, userID=%s", id, userID)
@@ -460,20 +306,25 @@ func (s *ExchangeStore) CreateLegacy(userID, id, name, typ string, enabled bool,
// Check if this is an old-style ID (exchange type as ID)
if id == "binance" || id == "bybit" || id == "okx" || id == "bitget" || id == "hyperliquid" || id == "aster" || id == "lighter" {
// Use new Create method with exchange type
_, err := s.Create(userID, id, "Default", enabled, apiKey, secretKey, "", testnet,
hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey, "", "", "", 0)
return err
}
// Otherwise assume it's already a UUID
_, err := s.db.Exec(`
INSERT OR IGNORE INTO exchanges (id, exchange_type, account_name, user_id, name, type, enabled,
api_key, secret_key, testnet,
hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key,
lighter_wallet_addr, lighter_private_key)
VALUES (?, '', '', ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, '', '')
`, id, userID, name, typ, enabled, s.encrypt(apiKey), s.encrypt(secretKey), testnet,
hyperliquidWalletAddr, asterUser, asterSigner, s.encrypt(asterPrivateKey))
return err
exchange := &Exchange{
ID: id,
UserID: userID,
Name: name,
Type: typ,
Enabled: enabled,
APIKey: crypto.EncryptedString(apiKey),
SecretKey: crypto.EncryptedString(secretKey),
Testnet: testnet,
HyperliquidWalletAddr: hyperliquidWalletAddr,
AsterUser: asterUser,
AsterSigner: asterSigner,
AsterPrivateKey: crypto.EncryptedString(asterPrivateKey),
}
return s.db.Where("id = ?", id).FirstOrCreate(exchange).Error
}
+146
View File
@@ -0,0 +1,146 @@
package store
import (
"fmt"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// GormDB is the global GORM database connection
var gormDB *gorm.DB
// DB returns the GORM database connection
func DB() *gorm.DB {
return gormDB
}
// InitGorm initializes GORM with SQLite
func InitGorm(dbPath string) (*gorm.DB, error) {
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
return nil, fmt.Errorf("failed to open SQLite database: %w", err)
}
// Set connection pool for SQLite
sqlDB, err := db.DB()
if err != nil {
return nil, err
}
sqlDB.SetMaxOpenConns(1)
sqlDB.SetMaxIdleConns(1)
// Enable foreign keys for SQLite
db.Exec("PRAGMA foreign_keys = ON")
db.Exec("PRAGMA journal_mode = DELETE")
db.Exec("PRAGMA synchronous = FULL")
db.Exec("PRAGMA busy_timeout = 5000")
gormDB = db
return db, nil
}
// InitGormPostgres initializes GORM with PostgreSQL
func InitGormPostgres(host string, port int, user, password, dbname, sslmode string) (*gorm.DB, error) {
dsn := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
host, port, user, password, dbname, sslmode,
)
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err)
}
// Set connection pool for PostgreSQL
sqlDB, err := db.DB()
if err != nil {
return nil, err
}
sqlDB.SetMaxOpenConns(25)
sqlDB.SetMaxIdleConns(5)
gormDB = db
return db, nil
}
// InitGormWithConfig initializes GORM with provided configuration
// Uses DBConfig from driver.go
func InitGormWithConfig(cfg DBConfig) (*gorm.DB, error) {
switch cfg.Type {
case DBTypeSQLite:
return InitGorm(cfg.Path)
case DBTypePostgres:
return InitGormPostgres(
cfg.Host,
cfg.Port,
cfg.User,
cfg.Password,
cfg.DBName,
cfg.SSLMode,
)
default:
return nil, fmt.Errorf("unsupported DB_TYPE: %s (use 'sqlite' or 'postgres')", cfg.Type)
}
}
// ============================================================================
// Query Scopes - Reusable query helpers
// ============================================================================
// ForUser returns a scope that filters by user_id
func ForUser(userID string) func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.Where("user_id = ?", userID)
}
}
// ForTrader returns a scope that filters by trader_id
func ForTrader(traderID string) func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.Where("trader_id = ?", traderID)
}
}
// OpenPositions returns a scope for open positions
func OpenPositions() func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.Where("status = ?", "OPEN")
}
}
// ClosedPositions returns a scope for closed positions
func ClosedPositions() func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.Where("status = ?", "CLOSED")
}
}
// OrderByCreatedDesc returns a scope that orders by created_at DESC
func OrderByCreatedDesc() func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.Order("created_at DESC")
}
}
// OrderByUpdatedDesc returns a scope that orders by updated_at DESC
func OrderByUpdatedDesc() func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.Order("updated_at DESC")
}
}
// Paginate returns a scope for pagination
func Paginate(limit, offset int) func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.Limit(limit).Offset(offset)
}
}
+195 -444
View File
@@ -1,495 +1,255 @@
package store
import (
"database/sql"
"fmt"
"strings"
"time"
"gorm.io/gorm"
)
// TraderOrder 订单记录(完整的订单生命周期追踪)
// TraderOrder order record
type TraderOrder struct {
ID int64 `json:"id"`
TraderID string `json:"trader_id"`
ExchangeID string `json:"exchange_id"` // Exchange account UUID
ExchangeType string `json:"exchange_type"` // Exchange type (hyperliquid/lighter/binance/etc)
ExchangeOrderID string `json:"exchange_order_id"` // Exchange order ID
ClientOrderID string `json:"client_order_id"` // Client order ID
Symbol string `json:"symbol"` // Trading pair
Side string `json:"side"` // BUY/SELL
PositionSide string `json:"position_side"` // LONG/SHORT (hedge mode)
Type string `json:"type"` // MARKET/LIMIT/STOP/STOP_MARKET/TAKE_PROFIT/TAKE_PROFIT_MARKET
TimeInForce string `json:"time_in_force"` // GTC/IOC/FOK
Quantity float64 `json:"quantity"` // 订单数量
Price float64 `json:"price"` // 限价单价格
StopPrice float64 `json:"stop_price"` // 止损/止盈触发价格
Status string `json:"status"` // NEW/PARTIALLY_FILLED/FILLED/CANCELED/REJECTED/EXPIRED
FilledQuantity float64 `json:"filled_quantity"` // 已成交数量
AvgFillPrice float64 `json:"avg_fill_price"` // 平均成交价格
Commission float64 `json:"commission"` // 手续费总额
CommissionAsset string `json:"commission_asset"` // 手续费资产(USDT等)
Leverage int `json:"leverage"` // 杠杆倍数
ReduceOnly bool `json:"reduce_only"` // 是否只减仓
ClosePosition bool `json:"close_position"` // 是否平仓单
WorkingType string `json:"working_type"` // CONTRACT_PRICE/MARK_PRICE
PriceProtect bool `json:"price_protect"` // 价格保护
OrderAction string `json:"order_action"` // OPEN_LONG/OPEN_SHORT/CLOSE_LONG/CLOSE_SHORT/ADD_LONG/ADD_SHORT/STOP_LOSS/TAKE_PROFIT
RelatedPositionID int64 `json:"related_position_id"` // 关联的仓位ID
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
FilledAt time.Time `json:"filled_at"` // 完全成交时间
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
TraderID string `gorm:"column:trader_id;not null;index:idx_orders_trader_id" json:"trader_id"`
ExchangeID string `gorm:"column:exchange_id;not null;default:''" json:"exchange_id"`
ExchangeType string `gorm:"column:exchange_type;not null;default:''" json:"exchange_type"`
ExchangeOrderID string `gorm:"column:exchange_order_id;not null;uniqueIndex:idx_orders_exchange_unique,priority:2" json:"exchange_order_id"`
ClientOrderID string `gorm:"column:client_order_id;default:''" json:"client_order_id"`
Symbol string `gorm:"column:symbol;not null;index:idx_orders_symbol" json:"symbol"`
Side string `gorm:"column:side;not null" json:"side"`
PositionSide string `gorm:"column:position_side;default:''" json:"position_side"`
Type string `gorm:"column:type;not null" json:"type"`
TimeInForce string `gorm:"column:time_in_force;default:GTC" json:"time_in_force"`
Quantity float64 `gorm:"column:quantity;not null" json:"quantity"`
Price float64 `gorm:"column:price;default:0" json:"price"`
StopPrice float64 `gorm:"column:stop_price;default:0" json:"stop_price"`
Status string `gorm:"column:status;not null;default:NEW;index:idx_orders_status" json:"status"`
FilledQuantity float64 `gorm:"column:filled_quantity;default:0" json:"filled_quantity"`
AvgFillPrice float64 `gorm:"column:avg_fill_price;default:0" json:"avg_fill_price"`
Commission float64 `gorm:"column:commission;default:0" json:"commission"`
CommissionAsset string `gorm:"column:commission_asset;default:USDT" json:"commission_asset"`
Leverage int `gorm:"column:leverage;default:1" json:"leverage"`
ReduceOnly bool `gorm:"column:reduce_only;default:false" json:"reduce_only"`
ClosePosition bool `gorm:"column:close_position;default:false" json:"close_position"`
WorkingType string `gorm:"column:working_type;default:CONTRACT_PRICE" json:"working_type"`
PriceProtect bool `gorm:"column:price_protect;default:false" json:"price_protect"`
OrderAction string `gorm:"column:order_action;default:''" json:"order_action"`
RelatedPositionID int64 `gorm:"column:related_position_id;default:0" json:"related_position_id"`
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime" json:"updated_at"`
FilledAt time.Time `gorm:"column:filled_at" json:"filled_at"`
}
// TraderFill trade record (one order may have multiple fills)
// TableName returns the table name for TraderOrder
func (TraderOrder) TableName() string {
return "trader_orders"
}
// TraderFill trade record
type TraderFill struct {
ID int64 `json:"id"`
TraderID string `json:"trader_id"`
ExchangeID string `json:"exchange_id"` // Exchange account UUID
ExchangeType string `json:"exchange_type"` // Exchange type (hyperliquid/lighter/binance/etc)
OrderID int64 `json:"order_id"` // Related order ID
ExchangeOrderID string `json:"exchange_order_id"` // Exchange order ID
ExchangeTradeID string `json:"exchange_trade_id"` // Exchange trade ID
Symbol string `json:"symbol"`
Side string `json:"side"` // BUY/SELL
Price float64 `json:"price"` // 成交价格
Quantity float64 `json:"quantity"` // 成交数量
QuoteQuantity float64 `json:"quote_quantity"` // 成交金额(USDT
Commission float64 `json:"commission"` // 手续费
CommissionAsset string `json:"commission_asset"`
RealizedPnL float64 `json:"realized_pnl"` // 实现盈亏(平仓时)
IsMaker bool `json:"is_maker"` // 是否为maker
CreatedAt time.Time `json:"created_at"`
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
TraderID string `gorm:"column:trader_id;not null;index:idx_fills_trader_id" json:"trader_id"`
ExchangeID string `gorm:"column:exchange_id;not null;default:''" json:"exchange_id"`
ExchangeType string `gorm:"column:exchange_type;not null;default:''" json:"exchange_type"`
OrderID int64 `gorm:"column:order_id;not null;index:idx_fills_order_id" json:"order_id"`
ExchangeOrderID string `gorm:"column:exchange_order_id;not null" json:"exchange_order_id"`
ExchangeTradeID string `gorm:"column:exchange_trade_id;not null;uniqueIndex:idx_fills_exchange_unique,priority:2" json:"exchange_trade_id"`
Symbol string `gorm:"column:symbol;not null" json:"symbol"`
Side string `gorm:"column:side;not null" json:"side"`
Price float64 `gorm:"column:price;not null" json:"price"`
Quantity float64 `gorm:"column:quantity;not null" json:"quantity"`
QuoteQuantity float64 `gorm:"column:quote_quantity;not null" json:"quote_quantity"`
Commission float64 `gorm:"column:commission;not null" json:"commission"`
CommissionAsset string `gorm:"column:commission_asset;not null" json:"commission_asset"`
RealizedPnL float64 `gorm:"column:realized_pnl;default:0" json:"realized_pnl"`
IsMaker bool `gorm:"column:is_maker;default:false" json:"is_maker"`
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
}
// OrderStore 订单存储
// TableName returns the table name for TraderFill
func (TraderFill) TableName() string {
return "trader_fills"
}
// OrderStore order storage
type OrderStore struct {
db *sql.DB
db *gorm.DB
}
// NewOrderStore 创建订单存储实例
func NewOrderStore(db *sql.DB) *OrderStore {
// NewOrderStore creates order storage instance
func NewOrderStore(db *gorm.DB) *OrderStore {
return &OrderStore{db: db}
}
// InitTables 初始化订单表
// InitTables initializes order tables
func (s *OrderStore) InitTables() error {
// 创建订单表
_, err := s.db.Exec(`
CREATE TABLE IF NOT EXISTS trader_orders (
id INTEGER PRIMARY KEY AUTOINCREMENT,
trader_id TEXT NOT NULL,
exchange_id TEXT NOT NULL DEFAULT '',
exchange_type TEXT NOT NULL DEFAULT '',
exchange_order_id TEXT NOT NULL,
client_order_id TEXT DEFAULT '',
symbol TEXT NOT NULL,
side TEXT NOT NULL,
position_side TEXT DEFAULT '',
type TEXT NOT NULL,
time_in_force TEXT DEFAULT 'GTC',
quantity REAL NOT NULL,
price REAL DEFAULT 0,
stop_price REAL DEFAULT 0,
status TEXT NOT NULL DEFAULT 'NEW',
filled_quantity REAL DEFAULT 0,
avg_fill_price REAL DEFAULT 0,
commission REAL DEFAULT 0,
commission_asset TEXT DEFAULT 'USDT',
leverage INTEGER DEFAULT 1,
reduce_only INTEGER DEFAULT 0,
close_position INTEGER DEFAULT 0,
working_type TEXT DEFAULT 'CONTRACT_PRICE',
price_protect INTEGER DEFAULT 0,
order_action TEXT DEFAULT '',
related_position_id INTEGER DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
filled_at DATETIME,
UNIQUE(exchange_id, exchange_order_id)
)
`)
if err != nil {
return fmt.Errorf("failed to create trader_orders table: %w", err)
}
// For PostgreSQL, check if tables exist to avoid AutoMigrate index conflicts
if s.db.Dialector.Name() == "postgres" {
var ordersExist, fillsExist int64
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'trader_orders'`).Scan(&ordersExist)
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'trader_fills'`).Scan(&fillsExist)
// 创建成交记录表
_, err = s.db.Exec(`
CREATE TABLE IF NOT EXISTS trader_fills (
id INTEGER PRIMARY KEY AUTOINCREMENT,
trader_id TEXT NOT NULL,
exchange_id TEXT NOT NULL DEFAULT '',
exchange_type TEXT NOT NULL DEFAULT '',
order_id INTEGER NOT NULL,
exchange_order_id TEXT NOT NULL,
exchange_trade_id TEXT NOT NULL,
symbol TEXT NOT NULL,
side TEXT NOT NULL,
price REAL NOT NULL,
quantity REAL NOT NULL,
quote_quantity REAL NOT NULL,
commission REAL NOT NULL,
commission_asset TEXT NOT NULL,
realized_pnl REAL DEFAULT 0,
is_maker INTEGER DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(exchange_id, exchange_trade_id),
FOREIGN KEY (order_id) REFERENCES trader_orders(id)
)
`)
if err != nil {
return fmt.Errorf("failed to create trader_fills table: %w", err)
}
// 创建索引
if ordersExist > 0 && fillsExist > 0 {
// Tables exist - just ensure indexes exist, skip AutoMigrate
s.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_orders_exchange_unique ON trader_orders(exchange_id, exchange_order_id)`)
s.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_fills_exchange_unique ON trader_fills(exchange_id, exchange_trade_id)`)
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_trader_id ON trader_orders(trader_id)`)
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_symbol ON trader_orders(symbol)`)
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_status ON trader_orders(status)`)
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_exchange_order_id ON trader_orders(exchange_id, exchange_order_id)`)
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_fills_order_id ON trader_fills(order_id)`)
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_fills_trader_id ON trader_fills(trader_id)`)
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_fills_order_id ON trader_fills(order_id)`)
return nil
}
}
if err := s.db.AutoMigrate(&TraderOrder{}, &TraderFill{}); err != nil {
return fmt.Errorf("failed to migrate order tables: %w", err)
}
// Create unique composite index for exchange_id + exchange_order_id
s.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_orders_exchange_unique ON trader_orders(exchange_id, exchange_order_id)`)
// Create unique composite index for exchange_id + exchange_trade_id
s.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_fills_exchange_unique ON trader_fills(exchange_id, exchange_trade_id)`)
return nil
}
// CreateOrder 创建订单记录(去重:如果订单已存在则返回已有记录)
// CreateOrder creates order record
func (s *OrderStore) CreateOrder(order *TraderOrder) error {
// 1. 先检查订单是否已存在(去重)
// Check if order already exists
existing, err := s.GetOrderByExchangeID(order.ExchangeID, order.ExchangeOrderID)
if err != nil {
return fmt.Errorf("failed to check existing order: %w", err)
}
if existing != nil {
// 订单已存在,返回已有记录的ID
order.ID = existing.ID
order.CreatedAt = existing.CreatedAt
order.UpdatedAt = existing.UpdatedAt
return nil // 不是错误,只是跳过插入
}
// 2. 订单不存在,插入新记录
now := time.Now()
order.CreatedAt = now
order.UpdatedAt = now
result, err := s.db.Exec(`
INSERT INTO trader_orders (
trader_id, exchange_id, exchange_type, exchange_order_id, client_order_id,
symbol, side, position_side, type, time_in_force,
quantity, price, stop_price, status,
filled_quantity, avg_fill_price, commission, commission_asset,
leverage, reduce_only, close_position, working_type, price_protect,
order_action, related_position_id,
created_at, updated_at, filled_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`,
order.TraderID, order.ExchangeID, order.ExchangeType, order.ExchangeOrderID, order.ClientOrderID,
order.Symbol, order.Side, order.PositionSide, order.Type, order.TimeInForce,
order.Quantity, order.Price, order.StopPrice, order.Status,
order.FilledQuantity, order.AvgFillPrice, order.Commission, order.CommissionAsset,
order.Leverage, order.ReduceOnly, order.ClosePosition, order.WorkingType, order.PriceProtect,
order.OrderAction, order.RelatedPositionID,
formatTimeOrNow(order.CreatedAt, now), formatTimeOrNow(order.UpdatedAt, now),
formatTimePtr(order.FilledAt),
)
if err != nil {
return fmt.Errorf("failed to create order: %w", err)
}
id, _ := result.LastInsertId()
order.ID = id
return nil
}
return s.db.Create(order).Error
}
// UpdateOrderStatus 更新订单状态
// UpdateOrderStatus updates order status
func (s *OrderStore) UpdateOrderStatus(id int64, status string, filledQty, avgPrice, commission float64) error {
now := time.Now()
updateSQL := `
UPDATE trader_orders SET
status = ?,
filled_quantity = ?,
avg_fill_price = ?,
commission = ?,
updated_at = ?
`
args := []interface{}{status, filledQty, avgPrice, commission, now.Format(time.RFC3339)}
updates := map[string]interface{}{
"status": status,
"filled_quantity": filledQty,
"avg_fill_price": avgPrice,
"commission": commission,
}
// 如果完全成交,记录成交时间
if status == "FILLED" {
updateSQL += `, filled_at = ?`
args = append(args, now.Format(time.RFC3339))
updates["filled_at"] = time.Now()
}
updateSQL += ` WHERE id = ?`
args = append(args, id)
_, err := s.db.Exec(updateSQL, args...)
if err != nil {
return fmt.Errorf("failed to update order status: %w", err)
}
return nil
return s.db.Model(&TraderOrder{}).Where("id = ?", id).Updates(updates).Error
}
// CreateFill 创建成交记录(去重:如果成交已存在则跳过)
// CreateFill creates fill record
func (s *OrderStore) CreateFill(fill *TraderFill) error {
// 1. 先检查成交是否已存在(去重)
// Check if fill already exists
existing, err := s.GetFillByExchangeTradeID(fill.ExchangeID, fill.ExchangeTradeID)
if err != nil {
return fmt.Errorf("failed to check existing fill: %w", err)
}
if existing != nil {
// 成交已存在,返回已有记录的ID
fill.ID = existing.ID
fill.CreatedAt = existing.CreatedAt
return nil // 不是错误,只是跳过插入
}
// 2. 成交不存在,插入新记录
now := time.Now()
fill.CreatedAt = now
result, err := s.db.Exec(`
INSERT INTO trader_fills (
trader_id, exchange_id, exchange_type, order_id, exchange_order_id, exchange_trade_id,
symbol, side, price, quantity, quote_quantity,
commission, commission_asset, realized_pnl, is_maker,
created_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`,
fill.TraderID, fill.ExchangeID, fill.ExchangeType, fill.OrderID, fill.ExchangeOrderID, fill.ExchangeTradeID,
fill.Symbol, fill.Side, fill.Price, fill.Quantity, fill.QuoteQuantity,
fill.Commission, fill.CommissionAsset, fill.RealizedPnL, fill.IsMaker,
now.Format(time.RFC3339),
)
if err != nil {
return fmt.Errorf("failed to create fill: %w", err)
}
id, _ := result.LastInsertId()
fill.ID = id
return nil
}
return s.db.Create(fill).Error
}
// GetFillByExchangeTradeID 根据交易所成交ID获取成交记录
// GetFillByExchangeTradeID gets fill by exchange trade ID
func (s *OrderStore) GetFillByExchangeTradeID(exchangeID, exchangeTradeID string) (*TraderFill, error) {
row := s.db.QueryRow(`
SELECT id, trader_id, exchange_id, exchange_type, order_id, exchange_order_id, exchange_trade_id,
symbol, side, price, quantity, quote_quantity,
commission, commission_asset, realized_pnl, is_maker,
created_at
FROM trader_fills
WHERE exchange_id = ? AND exchange_trade_id = ?
`, exchangeID, exchangeTradeID)
var fill TraderFill
var createdAt sql.NullString
err := row.Scan(
&fill.ID, &fill.TraderID, &fill.ExchangeID, &fill.ExchangeType, &fill.OrderID, &fill.ExchangeOrderID, &fill.ExchangeTradeID,
&fill.Symbol, &fill.Side, &fill.Price, &fill.Quantity, &fill.QuoteQuantity,
&fill.Commission, &fill.CommissionAsset, &fill.RealizedPnL, &fill.IsMaker,
&createdAt,
)
if err == sql.ErrNoRows {
err := s.db.Where("exchange_id = ? AND exchange_trade_id = ?", exchangeID, exchangeTradeID).First(&fill).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to get fill: %w", err)
}
// Parse time
if createdAt.Valid {
if t, err := time.Parse(time.RFC3339, createdAt.String); err == nil {
fill.CreatedAt = t
}
}
return &fill, nil
}
// GetOrderByExchangeID 根据交易所订单ID获取订单
// GetOrderByExchangeID gets order by exchange order ID
func (s *OrderStore) GetOrderByExchangeID(exchangeID, exchangeOrderID string) (*TraderOrder, error) {
row := s.db.QueryRow(`
SELECT id, trader_id, exchange_id, exchange_type, exchange_order_id, client_order_id,
symbol, side, position_side, type, time_in_force,
quantity, price, stop_price, status,
filled_quantity, avg_fill_price, commission, commission_asset,
leverage, reduce_only, close_position, working_type, price_protect,
order_action, related_position_id,
created_at, updated_at, filled_at
FROM trader_orders
WHERE exchange_id = ? AND exchange_order_id = ?
`, exchangeID, exchangeOrderID)
var order TraderOrder
var createdAt, updatedAt, filledAt sql.NullString
err := row.Scan(
&order.ID, &order.TraderID, &order.ExchangeID, &order.ExchangeType, &order.ExchangeOrderID, &order.ClientOrderID,
&order.Symbol, &order.Side, &order.PositionSide, &order.Type, &order.TimeInForce,
&order.Quantity, &order.Price, &order.StopPrice, &order.Status,
&order.FilledQuantity, &order.AvgFillPrice, &order.Commission, &order.CommissionAsset,
&order.Leverage, &order.ReduceOnly, &order.ClosePosition, &order.WorkingType, &order.PriceProtect,
&order.OrderAction, &order.RelatedPositionID,
&createdAt, &updatedAt, &filledAt,
)
if err == sql.ErrNoRows {
err := s.db.Where("exchange_id = ? AND exchange_order_id = ?", exchangeID, exchangeOrderID).First(&order).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to get order: %w", err)
}
// Parse times
if createdAt.Valid {
if t, err := time.Parse(time.RFC3339, createdAt.String); err == nil {
order.CreatedAt = t
}
}
if updatedAt.Valid {
if t, err := time.Parse(time.RFC3339, updatedAt.String); err == nil {
order.UpdatedAt = t
}
}
if filledAt.Valid {
if t, err := time.Parse(time.RFC3339, filledAt.String); err == nil {
order.FilledAt = t
}
}
return &order, nil
}
// GetTraderOrders 获取trader的订单列表
// GetTraderOrders gets trader's order list
func (s *OrderStore) GetTraderOrders(traderID string, limit int) ([]*TraderOrder, error) {
rows, err := s.db.Query(`
SELECT id, trader_id, exchange_id, exchange_type, exchange_order_id, client_order_id,
symbol, side, position_side, type, time_in_force,
quantity, price, stop_price, status,
filled_quantity, avg_fill_price, commission, commission_asset,
leverage, reduce_only, close_position, working_type, price_protect,
order_action, related_position_id,
created_at, updated_at, filled_at
FROM trader_orders
WHERE trader_id = ?
ORDER BY created_at DESC
LIMIT ?
`, traderID, limit)
var orders []*TraderOrder
err := s.db.Where("trader_id = ?", traderID).
Order("created_at DESC").
Limit(limit).
Find(&orders).Error
if err != nil {
return nil, fmt.Errorf("failed to query orders: %w", err)
}
defer rows.Close()
var orders []*TraderOrder
for rows.Next() {
var order TraderOrder
var createdAt, updatedAt, filledAt sql.NullString
err := rows.Scan(
&order.ID, &order.TraderID, &order.ExchangeID, &order.ExchangeType, &order.ExchangeOrderID, &order.ClientOrderID,
&order.Symbol, &order.Side, &order.PositionSide, &order.Type, &order.TimeInForce,
&order.Quantity, &order.Price, &order.StopPrice, &order.Status,
&order.FilledQuantity, &order.AvgFillPrice, &order.Commission, &order.CommissionAsset,
&order.Leverage, &order.ReduceOnly, &order.ClosePosition, &order.WorkingType, &order.PriceProtect,
&order.OrderAction, &order.RelatedPositionID,
&createdAt, &updatedAt, &filledAt,
)
if err != nil {
continue
}
// Parse times
if createdAt.Valid {
if t, err := time.Parse(time.RFC3339, createdAt.String); err == nil {
order.CreatedAt = t
}
}
if updatedAt.Valid {
if t, err := time.Parse(time.RFC3339, updatedAt.String); err == nil {
order.UpdatedAt = t
}
}
if filledAt.Valid {
if t, err := time.Parse(time.RFC3339, filledAt.String); err == nil {
order.FilledAt = t
}
}
orders = append(orders, &order)
}
return orders, nil
}
// GetOrderFills 获取订单的成交记录
// GetOrderFills gets order's fill records
func (s *OrderStore) GetOrderFills(orderID int64) ([]*TraderFill, error) {
rows, err := s.db.Query(`
SELECT id, trader_id, exchange_id, exchange_type, order_id, exchange_order_id, exchange_trade_id,
symbol, side, price, quantity, quote_quantity,
commission, commission_asset, realized_pnl, is_maker,
created_at
FROM trader_fills
WHERE order_id = ?
ORDER BY created_at ASC
`, orderID)
var fills []*TraderFill
err := s.db.Where("order_id = ?", orderID).
Order("created_at ASC").
Find(&fills).Error
if err != nil {
return nil, fmt.Errorf("failed to query fills: %w", err)
}
defer rows.Close()
var fills []*TraderFill
for rows.Next() {
var fill TraderFill
var createdAt sql.NullString
err := rows.Scan(
&fill.ID, &fill.TraderID, &fill.ExchangeID, &fill.ExchangeType, &fill.OrderID, &fill.ExchangeOrderID, &fill.ExchangeTradeID,
&fill.Symbol, &fill.Side, &fill.Price, &fill.Quantity, &fill.QuoteQuantity,
&fill.Commission, &fill.CommissionAsset, &fill.RealizedPnL, &fill.IsMaker,
&createdAt,
)
if err != nil {
continue
}
if createdAt.Valid {
if t, err := time.Parse(time.RFC3339, createdAt.String); err == nil {
fill.CreatedAt = t
}
}
fills = append(fills, &fill)
}
return fills, nil
}
// GetTraderOrderStats 获取trader的订单统计
// GetTraderOrderStats gets trader's order statistics
func (s *OrderStore) GetTraderOrderStats(traderID string) (map[string]interface{}, error) {
var totalOrders, filledOrders, canceledOrders int
var totalCommission, totalVolume float64
type result struct {
TotalOrders int
FilledOrders int
CanceledOrders int
TotalCommission float64
TotalVolume float64
}
var r result
err := s.db.QueryRow(`
SELECT
COUNT(*) as total_orders,
err := s.db.Model(&TraderOrder{}).
Select(`COUNT(*) as total_orders,
SUM(CASE WHEN status = 'FILLED' THEN 1 ELSE 0 END) as filled_orders,
SUM(CASE WHEN status = 'CANCELED' THEN 1 ELSE 0 END) as canceled_orders,
SUM(commission) as total_commission,
SUM(filled_quantity * avg_fill_price) as total_volume
FROM trader_orders
WHERE trader_id = ?
`, traderID).Scan(&totalOrders, &filledOrders, &canceledOrders, &totalCommission, &totalVolume)
SUM(filled_quantity * avg_fill_price) as total_volume`).
Where("trader_id = ?", traderID).
Scan(&r).Error
if err != nil {
return nil, fmt.Errorf("failed to get order stats: %w", err)
}
return map[string]interface{}{
"total_orders": totalOrders,
"filled_orders": filledOrders,
"canceled_orders": canceledOrders,
"total_commission": totalCommission,
"total_volume": totalVolume,
"total_orders": r.TotalOrders,
"filled_orders": r.FilledOrders,
"canceled_orders": r.CanceledOrders,
"total_commission": r.TotalCommission,
"total_volume": r.TotalVolume,
}, nil
}
// CleanupDuplicateOrders 清理重复的订单记录(保留最早创建的记录)
// CleanupDuplicateOrders cleans up duplicate order records
func (s *OrderStore) CleanupDuplicateOrders() (int, error) {
result, err := s.db.Exec(`
result := s.db.Exec(`
DELETE FROM trader_orders
WHERE id NOT IN (
SELECT MIN(id)
@@ -497,17 +257,15 @@ func (s *OrderStore) CleanupDuplicateOrders() (int, error) {
GROUP BY exchange_id, exchange_order_id
)
`)
if err != nil {
return 0, fmt.Errorf("failed to cleanup duplicate orders: %w", err)
if result.Error != nil {
return 0, fmt.Errorf("failed to cleanup duplicate orders: %w", result.Error)
}
rowsAffected, _ := result.RowsAffected()
return int(rowsAffected), nil
return int(result.RowsAffected), nil
}
// CleanupDuplicateFills 清理重复的成交记录(保留最早创建的记录)
// CleanupDuplicateFills cleans up duplicate fill records
func (s *OrderStore) CleanupDuplicateFills() (int, error) {
result, err := s.db.Exec(`
result := s.db.Exec(`
DELETE FROM trader_fills
WHERE id NOT IN (
SELECT MIN(id)
@@ -515,73 +273,66 @@ func (s *OrderStore) CleanupDuplicateFills() (int, error) {
GROUP BY exchange_id, exchange_trade_id
)
`)
if err != nil {
return 0, fmt.Errorf("failed to cleanup duplicate fills: %w", err)
if result.Error != nil {
return 0, fmt.Errorf("failed to cleanup duplicate fills: %w", result.Error)
}
rowsAffected, _ := result.RowsAffected()
return int(rowsAffected), nil
return int(result.RowsAffected), nil
}
// GetDuplicateOrdersCount 获取重复订单的数量(用于诊断)
// GetDuplicateOrdersCount gets duplicate orders count
func (s *OrderStore) GetDuplicateOrdersCount() (int, error) {
var count int
err := s.db.QueryRow(`
SELECT COUNT(*) - COUNT(DISTINCT exchange_id || ',' || exchange_order_id)
FROM trader_orders
`).Scan(&count)
return count, err
var total, distinct int64
s.db.Model(&TraderOrder{}).Count(&total)
// Count distinct combinations
var distinctResult struct{ Count int64 }
s.db.Model(&TraderOrder{}).
Select("COUNT(DISTINCT exchange_id || ',' || exchange_order_id) as count").
Scan(&distinctResult)
distinct = distinctResult.Count
return int(total - distinct), nil
}
// GetDuplicateFillsCount 获取重复成交的数量(用于诊断)
// GetDuplicateFillsCount gets duplicate fills count
func (s *OrderStore) GetDuplicateFillsCount() (int, error) {
var count int
err := s.db.QueryRow(`
SELECT COUNT(*) - COUNT(DISTINCT exchange_id || ',' || exchange_trade_id)
FROM trader_fills
`).Scan(&count)
return count, err
var total, distinct int64
s.db.Model(&TraderFill{}).Count(&total)
var distinctResult struct{ Count int64 }
s.db.Model(&TraderFill{}).
Select("COUNT(DISTINCT exchange_id || ',' || exchange_trade_id) as count").
Scan(&distinctResult)
distinct = distinctResult.Count
return int(total - distinct), nil
}
// GetMaxTradeIDsByExchange returns max trade ID for each symbol for a given exchange
// Used for incremental sync - only fetch trades with ID > maxTradeID
func (s *OrderStore) GetMaxTradeIDsByExchange(exchangeID string) (map[string]int64, error) {
rows, err := s.db.Query(`
SELECT symbol, MAX(CAST(exchange_trade_id AS INTEGER)) as max_trade_id
FROM trader_fills
WHERE exchange_id = ? AND exchange_trade_id != ''
GROUP BY symbol
`, exchangeID)
type symbolMaxID struct {
Symbol string
MaxTradeID int64
}
var results []symbolMaxID
err := s.db.Model(&TraderFill{}).
Select("symbol, MAX(CAST(exchange_trade_id AS INTEGER)) as max_trade_id").
Where("exchange_id = ? AND exchange_trade_id != ''", exchangeID).
Group("symbol").
Find(&results).Error
if err != nil {
// If CAST fails (non-numeric trade IDs), fallback to string comparison
if strings.Contains(err.Error(), "CAST") || strings.Contains(err.Error(), "invalid") {
return make(map[string]int64), nil
}
return nil, fmt.Errorf("failed to query max trade IDs: %w", err)
}
defer rows.Close()
result := make(map[string]int64)
for rows.Next() {
var symbol string
var maxID int64
if err := rows.Scan(&symbol, &maxID); err != nil {
continue
}
result[symbol] = maxID
for _, r := range results {
result[r.Symbol] = r.MaxTradeID
}
return result, nil
}
// formatTimePtr formats time.Time to RFC3339 string, returns NULL for zero time
func formatTimePtr(t time.Time) interface{} {
if t.IsZero() {
return nil
}
return t.Format(time.RFC3339)
}
// formatTimeOrNow returns the formatted time if not zero, otherwise returns now
func formatTimeOrNow(t time.Time, now time.Time) string {
if t.IsZero() {
return now.Format(time.RFC3339)
}
return t.Format(time.RFC3339)
}
+521 -813
View File
File diff suppressed because it is too large Load Diff
+104 -80
View File
@@ -7,12 +7,15 @@ import (
"fmt"
"nofx/logger"
"sync"
"gorm.io/gorm"
)
// Store unified data storage interface
type Store struct {
db *sql.DB
driver *DBDriver // Database driver for abstraction
gdb *gorm.DB // GORM database connection
db *sql.DB // Legacy sql.DB for backward compatibility
driver *DBDriver // Database driver for abstraction (legacy)
// Sub-stores (lazy initialization)
user *UserStore
@@ -26,105 +29,103 @@ type Store struct {
equity *EquityStore
order *OrderStore
// Encryption functions
encryptFunc func(string) string
decryptFunc func(string) string
mu sync.RWMutex
}
// New creates new Store instance (SQLite mode for backward compatibility)
func New(dbPath string) (*Store, error) {
driver, err := NewDBDriver(DBConfig{Type: DBTypeSQLite, Path: dbPath})
gdb, err := InitGorm(dbPath)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
s := &Store{db: driver.DB(), driver: driver}
// Get underlying sql.DB for legacy compatibility
sqlDB, err := gdb.DB()
if err != nil {
return nil, fmt.Errorf("failed to get sql.DB: %w", err)
}
s := &Store{gdb: gdb, db: sqlDB}
// Initialize all table structures
if err := s.initTables(); err != nil {
driver.Close()
sqlDB.Close()
return nil, fmt.Errorf("failed to initialize table structure: %w", err)
}
// Initialize default data
if err := s.initDefaultData(); err != nil {
driver.Close()
sqlDB.Close()
return nil, fmt.Errorf("failed to initialize default data: %w", err)
}
logger.Infof("✅ Database initialized (type: %s)", driver.Type)
logger.Infof("✅ Database initialized (GORM, SQLite)")
return s, nil
}
// NewFromEnv creates new Store instance from environment variables
// DB_TYPE: sqlite (default) or postgres
// For SQLite: DB_PATH (default: data/data.db)
// For PostgreSQL: DB_HOST, DB_PORT, DB_USER, DB_PASSWORD, DB_NAME, DB_SSLMODE
func NewFromEnv() (*Store, error) {
driver, err := NewDBDriverFromEnv()
// NewWithConfig creates new Store instance with provided database configuration
func NewWithConfig(cfg DBConfig) (*Store, error) {
gdb, err := InitGormWithConfig(cfg)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
s := &Store{db: driver.DB(), driver: driver}
// Get underlying sql.DB for legacy compatibility
sqlDB, err := gdb.DB()
if err != nil {
return nil, fmt.Errorf("failed to get sql.DB: %w", err)
}
s := &Store{gdb: gdb, db: sqlDB}
// Initialize all table structures
if err := s.initTables(); err != nil {
driver.Close()
sqlDB.Close()
return nil, fmt.Errorf("failed to initialize table structure: %w", err)
}
// Initialize default data
if err := s.initDefaultData(); err != nil {
driver.Close()
sqlDB.Close()
return nil, fmt.Errorf("failed to initialize default data: %w", err)
}
logger.Infof("✅ Database initialized (type: %s)", driver.Type)
dbTypeStr := "SQLite"
if cfg.Type == DBTypePostgres {
dbTypeStr = "PostgreSQL"
}
logger.Infof("✅ Database initialized (GORM, %s)", dbTypeStr)
return s, nil
}
// NewFromDB creates Store from existing database connection
// NewFromGorm creates Store from existing GORM connection
func NewFromGorm(gdb *gorm.DB) (*Store, error) {
sqlDB, err := gdb.DB()
if err != nil {
return nil, err
}
return &Store{gdb: gdb, db: sqlDB}, nil
}
// NewFromDB creates Store from existing database connection (legacy)
// Deprecated: Use NewFromGorm instead
func NewFromDB(db *sql.DB) *Store {
return &Store{db: db}
}
// SetCryptoFuncs sets encryption/decryption functions
func (s *Store) SetCryptoFuncs(encrypt, decrypt func(string) string) {
s.mu.Lock()
defer s.mu.Unlock()
s.encryptFunc = encrypt
s.decryptFunc = decrypt
// Update already initialized sub-stores
if s.aiModel != nil {
s.aiModel.encryptFunc = encrypt
s.aiModel.decryptFunc = decrypt
}
if s.exchange != nil {
s.exchange.encryptFunc = encrypt
s.exchange.decryptFunc = decrypt
}
if s.trader != nil {
s.trader.decryptFunc = decrypt
}
}
// initTables initializes all database tables
// initTables initializes all database tables using GORM AutoMigrate
func (s *Store) initTables() error {
// Initialize system config table first
if _, err := s.db.Exec(`
// Create system_config table (GORM handles this via raw SQL for simplicity)
if err := s.gdb.Exec(`
CREATE TABLE IF NOT EXISTS system_config (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
)
`); err != nil {
`).Error; err != nil {
return fmt.Errorf("failed to create system_config table: %w", err)
}
// Initialize in dependency order
// Initialize sub-store tables
if err := s.User().initTables(); err != nil {
return fmt.Errorf("failed to initialize user tables: %w", err)
}
@@ -183,7 +184,7 @@ func (s *Store) User() *UserStore {
s.mu.Lock()
defer s.mu.Unlock()
if s.user == nil {
s.user = &UserStore{db: s.db}
s.user = NewUserStore(s.gdb)
}
return s.user
}
@@ -193,11 +194,7 @@ func (s *Store) AIModel() *AIModelStore {
s.mu.Lock()
defer s.mu.Unlock()
if s.aiModel == nil {
s.aiModel = &AIModelStore{
db: s.db,
encryptFunc: s.encryptFunc,
decryptFunc: s.decryptFunc,
}
s.aiModel = NewAIModelStore(s.gdb)
}
return s.aiModel
}
@@ -207,11 +204,7 @@ func (s *Store) Exchange() *ExchangeStore {
s.mu.Lock()
defer s.mu.Unlock()
if s.exchange == nil {
s.exchange = &ExchangeStore{
db: s.db,
encryptFunc: s.encryptFunc,
decryptFunc: s.decryptFunc,
}
s.exchange = NewExchangeStore(s.gdb)
}
return s.exchange
}
@@ -221,10 +214,7 @@ func (s *Store) Trader() *TraderStore {
s.mu.Lock()
defer s.mu.Unlock()
if s.trader == nil {
s.trader = &TraderStore{
db: s.db,
decryptFunc: s.decryptFunc,
}
s.trader = NewTraderStore(s.gdb)
}
return s.trader
}
@@ -234,7 +224,7 @@ func (s *Store) Decision() *DecisionStore {
s.mu.Lock()
defer s.mu.Unlock()
if s.decision == nil {
s.decision = &DecisionStore{db: s.db}
s.decision = NewDecisionStore(s.gdb)
}
return s.decision
}
@@ -244,7 +234,7 @@ func (s *Store) Backtest() *BacktestStore {
s.mu.Lock()
defer s.mu.Unlock()
if s.backtest == nil {
s.backtest = &BacktestStore{db: s.db}
s.backtest = NewBacktestStore(s.gdb)
}
return s.backtest
}
@@ -254,7 +244,7 @@ func (s *Store) Position() *PositionStore {
s.mu.Lock()
defer s.mu.Unlock()
if s.position == nil {
s.position = NewPositionStore(s.db)
s.position = NewPositionStore(s.gdb)
}
return s.position
}
@@ -264,7 +254,7 @@ func (s *Store) Strategy() *StrategyStore {
s.mu.Lock()
defer s.mu.Unlock()
if s.strategy == nil {
s.strategy = &StrategyStore{db: s.db}
s.strategy = NewStrategyStore(s.gdb)
}
return s.strategy
}
@@ -274,7 +264,7 @@ func (s *Store) Equity() *EquityStore {
s.mu.Lock()
defer s.mu.Unlock()
if s.equity == nil {
s.equity = &EquityStore{db: s.db}
s.equity = NewEquityStore(s.gdb)
}
return s.equity
}
@@ -284,7 +274,7 @@ func (s *Store) Order() *OrderStore {
s.mu.Lock()
defer s.mu.Unlock()
if s.order == nil {
s.order = NewOrderStore(s.db)
s.order = NewOrderStore(s.gdb)
}
return s.order
}
@@ -294,10 +284,18 @@ func (s *Store) Close() error {
if s.driver != nil {
return s.driver.Close()
}
if s.db != nil {
return s.db.Close()
}
return nil
}
// Driver returns database driver for abstraction
// GormDB returns the GORM database connection
func (s *Store) GormDB() *gorm.DB {
return s.gdb
}
// Driver returns database driver for abstraction (legacy)
func (s *Store) Driver() *DBDriver {
return s.driver
}
@@ -307,11 +305,25 @@ func (s *Store) DBType() DBType {
if s.driver != nil {
return s.driver.Type
}
// Detect from GORM dialector
if s.gdb != nil {
switch s.gdb.Dialector.Name() {
case "postgres":
return DBTypePostgres
default:
return DBTypeSQLite
}
}
return DBTypeSQLite
}
// DB gets underlying database connection (for legacy code compatibility, gradually deprecated)
// Deprecated: use Store methods instead
// q converts query placeholders for current database type (legacy helper)
func (s *Store) q(query string) string {
return convertQuery(query, s.DBType())
}
// DB gets underlying database connection (for legacy code compatibility)
// Deprecated: use GormDB() instead
func (s *Store) DB() *sql.DB {
return s.db
}
@@ -319,24 +331,36 @@ func (s *Store) DB() *sql.DB {
// GetSystemConfig gets a system configuration value by key
func (s *Store) GetSystemConfig(key string) (string, error) {
var value string
err := s.db.QueryRow(`SELECT value FROM system_config WHERE key = ?`, key).Scan(&value)
if err == sql.ErrNoRows {
result := s.gdb.Raw("SELECT value FROM system_config WHERE key = ?", key).Scan(&value)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return "", nil
}
return value, err
return "", result.Error
}
if result.RowsAffected == 0 {
return "", nil
}
return value, nil
}
// SetSystemConfig sets a system configuration value
func (s *Store) SetSystemConfig(key, value string) error {
_, err := s.db.Exec(`
// Use GORM-compatible upsert
return s.gdb.Exec(`
INSERT INTO system_config (key, value) VALUES (?, ?)
ON CONFLICT(key) DO UPDATE SET value = excluded.value
`, key, value)
return err
`, key, value).Error
}
// Transaction executes transaction
func (s *Store) Transaction(fn func(tx *sql.Tx) error) error {
// Transaction executes transaction with GORM
func (s *Store) Transaction(fn func(tx *gorm.DB) error) error {
return s.gdb.Transaction(fn)
}
// TransactionSQL executes transaction with sql.Tx (legacy)
// Deprecated: Use Transaction() instead
func (s *Store) TransactionSQL(fn func(tx *sql.Tx) error) error {
tx, err := s.db.Begin()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
+52 -150
View File
@@ -1,30 +1,33 @@
package store
import (
"database/sql"
"encoding/json"
"fmt"
"time"
"gorm.io/gorm"
)
// StrategyStore strategy storage
type StrategyStore struct {
db *sql.DB
db *gorm.DB
}
// Strategy strategy configuration
type Strategy struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Name string `json:"name"`
Description string `json:"description"`
IsActive bool `json:"is_active"` // whether it is active (a user can only have one active strategy)
IsDefault bool `json:"is_default"` // whether it is a system default strategy
Config string `json:"config"` // strategy configuration in JSON format
ID string `gorm:"primaryKey" json:"id"`
UserID string `gorm:"column:user_id;not null;default:'';index" json:"user_id"`
Name string `gorm:"not null" json:"name"`
Description string `gorm:"default:''" json:"description"`
IsActive bool `gorm:"column:is_active;default:false;index" json:"is_active"`
IsDefault bool `gorm:"column:is_default;default:false" json:"is_default"`
Config string `gorm:"not null;default:'{}'" json:"config"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (Strategy) TableName() string { return "strategies" }
// StrategyConfig strategy configuration details (JSON structure)
type StrategyConfig struct {
// coin source configuration
@@ -136,24 +139,6 @@ type ExternalDataSource struct {
}
// RiskControlConfig risk control configuration
// All parameters are clearly defined without ambiguity:
//
// Position Limits:
// - MaxPositions: max number of coins held simultaneously (CODE ENFORCED)
//
// Trading Leverage (exchange leverage for opening positions):
// - BTCETHMaxLeverage: BTC/ETH max exchange leverage (AI guided)
// - AltcoinMaxLeverage: Altcoin max exchange leverage (AI guided)
//
// Position Value Limits (single position notional value / account equity):
// - BTCETHMaxPositionValueRatio: BTC/ETH max = equity × ratio (CODE ENFORCED)
// - AltcoinMaxPositionValueRatio: Altcoin max = equity × ratio (CODE ENFORCED)
//
// Risk Controls:
// - MaxMarginUsage: max margin utilization percentage (CODE ENFORCED)
// - MinPositionSize: minimum position size in USDT (CODE ENFORCED)
// - MinRiskRewardRatio: min take_profit / stop_loss ratio (AI guided)
// - MinConfidence: min AI confidence to open position (AI guided)
type RiskControlConfig struct {
// Max number of coins held simultaneously (CODE ENFORCED)
MaxPositions int `json:"max_positions"`
@@ -179,38 +164,21 @@ type RiskControlConfig struct {
MinConfidence int `json:"min_confidence"`
}
// NewStrategyStore creates a new StrategyStore
func NewStrategyStore(db *gorm.DB) *StrategyStore {
return &StrategyStore{db: db}
}
func (s *StrategyStore) initTables() error {
_, err := s.db.Exec(`
CREATE TABLE IF NOT EXISTS strategies (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL DEFAULT '',
name TEXT NOT NULL,
description TEXT DEFAULT '',
is_active BOOLEAN DEFAULT 0,
is_default BOOLEAN DEFAULT 0,
config TEXT NOT NULL DEFAULT '{}',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return err
// For PostgreSQL with existing table, skip AutoMigrate
if s.db.Dialector.Name() == "postgres" {
var tableExists int64
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'strategies'`).Scan(&tableExists)
if tableExists > 0 {
return nil
}
// create indexes
_, _ = s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_strategies_user_id ON strategies(user_id)`)
_, _ = s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_strategies_is_active ON strategies(is_active)`)
// trigger: automatically update updated_at on update
_, err = s.db.Exec(`
CREATE TRIGGER IF NOT EXISTS update_strategies_updated_at
AFTER UPDATE ON strategies
BEGIN
UPDATE strategies SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
END
`)
return err
}
return s.db.AutoMigrate(&Strategy{})
}
func (s *StrategyStore) initDefaultData() error {
@@ -322,159 +290,93 @@ Only enter positions when multiple signals resonate. Freely use any effective an
// Create create a strategy
func (s *StrategyStore) Create(strategy *Strategy) error {
_, err := s.db.Exec(`
INSERT INTO strategies (id, user_id, name, description, is_active, is_default, config)
VALUES (?, ?, ?, ?, ?, ?, ?)
`, strategy.ID, strategy.UserID, strategy.Name, strategy.Description, strategy.IsActive, strategy.IsDefault, strategy.Config)
return err
return s.db.Create(strategy).Error
}
// Update update a strategy
func (s *StrategyStore) Update(strategy *Strategy) error {
_, err := s.db.Exec(`
UPDATE strategies SET
name = ?, description = ?, config = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = ? AND user_id = ?
`, strategy.Name, strategy.Description, strategy.Config, strategy.ID, strategy.UserID)
return err
return s.db.Model(&Strategy{}).
Where("id = ? AND user_id = ?", strategy.ID, strategy.UserID).
Updates(map[string]interface{}{
"name": strategy.Name,
"description": strategy.Description,
"config": strategy.Config,
"updated_at": time.Now(),
}).Error
}
// Delete delete a strategy
func (s *StrategyStore) Delete(userID, id string) error {
// do not allow deleting system default strategy
var isDefault bool
s.db.QueryRow(`SELECT is_default FROM strategies WHERE id = ?`, id).Scan(&isDefault)
if isDefault {
var st Strategy
if err := s.db.Where("id = ?", id).First(&st).Error; err == nil && st.IsDefault {
return fmt.Errorf("cannot delete system default strategy")
}
_, err := s.db.Exec(`DELETE FROM strategies WHERE id = ? AND user_id = ?`, id, userID)
return err
return s.db.Where("id = ? AND user_id = ?", id, userID).Delete(&Strategy{}).Error
}
// List get user's strategy list
func (s *StrategyStore) List(userID string) ([]*Strategy, error) {
// get user's own strategies + system default strategy
rows, err := s.db.Query(`
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
FROM strategies
WHERE user_id = ? OR is_default = 1
ORDER BY is_default DESC, created_at DESC
`, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var strategies []*Strategy
for rows.Next() {
var st Strategy
var createdAt, updatedAt string
err := rows.Scan(
&st.ID, &st.UserID, &st.Name, &st.Description,
&st.IsActive, &st.IsDefault, &st.Config,
&createdAt, &updatedAt,
)
err := s.db.Where("user_id = ? OR is_default = ?", userID, true).
Order("is_default DESC, created_at DESC").
Find(&strategies).Error
if err != nil {
return nil, err
}
st.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
st.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
strategies = append(strategies, &st)
}
return strategies, nil
}
// Get get a single strategy
func (s *StrategyStore) Get(userID, id string) (*Strategy, error) {
var st Strategy
var createdAt, updatedAt string
err := s.db.QueryRow(`
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
FROM strategies
WHERE id = ? AND (user_id = ? OR is_default = 1)
`, id, userID).Scan(
&st.ID, &st.UserID, &st.Name, &st.Description,
&st.IsActive, &st.IsDefault, &st.Config,
&createdAt, &updatedAt,
)
err := s.db.Where("id = ? AND (user_id = ? OR is_default = ?)", id, userID, true).
First(&st).Error
if err != nil {
return nil, err
}
st.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
st.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
return &st, nil
}
// GetActive get user's currently active strategy
func (s *StrategyStore) GetActive(userID string) (*Strategy, error) {
var st Strategy
var createdAt, updatedAt string
err := s.db.QueryRow(`
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
FROM strategies
WHERE user_id = ? AND is_active = 1
`, userID).Scan(
&st.ID, &st.UserID, &st.Name, &st.Description,
&st.IsActive, &st.IsDefault, &st.Config,
&createdAt, &updatedAt,
)
if err == sql.ErrNoRows {
err := s.db.Where("user_id = ? AND is_active = ?", userID, true).First(&st).Error
if err == gorm.ErrRecordNotFound {
// no active strategy, return system default strategy
return s.GetDefault()
}
if err != nil {
return nil, err
}
st.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
st.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
return &st, nil
}
// GetDefault get system default strategy
func (s *StrategyStore) GetDefault() (*Strategy, error) {
var st Strategy
var createdAt, updatedAt string
err := s.db.QueryRow(`
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
FROM strategies
WHERE is_default = 1
LIMIT 1
`).Scan(
&st.ID, &st.UserID, &st.Name, &st.Description,
&st.IsActive, &st.IsDefault, &st.Config,
&createdAt, &updatedAt,
)
err := s.db.Where("is_default = ?", true).First(&st).Error
if err != nil {
return nil, err
}
st.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
st.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
return &st, nil
}
// SetActive set active strategy (will first deactivate other strategies)
func (s *StrategyStore) SetActive(userID, strategyID string) error {
// begin transaction
tx, err := s.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
return s.db.Transaction(func(tx *gorm.DB) error {
// first deactivate all strategies for the user
_, err = tx.Exec(`UPDATE strategies SET is_active = 0 WHERE user_id = ?`, userID)
if err != nil {
if err := tx.Model(&Strategy{}).Where("user_id = ?", userID).
Update("is_active", false).Error; err != nil {
return err
}
// activate specified strategy
_, err = tx.Exec(`UPDATE strategies SET is_active = 1 WHERE id = ? AND (user_id = ? OR is_default = 1)`, strategyID, userID)
if err != nil {
return err
}
return tx.Commit()
return tx.Model(&Strategy{}).
Where("id = ? AND (user_id = ? OR is_default = ?)", strategyID, userID, true).
Update("is_active", true).Error
})
}
// Duplicate duplicate a strategy (used to create custom strategy based on default strategy)
+108 -371
View File
@@ -1,43 +1,52 @@
package store
import (
"database/sql"
"fmt"
"strings"
"time"
"gorm.io/gorm"
)
// TraderStore trader storage
type TraderStore struct {
db *sql.DB
decryptFunc func(string) string
db *gorm.DB
}
// NewTraderStore creates a new trader store
func NewTraderStore(db *gorm.DB) *TraderStore {
return &TraderStore{db: db}
}
// Trader trader configuration
type Trader struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Name string `json:"name"`
AIModelID string `json:"ai_model_id"`
ExchangeID string `json:"exchange_id"`
StrategyID string `json:"strategy_id"` // Associated strategy ID
InitialBalance float64 `json:"initial_balance"`
ScanIntervalMinutes int `json:"scan_interval_minutes"`
IsRunning bool `json:"is_running"`
IsCrossMargin bool `json:"is_cross_margin"`
ShowInCompetition bool `json:"show_in_competition"` // Whether to show in competition page
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ID string `gorm:"primaryKey" json:"id"`
UserID string `gorm:"column:user_id;not null;default:default;index" json:"user_id"`
Name string `gorm:"column:name;not null" json:"name"`
AIModelID string `gorm:"column:ai_model_id;not null" json:"ai_model_id"`
ExchangeID string `gorm:"column:exchange_id;not null" json:"exchange_id"`
StrategyID string `gorm:"column:strategy_id;default:''" json:"strategy_id"`
InitialBalance float64 `gorm:"column:initial_balance;not null" json:"initial_balance"`
ScanIntervalMinutes int `gorm:"column:scan_interval_minutes;default:3" json:"scan_interval_minutes"`
IsRunning bool `gorm:"column:is_running;default:false" json:"is_running"`
IsCrossMargin bool `gorm:"column:is_cross_margin;default:true" json:"is_cross_margin"`
ShowInCompetition bool `gorm:"column:show_in_competition;default:true" json:"show_in_competition"`
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime" json:"updated_at"`
// Following fields are deprecated, kept for backward compatibility, new traders should use StrategyID
BTCETHLeverage int `json:"btc_eth_leverage,omitempty"`
AltcoinLeverage int `json:"altcoin_leverage,omitempty"`
TradingSymbols string `json:"trading_symbols,omitempty"`
UseCoinPool bool `json:"use_coin_pool,omitempty"`
UseOITop bool `json:"use_oi_top,omitempty"`
CustomPrompt string `json:"custom_prompt,omitempty"`
OverrideBasePrompt bool `json:"override_base_prompt,omitempty"`
SystemPromptTemplate string `json:"system_prompt_template,omitempty"`
BTCETHLeverage int `gorm:"column:btc_eth_leverage;default:5" json:"btc_eth_leverage,omitempty"`
AltcoinLeverage int `gorm:"column:altcoin_leverage;default:5" json:"altcoin_leverage,omitempty"`
TradingSymbols string `gorm:"column:trading_symbols;default:''" json:"trading_symbols,omitempty"`
UseCoinPool bool `gorm:"column:use_coin_pool;default:false" json:"use_coin_pool,omitempty"`
UseOITop bool `gorm:"column:use_oi_top;default:false" json:"use_oi_top,omitempty"`
CustomPrompt string `gorm:"column:custom_prompt;default:''" json:"custom_prompt,omitempty"`
OverrideBasePrompt bool `gorm:"column:override_base_prompt;default:false" json:"override_base_prompt,omitempty"`
SystemPromptTemplate string `gorm:"column:system_prompt_template;default:default" json:"system_prompt_template,omitempty"`
}
// TableName returns the table name for Trader
func (Trader) TableName() string {
return "traders"
}
// TraderFullConfig trader full configuration (includes AI model, exchange and strategy)
@@ -45,331 +54,130 @@ type TraderFullConfig struct {
Trader *Trader
AIModel *AIModel
Exchange *Exchange
Strategy *Strategy // Associated strategy configuration
Strategy *Strategy
}
func (s *TraderStore) initTables() error {
_, err := s.db.Exec(`
CREATE TABLE IF NOT EXISTS traders (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL DEFAULT 'default',
name TEXT NOT NULL,
ai_model_id TEXT NOT NULL,
exchange_id TEXT NOT NULL,
initial_balance REAL NOT NULL,
scan_interval_minutes INTEGER DEFAULT 3,
is_running BOOLEAN DEFAULT 0,
btc_eth_leverage INTEGER DEFAULT 5,
altcoin_leverage INTEGER DEFAULT 5,
trading_symbols TEXT DEFAULT '',
use_coin_pool BOOLEAN DEFAULT 0,
use_oi_top BOOLEAN DEFAULT 0,
custom_prompt TEXT DEFAULT '',
override_base_prompt BOOLEAN DEFAULT 0,
system_prompt_template TEXT DEFAULT 'default',
is_cross_margin BOOLEAN DEFAULT 1,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return err
}
// Trigger
_, err = s.db.Exec(`
CREATE TRIGGER IF NOT EXISTS update_traders_updated_at
AFTER UPDATE ON traders
BEGIN
UPDATE traders SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
END
`)
if err != nil {
return err
}
// Backward compatibility
alterQueries := []string{
`ALTER TABLE traders ADD COLUMN custom_prompt TEXT DEFAULT ''`,
`ALTER TABLE traders ADD COLUMN override_base_prompt BOOLEAN DEFAULT 0`,
`ALTER TABLE traders ADD COLUMN is_cross_margin BOOLEAN DEFAULT 1`,
`ALTER TABLE traders ADD COLUMN btc_eth_leverage INTEGER DEFAULT 5`,
`ALTER TABLE traders ADD COLUMN altcoin_leverage INTEGER DEFAULT 5`,
`ALTER TABLE traders ADD COLUMN trading_symbols TEXT DEFAULT ''`,
`ALTER TABLE traders ADD COLUMN use_coin_pool BOOLEAN DEFAULT 0`,
`ALTER TABLE traders ADD COLUMN use_oi_top BOOLEAN DEFAULT 0`,
`ALTER TABLE traders ADD COLUMN system_prompt_template TEXT DEFAULT 'default'`,
`ALTER TABLE traders ADD COLUMN strategy_id TEXT DEFAULT ''`,
`ALTER TABLE traders ADD COLUMN show_in_competition BOOLEAN DEFAULT 1`,
}
for _, q := range alterQueries {
s.db.Exec(q)
}
// Migration: Remove FOREIGN KEY constraint from existing traders table
// SQLite doesn't support ALTER TABLE DROP CONSTRAINT, so we need to recreate the table
if err := s.migrateTradersRemoveFK(); err != nil {
// Log but don't fail - this is a best-effort migration
// The constraint may not exist in older databases
}
return nil
}
// migrateTradersRemoveFK removes FOREIGN KEY constraint from traders table if it exists
func (s *TraderStore) migrateTradersRemoveFK() error {
// Check if the table has a foreign key constraint by examining the schema
var sql string
err := s.db.QueryRow(`SELECT sql FROM sqlite_master WHERE type='table' AND name='traders'`).Scan(&sql)
if err != nil {
return err
}
// If no FOREIGN KEY in schema, no migration needed
if !strings.Contains(sql, "FOREIGN KEY") {
// For PostgreSQL with existing table, skip AutoMigrate
if s.db.Dialector.Name() == "postgres" {
var tableExists int64
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'traders'`).Scan(&tableExists)
if tableExists > 0 {
return nil
}
// Recreate table without FOREIGN KEY constraint
_, err = s.db.Exec(`
-- Create new table without FOREIGN KEY
CREATE TABLE IF NOT EXISTS traders_new (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL DEFAULT 'default',
name TEXT NOT NULL,
ai_model_id TEXT NOT NULL,
exchange_id TEXT NOT NULL,
initial_balance REAL NOT NULL,
scan_interval_minutes INTEGER DEFAULT 3,
is_running BOOLEAN DEFAULT 0,
btc_eth_leverage INTEGER DEFAULT 5,
altcoin_leverage INTEGER DEFAULT 5,
trading_symbols TEXT DEFAULT '',
use_coin_pool BOOLEAN DEFAULT 0,
use_oi_top BOOLEAN DEFAULT 0,
custom_prompt TEXT DEFAULT '',
override_base_prompt BOOLEAN DEFAULT 0,
system_prompt_template TEXT DEFAULT 'default',
is_cross_margin BOOLEAN DEFAULT 1,
strategy_id TEXT DEFAULT '',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
-- Copy data from old table
INSERT OR IGNORE INTO traders_new
SELECT id, user_id, name, ai_model_id, exchange_id, initial_balance,
scan_interval_minutes, is_running, btc_eth_leverage, altcoin_leverage,
trading_symbols, use_coin_pool, use_oi_top, custom_prompt,
override_base_prompt, system_prompt_template, is_cross_margin,
COALESCE(strategy_id, ''), created_at, updated_at
FROM traders;
-- Drop old table
DROP TABLE traders;
-- Rename new table
ALTER TABLE traders_new RENAME TO traders;
`)
if err != nil {
return err
}
// Recreate trigger
_, err = s.db.Exec(`
CREATE TRIGGER IF NOT EXISTS update_traders_updated_at
AFTER UPDATE ON traders
BEGIN
UPDATE traders SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
END
`)
return err
}
func (s *TraderStore) decrypt(encrypted string) string {
if s.decryptFunc != nil {
return s.decryptFunc(encrypted)
// Use GORM AutoMigrate
if err := s.db.AutoMigrate(&Trader{}); err != nil {
return fmt.Errorf("failed to migrate traders table: %w", err)
}
return encrypted
return nil
}
// Create creates trader
func (s *TraderStore) Create(trader *Trader) error {
_, err := s.db.Exec(`
INSERT INTO traders (id, user_id, name, ai_model_id, exchange_id, strategy_id, initial_balance,
scan_interval_minutes, is_running, is_cross_margin, show_in_competition,
btc_eth_leverage, altcoin_leverage, trading_symbols, use_coin_pool,
use_oi_top, custom_prompt, override_base_prompt, system_prompt_template)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`, trader.ID, trader.UserID, trader.Name, trader.AIModelID, trader.ExchangeID, trader.StrategyID,
trader.InitialBalance, trader.ScanIntervalMinutes, trader.IsRunning, trader.IsCrossMargin, trader.ShowInCompetition,
trader.BTCETHLeverage, trader.AltcoinLeverage, trader.TradingSymbols, trader.UseCoinPool,
trader.UseOITop, trader.CustomPrompt, trader.OverrideBasePrompt, trader.SystemPromptTemplate)
return err
return s.db.Create(trader).Error
}
// List gets user's trader list
func (s *TraderStore) List(userID string) ([]*Trader, error) {
rows, err := s.db.Query(`
SELECT id, user_id, name, ai_model_id, exchange_id, COALESCE(strategy_id, ''),
initial_balance, scan_interval_minutes, is_running, COALESCE(is_cross_margin, 1),
COALESCE(show_in_competition, 1),
COALESCE(btc_eth_leverage, 5), COALESCE(altcoin_leverage, 5), COALESCE(trading_symbols, ''),
COALESCE(use_coin_pool, 0), COALESCE(use_oi_top, 0), COALESCE(custom_prompt, ''),
COALESCE(override_base_prompt, 0), COALESCE(system_prompt_template, 'default'),
created_at, updated_at
FROM traders WHERE user_id = ? ORDER BY created_at DESC
`, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var traders []*Trader
for rows.Next() {
var t Trader
var createdAt, updatedAt string
err := rows.Scan(
&t.ID, &t.UserID, &t.Name, &t.AIModelID, &t.ExchangeID, &t.StrategyID,
&t.InitialBalance, &t.ScanIntervalMinutes, &t.IsRunning, &t.IsCrossMargin,
&t.ShowInCompetition,
&t.BTCETHLeverage, &t.AltcoinLeverage, &t.TradingSymbols,
&t.UseCoinPool, &t.UseOITop, &t.CustomPrompt, &t.OverrideBasePrompt,
&t.SystemPromptTemplate, &createdAt, &updatedAt,
)
err := s.db.Where("user_id = ?", userID).
Order("created_at DESC").
Find(&traders).Error
if err != nil {
return nil, err
}
t.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
t.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
traders = append(traders, &t)
}
return traders, nil
}
// UpdateStatus updates trader running status
func (s *TraderStore) UpdateStatus(userID, id string, isRunning bool) error {
_, err := s.db.Exec(`UPDATE traders SET is_running = ? WHERE id = ? AND user_id = ?`, isRunning, id, userID)
return err
return s.db.Model(&Trader{}).
Where("id = ? AND user_id = ?", id, userID).
Update("is_running", isRunning).Error
}
// UpdateShowInCompetition updates trader competition visibility
func (s *TraderStore) UpdateShowInCompetition(userID, id string, showInCompetition bool) error {
_, err := s.db.Exec(`UPDATE traders SET show_in_competition = ? WHERE id = ? AND user_id = ?`, showInCompetition, id, userID)
return err
return s.db.Model(&Trader{}).
Where("id = ? AND user_id = ?", id, userID).
Update("show_in_competition", showInCompetition).Error
}
// Update updates trader configuration
func (s *TraderStore) Update(trader *Trader) error {
fmt.Printf("📝 TraderStore.Update: ID=%s, Name=%s, AIModelID=%s, StrategyID=%s\n",
trader.ID, trader.Name, trader.AIModelID, trader.StrategyID)
_, err := s.db.Exec(`
UPDATE traders SET
name = ?,
ai_model_id = ?,
exchange_id = ?,
strategy_id = ?,
initial_balance = CASE WHEN ? > 0 THEN ? ELSE initial_balance END,
scan_interval_minutes = CASE WHEN ? > 0 THEN ? ELSE scan_interval_minutes END,
is_cross_margin = ?,
show_in_competition = ?,
updated_at = CURRENT_TIMESTAMP
WHERE id = ? AND user_id = ?
`, trader.Name, trader.AIModelID, trader.ExchangeID, trader.StrategyID,
trader.InitialBalance, trader.InitialBalance,
trader.ScanIntervalMinutes, trader.ScanIntervalMinutes,
trader.IsCrossMargin, trader.ShowInCompetition,
trader.ID, trader.UserID)
return err
updates := map[string]interface{}{
"name": trader.Name,
"ai_model_id": trader.AIModelID,
"exchange_id": trader.ExchangeID,
"strategy_id": trader.StrategyID,
"is_cross_margin": trader.IsCrossMargin,
"show_in_competition": trader.ShowInCompetition,
}
// Only update these if > 0
if trader.InitialBalance > 0 {
updates["initial_balance"] = trader.InitialBalance
}
if trader.ScanIntervalMinutes > 0 {
updates["scan_interval_minutes"] = trader.ScanIntervalMinutes
}
return s.db.Model(&Trader{}).
Where("id = ? AND user_id = ?", trader.ID, trader.UserID).
Updates(updates).Error
}
// UpdateInitialBalance updates initial balance
func (s *TraderStore) UpdateInitialBalance(userID, id string, newBalance float64) error {
_, err := s.db.Exec(`UPDATE traders SET initial_balance = ? WHERE id = ? AND user_id = ?`, newBalance, id, userID)
return err
return s.db.Model(&Trader{}).
Where("id = ? AND user_id = ?", id, userID).
Update("initial_balance", newBalance).Error
}
// UpdateCustomPrompt updates custom prompt
func (s *TraderStore) UpdateCustomPrompt(userID, id string, customPrompt string, overrideBase bool) error {
_, err := s.db.Exec(`UPDATE traders SET custom_prompt = ?, override_base_prompt = ? WHERE id = ? AND user_id = ?`,
customPrompt, overrideBase, id, userID)
return err
return s.db.Model(&Trader{}).
Where("id = ? AND user_id = ?", id, userID).
Updates(map[string]interface{}{
"custom_prompt": customPrompt,
"override_base_prompt": overrideBase,
}).Error
}
// Delete deletes trader and associated data
func (s *TraderStore) Delete(userID, id string) error {
// Delete associated equity snapshots first
_, _ = s.db.Exec(`DELETE FROM trader_equity_snapshots WHERE trader_id = ?`, id)
s.db.Where("trader_id = ?", id).Delete(&EquitySnapshot{})
// Delete the trader
_, err := s.db.Exec(`DELETE FROM traders WHERE id = ? AND user_id = ?`, id, userID)
return err
return s.db.Where("id = ? AND user_id = ?", id, userID).Delete(&Trader{}).Error
}
// GetFullConfig gets trader full configuration
func (s *TraderStore) GetFullConfig(userID, traderID string) (*TraderFullConfig, error) {
var trader Trader
var aiModel AIModel
var exchange Exchange
var traderCreatedAt, traderUpdatedAt string
var aiModelCreatedAt, aiModelUpdatedAt string
var exchangeCreatedAt, exchangeUpdatedAt string
err := s.db.QueryRow(`
SELECT
t.id, t.user_id, t.name, t.ai_model_id, t.exchange_id, COALESCE(t.strategy_id, ''),
t.initial_balance, t.scan_interval_minutes, t.is_running, COALESCE(t.is_cross_margin, 1),
COALESCE(t.btc_eth_leverage, 5), COALESCE(t.altcoin_leverage, 5), COALESCE(t.trading_symbols, ''),
COALESCE(t.use_coin_pool, 0), COALESCE(t.use_oi_top, 0), COALESCE(t.custom_prompt, ''),
COALESCE(t.override_base_prompt, 0), COALESCE(t.system_prompt_template, 'default'),
t.created_at, t.updated_at,
a.id, a.user_id, a.name, a.provider, a.enabled, a.api_key,
COALESCE(a.custom_api_url, ''), COALESCE(a.custom_model_name, ''), a.created_at, a.updated_at,
e.id, COALESCE(e.exchange_type, '') as exchange_type, COALESCE(e.account_name, '') as account_name,
e.user_id, e.name, e.type, e.enabled, e.api_key, e.secret_key, COALESCE(e.passphrase, ''), e.testnet,
COALESCE(e.hyperliquid_wallet_addr, ''), COALESCE(e.aster_user, ''), COALESCE(e.aster_signer, ''),
COALESCE(e.aster_private_key, ''), COALESCE(e.lighter_wallet_addr, ''), COALESCE(e.lighter_private_key, ''),
COALESCE(e.lighter_api_key_private_key, ''), COALESCE(e.lighter_api_key_index, 0), e.created_at, e.updated_at
FROM traders t
JOIN ai_models a ON t.ai_model_id = a.id AND t.user_id = a.user_id
JOIN exchanges e ON t.exchange_id = e.id AND t.user_id = e.user_id
WHERE t.id = ? AND t.user_id = ?
`, traderID, userID).Scan(
&trader.ID, &trader.UserID, &trader.Name, &trader.AIModelID, &trader.ExchangeID, &trader.StrategyID,
&trader.InitialBalance, &trader.ScanIntervalMinutes, &trader.IsRunning, &trader.IsCrossMargin,
&trader.BTCETHLeverage, &trader.AltcoinLeverage, &trader.TradingSymbols,
&trader.UseCoinPool, &trader.UseOITop, &trader.CustomPrompt, &trader.OverrideBasePrompt,
&trader.SystemPromptTemplate, &traderCreatedAt, &traderUpdatedAt,
&aiModel.ID, &aiModel.UserID, &aiModel.Name, &aiModel.Provider, &aiModel.Enabled, &aiModel.APIKey,
&aiModel.CustomAPIURL, &aiModel.CustomModelName, &aiModelCreatedAt, &aiModelUpdatedAt,
&exchange.ID, &exchange.ExchangeType, &exchange.AccountName,
&exchange.UserID, &exchange.Name, &exchange.Type, &exchange.Enabled,
&exchange.APIKey, &exchange.SecretKey, &exchange.Passphrase, &exchange.Testnet, &exchange.HyperliquidWalletAddr,
&exchange.AsterUser, &exchange.AsterSigner, &exchange.AsterPrivateKey,
&exchange.LighterWalletAddr, &exchange.LighterPrivateKey, &exchange.LighterAPIKeyPrivateKey, &exchange.LighterAPIKeyIndex,
&exchangeCreatedAt, &exchangeUpdatedAt,
)
err := s.db.Where("id = ? AND user_id = ?", traderID, userID).First(&trader).Error
if err != nil {
return nil, err
}
trader.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", traderCreatedAt)
trader.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", traderUpdatedAt)
aiModel.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", aiModelCreatedAt)
aiModel.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", aiModelUpdatedAt)
exchange.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", exchangeCreatedAt)
exchange.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", exchangeUpdatedAt)
// Get AI model
var aiModel AIModel
err = s.db.Where("id = ? AND user_id = ?", trader.AIModelID, userID).First(&aiModel).Error
if err != nil {
return nil, fmt.Errorf("failed to get AI model: %w", err)
}
// Decrypt
aiModel.APIKey = s.decrypt(aiModel.APIKey)
exchange.APIKey = s.decrypt(exchange.APIKey)
exchange.SecretKey = s.decrypt(exchange.SecretKey)
exchange.Passphrase = s.decrypt(exchange.Passphrase)
exchange.AsterPrivateKey = s.decrypt(exchange.AsterPrivateKey)
exchange.LighterPrivateKey = s.decrypt(exchange.LighterPrivateKey)
exchange.LighterAPIKeyPrivateKey = s.decrypt(exchange.LighterAPIKeyPrivateKey)
// Get exchange
var exchange Exchange
err = s.db.Where("id = ? AND user_id = ?", trader.ExchangeID, userID).First(&exchange).Error
if err != nil {
return nil, fmt.Errorf("failed to get exchange: %w", err)
}
// Load associated strategy
var strategy *Strategy
@@ -392,119 +200,48 @@ func (s *TraderStore) GetFullConfig(userID, traderID string) (*TraderFullConfig,
// getStrategyByID internal method: gets strategy by ID
func (s *TraderStore) getStrategyByID(userID, strategyID string) (*Strategy, error) {
var strategy Strategy
var createdAt, updatedAt string
err := s.db.QueryRow(`
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
FROM strategies WHERE id = ? AND (user_id = ? OR is_default = 1)
`, strategyID, userID).Scan(
&strategy.ID, &strategy.UserID, &strategy.Name, &strategy.Description,
&strategy.IsActive, &strategy.IsDefault, &strategy.Config, &createdAt, &updatedAt,
)
err := s.db.Where("id = ? AND (user_id = ? OR is_default = ?)", strategyID, userID, true).
First(&strategy).Error
if err != nil {
return nil, err
}
strategy.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
strategy.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
return &strategy, nil
}
// getActiveOrDefaultStrategy internal method: gets user's active strategy or system default strategy
func (s *TraderStore) getActiveOrDefaultStrategy(userID string) (*Strategy, error) {
var strategy Strategy
var createdAt, updatedAt string
// First try to get user's active strategy
err := s.db.QueryRow(`
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
FROM strategies WHERE user_id = ? AND is_active = 1
`, userID).Scan(
&strategy.ID, &strategy.UserID, &strategy.Name, &strategy.Description,
&strategy.IsActive, &strategy.IsDefault, &strategy.Config, &createdAt, &updatedAt,
)
err := s.db.Where("user_id = ? AND is_active = ?", userID, true).First(&strategy).Error
if err == nil {
strategy.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
strategy.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
return &strategy, nil
}
// Fallback to system default strategy
err = s.db.QueryRow(`
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
FROM strategies WHERE is_default = 1 LIMIT 1
`).Scan(
&strategy.ID, &strategy.UserID, &strategy.Name, &strategy.Description,
&strategy.IsActive, &strategy.IsDefault, &strategy.Config, &createdAt, &updatedAt,
)
err = s.db.Where("is_default = ?", true).First(&strategy).Error
if err != nil {
return nil, err
}
strategy.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
strategy.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
return &strategy, nil
}
// ListAll gets all users' trader list
// GetByID gets a trader by ID without requiring userID (for public APIs)
func (s *TraderStore) GetByID(traderID string) (*Trader, error) {
var t Trader
var createdAt, updatedAt string
err := s.db.QueryRow(`
SELECT id, user_id, name, ai_model_id, exchange_id, COALESCE(strategy_id, ''),
initial_balance, scan_interval_minutes, is_running, COALESCE(is_cross_margin, 1),
COALESCE(btc_eth_leverage, 5), COALESCE(altcoin_leverage, 5), COALESCE(trading_symbols, ''),
COALESCE(use_coin_pool, 0), COALESCE(use_oi_top, 0), COALESCE(custom_prompt, ''),
COALESCE(override_base_prompt, 0), COALESCE(system_prompt_template, 'default'),
created_at, updated_at
FROM traders WHERE id = ?
`, traderID).Scan(
&t.ID, &t.UserID, &t.Name, &t.AIModelID, &t.ExchangeID, &t.StrategyID,
&t.InitialBalance, &t.ScanIntervalMinutes, &t.IsRunning, &t.IsCrossMargin,
&t.BTCETHLeverage, &t.AltcoinLeverage, &t.TradingSymbols,
&t.UseCoinPool, &t.UseOITop, &t.CustomPrompt, &t.OverrideBasePrompt,
&t.SystemPromptTemplate, &createdAt, &updatedAt,
)
var trader Trader
err := s.db.Where("id = ?", traderID).First(&trader).Error
if err != nil {
return nil, err
}
t.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
t.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
return &t, nil
return &trader, nil
}
// ListAll gets all traders
func (s *TraderStore) ListAll() ([]*Trader, error) {
rows, err := s.db.Query(`
SELECT id, user_id, name, ai_model_id, exchange_id, COALESCE(strategy_id, ''),
initial_balance, scan_interval_minutes, is_running, COALESCE(is_cross_margin, 1),
COALESCE(show_in_competition, 1),
COALESCE(btc_eth_leverage, 5), COALESCE(altcoin_leverage, 5), COALESCE(trading_symbols, ''),
COALESCE(use_coin_pool, 0), COALESCE(use_oi_top, 0), COALESCE(custom_prompt, ''),
COALESCE(override_base_prompt, 0), COALESCE(system_prompt_template, 'default'),
created_at, updated_at
FROM traders ORDER BY created_at DESC
`)
if err != nil {
return nil, err
}
defer rows.Close()
var traders []*Trader
for rows.Next() {
var t Trader
var createdAt, updatedAt string
err := rows.Scan(
&t.ID, &t.UserID, &t.Name, &t.AIModelID, &t.ExchangeID, &t.StrategyID,
&t.InitialBalance, &t.ScanIntervalMinutes, &t.IsRunning, &t.IsCrossMargin,
&t.ShowInCompetition,
&t.BTCETHLeverage, &t.AltcoinLeverage, &t.TradingSymbols,
&t.UseCoinPool, &t.UseOITop, &t.CustomPrompt, &t.OverrideBasePrompt,
&t.SystemPromptTemplate, &createdAt, &updatedAt,
)
err := s.db.Order("created_at DESC").Find(&traders).Error
if err != nil {
return nil, err
}
t.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
t.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
traders = append(traders, &t)
}
return traders, nil
}
+59 -86
View File
@@ -2,27 +2,30 @@ package store
import (
"crypto/rand"
"database/sql"
"encoding/base32"
"time"
"gorm.io/gorm"
)
// UserStore user storage
type UserStore struct {
db *sql.DB
db *gorm.DB
}
// User user
// User user model
type User struct {
ID string `json:"id"`
Email string `json:"email"`
PasswordHash string `json:"-"`
OTPSecret string `json:"-"`
OTPVerified bool `json:"otp_verified"`
ID string `gorm:"primaryKey" json:"id"`
Email string `gorm:"uniqueIndex:idx_users_email;not null" json:"email"`
PasswordHash string `gorm:"column:password_hash;not null" json:"-"`
OTPSecret string `gorm:"column:otp_secret" json:"-"`
OTPVerified bool `gorm:"column:otp_verified;default:false" json:"otp_verified"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (User) TableName() string { return "users" }
// GenerateOTPSecret generates OTP secret
func GenerateOTPSecret() (string, error) {
secret := make([]byte, 20)
@@ -33,131 +36,101 @@ func GenerateOTPSecret() (string, error) {
return base32.StdEncoding.EncodeToString(secret), nil
}
func (s *UserStore) initTables() error {
_, err := s.db.Exec(`
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
email TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
otp_secret TEXT,
otp_verified BOOLEAN DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return err
}
// NewUserStore creates a new UserStore
func NewUserStore(db *gorm.DB) *UserStore {
return &UserStore{db: db}
}
// Trigger
_, err = s.db.Exec(`
CREATE TRIGGER IF NOT EXISTS update_users_updated_at
AFTER UPDATE ON users
BEGIN
UPDATE users SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
END
`)
if err != nil {
return err
func (s *UserStore) initTables() error {
// For PostgreSQL with existing table, skip AutoMigrate to avoid index conflicts
if s.db.Dialector.Name() == "postgres" {
var tableExists int64
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'users'`).Scan(&tableExists)
if tableExists > 0 {
// Table exists - manually ensure all columns exist
// Core columns (should already exist)
s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS email TEXT NOT NULL DEFAULT ''`)
s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS password_hash TEXT NOT NULL DEFAULT ''`)
s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP`)
s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP`)
// OTP columns (added later)
s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS otp_secret TEXT DEFAULT ''`)
s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS otp_verified BOOLEAN DEFAULT FALSE`)
// Ensure unique index exists on email (don't care about the name)
var indexExists int64
s.db.Raw(`
SELECT COUNT(*) FROM pg_indexes
WHERE tablename = 'users' AND indexdef LIKE '%email%' AND indexdef LIKE '%UNIQUE%'
`).Scan(&indexExists)
if indexExists == 0 {
s.db.Exec("CREATE UNIQUE INDEX idx_users_email ON users(email)")
}
return nil
}
}
return s.db.AutoMigrate(&User{})
}
// Create creates user
func (s *UserStore) Create(user *User) error {
_, err := s.db.Exec(`
INSERT INTO users (id, email, password_hash, otp_secret, otp_verified)
VALUES (?, ?, ?, ?, ?)
`, user.ID, user.Email, user.PasswordHash, user.OTPSecret, user.OTPVerified)
return err
return s.db.Create(user).Error
}
// GetByEmail gets user by email
func (s *UserStore) GetByEmail(email string) (*User, error) {
var user User
var createdAt, updatedAt string
err := s.db.QueryRow(`
SELECT id, email, password_hash, otp_secret, otp_verified, created_at, updated_at
FROM users WHERE email = ?
`, email).Scan(
&user.ID, &user.Email, &user.PasswordHash, &user.OTPSecret,
&user.OTPVerified, &createdAt, &updatedAt,
)
err := s.db.Where("email = ?", email).First(&user).Error
if err != nil {
return nil, err
}
user.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
user.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
return &user, nil
}
// GetByID gets user by ID
func (s *UserStore) GetByID(userID string) (*User, error) {
var user User
var createdAt, updatedAt string
err := s.db.QueryRow(`
SELECT id, email, password_hash, otp_secret, otp_verified, created_at, updated_at
FROM users WHERE id = ?
`, userID).Scan(
&user.ID, &user.Email, &user.PasswordHash, &user.OTPSecret,
&user.OTPVerified, &createdAt, &updatedAt,
)
err := s.db.Where("id = ?", userID).First(&user).Error
if err != nil {
return nil, err
}
user.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
user.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
return &user, nil
}
// Count returns the total number of users
func (s *UserStore) Count() (int, error) {
var count int
err := s.db.QueryRow(`SELECT COUNT(*) FROM users`).Scan(&count)
return count, err
var count int64
err := s.db.Model(&User{}).Count(&count).Error
return int(count), err
}
// GetAllIDs gets all user IDs
func (s *UserStore) GetAllIDs() ([]string, error) {
rows, err := s.db.Query(`SELECT id FROM users ORDER BY id`)
if err != nil {
return nil, err
}
defer rows.Close()
var userIDs []string
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return nil, err
}
userIDs = append(userIDs, userID)
}
return userIDs, nil
err := s.db.Model(&User{}).Order("id").Pluck("id", &userIDs).Error
return userIDs, err
}
// UpdateOTPVerified updates OTP verification status
func (s *UserStore) UpdateOTPVerified(userID string, verified bool) error {
_, err := s.db.Exec(`UPDATE users SET otp_verified = ? WHERE id = ?`, verified, userID)
return err
return s.db.Model(&User{}).Where("id = ?", userID).Update("otp_verified", verified).Error
}
// UpdatePassword updates password
func (s *UserStore) UpdatePassword(userID, passwordHash string) error {
_, err := s.db.Exec(`
UPDATE users SET password_hash = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?
`, passwordHash, userID)
return err
return s.db.Model(&User{}).Where("id = ?", userID).Updates(map[string]interface{}{
"password_hash": passwordHash,
"updated_at": time.Now(),
}).Error
}
// EnsureAdmin ensures admin user exists
func (s *UserStore) EnsureAdmin() error {
var count int
err := s.db.QueryRow(`SELECT COUNT(*) FROM users WHERE id = 'admin'`).Scan(&count)
if err != nil {
return err
}
var count int64
s.db.Model(&User{}).Where("id = ?", "admin").Count(&count)
if count > 0 {
return nil
}
+12 -8
View File
@@ -1,12 +1,13 @@
package trader
import (
"database/sql"
"nofx/store"
"testing"
"time"
_ "github.com/mattn/go-sqlite3"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// TestScenario represents a trading scenario to test
@@ -116,11 +117,12 @@ func runStandardTests(t *testing.T, exchangeName string) {
for _, scenario := range scenarios {
t.Run(scenario.Name, func(t *testing.T) {
// Setup database
db, err := sql.Open("sqlite3", ":memory:")
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("Failed to create test database: %v", err)
}
defer db.Close()
positionStore := store.NewPositionStore(db)
if err := positionStore.InitTables(); err != nil {
@@ -199,11 +201,12 @@ func TestAllExchangesStandardScenarios(t *testing.T) {
// TestPositionAccumulationBug tests that positions don't accumulate incorrectly
func TestPositionAccumulationBug(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("Failed to create test database: %v", err)
}
defer db.Close()
positionStore := store.NewPositionStore(db)
if err := positionStore.InitTables(); err != nil {
@@ -283,11 +286,12 @@ func TestPositionAccumulationBug(t *testing.T) {
// TestQuantityPrecision tests handling of quantity precision issues
func TestQuantityPrecision(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("Failed to create test database: %v", err)
}
defer db.Close()
positionStore := store.NewPositionStore(db)
if err := positionStore.InitTables(); err != nil {
+9 -6
View File
@@ -1,13 +1,14 @@
package trader
import (
"database/sql"
"math"
"nofx/store"
"testing"
"time"
_ "github.com/mattn/go-sqlite3"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// TestHyperliquidOrderDirectionParsing tests Dir field parsing
@@ -75,11 +76,12 @@ func TestHyperliquidOrderDirectionParsing(t *testing.T) {
// TestHyperliquidPositionBuilding tests the complete flow of position building
func TestHyperliquidPositionBuilding(t *testing.T) {
// Setup in-memory database
db, err := sql.Open("sqlite3", ":memory:")
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("Failed to create test database: %v", err)
}
defer db.Close()
// Initialize stores
positionStore := store.NewPositionStore(db)
@@ -304,11 +306,12 @@ func TestHyperliquidPositionBuilding(t *testing.T) {
// TestHyperliquidBugScenario tests the exact bug we fixed
func TestHyperliquidBugScenario(t *testing.T) {
// Setup database
db, err := sql.Open("sqlite3", ":memory:")
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("Failed to create test database: %v", err)
}
defer db.Close()
positionStore := store.NewPositionStore(db)
if err := positionStore.InitTables(); err != nil {
@@ -1,5 +1,5 @@
import { motion } from 'framer-motion'
import { Bot, TrendingUp, Layers, Zap, Hexagon, Crosshair } from 'lucide-react'
import { TrendingUp, Layers, Zap, Hexagon, Crosshair } from 'lucide-react'
const agents = [
{
+10 -3
View File
@@ -1,8 +1,15 @@
import { motion } from 'framer-motion'
import { Activity, BarChart3, Globe, Wifi, Server, Database, Lock } from 'lucide-react'
import { useState, useEffect } from 'react'
const generateLog = (id) => {
interface LogEntry {
id: number
time: string
type: string
msg: string
color: string
}
const generateLog = (id: number): LogEntry => {
const types = ['EXE', 'ARB', 'LIQ', 'NET', 'SYS']
const pairs = ['BTC-USDT', 'ETH-PERP', 'SOL-USDT', 'BNB-BUSD']
const actions = ['BUY', 'SELL', 'SHORT', 'LONG']
@@ -37,7 +44,7 @@ const generateLog = (id) => {
}
export default function LiveFeed() {
const [logs, setLogs] = useState([])
const [logs, setLogs] = useState<LogEntry[]>([])
useEffect(() => {
// Initial population
@@ -159,11 +159,24 @@ export default function TerminalHero() {
<span className="text-stroke-1 text-transparent bg-clip-text bg-gradient-to-r from-nofx-gold via-white to-nofx-gold animate-shimmer bg-[length:200%_auto]">TRADING</span>
</h1>
<p className="max-w-xl text-zinc-400 text-lg mb-10 font-light leading-relaxed">
<p className="max-w-xl text-zinc-400 text-lg mb-6 font-light leading-relaxed">
The World's First Open-Source Agentic Trading OS.
Deploy autonomous high-frequency trading agents powered by advanced LLMs.
</p>
{/* Market Access Strip - Prominent Display */}
<div className="flex flex-wrap gap-4 mb-12 font-mono">
{['CRYPTO', 'US STOCKS', 'FOREX', 'METALS'].map((market) => (
<div key={market} className="flex items-center gap-3 px-4 py-2 rounded bg-zinc-900 border border-zinc-700 text-white font-bold tracking-wider hover:border-nofx-gold hover:shadow-[0_0_15px_rgba(255,215,0,0.3)] transition-all duration-300">
<span className="relative flex h-2 w-2">
<span className="animate-ping absolute inline-flex h-full w-full rounded-full bg-nofx-success opacity-75"></span>
<span className="relative inline-flex rounded-full h-2 w-2 bg-nofx-success"></span>
</span>
{market}
</div>
))}
</div>
{/* Command Line Input Simulation */}
<div className="w-full max-w-lg h-12 bg-black/50 border border-zinc-800 rounded flex items-center px-4 mb-10 font-mono text-sm shadow-2xl backdrop-blur-sm group hover:border-nofx-gold/50 transition-colors cursor-text" onClick={() => document.getElementById('market-scanner')?.scrollIntoView({ behavior: 'smooth' })}>
<span className="text-nofx-success mr-2"></span>
@@ -269,7 +282,7 @@ export default function TerminalHero() {
import { OFFICIAL_LINKS } from '../../../constants/branding'
function CommunityStats() {
const { stars, forks, contributors, isLoading, error } = useGitHubStats('tinkle-community', 'nofx')
const { stars, forks, contributors, isLoading, error } = useGitHubStats('NoFxAiOS', 'nofx')
const stats = [
{