mirror of
https://github.com/laoxong/nofx.git
synced 2026-06-04 09:58:22 +08:00
fix: backtest module PostgreSQL compatibility and bug fixes
- Fix PostgreSQL placeholder conversion (? to $1, $2...) in all SQL queries - Fix int4 overflow for timestamp columns (ALTER to BIGINT) - Fix notional calculation bug in position Close() using proportional entry - Fix potential panic in DecisionTimestamp with bounds check - Fix nil pointer dereference in sliceUpTo with defensive checks - Fix race condition in releaseLock using sync.Once - Fix UnrealizedPnLPct always 0 in convertPositions - Improve Sharpe ratio calculation with proper negative return handling
This commit is contained in:
+8
-4
@@ -122,10 +122,10 @@ func (acc *BacktestAccount) Close(symbol, side string, quantity float64, price f
|
||||
}
|
||||
|
||||
execPrice := applySlippage(price, acc.slippageRate, side, false)
|
||||
notional := execPrice * quantity
|
||||
closingFee := notional * acc.feeRate
|
||||
closeNotional := execPrice * quantity // Notional at close price (for fee calculation)
|
||||
closingFee := closeNotional * acc.feeRate
|
||||
|
||||
// Calculate proportional opening fee for the quantity being closed
|
||||
// Calculate proportional values based on the portion being closed
|
||||
closePortion := quantity / pos.Quantity
|
||||
openingFeePortion := pos.AccumulatedFee * closePortion
|
||||
totalFee := closingFee + openingFeePortion
|
||||
@@ -133,13 +133,17 @@ func (acc *BacktestAccount) Close(symbol, side string, quantity float64, price f
|
||||
realized := realizedPnL(pos, quantity, execPrice)
|
||||
|
||||
marginPortion := pos.Margin * closePortion
|
||||
// BUG FIX: Calculate notional portion based on ENTRY price, not close price
|
||||
// pos.Notional tracks the total notional at entry, so we must subtract proportionally
|
||||
entryNotionalPortion := pos.Notional * closePortion
|
||||
|
||||
// Note: Opening fee was already deducted from cash when opening, so we only deduct closing fee here
|
||||
acc.cash += marginPortion + realized - closingFee
|
||||
// But for realized P&L tracking, we include both fees
|
||||
acc.realizedPnL += realized - totalFee
|
||||
|
||||
pos.Quantity -= quantity
|
||||
pos.Notional -= notional
|
||||
pos.Notional -= entryNotionalPortion // FIX: Use entry notional portion, not close notional
|
||||
pos.Margin -= marginPortion
|
||||
pos.AccumulatedFee -= openingFeePortion // Reduce tracked opening fee
|
||||
|
||||
|
||||
+13
-1
@@ -124,11 +124,23 @@ func (df *DataFeed) DecisionBarCount() int {
|
||||
}
|
||||
|
||||
func (df *DataFeed) DecisionTimestamp(index int) int64 {
|
||||
// Bounds check to prevent panic
|
||||
if index < 0 || index >= len(df.decisionTimes) {
|
||||
return 0
|
||||
}
|
||||
return df.decisionTimes[index]
|
||||
}
|
||||
|
||||
func (df *DataFeed) sliceUpTo(symbol, tf string, ts int64) []market.Kline {
|
||||
series := df.symbolSeries[symbol].byTF[tf]
|
||||
// Nil checks to prevent panic
|
||||
ss, ok := df.symbolSeries[symbol]
|
||||
if !ok || ss == nil {
|
||||
return nil
|
||||
}
|
||||
series, ok := ss.byTF[tf]
|
||||
if !ok || series == nil {
|
||||
return nil
|
||||
}
|
||||
idx := sort.Search(len(series.closeTimes), func(i int) bool {
|
||||
return series.closeTimes[i] > ts
|
||||
})
|
||||
|
||||
+26
-12
@@ -91,8 +91,13 @@ func maxDrawdown(points []EquityPoint, state *BacktestState) float64 {
|
||||
return maxDD
|
||||
}
|
||||
|
||||
// sharpeRatio calculates the Sharpe ratio from equity points.
|
||||
// Uses sample standard deviation (n-1) and annualizes assuming ~252 trading days.
|
||||
// Returns math.NaN() for edge cases (insufficient data, zero variance).
|
||||
func sharpeRatio(points []EquityPoint) float64 {
|
||||
if len(points) < 2 {
|
||||
// Need at least 10 data points for meaningful Sharpe calculation
|
||||
const minDataPoints = 10
|
||||
if len(points) < minDataPoints {
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -108,34 +113,42 @@ func sharpeRatio(points []EquityPoint) float64 {
|
||||
returns = append(returns, ret)
|
||||
prev = curr
|
||||
}
|
||||
if len(returns) == 0 {
|
||||
if len(returns) < minDataPoints-1 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Calculate mean return
|
||||
mean := 0.0
|
||||
for _, r := range returns {
|
||||
mean += r
|
||||
}
|
||||
mean /= float64(len(returns))
|
||||
|
||||
// Calculate sample variance (using n-1 for unbiased estimator)
|
||||
variance := 0.0
|
||||
for _, r := range returns {
|
||||
diff := r - mean
|
||||
variance += diff * diff
|
||||
}
|
||||
variance /= float64(len(returns))
|
||||
if len(returns) > 1 {
|
||||
variance /= float64(len(returns) - 1)
|
||||
}
|
||||
|
||||
std := math.Sqrt(variance)
|
||||
if std == 0 {
|
||||
if mean > 0 {
|
||||
return 999
|
||||
}
|
||||
if mean < 0 {
|
||||
return -999
|
||||
}
|
||||
if std < 1e-10 {
|
||||
// Zero or near-zero volatility - return 0 instead of infinity/NaN
|
||||
return 0
|
||||
}
|
||||
return mean / std
|
||||
|
||||
// Calculate Sharpe ratio (assuming risk-free rate = 0 for crypto)
|
||||
// Annualize by multiplying by sqrt(periods per year)
|
||||
// Assuming each equity point represents ~1 hour, we have ~8760 periods/year
|
||||
// For conservative estimate, use sqrt(252) as if daily returns
|
||||
periodsPerYear := 252.0
|
||||
annualizationFactor := math.Sqrt(periodsPerYear)
|
||||
|
||||
sharpe := (mean / std) * annualizationFactor
|
||||
return sharpe
|
||||
}
|
||||
|
||||
func fillTradeMetrics(metrics *Metrics, events []TradeEvent) {
|
||||
@@ -189,7 +202,8 @@ func fillTradeMetrics(metrics *Metrics, events []TradeEvent) {
|
||||
if totalLossAmount > 0 {
|
||||
metrics.ProfitFactor = totalWinAmount / totalLossAmount
|
||||
} else if totalWinAmount > 0 {
|
||||
metrics.ProfitFactor = 999
|
||||
// No losses but have wins - use a high but reasonable cap
|
||||
metrics.ProfitFactor = 100.0
|
||||
}
|
||||
|
||||
bestSymbol := ""
|
||||
|
||||
@@ -2,15 +2,39 @@ package backtest
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var persistenceDB *sql.DB
|
||||
var dbIsPostgres bool
|
||||
|
||||
// UseDatabase enables database-backed persistence for all backtest storage operations.
|
||||
// If isPostgres is true, queries will use $1, $2... placeholders instead of ?
|
||||
func UseDatabase(db *sql.DB) {
|
||||
persistenceDB = db
|
||||
}
|
||||
|
||||
// UseDatabaseWithType enables database-backed persistence with explicit type.
|
||||
func UseDatabaseWithType(db *sql.DB, isPostgres bool) {
|
||||
persistenceDB = db
|
||||
dbIsPostgres = isPostgres
|
||||
}
|
||||
|
||||
func usingDB() bool {
|
||||
return persistenceDB != nil
|
||||
}
|
||||
|
||||
// convertQuery converts ? placeholders to $1, $2, etc. for PostgreSQL
|
||||
func convertQuery(query string) string {
|
||||
if !dbIsPostgres {
|
||||
return query
|
||||
}
|
||||
result := query
|
||||
index := 1
|
||||
for strings.Contains(result, "?") {
|
||||
result = strings.Replace(result, "?", fmt.Sprintf("$%d", index), 1)
|
||||
index++
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -73,12 +73,12 @@ func enforceRetentionDB(maxRuns int) {
|
||||
RunStateFailed,
|
||||
RunStateLiquidated,
|
||||
}
|
||||
query := `
|
||||
query := convertQuery(`
|
||||
SELECT run_id FROM backtest_runs
|
||||
WHERE state IN (?, ?, ?, ?)
|
||||
ORDER BY updated_at DESC
|
||||
OFFSET ?
|
||||
`
|
||||
`)
|
||||
rows, err := persistenceDB.Query(query,
|
||||
finalStates[0], finalStates[1], finalStates[2], finalStates[3], maxRuns)
|
||||
if err != nil {
|
||||
|
||||
+70
-23
@@ -60,8 +60,9 @@ type Runner struct {
|
||||
aiCache *AICache
|
||||
cachePath string
|
||||
|
||||
lockInfo *RunLockInfo
|
||||
lockStop chan struct{}
|
||||
lockInfo *RunLockInfo
|
||||
lockStop chan struct{}
|
||||
lockStopOnce sync.Once // Ensures lockStop is closed only once
|
||||
}
|
||||
|
||||
// NewRunner constructs a backtest runner.
|
||||
@@ -175,10 +176,12 @@ func (r *Runner) lockHeartbeatLoop() {
|
||||
}
|
||||
|
||||
func (r *Runner) releaseLock() {
|
||||
if r.lockStop != nil {
|
||||
close(r.lockStop)
|
||||
r.lockStop = nil
|
||||
}
|
||||
// Use sync.Once to ensure channel is closed exactly once, preventing panic on double-close
|
||||
r.lockStopOnce.Do(func() {
|
||||
if r.lockStop != nil {
|
||||
close(r.lockStop)
|
||||
}
|
||||
})
|
||||
if err := deleteRunLock(r.cfg.RunID); err != nil {
|
||||
logger.Infof("failed to release lock for %s: %v", r.cfg.RunID, err)
|
||||
}
|
||||
@@ -297,9 +300,12 @@ func (r *Runner) stepOnce() error {
|
||||
if shouldDecide {
|
||||
ctx, rec, err := r.buildDecisionContext(ts, marketData, multiTF, priceMap, callCount)
|
||||
if err != nil {
|
||||
rec.Success = false
|
||||
rec.ErrorMessage = fmt.Sprintf("failed to build trading context: %v", err)
|
||||
_ = r.logDecision(rec)
|
||||
// Defensive nil check to prevent panic if buildDecisionContext returns error with nil record
|
||||
if rec != nil {
|
||||
rec.Success = false
|
||||
rec.ErrorMessage = fmt.Sprintf("failed to build trading context: %v", err)
|
||||
_ = r.logDecision(rec)
|
||||
}
|
||||
return err
|
||||
}
|
||||
record = rec
|
||||
@@ -617,6 +623,10 @@ func (r *Runner) invokeAIWithRetry(ctx *kernel.Context) (*kernel.FullDecision, e
|
||||
|
||||
func (r *Runner) executeDecision(dec kernel.Decision, priceMap map[string]float64, ts int64, cycle int) (store.DecisionAction, []TradeEvent, string, error) {
|
||||
symbol := dec.Symbol
|
||||
if symbol == "" {
|
||||
return store.DecisionAction{}, nil, "", fmt.Errorf("empty symbol in decision")
|
||||
}
|
||||
|
||||
usedLeverage := r.resolveLeverage(dec.Leverage, symbol)
|
||||
actionRecord := store.DecisionAction{
|
||||
Action: dec.Action,
|
||||
@@ -625,9 +635,13 @@ func (r *Runner) executeDecision(dec kernel.Decision, priceMap map[string]float6
|
||||
Timestamp: time.UnixMilli(ts).UTC(),
|
||||
}
|
||||
|
||||
basePrice := priceMap[symbol]
|
||||
if basePrice <= 0 {
|
||||
return actionRecord, nil, "", fmt.Errorf("price unavailable for %s", symbol)
|
||||
if priceMap == nil {
|
||||
return actionRecord, nil, "", fmt.Errorf("priceMap is nil")
|
||||
}
|
||||
|
||||
basePrice, ok := priceMap[symbol]
|
||||
if !ok || basePrice <= 0 {
|
||||
return actionRecord, nil, "", fmt.Errorf("price unavailable for %s (found=%v, price=%.4f)", symbol, ok, basePrice)
|
||||
}
|
||||
fillPrice := r.executionPrice(symbol, basePrice, ts)
|
||||
|
||||
@@ -757,6 +771,9 @@ func (r *Runner) executeDecision(dec kernel.Decision, priceMap map[string]float6
|
||||
}
|
||||
}
|
||||
|
||||
// MinPositionSizeUSD is the minimum position size in USD to avoid dust positions
|
||||
const MinPositionSizeUSD = 10.0
|
||||
|
||||
func (r *Runner) determineQuantity(dec kernel.Decision, price float64) float64 {
|
||||
snapshot := r.snapshotState()
|
||||
equity := snapshot.Equity
|
||||
@@ -788,6 +805,13 @@ func (r *Runner) determineQuantity(dec kernel.Decision, price float64) float64 {
|
||||
sizeUSD = maxPositionValue
|
||||
}
|
||||
|
||||
// Reject positions below minimum size to avoid dust positions
|
||||
if sizeUSD < MinPositionSizeUSD {
|
||||
logger.Infof("📊 Backtest: rejecting position size %.2f USD (below minimum %.2f USD)",
|
||||
sizeUSD, MinPositionSizeUSD)
|
||||
return 0
|
||||
}
|
||||
|
||||
qty := sizeUSD / price
|
||||
if qty < 0 {
|
||||
qty = 0
|
||||
@@ -805,20 +829,37 @@ func (r *Runner) determineCloseQuantity(symbol, side string, dec kernel.Decision
|
||||
}
|
||||
|
||||
func (r *Runner) resolveLeverage(requested int, symbol string) int {
|
||||
if requested > 0 {
|
||||
return requested
|
||||
}
|
||||
sym := strings.ToUpper(symbol)
|
||||
if sym == "BTCUSDT" || sym == "ETHUSDT" {
|
||||
if r.cfg.Leverage.BTCETHLeverage > 0 {
|
||||
return r.cfg.Leverage.BTCETHLeverage
|
||||
isBTCETH := sym == "BTCUSDT" || sym == "ETHUSDT"
|
||||
|
||||
// Determine configured max leverage for this symbol type
|
||||
var maxLeverage int
|
||||
if isBTCETH {
|
||||
maxLeverage = r.cfg.Leverage.BTCETHLeverage
|
||||
if maxLeverage <= 0 {
|
||||
maxLeverage = 10 // Default max for BTC/ETH
|
||||
}
|
||||
} else {
|
||||
if r.cfg.Leverage.AltcoinLeverage > 0 {
|
||||
return r.cfg.Leverage.AltcoinLeverage
|
||||
maxLeverage = r.cfg.Leverage.AltcoinLeverage
|
||||
if maxLeverage <= 0 {
|
||||
maxLeverage = 5 // Default max for altcoins
|
||||
}
|
||||
}
|
||||
return 5
|
||||
|
||||
// Use requested leverage if provided, otherwise use max as default
|
||||
leverage := requested
|
||||
if leverage <= 0 {
|
||||
leverage = maxLeverage
|
||||
}
|
||||
|
||||
// Enforce max leverage limit
|
||||
if leverage > maxLeverage {
|
||||
logger.Infof("📊 Backtest: capping leverage from %dx to %dx for %s",
|
||||
leverage, maxLeverage, symbol)
|
||||
leverage = maxLeverage
|
||||
}
|
||||
|
||||
return leverage
|
||||
}
|
||||
|
||||
func (r *Runner) remainingPosition(symbol, side string) float64 {
|
||||
@@ -854,6 +895,12 @@ func (r *Runner) convertPositions(priceMap map[string]float64) []kernel.Position
|
||||
list := make([]kernel.PositionInfo, 0, len(positions))
|
||||
for _, pos := range positions {
|
||||
price := priceMap[pos.Symbol]
|
||||
pnl := unrealizedPnL(pos, price)
|
||||
// Calculate P&L percentage based on entry notional (position cost)
|
||||
pnlPct := 0.0
|
||||
if pos.Notional > 0 {
|
||||
pnlPct = (pnl / pos.Notional) * 100
|
||||
}
|
||||
list = append(list, kernel.PositionInfo{
|
||||
Symbol: pos.Symbol,
|
||||
Side: pos.Side,
|
||||
@@ -861,8 +908,8 @@ func (r *Runner) convertPositions(priceMap map[string]float64) []kernel.Position
|
||||
MarkPrice: price,
|
||||
Quantity: pos.Quantity,
|
||||
Leverage: pos.Leverage,
|
||||
UnrealizedPnL: unrealizedPnL(pos, price),
|
||||
UnrealizedPnLPct: 0,
|
||||
UnrealizedPnL: pnl,
|
||||
UnrealizedPnLPct: pnlPct,
|
||||
LiquidationPrice: pos.LiquidationPrice,
|
||||
MarginUsed: pos.Margin,
|
||||
UpdateTime: time.Now().UnixMilli(),
|
||||
|
||||
+36
-37
@@ -17,17 +17,17 @@ func saveCheckpointDB(runID string, ckpt *Checkpoint) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = persistenceDB.Exec(`
|
||||
_, err = persistenceDB.Exec(convertQuery(`
|
||||
INSERT INTO backtest_checkpoints (run_id, payload, updated_at)
|
||||
VALUES (?, ?, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT(run_id) DO UPDATE SET payload=excluded.payload, updated_at=CURRENT_TIMESTAMP
|
||||
`, runID, data)
|
||||
`), runID, data)
|
||||
return err
|
||||
}
|
||||
|
||||
func loadCheckpointDB(runID string) (*Checkpoint, error) {
|
||||
var payload []byte
|
||||
err := persistenceDB.QueryRow(`SELECT payload FROM backtest_checkpoints WHERE run_id = ?`, runID).Scan(&payload)
|
||||
err := persistenceDB.QueryRow(convertQuery(`SELECT payload FROM backtest_checkpoints WHERE run_id = ?`), runID).Scan(&payload)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, os.ErrNotExist
|
||||
@@ -57,25 +57,25 @@ func saveConfigDB(runID string, cfg *BacktestConfig) error {
|
||||
if userID == "" {
|
||||
userID = "default"
|
||||
}
|
||||
_, err = persistenceDB.Exec(`
|
||||
_, err = persistenceDB.Exec(convertQuery(`
|
||||
INSERT INTO backtest_runs (run_id, user_id, config_json, prompt_template, custom_prompt, override_prompt, ai_provider, ai_model, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(run_id) DO NOTHING
|
||||
`, runID, userID, data, template, cfg.CustomPrompt, cfg.OverrideBasePrompt, cfg.AICfg.Provider, cfg.AICfg.Model, now, now)
|
||||
`), runID, userID, data, template, cfg.CustomPrompt, cfg.OverrideBasePrompt, cfg.AICfg.Provider, cfg.AICfg.Model, now, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = persistenceDB.Exec(`
|
||||
_, err = persistenceDB.Exec(convertQuery(`
|
||||
UPDATE backtest_runs
|
||||
SET user_id = ?, config_json = ?, prompt_template = ?, custom_prompt = ?, override_prompt = ?, ai_provider = ?, ai_model = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE run_id = ?
|
||||
`, userID, data, template, cfg.CustomPrompt, cfg.OverrideBasePrompt, cfg.AICfg.Provider, cfg.AICfg.Model, runID)
|
||||
`), userID, data, template, cfg.CustomPrompt, cfg.OverrideBasePrompt, cfg.AICfg.Provider, cfg.AICfg.Model, runID)
|
||||
return err
|
||||
}
|
||||
|
||||
func loadConfigDB(runID string) (*BacktestConfig, error) {
|
||||
var payload []byte
|
||||
err := persistenceDB.QueryRow(`SELECT config_json FROM backtest_runs WHERE run_id = ?`, runID).Scan(&payload)
|
||||
err := persistenceDB.QueryRow(convertQuery(`SELECT config_json FROM backtest_runs WHERE run_id = ?`), runID).Scan(&payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -96,18 +96,18 @@ func saveRunMetadataDB(meta *RunMetadata) error {
|
||||
if userID == "" {
|
||||
userID = "default"
|
||||
}
|
||||
if _, err := persistenceDB.Exec(`
|
||||
if _, err := persistenceDB.Exec(convertQuery(`
|
||||
INSERT INTO backtest_runs (run_id, user_id, label, last_error, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(run_id) DO NOTHING
|
||||
`, meta.RunID, userID, meta.Label, meta.LastError, created, updated); err != nil {
|
||||
`), meta.RunID, userID, meta.Label, meta.LastError, created, updated); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := persistenceDB.Exec(`
|
||||
_, err := persistenceDB.Exec(convertQuery(`
|
||||
UPDATE backtest_runs
|
||||
SET user_id = ?, state = ?, symbol_count = ?, decision_tf = ?, processed_bars = ?, progress_pct = ?, equity_last = ?, max_drawdown_pct = ?, liquidated = ?, liquidation_note = ?, label = ?, last_error = ?, updated_at = ?
|
||||
WHERE run_id = ?
|
||||
`, userID, string(meta.State), meta.Summary.SymbolCount, meta.Summary.DecisionTF, meta.Summary.ProcessedBars, meta.Summary.ProgressPct, meta.Summary.EquityLast, meta.Summary.MaxDrawdownPct, meta.Summary.Liquidated, meta.Summary.LiquidationNote, meta.Label, meta.LastError, updated, meta.RunID)
|
||||
`), userID, string(meta.State), meta.Summary.SymbolCount, meta.Summary.DecisionTF, meta.Summary.ProcessedBars, meta.Summary.ProgressPct, meta.Summary.EquityLast, meta.Summary.MaxDrawdownPct, meta.Summary.Liquidated, meta.Summary.LiquidationNote, meta.Label, meta.LastError, updated, meta.RunID)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -128,10 +128,10 @@ func loadRunMetadataDB(runID string) (*RunMetadata, error) {
|
||||
createdISO string
|
||||
updatedISO string
|
||||
)
|
||||
err := persistenceDB.QueryRow(`
|
||||
err := persistenceDB.QueryRow(convertQuery(`
|
||||
SELECT user_id, state, label, last_error, symbol_count, decision_tf, processed_bars, progress_pct, equity_last, max_drawdown_pct, liquidated, liquidation_note, created_at, updated_at
|
||||
FROM backtest_runs WHERE run_id = ?
|
||||
`, runID).Scan(&userID, &state, &label, &lastErr, &symbolCount, &decisionTF, &processedBars, &progressPct, &equityLast, &maxDD, &liquidated, &liquidationNote, &createdISO, &updatedISO)
|
||||
`), runID).Scan(&userID, &state, &label, &lastErr, &symbolCount, &decisionTF, &processedBars, &progressPct, &equityLast, &maxDD, &liquidated, &liquidationNote, &createdISO, &updatedISO)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -183,18 +183,18 @@ func loadRunIDsDB() ([]string, error) {
|
||||
}
|
||||
|
||||
func appendEquityPointDB(runID string, point EquityPoint) error {
|
||||
_, err := persistenceDB.Exec(`
|
||||
_, err := persistenceDB.Exec(convertQuery(`
|
||||
INSERT INTO backtest_equity (run_id, ts, equity, available, pnl, pnl_pct, dd_pct, cycle)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`, runID, point.Timestamp, point.Equity, point.Available, point.PnL, point.PnLPct, point.DrawdownPct, point.Cycle)
|
||||
`), runID, point.Timestamp, point.Equity, point.Available, point.PnL, point.PnLPct, point.DrawdownPct, point.Cycle)
|
||||
return err
|
||||
}
|
||||
|
||||
func loadEquityPointsDB(runID string) ([]EquityPoint, error) {
|
||||
rows, err := persistenceDB.Query(`
|
||||
rows, err := persistenceDB.Query(convertQuery(`
|
||||
SELECT ts, equity, available, pnl, pnl_pct, dd_pct, cycle
|
||||
FROM backtest_equity WHERE run_id = ? ORDER BY ts ASC
|
||||
`, runID)
|
||||
`), runID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -211,18 +211,18 @@ func loadEquityPointsDB(runID string) ([]EquityPoint, error) {
|
||||
}
|
||||
|
||||
func appendTradeEventDB(runID string, event TradeEvent) error {
|
||||
_, err := persistenceDB.Exec(`
|
||||
_, err := persistenceDB.Exec(convertQuery(`
|
||||
INSERT INTO backtest_trades (run_id, ts, symbol, action, side, qty, price, fee, slippage, order_value, realized_pnl, leverage, cycle, position_after, liquidation, note)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`, runID, event.Timestamp, event.Symbol, event.Action, event.Side, event.Quantity, event.Price, event.Fee, event.Slippage, event.OrderValue, event.RealizedPnL, event.Leverage, event.Cycle, event.PositionAfter, event.LiquidationFlag, event.Note)
|
||||
`), runID, event.Timestamp, event.Symbol, event.Action, event.Side, event.Quantity, event.Price, event.Fee, event.Slippage, event.OrderValue, event.RealizedPnL, event.Leverage, event.Cycle, event.PositionAfter, event.LiquidationFlag, event.Note)
|
||||
return err
|
||||
}
|
||||
|
||||
func loadTradeEventsDB(runID string) ([]TradeEvent, error) {
|
||||
rows, err := persistenceDB.Query(`
|
||||
rows, err := persistenceDB.Query(convertQuery(`
|
||||
SELECT ts, symbol, action, side, qty, price, fee, slippage, order_value, realized_pnl, leverage, cycle, position_after, liquidation, note
|
||||
FROM backtest_trades WHERE run_id = ? ORDER BY ts ASC
|
||||
`, runID)
|
||||
`), runID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -243,17 +243,17 @@ func saveMetricsDB(runID string, metrics *Metrics) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = persistenceDB.Exec(`
|
||||
_, err = persistenceDB.Exec(convertQuery(`
|
||||
INSERT INTO backtest_metrics (run_id, payload, updated_at)
|
||||
VALUES (?, ?, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT(run_id) DO UPDATE SET payload=excluded.payload, updated_at=CURRENT_TIMESTAMP
|
||||
`, runID, data)
|
||||
`), runID, data)
|
||||
return err
|
||||
}
|
||||
|
||||
func loadMetricsDB(runID string) (*Metrics, error) {
|
||||
var payload []byte
|
||||
err := persistenceDB.QueryRow(`SELECT payload FROM backtest_metrics WHERE run_id = ?`, runID).Scan(&payload)
|
||||
err := persistenceDB.QueryRow(convertQuery(`SELECT payload FROM backtest_metrics WHERE run_id = ?`), runID).Scan(&payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -265,22 +265,21 @@ func loadMetricsDB(runID string) (*Metrics, error) {
|
||||
}
|
||||
|
||||
func saveProgressDB(runID string, payload progressPayload) error {
|
||||
_, err := persistenceDB.Exec(`
|
||||
_, err := persistenceDB.Exec(convertQuery(`
|
||||
UPDATE backtest_runs
|
||||
SET progress_pct = ?, equity_last = ?, processed_bars = ?, liquidated = ?, updated_at = ?
|
||||
WHERE run_id = ?
|
||||
`, payload.ProgressPct, payload.Equity, payload.BarIndex, payload.Liquidated, payload.UpdatedAtISO, runID)
|
||||
`), payload.ProgressPct, payload.Equity, payload.BarIndex, payload.Liquidated, payload.UpdatedAtISO, runID)
|
||||
return err
|
||||
}
|
||||
|
||||
func loadDecisionTraceDB(runID string, cycle int) (*store.DecisionRecord, error) {
|
||||
query := `SELECT payload FROM backtest_decisions WHERE run_id = ?`
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if cycle > 0 {
|
||||
rows, err = persistenceDB.Query(query+` AND cycle = ? ORDER BY created_at DESC LIMIT 1`, runID, cycle)
|
||||
rows, err = persistenceDB.Query(convertQuery(`SELECT payload FROM backtest_decisions WHERE run_id = ? AND cycle = ? ORDER BY created_at DESC LIMIT 1`), runID, cycle)
|
||||
} else {
|
||||
rows, err = persistenceDB.Query(query+` ORDER BY created_at DESC LIMIT 1`, runID)
|
||||
rows, err = persistenceDB.Query(convertQuery(`SELECT payload FROM backtest_decisions WHERE run_id = ? ORDER BY created_at DESC LIMIT 1`), runID)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -308,20 +307,20 @@ func saveDecisionRecordDB(runID string, record *store.DecisionRecord) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = persistenceDB.Exec(`
|
||||
_, err = persistenceDB.Exec(convertQuery(`
|
||||
INSERT INTO backtest_decisions (run_id, cycle, payload)
|
||||
VALUES (?, ?, ?)
|
||||
`, runID, record.CycleNumber, data)
|
||||
`), runID, record.CycleNumber, data)
|
||||
return err
|
||||
}
|
||||
|
||||
func loadDecisionRecordsDB(runID string, limit, offset int) ([]*store.DecisionRecord, error) {
|
||||
rows, err := persistenceDB.Query(`
|
||||
rows, err := persistenceDB.Query(convertQuery(`
|
||||
SELECT payload FROM backtest_decisions
|
||||
WHERE run_id = ?
|
||||
ORDER BY id DESC
|
||||
LIMIT ? OFFSET ?
|
||||
`, runID, limit, offset)
|
||||
`), runID, limit, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -428,10 +427,10 @@ func writeJSONLinesToZip[T any](z *zip.Writer, name string, items []T) error {
|
||||
}
|
||||
|
||||
func writeDecisionLogsToZip(z *zip.Writer, runID string) error {
|
||||
rows, err := persistenceDB.Query(`
|
||||
rows, err := persistenceDB.Query(convertQuery(`
|
||||
SELECT id, cycle, payload FROM backtest_decisions
|
||||
WHERE run_id = ? ORDER BY id ASC
|
||||
`, runID)
|
||||
`), runID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -494,6 +493,6 @@ func listIndexEntriesDB() ([]RunIndexEntry, error) {
|
||||
}
|
||||
|
||||
func deleteRunDB(runID string) error {
|
||||
_, err := persistenceDB.Exec(`DELETE FROM backtest_runs WHERE run_id = ?`, runID)
|
||||
_, err := persistenceDB.Exec(convertQuery(`DELETE FROM backtest_runs WHERE run_id = ?`), runID)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ func main() {
|
||||
logger.Fatalf("❌ Failed to initialize database: %v", err)
|
||||
}
|
||||
defer st.Close()
|
||||
backtest.UseDatabase(st.DB())
|
||||
backtest.UseDatabaseWithType(st.DB(), st.DBType() == store.DBTypePostgres)
|
||||
|
||||
// Initialize installation ID for experience improvement (anonymous statistics)
|
||||
initInstallationID(st)
|
||||
|
||||
+6
-3
@@ -147,7 +147,7 @@ func (BacktestCheckpoint) TableName() string {
|
||||
type BacktestEquity struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
RunID string `gorm:"column:run_id;not null;index:idx_backtest_equity_run_ts"`
|
||||
TS int64 `gorm:"column:ts;not null;index:idx_backtest_equity_run_ts"`
|
||||
TS int64 `gorm:"column:ts;type:bigint;not null;index:idx_backtest_equity_run_ts"`
|
||||
Equity float64 `gorm:"column:equity;not null"`
|
||||
Available float64 `gorm:"column:available;not null"`
|
||||
PnL float64 `gorm:"column:pnl;not null"`
|
||||
@@ -164,7 +164,7 @@ func (BacktestEquity) TableName() string {
|
||||
type BacktestTrade struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
RunID string `gorm:"column:run_id;not null;index:idx_backtest_trades_run_ts"`
|
||||
TS int64 `gorm:"column:ts;not null;index:idx_backtest_trades_run_ts"`
|
||||
TS int64 `gorm:"column:ts;type:bigint;not null;index:idx_backtest_trades_run_ts"`
|
||||
Symbol string `gorm:"column:symbol;not null"`
|
||||
Action string `gorm:"column:action;not null"`
|
||||
Side string `gorm:"column:side;default:''"`
|
||||
@@ -217,7 +217,10 @@ func (s *BacktestStore) initTables() error {
|
||||
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'backtest_runs'`).Scan(&tableExists)
|
||||
|
||||
if tableExists > 0 {
|
||||
// Tables exist - just ensure indexes exist
|
||||
// Tables exist - fix column types and ensure indexes exist
|
||||
// Fix ts column type from INTEGER to BIGINT (timestamps in milliseconds exceed int4 max)
|
||||
s.db.Exec(`ALTER TABLE backtest_equity ALTER COLUMN ts TYPE BIGINT`)
|
||||
s.db.Exec(`ALTER TABLE backtest_trades ALTER COLUMN ts TYPE BIGINT`)
|
||||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_backtest_equity_run_ts ON backtest_equity(run_id, ts)`)
|
||||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_backtest_trades_run_ts ON backtest_trades(run_id, ts)`)
|
||||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_backtest_decisions_run_cycle ON backtest_decisions(run_id, cycle)`)
|
||||
|
||||
Reference in New Issue
Block a user