mirror of
https://github.com/laoxong/nofx.git
synced 2026-06-04 09:58:22 +08:00
Refactor: Modularize codebase with separate decision and MCP packages
Architecture improvements: - Extract AI decision engine to dedicated `decision` package - Create `mcp` package for Model Context Protocol client - Separate market data structures into `market/data.go` - Update trader to use new modular structure New packages: - `decision/engine.go` - AI decision logic and prompt building - `mcp/client.go` - Unified AI API client (DeepSeek/Qwen) - `market/data.go` - Market data type definitions Benefits: - Better separation of concerns - Improved code organization and maintainability - Easier to test individual components - More flexible AI provider integration - Cleaner dependency management Updated imports: - trader/auto_trader.go now uses decision and mcp packages - Consistent API across different AI providers Co-Authored-By: tinkle-community <tinklefund@gmail.com>
This commit is contained in:
@@ -0,0 +1,514 @@
|
||||
package decision
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"nofx/market"
|
||||
"nofx/mcp"
|
||||
"nofx/pool"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PositionInfo 持仓信息
|
||||
type PositionInfo struct {
|
||||
Symbol string `json:"symbol"`
|
||||
Side string `json:"side"` // "long" or "short"
|
||||
EntryPrice float64 `json:"entry_price"`
|
||||
MarkPrice float64 `json:"mark_price"`
|
||||
Quantity float64 `json:"quantity"`
|
||||
Leverage int `json:"leverage"`
|
||||
UnrealizedPnL float64 `json:"unrealized_pnl"`
|
||||
UnrealizedPnLPct float64 `json:"unrealized_pnl_pct"`
|
||||
LiquidationPrice float64 `json:"liquidation_price"`
|
||||
MarginUsed float64 `json:"margin_used"`
|
||||
}
|
||||
|
||||
// AccountInfo 账户信息
|
||||
type AccountInfo struct {
|
||||
TotalEquity float64 `json:"total_equity"` // 账户净值
|
||||
AvailableBalance float64 `json:"available_balance"` // 可用余额
|
||||
TotalPnL float64 `json:"total_pnl"` // 总盈亏
|
||||
TotalPnLPct float64 `json:"total_pnl_pct"` // 总盈亏百分比
|
||||
MarginUsed float64 `json:"margin_used"` // 已用保证金
|
||||
MarginUsedPct float64 `json:"margin_used_pct"` // 保证金使用率
|
||||
PositionCount int `json:"position_count"` // 持仓数量
|
||||
}
|
||||
|
||||
// CandidateCoin 候选币种(来自币种池)
|
||||
type CandidateCoin struct {
|
||||
Symbol string `json:"symbol"`
|
||||
Sources []string `json:"sources"` // 来源: "ai500" 和/或 "oi_top"
|
||||
}
|
||||
|
||||
// OITopData 持仓量增长Top数据(用于AI决策参考)
|
||||
type OITopData struct {
|
||||
Rank int // OI Top排名
|
||||
OIDeltaPercent float64 // 持仓量变化百分比(1小时)
|
||||
OIDeltaValue float64 // 持仓量变化价值
|
||||
PriceDeltaPercent float64 // 价格变化百分比
|
||||
NetLong float64 // 净多仓
|
||||
NetShort float64 // 净空仓
|
||||
}
|
||||
|
||||
// Context 交易上下文(传递给AI的完整信息)
|
||||
type Context struct {
|
||||
CurrentTime string `json:"current_time"`
|
||||
RuntimeMinutes int `json:"runtime_minutes"`
|
||||
CallCount int `json:"call_count"`
|
||||
Account AccountInfo `json:"account"`
|
||||
Positions []PositionInfo `json:"positions"`
|
||||
CandidateCoins []CandidateCoin `json:"candidate_coins"`
|
||||
MarketDataMap map[string]*market.Data `json:"-"` // 不序列化,但内部使用
|
||||
OITopDataMap map[string]*OITopData `json:"-"` // OI Top数据映射
|
||||
Performance interface{} `json:"-"` // 历史表现分析(logger.PerformanceAnalysis)
|
||||
}
|
||||
|
||||
// Decision AI的交易决策
|
||||
type Decision struct {
|
||||
Symbol string `json:"symbol"`
|
||||
Action string `json:"action"` // "open_long", "open_short", "close_long", "close_short", "hold", "wait"
|
||||
Leverage int `json:"leverage,omitempty"`
|
||||
PositionSizeUSD float64 `json:"position_size_usd,omitempty"`
|
||||
StopLoss float64 `json:"stop_loss,omitempty"`
|
||||
TakeProfit float64 `json:"take_profit,omitempty"`
|
||||
Confidence int `json:"confidence,omitempty"` // 信心度 (0-100)
|
||||
RiskUSD float64 `json:"risk_usd,omitempty"` // 最大美元风险
|
||||
Reasoning string `json:"reasoning"`
|
||||
}
|
||||
|
||||
// FullDecision AI的完整决策(包含思维链)
|
||||
type FullDecision struct {
|
||||
UserPrompt string `json:"user_prompt"` // 发送给AI的输入prompt
|
||||
CoTTrace string `json:"cot_trace"` // 思维链分析(AI输出)
|
||||
Decisions []Decision `json:"decisions"` // 具体决策列表
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// GetFullDecision 获取AI的完整交易决策(批量分析所有币种和持仓)
|
||||
func GetFullDecision(ctx *Context) (*FullDecision, error) {
|
||||
// 1. 为所有币种获取市场数据
|
||||
if err := fetchMarketDataForContext(ctx); err != nil {
|
||||
return nil, fmt.Errorf("获取市场数据失败: %w", err)
|
||||
}
|
||||
|
||||
// 2. 构建 System Prompt(固定规则)和 User Prompt(动态数据)
|
||||
systemPrompt := buildSystemPrompt(ctx.Account.TotalEquity)
|
||||
userPrompt := buildUserPrompt(ctx)
|
||||
|
||||
// 3. 调用AI API(使用 system + user prompt)
|
||||
aiResponse, err := mcp.CallWithMessages(systemPrompt, userPrompt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("调用AI API失败: %w", err)
|
||||
}
|
||||
|
||||
// 4. 解析AI响应
|
||||
decision, err := parseFullDecisionResponse(aiResponse, ctx.Account.TotalEquity)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("解析AI响应失败: %w", err)
|
||||
}
|
||||
|
||||
decision.Timestamp = time.Now()
|
||||
decision.UserPrompt = userPrompt // 保存输入prompt
|
||||
return decision, nil
|
||||
}
|
||||
|
||||
// fetchMarketDataForContext 为上下文中的所有币种获取市场数据和OI数据
|
||||
func fetchMarketDataForContext(ctx *Context) error {
|
||||
ctx.MarketDataMap = make(map[string]*market.Data)
|
||||
ctx.OITopDataMap = make(map[string]*OITopData)
|
||||
|
||||
// 收集所有需要获取数据的币种
|
||||
symbolSet := make(map[string]bool)
|
||||
|
||||
// 1. 优先获取持仓币种的数据(这是必须的)
|
||||
for _, pos := range ctx.Positions {
|
||||
symbolSet[pos.Symbol] = true
|
||||
}
|
||||
|
||||
// 2. 候选币种数量根据账户状态动态调整
|
||||
maxCandidates := calculateMaxCandidates(ctx)
|
||||
for i, coin := range ctx.CandidateCoins {
|
||||
if i >= maxCandidates {
|
||||
break
|
||||
}
|
||||
symbolSet[coin.Symbol] = true
|
||||
}
|
||||
|
||||
// 并发获取市场数据
|
||||
// 持仓币种集合(用于判断是否跳过OI检查)
|
||||
positionSymbols := make(map[string]bool)
|
||||
for _, pos := range ctx.Positions {
|
||||
positionSymbols[pos.Symbol] = true
|
||||
}
|
||||
|
||||
for symbol := range symbolSet {
|
||||
data, err := market.Get(symbol)
|
||||
if err != nil {
|
||||
// 单个币种失败不影响整体,只记录错误
|
||||
continue
|
||||
}
|
||||
|
||||
// ⚠️ 流动性过滤:持仓价值低于15M USD的币种不做(多空都不做)
|
||||
// 持仓价值 = 持仓量 × 当前价格
|
||||
// 但现有持仓必须保留(需要决策是否平仓)
|
||||
isExistingPosition := positionSymbols[symbol]
|
||||
if !isExistingPosition && data.OpenInterest != nil && data.CurrentPrice > 0 {
|
||||
// 计算持仓价值(USD)= 持仓量 × 当前价格
|
||||
oiValue := data.OpenInterest.Latest * data.CurrentPrice
|
||||
oiValueInMillions := oiValue / 1_000_000 // 转换为百万美元单位
|
||||
if oiValueInMillions < 15 {
|
||||
log.Printf("⚠️ %s 持仓价值过低(%.2fM USD < 15M),跳过此币种 [持仓量:%.0f × 价格:%.4f]",
|
||||
symbol, oiValueInMillions, data.OpenInterest.Latest, data.CurrentPrice)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
ctx.MarketDataMap[symbol] = data
|
||||
}
|
||||
|
||||
// 加载OI Top数据(不影响主流程)
|
||||
oiPositions, err := pool.GetOITopPositions()
|
||||
if err == nil {
|
||||
for _, pos := range oiPositions {
|
||||
// 标准化符号匹配
|
||||
symbol := pos.Symbol
|
||||
ctx.OITopDataMap[symbol] = &OITopData{
|
||||
Rank: pos.Rank,
|
||||
OIDeltaPercent: pos.OIDeltaPercent,
|
||||
OIDeltaValue: pos.OIDeltaValue,
|
||||
PriceDeltaPercent: pos.PriceDeltaPercent,
|
||||
NetLong: pos.NetLong,
|
||||
NetShort: pos.NetShort,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateMaxCandidates 根据账户状态计算需要分析的候选币种数量
|
||||
func calculateMaxCandidates(ctx *Context) int {
|
||||
// 直接返回候选池的全部币种数量
|
||||
// 因为候选池已经在 auto_trader.go 中筛选过了
|
||||
// 固定分析前20个评分最高的币种(来自AI500)
|
||||
return len(ctx.CandidateCoins)
|
||||
}
|
||||
|
||||
// buildSystemPrompt 构建 System Prompt(固定规则,可缓存)
|
||||
func buildSystemPrompt(accountEquity float64) string {
|
||||
var sb strings.Builder
|
||||
|
||||
// 角色定义
|
||||
sb.WriteString("你是专业的加密货币交易AI,在币安合约市场进行自主交易。\n\n")
|
||||
sb.WriteString("**使命**: 最大化风险调整后收益(Sharpe Ratio)\n\n")
|
||||
|
||||
// 自我进化核心
|
||||
sb.WriteString("## 🧬 自我进化机制\n")
|
||||
sb.WriteString("每次调用你都会收到**夏普比率**作为你的业绩指标(周期级别,非年化):\n\n")
|
||||
sb.WriteString("**夏普比率解读**(正常范围 -2 到 +2):\n")
|
||||
sb.WriteString("- < -0.5:持续亏损 → 🔴 极度保守策略(减仓、收紧止损、减少持仓数)\n")
|
||||
sb.WriteString("- -0.5 到 0:轻微亏损 → 🟡 优化策略(保守仓位、提高选币标准)\n")
|
||||
sb.WriteString("- 0 到 0.7:正收益 → 🟢 维持/优化当前策略\n")
|
||||
sb.WriteString("- > 0.7:优异表现 → 🟢 可适度扩大仓位\n\n")
|
||||
|
||||
// 仓位管理规则
|
||||
sb.WriteString("## 仓位管理\n")
|
||||
sb.WriteString("- 最多持有 **3个币种**(质量>数量)\n")
|
||||
sb.WriteString(fmt.Sprintf("- 山寨币: %.0f-%.0f USDT/仓(推荐%.0f),杠杆20x\n",
|
||||
accountEquity*0.8, accountEquity*1.5, accountEquity*1.2))
|
||||
sb.WriteString(fmt.Sprintf("- BTC/ETH: %.0f-%.0f USDT/仓(推荐%.0f),杠杆50x\n",
|
||||
accountEquity*3, accountEquity*10, accountEquity*5))
|
||||
sb.WriteString("- 保证金使用率 ≤90%%\n")
|
||||
sb.WriteString("- 风险回报比 ≥1:2\n\n")
|
||||
|
||||
// 决策流程
|
||||
sb.WriteString("## 决策流程\n")
|
||||
sb.WriteString("1. **检查夏普比率**:理解当前策略效果,根据夏普比率调整策略\n")
|
||||
sb.WriteString("2. **评估持仓**:决定平仓/持有\n")
|
||||
sb.WriteString("3. **寻找机会**:筛选候选币种\n")
|
||||
sb.WriteString("4. **执行决策**:输出思维链和JSON决策\n\n")
|
||||
|
||||
// JSON 输出格式
|
||||
sb.WriteString("## 输出格式\n\n")
|
||||
sb.WriteString("**先输出思维链(纯文本),再输出JSON数组**\n\n")
|
||||
sb.WriteString("JSON示例:\n")
|
||||
sb.WriteString("```json\n")
|
||||
sb.WriteString("[\n")
|
||||
sb.WriteString(fmt.Sprintf(" {\"symbol\": \"BTCUSDT\", \"action\": \"open_long\", \"leverage\": 50, \"position_size_usd\": %.0f, \"stop_loss\": 92000, \"take_profit\": 98000, \"confidence\": 85, \"risk_usd\": 200, \"reasoning\": \"强势突破\"},\n", accountEquity*5))
|
||||
sb.WriteString(" {\"symbol\": \"ETHUSDT\", \"action\": \"close_long\", \"reasoning\": \"止盈\"}\n")
|
||||
sb.WriteString("]\n")
|
||||
sb.WriteString("```\n\n")
|
||||
sb.WriteString("**字段说明**:\n")
|
||||
sb.WriteString("- `action`: open_long | open_short | close_long | close_short | hold | wait\n")
|
||||
sb.WriteString("- `confidence`: 信心度0-100(必填,即使不确定也要给出)\n")
|
||||
sb.WriteString("- `risk_usd`: 最大美元风险 = (entry_price - stop_loss) × quantity(开仓时必填)\n")
|
||||
sb.WriteString("- 开仓时必填: leverage, position_size_usd, stop_loss, take_profit, confidence, risk_usd\n\n")
|
||||
|
||||
// DeepSeek/Qwen 特定优化
|
||||
sb.WriteString("**提示**: 运用技术分析原理,趋势确认>指标信号,不要过度依赖单一指标\n")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// buildUserPrompt 构建 User Prompt(动态数据)
|
||||
func buildUserPrompt(ctx *Context) string {
|
||||
var sb strings.Builder
|
||||
|
||||
// 系统状态
|
||||
sb.WriteString(fmt.Sprintf("**时间**: %s | **周期**: #%d | **运行**: %d分钟\n\n",
|
||||
ctx.CurrentTime, ctx.CallCount, ctx.RuntimeMinutes))
|
||||
|
||||
// BTC 市场
|
||||
if btcData, hasBTC := ctx.MarketDataMap["BTCUSDT"]; hasBTC {
|
||||
sb.WriteString(fmt.Sprintf("**BTC**: %.2f (1h: %+.2f%%, 4h: %+.2f%%) | MACD: %.4f | RSI: %.2f\n\n",
|
||||
btcData.CurrentPrice, btcData.PriceChange1h, btcData.PriceChange4h,
|
||||
btcData.CurrentMACD, btcData.CurrentRSI7))
|
||||
}
|
||||
|
||||
// 账户
|
||||
sb.WriteString(fmt.Sprintf("**账户**: 净值%.2f | 余额%.2f (%.1f%%) | 盈亏%+.2f%% | 保证金%.1f%% | 持仓%d个\n\n",
|
||||
ctx.Account.TotalEquity,
|
||||
ctx.Account.AvailableBalance,
|
||||
(ctx.Account.AvailableBalance/ctx.Account.TotalEquity)*100,
|
||||
ctx.Account.TotalPnLPct,
|
||||
ctx.Account.MarginUsedPct,
|
||||
ctx.Account.PositionCount))
|
||||
|
||||
// 持仓(完整市场数据)
|
||||
if len(ctx.Positions) > 0 {
|
||||
sb.WriteString("## 当前持仓\n")
|
||||
for i, pos := range ctx.Positions {
|
||||
sb.WriteString(fmt.Sprintf("%d. %s %s | 入场价%.4f 当前价%.4f | 盈亏%+.2f%% | 杠杆%dx | 保证金%.0f | 强平价%.4f\n\n",
|
||||
i+1, pos.Symbol, strings.ToUpper(pos.Side),
|
||||
pos.EntryPrice, pos.MarkPrice, pos.UnrealizedPnLPct,
|
||||
pos.Leverage, pos.MarginUsed, pos.LiquidationPrice))
|
||||
|
||||
// 使用FormatMarketData输出完整市场数据
|
||||
if marketData, ok := ctx.MarketDataMap[pos.Symbol]; ok {
|
||||
sb.WriteString(market.Format(marketData))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
sb.WriteString("**当前持仓**: 无\n\n")
|
||||
}
|
||||
|
||||
// 候选币种(完整市场数据)
|
||||
sb.WriteString(fmt.Sprintf("## 候选币种 (%d个)\n\n", len(ctx.MarketDataMap)))
|
||||
displayedCount := 0
|
||||
for _, coin := range ctx.CandidateCoins {
|
||||
marketData, hasData := ctx.MarketDataMap[coin.Symbol]
|
||||
if !hasData {
|
||||
continue
|
||||
}
|
||||
displayedCount++
|
||||
|
||||
sourceTags := ""
|
||||
if len(coin.Sources) > 1 {
|
||||
sourceTags = " (AI500+OI_Top双重信号)"
|
||||
} else if len(coin.Sources) == 1 && coin.Sources[0] == "oi_top" {
|
||||
sourceTags = " (OI_Top持仓增长)"
|
||||
}
|
||||
|
||||
// 使用FormatMarketData输出完整市场数据
|
||||
sb.WriteString(fmt.Sprintf("### %d. %s%s\n\n", displayedCount, coin.Symbol, sourceTags))
|
||||
sb.WriteString(market.Format(marketData))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
|
||||
// 夏普比率(直接传值,不要复杂格式化)
|
||||
if ctx.Performance != nil {
|
||||
// 直接从interface{}中提取SharpeRatio
|
||||
type PerformanceData struct {
|
||||
SharpeRatio float64 `json:"sharpe_ratio"`
|
||||
}
|
||||
var perfData PerformanceData
|
||||
if jsonData, err := json.Marshal(ctx.Performance); err == nil {
|
||||
if err := json.Unmarshal(jsonData, &perfData); err == nil {
|
||||
sb.WriteString(fmt.Sprintf("## 📊 夏普比率: %.2f\n\n", perfData.SharpeRatio))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("---\n\n")
|
||||
sb.WriteString("现在请分析并输出决策(思维链 + JSON)\n")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// parseFullDecisionResponse 解析AI的完整决策响应
|
||||
func parseFullDecisionResponse(aiResponse string, accountEquity float64) (*FullDecision, error) {
|
||||
// 1. 提取思维链
|
||||
cotTrace := extractCoTTrace(aiResponse)
|
||||
|
||||
// 2. 提取JSON决策列表
|
||||
decisions, err := extractDecisions(aiResponse)
|
||||
if err != nil {
|
||||
return &FullDecision{
|
||||
CoTTrace: cotTrace,
|
||||
Decisions: []Decision{},
|
||||
}, fmt.Errorf("提取决策失败: %w\n\n=== AI思维链分析 ===\n%s", err, cotTrace)
|
||||
}
|
||||
|
||||
// 3. 验证决策
|
||||
if err := validateDecisions(decisions, accountEquity); err != nil {
|
||||
return &FullDecision{
|
||||
CoTTrace: cotTrace,
|
||||
Decisions: decisions,
|
||||
}, fmt.Errorf("决策验证失败: %w\n\n=== AI思维链分析 ===\n%s", err, cotTrace)
|
||||
}
|
||||
|
||||
return &FullDecision{
|
||||
CoTTrace: cotTrace,
|
||||
Decisions: decisions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractCoTTrace 提取思维链分析
|
||||
func extractCoTTrace(response string) string {
|
||||
// 查找JSON数组的开始位置
|
||||
jsonStart := strings.Index(response, "[")
|
||||
|
||||
if jsonStart > 0 {
|
||||
// 思维链是JSON数组之前的内容
|
||||
return strings.TrimSpace(response[:jsonStart])
|
||||
}
|
||||
|
||||
// 如果找不到JSON,整个响应都是思维链
|
||||
return strings.TrimSpace(response)
|
||||
}
|
||||
|
||||
// extractDecisions 提取JSON决策列表
|
||||
func extractDecisions(response string) ([]Decision, error) {
|
||||
// 直接查找JSON数组 - 找第一个完整的JSON数组
|
||||
arrayStart := strings.Index(response, "[")
|
||||
if arrayStart == -1 {
|
||||
return nil, fmt.Errorf("无法找到JSON数组起始")
|
||||
}
|
||||
|
||||
// 从 [ 开始,匹配括号找到对应的 ]
|
||||
arrayEnd := findMatchingBracket(response, arrayStart)
|
||||
if arrayEnd == -1 {
|
||||
return nil, fmt.Errorf("无法找到JSON数组结束")
|
||||
}
|
||||
|
||||
jsonContent := strings.TrimSpace(response[arrayStart : arrayEnd+1])
|
||||
|
||||
// 🔧 修复常见的JSON格式错误:缺少引号的字段值
|
||||
// 匹配: "reasoning": 内容"} 或 "reasoning": 内容} (没有引号)
|
||||
// 修复为: "reasoning": "内容"}
|
||||
// 使用简单的字符串扫描而不是正则表达式
|
||||
jsonContent = fixMissingQuotes(jsonContent)
|
||||
|
||||
// 解析JSON
|
||||
var decisions []Decision
|
||||
if err := json.Unmarshal([]byte(jsonContent), &decisions); err != nil {
|
||||
return nil, fmt.Errorf("JSON解析失败: %w\nJSON内容: %s", err, jsonContent)
|
||||
}
|
||||
|
||||
return decisions, nil
|
||||
}
|
||||
|
||||
// fixMissingQuotes 替换中文引号为英文引号(避免输入法自动转换)
|
||||
func fixMissingQuotes(jsonStr string) string {
|
||||
jsonStr = strings.ReplaceAll(jsonStr, "\u201c", "\"") // "
|
||||
jsonStr = strings.ReplaceAll(jsonStr, "\u201d", "\"") // "
|
||||
jsonStr = strings.ReplaceAll(jsonStr, "\u2018", "'") // '
|
||||
jsonStr = strings.ReplaceAll(jsonStr, "\u2019", "'") // '
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
// validateDecisions 验证所有决策(需要账户信息)
|
||||
func validateDecisions(decisions []Decision, accountEquity float64) error {
|
||||
for i, decision := range decisions {
|
||||
if err := validateDecision(&decision, accountEquity); err != nil {
|
||||
return fmt.Errorf("决策 #%d 验证失败: %w", i+1, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// findMatchingBracket 查找匹配的右括号
|
||||
func findMatchingBracket(s string, start int) int {
|
||||
if start >= len(s) || s[start] != '[' {
|
||||
return -1
|
||||
}
|
||||
|
||||
depth := 0
|
||||
for i := start; i < len(s); i++ {
|
||||
switch s[i] {
|
||||
case '[':
|
||||
depth++
|
||||
case ']':
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
// validateDecision 验证单个决策的有效性
|
||||
func validateDecision(d *Decision, accountEquity float64) error {
|
||||
// 验证action
|
||||
validActions := map[string]bool{
|
||||
"open_long": true,
|
||||
"open_short": true,
|
||||
"close_long": true,
|
||||
"close_short": true,
|
||||
"hold": true,
|
||||
"wait": true,
|
||||
}
|
||||
|
||||
if !validActions[d.Action] {
|
||||
return fmt.Errorf("无效的action: %s", d.Action)
|
||||
}
|
||||
|
||||
// 开仓操作必须提供完整参数
|
||||
if d.Action == "open_long" || d.Action == "open_short" {
|
||||
// 根据币种判断杠杆上限和仓位价值上限
|
||||
maxLeverage := 20 // 山寨币固定20倍
|
||||
maxPositionValue := accountEquity * 1.5 // 山寨币最多1.5倍账户净值
|
||||
if d.Symbol == "BTCUSDT" || d.Symbol == "ETHUSDT" {
|
||||
maxLeverage = 50 // BTC和ETH固定50倍
|
||||
maxPositionValue = accountEquity * 10 // BTC/ETH最多10倍账户净值
|
||||
}
|
||||
|
||||
if d.Leverage <= 0 || d.Leverage > maxLeverage {
|
||||
return fmt.Errorf("杠杆必须在1-%d之间(%s): %d", maxLeverage, d.Symbol, d.Leverage)
|
||||
}
|
||||
if d.PositionSizeUSD <= 0 {
|
||||
return fmt.Errorf("仓位大小必须大于0: %.2f", d.PositionSizeUSD)
|
||||
}
|
||||
// 验证仓位价值上限(加1%容差以避免浮点数精度问题)
|
||||
tolerance := maxPositionValue * 0.01 // 1%容差
|
||||
if d.PositionSizeUSD > maxPositionValue+tolerance {
|
||||
if d.Symbol == "BTCUSDT" || d.Symbol == "ETHUSDT" {
|
||||
return fmt.Errorf("BTC/ETH单币种仓位价值不能超过%.0f USDT(10倍账户净值),实际: %.0f", maxPositionValue, d.PositionSizeUSD)
|
||||
} else {
|
||||
return fmt.Errorf("山寨币单币种仓位价值不能超过%.0f USDT(1.5倍账户净值),实际: %.0f", maxPositionValue, d.PositionSizeUSD)
|
||||
}
|
||||
}
|
||||
if d.StopLoss <= 0 || d.TakeProfit <= 0 {
|
||||
return fmt.Errorf("止损和止盈必须大于0")
|
||||
}
|
||||
|
||||
// 验证止损止盈的合理性
|
||||
if d.Action == "open_long" {
|
||||
if d.StopLoss >= d.TakeProfit {
|
||||
return fmt.Errorf("做多时止损价必须小于止盈价")
|
||||
}
|
||||
} else {
|
||||
if d.StopLoss <= d.TakeProfit {
|
||||
return fmt.Errorf("做空时止损价必须大于止盈价")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
+552
@@ -0,0 +1,552 @@
|
||||
package market
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Data 市场数据结构
|
||||
type Data struct {
|
||||
Symbol string
|
||||
CurrentPrice float64
|
||||
PriceChange1h float64 // 1小时价格变化百分比
|
||||
PriceChange4h float64 // 4小时价格变化百分比
|
||||
CurrentEMA20 float64
|
||||
CurrentMACD float64
|
||||
CurrentRSI7 float64
|
||||
OpenInterest *OIData
|
||||
FundingRate float64
|
||||
IntradaySeries *IntradayData
|
||||
LongerTermContext *LongerTermData
|
||||
}
|
||||
|
||||
// OIData Open Interest数据
|
||||
type OIData struct {
|
||||
Latest float64
|
||||
Average float64
|
||||
}
|
||||
|
||||
// IntradayData 日内数据(3分钟间隔)
|
||||
type IntradayData struct {
|
||||
MidPrices []float64
|
||||
EMA20Values []float64
|
||||
MACDValues []float64
|
||||
RSI7Values []float64
|
||||
RSI14Values []float64
|
||||
}
|
||||
|
||||
// LongerTermData 长期数据(4小时时间框架)
|
||||
type LongerTermData struct {
|
||||
EMA20 float64
|
||||
EMA50 float64
|
||||
ATR3 float64
|
||||
ATR14 float64
|
||||
CurrentVolume float64
|
||||
AverageVolume float64
|
||||
MACDValues []float64
|
||||
RSI14Values []float64
|
||||
}
|
||||
|
||||
// Kline K线数据
|
||||
type Kline struct {
|
||||
OpenTime int64
|
||||
Open float64
|
||||
High float64
|
||||
Low float64
|
||||
Close float64
|
||||
Volume float64
|
||||
CloseTime int64
|
||||
}
|
||||
|
||||
// Get 获取指定代币的市场数据
|
||||
func Get(symbol string) (*Data, error) {
|
||||
// 标准化symbol
|
||||
symbol = Normalize(symbol)
|
||||
|
||||
// 获取3分钟K线数据 (最近10个)
|
||||
klines3m, err := getKlines(symbol, "3m", 40) // 多获取一些用于计算
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取3分钟K线失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取4小时K线数据 (最近10个)
|
||||
klines4h, err := getKlines(symbol, "4h", 60) // 多获取用于计算指标
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取4小时K线失败: %v", err)
|
||||
}
|
||||
|
||||
// 计算当前指标 (基于3分钟最新数据)
|
||||
currentPrice := klines3m[len(klines3m)-1].Close
|
||||
currentEMA20 := calculateEMA(klines3m, 20)
|
||||
currentMACD := calculateMACD(klines3m)
|
||||
currentRSI7 := calculateRSI(klines3m, 7)
|
||||
|
||||
// 计算价格变化百分比
|
||||
// 1小时价格变化 = 20个3分钟K线前的价格
|
||||
priceChange1h := 0.0
|
||||
if len(klines3m) >= 21 { // 至少需要21根K线 (当前 + 20根前)
|
||||
price1hAgo := klines3m[len(klines3m)-21].Close
|
||||
if price1hAgo > 0 {
|
||||
priceChange1h = ((currentPrice - price1hAgo) / price1hAgo) * 100
|
||||
}
|
||||
}
|
||||
|
||||
// 4小时价格变化 = 1个4小时K线前的价格
|
||||
priceChange4h := 0.0
|
||||
if len(klines4h) >= 2 {
|
||||
price4hAgo := klines4h[len(klines4h)-2].Close
|
||||
if price4hAgo > 0 {
|
||||
priceChange4h = ((currentPrice - price4hAgo) / price4hAgo) * 100
|
||||
}
|
||||
}
|
||||
|
||||
// 获取OI数据
|
||||
oiData, err := getOpenInterestData(symbol)
|
||||
if err != nil {
|
||||
// OI失败不影响整体,使用默认值
|
||||
oiData = &OIData{Latest: 0, Average: 0}
|
||||
}
|
||||
|
||||
// 获取Funding Rate
|
||||
fundingRate, _ := getFundingRate(symbol)
|
||||
|
||||
// 计算日内系列数据
|
||||
intradayData := calculateIntradaySeries(klines3m)
|
||||
|
||||
// 计算长期数据
|
||||
longerTermData := calculateLongerTermData(klines4h)
|
||||
|
||||
return &Data{
|
||||
Symbol: symbol,
|
||||
CurrentPrice: currentPrice,
|
||||
PriceChange1h: priceChange1h,
|
||||
PriceChange4h: priceChange4h,
|
||||
CurrentEMA20: currentEMA20,
|
||||
CurrentMACD: currentMACD,
|
||||
CurrentRSI7: currentRSI7,
|
||||
OpenInterest: oiData,
|
||||
FundingRate: fundingRate,
|
||||
IntradaySeries: intradayData,
|
||||
LongerTermContext: longerTermData,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getKlines 从Binance获取K线数据
|
||||
func getKlines(symbol, interval string, limit int) ([]Kline, error) {
|
||||
url := fmt.Sprintf("https://fapi.binance.com/fapi/v1/klines?symbol=%s&interval=%s&limit=%d",
|
||||
symbol, interval, limit)
|
||||
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var rawData [][]interface{}
|
||||
if err := json.Unmarshal(body, &rawData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
klines := make([]Kline, len(rawData))
|
||||
for i, item := range rawData {
|
||||
openTime := int64(item[0].(float64))
|
||||
open, _ := parseFloat(item[1])
|
||||
high, _ := parseFloat(item[2])
|
||||
low, _ := parseFloat(item[3])
|
||||
close, _ := parseFloat(item[4])
|
||||
volume, _ := parseFloat(item[5])
|
||||
closeTime := int64(item[6].(float64))
|
||||
|
||||
klines[i] = Kline{
|
||||
OpenTime: openTime,
|
||||
Open: open,
|
||||
High: high,
|
||||
Low: low,
|
||||
Close: close,
|
||||
Volume: volume,
|
||||
CloseTime: closeTime,
|
||||
}
|
||||
}
|
||||
|
||||
return klines, nil
|
||||
}
|
||||
|
||||
// calculateEMA 计算EMA
|
||||
func calculateEMA(klines []Kline, period int) float64 {
|
||||
if len(klines) < period {
|
||||
return 0
|
||||
}
|
||||
|
||||
// 计算SMA作为初始EMA
|
||||
sum := 0.0
|
||||
for i := 0; i < period; i++ {
|
||||
sum += klines[i].Close
|
||||
}
|
||||
ema := sum / float64(period)
|
||||
|
||||
// 计算EMA
|
||||
multiplier := 2.0 / float64(period+1)
|
||||
for i := period; i < len(klines); i++ {
|
||||
ema = (klines[i].Close-ema)*multiplier + ema
|
||||
}
|
||||
|
||||
return ema
|
||||
}
|
||||
|
||||
// calculateMACD 计算MACD
|
||||
func calculateMACD(klines []Kline) float64 {
|
||||
if len(klines) < 26 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// 计算12期和26期EMA
|
||||
ema12 := calculateEMA(klines, 12)
|
||||
ema26 := calculateEMA(klines, 26)
|
||||
|
||||
// MACD = EMA12 - EMA26
|
||||
return ema12 - ema26
|
||||
}
|
||||
|
||||
// calculateRSI 计算RSI
|
||||
func calculateRSI(klines []Kline, period int) float64 {
|
||||
if len(klines) <= period {
|
||||
return 0
|
||||
}
|
||||
|
||||
gains := 0.0
|
||||
losses := 0.0
|
||||
|
||||
// 计算初始平均涨跌幅
|
||||
for i := 1; i <= period; i++ {
|
||||
change := klines[i].Close - klines[i-1].Close
|
||||
if change > 0 {
|
||||
gains += change
|
||||
} else {
|
||||
losses += -change
|
||||
}
|
||||
}
|
||||
|
||||
avgGain := gains / float64(period)
|
||||
avgLoss := losses / float64(period)
|
||||
|
||||
// 使用Wilder平滑方法计算后续RSI
|
||||
for i := period + 1; i < len(klines); i++ {
|
||||
change := klines[i].Close - klines[i-1].Close
|
||||
if change > 0 {
|
||||
avgGain = (avgGain*float64(period-1) + change) / float64(period)
|
||||
avgLoss = (avgLoss * float64(period-1)) / float64(period)
|
||||
} else {
|
||||
avgGain = (avgGain * float64(period-1)) / float64(period)
|
||||
avgLoss = (avgLoss*float64(period-1) + (-change)) / float64(period)
|
||||
}
|
||||
}
|
||||
|
||||
if avgLoss == 0 {
|
||||
return 100
|
||||
}
|
||||
|
||||
rs := avgGain / avgLoss
|
||||
rsi := 100 - (100 / (1 + rs))
|
||||
|
||||
return rsi
|
||||
}
|
||||
|
||||
// calculateATR 计算ATR
|
||||
func calculateATR(klines []Kline, period int) float64 {
|
||||
if len(klines) <= period {
|
||||
return 0
|
||||
}
|
||||
|
||||
trs := make([]float64, len(klines))
|
||||
for i := 1; i < len(klines); i++ {
|
||||
high := klines[i].High
|
||||
low := klines[i].Low
|
||||
prevClose := klines[i-1].Close
|
||||
|
||||
tr1 := high - low
|
||||
tr2 := math.Abs(high - prevClose)
|
||||
tr3 := math.Abs(low - prevClose)
|
||||
|
||||
trs[i] = math.Max(tr1, math.Max(tr2, tr3))
|
||||
}
|
||||
|
||||
// 计算初始ATR
|
||||
sum := 0.0
|
||||
for i := 1; i <= period; i++ {
|
||||
sum += trs[i]
|
||||
}
|
||||
atr := sum / float64(period)
|
||||
|
||||
// Wilder平滑
|
||||
for i := period + 1; i < len(klines); i++ {
|
||||
atr = (atr*float64(period-1) + trs[i]) / float64(period)
|
||||
}
|
||||
|
||||
return atr
|
||||
}
|
||||
|
||||
// calculateIntradaySeries 计算日内系列数据
|
||||
func calculateIntradaySeries(klines []Kline) *IntradayData {
|
||||
data := &IntradayData{
|
||||
MidPrices: make([]float64, 0, 10),
|
||||
EMA20Values: make([]float64, 0, 10),
|
||||
MACDValues: make([]float64, 0, 10),
|
||||
RSI7Values: make([]float64, 0, 10),
|
||||
RSI14Values: make([]float64, 0, 10),
|
||||
}
|
||||
|
||||
// 获取最近10个数据点
|
||||
start := len(klines) - 10
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
|
||||
for i := start; i < len(klines); i++ {
|
||||
data.MidPrices = append(data.MidPrices, klines[i].Close)
|
||||
|
||||
// 计算每个点的EMA20
|
||||
if i >= 19 {
|
||||
ema20 := calculateEMA(klines[:i+1], 20)
|
||||
data.EMA20Values = append(data.EMA20Values, ema20)
|
||||
}
|
||||
|
||||
// 计算每个点的MACD
|
||||
if i >= 25 {
|
||||
macd := calculateMACD(klines[:i+1])
|
||||
data.MACDValues = append(data.MACDValues, macd)
|
||||
}
|
||||
|
||||
// 计算每个点的RSI
|
||||
if i >= 7 {
|
||||
rsi7 := calculateRSI(klines[:i+1], 7)
|
||||
data.RSI7Values = append(data.RSI7Values, rsi7)
|
||||
}
|
||||
if i >= 14 {
|
||||
rsi14 := calculateRSI(klines[:i+1], 14)
|
||||
data.RSI14Values = append(data.RSI14Values, rsi14)
|
||||
}
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// calculateLongerTermData 计算长期数据
|
||||
func calculateLongerTermData(klines []Kline) *LongerTermData {
|
||||
data := &LongerTermData{
|
||||
MACDValues: make([]float64, 0, 10),
|
||||
RSI14Values: make([]float64, 0, 10),
|
||||
}
|
||||
|
||||
// 计算EMA
|
||||
data.EMA20 = calculateEMA(klines, 20)
|
||||
data.EMA50 = calculateEMA(klines, 50)
|
||||
|
||||
// 计算ATR
|
||||
data.ATR3 = calculateATR(klines, 3)
|
||||
data.ATR14 = calculateATR(klines, 14)
|
||||
|
||||
// 计算成交量
|
||||
if len(klines) > 0 {
|
||||
data.CurrentVolume = klines[len(klines)-1].Volume
|
||||
// 计算平均成交量
|
||||
sum := 0.0
|
||||
for _, k := range klines {
|
||||
sum += k.Volume
|
||||
}
|
||||
data.AverageVolume = sum / float64(len(klines))
|
||||
}
|
||||
|
||||
// 计算MACD和RSI序列
|
||||
start := len(klines) - 10
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
|
||||
for i := start; i < len(klines); i++ {
|
||||
if i >= 25 {
|
||||
macd := calculateMACD(klines[:i+1])
|
||||
data.MACDValues = append(data.MACDValues, macd)
|
||||
}
|
||||
if i >= 14 {
|
||||
rsi14 := calculateRSI(klines[:i+1], 14)
|
||||
data.RSI14Values = append(data.RSI14Values, rsi14)
|
||||
}
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// getOpenInterestData 获取OI数据
|
||||
func getOpenInterestData(symbol string) (*OIData, error) {
|
||||
url := fmt.Sprintf("https://fapi.binance.com/fapi/v1/openInterest?symbol=%s", symbol)
|
||||
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result struct {
|
||||
OpenInterest string `json:"openInterest"`
|
||||
Symbol string `json:"symbol"`
|
||||
Time int64 `json:"time"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oi, _ := strconv.ParseFloat(result.OpenInterest, 64)
|
||||
|
||||
return &OIData{
|
||||
Latest: oi,
|
||||
Average: oi * 0.999, // 近似平均值
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getFundingRate 获取资金费率
|
||||
func getFundingRate(symbol string) (float64, error) {
|
||||
url := fmt.Sprintf("https://fapi.binance.com/fapi/v1/premiumIndex?symbol=%s", symbol)
|
||||
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Symbol string `json:"symbol"`
|
||||
MarkPrice string `json:"markPrice"`
|
||||
IndexPrice string `json:"indexPrice"`
|
||||
LastFundingRate string `json:"lastFundingRate"`
|
||||
NextFundingTime int64 `json:"nextFundingTime"`
|
||||
InterestRate string `json:"interestRate"`
|
||||
Time int64 `json:"time"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
rate, _ := strconv.ParseFloat(result.LastFundingRate, 64)
|
||||
return rate, nil
|
||||
}
|
||||
|
||||
// Format 格式化输出市场数据
|
||||
func Format(data *Data) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(fmt.Sprintf("current_price = %.2f, current_ema20 = %.3f, current_macd = %.3f, current_rsi (7 period) = %.3f\n\n",
|
||||
data.CurrentPrice, data.CurrentEMA20, data.CurrentMACD, data.CurrentRSI7))
|
||||
|
||||
sb.WriteString(fmt.Sprintf("In addition, here is the latest %s open interest and funding rate for perps:\n\n",
|
||||
data.Symbol))
|
||||
|
||||
if data.OpenInterest != nil {
|
||||
sb.WriteString(fmt.Sprintf("Open Interest: Latest: %.2f Average: %.2f\n\n",
|
||||
data.OpenInterest.Latest, data.OpenInterest.Average))
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf("Funding Rate: %.2e\n\n", data.FundingRate))
|
||||
|
||||
if data.IntradaySeries != nil {
|
||||
sb.WriteString("Intraday series (3‑minute intervals, oldest → latest):\n\n")
|
||||
|
||||
if len(data.IntradaySeries.MidPrices) > 0 {
|
||||
sb.WriteString(fmt.Sprintf("Mid prices: %s\n\n", formatFloatSlice(data.IntradaySeries.MidPrices)))
|
||||
}
|
||||
|
||||
if len(data.IntradaySeries.EMA20Values) > 0 {
|
||||
sb.WriteString(fmt.Sprintf("EMA indicators (20‑period): %s\n\n", formatFloatSlice(data.IntradaySeries.EMA20Values)))
|
||||
}
|
||||
|
||||
if len(data.IntradaySeries.MACDValues) > 0 {
|
||||
sb.WriteString(fmt.Sprintf("MACD indicators: %s\n\n", formatFloatSlice(data.IntradaySeries.MACDValues)))
|
||||
}
|
||||
|
||||
if len(data.IntradaySeries.RSI7Values) > 0 {
|
||||
sb.WriteString(fmt.Sprintf("RSI indicators (7‑Period): %s\n\n", formatFloatSlice(data.IntradaySeries.RSI7Values)))
|
||||
}
|
||||
|
||||
if len(data.IntradaySeries.RSI14Values) > 0 {
|
||||
sb.WriteString(fmt.Sprintf("RSI indicators (14‑Period): %s\n\n", formatFloatSlice(data.IntradaySeries.RSI14Values)))
|
||||
}
|
||||
}
|
||||
|
||||
if data.LongerTermContext != nil {
|
||||
sb.WriteString("Longer‑term context (4‑hour timeframe):\n\n")
|
||||
|
||||
sb.WriteString(fmt.Sprintf("20‑Period EMA: %.3f vs. 50‑Period EMA: %.3f\n\n",
|
||||
data.LongerTermContext.EMA20, data.LongerTermContext.EMA50))
|
||||
|
||||
sb.WriteString(fmt.Sprintf("3‑Period ATR: %.3f vs. 14‑Period ATR: %.3f\n\n",
|
||||
data.LongerTermContext.ATR3, data.LongerTermContext.ATR14))
|
||||
|
||||
sb.WriteString(fmt.Sprintf("Current Volume: %.3f vs. Average Volume: %.3f\n\n",
|
||||
data.LongerTermContext.CurrentVolume, data.LongerTermContext.AverageVolume))
|
||||
|
||||
if len(data.LongerTermContext.MACDValues) > 0 {
|
||||
sb.WriteString(fmt.Sprintf("MACD indicators: %s\n\n", formatFloatSlice(data.LongerTermContext.MACDValues)))
|
||||
}
|
||||
|
||||
if len(data.LongerTermContext.RSI14Values) > 0 {
|
||||
sb.WriteString(fmt.Sprintf("RSI indicators (14‑Period): %s\n\n", formatFloatSlice(data.LongerTermContext.RSI14Values)))
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// formatFloatSlice 格式化float64切片为字符串
|
||||
func formatFloatSlice(values []float64) string {
|
||||
strValues := make([]string, len(values))
|
||||
for i, v := range values {
|
||||
strValues[i] = fmt.Sprintf("%.3f", v)
|
||||
}
|
||||
return "[" + strings.Join(strValues, ", ") + "]"
|
||||
}
|
||||
|
||||
// Normalize 标准化symbol,确保是USDT交易对
|
||||
func Normalize(symbol string) string {
|
||||
symbol = strings.ToUpper(symbol)
|
||||
if strings.HasSuffix(symbol, "USDT") {
|
||||
return symbol
|
||||
}
|
||||
return symbol + "USDT"
|
||||
}
|
||||
|
||||
// parseFloat 解析float值
|
||||
func parseFloat(v interface{}) (float64, error) {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return strconv.ParseFloat(val, 64)
|
||||
case float64:
|
||||
return val, nil
|
||||
case int:
|
||||
return float64(val), nil
|
||||
case int64:
|
||||
return float64(val), nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported type: %T", v)
|
||||
}
|
||||
}
|
||||
+216
@@ -0,0 +1,216 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Provider AI提供商类型
|
||||
type Provider string
|
||||
|
||||
const (
|
||||
ProviderDeepSeek Provider = "deepseek"
|
||||
ProviderQwen Provider = "qwen"
|
||||
)
|
||||
|
||||
// Config AI API配置
|
||||
type Config struct {
|
||||
Provider Provider
|
||||
APIKey string
|
||||
SecretKey string // 阿里云需要
|
||||
BaseURL string
|
||||
Model string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// 默认配置
|
||||
var defaultConfig = Config{
|
||||
Provider: ProviderDeepSeek,
|
||||
BaseURL: "https://api.deepseek.com/v1",
|
||||
Model: "deepseek-chat",
|
||||
Timeout: 120 * time.Second, // 增加到120秒,因为AI需要分析大量数据
|
||||
}
|
||||
|
||||
// SetDeepSeekAPIKey 设置DeepSeek API密钥
|
||||
func SetDeepSeekAPIKey(apiKey string) {
|
||||
defaultConfig.Provider = ProviderDeepSeek
|
||||
defaultConfig.APIKey = apiKey
|
||||
defaultConfig.BaseURL = "https://api.deepseek.com/v1"
|
||||
defaultConfig.Model = "deepseek-chat"
|
||||
}
|
||||
|
||||
// SetQwenAPIKey 设置阿里云Qwen API密钥
|
||||
func SetQwenAPIKey(apiKey, secretKey string) {
|
||||
defaultConfig.Provider = ProviderQwen
|
||||
defaultConfig.APIKey = apiKey
|
||||
defaultConfig.SecretKey = secretKey
|
||||
defaultConfig.BaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
defaultConfig.Model = "qwen-plus" // 可选: qwen-turbo, qwen-plus, qwen-max
|
||||
}
|
||||
|
||||
// SetConfig 设置完整的AI配置(高级用户)
|
||||
func SetConfig(config Config) {
|
||||
if config.Timeout == 0 {
|
||||
config.Timeout = 30 * time.Second
|
||||
}
|
||||
defaultConfig = config
|
||||
}
|
||||
|
||||
// CallWithMessages 使用 system + user prompt 调用AI API(推荐)
|
||||
func CallWithMessages(systemPrompt, userPrompt string) (string, error) {
|
||||
if defaultConfig.APIKey == "" {
|
||||
return "", fmt.Errorf("AI API密钥未设置,请先调用 SetDeepSeekAPIKey() 或 SetQwenAPIKey()")
|
||||
}
|
||||
|
||||
// 重试配置
|
||||
maxRetries := 3
|
||||
var lastErr error
|
||||
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
if attempt > 1 {
|
||||
fmt.Printf("⚠️ AI API调用失败,正在重试 (%d/%d)...\n", attempt, maxRetries)
|
||||
}
|
||||
|
||||
result, err := callOnce(systemPrompt, userPrompt)
|
||||
if err == nil {
|
||||
if attempt > 1 {
|
||||
fmt.Printf("✓ AI API重试成功\n")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
// 如果不是网络错误,不重试
|
||||
if !isRetryableError(err) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 重试前等待
|
||||
if attempt < maxRetries {
|
||||
waitTime := time.Duration(attempt) * 2 * time.Second
|
||||
fmt.Printf("⏳ 等待%v后重试...\n", waitTime)
|
||||
time.Sleep(waitTime)
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
// callOnce 单次调用AI API(内部使用)
|
||||
func callOnce(systemPrompt, userPrompt string) (string, error) {
|
||||
// 构建 messages 数组
|
||||
messages := []map[string]string{}
|
||||
|
||||
// 如果有 system prompt,添加 system message
|
||||
if systemPrompt != "" {
|
||||
messages = append(messages, map[string]string{
|
||||
"role": "system",
|
||||
"content": systemPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
// 添加 user message
|
||||
messages = append(messages, map[string]string{
|
||||
"role": "user",
|
||||
"content": userPrompt,
|
||||
})
|
||||
|
||||
// 构建请求体
|
||||
requestBody := map[string]interface{}{
|
||||
"model": defaultConfig.Model,
|
||||
"messages": messages,
|
||||
"temperature": 0.5, // 降低temperature以提高JSON格式稳定性
|
||||
"max_tokens": 2000,
|
||||
}
|
||||
|
||||
// 注意:response_format 参数仅 OpenAI 支持,DeepSeek/Qwen 不支持
|
||||
// 我们通过强化 prompt 和后处理来确保 JSON 格式正确
|
||||
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建HTTP请求
|
||||
url := fmt.Sprintf("%s/chat/completions", defaultConfig.BaseURL)
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// 根据不同的Provider设置认证方式
|
||||
switch defaultConfig.Provider {
|
||||
case ProviderDeepSeek:
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", defaultConfig.APIKey))
|
||||
case ProviderQwen:
|
||||
// 阿里云Qwen使用API-Key认证
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", defaultConfig.APIKey))
|
||||
// 注意:如果使用的不是兼容模式,可能需要不同的认证方式
|
||||
default:
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", defaultConfig.APIKey))
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
client := &http.Client{Timeout: defaultConfig.Timeout}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("发送请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("API返回错误 (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var result struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return "", fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
if len(result.Choices) == 0 {
|
||||
return "", fmt.Errorf("API返回空响应")
|
||||
}
|
||||
|
||||
return result.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
// isRetryableError 判断错误是否可重试
|
||||
func isRetryableError(err error) bool {
|
||||
errStr := err.Error()
|
||||
// 网络错误、超时、EOF等可以重试
|
||||
retryableErrors := []string{
|
||||
"EOF",
|
||||
"timeout",
|
||||
"connection reset",
|
||||
"connection refused",
|
||||
"temporary failure",
|
||||
"no such host",
|
||||
}
|
||||
for _, retryable := range retryableErrors {
|
||||
if strings.Contains(errStr, retryable) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
+23
-130
@@ -4,8 +4,10 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"nofx/decision"
|
||||
"nofx/logger"
|
||||
"nofx/market"
|
||||
"nofx/mcp"
|
||||
"nofx/pool"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -76,10 +78,10 @@ func NewAutoTrader(config AutoTraderConfig) (*AutoTrader, error) {
|
||||
|
||||
// 初始化AI
|
||||
if config.UseQwen {
|
||||
market.SetQwenAPIKey(config.QwenKey, "")
|
||||
mcp.SetQwenAPIKey(config.QwenKey, "")
|
||||
log.Printf("🤖 [%s] 使用阿里云Qwen AI", config.Name)
|
||||
} else {
|
||||
market.SetDeepSeekAPIKey(config.DeepSeekKey)
|
||||
mcp.SetDeepSeekAPIKey(config.DeepSeekKey)
|
||||
log.Printf("🤖 [%s] 使用DeepSeek AI", config.Name)
|
||||
}
|
||||
|
||||
@@ -222,7 +224,7 @@ func (at *AutoTrader) runCycle() error {
|
||||
|
||||
// 4. 调用AI获取完整决策
|
||||
log.Println("🤖 正在请求AI分析并决策...")
|
||||
decision, err := market.GetFullTradingDecision(ctx)
|
||||
decision, err := decision.GetFullDecision(ctx)
|
||||
|
||||
// 即使有错误,也保存思维链、决策和输入prompt(用于debug)
|
||||
if decision != nil {
|
||||
@@ -313,7 +315,7 @@ func (at *AutoTrader) runCycle() error {
|
||||
}
|
||||
|
||||
// buildTradingContext 构建交易上下文
|
||||
func (at *AutoTrader) buildTradingContext() (*market.TradingContext, error) {
|
||||
func (at *AutoTrader) buildTradingContext() (*decision.Context, error) {
|
||||
// 1. 获取账户信息
|
||||
balance, err := at.trader.GetBalance()
|
||||
if err != nil {
|
||||
@@ -344,7 +346,7 @@ func (at *AutoTrader) buildTradingContext() (*market.TradingContext, error) {
|
||||
return nil, fmt.Errorf("获取持仓失败: %w", err)
|
||||
}
|
||||
|
||||
var positionInfos []market.PositionInfo
|
||||
var positionInfos []decision.PositionInfo
|
||||
totalMarginUsed := 0.0
|
||||
|
||||
for _, pos := range positions {
|
||||
@@ -375,7 +377,7 @@ func (at *AutoTrader) buildTradingContext() (*market.TradingContext, error) {
|
||||
marginUsed := (quantity * markPrice) / float64(leverage)
|
||||
totalMarginUsed += marginUsed
|
||||
|
||||
positionInfos = append(positionInfos, market.PositionInfo{
|
||||
positionInfos = append(positionInfos, decision.PositionInfo{
|
||||
Symbol: symbol,
|
||||
Side: side,
|
||||
EntryPrice: entryPrice,
|
||||
@@ -401,10 +403,10 @@ func (at *AutoTrader) buildTradingContext() (*market.TradingContext, error) {
|
||||
}
|
||||
|
||||
// 构建候选币种列表(包含来源信息)
|
||||
var candidateCoins []market.CandidateCoin
|
||||
var candidateCoins []decision.CandidateCoin
|
||||
for _, symbol := range mergedPool.AllSymbols {
|
||||
sources := mergedPool.SymbolSources[symbol]
|
||||
candidateCoins = append(candidateCoins, market.CandidateCoin{
|
||||
candidateCoins = append(candidateCoins, decision.CandidateCoin{
|
||||
Symbol: symbol,
|
||||
Sources: sources, // "ai500" 和/或 "oi_top"
|
||||
})
|
||||
@@ -434,11 +436,11 @@ func (at *AutoTrader) buildTradingContext() (*market.TradingContext, error) {
|
||||
}
|
||||
|
||||
// 6. 构建上下文
|
||||
ctx := &market.TradingContext{
|
||||
ctx := &decision.Context{
|
||||
CurrentTime: time.Now().Format("2006-01-02 15:04:05"),
|
||||
RuntimeMinutes: int(time.Since(at.startTime).Minutes()),
|
||||
CallCount: at.callCount,
|
||||
Account: market.AccountInfo{
|
||||
Account: decision.AccountInfo{
|
||||
TotalEquity: totalEquity,
|
||||
AvailableBalance: availableBalance,
|
||||
TotalPnL: totalPnL,
|
||||
@@ -455,27 +457,8 @@ func (at *AutoTrader) buildTradingContext() (*market.TradingContext, error) {
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
// executeDecision 执行AI决策
|
||||
func (at *AutoTrader) executeDecision(decision *market.TradingDecision) error {
|
||||
switch decision.Action {
|
||||
case "open_long":
|
||||
return at.executeOpenLong(decision)
|
||||
case "open_short":
|
||||
return at.executeOpenShort(decision)
|
||||
case "close_long":
|
||||
return at.executeCloseLong(decision)
|
||||
case "close_short":
|
||||
return at.executeCloseShort(decision)
|
||||
case "hold", "wait":
|
||||
// 无需执行,仅记录
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("未知的action: %s", decision.Action)
|
||||
}
|
||||
}
|
||||
|
||||
// executeDecisionWithRecord 执行AI决策并记录详细信息
|
||||
func (at *AutoTrader) executeDecisionWithRecord(decision *market.TradingDecision, actionRecord *logger.DecisionAction) error {
|
||||
func (at *AutoTrader) executeDecisionWithRecord(decision *decision.Decision, actionRecord *logger.DecisionAction) error {
|
||||
switch decision.Action {
|
||||
case "open_long":
|
||||
return at.executeOpenLongWithRecord(decision, actionRecord)
|
||||
@@ -493,98 +476,8 @@ func (at *AutoTrader) executeDecisionWithRecord(decision *market.TradingDecision
|
||||
}
|
||||
}
|
||||
|
||||
// executeOpenLong 执行开多仓
|
||||
func (at *AutoTrader) executeOpenLong(decision *market.TradingDecision) error {
|
||||
log.Printf(" 📈 开多仓: %s", decision.Symbol)
|
||||
|
||||
// 获取当前价格
|
||||
marketData, err := market.GetMarketData(decision.Symbol)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 计算数量:仓位大小(USD) / 当前价格
|
||||
quantity := decision.PositionSizeUSD / marketData.CurrentPrice
|
||||
|
||||
// 开仓
|
||||
order, err := at.trader.OpenLong(decision.Symbol, quantity, decision.Leverage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf(" ✓ 开仓成功,订单ID: %v, 数量: %.4f", order["orderId"], quantity)
|
||||
|
||||
// 设置止损止盈
|
||||
if err := at.trader.SetStopLoss(decision.Symbol, "LONG", quantity, decision.StopLoss); err != nil {
|
||||
log.Printf(" ⚠ 设置止损失败: %v", err)
|
||||
}
|
||||
if err := at.trader.SetTakeProfit(decision.Symbol, "LONG", quantity, decision.TakeProfit); err != nil {
|
||||
log.Printf(" ⚠ 设置止盈失败: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeOpenShort 执行开空仓
|
||||
func (at *AutoTrader) executeOpenShort(decision *market.TradingDecision) error {
|
||||
log.Printf(" 📉 开空仓: %s", decision.Symbol)
|
||||
|
||||
// 获取当前价格
|
||||
marketData, err := market.GetMarketData(decision.Symbol)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 计算数量
|
||||
quantity := decision.PositionSizeUSD / marketData.CurrentPrice
|
||||
|
||||
// 开仓
|
||||
order, err := at.trader.OpenShort(decision.Symbol, quantity, decision.Leverage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf(" ✓ 开仓成功,订单ID: %v, 数量: %.4f", order["orderId"], quantity)
|
||||
|
||||
// 设置止损止盈
|
||||
if err := at.trader.SetStopLoss(decision.Symbol, "SHORT", quantity, decision.StopLoss); err != nil {
|
||||
log.Printf(" ⚠ 设置止损失败: %v", err)
|
||||
}
|
||||
if err := at.trader.SetTakeProfit(decision.Symbol, "SHORT", quantity, decision.TakeProfit); err != nil {
|
||||
log.Printf(" ⚠ 设置止盈失败: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeCloseLong 执行平多仓
|
||||
func (at *AutoTrader) executeCloseLong(decision *market.TradingDecision) error {
|
||||
log.Printf(" 🔄 平多仓: %s", decision.Symbol)
|
||||
|
||||
_, err := at.trader.CloseLong(decision.Symbol, 0) // 0 = 全部平仓
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf(" ✓ 平仓成功")
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeCloseShort 执行平空仓
|
||||
func (at *AutoTrader) executeCloseShort(decision *market.TradingDecision) error {
|
||||
log.Printf(" 🔄 平空仓: %s", decision.Symbol)
|
||||
|
||||
_, err := at.trader.CloseShort(decision.Symbol, 0) // 0 = 全部平仓
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf(" ✓ 平仓成功")
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeOpenLongWithRecord 执行开多仓并记录详细信息
|
||||
func (at *AutoTrader) executeOpenLongWithRecord(decision *market.TradingDecision, actionRecord *logger.DecisionAction) error {
|
||||
func (at *AutoTrader) executeOpenLongWithRecord(decision *decision.Decision, actionRecord *logger.DecisionAction) error {
|
||||
log.Printf(" 📈 开多仓: %s", decision.Symbol)
|
||||
|
||||
// ⚠️ 关键:检查是否已有同币种同方向持仓,如果有则拒绝开仓(防止仓位叠加超限)
|
||||
@@ -598,7 +491,7 @@ func (at *AutoTrader) executeOpenLongWithRecord(decision *market.TradingDecision
|
||||
}
|
||||
|
||||
// 获取当前价格
|
||||
marketData, err := market.GetMarketData(decision.Symbol)
|
||||
marketData, err := market.Get(decision.Symbol)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -633,7 +526,7 @@ func (at *AutoTrader) executeOpenLongWithRecord(decision *market.TradingDecision
|
||||
}
|
||||
|
||||
// executeOpenShortWithRecord 执行开空仓并记录详细信息
|
||||
func (at *AutoTrader) executeOpenShortWithRecord(decision *market.TradingDecision, actionRecord *logger.DecisionAction) error {
|
||||
func (at *AutoTrader) executeOpenShortWithRecord(decision *decision.Decision, actionRecord *logger.DecisionAction) error {
|
||||
log.Printf(" 📉 开空仓: %s", decision.Symbol)
|
||||
|
||||
// ⚠️ 关键:检查是否已有同币种同方向持仓,如果有则拒绝开仓(防止仓位叠加超限)
|
||||
@@ -647,7 +540,7 @@ func (at *AutoTrader) executeOpenShortWithRecord(decision *market.TradingDecisio
|
||||
}
|
||||
|
||||
// 获取当前价格
|
||||
marketData, err := market.GetMarketData(decision.Symbol)
|
||||
marketData, err := market.Get(decision.Symbol)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -682,11 +575,11 @@ func (at *AutoTrader) executeOpenShortWithRecord(decision *market.TradingDecisio
|
||||
}
|
||||
|
||||
// executeCloseLongWithRecord 执行平多仓并记录详细信息
|
||||
func (at *AutoTrader) executeCloseLongWithRecord(decision *market.TradingDecision, actionRecord *logger.DecisionAction) error {
|
||||
func (at *AutoTrader) executeCloseLongWithRecord(decision *decision.Decision, actionRecord *logger.DecisionAction) error {
|
||||
log.Printf(" 🔄 平多仓: %s", decision.Symbol)
|
||||
|
||||
// 获取当前价格
|
||||
marketData, err := market.GetMarketData(decision.Symbol)
|
||||
marketData, err := market.Get(decision.Symbol)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -708,11 +601,11 @@ func (at *AutoTrader) executeCloseLongWithRecord(decision *market.TradingDecisio
|
||||
}
|
||||
|
||||
// executeCloseShortWithRecord 执行平空仓并记录详细信息
|
||||
func (at *AutoTrader) executeCloseShortWithRecord(decision *market.TradingDecision, actionRecord *logger.DecisionAction) error {
|
||||
func (at *AutoTrader) executeCloseShortWithRecord(decision *decision.Decision, actionRecord *logger.DecisionAction) error {
|
||||
log.Printf(" 🔄 平空仓: %s", decision.Symbol)
|
||||
|
||||
// 获取当前价格
|
||||
marketData, err := market.GetMarketData(decision.Symbol)
|
||||
marketData, err := market.Get(decision.Symbol)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -911,7 +804,7 @@ func (at *AutoTrader) GetPositions() ([]map[string]interface{}, error) {
|
||||
|
||||
// sortDecisionsByPriority 对决策排序:先平仓,再开仓,最后hold/wait
|
||||
// 这样可以避免换仓时仓位叠加超限
|
||||
func sortDecisionsByPriority(decisions []market.TradingDecision) []market.TradingDecision {
|
||||
func sortDecisionsByPriority(decisions []decision.Decision) []decision.Decision {
|
||||
if len(decisions) <= 1 {
|
||||
return decisions
|
||||
}
|
||||
@@ -931,7 +824,7 @@ func sortDecisionsByPriority(decisions []market.TradingDecision) []market.Tradin
|
||||
}
|
||||
|
||||
// 复制决策列表
|
||||
sorted := make([]market.TradingDecision, len(decisions))
|
||||
sorted := make([]decision.Decision, len(decisions))
|
||||
copy(sorted, decisions)
|
||||
|
||||
// 按优先级排序
|
||||
|
||||
Reference in New Issue
Block a user