Files
nofx/config/database_pg.go
T
icy 7d58f56e49 feat: implement hybrid database architecture and frontend encryption
- Add PostgreSQL + SQLite hybrid database support with automatic switching
- Implement frontend AES-GCM + RSA-OAEP encryption for sensitive data
- Add comprehensive DatabaseInterface with all required methods
- Fix compilation issues with interface consistency
- Update all database method signatures to use DatabaseInterface
- Add missing UpdateTraderInitialBalance method to PostgreSQL implementation
- Integrate RSA public key distribution via /api/config endpoint
- Add frontend crypto service with proper error handling
- Support graceful degradation between encrypted and plaintext transmission
- Add directory creation for RSA keys and PEM parsing fixes
- Test both SQLite and PostgreSQL modes successfully
🤖 Generated with [Claude Code](https://claude.ai/code)
Co-Authored-By: tinkle-community <tinklefund@gmail.com>
2025-11-06 01:50:06 +08:00

707 lines
24 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package config
import (
"database/sql"
"encoding/json"
"fmt"
"log"
"nofx/market"
"os"
"slices"
"strings"
"time"
_ "github.com/lib/pq"
)
// PostgreSQLDatabase PostgreSQL数据库配置
type PostgreSQLDatabase struct {
db *sql.DB
}
// NewPostgreSQLDatabase 创建PostgreSQL数据库连接
func NewPostgreSQLDatabase() (*PostgreSQLDatabase, error) {
// 从环境变量获取数据库连接信息
host := getEnv("POSTGRES_HOST", "localhost")
port := getEnv("POSTGRES_PORT", "5432")
dbname := getEnv("POSTGRES_DB", "nofx")
user := getEnv("POSTGRES_USER", "nofx")
password := getEnv("POSTGRES_PASSWORD", "nofx123456")
// 构建连接字符串
dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable",
host, port, user, password, dbname)
log.Printf("📋 连接PostgreSQL数据库: %s:%s/%s", host, port, dbname)
db, err := sql.Open("postgres", dsn)
if err != nil {
return nil, fmt.Errorf("打开PostgreSQL数据库失败: %w", err)
}
// 测试连接
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("连接PostgreSQL数据库失败: %w", err)
}
// 设置连接池参数
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(time.Hour)
database := &PostgreSQLDatabase{db: db}
log.Printf("✅ PostgreSQL数据库连接成功")
return database, nil
}
// getEnv 获取环境变量,如果不存在返回默认值
func getEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
// CreateUser 创建用户
func (d *PostgreSQLDatabase) CreateUser(user *User) error {
_, err := d.db.Exec(`
INSERT INTO users (id, email, password_hash, otp_secret, otp_verified)
VALUES ($1, $2, $3, $4, $5)
`, user.ID, user.Email, user.PasswordHash, user.OTPSecret, user.OTPVerified)
return err
}
// EnsureAdminUser 确保admin用户存在(用于管理员模式)
func (d *PostgreSQLDatabase) EnsureAdminUser() error {
// 检查admin用户是否已存在
var count int
err := d.db.QueryRow(`SELECT COUNT(*) FROM users WHERE id = 'admin'`).Scan(&count)
if err != nil {
return err
}
// 如果已存在,直接返回
if count > 0 {
return nil
}
// 创建admin用户(密码为空,因为管理员模式下不需要密码)
adminUser := &User{
ID: "admin",
Email: "admin@localhost",
PasswordHash: "", // 管理员模式下不使用密码
OTPSecret: "",
OTPVerified: true,
}
return d.CreateUser(adminUser)
}
// GetUserByEmail 通过邮箱获取用户
func (d *PostgreSQLDatabase) GetUserByEmail(email string) (*User, error) {
var user User
err := d.db.QueryRow(`
SELECT id, email, password_hash, otp_secret, otp_verified, created_at, updated_at
FROM users WHERE email = $1
`, email).Scan(
&user.ID, &user.Email, &user.PasswordHash, &user.OTPSecret,
&user.OTPVerified, &user.CreatedAt, &user.UpdatedAt,
)
if err != nil {
return nil, err
}
return &user, nil
}
// GetUserByID 通过ID获取用户
func (d *PostgreSQLDatabase) GetUserByID(userID string) (*User, error) {
var user User
err := d.db.QueryRow(`
SELECT id, email, password_hash, otp_secret, otp_verified, created_at, updated_at
FROM users WHERE id = $1
`, userID).Scan(
&user.ID, &user.Email, &user.PasswordHash, &user.OTPSecret,
&user.OTPVerified, &user.CreatedAt, &user.UpdatedAt,
)
if err != nil {
return nil, err
}
return &user, nil
}
// GetAllUsers 获取所有用户ID列表
func (d *PostgreSQLDatabase) GetAllUsers() ([]string, error) {
rows, err := d.db.Query(`SELECT id FROM users ORDER BY id`)
if err != nil {
return nil, err
}
defer rows.Close()
var userIDs []string
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return nil, err
}
userIDs = append(userIDs, userID)
}
return userIDs, nil
}
// UpdateUserOTPVerified 更新用户OTP验证状态
func (d *PostgreSQLDatabase) UpdateUserOTPVerified(userID string, verified bool) error {
_, err := d.db.Exec(`UPDATE users SET otp_verified = $1 WHERE id = $2`, verified, userID)
return err
}
// GetAIModels 获取用户的AI模型配置
func (d *PostgreSQLDatabase) GetAIModels(userID string) ([]*AIModelConfig, error) {
rows, err := d.db.Query(`
SELECT id, user_id, name, provider, enabled, api_key,
COALESCE(custom_api_url, '') as custom_api_url,
COALESCE(custom_model_name, '') as custom_model_name,
created_at, updated_at
FROM ai_models WHERE user_id = $1 ORDER BY id
`, userID)
if err != nil {
return nil, err
}
defer rows.Close()
// 初始化为空切片而不是nil,确保JSON序列化为[]而不是null
models := make([]*AIModelConfig, 0)
for rows.Next() {
var model AIModelConfig
err := rows.Scan(
&model.ID, &model.UserID, &model.Name, &model.Provider,
&model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName,
&model.CreatedAt, &model.UpdatedAt,
)
if err != nil {
return nil, err
}
models = append(models, &model)
}
return models, nil
}
// UpdateAIModel 更新AI模型配置,如果不存在则创建用户特定配置
func (d *PostgreSQLDatabase) UpdateAIModel(userID, id string, enabled bool, apiKey, customAPIURL, customModelName string) error {
// 先尝试精确匹配 ID(新版逻辑,支持多个相同 provider 的模型)
var existingID string
err := d.db.QueryRow(`
SELECT id FROM ai_models WHERE user_id = $1 AND id = $2 LIMIT 1
`, userID, id).Scan(&existingID)
if err == nil {
// 找到了现有配置(精确匹配 ID),更新它
_, err = d.db.Exec(`
UPDATE ai_models SET enabled = $1, api_key = $2, custom_api_url = $3, custom_model_name = $4, updated_at = CURRENT_TIMESTAMP
WHERE id = $5 AND user_id = $6
`, enabled, apiKey, customAPIURL, customModelName, existingID, userID)
return err
}
// ID 不存在,尝试兼容旧逻辑:将 id 作为 provider 查找
provider := id
err = d.db.QueryRow(`
SELECT id FROM ai_models WHERE user_id = $1 AND provider = $2 LIMIT 1
`, userID, provider).Scan(&existingID)
if err == nil {
// 找到了现有配置(通过 provider 匹配,兼容旧版),更新它
log.Printf("⚠️ 使用旧版 provider 匹配更新模型: %s -> %s", provider, existingID)
_, err = d.db.Exec(`
UPDATE ai_models SET enabled = $1, api_key = $2, custom_api_url = $3, custom_model_name = $4, updated_at = CURRENT_TIMESTAMP
WHERE id = $5 AND user_id = $6
`, enabled, apiKey, customAPIURL, customModelName, existingID, userID)
return err
}
// 没有找到任何现有配置,创建新的
// 推断 provider(从 id 中提取,或者直接使用 id)
if provider == id && (provider == "deepseek" || provider == "qwen") {
// id 本身就是 provider
provider = id
} else {
// 从 id 中提取 provider(假设格式是 userID_provider 或 timestamp_userID_provider
parts := strings.Split(id, "_")
if len(parts) >= 2 {
provider = parts[len(parts)-1] // 取最后一部分作为 provider
} else {
provider = id
}
}
// 获取模型的基本信息
var name string
err = d.db.QueryRow(`
SELECT name FROM ai_models WHERE provider = $1 LIMIT 1
`, provider).Scan(&name)
if err != nil {
// 如果找不到基本信息,使用默认值
if provider == "deepseek" {
name = "DeepSeek AI"
} else if provider == "qwen" {
name = "Qwen AI"
} else {
name = provider + " AI"
}
}
// 如果传入的 ID 已经是完整格式(如 "admin_deepseek_custom1"),直接使用
// 否则生成新的 ID
newModelID := id
if id == provider {
// id 就是 provider,生成新的用户特定 ID
newModelID = fmt.Sprintf("%s_%s", userID, provider)
}
log.Printf("✓ 创建新的 AI 模型配置: ID=%s, Provider=%s, Name=%s", newModelID, provider, name)
_, err = d.db.Exec(`
INSERT INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url, custom_model_name, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
`, newModelID, userID, name, provider, enabled, apiKey, customAPIURL, customModelName)
return err
}
// GetExchanges 获取用户的交易所配置
func (d *PostgreSQLDatabase) GetExchanges(userID string) ([]*ExchangeConfig, error) {
rows, err := d.db.Query(`
SELECT id, user_id, name, type, enabled, api_key, secret_key, testnet,
COALESCE(hyperliquid_wallet_addr, '') as hyperliquid_wallet_addr,
COALESCE(aster_user, '') as aster_user,
COALESCE(aster_signer, '') as aster_signer,
COALESCE(aster_private_key, '') as aster_private_key,
created_at, updated_at
FROM exchanges WHERE user_id = $1 ORDER BY id
`, userID)
if err != nil {
return nil, err
}
defer rows.Close()
// 初始化为空切片而不是nil,确保JSON序列化为[]而不是null
exchanges := make([]*ExchangeConfig, 0)
for rows.Next() {
var exchange ExchangeConfig
err := rows.Scan(
&exchange.ID, &exchange.UserID, &exchange.Name, &exchange.Type,
&exchange.Enabled, &exchange.APIKey, &exchange.SecretKey, &exchange.Testnet,
&exchange.HyperliquidWalletAddr, &exchange.AsterUser,
&exchange.AsterSigner, &exchange.AsterPrivateKey,
&exchange.CreatedAt, &exchange.UpdatedAt,
)
if err != nil {
return nil, err
}
exchanges = append(exchanges, &exchange)
}
return exchanges, nil
}
// UpdateExchange 更新交易所配置,如果不存在则创建用户特定配置
func (d *PostgreSQLDatabase) UpdateExchange(userID, id string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error {
log.Printf("🔧 UpdateExchange: userID=%s, id=%s, enabled=%v", userID, id, enabled)
// 首先尝试更新现有的用户配置
result, err := d.db.Exec(`
UPDATE exchanges SET enabled = $1, api_key = $2, secret_key = $3, testnet = $4,
hyperliquid_wallet_addr = $5, aster_user = $6, aster_signer = $7, aster_private_key = $8, updated_at = CURRENT_TIMESTAMP
WHERE id = $9 AND user_id = $10
`, enabled, apiKey, secretKey, testnet, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey, id, userID)
if err != nil {
log.Printf("❌ UpdateExchange: 更新失败: %v", err)
return err
}
// 检查是否有行被更新
rowsAffected, err := result.RowsAffected()
if err != nil {
log.Printf("❌ UpdateExchange: 获取影响行数失败: %v", err)
return err
}
log.Printf("📊 UpdateExchange: 影响行数 = %d", rowsAffected)
// 如果没有行被更新,说明用户没有这个交易所的配置,需要创建
if rowsAffected == 0 {
log.Printf("💡 UpdateExchange: 没有现有记录,创建新记录")
// 根据交易所ID确定基本信息
var name, typ string
if id == "binance" {
name = "Binance Futures"
typ = "cex"
} else if id == "hyperliquid" {
name = "Hyperliquid"
typ = "dex"
} else if id == "aster" {
name = "Aster DEX"
typ = "dex"
} else {
name = id + " Exchange"
typ = "cex"
}
log.Printf("🆕 UpdateExchange: 创建新记录 ID=%s, name=%s, type=%s", id, name, typ)
// 创建用户特定的配置,使用原始的交易所ID
_, err = d.db.Exec(`
INSERT INTO exchanges (id, user_id, name, type, enabled, api_key, secret_key, testnet,
hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
`, id, userID, name, typ, enabled, apiKey, secretKey, testnet, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey)
if err != nil {
log.Printf("❌ UpdateExchange: 创建记录失败: %v", err)
} else {
log.Printf("✅ UpdateExchange: 创建记录成功")
}
return err
}
log.Printf("✅ UpdateExchange: 更新现有记录成功")
return nil
}
// CreateAIModel 创建AI模型配置
func (d *PostgreSQLDatabase) CreateAIModel(userID, id, name, provider string, enabled bool, apiKey, customAPIURL string) error {
_, err := d.db.Exec(`
INSERT INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (id) DO NOTHING
`, id, userID, name, provider, enabled, apiKey, customAPIURL)
return err
}
// CreateExchange 创建交易所配置
func (d *PostgreSQLDatabase) CreateExchange(userID, id, name, typ string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error {
_, err := d.db.Exec(`
INSERT INTO exchanges (id, user_id, name, type, enabled, api_key, secret_key, testnet, hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
ON CONFLICT (id, user_id) DO NOTHING
`, id, userID, name, typ, enabled, apiKey, secretKey, testnet, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey)
return err
}
// CreateTrader 创建交易员
func (d *PostgreSQLDatabase) CreateTrader(trader *TraderRecord) error {
_, err := d.db.Exec(`
INSERT INTO traders (id, user_id, name, ai_model_id, exchange_id, initial_balance, scan_interval_minutes, is_running, btc_eth_leverage, altcoin_leverage, trading_symbols, use_coin_pool, use_oi_top, custom_prompt, override_base_prompt, system_prompt_template, is_cross_margin)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)
`, trader.ID, trader.UserID, trader.Name, trader.AIModelID, trader.ExchangeID, trader.InitialBalance, trader.ScanIntervalMinutes, trader.IsRunning, trader.BTCETHLeverage, trader.AltcoinLeverage, trader.TradingSymbols, trader.UseCoinPool, trader.UseOITop, trader.CustomPrompt, trader.OverrideBasePrompt, trader.SystemPromptTemplate, trader.IsCrossMargin)
return err
}
// GetTraders 获取用户的交易员
func (d *PostgreSQLDatabase) GetTraders(userID string) ([]*TraderRecord, error) {
rows, err := d.db.Query(`
SELECT id, user_id, name, ai_model_id, exchange_id, initial_balance, scan_interval_minutes, is_running,
COALESCE(btc_eth_leverage, 5) as btc_eth_leverage, COALESCE(altcoin_leverage, 5) as altcoin_leverage,
COALESCE(trading_symbols, '') as trading_symbols,
COALESCE(use_coin_pool, false) as use_coin_pool, COALESCE(use_oi_top, false) as use_oi_top,
COALESCE(custom_prompt, '') as custom_prompt, COALESCE(override_base_prompt, false) as override_base_prompt,
COALESCE(system_prompt_template, 'default') as system_prompt_template,
COALESCE(is_cross_margin, true) as is_cross_margin, created_at, updated_at
FROM traders WHERE user_id = $1 ORDER BY created_at DESC
`, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var traders []*TraderRecord
for rows.Next() {
var trader TraderRecord
err := rows.Scan(
&trader.ID, &trader.UserID, &trader.Name, &trader.AIModelID, &trader.ExchangeID,
&trader.InitialBalance, &trader.ScanIntervalMinutes, &trader.IsRunning,
&trader.BTCETHLeverage, &trader.AltcoinLeverage, &trader.TradingSymbols,
&trader.UseCoinPool, &trader.UseOITop,
&trader.CustomPrompt, &trader.OverrideBasePrompt, &trader.SystemPromptTemplate,
&trader.IsCrossMargin,
&trader.CreatedAt, &trader.UpdatedAt,
)
if err != nil {
return nil, err
}
traders = append(traders, &trader)
}
return traders, nil
}
// UpdateTraderStatus 更新交易员状态
func (d *PostgreSQLDatabase) UpdateTraderStatus(userID, id string, isRunning bool) error {
_, err := d.db.Exec(`UPDATE traders SET is_running = $1 WHERE id = $2 AND user_id = $3`, isRunning, id, userID)
return err
}
// UpdateTrader 更新交易员配置
func (d *PostgreSQLDatabase) UpdateTrader(trader *TraderRecord) error {
_, err := d.db.Exec(`
UPDATE traders SET
name = $1, ai_model_id = $2, exchange_id = $3, initial_balance = $4,
scan_interval_minutes = $5, btc_eth_leverage = $6, altcoin_leverage = $7,
trading_symbols = $8, custom_prompt = $9, override_base_prompt = $10,
system_prompt_template = $11, is_cross_margin = $12, updated_at = CURRENT_TIMESTAMP
WHERE id = $13 AND user_id = $14
`, trader.Name, trader.AIModelID, trader.ExchangeID, trader.InitialBalance,
trader.ScanIntervalMinutes, trader.BTCETHLeverage, trader.AltcoinLeverage,
trader.TradingSymbols, trader.CustomPrompt, trader.OverrideBasePrompt,
trader.SystemPromptTemplate, trader.IsCrossMargin, trader.ID, trader.UserID)
return err
}
// UpdateTraderCustomPrompt 更新交易员自定义Prompt
func (d *PostgreSQLDatabase) UpdateTraderCustomPrompt(userID, id string, customPrompt string, overrideBase bool) error {
_, err := d.db.Exec(`UPDATE traders SET custom_prompt = $1, override_base_prompt = $2 WHERE id = $3 AND user_id = $4`, customPrompt, overrideBase, id, userID)
return err
}
// UpdateTraderInitialBalance 更新交易员初始余额(用于自动同步交易所实际余额)
func (d *PostgreSQLDatabase) UpdateTraderInitialBalance(userID, id string, newBalance float64) error {
_, err := d.db.Exec(`UPDATE traders SET initial_balance = $1, updated_at = CURRENT_TIMESTAMP WHERE id = $2 AND user_id = $3`, newBalance, id, userID)
return err
}
// DeleteTrader 删除交易员
func (d *PostgreSQLDatabase) DeleteTrader(userID, id string) error {
_, err := d.db.Exec(`DELETE FROM traders WHERE id = $1 AND user_id = $2`, id, userID)
return err
}
// GetTraderConfig 获取交易员完整配置(包含AI模型和交易所信息)
func (d *PostgreSQLDatabase) GetTraderConfig(userID, traderID string) (*TraderRecord, *AIModelConfig, *ExchangeConfig, error) {
var trader TraderRecord
var aiModel AIModelConfig
var exchange ExchangeConfig
err := d.db.QueryRow(`
SELECT
t.id, t.user_id, t.name, t.ai_model_id, t.exchange_id, t.initial_balance, t.scan_interval_minutes, t.is_running, t.created_at, t.updated_at,
a.id, a.user_id, a.name, a.provider, a.enabled, a.api_key, a.created_at, a.updated_at,
e.id, e.user_id, e.name, e.type, e.enabled, e.api_key, e.secret_key, e.testnet,
COALESCE(e.hyperliquid_wallet_addr, '') as hyperliquid_wallet_addr,
COALESCE(e.aster_user, '') as aster_user,
COALESCE(e.aster_signer, '') as aster_signer,
COALESCE(e.aster_private_key, '') as aster_private_key,
e.created_at, e.updated_at
FROM traders t
JOIN ai_models a ON t.ai_model_id = a.id AND t.user_id = a.user_id
JOIN exchanges e ON t.exchange_id = e.id AND t.user_id = e.user_id
WHERE t.id = $1 AND t.user_id = $2
`, traderID, userID).Scan(
&trader.ID, &trader.UserID, &trader.Name, &trader.AIModelID, &trader.ExchangeID,
&trader.InitialBalance, &trader.ScanIntervalMinutes, &trader.IsRunning,
&trader.CreatedAt, &trader.UpdatedAt,
&aiModel.ID, &aiModel.UserID, &aiModel.Name, &aiModel.Provider, &aiModel.Enabled, &aiModel.APIKey,
&aiModel.CreatedAt, &aiModel.UpdatedAt,
&exchange.ID, &exchange.UserID, &exchange.Name, &exchange.Type, &exchange.Enabled,
&exchange.APIKey, &exchange.SecretKey, &exchange.Testnet,
&exchange.HyperliquidWalletAddr, &exchange.AsterUser, &exchange.AsterSigner, &exchange.AsterPrivateKey,
&exchange.CreatedAt, &exchange.UpdatedAt,
)
if err != nil {
return nil, nil, nil, err
}
return &trader, &aiModel, &exchange, nil
}
// GetSystemConfig 获取系统配置
func (d *PostgreSQLDatabase) GetSystemConfig(key string) (string, error) {
var value string
err := d.db.QueryRow(`SELECT value FROM system_config WHERE key = $1`, key).Scan(&value)
return value, err
}
// SetSystemConfig 设置系统配置
func (d *PostgreSQLDatabase) SetSystemConfig(key, value string) error {
_, err := d.db.Exec(`
INSERT INTO system_config (key, value) VALUES ($1, $2)
ON CONFLICT (key) DO UPDATE SET value = $2, updated_at = CURRENT_TIMESTAMP
`, key, value)
return err
}
// CreateUserSignalSource 创建用户信号源配置
func (d *PostgreSQLDatabase) CreateUserSignalSource(userID, coinPoolURL, oiTopURL string) error {
_, err := d.db.Exec(`
INSERT INTO user_signal_sources (user_id, coin_pool_url, oi_top_url, updated_at)
VALUES ($1, $2, $3, CURRENT_TIMESTAMP)
ON CONFLICT (user_id) DO UPDATE SET
coin_pool_url = $2, oi_top_url = $3, updated_at = CURRENT_TIMESTAMP
`, userID, coinPoolURL, oiTopURL)
return err
}
// GetUserSignalSource 获取用户信号源配置
func (d *PostgreSQLDatabase) GetUserSignalSource(userID string) (*UserSignalSource, error) {
var source UserSignalSource
err := d.db.QueryRow(`
SELECT id, user_id, coin_pool_url, oi_top_url, created_at, updated_at
FROM user_signal_sources WHERE user_id = $1
`, userID).Scan(
&source.ID, &source.UserID, &source.CoinPoolURL, &source.OITopURL,
&source.CreatedAt, &source.UpdatedAt,
)
if err != nil {
return nil, err
}
return &source, nil
}
// UpdateUserSignalSource 更新用户信号源配置
func (d *PostgreSQLDatabase) UpdateUserSignalSource(userID, coinPoolURL, oiTopURL string) error {
_, err := d.db.Exec(`
UPDATE user_signal_sources SET coin_pool_url = $1, oi_top_url = $2, updated_at = CURRENT_TIMESTAMP
WHERE user_id = $3
`, coinPoolURL, oiTopURL, userID)
return err
}
// GetCustomCoins 获取所有交易员自定义币种
func (d *PostgreSQLDatabase) GetCustomCoins() []string {
var symbol string
var symbols []string
err := d.db.QueryRow(`
SELECT STRING_AGG(custom_coins, ',') as symbol
FROM traders WHERE custom_coins != ''
`).Scan(&symbol)
// 检测用户是否未配置币种 - 兼容性
if err != nil || symbol == "" {
symbolJSON, _ := d.GetSystemConfig("default_coins")
if err := json.Unmarshal([]byte(symbolJSON), &symbols); err != nil {
log.Printf("⚠️ 解析default_coins配置失败: %v,使用硬编码默认值", err)
symbols = []string{"BTCUSDT", "ETHUSDT", "SOLUSDT", "BNBUSDT"}
}
}
// filter Symbol
for _, s := range strings.Split(symbol, ",") {
if s == "" {
continue
}
coin := market.Normalize(s)
if !slices.Contains(symbols, coin) {
symbols = append(symbols, coin)
}
}
return symbols
}
// LoadBetaCodesFromFile 从文件加载内测码到数据库
func (d *PostgreSQLDatabase) LoadBetaCodesFromFile(filePath string) error {
// 读取文件内容
content, err := os.ReadFile(filePath)
if err != nil {
return fmt.Errorf("读取内测码文件失败: %w", err)
}
// 按行分割内测码
lines := strings.Split(string(content), "\n")
var codes []string
for _, line := range lines {
code := strings.TrimSpace(line)
if code != "" && !strings.HasPrefix(code, "#") {
codes = append(codes, code)
}
}
// 批量插入内测码
tx, err := d.db.Begin()
if err != nil {
return fmt.Errorf("开始事务失败: %w", err)
}
defer tx.Rollback()
stmt, err := tx.Prepare(`INSERT INTO beta_codes (code) VALUES ($1) ON CONFLICT (code) DO NOTHING`)
if err != nil {
return fmt.Errorf("准备语句失败: %w", err)
}
defer stmt.Close()
insertedCount := 0
for _, code := range codes {
result, err := stmt.Exec(code)
if err != nil {
log.Printf("插入内测码 %s 失败: %v", code, err)
continue
}
if rowsAffected, _ := result.RowsAffected(); rowsAffected > 0 {
insertedCount++
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("提交事务失败: %w", err)
}
log.Printf("✅ 成功加载 %d 个内测码到数据库 (总计 %d 个)", insertedCount, len(codes))
return nil
}
// ValidateBetaCode 验证内测码是否有效且未使用
func (d *PostgreSQLDatabase) ValidateBetaCode(code string) (bool, error) {
var used bool
err := d.db.QueryRow(`SELECT used FROM beta_codes WHERE code = $1`, code).Scan(&used)
if err != nil {
if err == sql.ErrNoRows {
return false, nil // 内测码不存在
}
return false, err
}
return !used, nil // 内测码存在且未使用
}
// UseBetaCode 使用内测码(标记为已使用)
func (d *PostgreSQLDatabase) UseBetaCode(code, userEmail string) error {
result, err := d.db.Exec(`
UPDATE beta_codes SET used = true, used_by = $1, used_at = CURRENT_TIMESTAMP
WHERE code = $2 AND used = false
`, userEmail, code)
if err != nil {
return err
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return err
}
if rowsAffected == 0 {
return fmt.Errorf("内测码无效或已被使用")
}
return nil
}
// GetBetaCodeStats 获取内测码统计信息
func (d *PostgreSQLDatabase) GetBetaCodeStats() (total, used int, err error) {
err = d.db.QueryRow(`SELECT COUNT(*) FROM beta_codes`).Scan(&total)
if err != nil {
return 0, 0, err
}
err = d.db.QueryRow(`SELECT COUNT(*) FROM beta_codes WHERE used = true`).Scan(&used)
if err != nil {
return 0, 0, err
}
return total, used, nil
}
// Close 关闭数据库连接
func (d *PostgreSQLDatabase) Close() error {
return d.db.Close()
}