From e2d702c662c9641e3ffdcdab325771afbc8b1c73 Mon Sep 17 00:00:00 2001 From: tinkle-community Date: Sat, 20 Dec 2025 01:10:11 +0800 Subject: [PATCH] feat: enhance backtest with real-time positions, P&L fixes, and strategy integration - Add real-time position display with unrealized P&L during backtest - Fix P&L calculation by tracking accumulated opening fees - Add strategy coin source resolution (AI500, OI Top, mixed) - Infer AI provider from model name for better compatibility - Cap position size to available margin to prevent insufficient cash errors - Fix trade markers on K-line chart (long/short instead of buy/sell) - Add QuantData and OI ranking to backtest decision context --- api/backtest.go | 301 ++++++++++++- backtest/account.go | 23 +- backtest/config.go | 50 ++- backtest/runner.go | 107 ++++- backtest/types.go | 39 +- config/config.go | 2 +- web/src/components/BacktestPage.tsx | 635 ++++++++++++++++++++++++++-- web/src/lib/api.ts | 14 + web/src/types.ts | 35 ++ 9 files changed, 1144 insertions(+), 62 deletions(-) diff --git a/api/backtest.go b/api/backtest.go index 3adc55d1..b8bc840a 100644 --- a/api/backtest.go +++ b/api/backtest.go @@ -3,6 +3,7 @@ package api import ( "context" "database/sql" + "encoding/json" "errors" "fmt" "net/http" @@ -12,6 +13,9 @@ import ( "time" "nofx/backtest" + "nofx/logger" + "nofx/market" + "nofx/provider" "nofx/store" "github.com/gin-gonic/gin" @@ -32,6 +36,7 @@ func (s *Server) registerBacktestRoutes(router *gin.RouterGroup) { router.GET("/trace", s.handleBacktestTrace) router.GET("/decisions", s.handleBacktestDecisions) router.GET("/export", s.handleBacktestExport) + router.GET("/klines", s.handleBacktestKlines) } type backtestStartRequest struct { @@ -65,11 +70,54 @@ func (s *Server) handleBacktestStart(c *gin.Context) { } cfg.CustomPrompt = strings.TrimSpace(cfg.CustomPrompt) cfg.UserID = normalizeUserID(c.GetString("user_id")) + + logger.Infof("📊 Backtest request - symbols from request: %v (count=%d), strategyID: %s", + cfg.Symbols, len(cfg.Symbols), cfg.StrategyID) + + // Load strategy config if strategy_id is provided + if cfg.StrategyID != "" { + strategy, err := s.store.Strategy().Get(cfg.UserID, cfg.StrategyID) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to load strategy: %v", err)}) + return + } + if strategy == nil { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("strategy not found: %s", cfg.StrategyID)}) + return + } + var strategyConfig store.StrategyConfig + if err := json.Unmarshal([]byte(strategy.Config), &strategyConfig); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to parse strategy config: %v", err)}) + return + } + cfg.SetLoadedStrategy(&strategyConfig) + logger.Infof("📊 Backtest using saved strategy: %s (%s)", strategy.Name, strategy.ID) + logger.Infof("📊 Strategy coin source: type=%s, use_coin_pool=%v, use_oi_top=%v, static_coins=%v", + strategyConfig.CoinSource.SourceType, + strategyConfig.CoinSource.UseCoinPool, + strategyConfig.CoinSource.UseOITop, + strategyConfig.CoinSource.StaticCoins) + + // If no symbols provided, fetch from strategy's coin source + if len(cfg.Symbols) == 0 { + symbols, err := s.resolveStrategyCoins(&strategyConfig) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to resolve coins from strategy: %v", err)}) + return + } + cfg.Symbols = symbols + logger.Infof("📊 Resolved %d coins from strategy: %v", len(symbols), symbols) + } + } + if err := s.hydrateBacktestAIConfig(&cfg); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } + logger.Infof("📊 Starting backtest with final config: runID=%s, symbols=%v (count=%d), strategyID=%s", + cfg.RunID, cfg.Symbols, len(cfg.Symbols), cfg.StrategyID) + runner, err := s.backtestManager.Start(context.Background(), cfg) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -443,6 +491,89 @@ func (s *Server) handleBacktestExport(c *gin.Context) { c.FileAttachment(path, filename) } +func (s *Server) handleBacktestKlines(c *gin.Context) { + if s.backtestManager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "backtest manager unavailable"}) + return + } + userID := normalizeUserID(c.GetString("user_id")) + runID := c.Query("run_id") + symbol := c.Query("symbol") + timeframe := c.Query("timeframe") + + if runID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "run_id is required"}) + return + } + if symbol == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "symbol is required"}) + return + } + + meta, err := s.ensureBacktestRunOwnership(runID, userID) + if writeBacktestAccessError(c, err) { + return + } + + // Load config to get time range + cfg, err := backtest.LoadConfig(runID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "failed to load backtest config"}) + return + } + + // Use decision timeframe if not specified + if timeframe == "" { + timeframe = cfg.DecisionTimeframe + if timeframe == "" { + timeframe = "15m" + } + } + + // Fetch klines for the backtest time range + startTime := time.Unix(cfg.StartTS, 0) + endTime := time.Unix(cfg.EndTS, 0) + + klines, err := market.GetKlinesRange(symbol, timeframe, startTime, endTime) + if err != nil { + logger.Errorf("Failed to fetch klines for %s: %v", symbol, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to fetch klines: %v", err)}) + return + } + + // Convert to response format + type KlineResponse struct { + Time int64 `json:"time"` + Open float64 `json:"open"` + High float64 `json:"high"` + Low float64 `json:"low"` + Close float64 `json:"close"` + Volume float64 `json:"volume"` + } + + result := make([]KlineResponse, len(klines)) + for i, k := range klines { + result[i] = KlineResponse{ + Time: k.OpenTime / 1000, // Convert to seconds for lightweight-charts + Open: k.Open, + High: k.High, + Low: k.Low, + Close: k.Close, + Volume: k.Volume, + } + } + + c.JSON(http.StatusOK, gin.H{ + "symbol": symbol, + "timeframe": timeframe, + "start_ts": cfg.StartTS, + "end_ts": cfg.EndTS, + "count": len(result), + "klines": result, + "run_id": meta.RunID, + }) +} + func queryInt(c *gin.Context, name string, fallback int) int { if value := c.Query(name); value != "" { if v, err := strconv.Atoi(value); err == nil { @@ -498,6 +629,155 @@ func writeBacktestAccessError(c *gin.Context, err error) bool { return true } +// resolveStrategyCoins fetches coins based on strategy's coin source configuration +func (s *Server) resolveStrategyCoins(strategyConfig *store.StrategyConfig) ([]string, error) { + if strategyConfig == nil { + return nil, fmt.Errorf("strategy config is nil") + } + + coinSource := strategyConfig.CoinSource + var symbols []string + symbolSet := make(map[string]bool) + + // Set custom API URLs if provided + if coinSource.CoinPoolAPIURL != "" { + provider.SetCoinPoolAPI(coinSource.CoinPoolAPIURL) + } + if coinSource.OITopAPIURL != "" { + provider.SetOITopAPI(coinSource.OITopAPIURL) + } + + // Handle empty source_type - check flags for backward compatibility + sourceType := coinSource.SourceType + if sourceType == "" { + if coinSource.UseCoinPool && coinSource.UseOITop { + sourceType = "mixed" + } else if coinSource.UseCoinPool { + sourceType = "coinpool" + } else if coinSource.UseOITop { + sourceType = "oi_top" + } else if len(coinSource.StaticCoins) > 0 { + sourceType = "static" + } else { + return nil, fmt.Errorf("strategy has no coin source configured") + } + logger.Infof("📊 Inferred source_type=%s from flags", sourceType) + } + + switch sourceType { + case "static": + for _, sym := range coinSource.StaticCoins { + sym = market.Normalize(sym) + if !symbolSet[sym] { + symbols = append(symbols, sym) + symbolSet[sym] = true + } + } + + case "coinpool": + limit := coinSource.CoinPoolLimit + if limit <= 0 { + limit = 30 + } + logger.Infof("📊 Fetching AI500 coins with limit=%d", limit) + coins, err := provider.GetTopRatedCoins(limit) + if err != nil { + return nil, fmt.Errorf("failed to get AI500 coins: %w", err) + } + logger.Infof("📊 Got %d coins from AI500: %v", len(coins), coins) + for _, sym := range coins { + sym = market.Normalize(sym) + if !symbolSet[sym] { + symbols = append(symbols, sym) + symbolSet[sym] = true + } + } + + case "oi_top": + coins, err := provider.GetOITopSymbols() + if err != nil { + return nil, fmt.Errorf("failed to get OI Top coins: %w", err) + } + limit := coinSource.OITopLimit + if limit <= 0 || limit > len(coins) { + limit = len(coins) + } + for i, sym := range coins { + if i >= limit { + break + } + sym = market.Normalize(sym) + if !symbolSet[sym] { + symbols = append(symbols, sym) + symbolSet[sym] = true + } + } + + case "mixed": + // Get from coin pool + if coinSource.UseCoinPool { + limit := coinSource.CoinPoolLimit + if limit <= 0 { + limit = 30 + } + coins, err := provider.GetTopRatedCoins(limit) + if err != nil { + logger.Warnf("Failed to get AI500 coins: %v", err) + } else { + for _, sym := range coins { + sym = market.Normalize(sym) + if !symbolSet[sym] { + symbols = append(symbols, sym) + symbolSet[sym] = true + } + } + } + } + + // Get from OI Top + if coinSource.UseOITop { + coins, err := provider.GetOITopSymbols() + if err != nil { + logger.Warnf("Failed to get OI Top coins: %v", err) + } else { + limit := coinSource.OITopLimit + if limit <= 0 || limit > len(coins) { + limit = len(coins) + } + for i, sym := range coins { + if i >= limit { + break + } + sym = market.Normalize(sym) + if !symbolSet[sym] { + symbols = append(symbols, sym) + symbolSet[sym] = true + } + } + } + } + + // Add static coins + for _, sym := range coinSource.StaticCoins { + sym = market.Normalize(sym) + if !symbolSet[sym] { + symbols = append(symbols, sym) + symbolSet[sym] = true + } + } + + default: + return nil, fmt.Errorf("unknown coin source type: %s", sourceType) + } + + if len(symbols) == 0 { + return nil, fmt.Errorf("no coins resolved from strategy") + } + + logger.Infof("📊 Final resolved symbols: %d coins - %v", len(symbols), symbols) + return symbols, nil +} + func (s *Server) resolveBacktestAIConfig(cfg *backtest.BacktestConfig, userID string) error { if cfg == nil { return fmt.Errorf("config is nil") @@ -549,7 +829,26 @@ func (s *Server) hydrateBacktestAIConfig(cfg *backtest.BacktestConfig) error { return fmt.Errorf("AI model %s is missing API Key, please configure it in the system first", model.Name) } - cfg.AICfg.Provider = strings.ToLower(model.Provider) + provider := strings.ToLower(strings.TrimSpace(model.Provider)) + // Ensure provider is never empty or "inherit" - infer from model name if needed + if provider == "" || provider == "inherit" { + modelNameLower := strings.ToLower(model.Name) + if strings.Contains(modelNameLower, "claude") || strings.Contains(modelNameLower, "anthropic") { + provider = "anthropic" + } else if strings.Contains(modelNameLower, "gpt") || strings.Contains(modelNameLower, "openai") { + provider = "openai" + } else if strings.Contains(modelNameLower, "gemini") || strings.Contains(modelNameLower, "google") { + provider = "google" + } else if strings.Contains(modelNameLower, "deepseek") { + provider = "deepseek" + } else if model.CustomAPIURL != "" { + provider = "custom" + } else { + provider = "openai" // default fallback + } + logger.Infof("📊 Inferred AI provider '%s' from model name '%s'", provider, model.Name) + } + cfg.AICfg.Provider = provider cfg.AICfg.APIKey = apiKey cfg.AICfg.BaseURL = strings.TrimSpace(model.CustomAPIURL) modelName := strings.TrimSpace(model.CustomModelName) diff --git a/backtest/account.go b/backtest/account.go index 2ef85762..abf891a9 100644 --- a/backtest/account.go +++ b/backtest/account.go @@ -18,6 +18,7 @@ type position struct { Notional float64 LiquidationPrice float64 OpenTime int64 + AccumulatedFee float64 // Total fees paid (opening + any additions) } type BacktestAccount struct { @@ -87,6 +88,7 @@ func (acc *BacktestAccount) Open(symbol, side string, quantity float64, leverage pos.Notional = notional pos.OpenTime = ts pos.LiquidationPrice = computeLiquidation(execPrice, leverage, side) + pos.AccumulatedFee = fee // Track opening fee } else { if leverage != pos.Leverage { // Use weighted average leverage (approximate) @@ -98,6 +100,7 @@ func (acc *BacktestAccount) Open(symbol, side string, quantity float64, leverage pos.EntryPrice = ((pos.EntryPrice * pos.Quantity) + execPrice*quantity) / (pos.Quantity + quantity) pos.Quantity += quantity pos.LiquidationPrice = computeLiquidation(pos.EntryPrice, pos.Leverage, side) + pos.AccumulatedFee += fee // Add to accumulated fee for position additions } return pos, fee, execPrice, nil @@ -120,23 +123,32 @@ func (acc *BacktestAccount) Close(symbol, side string, quantity float64, price f execPrice := applySlippage(price, acc.slippageRate, side, false) notional := execPrice * quantity - fee := notional * acc.feeRate + closingFee := notional * acc.feeRate + + // Calculate proportional opening fee for the quantity being closed + closePortion := quantity / pos.Quantity + openingFeePortion := pos.AccumulatedFee * closePortion + totalFee := closingFee + openingFeePortion realized := realizedPnL(pos, quantity, execPrice) - marginPortion := pos.Margin * (quantity / pos.Quantity) - acc.cash += marginPortion + realized - fee - acc.realizedPnL += realized - fee + marginPortion := pos.Margin * closePortion + // Note: Opening fee was already deducted from cash when opening, so we only deduct closing fee here + acc.cash += marginPortion + realized - closingFee + // But for realized P&L tracking, we include both fees + acc.realizedPnL += realized - totalFee pos.Quantity -= quantity pos.Notional -= notional pos.Margin -= marginPortion + pos.AccumulatedFee -= openingFeePortion // Reduce tracked opening fee if pos.Quantity <= epsilon { acc.removePosition(pos) } - return realized, fee, execPrice, nil + // Return total fee (opening + closing) so caller can calculate accurate P&L + return realized, totalFee, execPrice, nil } func (acc *BacktestAccount) TotalEquity(priceMap map[string]float64) (float64, float64, map[string]float64) { @@ -243,6 +255,7 @@ func (acc *BacktestAccount) RestoreFromSnapshots(cash float64, realized float64, Notional: snap.Quantity * snap.AvgPrice, LiquidationPrice: snap.LiquidationPrice, OpenTime: snap.OpenTime, + AccumulatedFee: snap.AccumulatedFee, } key := positionKey(pos.Symbol, pos.Side) acc.positions[key] = pos diff --git a/backtest/config.go b/backtest/config.go index 401d95f5..0dcf48ed 100644 --- a/backtest/config.go +++ b/backtest/config.go @@ -29,6 +29,7 @@ type BacktestConfig struct { RunID string `json:"run_id"` UserID string `json:"user_id,omitempty"` AIModelID string `json:"ai_model_id,omitempty"` + StrategyID string `json:"strategy_id,omitempty"` // Optional: use saved strategy from Strategy Studio Symbols []string `json:"symbols"` Timeframes []string `json:"timeframes"` DecisionTimeframe string `json:"decision_timeframe"` @@ -53,6 +54,9 @@ type BacktestConfig struct { CheckpointIntervalBars int `json:"checkpoint_interval_bars,omitempty"` CheckpointIntervalSeconds int `json:"checkpoint_interval_seconds,omitempty"` ReplayDecisionDir string `json:"replay_decision_dir,omitempty"` + + // Internal: loaded strategy config (set by Manager when StrategyID is provided) + loadedStrategy *store.StrategyConfig `json:"-"` } // Validate performs validity checks on the configuration and fills in default values. @@ -178,10 +182,54 @@ func validateFillPolicy(policy string) error { } } +// SetLoadedStrategy sets the loaded strategy config from database. +func (cfg *BacktestConfig) SetLoadedStrategy(strategy *store.StrategyConfig) { + cfg.loadedStrategy = strategy +} + // ToStrategyConfig converts BacktestConfig to StrategyConfig for unified prompt generation. // This ensures backtest uses the same StrategyEngine logic as live trading. +// If a strategy was loaded from database (via StrategyID), it will be used with overrides. func (cfg *BacktestConfig) ToStrategyConfig() *store.StrategyConfig { - // Determine primary and longer timeframe from the timeframes list + // If a strategy was loaded from database, use it with some overrides + if cfg.loadedStrategy != nil { + result := *cfg.loadedStrategy // Make a copy + + // Override coin source with backtest symbols (回测指定的币对优先) + if len(cfg.Symbols) > 0 { + result.CoinSource.SourceType = "static" + result.CoinSource.StaticCoins = cfg.Symbols + result.CoinSource.UseCoinPool = false + result.CoinSource.UseOITop = false + } + + // Override timeframes with backtest config + if len(cfg.Timeframes) > 0 { + result.Indicators.Klines.SelectedTimeframes = cfg.Timeframes + result.Indicators.Klines.PrimaryTimeframe = cfg.Timeframes[0] + if len(cfg.Timeframes) > 1 { + result.Indicators.Klines.LongerTimeframe = cfg.Timeframes[len(cfg.Timeframes)-1] + } + result.Indicators.Klines.EnableMultiTimeframe = len(cfg.Timeframes) > 1 + } + + // Override leverage with backtest config + if cfg.Leverage.BTCETHLeverage > 0 { + result.RiskControl.BTCETHMaxLeverage = cfg.Leverage.BTCETHLeverage + } + if cfg.Leverage.AltcoinLeverage > 0 { + result.RiskControl.AltcoinMaxLeverage = cfg.Leverage.AltcoinLeverage + } + + // Override custom prompt if provided in backtest config + if cfg.CustomPrompt != "" { + result.CustomPrompt = cfg.CustomPrompt + } + + return &result + } + + // Fallback: build strategy config from backtest config (original logic) primaryTF := "5m" longerTF := "4h" if len(cfg.Timeframes) > 0 { diff --git a/backtest/runner.go b/backtest/runner.go index 8c483ecb..8820d994 100644 --- a/backtest/runner.go +++ b/backtest/runner.go @@ -491,9 +491,14 @@ func (r *Runner) buildDecisionContext(ts int64, marketData map[string]*market.Da positions := r.convertPositions(priceMap) - candidateCoins := make([]decision.CandidateCoin, 0, len(r.cfg.Symbols)) - for _, sym := range r.cfg.Symbols { - candidateCoins = append(candidateCoins, decision.CandidateCoin{Symbol: sym}) + // Get candidate coins from strategy engine (includes source info) + candidateCoins, err := r.strategyEngine.GetCandidateCoins() + if err != nil { + // Fallback to simple list if strategy engine fails + candidateCoins = make([]decision.CandidateCoin, 0, len(r.cfg.Symbols)) + for _, sym := range r.cfg.Symbols { + candidateCoins = append(candidateCoins, decision.CandidateCoin{Symbol: sym, Sources: []string{"backtest"}}) + } } runtime := int((ts - int64(r.cfg.StartTS*1000)) / 60000) @@ -512,6 +517,36 @@ func (r *Runner) buildDecisionContext(ts int64, marketData map[string]*market.Da Timeframes: r.cfg.Timeframes, } + // Fetch quantitative data if enabled in strategy (uses current data as approximation) + strategyConfig := r.strategyEngine.GetConfig() + if strategyConfig.Indicators.EnableQuantData && strategyConfig.Indicators.QuantDataAPIURL != "" { + // Collect symbols to query (candidate coins + position coins) + symbolSet := make(map[string]bool) + for _, sym := range r.cfg.Symbols { + symbolSet[sym] = true + } + for _, pos := range positions { + symbolSet[pos.Symbol] = true + } + symbols := make([]string, 0, len(symbolSet)) + for sym := range symbolSet { + symbols = append(symbols, sym) + } + ctx.QuantDataMap = r.strategyEngine.FetchQuantDataBatch(symbols) + if len(ctx.QuantDataMap) > 0 { + logger.Infof("📊 Backtest: fetched quant data for %d symbols", len(ctx.QuantDataMap)) + } + } + + // Fetch OI ranking data if enabled in strategy (uses current data as approximation) + if strategyConfig.Indicators.EnableOIRanking { + ctx.OIRankingData = r.strategyEngine.FetchOIRankingData() + if ctx.OIRankingData != nil { + logger.Infof("📊 Backtest: OI ranking data ready: %d top, %d low positions", + len(ctx.OIRankingData.TopPositions), len(ctx.OIRankingData.LowPositions)) + } + } + record := &store.DecisionRecord{ AccountState: store.AccountSnapshot{ TotalBalance: accountInfo.TotalEquity, @@ -710,10 +745,31 @@ func (r *Runner) determineQuantity(dec decision.Decision, price float64) float64 if equity <= 0 { equity = r.account.InitialBalance() } + + // Get leverage for this symbol + leverage := r.resolveLeverage(dec.Leverage, dec.Symbol) + if leverage <= 0 { + leverage = 5 + } + + // Calculate available margin (leave some buffer for fees) + availableCash := r.account.Cash() + maxMarginToUse := availableCash * 0.9 // Use max 90% of available cash + maxPositionValue := maxMarginToUse * float64(leverage) + sizeUSD := dec.PositionSizeUSD if sizeUSD <= 0 { + // Default to 5% of equity, but cap to available margin sizeUSD = 0.05 * equity } + + // Cap position size to what we can actually afford + if sizeUSD > maxPositionValue { + logger.Infof("📊 Backtest: capping position from %.2f to %.2f (available margin: %.2f, leverage: %dx)", + sizeUSD, maxPositionValue, maxMarginToUse, leverage) + sizeUSD = maxPositionValue + } + qty := sizeUSD / price if qty < 0 { qty = 0 @@ -855,6 +911,7 @@ func (r *Runner) updateState(ts int64, equity, unrealized, marginUsed float64, p LiquidationPrice: pos.LiquidationPrice, MarginUsed: pos.Margin, OpenTime: pos.OpenTime, + AccumulatedFee: pos.AccumulatedFee, } } @@ -1098,6 +1155,49 @@ func (r *Runner) StatusPayload() StatusPayload { snapshot := r.snapshotState() progress := progressPercent(snapshot, r.cfg) + // Build position statuses with unrealized P&L + positions := make([]PositionStatus, 0, len(snapshot.Positions)) + for _, pos := range snapshot.Positions { + if pos.Quantity <= 0 { + continue + } + // Get mark price from feed if available + markPrice := pos.AvgPrice // fallback to entry price + if r.feed != nil && snapshot.BarTimestamp > 0 { + if md, _, err := r.feed.BuildMarketData(snapshot.BarTimestamp); err == nil { + if data, ok := md[pos.Symbol]; ok { + markPrice = data.CurrentPrice + } + } + } + + // Calculate unrealized P&L + var unrealizedPnL float64 + if pos.Side == "long" { + unrealizedPnL = (markPrice - pos.AvgPrice) * pos.Quantity + } else { + unrealizedPnL = (pos.AvgPrice - markPrice) * pos.Quantity + } + + // Calculate P&L percentage based on margin + pnlPct := 0.0 + if pos.MarginUsed > 0 { + pnlPct = (unrealizedPnL / pos.MarginUsed) * 100 + } + + positions = append(positions, PositionStatus{ + Symbol: pos.Symbol, + Side: pos.Side, + Quantity: pos.Quantity, + EntryPrice: pos.AvgPrice, + MarkPrice: markPrice, + Leverage: pos.Leverage, + UnrealizedPnL: unrealizedPnL, + UnrealizedPnLPct: pnlPct, + MarginUsed: pos.MarginUsed, + }) + } + payload := StatusPayload{ RunID: r.cfg.RunID, State: r.Status(), @@ -1108,6 +1208,7 @@ func (r *Runner) StatusPayload() StatusPayload { Equity: snapshot.Equity, UnrealizedPnL: snapshot.UnrealizedPnL, RealizedPnL: snapshot.RealizedPnL, + Positions: positions, Note: snapshot.LiquidationNote, LastError: r.lastErrorString(), LastUpdatedIso: snapshot.LastUpdate.UTC().Format(time.RFC3339), diff --git a/backtest/types.go b/backtest/types.go index dbd42abd..f9c1295c 100644 --- a/backtest/types.go +++ b/backtest/types.go @@ -25,6 +25,7 @@ type PositionSnapshot struct { LiquidationPrice float64 `json:"liquidation_price"` MarginUsed float64 `json:"margin_used"` OpenTime int64 `json:"open_time"` + AccumulatedFee float64 `json:"accumulated_fee,omitempty"` // Opening fees accumulated } // BacktestState represents the real-time state during execution (in-memory state). @@ -149,16 +150,30 @@ type RunSummary struct { // StatusPayload is used for /status API responses. type StatusPayload struct { - RunID string `json:"run_id"` - State RunState `json:"state"` - ProgressPct float64 `json:"progress_pct"` - ProcessedBars int `json:"processed_bars"` - CurrentTime int64 `json:"current_time"` - DecisionCycle int `json:"decision_cycle"` - Equity float64 `json:"equity"` - UnrealizedPnL float64 `json:"unrealized_pnl"` - RealizedPnL float64 `json:"realized_pnl"` - Note string `json:"note,omitempty"` - LastError string `json:"last_error,omitempty"` - LastUpdatedIso string `json:"last_updated_iso"` + RunID string `json:"run_id"` + State RunState `json:"state"` + ProgressPct float64 `json:"progress_pct"` + ProcessedBars int `json:"processed_bars"` + CurrentTime int64 `json:"current_time"` + DecisionCycle int `json:"decision_cycle"` + Equity float64 `json:"equity"` + UnrealizedPnL float64 `json:"unrealized_pnl"` + RealizedPnL float64 `json:"realized_pnl"` + Positions []PositionStatus `json:"positions,omitempty"` + Note string `json:"note,omitempty"` + LastError string `json:"last_error,omitempty"` + LastUpdatedIso string `json:"last_updated_iso"` +} + +// PositionStatus represents a position with unrealized P&L for status display. +type PositionStatus struct { + Symbol string `json:"symbol"` + Side string `json:"side"` + Quantity float64 `json:"quantity"` + EntryPrice float64 `json:"entry_price"` + MarkPrice float64 `json:"mark_price"` + Leverage int `json:"leverage"` + UnrealizedPnL float64 `json:"unrealized_pnl"` + UnrealizedPnLPct float64 `json:"unrealized_pnl_pct"` + MarginUsed float64 `json:"margin_used"` } diff --git a/config/config.go b/config/config.go index 94e83220..030a447d 100644 --- a/config/config.go +++ b/config/config.go @@ -35,7 +35,7 @@ func Init() { cfg := &Config{ APIServerPort: 8080, RegistrationEnabled: true, - MaxUsers: 1, // Default: only 1 user allowed + MaxUsers: 5, // Default: only 1 user allowed ExperienceImprovement: true, // Default: enabled to help improve the product } diff --git a/web/src/components/BacktestPage.tsx b/web/src/components/BacktestPage.tsx index de792250..cb49d533 100644 --- a/web/src/components/BacktestPage.tsx +++ b/web/src/components/BacktestPage.tsx @@ -1,6 +1,7 @@ -import { useEffect, useMemo, useState, useCallback, type FormEvent } from 'react' +import { useEffect, useMemo, useState, useCallback, useRef, type FormEvent } from 'react' import useSWR from 'swr' import { motion, AnimatePresence } from 'framer-motion' +import { createChart, ColorType, CrosshairMode, CandlestickSeries, createSeriesMarkers, type IChartApi, type ISeriesApi, type CandlestickData, type UTCTimestamp, type SeriesMarker } from 'lightweight-charts' import { Play, Pause, @@ -25,6 +26,7 @@ import { Eye, ArrowUpRight, ArrowDownRight, + CandlestickChart as CandlestickIcon, } from 'lucide-react' import { ResponsiveContainer, @@ -43,11 +45,14 @@ import { confirmToast } from '../lib/notify' import { DecisionCard } from './DecisionCard' import type { BacktestStatusPayload, + BacktestPositionStatus, BacktestEquityPoint, BacktestTradeEvent, BacktestMetrics, + BacktestKlinesResponse, DecisionRecord, AIModel, + Strategy, } from '../types' // ============ Types ============ @@ -261,6 +266,270 @@ function BacktestChart({ ) } +// Candlestick Chart Component with trade markers +function CandlestickChartComponent({ + runId, + trades, + language, +}: { + runId: string + trades: BacktestTradeEvent[] + language: string +}) { + const chartContainerRef = useRef(null) + const chartRef = useRef(null) + const candleSeriesRef = useRef | null>(null) + + // Get unique symbols from trades + const symbols = useMemo(() => { + const symbolSet = new Set(trades.map((t) => t.symbol)) + return Array.from(symbolSet).sort() + }, [trades]) + + const [selectedSymbol, setSelectedSymbol] = useState(symbols[0] || '') + const [selectedTimeframe, setSelectedTimeframe] = useState('15m') + const [isLoading, setIsLoading] = useState(false) + const [error, setError] = useState(null) + + const CHART_TIMEFRAMES = ['1m', '3m', '5m', '15m', '30m', '1h', '4h', '1d'] + + // Update selected symbol when symbols change + useEffect(() => { + if (symbols.length > 0 && !symbols.includes(selectedSymbol)) { + setSelectedSymbol(symbols[0]) + } + }, [symbols, selectedSymbol]) + + // Filter trades for selected symbol + const symbolTrades = useMemo(() => { + return trades.filter((t) => t.symbol === selectedSymbol) + }, [trades, selectedSymbol]) + + // Fetch klines and render chart + useEffect(() => { + if (!chartContainerRef.current || !selectedSymbol || !runId) return + + const container = chartContainerRef.current + + // Create chart + const chart = createChart(container, { + layout: { + background: { type: ColorType.Solid, color: '#0B0E11' }, + textColor: '#848E9C', + }, + grid: { + vertLines: { color: 'rgba(43, 49, 57, 0.5)' }, + horzLines: { color: 'rgba(43, 49, 57, 0.5)' }, + }, + crosshair: { + mode: CrosshairMode.Normal, + }, + rightPriceScale: { + borderColor: '#2B3139', + }, + timeScale: { + borderColor: '#2B3139', + timeVisible: true, + secondsVisible: false, + }, + width: container.clientWidth, + height: 400, + }) + + chartRef.current = chart + + // Add candlestick series + const candleSeries = chart.addSeries(CandlestickSeries, { + upColor: '#0ECB81', + downColor: '#F6465D', + borderUpColor: '#0ECB81', + borderDownColor: '#F6465D', + wickUpColor: '#0ECB81', + wickDownColor: '#F6465D', + }) + candleSeriesRef.current = candleSeries + + // Fetch klines + setIsLoading(true) + setError(null) + + api + .getBacktestKlines(runId, selectedSymbol, selectedTimeframe) + .then((data: BacktestKlinesResponse) => { + const klineData: CandlestickData[] = data.klines.map((k) => ({ + time: k.time as UTCTimestamp, + open: k.open, + high: k.high, + low: k.low, + close: k.close, + })) + candleSeries.setData(klineData) + + // Add trade markers with improved styling + const markers: SeriesMarker[] = symbolTrades + .map((trade) => { + const tradeTime = Math.floor(trade.ts / 1000) + // Find closest kline time + const closestKline = data.klines.reduce((prev, curr) => + Math.abs(curr.time - tradeTime) < Math.abs(prev.time - tradeTime) ? curr : prev + ) + const isOpen = trade.action.includes('open') + const isLong = trade.side === 'long' || trade.action.includes('long') + const pnl = trade.realized_pnl + + // Format display text + let text = '' + let color = '#0ECB81' // Default green + + if (isOpen) { + // Opening position: show direction and price + if (isLong) { + text = `▲ Long @${trade.price.toFixed(2)}` + color = '#0ECB81' // Green for long open + } else { + text = `▼ Short @${trade.price.toFixed(2)}` + color = '#F6465D' // Red for short open + } + } else { + // Closing position: show PnL + const pnlStr = pnl >= 0 ? `+$${pnl.toFixed(2)}` : `-$${Math.abs(pnl).toFixed(2)}` + text = `✕ ${pnlStr}` + color = pnl >= 0 ? '#0ECB81' : '#F6465D' // Green for profit, red for loss + } + + return { + time: closestKline.time as UTCTimestamp, + position: isOpen + ? (isLong ? 'belowBar' as const : 'aboveBar' as const) // Long below, short above + : (isLong ? 'aboveBar' as const : 'belowBar' as const), // Close opposite + color, + shape: 'circle' as const, + size: 2, + text, + } + }) + .sort((a, b) => (a.time as number) - (b.time as number)) + + createSeriesMarkers(candleSeries, markers) + chart.timeScale().fitContent() + setIsLoading(false) + }) + .catch((err) => { + setError(err.message || 'Failed to load klines') + setIsLoading(false) + }) + + // Handle resize + const handleResize = () => { + if (chartContainerRef.current) { + chart.applyOptions({ width: chartContainerRef.current.clientWidth }) + } + } + window.addEventListener('resize', handleResize) + + return () => { + window.removeEventListener('resize', handleResize) + chart.remove() + chartRef.current = null + candleSeriesRef.current = null + } + }, [runId, selectedSymbol, selectedTimeframe, symbolTrades]) + + if (symbols.length === 0) { + return ( +
+ {language === 'zh' ? '没有交易记录' : 'No trades to display'} +
+ ) + } + + return ( +
+ {/* Symbol and Timeframe selectors */} +
+
+ + + {language === 'zh' ? '币种' : 'Symbol'} + + +
+ +
+ + + {language === 'zh' ? '周期' : 'Interval'} + +
+ {CHART_TIMEFRAMES.map((tf) => ( + + ))} +
+
+ + + ({symbolTrades.length} {language === 'zh' ? '笔交易' : 'trades'}) + +
+ + {/* Chart container */} +
+ {isLoading && ( +
+ + {language === 'zh' ? '加载K线数据...' : 'Loading kline data...'} +
+ )} + {error && ( +
+ + {error} +
+ )} +
+ + {/* Legend */} +
+
+
+ {language === 'zh' ? '开仓/盈利' : 'Open/Profit'} +
+
+
+ {language === 'zh' ? '亏损平仓' : 'Loss Close'} +
+ | + ▲ Long · ▼ Short · ✕ {language === 'zh' ? '平仓' : 'Close'} +
+
+ ) +} + // Trade Timeline Component function TradeTimeline({ trades }: { trades: BacktestTradeEvent[] }) { const recentTrades = useMemo(() => [...trades].slice(-20).reverse(), [trades]) @@ -341,6 +610,128 @@ function TradeTimeline({ trades }: { trades: BacktestTradeEvent[] }) { ) } +// Real-time Positions Display Component +function PositionsDisplay({ + positions, + language, +}: { + positions: BacktestPositionStatus[] + language: string +}) { + if (!positions || positions.length === 0) { + return null + } + + const totalUnrealizedPnL = positions.reduce((sum, p) => sum + p.unrealized_pnl, 0) + const totalMargin = positions.reduce((sum, p) => sum + p.margin_used, 0) + + return ( +
+
+
+ + + {language === 'zh' ? '当前持仓' : 'Active Positions'} + + + {positions.length} + +
+
+ + {language === 'zh' ? '保证金' : 'Margin'}: ${totalMargin.toFixed(2)} + + = 0 ? '#0ECB81' : '#F6465D' }} + > + {language === 'zh' ? '浮盈' : 'Unrealized'}: {totalUnrealizedPnL >= 0 ? '+' : ''} + ${totalUnrealizedPnL.toFixed(2)} + +
+
+ +
+ {positions.map((pos) => { + const isLong = pos.side === 'long' + const pnlColor = pos.unrealized_pnl >= 0 ? '#0ECB81' : '#F6465D' + + return ( + +
+
+ {isLong ? ( + + ) : ( + + )} +
+
+
+ + {pos.symbol.replace('USDT', '')} + + + {isLong ? 'LONG' : 'SHORT'} {pos.leverage}x + +
+
+ {language === 'zh' ? '数量' : 'Qty'}: {pos.quantity.toFixed(4)} ·{' '} + {language === 'zh' ? '保证金' : 'Margin'}: ${pos.margin_used.toFixed(2)} +
+
+
+ +
+
+ + {language === 'zh' ? '开仓' : 'Entry'}: ${pos.entry_price.toFixed(2)} + + + {language === 'zh' ? '现价' : 'Mark'}: ${pos.mark_price.toFixed(2)} + +
+
+ + {pos.unrealized_pnl >= 0 ? '+' : ''}${pos.unrealized_pnl.toFixed(2)} + + + {pos.unrealized_pnl_pct >= 0 ? '+' : ''}{pos.unrealized_pnl_pct.toFixed(2)}% + +
+
+
+ ) + })} +
+
+ ) +} + // ============ Main Component ============ export function BacktestPage() { const { language } = useLanguage() @@ -380,6 +771,7 @@ export function BacktestPage() { cacheAI: true, replayOnly: false, aiModelId: '', + strategyId: '', // Optional: use saved strategy from Strategy Studio }) // Data fetching @@ -389,6 +781,7 @@ export function BacktestPage() { const runs = runsResp?.items ?? [] const { data: aiModels } = useSWR('ai-models', api.getModelConfigs, { refreshInterval: 30000 }) + const { data: strategies } = useSWR('strategies', api.getStrategies, { refreshInterval: 30000 }) const { data: status } = useSWR( selectedRunId ? ['bt-status', selectedRunId] : null, @@ -422,6 +815,69 @@ export function BacktestPage() { const selectedRun = runs.find((r) => r.run_id === selectedRunId) const selectedModel = aiModels?.find((m) => m.id === formState.aiModelId) + const selectedStrategy = strategies?.find((s) => s.id === formState.strategyId) + + // Check if selected strategy has dynamic coin source + const strategyHasDynamicCoins = useMemo(() => { + if (!selectedStrategy) return false + const coinSource = selectedStrategy.config?.coin_source + if (!coinSource) return false + + // Check explicit source_type + if (coinSource.source_type === 'coinpool' || coinSource.source_type === 'oi_top') { + return true + } + if (coinSource.source_type === 'mixed' && (coinSource.use_coin_pool || coinSource.use_oi_top)) { + return true + } + + // Also check flags for backward compatibility (when source_type is empty or not set) + const srcType = coinSource.source_type as string + if (!srcType) { + if (coinSource.use_coin_pool || coinSource.use_oi_top) { + return true + } + } + + return false + }, [selectedStrategy]) + + // Get coin source description + const coinSourceDescription = useMemo(() => { + if (!selectedStrategy?.config?.coin_source) return null + const cs = selectedStrategy.config.coin_source + + // Infer source_type from flags if empty (backward compatibility) + let sourceType = cs.source_type as string + if (!sourceType) { + if (cs.use_coin_pool && cs.use_oi_top) { + sourceType = 'mixed' + } else if (cs.use_coin_pool) { + sourceType = 'coinpool' + } else if (cs.use_oi_top) { + sourceType = 'oi_top' + } else if (cs.static_coins?.length) { + sourceType = 'static' + } + } + + switch (sourceType) { + case 'coinpool': + return { type: 'AI500', limit: cs.coin_pool_limit || 30 } + case 'oi_top': + return { type: 'OI Top', limit: cs.oi_top_limit || 30 } + case 'mixed': + const sources = [] + if (cs.use_coin_pool) sources.push(`AI500(${cs.coin_pool_limit || 30})`) + if (cs.use_oi_top) sources.push(`OI Top(${cs.oi_top_limit || 30})`) + if (cs.static_coins?.length) sources.push(`Static(${cs.static_coins.length})`) + return { type: 'Mixed', desc: sources.join(' + ') } + case 'static': + return { type: 'Static', coins: cs.static_coins || [] } + default: + return null + } + }, [selectedStrategy]) // Auto-select first model useEffect(() => { @@ -456,9 +912,16 @@ export function BacktestPage() { const end = new Date(formState.end).getTime() if (end <= start) throw new Error(tr('toasts.invalidRange')) + // Parse user symbols - if using dynamic coin strategy, allow empty + const userSymbols = formState.symbols.split(',').map((s) => s.trim()).filter(Boolean) + + // Only send empty symbols if user deliberately cleared them and strategy has dynamic coin source + const symbolsToSend = (userSymbols.length === 0 && strategyHasDynamicCoins) ? [] : userSymbols + const payload = await api.startBacktest({ run_id: formState.runId.trim() || undefined, - symbols: formState.symbols.split(',').map((s) => s.trim()).filter(Boolean), + strategy_id: formState.strategyId || undefined, // Use saved strategy from Strategy Studio + symbols: symbolsToSend, timeframes: formState.timeframes, decision_timeframe: formState.decisionTf, decision_cadence_nbars: formState.cadence, @@ -727,43 +1190,111 @@ export function BacktestPage() { )}
+ {/* Strategy Selection (Optional) */} +
+ + + {formState.strategyId && coinSourceDescription && ( +
+
+ + {language === 'zh' ? '币种来源:' : 'Coin Source:'} + + + {coinSourceDescription.type} + {coinSourceDescription.limit && ` (${coinSourceDescription.limit})`} + {coinSourceDescription.desc && ` - ${coinSourceDescription.desc}`} + +
+ {strategyHasDynamicCoins && ( +
+ {language === 'zh' + ? '⚡ 清空下方币种输入框即可使用策略的动态币种' + : '⚡ Clear the symbols field below to use strategy\'s dynamic coins'} +
+ )} +
+ )} +
+
-
- {POPULAR_SYMBOLS.map((sym) => { - const isSelected = formState.symbols.includes(sym) - return ( - - ) - })} + {!strategyHasDynamicCoins && ( +
+ {POPULAR_SYMBOLS.map((sym) => { + const isSelected = formState.symbols.includes(sym) + return ( + + ) + })} +
+ )} +
+