mirror of
https://github.com/laoxong/nofx.git
synced 2026-06-04 01:48:22 +08:00
2d272bb7b8
- Migrate all store packages from raw database/sql to GORM ORM - Add PostgreSQL support alongside SQLite - Move EncryptedString type to crypto package for cleaner architecture - Add automatic encryption/decryption for sensitive fields (API keys, secrets) - Fix PostgreSQL AutoMigrate conflicts by skipping existing tables - Fix duplicate /klines route registration - Update tests to use GORM database connections - Add database configuration support in config package
248 lines
8.3 KiB
Go
248 lines
8.3 KiB
Go
package store
|
|
|
|
import (
|
|
"fmt"
|
|
"time"
|
|
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// TraderStore trader storage
|
|
type TraderStore struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
// NewTraderStore creates a new trader store
|
|
func NewTraderStore(db *gorm.DB) *TraderStore {
|
|
return &TraderStore{db: db}
|
|
}
|
|
|
|
// Trader trader configuration
|
|
type Trader struct {
|
|
ID string `gorm:"primaryKey" json:"id"`
|
|
UserID string `gorm:"column:user_id;not null;default:default;index" json:"user_id"`
|
|
Name string `gorm:"column:name;not null" json:"name"`
|
|
AIModelID string `gorm:"column:ai_model_id;not null" json:"ai_model_id"`
|
|
ExchangeID string `gorm:"column:exchange_id;not null" json:"exchange_id"`
|
|
StrategyID string `gorm:"column:strategy_id;default:''" json:"strategy_id"`
|
|
InitialBalance float64 `gorm:"column:initial_balance;not null" json:"initial_balance"`
|
|
ScanIntervalMinutes int `gorm:"column:scan_interval_minutes;default:3" json:"scan_interval_minutes"`
|
|
IsRunning bool `gorm:"column:is_running;default:false" json:"is_running"`
|
|
IsCrossMargin bool `gorm:"column:is_cross_margin;default:true" json:"is_cross_margin"`
|
|
ShowInCompetition bool `gorm:"column:show_in_competition;default:true" json:"show_in_competition"`
|
|
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
|
|
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime" json:"updated_at"`
|
|
|
|
// Following fields are deprecated, kept for backward compatibility, new traders should use StrategyID
|
|
BTCETHLeverage int `gorm:"column:btc_eth_leverage;default:5" json:"btc_eth_leverage,omitempty"`
|
|
AltcoinLeverage int `gorm:"column:altcoin_leverage;default:5" json:"altcoin_leverage,omitempty"`
|
|
TradingSymbols string `gorm:"column:trading_symbols;default:''" json:"trading_symbols,omitempty"`
|
|
UseCoinPool bool `gorm:"column:use_coin_pool;default:false" json:"use_coin_pool,omitempty"`
|
|
UseOITop bool `gorm:"column:use_oi_top;default:false" json:"use_oi_top,omitempty"`
|
|
CustomPrompt string `gorm:"column:custom_prompt;default:''" json:"custom_prompt,omitempty"`
|
|
OverrideBasePrompt bool `gorm:"column:override_base_prompt;default:false" json:"override_base_prompt,omitempty"`
|
|
SystemPromptTemplate string `gorm:"column:system_prompt_template;default:default" json:"system_prompt_template,omitempty"`
|
|
}
|
|
|
|
// TableName returns the table name for Trader
|
|
func (Trader) TableName() string {
|
|
return "traders"
|
|
}
|
|
|
|
// TraderFullConfig trader full configuration (includes AI model, exchange and strategy)
|
|
type TraderFullConfig struct {
|
|
Trader *Trader
|
|
AIModel *AIModel
|
|
Exchange *Exchange
|
|
Strategy *Strategy
|
|
}
|
|
|
|
func (s *TraderStore) initTables() error {
|
|
// For PostgreSQL with existing table, skip AutoMigrate
|
|
if s.db.Dialector.Name() == "postgres" {
|
|
var tableExists int64
|
|
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'traders'`).Scan(&tableExists)
|
|
if tableExists > 0 {
|
|
return nil
|
|
}
|
|
}
|
|
// Use GORM AutoMigrate
|
|
if err := s.db.AutoMigrate(&Trader{}); err != nil {
|
|
return fmt.Errorf("failed to migrate traders table: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Create creates trader
|
|
func (s *TraderStore) Create(trader *Trader) error {
|
|
return s.db.Create(trader).Error
|
|
}
|
|
|
|
// List gets user's trader list
|
|
func (s *TraderStore) List(userID string) ([]*Trader, error) {
|
|
var traders []*Trader
|
|
err := s.db.Where("user_id = ?", userID).
|
|
Order("created_at DESC").
|
|
Find(&traders).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return traders, nil
|
|
}
|
|
|
|
// UpdateStatus updates trader running status
|
|
func (s *TraderStore) UpdateStatus(userID, id string, isRunning bool) error {
|
|
return s.db.Model(&Trader{}).
|
|
Where("id = ? AND user_id = ?", id, userID).
|
|
Update("is_running", isRunning).Error
|
|
}
|
|
|
|
// UpdateShowInCompetition updates trader competition visibility
|
|
func (s *TraderStore) UpdateShowInCompetition(userID, id string, showInCompetition bool) error {
|
|
return s.db.Model(&Trader{}).
|
|
Where("id = ? AND user_id = ?", id, userID).
|
|
Update("show_in_competition", showInCompetition).Error
|
|
}
|
|
|
|
// Update updates trader configuration
|
|
func (s *TraderStore) Update(trader *Trader) error {
|
|
fmt.Printf("📝 TraderStore.Update: ID=%s, Name=%s, AIModelID=%s, StrategyID=%s\n",
|
|
trader.ID, trader.Name, trader.AIModelID, trader.StrategyID)
|
|
|
|
updates := map[string]interface{}{
|
|
"name": trader.Name,
|
|
"ai_model_id": trader.AIModelID,
|
|
"exchange_id": trader.ExchangeID,
|
|
"strategy_id": trader.StrategyID,
|
|
"is_cross_margin": trader.IsCrossMargin,
|
|
"show_in_competition": trader.ShowInCompetition,
|
|
}
|
|
|
|
// Only update these if > 0
|
|
if trader.InitialBalance > 0 {
|
|
updates["initial_balance"] = trader.InitialBalance
|
|
}
|
|
if trader.ScanIntervalMinutes > 0 {
|
|
updates["scan_interval_minutes"] = trader.ScanIntervalMinutes
|
|
}
|
|
|
|
return s.db.Model(&Trader{}).
|
|
Where("id = ? AND user_id = ?", trader.ID, trader.UserID).
|
|
Updates(updates).Error
|
|
}
|
|
|
|
// UpdateInitialBalance updates initial balance
|
|
func (s *TraderStore) UpdateInitialBalance(userID, id string, newBalance float64) error {
|
|
return s.db.Model(&Trader{}).
|
|
Where("id = ? AND user_id = ?", id, userID).
|
|
Update("initial_balance", newBalance).Error
|
|
}
|
|
|
|
// UpdateCustomPrompt updates custom prompt
|
|
func (s *TraderStore) UpdateCustomPrompt(userID, id string, customPrompt string, overrideBase bool) error {
|
|
return s.db.Model(&Trader{}).
|
|
Where("id = ? AND user_id = ?", id, userID).
|
|
Updates(map[string]interface{}{
|
|
"custom_prompt": customPrompt,
|
|
"override_base_prompt": overrideBase,
|
|
}).Error
|
|
}
|
|
|
|
// Delete deletes trader and associated data
|
|
func (s *TraderStore) Delete(userID, id string) error {
|
|
// Delete associated equity snapshots first
|
|
s.db.Where("trader_id = ?", id).Delete(&EquitySnapshot{})
|
|
|
|
// Delete the trader
|
|
return s.db.Where("id = ? AND user_id = ?", id, userID).Delete(&Trader{}).Error
|
|
}
|
|
|
|
// GetFullConfig gets trader full configuration
|
|
func (s *TraderStore) GetFullConfig(userID, traderID string) (*TraderFullConfig, error) {
|
|
var trader Trader
|
|
err := s.db.Where("id = ? AND user_id = ?", traderID, userID).First(&trader).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Get AI model
|
|
var aiModel AIModel
|
|
err = s.db.Where("id = ? AND user_id = ?", trader.AIModelID, userID).First(&aiModel).Error
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get AI model: %w", err)
|
|
}
|
|
|
|
// Get exchange
|
|
var exchange Exchange
|
|
err = s.db.Where("id = ? AND user_id = ?", trader.ExchangeID, userID).First(&exchange).Error
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get exchange: %w", err)
|
|
}
|
|
|
|
// Load associated strategy
|
|
var strategy *Strategy
|
|
if trader.StrategyID != "" {
|
|
strategy, _ = s.getStrategyByID(userID, trader.StrategyID)
|
|
}
|
|
// If no associated strategy, get user's active strategy or default strategy
|
|
if strategy == nil {
|
|
strategy, _ = s.getActiveOrDefaultStrategy(userID)
|
|
}
|
|
|
|
return &TraderFullConfig{
|
|
Trader: &trader,
|
|
AIModel: &aiModel,
|
|
Exchange: &exchange,
|
|
Strategy: strategy,
|
|
}, nil
|
|
}
|
|
|
|
// getStrategyByID internal method: gets strategy by ID
|
|
func (s *TraderStore) getStrategyByID(userID, strategyID string) (*Strategy, error) {
|
|
var strategy Strategy
|
|
err := s.db.Where("id = ? AND (user_id = ? OR is_default = ?)", strategyID, userID, true).
|
|
First(&strategy).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &strategy, nil
|
|
}
|
|
|
|
// getActiveOrDefaultStrategy internal method: gets user's active strategy or system default strategy
|
|
func (s *TraderStore) getActiveOrDefaultStrategy(userID string) (*Strategy, error) {
|
|
var strategy Strategy
|
|
|
|
// First try to get user's active strategy
|
|
err := s.db.Where("user_id = ? AND is_active = ?", userID, true).First(&strategy).Error
|
|
if err == nil {
|
|
return &strategy, nil
|
|
}
|
|
|
|
// Fallback to system default strategy
|
|
err = s.db.Where("is_default = ?", true).First(&strategy).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &strategy, nil
|
|
}
|
|
|
|
// GetByID gets a trader by ID without requiring userID (for public APIs)
|
|
func (s *TraderStore) GetByID(traderID string) (*Trader, error) {
|
|
var trader Trader
|
|
err := s.db.Where("id = ?", traderID).First(&trader).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &trader, nil
|
|
}
|
|
|
|
// ListAll gets all traders
|
|
func (s *TraderStore) ListAll() ([]*Trader, error) {
|
|
var traders []*Trader
|
|
err := s.db.Order("created_at DESC").Find(&traders).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return traders, nil
|
|
}
|