Files
nofx/store/trader.go
T
tinkle-community bdfd8dc0d0 fix: auto-restart trader on config update and add scan interval debug logs
- RemoveTrader now stops running trader before removing from memory
- handleUpdateTrader auto-restarts trader if it was running before update
- Add debug logs to trace scan_interval_minutes through update/save/load flow
2026-01-04 01:27:30 +08:00

251 lines
8.6 KiB
Go

package store
import (
"fmt"
"time"
"gorm.io/gorm"
)
// TraderStore trader storage
type TraderStore struct {
db *gorm.DB
}
// NewTraderStore creates a new trader store
func NewTraderStore(db *gorm.DB) *TraderStore {
return &TraderStore{db: db}
}
// Trader trader configuration
type Trader struct {
ID string `gorm:"primaryKey" json:"id"`
UserID string `gorm:"column:user_id;not null;default:default;index" json:"user_id"`
Name string `gorm:"column:name;not null" json:"name"`
AIModelID string `gorm:"column:ai_model_id;not null" json:"ai_model_id"`
ExchangeID string `gorm:"column:exchange_id;not null" json:"exchange_id"`
StrategyID string `gorm:"column:strategy_id;default:''" json:"strategy_id"`
InitialBalance float64 `gorm:"column:initial_balance;not null" json:"initial_balance"`
ScanIntervalMinutes int `gorm:"column:scan_interval_minutes;default:3" json:"scan_interval_minutes"`
IsRunning bool `gorm:"column:is_running;default:false" json:"is_running"`
IsCrossMargin bool `gorm:"column:is_cross_margin;default:true" json:"is_cross_margin"`
ShowInCompetition bool `gorm:"column:show_in_competition;default:true" json:"show_in_competition"`
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime" json:"updated_at"`
// Following fields are deprecated, kept for backward compatibility, new traders should use StrategyID
BTCETHLeverage int `gorm:"column:btc_eth_leverage;default:5" json:"btc_eth_leverage,omitempty"`
AltcoinLeverage int `gorm:"column:altcoin_leverage;default:5" json:"altcoin_leverage,omitempty"`
TradingSymbols string `gorm:"column:trading_symbols;default:''" json:"trading_symbols,omitempty"`
UseAI500 bool `gorm:"column:use_coin_pool;default:false" json:"use_ai500,omitempty"`
UseOITop bool `gorm:"column:use_oi_top;default:false" json:"use_oi_top,omitempty"`
CustomPrompt string `gorm:"column:custom_prompt;default:''" json:"custom_prompt,omitempty"`
OverrideBasePrompt bool `gorm:"column:override_base_prompt;default:false" json:"override_base_prompt,omitempty"`
SystemPromptTemplate string `gorm:"column:system_prompt_template;default:default" json:"system_prompt_template,omitempty"`
}
// TableName returns the table name for Trader
func (Trader) TableName() string {
return "traders"
}
// TraderFullConfig trader full configuration (includes AI model, exchange and strategy)
type TraderFullConfig struct {
Trader *Trader
AIModel *AIModel
Exchange *Exchange
Strategy *Strategy
}
func (s *TraderStore) initTables() error {
// For PostgreSQL with existing table, skip AutoMigrate
if s.db.Dialector.Name() == "postgres" {
var tableExists int64
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'traders'`).Scan(&tableExists)
if tableExists > 0 {
return nil
}
}
// Use GORM AutoMigrate
if err := s.db.AutoMigrate(&Trader{}); err != nil {
return fmt.Errorf("failed to migrate traders table: %w", err)
}
return nil
}
// Create creates trader
func (s *TraderStore) Create(trader *Trader) error {
return s.db.Create(trader).Error
}
// List gets user's trader list
func (s *TraderStore) List(userID string) ([]*Trader, error) {
var traders []*Trader
err := s.db.Where("user_id = ?", userID).
Order("created_at DESC").
Find(&traders).Error
if err != nil {
return nil, err
}
return traders, nil
}
// UpdateStatus updates trader running status
func (s *TraderStore) UpdateStatus(userID, id string, isRunning bool) error {
return s.db.Model(&Trader{}).
Where("id = ? AND user_id = ?", id, userID).
Update("is_running", isRunning).Error
}
// UpdateShowInCompetition updates trader competition visibility
func (s *TraderStore) UpdateShowInCompetition(userID, id string, showInCompetition bool) error {
return s.db.Model(&Trader{}).
Where("id = ? AND user_id = ?", id, userID).
Update("show_in_competition", showInCompetition).Error
}
// Update updates trader configuration
func (s *TraderStore) Update(trader *Trader) error {
fmt.Printf("📝 TraderStore.Update: ID=%s, Name=%s, AIModelID=%s, StrategyID=%s\n",
trader.ID, trader.Name, trader.AIModelID, trader.StrategyID)
updates := map[string]interface{}{
"name": trader.Name,
"ai_model_id": trader.AIModelID,
"exchange_id": trader.ExchangeID,
"strategy_id": trader.StrategyID,
"is_cross_margin": trader.IsCrossMargin,
"show_in_competition": trader.ShowInCompetition,
}
// Only update these if > 0
if trader.InitialBalance > 0 {
updates["initial_balance"] = trader.InitialBalance
}
if trader.ScanIntervalMinutes > 0 {
updates["scan_interval_minutes"] = trader.ScanIntervalMinutes
fmt.Printf("📊 TraderStore.Update: scan_interval_minutes=%d will be saved\n", trader.ScanIntervalMinutes)
} else {
fmt.Printf("⚠️ TraderStore.Update: scan_interval_minutes=%d (<=0, NOT updating)\n", trader.ScanIntervalMinutes)
}
return s.db.Model(&Trader{}).
Where("id = ? AND user_id = ?", trader.ID, trader.UserID).
Updates(updates).Error
}
// UpdateInitialBalance updates initial balance
func (s *TraderStore) UpdateInitialBalance(userID, id string, newBalance float64) error {
return s.db.Model(&Trader{}).
Where("id = ? AND user_id = ?", id, userID).
Update("initial_balance", newBalance).Error
}
// UpdateCustomPrompt updates custom prompt
func (s *TraderStore) UpdateCustomPrompt(userID, id string, customPrompt string, overrideBase bool) error {
return s.db.Model(&Trader{}).
Where("id = ? AND user_id = ?", id, userID).
Updates(map[string]interface{}{
"custom_prompt": customPrompt,
"override_base_prompt": overrideBase,
}).Error
}
// Delete deletes trader and associated data
func (s *TraderStore) Delete(userID, id string) error {
// Delete associated equity snapshots first
s.db.Where("trader_id = ?", id).Delete(&EquitySnapshot{})
// Delete the trader
return s.db.Where("id = ? AND user_id = ?", id, userID).Delete(&Trader{}).Error
}
// GetFullConfig gets trader full configuration
func (s *TraderStore) GetFullConfig(userID, traderID string) (*TraderFullConfig, error) {
var trader Trader
err := s.db.Where("id = ? AND user_id = ?", traderID, userID).First(&trader).Error
if err != nil {
return nil, err
}
// Get AI model
var aiModel AIModel
err = s.db.Where("id = ? AND user_id = ?", trader.AIModelID, userID).First(&aiModel).Error
if err != nil {
return nil, fmt.Errorf("failed to get AI model: %w", err)
}
// Get exchange
var exchange Exchange
err = s.db.Where("id = ? AND user_id = ?", trader.ExchangeID, userID).First(&exchange).Error
if err != nil {
return nil, fmt.Errorf("failed to get exchange: %w", err)
}
// Load associated strategy
var strategy *Strategy
if trader.StrategyID != "" {
strategy, _ = s.getStrategyByID(userID, trader.StrategyID)
}
// If no associated strategy, get user's active strategy or default strategy
if strategy == nil {
strategy, _ = s.getActiveOrDefaultStrategy(userID)
}
return &TraderFullConfig{
Trader: &trader,
AIModel: &aiModel,
Exchange: &exchange,
Strategy: strategy,
}, nil
}
// getStrategyByID internal method: gets strategy by ID
func (s *TraderStore) getStrategyByID(userID, strategyID string) (*Strategy, error) {
var strategy Strategy
err := s.db.Where("id = ? AND (user_id = ? OR is_default = ?)", strategyID, userID, true).
First(&strategy).Error
if err != nil {
return nil, err
}
return &strategy, nil
}
// getActiveOrDefaultStrategy internal method: gets user's active strategy or system default strategy
func (s *TraderStore) getActiveOrDefaultStrategy(userID string) (*Strategy, error) {
var strategy Strategy
// First try to get user's active strategy
err := s.db.Where("user_id = ? AND is_active = ?", userID, true).First(&strategy).Error
if err == nil {
return &strategy, nil
}
// Fallback to system default strategy
err = s.db.Where("is_default = ?", true).First(&strategy).Error
if err != nil {
return nil, err
}
return &strategy, nil
}
// GetByID gets a trader by ID without requiring userID (for public APIs)
func (s *TraderStore) GetByID(traderID string) (*Trader, error) {
var trader Trader
err := s.db.Where("id = ?", traderID).First(&trader).Error
if err != nil {
return nil, err
}
return &trader, nil
}
// ListAll gets all traders
func (s *TraderStore) ListAll() ([]*Trader, error) {
var traders []*Trader
err := s.db.Order("created_at DESC").Find(&traders).Error
if err != nil {
return nil, err
}
return traders, nil
}