From c1003ca3e82be29e862565bb470aca121032b058 Mon Sep 17 00:00:00 2001 From: Shui <88711385+hzb1115@users.noreply.github.com> Date: Sat, 8 Nov 2025 20:02:30 -0500 Subject: [PATCH] feat(hook): Add hook module to help decouple some specific logic (#784) --- api/server.go | 25 +++- bootstrap/bootstrap.go | 23 ++-- hook/README.md | 270 ++++++++++++++++++++++++++++++++++++++ hook/hooks.go | 41 ++++++ hook/http_client_hook.go | 23 ++++ hook/ip_hook.go | 19 +++ hook/trader_hook.go | 42 ++++++ market/api_client.go | 16 ++- market/data.go | 7 +- trader/aster_trader.go | 24 ++-- trader/auto_trader.go | 2 +- trader/binance_futures.go | 9 +- 12 files changed, 466 insertions(+), 35 deletions(-) create mode 100644 hook/README.md create mode 100644 hook/hooks.go create mode 100644 hook/http_client_hook.go create mode 100644 hook/ip_hook.go create mode 100644 hook/trader_hook.go diff --git a/api/server.go b/api/server.go index e98db44a..543fb0a6 100644 --- a/api/server.go +++ b/api/server.go @@ -10,6 +10,7 @@ import ( "nofx/config" "nofx/crypto" "nofx/decision" + "nofx/hook" "nofx/manager" "nofx/trader" "strconv" @@ -204,6 +205,17 @@ func (s *Server) handleGetSystemConfig(c *gin.Context) { // handleGetServerIP 获取服务器IP地址(用于白名单配置) func (s *Server) handleGetServerIP(c *gin.Context) { + + // 首先尝试从Hook获取用户专用IP + userIP := hook.HookExec[hook.IpResult](hook.GETIP, c.GetString("user_id")) + if userIP != nil && userIP.Error() == nil { + c.JSON(http.StatusOK, gin.H{ + "public_ip": userIP.GetResult(), + "message": "请将此IP地址添加到白名单中", + }) + return + } + // 尝试通过第三方API获取公网IP publicIP := getPublicIPFromAPI() @@ -392,8 +404,8 @@ type SafeModelConfig struct { Name string `json:"name"` Provider string `json:"provider"` Enabled bool `json:"enabled"` - CustomAPIURL string `json:"customApiUrl"` // 自定义API URL(通常不敏感) - CustomModelName string `json:"customModelName"` // 自定义模型名(不敏感) + CustomAPIURL string `json:"customApiUrl"` // 自定义API URL(通常不敏感) + CustomModelName string `json:"customModelName"` // 自定义模型名(不敏感) } type ExchangeConfig struct { @@ -414,8 +426,8 @@ type SafeExchangeConfig struct { Enabled bool `json:"enabled"` Testnet bool `json:"testnet,omitempty"` HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"` // Hyperliquid钱包地址(不敏感) - AsterUser string `json:"asterUser"` // Aster用户名(不敏感) - AsterSigner string `json:"asterSigner"` // Aster签名者(不敏感) + AsterUser string `json:"asterUser"` // Aster用户名(不敏感) + AsterSigner string `json:"asterSigner"` // Aster签名者(不敏感) } type UpdateModelConfigRequest struct { @@ -543,7 +555,7 @@ func (s *Server) handleCreateTrader(c *gin.Context) { switch req.ExchangeID { case "binance": - tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey) + tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey, userID) case "hyperliquid": tempTrader, createErr = trader.NewHyperliquidTrader( exchangeCfg.APIKey, // private key @@ -904,7 +916,7 @@ func (s *Server) handleSyncBalance(c *gin.Context) { switch traderConfig.ExchangeID { case "binance": - tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey) + tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey, userID) case "hyperliquid": tempTrader, createErr = trader.NewHyperliquidTrader( exchangeCfg.APIKey, @@ -1638,7 +1650,6 @@ func (s *Server) authMiddleware() gin.HandlerFunc { } } - // handleLogout 将当前token加入黑名单 func (s *Server) handleLogout(c *gin.Context) { authHeader := c.GetHeader("Authorization") diff --git a/bootstrap/bootstrap.go b/bootstrap/bootstrap.go index 6e28cbe7..88fd2063 100644 --- a/bootstrap/bootstrap.go +++ b/bootstrap/bootstrap.go @@ -6,6 +6,7 @@ import ( "sort" "sync" "time" + "log" ) // Priority 初始化优先级常量 @@ -68,7 +69,7 @@ func RunWithPolicy(ctx *Context, defaultPolicy ErrorPolicy) error { hooksMu.Unlock() if len(hooksCopy) == 0 { - logger.Log.Warnf("⚠️ 没有注册任何初始化钩子") + log.Printf("⚠️ 没有注册任何初始化钩子") return nil } @@ -77,7 +78,7 @@ func RunWithPolicy(ctx *Context, defaultPolicy ErrorPolicy) error { return hooksCopy[i].Priority < hooksCopy[j].Priority }) - logger.Log.Infof("🔄 开始初始化 %d 个模块...", len(hooksCopy)) + log.Printf("🔄 开始初始化 %d 个模块...", len(hooksCopy)) startTime := time.Now() var errors []error @@ -87,13 +88,13 @@ func RunWithPolicy(ctx *Context, defaultPolicy ErrorPolicy) error { for i, hook := range hooksCopy { // 检查是否启用 if hook.Enabled != nil && !hook.Enabled(ctx) { - logger.Log.Infof(" [%d/%d] 跳过: %s (条件未满足)", + log.Printf(" [%d/%d] 跳过: %s (条件未满足)", i+1, len(hooksCopy), hook.Name) skippedCount++ continue } - logger.Log.Infof(" [%d/%d] 初始化: %s (优先级: %d)", + log.Printf(" [%d/%d] 初始化: %s (优先级: %d)", i+1, len(hooksCopy), hook.Name, hook.Priority) hookStart := time.Now() @@ -111,16 +112,16 @@ func RunWithPolicy(ctx *Context, defaultPolicy ErrorPolicy) error { switch policy { case FailFast: - logger.Log.Errorf(" ❌ 失败: %s (耗时: %v)", hook.Name, elapsed) + log.Printf(" ❌ 失败: %s (耗时: %v)", hook.Name, elapsed) return errMsg case ContinueOnError: - logger.Log.Errorf(" ❌ 失败: %s (耗时: %v) - 继续执行", hook.Name, elapsed) + log.Printf(" ❌ 失败: %s (耗时: %v) - 继续执行", hook.Name, elapsed) errors = append(errors, errMsg) case WarnOnError: - logger.Log.Warnf(" ⚠️ 警告: %s (耗时: %v) - %v", hook.Name, elapsed, err) + log.Printf(" ⚠️ 警告: %s (耗时: %v) - %v", hook.Name, elapsed, err) } } else { - logger.Log.Infof(" ✓ 完成: %s (耗时: %v)", hook.Name, elapsed) + log.Printf(" ✓ 完成: %s (耗时: %v)", hook.Name, elapsed) successCount++ } } @@ -131,15 +132,15 @@ func RunWithPolicy(ctx *Context, defaultPolicy ErrorPolicy) error { if len(errors) > 0 { logger.Log.Warnf("⚠️ 初始化完成,但有 %d 个模块失败 (总耗时: %v)", len(errors), totalElapsed) - logger.Log.Infof("📊 统计: 成功=%d, 失败=%d, 跳过=%d", + log.Printf("📊 统计: 成功=%d, 失败=%d, 跳过=%d", successCount, len(errors), skippedCount) // 返回合并的错误 return fmt.Errorf("以下模块初始化失败: %v", errors) } - logger.Log.Infof("✅ 所有模块初始化完成 (总耗时: %v)", totalElapsed) - logger.Log.Infof("📊 统计: 成功=%d, 跳过=%d", successCount, skippedCount) + log.Printf("✅ 所有模块初始化完成 (总耗时: %v)", totalElapsed) + log.Printf("📊 统计: 成功=%d, 跳过=%d", successCount, skippedCount) return nil } diff --git a/hook/README.md b/hook/README.md new file mode 100644 index 00000000..a5cce891 --- /dev/null +++ b/hook/README.md @@ -0,0 +1,270 @@ +# Hook 模块使用文档 + +## 简介 + +Hook模块提供了一个通用的扩展点机制,允许在不修改核心代码的前提下注入自定义逻辑。 + +**核心特点**: +- 类型安全的泛型API +- Hook未注册时自动fallback +- 支持任意参数和返回值 + +## 快速开始 + +### 基本用法 + +```go +// 1. 注册Hook +hook.RegisterHook(hook.GETIP, func(args ...any) any { + userId := args[0].(string) + return &hook.IpResult{IP: "192.168.1.1"} +}) + +// 2. 调用Hook +result := hook.HookExec[hook.IpResult](hook.GETIP, "user123") +if result != nil && result.Error() == nil { + ip := result.GetResult() +} +``` + +### 核心API + +```go +// 注册Hook函数 +func RegisterHook(key string, hook HookFunc) + +// 执行Hook(泛型) +func HookExec[T any](key string, args ...any) *T +``` + +## 可用的Hook扩展点 + +### 1. `GETIP` - 获取用户IP + +**调用位置**:`api/server.go:210` + +**参数**:`userId string` + +**返回**:`*IpResult` +```go +type IpResult struct { + Err error + IP string +} +``` + +**用途**:返回用户专用IP(如代理IP) + +--- + +### 2. `NEW_BINANCE_TRADER` - Binance客户端创建 + +**调用位置**:`trader/binance_futures.go:68` + +**参数**:`userId string, client *futures.Client` + +**返回**:`*NewBinanceTraderResult` +```go +type NewBinanceTraderResult struct { + Err error + Client *futures.Client // 可修改client配置 +} +``` + +**用途**:为Binance客户端注入代理、日志等 + +--- + +### 3. `NEW_ASTER_TRADER` - Aster客户端创建 + +**调用位置**:`trader/aster_trader.go:68` + +**参数**:`user string, client *http.Client` + +**返回**:`*NewAsterTraderResult` +```go +type NewAsterTraderResult struct { + Err error + Client *http.Client // 可修改HTTP client +} +``` + +**用途**:为Aster客户端注入代理等 + +## 使用示例 + +### 示例1:代理模块注册Hook + +```go +// proxy/init.go +package proxy + +import "nofx/hook" + +func InitHooks(enabled bool) { + if !enabled { + return // 条件不满足,不注册 + } + + // 注册IP获取Hook + hook.RegisterHook(hook.GETIP, func(args ...any) any { + userId := args[0].(string) + proxyIP, err := getProxyIP(userId) + return &hook.IpResult{Err: err, IP: proxyIP} + }) + + // 注册Binance客户端Hook + hook.RegisterHook(hook.NEW_BINANCE_TRADER, func(args ...any) any { + userId := args[0].(string) + client := args[1].(*futures.Client) + + // 修改client配置 + if client.HTTPClient != nil { + client.HTTPClient.Transport = getProxyTransport() + } + + return &hook.NewBinanceTraderResult{Client: client} + }) +} +``` + +## 最佳实践 + +### ✅ 推荐做法 + +```go +// 1. 在注册时判断条件 +func InitHooks(enabled bool) { + if !enabled { + return // 不注册 + } + hook.RegisterHook(KEY, hookFunc) +} + +// 2. 总是返回正确的Result类型 +hook.RegisterHook(hook.GETIP, func(args ...any) any { + ip, err := getIP() + return &hook.IpResult{Err: err, IP: ip} // ✅ +}) + +// 3. 安全的类型断言 +userId, ok := args[0].(string) +if !ok { + return &hook.IpResult{Err: fmt.Errorf("参数类型错误")} +} +``` + +### ❌ 避免的做法 + +```go +// 1. 不要在Hook内部判断条件(浪费性能) +hook.RegisterHook(KEY, func(args ...any) any { + if !enabled { + return nil // ❌ + } + // ... +}) + +// 2. 不要直接panic +hook.RegisterHook(KEY, func(args ...any) any { + if err != nil { + panic(err) // ❌ 会导致程序崩溃 + } +}) + +// 3. 不要跳过类型检查 +userId := args[0].(string) // ❌ 可能panic +``` + +## 添加新Hook扩展点 + +### 步骤1:定义Result类型 + +```go +// hook/my_hook.go +package hook + +type MyHookResult struct { + Err error + Data string +} + +func (r *MyHookResult) Error() error { + if r.Err != nil { + log.Printf("⚠️ Hook出错: %v", r.Err) + } + return r.Err +} + +func (r *MyHookResult) GetResult() string { + r.Error() + return r.Data +} +``` + +### 步骤2:定义Hook常量 + +```go +// hook/hooks.go +const ( + GETIP = "GETIP" + NEW_BINANCE_TRADER = "NEW_BINANCE_TRADER" + NEW_ASTER_TRADER = "NEW_ASTER_TRADER" + MY_HOOK = "MY_HOOK" // 新增 +) +``` + +### 步骤3:在业务代码调用 + +```go +result := hook.HookExec[hook.MyHookResult](hook.MY_HOOK, arg1, arg2) +if result != nil && result.Error() == nil { + data := result.GetResult() + // 使用data +} +``` + +### 步骤4:注册实现 + +```go +hook.RegisterHook(hook.MY_HOOK, func(args ...any) any { + // 处理逻辑 + return &hook.MyHookResult{Data: "result"} +}) +``` + +## 常见问题 + +**Q: Hook可以注册多个吗?** +A: 不可以,每个Key只能注册一个Hook,后注册会覆盖前面的。如需多个逻辑,请在一个Hook函数内组合。 + +**Q: Hook执行失败会影响主流程吗?** +A: 不会,主流程会检查返回值,失败时会fallback到默认逻辑。 + +**Q: 如何调试Hook?** +A: Hook执行时会自动打印日志: +- `🔌 Execute hook: {KEY}` - Hook存在并执行 +- `🔌 Do not find hook: {KEY}` - Hook未注册 + +**Q: 如何测试Hook?** +```go +func TestHook(t *testing.T) { + // 清空全局Hook + hook.Hooks = make(map[string]hook.HookFunc) + + // 注册测试Hook + hook.RegisterHook(hook.GETIP, func(args ...any) any { + return &hook.IpResult{IP: "127.0.0.1"} + }) + + // 验证 + result := hook.HookExec[hook.IpResult](hook.GETIP, "test") + assert.Equal(t, "127.0.0.1", result.IP) +} +``` + +## 参考 + +- 核心实现:`hook/hooks.go` +- Result类型:`hook/trader_hook.go`, `hook/ip_hook.go` +- 调用示例:`api/server.go`, `trader/binance_futures.go`, `trader/aster_trader.go` diff --git a/hook/hooks.go b/hook/hooks.go new file mode 100644 index 00000000..e94e28aa --- /dev/null +++ b/hook/hooks.go @@ -0,0 +1,41 @@ +package hook + +import ( + "log" +) + +type HookFunc func(args ...any) any + +var ( + Hooks map[string]HookFunc = map[string]HookFunc{} + EnableHooks = true +) + +func HookExec[T any](key string, args ...any) *T { + if !EnableHooks { + log.Printf("🔌 Hooks are disabled, skip hook: %s", key) + var zero *T + return zero + } + if hook, exists := Hooks[key]; exists && hook != nil { + log.Printf("🔌 Execute hook: %s", key) + res := hook(args...) + return res.(*T) + } else { + log.Printf("🔌 Do not find hook: %s", key) + } + var zero *T + return zero +} + +func RegisterHook(key string, hook HookFunc) { + Hooks[key] = hook +} + +// hook list +const ( + GETIP = "GETIP" // func (userID string) *IpResult + NEW_BINANCE_TRADER = "NEW_BINANCE_TRADER" // func (userID string, client *futures.Client) *NewBinanceTraderResult + NEW_ASTER_TRADER = "NEW_ASTER_TRADER" // func (userID string, client *http.Client) *NewAsterTraderResult + SET_HTTP_CLIENT = "SET_HTTP_CLIENT" // func (client *http.Client) *SetHttpClientResult +) diff --git a/hook/http_client_hook.go b/hook/http_client_hook.go new file mode 100644 index 00000000..5540b23a --- /dev/null +++ b/hook/http_client_hook.go @@ -0,0 +1,23 @@ +package hook + +import ( + "log" + "net/http" +) + +type SetHttpClientResult struct { + Err error + Client *http.Client +} + +func (r *SetHttpClientResult) Error() error { + if r.Err != nil { + log.Printf("⚠️ 执行NewAsterTraderResult时出错: %v", r.Err) + } + return r.Err +} + +func (r *SetHttpClientResult) GetResult() *http.Client { + r.Error() + return r.Client +} diff --git a/hook/ip_hook.go b/hook/ip_hook.go new file mode 100644 index 00000000..9ad597d3 --- /dev/null +++ b/hook/ip_hook.go @@ -0,0 +1,19 @@ +package hook + +import "github.com/rs/zerolog/log" + +type IpResult struct { + Err error + IP string +} + +func (r *IpResult) Error() error { + return r.Err +} + +func (r *IpResult) GetResult() string { + if r.Err != nil { + log.Printf("⚠️ 执行GetIP时出错: %v", r.Err) + } + return r.IP +} diff --git a/hook/trader_hook.go b/hook/trader_hook.go new file mode 100644 index 00000000..cbd7a1f1 --- /dev/null +++ b/hook/trader_hook.go @@ -0,0 +1,42 @@ +package hook + +import ( + "log" + "net/http" + + "github.com/adshao/go-binance/v2/futures" +) + +type NewBinanceTraderResult struct { + Err error + Client *futures.Client +} + +func (r *NewBinanceTraderResult) Error() error { + if r.Err != nil { + log.Printf("⚠️ 执行NewBinanceTraderResult时出错: %v", r.Err) + } + return r.Err +} + +func (r *NewBinanceTraderResult) GetResult() *futures.Client { + r.Error() + return r.Client +} + +type NewAsterTraderResult struct { + Err error + Client *http.Client +} + +func (r *NewAsterTraderResult) Error() error { + if r.Err != nil { + log.Printf("⚠️ 执行NewAsterTraderResult时出错: %v", r.Err) + } + return r.Err +} + +func (r *NewAsterTraderResult) GetResult() *http.Client { + r.Error() + return r.Client +} diff --git a/market/api_client.go b/market/api_client.go index 70bb1150..3b9c268e 100644 --- a/market/api_client.go +++ b/market/api_client.go @@ -6,6 +6,7 @@ import ( "io" "log" "net/http" + "nofx/hook" "strconv" "time" ) @@ -19,10 +20,18 @@ type APIClient struct { } func NewAPIClient() *APIClient { + client := &http.Client{ + Timeout: 30 * time.Second, + } + + hookRes := hook.HookExec[hook.SetHttpClientResult](hook.SET_HTTP_CLIENT, client) + if hookRes != nil && hookRes.Error() == nil { + log.Printf("使用Hook设置的HTTP客户端") + client = hookRes.GetResult() + } + return &APIClient{ - client: &http.Client{ - Timeout: 30 * time.Second, - }, + client: client, } } @@ -74,6 +83,7 @@ func (c *APIClient) GetKlines(symbol, interval string, limit int) ([]Kline, erro var klineResponses []KlineResponse err = json.Unmarshal(body, &klineResponses) if err != nil { + log.Printf("获取K线数据失败,响应内容: %s", string(body)) return nil, err } diff --git a/market/data.go b/market/data.go index e5675d43..b69281dc 100644 --- a/market/data.go +++ b/market/data.go @@ -5,7 +5,6 @@ import ( "fmt" "io/ioutil" "math" - "net/http" "strconv" "strings" "sync" @@ -315,7 +314,8 @@ func calculateLongerTermData(klines []Kline) *LongerTermData { func getOpenInterestData(symbol string) (*OIData, error) { url := fmt.Sprintf("https://fapi.binance.com/fapi/v1/openInterest?symbol=%s", symbol) - resp, err := http.Get(url) + apiClient := NewAPIClient() + resp, err := apiClient.client.Get(url) if err != nil { return nil, err } @@ -359,7 +359,8 @@ func getFundingRate(symbol string) (float64, error) { // 缓存过期或不存在,调用 API url := fmt.Sprintf("https://fapi.binance.com/fapi/v1/premiumIndex?symbol=%s", symbol) - resp, err := http.Get(url) + apiClient := NewAPIClient() + resp, err := apiClient.client.Get(url) if err != nil { return 0, err } diff --git a/trader/aster_trader.go b/trader/aster_trader.go index 362bb894..e33c1b0e 100644 --- a/trader/aster_trader.go +++ b/trader/aster_trader.go @@ -13,6 +13,7 @@ import ( "math/big" "net/http" "net/url" + "nofx/hook" "sort" "strconv" "strings" @@ -56,6 +57,18 @@ func NewAsterTrader(user, signer, privateKeyHex string) (*AsterTrader, error) { if err != nil { return nil, fmt.Errorf("解析私钥失败: %w", err) } + client := &http.Client{ + Timeout: 30 * time.Second, // 增加到30秒 + Transport: &http.Transport{ + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: 10 * time.Second, + IdleConnTimeout: 90 * time.Second, + }, + } + res := hook.HookExec[hook.NewAsterTraderResult](hook.NEW_ASTER_TRADER, user, client) + if res != nil && res.Error() == nil { + client = res.GetResult() + } return &AsterTrader{ ctx: context.Background(), @@ -63,15 +76,8 @@ func NewAsterTrader(user, signer, privateKeyHex string) (*AsterTrader, error) { signer: signer, privateKey: privKey, symbolPrecision: make(map[string]SymbolPrecision), - client: &http.Client{ - Timeout: 30 * time.Second, // 增加到30秒 - Transport: &http.Transport{ - TLSHandshakeTimeout: 10 * time.Second, - ResponseHeaderTimeout: 10 * time.Second, - IdleConnTimeout: 90 * time.Second, - }, - }, - baseURL: "https://fapi.asterdex.com", + client: client, + baseURL: "https://fapi.asterdex.com", }, nil } diff --git a/trader/auto_trader.go b/trader/auto_trader.go index 117af718..79879542 100644 --- a/trader/auto_trader.go +++ b/trader/auto_trader.go @@ -175,7 +175,7 @@ func NewAutoTrader(config AutoTraderConfig, database interface{}, userID string) switch config.Exchange { case "binance": log.Printf("🏦 [%s] 使用币安合约交易", config.Name) - trader = NewFuturesTrader(config.BinanceAPIKey, config.BinanceSecretKey) + trader = NewFuturesTrader(config.BinanceAPIKey, config.BinanceSecretKey, userID) case "hyperliquid": log.Printf("🏦 [%s] 使用Hyperliquid交易", config.Name) trader, err = NewHyperliquidTrader(config.HyperliquidPrivateKey, config.HyperliquidWalletAddr, config.HyperliquidTestnet) diff --git a/trader/binance_futures.go b/trader/binance_futures.go index 68089c15..f2489f6b 100644 --- a/trader/binance_futures.go +++ b/trader/binance_futures.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "fmt" "log" + "nofx/hook" "strconv" "strings" "sync" @@ -61,8 +62,14 @@ type FuturesTrader struct { } // NewFuturesTrader 创建合约交易器 -func NewFuturesTrader(apiKey, secretKey string) *FuturesTrader { +func NewFuturesTrader(apiKey, secretKey string, userId string) *FuturesTrader { client := futures.NewClient(apiKey, secretKey) + + hookRes := hook.HookExec[hook.NewBinanceTraderResult](hook.NEW_BINANCE_TRADER, userId, client) + if hookRes != nil && hookRes.GetResult() != nil { + client = hookRes.GetResult() + } + // 同步时间,避免 Timestamp ahead 错误 syncBinanceServerTime(client) trader := &FuturesTrader{