feat(grid): auto-adjust grid direction based on box breakouts

Add GridDirection type with 5 states:
- neutral (50% buy + 50% sell)
- long/short (100% one direction)
- long_bias/short_bias (70%/30% configurable)

Direction adjustment logic:
- Short box breakout → bias direction (long_bias/short_bias)
- Mid box breakout → full direction (long/short)
- Long box breakout → emergency handling (unchanged)
- Recovery: long → long_bias → neutral ← short_bias ← short

Config options:
- EnableDirectionAdjust (default: false)
- DirectionBiasRatio (default: 0.7)

Includes unit tests for all direction-related functions.
This commit is contained in:
tinkle-community
2026-02-04 11:25:47 +08:00
parent 382e756328
commit 773857351f
7 changed files with 691 additions and 6 deletions
+31
View File
@@ -84,6 +84,9 @@ type GridContext struct {
// Box indicators (Donchian Channels)
BoxData *market.BoxData `json:"box_data,omitempty"`
// Grid direction (neutral, long, short, long_bias, short_bias)
CurrentDirection string `json:"current_direction,omitempty"`
}
// ============================================================================
@@ -279,6 +282,20 @@ func buildGridUserPromptZh(ctx *GridContext) string {
sb.WriteString(fmt.Sprintf("- 活跃订单数: %d\n", ctx.ActiveOrderCount))
sb.WriteString(fmt.Sprintf("- 已成交层数: %d\n", ctx.FilledLevelCount))
sb.WriteString(fmt.Sprintf("- 网格已暂停: %v\n", ctx.IsPaused))
if ctx.CurrentDirection != "" {
directionDescZh := map[string]string{
"neutral": "中性 (50%买+50%卖)",
"long": "做多 (100%买)",
"short": "做空 (100%卖)",
"long_bias": "偏多 (70%买+30%卖)",
"short_bias": "偏空 (30%买+70%卖)",
}
desc := directionDescZh[ctx.CurrentDirection]
if desc == "" {
desc = ctx.CurrentDirection
}
sb.WriteString(fmt.Sprintf("- 网格方向: %s\n", desc))
}
sb.WriteString("\n")
// Grid levels detail
@@ -376,6 +393,20 @@ func buildGridUserPromptEn(ctx *GridContext) string {
sb.WriteString(fmt.Sprintf("- Active Orders: %d\n", ctx.ActiveOrderCount))
sb.WriteString(fmt.Sprintf("- Filled Levels: %d\n", ctx.FilledLevelCount))
sb.WriteString(fmt.Sprintf("- Grid Paused: %v\n", ctx.IsPaused))
if ctx.CurrentDirection != "" {
directionDescEn := map[string]string{
"neutral": "Neutral (50% buy + 50% sell)",
"long": "Long (100% buy)",
"short": "Short (100% sell)",
"long_bias": "Long Bias (70% buy + 30% sell)",
"short_bias": "Short Bias (30% buy + 70% sell)",
}
desc := directionDescEn[ctx.CurrentDirection]
if desc == "" {
desc = ctx.CurrentDirection
}
sb.WriteString(fmt.Sprintf("- Grid Direction: %s\n", desc))
}
sb.WriteString("\n")
// Grid levels detail
+34
View File
@@ -226,3 +226,37 @@ const (
BreakoutMid BreakoutLevel = "mid"
BreakoutLong BreakoutLevel = "long"
)
// GridDirection represents the current grid trading direction bias
type GridDirection string
const (
GridDirectionNeutral GridDirection = "neutral" // 50% buy + 50% sell
GridDirectionLong GridDirection = "long" // 100% buy
GridDirectionShort GridDirection = "short" // 100% sell
GridDirectionLongBias GridDirection = "long_bias" // 70% buy + 30% sell (default)
GridDirectionShortBias GridDirection = "short_bias" // 30% buy + 70% sell (default)
)
// GetBuySellRatio returns the buy and sell ratio for this direction
// biasRatio is the ratio for biased directions (default 0.7 means 70%/30%)
func (d GridDirection) GetBuySellRatio(biasRatio float64) (buyRatio, sellRatio float64) {
if biasRatio <= 0 || biasRatio > 1 {
biasRatio = 0.7 // Default 70%/30%
}
switch d {
case GridDirectionNeutral:
return 0.5, 0.5
case GridDirectionLong:
return 1.0, 0.0
case GridDirectionShort:
return 0.0, 1.0
case GridDirectionLongBias:
return biasRatio, 1.0 - biasRatio
case GridDirectionShortBias:
return 1.0 - biasRatio, biasRatio
default:
return 0.5, 0.5
}
}
+9
View File
@@ -63,6 +63,10 @@ type GridConfigModel struct {
AIProvider string `json:"ai_provider" gorm:"default:deepseek"`
AIModel string `json:"ai_model" gorm:"default:deepseek-chat"`
IsActive bool `json:"is_active" gorm:"default:false"`
// Direction adjustment settings
EnableDirectionAdjust bool `json:"enable_direction_adjust" gorm:"default:false"`
DirectionBiasRatio float64 `json:"direction_bias_ratio" gorm:"default:0.7"`
}
func (GridConfigModel) TableName() string {
@@ -108,6 +112,11 @@ type GridInstanceModel struct {
// Position adjustment due to breakout
PositionReductionPct float64 `json:"position_reduction_pct" gorm:"default:0"` // 0 = normal, 50 = reduced
// Grid direction adjustment state
CurrentDirection string `json:"current_direction" gorm:"default:neutral"`
DirectionChangedAt time.Time `json:"direction_changed_at"`
DirectionChangeCount int `json:"direction_change_count" gorm:"default:0"`
TotalProfit float64 `json:"total_profit" gorm:"default:0"`
TotalFees float64 `json:"total_fees" gorm:"default:0"`
TotalTrades int `json:"total_trades" gorm:"default:0"`
+4
View File
@@ -81,6 +81,10 @@ type GridStrategyConfig struct {
DailyLossLimitPct float64 `json:"daily_loss_limit_pct"`
// Use maker-only orders for lower fees
UseMakerOnly bool `json:"use_maker_only"`
// Enable automatic grid direction adjustment based on box breakouts
EnableDirectionAdjust bool `json:"enable_direction_adjust"`
// Direction bias ratio for long_bias/short_bias modes (default 0.7 = 70%/30%)
DirectionBiasRatio float64 `json:"direction_bias_ratio"`
}
// PromptSectionsConfig editable sections of System Prompt
+277 -6
View File
@@ -65,14 +65,20 @@ type GridState struct {
// Current regime level
CurrentRegimeLevel string
// Grid direction adjustment
CurrentDirection market.GridDirection
DirectionChangedAt time.Time
DirectionChangeCount int
}
// NewGridState creates a new grid state
func NewGridState(config *store.GridStrategyConfig) *GridState {
return &GridState{
Config: config,
Levels: make([]kernel.GridLevelInfo, 0),
OrderBook: make(map[string]int),
Config: config,
Levels: make([]kernel.GridLevelInfo, 0),
OrderBook: make(map[string]int),
CurrentDirection: market.GridDirectionNeutral,
}
}
@@ -325,7 +331,17 @@ func (at *AutoTrader) checkBoxBreakout() error {
}
// Take action based on breakout level
action := getBreakoutAction(breakoutLevel)
// Use direction-aware action if enabled
enableDirectionAdjust := gridConfig.EnableDirectionAdjust
action := getBreakoutActionWithDirection(breakoutLevel, enableDirectionAdjust)
// If direction adjustment action, determine the new direction
if action == BreakoutActionAdjustDirection {
box, _ := market.GetBoxData(gridConfig.Symbol)
newDirection := determineGridDirection(box, at.gridState.CurrentDirection, breakoutLevel, direction)
return at.executeDirectionAdjustment(newDirection)
}
return at.executeBreakoutAction(action)
}
@@ -358,11 +374,38 @@ func (at *AutoTrader) executeBreakoutAction(action BreakoutAction) error {
logger.Infof("Failed to cancel orders: %v", err)
}
return at.closeAllPositions()
case BreakoutActionAdjustDirection:
// Direction adjustment is handled separately via executeDirectionAdjustment
// This case should not be reached, but handle gracefully
logger.Infof("Direction adjustment action received via executeBreakoutAction")
return nil
}
return nil
}
// executeDirectionAdjustment handles grid direction changes based on box breakout
func (at *AutoTrader) executeDirectionAdjustment(newDirection market.GridDirection) error {
at.gridState.mu.RLock()
oldDirection := at.gridState.CurrentDirection
at.gridState.mu.RUnlock()
if oldDirection == newDirection {
return nil // No change needed
}
logger.Infof("[Grid] Direction adjustment: %s → %s", oldDirection, newDirection)
// Cancel existing orders before adjusting
if err := at.cancelAllGridOrders(); err != nil {
logger.Warnf("[Grid] Failed to cancel orders during direction adjustment: %v", err)
}
// Apply the new direction
return at.adjustGridDirection(newDirection)
}
// closeAllPositions closes all open positions for the grid symbol
func (at *AutoTrader) closeAllPositions() error {
gridConfig := at.config.StrategyConfig.GridConfig
@@ -410,10 +453,16 @@ func (at *AutoTrader) checkFalseBreakoutRecovery() error {
breakoutLevel := at.gridState.BreakoutLevel
isPaused := at.gridState.IsPaused
positionReduction := at.gridState.PositionReductionPct
currentDirection := at.gridState.CurrentDirection
at.gridState.mu.RUnlock()
// Only check if we had a breakout
if breakoutLevel == string(market.BreakoutNone) && positionReduction == 0 && !isPaused {
// Only check if we had a breakout or non-neutral direction
needsRecoveryCheck := breakoutLevel != string(market.BreakoutNone) ||
positionReduction != 0 ||
isPaused ||
(gridConfig.EnableDirectionAdjust && currentDirection != market.GridDirectionNeutral)
if !needsRecoveryCheck {
return nil
}
@@ -436,6 +485,18 @@ func (at *AutoTrader) checkFalseBreakoutRecovery() error {
at.gridState.mu.Unlock()
}
// Check for direction recovery toward neutral (if direction adjustment is enabled)
if gridConfig.EnableDirectionAdjust && currentDirection != market.GridDirectionNeutral {
if shouldRecoverDirection(box, currentDirection) {
newDirection := determineRecoveryDirection(box.CurrentPrice, box, currentDirection)
if newDirection != currentDirection {
logger.Infof("[Grid] Direction recovery: %s → %s (price back in short box)",
currentDirection, newDirection)
at.adjustGridDirection(newDirection)
}
}
}
return nil
}
@@ -570,6 +631,128 @@ func (at *AutoTrader) initializeGridLevels(currentPrice float64, config *store.G
}
at.gridState.Levels = levels
// Apply direction-based side assignment if enabled
if config.EnableDirectionAdjust {
at.applyGridDirection(currentPrice)
}
}
// applyGridDirection adjusts grid level sides based on the current direction
// This redistributes buy/sell levels according to the direction bias ratio
func (at *AutoTrader) applyGridDirection(currentPrice float64) {
config := at.gridState.Config
direction := at.gridState.CurrentDirection
// Get bias ratio from config, default to 0.7 (70%/30%)
biasRatio := config.DirectionBiasRatio
if biasRatio <= 0 || biasRatio > 1 {
biasRatio = 0.7
}
buyRatio, _ := direction.GetBuySellRatio(biasRatio)
// Calculate how many levels should be buy vs sell based on direction
totalLevels := len(at.gridState.Levels)
targetBuyLevels := int(float64(totalLevels) * buyRatio)
// For neutral: use price-based assignment (buy below, sell above)
if direction == market.GridDirectionNeutral {
for i := range at.gridState.Levels {
if at.gridState.Levels[i].Price <= currentPrice {
at.gridState.Levels[i].Side = "buy"
} else {
at.gridState.Levels[i].Side = "sell"
}
}
return
}
// For long/long_bias: more buy levels
// For short/short_bias: more sell levels
switch direction {
case market.GridDirectionLong:
// 100% buy - all levels are buy
for i := range at.gridState.Levels {
at.gridState.Levels[i].Side = "buy"
}
case market.GridDirectionShort:
// 100% sell - all levels are sell
for i := range at.gridState.Levels {
at.gridState.Levels[i].Side = "sell"
}
case market.GridDirectionLongBias, market.GridDirectionShortBias:
// Assign sides based on position relative to current price
// For long_bias: keep all below as buy, convert some above to buy
// For short_bias: keep all above as sell, convert some below to sell
buyCount := 0
sellCount := 0
for i := range at.gridState.Levels {
needMoreBuys := buyCount < targetBuyLevels
needMoreSells := sellCount < (totalLevels - targetBuyLevels)
if at.gridState.Levels[i].Price <= currentPrice {
// Level below or at current price
if needMoreBuys {
at.gridState.Levels[i].Side = "buy"
buyCount++
} else {
at.gridState.Levels[i].Side = "sell"
sellCount++
}
} else {
// Level above current price
if needMoreSells && direction == market.GridDirectionShortBias {
at.gridState.Levels[i].Side = "sell"
sellCount++
} else if needMoreBuys && direction == market.GridDirectionLongBias {
at.gridState.Levels[i].Side = "buy"
buyCount++
} else if needMoreSells {
at.gridState.Levels[i].Side = "sell"
sellCount++
} else {
at.gridState.Levels[i].Side = "buy"
buyCount++
}
}
}
}
logger.Infof("[Grid] Applied direction %s: buy_ratio=%.0f%%, levels reconfigured",
direction, buyRatio*100)
}
// adjustGridDirection handles runtime direction adjustment when breakout is detected
func (at *AutoTrader) adjustGridDirection(newDirection market.GridDirection) error {
at.gridState.mu.Lock()
defer at.gridState.mu.Unlock()
oldDirection := at.gridState.CurrentDirection
if oldDirection == newDirection {
return nil // No change needed
}
at.gridState.CurrentDirection = newDirection
at.gridState.DirectionChangedAt = time.Now()
at.gridState.DirectionChangeCount++
logger.Infof("[Grid] Direction changed: %s → %s (change count: %d)",
oldDirection, newDirection, at.gridState.DirectionChangeCount)
// Get current price for recalculation
currentPrice, err := at.trader.GetMarketPrice(at.gridState.Config.Symbol)
if err != nil {
return fmt.Errorf("failed to get market price: %w", err)
}
// Reapply direction to grid levels
at.applyGridDirection(currentPrice)
return nil
}
// RunGridCycle executes one grid trading cycle
@@ -1370,6 +1553,85 @@ func (at *AutoTrader) initializeGridLevelsLocked(currentPrice float64, config *s
}
at.gridState.Levels = levels
// Apply direction-based side assignment if enabled (note: caller holds lock)
if config.EnableDirectionAdjust {
at.applyGridDirectionLocked(currentPrice)
}
}
// applyGridDirectionLocked adjusts grid level sides based on the current direction (caller must hold lock)
func (at *AutoTrader) applyGridDirectionLocked(currentPrice float64) {
config := at.gridState.Config
direction := at.gridState.CurrentDirection
// Get bias ratio from config, default to 0.7 (70%/30%)
biasRatio := config.DirectionBiasRatio
if biasRatio <= 0 || biasRatio > 1 {
biasRatio = 0.7
}
buyRatio, _ := direction.GetBuySellRatio(biasRatio)
// For neutral: use price-based assignment (buy below, sell above)
if direction == market.GridDirectionNeutral {
for i := range at.gridState.Levels {
if at.gridState.Levels[i].Price <= currentPrice {
at.gridState.Levels[i].Side = "buy"
} else {
at.gridState.Levels[i].Side = "sell"
}
}
return
}
totalLevels := len(at.gridState.Levels)
targetBuyLevels := int(float64(totalLevels) * buyRatio)
switch direction {
case market.GridDirectionLong:
for i := range at.gridState.Levels {
at.gridState.Levels[i].Side = "buy"
}
case market.GridDirectionShort:
for i := range at.gridState.Levels {
at.gridState.Levels[i].Side = "sell"
}
case market.GridDirectionLongBias, market.GridDirectionShortBias:
buyCount := 0
sellCount := 0
for i := range at.gridState.Levels {
needMoreBuys := buyCount < targetBuyLevels
needMoreSells := sellCount < (totalLevels - targetBuyLevels)
if at.gridState.Levels[i].Price <= currentPrice {
if needMoreBuys {
at.gridState.Levels[i].Side = "buy"
buyCount++
} else {
at.gridState.Levels[i].Side = "sell"
sellCount++
}
} else {
if needMoreSells && direction == market.GridDirectionShortBias {
at.gridState.Levels[i].Side = "sell"
sellCount++
} else if needMoreBuys && direction == market.GridDirectionLongBias {
at.gridState.Levels[i].Side = "buy"
buyCount++
} else if needMoreSells {
at.gridState.Levels[i].Side = "sell"
sellCount++
} else {
at.gridState.Levels[i].Side = "buy"
buyCount++
}
}
}
}
}
// GridRiskInfo contains risk information for frontend display
@@ -1397,6 +1659,11 @@ type GridRiskInfo struct {
BreakoutLevel string `json:"breakout_level"`
BreakoutDirection string `json:"breakout_direction"`
// Grid direction
CurrentGridDirection string `json:"current_grid_direction"`
DirectionChangeCount int `json:"direction_change_count"`
EnableDirectionAdjust bool `json:"enable_direction_adjust"`
}
// GetGridRiskInfo returns current risk information for frontend display
@@ -1513,6 +1780,10 @@ func (at *AutoTrader) GetGridRiskInfo() *GridRiskInfo {
BreakoutLevel: at.gridState.BreakoutLevel,
BreakoutDirection: at.gridState.BreakoutDirection,
CurrentGridDirection: string(at.gridState.CurrentDirection),
DirectionChangeCount: at.gridState.DirectionChangeCount,
EnableDirectionAdjust: gridConfig.EnableDirectionAdjust,
}
}
+116
View File
@@ -194,3 +194,119 @@ func getBreakoutAction(level market.BreakoutLevel) BreakoutAction {
return BreakoutActionNone
}
}
// ============================================================================
// Task 10: Grid Direction Adjustment
// ============================================================================
const (
// BreakoutActionAdjustDirection adjusts grid direction based on breakout
BreakoutActionAdjustDirection BreakoutAction = 4
)
// determineGridDirection determines the new grid direction based on box breakout
// currentDirection: the current grid direction
// breakoutLevel: which box level has been broken (short/mid/long)
// direction: breakout direction ("up" or "down")
// Returns: the new grid direction
func determineGridDirection(box *market.BoxData, currentDirection market.GridDirection, breakoutLevel market.BreakoutLevel, direction string) market.GridDirection {
if box == nil {
return currentDirection
}
price := box.CurrentPrice
switch breakoutLevel {
case market.BreakoutShort:
// Short box breakout: bias direction
// Still within mid box, so not a full trend yet
if direction == "up" {
return market.GridDirectionLongBias
}
return market.GridDirectionShortBias
case market.BreakoutMid:
// Mid box breakout: full direction
// More significant move, commit fully
if direction == "up" {
return market.GridDirectionLong
}
return market.GridDirectionShort
case market.BreakoutLong:
// Long box breakout: handled by existing emergency logic
// Return current direction, let existing handlers take over
return currentDirection
case market.BreakoutNone:
// No breakout - check if we should recover toward neutral
return determineRecoveryDirection(price, box, currentDirection)
default:
return currentDirection
}
}
// determineRecoveryDirection determines if grid direction should recover toward neutral
// This implements the gradual recovery logic: long → long_bias → neutral ← short_bias ← short
func determineRecoveryDirection(price float64, box *market.BoxData, currentDirection market.GridDirection) market.GridDirection {
// Check if price is back inside the short box
insideShortBox := price >= box.ShortLower && price <= box.ShortUpper
if !insideShortBox {
// Still outside short box, maintain current direction
return currentDirection
}
// Price is inside short box, start recovery toward neutral
switch currentDirection {
case market.GridDirectionLong:
// Full long → bias long
return market.GridDirectionLongBias
case market.GridDirectionLongBias:
// Bias long → neutral
return market.GridDirectionNeutral
case market.GridDirectionShort:
// Full short → bias short
return market.GridDirectionShortBias
case market.GridDirectionShortBias:
// Bias short → neutral
return market.GridDirectionNeutral
default:
return currentDirection
}
}
// getBreakoutActionWithDirection returns the appropriate action for a breakout level
// when direction adjustment is enabled
func getBreakoutActionWithDirection(level market.BreakoutLevel, enableDirectionAdjust bool) BreakoutAction {
if !enableDirectionAdjust {
// Fall back to original behavior
return getBreakoutAction(level)
}
switch level {
case market.BreakoutShort:
// Short box breakout with direction adjustment: adjust direction instead of reducing position
return BreakoutActionAdjustDirection
case market.BreakoutMid:
// Mid box breakout with direction adjustment: adjust to full direction
return BreakoutActionAdjustDirection
case market.BreakoutLong:
// Long box breakout: always trigger emergency handling
return BreakoutActionCloseAll
default:
return BreakoutActionNone
}
}
// shouldRecoverDirection checks if the current grid direction should start recovering toward neutral
func shouldRecoverDirection(box *market.BoxData, currentDirection market.GridDirection) bool {
if box == nil || currentDirection == market.GridDirectionNeutral {
return false
}
price := box.CurrentPrice
// Check if price is back inside the short box
return price >= box.ShortLower && price <= box.ShortUpper
}
+220
View File
@@ -120,3 +120,223 @@ func TestGetBreakoutAction(t *testing.T) {
})
}
}
// ============================================================================
// Grid Direction Tests
// ============================================================================
func TestGetBuySellRatio(t *testing.T) {
tests := []struct {
name string
direction market.GridDirection
biasRatio float64
wantBuy float64
wantSell float64
}{
{"neutral", market.GridDirectionNeutral, 0.7, 0.5, 0.5},
{"long", market.GridDirectionLong, 0.7, 1.0, 0.0},
{"short", market.GridDirectionShort, 0.7, 0.0, 1.0},
{"long_bias_default", market.GridDirectionLongBias, 0.7, 0.7, 0.3},
{"short_bias_default", market.GridDirectionShortBias, 0.7, 0.3, 0.7},
{"long_bias_custom", market.GridDirectionLongBias, 0.8, 0.8, 0.2},
{"short_bias_custom", market.GridDirectionShortBias, 0.8, 0.2, 0.8},
{"invalid_bias_uses_default", market.GridDirectionLongBias, 0, 0.7, 0.3},
{"negative_bias_uses_default", market.GridDirectionLongBias, -1, 0.7, 0.3},
}
const tolerance = 0.0001
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buy, sell := tt.direction.GetBuySellRatio(tt.biasRatio)
buyDiff := buy - tt.wantBuy
sellDiff := sell - tt.wantSell
if buyDiff < -tolerance || buyDiff > tolerance || sellDiff < -tolerance || sellDiff > tolerance {
t.Errorf("GetBuySellRatio(%v, %v) = (%v, %v), want (%v, %v)",
tt.direction, tt.biasRatio, buy, sell, tt.wantBuy, tt.wantSell)
}
})
}
}
func TestDetermineGridDirection(t *testing.T) {
box := &market.BoxData{
ShortUpper: 100,
ShortLower: 90,
MidUpper: 105,
MidLower: 85,
LongUpper: 110,
LongLower: 80,
CurrentPrice: 95,
}
tests := []struct {
name string
currentDirection market.GridDirection
breakoutLevel market.BreakoutLevel
direction string
expected market.GridDirection
}{
// Short box breakouts
{
name: "short_breakout_up_neutral",
currentDirection: market.GridDirectionNeutral,
breakoutLevel: market.BreakoutShort,
direction: "up",
expected: market.GridDirectionLongBias,
},
{
name: "short_breakout_down_neutral",
currentDirection: market.GridDirectionNeutral,
breakoutLevel: market.BreakoutShort,
direction: "down",
expected: market.GridDirectionShortBias,
},
// Mid box breakouts
{
name: "mid_breakout_up",
currentDirection: market.GridDirectionLongBias,
breakoutLevel: market.BreakoutMid,
direction: "up",
expected: market.GridDirectionLong,
},
{
name: "mid_breakout_down",
currentDirection: market.GridDirectionShortBias,
breakoutLevel: market.BreakoutMid,
direction: "down",
expected: market.GridDirectionShort,
},
// Long box breakout - maintains current (emergency handling)
{
name: "long_breakout_maintains",
currentDirection: market.GridDirectionLong,
breakoutLevel: market.BreakoutLong,
direction: "up",
expected: market.GridDirectionLong,
},
// No breakout - tests recovery logic
{
name: "no_breakout_neutral_stays",
currentDirection: market.GridDirectionNeutral,
breakoutLevel: market.BreakoutNone,
direction: "",
expected: market.GridDirectionNeutral,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := determineGridDirection(box, tt.currentDirection, tt.breakoutLevel, tt.direction)
if result != tt.expected {
t.Errorf("determineGridDirection() = %v, want %v", result, tt.expected)
}
})
}
}
func TestDetermineRecoveryDirection(t *testing.T) {
box := &market.BoxData{
ShortUpper: 100,
ShortLower: 90,
MidUpper: 105,
MidLower: 85,
LongUpper: 110,
LongLower: 80,
CurrentPrice: 95, // Inside short box
}
tests := []struct {
name string
price float64
currentDirection market.GridDirection
expected market.GridDirection
}{
// Inside short box - should recover
{"long_to_long_bias", 95, market.GridDirectionLong, market.GridDirectionLongBias},
{"long_bias_to_neutral", 95, market.GridDirectionLongBias, market.GridDirectionNeutral},
{"short_to_short_bias", 95, market.GridDirectionShort, market.GridDirectionShortBias},
{"short_bias_to_neutral", 95, market.GridDirectionShortBias, market.GridDirectionNeutral},
{"neutral_stays_neutral", 95, market.GridDirectionNeutral, market.GridDirectionNeutral},
// Outside short box - should maintain
{"long_outside_stays", 101, market.GridDirectionLong, market.GridDirectionLong},
{"short_outside_stays", 89, market.GridDirectionShort, market.GridDirectionShort},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := determineRecoveryDirection(tt.price, box, tt.currentDirection)
if result != tt.expected {
t.Errorf("determineRecoveryDirection(%v, %v) = %v, want %v",
tt.price, tt.currentDirection, result, tt.expected)
}
})
}
}
func TestGetBreakoutActionWithDirection(t *testing.T) {
tests := []struct {
name string
level market.BreakoutLevel
enableDirectionAdjust bool
expected BreakoutAction
}{
// Direction adjustment disabled - original behavior
{"short_disabled", market.BreakoutShort, false, BreakoutActionReducePosition},
{"mid_disabled", market.BreakoutMid, false, BreakoutActionPauseGrid},
{"long_disabled", market.BreakoutLong, false, BreakoutActionCloseAll},
// Direction adjustment enabled
{"short_enabled", market.BreakoutShort, true, BreakoutActionAdjustDirection},
{"mid_enabled", market.BreakoutMid, true, BreakoutActionAdjustDirection},
{"long_enabled", market.BreakoutLong, true, BreakoutActionCloseAll}, // Long always triggers emergency
{"none_enabled", market.BreakoutNone, true, BreakoutActionNone},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
action := getBreakoutActionWithDirection(tt.level, tt.enableDirectionAdjust)
if action != tt.expected {
t.Errorf("getBreakoutActionWithDirection(%v, %v) = %v, want %v",
tt.level, tt.enableDirectionAdjust, action, tt.expected)
}
})
}
}
func TestShouldRecoverDirection(t *testing.T) {
box := &market.BoxData{
ShortUpper: 100,
ShortLower: 90,
MidUpper: 105,
MidLower: 85,
LongUpper: 110,
LongLower: 80,
CurrentPrice: 95,
}
tests := []struct {
name string
price float64
direction market.GridDirection
expected bool
}{
{"neutral_inside_no_recovery", 95, market.GridDirectionNeutral, false},
{"long_inside_should_recover", 95, market.GridDirectionLong, true},
{"long_outside_no_recovery", 101, market.GridDirectionLong, false},
{"short_inside_should_recover", 95, market.GridDirectionShort, true},
{"short_outside_no_recovery", 89, market.GridDirectionShort, false},
{"long_bias_inside_should_recover", 95, market.GridDirectionLongBias, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
box.CurrentPrice = tt.price
result := shouldRecoverDirection(box, tt.direction)
if result != tt.expected {
t.Errorf("shouldRecoverDirection(price=%v, %v) = %v, want %v",
tt.price, tt.direction, result, tt.expected)
}
})
}
}