diff --git a/backtest/account.go b/backtest/account.go index abf891a9..49468c1a 100644 --- a/backtest/account.go +++ b/backtest/account.go @@ -122,10 +122,10 @@ func (acc *BacktestAccount) Close(symbol, side string, quantity float64, price f } execPrice := applySlippage(price, acc.slippageRate, side, false) - notional := execPrice * quantity - closingFee := notional * acc.feeRate + closeNotional := execPrice * quantity // Notional at close price (for fee calculation) + closingFee := closeNotional * acc.feeRate - // Calculate proportional opening fee for the quantity being closed + // Calculate proportional values based on the portion being closed closePortion := quantity / pos.Quantity openingFeePortion := pos.AccumulatedFee * closePortion totalFee := closingFee + openingFeePortion @@ -133,13 +133,17 @@ func (acc *BacktestAccount) Close(symbol, side string, quantity float64, price f realized := realizedPnL(pos, quantity, execPrice) marginPortion := pos.Margin * closePortion + // BUG FIX: Calculate notional portion based on ENTRY price, not close price + // pos.Notional tracks the total notional at entry, so we must subtract proportionally + entryNotionalPortion := pos.Notional * closePortion + // Note: Opening fee was already deducted from cash when opening, so we only deduct closing fee here acc.cash += marginPortion + realized - closingFee // But for realized P&L tracking, we include both fees acc.realizedPnL += realized - totalFee pos.Quantity -= quantity - pos.Notional -= notional + pos.Notional -= entryNotionalPortion // FIX: Use entry notional portion, not close notional pos.Margin -= marginPortion pos.AccumulatedFee -= openingFeePortion // Reduce tracked opening fee diff --git a/backtest/datafeed.go b/backtest/datafeed.go index 5480fffc..0e06ed82 100644 --- a/backtest/datafeed.go +++ b/backtest/datafeed.go @@ -124,11 +124,23 @@ func (df *DataFeed) DecisionBarCount() int { } func (df *DataFeed) DecisionTimestamp(index int) int64 { + // Bounds check to prevent panic + if index < 0 || index >= len(df.decisionTimes) { + return 0 + } return df.decisionTimes[index] } func (df *DataFeed) sliceUpTo(symbol, tf string, ts int64) []market.Kline { - series := df.symbolSeries[symbol].byTF[tf] + // Nil checks to prevent panic + ss, ok := df.symbolSeries[symbol] + if !ok || ss == nil { + return nil + } + series, ok := ss.byTF[tf] + if !ok || series == nil { + return nil + } idx := sort.Search(len(series.closeTimes), func(i int) bool { return series.closeTimes[i] > ts }) diff --git a/backtest/metrics.go b/backtest/metrics.go index e7fe5bc5..a7aac519 100644 --- a/backtest/metrics.go +++ b/backtest/metrics.go @@ -91,8 +91,13 @@ func maxDrawdown(points []EquityPoint, state *BacktestState) float64 { return maxDD } +// sharpeRatio calculates the Sharpe ratio from equity points. +// Uses sample standard deviation (n-1) and annualizes assuming ~252 trading days. +// Returns math.NaN() for edge cases (insufficient data, zero variance). func sharpeRatio(points []EquityPoint) float64 { - if len(points) < 2 { + // Need at least 10 data points for meaningful Sharpe calculation + const minDataPoints = 10 + if len(points) < minDataPoints { return 0 } @@ -108,34 +113,42 @@ func sharpeRatio(points []EquityPoint) float64 { returns = append(returns, ret) prev = curr } - if len(returns) == 0 { + if len(returns) < minDataPoints-1 { return 0 } + // Calculate mean return mean := 0.0 for _, r := range returns { mean += r } mean /= float64(len(returns)) + // Calculate sample variance (using n-1 for unbiased estimator) variance := 0.0 for _, r := range returns { diff := r - mean variance += diff * diff } - variance /= float64(len(returns)) + if len(returns) > 1 { + variance /= float64(len(returns) - 1) + } std := math.Sqrt(variance) - if std == 0 { - if mean > 0 { - return 999 - } - if mean < 0 { - return -999 - } + if std < 1e-10 { + // Zero or near-zero volatility - return 0 instead of infinity/NaN return 0 } - return mean / std + + // Calculate Sharpe ratio (assuming risk-free rate = 0 for crypto) + // Annualize by multiplying by sqrt(periods per year) + // Assuming each equity point represents ~1 hour, we have ~8760 periods/year + // For conservative estimate, use sqrt(252) as if daily returns + periodsPerYear := 252.0 + annualizationFactor := math.Sqrt(periodsPerYear) + + sharpe := (mean / std) * annualizationFactor + return sharpe } func fillTradeMetrics(metrics *Metrics, events []TradeEvent) { @@ -189,7 +202,8 @@ func fillTradeMetrics(metrics *Metrics, events []TradeEvent) { if totalLossAmount > 0 { metrics.ProfitFactor = totalWinAmount / totalLossAmount } else if totalWinAmount > 0 { - metrics.ProfitFactor = 999 + // No losses but have wins - use a high but reasonable cap + metrics.ProfitFactor = 100.0 } bestSymbol := "" diff --git a/backtest/persistence_db.go b/backtest/persistence_db.go index 06d4dfeb..d494ed65 100644 --- a/backtest/persistence_db.go +++ b/backtest/persistence_db.go @@ -2,15 +2,39 @@ package backtest import ( "database/sql" + "fmt" + "strings" ) var persistenceDB *sql.DB +var dbIsPostgres bool // UseDatabase enables database-backed persistence for all backtest storage operations. +// If isPostgres is true, queries will use $1, $2... placeholders instead of ? func UseDatabase(db *sql.DB) { persistenceDB = db } +// UseDatabaseWithType enables database-backed persistence with explicit type. +func UseDatabaseWithType(db *sql.DB, isPostgres bool) { + persistenceDB = db + dbIsPostgres = isPostgres +} + func usingDB() bool { return persistenceDB != nil } + +// convertQuery converts ? placeholders to $1, $2, etc. for PostgreSQL +func convertQuery(query string) string { + if !dbIsPostgres { + return query + } + result := query + index := 1 + for strings.Contains(result, "?") { + result = strings.Replace(result, "?", fmt.Sprintf("$%d", index), 1) + index++ + } + return result +} diff --git a/backtest/retention.go b/backtest/retention.go index a9d34d74..49ec3542 100644 --- a/backtest/retention.go +++ b/backtest/retention.go @@ -73,12 +73,12 @@ func enforceRetentionDB(maxRuns int) { RunStateFailed, RunStateLiquidated, } - query := ` + query := convertQuery(` SELECT run_id FROM backtest_runs WHERE state IN (?, ?, ?, ?) ORDER BY updated_at DESC OFFSET ? - ` + `) rows, err := persistenceDB.Query(query, finalStates[0], finalStates[1], finalStates[2], finalStates[3], maxRuns) if err != nil { diff --git a/backtest/runner.go b/backtest/runner.go index 5f039bd4..70d2c5eb 100644 --- a/backtest/runner.go +++ b/backtest/runner.go @@ -60,8 +60,9 @@ type Runner struct { aiCache *AICache cachePath string - lockInfo *RunLockInfo - lockStop chan struct{} + lockInfo *RunLockInfo + lockStop chan struct{} + lockStopOnce sync.Once // Ensures lockStop is closed only once } // NewRunner constructs a backtest runner. @@ -175,10 +176,12 @@ func (r *Runner) lockHeartbeatLoop() { } func (r *Runner) releaseLock() { - if r.lockStop != nil { - close(r.lockStop) - r.lockStop = nil - } + // Use sync.Once to ensure channel is closed exactly once, preventing panic on double-close + r.lockStopOnce.Do(func() { + if r.lockStop != nil { + close(r.lockStop) + } + }) if err := deleteRunLock(r.cfg.RunID); err != nil { logger.Infof("failed to release lock for %s: %v", r.cfg.RunID, err) } @@ -297,9 +300,12 @@ func (r *Runner) stepOnce() error { if shouldDecide { ctx, rec, err := r.buildDecisionContext(ts, marketData, multiTF, priceMap, callCount) if err != nil { - rec.Success = false - rec.ErrorMessage = fmt.Sprintf("failed to build trading context: %v", err) - _ = r.logDecision(rec) + // Defensive nil check to prevent panic if buildDecisionContext returns error with nil record + if rec != nil { + rec.Success = false + rec.ErrorMessage = fmt.Sprintf("failed to build trading context: %v", err) + _ = r.logDecision(rec) + } return err } record = rec @@ -617,6 +623,10 @@ func (r *Runner) invokeAIWithRetry(ctx *kernel.Context) (*kernel.FullDecision, e func (r *Runner) executeDecision(dec kernel.Decision, priceMap map[string]float64, ts int64, cycle int) (store.DecisionAction, []TradeEvent, string, error) { symbol := dec.Symbol + if symbol == "" { + return store.DecisionAction{}, nil, "", fmt.Errorf("empty symbol in decision") + } + usedLeverage := r.resolveLeverage(dec.Leverage, symbol) actionRecord := store.DecisionAction{ Action: dec.Action, @@ -625,9 +635,13 @@ func (r *Runner) executeDecision(dec kernel.Decision, priceMap map[string]float6 Timestamp: time.UnixMilli(ts).UTC(), } - basePrice := priceMap[symbol] - if basePrice <= 0 { - return actionRecord, nil, "", fmt.Errorf("price unavailable for %s", symbol) + if priceMap == nil { + return actionRecord, nil, "", fmt.Errorf("priceMap is nil") + } + + basePrice, ok := priceMap[symbol] + if !ok || basePrice <= 0 { + return actionRecord, nil, "", fmt.Errorf("price unavailable for %s (found=%v, price=%.4f)", symbol, ok, basePrice) } fillPrice := r.executionPrice(symbol, basePrice, ts) @@ -757,6 +771,9 @@ func (r *Runner) executeDecision(dec kernel.Decision, priceMap map[string]float6 } } +// MinPositionSizeUSD is the minimum position size in USD to avoid dust positions +const MinPositionSizeUSD = 10.0 + func (r *Runner) determineQuantity(dec kernel.Decision, price float64) float64 { snapshot := r.snapshotState() equity := snapshot.Equity @@ -788,6 +805,13 @@ func (r *Runner) determineQuantity(dec kernel.Decision, price float64) float64 { sizeUSD = maxPositionValue } + // Reject positions below minimum size to avoid dust positions + if sizeUSD < MinPositionSizeUSD { + logger.Infof("📊 Backtest: rejecting position size %.2f USD (below minimum %.2f USD)", + sizeUSD, MinPositionSizeUSD) + return 0 + } + qty := sizeUSD / price if qty < 0 { qty = 0 @@ -805,20 +829,37 @@ func (r *Runner) determineCloseQuantity(symbol, side string, dec kernel.Decision } func (r *Runner) resolveLeverage(requested int, symbol string) int { - if requested > 0 { - return requested - } sym := strings.ToUpper(symbol) - if sym == "BTCUSDT" || sym == "ETHUSDT" { - if r.cfg.Leverage.BTCETHLeverage > 0 { - return r.cfg.Leverage.BTCETHLeverage + isBTCETH := sym == "BTCUSDT" || sym == "ETHUSDT" + + // Determine configured max leverage for this symbol type + var maxLeverage int + if isBTCETH { + maxLeverage = r.cfg.Leverage.BTCETHLeverage + if maxLeverage <= 0 { + maxLeverage = 10 // Default max for BTC/ETH } } else { - if r.cfg.Leverage.AltcoinLeverage > 0 { - return r.cfg.Leverage.AltcoinLeverage + maxLeverage = r.cfg.Leverage.AltcoinLeverage + if maxLeverage <= 0 { + maxLeverage = 5 // Default max for altcoins } } - return 5 + + // Use requested leverage if provided, otherwise use max as default + leverage := requested + if leverage <= 0 { + leverage = maxLeverage + } + + // Enforce max leverage limit + if leverage > maxLeverage { + logger.Infof("📊 Backtest: capping leverage from %dx to %dx for %s", + leverage, maxLeverage, symbol) + leverage = maxLeverage + } + + return leverage } func (r *Runner) remainingPosition(symbol, side string) float64 { @@ -854,6 +895,12 @@ func (r *Runner) convertPositions(priceMap map[string]float64) []kernel.Position list := make([]kernel.PositionInfo, 0, len(positions)) for _, pos := range positions { price := priceMap[pos.Symbol] + pnl := unrealizedPnL(pos, price) + // Calculate P&L percentage based on entry notional (position cost) + pnlPct := 0.0 + if pos.Notional > 0 { + pnlPct = (pnl / pos.Notional) * 100 + } list = append(list, kernel.PositionInfo{ Symbol: pos.Symbol, Side: pos.Side, @@ -861,8 +908,8 @@ func (r *Runner) convertPositions(priceMap map[string]float64) []kernel.Position MarkPrice: price, Quantity: pos.Quantity, Leverage: pos.Leverage, - UnrealizedPnL: unrealizedPnL(pos, price), - UnrealizedPnLPct: 0, + UnrealizedPnL: pnl, + UnrealizedPnLPct: pnlPct, LiquidationPrice: pos.LiquidationPrice, MarginUsed: pos.Margin, UpdateTime: time.Now().UnixMilli(), diff --git a/backtest/storage_db_impl.go b/backtest/storage_db_impl.go index f8899aa0..2eb2b407 100644 --- a/backtest/storage_db_impl.go +++ b/backtest/storage_db_impl.go @@ -17,17 +17,17 @@ func saveCheckpointDB(runID string, ckpt *Checkpoint) error { if err != nil { return err } - _, err = persistenceDB.Exec(` + _, err = persistenceDB.Exec(convertQuery(` INSERT INTO backtest_checkpoints (run_id, payload, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP) ON CONFLICT(run_id) DO UPDATE SET payload=excluded.payload, updated_at=CURRENT_TIMESTAMP - `, runID, data) + `), runID, data) return err } func loadCheckpointDB(runID string) (*Checkpoint, error) { var payload []byte - err := persistenceDB.QueryRow(`SELECT payload FROM backtest_checkpoints WHERE run_id = ?`, runID).Scan(&payload) + err := persistenceDB.QueryRow(convertQuery(`SELECT payload FROM backtest_checkpoints WHERE run_id = ?`), runID).Scan(&payload) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, os.ErrNotExist @@ -57,25 +57,25 @@ func saveConfigDB(runID string, cfg *BacktestConfig) error { if userID == "" { userID = "default" } - _, err = persistenceDB.Exec(` + _, err = persistenceDB.Exec(convertQuery(` INSERT INTO backtest_runs (run_id, user_id, config_json, prompt_template, custom_prompt, override_prompt, ai_provider, ai_model, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(run_id) DO NOTHING - `, runID, userID, data, template, cfg.CustomPrompt, cfg.OverrideBasePrompt, cfg.AICfg.Provider, cfg.AICfg.Model, now, now) + `), runID, userID, data, template, cfg.CustomPrompt, cfg.OverrideBasePrompt, cfg.AICfg.Provider, cfg.AICfg.Model, now, now) if err != nil { return err } - _, err = persistenceDB.Exec(` + _, err = persistenceDB.Exec(convertQuery(` UPDATE backtest_runs SET user_id = ?, config_json = ?, prompt_template = ?, custom_prompt = ?, override_prompt = ?, ai_provider = ?, ai_model = ?, updated_at = CURRENT_TIMESTAMP WHERE run_id = ? - `, userID, data, template, cfg.CustomPrompt, cfg.OverrideBasePrompt, cfg.AICfg.Provider, cfg.AICfg.Model, runID) + `), userID, data, template, cfg.CustomPrompt, cfg.OverrideBasePrompt, cfg.AICfg.Provider, cfg.AICfg.Model, runID) return err } func loadConfigDB(runID string) (*BacktestConfig, error) { var payload []byte - err := persistenceDB.QueryRow(`SELECT config_json FROM backtest_runs WHERE run_id = ?`, runID).Scan(&payload) + err := persistenceDB.QueryRow(convertQuery(`SELECT config_json FROM backtest_runs WHERE run_id = ?`), runID).Scan(&payload) if err != nil { return nil, err } @@ -96,18 +96,18 @@ func saveRunMetadataDB(meta *RunMetadata) error { if userID == "" { userID = "default" } - if _, err := persistenceDB.Exec(` + if _, err := persistenceDB.Exec(convertQuery(` INSERT INTO backtest_runs (run_id, user_id, label, last_error, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT(run_id) DO NOTHING - `, meta.RunID, userID, meta.Label, meta.LastError, created, updated); err != nil { + `), meta.RunID, userID, meta.Label, meta.LastError, created, updated); err != nil { return err } - _, err := persistenceDB.Exec(` + _, err := persistenceDB.Exec(convertQuery(` UPDATE backtest_runs SET user_id = ?, state = ?, symbol_count = ?, decision_tf = ?, processed_bars = ?, progress_pct = ?, equity_last = ?, max_drawdown_pct = ?, liquidated = ?, liquidation_note = ?, label = ?, last_error = ?, updated_at = ? WHERE run_id = ? - `, userID, string(meta.State), meta.Summary.SymbolCount, meta.Summary.DecisionTF, meta.Summary.ProcessedBars, meta.Summary.ProgressPct, meta.Summary.EquityLast, meta.Summary.MaxDrawdownPct, meta.Summary.Liquidated, meta.Summary.LiquidationNote, meta.Label, meta.LastError, updated, meta.RunID) + `), userID, string(meta.State), meta.Summary.SymbolCount, meta.Summary.DecisionTF, meta.Summary.ProcessedBars, meta.Summary.ProgressPct, meta.Summary.EquityLast, meta.Summary.MaxDrawdownPct, meta.Summary.Liquidated, meta.Summary.LiquidationNote, meta.Label, meta.LastError, updated, meta.RunID) return err } @@ -128,10 +128,10 @@ func loadRunMetadataDB(runID string) (*RunMetadata, error) { createdISO string updatedISO string ) - err := persistenceDB.QueryRow(` + err := persistenceDB.QueryRow(convertQuery(` SELECT user_id, state, label, last_error, symbol_count, decision_tf, processed_bars, progress_pct, equity_last, max_drawdown_pct, liquidated, liquidation_note, created_at, updated_at FROM backtest_runs WHERE run_id = ? - `, runID).Scan(&userID, &state, &label, &lastErr, &symbolCount, &decisionTF, &processedBars, &progressPct, &equityLast, &maxDD, &liquidated, &liquidationNote, &createdISO, &updatedISO) + `), runID).Scan(&userID, &state, &label, &lastErr, &symbolCount, &decisionTF, &processedBars, &progressPct, &equityLast, &maxDD, &liquidated, &liquidationNote, &createdISO, &updatedISO) if err != nil { return nil, err } @@ -183,18 +183,18 @@ func loadRunIDsDB() ([]string, error) { } func appendEquityPointDB(runID string, point EquityPoint) error { - _, err := persistenceDB.Exec(` + _, err := persistenceDB.Exec(convertQuery(` INSERT INTO backtest_equity (run_id, ts, equity, available, pnl, pnl_pct, dd_pct, cycle) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - `, runID, point.Timestamp, point.Equity, point.Available, point.PnL, point.PnLPct, point.DrawdownPct, point.Cycle) + `), runID, point.Timestamp, point.Equity, point.Available, point.PnL, point.PnLPct, point.DrawdownPct, point.Cycle) return err } func loadEquityPointsDB(runID string) ([]EquityPoint, error) { - rows, err := persistenceDB.Query(` + rows, err := persistenceDB.Query(convertQuery(` SELECT ts, equity, available, pnl, pnl_pct, dd_pct, cycle FROM backtest_equity WHERE run_id = ? ORDER BY ts ASC - `, runID) + `), runID) if err != nil { return nil, err } @@ -211,18 +211,18 @@ func loadEquityPointsDB(runID string) ([]EquityPoint, error) { } func appendTradeEventDB(runID string, event TradeEvent) error { - _, err := persistenceDB.Exec(` + _, err := persistenceDB.Exec(convertQuery(` INSERT INTO backtest_trades (run_id, ts, symbol, action, side, qty, price, fee, slippage, order_value, realized_pnl, leverage, cycle, position_after, liquidation, note) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - `, runID, event.Timestamp, event.Symbol, event.Action, event.Side, event.Quantity, event.Price, event.Fee, event.Slippage, event.OrderValue, event.RealizedPnL, event.Leverage, event.Cycle, event.PositionAfter, event.LiquidationFlag, event.Note) + `), runID, event.Timestamp, event.Symbol, event.Action, event.Side, event.Quantity, event.Price, event.Fee, event.Slippage, event.OrderValue, event.RealizedPnL, event.Leverage, event.Cycle, event.PositionAfter, event.LiquidationFlag, event.Note) return err } func loadTradeEventsDB(runID string) ([]TradeEvent, error) { - rows, err := persistenceDB.Query(` + rows, err := persistenceDB.Query(convertQuery(` SELECT ts, symbol, action, side, qty, price, fee, slippage, order_value, realized_pnl, leverage, cycle, position_after, liquidation, note FROM backtest_trades WHERE run_id = ? ORDER BY ts ASC - `, runID) + `), runID) if err != nil { return nil, err } @@ -243,17 +243,17 @@ func saveMetricsDB(runID string, metrics *Metrics) error { if err != nil { return err } - _, err = persistenceDB.Exec(` + _, err = persistenceDB.Exec(convertQuery(` INSERT INTO backtest_metrics (run_id, payload, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP) ON CONFLICT(run_id) DO UPDATE SET payload=excluded.payload, updated_at=CURRENT_TIMESTAMP - `, runID, data) + `), runID, data) return err } func loadMetricsDB(runID string) (*Metrics, error) { var payload []byte - err := persistenceDB.QueryRow(`SELECT payload FROM backtest_metrics WHERE run_id = ?`, runID).Scan(&payload) + err := persistenceDB.QueryRow(convertQuery(`SELECT payload FROM backtest_metrics WHERE run_id = ?`), runID).Scan(&payload) if err != nil { return nil, err } @@ -265,22 +265,21 @@ func loadMetricsDB(runID string) (*Metrics, error) { } func saveProgressDB(runID string, payload progressPayload) error { - _, err := persistenceDB.Exec(` + _, err := persistenceDB.Exec(convertQuery(` UPDATE backtest_runs SET progress_pct = ?, equity_last = ?, processed_bars = ?, liquidated = ?, updated_at = ? WHERE run_id = ? - `, payload.ProgressPct, payload.Equity, payload.BarIndex, payload.Liquidated, payload.UpdatedAtISO, runID) + `), payload.ProgressPct, payload.Equity, payload.BarIndex, payload.Liquidated, payload.UpdatedAtISO, runID) return err } func loadDecisionTraceDB(runID string, cycle int) (*store.DecisionRecord, error) { - query := `SELECT payload FROM backtest_decisions WHERE run_id = ?` var rows *sql.Rows var err error if cycle > 0 { - rows, err = persistenceDB.Query(query+` AND cycle = ? ORDER BY created_at DESC LIMIT 1`, runID, cycle) + rows, err = persistenceDB.Query(convertQuery(`SELECT payload FROM backtest_decisions WHERE run_id = ? AND cycle = ? ORDER BY created_at DESC LIMIT 1`), runID, cycle) } else { - rows, err = persistenceDB.Query(query+` ORDER BY created_at DESC LIMIT 1`, runID) + rows, err = persistenceDB.Query(convertQuery(`SELECT payload FROM backtest_decisions WHERE run_id = ? ORDER BY created_at DESC LIMIT 1`), runID) } if err != nil { return nil, err @@ -308,20 +307,20 @@ func saveDecisionRecordDB(runID string, record *store.DecisionRecord) error { if err != nil { return err } - _, err = persistenceDB.Exec(` + _, err = persistenceDB.Exec(convertQuery(` INSERT INTO backtest_decisions (run_id, cycle, payload) VALUES (?, ?, ?) - `, runID, record.CycleNumber, data) + `), runID, record.CycleNumber, data) return err } func loadDecisionRecordsDB(runID string, limit, offset int) ([]*store.DecisionRecord, error) { - rows, err := persistenceDB.Query(` + rows, err := persistenceDB.Query(convertQuery(` SELECT payload FROM backtest_decisions WHERE run_id = ? ORDER BY id DESC LIMIT ? OFFSET ? - `, runID, limit, offset) + `), runID, limit, offset) if err != nil { return nil, err } @@ -428,10 +427,10 @@ func writeJSONLinesToZip[T any](z *zip.Writer, name string, items []T) error { } func writeDecisionLogsToZip(z *zip.Writer, runID string) error { - rows, err := persistenceDB.Query(` + rows, err := persistenceDB.Query(convertQuery(` SELECT id, cycle, payload FROM backtest_decisions WHERE run_id = ? ORDER BY id ASC - `, runID) + `), runID) if err != nil { return err } @@ -494,6 +493,6 @@ func listIndexEntriesDB() ([]RunIndexEntry, error) { } func deleteRunDB(runID string) error { - _, err := persistenceDB.Exec(`DELETE FROM backtest_runs WHERE run_id = ?`, runID) + _, err := persistenceDB.Exec(convertQuery(`DELETE FROM backtest_runs WHERE run_id = ?`), runID) return err } diff --git a/main.go b/main.go index 9905e7b1..a0987f81 100644 --- a/main.go +++ b/main.go @@ -78,7 +78,7 @@ func main() { logger.Fatalf("❌ Failed to initialize database: %v", err) } defer st.Close() - backtest.UseDatabase(st.DB()) + backtest.UseDatabaseWithType(st.DB(), st.DBType() == store.DBTypePostgres) // Initialize installation ID for experience improvement (anonymous statistics) initInstallationID(st) diff --git a/store/backtest.go b/store/backtest.go index ecb59f0e..a77dfe16 100644 --- a/store/backtest.go +++ b/store/backtest.go @@ -147,7 +147,7 @@ func (BacktestCheckpoint) TableName() string { type BacktestEquity struct { ID int64 `gorm:"primaryKey;autoIncrement"` RunID string `gorm:"column:run_id;not null;index:idx_backtest_equity_run_ts"` - TS int64 `gorm:"column:ts;not null;index:idx_backtest_equity_run_ts"` + TS int64 `gorm:"column:ts;type:bigint;not null;index:idx_backtest_equity_run_ts"` Equity float64 `gorm:"column:equity;not null"` Available float64 `gorm:"column:available;not null"` PnL float64 `gorm:"column:pnl;not null"` @@ -164,7 +164,7 @@ func (BacktestEquity) TableName() string { type BacktestTrade struct { ID int64 `gorm:"primaryKey;autoIncrement"` RunID string `gorm:"column:run_id;not null;index:idx_backtest_trades_run_ts"` - TS int64 `gorm:"column:ts;not null;index:idx_backtest_trades_run_ts"` + TS int64 `gorm:"column:ts;type:bigint;not null;index:idx_backtest_trades_run_ts"` Symbol string `gorm:"column:symbol;not null"` Action string `gorm:"column:action;not null"` Side string `gorm:"column:side;default:''"` @@ -217,7 +217,10 @@ func (s *BacktestStore) initTables() error { s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'backtest_runs'`).Scan(&tableExists) if tableExists > 0 { - // Tables exist - just ensure indexes exist + // Tables exist - fix column types and ensure indexes exist + // Fix ts column type from INTEGER to BIGINT (timestamps in milliseconds exceed int4 max) + s.db.Exec(`ALTER TABLE backtest_equity ALTER COLUMN ts TYPE BIGINT`) + s.db.Exec(`ALTER TABLE backtest_trades ALTER COLUMN ts TYPE BIGINT`) s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_backtest_equity_run_ts ON backtest_equity(run_id, ts)`) s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_backtest_trades_run_ts ON backtest_trades(run_id, ts)`) s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_backtest_decisions_run_cycle ON backtest_decisions(run_id, cycle)`)