diff --git a/.github/workflows/pr-checks-advisory.yml.old b/.github/workflows/pr-checks-advisory.yml.old deleted file mode 100644 index 898fdc10..00000000 --- a/.github/workflows/pr-checks-advisory.yml.old +++ /dev/null @@ -1,331 +0,0 @@ -name: PR Checks (Advisory) - -on: - pull_request: - types: [opened, synchronize, reopened] - branches: [main, dev] - -# These checks are advisory only - they won't block PR merging -# Results will be posted as comments to help contributors improve their PRs - -permissions: - contents: write - pull-requests: write - checks: write - issues: write - -jobs: - pr-info: - name: PR Information - runs-on: ubuntu-latest - steps: - - name: Check PR title format - id: check-title - run: | - PR_TITLE="${{ github.event.pull_request.title }}" - - # Check if title follows conventional commits - if echo "$PR_TITLE" | grep -qE "^(feat|fix|docs|style|refactor|perf|test|chore|ci|security)(\(.+\))?: .+"; then - echo "status=✅ Good" >> $GITHUB_OUTPUT - echo "message=PR title follows Conventional Commits format" >> $GITHUB_OUTPUT - else - echo "status=⚠️ Suggestion" >> $GITHUB_OUTPUT - echo "message=Consider using Conventional Commits format: type(scope): description" >> $GITHUB_OUTPUT - fi - - - name: Calculate PR size - id: pr-size - run: | - ADDITIONS=${{ github.event.pull_request.additions }} - DELETIONS=${{ github.event.pull_request.deletions }} - TOTAL=$((ADDITIONS + DELETIONS)) - - if [ $TOTAL -lt 100 ]; then - echo "size=🟢 Small" >> $GITHUB_OUTPUT - echo "label=size: small" >> $GITHUB_OUTPUT - elif [ $TOTAL -lt 500 ]; then - echo "size=🟡 Medium" >> $GITHUB_OUTPUT - echo "label=size: medium" >> $GITHUB_OUTPUT - else - echo "size=🔴 Large" >> $GITHUB_OUTPUT - echo "label=size: large" >> $GITHUB_OUTPUT - echo "suggestion=Consider breaking this into smaller PRs for easier review" >> $GITHUB_OUTPUT - fi - echo "lines=$TOTAL" >> $GITHUB_OUTPUT - - - name: Post advisory comment - uses: actions/github-script@v7 - with: - script: | - const titleStatus = '${{ steps.check-title.outputs.status }}'; - const titleMessage = '${{ steps.check-title.outputs.message }}'; - const prSize = '${{ steps.pr-size.outputs.size }}'; - const prLines = '${{ steps.pr-size.outputs.lines }}'; - const sizeSuggestion = '${{ steps.pr-size.outputs.suggestion }}' || ''; - - let comment = '## 🤖 PR Advisory Feedback\n\n'; - comment += 'Thank you for your contribution! Here\'s some automated feedback to help improve your PR:\n\n'; - comment += '### PR Title\n'; - comment += titleStatus + ' ' + titleMessage + '\n\n'; - comment += '### PR Size\n'; - comment += prSize + ' (' + prLines + ' lines changed)\n'; - if (sizeSuggestion) { - comment += '\n💡 **Suggestion:** ' + sizeSuggestion + '\n'; - } - comment += '\n---\n\n'; - comment += '### 📖 New PR Management System\n\n'; - comment += 'We\'re introducing a new PR management system! These checks are **advisory only** and won\'t block your PR.\n\n'; - comment += '**Want to check your PR against new standards?**\n'; - comment += '```bash\n'; - comment += '# Run the PR health check tool\n'; - comment += './scripts/pr-check.sh\n'; - comment += '```\n\n'; - comment += 'This tool will:\n'; - comment += '- 🔍 Analyze your PR (doesn\'t modify anything)\n'; - comment += '- ✅ Show what\'s already good\n'; - comment += '- ⚠️ Point out issues\n'; - comment += '- 💡 Give specific suggestions on how to fix\n\n'; - comment += '**Learn more:**\n'; - comment += '- [Migration Guide](https://github.com/NoFxAiOS/nofx/blob/dev/docs/community/MIGRATION_ANNOUNCEMENT.md)\n'; - comment += '- [Contributing Guidelines](https://github.com/NoFxAiOS/nofx/blob/dev/CONTRIBUTING.md)\n\n'; - comment += '**Questions?** Just ask in the comments! We\'re here to help. 🙏\n\n'; - comment += '---\n\n'; - comment += '*This is an automated message. It won\'t affect your PR being merged.*'; - - github.rest.issues.createComment({ - issue_number: context.issue.number, - owner: context.repo.owner, - repo: context.repo.repo, - body: comment - }); - - backend-checks: - name: Backend Checks (Advisory) - runs-on: ubuntu-latest - continue-on-error: true - steps: - - uses: actions/checkout@v4 - - - name: Set up Go - uses: actions/setup-go@v5 - with: - go-version: '1.21' - - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y libta-lib-dev || true - go mod download || true - - - name: Check Go formatting - id: go-fmt - continue-on-error: true - run: | - UNFORMATTED=$(gofmt -l . 2>/dev/null || echo "") - if [ -n "$UNFORMATTED" ]; then - echo "status=⚠️ Needs formatting" >> $GITHUB_OUTPUT - echo "files<> $GITHUB_OUTPUT - echo "$UNFORMATTED" | head -10 >> $GITHUB_OUTPUT - echo "EOF" >> $GITHUB_OUTPUT - else - echo "status=✅ Good" >> $GITHUB_OUTPUT - echo "files=" >> $GITHUB_OUTPUT - fi - - - name: Run go vet - id: go-vet - continue-on-error: true - run: | - if go vet ./... 2>&1 | tee vet-output.txt; then - echo "status=✅ Good" >> $GITHUB_OUTPUT - echo "output=" >> $GITHUB_OUTPUT - else - echo "status=⚠️ Issues found" >> $GITHUB_OUTPUT - echo "output<> $GITHUB_OUTPUT - cat vet-output.txt | head -20 >> $GITHUB_OUTPUT - echo "EOF" >> $GITHUB_OUTPUT - fi - - - name: Run tests - id: go-test - continue-on-error: true - run: | - if go test ./... -v 2>&1 | tee test-output.txt; then - echo "status=✅ Passed" >> $GITHUB_OUTPUT - echo "output=" >> $GITHUB_OUTPUT - else - echo "status=⚠️ Failed" >> $GITHUB_OUTPUT - echo "output<> $GITHUB_OUTPUT - cat test-output.txt | tail -30 >> $GITHUB_OUTPUT - echo "EOF" >> $GITHUB_OUTPUT - fi - - - name: Post backend feedback - if: always() - uses: actions/github-script@v7 - with: - script: | - const fmtStatus = '${{ steps.go-fmt.outputs.status }}' || '⚠️ Skipped'; - const vetStatus = '${{ steps.go-vet.outputs.status }}' || '⚠️ Skipped'; - const testStatus = '${{ steps.go-test.outputs.status }}' || '⚠️ Skipped'; - const fmtFiles = `${{ steps.go-fmt.outputs.files }}`; - const vetOutput = `${{ steps.go-vet.outputs.output }}`; - const testOutput = `${{ steps.go-test.outputs.output }}`; - - let comment = '## 🔧 Backend Checks (Advisory)\n\n'; - comment += '### Go Formatting\n'; - comment += fmtStatus + '\n'; - if (fmtFiles) { - comment += '\nFiles needing formatting:\n```\n' + fmtFiles + '\n```\n'; - } - comment += '\n### Go Vet\n'; - comment += vetStatus + '\n'; - if (vetOutput) { - comment += '\n```\n' + vetOutput.substring(0, 500) + '\n```\n'; - } - comment += '\n### Tests\n'; - comment += testStatus + '\n'; - if (testOutput) { - comment += '\n```\n' + testOutput.substring(0, 1000) + '\n```\n'; - } - comment += '\n---\n\n'; - comment += '💡 **To fix locally:**\n'; - comment += '```bash\n'; - comment += '# Format code\n'; - comment += 'go fmt ./...\n\n'; - comment += '# Check for issues\n'; - comment += 'go vet ./...\n\n'; - comment += '# Run tests\n'; - comment += 'go test ./...\n'; - comment += '```\n\n'; - comment += '*These checks are advisory and won\'t block merging. Need help? Just ask!*'; - - github.rest.issues.createComment({ - issue_number: context.issue.number, - owner: context.repo.owner, - repo: context.repo.repo, - body: comment - }); - - frontend-checks: - name: Frontend Checks (Advisory) - runs-on: ubuntu-latest - continue-on-error: true - steps: - - uses: actions/checkout@v4 - - - name: Set up Node.js - uses: actions/setup-node@v4 - with: - node-version: '18' - - - name: Check if web directory exists - id: check-web - run: | - if [ -d "web" ]; then - echo "exists=true" >> $GITHUB_OUTPUT - else - echo "exists=false" >> $GITHUB_OUTPUT - fi - - - name: Install dependencies - if: steps.check-web.outputs.exists == 'true' - working-directory: ./web - continue-on-error: true - run: npm ci - - - name: Run linter - if: steps.check-web.outputs.exists == 'true' - id: lint - working-directory: ./web - continue-on-error: true - run: | - if npm run lint 2>&1 | tee lint-output.txt; then - echo "status=✅ Good" >> $GITHUB_OUTPUT - echo "output=" >> $GITHUB_OUTPUT - else - echo "status=⚠️ Issues found" >> $GITHUB_OUTPUT - echo "output<> $GITHUB_OUTPUT - cat lint-output.txt | head -20 >> $GITHUB_OUTPUT - echo "EOF" >> $GITHUB_OUTPUT - fi - - - name: Type check - if: steps.check-web.outputs.exists == 'true' - id: typecheck - working-directory: ./web - continue-on-error: true - run: | - if npm run type-check 2>&1 | tee typecheck-output.txt; then - echo "status=✅ Good" >> $GITHUB_OUTPUT - echo "output=" >> $GITHUB_OUTPUT - else - echo "status=⚠️ Issues found" >> $GITHUB_OUTPUT - echo "output<> $GITHUB_OUTPUT - cat typecheck-output.txt | head -20 >> $GITHUB_OUTPUT - echo "EOF" >> $GITHUB_OUTPUT - fi - - - name: Build - if: steps.check-web.outputs.exists == 'true' - id: build - working-directory: ./web - continue-on-error: true - run: | - if npm run build 2>&1 | tee build-output.txt; then - echo "status=✅ Success" >> $GITHUB_OUTPUT - echo "output=" >> $GITHUB_OUTPUT - else - echo "status=⚠️ Failed" >> $GITHUB_OUTPUT - echo "output<> $GITHUB_OUTPUT - cat build-output.txt | tail -20 >> $GITHUB_OUTPUT - echo "EOF" >> $GITHUB_OUTPUT - fi - - - name: Post frontend feedback - if: always() && steps.check-web.outputs.exists == 'true' - uses: actions/github-script@v7 - with: - script: | - const lintStatus = '${{ steps.lint.outputs.status }}' || '⚠️ Skipped'; - const typecheckStatus = '${{ steps.typecheck.outputs.status }}' || '⚠️ Skipped'; - const buildStatus = '${{ steps.build.outputs.status }}' || '⚠️ Skipped'; - const lintOutput = `${{ steps.lint.outputs.output }}`; - const typecheckOutput = `${{ steps.typecheck.outputs.output }}`; - const buildOutput = `${{ steps.build.outputs.output }}`; - - let comment = '## ⚛️ Frontend Checks (Advisory)\n\n'; - comment += '### Linting\n'; - comment += lintStatus + '\n'; - if (lintOutput) { - comment += '\n```\n' + lintOutput.substring(0, 500) + '\n```\n'; - } - comment += '\n### Type Checking\n'; - comment += typecheckStatus + '\n'; - if (typecheckOutput) { - comment += '\n```\n' + typecheckOutput.substring(0, 500) + '\n```\n'; - } - comment += '\n### Build\n'; - comment += buildStatus + '\n'; - if (buildOutput) { - comment += '\n```\n' + buildOutput.substring(0, 500) + '\n```\n'; - } - comment += '\n---\n\n'; - comment += '💡 **To fix locally:**\n'; - comment += '```bash\n'; - comment += 'cd web\n\n'; - comment += '# Fix linting issues\n'; - comment += 'npm run lint -- --fix\n\n'; - comment += '# Check types\n'; - comment += 'npm run type-check\n\n'; - comment += '# Test build\n'; - comment += 'npm run build\n'; - comment += '```\n\n'; - comment += '*These checks are advisory and won\'t block merging. Need help? Just ask!*'; - - github.rest.issues.createComment({ - issue_number: context.issue.number, - owner: context.repo.owner, - repo: context.repo.repo, - body: comment - }); diff --git a/api/handler_trader.go b/api/handler_trader.go index ead84200..d26d0edd 100644 --- a/api/handler_trader.go +++ b/api/handler_trader.go @@ -598,615 +598,3 @@ func (s *Server) handleStopTrader(c *gin.Context) { logger.Infof("⏹ Trader %s stopped", trader.GetName()) c.JSON(http.StatusOK, gin.H{"message": "Trader stopped"}) } - -// handleUpdateTraderPrompt Update trader custom prompt -func (s *Server) handleUpdateTraderPrompt(c *gin.Context) { - traderID := c.Param("id") - userID := c.GetString("user_id") - - var req struct { - CustomPrompt string `json:"custom_prompt"` - OverrideBasePrompt bool `json:"override_base_prompt"` - } - - if err := c.ShouldBindJSON(&req); err != nil { - SafeBadRequest(c, "Invalid request parameters") - return - } - - // Update database - err := s.store.Trader().UpdateCustomPrompt(userID, traderID, req.CustomPrompt, req.OverrideBasePrompt) - if err != nil { - SafeInternalError(c, "Failed to update custom prompt", err) - return - } - - // If trader is in memory, update its custom prompt and override settings - trader, err := s.traderManager.GetTrader(traderID) - if err == nil { - trader.SetCustomPrompt(req.CustomPrompt) - trader.SetOverrideBasePrompt(req.OverrideBasePrompt) - logger.Infof("✓ Updated trader %s custom prompt (override base=%v)", trader.GetName(), req.OverrideBasePrompt) - } - - c.JSON(http.StatusOK, gin.H{"message": "Custom prompt updated"}) -} - -// handleToggleCompetition Toggle trader competition visibility -func (s *Server) handleToggleCompetition(c *gin.Context) { - traderID := c.Param("id") - userID := c.GetString("user_id") - - var req struct { - ShowInCompetition bool `json:"show_in_competition"` - } - - if err := c.ShouldBindJSON(&req); err != nil { - SafeBadRequest(c, "Invalid request parameters") - return - } - - // Update database - err := s.store.Trader().UpdateShowInCompetition(userID, traderID, req.ShowInCompetition) - if err != nil { - SafeInternalError(c, "Update competition visibility", err) - return - } - - // Update in-memory trader if it exists - if trader, err := s.traderManager.GetTrader(traderID); err == nil { - trader.SetShowInCompetition(req.ShowInCompetition) - } - - status := "shown" - if !req.ShowInCompetition { - status = "hidden" - } - logger.Infof("✓ Trader %s competition visibility updated: %s", traderID, status) - c.JSON(http.StatusOK, gin.H{ - "message": "Competition visibility updated", - "show_in_competition": req.ShowInCompetition, - }) -} - -// handleGetGridRiskInfo returns current risk information for a grid trader -func (s *Server) handleGetGridRiskInfo(c *gin.Context) { - traderID := c.Param("id") - - autoTrader, err := s.traderManager.GetTrader(traderID) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "trader not found"}) - return - } - - riskInfo := autoTrader.GetGridRiskInfo() - c.JSON(http.StatusOK, riskInfo) -} - -// handleSyncBalance Sync exchange balance to initial_balance (Option B: Manual Sync + Option C: Smart Detection) -func (s *Server) handleSyncBalance(c *gin.Context) { - userID := c.GetString("user_id") - traderID := c.Param("id") - - logger.Infof("🔄 User %s requested balance sync for trader %s", userID, traderID) - - // Get trader configuration from database (including exchange info) - fullConfig, err := s.store.Trader().GetFullConfig(userID, traderID) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "Trader does not exist"}) - return - } - - traderConfig := fullConfig.Trader - exchangeCfg := fullConfig.Exchange - - if exchangeCfg == nil || !exchangeCfg.Enabled { - c.JSON(http.StatusBadRequest, gin.H{"error": "Exchange not configured or not enabled"}) - return - } - - // Create temporary trader to query balance - var tempTrader trader.Trader - var createErr error - - // Use ExchangeType (e.g., "binance") instead of ExchangeID (which is now UUID) - // Convert EncryptedString fields to string - switch exchangeCfg.ExchangeType { - case "binance": - tempTrader = binance.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID) - case "hyperliquid": - tempTrader, createErr = hyperliquidtrader.NewHyperliquidTrader( - string(exchangeCfg.APIKey), - exchangeCfg.HyperliquidWalletAddr, - exchangeCfg.Testnet, - exchangeCfg.HyperliquidUnifiedAcct, - ) - case "aster": - tempTrader, createErr = aster.NewAsterTrader( - exchangeCfg.AsterUser, - exchangeCfg.AsterSigner, - string(exchangeCfg.AsterPrivateKey), - ) - case "bybit": - tempTrader = bybit.NewBybitTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - ) - case "okx": - tempTrader = okx.NewOKXTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - string(exchangeCfg.Passphrase), - ) - case "bitget": - tempTrader = bitget.NewBitgetTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - string(exchangeCfg.Passphrase), - ) - case "gate": - tempTrader = gate.NewGateTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - ) - case "kucoin": - tempTrader = kucoin.NewKuCoinTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - string(exchangeCfg.Passphrase), - ) - case "lighter": - if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" { - // Lighter only supports mainnet - tempTrader, createErr = lighter.NewLighterTraderV2( - exchangeCfg.LighterWalletAddr, - string(exchangeCfg.LighterAPIKeyPrivateKey), - exchangeCfg.LighterAPIKeyIndex, - false, // Always use mainnet for Lighter - ) - } else { - createErr = fmt.Errorf("Lighter requires wallet address and API Key private key") - } - default: - c.JSON(http.StatusBadRequest, gin.H{"error": "Unsupported exchange type"}) - return - } - - if createErr != nil { - logger.Infof("⚠️ Failed to create temporary trader: %v", createErr) - SafeInternalError(c, "Failed to connect to exchange", createErr) - return - } - - // Query actual balance - balanceInfo, balanceErr := tempTrader.GetBalance() - if balanceErr != nil { - logger.Infof("⚠️ Failed to query exchange balance: %v", balanceErr) - SafeInternalError(c, "Failed to query balance", balanceErr) - return - } - - // Extract total equity (for P&L calculation, we need total account value, not available balance) - var actualBalance float64 - // Priority: total_equity > totalWalletBalance > wallet_balance > totalEq > balance - balanceKeys := []string{"total_equity", "totalWalletBalance", "wallet_balance", "totalEq", "balance"} - for _, key := range balanceKeys { - if balance, ok := balanceInfo[key].(float64); ok && balance > 0 { - actualBalance = balance - break - } - } - if actualBalance <= 0 { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Unable to get total equity"}) - return - } - - oldBalance := traderConfig.InitialBalance - - // ✅ Option C: Smart balance change detection - changePercent := ((actualBalance - oldBalance) / oldBalance) * 100 - changeType := "increase" - if changePercent < 0 { - changeType = "decrease" - } - - logger.Infof("✓ Queried actual exchange balance: %.2f USDT (current config: %.2f USDT, change: %.2f%%)", - actualBalance, oldBalance, changePercent) - - // Update initial_balance in database - err = s.store.Trader().UpdateInitialBalance(userID, traderID, actualBalance) - if err != nil { - logger.Infof("❌ Failed to update initial_balance: %v", err) - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update balance"}) - return - } - - // Reload traders into memory - err = s.traderManager.LoadUserTradersFromStore(s.store, userID) - if err != nil { - logger.Infof("⚠️ Failed to reload user traders into memory: %v", err) - } - - logger.Infof("✅ Synced balance: %.2f → %.2f USDT (%s %.2f%%)", oldBalance, actualBalance, changeType, changePercent) - - c.JSON(http.StatusOK, gin.H{ - "message": "Balance synced successfully", - "old_balance": oldBalance, - "new_balance": actualBalance, - "change_percent": changePercent, - "change_type": changeType, - }) -} - -// handleClosePosition One-click close position -func (s *Server) handleClosePosition(c *gin.Context) { - userID := c.GetString("user_id") - traderID := c.Param("id") - - var req struct { - Symbol string `json:"symbol" binding:"required"` - Side string `json:"side" binding:"required"` // "LONG" or "SHORT" - } - - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Parameter error: symbol and side are required"}) - return - } - - logger.Infof("🔻 User %s requested position close: trader=%s, symbol=%s, side=%s", userID, traderID, req.Symbol, req.Side) - - // Get trader configuration from database (including exchange info) - fullConfig, err := s.store.Trader().GetFullConfig(userID, traderID) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "Trader does not exist"}) - return - } - - exchangeCfg := fullConfig.Exchange - - if exchangeCfg == nil || !exchangeCfg.Enabled { - c.JSON(http.StatusBadRequest, gin.H{"error": "Exchange not configured or not enabled"}) - return - } - - // Create temporary trader to execute close position - var tempTrader trader.Trader - var createErr error - - // Use ExchangeType (e.g., "binance") instead of ExchangeID (which is now UUID) - // Convert EncryptedString fields to string - switch exchangeCfg.ExchangeType { - case "binance": - tempTrader = binance.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID) - case "hyperliquid": - tempTrader, createErr = hyperliquidtrader.NewHyperliquidTrader( - string(exchangeCfg.APIKey), - exchangeCfg.HyperliquidWalletAddr, - exchangeCfg.Testnet, - exchangeCfg.HyperliquidUnifiedAcct, - ) - case "aster": - tempTrader, createErr = aster.NewAsterTrader( - exchangeCfg.AsterUser, - exchangeCfg.AsterSigner, - string(exchangeCfg.AsterPrivateKey), - ) - case "bybit": - tempTrader = bybit.NewBybitTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - ) - case "okx": - tempTrader = okx.NewOKXTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - string(exchangeCfg.Passphrase), - ) - case "bitget": - tempTrader = bitget.NewBitgetTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - string(exchangeCfg.Passphrase), - ) - case "gate": - tempTrader = gate.NewGateTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - ) - case "kucoin": - tempTrader = kucoin.NewKuCoinTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - string(exchangeCfg.Passphrase), - ) - case "lighter": - if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" { - // Lighter only supports mainnet - tempTrader, createErr = lighter.NewLighterTraderV2( - exchangeCfg.LighterWalletAddr, - string(exchangeCfg.LighterAPIKeyPrivateKey), - exchangeCfg.LighterAPIKeyIndex, - false, // Always use mainnet for Lighter - ) - } else { - createErr = fmt.Errorf("Lighter requires wallet address and API Key private key") - } - default: - c.JSON(http.StatusBadRequest, gin.H{"error": "Unsupported exchange type"}) - return - } - - if createErr != nil { - logger.Infof("⚠️ Failed to create temporary trader: %v", createErr) - SafeInternalError(c, "Failed to connect to exchange", createErr) - return - } - - // Get current position info BEFORE closing (to get quantity and price) - positions, err := tempTrader.GetPositions() - if err != nil { - logger.Infof("⚠️ Failed to get positions: %v", err) - } - - var posQty float64 - var entryPrice float64 - for _, pos := range positions { - if pos["symbol"] == req.Symbol && pos["side"] == strings.ToLower(req.Side) { - if amt, ok := pos["positionAmt"].(float64); ok { - posQty = amt - if posQty < 0 { - posQty = -posQty // Make positive - } - } - if price, ok := pos["entryPrice"].(float64); ok { - entryPrice = price - } - break - } - } - - // Execute close position operation - var result map[string]interface{} - var closeErr error - - if req.Side == "LONG" { - result, closeErr = tempTrader.CloseLong(req.Symbol, 0) // 0 means close all - } else if req.Side == "SHORT" { - result, closeErr = tempTrader.CloseShort(req.Symbol, 0) // 0 means close all - } else { - c.JSON(http.StatusBadRequest, gin.H{"error": "side must be LONG or SHORT"}) - return - } - - if closeErr != nil { - logger.Infof("❌ Close position failed: symbol=%s, side=%s, error=%v", req.Symbol, req.Side, closeErr) - SafeInternalError(c, "Close position", closeErr) - return - } - - logger.Infof("✅ Position closed successfully: symbol=%s, side=%s, qty=%.6f, result=%v", req.Symbol, req.Side, posQty, result) - - // Record order to database (for chart markers and history) - s.recordClosePositionOrder(traderID, exchangeCfg.ID, exchangeCfg.ExchangeType, req.Symbol, req.Side, posQty, entryPrice, result) - - c.JSON(http.StatusOK, gin.H{ - "message": "Position closed successfully", - "symbol": req.Symbol, - "side": req.Side, - "result": result, - }) -} - -// recordClosePositionOrder Record close position order to database (Lighter version - direct FILLED status) -func (s *Server) recordClosePositionOrder(traderID, exchangeID, exchangeType, symbol, side string, quantity, exitPrice float64, result map[string]interface{}) { - // Skip for exchanges with OrderSync - let the background sync handle it to avoid duplicates - switch exchangeType { - case "binance", "lighter", "hyperliquid", "bybit", "okx", "bitget", "aster", "gate": - logger.Infof(" 📝 Close order will be synced by OrderSync, skipping immediate record") - return - } - - // Check if order was placed (skip if NO_POSITION) - status, _ := result["status"].(string) - if status == "NO_POSITION" { - logger.Infof(" ⚠️ No position to close, skipping order record") - return - } - - // Get order ID from result - var orderID string - switch v := result["orderId"].(type) { - case int64: - orderID = fmt.Sprintf("%d", v) - case float64: - orderID = fmt.Sprintf("%.0f", v) - case string: - orderID = v - default: - orderID = fmt.Sprintf("%v", v) - } - - if orderID == "" || orderID == "0" { - logger.Infof(" ⚠️ Order ID is empty, skipping record") - return - } - - // Determine order action based on side - var orderAction string - if side == "LONG" { - orderAction = "close_long" - } else { - orderAction = "close_short" - } - - // Use entry price if exit price not available - if exitPrice == 0 { - exitPrice = quantity * 100 // Rough estimate if we don't have price - } - - // Estimate fee (0.04% for Lighter taker) - fee := exitPrice * quantity * 0.0004 - - // Create order record - DIRECTLY as FILLED (Lighter market orders fill immediately) - orderRecord := &store.TraderOrder{ - TraderID: traderID, - ExchangeID: exchangeID, - ExchangeType: exchangeType, - ExchangeOrderID: orderID, - Symbol: symbol, - PositionSide: side, - OrderAction: orderAction, - Type: "MARKET", - Side: getSideFromAction(orderAction), - Quantity: quantity, - Price: 0, // Market order - Status: "FILLED", - FilledQuantity: quantity, - AvgFillPrice: exitPrice, - Commission: fee, - FilledAt: time.Now().UTC().UnixMilli(), - CreatedAt: time.Now().UTC().UnixMilli(), - UpdatedAt: time.Now().UTC().UnixMilli(), - } - - if err := s.store.Order().CreateOrder(orderRecord); err != nil { - logger.Infof(" ⚠️ Failed to record order: %v", err) - return - } - - logger.Infof(" ✅ Order recorded as FILLED: %s [%s] %s qty=%.6f price=%.6f", orderID, orderAction, symbol, quantity, exitPrice) - - // Create fill record immediately - tradeID := fmt.Sprintf("%s-%d", orderID, time.Now().UnixNano()) - fillRecord := &store.TraderFill{ - TraderID: traderID, - ExchangeID: exchangeID, - ExchangeType: exchangeType, - OrderID: orderRecord.ID, - ExchangeOrderID: orderID, - ExchangeTradeID: tradeID, - Symbol: symbol, - Side: getSideFromAction(orderAction), - Price: exitPrice, - Quantity: quantity, - QuoteQuantity: exitPrice * quantity, - Commission: fee, - CommissionAsset: "USDT", - RealizedPnL: 0, - IsMaker: false, - CreatedAt: time.Now().UTC().UnixMilli(), - } - - if err := s.store.Order().CreateFill(fillRecord); err != nil { - logger.Infof(" ⚠️ Failed to record fill: %v", err) - } else { - logger.Infof(" ✅ Fill record created: price=%.6f qty=%.6f", exitPrice, quantity) - } -} - -// pollAndUpdateOrderStatus Poll order status and update with fill data -func (s *Server) pollAndUpdateOrderStatus(orderRecordID int64, traderID, exchangeID, exchangeType, orderID, symbol, orderAction string, tempTrader trader.Trader) { - var actualPrice float64 - var actualQty float64 - var fee float64 - - // Wait a bit for order to be filled - time.Sleep(500 * time.Millisecond) - - // For Lighter, use GetTrades instead of GetOrderStatus (market orders are filled immediately) - if exchangeType == "lighter" { - s.pollLighterTradeHistory(orderRecordID, traderID, exchangeID, exchangeType, orderID, symbol, orderAction, tempTrader) - return - } - - // For other exchanges, poll GetOrderStatus - for i := 0; i < 5; i++ { - status, err := tempTrader.GetOrderStatus(symbol, orderID) - if err != nil { - logger.Infof(" ⚠️ GetOrderStatus failed (attempt %d/5): %v", i+1, err) - time.Sleep(500 * time.Millisecond) - continue - } - if err == nil { - statusStr, _ := status["status"].(string) - if statusStr == "FILLED" { - // Get actual fill price - if avgPrice, ok := status["avgPrice"].(float64); ok && avgPrice > 0 { - actualPrice = avgPrice - } - // Get actual executed quantity - if execQty, ok := status["executedQty"].(float64); ok && execQty > 0 { - actualQty = execQty - } - // Get commission/fee - if commission, ok := status["commission"].(float64); ok { - fee = commission - } - - logger.Infof(" ✅ Order filled: avgPrice=%.6f, qty=%.6f, fee=%.6f", actualPrice, actualQty, fee) - - // Update order status to FILLED - if err := s.store.Order().UpdateOrderStatus(orderRecordID, "FILLED", actualQty, actualPrice, fee); err != nil { - logger.Infof(" ⚠️ Failed to update order status: %v", err) - return - } - - // Record fill details - tradeID := fmt.Sprintf("%s-%d", orderID, time.Now().UnixNano()) - fillRecord := &store.TraderFill{ - TraderID: traderID, - ExchangeID: exchangeID, - ExchangeType: exchangeType, - OrderID: orderRecordID, - ExchangeOrderID: orderID, - ExchangeTradeID: tradeID, - Symbol: symbol, - Side: getSideFromAction(orderAction), - Price: actualPrice, - Quantity: actualQty, - QuoteQuantity: actualPrice * actualQty, - Commission: fee, - CommissionAsset: "USDT", - RealizedPnL: 0, - IsMaker: false, - CreatedAt: time.Now().UTC().UnixMilli(), - } - - if err := s.store.Order().CreateFill(fillRecord); err != nil { - logger.Infof(" ⚠️ Failed to record fill: %v", err) - } else { - logger.Infof(" 📝 Fill recorded: price=%.6f, qty=%.6f", actualPrice, actualQty) - } - - return - } else if statusStr == "CANCELED" || statusStr == "EXPIRED" || statusStr == "REJECTED" { - logger.Infof(" ⚠️ Order %s, updating status", statusStr) - s.store.Order().UpdateOrderStatus(orderRecordID, statusStr, 0, 0, 0) - return - } - } - time.Sleep(500 * time.Millisecond) - } - - logger.Infof(" ⚠️ Failed to confirm order fill after polling, order may still be pending") -} - -// pollLighterTradeHistory No longer used - Lighter orders are marked as FILLED immediately -// Keeping this function stub for compatibility with other exchanges -func (s *Server) pollLighterTradeHistory(orderRecordID int64, traderID, exchangeID, exchangeType, orderID, symbol, orderAction string, tempTrader trader.Trader) { - // For Lighter, orders are now recorded as FILLED immediately in recordClosePositionOrder - // This function is no longer called for Lighter exchange - logger.Infof(" ℹ️ pollLighterTradeHistory called but not needed (order already marked FILLED)") -} - -// getSideFromAction Get order side (BUY/SELL) from order action -func getSideFromAction(action string) string { - switch action { - case "open_long", "close_short": - return "BUY" - case "open_short", "close_long": - return "SELL" - default: - return "BUY" - } -} diff --git a/api/handler_trader_config.go b/api/handler_trader_config.go new file mode 100644 index 00000000..405d666d --- /dev/null +++ b/api/handler_trader_config.go @@ -0,0 +1,79 @@ +package api + +import ( + "net/http" + + "nofx/logger" + + "github.com/gin-gonic/gin" +) + +// handleUpdateTraderPrompt Update trader custom prompt +func (s *Server) handleUpdateTraderPrompt(c *gin.Context) { + traderID := c.Param("id") + userID := c.GetString("user_id") + + var req struct { + CustomPrompt string `json:"custom_prompt"` + OverrideBasePrompt bool `json:"override_base_prompt"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + SafeBadRequest(c, "Invalid request parameters") + return + } + + // Update database + err := s.store.Trader().UpdateCustomPrompt(userID, traderID, req.CustomPrompt, req.OverrideBasePrompt) + if err != nil { + SafeInternalError(c, "Failed to update custom prompt", err) + return + } + + // If trader is in memory, update its custom prompt and override settings + trader, err := s.traderManager.GetTrader(traderID) + if err == nil { + trader.SetCustomPrompt(req.CustomPrompt) + trader.SetOverrideBasePrompt(req.OverrideBasePrompt) + logger.Infof("✓ Updated trader %s custom prompt (override base=%v)", trader.GetName(), req.OverrideBasePrompt) + } + + c.JSON(http.StatusOK, gin.H{"message": "Custom prompt updated"}) +} + +// handleToggleCompetition Toggle trader competition visibility +func (s *Server) handleToggleCompetition(c *gin.Context) { + traderID := c.Param("id") + userID := c.GetString("user_id") + + var req struct { + ShowInCompetition bool `json:"show_in_competition"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + SafeBadRequest(c, "Invalid request parameters") + return + } + + // Update database + err := s.store.Trader().UpdateShowInCompetition(userID, traderID, req.ShowInCompetition) + if err != nil { + SafeInternalError(c, "Update competition visibility", err) + return + } + + // Update in-memory trader if it exists + if trader, err := s.traderManager.GetTrader(traderID); err == nil { + trader.SetShowInCompetition(req.ShowInCompetition) + } + + status := "shown" + if !req.ShowInCompetition { + status = "hidden" + } + logger.Infof("✓ Trader %s competition visibility updated: %s", traderID, status) + c.JSON(http.StatusOK, gin.H{ + "message": "Competition visibility updated", + "show_in_competition": req.ShowInCompetition, + }) +} diff --git a/api/handler_trader_status.go b/api/handler_trader_status.go new file mode 100644 index 00000000..cbae3a84 --- /dev/null +++ b/api/handler_trader_status.go @@ -0,0 +1,565 @@ +package api + +import ( + "fmt" + "net/http" + "strings" + "time" + + "nofx/logger" + "nofx/store" + "nofx/trader" + "nofx/trader/aster" + "nofx/trader/binance" + "nofx/trader/bitget" + "nofx/trader/bybit" + "nofx/trader/gate" + hyperliquidtrader "nofx/trader/hyperliquid" + "nofx/trader/kucoin" + "nofx/trader/lighter" + "nofx/trader/okx" + + "github.com/gin-gonic/gin" +) + +// handleGetGridRiskInfo returns current risk information for a grid trader +func (s *Server) handleGetGridRiskInfo(c *gin.Context) { + traderID := c.Param("id") + + autoTrader, err := s.traderManager.GetTrader(traderID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "trader not found"}) + return + } + + riskInfo := autoTrader.GetGridRiskInfo() + c.JSON(http.StatusOK, riskInfo) +} + +// handleSyncBalance Sync exchange balance to initial_balance (Option B: Manual Sync + Option C: Smart Detection) +func (s *Server) handleSyncBalance(c *gin.Context) { + userID := c.GetString("user_id") + traderID := c.Param("id") + + logger.Infof("🔄 User %s requested balance sync for trader %s", userID, traderID) + + // Get trader configuration from database (including exchange info) + fullConfig, err := s.store.Trader().GetFullConfig(userID, traderID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Trader does not exist"}) + return + } + + traderConfig := fullConfig.Trader + exchangeCfg := fullConfig.Exchange + + if exchangeCfg == nil || !exchangeCfg.Enabled { + c.JSON(http.StatusBadRequest, gin.H{"error": "Exchange not configured or not enabled"}) + return + } + + // Create temporary trader to query balance + var tempTrader trader.Trader + var createErr error + + // Use ExchangeType (e.g., "binance") instead of ExchangeID (which is now UUID) + // Convert EncryptedString fields to string + switch exchangeCfg.ExchangeType { + case "binance": + tempTrader = binance.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID) + case "hyperliquid": + tempTrader, createErr = hyperliquidtrader.NewHyperliquidTrader( + string(exchangeCfg.APIKey), + exchangeCfg.HyperliquidWalletAddr, + exchangeCfg.Testnet, + exchangeCfg.HyperliquidUnifiedAcct, + ) + case "aster": + tempTrader, createErr = aster.NewAsterTrader( + exchangeCfg.AsterUser, + exchangeCfg.AsterSigner, + string(exchangeCfg.AsterPrivateKey), + ) + case "bybit": + tempTrader = bybit.NewBybitTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + ) + case "okx": + tempTrader = okx.NewOKXTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + string(exchangeCfg.Passphrase), + ) + case "bitget": + tempTrader = bitget.NewBitgetTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + string(exchangeCfg.Passphrase), + ) + case "gate": + tempTrader = gate.NewGateTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + ) + case "kucoin": + tempTrader = kucoin.NewKuCoinTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + string(exchangeCfg.Passphrase), + ) + case "lighter": + if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" { + // Lighter only supports mainnet + tempTrader, createErr = lighter.NewLighterTraderV2( + exchangeCfg.LighterWalletAddr, + string(exchangeCfg.LighterAPIKeyPrivateKey), + exchangeCfg.LighterAPIKeyIndex, + false, // Always use mainnet for Lighter + ) + } else { + createErr = fmt.Errorf("Lighter requires wallet address and API Key private key") + } + default: + c.JSON(http.StatusBadRequest, gin.H{"error": "Unsupported exchange type"}) + return + } + + if createErr != nil { + logger.Infof("⚠️ Failed to create temporary trader: %v", createErr) + SafeInternalError(c, "Failed to connect to exchange", createErr) + return + } + + // Query actual balance + balanceInfo, balanceErr := tempTrader.GetBalance() + if balanceErr != nil { + logger.Infof("⚠️ Failed to query exchange balance: %v", balanceErr) + SafeInternalError(c, "Failed to query balance", balanceErr) + return + } + + // Extract total equity (for P&L calculation, we need total account value, not available balance) + var actualBalance float64 + // Priority: total_equity > totalWalletBalance > wallet_balance > totalEq > balance + balanceKeys := []string{"total_equity", "totalWalletBalance", "wallet_balance", "totalEq", "balance"} + for _, key := range balanceKeys { + if balance, ok := balanceInfo[key].(float64); ok && balance > 0 { + actualBalance = balance + break + } + } + if actualBalance <= 0 { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Unable to get total equity"}) + return + } + + oldBalance := traderConfig.InitialBalance + + // Smart balance change detection + changePercent := ((actualBalance - oldBalance) / oldBalance) * 100 + changeType := "increase" + if changePercent < 0 { + changeType = "decrease" + } + + logger.Infof("✓ Queried actual exchange balance: %.2f USDT (current config: %.2f USDT, change: %.2f%%)", + actualBalance, oldBalance, changePercent) + + // Update initial_balance in database + err = s.store.Trader().UpdateInitialBalance(userID, traderID, actualBalance) + if err != nil { + logger.Infof("❌ Failed to update initial_balance: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update balance"}) + return + } + + // Reload traders into memory + err = s.traderManager.LoadUserTradersFromStore(s.store, userID) + if err != nil { + logger.Infof("⚠️ Failed to reload user traders into memory: %v", err) + } + + logger.Infof("✅ Synced balance: %.2f → %.2f USDT (%s %.2f%%)", oldBalance, actualBalance, changeType, changePercent) + + c.JSON(http.StatusOK, gin.H{ + "message": "Balance synced successfully", + "old_balance": oldBalance, + "new_balance": actualBalance, + "change_percent": changePercent, + "change_type": changeType, + }) +} + +// handleClosePosition One-click close position +func (s *Server) handleClosePosition(c *gin.Context) { + userID := c.GetString("user_id") + traderID := c.Param("id") + + var req struct { + Symbol string `json:"symbol" binding:"required"` + Side string `json:"side" binding:"required"` // "LONG" or "SHORT" + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Parameter error: symbol and side are required"}) + return + } + + logger.Infof("🔻 User %s requested position close: trader=%s, symbol=%s, side=%s", userID, traderID, req.Symbol, req.Side) + + // Get trader configuration from database (including exchange info) + fullConfig, err := s.store.Trader().GetFullConfig(userID, traderID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Trader does not exist"}) + return + } + + exchangeCfg := fullConfig.Exchange + + if exchangeCfg == nil || !exchangeCfg.Enabled { + c.JSON(http.StatusBadRequest, gin.H{"error": "Exchange not configured or not enabled"}) + return + } + + // Create temporary trader to execute close position + var tempTrader trader.Trader + var createErr error + + // Use ExchangeType (e.g., "binance") instead of ExchangeID (which is now UUID) + // Convert EncryptedString fields to string + switch exchangeCfg.ExchangeType { + case "binance": + tempTrader = binance.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID) + case "hyperliquid": + tempTrader, createErr = hyperliquidtrader.NewHyperliquidTrader( + string(exchangeCfg.APIKey), + exchangeCfg.HyperliquidWalletAddr, + exchangeCfg.Testnet, + exchangeCfg.HyperliquidUnifiedAcct, + ) + case "aster": + tempTrader, createErr = aster.NewAsterTrader( + exchangeCfg.AsterUser, + exchangeCfg.AsterSigner, + string(exchangeCfg.AsterPrivateKey), + ) + case "bybit": + tempTrader = bybit.NewBybitTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + ) + case "okx": + tempTrader = okx.NewOKXTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + string(exchangeCfg.Passphrase), + ) + case "bitget": + tempTrader = bitget.NewBitgetTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + string(exchangeCfg.Passphrase), + ) + case "gate": + tempTrader = gate.NewGateTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + ) + case "kucoin": + tempTrader = kucoin.NewKuCoinTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + string(exchangeCfg.Passphrase), + ) + case "lighter": + if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" { + // Lighter only supports mainnet + tempTrader, createErr = lighter.NewLighterTraderV2( + exchangeCfg.LighterWalletAddr, + string(exchangeCfg.LighterAPIKeyPrivateKey), + exchangeCfg.LighterAPIKeyIndex, + false, // Always use mainnet for Lighter + ) + } else { + createErr = fmt.Errorf("Lighter requires wallet address and API Key private key") + } + default: + c.JSON(http.StatusBadRequest, gin.H{"error": "Unsupported exchange type"}) + return + } + + if createErr != nil { + logger.Infof("⚠️ Failed to create temporary trader: %v", createErr) + SafeInternalError(c, "Failed to connect to exchange", createErr) + return + } + + // Get current position info BEFORE closing (to get quantity and price) + positions, err := tempTrader.GetPositions() + if err != nil { + logger.Infof("⚠️ Failed to get positions: %v", err) + } + + var posQty float64 + var entryPrice float64 + for _, pos := range positions { + if pos["symbol"] == req.Symbol && pos["side"] == strings.ToLower(req.Side) { + if amt, ok := pos["positionAmt"].(float64); ok { + posQty = amt + if posQty < 0 { + posQty = -posQty // Make positive + } + } + if price, ok := pos["entryPrice"].(float64); ok { + entryPrice = price + } + break + } + } + + // Execute close position operation + var result map[string]interface{} + var closeErr error + + if req.Side == "LONG" { + result, closeErr = tempTrader.CloseLong(req.Symbol, 0) // 0 means close all + } else if req.Side == "SHORT" { + result, closeErr = tempTrader.CloseShort(req.Symbol, 0) // 0 means close all + } else { + c.JSON(http.StatusBadRequest, gin.H{"error": "side must be LONG or SHORT"}) + return + } + + if closeErr != nil { + logger.Infof("❌ Close position failed: symbol=%s, side=%s, error=%v", req.Symbol, req.Side, closeErr) + SafeInternalError(c, "Close position", closeErr) + return + } + + logger.Infof("✅ Position closed successfully: symbol=%s, side=%s, qty=%.6f, result=%v", req.Symbol, req.Side, posQty, result) + + // Record order to database (for chart markers and history) + s.recordClosePositionOrder(traderID, exchangeCfg.ID, exchangeCfg.ExchangeType, req.Symbol, req.Side, posQty, entryPrice, result) + + c.JSON(http.StatusOK, gin.H{ + "message": "Position closed successfully", + "symbol": req.Symbol, + "side": req.Side, + "result": result, + }) +} + +// recordClosePositionOrder Record close position order to database (Lighter version - direct FILLED status) +func (s *Server) recordClosePositionOrder(traderID, exchangeID, exchangeType, symbol, side string, quantity, exitPrice float64, result map[string]interface{}) { + // Skip for exchanges with OrderSync - let the background sync handle it to avoid duplicates + switch exchangeType { + case "binance", "lighter", "hyperliquid", "bybit", "okx", "bitget", "aster", "gate": + logger.Infof(" 📝 Close order will be synced by OrderSync, skipping immediate record") + return + } + + // Check if order was placed (skip if NO_POSITION) + status, _ := result["status"].(string) + if status == "NO_POSITION" { + logger.Infof(" ⚠️ No position to close, skipping order record") + return + } + + // Get order ID from result + var orderID string + switch v := result["orderId"].(type) { + case int64: + orderID = fmt.Sprintf("%d", v) + case float64: + orderID = fmt.Sprintf("%.0f", v) + case string: + orderID = v + default: + orderID = fmt.Sprintf("%v", v) + } + + if orderID == "" || orderID == "0" { + logger.Infof(" ⚠️ Order ID is empty, skipping record") + return + } + + // Determine order action based on side + var orderAction string + if side == "LONG" { + orderAction = "close_long" + } else { + orderAction = "close_short" + } + + // Use entry price if exit price not available + if exitPrice == 0 { + exitPrice = quantity * 100 // Rough estimate if we don't have price + } + + // Estimate fee (0.04% for Lighter taker) + fee := exitPrice * quantity * 0.0004 + + // Create order record - DIRECTLY as FILLED (Lighter market orders fill immediately) + orderRecord := &store.TraderOrder{ + TraderID: traderID, + ExchangeID: exchangeID, + ExchangeType: exchangeType, + ExchangeOrderID: orderID, + Symbol: symbol, + PositionSide: side, + OrderAction: orderAction, + Type: "MARKET", + Side: getSideFromAction(orderAction), + Quantity: quantity, + Price: 0, // Market order + Status: "FILLED", + FilledQuantity: quantity, + AvgFillPrice: exitPrice, + Commission: fee, + FilledAt: time.Now().UTC().UnixMilli(), + CreatedAt: time.Now().UTC().UnixMilli(), + UpdatedAt: time.Now().UTC().UnixMilli(), + } + + if err := s.store.Order().CreateOrder(orderRecord); err != nil { + logger.Infof(" ⚠️ Failed to record order: %v", err) + return + } + + logger.Infof(" ✅ Order recorded as FILLED: %s [%s] %s qty=%.6f price=%.6f", orderID, orderAction, symbol, quantity, exitPrice) + + // Create fill record immediately + tradeID := fmt.Sprintf("%s-%d", orderID, time.Now().UnixNano()) + fillRecord := &store.TraderFill{ + TraderID: traderID, + ExchangeID: exchangeID, + ExchangeType: exchangeType, + OrderID: orderRecord.ID, + ExchangeOrderID: orderID, + ExchangeTradeID: tradeID, + Symbol: symbol, + Side: getSideFromAction(orderAction), + Price: exitPrice, + Quantity: quantity, + QuoteQuantity: exitPrice * quantity, + Commission: fee, + CommissionAsset: "USDT", + RealizedPnL: 0, + IsMaker: false, + CreatedAt: time.Now().UTC().UnixMilli(), + } + + if err := s.store.Order().CreateFill(fillRecord); err != nil { + logger.Infof(" ⚠️ Failed to record fill: %v", err) + } else { + logger.Infof(" ✅ Fill record created: price=%.6f qty=%.6f", exitPrice, quantity) + } +} + +// pollAndUpdateOrderStatus Poll order status and update with fill data +func (s *Server) pollAndUpdateOrderStatus(orderRecordID int64, traderID, exchangeID, exchangeType, orderID, symbol, orderAction string, tempTrader trader.Trader) { + var actualPrice float64 + var actualQty float64 + var fee float64 + + // Wait a bit for order to be filled + time.Sleep(500 * time.Millisecond) + + // For Lighter, use GetTrades instead of GetOrderStatus (market orders are filled immediately) + if exchangeType == "lighter" { + s.pollLighterTradeHistory(orderRecordID, traderID, exchangeID, exchangeType, orderID, symbol, orderAction, tempTrader) + return + } + + // For other exchanges, poll GetOrderStatus + for i := 0; i < 5; i++ { + status, err := tempTrader.GetOrderStatus(symbol, orderID) + if err != nil { + logger.Infof(" ⚠️ GetOrderStatus failed (attempt %d/5): %v", i+1, err) + time.Sleep(500 * time.Millisecond) + continue + } + if err == nil { + statusStr, _ := status["status"].(string) + if statusStr == "FILLED" { + // Get actual fill price + if avgPrice, ok := status["avgPrice"].(float64); ok && avgPrice > 0 { + actualPrice = avgPrice + } + // Get actual executed quantity + if execQty, ok := status["executedQty"].(float64); ok && execQty > 0 { + actualQty = execQty + } + // Get commission/fee + if commission, ok := status["commission"].(float64); ok { + fee = commission + } + + logger.Infof(" ✅ Order filled: avgPrice=%.6f, qty=%.6f, fee=%.6f", actualPrice, actualQty, fee) + + // Update order status to FILLED + if err := s.store.Order().UpdateOrderStatus(orderRecordID, "FILLED", actualQty, actualPrice, fee); err != nil { + logger.Infof(" ⚠️ Failed to update order status: %v", err) + return + } + + // Record fill details + tradeID := fmt.Sprintf("%s-%d", orderID, time.Now().UnixNano()) + fillRecord := &store.TraderFill{ + TraderID: traderID, + ExchangeID: exchangeID, + ExchangeType: exchangeType, + OrderID: orderRecordID, + ExchangeOrderID: orderID, + ExchangeTradeID: tradeID, + Symbol: symbol, + Side: getSideFromAction(orderAction), + Price: actualPrice, + Quantity: actualQty, + QuoteQuantity: actualPrice * actualQty, + Commission: fee, + CommissionAsset: "USDT", + RealizedPnL: 0, + IsMaker: false, + CreatedAt: time.Now().UTC().UnixMilli(), + } + + if err := s.store.Order().CreateFill(fillRecord); err != nil { + logger.Infof(" ⚠️ Failed to record fill: %v", err) + } else { + logger.Infof(" 📝 Fill recorded: price=%.6f, qty=%.6f", actualPrice, actualQty) + } + + return + } else if statusStr == "CANCELED" || statusStr == "EXPIRED" || statusStr == "REJECTED" { + logger.Infof(" ⚠️ Order %s, updating status", statusStr) + s.store.Order().UpdateOrderStatus(orderRecordID, statusStr, 0, 0, 0) + return + } + } + time.Sleep(500 * time.Millisecond) + } + + logger.Infof(" ⚠️ Failed to confirm order fill after polling, order may still be pending") +} + +// pollLighterTradeHistory No longer used - Lighter orders are marked as FILLED immediately +// Keeping this function stub for compatibility with other exchanges +func (s *Server) pollLighterTradeHistory(orderRecordID int64, traderID, exchangeID, exchangeType, orderID, symbol, orderAction string, tempTrader trader.Trader) { + // For Lighter, orders are now recorded as FILLED immediately in recordClosePositionOrder + // This function is no longer called for Lighter exchange + logger.Infof(" ℹ️ pollLighterTradeHistory called but not needed (order already marked FILLED)") +} + +// getSideFromAction Get order side (BUY/SELL) from order action +func getSideFromAction(action string) string { + switch action { + case "open_long", "close_short": + return "BUY" + case "open_short", "close_long": + return "SELL" + default: + return "BUY" + } +} diff --git a/backtest/runner.go b/backtest/runner.go index 70d2c5eb..7b31c97f 100644 --- a/backtest/runner.go +++ b/backtest/runner.go @@ -2,21 +2,16 @@ package backtest import ( "context" - "encoding/json" "errors" "fmt" - "nofx/logger" "os" "path/filepath" - "sort" - "strings" "sync" "time" "nofx/kernel" - "nofx/market" + "nofx/logger" "nofx/mcp" - "nofx/store" ) var ( @@ -232,954 +227,6 @@ func (r *Runner) CurrentMetadata() *RunMetadata { return meta } -func (r *Runner) loop(ctx context.Context) { - defer close(r.doneCh) - - for { - select { - case <-ctx.Done(): - r.handleStop(fmt.Errorf("context canceled: %w", ctx.Err())) - return - case <-r.stopCh: - r.handleStop(nil) - return - case <-r.pauseCh: - r.handlePause() - <-r.resumeCh - r.resumeFromPause() - default: - } - - err := r.stepOnce() - if errors.Is(err, errBacktestCompleted) { - r.handleCompletion() - return - } - if errors.Is(err, errLiquidated) { - r.handleLiquidation() - return - } - if err != nil { - r.handleFailure(err) - return - } - } -} - -func (r *Runner) stepOnce() error { - state := r.snapshotState() - if state.BarIndex >= r.feed.DecisionBarCount() { - return errBacktestCompleted - } - - ts := r.feed.DecisionTimestamp(state.BarIndex) - - marketData, multiTF, err := r.feed.BuildMarketData(ts) - if err != nil { - return err - } - - priceMap := make(map[string]float64, len(marketData)) - for symbol, data := range marketData { - priceMap[symbol] = data.CurrentPrice - } - - callCount := state.DecisionCycle + 1 - shouldDecide := r.shouldTriggerDecision(state.BarIndex) - - var ( - record *store.DecisionRecord - decisionActions []store.DecisionAction - tradeEvents = make([]TradeEvent, 0) - execLog []string - hadError bool - ) - - decisionAttempted := shouldDecide - - if shouldDecide { - ctx, rec, err := r.buildDecisionContext(ts, marketData, multiTF, priceMap, callCount) - if err != nil { - // Defensive nil check to prevent panic if buildDecisionContext returns error with nil record - if rec != nil { - rec.Success = false - rec.ErrorMessage = fmt.Sprintf("failed to build trading context: %v", err) - _ = r.logDecision(rec) - } - return err - } - record = rec - - var ( - fullDecision *kernel.FullDecision - fromCache bool - cacheKey string - ) - if r.aiCache != nil { - if key, err := computeCacheKey(ctx, r.cfg.PromptVariant, ts); err == nil { - cacheKey = key - if cached, ok := r.aiCache.Get(cacheKey); ok { - fullDecision = cached - fromCache = true - } else if r.cfg.ReplayOnly { - decisionErr := fmt.Errorf("replay_only enabled but cache miss at %d", ts) - record.Success = false - record.ErrorMessage = fmt.Sprintf("cached decision not found for ts=%d", ts) - _ = r.logDecision(record) - return decisionErr - } - } else { - logger.Infof("failed to compute ai cache key: %v", err) - } - } - - if !fromCache { - fd, err := r.invokeAIWithRetry(ctx) - if err != nil { - decisionAttempted = true - hadError = true - record.Success = false - record.ErrorMessage = fmt.Sprintf("AI decision failed: %v", err) - execLog = append(execLog, fmt.Sprintf("⚠️ AI decision failed: %v", err)) - r.setLastError(err) - } else { - fullDecision = fd - if r.cfg.CacheAI && r.aiCache != nil && cacheKey != "" { - if err := r.aiCache.Put(cacheKey, r.cfg.PromptVariant, ts, fullDecision); err != nil { - logger.Infof("failed to persist ai cache for %s: %v", r.cfg.RunID, err) - } - } - } - } - - if fullDecision != nil { - r.fillDecisionRecord(record, fullDecision) - - sorted := sortDecisionsByPriority(fullDecision.Decisions) - - prevLogs := execLog - decisionActions = make([]store.DecisionAction, 0, len(sorted)) - execLog = make([]string, 0, len(sorted)+len(prevLogs)) - if len(prevLogs) > 0 { - execLog = append(execLog, prevLogs...) - } - - for _, dec := range sorted { - actionRecord, trades, logEntry, execErr := r.executeDecision(dec, priceMap, ts, callCount) - if execErr != nil { - actionRecord.Success = false - actionRecord.Error = execErr.Error() - hadError = true - execLog = append(execLog, fmt.Sprintf("❌ %s %s: %v", dec.Symbol, dec.Action, execErr)) - } else { - actionRecord.Success = true - execLog = append(execLog, fmt.Sprintf("✓ %s %s", dec.Symbol, dec.Action)) - } - if len(trades) > 0 { - tradeEvents = append(tradeEvents, trades...) - } - if logEntry != "" { - execLog = append(execLog, logEntry) - } - decisionActions = append(decisionActions, actionRecord) - } - } - } - - cycleForLog := state.DecisionCycle - if decisionAttempted { - cycleForLog = callCount - } - - liquidationEvents, liquidationNote, err := r.checkLiquidation(ts, priceMap, cycleForLog) - if err != nil { - if record != nil { - record.Success = false - record.ErrorMessage = err.Error() - _ = r.logDecision(record) - } - return err - } - if len(liquidationEvents) > 0 { - hadError = true - tradeEvents = append(tradeEvents, liquidationEvents...) - if record != nil { - execLog = append(execLog, fmt.Sprintf("⚠️ Forced liquidation: %s", liquidationNote)) - } - } - - if record != nil { - record.Decisions = decisionActions - record.ExecutionLog = execLog - record.Success = !hadError && liquidationNote == "" - if liquidationNote != "" { - record.ErrorMessage = liquidationNote - } - } - - equity, unrealized, _ := r.account.TotalEquity(priceMap) - marginUsed := r.totalMarginUsed() - - r.updateState(ts, equity, unrealized, marginUsed, priceMap, decisionAttempted) - - snapshot := r.snapshotState() - drawdownPct := 0.0 - if snapshot.MaxEquity > 0 { - drawdownPct = ((snapshot.MaxEquity - snapshot.Equity) / snapshot.MaxEquity) * 100 - } - - equityPoint := EquityPoint{ - Timestamp: ts, - Equity: snapshot.Equity, - Available: snapshot.Cash, - PnL: snapshot.Equity - r.account.InitialBalance(), - PnLPct: ((snapshot.Equity - r.account.InitialBalance()) / r.account.InitialBalance()) * 100, - DrawdownPct: drawdownPct, - Cycle: snapshot.DecisionCycle, - } - - if err := appendEquityPoint(r.cfg.RunID, equityPoint); err != nil { - return err - } - - for _, evt := range tradeEvents { - if err := appendTradeEvent(r.cfg.RunID, evt); err != nil { - return err - } - } - - if record != nil { - if err := r.logDecision(record); err != nil { - return err - } - } - - if err := saveProgress(r.cfg.RunID, &snapshot, &r.cfg); err != nil { - return err - } - - if err := r.maybeCheckpoint(); err != nil { - return err - } - - r.persistMetadata() - r.persistMetrics(false) - - if !hadError && liquidationNote == "" { - r.setLastError(nil) - } - - if snapshot.Liquidated { - return errLiquidated - } - - return nil -} - -func (r *Runner) buildDecisionContext(ts int64, marketData map[string]*market.Data, multiTF map[string]map[string]*market.Data, priceMap map[string]float64, callCount int) (*kernel.Context, *store.DecisionRecord, error) { - equity, unrealized, _ := r.account.TotalEquity(priceMap) - available := r.account.Cash() - marginUsed := r.totalMarginUsed() - marginPct := 0.0 - if equity > 0 { - marginPct = (marginUsed / equity) * 100 - } - - accountInfo := kernel.AccountInfo{ - TotalEquity: equity, - AvailableBalance: available, - TotalPnL: equity - r.account.InitialBalance(), - TotalPnLPct: ((equity - r.account.InitialBalance()) / r.account.InitialBalance()) * 100, - MarginUsed: marginUsed, - MarginUsedPct: marginPct, - PositionCount: len(r.account.Positions()), - } - - positions := r.convertPositions(priceMap) - - // Get candidate coins from strategy engine (includes source info) - candidateCoins, err := r.strategyEngine.GetCandidateCoins() - if err != nil { - // Fallback to simple list if strategy engine fails - candidateCoins = make([]kernel.CandidateCoin, 0, len(r.cfg.Symbols)) - for _, sym := range r.cfg.Symbols { - candidateCoins = append(candidateCoins, kernel.CandidateCoin{Symbol: sym, Sources: []string{"backtest"}}) - } - } - - runtime := int((ts - int64(r.cfg.StartTS*1000)) / 60000) - ctx := &kernel.Context{ - CurrentTime: time.UnixMilli(ts).UTC().Format("2006-01-02 15:04:05 UTC"), - RuntimeMinutes: runtime, - CallCount: callCount, - Account: accountInfo, - Positions: positions, - CandidateCoins: candidateCoins, - PromptVariant: r.cfg.PromptVariant, - MarketDataMap: marketData, - MultiTFMarket: multiTF, - BTCETHLeverage: r.cfg.Leverage.BTCETHLeverage, - AltcoinLeverage: r.cfg.Leverage.AltcoinLeverage, - Timeframes: r.cfg.Timeframes, - } - - // Fetch quantitative data if enabled in strategy (uses current data as approximation) - strategyConfig := r.strategyEngine.GetConfig() - if strategyConfig.Indicators.EnableQuantData { - // Collect symbols to query (candidate coins + position coins) - symbolSet := make(map[string]bool) - for _, sym := range r.cfg.Symbols { - symbolSet[sym] = true - } - for _, pos := range positions { - symbolSet[pos.Symbol] = true - } - symbols := make([]string, 0, len(symbolSet)) - for sym := range symbolSet { - symbols = append(symbols, sym) - } - ctx.QuantDataMap = r.strategyEngine.FetchQuantDataBatch(symbols) - if len(ctx.QuantDataMap) > 0 { - logger.Infof("📊 Backtest: fetched quant data for %d symbols", len(ctx.QuantDataMap)) - } - } - - // Fetch OI ranking data if enabled in strategy (uses current data as approximation) - if strategyConfig.Indicators.EnableOIRanking { - ctx.OIRankingData = r.strategyEngine.FetchOIRankingData() - if ctx.OIRankingData != nil { - logger.Infof("📊 Backtest: OI ranking data ready: %d top, %d low positions", - len(ctx.OIRankingData.TopPositions), len(ctx.OIRankingData.LowPositions)) - } - } - - // Fetch NetFlow ranking data if enabled in strategy - if strategyConfig.Indicators.EnableNetFlowRanking { - ctx.NetFlowRankingData = r.strategyEngine.FetchNetFlowRankingData() - if ctx.NetFlowRankingData != nil { - logger.Infof("💰 Backtest: NetFlow ranking data ready: inst_in=%d, inst_out=%d", - len(ctx.NetFlowRankingData.InstitutionFutureTop), len(ctx.NetFlowRankingData.InstitutionFutureLow)) - } - } - - // Fetch Price ranking data if enabled in strategy - if strategyConfig.Indicators.EnablePriceRanking { - ctx.PriceRankingData = r.strategyEngine.FetchPriceRankingData() - if ctx.PriceRankingData != nil { - logger.Infof("📈 Backtest: Price ranking data ready for %d durations", - len(ctx.PriceRankingData.Durations)) - } - } - - record := &store.DecisionRecord{ - AccountState: store.AccountSnapshot{ - TotalBalance: accountInfo.TotalEquity, - AvailableBalance: accountInfo.AvailableBalance, - TotalUnrealizedProfit: unrealized, - PositionCount: accountInfo.PositionCount, - MarginUsedPct: accountInfo.MarginUsedPct, - }, - CandidateCoins: make([]string, 0, len(candidateCoins)), - Positions: r.snapshotPositions(priceMap), - } - for _, coin := range candidateCoins { - record.CandidateCoins = append(record.CandidateCoins, coin.Symbol) - } - record.Timestamp = time.UnixMilli(ts).UTC() - - return ctx, record, nil -} - -func (r *Runner) fillDecisionRecord(record *store.DecisionRecord, full *kernel.FullDecision) { - record.InputPrompt = full.UserPrompt - record.CoTTrace = full.CoTTrace - if len(full.Decisions) > 0 { - if data, err := json.MarshalIndent(full.Decisions, "", " "); err == nil { - record.DecisionJSON = string(data) - } - } -} - -func (r *Runner) invokeAIWithRetry(ctx *kernel.Context) (*kernel.FullDecision, error) { - var lastErr error - for attempt := 0; attempt < aiDecisionMaxRetries; attempt++ { - // Use GetFullDecisionWithStrategy with the pre-configured strategy engine - // This ensures backtest uses the same unified prompt generation as live trading - fd, err := kernel.GetFullDecisionWithStrategy( - ctx, - r.mcpClient, - r.strategyEngine, - r.cfg.PromptVariant, - ) - if err == nil { - return fd, nil - } - lastErr = err - delay := time.Duration(attempt+1) * 500 * time.Millisecond - time.Sleep(delay) - } - return nil, lastErr -} - -func (r *Runner) executeDecision(dec kernel.Decision, priceMap map[string]float64, ts int64, cycle int) (store.DecisionAction, []TradeEvent, string, error) { - symbol := dec.Symbol - if symbol == "" { - return store.DecisionAction{}, nil, "", fmt.Errorf("empty symbol in decision") - } - - usedLeverage := r.resolveLeverage(dec.Leverage, symbol) - actionRecord := store.DecisionAction{ - Action: dec.Action, - Symbol: symbol, - Leverage: usedLeverage, - Timestamp: time.UnixMilli(ts).UTC(), - } - - if priceMap == nil { - return actionRecord, nil, "", fmt.Errorf("priceMap is nil") - } - - basePrice, ok := priceMap[symbol] - if !ok || basePrice <= 0 { - return actionRecord, nil, "", fmt.Errorf("price unavailable for %s (found=%v, price=%.4f)", symbol, ok, basePrice) - } - fillPrice := r.executionPrice(symbol, basePrice, ts) - - switch dec.Action { - case "open_long": - qty := r.determineQuantity(dec, basePrice) - if qty <= 0 { - return actionRecord, nil, "", fmt.Errorf("invalid qty") - } - pos, fee, execPrice, err := r.account.Open(symbol, "long", qty, usedLeverage, fillPrice, ts) - if err != nil { - return actionRecord, nil, "", err - } - actionRecord.Quantity = qty - actionRecord.Price = execPrice - actionRecord.Leverage = pos.Leverage - trade := TradeEvent{ - Timestamp: ts, - Symbol: symbol, - Action: dec.Action, - Side: "long", - Quantity: qty, - Price: execPrice, - Fee: fee, - Slippage: execPrice - basePrice, - OrderValue: execPrice * qty, - RealizedPnL: 0, - Leverage: pos.Leverage, - Cycle: cycle, - PositionAfter: pos.Quantity, - } - return actionRecord, []TradeEvent{trade}, "", nil - - case "open_short": - qty := r.determineQuantity(dec, basePrice) - if qty <= 0 { - return actionRecord, nil, "", fmt.Errorf("invalid qty") - } - pos, fee, execPrice, err := r.account.Open(symbol, "short", qty, usedLeverage, fillPrice, ts) - if err != nil { - return actionRecord, nil, "", err - } - actionRecord.Quantity = qty - actionRecord.Price = execPrice - actionRecord.Leverage = pos.Leverage - trade := TradeEvent{ - Timestamp: ts, - Symbol: symbol, - Action: dec.Action, - Side: "short", - Quantity: qty, - Price: execPrice, - Fee: fee, - Slippage: basePrice - execPrice, - OrderValue: execPrice * qty, - RealizedPnL: 0, - Leverage: pos.Leverage, - Cycle: cycle, - PositionAfter: pos.Quantity, - } - return actionRecord, []TradeEvent{trade}, "", nil - - case "close_long": - qty := r.determineCloseQuantity(symbol, "long", dec) - if qty <= 0 { - return actionRecord, nil, "", fmt.Errorf("invalid close qty") - } - posLev := r.account.positionLeverage(symbol, "long") - realized, fee, execPrice, err := r.account.Close(symbol, "long", qty, fillPrice) - if err != nil { - return actionRecord, nil, "", err - } - actionRecord.Quantity = qty - actionRecord.Price = execPrice - actionRecord.Leverage = posLev - trade := TradeEvent{ - Timestamp: ts, - Symbol: symbol, - Action: dec.Action, - Side: "long", - Quantity: qty, - Price: execPrice, - Fee: fee, - Slippage: basePrice - execPrice, - OrderValue: execPrice * qty, - RealizedPnL: realized - fee, - Leverage: posLev, - Cycle: cycle, - PositionAfter: r.remainingPosition(symbol, "long"), - } - return actionRecord, []TradeEvent{trade}, "", nil - - case "close_short": - qty := r.determineCloseQuantity(symbol, "short", dec) - if qty <= 0 { - return actionRecord, nil, "", fmt.Errorf("invalid close qty") - } - posLev := r.account.positionLeverage(symbol, "short") - realized, fee, execPrice, err := r.account.Close(symbol, "short", qty, fillPrice) - if err != nil { - return actionRecord, nil, "", err - } - actionRecord.Quantity = qty - actionRecord.Price = execPrice - actionRecord.Leverage = posLev - trade := TradeEvent{ - Timestamp: ts, - Symbol: symbol, - Action: dec.Action, - Side: "short", - Quantity: qty, - Price: execPrice, - Fee: fee, - Slippage: execPrice - basePrice, - OrderValue: execPrice * qty, - RealizedPnL: realized - fee, - Leverage: posLev, - Cycle: cycle, - PositionAfter: r.remainingPosition(symbol, "short"), - } - return actionRecord, []TradeEvent{trade}, "", nil - - case "hold", "wait": - return actionRecord, nil, fmt.Sprintf("hold position: %s", dec.Action), nil - default: - return actionRecord, nil, "", fmt.Errorf("unsupported action %s", dec.Action) - } -} - -// MinPositionSizeUSD is the minimum position size in USD to avoid dust positions -const MinPositionSizeUSD = 10.0 - -func (r *Runner) determineQuantity(dec kernel.Decision, price float64) float64 { - snapshot := r.snapshotState() - equity := snapshot.Equity - if equity <= 0 { - equity = r.account.InitialBalance() - } - - // Get leverage for this symbol - leverage := r.resolveLeverage(dec.Leverage, dec.Symbol) - if leverage <= 0 { - leverage = 5 - } - - // Calculate available margin (leave some buffer for fees) - availableCash := r.account.Cash() - maxMarginToUse := availableCash * 0.9 // Use max 90% of available cash - maxPositionValue := maxMarginToUse * float64(leverage) - - sizeUSD := dec.PositionSizeUSD - if sizeUSD <= 0 { - // Default to 5% of equity, but cap to available margin - sizeUSD = 0.05 * equity - } - - // Cap position size to what we can actually afford - if sizeUSD > maxPositionValue { - logger.Infof("📊 Backtest: capping position from %.2f to %.2f (available margin: %.2f, leverage: %dx)", - sizeUSD, maxPositionValue, maxMarginToUse, leverage) - sizeUSD = maxPositionValue - } - - // Reject positions below minimum size to avoid dust positions - if sizeUSD < MinPositionSizeUSD { - logger.Infof("📊 Backtest: rejecting position size %.2f USD (below minimum %.2f USD)", - sizeUSD, MinPositionSizeUSD) - return 0 - } - - qty := sizeUSD / price - if qty < 0 { - qty = 0 - } - return qty -} - -func (r *Runner) determineCloseQuantity(symbol, side string, dec kernel.Decision) float64 { - for _, pos := range r.account.Positions() { - if pos.Symbol == strings.ToUpper(symbol) && pos.Side == side { - return pos.Quantity - } - } - return 0 -} - -func (r *Runner) resolveLeverage(requested int, symbol string) int { - sym := strings.ToUpper(symbol) - isBTCETH := sym == "BTCUSDT" || sym == "ETHUSDT" - - // Determine configured max leverage for this symbol type - var maxLeverage int - if isBTCETH { - maxLeverage = r.cfg.Leverage.BTCETHLeverage - if maxLeverage <= 0 { - maxLeverage = 10 // Default max for BTC/ETH - } - } else { - maxLeverage = r.cfg.Leverage.AltcoinLeverage - if maxLeverage <= 0 { - maxLeverage = 5 // Default max for altcoins - } - } - - // Use requested leverage if provided, otherwise use max as default - leverage := requested - if leverage <= 0 { - leverage = maxLeverage - } - - // Enforce max leverage limit - if leverage > maxLeverage { - logger.Infof("📊 Backtest: capping leverage from %dx to %dx for %s", - leverage, maxLeverage, symbol) - leverage = maxLeverage - } - - return leverage -} - -func (r *Runner) remainingPosition(symbol, side string) float64 { - for _, pos := range r.account.Positions() { - if pos.Symbol == strings.ToUpper(symbol) && pos.Side == side { - return pos.Quantity - } - } - return 0 -} - -func (r *Runner) snapshotPositions(priceMap map[string]float64) []store.PositionSnapshot { - positions := r.account.Positions() - list := make([]store.PositionSnapshot, 0, len(positions)) - for _, pos := range positions { - price := priceMap[pos.Symbol] - list = append(list, store.PositionSnapshot{ - Symbol: pos.Symbol, - Side: pos.Side, - PositionAmt: pos.Quantity, - EntryPrice: pos.EntryPrice, - MarkPrice: price, - UnrealizedProfit: unrealizedPnL(pos, price), - Leverage: float64(pos.Leverage), - LiquidationPrice: pos.LiquidationPrice, - }) - } - return list -} - -func (r *Runner) convertPositions(priceMap map[string]float64) []kernel.PositionInfo { - positions := r.account.Positions() - list := make([]kernel.PositionInfo, 0, len(positions)) - for _, pos := range positions { - price := priceMap[pos.Symbol] - pnl := unrealizedPnL(pos, price) - // Calculate P&L percentage based on entry notional (position cost) - pnlPct := 0.0 - if pos.Notional > 0 { - pnlPct = (pnl / pos.Notional) * 100 - } - list = append(list, kernel.PositionInfo{ - Symbol: pos.Symbol, - Side: pos.Side, - EntryPrice: pos.EntryPrice, - MarkPrice: price, - Quantity: pos.Quantity, - Leverage: pos.Leverage, - UnrealizedPnL: pnl, - UnrealizedPnLPct: pnlPct, - LiquidationPrice: pos.LiquidationPrice, - MarginUsed: pos.Margin, - UpdateTime: time.Now().UnixMilli(), - }) - } - return list -} - -func (r *Runner) executionPrice(symbol string, markPrice float64, ts int64) float64 { - curr, next := r.feed.decisionBarSnapshot(symbol, ts) - switch r.cfg.FillPolicy { - case FillPolicyNextOpen: - if next != nil && next.Open > 0 { - return next.Open - } - case FillPolicyBarVWAP: - if curr != nil { - if vwap := barVWAP(*curr); vwap > 0 { - return vwap - } - } - case FillPolicyMidPrice: - if curr != nil && curr.High > 0 && curr.Low > 0 { - return (curr.High + curr.Low) / 2 - } - } - return markPrice -} - -func (r *Runner) totalMarginUsed() float64 { - sum := 0.0 - for _, pos := range r.account.Positions() { - sum += pos.Margin - } - return sum -} - -func (r *Runner) updateState(ts int64, equity, unrealized, marginUsed float64, priceMap map[string]float64, advancedDecision bool) { - r.stateMu.Lock() - defer r.stateMu.Unlock() - - if r.state.MaxEquity == 0 || equity > r.state.MaxEquity { - r.state.MaxEquity = equity - } - if r.state.MinEquity == 0 || equity < r.state.MinEquity { - r.state.MinEquity = equity - } - if r.state.MaxEquity > 0 { - drawdown := ((r.state.MaxEquity - equity) / r.state.MaxEquity) * 100 - if drawdown > r.state.MaxDrawdownPct { - r.state.MaxDrawdownPct = drawdown - } - } - - positions := make(map[string]PositionSnapshot) - for _, pos := range r.account.Positions() { - key := fmt.Sprintf("%s:%s", pos.Symbol, pos.Side) - positions[key] = PositionSnapshot{ - Symbol: pos.Symbol, - Side: pos.Side, - Quantity: pos.Quantity, - AvgPrice: pos.EntryPrice, - Leverage: pos.Leverage, - LiquidationPrice: pos.LiquidationPrice, - MarginUsed: pos.Margin, - OpenTime: pos.OpenTime, - AccumulatedFee: pos.AccumulatedFee, - } - } - - r.state.BarTimestamp = ts - r.state.BarIndex++ - if advancedDecision { - r.state.DecisionCycle++ - } - r.state.Cash = r.account.Cash() - r.state.Equity = equity - r.state.UnrealizedPnL = unrealized - r.state.RealizedPnL = r.account.RealizedPnL() - r.state.Positions = positions - r.state.LastUpdate = time.Now().UTC() -} - -func (r *Runner) maybeCheckpoint() error { - state := r.snapshotState() - shouldCheckpoint := false - - if r.cfg.CheckpointIntervalBars > 0 && state.BarIndex > 0 && state.BarIndex%r.cfg.CheckpointIntervalBars == 0 { - shouldCheckpoint = true - } - - interval := time.Duration(r.cfg.CheckpointIntervalSeconds) * time.Second - if interval <= 0 { - interval = 2 * time.Second - } - if time.Since(r.lastCheckpoint) >= interval { - shouldCheckpoint = true - } - - if !shouldCheckpoint { - return nil - } - - if err := r.saveCheckpoint(state); err != nil { - return err - } - - return nil -} - -func (r *Runner) snapshotForCheckpoint(state BacktestState) []PositionSnapshot { - res := make([]PositionSnapshot, 0, len(state.Positions)) - for _, pos := range state.Positions { - res = append(res, pos) - } - sort.Slice(res, func(i, j int) bool { - if res[i].Symbol == res[j].Symbol { - return res[i].Side < res[j].Side - } - return res[i].Symbol < res[j].Symbol - }) - return res -} - -func (r *Runner) checkLiquidation(ts int64, priceMap map[string]float64, cycle int) ([]TradeEvent, string, error) { - positions := append([]*position(nil), r.account.Positions()...) - events := make([]TradeEvent, 0) - var noteBuilder strings.Builder - - for _, pos := range positions { - price := priceMap[pos.Symbol] - liqPrice := pos.LiquidationPrice - trigger := false - execPrice := price - if pos.Side == "long" { - if price <= liqPrice && liqPrice > 0 { - trigger = true - execPrice = liqPrice - } - } else { - if price >= liqPrice && liqPrice > 0 { - trigger = true - execPrice = liqPrice - } - } - if !trigger { - continue - } - - realized, fee, finalPrice, err := r.account.Close(pos.Symbol, pos.Side, pos.Quantity, execPrice) - if err != nil { - return nil, "", err - } - - noteBuilder.WriteString(fmt.Sprintf("%s %s @ %.4f; ", pos.Symbol, pos.Side, finalPrice)) - - evt := TradeEvent{ - Timestamp: ts, - Symbol: pos.Symbol, - Action: "liquidated", - Side: pos.Side, - Quantity: pos.Quantity, - Price: finalPrice, - Fee: fee, - Slippage: 0, - OrderValue: finalPrice * pos.Quantity, - RealizedPnL: realized - fee, - Leverage: pos.Leverage, - Cycle: cycle, - PositionAfter: 0, - LiquidationFlag: true, - Note: fmt.Sprintf("forced liquidation at %.4f", finalPrice), - } - events = append(events, evt) - } - - if len(events) == 0 { - return events, "", nil - } - - note := strings.TrimSuffix(noteBuilder.String(), "; ") - - r.stateMu.Lock() - r.state.Liquidated = true - r.state.LiquidationNote = note - r.stateMu.Unlock() - - return events, note, nil -} - -func (r *Runner) shouldTriggerDecision(barIndex int) bool { - if r.cfg.DecisionCadenceNBars <= 1 { - return true - } - if barIndex < 0 { - return true - } - return barIndex%r.cfg.DecisionCadenceNBars == 0 -} - -func (r *Runner) handleStop(reason error) { - r.forceCheckpoint() - if reason != nil { - r.setLastError(reason) - } else { - r.setLastError(nil) - } - r.statusMu.Lock() - r.err = reason - r.status = RunStateStopped - r.statusMu.Unlock() - r.persistMetadata() - r.persistMetrics(true) - r.releaseLock() -} - -func (r *Runner) handlePause() { - r.forceCheckpoint() - r.setLastError(nil) - r.statusMu.Lock() - r.status = RunStatePaused - r.statusMu.Unlock() - r.persistMetadata() - r.persistMetrics(true) -} - -func (r *Runner) resumeFromPause() { - r.setLastError(nil) - r.statusMu.Lock() - r.status = RunStateRunning - r.statusMu.Unlock() - r.persistMetadata() -} - -func (r *Runner) handleCompletion() { - r.setLastError(nil) - r.statusMu.Lock() - r.status = RunStateCompleted - r.statusMu.Unlock() - r.persistMetadata() - r.persistMetrics(true) - r.releaseLock() -} - -func (r *Runner) handleFailure(err error) { - r.forceCheckpoint() - if err != nil { - r.setLastError(err) - } - r.statusMu.Lock() - r.err = err - r.status = RunStateFailed - r.statusMu.Unlock() - r.persistMetadata() - r.persistMetrics(true) - r.releaseLock() -} - -func (r *Runner) handleLiquidation() { - r.forceCheckpoint() - r.setLastError(errLiquidated) - r.statusMu.Lock() - r.err = errLiquidated - r.status = RunStateLiquidated - r.statusMu.Unlock() - r.persistMetadata() - r.persistMetrics(true) - r.releaseLock() -} - func (r *Runner) Pause() { select { case r.pauseCh <- struct{}{}: @@ -1292,240 +339,3 @@ func (r *Runner) snapshotState() BacktestState { } return copyState } - -func (r *Runner) persistMetadata() { - state := r.snapshotState() - meta := r.buildMetadata(state, r.Status()) - meta.CreatedAt = r.createdAt - if err := SaveRunMetadata(meta); err != nil { - logger.Infof("failed to save run metadata for %s: %v", r.cfg.RunID, err) - } else { - if err := updateRunIndex(meta, &r.cfg); err != nil { - logger.Infof("failed to update index for %s: %v", r.cfg.RunID, err) - } - } -} - -func (r *Runner) logDecision(record *store.DecisionRecord) error { - if record == nil { - return nil - } - persistDecisionRecord(r.cfg.RunID, record) - return nil -} - -func (r *Runner) persistMetrics(force bool) { - if r.cfg.RunID == "" { - return - } - - if !force && !r.lastMetricsWrite.IsZero() { - if time.Since(r.lastMetricsWrite) < metricsWriteInterval { - return - } - } - - state := r.snapshotState() - metrics, err := CalculateMetrics(r.cfg.RunID, &r.cfg, &state) - if err != nil { - logger.Infof("failed to compute metrics for %s: %v", r.cfg.RunID, err) - return - } - if metrics == nil { - return - } - if err := PersistMetrics(r.cfg.RunID, metrics); err != nil { - logger.Infof("failed to persist metrics for %s: %v", r.cfg.RunID, err) - return - } - r.lastMetricsWrite = time.Now() -} - -func (r *Runner) buildMetadata(state BacktestState, runState RunState) *RunMetadata { - if state.Liquidated && runState != RunStateLiquidated { - runState = RunStateLiquidated - } - - progress := progressPercent(state, r.cfg) - - summary := RunSummary{ - SymbolCount: len(r.cfg.Symbols), - DecisionTF: r.cfg.DecisionTimeframe, - ProcessedBars: state.BarIndex, - ProgressPct: progress, - EquityLast: state.Equity, - MaxDrawdownPct: state.MaxDrawdownPct, - Liquidated: state.Liquidated, - LiquidationNote: state.LiquidationNote, - } - - meta := &RunMetadata{ - RunID: r.cfg.RunID, - UserID: r.cfg.UserID, - State: runState, - LastError: r.lastErrorString(), - Summary: summary, - } - - return meta -} - -func progressPercent(state BacktestState, cfg BacktestConfig) float64 { - duration := cfg.Duration() - if duration <= 0 { - return 0 - } - if state.BarTimestamp == 0 { - return 0 - } - - start := time.Unix(cfg.StartTS, 0) - end := time.Unix(cfg.EndTS, 0) - current := time.UnixMilli(state.BarTimestamp) - - if !current.After(start) { - return 0 - } - if current.After(end) { - return 100 - } - - elapsed := current.Sub(start) - pct := float64(elapsed) / float64(duration) * 100 - if pct > 100 { - pct = 100 - } - if pct < 0 { - pct = 0 - } - return pct -} - -func (r *Runner) buildCheckpointFromState(state BacktestState) *Checkpoint { - return &Checkpoint{ - BarIndex: state.BarIndex, - BarTimestamp: state.BarTimestamp, - Cash: state.Cash, - Equity: state.Equity, - UnrealizedPnL: state.UnrealizedPnL, - RealizedPnL: state.RealizedPnL, - Positions: r.snapshotForCheckpoint(state), - DecisionCycle: state.DecisionCycle, - Liquidated: state.Liquidated, - LiquidationNote: state.LiquidationNote, - MaxEquity: state.MaxEquity, - MinEquity: state.MinEquity, - MaxDrawdownPct: state.MaxDrawdownPct, - AICacheRef: r.cachePath, - } -} - -func (r *Runner) saveCheckpoint(state BacktestState) error { - ckpt := r.buildCheckpointFromState(state) - if ckpt == nil { - return nil - } - if err := SaveCheckpoint(r.cfg.RunID, ckpt); err != nil { - return err - } - r.lastCheckpoint = time.Now() - return nil -} - -func (r *Runner) forceCheckpoint() { - state := r.snapshotState() - if err := r.saveCheckpoint(state); err != nil { - logger.Infof("failed to save checkpoint for %s: %v", r.cfg.RunID, err) - } -} - -func (r *Runner) RestoreFromCheckpoint() error { - ckpt, err := LoadCheckpoint(r.cfg.RunID) - if err != nil { - return err - } - return r.applyCheckpoint(ckpt) -} - -func (r *Runner) applyCheckpoint(ckpt *Checkpoint) error { - if ckpt == nil { - return fmt.Errorf("checkpoint is nil") - } - r.account.RestoreFromSnapshots(ckpt.Cash, ckpt.RealizedPnL, ckpt.Positions) - r.stateMu.Lock() - defer r.stateMu.Unlock() - r.state.BarIndex = ckpt.BarIndex - r.state.BarTimestamp = ckpt.BarTimestamp - r.state.Cash = ckpt.Cash - r.state.Equity = ckpt.Equity - r.state.UnrealizedPnL = ckpt.UnrealizedPnL - r.state.RealizedPnL = ckpt.RealizedPnL - r.state.DecisionCycle = ckpt.DecisionCycle - r.state.Liquidated = ckpt.Liquidated - r.state.LiquidationNote = ckpt.LiquidationNote - r.state.MaxEquity = ckpt.MaxEquity - r.state.MinEquity = ckpt.MinEquity - r.state.MaxDrawdownPct = ckpt.MaxDrawdownPct - r.state.Positions = snapshotsToMap(ckpt.Positions) - r.state.LastUpdate = time.Now().UTC() - r.lastCheckpoint = time.Now() - return nil -} - -func snapshotsToMap(snaps []PositionSnapshot) map[string]PositionSnapshot { - positions := make(map[string]PositionSnapshot, len(snaps)) - for _, snap := range snaps { - key := fmt.Sprintf("%s:%s", snap.Symbol, snap.Side) - positions[key] = snap - } - return positions -} - -func sortDecisionsByPriority(decisions []kernel.Decision) []kernel.Decision { - if len(decisions) <= 1 { - return decisions - } - - priority := func(action string) int { - switch action { - case "close_long", "close_short": - return 1 - case "open_long", "open_short": - return 2 - case "hold", "wait": - return 3 - default: - return 99 - } - } - - result := make([]kernel.Decision, len(decisions)) - copy(result, decisions) - - sort.Slice(result, func(i, j int) bool { - pi := priority(result[i].Action) - pj := priority(result[j].Action) - if pi != pj { - return pi < pj - } - return i < j - }) - - return result -} - -func barVWAP(k market.Kline) float64 { - values := []float64{k.Open, k.High, k.Low, k.Close} - sum := 0.0 - count := 0.0 - for _, v := range values { - if v > 0 { - sum += v - count++ - } - } - if count == 0 { - return 0 - } - return sum / count -} diff --git a/backtest/runner_loop.go b/backtest/runner_loop.go new file mode 100644 index 00000000..81703bc1 --- /dev/null +++ b/backtest/runner_loop.go @@ -0,0 +1,563 @@ +package backtest + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sort" + "time" + + "nofx/kernel" + "nofx/logger" + "nofx/market" + "nofx/store" +) + +func (r *Runner) loop(ctx context.Context) { + defer close(r.doneCh) + + for { + select { + case <-ctx.Done(): + r.handleStop(fmt.Errorf("context canceled: %w", ctx.Err())) + return + case <-r.stopCh: + r.handleStop(nil) + return + case <-r.pauseCh: + r.handlePause() + <-r.resumeCh + r.resumeFromPause() + default: + } + + err := r.stepOnce() + if errors.Is(err, errBacktestCompleted) { + r.handleCompletion() + return + } + if errors.Is(err, errLiquidated) { + r.handleLiquidation() + return + } + if err != nil { + r.handleFailure(err) + return + } + } +} + +func (r *Runner) stepOnce() error { + state := r.snapshotState() + if state.BarIndex >= r.feed.DecisionBarCount() { + return errBacktestCompleted + } + + ts := r.feed.DecisionTimestamp(state.BarIndex) + + marketData, multiTF, err := r.feed.BuildMarketData(ts) + if err != nil { + return err + } + + priceMap := make(map[string]float64, len(marketData)) + for symbol, data := range marketData { + priceMap[symbol] = data.CurrentPrice + } + + callCount := state.DecisionCycle + 1 + shouldDecide := r.shouldTriggerDecision(state.BarIndex) + + var ( + record *store.DecisionRecord + decisionActions []store.DecisionAction + tradeEvents = make([]TradeEvent, 0) + execLog []string + hadError bool + ) + + decisionAttempted := shouldDecide + + if shouldDecide { + ctx, rec, err := r.buildDecisionContext(ts, marketData, multiTF, priceMap, callCount) + if err != nil { + // Defensive nil check to prevent panic if buildDecisionContext returns error with nil record + if rec != nil { + rec.Success = false + rec.ErrorMessage = fmt.Sprintf("failed to build trading context: %v", err) + _ = r.logDecision(rec) + } + return err + } + record = rec + + var ( + fullDecision *kernel.FullDecision + fromCache bool + cacheKey string + ) + if r.aiCache != nil { + if key, err := computeCacheKey(ctx, r.cfg.PromptVariant, ts); err == nil { + cacheKey = key + if cached, ok := r.aiCache.Get(cacheKey); ok { + fullDecision = cached + fromCache = true + } else if r.cfg.ReplayOnly { + decisionErr := fmt.Errorf("replay_only enabled but cache miss at %d", ts) + record.Success = false + record.ErrorMessage = fmt.Sprintf("cached decision not found for ts=%d", ts) + _ = r.logDecision(record) + return decisionErr + } + } else { + logger.Infof("failed to compute ai cache key: %v", err) + } + } + + if !fromCache { + fd, err := r.invokeAIWithRetry(ctx) + if err != nil { + decisionAttempted = true + hadError = true + record.Success = false + record.ErrorMessage = fmt.Sprintf("AI decision failed: %v", err) + execLog = append(execLog, fmt.Sprintf("⚠️ AI decision failed: %v", err)) + r.setLastError(err) + } else { + fullDecision = fd + if r.cfg.CacheAI && r.aiCache != nil && cacheKey != "" { + if err := r.aiCache.Put(cacheKey, r.cfg.PromptVariant, ts, fullDecision); err != nil { + logger.Infof("failed to persist ai cache for %s: %v", r.cfg.RunID, err) + } + } + } + } + + if fullDecision != nil { + r.fillDecisionRecord(record, fullDecision) + + sorted := sortDecisionsByPriority(fullDecision.Decisions) + + prevLogs := execLog + decisionActions = make([]store.DecisionAction, 0, len(sorted)) + execLog = make([]string, 0, len(sorted)+len(prevLogs)) + if len(prevLogs) > 0 { + execLog = append(execLog, prevLogs...) + } + + for _, dec := range sorted { + actionRecord, trades, logEntry, execErr := r.executeDecision(dec, priceMap, ts, callCount) + if execErr != nil { + actionRecord.Success = false + actionRecord.Error = execErr.Error() + hadError = true + execLog = append(execLog, fmt.Sprintf("❌ %s %s: %v", dec.Symbol, dec.Action, execErr)) + } else { + actionRecord.Success = true + execLog = append(execLog, fmt.Sprintf("✓ %s %s", dec.Symbol, dec.Action)) + } + if len(trades) > 0 { + tradeEvents = append(tradeEvents, trades...) + } + if logEntry != "" { + execLog = append(execLog, logEntry) + } + decisionActions = append(decisionActions, actionRecord) + } + } + } + + cycleForLog := state.DecisionCycle + if decisionAttempted { + cycleForLog = callCount + } + + liquidationEvents, liquidationNote, err := r.checkLiquidation(ts, priceMap, cycleForLog) + if err != nil { + if record != nil { + record.Success = false + record.ErrorMessage = err.Error() + _ = r.logDecision(record) + } + return err + } + if len(liquidationEvents) > 0 { + hadError = true + tradeEvents = append(tradeEvents, liquidationEvents...) + if record != nil { + execLog = append(execLog, fmt.Sprintf("⚠️ Forced liquidation: %s", liquidationNote)) + } + } + + if record != nil { + record.Decisions = decisionActions + record.ExecutionLog = execLog + record.Success = !hadError && liquidationNote == "" + if liquidationNote != "" { + record.ErrorMessage = liquidationNote + } + } + + equity, unrealized, _ := r.account.TotalEquity(priceMap) + marginUsed := r.totalMarginUsed() + + r.updateState(ts, equity, unrealized, marginUsed, priceMap, decisionAttempted) + + snapshot := r.snapshotState() + drawdownPct := 0.0 + if snapshot.MaxEquity > 0 { + drawdownPct = ((snapshot.MaxEquity - snapshot.Equity) / snapshot.MaxEquity) * 100 + } + + equityPoint := EquityPoint{ + Timestamp: ts, + Equity: snapshot.Equity, + Available: snapshot.Cash, + PnL: snapshot.Equity - r.account.InitialBalance(), + PnLPct: ((snapshot.Equity - r.account.InitialBalance()) / r.account.InitialBalance()) * 100, + DrawdownPct: drawdownPct, + Cycle: snapshot.DecisionCycle, + } + + if err := appendEquityPoint(r.cfg.RunID, equityPoint); err != nil { + return err + } + + for _, evt := range tradeEvents { + if err := appendTradeEvent(r.cfg.RunID, evt); err != nil { + return err + } + } + + if record != nil { + if err := r.logDecision(record); err != nil { + return err + } + } + + if err := saveProgress(r.cfg.RunID, &snapshot, &r.cfg); err != nil { + return err + } + + if err := r.maybeCheckpoint(); err != nil { + return err + } + + r.persistMetadata() + r.persistMetrics(false) + + if !hadError && liquidationNote == "" { + r.setLastError(nil) + } + + if snapshot.Liquidated { + return errLiquidated + } + + return nil +} + +func (r *Runner) buildDecisionContext(ts int64, marketData map[string]*market.Data, multiTF map[string]map[string]*market.Data, priceMap map[string]float64, callCount int) (*kernel.Context, *store.DecisionRecord, error) { + equity, unrealized, _ := r.account.TotalEquity(priceMap) + available := r.account.Cash() + marginUsed := r.totalMarginUsed() + marginPct := 0.0 + if equity > 0 { + marginPct = (marginUsed / equity) * 100 + } + + accountInfo := kernel.AccountInfo{ + TotalEquity: equity, + AvailableBalance: available, + TotalPnL: equity - r.account.InitialBalance(), + TotalPnLPct: ((equity - r.account.InitialBalance()) / r.account.InitialBalance()) * 100, + MarginUsed: marginUsed, + MarginUsedPct: marginPct, + PositionCount: len(r.account.Positions()), + } + + positions := r.convertPositions(priceMap) + + // Get candidate coins from strategy engine (includes source info) + candidateCoins, err := r.strategyEngine.GetCandidateCoins() + if err != nil { + // Fallback to simple list if strategy engine fails + candidateCoins = make([]kernel.CandidateCoin, 0, len(r.cfg.Symbols)) + for _, sym := range r.cfg.Symbols { + candidateCoins = append(candidateCoins, kernel.CandidateCoin{Symbol: sym, Sources: []string{"backtest"}}) + } + } + + runtime := int((ts - int64(r.cfg.StartTS*1000)) / 60000) + ctx := &kernel.Context{ + CurrentTime: time.UnixMilli(ts).UTC().Format("2006-01-02 15:04:05 UTC"), + RuntimeMinutes: runtime, + CallCount: callCount, + Account: accountInfo, + Positions: positions, + CandidateCoins: candidateCoins, + PromptVariant: r.cfg.PromptVariant, + MarketDataMap: marketData, + MultiTFMarket: multiTF, + BTCETHLeverage: r.cfg.Leverage.BTCETHLeverage, + AltcoinLeverage: r.cfg.Leverage.AltcoinLeverage, + Timeframes: r.cfg.Timeframes, + } + + // Fetch quantitative data if enabled in strategy (uses current data as approximation) + strategyConfig := r.strategyEngine.GetConfig() + if strategyConfig.Indicators.EnableQuantData { + // Collect symbols to query (candidate coins + position coins) + symbolSet := make(map[string]bool) + for _, sym := range r.cfg.Symbols { + symbolSet[sym] = true + } + for _, pos := range positions { + symbolSet[pos.Symbol] = true + } + symbols := make([]string, 0, len(symbolSet)) + for sym := range symbolSet { + symbols = append(symbols, sym) + } + ctx.QuantDataMap = r.strategyEngine.FetchQuantDataBatch(symbols) + if len(ctx.QuantDataMap) > 0 { + logger.Infof("📊 Backtest: fetched quant data for %d symbols", len(ctx.QuantDataMap)) + } + } + + // Fetch OI ranking data if enabled in strategy (uses current data as approximation) + if strategyConfig.Indicators.EnableOIRanking { + ctx.OIRankingData = r.strategyEngine.FetchOIRankingData() + if ctx.OIRankingData != nil { + logger.Infof("📊 Backtest: OI ranking data ready: %d top, %d low positions", + len(ctx.OIRankingData.TopPositions), len(ctx.OIRankingData.LowPositions)) + } + } + + // Fetch NetFlow ranking data if enabled in strategy + if strategyConfig.Indicators.EnableNetFlowRanking { + ctx.NetFlowRankingData = r.strategyEngine.FetchNetFlowRankingData() + if ctx.NetFlowRankingData != nil { + logger.Infof("💰 Backtest: NetFlow ranking data ready: inst_in=%d, inst_out=%d", + len(ctx.NetFlowRankingData.InstitutionFutureTop), len(ctx.NetFlowRankingData.InstitutionFutureLow)) + } + } + + // Fetch Price ranking data if enabled in strategy + if strategyConfig.Indicators.EnablePriceRanking { + ctx.PriceRankingData = r.strategyEngine.FetchPriceRankingData() + if ctx.PriceRankingData != nil { + logger.Infof("📈 Backtest: Price ranking data ready for %d durations", + len(ctx.PriceRankingData.Durations)) + } + } + + record := &store.DecisionRecord{ + AccountState: store.AccountSnapshot{ + TotalBalance: accountInfo.TotalEquity, + AvailableBalance: accountInfo.AvailableBalance, + TotalUnrealizedProfit: unrealized, + PositionCount: accountInfo.PositionCount, + MarginUsedPct: accountInfo.MarginUsedPct, + }, + CandidateCoins: make([]string, 0, len(candidateCoins)), + Positions: r.snapshotPositions(priceMap), + } + for _, coin := range candidateCoins { + record.CandidateCoins = append(record.CandidateCoins, coin.Symbol) + } + record.Timestamp = time.UnixMilli(ts).UTC() + + return ctx, record, nil +} + +func (r *Runner) fillDecisionRecord(record *store.DecisionRecord, full *kernel.FullDecision) { + record.InputPrompt = full.UserPrompt + record.CoTTrace = full.CoTTrace + if len(full.Decisions) > 0 { + if data, err := json.MarshalIndent(full.Decisions, "", " "); err == nil { + record.DecisionJSON = string(data) + } + } +} + +func (r *Runner) invokeAIWithRetry(ctx *kernel.Context) (*kernel.FullDecision, error) { + var lastErr error + for attempt := 0; attempt < aiDecisionMaxRetries; attempt++ { + // Use GetFullDecisionWithStrategy with the pre-configured strategy engine + // This ensures backtest uses the same unified prompt generation as live trading + fd, err := kernel.GetFullDecisionWithStrategy( + ctx, + r.mcpClient, + r.strategyEngine, + r.cfg.PromptVariant, + ) + if err == nil { + return fd, nil + } + lastErr = err + delay := time.Duration(attempt+1) * 500 * time.Millisecond + time.Sleep(delay) + } + return nil, lastErr +} + +func (r *Runner) shouldTriggerDecision(barIndex int) bool { + if r.cfg.DecisionCadenceNBars <= 1 { + return true + } + if barIndex < 0 { + return true + } + return barIndex%r.cfg.DecisionCadenceNBars == 0 +} + +func (r *Runner) updateState(ts int64, equity, unrealized, marginUsed float64, priceMap map[string]float64, advancedDecision bool) { + r.stateMu.Lock() + defer r.stateMu.Unlock() + + if r.state.MaxEquity == 0 || equity > r.state.MaxEquity { + r.state.MaxEquity = equity + } + if r.state.MinEquity == 0 || equity < r.state.MinEquity { + r.state.MinEquity = equity + } + if r.state.MaxEquity > 0 { + drawdown := ((r.state.MaxEquity - equity) / r.state.MaxEquity) * 100 + if drawdown > r.state.MaxDrawdownPct { + r.state.MaxDrawdownPct = drawdown + } + } + + positions := make(map[string]PositionSnapshot) + for _, pos := range r.account.Positions() { + key := fmt.Sprintf("%s:%s", pos.Symbol, pos.Side) + positions[key] = PositionSnapshot{ + Symbol: pos.Symbol, + Side: pos.Side, + Quantity: pos.Quantity, + AvgPrice: pos.EntryPrice, + Leverage: pos.Leverage, + LiquidationPrice: pos.LiquidationPrice, + MarginUsed: pos.Margin, + OpenTime: pos.OpenTime, + AccumulatedFee: pos.AccumulatedFee, + } + } + + r.state.BarTimestamp = ts + r.state.BarIndex++ + if advancedDecision { + r.state.DecisionCycle++ + } + r.state.Cash = r.account.Cash() + r.state.Equity = equity + r.state.UnrealizedPnL = unrealized + r.state.RealizedPnL = r.account.RealizedPnL() + r.state.Positions = positions + r.state.LastUpdate = time.Now().UTC() +} + +func (r *Runner) handleStop(reason error) { + r.forceCheckpoint() + if reason != nil { + r.setLastError(reason) + } else { + r.setLastError(nil) + } + r.statusMu.Lock() + r.err = reason + r.status = RunStateStopped + r.statusMu.Unlock() + r.persistMetadata() + r.persistMetrics(true) + r.releaseLock() +} + +func (r *Runner) handlePause() { + r.forceCheckpoint() + r.setLastError(nil) + r.statusMu.Lock() + r.status = RunStatePaused + r.statusMu.Unlock() + r.persistMetadata() + r.persistMetrics(true) +} + +func (r *Runner) resumeFromPause() { + r.setLastError(nil) + r.statusMu.Lock() + r.status = RunStateRunning + r.statusMu.Unlock() + r.persistMetadata() +} + +func (r *Runner) handleCompletion() { + r.setLastError(nil) + r.statusMu.Lock() + r.status = RunStateCompleted + r.statusMu.Unlock() + r.persistMetadata() + r.persistMetrics(true) + r.releaseLock() +} + +func (r *Runner) handleFailure(err error) { + r.forceCheckpoint() + if err != nil { + r.setLastError(err) + } + r.statusMu.Lock() + r.err = err + r.status = RunStateFailed + r.statusMu.Unlock() + r.persistMetadata() + r.persistMetrics(true) + r.releaseLock() +} + +func (r *Runner) handleLiquidation() { + r.forceCheckpoint() + r.setLastError(errLiquidated) + r.statusMu.Lock() + r.err = errLiquidated + r.status = RunStateLiquidated + r.statusMu.Unlock() + r.persistMetadata() + r.persistMetrics(true) + r.releaseLock() +} + +func sortDecisionsByPriority(decisions []kernel.Decision) []kernel.Decision { + if len(decisions) <= 1 { + return decisions + } + + priority := func(action string) int { + switch action { + case "close_long", "close_short": + return 1 + case "open_long", "open_short": + return 2 + case "hold", "wait": + return 3 + default: + return 99 + } + } + + result := make([]kernel.Decision, len(decisions)) + copy(result, decisions) + + sort.Slice(result, func(i, j int) bool { + pi := priority(result[i].Action) + pj := priority(result[j].Action) + if pi != pj { + return pi < pj + } + return i < j + }) + + return result +} diff --git a/backtest/runner_metrics.go b/backtest/runner_metrics.go new file mode 100644 index 00000000..e3dea04e --- /dev/null +++ b/backtest/runner_metrics.go @@ -0,0 +1,239 @@ +package backtest + +import ( + "fmt" + "sort" + "time" + + "nofx/logger" + "nofx/store" +) + +func (r *Runner) persistMetadata() { + state := r.snapshotState() + meta := r.buildMetadata(state, r.Status()) + meta.CreatedAt = r.createdAt + if err := SaveRunMetadata(meta); err != nil { + logger.Infof("failed to save run metadata for %s: %v", r.cfg.RunID, err) + } else { + if err := updateRunIndex(meta, &r.cfg); err != nil { + logger.Infof("failed to update index for %s: %v", r.cfg.RunID, err) + } + } +} + +func (r *Runner) logDecision(record *store.DecisionRecord) error { + if record == nil { + return nil + } + persistDecisionRecord(r.cfg.RunID, record) + return nil +} + +func (r *Runner) persistMetrics(force bool) { + if r.cfg.RunID == "" { + return + } + + if !force && !r.lastMetricsWrite.IsZero() { + if time.Since(r.lastMetricsWrite) < metricsWriteInterval { + return + } + } + + state := r.snapshotState() + metrics, err := CalculateMetrics(r.cfg.RunID, &r.cfg, &state) + if err != nil { + logger.Infof("failed to compute metrics for %s: %v", r.cfg.RunID, err) + return + } + if metrics == nil { + return + } + if err := PersistMetrics(r.cfg.RunID, metrics); err != nil { + logger.Infof("failed to persist metrics for %s: %v", r.cfg.RunID, err) + return + } + r.lastMetricsWrite = time.Now() +} + +func (r *Runner) buildMetadata(state BacktestState, runState RunState) *RunMetadata { + if state.Liquidated && runState != RunStateLiquidated { + runState = RunStateLiquidated + } + + progress := progressPercent(state, r.cfg) + + summary := RunSummary{ + SymbolCount: len(r.cfg.Symbols), + DecisionTF: r.cfg.DecisionTimeframe, + ProcessedBars: state.BarIndex, + ProgressPct: progress, + EquityLast: state.Equity, + MaxDrawdownPct: state.MaxDrawdownPct, + Liquidated: state.Liquidated, + LiquidationNote: state.LiquidationNote, + } + + meta := &RunMetadata{ + RunID: r.cfg.RunID, + UserID: r.cfg.UserID, + State: runState, + LastError: r.lastErrorString(), + Summary: summary, + } + + return meta +} + +func progressPercent(state BacktestState, cfg BacktestConfig) float64 { + duration := cfg.Duration() + if duration <= 0 { + return 0 + } + if state.BarTimestamp == 0 { + return 0 + } + + start := time.Unix(cfg.StartTS, 0) + end := time.Unix(cfg.EndTS, 0) + current := time.UnixMilli(state.BarTimestamp) + + if !current.After(start) { + return 0 + } + if current.After(end) { + return 100 + } + + elapsed := current.Sub(start) + pct := float64(elapsed) / float64(duration) * 100 + if pct > 100 { + pct = 100 + } + if pct < 0 { + pct = 0 + } + return pct +} + +func (r *Runner) maybeCheckpoint() error { + state := r.snapshotState() + shouldCheckpoint := false + + if r.cfg.CheckpointIntervalBars > 0 && state.BarIndex > 0 && state.BarIndex%r.cfg.CheckpointIntervalBars == 0 { + shouldCheckpoint = true + } + + interval := time.Duration(r.cfg.CheckpointIntervalSeconds) * time.Second + if interval <= 0 { + interval = 2 * time.Second + } + if time.Since(r.lastCheckpoint) >= interval { + shouldCheckpoint = true + } + + if !shouldCheckpoint { + return nil + } + + if err := r.saveCheckpoint(state); err != nil { + return err + } + + return nil +} + +func (r *Runner) snapshotForCheckpoint(state BacktestState) []PositionSnapshot { + res := make([]PositionSnapshot, 0, len(state.Positions)) + for _, pos := range state.Positions { + res = append(res, pos) + } + sort.Slice(res, func(i, j int) bool { + if res[i].Symbol == res[j].Symbol { + return res[i].Side < res[j].Side + } + return res[i].Symbol < res[j].Symbol + }) + return res +} + +func (r *Runner) buildCheckpointFromState(state BacktestState) *Checkpoint { + return &Checkpoint{ + BarIndex: state.BarIndex, + BarTimestamp: state.BarTimestamp, + Cash: state.Cash, + Equity: state.Equity, + UnrealizedPnL: state.UnrealizedPnL, + RealizedPnL: state.RealizedPnL, + Positions: r.snapshotForCheckpoint(state), + DecisionCycle: state.DecisionCycle, + Liquidated: state.Liquidated, + LiquidationNote: state.LiquidationNote, + MaxEquity: state.MaxEquity, + MinEquity: state.MinEquity, + MaxDrawdownPct: state.MaxDrawdownPct, + AICacheRef: r.cachePath, + } +} + +func (r *Runner) saveCheckpoint(state BacktestState) error { + ckpt := r.buildCheckpointFromState(state) + if ckpt == nil { + return nil + } + if err := SaveCheckpoint(r.cfg.RunID, ckpt); err != nil { + return err + } + r.lastCheckpoint = time.Now() + return nil +} + +func (r *Runner) forceCheckpoint() { + state := r.snapshotState() + if err := r.saveCheckpoint(state); err != nil { + logger.Infof("failed to save checkpoint for %s: %v", r.cfg.RunID, err) + } +} + +func (r *Runner) RestoreFromCheckpoint() error { + ckpt, err := LoadCheckpoint(r.cfg.RunID) + if err != nil { + return err + } + return r.applyCheckpoint(ckpt) +} + +func (r *Runner) applyCheckpoint(ckpt *Checkpoint) error { + if ckpt == nil { + return fmt.Errorf("checkpoint is nil") + } + r.account.RestoreFromSnapshots(ckpt.Cash, ckpt.RealizedPnL, ckpt.Positions) + r.stateMu.Lock() + defer r.stateMu.Unlock() + r.state.BarIndex = ckpt.BarIndex + r.state.BarTimestamp = ckpt.BarTimestamp + r.state.Cash = ckpt.Cash + r.state.Equity = ckpt.Equity + r.state.UnrealizedPnL = ckpt.UnrealizedPnL + r.state.RealizedPnL = ckpt.RealizedPnL + r.state.DecisionCycle = ckpt.DecisionCycle + r.state.Liquidated = ckpt.Liquidated + r.state.LiquidationNote = ckpt.LiquidationNote + r.state.MaxEquity = ckpt.MaxEquity + r.state.MinEquity = ckpt.MinEquity + r.state.MaxDrawdownPct = ckpt.MaxDrawdownPct + r.state.Positions = snapshotsToMap(ckpt.Positions) + r.state.LastUpdate = time.Now().UTC() + r.lastCheckpoint = time.Now() + return nil +} + +func snapshotsToMap(snaps []PositionSnapshot) map[string]PositionSnapshot { + positions := make(map[string]PositionSnapshot, len(snaps)) + for _, snap := range snaps { + key := fmt.Sprintf("%s:%s", snap.Symbol, snap.Side) + positions[key] = snap + } + return positions +} diff --git a/backtest/runner_orders.go b/backtest/runner_orders.go new file mode 100644 index 00000000..8e6bdfa0 --- /dev/null +++ b/backtest/runner_orders.go @@ -0,0 +1,420 @@ +package backtest + +import ( + "fmt" + "strings" + "time" + + "nofx/kernel" + "nofx/logger" + "nofx/market" + "nofx/store" +) + +func (r *Runner) executeDecision(dec kernel.Decision, priceMap map[string]float64, ts int64, cycle int) (store.DecisionAction, []TradeEvent, string, error) { + symbol := dec.Symbol + if symbol == "" { + return store.DecisionAction{}, nil, "", fmt.Errorf("empty symbol in decision") + } + + usedLeverage := r.resolveLeverage(dec.Leverage, symbol) + actionRecord := store.DecisionAction{ + Action: dec.Action, + Symbol: symbol, + Leverage: usedLeverage, + Timestamp: time.UnixMilli(ts).UTC(), + } + + if priceMap == nil { + return actionRecord, nil, "", fmt.Errorf("priceMap is nil") + } + + basePrice, ok := priceMap[symbol] + if !ok || basePrice <= 0 { + return actionRecord, nil, "", fmt.Errorf("price unavailable for %s (found=%v, price=%.4f)", symbol, ok, basePrice) + } + fillPrice := r.executionPrice(symbol, basePrice, ts) + + switch dec.Action { + case "open_long": + qty := r.determineQuantity(dec, basePrice) + if qty <= 0 { + return actionRecord, nil, "", fmt.Errorf("invalid qty") + } + pos, fee, execPrice, err := r.account.Open(symbol, "long", qty, usedLeverage, fillPrice, ts) + if err != nil { + return actionRecord, nil, "", err + } + actionRecord.Quantity = qty + actionRecord.Price = execPrice + actionRecord.Leverage = pos.Leverage + trade := TradeEvent{ + Timestamp: ts, + Symbol: symbol, + Action: dec.Action, + Side: "long", + Quantity: qty, + Price: execPrice, + Fee: fee, + Slippage: execPrice - basePrice, + OrderValue: execPrice * qty, + RealizedPnL: 0, + Leverage: pos.Leverage, + Cycle: cycle, + PositionAfter: pos.Quantity, + } + return actionRecord, []TradeEvent{trade}, "", nil + + case "open_short": + qty := r.determineQuantity(dec, basePrice) + if qty <= 0 { + return actionRecord, nil, "", fmt.Errorf("invalid qty") + } + pos, fee, execPrice, err := r.account.Open(symbol, "short", qty, usedLeverage, fillPrice, ts) + if err != nil { + return actionRecord, nil, "", err + } + actionRecord.Quantity = qty + actionRecord.Price = execPrice + actionRecord.Leverage = pos.Leverage + trade := TradeEvent{ + Timestamp: ts, + Symbol: symbol, + Action: dec.Action, + Side: "short", + Quantity: qty, + Price: execPrice, + Fee: fee, + Slippage: basePrice - execPrice, + OrderValue: execPrice * qty, + RealizedPnL: 0, + Leverage: pos.Leverage, + Cycle: cycle, + PositionAfter: pos.Quantity, + } + return actionRecord, []TradeEvent{trade}, "", nil + + case "close_long": + qty := r.determineCloseQuantity(symbol, "long", dec) + if qty <= 0 { + return actionRecord, nil, "", fmt.Errorf("invalid close qty") + } + posLev := r.account.positionLeverage(symbol, "long") + realized, fee, execPrice, err := r.account.Close(symbol, "long", qty, fillPrice) + if err != nil { + return actionRecord, nil, "", err + } + actionRecord.Quantity = qty + actionRecord.Price = execPrice + actionRecord.Leverage = posLev + trade := TradeEvent{ + Timestamp: ts, + Symbol: symbol, + Action: dec.Action, + Side: "long", + Quantity: qty, + Price: execPrice, + Fee: fee, + Slippage: basePrice - execPrice, + OrderValue: execPrice * qty, + RealizedPnL: realized - fee, + Leverage: posLev, + Cycle: cycle, + PositionAfter: r.remainingPosition(symbol, "long"), + } + return actionRecord, []TradeEvent{trade}, "", nil + + case "close_short": + qty := r.determineCloseQuantity(symbol, "short", dec) + if qty <= 0 { + return actionRecord, nil, "", fmt.Errorf("invalid close qty") + } + posLev := r.account.positionLeverage(symbol, "short") + realized, fee, execPrice, err := r.account.Close(symbol, "short", qty, fillPrice) + if err != nil { + return actionRecord, nil, "", err + } + actionRecord.Quantity = qty + actionRecord.Price = execPrice + actionRecord.Leverage = posLev + trade := TradeEvent{ + Timestamp: ts, + Symbol: symbol, + Action: dec.Action, + Side: "short", + Quantity: qty, + Price: execPrice, + Fee: fee, + Slippage: execPrice - basePrice, + OrderValue: execPrice * qty, + RealizedPnL: realized - fee, + Leverage: posLev, + Cycle: cycle, + PositionAfter: r.remainingPosition(symbol, "short"), + } + return actionRecord, []TradeEvent{trade}, "", nil + + case "hold", "wait": + return actionRecord, nil, fmt.Sprintf("hold position: %s", dec.Action), nil + default: + return actionRecord, nil, "", fmt.Errorf("unsupported action %s", dec.Action) + } +} + +// MinPositionSizeUSD is the minimum position size in USD to avoid dust positions +const MinPositionSizeUSD = 10.0 + +func (r *Runner) determineQuantity(dec kernel.Decision, price float64) float64 { + snapshot := r.snapshotState() + equity := snapshot.Equity + if equity <= 0 { + equity = r.account.InitialBalance() + } + + // Get leverage for this symbol + leverage := r.resolveLeverage(dec.Leverage, dec.Symbol) + if leverage <= 0 { + leverage = 5 + } + + // Calculate available margin (leave some buffer for fees) + availableCash := r.account.Cash() + maxMarginToUse := availableCash * 0.9 // Use max 90% of available cash + maxPositionValue := maxMarginToUse * float64(leverage) + + sizeUSD := dec.PositionSizeUSD + if sizeUSD <= 0 { + // Default to 5% of equity, but cap to available margin + sizeUSD = 0.05 * equity + } + + // Cap position size to what we can actually afford + if sizeUSD > maxPositionValue { + logger.Infof("📊 Backtest: capping position from %.2f to %.2f (available margin: %.2f, leverage: %dx)", + sizeUSD, maxPositionValue, maxMarginToUse, leverage) + sizeUSD = maxPositionValue + } + + // Reject positions below minimum size to avoid dust positions + if sizeUSD < MinPositionSizeUSD { + logger.Infof("📊 Backtest: rejecting position size %.2f USD (below minimum %.2f USD)", + sizeUSD, MinPositionSizeUSD) + return 0 + } + + qty := sizeUSD / price + if qty < 0 { + qty = 0 + } + return qty +} + +func (r *Runner) determineCloseQuantity(symbol, side string, dec kernel.Decision) float64 { + for _, pos := range r.account.Positions() { + if pos.Symbol == strings.ToUpper(symbol) && pos.Side == side { + return pos.Quantity + } + } + return 0 +} + +func (r *Runner) resolveLeverage(requested int, symbol string) int { + sym := strings.ToUpper(symbol) + isBTCETH := sym == "BTCUSDT" || sym == "ETHUSDT" + + // Determine configured max leverage for this symbol type + var maxLeverage int + if isBTCETH { + maxLeverage = r.cfg.Leverage.BTCETHLeverage + if maxLeverage <= 0 { + maxLeverage = 10 // Default max for BTC/ETH + } + } else { + maxLeverage = r.cfg.Leverage.AltcoinLeverage + if maxLeverage <= 0 { + maxLeverage = 5 // Default max for altcoins + } + } + + // Use requested leverage if provided, otherwise use max as default + leverage := requested + if leverage <= 0 { + leverage = maxLeverage + } + + // Enforce max leverage limit + if leverage > maxLeverage { + logger.Infof("📊 Backtest: capping leverage from %dx to %dx for %s", + leverage, maxLeverage, symbol) + leverage = maxLeverage + } + + return leverage +} + +func (r *Runner) remainingPosition(symbol, side string) float64 { + for _, pos := range r.account.Positions() { + if pos.Symbol == strings.ToUpper(symbol) && pos.Side == side { + return pos.Quantity + } + } + return 0 +} + +func (r *Runner) snapshotPositions(priceMap map[string]float64) []store.PositionSnapshot { + positions := r.account.Positions() + list := make([]store.PositionSnapshot, 0, len(positions)) + for _, pos := range positions { + price := priceMap[pos.Symbol] + list = append(list, store.PositionSnapshot{ + Symbol: pos.Symbol, + Side: pos.Side, + PositionAmt: pos.Quantity, + EntryPrice: pos.EntryPrice, + MarkPrice: price, + UnrealizedProfit: unrealizedPnL(pos, price), + Leverage: float64(pos.Leverage), + LiquidationPrice: pos.LiquidationPrice, + }) + } + return list +} + +func (r *Runner) convertPositions(priceMap map[string]float64) []kernel.PositionInfo { + positions := r.account.Positions() + list := make([]kernel.PositionInfo, 0, len(positions)) + for _, pos := range positions { + price := priceMap[pos.Symbol] + pnl := unrealizedPnL(pos, price) + // Calculate P&L percentage based on entry notional (position cost) + pnlPct := 0.0 + if pos.Notional > 0 { + pnlPct = (pnl / pos.Notional) * 100 + } + list = append(list, kernel.PositionInfo{ + Symbol: pos.Symbol, + Side: pos.Side, + EntryPrice: pos.EntryPrice, + MarkPrice: price, + Quantity: pos.Quantity, + Leverage: pos.Leverage, + UnrealizedPnL: pnl, + UnrealizedPnLPct: pnlPct, + LiquidationPrice: pos.LiquidationPrice, + MarginUsed: pos.Margin, + UpdateTime: time.Now().UnixMilli(), + }) + } + return list +} + +func (r *Runner) executionPrice(symbol string, markPrice float64, ts int64) float64 { + curr, next := r.feed.decisionBarSnapshot(symbol, ts) + switch r.cfg.FillPolicy { + case FillPolicyNextOpen: + if next != nil && next.Open > 0 { + return next.Open + } + case FillPolicyBarVWAP: + if curr != nil { + if vwap := barVWAP(*curr); vwap > 0 { + return vwap + } + } + case FillPolicyMidPrice: + if curr != nil && curr.High > 0 && curr.Low > 0 { + return (curr.High + curr.Low) / 2 + } + } + return markPrice +} + +func (r *Runner) totalMarginUsed() float64 { + sum := 0.0 + for _, pos := range r.account.Positions() { + sum += pos.Margin + } + return sum +} + +func (r *Runner) checkLiquidation(ts int64, priceMap map[string]float64, cycle int) ([]TradeEvent, string, error) { + positions := append([]*position(nil), r.account.Positions()...) + events := make([]TradeEvent, 0) + var noteBuilder strings.Builder + + for _, pos := range positions { + price := priceMap[pos.Symbol] + liqPrice := pos.LiquidationPrice + trigger := false + execPrice := price + if pos.Side == "long" { + if price <= liqPrice && liqPrice > 0 { + trigger = true + execPrice = liqPrice + } + } else { + if price >= liqPrice && liqPrice > 0 { + trigger = true + execPrice = liqPrice + } + } + if !trigger { + continue + } + + realized, fee, finalPrice, err := r.account.Close(pos.Symbol, pos.Side, pos.Quantity, execPrice) + if err != nil { + return nil, "", err + } + + noteBuilder.WriteString(fmt.Sprintf("%s %s @ %.4f; ", pos.Symbol, pos.Side, finalPrice)) + + evt := TradeEvent{ + Timestamp: ts, + Symbol: pos.Symbol, + Action: "liquidated", + Side: pos.Side, + Quantity: pos.Quantity, + Price: finalPrice, + Fee: fee, + Slippage: 0, + OrderValue: finalPrice * pos.Quantity, + RealizedPnL: realized - fee, + Leverage: pos.Leverage, + Cycle: cycle, + PositionAfter: 0, + LiquidationFlag: true, + Note: fmt.Sprintf("forced liquidation at %.4f", finalPrice), + } + events = append(events, evt) + } + + if len(events) == 0 { + return events, "", nil + } + + note := strings.TrimSuffix(noteBuilder.String(), "; ") + + r.stateMu.Lock() + r.state.Liquidated = true + r.state.LiquidationNote = note + r.stateMu.Unlock() + + return events, note, nil +} + +func barVWAP(k market.Kline) float64 { + values := []float64{k.Open, k.High, k.Low, k.Close} + sum := 0.0 + count := 0.0 + for _, v := range values { + if v > 0 { + sum += v + count++ + } + } + if count == 0 { + return 0 + } + return sum / count +} diff --git a/cmd/lighter_test/main.go b/cmd/lighter_test/main.go deleted file mode 100644 index 6f896a23..00000000 --- a/cmd/lighter_test/main.go +++ /dev/null @@ -1,233 +0,0 @@ -// Lighter API Authentication Test Tool -// Usage: go run cmd/lighter_test/main.go -wallet=0x... -apikey=... [-testnet] -package main - -import ( - "context" - "encoding/json" - "flag" - "fmt" - "io" - "net/http" - "net/url" - "os" - "time" - - lighterClient "github.com/elliottech/lighter-go/client" - lighterHTTP "github.com/elliottech/lighter-go/client/http" -) - -func main() { - // Parse command line flags - walletAddr := flag.String("wallet", "", "Ethereum wallet address") - apiKeyPrivateKey := flag.String("apikey", "", "API key private key (40 bytes hex)") - apiKeyIndex := flag.Int("apikeyindex", 0, "API key index (0-255)") - testnet := flag.Bool("testnet", false, "Use testnet instead of mainnet") - flag.Parse() - - if *walletAddr == "" || *apiKeyPrivateKey == "" { - fmt.Println("Usage: go run cmd/lighter_test/main.go -wallet=0x... -apikey=...") - fmt.Println("Options:") - fmt.Println(" -wallet Ethereum wallet address (required)") - fmt.Println(" -apikey API key private key, 40 bytes hex (required)") - fmt.Println(" -apikeyindex API key index, 0-255 (default: 0)") - fmt.Println(" -testnet Use testnet instead of mainnet") - os.Exit(1) - } - - fmt.Println("=== Lighter API Authentication Test ===") - fmt.Printf("Wallet: %s\n", *walletAddr) - fmt.Printf("API Key Index: %d\n", *apiKeyIndex) - fmt.Printf("Testnet: %v\n", *testnet) - fmt.Println() - - // Determine base URL - baseURL := "https://mainnet.zklighter.elliot.ai" - chainID := uint32(304) - if *testnet { - baseURL = "https://testnet.zklighter.elliot.ai" - chainID = uint32(300) - } - - // Create HTTP client - httpClient := lighterHTTP.NewClient(baseURL) - client := &http.Client{Timeout: 30 * time.Second} - - // Step 1: Get account info - fmt.Println("Step 1: Getting account info...") - accountInfo, err := getAccountByL1Address(client, baseURL, *walletAddr) - if err != nil { - fmt.Printf("ERROR: Failed to get account info: %v\n", err) - os.Exit(1) - } - fmt.Printf("SUCCESS: Account index = %d\n\n", accountInfo.AccountIndex) - - // Step 2: Create TxClient - fmt.Println("Step 2: Creating TxClient...") - txClient, err := lighterClient.NewTxClient( - httpClient, - *apiKeyPrivateKey, - accountInfo.AccountIndex, - uint8(*apiKeyIndex), - chainID, - ) - if err != nil { - fmt.Printf("ERROR: Failed to create TxClient: %v\n", err) - os.Exit(1) - } - fmt.Println("SUCCESS: TxClient created\n") - - // Step 3: Generate auth token - fmt.Println("Step 3: Generating auth token...") - deadline := time.Now().Add(1 * time.Hour) - authToken, err := txClient.GetAuthToken(deadline) - if err != nil { - fmt.Printf("ERROR: Failed to generate auth token: %v\n", err) - os.Exit(1) - } - fmt.Printf("SUCCESS: Auth token generated\n") - fmt.Printf("Token: %s...\n", authToken[:min(50, len(authToken))]) - fmt.Printf("Valid until: %s\n\n", deadline.Format(time.RFC3339)) - - // Step 4: Test GetActiveOrders API with auth query parameter - fmt.Println("Step 4: Testing GetActiveOrders API...") - encodedAuth := url.QueryEscape(authToken) - endpoint := fmt.Sprintf("%s/api/v1/accountActiveOrders?account_index=%d&market_id=0&auth=%s", - baseURL, accountInfo.AccountIndex, encodedAuth) - - fmt.Printf("Endpoint: %s...\n", endpoint[:min(120, len(endpoint))]) - - req, err := http.NewRequest("GET", endpoint, nil) - if err != nil { - fmt.Printf("ERROR: Failed to create request: %v\n", err) - os.Exit(1) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := client.Do(req) - if err != nil { - fmt.Printf("ERROR: Request failed: %v\n", err) - os.Exit(1) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - fmt.Printf("Status: %d\n", resp.StatusCode) - fmt.Printf("Response: %s\n\n", string(body)) - - // Parse response - var apiResp struct { - Code int `json:"code"` - Message string `json:"message"` - Orders []struct { - OrderID string `json:"order_id"` - Side string `json:"side"` - Type string `json:"type"` - Price string `json:"price"` - } `json:"orders"` - } - if err := json.Unmarshal(body, &apiResp); err != nil { - fmt.Printf("ERROR: Failed to parse response: %v\n", err) - os.Exit(1) - } - - if apiResp.Code != 200 { - fmt.Printf("API ERROR: code=%d, message=%s\n", apiResp.Code, apiResp.Message) - fmt.Println("\n=== DIAGNOSTIC INFO ===") - fmt.Println("If you see 'invalid signature', possible causes:") - fmt.Println("1. API key is not registered on-chain") - fmt.Println("2. API key private key is incorrect") - fmt.Println("3. API key index is wrong") - fmt.Println("4. Account index mismatch") - fmt.Println("\nTo fix:") - fmt.Println("- Go to app.lighter.xyz and register/verify your API key") - fmt.Println("- Make sure you're using the correct API key private key") - os.Exit(1) - } - - fmt.Printf("SUCCESS: Retrieved %d orders\n", len(apiResp.Orders)) - for i, order := range apiResp.Orders { - if i >= 5 { - fmt.Printf("... and %d more orders\n", len(apiResp.Orders)-5) - break - } - fmt.Printf(" Order %s: %s %s @ %s\n", order.OrderID, order.Side, order.Type, order.Price) - } - - // Step 5: Test GetTrades API (also needs auth) - fmt.Println("\nStep 5: Testing GetTrades API...") - tradesEndpoint := fmt.Sprintf("%s/api/v1/trades?account_index=%d&sort_by=timestamp&sort_dir=desc&limit=5&auth=%s", - baseURL, accountInfo.AccountIndex, encodedAuth) - - tradesReq, _ := http.NewRequest("GET", tradesEndpoint, nil) - tradesResp, err := client.Do(tradesReq) - if err != nil { - fmt.Printf("ERROR: Trades request failed: %v\n", err) - } else { - defer tradesResp.Body.Close() - tradesBody, _ := io.ReadAll(tradesResp.Body) - fmt.Printf("Status: %d\n", tradesResp.StatusCode) - if tradesResp.StatusCode == 200 { - fmt.Println("SUCCESS: GetTrades API working") - } else { - fmt.Printf("Response: %s\n", string(tradesBody)) - } - } - - fmt.Println("\n=== ALL TESTS PASSED ===") -} - -// AccountInfo represents Lighter account information -type AccountInfo struct { - AccountIndex int64 `json:"account_index"` - L1Address string `json:"l1_address"` -} - -// getAccountByL1Address gets account info by L1 wallet address -func getAccountByL1Address(client *http.Client, baseURL, walletAddr string) (*AccountInfo, error) { - endpoint := fmt.Sprintf("%s/api/v1/account?by=l1_address&value=%s", baseURL, walletAddr) - - req, err := http.NewRequest("GET", endpoint, nil) - if err != nil { - return nil, err - } - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - req = req.WithContext(ctx) - - resp, err := client.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - // Parse response - can be in "accounts" or "sub_accounts" field - var apiResp struct { - Code int `json:"code"` - Message string `json:"message"` - Accounts []AccountInfo `json:"accounts"` - SubAccounts []AccountInfo `json:"sub_accounts"` - } - - if err := json.Unmarshal(body, &apiResp); err != nil { - return nil, fmt.Errorf("failed to parse response: %w, body: %s", err, string(body)) - } - - // Check main accounts first - if len(apiResp.Accounts) > 0 { - return &apiResp.Accounts[0], nil - } - - // Check sub-accounts - if len(apiResp.SubAccounts) > 0 { - return &apiResp.SubAccounts[0], nil - } - - return nil, fmt.Errorf("no account found for address: %s", walletAddr) -} diff --git a/config/config.go b/config/config.go index 56ad3601..1332ab3c 100644 --- a/config/config.go +++ b/config/config.go @@ -1,7 +1,7 @@ package config import ( - "nofx/experience" + "nofx/telemetry" "nofx/mcp" "os" "strconv" @@ -122,11 +122,11 @@ func Init() { global = cfg // Initialize experience improvement (installation ID will be set after database init) - experience.Init(cfg.ExperienceImprovement, "") + telemetry.Init(cfg.ExperienceImprovement, "") // Set up AI token usage tracking callback mcp.TokenUsageCallback = func(usage mcp.TokenUsage) { - experience.TrackAIUsage(experience.AIUsageEvent{ + telemetry.TrackAIUsage(telemetry.AIUsageEvent{ ModelProvider: usage.Provider, ModelName: usage.Model, InputTokens: usage.PromptTokens, diff --git a/kernel/engine.go b/kernel/engine.go index 6a309e55..79a5edd7 100644 --- a/kernel/engine.go +++ b/kernel/engine.go @@ -8,33 +8,14 @@ import ( "net/http" "nofx/logger" "nofx/market" - "nofx/mcp" "nofx/provider/hyperliquid" "nofx/provider/nofxos" "nofx/security" "nofx/store" - "regexp" "strings" "time" ) -// ============================================================================ -// Pre-compiled regular expressions (performance optimization) -// ============================================================================ - -var ( - // Safe regex: precisely match ```json code blocks - reJSONFence = regexp.MustCompile(`(?is)` + "```json\\s*(\\[\\s*\\{.*?\\}\\s*\\])\\s*```") - reJSONArray = regexp.MustCompile(`(?is)\[\s*\{.*?\}\s*\]`) - reArrayHead = regexp.MustCompile(`^\[\s*\{`) - reArrayOpenSpace = regexp.MustCompile(`^\[\s+\{`) - reInvisibleRunes = regexp.MustCompile("[\u200B\u200C\u200D\uFEFF]") - - // XML tag extraction (supports any characters in reasoning chain) - reReasoningTag = regexp.MustCompile(`(?s)(.*?)`) - reDecisionTag = regexp.MustCompile(`(?s)(.*?)`) -) - // ============================================================================ // Type Definitions // ============================================================================ @@ -108,25 +89,25 @@ type RecentOrder struct { // Context trading context (complete information passed to AI) type Context struct { - CurrentTime string `json:"current_time"` - RuntimeMinutes int `json:"runtime_minutes"` - CallCount int `json:"call_count"` - Account AccountInfo `json:"account"` - Positions []PositionInfo `json:"positions"` - CandidateCoins []CandidateCoin `json:"candidate_coins"` - PromptVariant string `json:"prompt_variant,omitempty"` - TradingStats *TradingStats `json:"trading_stats,omitempty"` - RecentOrders []RecentOrder `json:"recent_orders,omitempty"` - MarketDataMap map[string]*market.Data `json:"-"` - MultiTFMarket map[string]map[string]*market.Data `json:"-"` - OITopDataMap map[string]*OITopData `json:"-"` - QuantDataMap map[string]*QuantData `json:"-"` - OIRankingData *nofxos.OIRankingData `json:"-"` // Market-wide OI ranking data - NetFlowRankingData *nofxos.NetFlowRankingData `json:"-"` // Market-wide fund flow ranking data - PriceRankingData *nofxos.PriceRankingData `json:"-"` // Market-wide price gainers/losers - BTCETHLeverage int `json:"-"` - AltcoinLeverage int `json:"-"` - Timeframes []string `json:"-"` + CurrentTime string `json:"current_time"` + RuntimeMinutes int `json:"runtime_minutes"` + CallCount int `json:"call_count"` + Account AccountInfo `json:"account"` + Positions []PositionInfo `json:"positions"` + CandidateCoins []CandidateCoin `json:"candidate_coins"` + PromptVariant string `json:"prompt_variant,omitempty"` + TradingStats *TradingStats `json:"trading_stats,omitempty"` + RecentOrders []RecentOrder `json:"recent_orders,omitempty"` + MarketDataMap map[string]*market.Data `json:"-"` + MultiTFMarket map[string]map[string]*market.Data `json:"-"` + OITopDataMap map[string]*OITopData `json:"-"` + QuantDataMap map[string]*QuantData `json:"-"` + OIRankingData *nofxos.OIRankingData `json:"-"` // Market-wide OI ranking data + NetFlowRankingData *nofxos.NetFlowRankingData `json:"-"` // Market-wide fund flow ranking data + PriceRankingData *nofxos.PriceRankingData `json:"-"` // Market-wide price gainers/losers + BTCETHLeverage int `json:"-"` + AltcoinLeverage int `json:"-"` + Timeframes []string `json:"-"` } // Decision AI trading decision @@ -242,173 +223,6 @@ func (e *StrategyEngine) GetConfig() *store.StrategyConfig { return e.config } -// ============================================================================ -// Entry Functions - Main API -// ============================================================================ - -// GetFullDecision gets AI's complete trading decision (batch analysis of all coins and positions) -// Uses default strategy configuration - for production use GetFullDecisionWithStrategy with explicit config -func GetFullDecision(ctx *Context, mcpClient mcp.AIClient) (*FullDecision, error) { - defaultConfig := store.GetDefaultStrategyConfig("en") - engine := NewStrategyEngine(&defaultConfig) - return GetFullDecisionWithStrategy(ctx, mcpClient, engine, "") -} - -// GetFullDecisionWithStrategy uses StrategyEngine to get AI decision (unified prompt generation) -func GetFullDecisionWithStrategy(ctx *Context, mcpClient mcp.AIClient, engine *StrategyEngine, variant string) (*FullDecision, error) { - if ctx == nil { - return nil, fmt.Errorf("context is nil") - } - if engine == nil { - defaultConfig := store.GetDefaultStrategyConfig("en") - engine = NewStrategyEngine(&defaultConfig) - } - - // 1. Fetch market data using strategy config - if len(ctx.MarketDataMap) == 0 { - if err := fetchMarketDataWithStrategy(ctx, engine); err != nil { - return nil, fmt.Errorf("failed to fetch market data: %w", err) - } - } - - // Ensure OITopDataMap is initialized - if ctx.OITopDataMap == nil { - ctx.OITopDataMap = make(map[string]*OITopData) - oiPositions, err := engine.nofxosClient.GetOITopPositions() - if err == nil { - for _, pos := range oiPositions { - ctx.OITopDataMap[pos.Symbol] = &OITopData{ - Rank: pos.Rank, - OIDeltaPercent: pos.OIDeltaPercent, - OIDeltaValue: pos.OIDeltaValue, - PriceDeltaPercent: pos.PriceDeltaPercent, - } - } - } - } - - // 2. Build System Prompt using strategy engine - riskConfig := engine.GetRiskControlConfig() - systemPrompt := engine.BuildSystemPrompt(ctx.Account.TotalEquity, variant) - - // 3. Build User Prompt using strategy engine - userPrompt := engine.BuildUserPrompt(ctx) - - // 4. Call AI API - aiCallStart := time.Now() - aiResponse, err := mcpClient.CallWithMessages(systemPrompt, userPrompt) - aiCallDuration := time.Since(aiCallStart) - if err != nil { - return nil, fmt.Errorf("AI API call failed: %w", err) - } - - // 5. Parse AI response - decision, err := parseFullDecisionResponse( - aiResponse, - ctx.Account.TotalEquity, - riskConfig.BTCETHMaxLeverage, - riskConfig.AltcoinMaxLeverage, - riskConfig.BTCETHMaxPositionValueRatio, - riskConfig.AltcoinMaxPositionValueRatio, - ) - - if decision != nil { - decision.Timestamp = time.Now() - decision.SystemPrompt = systemPrompt - decision.UserPrompt = userPrompt - decision.AIRequestDurationMs = aiCallDuration.Milliseconds() - decision.RawResponse = aiResponse - } - - if err != nil { - return decision, fmt.Errorf("failed to parse AI response: %w", err) - } - - return decision, nil -} - -// ============================================================================ -// Market Data Fetching -// ============================================================================ - -// fetchMarketDataWithStrategy fetches market data using strategy config (multiple timeframes) -func fetchMarketDataWithStrategy(ctx *Context, engine *StrategyEngine) error { - config := engine.GetConfig() - ctx.MarketDataMap = make(map[string]*market.Data) - - timeframes := config.Indicators.Klines.SelectedTimeframes - primaryTimeframe := config.Indicators.Klines.PrimaryTimeframe - klineCount := config.Indicators.Klines.PrimaryCount - - // Compatible with old configuration - if len(timeframes) == 0 { - if primaryTimeframe != "" { - timeframes = append(timeframes, primaryTimeframe) - } else { - timeframes = append(timeframes, "3m") - } - if config.Indicators.Klines.LongerTimeframe != "" { - timeframes = append(timeframes, config.Indicators.Klines.LongerTimeframe) - } - } - if primaryTimeframe == "" { - primaryTimeframe = timeframes[0] - } - if klineCount <= 0 { - klineCount = 30 - } - - logger.Infof("📊 Strategy timeframes: %v, Primary: %s, Kline count: %d", timeframes, primaryTimeframe, klineCount) - - // 1. First fetch data for position coins (must fetch) - for _, pos := range ctx.Positions { - data, err := market.GetWithTimeframes(pos.Symbol, timeframes, primaryTimeframe, klineCount) - if err != nil { - logger.Infof("⚠️ Failed to fetch market data for position %s: %v", pos.Symbol, err) - continue - } - ctx.MarketDataMap[pos.Symbol] = data - } - - // 2. Fetch data for all candidate coins - positionSymbols := make(map[string]bool) - for _, pos := range ctx.Positions { - positionSymbols[pos.Symbol] = true - } - - const minOIThresholdMillions = 15.0 // 15M USD minimum open interest value - - for _, coin := range ctx.CandidateCoins { - if _, exists := ctx.MarketDataMap[coin.Symbol]; exists { - continue - } - - data, err := market.GetWithTimeframes(coin.Symbol, timeframes, primaryTimeframe, klineCount) - if err != nil { - logger.Infof("⚠️ Failed to fetch market data for %s: %v", coin.Symbol, err) - continue - } - - // Liquidity filter (skip for xyz dex assets - they don't have OI data from Binance) - isExistingPosition := positionSymbols[coin.Symbol] - isXyzAsset := market.IsXyzDexAsset(coin.Symbol) - if !isExistingPosition && !isXyzAsset && data.OpenInterest != nil && data.CurrentPrice > 0 { - oiValue := data.OpenInterest.Latest * data.CurrentPrice - oiValueInMillions := oiValue / 1_000_000 - if oiValueInMillions < minOIThresholdMillions { - logger.Infof("⚠️ %s OI value too low (%.2fM USD < %.1fM), skipping coin", - coin.Symbol, oiValueInMillions, minOIThresholdMillions) - continue - } - } - - ctx.MarketDataMap[coin.Symbol] = data - } - - logger.Infof("📊 Successfully fetched multi-timeframe market data for %d coins", len(ctx.MarketDataMap)) - return nil -} - // ============================================================================ // Candidate Coins // ============================================================================ @@ -1022,1067 +836,6 @@ func (e *StrategyEngine) FetchPriceRankingData() *nofxos.PriceRankingData { return data } -// ============================================================================ -// Prompt Building - System Prompt -// ============================================================================ - -// BuildSystemPrompt builds System Prompt according to strategy configuration -func (e *StrategyEngine) BuildSystemPrompt(accountEquity float64, variant string) string { - var sb strings.Builder - riskControl := e.config.RiskControl - promptSections := e.config.PromptSections - - // 0. Data Dictionary & Schema (ensure AI understands all fields) - lang := e.GetLanguage() - schemaPrompt := GetSchemaPrompt(lang) - sb.WriteString(schemaPrompt) - sb.WriteString("\n\n") - sb.WriteString("---\n\n") - - // 1. Role definition (editable) - if promptSections.RoleDefinition != "" { - sb.WriteString(promptSections.RoleDefinition) - sb.WriteString("\n\n") - } else { - sb.WriteString("# You are a professional cryptocurrency trading AI\n\n") - sb.WriteString("Your task is to make trading decisions based on provided market data.\n\n") - } - - // 2. Trading mode variant - switch strings.ToLower(strings.TrimSpace(variant)) { - case "aggressive": - sb.WriteString("## Mode: Aggressive\n- Prioritize capturing trend breakouts, can build positions in batches when confidence ≥ 70\n- Allow higher positions, but must strictly set stop-loss and explain risk-reward ratio\n\n") - case "conservative": - sb.WriteString("## Mode: Conservative\n- Only open positions when multiple signals resonate\n- Prioritize cash preservation, must pause for multiple periods after consecutive losses\n\n") - case "scalping": - sb.WriteString("## Mode: Scalping\n- Focus on short-term momentum, smaller profit targets but require quick action\n- If price doesn't move as expected within two bars, immediately reduce position or stop-loss\n\n") - } - - // 3. Hard constraints (risk control) - btcEthPosValueRatio := riskControl.BTCETHMaxPositionValueRatio - if btcEthPosValueRatio <= 0 { - btcEthPosValueRatio = 5.0 - } - altcoinPosValueRatio := riskControl.AltcoinMaxPositionValueRatio - if altcoinPosValueRatio <= 0 { - altcoinPosValueRatio = 1.0 - } - - sb.WriteString("# Hard Constraints (Risk Control)\n\n") - sb.WriteString("## CODE ENFORCED (Backend validation, cannot be bypassed):\n") - sb.WriteString(fmt.Sprintf("- Max Positions: %d coins simultaneously\n", riskControl.MaxPositions)) - sb.WriteString(fmt.Sprintf("- Position Value Limit (Altcoins): max %.0f USDT (= equity %.0f × %.1fx)\n", - accountEquity*altcoinPosValueRatio, accountEquity, altcoinPosValueRatio)) - sb.WriteString(fmt.Sprintf("- Position Value Limit (BTC/ETH): max %.0f USDT (= equity %.0f × %.1fx)\n", - accountEquity*btcEthPosValueRatio, accountEquity, btcEthPosValueRatio)) - sb.WriteString(fmt.Sprintf("- Max Margin Usage: ≤%.0f%%\n", riskControl.MaxMarginUsage*100)) - sb.WriteString(fmt.Sprintf("- Min Position Size: ≥%.0f USDT\n\n", riskControl.MinPositionSize)) - - sb.WriteString("## AI GUIDED (Recommended, you should follow):\n") - sb.WriteString(fmt.Sprintf("- Trading Leverage: Altcoins max %dx | BTC/ETH max %dx\n", - riskControl.AltcoinMaxLeverage, riskControl.BTCETHMaxLeverage)) - sb.WriteString(fmt.Sprintf("- Risk-Reward Ratio: ≥1:%.1f (take_profit / stop_loss)\n", riskControl.MinRiskRewardRatio)) - sb.WriteString(fmt.Sprintf("- Min Confidence: ≥%d to open position\n\n", riskControl.MinConfidence)) - - // Position sizing guidance - sb.WriteString("## Position Sizing Guidance\n") - sb.WriteString("Calculate `position_size_usd` based on your confidence and the Position Value Limits above:\n") - sb.WriteString("- High confidence (≥85): Use 80-100%% of max position value limit\n") - sb.WriteString("- Medium confidence (70-84): Use 50-80%% of max position value limit\n") - sb.WriteString("- Low confidence (60-69): Use 30-50%% of max position value limit\n") - sb.WriteString(fmt.Sprintf("- Example: With equity %.0f and BTC/ETH ratio %.1fx, max is %.0f USDT\n", - accountEquity, btcEthPosValueRatio, accountEquity*btcEthPosValueRatio)) - sb.WriteString("- **DO NOT** just use available_balance as position_size_usd. Use the Position Value Limits!\n\n") - - // 4. Trading frequency (editable) - if promptSections.TradingFrequency != "" { - sb.WriteString(promptSections.TradingFrequency) - sb.WriteString("\n\n") - } else { - sb.WriteString("# ⏱️ Trading Frequency Awareness\n\n") - sb.WriteString("- Excellent traders: 2-4 trades/day ≈ 0.1-0.2 trades/hour\n") - sb.WriteString("- >2 trades/hour = Overtrading\n") - sb.WriteString("- Single position hold time ≥ 30-60 minutes\n") - sb.WriteString("If you find yourself trading every period → standards too low; if closing positions < 30 minutes → too impatient.\n\n") - } - - // 5. Entry standards (editable) - if promptSections.EntryStandards != "" { - sb.WriteString(promptSections.EntryStandards) - sb.WriteString("\n\nYou have the following indicator data:\n") - e.writeAvailableIndicators(&sb) - sb.WriteString(fmt.Sprintf("\n**Confidence ≥ %d** required to open positions.\n\n", riskControl.MinConfidence)) - } else { - sb.WriteString("# 🎯 Entry Standards (Strict)\n\n") - sb.WriteString("Only open positions when multiple signals resonate. You have:\n") - e.writeAvailableIndicators(&sb) - sb.WriteString(fmt.Sprintf("\nFeel free to use any effective analysis method, but **confidence ≥ %d** required to open positions; avoid low-quality behaviors such as single indicators, contradictory signals, sideways consolidation, reopening immediately after closing, etc.\n\n", riskControl.MinConfidence)) - } - - // 6. Decision process (editable) - if promptSections.DecisionProcess != "" { - sb.WriteString(promptSections.DecisionProcess) - sb.WriteString("\n\n") - } else { - sb.WriteString("# 📋 Decision Process\n\n") - sb.WriteString("1. Check positions → Should we take profit/stop-loss\n") - sb.WriteString("2. Scan candidate coins + multi-timeframe → Are there strong signals\n") - sb.WriteString("3. Write chain of thought first, then output structured JSON\n\n") - } - - // 7. Output format - sb.WriteString("# Output Format (Strictly Follow)\n\n") - sb.WriteString("**Must use XML tags and to separate chain of thought and decision JSON, avoiding parsing errors**\n\n") - sb.WriteString("## Format Requirements\n\n") - sb.WriteString("\n") - sb.WriteString("Your chain of thought analysis...\n") - sb.WriteString("- Briefly analyze your thinking process \n") - sb.WriteString("\n\n") - sb.WriteString("\n") - sb.WriteString("Step 2: JSON decision array\n\n") - sb.WriteString("```json\n[\n") - // Use the actual configured position value ratio for BTC/ETH in the example - examplePositionSize := accountEquity * btcEthPosValueRatio - sb.WriteString(fmt.Sprintf(" {\"symbol\": \"BTCUSDT\", \"action\": \"open_short\", \"leverage\": %d, \"position_size_usd\": %.0f, \"stop_loss\": 97000, \"take_profit\": 91000, \"confidence\": 85, \"risk_usd\": 300},\n", - riskControl.BTCETHMaxLeverage, examplePositionSize)) - sb.WriteString(" {\"symbol\": \"ETHUSDT\", \"action\": \"close_long\"}\n") - sb.WriteString("]\n```\n") - sb.WriteString("\n\n") - sb.WriteString("## Field Description\n\n") - sb.WriteString("- `action`: open_long | open_short | close_long | close_short | hold | wait\n") - sb.WriteString(fmt.Sprintf("- `confidence`: 0-100 (opening recommended ≥ %d)\n", riskControl.MinConfidence)) - sb.WriteString("- Required when opening: leverage, position_size_usd, stop_loss, take_profit, confidence, risk_usd\n") - sb.WriteString("- **IMPORTANT**: All numeric values must be calculated numbers, NOT formulas/expressions (e.g., use `27.76` not `3000 * 0.01`)\n\n") - - // 8. Custom Prompt - if e.config.CustomPrompt != "" { - sb.WriteString("# 📌 Personalized Trading Strategy\n\n") - sb.WriteString(e.config.CustomPrompt) - sb.WriteString("\n\n") - sb.WriteString("Note: The above personalized strategy is a supplement to the basic rules and cannot violate the basic risk control principles.\n") - } - - return sb.String() -} - -func (e *StrategyEngine) writeAvailableIndicators(sb *strings.Builder) { - indicators := e.config.Indicators - kline := indicators.Klines - - sb.WriteString(fmt.Sprintf("- %s price series", kline.PrimaryTimeframe)) - if kline.EnableMultiTimeframe { - sb.WriteString(fmt.Sprintf(" + %s K-line series\n", kline.LongerTimeframe)) - } else { - sb.WriteString("\n") - } - - if indicators.EnableEMA { - sb.WriteString("- EMA indicators") - if len(indicators.EMAPeriods) > 0 { - sb.WriteString(fmt.Sprintf(" (periods: %v)", indicators.EMAPeriods)) - } - sb.WriteString("\n") - } - - if indicators.EnableMACD { - sb.WriteString("- MACD indicators\n") - } - - if indicators.EnableRSI { - sb.WriteString("- RSI indicators") - if len(indicators.RSIPeriods) > 0 { - sb.WriteString(fmt.Sprintf(" (periods: %v)", indicators.RSIPeriods)) - } - sb.WriteString("\n") - } - - if indicators.EnableATR { - sb.WriteString("- ATR indicators") - if len(indicators.ATRPeriods) > 0 { - sb.WriteString(fmt.Sprintf(" (periods: %v)", indicators.ATRPeriods)) - } - sb.WriteString("\n") - } - - if indicators.EnableBOLL { - sb.WriteString("- Bollinger Bands (BOLL) - Upper/Middle/Lower bands") - if len(indicators.BOLLPeriods) > 0 { - sb.WriteString(fmt.Sprintf(" (periods: %v)", indicators.BOLLPeriods)) - } - sb.WriteString("\n") - } - - if indicators.EnableVolume { - sb.WriteString("- Volume data\n") - } - - if indicators.EnableOI { - sb.WriteString("- Open Interest (OI) data\n") - } - - if indicators.EnableFundingRate { - sb.WriteString("- Funding rate\n") - } - - if len(e.config.CoinSource.StaticCoins) > 0 || e.config.CoinSource.UseAI500 || e.config.CoinSource.UseOITop { - sb.WriteString("- AI500 / OI_Top filter tags (if available)\n") - } - - if indicators.EnableQuantData { - sb.WriteString("- Quantitative data (institutional/retail fund flow, position changes, multi-period price changes)\n") - } -} - -// ============================================================================ -// Prompt Building - User Prompt -// ============================================================================ - -// BuildUserPrompt builds User Prompt based on strategy configuration -func (e *StrategyEngine) BuildUserPrompt(ctx *Context) string { - var sb strings.Builder - - // System status - sb.WriteString(fmt.Sprintf("Time: %s | Period: #%d | Runtime: %d minutes\n\n", - ctx.CurrentTime, ctx.CallCount, ctx.RuntimeMinutes)) - - // BTC market - if btcData, hasBTC := ctx.MarketDataMap["BTCUSDT"]; hasBTC { - sb.WriteString(fmt.Sprintf("BTC: %.2f (1h: %+.2f%%, 4h: %+.2f%%) | MACD: %.4f | RSI: %.2f\n\n", - btcData.CurrentPrice, btcData.PriceChange1h, btcData.PriceChange4h, - btcData.CurrentMACD, btcData.CurrentRSI7)) - } - - // Account information - sb.WriteString(fmt.Sprintf("Account: Equity %.2f | Balance %.2f (%.1f%%) | PnL %+.2f%% | Margin %.1f%% | Positions %d\n\n", - ctx.Account.TotalEquity, - ctx.Account.AvailableBalance, - (ctx.Account.AvailableBalance/ctx.Account.TotalEquity)*100, - ctx.Account.TotalPnLPct, - ctx.Account.MarginUsedPct, - ctx.Account.PositionCount)) - - // Recently completed orders (placed before positions to ensure visibility) - if len(ctx.RecentOrders) > 0 { - sb.WriteString("## Recent Completed Trades\n") - for i, order := range ctx.RecentOrders { - resultStr := "Profit" - if order.RealizedPnL < 0 { - resultStr = "Loss" - } - sb.WriteString(fmt.Sprintf("%d. %s %s | Entry %.4f Exit %.4f | %s: %+.2f USDT (%+.2f%%) | %s→%s (%s)\n", - i+1, order.Symbol, order.Side, - order.EntryPrice, order.ExitPrice, - resultStr, order.RealizedPnL, order.PnLPct, - order.EntryTime, order.ExitTime, order.HoldDuration)) - } - sb.WriteString("\n") - } - - // Historical trading statistics (helps AI understand past performance) - if ctx.TradingStats != nil && ctx.TradingStats.TotalTrades > 0 { - // Get language from strategy config - lang := e.GetLanguage() - - // Win/Loss ratio - var winLossRatio float64 - if ctx.TradingStats.AvgLoss > 0 { - winLossRatio = ctx.TradingStats.AvgWin / ctx.TradingStats.AvgLoss - } - - if lang == LangChinese { - sb.WriteString("## 历史交易统计\n") - sb.WriteString(fmt.Sprintf("总交易: %d 笔 | 盈利因子: %.2f | 夏普比率: %.2f | 盈亏比: %.2f\n", - ctx.TradingStats.TotalTrades, - ctx.TradingStats.ProfitFactor, - ctx.TradingStats.SharpeRatio, - winLossRatio)) - sb.WriteString(fmt.Sprintf("总盈亏: %+.2f USDT | 平均盈利: +%.2f | 平均亏损: -%.2f | 最大回撤: %.1f%%\n", - ctx.TradingStats.TotalPnL, - ctx.TradingStats.AvgWin, - ctx.TradingStats.AvgLoss, - ctx.TradingStats.MaxDrawdownPct)) - - // Performance hints based on profit factor, sharpe, and drawdown - if ctx.TradingStats.ProfitFactor >= 1.5 && ctx.TradingStats.SharpeRatio >= 1 { - sb.WriteString("表现: 良好 - 保持当前策略\n") - } else if ctx.TradingStats.ProfitFactor < 1 { - sb.WriteString("表现: 需改进 - 提高盈亏比,优化止盈止损\n") - } else if ctx.TradingStats.MaxDrawdownPct > 30 { - sb.WriteString("表现: 风险偏高 - 减少仓位,控制回撤\n") - } else { - sb.WriteString("表现: 正常 - 有优化空间\n") - } - } else { - sb.WriteString("## Historical Trading Statistics\n") - sb.WriteString(fmt.Sprintf("Total Trades: %d | Profit Factor: %.2f | Sharpe: %.2f | Win/Loss Ratio: %.2f\n", - ctx.TradingStats.TotalTrades, - ctx.TradingStats.ProfitFactor, - ctx.TradingStats.SharpeRatio, - winLossRatio)) - sb.WriteString(fmt.Sprintf("Total PnL: %+.2f USDT | Avg Win: +%.2f | Avg Loss: -%.2f | Max Drawdown: %.1f%%\n", - ctx.TradingStats.TotalPnL, - ctx.TradingStats.AvgWin, - ctx.TradingStats.AvgLoss, - ctx.TradingStats.MaxDrawdownPct)) - - // Performance hints based on profit factor, sharpe, and drawdown - if ctx.TradingStats.ProfitFactor >= 1.5 && ctx.TradingStats.SharpeRatio >= 1 { - sb.WriteString("Performance: GOOD - maintain current strategy\n") - } else if ctx.TradingStats.ProfitFactor < 1 { - sb.WriteString("Performance: NEEDS IMPROVEMENT - improve win/loss ratio, optimize TP/SL\n") - } else if ctx.TradingStats.MaxDrawdownPct > 30 { - sb.WriteString("Performance: HIGH RISK - reduce position size, control drawdown\n") - } else { - sb.WriteString("Performance: NORMAL - room for optimization\n") - } - } - sb.WriteString("\n") - } - - // Position information - if len(ctx.Positions) > 0 { - sb.WriteString("## Current Positions\n") - for i, pos := range ctx.Positions { - sb.WriteString(e.formatPositionInfo(i+1, pos, ctx)) - } - } else { - sb.WriteString("Current Positions: None\n\n") - } - - // Candidate coins (exclude coins already in positions to avoid duplicate data) - positionSymbols := make(map[string]bool) - for _, pos := range ctx.Positions { - // Normalize symbol to handle both "ETH" and "ETHUSDT" formats - normalizedSymbol := market.Normalize(pos.Symbol) - positionSymbols[normalizedSymbol] = true - } - - sb.WriteString(fmt.Sprintf("## Candidate Coins (%d coins)\n\n", len(ctx.MarketDataMap))) - displayedCount := 0 - for _, coin := range ctx.CandidateCoins { - // Skip if this coin is already a position (data already shown in positions section) - normalizedCoinSymbol := market.Normalize(coin.Symbol) - if positionSymbols[normalizedCoinSymbol] { - continue - } - - marketData, hasData := ctx.MarketDataMap[coin.Symbol] - if !hasData { - continue - } - displayedCount++ - - sourceTags := e.formatCoinSourceTag(coin.Sources) - sb.WriteString(fmt.Sprintf("### %d. %s%s\n\n", displayedCount, coin.Symbol, sourceTags)) - sb.WriteString(e.formatMarketData(marketData)) - - if ctx.QuantDataMap != nil { - if quantData, hasQuant := ctx.QuantDataMap[coin.Symbol]; hasQuant { - sb.WriteString(e.formatQuantData(quantData)) - } - } - sb.WriteString("\n") - } - sb.WriteString("\n") - - // Get language for market data formatting - nofxosLang := nofxos.LangEnglish - if e.GetLanguage() == LangChinese { - nofxosLang = nofxos.LangChinese - } - - // OI Ranking data (market-wide open interest changes) - if ctx.OIRankingData != nil { - sb.WriteString(nofxos.FormatOIRankingForAI(ctx.OIRankingData, nofxosLang)) - } - - // NetFlow Ranking data (market-wide fund flow) - if ctx.NetFlowRankingData != nil { - sb.WriteString(nofxos.FormatNetFlowRankingForAI(ctx.NetFlowRankingData, nofxosLang)) - } - - // Price Ranking data (market-wide gainers/losers) - if ctx.PriceRankingData != nil { - sb.WriteString(nofxos.FormatPriceRankingForAI(ctx.PriceRankingData, nofxosLang)) - } - - sb.WriteString("---\n\n") - sb.WriteString("Now please analyze and output your decision (Chain of Thought + JSON)\n") - - return sb.String() -} - -func (e *StrategyEngine) formatPositionInfo(index int, pos PositionInfo, ctx *Context) string { - var sb strings.Builder - - holdingDuration := "" - if pos.UpdateTime > 0 { - durationMs := time.Now().UnixMilli() - pos.UpdateTime - durationMin := durationMs / (1000 * 60) - if durationMin < 60 { - holdingDuration = fmt.Sprintf(" | Holding Duration %d min", durationMin) - } else { - durationHour := durationMin / 60 - durationMinRemainder := durationMin % 60 - holdingDuration = fmt.Sprintf(" | Holding Duration %dh %dm", durationHour, durationMinRemainder) - } - } - - positionValue := pos.Quantity * pos.MarkPrice - if positionValue < 0 { - positionValue = -positionValue - } - - sb.WriteString(fmt.Sprintf("%d. %s %s | Entry %.4f Current %.4f | Qty %.4f | Position Value %.2f USDT | PnL%+.2f%% | PnL Amount%+.2f USDT | Peak PnL%.2f%% | Leverage %dx | Margin %.0f | Liq Price %.4f%s\n\n", - index, pos.Symbol, strings.ToUpper(pos.Side), - pos.EntryPrice, pos.MarkPrice, pos.Quantity, positionValue, pos.UnrealizedPnLPct, pos.UnrealizedPnL, pos.PeakPnLPct, - pos.Leverage, pos.MarginUsed, pos.LiquidationPrice, holdingDuration)) - - if marketData, ok := ctx.MarketDataMap[pos.Symbol]; ok { - sb.WriteString(e.formatMarketData(marketData)) - - if ctx.QuantDataMap != nil { - if quantData, hasQuant := ctx.QuantDataMap[pos.Symbol]; hasQuant { - sb.WriteString(e.formatQuantData(quantData)) - } - } - sb.WriteString("\n") - } - - return sb.String() -} - -func (e *StrategyEngine) formatCoinSourceTag(sources []string) string { - if len(sources) > 1 { - // 多信号源组合 - hasAI500 := false - hasOITop := false - hasOILow := false - hasHyperAll := false - hasHyperMain := false - for _, s := range sources { - switch s { - case "ai500": - hasAI500 = true - case "oi_top": - hasOITop = true - case "oi_low": - hasOILow = true - case "hyper_all": - hasHyperAll = true - case "hyper_main": - hasHyperMain = true - } - } - if hasAI500 && hasOITop { - return " (AI500+OI_Top dual signal)" - } - if hasAI500 && hasOILow { - return " (AI500+OI_Low dual signal)" - } - if hasOITop && hasOILow { - return " (OI_Top+OI_Low)" - } - if hasHyperMain && hasAI500 { - return " (HyperMain+AI500)" - } - if hasHyperAll || hasHyperMain { - return " (Hyperliquid)" - } - return " (Multiple sources)" - } else if len(sources) == 1 { - switch sources[0] { - case "ai500": - return " (AI500)" - case "oi_top": - return " (OI_Top 持仓增加)" - case "oi_low": - return " (OI_Low 持仓减少)" - case "static": - return " (Manual selection)" - case "hyper_all": - return " (Hyperliquid All)" - case "hyper_main": - return " (Hyperliquid Top20)" - } - } - return "" -} - -// ============================================================================ -// Market Data Formatting -// ============================================================================ - -func (e *StrategyEngine) formatMarketData(data *market.Data) string { - var sb strings.Builder - indicators := e.config.Indicators - - // 明确标注币种 - sb.WriteString(fmt.Sprintf("=== %s Market Data ===\n\n", data.Symbol)) - sb.WriteString(fmt.Sprintf("current_price = %.4f", data.CurrentPrice)) - - if indicators.EnableEMA { - sb.WriteString(fmt.Sprintf(", current_ema20 = %.3f", data.CurrentEMA20)) - } - - if indicators.EnableMACD { - sb.WriteString(fmt.Sprintf(", current_macd = %.3f", data.CurrentMACD)) - } - - if indicators.EnableRSI { - sb.WriteString(fmt.Sprintf(", current_rsi7 = %.3f", data.CurrentRSI7)) - } - - sb.WriteString("\n\n") - - if indicators.EnableOI || indicators.EnableFundingRate { - sb.WriteString(fmt.Sprintf("Additional data for %s:\n\n", data.Symbol)) - - if indicators.EnableOI && data.OpenInterest != nil { - sb.WriteString(fmt.Sprintf("Open Interest: Latest: %.2f Average: %.2f\n\n", - data.OpenInterest.Latest, data.OpenInterest.Average)) - } - - if indicators.EnableFundingRate { - sb.WriteString(fmt.Sprintf("Funding Rate: %.2e\n\n", data.FundingRate)) - } - } - - if len(data.TimeframeData) > 0 { - timeframeOrder := []string{"1m", "3m", "5m", "15m", "30m", "1h", "2h", "4h", "6h", "8h", "12h", "1d", "3d", "1w"} - for _, tf := range timeframeOrder { - if tfData, ok := data.TimeframeData[tf]; ok { - sb.WriteString(fmt.Sprintf("=== %s Timeframe (oldest → latest) ===\n\n", strings.ToUpper(tf))) - e.formatTimeframeSeriesData(&sb, tfData, indicators) - } - } - } else { - // Compatible with old data format - if data.IntradaySeries != nil { - klineConfig := indicators.Klines - sb.WriteString(fmt.Sprintf("Intraday series (%s intervals, oldest → latest):\n\n", klineConfig.PrimaryTimeframe)) - - if len(data.IntradaySeries.MidPrices) > 0 { - sb.WriteString(fmt.Sprintf("Mid prices: %s\n\n", formatFloatSlice(data.IntradaySeries.MidPrices))) - } - - if indicators.EnableEMA && len(data.IntradaySeries.EMA20Values) > 0 { - sb.WriteString(fmt.Sprintf("EMA indicators (20-period): %s\n\n", formatFloatSlice(data.IntradaySeries.EMA20Values))) - } - - if indicators.EnableMACD && len(data.IntradaySeries.MACDValues) > 0 { - sb.WriteString(fmt.Sprintf("MACD indicators: %s\n\n", formatFloatSlice(data.IntradaySeries.MACDValues))) - } - - if indicators.EnableRSI { - if len(data.IntradaySeries.RSI7Values) > 0 { - sb.WriteString(fmt.Sprintf("RSI indicators (7-Period): %s\n\n", formatFloatSlice(data.IntradaySeries.RSI7Values))) - } - if len(data.IntradaySeries.RSI14Values) > 0 { - sb.WriteString(fmt.Sprintf("RSI indicators (14-Period): %s\n\n", formatFloatSlice(data.IntradaySeries.RSI14Values))) - } - } - - if indicators.EnableVolume && len(data.IntradaySeries.Volume) > 0 { - sb.WriteString(fmt.Sprintf("Volume: %s\n\n", formatFloatSlice(data.IntradaySeries.Volume))) - } - - if indicators.EnableATR { - sb.WriteString(fmt.Sprintf("3m ATR (14-period): %.3f\n\n", data.IntradaySeries.ATR14)) - } - } - - if data.LongerTermContext != nil && indicators.Klines.EnableMultiTimeframe { - sb.WriteString(fmt.Sprintf("Longer-term context (%s timeframe):\n\n", indicators.Klines.LongerTimeframe)) - - if indicators.EnableEMA { - sb.WriteString(fmt.Sprintf("20-Period EMA: %.3f vs. 50-Period EMA: %.3f\n\n", - data.LongerTermContext.EMA20, data.LongerTermContext.EMA50)) - } - - if indicators.EnableATR { - sb.WriteString(fmt.Sprintf("3-Period ATR: %.3f vs. 14-Period ATR: %.3f\n\n", - data.LongerTermContext.ATR3, data.LongerTermContext.ATR14)) - } - - if indicators.EnableVolume { - sb.WriteString(fmt.Sprintf("Current Volume: %.3f vs. Average Volume: %.3f\n\n", - data.LongerTermContext.CurrentVolume, data.LongerTermContext.AverageVolume)) - } - - if indicators.EnableMACD && len(data.LongerTermContext.MACDValues) > 0 { - sb.WriteString(fmt.Sprintf("MACD indicators: %s\n\n", formatFloatSlice(data.LongerTermContext.MACDValues))) - } - - if indicators.EnableRSI && len(data.LongerTermContext.RSI14Values) > 0 { - sb.WriteString(fmt.Sprintf("RSI indicators (14-Period): %s\n\n", formatFloatSlice(data.LongerTermContext.RSI14Values))) - } - } - } - - return sb.String() -} - -func (e *StrategyEngine) formatTimeframeSeriesData(sb *strings.Builder, data *market.TimeframeSeriesData, indicators store.IndicatorConfig) { - if len(data.Klines) > 0 { - sb.WriteString("Time(UTC) Open High Low Close Volume\n") - for i, k := range data.Klines { - t := time.Unix(k.Time/1000, 0).UTC() - timeStr := t.Format("01-02 15:04") - marker := "" - if i == len(data.Klines)-1 { - marker = " <- current" - } - sb.WriteString(fmt.Sprintf("%-14s %-9.4f %-9.4f %-9.4f %-9.4f %-12.2f%s\n", - timeStr, k.Open, k.High, k.Low, k.Close, k.Volume, marker)) - } - sb.WriteString("\n") - } else if len(data.MidPrices) > 0 { - sb.WriteString(fmt.Sprintf("Mid prices: %s\n\n", formatFloatSlice(data.MidPrices))) - if indicators.EnableVolume && len(data.Volume) > 0 { - sb.WriteString(fmt.Sprintf("Volume: %s\n\n", formatFloatSlice(data.Volume))) - } - } - - if indicators.EnableEMA { - if len(data.EMA20Values) > 0 { - sb.WriteString(fmt.Sprintf("EMA20: %s\n", formatFloatSlice(data.EMA20Values))) - } - if len(data.EMA50Values) > 0 { - sb.WriteString(fmt.Sprintf("EMA50: %s\n", formatFloatSlice(data.EMA50Values))) - } - } - - if indicators.EnableMACD && len(data.MACDValues) > 0 { - sb.WriteString(fmt.Sprintf("MACD: %s\n", formatFloatSlice(data.MACDValues))) - } - - if indicators.EnableRSI { - if len(data.RSI7Values) > 0 { - sb.WriteString(fmt.Sprintf("RSI7: %s\n", formatFloatSlice(data.RSI7Values))) - } - if len(data.RSI14Values) > 0 { - sb.WriteString(fmt.Sprintf("RSI14: %s\n", formatFloatSlice(data.RSI14Values))) - } - } - - if indicators.EnableATR && data.ATR14 > 0 { - sb.WriteString(fmt.Sprintf("ATR14: %.4f\n", data.ATR14)) - } - - if indicators.EnableBOLL && len(data.BOLLUpper) > 0 { - sb.WriteString(fmt.Sprintf("BOLL Upper: %s\n", formatFloatSlice(data.BOLLUpper))) - sb.WriteString(fmt.Sprintf("BOLL Middle: %s\n", formatFloatSlice(data.BOLLMiddle))) - sb.WriteString(fmt.Sprintf("BOLL Lower: %s\n", formatFloatSlice(data.BOLLLower))) - } - - sb.WriteString("\n") -} - -func (e *StrategyEngine) formatQuantData(data *QuantData) string { - if data == nil { - return "" - } - - indicators := e.config.Indicators - if !indicators.EnableQuantOI && !indicators.EnableQuantNetflow { - return "" - } - - var sb strings.Builder - sb.WriteString(fmt.Sprintf("📊 %s Quantitative Data:\n", data.Symbol)) - - if len(data.PriceChange) > 0 { - sb.WriteString("Price Change: ") - timeframes := []string{"5m", "15m", "1h", "4h", "12h", "24h"} - parts := []string{} - for _, tf := range timeframes { - if v, ok := data.PriceChange[tf]; ok { - parts = append(parts, fmt.Sprintf("%s: %+.4f%%", tf, v*100)) - } - } - sb.WriteString(strings.Join(parts, " | ")) - sb.WriteString("\n") - } - - if indicators.EnableQuantNetflow && data.Netflow != nil { - sb.WriteString("Fund Flow (Netflow):\n") - timeframes := []string{"5m", "15m", "1h", "4h", "12h", "24h"} - - if data.Netflow.Institution != nil { - if data.Netflow.Institution.Future != nil && len(data.Netflow.Institution.Future) > 0 { - sb.WriteString(" Institutional Futures:\n") - for _, tf := range timeframes { - if v, ok := data.Netflow.Institution.Future[tf]; ok { - sb.WriteString(fmt.Sprintf(" %s: %s\n", tf, formatFlowValue(v))) - } - } - } - if data.Netflow.Institution.Spot != nil && len(data.Netflow.Institution.Spot) > 0 { - sb.WriteString(" Institutional Spot:\n") - for _, tf := range timeframes { - if v, ok := data.Netflow.Institution.Spot[tf]; ok { - sb.WriteString(fmt.Sprintf(" %s: %s\n", tf, formatFlowValue(v))) - } - } - } - } - - if data.Netflow.Personal != nil { - if data.Netflow.Personal.Future != nil && len(data.Netflow.Personal.Future) > 0 { - sb.WriteString(" Retail Futures:\n") - for _, tf := range timeframes { - if v, ok := data.Netflow.Personal.Future[tf]; ok { - sb.WriteString(fmt.Sprintf(" %s: %s\n", tf, formatFlowValue(v))) - } - } - } - if data.Netflow.Personal.Spot != nil && len(data.Netflow.Personal.Spot) > 0 { - sb.WriteString(" Retail Spot:\n") - for _, tf := range timeframes { - if v, ok := data.Netflow.Personal.Spot[tf]; ok { - sb.WriteString(fmt.Sprintf(" %s: %s\n", tf, formatFlowValue(v))) - } - } - } - } - } - - if indicators.EnableQuantOI && len(data.OI) > 0 { - for exchange, oiData := range data.OI { - if len(oiData.Delta) > 0 { - sb.WriteString(fmt.Sprintf("Open Interest (%s):\n", exchange)) - for _, tf := range []string{"5m", "15m", "1h", "4h", "12h", "24h"} { - if d, ok := oiData.Delta[tf]; ok { - sb.WriteString(fmt.Sprintf(" %s: %+.4f%% (%s)\n", tf, d.OIDeltaPercent, formatFlowValue(d.OIDeltaValue))) - } - } - } - } - } - - return sb.String() -} - -func formatFlowValue(v float64) string { - sign := "" - if v >= 0 { - sign = "+" - } - absV := v - if absV < 0 { - absV = -absV - } - if absV >= 1e9 { - return fmt.Sprintf("%s%.2fB", sign, v/1e9) - } else if absV >= 1e6 { - return fmt.Sprintf("%s%.2fM", sign, v/1e6) - } else if absV >= 1e3 { - return fmt.Sprintf("%s%.2fK", sign, v/1e3) - } - return fmt.Sprintf("%s%.2f", sign, v) -} - -func formatFloatSlice(values []float64) string { - strValues := make([]string, len(values)) - for i, v := range values { - strValues[i] = fmt.Sprintf("%.4f", v) - } - return "[" + strings.Join(strValues, ", ") + "]" -} - -// ============================================================================ -// AI Response Parsing -// ============================================================================ - -func parseFullDecisionResponse(aiResponse string, accountEquity float64, btcEthLeverage, altcoinLeverage int, btcEthPosRatio, altcoinPosRatio float64) (*FullDecision, error) { - cotTrace := extractCoTTrace(aiResponse) - - decisions, err := extractDecisions(aiResponse) - if err != nil { - return &FullDecision{ - CoTTrace: cotTrace, - Decisions: []Decision{}, - }, fmt.Errorf("failed to extract decisions: %w", err) - } - - if err := validateDecisions(decisions, accountEquity, btcEthLeverage, altcoinLeverage, btcEthPosRatio, altcoinPosRatio); err != nil { - return &FullDecision{ - CoTTrace: cotTrace, - Decisions: decisions, - }, fmt.Errorf("decision validation failed: %w", err) - } - - return &FullDecision{ - CoTTrace: cotTrace, - Decisions: decisions, - }, nil -} - -func extractCoTTrace(response string) string { - if match := reReasoningTag.FindStringSubmatch(response); match != nil && len(match) > 1 { - logger.Infof("✓ Extracted reasoning chain using tag") - return strings.TrimSpace(match[1]) - } - - if decisionIdx := strings.Index(response, ""); decisionIdx > 0 { - logger.Infof("✓ Extracted content before tag as reasoning chain") - return strings.TrimSpace(response[:decisionIdx]) - } - - jsonStart := strings.Index(response, "[") - if jsonStart > 0 { - logger.Infof("⚠️ Extracted reasoning chain using old format ([ character separator)") - return strings.TrimSpace(response[:jsonStart]) - } - - return strings.TrimSpace(response) -} - -func extractDecisions(response string) ([]Decision, error) { - s := removeInvisibleRunes(response) - s = strings.TrimSpace(s) - s = fixMissingQuotes(s) - - var jsonPart string - if match := reDecisionTag.FindStringSubmatch(s); match != nil && len(match) > 1 { - jsonPart = strings.TrimSpace(match[1]) - logger.Infof("✓ Extracted JSON using tag") - } else { - jsonPart = s - logger.Infof("⚠️ tag not found, searching JSON in full text") - } - - jsonPart = fixMissingQuotes(jsonPart) - - if m := reJSONFence.FindStringSubmatch(jsonPart); m != nil && len(m) > 1 { - jsonContent := strings.TrimSpace(m[1]) - jsonContent = compactArrayOpen(jsonContent) - jsonContent = fixMissingQuotes(jsonContent) - if err := validateJSONFormat(jsonContent); err != nil { - return nil, fmt.Errorf("JSON format validation failed: %w\nJSON content: %s\nFull response:\n%s", err, jsonContent, response) - } - var decisions []Decision - if err := json.Unmarshal([]byte(jsonContent), &decisions); err != nil { - return nil, fmt.Errorf("JSON parsing failed: %w\nJSON content: %s", err, jsonContent) - } - return decisions, nil - } - - jsonContent := strings.TrimSpace(reJSONArray.FindString(jsonPart)) - if jsonContent == "" { - logger.Infof("⚠️ [SafeFallback] AI didn't output JSON decision, entering safe wait mode") - - cotSummary := jsonPart - if len(cotSummary) > 240 { - cotSummary = cotSummary[:240] + "..." - } - - fallbackDecision := Decision{ - Symbol: "ALL", - Action: "wait", - Reasoning: fmt.Sprintf("Model didn't output structured JSON decision, entering safe wait; summary: %s", cotSummary), - } - - return []Decision{fallbackDecision}, nil - } - - jsonContent = compactArrayOpen(jsonContent) - jsonContent = fixMissingQuotes(jsonContent) - - if err := validateJSONFormat(jsonContent); err != nil { - return nil, fmt.Errorf("JSON format validation failed: %w\nJSON content: %s\nFull response:\n%s", err, jsonContent, response) - } - - var decisions []Decision - if err := json.Unmarshal([]byte(jsonContent), &decisions); err != nil { - return nil, fmt.Errorf("JSON parsing failed: %w\nJSON content: %s", err, jsonContent) - } - - return decisions, nil -} - -func fixMissingQuotes(jsonStr string) string { - jsonStr = strings.ReplaceAll(jsonStr, "\u201c", "\"") - jsonStr = strings.ReplaceAll(jsonStr, "\u201d", "\"") - jsonStr = strings.ReplaceAll(jsonStr, "\u2018", "'") - jsonStr = strings.ReplaceAll(jsonStr, "\u2019", "'") - - jsonStr = strings.ReplaceAll(jsonStr, "[", "[") - jsonStr = strings.ReplaceAll(jsonStr, "]", "]") - jsonStr = strings.ReplaceAll(jsonStr, "{", "{") - jsonStr = strings.ReplaceAll(jsonStr, "}", "}") - jsonStr = strings.ReplaceAll(jsonStr, ":", ":") - jsonStr = strings.ReplaceAll(jsonStr, ",", ",") - - jsonStr = strings.ReplaceAll(jsonStr, "【", "[") - jsonStr = strings.ReplaceAll(jsonStr, "】", "]") - jsonStr = strings.ReplaceAll(jsonStr, "〔", "[") - jsonStr = strings.ReplaceAll(jsonStr, "〕", "]") - jsonStr = strings.ReplaceAll(jsonStr, "、", ",") - - jsonStr = strings.ReplaceAll(jsonStr, " ", " ") - - return jsonStr -} - -func validateJSONFormat(jsonStr string) error { - trimmed := strings.TrimSpace(jsonStr) - - if !reArrayHead.MatchString(trimmed) { - if strings.HasPrefix(trimmed, "[") && !strings.Contains(trimmed[:min(20, len(trimmed))], "{") { - return fmt.Errorf("not a valid decision array (must contain objects {}), actual content: %s", trimmed[:min(50, len(trimmed))]) - } - return fmt.Errorf("JSON must start with [{ (whitespace allowed), actual: %s", trimmed[:min(20, len(trimmed))]) - } - - if strings.Contains(jsonStr, "~") { - return fmt.Errorf("JSON cannot contain range symbol ~, all numbers must be precise single values") - } - - for i := 0; i < len(jsonStr)-4; i++ { - if jsonStr[i] >= '0' && jsonStr[i] <= '9' && - jsonStr[i+1] == ',' && - jsonStr[i+2] >= '0' && jsonStr[i+2] <= '9' && - jsonStr[i+3] >= '0' && jsonStr[i+3] <= '9' && - jsonStr[i+4] >= '0' && jsonStr[i+4] <= '9' { - return fmt.Errorf("JSON numbers cannot contain thousand separator comma, found: %s", jsonStr[i:min(i+10, len(jsonStr))]) - } - } - - return nil -} - -func min(a, b int) int { - if a < b { - return a - } - return b -} - -func removeInvisibleRunes(s string) string { - return reInvisibleRunes.ReplaceAllString(s, "") -} - -func compactArrayOpen(s string) string { - return reArrayOpenSpace.ReplaceAllString(strings.TrimSpace(s), "[{") -} - -// ============================================================================ -// Decision Validation -// ============================================================================ - -func validateDecisions(decisions []Decision, accountEquity float64, btcEthLeverage, altcoinLeverage int, btcEthPosRatio, altcoinPosRatio float64) error { - for i := range decisions { - if err := validateDecision(&decisions[i], accountEquity, btcEthLeverage, altcoinLeverage, btcEthPosRatio, altcoinPosRatio); err != nil { - return fmt.Errorf("decision #%d validation failed: %w", i+1, err) - } - } - return nil -} - -func validateDecision(d *Decision, accountEquity float64, btcEthLeverage, altcoinLeverage int, btcEthPosRatio, altcoinPosRatio float64) error { - validActions := map[string]bool{ - "open_long": true, - "open_short": true, - "close_long": true, - "close_short": true, - "hold": true, - "wait": true, - } - - if !validActions[d.Action] { - return fmt.Errorf("invalid action: %s", d.Action) - } - - if d.Action == "open_long" || d.Action == "open_short" { - maxLeverage := altcoinLeverage - posRatio := altcoinPosRatio - maxPositionValue := accountEquity * posRatio - if d.Symbol == "BTCUSDT" || d.Symbol == "ETHUSDT" { - maxLeverage = btcEthLeverage - posRatio = btcEthPosRatio - maxPositionValue = accountEquity * posRatio - } - - if d.Leverage <= 0 { - return fmt.Errorf("leverage must be greater than 0: %d", d.Leverage) - } - if d.Leverage > maxLeverage { - logger.Infof("⚠️ [Leverage Fallback] %s leverage exceeded (%dx > %dx), auto-adjusting to limit %dx", - d.Symbol, d.Leverage, maxLeverage, maxLeverage) - d.Leverage = maxLeverage - } - if d.PositionSizeUSD <= 0 { - return fmt.Errorf("position size must be greater than 0: %.2f", d.PositionSizeUSD) - } - - const minPositionSizeGeneral = 12.0 - const minPositionSizeBTCETH = 60.0 - - if d.Symbol == "BTCUSDT" || d.Symbol == "ETHUSDT" { - if d.PositionSizeUSD < minPositionSizeBTCETH { - return fmt.Errorf("%s opening amount too small (%.2f USDT), must be ≥%.2f USDT", d.Symbol, d.PositionSizeUSD, minPositionSizeBTCETH) - } - } else { - if d.PositionSizeUSD < minPositionSizeGeneral { - return fmt.Errorf("opening amount too small (%.2f USDT), must be ≥%.2f USDT", d.PositionSizeUSD, minPositionSizeGeneral) - } - } - - tolerance := maxPositionValue * 0.01 - if d.PositionSizeUSD > maxPositionValue+tolerance { - if d.Symbol == "BTCUSDT" || d.Symbol == "ETHUSDT" { - return fmt.Errorf("BTC/ETH single coin position value cannot exceed %.0f USDT (%.1fx account equity), actual: %.0f", maxPositionValue, posRatio, d.PositionSizeUSD) - } else { - return fmt.Errorf("altcoin single coin position value cannot exceed %.0f USDT (%.1fx account equity), actual: %.0f", maxPositionValue, posRatio, d.PositionSizeUSD) - } - } - if d.StopLoss <= 0 || d.TakeProfit <= 0 { - return fmt.Errorf("stop loss and take profit must be greater than 0") - } - - if d.Action == "open_long" { - if d.StopLoss >= d.TakeProfit { - return fmt.Errorf("for long positions, stop loss price must be less than take profit price") - } - } else { - if d.StopLoss <= d.TakeProfit { - return fmt.Errorf("for short positions, stop loss price must be greater than take profit price") - } - } - - var entryPrice float64 - if d.Action == "open_long" { - entryPrice = d.StopLoss + (d.TakeProfit-d.StopLoss)*0.2 - } else { - entryPrice = d.StopLoss - (d.StopLoss-d.TakeProfit)*0.2 - } - - var riskPercent, rewardPercent, riskRewardRatio float64 - if d.Action == "open_long" { - riskPercent = (entryPrice - d.StopLoss) / entryPrice * 100 - rewardPercent = (d.TakeProfit - entryPrice) / entryPrice * 100 - if riskPercent > 0 { - riskRewardRatio = rewardPercent / riskPercent - } - } else { - riskPercent = (d.StopLoss - entryPrice) / entryPrice * 100 - rewardPercent = (entryPrice - d.TakeProfit) / entryPrice * 100 - if riskPercent > 0 { - riskRewardRatio = rewardPercent / riskPercent - } - } - - if riskRewardRatio < 3.0 { - return fmt.Errorf("risk/reward ratio too low (%.2f:1), must be ≥3.0:1 [risk: %.2f%% reward: %.2f%%] [stop loss: %.2f take profit: %.2f]", - riskRewardRatio, riskPercent, rewardPercent, d.StopLoss, d.TakeProfit) - } - } - - return nil -} - // ============================================================================ // Helper Functions // ============================================================================ diff --git a/kernel/engine_analysis.go b/kernel/engine_analysis.go new file mode 100644 index 00000000..4a1071bd --- /dev/null +++ b/kernel/engine_analysis.go @@ -0,0 +1,374 @@ +package kernel + +import ( + "encoding/json" + "fmt" + "nofx/logger" + "nofx/market" + "nofx/mcp" + "nofx/store" + "regexp" + "strings" + "time" +) + +// ============================================================================ +// Pre-compiled regular expressions (performance optimization) +// ============================================================================ + +var ( + // Safe regex: precisely match ```json code blocks + reJSONFence = regexp.MustCompile(`(?is)` + "```json\\s*(\\[\\s*\\{.*?\\}\\s*\\])\\s*```") + reJSONArray = regexp.MustCompile(`(?is)\[\s*\{.*?\}\s*\]`) + reArrayHead = regexp.MustCompile(`^\[\s*\{`) + reArrayOpenSpace = regexp.MustCompile(`^\[\s+\{`) + reInvisibleRunes = regexp.MustCompile("[\u200B\u200C\u200D\uFEFF]") + + // XML tag extraction (supports any characters in reasoning chain) + reReasoningTag = regexp.MustCompile(`(?s)(.*?)`) + reDecisionTag = regexp.MustCompile(`(?s)(.*?)`) +) + +// ============================================================================ +// Entry Functions - Main API +// ============================================================================ + +// GetFullDecision gets AI's complete trading decision (batch analysis of all coins and positions) +// Uses default strategy configuration - for production use GetFullDecisionWithStrategy with explicit config +func GetFullDecision(ctx *Context, mcpClient mcp.AIClient) (*FullDecision, error) { + defaultConfig := store.GetDefaultStrategyConfig("en") + engine := NewStrategyEngine(&defaultConfig) + return GetFullDecisionWithStrategy(ctx, mcpClient, engine, "") +} + +// GetFullDecisionWithStrategy uses StrategyEngine to get AI decision (unified prompt generation) +func GetFullDecisionWithStrategy(ctx *Context, mcpClient mcp.AIClient, engine *StrategyEngine, variant string) (*FullDecision, error) { + if ctx == nil { + return nil, fmt.Errorf("context is nil") + } + if engine == nil { + defaultConfig := store.GetDefaultStrategyConfig("en") + engine = NewStrategyEngine(&defaultConfig) + } + + // 1. Fetch market data using strategy config + if len(ctx.MarketDataMap) == 0 { + if err := fetchMarketDataWithStrategy(ctx, engine); err != nil { + return nil, fmt.Errorf("failed to fetch market data: %w", err) + } + } + + // Ensure OITopDataMap is initialized + if ctx.OITopDataMap == nil { + ctx.OITopDataMap = make(map[string]*OITopData) + oiPositions, err := engine.nofxosClient.GetOITopPositions() + if err == nil { + for _, pos := range oiPositions { + ctx.OITopDataMap[pos.Symbol] = &OITopData{ + Rank: pos.Rank, + OIDeltaPercent: pos.OIDeltaPercent, + OIDeltaValue: pos.OIDeltaValue, + PriceDeltaPercent: pos.PriceDeltaPercent, + } + } + } + } + + // 2. Build System Prompt using strategy engine + riskConfig := engine.GetRiskControlConfig() + systemPrompt := engine.BuildSystemPrompt(ctx.Account.TotalEquity, variant) + + // 3. Build User Prompt using strategy engine + userPrompt := engine.BuildUserPrompt(ctx) + + // 4. Call AI API + aiCallStart := time.Now() + aiResponse, err := mcpClient.CallWithMessages(systemPrompt, userPrompt) + aiCallDuration := time.Since(aiCallStart) + if err != nil { + return nil, fmt.Errorf("AI API call failed: %w", err) + } + + // 5. Parse AI response + decision, err := parseFullDecisionResponse( + aiResponse, + ctx.Account.TotalEquity, + riskConfig.BTCETHMaxLeverage, + riskConfig.AltcoinMaxLeverage, + riskConfig.BTCETHMaxPositionValueRatio, + riskConfig.AltcoinMaxPositionValueRatio, + ) + + if decision != nil { + decision.Timestamp = time.Now() + decision.SystemPrompt = systemPrompt + decision.UserPrompt = userPrompt + decision.AIRequestDurationMs = aiCallDuration.Milliseconds() + decision.RawResponse = aiResponse + } + + if err != nil { + return decision, fmt.Errorf("failed to parse AI response: %w", err) + } + + return decision, nil +} + +// ============================================================================ +// Market Data Fetching +// ============================================================================ + +// fetchMarketDataWithStrategy fetches market data using strategy config (multiple timeframes) +func fetchMarketDataWithStrategy(ctx *Context, engine *StrategyEngine) error { + config := engine.GetConfig() + ctx.MarketDataMap = make(map[string]*market.Data) + + timeframes := config.Indicators.Klines.SelectedTimeframes + primaryTimeframe := config.Indicators.Klines.PrimaryTimeframe + klineCount := config.Indicators.Klines.PrimaryCount + + // Compatible with old configuration + if len(timeframes) == 0 { + if primaryTimeframe != "" { + timeframes = append(timeframes, primaryTimeframe) + } else { + timeframes = append(timeframes, "3m") + } + if config.Indicators.Klines.LongerTimeframe != "" { + timeframes = append(timeframes, config.Indicators.Klines.LongerTimeframe) + } + } + if primaryTimeframe == "" { + primaryTimeframe = timeframes[0] + } + if klineCount <= 0 { + klineCount = 30 + } + + logger.Infof("📊 Strategy timeframes: %v, Primary: %s, Kline count: %d", timeframes, primaryTimeframe, klineCount) + + // 1. First fetch data for position coins (must fetch) + for _, pos := range ctx.Positions { + data, err := market.GetWithTimeframes(pos.Symbol, timeframes, primaryTimeframe, klineCount) + if err != nil { + logger.Infof("⚠️ Failed to fetch market data for position %s: %v", pos.Symbol, err) + continue + } + ctx.MarketDataMap[pos.Symbol] = data + } + + // 2. Fetch data for all candidate coins + positionSymbols := make(map[string]bool) + for _, pos := range ctx.Positions { + positionSymbols[pos.Symbol] = true + } + + const minOIThresholdMillions = 15.0 // 15M USD minimum open interest value + + for _, coin := range ctx.CandidateCoins { + if _, exists := ctx.MarketDataMap[coin.Symbol]; exists { + continue + } + + data, err := market.GetWithTimeframes(coin.Symbol, timeframes, primaryTimeframe, klineCount) + if err != nil { + logger.Infof("⚠️ Failed to fetch market data for %s: %v", coin.Symbol, err) + continue + } + + // Liquidity filter (skip for xyz dex assets - they don't have OI data from Binance) + isExistingPosition := positionSymbols[coin.Symbol] + isXyzAsset := market.IsXyzDexAsset(coin.Symbol) + if !isExistingPosition && !isXyzAsset && data.OpenInterest != nil && data.CurrentPrice > 0 { + oiValue := data.OpenInterest.Latest * data.CurrentPrice + oiValueInMillions := oiValue / 1_000_000 + if oiValueInMillions < minOIThresholdMillions { + logger.Infof("⚠️ %s OI value too low (%.2fM USD < %.1fM), skipping coin", + coin.Symbol, oiValueInMillions, minOIThresholdMillions) + continue + } + } + + ctx.MarketDataMap[coin.Symbol] = data + } + + logger.Infof("📊 Successfully fetched multi-timeframe market data for %d coins", len(ctx.MarketDataMap)) + return nil +} + +// ============================================================================ +// AI Response Parsing +// ============================================================================ + +func parseFullDecisionResponse(aiResponse string, accountEquity float64, btcEthLeverage, altcoinLeverage int, btcEthPosRatio, altcoinPosRatio float64) (*FullDecision, error) { + cotTrace := extractCoTTrace(aiResponse) + + decisions, err := extractDecisions(aiResponse) + if err != nil { + return &FullDecision{ + CoTTrace: cotTrace, + Decisions: []Decision{}, + }, fmt.Errorf("failed to extract decisions: %w", err) + } + + if err := validateDecisions(decisions, accountEquity, btcEthLeverage, altcoinLeverage, btcEthPosRatio, altcoinPosRatio); err != nil { + return &FullDecision{ + CoTTrace: cotTrace, + Decisions: decisions, + }, fmt.Errorf("decision validation failed: %w", err) + } + + return &FullDecision{ + CoTTrace: cotTrace, + Decisions: decisions, + }, nil +} + +func extractCoTTrace(response string) string { + if match := reReasoningTag.FindStringSubmatch(response); match != nil && len(match) > 1 { + logger.Infof("✓ Extracted reasoning chain using tag") + return strings.TrimSpace(match[1]) + } + + if decisionIdx := strings.Index(response, ""); decisionIdx > 0 { + logger.Infof("✓ Extracted content before tag as reasoning chain") + return strings.TrimSpace(response[:decisionIdx]) + } + + jsonStart := strings.Index(response, "[") + if jsonStart > 0 { + logger.Infof("⚠️ Extracted reasoning chain using old format ([ character separator)") + return strings.TrimSpace(response[:jsonStart]) + } + + return strings.TrimSpace(response) +} + +func extractDecisions(response string) ([]Decision, error) { + s := removeInvisibleRunes(response) + s = strings.TrimSpace(s) + s = fixMissingQuotes(s) + + var jsonPart string + if match := reDecisionTag.FindStringSubmatch(s); match != nil && len(match) > 1 { + jsonPart = strings.TrimSpace(match[1]) + logger.Infof("✓ Extracted JSON using tag") + } else { + jsonPart = s + logger.Infof("⚠️ tag not found, searching JSON in full text") + } + + jsonPart = fixMissingQuotes(jsonPart) + + if m := reJSONFence.FindStringSubmatch(jsonPart); m != nil && len(m) > 1 { + jsonContent := strings.TrimSpace(m[1]) + jsonContent = compactArrayOpen(jsonContent) + jsonContent = fixMissingQuotes(jsonContent) + if err := validateJSONFormat(jsonContent); err != nil { + return nil, fmt.Errorf("JSON format validation failed: %w\nJSON content: %s\nFull response:\n%s", err, jsonContent, response) + } + var decisions []Decision + if err := json.Unmarshal([]byte(jsonContent), &decisions); err != nil { + return nil, fmt.Errorf("JSON parsing failed: %w\nJSON content: %s", err, jsonContent) + } + return decisions, nil + } + + jsonContent := strings.TrimSpace(reJSONArray.FindString(jsonPart)) + if jsonContent == "" { + logger.Infof("⚠️ [SafeFallback] AI didn't output JSON decision, entering safe wait mode") + + cotSummary := jsonPart + if len(cotSummary) > 240 { + cotSummary = cotSummary[:240] + "..." + } + + fallbackDecision := Decision{ + Symbol: "ALL", + Action: "wait", + Reasoning: fmt.Sprintf("Model didn't output structured JSON decision, entering safe wait; summary: %s", cotSummary), + } + + return []Decision{fallbackDecision}, nil + } + + jsonContent = compactArrayOpen(jsonContent) + jsonContent = fixMissingQuotes(jsonContent) + + if err := validateJSONFormat(jsonContent); err != nil { + return nil, fmt.Errorf("JSON format validation failed: %w\nJSON content: %s\nFull response:\n%s", err, jsonContent, response) + } + + var decisions []Decision + if err := json.Unmarshal([]byte(jsonContent), &decisions); err != nil { + return nil, fmt.Errorf("JSON parsing failed: %w\nJSON content: %s", err, jsonContent) + } + + return decisions, nil +} + +func fixMissingQuotes(jsonStr string) string { + jsonStr = strings.ReplaceAll(jsonStr, "\u201c", "\"") + jsonStr = strings.ReplaceAll(jsonStr, "\u201d", "\"") + jsonStr = strings.ReplaceAll(jsonStr, "\u2018", "'") + jsonStr = strings.ReplaceAll(jsonStr, "\u2019", "'") + + jsonStr = strings.ReplaceAll(jsonStr, "[", "[") + jsonStr = strings.ReplaceAll(jsonStr, "]", "]") + jsonStr = strings.ReplaceAll(jsonStr, "{", "{") + jsonStr = strings.ReplaceAll(jsonStr, "}", "}") + jsonStr = strings.ReplaceAll(jsonStr, ":", ":") + jsonStr = strings.ReplaceAll(jsonStr, ",", ",") + + jsonStr = strings.ReplaceAll(jsonStr, "【", "[") + jsonStr = strings.ReplaceAll(jsonStr, "】", "]") + jsonStr = strings.ReplaceAll(jsonStr, "〔", "[") + jsonStr = strings.ReplaceAll(jsonStr, "〕", "]") + jsonStr = strings.ReplaceAll(jsonStr, "、", ",") + + jsonStr = strings.ReplaceAll(jsonStr, " ", " ") + + return jsonStr +} + +func validateJSONFormat(jsonStr string) error { + trimmed := strings.TrimSpace(jsonStr) + + if !reArrayHead.MatchString(trimmed) { + if strings.HasPrefix(trimmed, "[") && !strings.Contains(trimmed[:min(20, len(trimmed))], "{") { + return fmt.Errorf("not a valid decision array (must contain objects {}), actual content: %s", trimmed[:min(50, len(trimmed))]) + } + return fmt.Errorf("JSON must start with [{ (whitespace allowed), actual: %s", trimmed[:min(20, len(trimmed))]) + } + + if strings.Contains(jsonStr, "~") { + return fmt.Errorf("JSON cannot contain range symbol ~, all numbers must be precise single values") + } + + for i := 0; i < len(jsonStr)-4; i++ { + if jsonStr[i] >= '0' && jsonStr[i] <= '9' && + jsonStr[i+1] == ',' && + jsonStr[i+2] >= '0' && jsonStr[i+2] <= '9' && + jsonStr[i+3] >= '0' && jsonStr[i+3] <= '9' && + jsonStr[i+4] >= '0' && jsonStr[i+4] <= '9' { + return fmt.Errorf("JSON numbers cannot contain thousand separator comma, found: %s", jsonStr[i:min(i+10, len(jsonStr))]) + } + } + + return nil +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func removeInvisibleRunes(s string) string { + return reInvisibleRunes.ReplaceAllString(s, "") +} + +func compactArrayOpen(s string) string { + return reArrayOpenSpace.ReplaceAllString(strings.TrimSpace(s), "[{") +} diff --git a/kernel/engine_position.go b/kernel/engine_position.go new file mode 100644 index 00000000..437aea1a --- /dev/null +++ b/kernel/engine_position.go @@ -0,0 +1,121 @@ +package kernel + +import ( + "fmt" + "nofx/logger" +) + +// ============================================================================ +// Decision Validation +// ============================================================================ + +func validateDecisions(decisions []Decision, accountEquity float64, btcEthLeverage, altcoinLeverage int, btcEthPosRatio, altcoinPosRatio float64) error { + for i := range decisions { + if err := validateDecision(&decisions[i], accountEquity, btcEthLeverage, altcoinLeverage, btcEthPosRatio, altcoinPosRatio); err != nil { + return fmt.Errorf("decision #%d validation failed: %w", i+1, err) + } + } + return nil +} + +func validateDecision(d *Decision, accountEquity float64, btcEthLeverage, altcoinLeverage int, btcEthPosRatio, altcoinPosRatio float64) error { + validActions := map[string]bool{ + "open_long": true, + "open_short": true, + "close_long": true, + "close_short": true, + "hold": true, + "wait": true, + } + + if !validActions[d.Action] { + return fmt.Errorf("invalid action: %s", d.Action) + } + + if d.Action == "open_long" || d.Action == "open_short" { + maxLeverage := altcoinLeverage + posRatio := altcoinPosRatio + maxPositionValue := accountEquity * posRatio + if d.Symbol == "BTCUSDT" || d.Symbol == "ETHUSDT" { + maxLeverage = btcEthLeverage + posRatio = btcEthPosRatio + maxPositionValue = accountEquity * posRatio + } + + if d.Leverage <= 0 { + return fmt.Errorf("leverage must be greater than 0: %d", d.Leverage) + } + if d.Leverage > maxLeverage { + logger.Infof("⚠️ [Leverage Fallback] %s leverage exceeded (%dx > %dx), auto-adjusting to limit %dx", + d.Symbol, d.Leverage, maxLeverage, maxLeverage) + d.Leverage = maxLeverage + } + if d.PositionSizeUSD <= 0 { + return fmt.Errorf("position size must be greater than 0: %.2f", d.PositionSizeUSD) + } + + const minPositionSizeGeneral = 12.0 + const minPositionSizeBTCETH = 60.0 + + if d.Symbol == "BTCUSDT" || d.Symbol == "ETHUSDT" { + if d.PositionSizeUSD < minPositionSizeBTCETH { + return fmt.Errorf("%s opening amount too small (%.2f USDT), must be ≥%.2f USDT", d.Symbol, d.PositionSizeUSD, minPositionSizeBTCETH) + } + } else { + if d.PositionSizeUSD < minPositionSizeGeneral { + return fmt.Errorf("opening amount too small (%.2f USDT), must be ≥%.2f USDT", d.PositionSizeUSD, minPositionSizeGeneral) + } + } + + tolerance := maxPositionValue * 0.01 + if d.PositionSizeUSD > maxPositionValue+tolerance { + if d.Symbol == "BTCUSDT" || d.Symbol == "ETHUSDT" { + return fmt.Errorf("BTC/ETH single coin position value cannot exceed %.0f USDT (%.1fx account equity), actual: %.0f", maxPositionValue, posRatio, d.PositionSizeUSD) + } else { + return fmt.Errorf("altcoin single coin position value cannot exceed %.0f USDT (%.1fx account equity), actual: %.0f", maxPositionValue, posRatio, d.PositionSizeUSD) + } + } + if d.StopLoss <= 0 || d.TakeProfit <= 0 { + return fmt.Errorf("stop loss and take profit must be greater than 0") + } + + if d.Action == "open_long" { + if d.StopLoss >= d.TakeProfit { + return fmt.Errorf("for long positions, stop loss price must be less than take profit price") + } + } else { + if d.StopLoss <= d.TakeProfit { + return fmt.Errorf("for short positions, stop loss price must be greater than take profit price") + } + } + + var entryPrice float64 + if d.Action == "open_long" { + entryPrice = d.StopLoss + (d.TakeProfit-d.StopLoss)*0.2 + } else { + entryPrice = d.StopLoss - (d.StopLoss-d.TakeProfit)*0.2 + } + + var riskPercent, rewardPercent, riskRewardRatio float64 + if d.Action == "open_long" { + riskPercent = (entryPrice - d.StopLoss) / entryPrice * 100 + rewardPercent = (d.TakeProfit - entryPrice) / entryPrice * 100 + if riskPercent > 0 { + riskRewardRatio = rewardPercent / riskPercent + } + } else { + riskPercent = (d.StopLoss - entryPrice) / entryPrice * 100 + rewardPercent = (entryPrice - d.TakeProfit) / entryPrice * 100 + if riskPercent > 0 { + riskRewardRatio = rewardPercent / riskPercent + } + } + + if riskRewardRatio < 3.0 { + return fmt.Errorf("risk/reward ratio too low (%.2f:1), must be ≥3.0:1 [risk: %.2f%% reward: %.2f%%] [stop loss: %.2f take profit: %.2f]", + riskRewardRatio, riskPercent, rewardPercent, d.StopLoss, d.TakeProfit) + } + } + + return nil +} diff --git a/kernel/engine_prompt.go b/kernel/engine_prompt.go new file mode 100644 index 00000000..15c79a38 --- /dev/null +++ b/kernel/engine_prompt.go @@ -0,0 +1,779 @@ +package kernel + +import ( + "fmt" + "nofx/market" + "nofx/provider/nofxos" + "nofx/store" + "strings" + "time" +) + +// ============================================================================ +// Prompt Building - System Prompt +// ============================================================================ + +// BuildSystemPrompt builds System Prompt according to strategy configuration +func (e *StrategyEngine) BuildSystemPrompt(accountEquity float64, variant string) string { + var sb strings.Builder + riskControl := e.config.RiskControl + promptSections := e.config.PromptSections + + // 0. Data Dictionary & Schema (ensure AI understands all fields) + lang := e.GetLanguage() + schemaPrompt := GetSchemaPrompt(lang) + sb.WriteString(schemaPrompt) + sb.WriteString("\n\n") + sb.WriteString("---\n\n") + + // 1. Role definition (editable) + if promptSections.RoleDefinition != "" { + sb.WriteString(promptSections.RoleDefinition) + sb.WriteString("\n\n") + } else { + sb.WriteString("# You are a professional cryptocurrency trading AI\n\n") + sb.WriteString("Your task is to make trading decisions based on provided market data.\n\n") + } + + // 2. Trading mode variant + switch strings.ToLower(strings.TrimSpace(variant)) { + case "aggressive": + sb.WriteString("## Mode: Aggressive\n- Prioritize capturing trend breakouts, can build positions in batches when confidence ≥ 70\n- Allow higher positions, but must strictly set stop-loss and explain risk-reward ratio\n\n") + case "conservative": + sb.WriteString("## Mode: Conservative\n- Only open positions when multiple signals resonate\n- Prioritize cash preservation, must pause for multiple periods after consecutive losses\n\n") + case "scalping": + sb.WriteString("## Mode: Scalping\n- Focus on short-term momentum, smaller profit targets but require quick action\n- If price doesn't move as expected within two bars, immediately reduce position or stop-loss\n\n") + } + + // 3. Hard constraints (risk control) + btcEthPosValueRatio := riskControl.BTCETHMaxPositionValueRatio + if btcEthPosValueRatio <= 0 { + btcEthPosValueRatio = 5.0 + } + altcoinPosValueRatio := riskControl.AltcoinMaxPositionValueRatio + if altcoinPosValueRatio <= 0 { + altcoinPosValueRatio = 1.0 + } + + sb.WriteString("# Hard Constraints (Risk Control)\n\n") + sb.WriteString("## CODE ENFORCED (Backend validation, cannot be bypassed):\n") + sb.WriteString(fmt.Sprintf("- Max Positions: %d coins simultaneously\n", riskControl.MaxPositions)) + sb.WriteString(fmt.Sprintf("- Position Value Limit (Altcoins): max %.0f USDT (= equity %.0f × %.1fx)\n", + accountEquity*altcoinPosValueRatio, accountEquity, altcoinPosValueRatio)) + sb.WriteString(fmt.Sprintf("- Position Value Limit (BTC/ETH): max %.0f USDT (= equity %.0f × %.1fx)\n", + accountEquity*btcEthPosValueRatio, accountEquity, btcEthPosValueRatio)) + sb.WriteString(fmt.Sprintf("- Max Margin Usage: ≤%.0f%%\n", riskControl.MaxMarginUsage*100)) + sb.WriteString(fmt.Sprintf("- Min Position Size: ≥%.0f USDT\n\n", riskControl.MinPositionSize)) + + sb.WriteString("## AI GUIDED (Recommended, you should follow):\n") + sb.WriteString(fmt.Sprintf("- Trading Leverage: Altcoins max %dx | BTC/ETH max %dx\n", + riskControl.AltcoinMaxLeverage, riskControl.BTCETHMaxLeverage)) + sb.WriteString(fmt.Sprintf("- Risk-Reward Ratio: ≥1:%.1f (take_profit / stop_loss)\n", riskControl.MinRiskRewardRatio)) + sb.WriteString(fmt.Sprintf("- Min Confidence: ≥%d to open position\n\n", riskControl.MinConfidence)) + + // Position sizing guidance + sb.WriteString("## Position Sizing Guidance\n") + sb.WriteString("Calculate `position_size_usd` based on your confidence and the Position Value Limits above:\n") + sb.WriteString("- High confidence (≥85): Use 80-100%% of max position value limit\n") + sb.WriteString("- Medium confidence (70-84): Use 50-80%% of max position value limit\n") + sb.WriteString("- Low confidence (60-69): Use 30-50%% of max position value limit\n") + sb.WriteString(fmt.Sprintf("- Example: With equity %.0f and BTC/ETH ratio %.1fx, max is %.0f USDT\n", + accountEquity, btcEthPosValueRatio, accountEquity*btcEthPosValueRatio)) + sb.WriteString("- **DO NOT** just use available_balance as position_size_usd. Use the Position Value Limits!\n\n") + + // 4. Trading frequency (editable) + if promptSections.TradingFrequency != "" { + sb.WriteString(promptSections.TradingFrequency) + sb.WriteString("\n\n") + } else { + sb.WriteString("# ⏱️ Trading Frequency Awareness\n\n") + sb.WriteString("- Excellent traders: 2-4 trades/day ≈ 0.1-0.2 trades/hour\n") + sb.WriteString("- >2 trades/hour = Overtrading\n") + sb.WriteString("- Single position hold time ≥ 30-60 minutes\n") + sb.WriteString("If you find yourself trading every period → standards too low; if closing positions < 30 minutes → too impatient.\n\n") + } + + // 5. Entry standards (editable) + if promptSections.EntryStandards != "" { + sb.WriteString(promptSections.EntryStandards) + sb.WriteString("\n\nYou have the following indicator data:\n") + e.writeAvailableIndicators(&sb) + sb.WriteString(fmt.Sprintf("\n**Confidence ≥ %d** required to open positions.\n\n", riskControl.MinConfidence)) + } else { + sb.WriteString("# 🎯 Entry Standards (Strict)\n\n") + sb.WriteString("Only open positions when multiple signals resonate. You have:\n") + e.writeAvailableIndicators(&sb) + sb.WriteString(fmt.Sprintf("\nFeel free to use any effective analysis method, but **confidence ≥ %d** required to open positions; avoid low-quality behaviors such as single indicators, contradictory signals, sideways consolidation, reopening immediately after closing, etc.\n\n", riskControl.MinConfidence)) + } + + // 6. Decision process (editable) + if promptSections.DecisionProcess != "" { + sb.WriteString(promptSections.DecisionProcess) + sb.WriteString("\n\n") + } else { + sb.WriteString("# 📋 Decision Process\n\n") + sb.WriteString("1. Check positions → Should we take profit/stop-loss\n") + sb.WriteString("2. Scan candidate coins + multi-timeframe → Are there strong signals\n") + sb.WriteString("3. Write chain of thought first, then output structured JSON\n\n") + } + + // 7. Output format + sb.WriteString("# Output Format (Strictly Follow)\n\n") + sb.WriteString("**Must use XML tags and to separate chain of thought and decision JSON, avoiding parsing errors**\n\n") + sb.WriteString("## Format Requirements\n\n") + sb.WriteString("\n") + sb.WriteString("Your chain of thought analysis...\n") + sb.WriteString("- Briefly analyze your thinking process \n") + sb.WriteString("\n\n") + sb.WriteString("\n") + sb.WriteString("Step 2: JSON decision array\n\n") + sb.WriteString("```json\n[\n") + // Use the actual configured position value ratio for BTC/ETH in the example + examplePositionSize := accountEquity * btcEthPosValueRatio + sb.WriteString(fmt.Sprintf(" {\"symbol\": \"BTCUSDT\", \"action\": \"open_short\", \"leverage\": %d, \"position_size_usd\": %.0f, \"stop_loss\": 97000, \"take_profit\": 91000, \"confidence\": 85, \"risk_usd\": 300},\n", + riskControl.BTCETHMaxLeverage, examplePositionSize)) + sb.WriteString(" {\"symbol\": \"ETHUSDT\", \"action\": \"close_long\"}\n") + sb.WriteString("]\n```\n") + sb.WriteString("\n\n") + sb.WriteString("## Field Description\n\n") + sb.WriteString("- `action`: open_long | open_short | close_long | close_short | hold | wait\n") + sb.WriteString(fmt.Sprintf("- `confidence`: 0-100 (opening recommended ≥ %d)\n", riskControl.MinConfidence)) + sb.WriteString("- Required when opening: leverage, position_size_usd, stop_loss, take_profit, confidence, risk_usd\n") + sb.WriteString("- **IMPORTANT**: All numeric values must be calculated numbers, NOT formulas/expressions (e.g., use `27.76` not `3000 * 0.01`)\n\n") + + // 8. Custom Prompt + if e.config.CustomPrompt != "" { + sb.WriteString("# 📌 Personalized Trading Strategy\n\n") + sb.WriteString(e.config.CustomPrompt) + sb.WriteString("\n\n") + sb.WriteString("Note: The above personalized strategy is a supplement to the basic rules and cannot violate the basic risk control principles.\n") + } + + return sb.String() +} + +func (e *StrategyEngine) writeAvailableIndicators(sb *strings.Builder) { + indicators := e.config.Indicators + kline := indicators.Klines + + sb.WriteString(fmt.Sprintf("- %s price series", kline.PrimaryTimeframe)) + if kline.EnableMultiTimeframe { + sb.WriteString(fmt.Sprintf(" + %s K-line series\n", kline.LongerTimeframe)) + } else { + sb.WriteString("\n") + } + + if indicators.EnableEMA { + sb.WriteString("- EMA indicators") + if len(indicators.EMAPeriods) > 0 { + sb.WriteString(fmt.Sprintf(" (periods: %v)", indicators.EMAPeriods)) + } + sb.WriteString("\n") + } + + if indicators.EnableMACD { + sb.WriteString("- MACD indicators\n") + } + + if indicators.EnableRSI { + sb.WriteString("- RSI indicators") + if len(indicators.RSIPeriods) > 0 { + sb.WriteString(fmt.Sprintf(" (periods: %v)", indicators.RSIPeriods)) + } + sb.WriteString("\n") + } + + if indicators.EnableATR { + sb.WriteString("- ATR indicators") + if len(indicators.ATRPeriods) > 0 { + sb.WriteString(fmt.Sprintf(" (periods: %v)", indicators.ATRPeriods)) + } + sb.WriteString("\n") + } + + if indicators.EnableBOLL { + sb.WriteString("- Bollinger Bands (BOLL) - Upper/Middle/Lower bands") + if len(indicators.BOLLPeriods) > 0 { + sb.WriteString(fmt.Sprintf(" (periods: %v)", indicators.BOLLPeriods)) + } + sb.WriteString("\n") + } + + if indicators.EnableVolume { + sb.WriteString("- Volume data\n") + } + + if indicators.EnableOI { + sb.WriteString("- Open Interest (OI) data\n") + } + + if indicators.EnableFundingRate { + sb.WriteString("- Funding rate\n") + } + + if len(e.config.CoinSource.StaticCoins) > 0 || e.config.CoinSource.UseAI500 || e.config.CoinSource.UseOITop { + sb.WriteString("- AI500 / OI_Top filter tags (if available)\n") + } + + if indicators.EnableQuantData { + sb.WriteString("- Quantitative data (institutional/retail fund flow, position changes, multi-period price changes)\n") + } +} + +// ============================================================================ +// Prompt Building - User Prompt +// ============================================================================ + +// BuildUserPrompt builds User Prompt based on strategy configuration +func (e *StrategyEngine) BuildUserPrompt(ctx *Context) string { + var sb strings.Builder + + // System status + sb.WriteString(fmt.Sprintf("Time: %s | Period: #%d | Runtime: %d minutes\n\n", + ctx.CurrentTime, ctx.CallCount, ctx.RuntimeMinutes)) + + // BTC market + if btcData, hasBTC := ctx.MarketDataMap["BTCUSDT"]; hasBTC { + sb.WriteString(fmt.Sprintf("BTC: %.2f (1h: %+.2f%%, 4h: %+.2f%%) | MACD: %.4f | RSI: %.2f\n\n", + btcData.CurrentPrice, btcData.PriceChange1h, btcData.PriceChange4h, + btcData.CurrentMACD, btcData.CurrentRSI7)) + } + + // Account information + sb.WriteString(fmt.Sprintf("Account: Equity %.2f | Balance %.2f (%.1f%%) | PnL %+.2f%% | Margin %.1f%% | Positions %d\n\n", + ctx.Account.TotalEquity, + ctx.Account.AvailableBalance, + (ctx.Account.AvailableBalance/ctx.Account.TotalEquity)*100, + ctx.Account.TotalPnLPct, + ctx.Account.MarginUsedPct, + ctx.Account.PositionCount)) + + // Recently completed orders (placed before positions to ensure visibility) + if len(ctx.RecentOrders) > 0 { + sb.WriteString("## Recent Completed Trades\n") + for i, order := range ctx.RecentOrders { + resultStr := "Profit" + if order.RealizedPnL < 0 { + resultStr = "Loss" + } + sb.WriteString(fmt.Sprintf("%d. %s %s | Entry %.4f Exit %.4f | %s: %+.2f USDT (%+.2f%%) | %s→%s (%s)\n", + i+1, order.Symbol, order.Side, + order.EntryPrice, order.ExitPrice, + resultStr, order.RealizedPnL, order.PnLPct, + order.EntryTime, order.ExitTime, order.HoldDuration)) + } + sb.WriteString("\n") + } + + // Historical trading statistics (helps AI understand past performance) + if ctx.TradingStats != nil && ctx.TradingStats.TotalTrades > 0 { + // Get language from strategy config + lang := e.GetLanguage() + + // Win/Loss ratio + var winLossRatio float64 + if ctx.TradingStats.AvgLoss > 0 { + winLossRatio = ctx.TradingStats.AvgWin / ctx.TradingStats.AvgLoss + } + + if lang == LangChinese { + sb.WriteString("## 历史交易统计\n") + sb.WriteString(fmt.Sprintf("总交易: %d 笔 | 盈利因子: %.2f | 夏普比率: %.2f | 盈亏比: %.2f\n", + ctx.TradingStats.TotalTrades, + ctx.TradingStats.ProfitFactor, + ctx.TradingStats.SharpeRatio, + winLossRatio)) + sb.WriteString(fmt.Sprintf("总盈亏: %+.2f USDT | 平均盈利: +%.2f | 平均亏损: -%.2f | 最大回撤: %.1f%%\n", + ctx.TradingStats.TotalPnL, + ctx.TradingStats.AvgWin, + ctx.TradingStats.AvgLoss, + ctx.TradingStats.MaxDrawdownPct)) + + // Performance hints based on profit factor, sharpe, and drawdown + if ctx.TradingStats.ProfitFactor >= 1.5 && ctx.TradingStats.SharpeRatio >= 1 { + sb.WriteString("表现: 良好 - 保持当前策略\n") + } else if ctx.TradingStats.ProfitFactor < 1 { + sb.WriteString("表现: 需改进 - 提高盈亏比,优化止盈止损\n") + } else if ctx.TradingStats.MaxDrawdownPct > 30 { + sb.WriteString("表现: 风险偏高 - 减少仓位,控制回撤\n") + } else { + sb.WriteString("表现: 正常 - 有优化空间\n") + } + } else { + sb.WriteString("## Historical Trading Statistics\n") + sb.WriteString(fmt.Sprintf("Total Trades: %d | Profit Factor: %.2f | Sharpe: %.2f | Win/Loss Ratio: %.2f\n", + ctx.TradingStats.TotalTrades, + ctx.TradingStats.ProfitFactor, + ctx.TradingStats.SharpeRatio, + winLossRatio)) + sb.WriteString(fmt.Sprintf("Total PnL: %+.2f USDT | Avg Win: +%.2f | Avg Loss: -%.2f | Max Drawdown: %.1f%%\n", + ctx.TradingStats.TotalPnL, + ctx.TradingStats.AvgWin, + ctx.TradingStats.AvgLoss, + ctx.TradingStats.MaxDrawdownPct)) + + // Performance hints based on profit factor, sharpe, and drawdown + if ctx.TradingStats.ProfitFactor >= 1.5 && ctx.TradingStats.SharpeRatio >= 1 { + sb.WriteString("Performance: GOOD - maintain current strategy\n") + } else if ctx.TradingStats.ProfitFactor < 1 { + sb.WriteString("Performance: NEEDS IMPROVEMENT - improve win/loss ratio, optimize TP/SL\n") + } else if ctx.TradingStats.MaxDrawdownPct > 30 { + sb.WriteString("Performance: HIGH RISK - reduce position size, control drawdown\n") + } else { + sb.WriteString("Performance: NORMAL - room for optimization\n") + } + } + sb.WriteString("\n") + } + + // Position information + if len(ctx.Positions) > 0 { + sb.WriteString("## Current Positions\n") + for i, pos := range ctx.Positions { + sb.WriteString(e.formatPositionInfo(i+1, pos, ctx)) + } + } else { + sb.WriteString("Current Positions: None\n\n") + } + + // Candidate coins (exclude coins already in positions to avoid duplicate data) + positionSymbols := make(map[string]bool) + for _, pos := range ctx.Positions { + // Normalize symbol to handle both "ETH" and "ETHUSDT" formats + normalizedSymbol := market.Normalize(pos.Symbol) + positionSymbols[normalizedSymbol] = true + } + + sb.WriteString(fmt.Sprintf("## Candidate Coins (%d coins)\n\n", len(ctx.MarketDataMap))) + displayedCount := 0 + for _, coin := range ctx.CandidateCoins { + // Skip if this coin is already a position (data already shown in positions section) + normalizedCoinSymbol := market.Normalize(coin.Symbol) + if positionSymbols[normalizedCoinSymbol] { + continue + } + + marketData, hasData := ctx.MarketDataMap[coin.Symbol] + if !hasData { + continue + } + displayedCount++ + + sourceTags := e.formatCoinSourceTag(coin.Sources) + sb.WriteString(fmt.Sprintf("### %d. %s%s\n\n", displayedCount, coin.Symbol, sourceTags)) + sb.WriteString(e.formatMarketData(marketData)) + + if ctx.QuantDataMap != nil { + if quantData, hasQuant := ctx.QuantDataMap[coin.Symbol]; hasQuant { + sb.WriteString(e.formatQuantData(quantData)) + } + } + sb.WriteString("\n") + } + sb.WriteString("\n") + + // Get language for market data formatting + nofxosLang := nofxos.LangEnglish + if e.GetLanguage() == LangChinese { + nofxosLang = nofxos.LangChinese + } + + // OI Ranking data (market-wide open interest changes) + if ctx.OIRankingData != nil { + sb.WriteString(nofxos.FormatOIRankingForAI(ctx.OIRankingData, nofxosLang)) + } + + // NetFlow Ranking data (market-wide fund flow) + if ctx.NetFlowRankingData != nil { + sb.WriteString(nofxos.FormatNetFlowRankingForAI(ctx.NetFlowRankingData, nofxosLang)) + } + + // Price Ranking data (market-wide gainers/losers) + if ctx.PriceRankingData != nil { + sb.WriteString(nofxos.FormatPriceRankingForAI(ctx.PriceRankingData, nofxosLang)) + } + + sb.WriteString("---\n\n") + sb.WriteString("Now please analyze and output your decision (Chain of Thought + JSON)\n") + + return sb.String() +} + +func (e *StrategyEngine) formatPositionInfo(index int, pos PositionInfo, ctx *Context) string { + var sb strings.Builder + + holdingDuration := "" + if pos.UpdateTime > 0 { + durationMs := time.Now().UnixMilli() - pos.UpdateTime + durationMin := durationMs / (1000 * 60) + if durationMin < 60 { + holdingDuration = fmt.Sprintf(" | Holding Duration %d min", durationMin) + } else { + durationHour := durationMin / 60 + durationMinRemainder := durationMin % 60 + holdingDuration = fmt.Sprintf(" | Holding Duration %dh %dm", durationHour, durationMinRemainder) + } + } + + positionValue := pos.Quantity * pos.MarkPrice + if positionValue < 0 { + positionValue = -positionValue + } + + sb.WriteString(fmt.Sprintf("%d. %s %s | Entry %.4f Current %.4f | Qty %.4f | Position Value %.2f USDT | PnL%+.2f%% | PnL Amount%+.2f USDT | Peak PnL%.2f%% | Leverage %dx | Margin %.0f | Liq Price %.4f%s\n\n", + index, pos.Symbol, strings.ToUpper(pos.Side), + pos.EntryPrice, pos.MarkPrice, pos.Quantity, positionValue, pos.UnrealizedPnLPct, pos.UnrealizedPnL, pos.PeakPnLPct, + pos.Leverage, pos.MarginUsed, pos.LiquidationPrice, holdingDuration)) + + if marketData, ok := ctx.MarketDataMap[pos.Symbol]; ok { + sb.WriteString(e.formatMarketData(marketData)) + + if ctx.QuantDataMap != nil { + if quantData, hasQuant := ctx.QuantDataMap[pos.Symbol]; hasQuant { + sb.WriteString(e.formatQuantData(quantData)) + } + } + sb.WriteString("\n") + } + + return sb.String() +} + +func (e *StrategyEngine) formatCoinSourceTag(sources []string) string { + if len(sources) > 1 { + // 多信号源组合 + hasAI500 := false + hasOITop := false + hasOILow := false + hasHyperAll := false + hasHyperMain := false + for _, s := range sources { + switch s { + case "ai500": + hasAI500 = true + case "oi_top": + hasOITop = true + case "oi_low": + hasOILow = true + case "hyper_all": + hasHyperAll = true + case "hyper_main": + hasHyperMain = true + } + } + if hasAI500 && hasOITop { + return " (AI500+OI_Top dual signal)" + } + if hasAI500 && hasOILow { + return " (AI500+OI_Low dual signal)" + } + if hasOITop && hasOILow { + return " (OI_Top+OI_Low)" + } + if hasHyperMain && hasAI500 { + return " (HyperMain+AI500)" + } + if hasHyperAll || hasHyperMain { + return " (Hyperliquid)" + } + return " (Multiple sources)" + } else if len(sources) == 1 { + switch sources[0] { + case "ai500": + return " (AI500)" + case "oi_top": + return " (OI_Top 持仓增加)" + case "oi_low": + return " (OI_Low 持仓减少)" + case "static": + return " (Manual selection)" + case "hyper_all": + return " (Hyperliquid All)" + case "hyper_main": + return " (Hyperliquid Top20)" + } + } + return "" +} + +// ============================================================================ +// Market Data Formatting +// ============================================================================ + +func (e *StrategyEngine) formatMarketData(data *market.Data) string { + var sb strings.Builder + indicators := e.config.Indicators + + // 明确标注币种 + sb.WriteString(fmt.Sprintf("=== %s Market Data ===\n\n", data.Symbol)) + sb.WriteString(fmt.Sprintf("current_price = %.4f", data.CurrentPrice)) + + if indicators.EnableEMA { + sb.WriteString(fmt.Sprintf(", current_ema20 = %.3f", data.CurrentEMA20)) + } + + if indicators.EnableMACD { + sb.WriteString(fmt.Sprintf(", current_macd = %.3f", data.CurrentMACD)) + } + + if indicators.EnableRSI { + sb.WriteString(fmt.Sprintf(", current_rsi7 = %.3f", data.CurrentRSI7)) + } + + sb.WriteString("\n\n") + + if indicators.EnableOI || indicators.EnableFundingRate { + sb.WriteString(fmt.Sprintf("Additional data for %s:\n\n", data.Symbol)) + + if indicators.EnableOI && data.OpenInterest != nil { + sb.WriteString(fmt.Sprintf("Open Interest: Latest: %.2f Average: %.2f\n\n", + data.OpenInterest.Latest, data.OpenInterest.Average)) + } + + if indicators.EnableFundingRate { + sb.WriteString(fmt.Sprintf("Funding Rate: %.2e\n\n", data.FundingRate)) + } + } + + if len(data.TimeframeData) > 0 { + timeframeOrder := []string{"1m", "3m", "5m", "15m", "30m", "1h", "2h", "4h", "6h", "8h", "12h", "1d", "3d", "1w"} + for _, tf := range timeframeOrder { + if tfData, ok := data.TimeframeData[tf]; ok { + sb.WriteString(fmt.Sprintf("=== %s Timeframe (oldest → latest) ===\n\n", strings.ToUpper(tf))) + e.formatTimeframeSeriesData(&sb, tfData, indicators) + } + } + } else { + // Compatible with old data format + if data.IntradaySeries != nil { + klineConfig := indicators.Klines + sb.WriteString(fmt.Sprintf("Intraday series (%s intervals, oldest → latest):\n\n", klineConfig.PrimaryTimeframe)) + + if len(data.IntradaySeries.MidPrices) > 0 { + sb.WriteString(fmt.Sprintf("Mid prices: %s\n\n", formatFloatSlice(data.IntradaySeries.MidPrices))) + } + + if indicators.EnableEMA && len(data.IntradaySeries.EMA20Values) > 0 { + sb.WriteString(fmt.Sprintf("EMA indicators (20-period): %s\n\n", formatFloatSlice(data.IntradaySeries.EMA20Values))) + } + + if indicators.EnableMACD && len(data.IntradaySeries.MACDValues) > 0 { + sb.WriteString(fmt.Sprintf("MACD indicators: %s\n\n", formatFloatSlice(data.IntradaySeries.MACDValues))) + } + + if indicators.EnableRSI { + if len(data.IntradaySeries.RSI7Values) > 0 { + sb.WriteString(fmt.Sprintf("RSI indicators (7-Period): %s\n\n", formatFloatSlice(data.IntradaySeries.RSI7Values))) + } + if len(data.IntradaySeries.RSI14Values) > 0 { + sb.WriteString(fmt.Sprintf("RSI indicators (14-Period): %s\n\n", formatFloatSlice(data.IntradaySeries.RSI14Values))) + } + } + + if indicators.EnableVolume && len(data.IntradaySeries.Volume) > 0 { + sb.WriteString(fmt.Sprintf("Volume: %s\n\n", formatFloatSlice(data.IntradaySeries.Volume))) + } + + if indicators.EnableATR { + sb.WriteString(fmt.Sprintf("3m ATR (14-period): %.3f\n\n", data.IntradaySeries.ATR14)) + } + } + + if data.LongerTermContext != nil && indicators.Klines.EnableMultiTimeframe { + sb.WriteString(fmt.Sprintf("Longer-term context (%s timeframe):\n\n", indicators.Klines.LongerTimeframe)) + + if indicators.EnableEMA { + sb.WriteString(fmt.Sprintf("20-Period EMA: %.3f vs. 50-Period EMA: %.3f\n\n", + data.LongerTermContext.EMA20, data.LongerTermContext.EMA50)) + } + + if indicators.EnableATR { + sb.WriteString(fmt.Sprintf("3-Period ATR: %.3f vs. 14-Period ATR: %.3f\n\n", + data.LongerTermContext.ATR3, data.LongerTermContext.ATR14)) + } + + if indicators.EnableVolume { + sb.WriteString(fmt.Sprintf("Current Volume: %.3f vs. Average Volume: %.3f\n\n", + data.LongerTermContext.CurrentVolume, data.LongerTermContext.AverageVolume)) + } + + if indicators.EnableMACD && len(data.LongerTermContext.MACDValues) > 0 { + sb.WriteString(fmt.Sprintf("MACD indicators: %s\n\n", formatFloatSlice(data.LongerTermContext.MACDValues))) + } + + if indicators.EnableRSI && len(data.LongerTermContext.RSI14Values) > 0 { + sb.WriteString(fmt.Sprintf("RSI indicators (14-Period): %s\n\n", formatFloatSlice(data.LongerTermContext.RSI14Values))) + } + } + } + + return sb.String() +} + +func (e *StrategyEngine) formatTimeframeSeriesData(sb *strings.Builder, data *market.TimeframeSeriesData, indicators store.IndicatorConfig) { + if len(data.Klines) > 0 { + sb.WriteString("Time(UTC) Open High Low Close Volume\n") + for i, k := range data.Klines { + t := time.Unix(k.Time/1000, 0).UTC() + timeStr := t.Format("01-02 15:04") + marker := "" + if i == len(data.Klines)-1 { + marker = " <- current" + } + sb.WriteString(fmt.Sprintf("%-14s %-9.4f %-9.4f %-9.4f %-9.4f %-12.2f%s\n", + timeStr, k.Open, k.High, k.Low, k.Close, k.Volume, marker)) + } + sb.WriteString("\n") + } else if len(data.MidPrices) > 0 { + sb.WriteString(fmt.Sprintf("Mid prices: %s\n\n", formatFloatSlice(data.MidPrices))) + if indicators.EnableVolume && len(data.Volume) > 0 { + sb.WriteString(fmt.Sprintf("Volume: %s\n\n", formatFloatSlice(data.Volume))) + } + } + + if indicators.EnableEMA { + if len(data.EMA20Values) > 0 { + sb.WriteString(fmt.Sprintf("EMA20: %s\n", formatFloatSlice(data.EMA20Values))) + } + if len(data.EMA50Values) > 0 { + sb.WriteString(fmt.Sprintf("EMA50: %s\n", formatFloatSlice(data.EMA50Values))) + } + } + + if indicators.EnableMACD && len(data.MACDValues) > 0 { + sb.WriteString(fmt.Sprintf("MACD: %s\n", formatFloatSlice(data.MACDValues))) + } + + if indicators.EnableRSI { + if len(data.RSI7Values) > 0 { + sb.WriteString(fmt.Sprintf("RSI7: %s\n", formatFloatSlice(data.RSI7Values))) + } + if len(data.RSI14Values) > 0 { + sb.WriteString(fmt.Sprintf("RSI14: %s\n", formatFloatSlice(data.RSI14Values))) + } + } + + if indicators.EnableATR && data.ATR14 > 0 { + sb.WriteString(fmt.Sprintf("ATR14: %.4f\n", data.ATR14)) + } + + if indicators.EnableBOLL && len(data.BOLLUpper) > 0 { + sb.WriteString(fmt.Sprintf("BOLL Upper: %s\n", formatFloatSlice(data.BOLLUpper))) + sb.WriteString(fmt.Sprintf("BOLL Middle: %s\n", formatFloatSlice(data.BOLLMiddle))) + sb.WriteString(fmt.Sprintf("BOLL Lower: %s\n", formatFloatSlice(data.BOLLLower))) + } + + sb.WriteString("\n") +} + +func (e *StrategyEngine) formatQuantData(data *QuantData) string { + if data == nil { + return "" + } + + indicators := e.config.Indicators + if !indicators.EnableQuantOI && !indicators.EnableQuantNetflow { + return "" + } + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("📊 %s Quantitative Data:\n", data.Symbol)) + + if len(data.PriceChange) > 0 { + sb.WriteString("Price Change: ") + timeframes := []string{"5m", "15m", "1h", "4h", "12h", "24h"} + parts := []string{} + for _, tf := range timeframes { + if v, ok := data.PriceChange[tf]; ok { + parts = append(parts, fmt.Sprintf("%s: %+.4f%%", tf, v*100)) + } + } + sb.WriteString(strings.Join(parts, " | ")) + sb.WriteString("\n") + } + + if indicators.EnableQuantNetflow && data.Netflow != nil { + sb.WriteString("Fund Flow (Netflow):\n") + timeframes := []string{"5m", "15m", "1h", "4h", "12h", "24h"} + + if data.Netflow.Institution != nil { + if data.Netflow.Institution.Future != nil && len(data.Netflow.Institution.Future) > 0 { + sb.WriteString(" Institutional Futures:\n") + for _, tf := range timeframes { + if v, ok := data.Netflow.Institution.Future[tf]; ok { + sb.WriteString(fmt.Sprintf(" %s: %s\n", tf, formatFlowValue(v))) + } + } + } + if data.Netflow.Institution.Spot != nil && len(data.Netflow.Institution.Spot) > 0 { + sb.WriteString(" Institutional Spot:\n") + for _, tf := range timeframes { + if v, ok := data.Netflow.Institution.Spot[tf]; ok { + sb.WriteString(fmt.Sprintf(" %s: %s\n", tf, formatFlowValue(v))) + } + } + } + } + + if data.Netflow.Personal != nil { + if data.Netflow.Personal.Future != nil && len(data.Netflow.Personal.Future) > 0 { + sb.WriteString(" Retail Futures:\n") + for _, tf := range timeframes { + if v, ok := data.Netflow.Personal.Future[tf]; ok { + sb.WriteString(fmt.Sprintf(" %s: %s\n", tf, formatFlowValue(v))) + } + } + } + if data.Netflow.Personal.Spot != nil && len(data.Netflow.Personal.Spot) > 0 { + sb.WriteString(" Retail Spot:\n") + for _, tf := range timeframes { + if v, ok := data.Netflow.Personal.Spot[tf]; ok { + sb.WriteString(fmt.Sprintf(" %s: %s\n", tf, formatFlowValue(v))) + } + } + } + } + } + + if indicators.EnableQuantOI && len(data.OI) > 0 { + for exchange, oiData := range data.OI { + if len(oiData.Delta) > 0 { + sb.WriteString(fmt.Sprintf("Open Interest (%s):\n", exchange)) + for _, tf := range []string{"5m", "15m", "1h", "4h", "12h", "24h"} { + if d, ok := oiData.Delta[tf]; ok { + sb.WriteString(fmt.Sprintf(" %s: %+.4f%% (%s)\n", tf, d.OIDeltaPercent, formatFlowValue(d.OIDeltaValue))) + } + } + } + } + } + + return sb.String() +} + +func formatFlowValue(v float64) string { + sign := "" + if v >= 0 { + sign = "+" + } + absV := v + if absV < 0 { + absV = -absV + } + if absV >= 1e9 { + return fmt.Sprintf("%s%.2fB", sign, v/1e9) + } else if absV >= 1e6 { + return fmt.Sprintf("%s%.2fM", sign, v/1e6) + } else if absV >= 1e3 { + return fmt.Sprintf("%s%.2fK", sign, v/1e3) + } + return fmt.Sprintf("%s%.2f", sign, v) +} + +func formatFloatSlice(values []float64) string { + strValues := make([]string, len(values)) + for i, v := range values { + strValues[i] = fmt.Sprintf("%.4f", v) + } + return "[" + strings.Join(strValues, ", ") + "]" +} diff --git a/kernel/schema_test.go b/kernel/schema_test.go deleted file mode 100644 index e0ee0a1e..00000000 --- a/kernel/schema_test.go +++ /dev/null @@ -1,278 +0,0 @@ -package kernel - -import ( - "strings" - "testing" -) - -// TestDataDictionary 测试数据字典定义 -func TestDataDictionary(t *testing.T) { - // 测试账户指标字典 - t.Run("AccountMetrics", func(t *testing.T) { - equity := DataDictionary["AccountMetrics"]["Equity"] - - if equity.NameZH != "总权益" { - t.Errorf("Expected NameZH='总权益', got '%s'", equity.NameZH) - } - - if equity.NameEN != "Total Equity" { - t.Errorf("Expected NameEN='Total Equity', got '%s'", equity.NameEN) - } - - if equity.Unit != "USDT" { - t.Errorf("Expected Unit='USDT', got '%s'", equity.Unit) - } - - if equity.GetName(LangChinese) != "总权益" { - t.Errorf("GetName(Chinese) failed") - } - - if equity.GetName(LangEnglish) != "Total Equity" { - t.Errorf("GetName(English) failed") - } - }) - - // 测试持仓指标字典 - t.Run("PositionMetrics", func(t *testing.T) { - peakPnL := DataDictionary["PositionMetrics"]["PeakPnL%"] - - if peakPnL.NameZH == "" { - t.Error("PeakPnL% NameZH is empty") - } - - if peakPnL.NameEN == "" { - t.Error("PeakPnL% NameEN is empty") - } - - if !strings.Contains(peakPnL.DescZH, "峰值") { - t.Error("PeakPnL% DescZH should contain '峰值'") - } - }) -} - -// TestTradingRules 测试交易规则定义 -func TestTradingRules(t *testing.T) { - t.Run("RiskManagement", func(t *testing.T) { - maxMargin := TradingRules.RiskManagement["MaxMarginUsage"] - - if maxMargin.Value != 0.30 { - t.Errorf("Expected MaxMarginUsage=0.30, got %v", maxMargin.Value) - } - - if maxMargin.GetDesc(LangChinese) == "" { - t.Error("MaxMarginUsage DescZH is empty") - } - - if maxMargin.GetDesc(LangEnglish) == "" { - t.Error("MaxMarginUsage DescEN is empty") - } - - if !strings.Contains(maxMargin.DescZH, "30%") { - t.Error("MaxMarginUsage DescZH should mention 30%") - } - }) - - t.Run("ExitSignals", func(t *testing.T) { - trailing := TradingRules.ExitSignals["TrailingStop"] - - if trailing.Value != 0.30 { - t.Errorf("Expected TrailingStop=0.30, got %v", trailing.Value) - } - - if !strings.Contains(trailing.ReasonZH, "止盈") { - t.Error("TrailingStop ReasonZH should mention '止盈'") - } - - if !strings.Contains(trailing.ReasonEN, "profit") { - t.Error("TrailingStop ReasonEN should mention 'profit'") - } - }) -} - -// TestOIInterpretation 测试OI解读 -func TestOIInterpretation(t *testing.T) { - t.Run("OI_Up_Price_Up", func(t *testing.T) { - if OIInterpretation.OIUp_PriceUp.ZH == "" { - t.Error("OI Up + Price Up ZH is empty") - } - - if OIInterpretation.OIUp_PriceUp.EN == "" { - t.Error("OI Up + Price Up EN is empty") - } - - if !strings.Contains(OIInterpretation.OIUp_PriceUp.ZH, "多头") { - t.Error("OI Up + Price Up should indicate bullish trend") - } - }) -} - -// TestCommonMistakes 测试常见错误定义 -func TestCommonMistakes(t *testing.T) { - if len(CommonMistakes) == 0 { - t.Error("CommonMistakes should not be empty") - } - - for i, mistake := range CommonMistakes { - if mistake.ErrorZH == "" { - t.Errorf("Mistake #%d ErrorZH is empty", i+1) - } - - if mistake.ErrorEN == "" { - t.Errorf("Mistake #%d ErrorEN is empty", i+1) - } - - if mistake.CorrectZH == "" { - t.Errorf("Mistake #%d CorrectZH is empty", i+1) - } - - if mistake.CorrectEN == "" { - t.Errorf("Mistake #%d CorrectEN is empty", i+1) - } - } -} - -// TestGetSchemaPrompt 测试Schema提示词生成 -func TestGetSchemaPrompt(t *testing.T) { - t.Run("Chinese", func(t *testing.T) { - prompt := GetSchemaPrompt(LangChinese) - - if prompt == "" { - t.Fatal("Chinese schema prompt is empty") - } - - // 验证包含关键内容 - mustContain := []string{ - "数据字典", - "账户指标", - "交易指标", - "持仓指标", - "市场数据", - "持仓量(OI)变化解读", - } - - for _, keyword := range mustContain { - if !strings.Contains(prompt, keyword) { - t.Errorf("Chinese prompt should contain '%s'", keyword) - } - } - }) - - t.Run("English", func(t *testing.T) { - prompt := GetSchemaPrompt(LangEnglish) - - if prompt == "" { - t.Fatal("English schema prompt is empty") - } - - // 验证包含关键内容 - mustContain := []string{ - "Data Dictionary", - "Account Metrics", - "Trade Metrics", - "Position Metrics", - "Market Data", - "Open Interest", - } - - for _, keyword := range mustContain { - if !strings.Contains(prompt, keyword) { - t.Errorf("English prompt should contain '%s'", keyword) - } - } - }) - - t.Run("Consistency", func(t *testing.T) { - promptZH := GetSchemaPrompt(LangChinese) - promptEN := GetSchemaPrompt(LangEnglish) - - // 两个版本都应该包含相同数量的字段定义 - // 虽然内容不同,但结构应该相似 - - zhLines := strings.Split(promptZH, "\n") - enLines := strings.Split(promptEN, "\n") - - // 行数应该大致相当(允许10%的差异) - ratio := float64(len(zhLines)) / float64(len(enLines)) - if ratio < 0.9 || ratio > 1.1 { - t.Logf("Warning: Line count difference is significant (ZH: %d, EN: %d)", - len(zhLines), len(enLines)) - } - }) -} - -// BenchmarkGetSchemaPrompt 性能测试 -func BenchmarkGetSchemaPrompt(b *testing.B) { - b.Run("Chinese", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = GetSchemaPrompt(LangChinese) - } - }) - - b.Run("English", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = GetSchemaPrompt(LangEnglish) - } - }) -} - -// TestFieldDefinitionMethods 测试字段定义方法 -func TestFieldDefinitionMethods(t *testing.T) { - field := BilingualFieldDef{ - NameZH: "测试字段", - NameEN: "Test Field", - Unit: "USDT", - FormulaZH: "中文公式", - FormulaEN: "English formula", - DescZH: "中文描述", - DescEN: "English description", - } - - // 测试GetName - if field.GetName(LangChinese) != "测试字段" { - t.Error("GetName(Chinese) failed") - } - if field.GetName(LangEnglish) != "Test Field" { - t.Error("GetName(English) failed") - } - - // 测试GetFormula - if field.GetFormula(LangChinese) != "中文公式" { - t.Error("GetFormula(Chinese) failed") - } - if field.GetFormula(LangEnglish) != "English formula" { - t.Error("GetFormula(English) failed") - } - - // 测试GetDesc - if field.GetDesc(LangChinese) != "中文描述" { - t.Error("GetDesc(Chinese) failed") - } - if field.GetDesc(LangEnglish) != "English description" { - t.Error("GetDesc(English) failed") - } -} - -// TestRuleDefinitionMethods 测试规则定义方法 -func TestRuleDefinitionMethods(t *testing.T) { - rule := BilingualRuleDef{ - Value: 0.30, - DescZH: "中文描述", - DescEN: "English description", - ReasonZH: "中文原因", - ReasonEN: "English reason", - } - - if rule.GetDesc(LangChinese) != "中文描述" { - t.Error("GetDesc(Chinese) failed") - } - if rule.GetDesc(LangEnglish) != "English description" { - t.Error("GetDesc(English) failed") - } - - if rule.GetReason(LangChinese) != "中文原因" { - t.Error("GetReason(Chinese) failed") - } - if rule.GetReason(LangEnglish) != "English reason" { - t.Error("GetReason(English) failed") - } -} diff --git a/main.go b/main.go index 06f76a17..4dcf8bb6 100644 --- a/main.go +++ b/main.go @@ -6,7 +6,7 @@ import ( "nofx/backtest" "nofx/config" "nofx/crypto" - "nofx/experience" + "nofx/telemetry" "nofx/logger" "nofx/manager" "nofx/mcp" @@ -194,5 +194,5 @@ func initInstallationID(st *store.Store) { } // Set installation ID in experience module - experience.SetInstallationID(installationID) + telemetry.SetInstallationID(installationID) } diff --git a/manager/trader_manager_test.go b/manager/trader_manager_test.go deleted file mode 100644 index 6d3f0bc5..00000000 --- a/manager/trader_manager_test.go +++ /dev/null @@ -1,87 +0,0 @@ -package manager - -import ( - "testing" -) - -// TestRemoveTrader tests removing trader from memory -func TestRemoveTrader(t *testing.T) { - tm := NewTraderManager() - - // Create a mock trader and add it to map - traderID := "test-trader-123" - tm.traders[traderID] = nil // Use nil as placeholder, only need to verify deletion logic in test - - // Verify trader exists - if _, exists := tm.traders[traderID]; !exists { - t.Fatal("trader should exist in map") - } - - // Call RemoveTrader - tm.RemoveTrader(traderID) - - // Verify trader has been removed - if _, exists := tm.traders[traderID]; exists { - t.Error("trader should be removed from map") - } -} - -// TestRemoveTrader_NonExistent tests that removing non-existent trader doesn't error -func TestRemoveTrader_NonExistent(t *testing.T) { - tm := NewTraderManager() - - // Trying to remove non-existent trader should not panic - defer func() { - if r := recover(); r != nil { - t.Errorf("removing non-existent trader should not panic: %v", r) - } - }() - - tm.RemoveTrader("non-existent-trader") -} - -// TestRemoveTrader_Concurrent tests concurrent removal of trader safety -func TestRemoveTrader_Concurrent(t *testing.T) { - tm := NewTraderManager() - traderID := "test-trader-concurrent" - - // Add trader - tm.traders[traderID] = nil - - // Concurrently call RemoveTrader - done := make(chan bool, 10) - for i := 0; i < 10; i++ { - go func() { - tm.RemoveTrader(traderID) - done <- true - }() - } - - // Wait for all goroutines to complete - for i := 0; i < 10; i++ { - <-done - } - - // Verify trader has been removed - if _, exists := tm.traders[traderID]; exists { - t.Error("trader should be removed from map") - } -} - -// TestGetTrader_AfterRemove tests that getting trader after removal returns error -func TestGetTrader_AfterRemove(t *testing.T) { - tm := NewTraderManager() - traderID := "test-trader-get" - - // Add trader - tm.traders[traderID] = nil - - // Remove trader - tm.RemoveTrader(traderID) - - // Try to get removed trader - _, err := tm.GetTrader(traderID) - if err == nil { - t.Error("getting removed trader should return error") - } -} diff --git a/market/data.go b/market/data.go index 1860ff54..99dbfc9c 100644 --- a/market/data.go +++ b/market/data.go @@ -1,15 +1,11 @@ package market import ( - "context" "encoding/json" "fmt" "io" "math" "nofx/logger" - "nofx/provider/coinank/coinank_api" - "nofx/provider/coinank/coinank_enum" - "nofx/provider/hyperliquid" "strconv" "strings" "sync" @@ -28,143 +24,6 @@ var ( frCacheTTL = 1 * time.Hour ) -// Note: Kline data now uses free/open API (coinank_api.Kline) which doesn't require authentication - -// getKlinesFromCoinAnk fetches kline data from CoinAnk API (replacement for WSMonitorCli) -func getKlinesFromCoinAnk(symbol, interval, exchange string, limit int) ([]Kline, error) { - // Map interval string to coinank enum - var coinankInterval coinank_enum.Interval - switch interval { - case "1m": - coinankInterval = coinank_enum.Minute1 - case "3m": - coinankInterval = coinank_enum.Minute3 - case "5m": - coinankInterval = coinank_enum.Minute5 - case "15m": - coinankInterval = coinank_enum.Minute15 - case "30m": - coinankInterval = coinank_enum.Minute30 - case "1h": - coinankInterval = coinank_enum.Hour1 - case "2h": - coinankInterval = coinank_enum.Hour2 - case "4h": - coinankInterval = coinank_enum.Hour4 - case "6h": - coinankInterval = coinank_enum.Hour6 - case "8h": - coinankInterval = coinank_enum.Hour8 - case "12h": - coinankInterval = coinank_enum.Hour12 - case "1d": - coinankInterval = coinank_enum.Day1 - case "3d": - coinankInterval = coinank_enum.Day3 - case "1w": - coinankInterval = coinank_enum.Week1 - default: - return nil, fmt.Errorf("unsupported interval: %s", interval) - } - - // Map exchange string to coinank enum - var coinankExchange coinank_enum.Exchange - switch strings.ToLower(exchange) { - case "binance": - coinankExchange = coinank_enum.Binance - case "bybit": - coinankExchange = coinank_enum.Bybit - case "okx": - coinankExchange = coinank_enum.Okex - case "bitget": - coinankExchange = coinank_enum.Bitget - case "gate": - coinankExchange = coinank_enum.Gate - case "hyperliquid": - coinankExchange = coinank_enum.Hyperliquid - case "aster": - coinankExchange = coinank_enum.Aster - default: - // Default to Binance for unknown exchanges - coinankExchange = coinank_enum.Binance - } - - // Call CoinAnk free/open API (no authentication required) - ctx := context.Background() - ts := time.Now().UnixMilli() - // Use "To" side to search backward from current time (get historical klines) - coinankKlines, err := coinank_api.Kline(ctx, symbol, coinankExchange, ts, coinank_enum.To, limit, coinankInterval) - if err != nil { - // If exchange-specific data fails, fallback to Binance - if coinankExchange != coinank_enum.Binance { - logger.Warnf("⚠️ CoinAnk %s data failed, falling back to Binance: %v", exchange, err) - coinankKlines, err = coinank_api.Kline(ctx, symbol, coinank_enum.Binance, ts, coinank_enum.To, limit, coinankInterval) - if err != nil { - return nil, fmt.Errorf("CoinAnk API error (fallback): %w", err) - } - } else { - return nil, fmt.Errorf("CoinAnk API error: %w", err) - } - } - - // Convert coinank kline format to market.Kline format - klines := make([]Kline, len(coinankKlines)) - for i, ck := range coinankKlines { - klines[i] = Kline{ - OpenTime: ck.StartTime, - Open: ck.Open, - High: ck.High, - Low: ck.Low, - Close: ck.Close, - Volume: ck.Volume, - CloseTime: ck.EndTime, - } - } - - return klines, nil -} - -// getKlinesFromHyperliquid fetches kline data from Hyperliquid API for xyz dex assets -func getKlinesFromHyperliquid(symbol, interval string, limit int) ([]Kline, error) { - // Remove xyz: prefix if present for the API call - baseCoin := strings.TrimPrefix(symbol, "xyz:") - - // Map interval to Hyperliquid format - hlInterval := hyperliquid.MapTimeframe(interval) - - // Create Hyperliquid client - client := hyperliquid.NewClient() - - // Fetch candles - ctx := context.Background() - candles, err := client.GetCandles(ctx, baseCoin, hlInterval, limit) - if err != nil { - return nil, fmt.Errorf("Hyperliquid API error: %w", err) - } - - // Convert to market.Kline format - klines := make([]Kline, len(candles)) - for i, c := range candles { - open, _ := strconv.ParseFloat(c.Open, 64) - high, _ := strconv.ParseFloat(c.High, 64) - low, _ := strconv.ParseFloat(c.Low, 64) - closePrice, _ := strconv.ParseFloat(c.Close, 64) - volume, _ := strconv.ParseFloat(c.Volume, 64) - - klines[i] = Kline{ - OpenTime: c.OpenTime, - Open: open, - High: high, - Low: low, - Close: closePrice, - Volume: volume, - CloseTime: c.CloseTime, - } - } - - return klines, nil -} - // Get retrieves market data for the specified token (uses Binance data by default) func Get(symbol string) (*Data, error) { return GetWithExchange(symbol, "binance") @@ -396,398 +255,6 @@ func GetWithTimeframes(symbol string, timeframes []string, primaryTimeframe stri }, nil } -// calculateTimeframeSeries calculates series data for a single timeframe -func calculateTimeframeSeries(klines []Kline, timeframe string, count int) *TimeframeSeriesData { - if count <= 0 { - count = 10 // default - } - - data := &TimeframeSeriesData{ - Timeframe: timeframe, - Klines: make([]KlineBar, 0, count), - MidPrices: make([]float64, 0, count), - EMA20Values: make([]float64, 0, count), - EMA50Values: make([]float64, 0, count), - MACDValues: make([]float64, 0, count), - RSI7Values: make([]float64, 0, count), - RSI14Values: make([]float64, 0, count), - Volume: make([]float64, 0, count), - BOLLUpper: make([]float64, 0, count), - BOLLMiddle: make([]float64, 0, count), - BOLLLower: make([]float64, 0, count), - } - - // Get latest N data points based on count from config - start := len(klines) - count - if start < 0 { - start = 0 - } - - for i := start; i < len(klines); i++ { - // Store full OHLCV kline data - data.Klines = append(data.Klines, KlineBar{ - Time: klines[i].OpenTime, - Open: klines[i].Open, - High: klines[i].High, - Low: klines[i].Low, - Close: klines[i].Close, - Volume: klines[i].Volume, - }) - - // Keep MidPrices and Volume for backward compatibility - data.MidPrices = append(data.MidPrices, klines[i].Close) - data.Volume = append(data.Volume, klines[i].Volume) - - // Calculate EMA20 for each point - if i >= 19 { - ema20 := calculateEMA(klines[:i+1], 20) - data.EMA20Values = append(data.EMA20Values, ema20) - } - - // Calculate EMA50 for each point - if i >= 49 { - ema50 := calculateEMA(klines[:i+1], 50) - data.EMA50Values = append(data.EMA50Values, ema50) - } - - // Calculate MACD for each point - if i >= 25 { - macd := calculateMACD(klines[:i+1]) - data.MACDValues = append(data.MACDValues, macd) - } - - // Calculate RSI for each point - if i >= 7 { - rsi7 := calculateRSI(klines[:i+1], 7) - data.RSI7Values = append(data.RSI7Values, rsi7) - } - if i >= 14 { - rsi14 := calculateRSI(klines[:i+1], 14) - data.RSI14Values = append(data.RSI14Values, rsi14) - } - - // Calculate Bollinger Bands (period 20, std dev multiplier 2) - if i >= 19 { - upper, middle, lower := calculateBOLL(klines[:i+1], 20, 2.0) - data.BOLLUpper = append(data.BOLLUpper, upper) - data.BOLLMiddle = append(data.BOLLMiddle, middle) - data.BOLLLower = append(data.BOLLLower, lower) - } - } - - // Calculate ATR14 - data.ATR14 = calculateATR(klines, 14) - - return data -} - -// calculatePriceChangeByBars calculates how many K-lines to look back for price change based on timeframe -func calculatePriceChangeByBars(klines []Kline, timeframe string, targetMinutes int) float64 { - if len(klines) < 2 { - return 0 - } - - // Parse timeframe to minutes - tfMinutes := parseTimeframeToMinutes(timeframe) - if tfMinutes <= 0 { - return 0 - } - - // Calculate how many K-lines to look back - barsBack := targetMinutes / tfMinutes - if barsBack < 1 { - barsBack = 1 - } - - currentPrice := klines[len(klines)-1].Close - idx := len(klines) - 1 - barsBack - if idx < 0 { - idx = 0 - } - - oldPrice := klines[idx].Close - if oldPrice > 0 { - return ((currentPrice - oldPrice) / oldPrice) * 100 - } - return 0 -} - -// parseTimeframeToMinutes parses timeframe string to minutes -func parseTimeframeToMinutes(tf string) int { - switch tf { - case "1m": - return 1 - case "3m": - return 3 - case "5m": - return 5 - case "15m": - return 15 - case "30m": - return 30 - case "1h": - return 60 - case "2h": - return 120 - case "4h": - return 240 - case "6h": - return 360 - case "8h": - return 480 - case "12h": - return 720 - case "1d": - return 1440 - case "3d": - return 4320 - case "1w": - return 10080 - default: - return 0 - } -} - -// calculateEMA calculates EMA -func calculateEMA(klines []Kline, period int) float64 { - if len(klines) < period { - return 0 - } - - // Calculate SMA as initial EMA - sum := 0.0 - for i := 0; i < period; i++ { - sum += klines[i].Close - } - ema := sum / float64(period) - - // Calculate EMA - multiplier := 2.0 / float64(period+1) - for i := period; i < len(klines); i++ { - ema = (klines[i].Close-ema)*multiplier + ema - } - - return ema -} - -// calculateMACD calculates MACD -func calculateMACD(klines []Kline) float64 { - if len(klines) < 26 { - return 0 - } - - // Calculate 12-period and 26-period EMA - ema12 := calculateEMA(klines, 12) - ema26 := calculateEMA(klines, 26) - - // MACD = EMA12 - EMA26 - return ema12 - ema26 -} - -// calculateRSI calculates RSI -func calculateRSI(klines []Kline, period int) float64 { - if len(klines) <= period { - return 0 - } - - gains := 0.0 - losses := 0.0 - - // Calculate initial average gain/loss - for i := 1; i <= period; i++ { - change := klines[i].Close - klines[i-1].Close - if change > 0 { - gains += change - } else { - losses += -change - } - } - - avgGain := gains / float64(period) - avgLoss := losses / float64(period) - - // Use Wilder smoothing method to calculate subsequent RSI - for i := period + 1; i < len(klines); i++ { - change := klines[i].Close - klines[i-1].Close - if change > 0 { - avgGain = (avgGain*float64(period-1) + change) / float64(period) - avgLoss = (avgLoss * float64(period-1)) / float64(period) - } else { - avgGain = (avgGain * float64(period-1)) / float64(period) - avgLoss = (avgLoss*float64(period-1) + (-change)) / float64(period) - } - } - - if avgLoss == 0 { - return 100 - } - - rs := avgGain / avgLoss - rsi := 100 - (100 / (1 + rs)) - - return rsi -} - -// calculateATR calculates ATR -func calculateATR(klines []Kline, period int) float64 { - if len(klines) <= period { - return 0 - } - - trs := make([]float64, len(klines)) - for i := 1; i < len(klines); i++ { - high := klines[i].High - low := klines[i].Low - prevClose := klines[i-1].Close - - tr1 := high - low - tr2 := math.Abs(high - prevClose) - tr3 := math.Abs(low - prevClose) - - trs[i] = math.Max(tr1, math.Max(tr2, tr3)) - } - - // Calculate initial ATR - sum := 0.0 - for i := 1; i <= period; i++ { - sum += trs[i] - } - atr := sum / float64(period) - - // Wilder smoothing - for i := period + 1; i < len(klines); i++ { - atr = (atr*float64(period-1) + trs[i]) / float64(period) - } - - return atr -} - -// calculateBOLL calculates Bollinger Bands (upper, middle, lower) -// period: typically 20, multiplier: typically 2 -func calculateBOLL(klines []Kline, period int, multiplier float64) (upper, middle, lower float64) { - if len(klines) < period { - return 0, 0, 0 - } - - // Calculate SMA (middle band) - sum := 0.0 - for i := len(klines) - period; i < len(klines); i++ { - sum += klines[i].Close - } - sma := sum / float64(period) - - // Calculate standard deviation - variance := 0.0 - for i := len(klines) - period; i < len(klines); i++ { - diff := klines[i].Close - sma - variance += diff * diff - } - stdDev := math.Sqrt(variance / float64(period)) - - // Calculate bands - middle = sma - upper = sma + multiplier*stdDev - lower = sma - multiplier*stdDev - - return upper, middle, lower -} - -// calculateIntradaySeries calculates intraday series data -func calculateIntradaySeries(klines []Kline) *IntradayData { - data := &IntradayData{ - MidPrices: make([]float64, 0, 10), - EMA20Values: make([]float64, 0, 10), - MACDValues: make([]float64, 0, 10), - RSI7Values: make([]float64, 0, 10), - RSI14Values: make([]float64, 0, 10), - Volume: make([]float64, 0, 10), - } - - // Get latest 10 data points - start := len(klines) - 10 - if start < 0 { - start = 0 - } - - for i := start; i < len(klines); i++ { - data.MidPrices = append(data.MidPrices, klines[i].Close) - data.Volume = append(data.Volume, klines[i].Volume) - - // Calculate EMA20 for each point - if i >= 19 { - ema20 := calculateEMA(klines[:i+1], 20) - data.EMA20Values = append(data.EMA20Values, ema20) - } - - // Calculate MACD for each point - if i >= 25 { - macd := calculateMACD(klines[:i+1]) - data.MACDValues = append(data.MACDValues, macd) - } - - // Calculate RSI for each point - if i >= 7 { - rsi7 := calculateRSI(klines[:i+1], 7) - data.RSI7Values = append(data.RSI7Values, rsi7) - } - if i >= 14 { - rsi14 := calculateRSI(klines[:i+1], 14) - data.RSI14Values = append(data.RSI14Values, rsi14) - } - } - - // Calculate 3m ATR14 - data.ATR14 = calculateATR(klines, 14) - - return data -} - -// calculateLongerTermData calculates longer-term data -func calculateLongerTermData(klines []Kline) *LongerTermData { - data := &LongerTermData{ - MACDValues: make([]float64, 0, 10), - RSI14Values: make([]float64, 0, 10), - } - - // Calculate EMA - data.EMA20 = calculateEMA(klines, 20) - data.EMA50 = calculateEMA(klines, 50) - - // Calculate ATR - data.ATR3 = calculateATR(klines, 3) - data.ATR14 = calculateATR(klines, 14) - - // Calculate volume - if len(klines) > 0 { - data.CurrentVolume = klines[len(klines)-1].Volume - // Calculate average volume - sum := 0.0 - for _, k := range klines { - sum += k.Volume - } - data.AverageVolume = sum / float64(len(klines)) - } - - // Calculate MACD and RSI series - start := len(klines) - 10 - if start < 0 { - start = 0 - } - - for i := start; i < len(klines); i++ { - if i >= 25 { - macd := calculateMACD(klines[:i+1]) - data.MACDValues = append(data.MACDValues, macd) - } - if i >= 14 { - rsi14 := calculateRSI(klines[:i+1], 14) - data.RSI14Values = append(data.RSI14Values, rsi14) - } - } - - return data -} - // getOpenInterestData retrieves OI data func getOpenInterestData(symbol string) (*OIData, error) { url := fmt.Sprintf("https://fapi.binance.com/fapi/v1/openInterest?symbol=%s", symbol) @@ -1227,118 +694,3 @@ func isStaleData(klines []Kline, symbol string) bool { logger.Infof("⚠️ %s detected extreme price stability (no fluctuation for %d consecutive periods), but volume is normal", symbol, stalePriceThreshold) return false } - -// ========== 导出的指标计算函数(供测试使用) ========== - -// ExportCalculateEMA exports calculateEMA for testing -func ExportCalculateEMA(klines []Kline, period int) float64 { - return calculateEMA(klines, period) -} - -// ExportCalculateMACD exports calculateMACD for testing -func ExportCalculateMACD(klines []Kline) float64 { - return calculateMACD(klines) -} - -// ExportCalculateRSI exports calculateRSI for testing -func ExportCalculateRSI(klines []Kline, period int) float64 { - return calculateRSI(klines, period) -} - -// ExportCalculateATR exports calculateATR for testing -func ExportCalculateATR(klines []Kline, period int) float64 { - return calculateATR(klines, period) -} - -// ExportCalculateBOLL exports calculateBOLL for testing -func ExportCalculateBOLL(klines []Kline, period int, multiplier float64) (upper, middle, lower float64) { - return calculateBOLL(klines, period, multiplier) -} - -// calculateDonchian calculates Donchian channel (highest high, lowest low) for given period -func calculateDonchian(klines []Kline, period int) (upper, lower float64) { - if len(klines) == 0 || period <= 0 { - return 0, 0 - } - - // Use all available klines if period > len(klines) - start := len(klines) - period - if start < 0 { - start = 0 - } - - upper = klines[start].High - lower = klines[start].Low - - for i := start + 1; i < len(klines); i++ { - if klines[i].High > upper { - upper = klines[i].High - } - if klines[i].Low < lower { - lower = klines[i].Low - } - } - - return upper, lower -} - -// ExportCalculateDonchian exports calculateDonchian for testing -func ExportCalculateDonchian(klines []Kline, period int) (float64, float64) { - return calculateDonchian(klines, period) -} - -// Box period constants (in 1h candles) -const ( - ShortBoxPeriod = 72 // 3 days of 1h candles - MidBoxPeriod = 240 // 10 days of 1h candles - LongBoxPeriod = 500 // ~21 days of 1h candles -) - -// calculateBoxData calculates multi-period box data from klines -func calculateBoxData(klines []Kline, currentPrice float64) *BoxData { - box := &BoxData{ - CurrentPrice: currentPrice, - } - - if len(klines) == 0 { - return box - } - - box.ShortUpper, box.ShortLower = calculateDonchian(klines, ShortBoxPeriod) - box.MidUpper, box.MidLower = calculateDonchian(klines, MidBoxPeriod) - box.LongUpper, box.LongLower = calculateDonchian(klines, LongBoxPeriod) - - return box -} - -// ExportCalculateBoxData exports calculateBoxData for testing -func ExportCalculateBoxData(klines []Kline, currentPrice float64) *BoxData { - return calculateBoxData(klines, currentPrice) -} - -// GetBoxData fetches 1h klines and calculates box data for a symbol -func GetBoxData(symbol string) (*BoxData, error) { - symbol = Normalize(symbol) - - // Fetch 500 1h klines - var klines []Kline - var err error - - if IsXyzDexAsset(symbol) { - klines, err = getKlinesFromHyperliquid(symbol, "1h", LongBoxPeriod) - } else { - klines, err = getKlinesFromCoinAnk(symbol, "1h", "binance", LongBoxPeriod) - } - - if err != nil { - return nil, fmt.Errorf("failed to get 1h klines: %w", err) - } - - if len(klines) == 0 { - return nil, fmt.Errorf("no kline data available") - } - - currentPrice := klines[len(klines)-1].Close - - return calculateBoxData(klines, currentPrice), nil -} diff --git a/market/data_indicators.go b/market/data_indicators.go new file mode 100644 index 00000000..a0dcea1f --- /dev/null +++ b/market/data_indicators.go @@ -0,0 +1,235 @@ +package market + +import "math" + +// calculateEMA calculates EMA +func calculateEMA(klines []Kline, period int) float64 { + if len(klines) < period { + return 0 + } + + // Calculate SMA as initial EMA + sum := 0.0 + for i := 0; i < period; i++ { + sum += klines[i].Close + } + ema := sum / float64(period) + + // Calculate EMA + multiplier := 2.0 / float64(period+1) + for i := period; i < len(klines); i++ { + ema = (klines[i].Close-ema)*multiplier + ema + } + + return ema +} + +// calculateMACD calculates MACD +func calculateMACD(klines []Kline) float64 { + if len(klines) < 26 { + return 0 + } + + // Calculate 12-period and 26-period EMA + ema12 := calculateEMA(klines, 12) + ema26 := calculateEMA(klines, 26) + + // MACD = EMA12 - EMA26 + return ema12 - ema26 +} + +// calculateRSI calculates RSI +func calculateRSI(klines []Kline, period int) float64 { + if len(klines) <= period { + return 0 + } + + gains := 0.0 + losses := 0.0 + + // Calculate initial average gain/loss + for i := 1; i <= period; i++ { + change := klines[i].Close - klines[i-1].Close + if change > 0 { + gains += change + } else { + losses += -change + } + } + + avgGain := gains / float64(period) + avgLoss := losses / float64(period) + + // Use Wilder smoothing method to calculate subsequent RSI + for i := period + 1; i < len(klines); i++ { + change := klines[i].Close - klines[i-1].Close + if change > 0 { + avgGain = (avgGain*float64(period-1) + change) / float64(period) + avgLoss = (avgLoss * float64(period-1)) / float64(period) + } else { + avgGain = (avgGain * float64(period-1)) / float64(period) + avgLoss = (avgLoss*float64(period-1) + (-change)) / float64(period) + } + } + + if avgLoss == 0 { + return 100 + } + + rs := avgGain / avgLoss + rsi := 100 - (100 / (1 + rs)) + + return rsi +} + +// calculateATR calculates ATR +func calculateATR(klines []Kline, period int) float64 { + if len(klines) <= period { + return 0 + } + + trs := make([]float64, len(klines)) + for i := 1; i < len(klines); i++ { + high := klines[i].High + low := klines[i].Low + prevClose := klines[i-1].Close + + tr1 := high - low + tr2 := math.Abs(high - prevClose) + tr3 := math.Abs(low - prevClose) + + trs[i] = math.Max(tr1, math.Max(tr2, tr3)) + } + + // Calculate initial ATR + sum := 0.0 + for i := 1; i <= period; i++ { + sum += trs[i] + } + atr := sum / float64(period) + + // Wilder smoothing + for i := period + 1; i < len(klines); i++ { + atr = (atr*float64(period-1) + trs[i]) / float64(period) + } + + return atr +} + +// calculateBOLL calculates Bollinger Bands (upper, middle, lower) +// period: typically 20, multiplier: typically 2 +func calculateBOLL(klines []Kline, period int, multiplier float64) (upper, middle, lower float64) { + if len(klines) < period { + return 0, 0, 0 + } + + // Calculate SMA (middle band) + sum := 0.0 + for i := len(klines) - period; i < len(klines); i++ { + sum += klines[i].Close + } + sma := sum / float64(period) + + // Calculate standard deviation + variance := 0.0 + for i := len(klines) - period; i < len(klines); i++ { + diff := klines[i].Close - sma + variance += diff * diff + } + stdDev := math.Sqrt(variance / float64(period)) + + // Calculate bands + middle = sma + upper = sma + multiplier*stdDev + lower = sma - multiplier*stdDev + + return upper, middle, lower +} + +// calculateDonchian calculates Donchian channel (highest high, lowest low) for given period +func calculateDonchian(klines []Kline, period int) (upper, lower float64) { + if len(klines) == 0 || period <= 0 { + return 0, 0 + } + + // Use all available klines if period > len(klines) + start := len(klines) - period + if start < 0 { + start = 0 + } + + upper = klines[start].High + lower = klines[start].Low + + for i := start + 1; i < len(klines); i++ { + if klines[i].High > upper { + upper = klines[i].High + } + if klines[i].Low < lower { + lower = klines[i].Low + } + } + + return upper, lower +} + +// Box period constants (in 1h candles) +const ( + ShortBoxPeriod = 72 // 3 days of 1h candles + MidBoxPeriod = 240 // 10 days of 1h candles + LongBoxPeriod = 500 // ~21 days of 1h candles +) + +// calculateBoxData calculates multi-period box data from klines +func calculateBoxData(klines []Kline, currentPrice float64) *BoxData { + box := &BoxData{ + CurrentPrice: currentPrice, + } + + if len(klines) == 0 { + return box + } + + box.ShortUpper, box.ShortLower = calculateDonchian(klines, ShortBoxPeriod) + box.MidUpper, box.MidLower = calculateDonchian(klines, MidBoxPeriod) + box.LongUpper, box.LongLower = calculateDonchian(klines, LongBoxPeriod) + + return box +} + +// ========== Exported indicator calculation functions (for testing) ========== + +// ExportCalculateEMA exports calculateEMA for testing +func ExportCalculateEMA(klines []Kline, period int) float64 { + return calculateEMA(klines, period) +} + +// ExportCalculateMACD exports calculateMACD for testing +func ExportCalculateMACD(klines []Kline) float64 { + return calculateMACD(klines) +} + +// ExportCalculateRSI exports calculateRSI for testing +func ExportCalculateRSI(klines []Kline, period int) float64 { + return calculateRSI(klines, period) +} + +// ExportCalculateATR exports calculateATR for testing +func ExportCalculateATR(klines []Kline, period int) float64 { + return calculateATR(klines, period) +} + +// ExportCalculateBOLL exports calculateBOLL for testing +func ExportCalculateBOLL(klines []Kline, period int, multiplier float64) (upper, middle, lower float64) { + return calculateBOLL(klines, period, multiplier) +} + +// ExportCalculateDonchian exports calculateDonchian for testing +func ExportCalculateDonchian(klines []Kline, period int) (float64, float64) { + return calculateDonchian(klines, period) +} + +// ExportCalculateBoxData exports calculateBoxData for testing +func ExportCalculateBoxData(klines []Kline, currentPrice float64) *BoxData { + return calculateBoxData(klines, currentPrice) +} diff --git a/market/data_klines.go b/market/data_klines.go new file mode 100644 index 00000000..e4cb869f --- /dev/null +++ b/market/data_klines.go @@ -0,0 +1,425 @@ +package market + +import ( + "context" + "fmt" + "nofx/logger" + "nofx/provider/coinank/coinank_api" + "nofx/provider/coinank/coinank_enum" + "nofx/provider/hyperliquid" + "strconv" + "strings" + "time" +) + +// Note: Kline data now uses free/open API (coinank_api.Kline) which doesn't require authentication + +// getKlinesFromCoinAnk fetches kline data from CoinAnk API (replacement for WSMonitorCli) +func getKlinesFromCoinAnk(symbol, interval, exchange string, limit int) ([]Kline, error) { + // Map interval string to coinank enum + var coinankInterval coinank_enum.Interval + switch interval { + case "1m": + coinankInterval = coinank_enum.Minute1 + case "3m": + coinankInterval = coinank_enum.Minute3 + case "5m": + coinankInterval = coinank_enum.Minute5 + case "15m": + coinankInterval = coinank_enum.Minute15 + case "30m": + coinankInterval = coinank_enum.Minute30 + case "1h": + coinankInterval = coinank_enum.Hour1 + case "2h": + coinankInterval = coinank_enum.Hour2 + case "4h": + coinankInterval = coinank_enum.Hour4 + case "6h": + coinankInterval = coinank_enum.Hour6 + case "8h": + coinankInterval = coinank_enum.Hour8 + case "12h": + coinankInterval = coinank_enum.Hour12 + case "1d": + coinankInterval = coinank_enum.Day1 + case "3d": + coinankInterval = coinank_enum.Day3 + case "1w": + coinankInterval = coinank_enum.Week1 + default: + return nil, fmt.Errorf("unsupported interval: %s", interval) + } + + // Map exchange string to coinank enum + var coinankExchange coinank_enum.Exchange + switch strings.ToLower(exchange) { + case "binance": + coinankExchange = coinank_enum.Binance + case "bybit": + coinankExchange = coinank_enum.Bybit + case "okx": + coinankExchange = coinank_enum.Okex + case "bitget": + coinankExchange = coinank_enum.Bitget + case "gate": + coinankExchange = coinank_enum.Gate + case "hyperliquid": + coinankExchange = coinank_enum.Hyperliquid + case "aster": + coinankExchange = coinank_enum.Aster + default: + // Default to Binance for unknown exchanges + coinankExchange = coinank_enum.Binance + } + + // Call CoinAnk free/open API (no authentication required) + ctx := context.Background() + ts := time.Now().UnixMilli() + // Use "To" side to search backward from current time (get historical klines) + coinankKlines, err := coinank_api.Kline(ctx, symbol, coinankExchange, ts, coinank_enum.To, limit, coinankInterval) + if err != nil { + // If exchange-specific data fails, fallback to Binance + if coinankExchange != coinank_enum.Binance { + logger.Warnf("⚠️ CoinAnk %s data failed, falling back to Binance: %v", exchange, err) + coinankKlines, err = coinank_api.Kline(ctx, symbol, coinank_enum.Binance, ts, coinank_enum.To, limit, coinankInterval) + if err != nil { + return nil, fmt.Errorf("CoinAnk API error (fallback): %w", err) + } + } else { + return nil, fmt.Errorf("CoinAnk API error: %w", err) + } + } + + // Convert coinank kline format to market.Kline format + klines := make([]Kline, len(coinankKlines)) + for i, ck := range coinankKlines { + klines[i] = Kline{ + OpenTime: ck.StartTime, + Open: ck.Open, + High: ck.High, + Low: ck.Low, + Close: ck.Close, + Volume: ck.Volume, + CloseTime: ck.EndTime, + } + } + + return klines, nil +} + +// getKlinesFromHyperliquid fetches kline data from Hyperliquid API for xyz dex assets +func getKlinesFromHyperliquid(symbol, interval string, limit int) ([]Kline, error) { + // Remove xyz: prefix if present for the API call + baseCoin := strings.TrimPrefix(symbol, "xyz:") + + // Map interval to Hyperliquid format + hlInterval := hyperliquid.MapTimeframe(interval) + + // Create Hyperliquid client + client := hyperliquid.NewClient() + + // Fetch candles + ctx := context.Background() + candles, err := client.GetCandles(ctx, baseCoin, hlInterval, limit) + if err != nil { + return nil, fmt.Errorf("Hyperliquid API error: %w", err) + } + + // Convert to market.Kline format + klines := make([]Kline, len(candles)) + for i, c := range candles { + open, _ := strconv.ParseFloat(c.Open, 64) + high, _ := strconv.ParseFloat(c.High, 64) + low, _ := strconv.ParseFloat(c.Low, 64) + closePrice, _ := strconv.ParseFloat(c.Close, 64) + volume, _ := strconv.ParseFloat(c.Volume, 64) + + klines[i] = Kline{ + OpenTime: c.OpenTime, + Open: open, + High: high, + Low: low, + Close: closePrice, + Volume: volume, + CloseTime: c.CloseTime, + } + } + + return klines, nil +} + +// calculateTimeframeSeries calculates series data for a single timeframe +func calculateTimeframeSeries(klines []Kline, timeframe string, count int) *TimeframeSeriesData { + if count <= 0 { + count = 10 // default + } + + data := &TimeframeSeriesData{ + Timeframe: timeframe, + Klines: make([]KlineBar, 0, count), + MidPrices: make([]float64, 0, count), + EMA20Values: make([]float64, 0, count), + EMA50Values: make([]float64, 0, count), + MACDValues: make([]float64, 0, count), + RSI7Values: make([]float64, 0, count), + RSI14Values: make([]float64, 0, count), + Volume: make([]float64, 0, count), + BOLLUpper: make([]float64, 0, count), + BOLLMiddle: make([]float64, 0, count), + BOLLLower: make([]float64, 0, count), + } + + // Get latest N data points based on count from config + start := len(klines) - count + if start < 0 { + start = 0 + } + + for i := start; i < len(klines); i++ { + // Store full OHLCV kline data + data.Klines = append(data.Klines, KlineBar{ + Time: klines[i].OpenTime, + Open: klines[i].Open, + High: klines[i].High, + Low: klines[i].Low, + Close: klines[i].Close, + Volume: klines[i].Volume, + }) + + // Keep MidPrices and Volume for backward compatibility + data.MidPrices = append(data.MidPrices, klines[i].Close) + data.Volume = append(data.Volume, klines[i].Volume) + + // Calculate EMA20 for each point + if i >= 19 { + ema20 := calculateEMA(klines[:i+1], 20) + data.EMA20Values = append(data.EMA20Values, ema20) + } + + // Calculate EMA50 for each point + if i >= 49 { + ema50 := calculateEMA(klines[:i+1], 50) + data.EMA50Values = append(data.EMA50Values, ema50) + } + + // Calculate MACD for each point + if i >= 25 { + macd := calculateMACD(klines[:i+1]) + data.MACDValues = append(data.MACDValues, macd) + } + + // Calculate RSI for each point + if i >= 7 { + rsi7 := calculateRSI(klines[:i+1], 7) + data.RSI7Values = append(data.RSI7Values, rsi7) + } + if i >= 14 { + rsi14 := calculateRSI(klines[:i+1], 14) + data.RSI14Values = append(data.RSI14Values, rsi14) + } + + // Calculate Bollinger Bands (period 20, std dev multiplier 2) + if i >= 19 { + upper, middle, lower := calculateBOLL(klines[:i+1], 20, 2.0) + data.BOLLUpper = append(data.BOLLUpper, upper) + data.BOLLMiddle = append(data.BOLLMiddle, middle) + data.BOLLLower = append(data.BOLLLower, lower) + } + } + + // Calculate ATR14 + data.ATR14 = calculateATR(klines, 14) + + return data +} + +// calculatePriceChangeByBars calculates how many K-lines to look back for price change based on timeframe +func calculatePriceChangeByBars(klines []Kline, timeframe string, targetMinutes int) float64 { + if len(klines) < 2 { + return 0 + } + + // Parse timeframe to minutes + tfMinutes := parseTimeframeToMinutes(timeframe) + if tfMinutes <= 0 { + return 0 + } + + // Calculate how many K-lines to look back + barsBack := targetMinutes / tfMinutes + if barsBack < 1 { + barsBack = 1 + } + + currentPrice := klines[len(klines)-1].Close + idx := len(klines) - 1 - barsBack + if idx < 0 { + idx = 0 + } + + oldPrice := klines[idx].Close + if oldPrice > 0 { + return ((currentPrice - oldPrice) / oldPrice) * 100 + } + return 0 +} + +// parseTimeframeToMinutes parses timeframe string to minutes +func parseTimeframeToMinutes(tf string) int { + switch tf { + case "1m": + return 1 + case "3m": + return 3 + case "5m": + return 5 + case "15m": + return 15 + case "30m": + return 30 + case "1h": + return 60 + case "2h": + return 120 + case "4h": + return 240 + case "6h": + return 360 + case "8h": + return 480 + case "12h": + return 720 + case "1d": + return 1440 + case "3d": + return 4320 + case "1w": + return 10080 + default: + return 0 + } +} + +// calculateIntradaySeries calculates intraday series data +func calculateIntradaySeries(klines []Kline) *IntradayData { + data := &IntradayData{ + MidPrices: make([]float64, 0, 10), + EMA20Values: make([]float64, 0, 10), + MACDValues: make([]float64, 0, 10), + RSI7Values: make([]float64, 0, 10), + RSI14Values: make([]float64, 0, 10), + Volume: make([]float64, 0, 10), + } + + // Get latest 10 data points + start := len(klines) - 10 + if start < 0 { + start = 0 + } + + for i := start; i < len(klines); i++ { + data.MidPrices = append(data.MidPrices, klines[i].Close) + data.Volume = append(data.Volume, klines[i].Volume) + + // Calculate EMA20 for each point + if i >= 19 { + ema20 := calculateEMA(klines[:i+1], 20) + data.EMA20Values = append(data.EMA20Values, ema20) + } + + // Calculate MACD for each point + if i >= 25 { + macd := calculateMACD(klines[:i+1]) + data.MACDValues = append(data.MACDValues, macd) + } + + // Calculate RSI for each point + if i >= 7 { + rsi7 := calculateRSI(klines[:i+1], 7) + data.RSI7Values = append(data.RSI7Values, rsi7) + } + if i >= 14 { + rsi14 := calculateRSI(klines[:i+1], 14) + data.RSI14Values = append(data.RSI14Values, rsi14) + } + } + + // Calculate 3m ATR14 + data.ATR14 = calculateATR(klines, 14) + + return data +} + +// calculateLongerTermData calculates longer-term data +func calculateLongerTermData(klines []Kline) *LongerTermData { + data := &LongerTermData{ + MACDValues: make([]float64, 0, 10), + RSI14Values: make([]float64, 0, 10), + } + + // Calculate EMA + data.EMA20 = calculateEMA(klines, 20) + data.EMA50 = calculateEMA(klines, 50) + + // Calculate ATR + data.ATR3 = calculateATR(klines, 3) + data.ATR14 = calculateATR(klines, 14) + + // Calculate volume + if len(klines) > 0 { + data.CurrentVolume = klines[len(klines)-1].Volume + // Calculate average volume + sum := 0.0 + for _, k := range klines { + sum += k.Volume + } + data.AverageVolume = sum / float64(len(klines)) + } + + // Calculate MACD and RSI series + start := len(klines) - 10 + if start < 0 { + start = 0 + } + + for i := start; i < len(klines); i++ { + if i >= 25 { + macd := calculateMACD(klines[:i+1]) + data.MACDValues = append(data.MACDValues, macd) + } + if i >= 14 { + rsi14 := calculateRSI(klines[:i+1], 14) + data.RSI14Values = append(data.RSI14Values, rsi14) + } + } + + return data +} + +// GetBoxData fetches 1h klines and calculates box data for a symbol +func GetBoxData(symbol string) (*BoxData, error) { + symbol = Normalize(symbol) + + // Fetch 500 1h klines + var klines []Kline + var err error + + if IsXyzDexAsset(symbol) { + klines, err = getKlinesFromHyperliquid(symbol, "1h", LongBoxPeriod) + } else { + klines, err = getKlinesFromCoinAnk(symbol, "1h", "binance", LongBoxPeriod) + } + + if err != nil { + return nil, fmt.Errorf("failed to get 1h klines: %w", err) + } + + if len(klines) == 0 { + return nil, fmt.Errorf("no kline data available") + } + + currentPrice := klines[len(klines)-1].Close + + return calculateBoxData(klines, currentPrice), nil +} diff --git a/provider/alpaca/kline_test.go b/provider/alpaca/kline_test.go deleted file mode 100644 index 6425f332..00000000 --- a/provider/alpaca/kline_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package alpaca - -import ( - "context" - "fmt" - "testing" -) - -func TestGetBars(t *testing.T) { - client := NewClient() - - resp, err := client.GetBars(context.TODO(), "AAPL", "1Day", 5) - if err != nil { - t.Fatal(err) - } - - t.Log("=== AAPL 日线数据 (Alpaca IEX feed) ===") - for i, bar := range resp { - t.Logf("\n[%d] 时间: %s", i, bar.Timestamp.Format("2006-01-02 15:04:05")) - t.Logf(" Open: %.2f", bar.Open) - t.Logf(" High: %.2f", bar.High) - t.Logf(" Low: %.2f", bar.Low) - t.Logf(" Close: %.2f", bar.Close) - t.Logf(" Volume: %d (股数)", bar.Volume) - t.Logf(" TradeCount: %d (成交笔数)", bar.TradeCount) - t.Logf(" VWAP: %.2f (成交量加权平均价)", bar.VWAP) - - // 计算成交额 - quoteVolume := float64(bar.Volume) * bar.Close - t.Logf(" 成交额: %.2f USD (Volume × Close)", quoteVolume) - } - - fmt.Printf("\n⚠️ 注意:IEX feed 只包含 IEX 交易所的数据,不是完整市场数据\n") - fmt.Printf("完整市场数据需要使用 SIP feed(付费)\n") -} diff --git a/provider/coinank/base_coin_test.go b/provider/coinank/base_coin_test.go deleted file mode 100644 index 47c8a670..00000000 --- a/provider/coinank/base_coin_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package coinank - -import ( - "context" - "encoding/json" - "nofx/provider/coinank/coinank_enum" - "testing" -) - -func TestListCoin(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.ListCoin(context.TODO(), "SPOT") - if err != nil { - t.Error(err) - } - res, err := json.Marshal(resp) - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} - -func TestListSymbols(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.ListSymbols(context.TODO(), "Binance", "SWAP") - if err != nil { - t.Error(err) - } - res, err := json.Marshal(resp) - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} diff --git a/provider/coinank/coinank_http_test.go b/provider/coinank/coinank_http_test.go deleted file mode 100644 index 4b795308..00000000 --- a/provider/coinank/coinank_http_test.go +++ /dev/null @@ -1,3 +0,0 @@ -package coinank - -var TestApikey = "" //need fill the apikey before test diff --git a/provider/coinank/instrument_agg_rank_test.go b/provider/coinank/instrument_agg_rank_test.go deleted file mode 100644 index 8a008342..00000000 --- a/provider/coinank/instrument_agg_rank_test.go +++ /dev/null @@ -1,101 +0,0 @@ -package coinank - -import ( - "context" - "encoding/json" - "nofx/provider/coinank/coinank_enum" - "testing" -) - -func TestVisualScreener(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.VisualScreener(context.TODO(), coinank_enum.Minute15) - if err != nil { - t.Error(err) - } - res, err := json.Marshal(resp) - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} - -func TestOiRank(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.OiRank(context.TODO(), coinank_enum.OpenInterest, coinank_enum.Desc, 1, 10) - if err != nil { - t.Error(err) - } - if resp[0].BaseCoin != "BTC" { - t.Error("oi first not BTC") - } - res, err := json.Marshal(resp) - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} - -func TestLongShortRank(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.LongShortRank(context.TODO(), coinank_enum.LongShortRatio, coinank_enum.Desc, 1, 10) - if err != nil { - t.Error(err) - } - if resp[0].BaseCoin == "" { - t.Error("baseCoin is empty") - } - res, err := json.Marshal(resp) - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} - -func TestLiquidationRank(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.LiquidationRank(context.TODO(), coinank_enum.LiquidationH1, coinank_enum.Desc, 1, 10) - if err != nil { - t.Error(err) - } - if resp[0].BaseCoin == "" { - t.Error("baseCoin is empty") - } - res, err := json.Marshal(resp) - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} - -func TestPriceRank(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.PriceRank(context.TODO(), coinank_enum.Price, coinank_enum.Desc, 1, 10) - if err != nil { - t.Error(err) - } - if resp[0].BaseCoin == "" { - t.Error("baseCoin is empty") - } - res, err := json.Marshal(resp) - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} - -func TestVolumeRank(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.VolumeRank(context.TODO(), coinank_enum.Turnover24h, coinank_enum.Desc, 1, 10) - if err != nil { - t.Error(err) - } - if resp[0].BaseCoin == "" { - t.Error("baseCoin is empty") - } - res, err := json.Marshal(resp) - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} diff --git a/provider/coinank/instruments_test.go b/provider/coinank/instruments_test.go deleted file mode 100644 index c6855b14..00000000 --- a/provider/coinank/instruments_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package coinank - -import ( - "context" - "encoding/json" - "nofx/provider/coinank/coinank_enum" - "testing" -) - -func TestGetLastPrice(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.GetLastPrice(context.TODO(), "BTCUSDT", "Binance", "SWAP") - if err != nil { - t.Error(err) - } - res, err := json.Marshal(resp) - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} - -func TestGetCoinMarketCap(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.GetCoinMarketCap(context.TODO(), "BTC") - if err != nil { - t.Error(err) - } - res, err := json.Marshal(resp) - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} diff --git a/provider/coinank/kline_test.go b/provider/coinank/kline_test.go deleted file mode 100644 index d7e690df..00000000 --- a/provider/coinank/kline_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package coinank - -import ( - "context" - "encoding/json" - "fmt" - "nofx/provider/coinank/coinank_enum" - "testing" - "time" -) - -func TestKline(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.Kline(context.TODO(), "BTCUSDT", coinank_enum.Binance, 0, time.Now().UnixMilli(), 10, coinank_enum.Hour1) - if err != nil { - t.Error(err) - } - res, err := json.Marshal(resp) - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} - -func TestKlineDaily(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.Kline(context.TODO(), "BTCUSDT", coinank_enum.Binance, 0, time.Now().UnixMilli(), 5, coinank_enum.Day1) - if err != nil { - t.Fatal(err) - } - - t.Log("=== BTCUSDT 日线 K线数据 ===") - for i, k := range resp { - startTime := time.UnixMilli(k.StartTime).Format("2006-01-02 15:04:05") - t.Logf("\n[%d] 时间: %s", i, startTime) - t.Logf(" Open: %.2f", k.Open) - t.Logf(" High: %.2f", k.High) - t.Logf(" Low: %.2f", k.Low) - t.Logf(" Close: %.2f", k.Close) - t.Logf(" Volume: %.2f (k[6])", k.Volume) - t.Logf(" Quantity: %.2f (k[7])", k.Quantity) - t.Logf(" Count: %.0f (k[8])", k.Count) - - // 计算验证 - if k.Close > 0 { - calcQuote := k.Volume * k.Close - t.Logf(" --- 验证 ---") - t.Logf(" Volume × Close = %.2f", calcQuote) - t.Logf(" Quantity / Close = %.2f", k.Quantity/k.Close) - } - } - - // 打印原始 JSON - res, _ := json.MarshalIndent(resp, "", " ") - fmt.Printf("\n原始 JSON:\n%s\n", res) -} diff --git a/provider/coinank/liquidation_test.go b/provider/coinank/liquidation_test.go deleted file mode 100644 index 24ce657c..00000000 --- a/provider/coinank/liquidation_test.go +++ /dev/null @@ -1,89 +0,0 @@ -package coinank - -import ( - "context" - "encoding/json" - "nofx/provider/coinank/coinank_enum" - "testing" - "time" -) - -func TestLiquidationExchangeStatistics(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.LiquidationExchangeStatistics(context.TODO(), "BTC") - if err != nil { - t.Fatal(err) - } - if resp.Total <= 0 { - t.Errorf("total amount is negative") - } - res, err := json.Marshal(resp) - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} - -func TestLiquidationCoinAggHistory(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.LiquidationCoinAggHistory(context.TODO(), "BTC", coinank_enum.Hour1, time.Now().UnixMilli(), 10) - if err != nil { - t.Fatal(err) - } - if resp[0].All.LongTurnover <= 0 { - t.Errorf("longTurnover is negative") - } - res, err := json.Marshal(resp) - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} - -func TestLiquidationHistory(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.LiquidationHistory(context.TODO(), coinank_enum.Binance, "BTCUSDT", coinank_enum.Hour1, time.Now().UnixMilli(), 10) - if err != nil { - t.Fatal(err) - } - if resp[0].LongTurnover <= 0 { - t.Errorf("longTurnover is negative") - } - res, err := json.Marshal(resp) - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} - -func TestLiquidationOrders(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.LiquidationOrders(context.TODO(), "BTC", coinank_enum.Binance, "long", 1000, time.Now().UnixMilli()) - if err != nil { - t.Fatal(err) - } - res, err := json.Marshal(resp) - if resp[0].Price <= 0 { - t.Errorf("price is negative") - } - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} - -func TestLiquidationOrdersNoArgs(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.LiquidationOrders(context.TODO(), "", "", "", 0, 0) - if err != nil { - t.Fatal(err) - } - res, err := json.Marshal(resp) - if resp[0].Price <= 0 { - t.Errorf("price is negative") - } - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} diff --git a/provider/coinank/net_positions_test.go b/provider/coinank/net_positions_test.go deleted file mode 100644 index 2cc3cc0d..00000000 --- a/provider/coinank/net_positions_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package coinank - -import ( - "context" - "encoding/json" - "nofx/provider/coinank/coinank_enum" - "testing" - "time" -) - -func TestNetPositions(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.NetPositions(context.TODO(), coinank_enum.Binance, "BTCUSDT", coinank_enum.Hour1, time.Now().UnixMilli(), 10) - if err != nil { - t.Fatal(err) - } - if resp[0].Begin <= 0 { - t.Errorf("begin timestamp error") - } - res, err := json.Marshal(resp) - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} diff --git a/provider/coinank/open_interest_test.go b/provider/coinank/open_interest_test.go deleted file mode 100644 index 3bbbd1f9..00000000 --- a/provider/coinank/open_interest_test.go +++ /dev/null @@ -1,106 +0,0 @@ -package coinank - -import ( - "context" - "encoding/json" - "nofx/provider/coinank/coinank_enum" - "testing" - "time" -) - -func TestOpenInterestAll(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.OpenInterestAll(context.TODO(), "BTC") - if err != nil { - t.Error(err) - } - if resp[0].ExchangeName != "ALL" { - t.Error("exchange name is empty") - } - res, err := json.Marshal(resp) - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} - -func TestOpenInterestChartV2(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.OpenInterestChartV2(context.TODO(), "BTC", coinank_enum.Binance, coinank_enum.Hour1, 10) - if err != nil { - t.Error(err) - } - res, err := json.Marshal(resp) - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} - -func TestOpenInterestSymbolChart(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.OpenInterestSymbolChart(context.TODO(), coinank_enum.Binance, "BTCUSDT", coinank_enum.Hour1, time.Now().UnixMilli(), 10) - if err != nil { - t.Error(err) - } - if resp[0].BaseCoin != "BTC" { - t.Error("baseCoin is error") - } - res, err := json.Marshal(resp) - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} - -func TestOpenInterestKline(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.OpenInterestKline(context.TODO(), coinank_enum.Binance, "BTCUSDT", coinank_enum.Hour1, time.Now().UnixMilli(), 10) - if err != nil { - t.Error(err) - } - res, err := json.Marshal(resp) - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} - -func TestOpenInterestAggKline(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.OpenInterestAggKline(context.TODO(), "BTC", coinank_enum.Hour1, time.Now().UnixMilli(), 10) - if err != nil { - t.Error(err) - } - res, err := json.Marshal(resp) - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} - -func TestTickersTopOIByEx(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.TickersTopOIByEx(context.TODO(), "BTC") - if err != nil { - t.Error(err) - } - res, err := json.Marshal(resp) - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} - -func TestInstrumentsOiVsMc(t *testing.T) { - client := NewCoinankClient(coinank_enum.MainUrl, TestApikey) - resp, err := client.InstrumentsOiVsMc(context.TODO(), "BTC", coinank_enum.Hour1, time.Now().UnixMilli(), 10) - if err != nil { - t.Error(err) - } - res, err := json.Marshal(resp) - if err != nil { - t.Error(err) - } - t.Logf("%s", res) -} diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index bbaeecdc..00000000 --- a/pyproject.toml +++ /dev/null @@ -1,7 +0,0 @@ -[project] -name = "nofx" -version = "0.1.0" -description = "Add your description here" -readme = "README.md" -requires-python = ">=3.12" -dependencies = [] diff --git a/screenshots/debate-arena.png b/screenshots/debate-arena.png deleted file mode 100644 index 329b5618..00000000 Binary files a/screenshots/debate-arena.png and /dev/null differ diff --git a/screenshots/debate-create.png b/screenshots/debate-create.png deleted file mode 100644 index 5a80cabf..00000000 Binary files a/screenshots/debate-create.png and /dev/null differ diff --git a/scripts/ENCRYPTION_README.md b/scripts/ENCRYPTION_README.md deleted file mode 100644 index 672ad0d8..00000000 --- a/scripts/ENCRYPTION_README.md +++ /dev/null @@ -1,302 +0,0 @@ -# Mars AI交易系统 - 加密密钥生成脚本 - -本目录包含用于Mars AI交易系统加密环境设置的脚本工具。 - -## 🔐 加密架构 - -Mars AI交易系统使用双重加密架构来保护敏感数据: - -1. **RSA-OAEP + AES-GCM 混合加密** - 用于前端到后端的安全通信 -2. **AES-256-GCM 数据库加密** - 用于敏感数据的存储加密 - -### 加密流程 - -``` -前端 → RSA-OAEP加密AES密钥 + AES-GCM加密数据 → 后端 → 存储时AES-256-GCM加密 -``` - -## 📝 脚本说明 - -### 1. `setup_encryption.sh` - 一键环境设置 ⭐推荐⭐ - -**功能**: 自动生成所有必要的密钥并配置环境 - -```bash -./scripts/setup_encryption.sh -``` - -**生成内容**: -- RSA-2048 密钥对 (`secrets/rsa_key`, `secrets/rsa_key.pub`) -- AES-256 数据加密密钥 (保存到 `.env`) -- 自动权限设置和验证 - -**适用场景**: -- 首次部署 -- 开发环境快速设置 -- 生产环境初始化 - -### 2. `generate_rsa_keys.sh` - RSA密钥生成 - -**功能**: 专门生成RSA密钥对 - -```bash -./scripts/generate_rsa_keys.sh -``` - -**生成内容**: -- `secrets/rsa_key` (私钥, 权限 600) -- `secrets/rsa_key.pub` (公钥, 权限 644) - -**技术规格**: -- 算法: RSA-OAEP -- 密钥长度: 2048 bits -- 格式: PEM - -### 3. `generate_data_key.sh` - 数据加密密钥生成 - -**功能**: 生成数据库加密密钥 - -```bash -./scripts/generate_data_key.sh -``` - -**生成内容**: -- 32字节(256位)随机密钥 -- Base64编码格式 -- 可选保存到 `.env` 文件 - -**技术规格**: -- 算法: AES-256-GCM -- 编码: Base64 -- 环境变量: `DATA_ENCRYPTION_KEY` - -## 🚀 快速开始 - -### 方案1: 一键设置 (推荐) - -```bash -# 克隆项目后,直接运行一键设置 -cd mars-ai-trading -./scripts/setup_encryption.sh - -# 按提示确认即可完成所有设置 -``` - -### 方案2: 分步设置 - -```bash -# 1. 生成RSA密钥对 -./scripts/generate_rsa_keys.sh - -# 2. 生成数据加密密钥 -./scripts/generate_data_key.sh - -# 3. 启动系统 -source .env && ./mars -``` - -## 📁 文件结构 - -生成完成后的目录结构: - -``` -mars-ai-trading/ -├── secrets/ -│ ├── rsa_key # RSA私钥 (600权限) -│ └── rsa_key.pub # RSA公钥 (644权限) -├── .env # 环境变量 (600权限) -│ └── DATA_ENCRYPTION_KEY=xxx -└── scripts/ - ├── setup_encryption.sh # 一键设置脚本 - ├── generate_rsa_keys.sh # RSA密钥生成 - └── generate_data_key.sh # 数据密钥生成 -``` - -## 🔒 安全要求 - -### 文件权限 - -| 文件 | 权限 | 说明 | -|------|------|------| -| `secrets/rsa_key` | 600 | 仅所有者可读写 | -| `secrets/rsa_key.pub` | 644 | 所有人可读 | -| `.env` | 600 | 仅所有者可读写 | - -### 环境变量 - -```bash -# 必需的环境变量 -DATA_ENCRYPTION_KEY=<32字节Base64编码的AES密钥> -``` - -## 🐳 Docker部署 - -### 使用环境文件 - -```bash -# 生成密钥 -./scripts/setup_encryption.sh - -# Docker运行 -docker run --env-file .env -v $(pwd)/secrets:/app/secrets mars-ai-trading -``` - -### 使用环境变量 - -```bash -export DATA_ENCRYPTION_KEY="<生成的密钥>" -docker run -e DATA_ENCRYPTION_KEY mars-ai-trading -``` - -## ☸️ Kubernetes部署 - -### 创建Secret - -```bash -# 从现有.env文件创建 -kubectl create secret generic mars-crypto-key --from-env-file=.env - -# 或直接指定密钥 -kubectl create secret generic mars-crypto-key \ - --from-literal=DATA_ENCRYPTION_KEY="<生成的密钥>" -``` - -### 挂载RSA密钥 - -```yaml -apiVersion: v1 -kind: Secret -metadata: - name: mars-rsa-keys -type: Opaque -data: - rsa_key: - rsa_key.pub: ---- -apiVersion: apps/v1 -kind: Deployment -metadata: - name: mars-ai-trading -spec: - template: - spec: - containers: - - name: mars - envFrom: - - secretRef: - name: mars-crypto-key - volumeMounts: - - name: rsa-keys - mountPath: /app/secrets - volumes: - - name: rsa-keys - secret: - secretName: mars-rsa-keys -``` - -## 🔄 密钥轮换 - -### 数据加密密钥轮换 - -```bash -# 1. 生成新密钥 -./scripts/generate_data_key.sh - -# 2. 备份旧数据库 -cp data.db data.db.backup - -# 3. 重启服务 (会自动处理密钥迁移) -source .env && ./mars -``` - -### RSA密钥轮换 - -```bash -# 1. 生成新密钥对 -./scripts/generate_rsa_keys.sh - -# 2. 重启服务 -./mars -``` - -## 🛠️ 故障排除 - -### 常见问题 - -1. **权限错误** - ```bash - chmod 600 secrets/rsa_key .env - chmod 644 secrets/rsa_key.pub - ``` - -2. **OpenSSL未安装** - ```bash - # macOS - brew install openssl - - # Ubuntu/Debian - sudo apt-get install openssl - - # CentOS/RHEL - sudo yum install openssl - ``` - -3. **环境变量未加载** - ```bash - source .env - echo $DATA_ENCRYPTION_KEY - ``` - -4. **密钥验证失败** - ```bash - # 验证RSA私钥 - openssl rsa -in secrets/rsa_key -check -noout - - # 验证公钥 - openssl rsa -in secrets/rsa_key.pub -pubin -text -noout - ``` - -### 日志检查 - -启动时检查以下日志: -- `🔐 初始化加密服务...` -- `✅ 加密服务初始化成功` - -## 📊 性能考虑 - -- **RSA加密**: 仅用于小量密钥交换,性能影响极小 -- **AES加密**: 数据库字段级加密,对读写性能影响约5-10% -- **内存使用**: 加密服务约占用2-5MB内存 - -## 🔐 算法详细说明 - -### RSA-OAEP-2048 -- **用途**: 前端到后端的混合加密中的密钥交换 -- **密钥长度**: 2048 bits -- **填充**: OAEP with SHA-256 -- **安全级别**: 相当于112位对称加密 - -### AES-256-GCM -- **用途**: 数据库敏感字段存储加密 -- **密钥长度**: 256 bits -- **模式**: GCM (Galois/Counter Mode) -- **认证**: 内置消息认证 -- **安全级别**: 256位安全强度 - -## 📋 合规性 - -此加密实现满足以下标准: -- **FIPS 140-2**: AES-256 和 RSA-2048 -- **Common Criteria**: EAL4+ -- **NIST推荐**: SP 800-57 密钥管理 -- **行业标准**: 符合金融业数据保护要求 - ---- - -## 📞 技术支持 - -如有问题,请检查: -1. OpenSSL版本 >= 1.1.1 -2. 文件权限设置正确 -3. 环境变量加载成功 -4. 系统日志中的加密初始化信息 \ No newline at end of file diff --git a/scripts/cleanup_duplicates.go b/scripts/cleanup_duplicates.go deleted file mode 100644 index 87b7853e..00000000 --- a/scripts/cleanup_duplicates.go +++ /dev/null @@ -1,98 +0,0 @@ -package main - -import ( - "flag" - "fmt" - "log" - "nofx/store" - "os" - "path/filepath" -) - -func main() { - var dbPath string - var dryRun bool - - flag.StringVar(&dbPath, "db", "./data/data.db", "数据库文件路径") - flag.BoolVar(&dryRun, "dry-run", false, "只检查不删除(预览模式)") - flag.Parse() - - // 确保数据库文件存在 - absPath, err := filepath.Abs(dbPath) - if err != nil { - log.Fatalf("❌ 无效的数据库路径: %v", err) - } - - if _, err := os.Stat(absPath); os.IsNotExist(err) { - log.Fatalf("❌ 数据库文件不存在: %s", absPath) - } - - fmt.Printf("📂 数据库路径: %s\n", absPath) - - // 打开数据库 - s, err := store.New(absPath) - if err != nil { - log.Fatalf("❌ 无法打开数据库: %v", err) - } - defer s.Close() - - orderStore := s.Order() - - // 1. 检查重复订单数量 - fmt.Println("\n🔍 检查重复数据...") - dupOrders, err := orderStore.GetDuplicateOrdersCount() - if err != nil { - log.Fatalf("❌ 检查重复订单失败: %v", err) - } - fmt.Printf(" 📋 重复订单: %d 条\n", dupOrders) - - dupFills, err := orderStore.GetDuplicateFillsCount() - if err != nil { - log.Fatalf("❌ 检查重复成交失败: %v", err) - } - fmt.Printf(" 📊 重复成交: %d 条\n", dupFills) - - if dupOrders == 0 && dupFills == 0 { - fmt.Println("\n✅ 数据库没有重复记录,无需清理") - return - } - - if dryRun { - fmt.Println("\n⚠️ 预览模式(--dry-run),不会删除数据") - fmt.Println(" 运行 'go run scripts/cleanup_duplicates.go' 来执行实际清理") - return - } - - // 2. 清理重复订单 - if dupOrders > 0 { - fmt.Println("\n🧹 清理重复订单...") - deleted, err := orderStore.CleanupDuplicateOrders() - if err != nil { - log.Fatalf("❌ 清理失败: %v", err) - } - fmt.Printf(" ✅ 删除了 %d 条重复订单\n", deleted) - } - - // 3. 清理重复成交 - if dupFills > 0 { - fmt.Println("\n🧹 清理重复成交...") - deleted, err := orderStore.CleanupDuplicateFills() - if err != nil { - log.Fatalf("❌ 清理失败: %v", err) - } - fmt.Printf(" ✅ 删除了 %d 条重复成交\n", deleted) - } - - // 4. 验证清理结果 - fmt.Println("\n🔍 验证清理结果...") - dupOrdersAfter, _ := orderStore.GetDuplicateOrdersCount() - dupFillsAfter, _ := orderStore.GetDuplicateFillsCount() - fmt.Printf(" 📋 剩余重复订单: %d 条\n", dupOrdersAfter) - fmt.Printf(" 📊 剩余重复成交: %d 条\n", dupFillsAfter) - - if dupOrdersAfter == 0 && dupFillsAfter == 0 { - fmt.Println("\n✅ 清理完成!数据库已去重") - } else { - fmt.Println("\n⚠️ 仍有重复数据,可能需要手动检查") - } -} diff --git a/scripts/clear_orders.go b/scripts/clear_orders.go deleted file mode 100644 index 934284cc..00000000 --- a/scripts/clear_orders.go +++ /dev/null @@ -1,111 +0,0 @@ -package main - -import ( - "bufio" - "flag" - "fmt" - "log" - "nofx/store" - "os" - "path/filepath" - "strings" -) - -func main() { - var dbPath string - var force bool - - flag.StringVar(&dbPath, "db", "./data/data.db", "数据库文件路径") - flag.BoolVar(&force, "force", false, "跳过确认直接删除") - flag.Parse() - - // 确保数据库文件存在 - absPath, err := filepath.Abs(dbPath) - if err != nil { - log.Fatalf("❌ 无效的数据库路径: %v", err) - } - - if _, err := os.Stat(absPath); os.IsNotExist(err) { - log.Fatalf("❌ 数据库文件不存在: %s", absPath) - } - - fmt.Printf("📂 数据库路径: %s\n", absPath) - - // 打开数据库 - s, err := store.New(absPath) - if err != nil { - log.Fatalf("❌ 无法打开数据库: %v", err) - } - defer s.Close() - - db := s.DB() - - // 统计当前数据 - var orderCount, fillCount int - db.QueryRow(`SELECT COUNT(*) FROM trader_orders`).Scan(&orderCount) - db.QueryRow(`SELECT COUNT(*) FROM trader_fills`).Scan(&fillCount) - - fmt.Printf("\n📊 当前数据统计:\n") - fmt.Printf(" trader_orders: %d 条记录\n", orderCount) - fmt.Printf(" trader_fills: %d 条记录\n", fillCount) - - if orderCount == 0 && fillCount == 0 { - fmt.Println("\n✅ 表已经是空的,无需清空") - return - } - - // 确认删除 - if !force { - fmt.Println("\n⚠️ 警告: 此操作将删除所有订单和成交记录,无法恢复!") - fmt.Print("\n确认删除?请输入 'yes' 继续: ") - - reader := bufio.NewReader(os.Stdin) - input, _ := reader.ReadString('\n') - input = strings.TrimSpace(input) - - if input != "yes" { - fmt.Println("\n❌ 操作已取消") - return - } - } - - fmt.Println("\n🗑️ 开始清空表...") - - // 清空 trader_fills 表(先删除,因为有外键约束) - result, err := db.Exec(`DELETE FROM trader_fills`) - if err != nil { - log.Fatalf("❌ 清空 trader_fills 失败: %v", err) - } - fillsDeleted, _ := result.RowsAffected() - fmt.Printf(" ✅ 删除了 %d 条成交记录\n", fillsDeleted) - - // 清空 trader_orders 表 - result, err = db.Exec(`DELETE FROM trader_orders`) - if err != nil { - log.Fatalf("❌ 清空 trader_orders 失败: %v", err) - } - ordersDeleted, _ := result.RowsAffected() - fmt.Printf(" ✅ 删除了 %d 条订单记录\n", ordersDeleted) - - // 重置自增ID(可选,让ID从1重新开始) - _, err = db.Exec(`DELETE FROM sqlite_sequence WHERE name IN ('trader_orders', 'trader_fills')`) - if err == nil { - fmt.Println(" ✅ 重置了自增ID计数器") - } - - // 验证清空结果 - db.QueryRow(`SELECT COUNT(*) FROM trader_orders`).Scan(&orderCount) - db.QueryRow(`SELECT COUNT(*) FROM trader_fills`).Scan(&fillCount) - - fmt.Printf("\n🔍 验证结果:\n") - fmt.Printf(" trader_orders: %d 条记录\n", orderCount) - fmt.Printf(" trader_fills: %d 条记录\n", fillCount) - - if orderCount == 0 && fillCount == 0 { - fmt.Println("\n✅ 表已成功清空!") - fmt.Println("\n💡 现在可以重新运行 trader 进行测试") - fmt.Println(" 新的订单将从 ID=1 开始记录") - } else { - fmt.Println("\n⚠️ 清空未完成,请检查数据库") - } -} diff --git a/scripts/diagnose_orders.go b/scripts/diagnose_orders.go deleted file mode 100644 index 0a1b2bed..00000000 --- a/scripts/diagnose_orders.go +++ /dev/null @@ -1,189 +0,0 @@ -package main - -import ( - "flag" - "fmt" - "log" - "nofx/store" - "os" - "path/filepath" - "time" -) - -func main() { - var dbPath string - var traderID string - - flag.StringVar(&dbPath, "db", "./data/data.db", "数据库文件路径") - flag.StringVar(&traderID, "trader", "", "Trader ID(可选)") - flag.Parse() - - // 确保数据库文件存在 - absPath, err := filepath.Abs(dbPath) - if err != nil { - log.Fatalf("❌ 无效的数据库路径: %v", err) - } - - if _, err := os.Stat(absPath); os.IsNotExist(err) { - log.Fatalf("❌ 数据库文件不存在: %s", absPath) - } - - fmt.Printf("📂 数据库路径: %s\n", absPath) - - // 打开数据库 - s, err := store.New(absPath) - if err != nil { - log.Fatalf("❌ 无法打开数据库: %v", err) - } - defer s.Close() - - orderStore := s.Order() - - // 如果指定了 traderID,获取该 trader 的订单 - if traderID == "" { - fmt.Println("\n⚠️ 未指定 trader_id,使用: --trader ") - fmt.Println(" 获取所有 trader 的统计信息...\n") - } - - // 获取订单列表 - orders, err := orderStore.GetTraderOrders(traderID, 100) - if err != nil { - log.Fatalf("❌ 获取订单失败: %v", err) - } - - fmt.Printf("\n📋 找到 %d 条订单记录\n\n", len(orders)) - - if len(orders) == 0 { - fmt.Println("⚠️ 没有订单数据!可能的原因:") - fmt.Println(" 1. Trader 还没有执行过交易") - fmt.Println(" 2. CreateOrder 插入失败(重复键冲突)") - fmt.Println(" 3. 指定的 trader_id 不存在") - return - } - - // 统计数据 - var ( - totalOrders = len(orders) - filledOrders = 0 - withFilledAt = 0 - withAvgFillPrice = 0 - withOrderAction = 0 - missingFilledAt = 0 - missingAvgPrice = 0 - missingOrderAction = 0 - ) - - fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") - fmt.Printf("%-15s %-10s %-10s %-15s %-10s %-15s\n", "订单ID", "状态", "动作", "平均成交价", "成交时间", "问题") - fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") - - for _, order := range orders { - issues := []string{} - - if order.Status == "FILLED" { - filledOrders++ - - // 检查 filled_at - if order.FilledAt > 0 { - withFilledAt++ - } else { - missingFilledAt++ - issues = append(issues, "❌ 缺少成交时间") - } - - // 检查 avg_fill_price - if order.AvgFillPrice > 0 { - withAvgFillPrice++ - } else { - missingAvgPrice++ - issues = append(issues, "❌ 成交价为0") - } - } - - // 检查 order_action - if order.OrderAction != "" { - withOrderAction++ - } else { - missingOrderAction++ - issues = append(issues, "⚠️ 缺少订单动作") - } - - issueStr := "✅ 正常" - if len(issues) > 0 { - issueStr = "" - for i, issue := range issues { - if i > 0 { - issueStr += ", " - } - issueStr += issue - } - } - - filledAtStr := "N/A" - if order.FilledAt > 0 { - filledAtStr = time.UnixMilli(order.FilledAt).Format("01-02 15:04") - } - - fmt.Printf("%-15s %-10s %-10s %-15.2f %-10s %s\n", - order.ExchangeOrderID[:min(15, len(order.ExchangeOrderID))], - order.Status, - order.OrderAction, - order.AvgFillPrice, - filledAtStr, - issueStr, - ) - } - - fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") - - // 统计摘要 - fmt.Printf("\n📊 统计摘要:\n") - fmt.Printf(" 总订单数: %d\n", totalOrders) - fmt.Printf(" 已成交订单: %d\n", filledOrders) - fmt.Printf(" 有成交时间: %d / %d (%.1f%%)\n", withFilledAt, filledOrders, float64(withFilledAt)/float64(max(filledOrders, 1))*100) - fmt.Printf(" 有成交价格: %d / %d (%.1f%%)\n", withAvgFillPrice, filledOrders, float64(withAvgFillPrice)/float64(max(filledOrders, 1))*100) - fmt.Printf(" 有订单动作: %d / %d (%.1f%%)\n", withOrderAction, totalOrders, float64(withOrderAction)/float64(max(totalOrders, 1))*100) - - fmt.Printf("\n⚠️ 问题订单:\n") - if missingFilledAt > 0 { - fmt.Printf(" ❌ %d 条订单缺少成交时间 (filled_at)\n", missingFilledAt) - } - if missingAvgPrice > 0 { - fmt.Printf(" ❌ %d 条订单成交价为 0 (avg_fill_price)\n", missingAvgPrice) - } - if missingOrderAction > 0 { - fmt.Printf(" ⚠️ %d 条订单缺少订单动作 (order_action)\n", missingOrderAction) - } - - if missingFilledAt > 0 || missingAvgPrice > 0 { - fmt.Println("\n💡 这些订单无法在图表上显示,因为:") - fmt.Println(" - 缺少成交时间 → 前端无法定位到K线时间轴") - fmt.Println(" - 成交价为 0 → 前端会过滤掉 (line 164: if (!orderPrice || orderPrice === 0) return)") - fmt.Println("\n🔧 可能的原因:") - fmt.Println(" 1. UpdateOrderStatus 没有被正确调用") - fmt.Println(" 2. GetOrderStatus 返回的数据缺少 avgPrice 字段") - fmt.Println(" 3. Lighter 交易所的订单状态查询有问题") - } - - if missingFilledAt == 0 && missingAvgPrice == 0 && missingOrderAction == 0 { - fmt.Println("\n✅ 所有订单数据完整!") - fmt.Println(" 如果图表仍然没有显示 B/S 标记,检查:") - fmt.Println(" 1. 前端是否正确调用了 /api/orders API") - fmt.Println(" 2. 浏览器控制台是否有错误") - fmt.Println(" 3. 订单时间是否在图表的时间范围内") - } -} - -func min(a, b int) int { - if a < b { - return a - } - return b -} - -func max(a, b int) int { - if a > b { - return a - } - return b -} diff --git a/scripts/fix_order_data.go b/scripts/fix_order_data.go deleted file mode 100644 index eac8d4b6..00000000 --- a/scripts/fix_order_data.go +++ /dev/null @@ -1,141 +0,0 @@ -package main - -import ( - "flag" - "fmt" - "log" - "nofx/store" - "os" - "path/filepath" - "time" -) - -func main() { - var dbPath string - var dryRun bool - - flag.StringVar(&dbPath, "db", "./data/data.db", "数据库文件路径") - flag.BoolVar(&dryRun, "dry-run", false, "只检查不修复(预览模式)") - flag.Parse() - - // 确保数据库文件存在 - absPath, err := filepath.Abs(dbPath) - if err != nil { - log.Fatalf("❌ 无效的数据库路径: %v", err) - } - - if _, err := os.Stat(absPath); os.IsNotExist(err) { - log.Fatalf("❌ 数据库文件不存在: %s", absPath) - } - - fmt.Printf("📂 数据库路径: %s\n", absPath) - - // 打开数据库 - s, err := store.New(absPath) - if err != nil { - log.Fatalf("❌ 无法打开数据库: %v", err) - } - defer s.Close() - - db := s.DB() - - fmt.Println("\n🔍 检查需要修复的订单...") - - // 1. 修复缺少 filled_at 的 FILLED 订单(使用 updated_at 或 created_at) - var needFixFilledAt int - err = db.QueryRow(` - SELECT COUNT(*) - FROM trader_orders - WHERE status = 'FILLED' AND (filled_at IS NULL OR filled_at = '') - `).Scan(&needFixFilledAt) - if err != nil { - log.Fatalf("❌ 查询失败: %v", err) - } - - fmt.Printf(" 📋 缺少成交时间的订单: %d 条\n", needFixFilledAt) - - // 2. 修复 avg_fill_price = 0 的 FILLED 订单(使用 price 字段) - var needFixAvgPrice int - err = db.QueryRow(` - SELECT COUNT(*) - FROM trader_orders - WHERE status = 'FILLED' AND (avg_fill_price = 0 OR avg_fill_price IS NULL) AND price > 0 - `).Scan(&needFixAvgPrice) - if err != nil { - log.Fatalf("❌ 查询失败: %v", err) - } - - fmt.Printf(" 💰 成交价为0的订单: %d 条\n", needFixAvgPrice) - - if needFixFilledAt == 0 && needFixAvgPrice == 0 { - fmt.Println("\n✅ 没有需要修复的订单!") - return - } - - if dryRun { - fmt.Println("\n⚠️ 预览模式(--dry-run),不会修改数据") - fmt.Println(" 运行 'go run scripts/fix_order_data.go' 来执行实际修复") - return - } - - fmt.Println("\n🔧 开始修复...") - - // 修复缺少 filled_at 的订单 - if needFixFilledAt > 0 { - result, err := db.Exec(` - UPDATE trader_orders - SET filled_at = COALESCE(updated_at, created_at) - WHERE status = 'FILLED' AND (filled_at IS NULL OR filled_at = '') - `) - if err != nil { - log.Fatalf("❌ 修复成交时间失败: %v", err) - } - rows, _ := result.RowsAffected() - fmt.Printf(" ✅ 修复了 %d 条订单的成交时间\n", rows) - } - - // 修复 avg_fill_price = 0 的订单 - if needFixAvgPrice > 0 { - result, err := db.Exec(` - UPDATE trader_orders - SET avg_fill_price = price, - filled_quantity = quantity - WHERE status = 'FILLED' - AND (avg_fill_price = 0 OR avg_fill_price IS NULL) - AND price > 0 - `) - if err != nil { - log.Fatalf("❌ 修复成交价失败: %v", err) - } - rows, _ := result.RowsAffected() - fmt.Printf(" ✅ 修复了 %d 条订单的成交价\n", rows) - } - - // 验证修复结果 - fmt.Println("\n🔍 验证修复结果...") - time.Sleep(100 * time.Millisecond) - - var stillMissingFilledAt int - db.QueryRow(` - SELECT COUNT(*) - FROM trader_orders - WHERE status = 'FILLED' AND (filled_at IS NULL OR filled_at = '') - `).Scan(&stillMissingFilledAt) - - var stillMissingAvgPrice int - db.QueryRow(` - SELECT COUNT(*) - FROM trader_orders - WHERE status = 'FILLED' AND (avg_fill_price = 0 OR avg_fill_price IS NULL) - `).Scan(&stillMissingAvgPrice) - - fmt.Printf(" 📋 仍缺少成交时间: %d 条\n", stillMissingFilledAt) - fmt.Printf(" 💰 仍缺少成交价: %d 条\n", stillMissingAvgPrice) - - if stillMissingFilledAt == 0 && stillMissingAvgPrice == 0 { - fmt.Println("\n✅ 修复完成!所有订单数据已完整") - fmt.Println("\n💡 现在刷新图表页面,应该能看到 B/S 标记了") - } else { - fmt.Println("\n⚠️ 仍有部分订单无法修复,可能需要手动检查") - } -} diff --git a/scripts/migrate_encryption.go b/scripts/migrate_encryption.go deleted file mode 100644 index bfdb120e..00000000 --- a/scripts/migrate_encryption.go +++ /dev/null @@ -1,200 +0,0 @@ -package main - -import ( - "database/sql" - "fmt" - "log" - "os" - - "nofx/crypto" - - _ "modernc.org/sqlite" -) - -func main() { - log.Println("🔄 Starting database migration to encrypted format...") - - // 1. Check database file - dbPath := "data/data.db" - if len(os.Args) > 1 { - dbPath = os.Args[1] - } - - if _, err := os.Stat(dbPath); os.IsNotExist(err) { - log.Fatalf("❌ Database file does not exist: %s", dbPath) - } - - // 2. Backup database - backupPath := fmt.Sprintf("%s.pre_encryption_backup", dbPath) - log.Printf("📦 Backing up database to: %s", backupPath) - - input, err := os.ReadFile(dbPath) - if err != nil { - log.Fatalf("❌ Failed to read database: %v", err) - } - - if err := os.WriteFile(backupPath, input, 0600); err != nil { - log.Fatalf("❌ Backup failed: %v", err) - } - - // 3. Open database - db, err := sql.Open("sqlite", dbPath) - if err != nil { - log.Fatalf("❌ Failed to open database: %v", err) - } - defer db.Close() - - // 4. Initialize CryptoService (load key from environment variables) - cs, err := crypto.NewCryptoService() - if err != nil { - log.Fatalf("❌ Failed to initialize encryption service: %v", err) - } - - // 5. Migrate exchange configurations - if err := migrateExchanges(db, cs); err != nil { - log.Fatalf("❌ Failed to migrate exchange configurations: %v", err) - } - - // 6. Migrate AI model configurations - if err := migrateAIModels(db, cs); err != nil { - log.Fatalf("❌ Failed to migrate AI model configurations: %v", err) - } - - log.Println("✅ Data migration completed!") - log.Printf("📝 Original data backed up at: %s", backupPath) - log.Println("⚠️ Please verify system functionality before manually deleting backup file") -} - -// migrateExchanges migrates exchange configurations -func migrateExchanges(db *sql.DB, cs *crypto.CryptoService) error { - log.Println("🔄 Migrating exchange configurations...") - - // Query all unencrypted records (encrypted data starts with ENC:v1:) - rows, err := db.Query(` - SELECT user_id, id, api_key, secret_key, - COALESCE(hyperliquid_private_key, ''), - COALESCE(aster_private_key, '') - FROM exchanges - WHERE (api_key != '' AND api_key NOT LIKE 'ENC:v1:%') - OR (secret_key != '' AND secret_key NOT LIKE 'ENC:v1:%') - `) - if err != nil { - return err - } - defer rows.Close() - - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - count := 0 - for rows.Next() { - var userID, exchangeID, apiKey, secretKey, hlPrivateKey, asterPrivateKey string - if err := rows.Scan(&userID, &exchangeID, &apiKey, &secretKey, &hlPrivateKey, &asterPrivateKey); err != nil { - return err - } - - // Encrypt each field - encAPIKey, err := cs.EncryptForStorage(apiKey) - if err != nil { - return fmt.Errorf("failed to encrypt API Key: %w", err) - } - - encSecretKey, err := cs.EncryptForStorage(secretKey) - if err != nil { - return fmt.Errorf("failed to encrypt Secret Key: %w", err) - } - - encHLPrivateKey := "" - if hlPrivateKey != "" { - encHLPrivateKey, err = cs.EncryptForStorage(hlPrivateKey) - if err != nil { - return fmt.Errorf("failed to encrypt Hyperliquid Private Key: %w", err) - } - } - - encAsterPrivateKey := "" - if asterPrivateKey != "" { - encAsterPrivateKey, err = cs.EncryptForStorage(asterPrivateKey) - if err != nil { - return fmt.Errorf("failed to encrypt Aster Private Key: %w", err) - } - } - - // Update database - _, err = tx.Exec(` - UPDATE exchanges - SET api_key = ?, secret_key = ?, - hyperliquid_private_key = ?, aster_private_key = ? - WHERE user_id = ? AND id = ? - `, encAPIKey, encSecretKey, encHLPrivateKey, encAsterPrivateKey, userID, exchangeID) - - if err != nil { - return fmt.Errorf("failed to update database: %w", err) - } - - log.Printf(" ✓ Encrypted: [%s] %s", userID, exchangeID) - count++ - } - - if err := tx.Commit(); err != nil { - return err - } - - log.Printf("✅ Migrated %d exchange configurations", count) - return nil -} - -// migrateAIModels migrates AI model configurations -func migrateAIModels(db *sql.DB, cs *crypto.CryptoService) error { - log.Println("🔄 Migrating AI model configurations...") - - rows, err := db.Query(` - SELECT user_id, id, api_key - FROM ai_models - WHERE api_key != '' AND api_key NOT LIKE 'ENC:v1:%' - `) - if err != nil { - return err - } - defer rows.Close() - - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - count := 0 - for rows.Next() { - var userID, modelID, apiKey string - if err := rows.Scan(&userID, &modelID, &apiKey); err != nil { - return err - } - - encAPIKey, err := cs.EncryptForStorage(apiKey) - if err != nil { - return fmt.Errorf("failed to encrypt API Key: %w", err) - } - - _, err = tx.Exec(` - UPDATE ai_models SET api_key = ? WHERE user_id = ? AND id = ? - `, encAPIKey, userID, modelID) - - if err != nil { - return fmt.Errorf("failed to update database: %w", err) - } - - log.Printf(" ✓ Encrypted: [%s] %s", userID, modelID) - count++ - } - - if err := tx.Commit(); err != nil { - return err - } - - log.Printf("✅ Migrated %d AI model configurations", count) - return nil -} diff --git a/scripts/pr-check.sh b/scripts/pr-check.sh deleted file mode 100755 index 277254d7..00000000 --- a/scripts/pr-check.sh +++ /dev/null @@ -1,413 +0,0 @@ -#!/bin/bash - -# 🔍 PR Health Check Script -# Analyzes your PR and gives suggestions on how to meet the new standards -# This script only analyzes and suggests - it won't modify your code - -set -e - -# Colors -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -BLUE='\033[0;34m' -CYAN='\033[0;36m' -NC='\033[0m' # No Color - -# Counters -ISSUES_FOUND=0 -WARNINGS_FOUND=0 -PASSED_CHECKS=0 - -# Helper functions -log_section() { - echo "" - echo -e "${CYAN}═══════════════════════════════════════════${NC}" - echo -e "${CYAN} $1${NC}" - echo -e "${CYAN}═══════════════════════════════════════════${NC}" -} - -log_check() { - echo -e "${BLUE}🔍 Checking: $1${NC}" -} - -log_pass() { - echo -e "${GREEN}✅ PASS: $1${NC}" - ((PASSED_CHECKS++)) -} - -log_warning() { - echo -e "${YELLOW}⚠️ WARNING: $1${NC}" - ((WARNINGS_FOUND++)) -} - -log_error() { - echo -e "${RED}❌ ISSUE: $1${NC}" - ((ISSUES_FOUND++)) -} - -log_suggestion() { - echo -e "${CYAN}💡 Suggestion: $1${NC}" -} - -log_command() { - echo -e "${GREEN} Run: ${NC}$1" -} - -# Welcome -echo "" -echo "╔═══════════════════════════════════════════╗" -echo "║ NOFX PR Health Check ║" -echo "║ Analyze your PR and get suggestions ║" -echo "╚═══════════════════════════════════════════╝" -echo "" - -# Check if we're in a git repo -if ! git rev-parse --is-inside-work-tree > /dev/null 2>&1; then - log_error "Not a git repository" - exit 1 -fi - -# Get current branch -CURRENT_BRANCH=$(git rev-parse --abbrev-ref HEAD) -echo -e "${BLUE}Current branch: ${GREEN}$CURRENT_BRANCH${NC}" - -if [ "$CURRENT_BRANCH" = "main" ] || [ "$CURRENT_BRANCH" = "dev" ]; then - log_error "You're on the $CURRENT_BRANCH branch. Please switch to your PR branch." - exit 1 -fi - -# Check if upstream exists -if ! git remote | grep -q "^upstream$"; then - log_warning "Upstream remote not found" - log_suggestion "Add upstream remote:" - log_command "git remote add upstream https://github.com/NoFxAiOS/nofx.git" - echo "" -fi - -# ═══════════════════════════════════════════ -# 1. GIT BRANCH CHECKS -# ═══════════════════════════════════════════ -log_section "1. Git Branch Status" - -# Check if branch is up to date with upstream -log_check "Is branch based on latest upstream/dev?" -if git remote | grep -q "^upstream$"; then - git fetch upstream -q 2>/dev/null || true - - if git merge-base --is-ancestor upstream/dev HEAD 2>/dev/null; then - log_pass "Branch is up to date with upstream/dev" - else - log_error "Branch is not based on latest upstream/dev" - log_suggestion "Rebase your branch:" - log_command "git fetch upstream && git rebase upstream/dev" - echo "" - fi -else - log_warning "Cannot check - upstream remote not configured" -fi - -# Check for merge conflicts -log_check "Any merge conflicts?" -if git diff --check > /dev/null 2>&1; then - log_pass "No merge conflicts detected" -else - log_error "Merge conflicts detected" - log_suggestion "Resolve conflicts and commit" -fi - -# ═══════════════════════════════════════════ -# 2. COMMIT MESSAGE CHECKS -# ═══════════════════════════════════════════ -log_section "2. Commit Messages" - -# Get commits in this branch (not in upstream/dev) -if git remote | grep -q "^upstream$"; then - COMMITS=$(git log upstream/dev..HEAD --oneline 2>/dev/null || git log --oneline -10) -else - COMMITS=$(git log --oneline -10) -fi - -COMMIT_COUNT=$(echo "$COMMITS" | wc -l | tr -d ' ') -echo -e "${BLUE}Found $COMMIT_COUNT commit(s) in your branch${NC}" -echo "" - -# Check each commit message -echo "$COMMITS" | while read -r line; do - COMMIT_MSG=$(echo "$line" | cut -d' ' -f2-) - - # Check if follows conventional commits - if echo "$COMMIT_MSG" | grep -qE "^(feat|fix|docs|style|refactor|perf|test|chore|ci|security)(\(.+\))?: .+"; then - log_pass "\"$COMMIT_MSG\"" - else - log_warning "\"$COMMIT_MSG\"" - log_suggestion "Should follow format: type(scope): description" - echo " Examples:" - echo " - feat(exchange): add OKX integration" - echo " - fix(trader): resolve position bug" - echo "" - fi -done - -# Suggest PR title based on commits -echo "" -log_check "Suggested PR title:" -SUGGESTED_TITLE=$(git log --pretty=%s upstream/dev..HEAD 2>/dev/null | head -1 || git log --pretty=%s -1) -echo -e "${GREEN} \"$SUGGESTED_TITLE\"${NC}" -echo "" - -# ═══════════════════════════════════════════ -# 3. CODE QUALITY - BACKEND (Go) -# ═══════════════════════════════════════════ -if find . -name "*.go" -not -path "./vendor/*" -not -path "./.git/*" | grep -q .; then - log_section "3. Backend Code Quality (Go)" - - # Check if Go is installed - if ! command -v go &> /dev/null; then - log_warning "Go not installed - skipping backend checks" - log_suggestion "Install Go: https://go.dev/doc/install" - else - # Check go fmt - log_check "Go code formatting (go fmt)" - UNFORMATTED=$(gofmt -l . 2>/dev/null | grep -v vendor || true) - if [ -z "$UNFORMATTED" ]; then - log_pass "All Go files are formatted" - else - log_error "Some files need formatting:" - echo "$UNFORMATTED" | head -5 | while read -r file; do - echo " - $file" - done - log_suggestion "Format your code:" - log_command "go fmt ./..." - echo "" - fi - - # Check go vet - log_check "Go static analysis (go vet)" - if go vet ./... > /tmp/vet-output.txt 2>&1; then - log_pass "No issues found by go vet" - else - log_error "Go vet found issues:" - head -10 /tmp/vet-output.txt | sed 's/^/ /' - log_suggestion "Fix the issues above" - echo "" - fi - - # Check tests exist - log_check "Do tests exist?" - TEST_FILES=$(find . -name "*_test.go" -not -path "./vendor/*" | wc -l) - if [ "$TEST_FILES" -gt 0 ]; then - log_pass "Found $TEST_FILES test file(s)" - else - log_warning "No test files found" - log_suggestion "Add tests for your changes" - echo "" - fi - - # Run tests - log_check "Running Go tests..." - if go test ./... -v > /tmp/test-output.txt 2>&1; then - log_pass "All tests passed" - else - log_error "Some tests failed:" - grep -E "FAIL|ERROR" /tmp/test-output.txt | head -10 | sed 's/^/ /' || true - log_suggestion "Fix failing tests:" - log_command "go test ./... -v" - echo "" - fi - fi -fi - -# ═══════════════════════════════════════════ -# 4. CODE QUALITY - FRONTEND -# ═══════════════════════════════════════════ -if [ -d "web" ]; then - log_section "4. Frontend Code Quality" - - # Check if npm is installed - if ! command -v npm &> /dev/null; then - log_warning "npm not installed - skipping frontend checks" - log_suggestion "Install Node.js: https://nodejs.org/" - else - cd web - - # Check if node_modules exists - if [ ! -d "node_modules" ]; then - log_warning "Dependencies not installed" - log_suggestion "Install dependencies:" - log_command "cd web && npm install" - cd .. - else - # Check linting - log_check "Frontend linting" - if npm run lint > /tmp/lint-output.txt 2>&1; then - log_pass "No linting issues" - else - log_error "Linting issues found:" - tail -20 /tmp/lint-output.txt | sed 's/^/ /' || true - log_suggestion "Fix linting issues:" - log_command "cd web && npm run lint -- --fix" - echo "" - fi - - # Check type errors - log_check "TypeScript type checking" - if npm run type-check > /tmp/typecheck-output.txt 2>&1; then - log_pass "No type errors" - else - log_error "Type errors found:" - tail -20 /tmp/typecheck-output.txt | sed 's/^/ /' || true - log_suggestion "Fix type errors in your code" - echo "" - fi - - # Check build - log_check "Frontend build" - if npm run build > /tmp/build-output.txt 2>&1; then - log_pass "Build successful" - else - log_error "Build failed:" - tail -20 /tmp/build-output.txt | sed 's/^/ /' || true - log_suggestion "Fix build errors" - echo "" - fi - fi - - cd .. - fi -fi - -# ═══════════════════════════════════════════ -# 5. PR SIZE CHECK -# ═══════════════════════════════════════════ -log_section "5. PR Size" - -if git remote | grep -q "^upstream$"; then - ADDED=$(git diff --numstat upstream/dev...HEAD | awk '{sum+=$1} END {print sum+0}') - DELETED=$(git diff --numstat upstream/dev...HEAD | awk '{sum+=$2} END {print sum+0}') - TOTAL=$((ADDED + DELETED)) - FILES_CHANGED=$(git diff --name-only upstream/dev...HEAD | wc -l) - - echo -e "${BLUE}Lines changed: ${GREEN}+$ADDED ${RED}-$DELETED ${NC}(total: $TOTAL)" - echo -e "${BLUE}Files changed: ${GREEN}$FILES_CHANGED${NC}" - echo "" - - if [ "$TOTAL" -lt 100 ]; then - log_pass "Small PR (<100 lines) - ideal for quick review" - elif [ "$TOTAL" -lt 500 ]; then - log_pass "Medium PR (100-500 lines) - reasonable size" - elif [ "$TOTAL" -lt 1000 ]; then - log_warning "Large PR (500-1000 lines) - consider splitting" - log_suggestion "Breaking into smaller PRs makes review faster" - else - log_error "Very large PR (>1000 lines) - strongly consider splitting" - log_suggestion "Split into multiple smaller PRs, each with a focused change" - echo "" - fi -fi - -# ═══════════════════════════════════════════ -# 6. DOCUMENTATION CHECK -# ═══════════════════════════════════════════ -log_section "6. Documentation" - -# Check if README or docs were updated -log_check "Documentation updates" -if git remote | grep -q "^upstream$"; then - DOC_CHANGES=$(git diff --name-only upstream/dev...HEAD | grep -E "\.(md|txt)$" || true) - - if [ -n "$DOC_CHANGES" ]; then - log_pass "Documentation files updated" - echo "$DOC_CHANGES" | sed 's/^/ - /' - else - # Check if this is a feature/fix that might need docs - COMMIT_TYPES=$(git log --pretty=%s upstream/dev..HEAD | grep -oE "^(feat|fix)" || true) - if [ -n "$COMMIT_TYPES" ]; then - log_warning "No documentation updates found" - log_suggestion "Consider updating docs if your changes affect usage" - echo "" - else - log_pass "No documentation update needed" - fi - fi -fi - -# ═══════════════════════════════════════════ -# 7. ROADMAP ALIGNMENT -# ═══════════════════════════════════════════ -log_section "7. Roadmap Alignment" - -log_check "Does your PR align with the roadmap?" -echo "" -echo "Current priorities (Phase 1):" -echo " ✅ Security enhancements" -echo " ✅ AI model integrations" -echo " ✅ Exchange integrations (OKX, Bybit, Lighter, EdgeX)" -echo " ✅ UI/UX improvements" -echo " ✅ Performance optimizations" -echo " ✅ Bug fixes" -echo "" -log_suggestion "Check roadmap: https://github.com/NoFxAiOS/nofx/blob/dev/docs/roadmap/README.md" -echo "" - -# ═══════════════════════════════════════════ -# FINAL REPORT -# ═══════════════════════════════════════════ -log_section "Summary Report" - -echo "" -echo -e "${GREEN}✅ Passed checks: $PASSED_CHECKS${NC}" -echo -e "${YELLOW}⚠️ Warnings: $WARNINGS_FOUND${NC}" -echo -e "${RED}❌ Issues found: $ISSUES_FOUND${NC}" -echo "" - -# Overall assessment -if [ "$ISSUES_FOUND" -eq 0 ] && [ "$WARNINGS_FOUND" -eq 0 ]; then - echo "╔═══════════════════════════════════════════╗" - echo "║ 🎉 Excellent! Your PR looks great! ║" - echo "║ Ready to submit or update your PR ║" - echo "╚═══════════════════════════════════════════╝" -elif [ "$ISSUES_FOUND" -eq 0 ]; then - echo "╔═══════════════════════════════════════════╗" - echo "║ 👍 Good! Minor warnings found ║" - echo "║ Consider addressing warnings ║" - echo "╚═══════════════════════════════════════════╝" -elif [ "$ISSUES_FOUND" -le 3 ]; then - echo "╔═══════════════════════════════════════════╗" - echo "║ ⚠️ Issues found - Please fix ║" - echo "║ See suggestions above ║" - echo "╚═══════════════════════════════════════════╝" -else - echo "╔═══════════════════════════════════════════╗" - echo "║ ❌ Multiple issues found ║" - echo "║ Please address issues before submitting ║" - echo "╚═══════════════════════════════════════════╝" -fi - -echo "" -echo "📖 Next steps:" -echo "" - -if [ "$ISSUES_FOUND" -gt 0 ] || [ "$WARNINGS_FOUND" -gt 0 ]; then - echo "1. Fix the issues and warnings listed above" - echo "2. Run this script again to verify: ./scripts/pr-check.sh" - echo "3. Commit your fixes" - echo "4. Push to your PR: git push origin $CURRENT_BRANCH" -else - echo "1. Push your changes: git push origin $CURRENT_BRANCH" - echo "2. Create or update your PR on GitHub" - echo "3. Wait for automated CI checks" - echo "4. Address reviewer feedback" -fi - -echo "" -echo "📚 Resources:" -echo " - Contributing Guide: https://github.com/NoFxAiOS/nofx/blob/dev/CONTRIBUTING.md" -echo " - Migration Guide: https://github.com/NoFxAiOS/nofx/blob/dev/docs/community/MIGRATION_ANNOUNCEMENT.md" -echo "" - -# Cleanup temp files -rm -f /tmp/vet-output.txt /tmp/test-output.txt /tmp/lint-output.txt /tmp/typecheck-output.txt /tmp/build-output.txt - -echo "✨ Analysis complete! Good luck with your PR! 🚀" -echo "" diff --git a/scripts/pr-fix.sh b/scripts/pr-fix.sh deleted file mode 100755 index f643cd56..00000000 --- a/scripts/pr-fix.sh +++ /dev/null @@ -1,335 +0,0 @@ -#!/bin/bash - -# 🔄 PR Migration Script for Contributors -# This script helps you migrate your PR to the new format -# Run this in your local fork to update your PR automatically - -set -e - -# Colors -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -BLUE='\033[0;34m' -NC='\033[0m' # No Color - -# Helper functions -log_info() { - echo -e "${BLUE}ℹ️ $1${NC}" -} - -log_success() { - echo -e "${GREEN}✅ $1${NC}" -} - -log_warning() { - echo -e "${YELLOW}⚠️ $1${NC}" -} - -log_error() { - echo -e "${RED}❌ $1${NC}" -} - -confirm() { - read -p "$(echo -e ${YELLOW}"$1 (y/N): "${NC})" -n 1 -r - echo - [[ $REPLY =~ ^[Yy]$ ]] -} - -# Welcome message -echo "" -echo "╔═══════════════════════════════════════════╗" -echo "║ NOFX PR Migration Tool ║" -echo "║ Migrate your PR to the new format ║" -echo "╚═══════════════════════════════════════════╝" -echo "" - -# Check if we're in a git repo -if ! git rev-parse --is-inside-work-tree > /dev/null 2>&1; then - log_error "Not a git repository. Please run this from your NOFX fork." - exit 1 -fi - -# Check current branch -CURRENT_BRANCH=$(git rev-parse --abbrev-ref HEAD) -log_info "Current branch: $CURRENT_BRANCH" - -if [ "$CURRENT_BRANCH" = "main" ] || [ "$CURRENT_BRANCH" = "dev" ]; then - log_warning "You're on the $CURRENT_BRANCH branch." - log_info "This script should be run on your PR branch." - - # List branches - log_info "Your branches:" - git branch - - echo "" - read -p "Enter your PR branch name: " PR_BRANCH - - if [ -z "$PR_BRANCH" ]; then - log_error "No branch specified. Exiting." - exit 1 - fi - - git checkout "$PR_BRANCH" || { - log_error "Failed to checkout branch $PR_BRANCH" - exit 1 - } - - CURRENT_BRANCH="$PR_BRANCH" -fi - -log_success "Working on branch: $CURRENT_BRANCH" - -echo "" -log_info "What this script will do:" -echo " 1. ✅ Verify you're rebased on latest upstream/dev" -echo " 2. ✅ Check and format Go code (go fmt)" -echo " 3. ✅ Run Go linting (go vet)" -echo " 4. ✅ Run Go tests" -echo " 5. ✅ Check frontend code (if modified)" -echo " 6. ✅ Give you feedback and suggestions" -echo "" -log_warning "Make sure you've already run: git fetch upstream && git rebase upstream/dev" -echo "" - -if ! confirm "Continue with migration?"; then - log_info "Migration cancelled" - exit 0 -fi - -# Step 1: Verify upstream sync -echo "" -log_info "Step 1: Verifying upstream sync..." - -# Check if upstream remote exists -if ! git remote | grep -q "^upstream$"; then - log_warning "Upstream remote not found. Adding it..." - git remote add upstream https://github.com/NoFxAiOS/nofx.git - git fetch upstream - log_success "Added upstream remote" -fi - -# Check if we're up to date with upstream/dev -if git merge-base --is-ancestor upstream/dev HEAD; then - log_success "Your branch is up to date with upstream/dev" -else - log_warning "Your branch is not based on latest upstream/dev" - log_info "Please run first: git fetch upstream && git rebase upstream/dev" - - if confirm "Try to rebase now?"; then - git fetch upstream - if git rebase upstream/dev; then - log_success "Successfully rebased on upstream/dev" - else - log_error "Rebase failed. Please resolve conflicts manually." - exit 1 - fi - else - log_warning "Skipping rebase. Results may not be accurate." - fi -fi - -# Step 2: Backend checks (if Go files exist) -if find . -name "*.go" -not -path "./vendor/*" | grep -q .; then - echo "" - log_info "Step 2: Running backend checks..." - - # Check if Go is installed - if ! command -v go &> /dev/null; then - log_warning "Go not found. Skipping backend checks." - log_info "Install Go: https://go.dev/doc/install" - else - # Format Go code - log_info "Formatting Go code..." - if go fmt ./...; then - log_success "Go code formatted" - - # Check if there are changes - if ! git diff --quiet; then - log_info "Formatting created changes. Committing..." - git add . - git commit -m "chore: format Go code with go fmt" || true - fi - else - log_warning "Go formatting had issues (non-critical)" - fi - - # Run go vet - log_info "Running go vet..." - if go vet ./...; then - log_success "Go vet passed" - else - log_warning "Go vet found issues. Please review them." - if confirm "Continue anyway?"; then - log_info "Continuing..." - else - exit 1 - fi - fi - - # Run tests - log_info "Running Go tests..." - if go test ./...; then - log_success "All Go tests passed" - else - log_warning "Some tests failed. Please fix them before pushing." - if confirm "Continue anyway?"; then - log_info "Continuing..." - else - exit 1 - fi - fi - fi -else - log_info "Step 2: No Go files found, skipping backend checks" -fi - -# Step 3: Frontend checks (if web directory exists) -if [ -d "web" ]; then - echo "" - log_info "Step 3: Running frontend checks..." - - # Check if npm is installed - if ! command -v npm &> /dev/null; then - log_warning "npm not found. Skipping frontend checks." - log_info "Install Node.js: https://nodejs.org/" - else - cd web - - # Install dependencies if needed - if [ ! -d "node_modules" ]; then - log_info "Installing dependencies..." - npm install - fi - - # Run linter - log_info "Running linter..." - if npm run lint; then - log_success "Linting passed" - else - log_warning "Linting found issues" - log_info "Attempting to auto-fix..." - npm run lint -- --fix || true - - # Commit fixes if any - if ! git diff --quiet; then - git add . - git commit -m "chore: fix linting issues" || true - fi - fi - - # Type check - log_info "Running type check..." - if npm run type-check; then - log_success "Type checking passed" - else - log_warning "Type checking found issues. Please fix them." - fi - - # Build - log_info "Testing build..." - if npm run build; then - log_success "Build successful" - else - log_error "Build failed. Please fix build errors." - cd .. - exit 1 - fi - - cd .. - fi -else - log_info "Step 3: No frontend changes, skipping frontend checks" -fi - -# Step 4: Check PR title format -echo "" -log_info "Step 4: Checking PR title format..." - -# Get the commit messages to suggest a title -COMMITS=$(git log upstream/dev..HEAD --oneline) -COMMIT_COUNT=$(echo "$COMMITS" | wc -l | tr -d ' ') - -log_info "Found $COMMIT_COUNT commit(s) in your PR" - -if [ "$COMMIT_COUNT" -eq 1 ]; then - SUGGESTED_TITLE=$(git log -1 --pretty=%s) -else - SUGGESTED_TITLE=$(git log --pretty=%s upstream/dev..HEAD | head -1) -fi - -log_info "Current/suggested title: $SUGGESTED_TITLE" - -# Check if it follows conventional commits -if echo "$SUGGESTED_TITLE" | grep -qE "^(feat|fix|docs|style|refactor|perf|test|chore|ci|security)(\(.+\))?: .+"; then - log_success "Title follows Conventional Commits format" -else - log_warning "Title doesn't follow Conventional Commits format" - echo "" - echo "Conventional Commits format:" - echo " (): " - echo "" - echo "Types: feat, fix, docs, style, refactor, perf, test, chore, ci, security" - echo "" - echo "Examples:" - echo " feat(exchange): add OKX integration" - echo " fix(trader): resolve position tracking bug" - echo " docs(readme): update installation guide" - echo "" - - read -p "Enter new title (or press Enter to keep current): " NEW_TITLE - - if [ -n "$NEW_TITLE" ]; then - log_info "You can update the PR title on GitHub after pushing" - log_info "Suggested title: $NEW_TITLE" - fi -fi - -# Step 5: Push changes -echo "" -log_info "Step 5: Ready to push changes" - -# Check if there are changes to push -if git diff upstream/dev..HEAD --quiet; then - log_info "No changes to push" -else - log_info "Changes ready to push to origin/$CURRENT_BRANCH" - - if confirm "Push changes now?"; then - log_info "Pushing to origin/$CURRENT_BRANCH..." - if git push -f origin "$CURRENT_BRANCH"; then - log_success "Successfully pushed changes!" - else - log_error "Failed to push. You may need to push manually:" - echo " git push -f origin $CURRENT_BRANCH" - exit 1 - fi - else - log_info "Skipped push. You can push manually later:" - echo " git push -f origin $CURRENT_BRANCH" - fi -fi - -# Summary -echo "" -echo "╔═══════════════════════════════════════════╗" -echo "║ ✅ Migration Complete! ║" -echo "╚═══════════════════════════════════════════╝" -echo "" - -log_success "Your PR has been migrated!" - -echo "" -log_info "Next steps:" -echo " 1. Check your PR on GitHub" -echo " 2. Update PR title if needed (Conventional Commits format)" -echo " 3. Wait for CI checks to run" -echo " 4. Address any reviewer feedback" -echo "" - -log_info "Need help? Ask in the PR comments or Telegram!" -log_info "Telegram: https://t.me/nofx_dev_community" - -echo "" -log_success "Thank you for contributing to NOFX! 🚀" -echo "" diff --git a/scripts/restart_and_test.sh b/scripts/restart_and_test.sh deleted file mode 100644 index 740c7542..00000000 --- a/scripts/restart_and_test.sh +++ /dev/null @@ -1,65 +0,0 @@ -#!/bin/bash - -echo "==================================" -echo "NOFX 后端重启和测试脚本" -echo "==================================" - -# 1. 停止旧进程 -echo "" -echo "1️⃣ 停止旧进程..." -pkill -f "bin/nofx" || echo " 没有运行中的进程" -sleep 2 - -# 2. 清理旧数据 -echo "" -echo "2️⃣ 清理测试数据..." -sqlite3 data/data.db "DELETE FROM trader_fills; DELETE FROM trader_orders;" -echo " ✅ trader_orders 和 trader_fills 表已清空" - -# 3. 验证数据库已清空 -ORDERS_COUNT=$(sqlite3 data/data.db "SELECT COUNT(*) FROM trader_orders") -FILLS_COUNT=$(sqlite3 data/data.db "SELECT COUNT(*) FROM trader_fills") -echo " 验证: trader_orders=$ORDERS_COUNT, trader_fills=$FILLS_COUNT" - -# 4. 启动新进程 -echo "" -echo "3️⃣ 启动新编译的后端服务..." -if [ ! -f "bin/nofx" ]; then - echo " ❌ bin/nofx 不存在,请先运行 go build -o bin/nofx ." - exit 1 -fi - -nohup ./bin/nofx > data/nofx_$(date +%Y-%m-%d).log 2>&1 & -NOFX_PID=$! -echo " ✅ 后端已启动 (PID: $NOFX_PID)" - -# 5. 等待服务启动 -echo "" -echo "4️⃣ 等待服务启动..." -sleep 3 - -# 6. 验证进程运行 -if ps -p $NOFX_PID > /dev/null; then - echo " ✅ 后端进程运行正常 (PID: $NOFX_PID)" -else - echo " ❌ 后端进程启动失败,请检查日志" - tail -20 data/nofx_$(date +%Y-%m-%d).log - exit 1 -fi - -echo "" -echo "==================================" -echo "✅ 重启完成!" -echo "==================================" -echo "" -echo "📝 下一步操作:" -echo " 1. 访问前端页面" -echo " 2. 执行一次平仓操作(手动或AI)" -echo " 3. 等待 10 秒(让 pollLighterTradeHistory 完成)" -echo " 4. 检查数据库:" -echo " sqlite3 data/data.db \"SELECT id, status, avg_fill_price, filled_quantity FROM trader_orders\"" -echo " 5. 刷新图表页面,应该能看到 B/S 标记" -echo "" -echo "📊 实时日志查看:" -echo " tail -f data/nofx_$(date +%Y-%m-%d).log | grep -E 'Order recorded|Found matching trade|Fill recorded'" -echo "" diff --git a/scripts/test_lighter_orders.go b/scripts/test_lighter_orders.go deleted file mode 100644 index e064cac2..00000000 --- a/scripts/test_lighter_orders.go +++ /dev/null @@ -1,168 +0,0 @@ -//go:build ignore - -// Test script to verify Lighter API authentication -// Run: go run scripts/test_lighter_orders.go -package main - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "os" - "time" - - lighterClient "github.com/elliottech/lighter-go/client" - lighterHTTP "github.com/elliottech/lighter-go/client/http" -) - -func main() { - // Configuration - update these values - walletAddr := os.Getenv("LIGHTER_WALLET") - apiKeyPrivateKey := os.Getenv("LIGHTER_API_KEY") - - if walletAddr == "" || apiKeyPrivateKey == "" { - fmt.Println("Usage: LIGHTER_WALLET=0x... LIGHTER_API_KEY=... go run scripts/test_lighter_orders.go") - fmt.Println("Environment variables required:") - fmt.Println(" LIGHTER_WALLET - Ethereum wallet address") - fmt.Println(" LIGHTER_API_KEY - API key private key (40 bytes hex)") - os.Exit(1) - } - - fmt.Println("=== Lighter API Test ===") - fmt.Printf("Wallet: %s\n\n", walletAddr) - - baseURL := "https://mainnet.zklighter.elliot.ai" - chainID := uint32(304) - client := &http.Client{Timeout: 30 * time.Second} - - // Step 1: Get account info (no auth required) - fmt.Println("1. Getting account info...") - accountIndex, err := getAccountIndex(client, baseURL, walletAddr) - if err != nil { - fmt.Printf(" FAILED: %v\n", err) - os.Exit(1) - } - fmt.Printf(" OK: account_index = %d\n\n", accountIndex) - - // Step 2: Create TxClient and generate auth token - fmt.Println("2. Creating TxClient and generating auth token...") - httpClient := lighterHTTP.NewClient(baseURL) - txClient, err := lighterClient.NewTxClient(httpClient, apiKeyPrivateKey, accountIndex, 0, chainID) - if err != nil { - fmt.Printf(" FAILED: %v\n", err) - os.Exit(1) - } - - authToken, err := txClient.GetAuthToken(time.Now().Add(1 * time.Hour)) - if err != nil { - fmt.Printf(" FAILED: %v\n", err) - os.Exit(1) - } - fmt.Printf(" OK: auth token generated\n\n") - - // Step 3: Test GetActiveOrders with auth query parameter (NEW method) - fmt.Println("3. Testing GetActiveOrders with auth query parameter (FIXED)...") - encodedAuth := url.QueryEscape(authToken) - endpoint := fmt.Sprintf("%s/api/v1/accountActiveOrders?account_index=%d&market_id=0&auth=%s", - baseURL, accountIndex, encodedAuth) - - resp, err := client.Get(endpoint) - if err != nil { - fmt.Printf(" FAILED: %v\n", err) - os.Exit(1) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - var result map[string]interface{} - json.Unmarshal(body, &result) - - if code, ok := result["code"].(float64); ok && code == 200 { - orders := result["orders"].([]interface{}) - fmt.Printf(" OK: Retrieved %d orders\n", len(orders)) - if len(orders) > 0 { - fmt.Println(" Sample orders:") - for i, o := range orders { - if i >= 3 { - fmt.Printf(" ... and %d more\n", len(orders)-3) - break - } - order := o.(map[string]interface{}) - fmt.Printf(" - ID: %v, Price: %v, Side: %v\n", - order["order_id"], order["price"], order["is_ask"]) - } - } - } else { - fmt.Printf(" FAILED: %s\n", string(body)) - fmt.Println("\n Possible causes:") - fmt.Println(" - API key not registered on-chain") - fmt.Println(" - API key private key incorrect") - fmt.Println(" - Account index mismatch") - os.Exit(1) - } - - // Step 4: Test GetActiveOrders with Authorization header (OLD method - for comparison) - fmt.Println("\n4. Testing GetActiveOrders with Authorization header (OLD method)...") - endpoint2 := fmt.Sprintf("%s/api/v1/accountActiveOrders?account_index=%d&market_id=0", - baseURL, accountIndex) - - req, _ := http.NewRequest("GET", endpoint2, nil) - req.Header.Set("Authorization", authToken) - req.Header.Set("Content-Type", "application/json") - - resp2, err := client.Do(req) - if err != nil { - fmt.Printf(" FAILED: %v\n", err) - } else { - defer resp2.Body.Close() - body2, _ := io.ReadAll(resp2.Body) - var result2 map[string]interface{} - json.Unmarshal(body2, &result2) - - if code, ok := result2["code"].(float64); ok && code == 200 { - orders := result2["orders"].([]interface{}) - fmt.Printf(" OK: Retrieved %d orders (both methods work!)\n", len(orders)) - } else { - fmt.Printf(" FAILED: %s\n", string(body2)) - fmt.Println(" ^ This is expected - Authorization header doesn't work consistently") - } - } - - fmt.Println("\n=== TEST COMPLETE ===") - fmt.Println("If test 3 passed, the fix is working correctly.") -} - -func getAccountIndex(client *http.Client, baseURL, walletAddr string) (int64, error) { - endpoint := fmt.Sprintf("%s/api/v1/account?by=l1_address&value=%s", baseURL, walletAddr) - resp, err := client.Get(endpoint) - if err != nil { - return 0, err - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - var result struct { - Code int `json:"code"` - Accounts []struct { - AccountIndex int64 `json:"account_index"` - } `json:"accounts"` - SubAccounts []struct { - AccountIndex int64 `json:"account_index"` - } `json:"sub_accounts"` - } - - if err := json.Unmarshal(body, &result); err != nil { - return 0, fmt.Errorf("failed to parse: %w", err) - } - - if len(result.Accounts) > 0 { - return result.Accounts[0].AccountIndex, nil - } - if len(result.SubAccounts) > 0 { - return result.SubAccounts[0].AccountIndex, nil - } - - return 0, fmt.Errorf("no account found") -} diff --git a/store/position.go b/store/position.go index b62a260f..92463db9 100644 --- a/store/position.go +++ b/store/position.go @@ -60,19 +60,36 @@ func getPriceDecimalPlaces(price float64) int { return len(s) - idx - 1 } -// TraderStats trading statistics metrics -type TraderStats struct { - TotalTrades int `json:"total_trades"` - WinTrades int `json:"win_trades"` - LossTrades int `json:"loss_trades"` - WinRate float64 `json:"win_rate"` - ProfitFactor float64 `json:"profit_factor"` - SharpeRatio float64 `json:"sharpe_ratio"` - TotalPnL float64 `json:"total_pnl"` - TotalFee float64 `json:"total_fee"` - AvgWin float64 `json:"avg_win"` - AvgLoss float64 `json:"avg_loss"` - MaxDrawdownPct float64 `json:"max_drawdown_pct"` +// formatDuration formats a duration +func formatDuration(d time.Duration) string { + return formatDurationMs(d.Milliseconds()) +} + +// formatDurationMs formats a duration in milliseconds +func formatDurationMs(ms int64) string { + seconds := ms / 1000 + minutes := seconds / 60 + hours := minutes / 60 + days := hours / 24 + + if seconds < 60 { + return fmt.Sprintf("%ds", seconds) + } + if minutes < 60 { + return fmt.Sprintf("%dm", minutes) + } + if hours < 24 { + remainingMins := minutes % 60 + if remainingMins == 0 { + return fmt.Sprintf("%dh", hours) + } + return fmt.Sprintf("%dh%dm", hours, remainingMins) + } + remainingHours := hours % 24 + if remainingHours == 0 { + return fmt.Sprintf("%dd", days) + } + return fmt.Sprintf("%dd%dh", days, remainingHours) } // TraderPosition position record @@ -400,585 +417,6 @@ func (s *PositionStore) GetAllOpenPositions() ([]*TraderPosition, error) { return positions, nil } -// GetPositionStats gets position statistics -func (s *PositionStore) GetPositionStats(traderID string) (map[string]interface{}, error) { - stats := make(map[string]interface{}) - - type result struct { - Total int - Wins int - TotalPnL float64 - TotalFee float64 - } - var r result - - err := s.db.Model(&TraderPosition{}). - Select("COUNT(*) as total, SUM(CASE WHEN realized_pnl > 0 THEN 1 ELSE 0 END) as wins, COALESCE(SUM(realized_pnl), 0) as total_pnl, COALESCE(SUM(fee), 0) as total_fee"). - Where("trader_id = ? AND status = ?", traderID, "CLOSED"). - Scan(&r).Error - if err != nil { - return nil, err - } - - stats["total_trades"] = r.Total - stats["win_trades"] = r.Wins - stats["total_pnl"] = r.TotalPnL - stats["total_fee"] = r.TotalFee - if r.Total > 0 { - stats["win_rate"] = float64(r.Wins) / float64(r.Total) * 100 - } else { - stats["win_rate"] = 0.0 - } - - return stats, nil -} - -// GetFullStats gets complete trading statistics -func (s *PositionStore) GetFullStats(traderID string) (*TraderStats, error) { - stats := &TraderStats{} - - var count int64 - if err := s.db.Model(&TraderPosition{}).Where("trader_id = ? AND status = ?", traderID, "CLOSED").Count(&count).Error; err != nil { - return nil, err - } - if count == 0 { - return stats, nil - } - - var positions []TraderPosition - err := s.db.Where("trader_id = ? AND status = ?", traderID, "CLOSED"). - Order("exit_time ASC"). - Find(&positions).Error - if err != nil { - return nil, fmt.Errorf("failed to query position statistics: %w", err) - } - - var pnls []float64 - var totalWin, totalLoss float64 - - for _, pos := range positions { - stats.TotalTrades++ - stats.TotalPnL += pos.RealizedPnL - stats.TotalFee += pos.Fee - pnls = append(pnls, pos.RealizedPnL) - - if pos.RealizedPnL > 0 { - stats.WinTrades++ - totalWin += pos.RealizedPnL - } else if pos.RealizedPnL < 0 { - stats.LossTrades++ - totalLoss += -pos.RealizedPnL - } - } - - if stats.TotalTrades > 0 { - stats.WinRate = float64(stats.WinTrades) / float64(stats.TotalTrades) * 100 - } - if totalLoss > 0 { - stats.ProfitFactor = totalWin / totalLoss - } - if stats.WinTrades > 0 { - stats.AvgWin = totalWin / float64(stats.WinTrades) - } - if stats.LossTrades > 0 { - stats.AvgLoss = totalLoss / float64(stats.LossTrades) - } - if len(pnls) > 1 { - stats.SharpeRatio = calculateSharpeRatioFromPnls(pnls) - } - if len(pnls) > 0 { - stats.MaxDrawdownPct = calculateMaxDrawdownFromPnls(pnls) - } - - return stats, nil -} - -// RecentTrade recent trade record -type RecentTrade struct { - Symbol string `json:"symbol"` - Side string `json:"side"` - EntryPrice float64 `json:"entry_price"` - ExitPrice float64 `json:"exit_price"` - RealizedPnL float64 `json:"realized_pnl"` - PnLPct float64 `json:"pnl_pct"` - EntryTime int64 `json:"entry_time"` - ExitTime int64 `json:"exit_time"` - HoldDuration string `json:"hold_duration"` -} - -// GetRecentTrades gets recent closed trades -func (s *PositionStore) GetRecentTrades(traderID string, limit int) ([]RecentTrade, error) { - var positions []TraderPosition - err := s.db.Where("trader_id = ? AND status = ?", traderID, "CLOSED"). - Order("exit_time DESC"). - Limit(limit). - Find(&positions).Error - if err != nil { - return nil, fmt.Errorf("failed to query recent trades: %w", err) - } - - var trades []RecentTrade - for _, pos := range positions { - t := RecentTrade{ - Symbol: pos.Symbol, - Side: strings.ToLower(pos.Side), - EntryPrice: pos.EntryPrice, - ExitPrice: pos.ExitPrice, - RealizedPnL: pos.RealizedPnL, - EntryTime: pos.EntryTime / 1000, // Convert ms to seconds for API compatibility - } - - if pos.ExitTime > 0 { - t.ExitTime = pos.ExitTime / 1000 // Convert ms to seconds - durationMs := pos.ExitTime - pos.EntryTime - t.HoldDuration = formatDurationMs(durationMs) - } - - if pos.EntryPrice > 0 { - if t.Side == "long" { - t.PnLPct = (pos.ExitPrice - pos.EntryPrice) / pos.EntryPrice * 100 * float64(pos.Leverage) - } else { - t.PnLPct = (pos.EntryPrice - pos.ExitPrice) / pos.EntryPrice * 100 * float64(pos.Leverage) - } - } - - trades = append(trades, t) - } - - return trades, nil -} - -// formatDuration formats a duration -func formatDuration(d time.Duration) string { - return formatDurationMs(d.Milliseconds()) -} - -// formatDurationMs formats a duration in milliseconds -func formatDurationMs(ms int64) string { - seconds := ms / 1000 - minutes := seconds / 60 - hours := minutes / 60 - days := hours / 24 - - if seconds < 60 { - return fmt.Sprintf("%ds", seconds) - } - if minutes < 60 { - return fmt.Sprintf("%dm", minutes) - } - if hours < 24 { - remainingMins := minutes % 60 - if remainingMins == 0 { - return fmt.Sprintf("%dh", hours) - } - return fmt.Sprintf("%dh%dm", hours, remainingMins) - } - remainingHours := hours % 24 - if remainingHours == 0 { - return fmt.Sprintf("%dd", days) - } - return fmt.Sprintf("%dd%dh", days, remainingHours) -} - -// calculateSharpeRatioFromPnls calculates Sharpe ratio -func calculateSharpeRatioFromPnls(pnls []float64) float64 { - if len(pnls) < 2 { - return 0 - } - - var sum float64 - for _, pnl := range pnls { - sum += pnl - } - mean := sum / float64(len(pnls)) - - var variance float64 - for _, pnl := range pnls { - variance += (pnl - mean) * (pnl - mean) - } - stdDev := math.Sqrt(variance / float64(len(pnls)-1)) - - if stdDev == 0 { - return 0 - } - - return mean / stdDev -} - -// calculateMaxDrawdownFromPnls calculates maximum drawdown -func calculateMaxDrawdownFromPnls(pnls []float64) float64 { - if len(pnls) == 0 { - return 0 - } - - const startingEquity = 10000.0 - equity := startingEquity - peak := startingEquity - var maxDD float64 - - for _, pnl := range pnls { - equity += pnl - if equity > peak { - peak = equity - } - if peak > 0 { - dd := (peak - equity) / peak * 100 - if dd > maxDD { - maxDD = dd - } - } - } - - return maxDD -} - -// SymbolStats per-symbol trading statistics -type SymbolStats struct { - Symbol string `json:"symbol"` - TotalTrades int `json:"total_trades"` - WinTrades int `json:"win_trades"` - WinRate float64 `json:"win_rate"` - TotalPnL float64 `json:"total_pnl"` - AvgPnL float64 `json:"avg_pnl"` - AvgHoldMins float64 `json:"avg_hold_mins"` -} - -// GetSymbolStats gets per-symbol trading statistics -func (s *PositionStore) GetSymbolStats(traderID string, limit int) ([]SymbolStats, error) { - var positions []TraderPosition - err := s.db.Where("trader_id = ? AND status = ?", traderID, "CLOSED").Find(&positions).Error - if err != nil { - return nil, fmt.Errorf("failed to query symbol stats: %w", err) - } - - // Group by symbol - symbolMap := make(map[string]*SymbolStats) - symbolHoldMins := make(map[string][]float64) - - for _, pos := range positions { - if _, ok := symbolMap[pos.Symbol]; !ok { - symbolMap[pos.Symbol] = &SymbolStats{Symbol: pos.Symbol} - symbolHoldMins[pos.Symbol] = []float64{} - } - s := symbolMap[pos.Symbol] - s.TotalTrades++ - s.TotalPnL += pos.RealizedPnL - if pos.RealizedPnL > 0 { - s.WinTrades++ - } - - if pos.ExitTime > 0 { - holdMins := float64(pos.ExitTime-pos.EntryTime) / 60000.0 // ms to minutes - symbolHoldMins[pos.Symbol] = append(symbolHoldMins[pos.Symbol], holdMins) - } - } - - var stats []SymbolStats - for symbol, s := range symbolMap { - if s.TotalTrades > 0 { - s.WinRate = float64(s.WinTrades) / float64(s.TotalTrades) * 100 - s.AvgPnL = s.TotalPnL / float64(s.TotalTrades) - } - if len(symbolHoldMins[symbol]) > 0 { - var totalMins float64 - for _, m := range symbolHoldMins[symbol] { - totalMins += m - } - s.AvgHoldMins = totalMins / float64(len(symbolHoldMins[symbol])) - } - stats = append(stats, *s) - } - - // Sort by TotalPnL descending and limit - for i := 0; i < len(stats)-1; i++ { - for j := i + 1; j < len(stats); j++ { - if stats[j].TotalPnL > stats[i].TotalPnL { - stats[i], stats[j] = stats[j], stats[i] - } - } - } - - if limit > 0 && len(stats) > limit { - stats = stats[:limit] - } - - return stats, nil -} - -// HoldingTimeStats holding duration analysis -type HoldingTimeStats struct { - Range string `json:"range"` - TradeCount int `json:"trade_count"` - WinRate float64 `json:"win_rate"` - AvgPnL float64 `json:"avg_pnl"` -} - -// GetHoldingTimeStats analyzes performance by holding duration -func (s *PositionStore) GetHoldingTimeStats(traderID string) ([]HoldingTimeStats, error) { - var positions []TraderPosition - err := s.db.Where("trader_id = ? AND status = ? AND exit_time > 0", traderID, "CLOSED").Find(&positions).Error - if err != nil { - return nil, fmt.Errorf("failed to query holding time stats: %w", err) - } - - rangeStats := map[string]*struct { - count int - wins int - totalPnL float64 - }{ - "<1h": {}, - "1-4h": {}, - "4-24h": {}, - ">24h": {}, - } - - for _, pos := range positions { - if pos.ExitTime == 0 { - continue - } - holdHours := float64(pos.ExitTime-pos.EntryTime) / 3600000.0 // ms to hours - - var rangeKey string - switch { - case holdHours < 1: - rangeKey = "<1h" - case holdHours < 4: - rangeKey = "1-4h" - case holdHours < 24: - rangeKey = "4-24h" - default: - rangeKey = ">24h" - } - - r := rangeStats[rangeKey] - r.count++ - r.totalPnL += pos.RealizedPnL - if pos.RealizedPnL > 0 { - r.wins++ - } - } - - var stats []HoldingTimeStats - for _, rangeKey := range []string{"<1h", "1-4h", "4-24h", ">24h"} { - r := rangeStats[rangeKey] - if r.count > 0 { - stats = append(stats, HoldingTimeStats{ - Range: rangeKey, - TradeCount: r.count, - WinRate: float64(r.wins) / float64(r.count) * 100, - AvgPnL: r.totalPnL / float64(r.count), - }) - } - } - - return stats, nil -} - -// DirectionStats long/short performance comparison -type DirectionStats struct { - Side string `json:"side"` - TradeCount int `json:"trade_count"` - WinRate float64 `json:"win_rate"` - TotalPnL float64 `json:"total_pnl"` - AvgPnL float64 `json:"avg_pnl"` -} - -// GetDirectionStats analyzes long vs short performance -func (s *PositionStore) GetDirectionStats(traderID string) ([]DirectionStats, error) { - var positions []TraderPosition - err := s.db.Where("trader_id = ? AND status = ?", traderID, "CLOSED").Find(&positions).Error - if err != nil { - return nil, fmt.Errorf("failed to query direction stats: %w", err) - } - - sideStats := make(map[string]*DirectionStats) - for _, pos := range positions { - if _, ok := sideStats[pos.Side]; !ok { - sideStats[pos.Side] = &DirectionStats{Side: pos.Side} - } - s := sideStats[pos.Side] - s.TradeCount++ - s.TotalPnL += pos.RealizedPnL - if pos.RealizedPnL > 0 { - s.WinRate++ - } - } - - var stats []DirectionStats - for _, s := range sideStats { - if s.TradeCount > 0 { - s.AvgPnL = s.TotalPnL / float64(s.TradeCount) - s.WinRate = s.WinRate / float64(s.TradeCount) * 100 - } - stats = append(stats, *s) - } - - return stats, nil -} - -// HistorySummary comprehensive trading history for AI context -type HistorySummary struct { - TotalTrades int `json:"total_trades"` - WinRate float64 `json:"win_rate"` - TotalPnL float64 `json:"total_pnl"` - AvgTradeReturn float64 `json:"avg_trade_return"` - - BestSymbols []SymbolStats `json:"best_symbols"` - WorstSymbols []SymbolStats `json:"worst_symbols"` - - LongWinRate float64 `json:"long_win_rate"` - ShortWinRate float64 `json:"short_win_rate"` - LongPnL float64 `json:"long_pnl"` - ShortPnL float64 `json:"short_pnl"` - - AvgHoldingMins float64 `json:"avg_holding_mins"` - BestHoldRange string `json:"best_hold_range"` - - RecentWinRate float64 `json:"recent_win_rate"` - RecentPnL float64 `json:"recent_pnl"` - - CurrentStreak int `json:"current_streak"` - MaxWinStreak int `json:"max_win_streak"` - MaxLoseStreak int `json:"max_lose_streak"` -} - -// GetHistorySummary generates comprehensive AI context summary -func (s *PositionStore) GetHistorySummary(traderID string) (*HistorySummary, error) { - summary := &HistorySummary{} - - fullStats, err := s.GetFullStats(traderID) - if err != nil { - return nil, err - } - summary.TotalTrades = fullStats.TotalTrades - summary.WinRate = fullStats.WinRate - summary.TotalPnL = fullStats.TotalPnL - if fullStats.TotalTrades > 0 { - summary.AvgTradeReturn = fullStats.TotalPnL / float64(fullStats.TotalTrades) - } - - symbolStats, _ := s.GetSymbolStats(traderID, 20) - if len(symbolStats) > 0 { - for i := 0; i < len(symbolStats) && i < 3; i++ { - if symbolStats[i].TotalPnL > 0 { - summary.BestSymbols = append(summary.BestSymbols, symbolStats[i]) - } - } - for i := len(symbolStats) - 1; i >= 0 && len(summary.WorstSymbols) < 3; i-- { - if symbolStats[i].TotalPnL < 0 { - summary.WorstSymbols = append(summary.WorstSymbols, symbolStats[i]) - } - } - } - - dirStats, _ := s.GetDirectionStats(traderID) - for _, d := range dirStats { - if d.Side == "LONG" { - summary.LongWinRate = d.WinRate - summary.LongPnL = d.TotalPnL - } else if d.Side == "SHORT" { - summary.ShortWinRate = d.WinRate - summary.ShortPnL = d.TotalPnL - } - } - - holdStats, _ := s.GetHoldingTimeStats(traderID) - var bestHoldWinRate float64 - for _, h := range holdStats { - if h.WinRate > bestHoldWinRate && h.TradeCount >= 3 { - bestHoldWinRate = h.WinRate - summary.BestHoldRange = h.Range - } - } - - // Calculate average holding time - var positions []TraderPosition - s.db.Where("trader_id = ? AND status = ? AND exit_time > 0", traderID, "CLOSED").Find(&positions) - if len(positions) > 0 { - var totalMins float64 - for _, pos := range positions { - if pos.ExitTime > 0 { - totalMins += float64(pos.ExitTime-pos.EntryTime) / 60000.0 // ms to minutes - } - } - summary.AvgHoldingMins = totalMins / float64(len(positions)) - } - - // Recent 20 trades - var recent []TraderPosition - s.db.Where("trader_id = ? AND status = ?", traderID, "CLOSED"). - Order("exit_time DESC").Limit(20).Find(&recent) - for _, pos := range recent { - summary.RecentPnL += pos.RealizedPnL - if pos.RealizedPnL > 0 { - summary.RecentWinRate++ - } - } - if len(recent) > 0 { - summary.RecentWinRate = summary.RecentWinRate / float64(len(recent)) * 100 - } - - // Calculate streaks - s.calculateStreaks(traderID, summary) - - return summary, nil -} - -// calculateStreaks calculates win/loss streaks -func (s *PositionStore) calculateStreaks(traderID string, summary *HistorySummary) { - var positions []TraderPosition - err := s.db.Where("trader_id = ? AND status = ?", traderID, "CLOSED"). - Order("exit_time DESC"). - Find(&positions).Error - if err != nil || len(positions) == 0 { - return - } - - var currentStreak, maxWin, maxLose int - var prevWin *bool - isFirst := true - - for _, pos := range positions { - isWin := pos.RealizedPnL > 0 - - if isFirst { - if isWin { - currentStreak = 1 - } else { - currentStreak = -1 - } - isFirst = false - } - - if prevWin == nil { - prevWin = &isWin - } else if *prevWin == isWin { - if isWin { - currentStreak++ - if currentStreak > maxWin { - maxWin = currentStreak - } - } else { - currentStreak-- - if -currentStreak > maxLose { - maxLose = -currentStreak - } - } - } else { - if isWin { - currentStreak = 1 - } else { - currentStreak = -1 - } - *prevWin = isWin - } - } - - summary.CurrentStreak = currentStreak - summary.MaxWinStreak = maxWin - summary.MaxLoseStreak = maxLose -} - // ExistsWithExchangePositionID checks if a position exists func (s *PositionStore) ExistsWithExchangePositionID(exchangeID, exchangePositionID string) (bool, error) { if exchangePositionID == "" { @@ -1017,124 +455,6 @@ func (s *PositionStore) GetOpenPositionByExchangePositionID(exchangeID, exchange return &pos, nil } -// ClosedPnLRecord represents a closed position record from exchange -// All time fields use int64 millisecond timestamps (UTC) -type ClosedPnLRecord struct { - Symbol string - Side string - EntryPrice float64 - ExitPrice float64 - Quantity float64 - RealizedPnL float64 - Fee float64 - Leverage int - EntryTime int64 // Unix milliseconds UTC - ExitTime int64 // Unix milliseconds UTC - OrderID string - CloseType string - ExchangeID string -} - -// CreateFromClosedPnL creates a closed position record from exchange data -func (s *PositionStore) CreateFromClosedPnL(traderID, exchangeID, exchangeType string, record *ClosedPnLRecord) (bool, error) { - if record.Symbol == "" { - return false, nil - } - - side := strings.ToUpper(record.Side) - if side == "LONG" || side == "BUY" { - side = "LONG" - } else if side == "SHORT" || side == "SELL" { - side = "SHORT" - } else { - return false, nil - } - - if record.Quantity <= 0 || record.ExitPrice <= 0 || record.EntryPrice <= 0 { - return false, nil - } - - exchangePositionID := record.ExchangeID - if exchangePositionID == "" { - exchangePositionID = fmt.Sprintf("%s_%s_%d_%.8f", record.Symbol, side, record.ExitTime, record.RealizedPnL) - } - - exists, err := s.ExistsWithExchangePositionID(exchangeID, exchangePositionID) - if err != nil { - return false, err - } - if exists { - return false, nil - } - - exitTimeMs := record.ExitTime - entryTimeMs := record.EntryTime - - // Validate timestamps (must be after year 2000 = ~946684800000 ms) - minValidTime := int64(946684800000) // 2000-01-01 UTC in milliseconds - if exitTimeMs < minValidTime { - return false, nil - } - if entryTimeMs < minValidTime { - entryTimeMs = exitTimeMs - } - if entryTimeMs > exitTimeMs { - entryTimeMs = exitTimeMs - } - - nowMs := time.Now().UTC().UnixMilli() - pos := &TraderPosition{ - TraderID: traderID, - ExchangeID: exchangeID, - ExchangeType: exchangeType, - ExchangePositionID: exchangePositionID, - Symbol: record.Symbol, - Side: side, - Quantity: record.Quantity, - EntryQuantity: record.Quantity, - EntryPrice: record.EntryPrice, - EntryTime: entryTimeMs, - ExitPrice: record.ExitPrice, - ExitOrderID: record.OrderID, - ExitTime: exitTimeMs, - RealizedPnL: record.RealizedPnL, - Fee: record.Fee, - Leverage: record.Leverage, - Status: "CLOSED", - CloseReason: record.CloseType, - Source: "sync", - CreatedAt: nowMs, - UpdatedAt: nowMs, - } - - err = s.db.Create(pos).Error - if err != nil { - if strings.Contains(err.Error(), "UNIQUE constraint failed") { - return false, nil - } - return false, fmt.Errorf("failed to create position from closed PnL: %w", err) - } - - return true, nil -} - -// GetLastClosedPositionTime gets the most recent exit time (Unix ms) -func (s *PositionStore) GetLastClosedPositionTime(traderID string) (int64, error) { - var pos TraderPosition - err := s.db.Where("trader_id = ? AND status = ? AND exit_time > 0", traderID, "CLOSED"). - Order("exit_time DESC"). - First(&pos).Error - - if err == gorm.ErrRecordNotFound || pos.ExitTime == 0 { - return time.Now().UTC().Add(-30 * 24 * time.Hour).UnixMilli(), nil - } - if err != nil { - return 0, fmt.Errorf("failed to get last closed position time: %w", err) - } - - return pos.ExitTime, nil -} - // CreateOpenPosition creates an open position func (s *PositionStore) CreateOpenPosition(pos *TraderPosition) error { if pos.ExchangePositionID != "" && pos.ExchangeID != "" { @@ -1196,21 +516,3 @@ func (s *PositionStore) ClosePositionWithAccurateData(id int64, exitPrice float6 "updated_at": time.Now().UTC().UnixMilli(), }).Error } - -// SyncClosedPositions syncs closed positions from exchange -func (s *PositionStore) SyncClosedPositions(traderID, exchangeID, exchangeType string, records []ClosedPnLRecord) (int, int, error) { - created, skipped := 0, 0 - for _, record := range records { - rec := record - wasCreated, err := s.CreateFromClosedPnL(traderID, exchangeID, exchangeType, &rec) - if err != nil { - return created, skipped, fmt.Errorf("failed to sync position: %w", err) - } - if wasCreated { - created++ - } else { - skipped++ - } - } - return created, skipped, nil -} diff --git a/store/position_history.go b/store/position_history.go new file mode 100644 index 00000000..d217839f --- /dev/null +++ b/store/position_history.go @@ -0,0 +1,308 @@ +package store + +import ( + "fmt" + "strings" + "time" + + "gorm.io/gorm" +) + +// HistorySummary comprehensive trading history for AI context +type HistorySummary struct { + TotalTrades int `json:"total_trades"` + WinRate float64 `json:"win_rate"` + TotalPnL float64 `json:"total_pnl"` + AvgTradeReturn float64 `json:"avg_trade_return"` + + BestSymbols []SymbolStats `json:"best_symbols"` + WorstSymbols []SymbolStats `json:"worst_symbols"` + + LongWinRate float64 `json:"long_win_rate"` + ShortWinRate float64 `json:"short_win_rate"` + LongPnL float64 `json:"long_pnl"` + ShortPnL float64 `json:"short_pnl"` + + AvgHoldingMins float64 `json:"avg_holding_mins"` + BestHoldRange string `json:"best_hold_range"` + + RecentWinRate float64 `json:"recent_win_rate"` + RecentPnL float64 `json:"recent_pnl"` + + CurrentStreak int `json:"current_streak"` + MaxWinStreak int `json:"max_win_streak"` + MaxLoseStreak int `json:"max_lose_streak"` +} + +// GetHistorySummary generates comprehensive AI context summary +func (s *PositionStore) GetHistorySummary(traderID string) (*HistorySummary, error) { + summary := &HistorySummary{} + + fullStats, err := s.GetFullStats(traderID) + if err != nil { + return nil, err + } + summary.TotalTrades = fullStats.TotalTrades + summary.WinRate = fullStats.WinRate + summary.TotalPnL = fullStats.TotalPnL + if fullStats.TotalTrades > 0 { + summary.AvgTradeReturn = fullStats.TotalPnL / float64(fullStats.TotalTrades) + } + + symbolStats, _ := s.GetSymbolStats(traderID, 20) + if len(symbolStats) > 0 { + for i := 0; i < len(symbolStats) && i < 3; i++ { + if symbolStats[i].TotalPnL > 0 { + summary.BestSymbols = append(summary.BestSymbols, symbolStats[i]) + } + } + for i := len(symbolStats) - 1; i >= 0 && len(summary.WorstSymbols) < 3; i-- { + if symbolStats[i].TotalPnL < 0 { + summary.WorstSymbols = append(summary.WorstSymbols, symbolStats[i]) + } + } + } + + dirStats, _ := s.GetDirectionStats(traderID) + for _, d := range dirStats { + if d.Side == "LONG" { + summary.LongWinRate = d.WinRate + summary.LongPnL = d.TotalPnL + } else if d.Side == "SHORT" { + summary.ShortWinRate = d.WinRate + summary.ShortPnL = d.TotalPnL + } + } + + holdStats, _ := s.GetHoldingTimeStats(traderID) + var bestHoldWinRate float64 + for _, h := range holdStats { + if h.WinRate > bestHoldWinRate && h.TradeCount >= 3 { + bestHoldWinRate = h.WinRate + summary.BestHoldRange = h.Range + } + } + + // Calculate average holding time + var positions []TraderPosition + s.db.Where("trader_id = ? AND status = ? AND exit_time > 0", traderID, "CLOSED").Find(&positions) + if len(positions) > 0 { + var totalMins float64 + for _, pos := range positions { + if pos.ExitTime > 0 { + totalMins += float64(pos.ExitTime-pos.EntryTime) / 60000.0 // ms to minutes + } + } + summary.AvgHoldingMins = totalMins / float64(len(positions)) + } + + // Recent 20 trades + var recent []TraderPosition + s.db.Where("trader_id = ? AND status = ?", traderID, "CLOSED"). + Order("exit_time DESC").Limit(20).Find(&recent) + for _, pos := range recent { + summary.RecentPnL += pos.RealizedPnL + if pos.RealizedPnL > 0 { + summary.RecentWinRate++ + } + } + if len(recent) > 0 { + summary.RecentWinRate = summary.RecentWinRate / float64(len(recent)) * 100 + } + + // Calculate streaks + s.calculateStreaks(traderID, summary) + + return summary, nil +} + +// calculateStreaks calculates win/loss streaks +func (s *PositionStore) calculateStreaks(traderID string, summary *HistorySummary) { + var positions []TraderPosition + err := s.db.Where("trader_id = ? AND status = ?", traderID, "CLOSED"). + Order("exit_time DESC"). + Find(&positions).Error + if err != nil || len(positions) == 0 { + return + } + + var currentStreak, maxWin, maxLose int + var prevWin *bool + isFirst := true + + for _, pos := range positions { + isWin := pos.RealizedPnL > 0 + + if isFirst { + if isWin { + currentStreak = 1 + } else { + currentStreak = -1 + } + isFirst = false + } + + if prevWin == nil { + prevWin = &isWin + } else if *prevWin == isWin { + if isWin { + currentStreak++ + if currentStreak > maxWin { + maxWin = currentStreak + } + } else { + currentStreak-- + if -currentStreak > maxLose { + maxLose = -currentStreak + } + } + } else { + if isWin { + currentStreak = 1 + } else { + currentStreak = -1 + } + *prevWin = isWin + } + } + + summary.CurrentStreak = currentStreak + summary.MaxWinStreak = maxWin + summary.MaxLoseStreak = maxLose +} + +// ClosedPnLRecord represents a closed position record from exchange +// All time fields use int64 millisecond timestamps (UTC) +type ClosedPnLRecord struct { + Symbol string + Side string + EntryPrice float64 + ExitPrice float64 + Quantity float64 + RealizedPnL float64 + Fee float64 + Leverage int + EntryTime int64 // Unix milliseconds UTC + ExitTime int64 // Unix milliseconds UTC + OrderID string + CloseType string + ExchangeID string +} + +// CreateFromClosedPnL creates a closed position record from exchange data +func (s *PositionStore) CreateFromClosedPnL(traderID, exchangeID, exchangeType string, record *ClosedPnLRecord) (bool, error) { + if record.Symbol == "" { + return false, nil + } + + side := strings.ToUpper(record.Side) + if side == "LONG" || side == "BUY" { + side = "LONG" + } else if side == "SHORT" || side == "SELL" { + side = "SHORT" + } else { + return false, nil + } + + if record.Quantity <= 0 || record.ExitPrice <= 0 || record.EntryPrice <= 0 { + return false, nil + } + + exchangePositionID := record.ExchangeID + if exchangePositionID == "" { + exchangePositionID = fmt.Sprintf("%s_%s_%d_%.8f", record.Symbol, side, record.ExitTime, record.RealizedPnL) + } + + exists, err := s.ExistsWithExchangePositionID(exchangeID, exchangePositionID) + if err != nil { + return false, err + } + if exists { + return false, nil + } + + exitTimeMs := record.ExitTime + entryTimeMs := record.EntryTime + + // Validate timestamps (must be after year 2000 = ~946684800000 ms) + minValidTime := int64(946684800000) // 2000-01-01 UTC in milliseconds + if exitTimeMs < minValidTime { + return false, nil + } + if entryTimeMs < minValidTime { + entryTimeMs = exitTimeMs + } + if entryTimeMs > exitTimeMs { + entryTimeMs = exitTimeMs + } + + nowMs := time.Now().UTC().UnixMilli() + pos := &TraderPosition{ + TraderID: traderID, + ExchangeID: exchangeID, + ExchangeType: exchangeType, + ExchangePositionID: exchangePositionID, + Symbol: record.Symbol, + Side: side, + Quantity: record.Quantity, + EntryQuantity: record.Quantity, + EntryPrice: record.EntryPrice, + EntryTime: entryTimeMs, + ExitPrice: record.ExitPrice, + ExitOrderID: record.OrderID, + ExitTime: exitTimeMs, + RealizedPnL: record.RealizedPnL, + Fee: record.Fee, + Leverage: record.Leverage, + Status: "CLOSED", + CloseReason: record.CloseType, + Source: "sync", + CreatedAt: nowMs, + UpdatedAt: nowMs, + } + + err = s.db.Create(pos).Error + if err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint failed") { + return false, nil + } + return false, fmt.Errorf("failed to create position from closed PnL: %w", err) + } + + return true, nil +} + +// GetLastClosedPositionTime gets the most recent exit time (Unix ms) +func (s *PositionStore) GetLastClosedPositionTime(traderID string) (int64, error) { + var pos TraderPosition + err := s.db.Where("trader_id = ? AND status = ? AND exit_time > 0", traderID, "CLOSED"). + Order("exit_time DESC"). + First(&pos).Error + + if err == gorm.ErrRecordNotFound || pos.ExitTime == 0 { + return time.Now().UTC().Add(-30 * 24 * time.Hour).UnixMilli(), nil + } + if err != nil { + return 0, fmt.Errorf("failed to get last closed position time: %w", err) + } + + return pos.ExitTime, nil +} + +// SyncClosedPositions syncs closed positions from exchange +func (s *PositionStore) SyncClosedPositions(traderID, exchangeID, exchangeType string, records []ClosedPnLRecord) (int, int, error) { + created, skipped := 0, 0 + for _, record := range records { + rec := record + wasCreated, err := s.CreateFromClosedPnL(traderID, exchangeID, exchangeType, &rec) + if err != nil { + return created, skipped, fmt.Errorf("failed to sync position: %w", err) + } + if wasCreated { + created++ + } else { + skipped++ + } + } + return created, skipped, nil +} diff --git a/store/position_query.go b/store/position_query.go new file mode 100644 index 00000000..b50a5034 --- /dev/null +++ b/store/position_query.go @@ -0,0 +1,406 @@ +package store + +import ( + "fmt" + "math" + "strings" +) + +// TraderStats trading statistics metrics +type TraderStats struct { + TotalTrades int `json:"total_trades"` + WinTrades int `json:"win_trades"` + LossTrades int `json:"loss_trades"` + WinRate float64 `json:"win_rate"` + ProfitFactor float64 `json:"profit_factor"` + SharpeRatio float64 `json:"sharpe_ratio"` + TotalPnL float64 `json:"total_pnl"` + TotalFee float64 `json:"total_fee"` + AvgWin float64 `json:"avg_win"` + AvgLoss float64 `json:"avg_loss"` + MaxDrawdownPct float64 `json:"max_drawdown_pct"` +} + +// GetPositionStats gets position statistics +func (s *PositionStore) GetPositionStats(traderID string) (map[string]interface{}, error) { + stats := make(map[string]interface{}) + + type result struct { + Total int + Wins int + TotalPnL float64 + TotalFee float64 + } + var r result + + err := s.db.Model(&TraderPosition{}). + Select("COUNT(*) as total, SUM(CASE WHEN realized_pnl > 0 THEN 1 ELSE 0 END) as wins, COALESCE(SUM(realized_pnl), 0) as total_pnl, COALESCE(SUM(fee), 0) as total_fee"). + Where("trader_id = ? AND status = ?", traderID, "CLOSED"). + Scan(&r).Error + if err != nil { + return nil, err + } + + stats["total_trades"] = r.Total + stats["win_trades"] = r.Wins + stats["total_pnl"] = r.TotalPnL + stats["total_fee"] = r.TotalFee + if r.Total > 0 { + stats["win_rate"] = float64(r.Wins) / float64(r.Total) * 100 + } else { + stats["win_rate"] = 0.0 + } + + return stats, nil +} + +// GetFullStats gets complete trading statistics +func (s *PositionStore) GetFullStats(traderID string) (*TraderStats, error) { + stats := &TraderStats{} + + var count int64 + if err := s.db.Model(&TraderPosition{}).Where("trader_id = ? AND status = ?", traderID, "CLOSED").Count(&count).Error; err != nil { + return nil, err + } + if count == 0 { + return stats, nil + } + + var positions []TraderPosition + err := s.db.Where("trader_id = ? AND status = ?", traderID, "CLOSED"). + Order("exit_time ASC"). + Find(&positions).Error + if err != nil { + return nil, fmt.Errorf("failed to query position statistics: %w", err) + } + + var pnls []float64 + var totalWin, totalLoss float64 + + for _, pos := range positions { + stats.TotalTrades++ + stats.TotalPnL += pos.RealizedPnL + stats.TotalFee += pos.Fee + pnls = append(pnls, pos.RealizedPnL) + + if pos.RealizedPnL > 0 { + stats.WinTrades++ + totalWin += pos.RealizedPnL + } else if pos.RealizedPnL < 0 { + stats.LossTrades++ + totalLoss += -pos.RealizedPnL + } + } + + if stats.TotalTrades > 0 { + stats.WinRate = float64(stats.WinTrades) / float64(stats.TotalTrades) * 100 + } + if totalLoss > 0 { + stats.ProfitFactor = totalWin / totalLoss + } + if stats.WinTrades > 0 { + stats.AvgWin = totalWin / float64(stats.WinTrades) + } + if stats.LossTrades > 0 { + stats.AvgLoss = totalLoss / float64(stats.LossTrades) + } + if len(pnls) > 1 { + stats.SharpeRatio = calculateSharpeRatioFromPnls(pnls) + } + if len(pnls) > 0 { + stats.MaxDrawdownPct = calculateMaxDrawdownFromPnls(pnls) + } + + return stats, nil +} + +// RecentTrade recent trade record +type RecentTrade struct { + Symbol string `json:"symbol"` + Side string `json:"side"` + EntryPrice float64 `json:"entry_price"` + ExitPrice float64 `json:"exit_price"` + RealizedPnL float64 `json:"realized_pnl"` + PnLPct float64 `json:"pnl_pct"` + EntryTime int64 `json:"entry_time"` + ExitTime int64 `json:"exit_time"` + HoldDuration string `json:"hold_duration"` +} + +// GetRecentTrades gets recent closed trades +func (s *PositionStore) GetRecentTrades(traderID string, limit int) ([]RecentTrade, error) { + var positions []TraderPosition + err := s.db.Where("trader_id = ? AND status = ?", traderID, "CLOSED"). + Order("exit_time DESC"). + Limit(limit). + Find(&positions).Error + if err != nil { + return nil, fmt.Errorf("failed to query recent trades: %w", err) + } + + var trades []RecentTrade + for _, pos := range positions { + t := RecentTrade{ + Symbol: pos.Symbol, + Side: strings.ToLower(pos.Side), + EntryPrice: pos.EntryPrice, + ExitPrice: pos.ExitPrice, + RealizedPnL: pos.RealizedPnL, + EntryTime: pos.EntryTime / 1000, // Convert ms to seconds for API compatibility + } + + if pos.ExitTime > 0 { + t.ExitTime = pos.ExitTime / 1000 // Convert ms to seconds + durationMs := pos.ExitTime - pos.EntryTime + t.HoldDuration = formatDurationMs(durationMs) + } + + if pos.EntryPrice > 0 { + if t.Side == "long" { + t.PnLPct = (pos.ExitPrice - pos.EntryPrice) / pos.EntryPrice * 100 * float64(pos.Leverage) + } else { + t.PnLPct = (pos.EntryPrice - pos.ExitPrice) / pos.EntryPrice * 100 * float64(pos.Leverage) + } + } + + trades = append(trades, t) + } + + return trades, nil +} + +// calculateSharpeRatioFromPnls calculates Sharpe ratio +func calculateSharpeRatioFromPnls(pnls []float64) float64 { + if len(pnls) < 2 { + return 0 + } + + var sum float64 + for _, pnl := range pnls { + sum += pnl + } + mean := sum / float64(len(pnls)) + + var variance float64 + for _, pnl := range pnls { + variance += (pnl - mean) * (pnl - mean) + } + stdDev := math.Sqrt(variance / float64(len(pnls)-1)) + + if stdDev == 0 { + return 0 + } + + return mean / stdDev +} + +// calculateMaxDrawdownFromPnls calculates maximum drawdown +func calculateMaxDrawdownFromPnls(pnls []float64) float64 { + if len(pnls) == 0 { + return 0 + } + + const startingEquity = 10000.0 + equity := startingEquity + peak := startingEquity + var maxDD float64 + + for _, pnl := range pnls { + equity += pnl + if equity > peak { + peak = equity + } + if peak > 0 { + dd := (peak - equity) / peak * 100 + if dd > maxDD { + maxDD = dd + } + } + } + + return maxDD +} + +// SymbolStats per-symbol trading statistics +type SymbolStats struct { + Symbol string `json:"symbol"` + TotalTrades int `json:"total_trades"` + WinTrades int `json:"win_trades"` + WinRate float64 `json:"win_rate"` + TotalPnL float64 `json:"total_pnl"` + AvgPnL float64 `json:"avg_pnl"` + AvgHoldMins float64 `json:"avg_hold_mins"` +} + +// GetSymbolStats gets per-symbol trading statistics +func (s *PositionStore) GetSymbolStats(traderID string, limit int) ([]SymbolStats, error) { + var positions []TraderPosition + err := s.db.Where("trader_id = ? AND status = ?", traderID, "CLOSED").Find(&positions).Error + if err != nil { + return nil, fmt.Errorf("failed to query symbol stats: %w", err) + } + + // Group by symbol + symbolMap := make(map[string]*SymbolStats) + symbolHoldMins := make(map[string][]float64) + + for _, pos := range positions { + if _, ok := symbolMap[pos.Symbol]; !ok { + symbolMap[pos.Symbol] = &SymbolStats{Symbol: pos.Symbol} + symbolHoldMins[pos.Symbol] = []float64{} + } + s := symbolMap[pos.Symbol] + s.TotalTrades++ + s.TotalPnL += pos.RealizedPnL + if pos.RealizedPnL > 0 { + s.WinTrades++ + } + + if pos.ExitTime > 0 { + holdMins := float64(pos.ExitTime-pos.EntryTime) / 60000.0 // ms to minutes + symbolHoldMins[pos.Symbol] = append(symbolHoldMins[pos.Symbol], holdMins) + } + } + + var stats []SymbolStats + for symbol, s := range symbolMap { + if s.TotalTrades > 0 { + s.WinRate = float64(s.WinTrades) / float64(s.TotalTrades) * 100 + s.AvgPnL = s.TotalPnL / float64(s.TotalTrades) + } + if len(symbolHoldMins[symbol]) > 0 { + var totalMins float64 + for _, m := range symbolHoldMins[symbol] { + totalMins += m + } + s.AvgHoldMins = totalMins / float64(len(symbolHoldMins[symbol])) + } + stats = append(stats, *s) + } + + // Sort by TotalPnL descending and limit + for i := 0; i < len(stats)-1; i++ { + for j := i + 1; j < len(stats); j++ { + if stats[j].TotalPnL > stats[i].TotalPnL { + stats[i], stats[j] = stats[j], stats[i] + } + } + } + + if limit > 0 && len(stats) > limit { + stats = stats[:limit] + } + + return stats, nil +} + +// HoldingTimeStats holding duration analysis +type HoldingTimeStats struct { + Range string `json:"range"` + TradeCount int `json:"trade_count"` + WinRate float64 `json:"win_rate"` + AvgPnL float64 `json:"avg_pnl"` +} + +// GetHoldingTimeStats analyzes performance by holding duration +func (s *PositionStore) GetHoldingTimeStats(traderID string) ([]HoldingTimeStats, error) { + var positions []TraderPosition + err := s.db.Where("trader_id = ? AND status = ? AND exit_time > 0", traderID, "CLOSED").Find(&positions).Error + if err != nil { + return nil, fmt.Errorf("failed to query holding time stats: %w", err) + } + + rangeStats := map[string]*struct { + count int + wins int + totalPnL float64 + }{ + "<1h": {}, + "1-4h": {}, + "4-24h": {}, + ">24h": {}, + } + + for _, pos := range positions { + if pos.ExitTime == 0 { + continue + } + holdHours := float64(pos.ExitTime-pos.EntryTime) / 3600000.0 // ms to hours + + var rangeKey string + switch { + case holdHours < 1: + rangeKey = "<1h" + case holdHours < 4: + rangeKey = "1-4h" + case holdHours < 24: + rangeKey = "4-24h" + default: + rangeKey = ">24h" + } + + r := rangeStats[rangeKey] + r.count++ + r.totalPnL += pos.RealizedPnL + if pos.RealizedPnL > 0 { + r.wins++ + } + } + + var stats []HoldingTimeStats + for _, rangeKey := range []string{"<1h", "1-4h", "4-24h", ">24h"} { + r := rangeStats[rangeKey] + if r.count > 0 { + stats = append(stats, HoldingTimeStats{ + Range: rangeKey, + TradeCount: r.count, + WinRate: float64(r.wins) / float64(r.count) * 100, + AvgPnL: r.totalPnL / float64(r.count), + }) + } + } + + return stats, nil +} + +// DirectionStats long/short performance comparison +type DirectionStats struct { + Side string `json:"side"` + TradeCount int `json:"trade_count"` + WinRate float64 `json:"win_rate"` + TotalPnL float64 `json:"total_pnl"` + AvgPnL float64 `json:"avg_pnl"` +} + +// GetDirectionStats analyzes long vs short performance +func (s *PositionStore) GetDirectionStats(traderID string) ([]DirectionStats, error) { + var positions []TraderPosition + err := s.db.Where("trader_id = ? AND status = ?", traderID, "CLOSED").Find(&positions).Error + if err != nil { + return nil, fmt.Errorf("failed to query direction stats: %w", err) + } + + sideStats := make(map[string]*DirectionStats) + for _, pos := range positions { + if _, ok := sideStats[pos.Side]; !ok { + sideStats[pos.Side] = &DirectionStats{Side: pos.Side} + } + s := sideStats[pos.Side] + s.TradeCount++ + s.TotalPnL += pos.RealizedPnL + if pos.RealizedPnL > 0 { + s.WinRate++ + } + } + + var stats []DirectionStats + for _, s := range sideStats { + if s.TradeCount > 0 { + s.AvgPnL = s.TotalPnL / float64(s.TradeCount) + s.WinRate = s.WinRate / float64(s.TradeCount) * 100 + } + stats = append(stats, *s) + } + + return stats, nil +} diff --git a/experience/experience.go b/telemetry/experience.go similarity index 90% rename from experience/experience.go rename to telemetry/experience.go index b6a24430..69f57c37 100644 --- a/experience/experience.go +++ b/telemetry/experience.go @@ -1,5 +1,5 @@ -// Package experience handles product telemetry -package experience +// Package telemetry handles product telemetry +package telemetry import ( "bytes" @@ -28,13 +28,13 @@ type Client struct { } type TradeEvent struct { - Exchange string - TradeType string - Symbol string - AmountUSD float64 - Leverage int - UserID string - TraderID string + Exchange string + TradeType string + Symbol string + AmountUSD float64 + Leverage int + UserID string + TraderID string } type AIUsageEvent struct { @@ -129,10 +129,10 @@ func sendTradeEvent(event TradeEvent) error { "symbol": event.Symbol, "amount_usd": event.AmountUSD, "leverage": event.Leverage, - "installation_id": installationID, // For counting active installations - "user_id": event.UserID, // For counting active users - "trader_id": event.TraderID, // For counting active traders - "engagement_time_msec": 1, // Required by GA4 + "installation_id": installationID, // For counting active installations + "user_id": event.UserID, // For counting active users + "trader_id": event.TraderID, // For counting active traders + "engagement_time_msec": 1, // Required by GA4 }, }, }, diff --git a/trader/aster/trader.go b/trader/aster/trader.go index 117d73be..d7272f21 100644 --- a/trader/aster/trader.go +++ b/trader/aster/trader.go @@ -5,10 +5,8 @@ import ( "crypto/ecdsa" "encoding/hex" "encoding/json" - "errors" "fmt" "io" - "nofx/logger" "math" "math/big" "net/http" @@ -23,7 +21,6 @@ import ( "github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" - "nofx/trader/types" ) // AsterTrader Aster trading platform implementation @@ -431,1178 +428,3 @@ func (t *AsterTrader) doRequest(method, endpoint string, params map[string]inter return nil, fmt.Errorf("unsupported HTTP method: %s", method) } } - -// GetBalance Get account balance -func (t *AsterTrader) GetBalance() (map[string]interface{}, error) { - params := make(map[string]interface{}) - body, err := t.request("GET", "/fapi/v3/balance", params) - if err != nil { - return nil, err - } - - var balances []map[string]interface{} - if err := json.Unmarshal(body, &balances); err != nil { - return nil, err - } - - // Find USDT balance - availableBalance := 0.0 - crossUnPnl := 0.0 - crossWalletBalance := 0.0 - foundUSDT := false - - for _, bal := range balances { - if asset, ok := bal["asset"].(string); ok && asset == "USDT" { - foundUSDT = true - - // Parse Aster fields (reference: https://github.com/asterdex/api-docs) - if avail, ok := bal["availableBalance"].(string); ok { - availableBalance, _ = strconv.ParseFloat(avail, 64) - } - if unpnl, ok := bal["crossUnPnl"].(string); ok { - crossUnPnl, _ = strconv.ParseFloat(unpnl, 64) - } - if cwb, ok := bal["crossWalletBalance"].(string); ok { - crossWalletBalance, _ = strconv.ParseFloat(cwb, 64) - } - break - } - } - - if !foundUSDT { - logger.Infof("⚠️ USDT asset record not found!") - } - - // Get positions to calculate margin used and real unrealized PnL - positions, err := t.GetPositions() - if err != nil { - logger.Infof("⚠️ Failed to get position information: %v", err) - // fallback: use simple calculation when unable to get positions - return map[string]interface{}{ - "totalWalletBalance": crossWalletBalance, - "availableBalance": availableBalance, - "totalUnrealizedProfit": crossUnPnl, - }, nil - } - - // ⚠️ Critical fix: accumulate real unrealized PnL from positions - // Aster's crossUnPnl field is inaccurate, need to recalculate from position data - totalMarginUsed := 0.0 - realUnrealizedPnl := 0.0 - for _, pos := range positions { - markPrice := pos["markPrice"].(float64) - quantity := pos["positionAmt"].(float64) - if quantity < 0 { - quantity = -quantity - } - unrealizedPnl := pos["unRealizedProfit"].(float64) - realUnrealizedPnl += unrealizedPnl - - leverage := 10 - if lev, ok := pos["leverage"].(float64); ok { - leverage = int(lev) - } - marginUsed := (quantity * markPrice) / float64(leverage) - totalMarginUsed += marginUsed - } - - // ✅ Aster correct calculation method: - // Total equity = available balance + margin used - // Wallet balance = total equity - unrealized PnL - // Unrealized PnL = calculated from accumulated positions (don't use API's crossUnPnl) - totalEquity := availableBalance + totalMarginUsed - totalWalletBalance := totalEquity - realUnrealizedPnl - - return map[string]interface{}{ - "totalWalletBalance": totalWalletBalance, // Wallet balance (excluding unrealized PnL) - "availableBalance": availableBalance, // Available balance - "totalUnrealizedProfit": realUnrealizedPnl, // Unrealized PnL (accumulated from positions) - }, nil -} - -// GetPositions Get position information -func (t *AsterTrader) GetPositions() ([]map[string]interface{}, error) { - params := make(map[string]interface{}) - body, err := t.request("GET", "/fapi/v3/positionRisk", params) - if err != nil { - return nil, err - } - - var positions []map[string]interface{} - if err := json.Unmarshal(body, &positions); err != nil { - return nil, err - } - - result := []map[string]interface{}{} - for _, pos := range positions { - posAmtStr, ok := pos["positionAmt"].(string) - if !ok { - continue - } - - posAmt, _ := strconv.ParseFloat(posAmtStr, 64) - if posAmt == 0 { - continue // Skip empty positions - } - - entryPrice, _ := strconv.ParseFloat(pos["entryPrice"].(string), 64) - markPrice, _ := strconv.ParseFloat(pos["markPrice"].(string), 64) - unRealizedProfit, _ := strconv.ParseFloat(pos["unRealizedProfit"].(string), 64) - leverageVal, _ := strconv.ParseFloat(pos["leverage"].(string), 64) - liquidationPrice, _ := strconv.ParseFloat(pos["liquidationPrice"].(string), 64) - - // Determine direction (consistent with Binance) - side := "long" - if posAmt < 0 { - side = "short" - posAmt = -posAmt - } - - // Return same field names as Binance - result = append(result, map[string]interface{}{ - "symbol": pos["symbol"], - "side": side, - "positionAmt": posAmt, - "entryPrice": entryPrice, - "markPrice": markPrice, - "unRealizedProfit": unRealizedProfit, - "leverage": leverageVal, - "liquidationPrice": liquidationPrice, - }) - } - - return result, nil -} - -// OpenLong Open long position -func (t *AsterTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { - // Cancel all pending orders before opening position to prevent position stacking from residual orders - if err := t.CancelAllOrders(symbol); err != nil { - logger.Infof(" ⚠ Failed to cancel pending orders (continuing to open position): %v", err) - } - - // Set leverage first (non-fatal if position already exists) - if err := t.SetLeverage(symbol, leverage); err != nil { - // Error -2030: Cannot adjust leverage when position exists - // This is expected when adding to an existing position, continue with current leverage - if strings.Contains(err.Error(), "-2030") { - logger.Infof(" ⚠ Cannot change leverage (position exists), using current leverage: %v", err) - } else { - return nil, fmt.Errorf("failed to set leverage: %w", err) - } - } - - // Get current price - price, err := t.GetMarketPrice(symbol) - if err != nil { - return nil, err - } - - // Use limit order to simulate market order (price set slightly higher to ensure execution) - limitPrice := price * 1.01 - - // Format price and quantity to correct precision - formattedPrice, err := t.formatPrice(symbol, limitPrice) - if err != nil { - return nil, err - } - formattedQty, err := t.formatQuantity(symbol, quantity) - if err != nil { - return nil, err - } - - // Get precision information - prec, err := t.getPrecision(symbol) - if err != nil { - return nil, err - } - - // Convert to string with correct precision format - priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision) - qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision) - - logger.Infof(" 📏 Precision handling: price %.8f -> %s (precision=%d), quantity %.8f -> %s (precision=%d)", - limitPrice, priceStr, prec.PricePrecision, quantity, qtyStr, prec.QuantityPrecision) - - params := map[string]interface{}{ - "symbol": symbol, - "positionSide": "BOTH", - "type": "LIMIT", - "side": "BUY", - "timeInForce": "GTC", - "quantity": qtyStr, - "price": priceStr, - } - - body, err := t.request("POST", "/fapi/v3/order", params) - if err != nil { - return nil, err - } - - var result map[string]interface{} - if err := json.Unmarshal(body, &result); err != nil { - return nil, err - } - - return result, nil -} - -// OpenShort Open short position -func (t *AsterTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { - // Cancel all pending orders before opening position to prevent position stacking from residual orders - if err := t.CancelAllOrders(symbol); err != nil { - logger.Infof(" ⚠ Failed to cancel pending orders (continuing to open position): %v", err) - } - - // Set leverage first (non-fatal if position already exists) - if err := t.SetLeverage(symbol, leverage); err != nil { - // Error -2030: Cannot adjust leverage when position exists - // This is expected when adding to an existing position, continue with current leverage - if strings.Contains(err.Error(), "-2030") { - logger.Infof(" ⚠ Cannot change leverage (position exists), using current leverage: %v", err) - } else { - return nil, fmt.Errorf("failed to set leverage: %w", err) - } - } - - // Get current price - price, err := t.GetMarketPrice(symbol) - if err != nil { - return nil, err - } - - // Use limit order to simulate market order (price set slightly lower to ensure execution) - limitPrice := price * 0.99 - - // Format price and quantity to correct precision - formattedPrice, err := t.formatPrice(symbol, limitPrice) - if err != nil { - return nil, err - } - formattedQty, err := t.formatQuantity(symbol, quantity) - if err != nil { - return nil, err - } - - // Get precision information - prec, err := t.getPrecision(symbol) - if err != nil { - return nil, err - } - - // Convert to string with correct precision format - priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision) - qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision) - - logger.Infof(" 📏 Precision handling: price %.8f -> %s (precision=%d), quantity %.8f -> %s (precision=%d)", - limitPrice, priceStr, prec.PricePrecision, quantity, qtyStr, prec.QuantityPrecision) - - params := map[string]interface{}{ - "symbol": symbol, - "positionSide": "BOTH", - "type": "LIMIT", - "side": "SELL", - "timeInForce": "GTC", - "quantity": qtyStr, - "price": priceStr, - } - - body, err := t.request("POST", "/fapi/v3/order", params) - if err != nil { - return nil, err - } - - var result map[string]interface{} - if err := json.Unmarshal(body, &result); err != nil { - return nil, err - } - - return result, nil -} - -// CloseLong Close long position -func (t *AsterTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) { - // If quantity is 0, get current position quantity - if quantity == 0 { - positions, err := t.GetPositions() - if err != nil { - return nil, err - } - - for _, pos := range positions { - if pos["symbol"] == symbol && pos["side"] == "long" { - quantity = pos["positionAmt"].(float64) - break - } - } - - if quantity == 0 { - return nil, fmt.Errorf("no long position found for %s", symbol) - } - logger.Infof(" 📊 Retrieved long position quantity: %.8f", quantity) - } - - price, err := t.GetMarketPrice(symbol) - if err != nil { - return nil, err - } - - limitPrice := price * 0.99 - - // Format price and quantity to correct precision - formattedPrice, err := t.formatPrice(symbol, limitPrice) - if err != nil { - return nil, err - } - formattedQty, err := t.formatQuantity(symbol, quantity) - if err != nil { - return nil, err - } - - // Get precision information - prec, err := t.getPrecision(symbol) - if err != nil { - return nil, err - } - - // Convert to string with correct precision format - priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision) - qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision) - - logger.Infof(" 📏 Precision handling: price %.8f -> %s (precision=%d), quantity %.8f -> %s (precision=%d)", - limitPrice, priceStr, prec.PricePrecision, quantity, qtyStr, prec.QuantityPrecision) - - params := map[string]interface{}{ - "symbol": symbol, - "positionSide": "BOTH", - "type": "LIMIT", - "side": "SELL", - "timeInForce": "GTC", - "quantity": qtyStr, - "price": priceStr, - } - - body, err := t.request("POST", "/fapi/v3/order", params) - if err != nil { - return nil, err - } - - var result map[string]interface{} - if err := json.Unmarshal(body, &result); err != nil { - return nil, err - } - - logger.Infof("✓ Successfully closed long position: %s quantity: %s", symbol, qtyStr) - - // Cancel all pending orders for this symbol after closing position (stop-loss/take-profit orders) - if err := t.CancelAllOrders(symbol); err != nil { - logger.Infof(" ⚠ Failed to cancel pending orders: %v", err) - } - - return result, nil -} - -// CloseShort Close short position -func (t *AsterTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) { - // If quantity is 0, get current position quantity - if quantity == 0 { - positions, err := t.GetPositions() - if err != nil { - return nil, err - } - - for _, pos := range positions { - if pos["symbol"] == symbol && pos["side"] == "short" { - // Aster's GetPositions has already converted short position quantity to positive, use directly - quantity = pos["positionAmt"].(float64) - break - } - } - - if quantity == 0 { - return nil, fmt.Errorf("no short position found for %s", symbol) - } - logger.Infof(" 📊 Retrieved short position quantity: %.8f", quantity) - } - - price, err := t.GetMarketPrice(symbol) - if err != nil { - return nil, err - } - - limitPrice := price * 1.01 - - // Format price and quantity to correct precision - formattedPrice, err := t.formatPrice(symbol, limitPrice) - if err != nil { - return nil, err - } - formattedQty, err := t.formatQuantity(symbol, quantity) - if err != nil { - return nil, err - } - - // Get precision information - prec, err := t.getPrecision(symbol) - if err != nil { - return nil, err - } - - // Convert to string with correct precision format - priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision) - qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision) - - logger.Infof(" 📏 Precision handling: price %.8f -> %s (precision=%d), quantity %.8f -> %s (precision=%d)", - limitPrice, priceStr, prec.PricePrecision, quantity, qtyStr, prec.QuantityPrecision) - - params := map[string]interface{}{ - "symbol": symbol, - "positionSide": "BOTH", - "type": "LIMIT", - "side": "BUY", - "timeInForce": "GTC", - "quantity": qtyStr, - "price": priceStr, - } - - body, err := t.request("POST", "/fapi/v3/order", params) - if err != nil { - return nil, err - } - - var result map[string]interface{} - if err := json.Unmarshal(body, &result); err != nil { - return nil, err - } - - logger.Infof("✓ Successfully closed short position: %s quantity: %s", symbol, qtyStr) - - // Cancel all pending orders for this symbol after closing position (stop-loss/take-profit orders) - if err := t.CancelAllOrders(symbol); err != nil { - logger.Infof(" ⚠ Failed to cancel pending orders: %v", err) - } - - return result, nil -} - -// SetMarginMode Set margin mode -func (t *AsterTrader) SetMarginMode(symbol string, isCrossMargin bool) error { - // Aster supports margin mode settings - // API format similar to Binance: CROSSED (cross margin) / ISOLATED (isolated margin) - marginType := "CROSSED" - if !isCrossMargin { - marginType = "ISOLATED" - } - - params := map[string]interface{}{ - "symbol": symbol, - "marginType": marginType, - } - - // Use request method to call API - _, err := t.request("POST", "/fapi/v3/marginType", params) - if err != nil { - // Ignore error if it indicates no need to change - if strings.Contains(err.Error(), "No need to change") || - strings.Contains(err.Error(), "Margin type cannot be changed") { - logger.Infof(" ✓ %s margin mode is already %s or cannot be changed due to existing positions", symbol, marginType) - return nil - } - // Detect multi-assets mode (error code -4168) - if strings.Contains(err.Error(), "Multi-Assets mode") || - strings.Contains(err.Error(), "-4168") || - strings.Contains(err.Error(), "4168") { - logger.Infof(" ⚠️ %s detected multi-assets mode, forcing cross margin mode", symbol) - logger.Infof(" 💡 Tip: To use isolated margin mode, please disable multi-assets mode on the exchange") - return nil - } - // Detect unified account API - if strings.Contains(err.Error(), "unified") || - strings.Contains(err.Error(), "portfolio") || - strings.Contains(err.Error(), "Portfolio") { - logger.Infof(" ❌ %s detected unified account API, cannot perform futures trading", symbol) - return fmt.Errorf("please use 'Spot & Futures Trading' API permission, not 'Unified Account API'") - } - logger.Infof(" ⚠️ Failed to set margin mode: %v", err) - // Don't return error, let trading continue - return nil - } - - logger.Infof(" ✓ %s margin mode has been set to %s", symbol, marginType) - return nil -} - -// SetLeverage Set leverage multiplier -func (t *AsterTrader) SetLeverage(symbol string, leverage int) error { - params := map[string]interface{}{ - "symbol": symbol, - "leverage": leverage, - } - - _, err := t.request("POST", "/fapi/v3/leverage", params) - return err -} - -// GetMarketPrice Get market price -func (t *AsterTrader) GetMarketPrice(symbol string) (float64, error) { - // Use ticker interface to get current price - resp, err := t.client.Get(fmt.Sprintf("%s/fapi/v3/ticker/price?symbol=%s", t.baseURL, symbol)) - if err != nil { - return 0, err - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusOK { - return 0, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) - } - - var result map[string]interface{} - if err := json.Unmarshal(body, &result); err != nil { - return 0, err - } - - priceStr, ok := result["price"].(string) - if !ok { - return 0, errors.New("unable to get price") - } - - return strconv.ParseFloat(priceStr, 64) -} - -// SetStopLoss Set stop loss -func (t *AsterTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error { - side := "SELL" - if positionSide == "SHORT" { - side = "BUY" - } - - // Format price and quantity to correct precision - formattedPrice, err := t.formatPrice(symbol, stopPrice) - if err != nil { - return err - } - formattedQty, err := t.formatQuantity(symbol, quantity) - if err != nil { - return err - } - - // Get precision information - prec, err := t.getPrecision(symbol) - if err != nil { - return err - } - - // Convert to string with correct precision format - priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision) - qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision) - - params := map[string]interface{}{ - "symbol": symbol, - "positionSide": "BOTH", - "type": "STOP_MARKET", - "side": side, - "stopPrice": priceStr, - "quantity": qtyStr, - "timeInForce": "GTC", - } - - _, err = t.request("POST", "/fapi/v3/order", params) - return err -} - -// SetTakeProfit Set take profit -func (t *AsterTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error { - side := "SELL" - if positionSide == "SHORT" { - side = "BUY" - } - - // Format price and quantity to correct precision - formattedPrice, err := t.formatPrice(symbol, takeProfitPrice) - if err != nil { - return err - } - formattedQty, err := t.formatQuantity(symbol, quantity) - if err != nil { - return err - } - - // Get precision information - prec, err := t.getPrecision(symbol) - if err != nil { - return err - } - - // Convert to string with correct precision format - priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision) - qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision) - - params := map[string]interface{}{ - "symbol": symbol, - "positionSide": "BOTH", - "type": "TAKE_PROFIT_MARKET", - "side": side, - "stopPrice": priceStr, - "quantity": qtyStr, - "timeInForce": "GTC", - } - - _, err = t.request("POST", "/fapi/v3/order", params) - return err -} - -// CancelStopLossOrders Cancel stop-loss orders only (does not affect take-profit orders) -func (t *AsterTrader) CancelStopLossOrders(symbol string) error { - // Get all open orders for this symbol - params := map[string]interface{}{ - "symbol": symbol, - } - - body, err := t.request("GET", "/fapi/v3/openOrders", params) - if err != nil { - return fmt.Errorf("failed to get open orders: %w", err) - } - - var orders []map[string]interface{} - if err := json.Unmarshal(body, &orders); err != nil { - return fmt.Errorf("failed to parse order data: %w", err) - } - - // Filter and cancel stop-loss orders (cancel all directions including LONG and SHORT) - canceledCount := 0 - var cancelErrors []error - for _, order := range orders { - orderType, _ := order["type"].(string) - - // Only cancel stop-loss orders (don't cancel take-profit orders) - if orderType == "STOP_MARKET" || orderType == "STOP" { - orderID, _ := order["orderId"].(float64) - positionSide, _ := order["positionSide"].(string) - cancelParams := map[string]interface{}{ - "symbol": symbol, - "orderId": int64(orderID), - } - - _, err := t.request("DELETE", "/fapi/v1/order", cancelParams) - if err != nil { - errMsg := fmt.Sprintf("order ID %d: %v", int64(orderID), err) - cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg)) - logger.Infof(" ⚠ Failed to cancel stop-loss order: %s", errMsg) - continue - } - - canceledCount++ - logger.Infof(" ✓ Canceled stop-loss order (order ID: %d, type: %s, direction: %s)", int64(orderID), orderType, positionSide) - } - } - - if canceledCount == 0 && len(cancelErrors) == 0 { - logger.Infof(" ℹ %s no stop-loss orders to cancel", symbol) - } else if canceledCount > 0 { - logger.Infof(" ✓ Canceled %d stop-loss order(s) for %s", canceledCount, symbol) - } - - // Return error if all cancellations failed - if len(cancelErrors) > 0 && canceledCount == 0 { - return fmt.Errorf("failed to cancel stop-loss orders: %v", cancelErrors) - } - - return nil -} - -// CancelTakeProfitOrders Cancel take-profit orders only (does not affect stop-loss orders) -func (t *AsterTrader) CancelTakeProfitOrders(symbol string) error { - // Get all open orders for this symbol - params := map[string]interface{}{ - "symbol": symbol, - } - - body, err := t.request("GET", "/fapi/v3/openOrders", params) - if err != nil { - return fmt.Errorf("failed to get open orders: %w", err) - } - - var orders []map[string]interface{} - if err := json.Unmarshal(body, &orders); err != nil { - return fmt.Errorf("failed to parse order data: %w", err) - } - - // Filter and cancel take-profit orders (cancel all directions including LONG and SHORT) - canceledCount := 0 - var cancelErrors []error - for _, order := range orders { - orderType, _ := order["type"].(string) - - // Only cancel take-profit orders (don't cancel stop-loss orders) - if orderType == "TAKE_PROFIT_MARKET" || orderType == "TAKE_PROFIT" { - orderID, _ := order["orderId"].(float64) - positionSide, _ := order["positionSide"].(string) - cancelParams := map[string]interface{}{ - "symbol": symbol, - "orderId": int64(orderID), - } - - _, err := t.request("DELETE", "/fapi/v1/order", cancelParams) - if err != nil { - errMsg := fmt.Sprintf("order ID %d: %v", int64(orderID), err) - cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg)) - logger.Infof(" ⚠ Failed to cancel take-profit order: %s", errMsg) - continue - } - - canceledCount++ - logger.Infof(" ✓ Canceled take-profit order (order ID: %d, type: %s, direction: %s)", int64(orderID), orderType, positionSide) - } - } - - if canceledCount == 0 && len(cancelErrors) == 0 { - logger.Infof(" ℹ %s no take-profit orders to cancel", symbol) - } else if canceledCount > 0 { - logger.Infof(" ✓ Canceled %d take-profit order(s) for %s", canceledCount, symbol) - } - - // Return error if all cancellations failed - if len(cancelErrors) > 0 && canceledCount == 0 { - return fmt.Errorf("failed to cancel take-profit orders: %v", cancelErrors) - } - - return nil -} - -// CancelAllOrders Cancel all orders -func (t *AsterTrader) CancelAllOrders(symbol string) error { - params := map[string]interface{}{ - "symbol": symbol, - } - - _, err := t.request("DELETE", "/fapi/v3/allOpenOrders", params) - return err -} - -// CancelStopOrders Cancel take-profit/stop-loss orders for this symbol (used to adjust TP/SL positions) -func (t *AsterTrader) CancelStopOrders(symbol string) error { - // Get all open orders for this symbol - params := map[string]interface{}{ - "symbol": symbol, - } - - body, err := t.request("GET", "/fapi/v3/openOrders", params) - if err != nil { - return fmt.Errorf("failed to get open orders: %w", err) - } - - var orders []map[string]interface{} - if err := json.Unmarshal(body, &orders); err != nil { - return fmt.Errorf("failed to parse order data: %w", err) - } - - // Filter and cancel take-profit/stop-loss orders - canceledCount := 0 - for _, order := range orders { - orderType, _ := order["type"].(string) - - // Only cancel stop-loss and take-profit orders - if orderType == "STOP_MARKET" || - orderType == "TAKE_PROFIT_MARKET" || - orderType == "STOP" || - orderType == "TAKE_PROFIT" { - - orderID, _ := order["orderId"].(float64) - cancelParams := map[string]interface{}{ - "symbol": symbol, - "orderId": int64(orderID), - } - - _, err := t.request("DELETE", "/fapi/v3/order", cancelParams) - if err != nil { - logger.Infof(" ⚠ Failed to cancel order %d: %v", int64(orderID), err) - continue - } - - canceledCount++ - logger.Infof(" ✓ Canceled take-profit/stop-loss order for %s (order ID: %d, type: %s)", - symbol, int64(orderID), orderType) - } - } - - if canceledCount == 0 { - logger.Infof(" ℹ %s no take-profit/stop-loss orders to cancel", symbol) - } else { - logger.Infof(" ✓ Canceled %d take-profit/stop-loss order(s) for %s", canceledCount, symbol) - } - - return nil -} - -// FormatQuantity Format quantity (implements Trader interface) -func (t *AsterTrader) FormatQuantity(symbol string, quantity float64) (string, error) { - formatted, err := t.formatQuantity(symbol, quantity) - if err != nil { - return "", err - } - return fmt.Sprintf("%v", formatted), nil -} - -// GetOrderStatus Get order status -func (t *AsterTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { - params := map[string]interface{}{ - "symbol": symbol, - "orderId": orderID, - } - - body, err := t.request("GET", "/fapi/v3/order", params) - if err != nil { - return nil, fmt.Errorf("failed to get order status: %w", err) - } - - var result map[string]interface{} - if err := json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("failed to parse order response: %w", err) - } - - // Standardize return fields - response := map[string]interface{}{ - "orderId": result["orderId"], - "symbol": result["symbol"], - "status": result["status"], - "side": result["side"], - "type": result["type"], - "time": result["time"], - "updateTime": result["updateTime"], - "commission": 0.0, // Aster may require separate query - } - - // Parse numeric fields - if avgPrice, ok := result["avgPrice"].(string); ok { - if v, err := strconv.ParseFloat(avgPrice, 64); err == nil { - response["avgPrice"] = v - } - } else if avgPrice, ok := result["avgPrice"].(float64); ok { - response["avgPrice"] = avgPrice - } - - if executedQty, ok := result["executedQty"].(string); ok { - if v, err := strconv.ParseFloat(executedQty, 64); err == nil { - response["executedQty"] = v - } - } else if executedQty, ok := result["executedQty"].(float64); ok { - response["executedQty"] = executedQty - } - - return response, nil -} - -// GetClosedPnL gets recent closing trades from Aster -// Note: Aster does NOT have a position history API, only trade history. -// This returns individual closing trades for real-time position closure detection. -func (t *AsterTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) { - trades, err := t.GetTrades(startTime, limit) - if err != nil { - return nil, err - } - - // Filter only closing trades (realizedPnl != 0) - var records []types.ClosedPnLRecord - for _, trade := range trades { - if trade.RealizedPnL == 0 { - continue - } - - // Determine side from PositionSide or trade direction - side := "long" - if trade.PositionSide == "SHORT" || trade.PositionSide == "short" { - side = "short" - } else if trade.PositionSide == "BOTH" || trade.PositionSide == "" { - if trade.Side == "SELL" || trade.Side == "Sell" { - side = "long" - } else { - side = "short" - } - } - - // Calculate entry price from PnL - var entryPrice float64 - if trade.Quantity > 0 { - if side == "long" { - entryPrice = trade.Price - trade.RealizedPnL/trade.Quantity - } else { - entryPrice = trade.Price + trade.RealizedPnL/trade.Quantity - } - } - - records = append(records, types.ClosedPnLRecord{ - Symbol: trade.Symbol, - Side: side, - EntryPrice: entryPrice, - ExitPrice: trade.Price, - Quantity: trade.Quantity, - RealizedPnL: trade.RealizedPnL, - Fee: trade.Fee, - ExitTime: trade.Time, - EntryTime: trade.Time, - OrderID: trade.TradeID, - ExchangeID: trade.TradeID, - CloseType: "unknown", - }) - } - - return records, nil -} - -// AsterTradeRecord represents a trade from Aster API -type AsterTradeRecord struct { - ID int64 `json:"id"` - Symbol string `json:"symbol"` - OrderID int64 `json:"orderId"` - Side string `json:"side"` // BUY or SELL - PositionSide string `json:"positionSide"` // LONG or SHORT - Price string `json:"price"` - Qty string `json:"qty"` - RealizedPnl string `json:"realizedPnl"` - Commission string `json:"commission"` - Time int64 `json:"time"` - Buyer bool `json:"buyer"` - Maker bool `json:"maker"` -} - -// GetTrades retrieves trade history from Aster -func (t *AsterTrader) GetTrades(startTime time.Time, limit int) ([]types.TradeRecord, error) { - if limit <= 0 { - limit = 500 - } - - // Build request params - params := map[string]interface{}{ - "startTime": startTime.UnixMilli(), - "limit": limit, - } - - // Use existing request method with signing - body, err := t.request("GET", "/fapi/v3/userTrades", params) - if err != nil { - logger.Infof("⚠️ Aster userTrades API error: %v", err) - return []types.TradeRecord{}, nil - } - - var asterTrades []AsterTradeRecord - if err := json.Unmarshal(body, &asterTrades); err != nil { - logger.Infof("⚠️ Failed to parse Aster trades response: %v", err) - return []types.TradeRecord{}, nil - } - - // Convert to unified TradeRecord format - var result []types.TradeRecord - for _, at := range asterTrades { - price, _ := strconv.ParseFloat(at.Price, 64) - qty, _ := strconv.ParseFloat(at.Qty, 64) - fee, _ := strconv.ParseFloat(at.Commission, 64) - pnl, _ := strconv.ParseFloat(at.RealizedPnl, 64) - - trade := types.TradeRecord{ - TradeID: strconv.FormatInt(at.ID, 10), - Symbol: at.Symbol, - Side: at.Side, - PositionSide: at.PositionSide, - Price: price, - Quantity: qty, - RealizedPnL: pnl, - Fee: fee, - Time: time.UnixMilli(at.Time).UTC(), - } - result = append(result, trade) - } - - return result, nil -} - -// GetOpenOrders gets all open/pending orders for a symbol -func (t *AsterTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) { - params := map[string]interface{}{ - "symbol": symbol, - } - - body, err := t.request("GET", "/fapi/v3/openOrders", params) - if err != nil { - return nil, fmt.Errorf("failed to get open orders: %w", err) - } - - var orders []struct { - OrderID int64 `json:"orderId"` - Symbol string `json:"symbol"` - Side string `json:"side"` - PositionSide string `json:"positionSide"` - Type string `json:"type"` - Price string `json:"price"` - StopPrice string `json:"stopPrice"` - OrigQty string `json:"origQty"` - Status string `json:"status"` - } - - if err := json.Unmarshal(body, &orders); err != nil { - return nil, fmt.Errorf("failed to parse open orders: %w", err) - } - - var result []types.OpenOrder - for _, order := range orders { - price, _ := strconv.ParseFloat(order.Price, 64) - stopPrice, _ := strconv.ParseFloat(order.StopPrice, 64) - quantity, _ := strconv.ParseFloat(order.OrigQty, 64) - - result = append(result, types.OpenOrder{ - OrderID: fmt.Sprintf("%d", order.OrderID), - Symbol: order.Symbol, - Side: order.Side, - PositionSide: order.PositionSide, - Type: order.Type, - Price: price, - StopPrice: stopPrice, - Quantity: quantity, - Status: order.Status, - }) - } - - logger.Infof("✓ ASTER GetOpenOrders: found %d open orders for %s", len(result), symbol) - return result, nil -} - -// PlaceLimitOrder places a limit order for grid trading -func (t *AsterTrader) PlaceLimitOrder(req *types.LimitOrderRequest) (*types.LimitOrderResult, error) { - // Format price and quantity to correct precision - formattedPrice, err := t.formatPrice(req.Symbol, req.Price) - if err != nil { - return nil, fmt.Errorf("failed to format price: %w", err) - } - formattedQty, err := t.formatQuantity(req.Symbol, req.Quantity) - if err != nil { - return nil, fmt.Errorf("failed to format quantity: %w", err) - } - - // Get precision information - prec, err := t.getPrecision(req.Symbol) - if err != nil { - return nil, fmt.Errorf("failed to get precision: %w", err) - } - - // Convert to string with correct precision format - priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision) - qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision) - - // Determine side - side := "BUY" - if req.Side == "SELL" || req.Side == "Sell" || req.Side == "sell" { - side = "SELL" - } - - params := map[string]interface{}{ - "symbol": req.Symbol, - "positionSide": "BOTH", - "type": "LIMIT", - "side": side, - "timeInForce": "GTC", - "quantity": qtyStr, - "price": priceStr, - } - - // Add reduceOnly if specified - if req.ReduceOnly { - params["reduceOnly"] = "true" - } - - body, err := t.request("POST", "/fapi/v3/order", params) - if err != nil { - return nil, fmt.Errorf("failed to place limit order: %w", err) - } - - var result map[string]interface{} - if err := json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("failed to parse order response: %w", err) - } - - // Extract order ID - orderID := "" - if id, ok := result["orderId"].(float64); ok { - orderID = fmt.Sprintf("%.0f", id) - } else if id, ok := result["orderId"].(string); ok { - orderID = id - } - - // Extract client order ID - clientOrderID := "" - if cid, ok := result["clientOrderId"].(string); ok { - clientOrderID = cid - } - - return &types.LimitOrderResult{ - OrderID: orderID, - ClientID: clientOrderID, - Symbol: req.Symbol, - Side: side, - Price: formattedPrice, - Quantity: formattedQty, - Status: "NEW", - }, nil -} - -// CancelOrder cancels a specific order by order ID -func (t *AsterTrader) CancelOrder(symbol, orderID string) error { - params := map[string]interface{}{ - "symbol": symbol, - "orderId": orderID, - } - - _, err := t.request("DELETE", "/fapi/v3/order", params) - if err != nil { - return fmt.Errorf("failed to cancel order %s: %w", orderID, err) - } - - return nil -} - -// GetOrderBook gets the order book for a symbol -func (t *AsterTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { - if depth <= 0 { - depth = 20 - } - - // Aster uses public endpoint (no signature required) - resp, err := t.client.Get(fmt.Sprintf("%s/fapi/v3/depth?symbol=%s&limit=%d", t.baseURL, symbol, depth)) - if err != nil { - return nil, nil, fmt.Errorf("failed to fetch order book: %w", err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusOK { - return nil, nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) - } - - var result struct { - Bids [][]string `json:"bids"` // [[price, qty], ...] - Asks [][]string `json:"asks"` // [[price, qty], ...] - } - if err := json.Unmarshal(body, &result); err != nil { - return nil, nil, fmt.Errorf("failed to parse order book: %w", err) - } - - // Convert string arrays to float64 arrays - bids = make([][]float64, len(result.Bids)) - for i, bid := range result.Bids { - if len(bid) >= 2 { - price, _ := strconv.ParseFloat(bid[0], 64) - qty, _ := strconv.ParseFloat(bid[1], 64) - bids[i] = []float64{price, qty} - } - } - - asks = make([][]float64, len(result.Asks)) - for i, ask := range result.Asks { - if len(ask) >= 2 { - price, _ := strconv.ParseFloat(ask[0], 64) - qty, _ := strconv.ParseFloat(ask[1], 64) - asks[i] = []float64{price, qty} - } - } - - return bids, asks, nil -} diff --git a/trader/aster/trader_account.go b/trader/aster/trader_account.go new file mode 100644 index 00000000..bfc41304 --- /dev/null +++ b/trader/aster/trader_account.go @@ -0,0 +1,299 @@ +package aster + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "nofx/logger" + "nofx/trader/types" + "strconv" + "time" +) + +// GetBalance Get account balance +func (t *AsterTrader) GetBalance() (map[string]interface{}, error) { + params := make(map[string]interface{}) + body, err := t.request("GET", "/fapi/v3/balance", params) + if err != nil { + return nil, err + } + + var balances []map[string]interface{} + if err := json.Unmarshal(body, &balances); err != nil { + return nil, err + } + + // Find USDT balance + availableBalance := 0.0 + crossUnPnl := 0.0 + crossWalletBalance := 0.0 + foundUSDT := false + + for _, bal := range balances { + if asset, ok := bal["asset"].(string); ok && asset == "USDT" { + foundUSDT = true + + // Parse Aster fields (reference: https://github.com/asterdex/api-docs) + if avail, ok := bal["availableBalance"].(string); ok { + availableBalance, _ = strconv.ParseFloat(avail, 64) + } + if unpnl, ok := bal["crossUnPnl"].(string); ok { + crossUnPnl, _ = strconv.ParseFloat(unpnl, 64) + } + if cwb, ok := bal["crossWalletBalance"].(string); ok { + crossWalletBalance, _ = strconv.ParseFloat(cwb, 64) + } + break + } + } + + if !foundUSDT { + logger.Infof("⚠️ USDT asset record not found!") + } + + // Get positions to calculate margin used and real unrealized PnL + positions, err := t.GetPositions() + if err != nil { + logger.Infof("⚠️ Failed to get position information: %v", err) + // fallback: use simple calculation when unable to get positions + return map[string]interface{}{ + "totalWalletBalance": crossWalletBalance, + "availableBalance": availableBalance, + "totalUnrealizedProfit": crossUnPnl, + }, nil + } + + // Critical fix: accumulate real unrealized PnL from positions + // Aster's crossUnPnl field is inaccurate, need to recalculate from position data + totalMarginUsed := 0.0 + realUnrealizedPnl := 0.0 + for _, pos := range positions { + markPrice := pos["markPrice"].(float64) + quantity := pos["positionAmt"].(float64) + if quantity < 0 { + quantity = -quantity + } + unrealizedPnl := pos["unRealizedProfit"].(float64) + realUnrealizedPnl += unrealizedPnl + + leverage := 10 + if lev, ok := pos["leverage"].(float64); ok { + leverage = int(lev) + } + marginUsed := (quantity * markPrice) / float64(leverage) + totalMarginUsed += marginUsed + } + + // Aster correct calculation method: + // Total equity = available balance + margin used + // Wallet balance = total equity - unrealized PnL + // Unrealized PnL = calculated from accumulated positions (don't use API's crossUnPnl) + totalEquity := availableBalance + totalMarginUsed + totalWalletBalance := totalEquity - realUnrealizedPnl + + return map[string]interface{}{ + "totalWalletBalance": totalWalletBalance, // Wallet balance (excluding unrealized PnL) + "availableBalance": availableBalance, // Available balance + "totalUnrealizedProfit": realUnrealizedPnl, // Unrealized PnL (accumulated from positions) + }, nil +} + +// GetMarketPrice Get market price +func (t *AsterTrader) GetMarketPrice(symbol string) (float64, error) { + // Use ticker interface to get current price + resp, err := t.client.Get(fmt.Sprintf("%s/fapi/v3/ticker/price?symbol=%s", t.baseURL, symbol)) + if err != nil { + return 0, err + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return 0, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err := json.Unmarshal(body, &result); err != nil { + return 0, err + } + + priceStr, ok := result["price"].(string) + if !ok { + return 0, errors.New("unable to get price") + } + + return strconv.ParseFloat(priceStr, 64) +} + +// GetClosedPnL gets recent closing trades from Aster +// Note: Aster does NOT have a position history API, only trade history. +// This returns individual closing trades for real-time position closure detection. +func (t *AsterTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) { + trades, err := t.GetTrades(startTime, limit) + if err != nil { + return nil, err + } + + // Filter only closing trades (realizedPnl != 0) + var records []types.ClosedPnLRecord + for _, trade := range trades { + if trade.RealizedPnL == 0 { + continue + } + + // Determine side from PositionSide or trade direction + side := "long" + if trade.PositionSide == "SHORT" || trade.PositionSide == "short" { + side = "short" + } else if trade.PositionSide == "BOTH" || trade.PositionSide == "" { + if trade.Side == "SELL" || trade.Side == "Sell" { + side = "long" + } else { + side = "short" + } + } + + // Calculate entry price from PnL + var entryPrice float64 + if trade.Quantity > 0 { + if side == "long" { + entryPrice = trade.Price - trade.RealizedPnL/trade.Quantity + } else { + entryPrice = trade.Price + trade.RealizedPnL/trade.Quantity + } + } + + records = append(records, types.ClosedPnLRecord{ + Symbol: trade.Symbol, + Side: side, + EntryPrice: entryPrice, + ExitPrice: trade.Price, + Quantity: trade.Quantity, + RealizedPnL: trade.RealizedPnL, + Fee: trade.Fee, + ExitTime: trade.Time, + EntryTime: trade.Time, + OrderID: trade.TradeID, + ExchangeID: trade.TradeID, + CloseType: "unknown", + }) + } + + return records, nil +} + +// AsterTradeRecord represents a trade from Aster API +type AsterTradeRecord struct { + ID int64 `json:"id"` + Symbol string `json:"symbol"` + OrderID int64 `json:"orderId"` + Side string `json:"side"` // BUY or SELL + PositionSide string `json:"positionSide"` // LONG or SHORT + Price string `json:"price"` + Qty string `json:"qty"` + RealizedPnl string `json:"realizedPnl"` + Commission string `json:"commission"` + Time int64 `json:"time"` + Buyer bool `json:"buyer"` + Maker bool `json:"maker"` +} + +// GetTrades retrieves trade history from Aster +func (t *AsterTrader) GetTrades(startTime time.Time, limit int) ([]types.TradeRecord, error) { + if limit <= 0 { + limit = 500 + } + + // Build request params + params := map[string]interface{}{ + "startTime": startTime.UnixMilli(), + "limit": limit, + } + + // Use existing request method with signing + body, err := t.request("GET", "/fapi/v3/userTrades", params) + if err != nil { + logger.Infof("⚠️ Aster userTrades API error: %v", err) + return []types.TradeRecord{}, nil + } + + var asterTrades []AsterTradeRecord + if err := json.Unmarshal(body, &asterTrades); err != nil { + logger.Infof("⚠️ Failed to parse Aster trades response: %v", err) + return []types.TradeRecord{}, nil + } + + // Convert to unified TradeRecord format + var result []types.TradeRecord + for _, at := range asterTrades { + price, _ := strconv.ParseFloat(at.Price, 64) + qty, _ := strconv.ParseFloat(at.Qty, 64) + fee, _ := strconv.ParseFloat(at.Commission, 64) + pnl, _ := strconv.ParseFloat(at.RealizedPnl, 64) + + trade := types.TradeRecord{ + TradeID: strconv.FormatInt(at.ID, 10), + Symbol: at.Symbol, + Side: at.Side, + PositionSide: at.PositionSide, + Price: price, + Quantity: qty, + RealizedPnL: pnl, + Fee: fee, + Time: time.UnixMilli(at.Time).UTC(), + } + result = append(result, trade) + } + + return result, nil +} + +// GetOrderBook gets the order book for a symbol +func (t *AsterTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { + if depth <= 0 { + depth = 20 + } + + // Aster uses public endpoint (no signature required) + resp, err := t.client.Get(fmt.Sprintf("%s/fapi/v3/depth?symbol=%s&limit=%d", t.baseURL, symbol, depth)) + if err != nil { + return nil, nil, fmt.Errorf("failed to fetch order book: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + var result struct { + Bids [][]string `json:"bids"` // [[price, qty], ...] + Asks [][]string `json:"asks"` // [[price, qty], ...] + } + if err := json.Unmarshal(body, &result); err != nil { + return nil, nil, fmt.Errorf("failed to parse order book: %w", err) + } + + // Convert string arrays to float64 arrays + bids = make([][]float64, len(result.Bids)) + for i, bid := range result.Bids { + if len(bid) >= 2 { + price, _ := strconv.ParseFloat(bid[0], 64) + qty, _ := strconv.ParseFloat(bid[1], 64) + bids[i] = []float64{price, qty} + } + } + + asks = make([][]float64, len(result.Asks)) + for i, ask := range result.Asks { + if len(ask) >= 2 { + price, _ := strconv.ParseFloat(ask[0], 64) + qty, _ := strconv.ParseFloat(ask[1], 64) + asks[i] = []float64{price, qty} + } + } + + return bids, asks, nil +} diff --git a/trader/aster/trader_orders.go b/trader/aster/trader_orders.go new file mode 100644 index 00000000..12ce9502 --- /dev/null +++ b/trader/aster/trader_orders.go @@ -0,0 +1,787 @@ +package aster + +import ( + "encoding/json" + "fmt" + "nofx/logger" + "nofx/trader/types" + "strconv" + "strings" +) + +// OpenLong Open long position +func (t *AsterTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { + // Cancel all pending orders before opening position to prevent position stacking from residual orders + if err := t.CancelAllOrders(symbol); err != nil { + logger.Infof(" ⚠ Failed to cancel pending orders (continuing to open position): %v", err) + } + + // Set leverage first (non-fatal if position already exists) + if err := t.SetLeverage(symbol, leverage); err != nil { + // Error -2030: Cannot adjust leverage when position exists + // This is expected when adding to an existing position, continue with current leverage + if strings.Contains(err.Error(), "-2030") { + logger.Infof(" ⚠ Cannot change leverage (position exists), using current leverage: %v", err) + } else { + return nil, fmt.Errorf("failed to set leverage: %w", err) + } + } + + // Get current price + price, err := t.GetMarketPrice(symbol) + if err != nil { + return nil, err + } + + // Use limit order to simulate market order (price set slightly higher to ensure execution) + limitPrice := price * 1.01 + + // Format price and quantity to correct precision + formattedPrice, err := t.formatPrice(symbol, limitPrice) + if err != nil { + return nil, err + } + formattedQty, err := t.formatQuantity(symbol, quantity) + if err != nil { + return nil, err + } + + // Get precision information + prec, err := t.getPrecision(symbol) + if err != nil { + return nil, err + } + + // Convert to string with correct precision format + priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision) + qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision) + + logger.Infof(" 📏 Precision handling: price %.8f -> %s (precision=%d), quantity %.8f -> %s (precision=%d)", + limitPrice, priceStr, prec.PricePrecision, quantity, qtyStr, prec.QuantityPrecision) + + params := map[string]interface{}{ + "symbol": symbol, + "positionSide": "BOTH", + "type": "LIMIT", + "side": "BUY", + "timeInForce": "GTC", + "quantity": qtyStr, + "price": priceStr, + } + + body, err := t.request("POST", "/fapi/v3/order", params) + if err != nil { + return nil, err + } + + var result map[string]interface{} + if err := json.Unmarshal(body, &result); err != nil { + return nil, err + } + + return result, nil +} + +// OpenShort Open short position +func (t *AsterTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { + // Cancel all pending orders before opening position to prevent position stacking from residual orders + if err := t.CancelAllOrders(symbol); err != nil { + logger.Infof(" ⚠ Failed to cancel pending orders (continuing to open position): %v", err) + } + + // Set leverage first (non-fatal if position already exists) + if err := t.SetLeverage(symbol, leverage); err != nil { + // Error -2030: Cannot adjust leverage when position exists + // This is expected when adding to an existing position, continue with current leverage + if strings.Contains(err.Error(), "-2030") { + logger.Infof(" ⚠ Cannot change leverage (position exists), using current leverage: %v", err) + } else { + return nil, fmt.Errorf("failed to set leverage: %w", err) + } + } + + // Get current price + price, err := t.GetMarketPrice(symbol) + if err != nil { + return nil, err + } + + // Use limit order to simulate market order (price set slightly lower to ensure execution) + limitPrice := price * 0.99 + + // Format price and quantity to correct precision + formattedPrice, err := t.formatPrice(symbol, limitPrice) + if err != nil { + return nil, err + } + formattedQty, err := t.formatQuantity(symbol, quantity) + if err != nil { + return nil, err + } + + // Get precision information + prec, err := t.getPrecision(symbol) + if err != nil { + return nil, err + } + + // Convert to string with correct precision format + priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision) + qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision) + + logger.Infof(" 📏 Precision handling: price %.8f -> %s (precision=%d), quantity %.8f -> %s (precision=%d)", + limitPrice, priceStr, prec.PricePrecision, quantity, qtyStr, prec.QuantityPrecision) + + params := map[string]interface{}{ + "symbol": symbol, + "positionSide": "BOTH", + "type": "LIMIT", + "side": "SELL", + "timeInForce": "GTC", + "quantity": qtyStr, + "price": priceStr, + } + + body, err := t.request("POST", "/fapi/v3/order", params) + if err != nil { + return nil, err + } + + var result map[string]interface{} + if err := json.Unmarshal(body, &result); err != nil { + return nil, err + } + + return result, nil +} + +// CloseLong Close long position +func (t *AsterTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) { + // If quantity is 0, get current position quantity + if quantity == 0 { + positions, err := t.GetPositions() + if err != nil { + return nil, err + } + + for _, pos := range positions { + if pos["symbol"] == symbol && pos["side"] == "long" { + quantity = pos["positionAmt"].(float64) + break + } + } + + if quantity == 0 { + return nil, fmt.Errorf("no long position found for %s", symbol) + } + logger.Infof(" 📊 Retrieved long position quantity: %.8f", quantity) + } + + price, err := t.GetMarketPrice(symbol) + if err != nil { + return nil, err + } + + limitPrice := price * 0.99 + + // Format price and quantity to correct precision + formattedPrice, err := t.formatPrice(symbol, limitPrice) + if err != nil { + return nil, err + } + formattedQty, err := t.formatQuantity(symbol, quantity) + if err != nil { + return nil, err + } + + // Get precision information + prec, err := t.getPrecision(symbol) + if err != nil { + return nil, err + } + + // Convert to string with correct precision format + priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision) + qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision) + + logger.Infof(" 📏 Precision handling: price %.8f -> %s (precision=%d), quantity %.8f -> %s (precision=%d)", + limitPrice, priceStr, prec.PricePrecision, quantity, qtyStr, prec.QuantityPrecision) + + params := map[string]interface{}{ + "symbol": symbol, + "positionSide": "BOTH", + "type": "LIMIT", + "side": "SELL", + "timeInForce": "GTC", + "quantity": qtyStr, + "price": priceStr, + } + + body, err := t.request("POST", "/fapi/v3/order", params) + if err != nil { + return nil, err + } + + var result map[string]interface{} + if err := json.Unmarshal(body, &result); err != nil { + return nil, err + } + + logger.Infof("✓ Successfully closed long position: %s quantity: %s", symbol, qtyStr) + + // Cancel all pending orders for this symbol after closing position (stop-loss/take-profit orders) + if err := t.CancelAllOrders(symbol); err != nil { + logger.Infof(" ⚠ Failed to cancel pending orders: %v", err) + } + + return result, nil +} + +// CloseShort Close short position +func (t *AsterTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) { + // If quantity is 0, get current position quantity + if quantity == 0 { + positions, err := t.GetPositions() + if err != nil { + return nil, err + } + + for _, pos := range positions { + if pos["symbol"] == symbol && pos["side"] == "short" { + // Aster's GetPositions has already converted short position quantity to positive, use directly + quantity = pos["positionAmt"].(float64) + break + } + } + + if quantity == 0 { + return nil, fmt.Errorf("no short position found for %s", symbol) + } + logger.Infof(" 📊 Retrieved short position quantity: %.8f", quantity) + } + + price, err := t.GetMarketPrice(symbol) + if err != nil { + return nil, err + } + + limitPrice := price * 1.01 + + // Format price and quantity to correct precision + formattedPrice, err := t.formatPrice(symbol, limitPrice) + if err != nil { + return nil, err + } + formattedQty, err := t.formatQuantity(symbol, quantity) + if err != nil { + return nil, err + } + + // Get precision information + prec, err := t.getPrecision(symbol) + if err != nil { + return nil, err + } + + // Convert to string with correct precision format + priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision) + qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision) + + logger.Infof(" 📏 Precision handling: price %.8f -> %s (precision=%d), quantity %.8f -> %s (precision=%d)", + limitPrice, priceStr, prec.PricePrecision, quantity, qtyStr, prec.QuantityPrecision) + + params := map[string]interface{}{ + "symbol": symbol, + "positionSide": "BOTH", + "type": "LIMIT", + "side": "BUY", + "timeInForce": "GTC", + "quantity": qtyStr, + "price": priceStr, + } + + body, err := t.request("POST", "/fapi/v3/order", params) + if err != nil { + return nil, err + } + + var result map[string]interface{} + if err := json.Unmarshal(body, &result); err != nil { + return nil, err + } + + logger.Infof("✓ Successfully closed short position: %s quantity: %s", symbol, qtyStr) + + // Cancel all pending orders for this symbol after closing position (stop-loss/take-profit orders) + if err := t.CancelAllOrders(symbol); err != nil { + logger.Infof(" ⚠ Failed to cancel pending orders: %v", err) + } + + return result, nil +} + +// SetStopLoss Set stop loss +func (t *AsterTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error { + side := "SELL" + if positionSide == "SHORT" { + side = "BUY" + } + + // Format price and quantity to correct precision + formattedPrice, err := t.formatPrice(symbol, stopPrice) + if err != nil { + return err + } + formattedQty, err := t.formatQuantity(symbol, quantity) + if err != nil { + return err + } + + // Get precision information + prec, err := t.getPrecision(symbol) + if err != nil { + return err + } + + // Convert to string with correct precision format + priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision) + qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision) + + params := map[string]interface{}{ + "symbol": symbol, + "positionSide": "BOTH", + "type": "STOP_MARKET", + "side": side, + "stopPrice": priceStr, + "quantity": qtyStr, + "timeInForce": "GTC", + } + + _, err = t.request("POST", "/fapi/v3/order", params) + return err +} + +// SetTakeProfit Set take profit +func (t *AsterTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error { + side := "SELL" + if positionSide == "SHORT" { + side = "BUY" + } + + // Format price and quantity to correct precision + formattedPrice, err := t.formatPrice(symbol, takeProfitPrice) + if err != nil { + return err + } + formattedQty, err := t.formatQuantity(symbol, quantity) + if err != nil { + return err + } + + // Get precision information + prec, err := t.getPrecision(symbol) + if err != nil { + return err + } + + // Convert to string with correct precision format + priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision) + qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision) + + params := map[string]interface{}{ + "symbol": symbol, + "positionSide": "BOTH", + "type": "TAKE_PROFIT_MARKET", + "side": side, + "stopPrice": priceStr, + "quantity": qtyStr, + "timeInForce": "GTC", + } + + _, err = t.request("POST", "/fapi/v3/order", params) + return err +} + +// CancelStopLossOrders Cancel stop-loss orders only (does not affect take-profit orders) +func (t *AsterTrader) CancelStopLossOrders(symbol string) error { + // Get all open orders for this symbol + params := map[string]interface{}{ + "symbol": symbol, + } + + body, err := t.request("GET", "/fapi/v3/openOrders", params) + if err != nil { + return fmt.Errorf("failed to get open orders: %w", err) + } + + var orders []map[string]interface{} + if err := json.Unmarshal(body, &orders); err != nil { + return fmt.Errorf("failed to parse order data: %w", err) + } + + // Filter and cancel stop-loss orders (cancel all directions including LONG and SHORT) + canceledCount := 0 + var cancelErrors []error + for _, order := range orders { + orderType, _ := order["type"].(string) + + // Only cancel stop-loss orders (don't cancel take-profit orders) + if orderType == "STOP_MARKET" || orderType == "STOP" { + orderID, _ := order["orderId"].(float64) + positionSide, _ := order["positionSide"].(string) + cancelParams := map[string]interface{}{ + "symbol": symbol, + "orderId": int64(orderID), + } + + _, err := t.request("DELETE", "/fapi/v1/order", cancelParams) + if err != nil { + errMsg := fmt.Sprintf("order ID %d: %v", int64(orderID), err) + cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg)) + logger.Infof(" ⚠ Failed to cancel stop-loss order: %s", errMsg) + continue + } + + canceledCount++ + logger.Infof(" ✓ Canceled stop-loss order (order ID: %d, type: %s, direction: %s)", int64(orderID), orderType, positionSide) + } + } + + if canceledCount == 0 && len(cancelErrors) == 0 { + logger.Infof(" ℹ %s no stop-loss orders to cancel", symbol) + } else if canceledCount > 0 { + logger.Infof(" ✓ Canceled %d stop-loss order(s) for %s", canceledCount, symbol) + } + + // Return error if all cancellations failed + if len(cancelErrors) > 0 && canceledCount == 0 { + return fmt.Errorf("failed to cancel stop-loss orders: %v", cancelErrors) + } + + return nil +} + +// CancelTakeProfitOrders Cancel take-profit orders only (does not affect stop-loss orders) +func (t *AsterTrader) CancelTakeProfitOrders(symbol string) error { + // Get all open orders for this symbol + params := map[string]interface{}{ + "symbol": symbol, + } + + body, err := t.request("GET", "/fapi/v3/openOrders", params) + if err != nil { + return fmt.Errorf("failed to get open orders: %w", err) + } + + var orders []map[string]interface{} + if err := json.Unmarshal(body, &orders); err != nil { + return fmt.Errorf("failed to parse order data: %w", err) + } + + // Filter and cancel take-profit orders (cancel all directions including LONG and SHORT) + canceledCount := 0 + var cancelErrors []error + for _, order := range orders { + orderType, _ := order["type"].(string) + + // Only cancel take-profit orders (don't cancel stop-loss orders) + if orderType == "TAKE_PROFIT_MARKET" || orderType == "TAKE_PROFIT" { + orderID, _ := order["orderId"].(float64) + positionSide, _ := order["positionSide"].(string) + cancelParams := map[string]interface{}{ + "symbol": symbol, + "orderId": int64(orderID), + } + + _, err := t.request("DELETE", "/fapi/v1/order", cancelParams) + if err != nil { + errMsg := fmt.Sprintf("order ID %d: %v", int64(orderID), err) + cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg)) + logger.Infof(" ⚠ Failed to cancel take-profit order: %s", errMsg) + continue + } + + canceledCount++ + logger.Infof(" ✓ Canceled take-profit order (order ID: %d, type: %s, direction: %s)", int64(orderID), orderType, positionSide) + } + } + + if canceledCount == 0 && len(cancelErrors) == 0 { + logger.Infof(" ℹ %s no take-profit orders to cancel", symbol) + } else if canceledCount > 0 { + logger.Infof(" ✓ Canceled %d take-profit order(s) for %s", canceledCount, symbol) + } + + // Return error if all cancellations failed + if len(cancelErrors) > 0 && canceledCount == 0 { + return fmt.Errorf("failed to cancel take-profit orders: %v", cancelErrors) + } + + return nil +} + +// CancelAllOrders Cancel all orders +func (t *AsterTrader) CancelAllOrders(symbol string) error { + params := map[string]interface{}{ + "symbol": symbol, + } + + _, err := t.request("DELETE", "/fapi/v3/allOpenOrders", params) + return err +} + +// CancelStopOrders Cancel take-profit/stop-loss orders for this symbol (used to adjust TP/SL positions) +func (t *AsterTrader) CancelStopOrders(symbol string) error { + // Get all open orders for this symbol + params := map[string]interface{}{ + "symbol": symbol, + } + + body, err := t.request("GET", "/fapi/v3/openOrders", params) + if err != nil { + return fmt.Errorf("failed to get open orders: %w", err) + } + + var orders []map[string]interface{} + if err := json.Unmarshal(body, &orders); err != nil { + return fmt.Errorf("failed to parse order data: %w", err) + } + + // Filter and cancel take-profit/stop-loss orders + canceledCount := 0 + for _, order := range orders { + orderType, _ := order["type"].(string) + + // Only cancel stop-loss and take-profit orders + if orderType == "STOP_MARKET" || + orderType == "TAKE_PROFIT_MARKET" || + orderType == "STOP" || + orderType == "TAKE_PROFIT" { + + orderID, _ := order["orderId"].(float64) + cancelParams := map[string]interface{}{ + "symbol": symbol, + "orderId": int64(orderID), + } + + _, err := t.request("DELETE", "/fapi/v3/order", cancelParams) + if err != nil { + logger.Infof(" ⚠ Failed to cancel order %d: %v", int64(orderID), err) + continue + } + + canceledCount++ + logger.Infof(" ✓ Canceled take-profit/stop-loss order for %s (order ID: %d, type: %s)", + symbol, int64(orderID), orderType) + } + } + + if canceledCount == 0 { + logger.Infof(" ℹ %s no take-profit/stop-loss orders to cancel", symbol) + } else { + logger.Infof(" ✓ Canceled %d take-profit/stop-loss order(s) for %s", canceledCount, symbol) + } + + return nil +} + +// FormatQuantity Format quantity (implements Trader interface) +func (t *AsterTrader) FormatQuantity(symbol string, quantity float64) (string, error) { + formatted, err := t.formatQuantity(symbol, quantity) + if err != nil { + return "", err + } + return fmt.Sprintf("%v", formatted), nil +} + +// GetOrderStatus Get order status +func (t *AsterTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { + params := map[string]interface{}{ + "symbol": symbol, + "orderId": orderID, + } + + body, err := t.request("GET", "/fapi/v3/order", params) + if err != nil { + return nil, fmt.Errorf("failed to get order status: %w", err) + } + + var result map[string]interface{} + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse order response: %w", err) + } + + // Standardize return fields + response := map[string]interface{}{ + "orderId": result["orderId"], + "symbol": result["symbol"], + "status": result["status"], + "side": result["side"], + "type": result["type"], + "time": result["time"], + "updateTime": result["updateTime"], + "commission": 0.0, // Aster may require separate query + } + + // Parse numeric fields + if avgPrice, ok := result["avgPrice"].(string); ok { + if v, err := strconv.ParseFloat(avgPrice, 64); err == nil { + response["avgPrice"] = v + } + } else if avgPrice, ok := result["avgPrice"].(float64); ok { + response["avgPrice"] = avgPrice + } + + if executedQty, ok := result["executedQty"].(string); ok { + if v, err := strconv.ParseFloat(executedQty, 64); err == nil { + response["executedQty"] = v + } + } else if executedQty, ok := result["executedQty"].(float64); ok { + response["executedQty"] = executedQty + } + + return response, nil +} + +// GetOpenOrders gets all open/pending orders for a symbol +func (t *AsterTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) { + params := map[string]interface{}{ + "symbol": symbol, + } + + body, err := t.request("GET", "/fapi/v3/openOrders", params) + if err != nil { + return nil, fmt.Errorf("failed to get open orders: %w", err) + } + + var orders []struct { + OrderID int64 `json:"orderId"` + Symbol string `json:"symbol"` + Side string `json:"side"` + PositionSide string `json:"positionSide"` + Type string `json:"type"` + Price string `json:"price"` + StopPrice string `json:"stopPrice"` + OrigQty string `json:"origQty"` + Status string `json:"status"` + } + + if err := json.Unmarshal(body, &orders); err != nil { + return nil, fmt.Errorf("failed to parse open orders: %w", err) + } + + var result []types.OpenOrder + for _, order := range orders { + price, _ := strconv.ParseFloat(order.Price, 64) + stopPrice, _ := strconv.ParseFloat(order.StopPrice, 64) + quantity, _ := strconv.ParseFloat(order.OrigQty, 64) + + result = append(result, types.OpenOrder{ + OrderID: fmt.Sprintf("%d", order.OrderID), + Symbol: order.Symbol, + Side: order.Side, + PositionSide: order.PositionSide, + Type: order.Type, + Price: price, + StopPrice: stopPrice, + Quantity: quantity, + Status: order.Status, + }) + } + + logger.Infof("✓ ASTER GetOpenOrders: found %d open orders for %s", len(result), symbol) + return result, nil +} + +// PlaceLimitOrder places a limit order for grid trading +func (t *AsterTrader) PlaceLimitOrder(req *types.LimitOrderRequest) (*types.LimitOrderResult, error) { + // Format price and quantity to correct precision + formattedPrice, err := t.formatPrice(req.Symbol, req.Price) + if err != nil { + return nil, fmt.Errorf("failed to format price: %w", err) + } + formattedQty, err := t.formatQuantity(req.Symbol, req.Quantity) + if err != nil { + return nil, fmt.Errorf("failed to format quantity: %w", err) + } + + // Get precision information + prec, err := t.getPrecision(req.Symbol) + if err != nil { + return nil, fmt.Errorf("failed to get precision: %w", err) + } + + // Convert to string with correct precision format + priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision) + qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision) + + // Determine side + side := "BUY" + if req.Side == "SELL" || req.Side == "Sell" || req.Side == "sell" { + side = "SELL" + } + + params := map[string]interface{}{ + "symbol": req.Symbol, + "positionSide": "BOTH", + "type": "LIMIT", + "side": side, + "timeInForce": "GTC", + "quantity": qtyStr, + "price": priceStr, + } + + // Add reduceOnly if specified + if req.ReduceOnly { + params["reduceOnly"] = "true" + } + + body, err := t.request("POST", "/fapi/v3/order", params) + if err != nil { + return nil, fmt.Errorf("failed to place limit order: %w", err) + } + + var result map[string]interface{} + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse order response: %w", err) + } + + // Extract order ID + orderID := "" + if id, ok := result["orderId"].(float64); ok { + orderID = fmt.Sprintf("%.0f", id) + } else if id, ok := result["orderId"].(string); ok { + orderID = id + } + + // Extract client order ID + clientOrderID := "" + if cid, ok := result["clientOrderId"].(string); ok { + clientOrderID = cid + } + + return &types.LimitOrderResult{ + OrderID: orderID, + ClientID: clientOrderID, + Symbol: req.Symbol, + Side: side, + Price: formattedPrice, + Quantity: formattedQty, + Status: "NEW", + }, nil +} + +// CancelOrder cancels a specific order by order ID +func (t *AsterTrader) CancelOrder(symbol, orderID string) error { + params := map[string]interface{}{ + "symbol": symbol, + "orderId": orderID, + } + + _, err := t.request("DELETE", "/fapi/v3/order", params) + if err != nil { + return fmt.Errorf("failed to cancel order %s: %w", orderID, err) + } + + return nil +} diff --git a/trader/aster/trader_positions.go b/trader/aster/trader_positions.go new file mode 100644 index 00000000..746fd8dd --- /dev/null +++ b/trader/aster/trader_positions.go @@ -0,0 +1,121 @@ +package aster + +import ( + "encoding/json" + "fmt" + "nofx/logger" + "strconv" + "strings" +) + +// GetPositions Get position information +func (t *AsterTrader) GetPositions() ([]map[string]interface{}, error) { + params := make(map[string]interface{}) + body, err := t.request("GET", "/fapi/v3/positionRisk", params) + if err != nil { + return nil, err + } + + var positions []map[string]interface{} + if err := json.Unmarshal(body, &positions); err != nil { + return nil, err + } + + result := []map[string]interface{}{} + for _, pos := range positions { + posAmtStr, ok := pos["positionAmt"].(string) + if !ok { + continue + } + + posAmt, _ := strconv.ParseFloat(posAmtStr, 64) + if posAmt == 0 { + continue // Skip empty positions + } + + entryPrice, _ := strconv.ParseFloat(pos["entryPrice"].(string), 64) + markPrice, _ := strconv.ParseFloat(pos["markPrice"].(string), 64) + unRealizedProfit, _ := strconv.ParseFloat(pos["unRealizedProfit"].(string), 64) + leverageVal, _ := strconv.ParseFloat(pos["leverage"].(string), 64) + liquidationPrice, _ := strconv.ParseFloat(pos["liquidationPrice"].(string), 64) + + // Determine direction (consistent with Binance) + side := "long" + if posAmt < 0 { + side = "short" + posAmt = -posAmt + } + + // Return same field names as Binance + result = append(result, map[string]interface{}{ + "symbol": pos["symbol"], + "side": side, + "positionAmt": posAmt, + "entryPrice": entryPrice, + "markPrice": markPrice, + "unRealizedProfit": unRealizedProfit, + "leverage": leverageVal, + "liquidationPrice": liquidationPrice, + }) + } + + return result, nil +} + +// SetMarginMode Set margin mode +func (t *AsterTrader) SetMarginMode(symbol string, isCrossMargin bool) error { + // Aster supports margin mode settings + // API format similar to Binance: CROSSED (cross margin) / ISOLATED (isolated margin) + marginType := "CROSSED" + if !isCrossMargin { + marginType = "ISOLATED" + } + + params := map[string]interface{}{ + "symbol": symbol, + "marginType": marginType, + } + + // Use request method to call API + _, err := t.request("POST", "/fapi/v3/marginType", params) + if err != nil { + // Ignore error if it indicates no need to change + if strings.Contains(err.Error(), "No need to change") || + strings.Contains(err.Error(), "Margin type cannot be changed") { + logger.Infof(" ✓ %s margin mode is already %s or cannot be changed due to existing positions", symbol, marginType) + return nil + } + // Detect multi-assets mode (error code -4168) + if strings.Contains(err.Error(), "Multi-Assets mode") || + strings.Contains(err.Error(), "-4168") || + strings.Contains(err.Error(), "4168") { + logger.Infof(" ⚠️ %s detected multi-assets mode, forcing cross margin mode", symbol) + logger.Infof(" 💡 Tip: To use isolated margin mode, please disable multi-assets mode on the exchange") + return nil + } + // Detect unified account API + if strings.Contains(err.Error(), "unified") || + strings.Contains(err.Error(), "portfolio") || + strings.Contains(err.Error(), "Portfolio") { + logger.Infof(" ❌ %s detected unified account API, cannot perform futures trading", symbol) + return fmt.Errorf("please use 'Spot & Futures Trading' API permission, not 'Unified Account API'") + } + logger.Infof(" ⚠️ Failed to set margin mode: %v", err) + // Don't return error, let trading continue + return nil + } + + logger.Infof(" ✓ %s margin mode has been set to %s", symbol, marginType) + return nil +} + +// SetLeverage Set leverage multiplier +func (t *AsterTrader) SetLeverage(symbol string, leverage int) error { + params := map[string]interface{}{ + "symbol": symbol, + "leverage": leverage, + } + + _, err := t.request("POST", "/fapi/v3/leverage", params) + return err +} diff --git a/trader/aster/order_sync.go b/trader/aster/trader_sync.go similarity index 100% rename from trader/aster/order_sync.go rename to trader/aster/trader_sync.go diff --git a/trader/auto_trader_decision.go b/trader/auto_trader_decision.go index 2cec395c..b2898174 100644 --- a/trader/auto_trader_decision.go +++ b/trader/auto_trader_decision.go @@ -3,7 +3,7 @@ package trader import ( "fmt" "math" - "nofx/experience" + "nofx/telemetry" "nofx/kernel" "nofx/logger" "nofx/market" @@ -345,7 +345,7 @@ func (at *AutoTrader) recordAndConfirmOrder(orderResult map[string]interface{}, // Send anonymous trade statistics for experience improvement (async, non-blocking) // This helps us understand overall product usage across all deployments - experience.TrackTrade(experience.TradeEvent{ + telemetry.TrackTrade(telemetry.TradeEvent{ Exchange: at.exchange, TradeType: action, Symbol: symbol, diff --git a/trader/auto_trader_grid.go b/trader/auto_trader_grid.go index 37b62d2d..6151dbbb 100644 --- a/trader/auto_trader_grid.go +++ b/trader/auto_trader_grid.go @@ -3,7 +3,6 @@ package trader import ( "encoding/json" "fmt" - "math" "nofx/kernel" "nofx/logger" "nofx/market" @@ -32,16 +31,16 @@ type GridState struct { GridSpacing float64 // State flags - IsPaused bool + IsPaused bool IsInitialized bool // Performance tracking - TotalProfit float64 - TotalTrades int - WinningTrades int - MaxDrawdown float64 - PeakEquity float64 - DailyPnL float64 + TotalProfit float64 + TotalTrades int + WinningTrades int + MaxDrawdown float64 + PeakEquity float64 + DailyPnL float64 LastDailyReset time.Time // Order tracking @@ -67,9 +66,9 @@ type GridState struct { CurrentRegimeLevel string // Grid direction adjustment - CurrentDirection market.GridDirection - DirectionChangedAt time.Time - DirectionChangeCount int + CurrentDirection market.GridDirection + DirectionChangedAt time.Time + DirectionChangeCount int } // NewGridState creates a new grid state @@ -83,7 +82,7 @@ func NewGridState(config *store.GridStrategyConfig) *GridState { } // ============================================================================ -// Breakout Detection +// Breakout Detection (price vs grid boundary) // ============================================================================ // BreakoutType represents the type of price breakout @@ -282,226 +281,8 @@ func (at *AutoTrader) handleBreakout(breakoutType BreakoutType, breakoutPct floa return nil } -// checkBoxBreakout checks for multi-period box breakouts and takes appropriate action -func (at *AutoTrader) checkBoxBreakout() error { - gridConfig := at.config.StrategyConfig.GridConfig - if gridConfig == nil { - return nil - } - - // Get box data - box, err := market.GetBoxData(gridConfig.Symbol) - if err != nil { - logger.Infof("Failed to get box data: %v", err) - return nil // Non-fatal, continue with other checks - } - - // Update grid state with box values - at.gridState.mu.Lock() - at.gridState.ShortBoxUpper = box.ShortUpper - at.gridState.ShortBoxLower = box.ShortLower - at.gridState.MidBoxUpper = box.MidUpper - at.gridState.MidBoxLower = box.MidLower - at.gridState.LongBoxUpper = box.LongUpper - at.gridState.LongBoxLower = box.LongLower - at.gridState.mu.Unlock() - - // Detect breakout - breakoutLevel, direction := detectBoxBreakout(box) - - // Get current breakout state - state := &BreakoutState{ - Level: market.BreakoutLevel(at.gridState.BreakoutLevel), - Direction: at.gridState.BreakoutDirection, - ConfirmCount: at.gridState.BreakoutConfirmCount, - } - - // Check if breakout is confirmed (3 candles) - confirmed := confirmBreakout(state, breakoutLevel, direction) - - // Update grid state - at.gridState.mu.Lock() - at.gridState.BreakoutLevel = string(state.Level) - at.gridState.BreakoutDirection = state.Direction - at.gridState.BreakoutConfirmCount = state.ConfirmCount - at.gridState.mu.Unlock() - - if !confirmed { - return nil - } - - // Take action based on breakout level - // Use direction-aware action if enabled - enableDirectionAdjust := gridConfig.EnableDirectionAdjust - action := getBreakoutActionWithDirection(breakoutLevel, enableDirectionAdjust) - - // If direction adjustment action, determine the new direction - if action == BreakoutActionAdjustDirection { - box, _ := market.GetBoxData(gridConfig.Symbol) - newDirection := determineGridDirection(box, at.gridState.CurrentDirection, breakoutLevel, direction) - return at.executeDirectionAdjustment(newDirection) - } - - return at.executeBreakoutAction(action) -} - -// executeBreakoutAction executes the appropriate action for a breakout -func (at *AutoTrader) executeBreakoutAction(action BreakoutAction) error { - switch action { - case BreakoutActionReducePosition: - // Short box breakout: reduce position to 50% - logger.Infof("Short box breakout confirmed, reducing position to 50%%") - at.gridState.mu.Lock() - at.gridState.PositionReductionPct = 50 - at.gridState.mu.Unlock() - return nil - - case BreakoutActionPauseGrid: - // Mid box breakout: pause grid + cancel orders - logger.Infof("Mid box breakout confirmed, pausing grid and canceling orders") - at.gridState.mu.Lock() - at.gridState.IsPaused = true - at.gridState.mu.Unlock() - return at.cancelAllGridOrders() - - case BreakoutActionCloseAll: - // Long box breakout: pause + cancel + close all - logger.Infof("Long box breakout confirmed, closing all positions") - at.gridState.mu.Lock() - at.gridState.IsPaused = true - at.gridState.mu.Unlock() - if err := at.cancelAllGridOrders(); err != nil { - logger.Infof("Failed to cancel orders: %v", err) - } - return at.closeAllPositions() - - case BreakoutActionAdjustDirection: - // Direction adjustment is handled separately via executeDirectionAdjustment - // This case should not be reached, but handle gracefully - logger.Infof("Direction adjustment action received via executeBreakoutAction") - return nil - } - - return nil -} - -// executeDirectionAdjustment handles grid direction changes based on box breakout -func (at *AutoTrader) executeDirectionAdjustment(newDirection market.GridDirection) error { - at.gridState.mu.RLock() - oldDirection := at.gridState.CurrentDirection - at.gridState.mu.RUnlock() - - if oldDirection == newDirection { - return nil // No change needed - } - - logger.Infof("[Grid] Direction adjustment: %s → %s", oldDirection, newDirection) - - // Cancel existing orders before adjusting - if err := at.cancelAllGridOrders(); err != nil { - logger.Warnf("[Grid] Failed to cancel orders during direction adjustment: %v", err) - } - - // Apply the new direction - return at.adjustGridDirection(newDirection) -} - -// closeAllPositions closes all open positions for the grid symbol -func (at *AutoTrader) closeAllPositions() error { - gridConfig := at.config.StrategyConfig.GridConfig - if gridConfig == nil { - return nil - } - - positions, err := at.trader.GetPositions() - if err != nil { - return fmt.Errorf("failed to get positions: %w", err) - } - - for _, pos := range positions { - symbol, _ := pos["symbol"].(string) - if symbol != gridConfig.Symbol { - continue - } - - size, _ := pos["positionAmt"].(float64) - if size == 0 { - continue - } - - if size > 0 { - _, err = at.trader.CloseLong(symbol, size) - } else { - _, err = at.trader.CloseShort(symbol, -size) - } - if err != nil { - logger.Infof("Failed to close position: %v", err) - } - } - - return nil -} - -// checkFalseBreakoutRecovery checks if price has returned to box after breakout -func (at *AutoTrader) checkFalseBreakoutRecovery() error { - gridConfig := at.config.StrategyConfig.GridConfig - if gridConfig == nil { - return nil - } - - at.gridState.mu.RLock() - breakoutLevel := at.gridState.BreakoutLevel - isPaused := at.gridState.IsPaused - positionReduction := at.gridState.PositionReductionPct - currentDirection := at.gridState.CurrentDirection - at.gridState.mu.RUnlock() - - // Only check if we had a breakout or non-neutral direction - needsRecoveryCheck := breakoutLevel != string(market.BreakoutNone) || - positionReduction != 0 || - isPaused || - (gridConfig.EnableDirectionAdjust && currentDirection != market.GridDirectionNeutral) - - if !needsRecoveryCheck { - return nil - } - - // Get current box data - box, err := market.GetBoxData(gridConfig.Symbol) - if err != nil { - return nil - } - - // Check if price is back inside the long box - if box.CurrentPrice >= box.LongLower && box.CurrentPrice <= box.LongUpper { - logger.Infof("Price returned to box, recovering with 50%% position") - - at.gridState.mu.Lock() - at.gridState.BreakoutLevel = string(market.BreakoutNone) - at.gridState.BreakoutDirection = "" - at.gridState.BreakoutConfirmCount = 0 - at.gridState.PositionReductionPct = 50 // Recover at 50% - at.gridState.IsPaused = false - at.gridState.mu.Unlock() - } - - // Check for direction recovery toward neutral (if direction adjustment is enabled) - if gridConfig.EnableDirectionAdjust && currentDirection != market.GridDirectionNeutral { - if shouldRecoverDirection(box, currentDirection) { - newDirection := determineRecoveryDirection(box.CurrentPrice, box, currentDirection) - if newDirection != currentDirection { - logger.Infof("[Grid] Direction recovery: %s → %s (price back in short box)", - currentDirection, newDirection) - at.adjustGridDirection(newDirection) - } - } - } - - return nil -} - // ============================================================================ -// AutoTrader Grid Methods +// AutoTrader Grid Lifecycle // ============================================================================ // InitializeGrid initializes the grid state and calculates levels @@ -551,210 +332,12 @@ func (at *AutoTrader) InitializeGrid() error { logger.Infof("[Grid] Leverage set to %dx for %s", gridConfig.Leverage, gridConfig.Symbol) } - logger.Infof("📊 [Grid] Initialized: %d levels, $%.2f - $%.2f, spacing $%.2f", + logger.Infof("[Grid] Initialized: %d levels, $%.2f - $%.2f, spacing $%.2f", gridConfig.GridCount, at.gridState.LowerPrice, at.gridState.UpperPrice, at.gridState.GridSpacing) return nil } -// calculateDefaultBounds calculates default bounds based on price -func (at *AutoTrader) calculateDefaultBounds(price float64, config *store.GridStrategyConfig) { - // Default: ±3% from current price - multiplier := 0.03 * float64(config.GridCount) / 10 - at.gridState.UpperPrice = price * (1 + multiplier) - at.gridState.LowerPrice = price * (1 - multiplier) -} - -// calculateATRBounds calculates bounds using ATR -func (at *AutoTrader) calculateATRBounds(price float64, mktData *market.Data, config *store.GridStrategyConfig) { - atr := 0.0 - if mktData.LongerTermContext != nil { - atr = mktData.LongerTermContext.ATR14 - } - - if atr <= 0 { - at.calculateDefaultBounds(price, config) - return - } - - multiplier := config.ATRMultiplier - if multiplier <= 0 { - multiplier = 2.0 - } - - halfRange := atr * multiplier - at.gridState.UpperPrice = price + halfRange - at.gridState.LowerPrice = price - halfRange -} - -// initializeGridLevels creates the grid level structure -func (at *AutoTrader) initializeGridLevels(currentPrice float64, config *store.GridStrategyConfig) { - levels := make([]kernel.GridLevelInfo, config.GridCount) - totalWeight := 0.0 - weights := make([]float64, config.GridCount) - - // Calculate weights based on distribution - for i := 0; i < config.GridCount; i++ { - switch config.Distribution { - case "gaussian": - // Gaussian distribution - more weight in the middle - center := float64(config.GridCount-1) / 2 - sigma := float64(config.GridCount) / 4 - weights[i] = math.Exp(-math.Pow(float64(i)-center, 2) / (2 * sigma * sigma)) - case "pyramid": - // Pyramid - more weight at bottom - weights[i] = float64(config.GridCount - i) - default: // uniform - weights[i] = 1.0 - } - totalWeight += weights[i] - } - - // Create levels - for i := 0; i < config.GridCount; i++ { - price := at.gridState.LowerPrice + float64(i)*at.gridState.GridSpacing - allocatedUSD := config.TotalInvestment * weights[i] / totalWeight - - // Determine initial side (below current price = buy, above = sell) - side := "buy" - if price > currentPrice { - side = "sell" - } - - levels[i] = kernel.GridLevelInfo{ - Index: i, - Price: price, - State: "empty", - Side: side, - AllocatedUSD: allocatedUSD, - } - } - - at.gridState.Levels = levels - - // Apply direction-based side assignment if enabled - if config.EnableDirectionAdjust { - at.applyGridDirection(currentPrice) - } -} - -// applyGridDirection adjusts grid level sides based on the current direction -// This redistributes buy/sell levels according to the direction bias ratio -func (at *AutoTrader) applyGridDirection(currentPrice float64) { - config := at.gridState.Config - direction := at.gridState.CurrentDirection - - // Get bias ratio from config, default to 0.7 (70%/30%) - biasRatio := config.DirectionBiasRatio - if biasRatio <= 0 || biasRatio > 1 { - biasRatio = 0.7 - } - - buyRatio, _ := direction.GetBuySellRatio(biasRatio) - - // Calculate how many levels should be buy vs sell based on direction - totalLevels := len(at.gridState.Levels) - targetBuyLevels := int(float64(totalLevels) * buyRatio) - - // For neutral: use price-based assignment (buy below, sell above) - if direction == market.GridDirectionNeutral { - for i := range at.gridState.Levels { - if at.gridState.Levels[i].Price <= currentPrice { - at.gridState.Levels[i].Side = "buy" - } else { - at.gridState.Levels[i].Side = "sell" - } - } - return - } - - // For long/long_bias: more buy levels - // For short/short_bias: more sell levels - switch direction { - case market.GridDirectionLong: - // 100% buy - all levels are buy - for i := range at.gridState.Levels { - at.gridState.Levels[i].Side = "buy" - } - - case market.GridDirectionShort: - // 100% sell - all levels are sell - for i := range at.gridState.Levels { - at.gridState.Levels[i].Side = "sell" - } - - case market.GridDirectionLongBias, market.GridDirectionShortBias: - // Assign sides based on position relative to current price - // For long_bias: keep all below as buy, convert some above to buy - // For short_bias: keep all above as sell, convert some below to sell - buyCount := 0 - sellCount := 0 - - for i := range at.gridState.Levels { - needMoreBuys := buyCount < targetBuyLevels - needMoreSells := sellCount < (totalLevels - targetBuyLevels) - - if at.gridState.Levels[i].Price <= currentPrice { - // Level below or at current price - if needMoreBuys { - at.gridState.Levels[i].Side = "buy" - buyCount++ - } else { - at.gridState.Levels[i].Side = "sell" - sellCount++ - } - } else { - // Level above current price - if needMoreSells && direction == market.GridDirectionShortBias { - at.gridState.Levels[i].Side = "sell" - sellCount++ - } else if needMoreBuys && direction == market.GridDirectionLongBias { - at.gridState.Levels[i].Side = "buy" - buyCount++ - } else if needMoreSells { - at.gridState.Levels[i].Side = "sell" - sellCount++ - } else { - at.gridState.Levels[i].Side = "buy" - buyCount++ - } - } - } - } - - logger.Infof("[Grid] Applied direction %s: buy_ratio=%.0f%%, levels reconfigured", - direction, buyRatio*100) -} - -// adjustGridDirection handles runtime direction adjustment when breakout is detected -func (at *AutoTrader) adjustGridDirection(newDirection market.GridDirection) error { - at.gridState.mu.Lock() - defer at.gridState.mu.Unlock() - - oldDirection := at.gridState.CurrentDirection - if oldDirection == newDirection { - return nil // No change needed - } - - at.gridState.CurrentDirection = newDirection - at.gridState.DirectionChangedAt = time.Now() - at.gridState.DirectionChangeCount++ - - logger.Infof("[Grid] Direction changed: %s → %s (change count: %d)", - oldDirection, newDirection, at.gridState.DirectionChangeCount) - - // Get current price for recalculation - currentPrice, err := at.trader.GetMarketPrice(at.gridState.Config.Symbol) - if err != nil { - return fmt.Errorf("failed to get market price: %w", err) - } - - // Reapply direction to grid levels - at.applyGridDirection(currentPrice) - - return nil -} - // RunGridCycle executes one grid trading cycle func (at *AutoTrader) RunGridCycle() error { // Check if trader is stopped (early exit to prevent trades after Stop() is called) @@ -965,312 +548,12 @@ func (at *AutoTrader) executeGridDecision(d *kernel.Decision) error { } } -// checkTotalPositionLimit checks if adding a new position would exceed total limits -// Returns: (allowed bool, currentPositionValue float64, maxAllowed float64) -func (at *AutoTrader) checkTotalPositionLimit(symbol string, additionalValue float64) (bool, float64, float64) { - gridConfig := at.config.StrategyConfig.GridConfig - - // Calculate max allowed total position value - // Total position should not exceed: TotalInvestment × Leverage - maxTotalPositionValue := gridConfig.TotalInvestment * float64(gridConfig.Leverage) - - // Get current position value from exchange - currentPositionValue := 0.0 - positions, err := at.trader.GetPositions() - if err == nil { - for _, pos := range positions { - if sym, ok := pos["symbol"].(string); ok && sym == symbol { - if size, ok := pos["positionAmt"].(float64); ok { - if price, ok := pos["markPrice"].(float64); ok { - currentPositionValue = math.Abs(size) * price - } else if entryPrice, ok := pos["entryPrice"].(float64); ok { - currentPositionValue = math.Abs(size) * entryPrice - } - } - } - } +// IsGridStrategy returns true if current strategy is grid trading +func (at *AutoTrader) IsGridStrategy() bool { + if at.config.StrategyConfig == nil { + return false } - - // Also count pending orders as potential position - at.gridState.mu.RLock() - pendingValue := 0.0 - for _, level := range at.gridState.Levels { - if level.State == "pending" { - pendingValue += level.OrderQuantity * level.Price - } - } - at.gridState.mu.RUnlock() - - totalAfterOrder := currentPositionValue + pendingValue + additionalValue - allowed := totalAfterOrder <= maxTotalPositionValue - - return allowed, currentPositionValue + pendingValue, maxTotalPositionValue -} - -// placeGridLimitOrder places a limit order for grid trading -func (at *AutoTrader) placeGridLimitOrder(d *kernel.Decision, side string) error { - // Check if trader supports GridTrader interface - gridTrader, ok := at.trader.(GridTrader) - if !ok { - // Fallback to adapter - gridTrader = NewGridTraderAdapter(at.trader) - } - - gridConfig := at.config.StrategyConfig.GridConfig - - // CRITICAL: Validate and cap quantity to prevent excessive position sizes - // This protects against AI miscalculations or leverage misconfigurations - quantity := d.Quantity - if d.Price > 0 && gridConfig.TotalInvestment > 0 { - // Calculate max allowed position value per grid level - // Each level gets proportional share of total investment - maxMarginPerLevel := gridConfig.TotalInvestment / float64(gridConfig.GridCount) - maxPositionValuePerLevel := maxMarginPerLevel * float64(gridConfig.Leverage) - maxQuantityPerLevel := maxPositionValuePerLevel / d.Price - - // Also get the level's allocated USD for additional validation - at.gridState.mu.RLock() - var levelAllocatedUSD float64 - if d.LevelIndex >= 0 && d.LevelIndex < len(at.gridState.Levels) { - levelAllocatedUSD = at.gridState.Levels[d.LevelIndex].AllocatedUSD - } - at.gridState.mu.RUnlock() - - // Use level-specific allocation if available - if levelAllocatedUSD > 0 { - levelMaxPositionValue := levelAllocatedUSD * float64(gridConfig.Leverage) - levelMaxQuantity := levelMaxPositionValue / d.Price - if levelMaxQuantity < maxQuantityPerLevel { - maxQuantityPerLevel = levelMaxQuantity - } - } - - // Cap quantity if it exceeds the maximum allowed - if quantity > maxQuantityPerLevel { - logger.Warnf("[Grid] ⚠️ Quantity %.4f exceeds max allowed %.4f (position_value $%.2f > max $%.2f), capping", - quantity, maxQuantityPerLevel, quantity*d.Price, maxPositionValuePerLevel) - quantity = maxQuantityPerLevel - } - - // Safety check: ensure position value is reasonable (within 2x of intended max as absolute limit) - positionValue := quantity * d.Price - absoluteMaxValue := gridConfig.TotalInvestment * float64(gridConfig.Leverage) * 2 // 2x safety margin - if positionValue > absoluteMaxValue { - logger.Errorf("[Grid] CRITICAL: Position value $%.2f exceeds absolute max $%.2f! Rejecting order.", - positionValue, absoluteMaxValue) - return fmt.Errorf("position value $%.2f exceeds safety limit $%.2f", positionValue, absoluteMaxValue) - } - } - - // CRITICAL: Check total position limit before placing order - orderValue := quantity * d.Price - allowed, currentValue, maxValue := at.checkTotalPositionLimit(d.Symbol, orderValue) - if !allowed { - logger.Errorf("[Grid] TOTAL POSITION LIMIT EXCEEDED: current=$%.2f + order=$%.2f > max=$%.2f. Rejecting order.", - currentValue, orderValue, maxValue) - return fmt.Errorf("total position value $%.2f would exceed limit $%.2f", currentValue+orderValue, maxValue) - } - - req := &LimitOrderRequest{ - Symbol: d.Symbol, - Side: side, - Price: d.Price, - Quantity: quantity, // Use validated/capped quantity - Leverage: gridConfig.Leverage, - PostOnly: gridConfig.UseMakerOnly, - ReduceOnly: false, - ClientID: fmt.Sprintf("grid-%d-%d", d.LevelIndex, time.Now().UnixNano()%1000000), - } - - result, err := gridTrader.PlaceLimitOrder(req) - if err != nil { - return fmt.Errorf("failed to place limit order: %w", err) - } - - // Update grid level state - at.gridState.mu.Lock() - if d.LevelIndex >= 0 && d.LevelIndex < len(at.gridState.Levels) { - at.gridState.Levels[d.LevelIndex].State = "pending" - at.gridState.Levels[d.LevelIndex].OrderID = result.OrderID - at.gridState.Levels[d.LevelIndex].OrderQuantity = d.Quantity - at.gridState.OrderBook[result.OrderID] = d.LevelIndex - } - at.gridState.mu.Unlock() - - logger.Infof("[Grid] Placed %s limit order at $%.2f, qty=%.4f, level=%d, orderID=%s", - side, d.Price, d.Quantity, d.LevelIndex, result.OrderID) - - return nil -} - -// cancelGridOrder cancels a specific grid order -func (at *AutoTrader) cancelGridOrder(d *kernel.Decision) error { - gridTrader, ok := at.trader.(GridTrader) - if !ok { - gridTrader = NewGridTraderAdapter(at.trader) - } - - if err := gridTrader.CancelOrder(d.Symbol, d.OrderID); err != nil { - return fmt.Errorf("failed to cancel order: %w", err) - } - - // Update state - at.gridState.mu.Lock() - if levelIdx, ok := at.gridState.OrderBook[d.OrderID]; ok { - if levelIdx >= 0 && levelIdx < len(at.gridState.Levels) { - at.gridState.Levels[levelIdx].State = "empty" - at.gridState.Levels[levelIdx].OrderID = "" - at.gridState.Levels[levelIdx].OrderQuantity = 0 - } - delete(at.gridState.OrderBook, d.OrderID) - } - at.gridState.mu.Unlock() - - logger.Infof("[Grid] Cancelled order: %s", d.OrderID) - return nil -} - -// cancelAllGridOrders cancels all grid orders -func (at *AutoTrader) cancelAllGridOrders() error { - gridConfig := at.config.StrategyConfig.GridConfig - - if err := at.trader.CancelAllOrders(gridConfig.Symbol); err != nil { - return fmt.Errorf("failed to cancel all orders: %w", err) - } - - // Reset all pending levels - at.gridState.mu.Lock() - for i := range at.gridState.Levels { - if at.gridState.Levels[i].State == "pending" { - at.gridState.Levels[i].State = "empty" - at.gridState.Levels[i].OrderID = "" - at.gridState.Levels[i].OrderQuantity = 0 - } - } - at.gridState.OrderBook = make(map[string]int) - at.gridState.mu.Unlock() - - logger.Infof("[Grid] Cancelled all orders") - return nil -} - -// pauseGrid pauses grid trading -func (at *AutoTrader) pauseGrid(reason string) error { - at.cancelAllGridOrders() - - at.gridState.mu.Lock() - at.gridState.IsPaused = true - at.gridState.mu.Unlock() - - logger.Infof("[Grid] Paused: %s", reason) - return nil -} - -// resumeGrid resumes grid trading -func (at *AutoTrader) resumeGrid() error { - at.gridState.mu.Lock() - at.gridState.IsPaused = false - at.gridState.mu.Unlock() - - logger.Infof("[Grid] Resumed") - return nil -} - -// adjustGrid adjusts grid parameters -func (at *AutoTrader) adjustGrid(d *kernel.Decision) error { - // Cancel existing orders first - at.cancelAllGridOrders() - - gridConfig := at.config.StrategyConfig.GridConfig - - // Get current price - price, err := at.trader.GetMarketPrice(gridConfig.Symbol) - if err != nil { - return fmt.Errorf("failed to get market price: %w", err) - } - - // Reinitialize grid levels - at.initializeGridLevels(price, gridConfig) - - logger.Infof("[Grid] Adjusted grid bounds around price $%.2f", price) - return nil -} - -// syncGridState syncs grid state with exchange -func (at *AutoTrader) syncGridState() { - gridConfig := at.config.StrategyConfig.GridConfig - - // Get open orders from exchange - openOrders, err := at.trader.GetOpenOrders(gridConfig.Symbol) - if err != nil { - logger.Warnf("[Grid] Failed to get open orders: %v", err) - return - } - - // Build set of active order IDs - activeOrderIDs := make(map[string]bool) - for _, order := range openOrders { - activeOrderIDs[order.OrderID] = true - } - - // Get current positions to verify fills - positions, err := at.trader.GetPositions() - currentPositionSize := 0.0 - if err != nil { - logger.Warnf("[Grid] Failed to get positions for state sync: %v", err) - } else { - for _, pos := range positions { - if sym, ok := pos["symbol"].(string); ok && sym == gridConfig.Symbol { - if size, ok := pos["positionAmt"].(float64); ok { - currentPositionSize = size - } - } - } - } - - // Update levels based on order status - at.gridState.mu.Lock() - expectedPositionSize := 0.0 - for _, level := range at.gridState.Levels { - if level.State == "filled" { - expectedPositionSize += level.PositionSize - } - } - - for i := range at.gridState.Levels { - level := &at.gridState.Levels[i] - if level.State == "pending" && level.OrderID != "" { - if !activeOrderIDs[level.OrderID] { - // Order no longer exists - check if position changed to determine fill vs cancel - // This is a heuristic - ideally we'd query order history - // If current position is larger than expected filled positions, this order was likely filled - if math.Abs(currentPositionSize) > math.Abs(expectedPositionSize) { - // Position increased, likely filled - level.State = "filled" - level.PositionEntry = level.Price - level.PositionSize = level.OrderQuantity - at.gridState.TotalTrades++ - logger.Infof("[Grid] Level %d order filled at $%.2f", i, level.Price) - } else { - // Position didn't increase as expected, likely cancelled - level.State = "empty" - level.OrderID = "" - level.OrderQuantity = 0 - logger.Infof("[Grid] Level %d order cancelled/expired", i) - } - delete(at.gridState.OrderBook, level.OrderID) - } - } - } - at.gridState.mu.Unlock() - - logger.Debugf("[Grid] Synced state: position=%.4f, orders=%d", currentPositionSize, len(openOrders)) - - // Check stop loss - at.checkAndExecuteStopLoss() - - // Check grid skew - at.autoAdjustGrid() + return at.config.StrategyConfig.StrategyType == "grid_trading" && at.config.StrategyConfig.GridConfig != nil } // saveGridDecisionRecord saves the grid decision to database @@ -1323,317 +606,6 @@ func (at *AutoTrader) saveGridDecisionRecord(decision *kernel.FullDecision) { } } -// IsGridStrategy returns true if current strategy is grid trading -func (at *AutoTrader) IsGridStrategy() bool { - if at.config.StrategyConfig == nil { - return false - } - return at.config.StrategyConfig.StrategyType == "grid_trading" && at.config.StrategyConfig.GridConfig != nil -} - -// checkGridSkew checks if grid is heavily skewed (too many fills on one side) -// Returns: (skewed bool, buyFilledCount int, sellFilledCount int) -func (at *AutoTrader) checkGridSkew() (bool, int, int) { - at.gridState.mu.RLock() - defer at.gridState.mu.RUnlock() - - buyFilled := 0 - sellFilled := 0 - buyEmpty := 0 - sellEmpty := 0 - - for _, level := range at.gridState.Levels { - if level.Side == "buy" { - if level.State == "filled" { - buyFilled++ - } else if level.State == "empty" { - buyEmpty++ - } - } else { - if level.State == "filled" { - sellFilled++ - } else if level.State == "empty" { - sellEmpty++ - } - } - } - - // Grid is skewed if one side has 3x more fills than the other - // or if one side is completely empty - skewed := false - if buyFilled > 0 && sellFilled == 0 && sellEmpty > 5 { - skewed = true // All buys filled, no sells - } else if sellFilled > 0 && buyFilled == 0 && buyEmpty > 5 { - skewed = true // All sells filled, no buys - } else if buyFilled >= 3*sellFilled && buyFilled > 5 { - skewed = true - } else if sellFilled >= 3*buyFilled && sellFilled > 5 { - skewed = true - } - - return skewed, buyFilled, sellFilled -} - -// autoAdjustGrid automatically adjusts grid when heavily skewed -func (at *AutoTrader) autoAdjustGrid() { - skewed, buyFilled, sellFilled := at.checkGridSkew() - if !skewed { - return - } - - logger.Warnf("[Grid] Grid heavily skewed: buy_filled=%d, sell_filled=%d. Auto-adjusting...", - buyFilled, sellFilled) - - gridConfig := at.config.StrategyConfig.GridConfig - - // Get current price - currentPrice, err := at.trader.GetMarketPrice(gridConfig.Symbol) - if err != nil { - logger.Errorf("[Grid] Failed to get price for auto-adjust: %v", err) - return - } - - // Check if price is near grid boundary - at.gridState.mu.RLock() - upper := at.gridState.UpperPrice - lower := at.gridState.LowerPrice - at.gridState.mu.RUnlock() - - // Only adjust if price has moved significantly (>30% of grid range) - gridRange := upper - lower - midPrice := (upper + lower) / 2 - priceDeviation := math.Abs(currentPrice - midPrice) - - if priceDeviation < gridRange*0.3 { - return // Price still near center, don't adjust - } - - logger.Infof("[Grid] Adjusting grid around new price $%.2f", currentPrice) - - // Cancel existing orders first (before taking the lock for state modification) - if err := at.cancelAllGridOrders(); err != nil { - logger.Errorf("[Grid] Failed to cancel orders during auto-adjust: %v", err) - // Continue with adjustment anyway - } - - // CRITICAL FIX: Hold lock for the entire adjustment operation to ensure atomicity - at.gridState.mu.Lock() - defer at.gridState.mu.Unlock() - - // Preserve filled positions before reinitializing - filledPositions := make(map[int]kernel.GridLevelInfo) - for i, level := range at.gridState.Levels { - if level.State == "filled" { - filledPositions[i] = level - } - } - - // CRITICAL FIX: Recalculate grid bounds centered on current price - // Use the same logic as InitializeGrid() - either ATR-based or default percentage - if gridConfig.UseATRBounds { - // Try to get ATR for bound calculation - mktData, err := market.GetWithTimeframes(gridConfig.Symbol, []string{"4h"}, "4h", 20) - if err != nil { - logger.Warnf("[Grid] Failed to get market data for ATR during adjust: %v, using default bounds", err) - at.calculateDefaultBoundsLocked(currentPrice, gridConfig) - } else { - at.calculateATRBoundsLocked(currentPrice, mktData, gridConfig) - } - } else { - // Use default bounds calculation (scaled by grid count) - at.calculateDefaultBoundsLocked(currentPrice, gridConfig) - } - - // Recalculate grid spacing based on new bounds - at.gridState.GridSpacing = (at.gridState.UpperPrice - at.gridState.LowerPrice) / float64(gridConfig.GridCount-1) - - logger.Infof("[Grid] New bounds: $%.2f - $%.2f, spacing: $%.2f", - at.gridState.LowerPrice, at.gridState.UpperPrice, at.gridState.GridSpacing) - - // Initialize new grid levels (without lock since we already hold it) - at.initializeGridLevelsLocked(currentPrice, gridConfig) - - // CRITICAL FIX: Restore filled positions - find closest new level for each filled position - for _, filledLevel := range filledPositions { - closestIdx := -1 - closestDist := math.MaxFloat64 - - for i, newLevel := range at.gridState.Levels { - dist := math.Abs(newLevel.Price - filledLevel.PositionEntry) - if dist < closestDist { - closestDist = dist - closestIdx = i - } - } - - if closestIdx >= 0 { - // Restore the filled state to the closest level - at.gridState.Levels[closestIdx].State = "filled" - at.gridState.Levels[closestIdx].PositionEntry = filledLevel.PositionEntry - at.gridState.Levels[closestIdx].PositionSize = filledLevel.PositionSize - at.gridState.Levels[closestIdx].UnrealizedPnL = filledLevel.UnrealizedPnL - at.gridState.Levels[closestIdx].OrderID = filledLevel.OrderID - at.gridState.Levels[closestIdx].OrderQuantity = filledLevel.OrderQuantity - logger.Infof("[Grid] Restored filled position at level %d (entry $%.2f)", closestIdx, filledLevel.PositionEntry) - } - } -} - -// calculateDefaultBoundsLocked calculates default bounds (caller must hold lock) -func (at *AutoTrader) calculateDefaultBoundsLocked(price float64, config *store.GridStrategyConfig) { - // Default: ±3% from current price, scaled by grid count - multiplier := 0.03 * float64(config.GridCount) / 10 - at.gridState.UpperPrice = price * (1 + multiplier) - at.gridState.LowerPrice = price * (1 - multiplier) -} - -// calculateATRBoundsLocked calculates bounds using ATR (caller must hold lock) -func (at *AutoTrader) calculateATRBoundsLocked(price float64, mktData *market.Data, config *store.GridStrategyConfig) { - atr := 0.0 - if mktData.LongerTermContext != nil { - atr = mktData.LongerTermContext.ATR14 - } - - if atr <= 0 { - at.calculateDefaultBoundsLocked(price, config) - return - } - - multiplier := config.ATRMultiplier - if multiplier <= 0 { - multiplier = 2.0 - } - - halfRange := atr * multiplier - at.gridState.UpperPrice = price + halfRange - at.gridState.LowerPrice = price - halfRange -} - -// initializeGridLevelsLocked creates the grid level structure (caller must hold lock) -func (at *AutoTrader) initializeGridLevelsLocked(currentPrice float64, config *store.GridStrategyConfig) { - levels := make([]kernel.GridLevelInfo, config.GridCount) - totalWeight := 0.0 - weights := make([]float64, config.GridCount) - - // Calculate weights based on distribution - for i := 0; i < config.GridCount; i++ { - switch config.Distribution { - case "gaussian": - // Gaussian distribution - more weight in the middle - center := float64(config.GridCount-1) / 2 - sigma := float64(config.GridCount) / 4 - weights[i] = math.Exp(-math.Pow(float64(i)-center, 2) / (2 * sigma * sigma)) - case "pyramid": - // Pyramid - more weight at bottom - weights[i] = float64(config.GridCount - i) - default: // uniform - weights[i] = 1.0 - } - totalWeight += weights[i] - } - - // Create levels - for i := 0; i < config.GridCount; i++ { - price := at.gridState.LowerPrice + float64(i)*at.gridState.GridSpacing - allocatedUSD := config.TotalInvestment * weights[i] / totalWeight - - // Determine initial side (below current price = buy, above = sell) - side := "buy" - if price > currentPrice { - side = "sell" - } - - levels[i] = kernel.GridLevelInfo{ - Index: i, - Price: price, - State: "empty", - Side: side, - AllocatedUSD: allocatedUSD, - } - } - - at.gridState.Levels = levels - - // Apply direction-based side assignment if enabled (note: caller holds lock) - if config.EnableDirectionAdjust { - at.applyGridDirectionLocked(currentPrice) - } -} - -// applyGridDirectionLocked adjusts grid level sides based on the current direction (caller must hold lock) -func (at *AutoTrader) applyGridDirectionLocked(currentPrice float64) { - config := at.gridState.Config - direction := at.gridState.CurrentDirection - - // Get bias ratio from config, default to 0.7 (70%/30%) - biasRatio := config.DirectionBiasRatio - if biasRatio <= 0 || biasRatio > 1 { - biasRatio = 0.7 - } - - buyRatio, _ := direction.GetBuySellRatio(biasRatio) - - // For neutral: use price-based assignment (buy below, sell above) - if direction == market.GridDirectionNeutral { - for i := range at.gridState.Levels { - if at.gridState.Levels[i].Price <= currentPrice { - at.gridState.Levels[i].Side = "buy" - } else { - at.gridState.Levels[i].Side = "sell" - } - } - return - } - - totalLevels := len(at.gridState.Levels) - targetBuyLevels := int(float64(totalLevels) * buyRatio) - - switch direction { - case market.GridDirectionLong: - for i := range at.gridState.Levels { - at.gridState.Levels[i].Side = "buy" - } - - case market.GridDirectionShort: - for i := range at.gridState.Levels { - at.gridState.Levels[i].Side = "sell" - } - - case market.GridDirectionLongBias, market.GridDirectionShortBias: - buyCount := 0 - sellCount := 0 - - for i := range at.gridState.Levels { - needMoreBuys := buyCount < targetBuyLevels - needMoreSells := sellCount < (totalLevels - targetBuyLevels) - - if at.gridState.Levels[i].Price <= currentPrice { - if needMoreBuys { - at.gridState.Levels[i].Side = "buy" - buyCount++ - } else { - at.gridState.Levels[i].Side = "sell" - sellCount++ - } - } else { - if needMoreSells && direction == market.GridDirectionShortBias { - at.gridState.Levels[i].Side = "sell" - sellCount++ - } else if needMoreBuys && direction == market.GridDirectionLongBias { - at.gridState.Levels[i].Side = "buy" - buyCount++ - } else if needMoreSells { - at.gridState.Levels[i].Side = "sell" - sellCount++ - } else { - at.gridState.Levels[i].Side = "buy" - buyCount++ - } - } - } - } -} - // GridRiskInfo contains risk information for frontend display type GridRiskInfo struct { CurrentLeverage int `json:"current_leverage"` @@ -1661,190 +633,7 @@ type GridRiskInfo struct { BreakoutDirection string `json:"breakout_direction"` // Grid direction - CurrentGridDirection string `json:"current_grid_direction"` - DirectionChangeCount int `json:"direction_change_count"` - EnableDirectionAdjust bool `json:"enable_direction_adjust"` -} - -// GetGridRiskInfo returns current risk information for frontend display -func (at *AutoTrader) GetGridRiskInfo() *GridRiskInfo { - gridConfig := at.config.StrategyConfig.GridConfig - if gridConfig == nil { - return &GridRiskInfo{} - } - - at.gridState.mu.RLock() - defer at.gridState.mu.RUnlock() - - // Get current price - currentPrice, _ := at.trader.GetMarketPrice(gridConfig.Symbol) - - // Calculate effective leverage - totalInvestment := gridConfig.TotalInvestment - leverage := gridConfig.Leverage - - // Get current position value - positions, _ := at.trader.GetPositions() - var currentPositionValue float64 - var currentPositionSize float64 - for _, pos := range positions { - if sym, _ := pos["symbol"].(string); sym == gridConfig.Symbol { - size, _ := pos["positionAmt"].(float64) - entry, _ := pos["entryPrice"].(float64) - currentPositionValue = math.Abs(size * entry) - currentPositionSize = size - break - } - } - - effectiveLeverage := 0.0 - if totalInvestment > 0 { - effectiveLeverage = currentPositionValue / totalInvestment - } - - // Calculate max position based on regime - regimeLevel := market.RegimeLevel(at.gridState.CurrentRegimeLevel) - if regimeLevel == "" { - regimeLevel = market.RegimeLevelStandard - } - - // Use default position limit since GridStrategyConfig doesn't have regime-specific limits - // Default is 70% for standard regime - maxPositionPct := 70.0 - switch regimeLevel { - case market.RegimeLevelNarrow: - maxPositionPct = 40.0 - case market.RegimeLevelStandard: - maxPositionPct = 70.0 - case market.RegimeLevelWide: - maxPositionPct = 60.0 - case market.RegimeLevelVolatile: - maxPositionPct = 40.0 - } - - maxPosition := totalInvestment * maxPositionPct / 100 * float64(leverage) - - // Use default leverage limits since GridStrategyConfig doesn't have regime-specific limits - recommendedLeverage := leverage - switch regimeLevel { - case market.RegimeLevelNarrow: - recommendedLeverage = min(leverage, 2) - case market.RegimeLevelStandard: - recommendedLeverage = min(leverage, 4) - case market.RegimeLevelWide: - recommendedLeverage = min(leverage, 3) - case market.RegimeLevelVolatile: - recommendedLeverage = min(leverage, 2) - } - - // Calculate liquidation distance and price only when there's a position - var liquidationDistance float64 - var liquidationPrice float64 - if currentPositionSize != 0 && currentPrice > 0 { - liquidationDistance = 100.0 / float64(leverage) * 0.9 // ~90% of theoretical max - if currentPositionSize > 0 { - // Long position: liquidation below entry - liquidationPrice = currentPrice * (1 - liquidationDistance/100) - } else { - // Short position: liquidation above entry - liquidationPrice = currentPrice * (1 + liquidationDistance/100) - } - } - - positionPercent := 0.0 - if maxPosition > 0 { - positionPercent = currentPositionValue / maxPosition * 100 - } - - return &GridRiskInfo{ - CurrentLeverage: leverage, - EffectiveLeverage: effectiveLeverage, - RecommendedLeverage: recommendedLeverage, - - CurrentPosition: currentPositionValue, - MaxPosition: maxPosition, - PositionPercent: positionPercent, - - LiquidationPrice: liquidationPrice, - LiquidationDistance: liquidationDistance, - - RegimeLevel: string(regimeLevel), - - ShortBoxUpper: at.gridState.ShortBoxUpper, - ShortBoxLower: at.gridState.ShortBoxLower, - MidBoxUpper: at.gridState.MidBoxUpper, - MidBoxLower: at.gridState.MidBoxLower, - LongBoxUpper: at.gridState.LongBoxUpper, - LongBoxLower: at.gridState.LongBoxLower, - CurrentPrice: currentPrice, - - BreakoutLevel: at.gridState.BreakoutLevel, - BreakoutDirection: at.gridState.BreakoutDirection, - - CurrentGridDirection: string(at.gridState.CurrentDirection), - DirectionChangeCount: at.gridState.DirectionChangeCount, - EnableDirectionAdjust: gridConfig.EnableDirectionAdjust, - } -} - -// checkAndExecuteStopLoss checks if any filled level has exceeded stop loss and closes it -func (at *AutoTrader) checkAndExecuteStopLoss() { - gridConfig := at.config.StrategyConfig.GridConfig - if gridConfig.StopLossPct <= 0 { - return // Stop loss not configured - } - - currentPrice, err := at.trader.GetMarketPrice(gridConfig.Symbol) - if err != nil { - logger.Warnf("[Grid] Failed to get market price for stop loss check: %v", err) - return - } - - at.gridState.mu.Lock() - defer at.gridState.mu.Unlock() - - for i := range at.gridState.Levels { - level := &at.gridState.Levels[i] - if level.State != "filled" || level.PositionEntry <= 0 { - continue - } - - // Calculate loss percentage - var lossPct float64 - if level.Side == "buy" { - // Long position: loss when price drops - lossPct = (level.PositionEntry - currentPrice) / level.PositionEntry * 100 - } else { - // Short position: loss when price rises - lossPct = (currentPrice - level.PositionEntry) / level.PositionEntry * 100 - } - - // Check if stop loss triggered - if lossPct >= gridConfig.StopLossPct { - logger.Warnf("[Grid] STOP LOSS TRIGGERED: Level %d, entry=$%.2f, current=$%.2f, loss=%.2f%%", - i, level.PositionEntry, currentPrice, lossPct) - - // Close the position - var closeErr error - if level.Side == "buy" { - _, closeErr = at.trader.CloseLong(gridConfig.Symbol, level.PositionSize) - } else { - _, closeErr = at.trader.CloseShort(gridConfig.Symbol, level.PositionSize) - } - - if closeErr != nil { - logger.Errorf("[Grid] Failed to execute stop loss for level %d: %v", i, closeErr) - } else { - level.State = "stopped" - realizedLoss := -lossPct * level.AllocatedUSD / 100 - level.UnrealizedPnL = realizedLoss - at.gridState.TotalTrades++ - // Update daily PnL tracking (lock already held, update directly) - at.gridState.DailyPnL += realizedLoss - at.gridState.TotalProfit += realizedLoss - logger.Infof("[Grid] Stop loss executed: Level %d closed at $%.2f (loss %.2f%%)", - i, currentPrice, lossPct) - } - } - } + CurrentGridDirection string `json:"current_grid_direction"` + DirectionChangeCount int `json:"direction_change_count"` + EnableDirectionAdjust bool `json:"enable_direction_adjust"` } diff --git a/trader/auto_trader_grid_levels.go b/trader/auto_trader_grid_levels.go new file mode 100644 index 00000000..ddc2a408 --- /dev/null +++ b/trader/auto_trader_grid_levels.go @@ -0,0 +1,485 @@ +package trader + +import ( + "math" + "nofx/kernel" + "nofx/logger" + "nofx/market" + "nofx/store" +) + +// ============================================================================ +// Grid Level Calculation and Rebalancing +// ============================================================================ + +// calculateDefaultBounds calculates default bounds based on price +func (at *AutoTrader) calculateDefaultBounds(price float64, config *store.GridStrategyConfig) { + // Default: +/-3% from current price + multiplier := 0.03 * float64(config.GridCount) / 10 + at.gridState.UpperPrice = price * (1 + multiplier) + at.gridState.LowerPrice = price * (1 - multiplier) +} + +// calculateATRBounds calculates bounds using ATR +func (at *AutoTrader) calculateATRBounds(price float64, mktData *market.Data, config *store.GridStrategyConfig) { + atr := 0.0 + if mktData.LongerTermContext != nil { + atr = mktData.LongerTermContext.ATR14 + } + + if atr <= 0 { + at.calculateDefaultBounds(price, config) + return + } + + multiplier := config.ATRMultiplier + if multiplier <= 0 { + multiplier = 2.0 + } + + halfRange := atr * multiplier + at.gridState.UpperPrice = price + halfRange + at.gridState.LowerPrice = price - halfRange +} + +// initializeGridLevels creates the grid level structure +func (at *AutoTrader) initializeGridLevels(currentPrice float64, config *store.GridStrategyConfig) { + levels := make([]kernel.GridLevelInfo, config.GridCount) + totalWeight := 0.0 + weights := make([]float64, config.GridCount) + + // Calculate weights based on distribution + for i := 0; i < config.GridCount; i++ { + switch config.Distribution { + case "gaussian": + // Gaussian distribution - more weight in the middle + center := float64(config.GridCount-1) / 2 + sigma := float64(config.GridCount) / 4 + weights[i] = math.Exp(-math.Pow(float64(i)-center, 2) / (2 * sigma * sigma)) + case "pyramid": + // Pyramid - more weight at bottom + weights[i] = float64(config.GridCount - i) + default: // uniform + weights[i] = 1.0 + } + totalWeight += weights[i] + } + + // Create levels + for i := 0; i < config.GridCount; i++ { + price := at.gridState.LowerPrice + float64(i)*at.gridState.GridSpacing + allocatedUSD := config.TotalInvestment * weights[i] / totalWeight + + // Determine initial side (below current price = buy, above = sell) + side := "buy" + if price > currentPrice { + side = "sell" + } + + levels[i] = kernel.GridLevelInfo{ + Index: i, + Price: price, + State: "empty", + Side: side, + AllocatedUSD: allocatedUSD, + } + } + + at.gridState.Levels = levels + + // Apply direction-based side assignment if enabled + if config.EnableDirectionAdjust { + at.applyGridDirection(currentPrice) + } +} + +// applyGridDirection adjusts grid level sides based on the current direction +// This redistributes buy/sell levels according to the direction bias ratio +func (at *AutoTrader) applyGridDirection(currentPrice float64) { + config := at.gridState.Config + direction := at.gridState.CurrentDirection + + // Get bias ratio from config, default to 0.7 (70%/30%) + biasRatio := config.DirectionBiasRatio + if biasRatio <= 0 || biasRatio > 1 { + biasRatio = 0.7 + } + + buyRatio, _ := direction.GetBuySellRatio(biasRatio) + + // Calculate how many levels should be buy vs sell based on direction + totalLevels := len(at.gridState.Levels) + targetBuyLevels := int(float64(totalLevels) * buyRatio) + + // For neutral: use price-based assignment (buy below, sell above) + if direction == market.GridDirectionNeutral { + for i := range at.gridState.Levels { + if at.gridState.Levels[i].Price <= currentPrice { + at.gridState.Levels[i].Side = "buy" + } else { + at.gridState.Levels[i].Side = "sell" + } + } + return + } + + // For long/long_bias: more buy levels + // For short/short_bias: more sell levels + switch direction { + case market.GridDirectionLong: + // 100% buy - all levels are buy + for i := range at.gridState.Levels { + at.gridState.Levels[i].Side = "buy" + } + + case market.GridDirectionShort: + // 100% sell - all levels are sell + for i := range at.gridState.Levels { + at.gridState.Levels[i].Side = "sell" + } + + case market.GridDirectionLongBias, market.GridDirectionShortBias: + // Assign sides based on position relative to current price + // For long_bias: keep all below as buy, convert some above to buy + // For short_bias: keep all above as sell, convert some below to sell + buyCount := 0 + sellCount := 0 + + for i := range at.gridState.Levels { + needMoreBuys := buyCount < targetBuyLevels + needMoreSells := sellCount < (totalLevels - targetBuyLevels) + + if at.gridState.Levels[i].Price <= currentPrice { + // Level below or at current price + if needMoreBuys { + at.gridState.Levels[i].Side = "buy" + buyCount++ + } else { + at.gridState.Levels[i].Side = "sell" + sellCount++ + } + } else { + // Level above current price + if needMoreSells && direction == market.GridDirectionShortBias { + at.gridState.Levels[i].Side = "sell" + sellCount++ + } else if needMoreBuys && direction == market.GridDirectionLongBias { + at.gridState.Levels[i].Side = "buy" + buyCount++ + } else if needMoreSells { + at.gridState.Levels[i].Side = "sell" + sellCount++ + } else { + at.gridState.Levels[i].Side = "buy" + buyCount++ + } + } + } + } + + logger.Infof("[Grid] Applied direction %s: buy_ratio=%.0f%%, levels reconfigured", + direction, buyRatio*100) +} + +// checkGridSkew checks if grid is heavily skewed (too many fills on one side) +// Returns: (skewed bool, buyFilledCount int, sellFilledCount int) +func (at *AutoTrader) checkGridSkew() (bool, int, int) { + at.gridState.mu.RLock() + defer at.gridState.mu.RUnlock() + + buyFilled := 0 + sellFilled := 0 + buyEmpty := 0 + sellEmpty := 0 + + for _, level := range at.gridState.Levels { + if level.Side == "buy" { + if level.State == "filled" { + buyFilled++ + } else if level.State == "empty" { + buyEmpty++ + } + } else { + if level.State == "filled" { + sellFilled++ + } else if level.State == "empty" { + sellEmpty++ + } + } + } + + // Grid is skewed if one side has 3x more fills than the other + // or if one side is completely empty + skewed := false + if buyFilled > 0 && sellFilled == 0 && sellEmpty > 5 { + skewed = true // All buys filled, no sells + } else if sellFilled > 0 && buyFilled == 0 && buyEmpty > 5 { + skewed = true // All sells filled, no buys + } else if buyFilled >= 3*sellFilled && buyFilled > 5 { + skewed = true + } else if sellFilled >= 3*buyFilled && sellFilled > 5 { + skewed = true + } + + return skewed, buyFilled, sellFilled +} + +// autoAdjustGrid automatically adjusts grid when heavily skewed +func (at *AutoTrader) autoAdjustGrid() { + skewed, buyFilled, sellFilled := at.checkGridSkew() + if !skewed { + return + } + + logger.Warnf("[Grid] Grid heavily skewed: buy_filled=%d, sell_filled=%d. Auto-adjusting...", + buyFilled, sellFilled) + + gridConfig := at.config.StrategyConfig.GridConfig + + // Get current price + currentPrice, err := at.trader.GetMarketPrice(gridConfig.Symbol) + if err != nil { + logger.Errorf("[Grid] Failed to get price for auto-adjust: %v", err) + return + } + + // Check if price is near grid boundary + at.gridState.mu.RLock() + upper := at.gridState.UpperPrice + lower := at.gridState.LowerPrice + at.gridState.mu.RUnlock() + + // Only adjust if price has moved significantly (>30% of grid range) + gridRange := upper - lower + midPrice := (upper + lower) / 2 + priceDeviation := math.Abs(currentPrice - midPrice) + + if priceDeviation < gridRange*0.3 { + return // Price still near center, don't adjust + } + + logger.Infof("[Grid] Adjusting grid around new price $%.2f", currentPrice) + + // Cancel existing orders first (before taking the lock for state modification) + if err := at.cancelAllGridOrders(); err != nil { + logger.Errorf("[Grid] Failed to cancel orders during auto-adjust: %v", err) + // Continue with adjustment anyway + } + + // CRITICAL FIX: Hold lock for the entire adjustment operation to ensure atomicity + at.gridState.mu.Lock() + defer at.gridState.mu.Unlock() + + // Preserve filled positions before reinitializing + filledPositions := make(map[int]kernel.GridLevelInfo) + for i, level := range at.gridState.Levels { + if level.State == "filled" { + filledPositions[i] = level + } + } + + // CRITICAL FIX: Recalculate grid bounds centered on current price + // Use the same logic as InitializeGrid() - either ATR-based or default percentage + if gridConfig.UseATRBounds { + // Try to get ATR for bound calculation + mktData, err := market.GetWithTimeframes(gridConfig.Symbol, []string{"4h"}, "4h", 20) + if err != nil { + logger.Warnf("[Grid] Failed to get market data for ATR during adjust: %v, using default bounds", err) + at.calculateDefaultBoundsLocked(currentPrice, gridConfig) + } else { + at.calculateATRBoundsLocked(currentPrice, mktData, gridConfig) + } + } else { + // Use default bounds calculation (scaled by grid count) + at.calculateDefaultBoundsLocked(currentPrice, gridConfig) + } + + // Recalculate grid spacing based on new bounds + at.gridState.GridSpacing = (at.gridState.UpperPrice - at.gridState.LowerPrice) / float64(gridConfig.GridCount-1) + + logger.Infof("[Grid] New bounds: $%.2f - $%.2f, spacing: $%.2f", + at.gridState.LowerPrice, at.gridState.UpperPrice, at.gridState.GridSpacing) + + // Initialize new grid levels (without lock since we already hold it) + at.initializeGridLevelsLocked(currentPrice, gridConfig) + + // CRITICAL FIX: Restore filled positions - find closest new level for each filled position + for _, filledLevel := range filledPositions { + closestIdx := -1 + closestDist := math.MaxFloat64 + + for i, newLevel := range at.gridState.Levels { + dist := math.Abs(newLevel.Price - filledLevel.PositionEntry) + if dist < closestDist { + closestDist = dist + closestIdx = i + } + } + + if closestIdx >= 0 { + // Restore the filled state to the closest level + at.gridState.Levels[closestIdx].State = "filled" + at.gridState.Levels[closestIdx].PositionEntry = filledLevel.PositionEntry + at.gridState.Levels[closestIdx].PositionSize = filledLevel.PositionSize + at.gridState.Levels[closestIdx].UnrealizedPnL = filledLevel.UnrealizedPnL + at.gridState.Levels[closestIdx].OrderID = filledLevel.OrderID + at.gridState.Levels[closestIdx].OrderQuantity = filledLevel.OrderQuantity + logger.Infof("[Grid] Restored filled position at level %d (entry $%.2f)", closestIdx, filledLevel.PositionEntry) + } + } +} + +// calculateDefaultBoundsLocked calculates default bounds (caller must hold lock) +func (at *AutoTrader) calculateDefaultBoundsLocked(price float64, config *store.GridStrategyConfig) { + // Default: +/-3% from current price, scaled by grid count + multiplier := 0.03 * float64(config.GridCount) / 10 + at.gridState.UpperPrice = price * (1 + multiplier) + at.gridState.LowerPrice = price * (1 - multiplier) +} + +// calculateATRBoundsLocked calculates bounds using ATR (caller must hold lock) +func (at *AutoTrader) calculateATRBoundsLocked(price float64, mktData *market.Data, config *store.GridStrategyConfig) { + atr := 0.0 + if mktData.LongerTermContext != nil { + atr = mktData.LongerTermContext.ATR14 + } + + if atr <= 0 { + at.calculateDefaultBoundsLocked(price, config) + return + } + + multiplier := config.ATRMultiplier + if multiplier <= 0 { + multiplier = 2.0 + } + + halfRange := atr * multiplier + at.gridState.UpperPrice = price + halfRange + at.gridState.LowerPrice = price - halfRange +} + +// initializeGridLevelsLocked creates the grid level structure (caller must hold lock) +func (at *AutoTrader) initializeGridLevelsLocked(currentPrice float64, config *store.GridStrategyConfig) { + levels := make([]kernel.GridLevelInfo, config.GridCount) + totalWeight := 0.0 + weights := make([]float64, config.GridCount) + + // Calculate weights based on distribution + for i := 0; i < config.GridCount; i++ { + switch config.Distribution { + case "gaussian": + // Gaussian distribution - more weight in the middle + center := float64(config.GridCount-1) / 2 + sigma := float64(config.GridCount) / 4 + weights[i] = math.Exp(-math.Pow(float64(i)-center, 2) / (2 * sigma * sigma)) + case "pyramid": + // Pyramid - more weight at bottom + weights[i] = float64(config.GridCount - i) + default: // uniform + weights[i] = 1.0 + } + totalWeight += weights[i] + } + + // Create levels + for i := 0; i < config.GridCount; i++ { + price := at.gridState.LowerPrice + float64(i)*at.gridState.GridSpacing + allocatedUSD := config.TotalInvestment * weights[i] / totalWeight + + // Determine initial side (below current price = buy, above = sell) + side := "buy" + if price > currentPrice { + side = "sell" + } + + levels[i] = kernel.GridLevelInfo{ + Index: i, + Price: price, + State: "empty", + Side: side, + AllocatedUSD: allocatedUSD, + } + } + + at.gridState.Levels = levels + + // Apply direction-based side assignment if enabled (note: caller holds lock) + if config.EnableDirectionAdjust { + at.applyGridDirectionLocked(currentPrice) + } +} + +// applyGridDirectionLocked adjusts grid level sides based on the current direction (caller must hold lock) +func (at *AutoTrader) applyGridDirectionLocked(currentPrice float64) { + config := at.gridState.Config + direction := at.gridState.CurrentDirection + + // Get bias ratio from config, default to 0.7 (70%/30%) + biasRatio := config.DirectionBiasRatio + if biasRatio <= 0 || biasRatio > 1 { + biasRatio = 0.7 + } + + buyRatio, _ := direction.GetBuySellRatio(biasRatio) + + // For neutral: use price-based assignment (buy below, sell above) + if direction == market.GridDirectionNeutral { + for i := range at.gridState.Levels { + if at.gridState.Levels[i].Price <= currentPrice { + at.gridState.Levels[i].Side = "buy" + } else { + at.gridState.Levels[i].Side = "sell" + } + } + return + } + + totalLevels := len(at.gridState.Levels) + targetBuyLevels := int(float64(totalLevels) * buyRatio) + + switch direction { + case market.GridDirectionLong: + for i := range at.gridState.Levels { + at.gridState.Levels[i].Side = "buy" + } + + case market.GridDirectionShort: + for i := range at.gridState.Levels { + at.gridState.Levels[i].Side = "sell" + } + + case market.GridDirectionLongBias, market.GridDirectionShortBias: + buyCount := 0 + sellCount := 0 + + for i := range at.gridState.Levels { + needMoreBuys := buyCount < targetBuyLevels + needMoreSells := sellCount < (totalLevels - targetBuyLevels) + + if at.gridState.Levels[i].Price <= currentPrice { + if needMoreBuys { + at.gridState.Levels[i].Side = "buy" + buyCount++ + } else { + at.gridState.Levels[i].Side = "sell" + sellCount++ + } + } else { + if needMoreSells && direction == market.GridDirectionShortBias { + at.gridState.Levels[i].Side = "sell" + sellCount++ + } else if needMoreBuys && direction == market.GridDirectionLongBias { + at.gridState.Levels[i].Side = "buy" + buyCount++ + } else if needMoreSells { + at.gridState.Levels[i].Side = "sell" + sellCount++ + } else { + at.gridState.Levels[i].Side = "buy" + buyCount++ + } + } + } + } +} diff --git a/trader/auto_trader_grid_orders.go b/trader/auto_trader_grid_orders.go new file mode 100644 index 00000000..c19dc6d0 --- /dev/null +++ b/trader/auto_trader_grid_orders.go @@ -0,0 +1,419 @@ +package trader + +import ( + "fmt" + "math" + "nofx/kernel" + "nofx/logger" + "time" +) + +// ============================================================================ +// Grid Order Placement and Management +// ============================================================================ + +// checkTotalPositionLimit checks if adding a new position would exceed total limits +// Returns: (allowed bool, currentPositionValue float64, maxAllowed float64) +func (at *AutoTrader) checkTotalPositionLimit(symbol string, additionalValue float64) (bool, float64, float64) { + gridConfig := at.config.StrategyConfig.GridConfig + + // Calculate max allowed total position value + // Total position should not exceed: TotalInvestment * Leverage + maxTotalPositionValue := gridConfig.TotalInvestment * float64(gridConfig.Leverage) + + // Get current position value from exchange + currentPositionValue := 0.0 + positions, err := at.trader.GetPositions() + if err == nil { + for _, pos := range positions { + if sym, ok := pos["symbol"].(string); ok && sym == symbol { + if size, ok := pos["positionAmt"].(float64); ok { + if price, ok := pos["markPrice"].(float64); ok { + currentPositionValue = math.Abs(size) * price + } else if entryPrice, ok := pos["entryPrice"].(float64); ok { + currentPositionValue = math.Abs(size) * entryPrice + } + } + } + } + } + + // Also count pending orders as potential position + at.gridState.mu.RLock() + pendingValue := 0.0 + for _, level := range at.gridState.Levels { + if level.State == "pending" { + pendingValue += level.OrderQuantity * level.Price + } + } + at.gridState.mu.RUnlock() + + totalAfterOrder := currentPositionValue + pendingValue + additionalValue + allowed := totalAfterOrder <= maxTotalPositionValue + + return allowed, currentPositionValue + pendingValue, maxTotalPositionValue +} + +// placeGridLimitOrder places a limit order for grid trading +func (at *AutoTrader) placeGridLimitOrder(d *kernel.Decision, side string) error { + // Check if trader supports GridTrader interface + gridTrader, ok := at.trader.(GridTrader) + if !ok { + // Fallback to adapter + gridTrader = NewGridTraderAdapter(at.trader) + } + + gridConfig := at.config.StrategyConfig.GridConfig + + // CRITICAL: Validate and cap quantity to prevent excessive position sizes + // This protects against AI miscalculations or leverage misconfigurations + quantity := d.Quantity + if d.Price > 0 && gridConfig.TotalInvestment > 0 { + // Calculate max allowed position value per grid level + // Each level gets proportional share of total investment + maxMarginPerLevel := gridConfig.TotalInvestment / float64(gridConfig.GridCount) + maxPositionValuePerLevel := maxMarginPerLevel * float64(gridConfig.Leverage) + maxQuantityPerLevel := maxPositionValuePerLevel / d.Price + + // Also get the level's allocated USD for additional validation + at.gridState.mu.RLock() + var levelAllocatedUSD float64 + if d.LevelIndex >= 0 && d.LevelIndex < len(at.gridState.Levels) { + levelAllocatedUSD = at.gridState.Levels[d.LevelIndex].AllocatedUSD + } + at.gridState.mu.RUnlock() + + // Use level-specific allocation if available + if levelAllocatedUSD > 0 { + levelMaxPositionValue := levelAllocatedUSD * float64(gridConfig.Leverage) + levelMaxQuantity := levelMaxPositionValue / d.Price + if levelMaxQuantity < maxQuantityPerLevel { + maxQuantityPerLevel = levelMaxQuantity + } + } + + // Cap quantity if it exceeds the maximum allowed + if quantity > maxQuantityPerLevel { + logger.Warnf("[Grid] Quantity %.4f exceeds max allowed %.4f (position_value $%.2f > max $%.2f), capping", + quantity, maxQuantityPerLevel, quantity*d.Price, maxPositionValuePerLevel) + quantity = maxQuantityPerLevel + } + + // Safety check: ensure position value is reasonable (within 2x of intended max as absolute limit) + positionValue := quantity * d.Price + absoluteMaxValue := gridConfig.TotalInvestment * float64(gridConfig.Leverage) * 2 // 2x safety margin + if positionValue > absoluteMaxValue { + logger.Errorf("[Grid] CRITICAL: Position value $%.2f exceeds absolute max $%.2f! Rejecting order.", + positionValue, absoluteMaxValue) + return fmt.Errorf("position value $%.2f exceeds safety limit $%.2f", positionValue, absoluteMaxValue) + } + } + + // CRITICAL: Check total position limit before placing order + orderValue := quantity * d.Price + allowed, currentValue, maxValue := at.checkTotalPositionLimit(d.Symbol, orderValue) + if !allowed { + logger.Errorf("[Grid] TOTAL POSITION LIMIT EXCEEDED: current=$%.2f + order=$%.2f > max=$%.2f. Rejecting order.", + currentValue, orderValue, maxValue) + return fmt.Errorf("total position value $%.2f would exceed limit $%.2f", currentValue+orderValue, maxValue) + } + + req := &LimitOrderRequest{ + Symbol: d.Symbol, + Side: side, + Price: d.Price, + Quantity: quantity, // Use validated/capped quantity + Leverage: gridConfig.Leverage, + PostOnly: gridConfig.UseMakerOnly, + ReduceOnly: false, + ClientID: fmt.Sprintf("grid-%d-%d", d.LevelIndex, time.Now().UnixNano()%1000000), + } + + result, err := gridTrader.PlaceLimitOrder(req) + if err != nil { + return fmt.Errorf("failed to place limit order: %w", err) + } + + // Update grid level state + at.gridState.mu.Lock() + if d.LevelIndex >= 0 && d.LevelIndex < len(at.gridState.Levels) { + at.gridState.Levels[d.LevelIndex].State = "pending" + at.gridState.Levels[d.LevelIndex].OrderID = result.OrderID + at.gridState.Levels[d.LevelIndex].OrderQuantity = d.Quantity + at.gridState.OrderBook[result.OrderID] = d.LevelIndex + } + at.gridState.mu.Unlock() + + logger.Infof("[Grid] Placed %s limit order at $%.2f, qty=%.4f, level=%d, orderID=%s", + side, d.Price, d.Quantity, d.LevelIndex, result.OrderID) + + return nil +} + +// cancelGridOrder cancels a specific grid order +func (at *AutoTrader) cancelGridOrder(d *kernel.Decision) error { + gridTrader, ok := at.trader.(GridTrader) + if !ok { + gridTrader = NewGridTraderAdapter(at.trader) + } + + if err := gridTrader.CancelOrder(d.Symbol, d.OrderID); err != nil { + return fmt.Errorf("failed to cancel order: %w", err) + } + + // Update state + at.gridState.mu.Lock() + if levelIdx, ok := at.gridState.OrderBook[d.OrderID]; ok { + if levelIdx >= 0 && levelIdx < len(at.gridState.Levels) { + at.gridState.Levels[levelIdx].State = "empty" + at.gridState.Levels[levelIdx].OrderID = "" + at.gridState.Levels[levelIdx].OrderQuantity = 0 + } + delete(at.gridState.OrderBook, d.OrderID) + } + at.gridState.mu.Unlock() + + logger.Infof("[Grid] Cancelled order: %s", d.OrderID) + return nil +} + +// cancelAllGridOrders cancels all grid orders +func (at *AutoTrader) cancelAllGridOrders() error { + gridConfig := at.config.StrategyConfig.GridConfig + + if err := at.trader.CancelAllOrders(gridConfig.Symbol); err != nil { + return fmt.Errorf("failed to cancel all orders: %w", err) + } + + // Reset all pending levels + at.gridState.mu.Lock() + for i := range at.gridState.Levels { + if at.gridState.Levels[i].State == "pending" { + at.gridState.Levels[i].State = "empty" + at.gridState.Levels[i].OrderID = "" + at.gridState.Levels[i].OrderQuantity = 0 + } + } + at.gridState.OrderBook = make(map[string]int) + at.gridState.mu.Unlock() + + logger.Infof("[Grid] Cancelled all orders") + return nil +} + +// pauseGrid pauses grid trading +func (at *AutoTrader) pauseGrid(reason string) error { + at.cancelAllGridOrders() + + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + + logger.Infof("[Grid] Paused: %s", reason) + return nil +} + +// resumeGrid resumes grid trading +func (at *AutoTrader) resumeGrid() error { + at.gridState.mu.Lock() + at.gridState.IsPaused = false + at.gridState.mu.Unlock() + + logger.Infof("[Grid] Resumed") + return nil +} + +// adjustGrid adjusts grid parameters +func (at *AutoTrader) adjustGrid(d *kernel.Decision) error { + // Cancel existing orders first + at.cancelAllGridOrders() + + gridConfig := at.config.StrategyConfig.GridConfig + + // Get current price + price, err := at.trader.GetMarketPrice(gridConfig.Symbol) + if err != nil { + return fmt.Errorf("failed to get market price: %w", err) + } + + // Reinitialize grid levels + at.initializeGridLevels(price, gridConfig) + + logger.Infof("[Grid] Adjusted grid bounds around price $%.2f", price) + return nil +} + +// syncGridState syncs grid state with exchange +func (at *AutoTrader) syncGridState() { + gridConfig := at.config.StrategyConfig.GridConfig + + // Get open orders from exchange + openOrders, err := at.trader.GetOpenOrders(gridConfig.Symbol) + if err != nil { + logger.Warnf("[Grid] Failed to get open orders: %v", err) + return + } + + // Build set of active order IDs + activeOrderIDs := make(map[string]bool) + for _, order := range openOrders { + activeOrderIDs[order.OrderID] = true + } + + // Get current positions to verify fills + positions, err := at.trader.GetPositions() + currentPositionSize := 0.0 + if err != nil { + logger.Warnf("[Grid] Failed to get positions for state sync: %v", err) + } else { + for _, pos := range positions { + if sym, ok := pos["symbol"].(string); ok && sym == gridConfig.Symbol { + if size, ok := pos["positionAmt"].(float64); ok { + currentPositionSize = size + } + } + } + } + + // Update levels based on order status + at.gridState.mu.Lock() + expectedPositionSize := 0.0 + for _, level := range at.gridState.Levels { + if level.State == "filled" { + expectedPositionSize += level.PositionSize + } + } + + for i := range at.gridState.Levels { + level := &at.gridState.Levels[i] + if level.State == "pending" && level.OrderID != "" { + if !activeOrderIDs[level.OrderID] { + // Order no longer exists - check if position changed to determine fill vs cancel + // This is a heuristic - ideally we'd query order history + // If current position is larger than expected filled positions, this order was likely filled + if math.Abs(currentPositionSize) > math.Abs(expectedPositionSize) { + // Position increased, likely filled + level.State = "filled" + level.PositionEntry = level.Price + level.PositionSize = level.OrderQuantity + at.gridState.TotalTrades++ + logger.Infof("[Grid] Level %d order filled at $%.2f", i, level.Price) + } else { + // Position didn't increase as expected, likely cancelled + level.State = "empty" + level.OrderID = "" + level.OrderQuantity = 0 + logger.Infof("[Grid] Level %d order cancelled/expired", i) + } + delete(at.gridState.OrderBook, level.OrderID) + } + } + } + at.gridState.mu.Unlock() + + logger.Debugf("[Grid] Synced state: position=%.4f, orders=%d", currentPositionSize, len(openOrders)) + + // Check stop loss + at.checkAndExecuteStopLoss() + + // Check grid skew + at.autoAdjustGrid() +} + +// closeAllPositions closes all open positions for the grid symbol +func (at *AutoTrader) closeAllPositions() error { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig == nil { + return nil + } + + positions, err := at.trader.GetPositions() + if err != nil { + return fmt.Errorf("failed to get positions: %w", err) + } + + for _, pos := range positions { + symbol, _ := pos["symbol"].(string) + if symbol != gridConfig.Symbol { + continue + } + + size, _ := pos["positionAmt"].(float64) + if size == 0 { + continue + } + + if size > 0 { + _, err = at.trader.CloseLong(symbol, size) + } else { + _, err = at.trader.CloseShort(symbol, -size) + } + if err != nil { + logger.Infof("Failed to close position: %v", err) + } + } + + return nil +} + +// checkAndExecuteStopLoss checks if any filled level has exceeded stop loss and closes it +func (at *AutoTrader) checkAndExecuteStopLoss() { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig.StopLossPct <= 0 { + return // Stop loss not configured + } + + currentPrice, err := at.trader.GetMarketPrice(gridConfig.Symbol) + if err != nil { + logger.Warnf("[Grid] Failed to get market price for stop loss check: %v", err) + return + } + + at.gridState.mu.Lock() + defer at.gridState.mu.Unlock() + + for i := range at.gridState.Levels { + level := &at.gridState.Levels[i] + if level.State != "filled" || level.PositionEntry <= 0 { + continue + } + + // Calculate loss percentage + var lossPct float64 + if level.Side == "buy" { + // Long position: loss when price drops + lossPct = (level.PositionEntry - currentPrice) / level.PositionEntry * 100 + } else { + // Short position: loss when price rises + lossPct = (currentPrice - level.PositionEntry) / level.PositionEntry * 100 + } + + // Check if stop loss triggered + if lossPct >= gridConfig.StopLossPct { + logger.Warnf("[Grid] STOP LOSS TRIGGERED: Level %d, entry=$%.2f, current=$%.2f, loss=%.2f%%", + i, level.PositionEntry, currentPrice, lossPct) + + // Close the position + var closeErr error + if level.Side == "buy" { + _, closeErr = at.trader.CloseLong(gridConfig.Symbol, level.PositionSize) + } else { + _, closeErr = at.trader.CloseShort(gridConfig.Symbol, level.PositionSize) + } + + if closeErr != nil { + logger.Errorf("[Grid] Failed to execute stop loss for level %d: %v", i, closeErr) + } else { + level.State = "stopped" + realizedLoss := -lossPct * level.AllocatedUSD / 100 + level.UnrealizedPnL = realizedLoss + at.gridState.TotalTrades++ + // Update daily PnL tracking (lock already held, update directly) + at.gridState.DailyPnL += realizedLoss + at.gridState.TotalProfit += realizedLoss + logger.Infof("[Grid] Stop loss executed: Level %d closed at $%.2f (loss %.2f%%)", + i, currentPrice, lossPct) + } + } + } +} diff --git a/trader/auto_trader_grid_regime.go b/trader/auto_trader_grid_regime.go new file mode 100644 index 00000000..0999ff1f --- /dev/null +++ b/trader/auto_trader_grid_regime.go @@ -0,0 +1,345 @@ +package trader + +import ( + "fmt" + "math" + "nofx/logger" + "nofx/market" + "time" +) + +// ============================================================================ +// Regime Detection and Strategy Switching +// ============================================================================ + +// checkBoxBreakout checks for multi-period box breakouts and takes appropriate action +func (at *AutoTrader) checkBoxBreakout() error { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig == nil { + return nil + } + + // Get box data + box, err := market.GetBoxData(gridConfig.Symbol) + if err != nil { + logger.Infof("Failed to get box data: %v", err) + return nil // Non-fatal, continue with other checks + } + + // Update grid state with box values + at.gridState.mu.Lock() + at.gridState.ShortBoxUpper = box.ShortUpper + at.gridState.ShortBoxLower = box.ShortLower + at.gridState.MidBoxUpper = box.MidUpper + at.gridState.MidBoxLower = box.MidLower + at.gridState.LongBoxUpper = box.LongUpper + at.gridState.LongBoxLower = box.LongLower + at.gridState.mu.Unlock() + + // Detect breakout + breakoutLevel, direction := detectBoxBreakout(box) + + // Get current breakout state + state := &BreakoutState{ + Level: market.BreakoutLevel(at.gridState.BreakoutLevel), + Direction: at.gridState.BreakoutDirection, + ConfirmCount: at.gridState.BreakoutConfirmCount, + } + + // Check if breakout is confirmed (3 candles) + confirmed := confirmBreakout(state, breakoutLevel, direction) + + // Update grid state + at.gridState.mu.Lock() + at.gridState.BreakoutLevel = string(state.Level) + at.gridState.BreakoutDirection = state.Direction + at.gridState.BreakoutConfirmCount = state.ConfirmCount + at.gridState.mu.Unlock() + + if !confirmed { + return nil + } + + // Take action based on breakout level + // Use direction-aware action if enabled + enableDirectionAdjust := gridConfig.EnableDirectionAdjust + action := getBreakoutActionWithDirection(breakoutLevel, enableDirectionAdjust) + + // If direction adjustment action, determine the new direction + if action == BreakoutActionAdjustDirection { + box, _ := market.GetBoxData(gridConfig.Symbol) + newDirection := determineGridDirection(box, at.gridState.CurrentDirection, breakoutLevel, direction) + return at.executeDirectionAdjustment(newDirection) + } + + return at.executeBreakoutAction(action) +} + +// executeBreakoutAction executes the appropriate action for a breakout +func (at *AutoTrader) executeBreakoutAction(action BreakoutAction) error { + switch action { + case BreakoutActionReducePosition: + // Short box breakout: reduce position to 50% + logger.Infof("Short box breakout confirmed, reducing position to 50%%") + at.gridState.mu.Lock() + at.gridState.PositionReductionPct = 50 + at.gridState.mu.Unlock() + return nil + + case BreakoutActionPauseGrid: + // Mid box breakout: pause grid + cancel orders + logger.Infof("Mid box breakout confirmed, pausing grid and canceling orders") + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + return at.cancelAllGridOrders() + + case BreakoutActionCloseAll: + // Long box breakout: pause + cancel + close all + logger.Infof("Long box breakout confirmed, closing all positions") + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + if err := at.cancelAllGridOrders(); err != nil { + logger.Infof("Failed to cancel orders: %v", err) + } + return at.closeAllPositions() + + case BreakoutActionAdjustDirection: + // Direction adjustment is handled separately via executeDirectionAdjustment + // This case should not be reached, but handle gracefully + logger.Infof("Direction adjustment action received via executeBreakoutAction") + return nil + } + + return nil +} + +// executeDirectionAdjustment handles grid direction changes based on box breakout +func (at *AutoTrader) executeDirectionAdjustment(newDirection market.GridDirection) error { + at.gridState.mu.RLock() + oldDirection := at.gridState.CurrentDirection + at.gridState.mu.RUnlock() + + if oldDirection == newDirection { + return nil // No change needed + } + + logger.Infof("[Grid] Direction adjustment: %s -> %s", oldDirection, newDirection) + + // Cancel existing orders before adjusting + if err := at.cancelAllGridOrders(); err != nil { + logger.Warnf("[Grid] Failed to cancel orders during direction adjustment: %v", err) + } + + // Apply the new direction + return at.adjustGridDirection(newDirection) +} + +// adjustGridDirection handles runtime direction adjustment when breakout is detected +func (at *AutoTrader) adjustGridDirection(newDirection market.GridDirection) error { + at.gridState.mu.Lock() + defer at.gridState.mu.Unlock() + + oldDirection := at.gridState.CurrentDirection + if oldDirection == newDirection { + return nil // No change needed + } + + at.gridState.CurrentDirection = newDirection + at.gridState.DirectionChangedAt = time.Now() + at.gridState.DirectionChangeCount++ + + logger.Infof("[Grid] Direction changed: %s -> %s (change count: %d)", + oldDirection, newDirection, at.gridState.DirectionChangeCount) + + // Get current price for recalculation + currentPrice, err := at.trader.GetMarketPrice(at.gridState.Config.Symbol) + if err != nil { + return fmt.Errorf("failed to get market price: %w", err) + } + + // Reapply direction to grid levels + at.applyGridDirection(currentPrice) + + return nil +} + +// checkFalseBreakoutRecovery checks if price has returned to box after breakout +func (at *AutoTrader) checkFalseBreakoutRecovery() error { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig == nil { + return nil + } + + at.gridState.mu.RLock() + breakoutLevel := at.gridState.BreakoutLevel + isPaused := at.gridState.IsPaused + positionReduction := at.gridState.PositionReductionPct + currentDirection := at.gridState.CurrentDirection + at.gridState.mu.RUnlock() + + // Only check if we had a breakout or non-neutral direction + needsRecoveryCheck := breakoutLevel != string(market.BreakoutNone) || + positionReduction != 0 || + isPaused || + (gridConfig.EnableDirectionAdjust && currentDirection != market.GridDirectionNeutral) + + if !needsRecoveryCheck { + return nil + } + + // Get current box data + box, err := market.GetBoxData(gridConfig.Symbol) + if err != nil { + return nil + } + + // Check if price is back inside the long box + if box.CurrentPrice >= box.LongLower && box.CurrentPrice <= box.LongUpper { + logger.Infof("Price returned to box, recovering with 50%% position") + + at.gridState.mu.Lock() + at.gridState.BreakoutLevel = string(market.BreakoutNone) + at.gridState.BreakoutDirection = "" + at.gridState.BreakoutConfirmCount = 0 + at.gridState.PositionReductionPct = 50 // Recover at 50% + at.gridState.IsPaused = false + at.gridState.mu.Unlock() + } + + // Check for direction recovery toward neutral (if direction adjustment is enabled) + if gridConfig.EnableDirectionAdjust && currentDirection != market.GridDirectionNeutral { + if shouldRecoverDirection(box, currentDirection) { + newDirection := determineRecoveryDirection(box.CurrentPrice, box, currentDirection) + if newDirection != currentDirection { + logger.Infof("[Grid] Direction recovery: %s -> %s (price back in short box)", + currentDirection, newDirection) + at.adjustGridDirection(newDirection) + } + } + } + + return nil +} + +// GetGridRiskInfo returns current risk information for frontend display +func (at *AutoTrader) GetGridRiskInfo() *GridRiskInfo { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig == nil { + return &GridRiskInfo{} + } + + at.gridState.mu.RLock() + defer at.gridState.mu.RUnlock() + + // Get current price + currentPrice, _ := at.trader.GetMarketPrice(gridConfig.Symbol) + + // Calculate effective leverage + totalInvestment := gridConfig.TotalInvestment + leverage := gridConfig.Leverage + + // Get current position value + positions, _ := at.trader.GetPositions() + var currentPositionValue float64 + var currentPositionSize float64 + for _, pos := range positions { + if sym, _ := pos["symbol"].(string); sym == gridConfig.Symbol { + size, _ := pos["positionAmt"].(float64) + entry, _ := pos["entryPrice"].(float64) + currentPositionValue = math.Abs(size * entry) + currentPositionSize = size + break + } + } + + effectiveLeverage := 0.0 + if totalInvestment > 0 { + effectiveLeverage = currentPositionValue / totalInvestment + } + + // Calculate max position based on regime + regimeLevel := market.RegimeLevel(at.gridState.CurrentRegimeLevel) + if regimeLevel == "" { + regimeLevel = market.RegimeLevelStandard + } + + // Use default position limit since GridStrategyConfig doesn't have regime-specific limits + // Default is 70% for standard regime + maxPositionPct := 70.0 + switch regimeLevel { + case market.RegimeLevelNarrow: + maxPositionPct = 40.0 + case market.RegimeLevelStandard: + maxPositionPct = 70.0 + case market.RegimeLevelWide: + maxPositionPct = 60.0 + case market.RegimeLevelVolatile: + maxPositionPct = 40.0 + } + + maxPosition := totalInvestment * maxPositionPct / 100 * float64(leverage) + + // Use default leverage limits since GridStrategyConfig doesn't have regime-specific limits + recommendedLeverage := leverage + switch regimeLevel { + case market.RegimeLevelNarrow: + recommendedLeverage = min(leverage, 2) + case market.RegimeLevelStandard: + recommendedLeverage = min(leverage, 4) + case market.RegimeLevelWide: + recommendedLeverage = min(leverage, 3) + case market.RegimeLevelVolatile: + recommendedLeverage = min(leverage, 2) + } + + // Calculate liquidation distance and price only when there's a position + var liquidationDistance float64 + var liquidationPrice float64 + if currentPositionSize != 0 && currentPrice > 0 { + liquidationDistance = 100.0 / float64(leverage) * 0.9 // ~90% of theoretical max + if currentPositionSize > 0 { + // Long position: liquidation below entry + liquidationPrice = currentPrice * (1 - liquidationDistance/100) + } else { + // Short position: liquidation above entry + liquidationPrice = currentPrice * (1 + liquidationDistance/100) + } + } + + positionPercent := 0.0 + if maxPosition > 0 { + positionPercent = currentPositionValue / maxPosition * 100 + } + + return &GridRiskInfo{ + CurrentLeverage: leverage, + EffectiveLeverage: effectiveLeverage, + RecommendedLeverage: recommendedLeverage, + + CurrentPosition: currentPositionValue, + MaxPosition: maxPosition, + PositionPercent: positionPercent, + + LiquidationPrice: liquidationPrice, + LiquidationDistance: liquidationDistance, + + RegimeLevel: string(regimeLevel), + + ShortBoxUpper: at.gridState.ShortBoxUpper, + ShortBoxLower: at.gridState.ShortBoxLower, + MidBoxUpper: at.gridState.MidBoxUpper, + MidBoxLower: at.gridState.MidBoxLower, + LongBoxUpper: at.gridState.LongBoxUpper, + LongBoxLower: at.gridState.LongBoxLower, + CurrentPrice: currentPrice, + + BreakoutLevel: at.gridState.BreakoutLevel, + BreakoutDirection: at.gridState.BreakoutDirection, + + CurrentGridDirection: string(at.gridState.CurrentDirection), + DirectionChangeCount: at.gridState.DirectionChangeCount, + EnableDirectionAdjust: gridConfig.EnableDirectionAdjust, + } +} diff --git a/trader/binance/futures.go b/trader/binance/futures.go index 723cb470..efc3bb32 100644 --- a/trader/binance/futures.go +++ b/trader/binance/futures.go @@ -7,8 +7,6 @@ import ( "fmt" "nofx/hook" "nofx/logger" - "nofx/trader/types" - "strconv" "strings" "sync" "time" @@ -123,981 +121,19 @@ func syncBinanceServerTime(client *futures.Client) { logger.Infof("⏱ Binance server time synced, offset %dms", offset) } -// GetBalance gets account balance (with cache) -func (t *FuturesTrader) GetBalance() (map[string]interface{}, error) { - // First check if cache is valid - t.balanceCacheMutex.RLock() - if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration { - cacheAge := time.Since(t.balanceCacheTime) - t.balanceCacheMutex.RUnlock() - logger.Infof("✓ Using cached account balance (cache age: %.1f seconds ago)", cacheAge.Seconds()) - return t.cachedBalance, nil - } - t.balanceCacheMutex.RUnlock() +// Helper functions - // Cache expired or doesn't exist, call API - logger.Infof("🔄 Cache expired, calling Binance API to get account balance...") - account, err := t.client.NewGetAccountService().Do(context.Background()) - if err != nil { - logger.Infof("❌ Binance API call failed: %v", err) - return nil, fmt.Errorf("failed to get account info: %w", err) - } - - result := make(map[string]interface{}) - result["totalWalletBalance"], _ = strconv.ParseFloat(account.TotalWalletBalance, 64) - result["availableBalance"], _ = strconv.ParseFloat(account.AvailableBalance, 64) - result["totalUnrealizedProfit"], _ = strconv.ParseFloat(account.TotalUnrealizedProfit, 64) - - logger.Infof("✓ Binance API returned: total balance=%s, available=%s, unrealized PnL=%s", - account.TotalWalletBalance, - account.AvailableBalance, - account.TotalUnrealizedProfit) - - // Update cache - t.balanceCacheMutex.Lock() - t.cachedBalance = result - t.balanceCacheTime = time.Now() - t.balanceCacheMutex.Unlock() - - return result, nil +func contains(s, substr string) bool { + return len(s) >= len(substr) && stringContains(s, substr) } -// GetPositions gets all positions (with cache) -func (t *FuturesTrader) GetPositions() ([]map[string]interface{}, error) { - // First check if cache is valid - t.positionsCacheMutex.RLock() - if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration { - cacheAge := time.Since(t.positionsCacheTime) - t.positionsCacheMutex.RUnlock() - logger.Infof("✓ Using cached position information (cache age: %.1f seconds ago)", cacheAge.Seconds()) - return t.cachedPositions, nil - } - t.positionsCacheMutex.RUnlock() - - // Cache expired or doesn't exist, call API - logger.Infof("🔄 Cache expired, calling Binance API to get position information...") - positions, err := t.client.NewGetPositionRiskService().Do(context.Background()) - if err != nil { - return nil, fmt.Errorf("failed to get positions: %w", err) - } - - var result []map[string]interface{} - for _, pos := range positions { - posAmt, _ := strconv.ParseFloat(pos.PositionAmt, 64) - if posAmt == 0 { - continue // Skip positions with zero amount - } - - posMap := make(map[string]interface{}) - posMap["symbol"] = pos.Symbol - posMap["positionAmt"], _ = strconv.ParseFloat(pos.PositionAmt, 64) - posMap["entryPrice"], _ = strconv.ParseFloat(pos.EntryPrice, 64) - posMap["markPrice"], _ = strconv.ParseFloat(pos.MarkPrice, 64) - posMap["unRealizedProfit"], _ = strconv.ParseFloat(pos.UnRealizedProfit, 64) - posMap["leverage"], _ = strconv.ParseFloat(pos.Leverage, 64) - posMap["liquidationPrice"], _ = strconv.ParseFloat(pos.LiquidationPrice, 64) - // Note: Binance SDK doesn't expose updateTime field, will fallback to local tracking - - // Determine direction - if posAmt > 0 { - posMap["side"] = "long" - } else { - posMap["side"] = "short" - } - - result = append(result, posMap) - } - - // Update cache - t.positionsCacheMutex.Lock() - t.cachedPositions = result - t.positionsCacheTime = time.Now() - t.positionsCacheMutex.Unlock() - - return result, nil -} - -// SetMarginMode sets margin mode -func (t *FuturesTrader) SetMarginMode(symbol string, isCrossMargin bool) error { - var marginType futures.MarginType - if isCrossMargin { - marginType = futures.MarginTypeCrossed - } else { - marginType = futures.MarginTypeIsolated - } - - // Try to set margin mode - err := t.client.NewChangeMarginTypeService(). - Symbol(symbol). - MarginType(marginType). - Do(context.Background()) - - marginModeStr := "Cross Margin" - if !isCrossMargin { - marginModeStr = "Isolated Margin" - } - - if err != nil { - // If error message contains "No need to change", margin mode is already set to target value - if contains(err.Error(), "No need to change margin type") { - logger.Infof(" ✓ %s margin mode is already %s", symbol, marginModeStr) - return nil - } - // If there is an open position, margin mode cannot be changed, but this doesn't affect trading - if contains(err.Error(), "Margin type cannot be changed if there exists position") { - logger.Infof(" ⚠️ %s has open positions, cannot change margin mode, continuing with current mode", symbol) - return nil - } - // Detect Multi-Assets mode (error code -4168) - if contains(err.Error(), "Multi-Assets mode") || contains(err.Error(), "-4168") || contains(err.Error(), "4168") { - logger.Infof(" ⚠️ %s detected Multi-Assets mode, forcing Cross Margin mode", symbol) - logger.Infof(" 💡 Tip: To use Isolated Margin mode, please disable Multi-Assets mode in Binance") - return nil - } - // Detect Unified Account API (Portfolio Margin) - if contains(err.Error(), "unified") || contains(err.Error(), "portfolio") || contains(err.Error(), "Portfolio") { - logger.Infof(" ❌ %s detected Unified Account API, unable to trade futures", symbol) - return fmt.Errorf("please use 'Spot & Futures Trading' API permission, do not use 'Unified Account API'") - } - logger.Infof(" ⚠️ Failed to set margin mode: %v", err) - // Don't return error, let trading continue - return nil - } - - logger.Infof(" ✓ %s margin mode set to %s", symbol, marginModeStr) - return nil -} - -// SetLeverage sets leverage (with smart detection and cooldown period) -func (t *FuturesTrader) SetLeverage(symbol string, leverage int) error { - // First try to get current leverage (from position information) - currentLeverage := 0 - positions, err := t.GetPositions() - if err == nil { - for _, pos := range positions { - if pos["symbol"] == symbol { - if lev, ok := pos["leverage"].(float64); ok { - currentLeverage = int(lev) - break - } - } +func stringContains(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true } } - - // If current leverage is already the target leverage, skip - if currentLeverage == leverage && currentLeverage > 0 { - logger.Infof(" ✓ %s leverage is already %dx, no need to change", symbol, leverage) - return nil - } - - // Change leverage - _, err = t.client.NewChangeLeverageService(). - Symbol(symbol). - Leverage(leverage). - Do(context.Background()) - - if err != nil { - // If error message contains "No need to change", leverage is already the target value - if contains(err.Error(), "No need to change") { - logger.Infof(" ✓ %s leverage is already %dx", symbol, leverage) - return nil - } - return fmt.Errorf("failed to set leverage: %w", err) - } - - logger.Infof(" ✓ %s leverage changed to %dx", symbol, leverage) - - // Wait 5 seconds after changing leverage (to avoid cooldown period errors) - logger.Infof(" ⏱ Waiting 5 seconds for cooldown period...") - time.Sleep(5 * time.Second) - - return nil -} - -// OpenLong opens a long position -func (t *FuturesTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { - // First cancel all pending orders for this symbol (clean up old stop-loss and take-profit orders) - if err := t.CancelAllOrders(symbol); err != nil { - logger.Infof(" ⚠ Failed to cancel old pending orders (may not have any): %v", err) - } - - // Set leverage - if err := t.SetLeverage(symbol, leverage); err != nil { - return nil, err - } - - // Note: Margin mode should be set by the caller (AutoTrader) before opening position via SetMarginMode - - // Format quantity to correct precision - quantityStr, err := t.FormatQuantity(symbol, quantity) - if err != nil { - return nil, err - } - - // Check if formatted quantity is 0 (prevent rounding errors) - quantityFloat, parseErr := strconv.ParseFloat(quantityStr, 64) - if parseErr != nil || quantityFloat <= 0 { - return nil, fmt.Errorf("position size too small, rounded to 0 (original: %.8f → formatted: %s). Suggest increasing position amount or selecting a lower-priced coin", quantity, quantityStr) - } - - // Check minimum notional value (Binance requires at least 10 USDT) - if err := t.CheckMinNotional(symbol, quantityFloat); err != nil { - return nil, err - } - - // Create market buy order (using br ID) - order, err := t.client.NewCreateOrderService(). - Symbol(symbol). - Side(futures.SideTypeBuy). - PositionSide(futures.PositionSideTypeLong). - Type(futures.OrderTypeMarket). - Quantity(quantityStr). - NewClientOrderID(getBrOrderID()). - Do(context.Background()) - - if err != nil { - return nil, fmt.Errorf("failed to open long position: %w", err) - } - - logger.Infof("✓ Opened long position successfully: %s quantity: %s", symbol, quantityStr) - logger.Infof(" Order ID: %d", order.OrderID) - - result := make(map[string]interface{}) - result["orderId"] = order.OrderID - result["symbol"] = order.Symbol - result["status"] = order.Status - return result, nil -} - -// OpenShort opens a short position -func (t *FuturesTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { - // First cancel all pending orders for this symbol (clean up old stop-loss and take-profit orders) - if err := t.CancelAllOrders(symbol); err != nil { - logger.Infof(" ⚠ Failed to cancel old pending orders (may not have any): %v", err) - } - - // Set leverage - if err := t.SetLeverage(symbol, leverage); err != nil { - return nil, err - } - - // Note: Margin mode should be set by the caller (AutoTrader) before opening position via SetMarginMode - - // Format quantity to correct precision - quantityStr, err := t.FormatQuantity(symbol, quantity) - if err != nil { - return nil, err - } - - // Check if formatted quantity is 0 (prevent rounding errors) - quantityFloat, parseErr := strconv.ParseFloat(quantityStr, 64) - if parseErr != nil || quantityFloat <= 0 { - return nil, fmt.Errorf("position size too small, rounded to 0 (original: %.8f → formatted: %s). Suggest increasing position amount or selecting a lower-priced coin", quantity, quantityStr) - } - - // Check minimum notional value (Binance requires at least 10 USDT) - if err := t.CheckMinNotional(symbol, quantityFloat); err != nil { - return nil, err - } - - // Create market sell order (using br ID) - order, err := t.client.NewCreateOrderService(). - Symbol(symbol). - Side(futures.SideTypeSell). - PositionSide(futures.PositionSideTypeShort). - Type(futures.OrderTypeMarket). - Quantity(quantityStr). - NewClientOrderID(getBrOrderID()). - Do(context.Background()) - - if err != nil { - return nil, fmt.Errorf("failed to open short position: %w", err) - } - - logger.Infof("✓ Opened short position successfully: %s quantity: %s", symbol, quantityStr) - logger.Infof(" Order ID: %d", order.OrderID) - - result := make(map[string]interface{}) - result["orderId"] = order.OrderID - result["symbol"] = order.Symbol - result["status"] = order.Status - return result, nil -} - -// CloseLong closes a long position -func (t *FuturesTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) { - // If quantity is 0, get current position quantity - if quantity == 0 { - positions, err := t.GetPositions() - if err != nil { - return nil, err - } - - for _, pos := range positions { - if pos["symbol"] == symbol && pos["side"] == "long" { - quantity = pos["positionAmt"].(float64) - break - } - } - - if quantity == 0 { - return nil, fmt.Errorf("no long position found for %s", symbol) - } - } - - // Format quantity - quantityStr, err := t.FormatQuantity(symbol, quantity) - if err != nil { - return nil, err - } - - // Create market sell order (close long, using br ID) - order, err := t.client.NewCreateOrderService(). - Symbol(symbol). - Side(futures.SideTypeSell). - PositionSide(futures.PositionSideTypeLong). - Type(futures.OrderTypeMarket). - Quantity(quantityStr). - NewClientOrderID(getBrOrderID()). - Do(context.Background()) - - if err != nil { - return nil, fmt.Errorf("failed to close long position: %w", err) - } - - logger.Infof("✓ Closed long position successfully: %s quantity: %s", symbol, quantityStr) - - // After closing position, cancel all pending orders for this symbol (stop-loss and take-profit orders) - if err := t.CancelAllOrders(symbol); err != nil { - logger.Infof(" ⚠ Failed to cancel pending orders: %v", err) - } - - result := make(map[string]interface{}) - result["orderId"] = order.OrderID - result["symbol"] = order.Symbol - result["status"] = order.Status - return result, nil -} - -// CloseShort closes a short position -func (t *FuturesTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) { - // If quantity is 0, get current position quantity - if quantity == 0 { - positions, err := t.GetPositions() - if err != nil { - return nil, err - } - - for _, pos := range positions { - if pos["symbol"] == symbol && pos["side"] == "short" { - quantity = -pos["positionAmt"].(float64) // Short position quantity is negative, take absolute value - break - } - } - - if quantity == 0 { - return nil, fmt.Errorf("no short position found for %s", symbol) - } - } - - // Format quantity - quantityStr, err := t.FormatQuantity(symbol, quantity) - if err != nil { - return nil, err - } - - // Create market buy order (close short, using br ID) - order, err := t.client.NewCreateOrderService(). - Symbol(symbol). - Side(futures.SideTypeBuy). - PositionSide(futures.PositionSideTypeShort). - Type(futures.OrderTypeMarket). - Quantity(quantityStr). - NewClientOrderID(getBrOrderID()). - Do(context.Background()) - - if err != nil { - return nil, fmt.Errorf("failed to close short position: %w", err) - } - - logger.Infof("✓ Closed short position successfully: %s quantity: %s", symbol, quantityStr) - - // After closing position, cancel all pending orders for this symbol (stop-loss and take-profit orders) - if err := t.CancelAllOrders(symbol); err != nil { - logger.Infof(" ⚠ Failed to cancel pending orders: %v", err) - } - - result := make(map[string]interface{}) - result["orderId"] = order.OrderID - result["symbol"] = order.Symbol - result["status"] = order.Status - return result, nil -} - -// CancelStopLossOrders cancels only stop-loss orders (doesn't affect take-profit orders) -// Now uses both legacy API and new Algo Order API -func (t *FuturesTrader) CancelStopLossOrders(symbol string) error { - canceledCount := 0 - var cancelErrors []error - - // 1. Cancel legacy stop-loss orders - orders, err := t.client.NewListOpenOrdersService(). - Symbol(symbol). - Do(context.Background()) - - if err == nil { - for _, order := range orders { - orderType := string(order.Type) - - // Only cancel stop-loss orders (don't cancel take-profit orders) - // Use string comparison since OrderType constants were removed in v2.8.9 - if orderType == "STOP_MARKET" || orderType == "STOP" { - _, err := t.client.NewCancelOrderService(). - Symbol(symbol). - OrderID(order.OrderID). - Do(context.Background()) - - if err != nil { - errMsg := fmt.Sprintf("Order ID %d: %v", order.OrderID, err) - cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg)) - logger.Infof(" ⚠ Failed to cancel legacy stop-loss order: %s", errMsg) - continue - } - - canceledCount++ - logger.Infof(" ✓ Canceled legacy stop-loss order (Order ID: %d, Type: %s, Side: %s)", order.OrderID, orderType, order.PositionSide) - } - } - } - - // 2. Cancel Algo stop-loss orders - algoOrders, err := t.client.NewListOpenAlgoOrdersService(). - Symbol(symbol). - Do(context.Background()) - - if err == nil { - for _, algoOrder := range algoOrders { - // Only cancel stop-loss orders - if algoOrder.OrderType == futures.AlgoOrderTypeStopMarket || algoOrder.OrderType == futures.AlgoOrderTypeStop { - _, err := t.client.NewCancelAlgoOrderService(). - AlgoID(algoOrder.AlgoId). - Do(context.Background()) - - if err != nil { - errMsg := fmt.Sprintf("Algo ID %d: %v", algoOrder.AlgoId, err) - cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg)) - logger.Infof(" ⚠ Failed to cancel Algo stop-loss order: %s", errMsg) - continue - } - - canceledCount++ - logger.Infof(" ✓ Canceled Algo stop-loss order (Algo ID: %d, Type: %s)", algoOrder.AlgoId, algoOrder.OrderType) - } - } - } - - if canceledCount == 0 && len(cancelErrors) == 0 { - logger.Infof(" ℹ %s has no stop-loss orders to cancel", symbol) - } else if canceledCount > 0 { - logger.Infof(" ✓ Canceled %d stop-loss order(s) for %s", canceledCount, symbol) - } - - // If all cancellations failed, return error - if len(cancelErrors) > 0 && canceledCount == 0 { - return fmt.Errorf("failed to cancel stop-loss orders: %v", cancelErrors) - } - - return nil -} - -// CancelTakeProfitOrders cancels only take-profit orders (doesn't affect stop-loss orders) -// Now uses both legacy API and new Algo Order API -func (t *FuturesTrader) CancelTakeProfitOrders(symbol string) error { - canceledCount := 0 - var cancelErrors []error - - // 1. Cancel legacy take-profit orders - orders, err := t.client.NewListOpenOrdersService(). - Symbol(symbol). - Do(context.Background()) - - if err == nil { - for _, order := range orders { - orderType := string(order.Type) - - // Only cancel take-profit orders (don't cancel stop-loss orders) - // Use string comparison since OrderType constants were removed in v2.8.9 - if orderType == "TAKE_PROFIT_MARKET" || orderType == "TAKE_PROFIT" { - _, err := t.client.NewCancelOrderService(). - Symbol(symbol). - OrderID(order.OrderID). - Do(context.Background()) - - if err != nil { - errMsg := fmt.Sprintf("Order ID %d: %v", order.OrderID, err) - cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg)) - logger.Infof(" ⚠ Failed to cancel legacy take-profit order: %s", errMsg) - continue - } - - canceledCount++ - logger.Infof(" ✓ Canceled legacy take-profit order (Order ID: %d, Type: %s, Side: %s)", order.OrderID, orderType, order.PositionSide) - } - } - } - - // 2. Cancel Algo take-profit orders - algoOrders, err := t.client.NewListOpenAlgoOrdersService(). - Symbol(symbol). - Do(context.Background()) - - if err == nil { - for _, algoOrder := range algoOrders { - // Only cancel take-profit orders - if algoOrder.OrderType == futures.AlgoOrderTypeTakeProfitMarket || algoOrder.OrderType == futures.AlgoOrderTypeTakeProfit { - _, err := t.client.NewCancelAlgoOrderService(). - AlgoID(algoOrder.AlgoId). - Do(context.Background()) - - if err != nil { - errMsg := fmt.Sprintf("Algo ID %d: %v", algoOrder.AlgoId, err) - cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg)) - logger.Infof(" ⚠ Failed to cancel Algo take-profit order: %s", errMsg) - continue - } - - canceledCount++ - logger.Infof(" ✓ Canceled Algo take-profit order (Algo ID: %d, Type: %s)", algoOrder.AlgoId, algoOrder.OrderType) - } - } - } - - if canceledCount == 0 && len(cancelErrors) == 0 { - logger.Infof(" ℹ %s has no take-profit orders to cancel", symbol) - } else if canceledCount > 0 { - logger.Infof(" ✓ Canceled %d take-profit order(s) for %s", canceledCount, symbol) - } - - // If all cancellations failed, return error - if len(cancelErrors) > 0 && canceledCount == 0 { - return fmt.Errorf("failed to cancel take-profit orders: %v", cancelErrors) - } - - return nil -} - -// CancelAllOrders cancels all pending orders for this symbol -// Now uses both legacy API and new Algo Order API -func (t *FuturesTrader) CancelAllOrders(symbol string) error { - // 1. Cancel all legacy orders - err := t.client.NewCancelAllOpenOrdersService(). - Symbol(symbol). - Do(context.Background()) - - if err != nil { - logger.Infof(" ⚠ Failed to cancel legacy orders: %v", err) - } else { - logger.Infof(" ✓ Canceled all legacy pending orders for %s", symbol) - } - - // 2. Cancel all Algo orders - err = t.client.NewCancelAllAlgoOpenOrdersService(). - Symbol(symbol). - Do(context.Background()) - - if err != nil { - // Ignore "no algo orders" error - if !contains(err.Error(), "no algo") && !contains(err.Error(), "No algo") { - logger.Infof(" ⚠ Failed to cancel Algo orders: %v", err) - } - } else { - logger.Infof(" ✓ Canceled all Algo orders for %s", symbol) - } - - return nil -} - -// PlaceLimitOrder places a limit order for grid trading -// This implements the GridTrader interface for FuturesTrader -func (t *FuturesTrader) PlaceLimitOrder(req *types.LimitOrderRequest) (*types.LimitOrderResult, error) { - // Format quantity to correct precision - quantityStr, err := t.FormatQuantity(req.Symbol, req.Quantity) - if err != nil { - return nil, fmt.Errorf("failed to format quantity: %w", err) - } - - // Format price to correct precision - priceStr, err := t.FormatPrice(req.Symbol, req.Price) - if err != nil { - return nil, fmt.Errorf("failed to format price: %w", err) - } - - // Set leverage if specified - if req.Leverage > 0 { - if err := t.SetLeverage(req.Symbol, req.Leverage); err != nil { - logger.Warnf("Failed to set leverage: %v", err) - } - } - - // Determine side and position side - var side futures.SideType - var positionSide futures.PositionSideType - - if req.Side == "BUY" { - side = futures.SideTypeBuy - positionSide = futures.PositionSideTypeLong - } else { - side = futures.SideTypeSell - positionSide = futures.PositionSideTypeShort - } - - // Build order service with broker ID - orderService := t.client.NewCreateOrderService(). - Symbol(req.Symbol). - Side(side). - PositionSide(positionSide). - Type(futures.OrderTypeLimit). - TimeInForce(futures.TimeInForceTypeGTC). - Quantity(quantityStr). - Price(priceStr). - NewClientOrderID(getBrOrderID()) - - // Execute order - order, err := orderService.Do(context.Background()) - if err != nil { - return nil, fmt.Errorf("failed to place limit order: %w", err) - } - - logger.Infof("✓ [Grid] Placed limit order: %s %s %s @ %s, qty=%s, orderID=%d", - req.Symbol, req.Side, positionSide, priceStr, quantityStr, order.OrderID) - - return &types.LimitOrderResult{ - OrderID: fmt.Sprintf("%d", order.OrderID), - ClientID: order.ClientOrderID, - Symbol: order.Symbol, - Side: string(order.Side), - PositionSide: string(order.PositionSide), - Price: req.Price, - Quantity: req.Quantity, - Status: string(order.Status), - }, nil -} - -// CancelOrder cancels a specific order by ID -// This implements the GridTrader interface for FuturesTrader -func (t *FuturesTrader) CancelOrder(symbol, orderID string) error { - // Parse order ID to int64 - orderIDInt, err := strconv.ParseInt(orderID, 10, 64) - if err != nil { - return fmt.Errorf("invalid order ID: %w", err) - } - - _, err = t.client.NewCancelOrderService(). - Symbol(symbol). - OrderID(orderIDInt). - Do(context.Background()) - - if err != nil { - return fmt.Errorf("failed to cancel order: %w", err) - } - - logger.Infof("✓ [Grid] Cancelled order: %s/%s", symbol, orderID) - return nil -} - -// GetOrderBook gets the order book for a symbol -// This implements the GridTrader interface for FuturesTrader -func (t *FuturesTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { - book, err := t.client.NewDepthService(). - Symbol(symbol). - Limit(depth). - Do(context.Background()) - - if err != nil { - return nil, nil, fmt.Errorf("failed to get order book: %w", err) - } - - // Convert bids - bids = make([][]float64, len(book.Bids)) - for i, bid := range book.Bids { - price, _ := strconv.ParseFloat(bid.Price, 64) - qty, _ := strconv.ParseFloat(bid.Quantity, 64) - bids[i] = []float64{price, qty} - } - - // Convert asks - asks = make([][]float64, len(book.Asks)) - for i, ask := range book.Asks { - price, _ := strconv.ParseFloat(ask.Price, 64) - qty, _ := strconv.ParseFloat(ask.Quantity, 64) - asks[i] = []float64{price, qty} - } - - return bids, asks, nil -} - -// CancelStopOrders cancels take-profit/stop-loss orders for this symbol (used to adjust TP/SL positions) -// Now uses both legacy API and new Algo Order API (Binance migrated stop orders to Algo system) -func (t *FuturesTrader) CancelStopOrders(symbol string) error { - canceledCount := 0 - - // 1. Cancel legacy stop orders (for backward compatibility) - orders, err := t.client.NewListOpenOrdersService(). - Symbol(symbol). - Do(context.Background()) - - if err == nil { - for _, order := range orders { - orderType := string(order.Type) - - // Only cancel stop-loss and take-profit orders - // Use string comparison since OrderType constants were removed in v2.8.9 - if orderType == "STOP_MARKET" || - orderType == "TAKE_PROFIT_MARKET" || - orderType == "STOP" || - orderType == "TAKE_PROFIT" { - - _, err := t.client.NewCancelOrderService(). - Symbol(symbol). - OrderID(order.OrderID). - Do(context.Background()) - - if err != nil { - logger.Infof(" ⚠ Failed to cancel legacy order %d: %v", order.OrderID, err) - continue - } - - canceledCount++ - logger.Infof(" ✓ Canceled legacy stop order for %s (Order ID: %d, Type: %s)", - symbol, order.OrderID, orderType) - } - } - } - - // 2. Cancel Algo orders (new API) - err = t.client.NewCancelAllAlgoOpenOrdersService(). - Symbol(symbol). - Do(context.Background()) - - if err != nil { - // Ignore "no algo orders" error - if !contains(err.Error(), "no algo") && !contains(err.Error(), "No algo") { - logger.Infof(" ⚠ Failed to cancel Algo orders: %v", err) - } - } else { - logger.Infof(" ✓ Canceled all Algo orders for %s", symbol) - canceledCount++ - } - - if canceledCount == 0 { - logger.Infof(" ℹ %s has no take-profit/stop-loss orders to cancel", symbol) - } - - return nil -} - -// GetOpenOrders gets all open/pending orders for a symbol -func (t *FuturesTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) { - var result []types.OpenOrder - - // 1. Get legacy open orders - orders, err := t.client.NewListOpenOrdersService(). - Symbol(symbol). - Do(context.Background()) - - if err != nil { - return nil, fmt.Errorf("failed to get open orders: %w", err) - } - - for _, order := range orders { - price, _ := strconv.ParseFloat(order.Price, 64) - stopPrice, _ := strconv.ParseFloat(order.StopPrice, 64) - quantity, _ := strconv.ParseFloat(order.OrigQuantity, 64) - - result = append(result, types.OpenOrder{ - OrderID: fmt.Sprintf("%d", order.OrderID), - Symbol: order.Symbol, - Side: string(order.Side), - PositionSide: string(order.PositionSide), - Type: string(order.Type), - Price: price, - StopPrice: stopPrice, - Quantity: quantity, - Status: string(order.Status), - }) - } - - // 2. Get Algo orders (new API for stop-loss/take-profit) - algoOrders, err := t.client.NewListOpenAlgoOrdersService(). - Symbol(symbol). - Do(context.Background()) - - if err == nil { - for _, algoOrder := range algoOrders { - triggerPrice, _ := strconv.ParseFloat(algoOrder.TriggerPrice, 64) - quantity, _ := strconv.ParseFloat(algoOrder.Quantity, 64) - - result = append(result, types.OpenOrder{ - OrderID: fmt.Sprintf("%d", algoOrder.AlgoId), - Symbol: algoOrder.Symbol, - Side: string(algoOrder.Side), - PositionSide: string(algoOrder.PositionSide), - Type: string(algoOrder.OrderType), - Price: 0, // Algo orders use stop price - StopPrice: triggerPrice, - Quantity: quantity, - Status: "NEW", - }) - } - } - - return result, nil -} - -// GetMarketPrice gets market price -func (t *FuturesTrader) GetMarketPrice(symbol string) (float64, error) { - prices, err := t.client.NewListPricesService().Symbol(symbol).Do(context.Background()) - if err != nil { - return 0, fmt.Errorf("failed to get price: %w", err) - } - - if len(prices) == 0 { - return 0, fmt.Errorf("price not found") - } - - price, err := strconv.ParseFloat(prices[0].Price, 64) - if err != nil { - return 0, err - } - - return price, nil -} - -// CalculatePositionSize calculates position size -func (t *FuturesTrader) CalculatePositionSize(balance, riskPercent, price float64, leverage int) float64 { - riskAmount := balance * (riskPercent / 100.0) - positionValue := riskAmount * float64(leverage) - quantity := positionValue / price - return quantity -} - -// SetStopLoss sets stop-loss order using new Algo Order API -// Binance has migrated stop orders to Algo Order system (error -4120 STOP_ORDER_SWITCH_ALGO) -func (t *FuturesTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error { - var side futures.SideType - var posSide futures.PositionSideType - - if positionSide == "LONG" { - side = futures.SideTypeSell - posSide = futures.PositionSideTypeLong - } else { - side = futures.SideTypeBuy - posSide = futures.PositionSideTypeShort - } - - // Use new Algo Order API - _, err := t.client.NewCreateAlgoOrderService(). - Symbol(symbol). - Side(side). - PositionSide(posSide). - Type(futures.AlgoOrderTypeStopMarket). - TriggerPrice(fmt.Sprintf("%.8f", stopPrice)). - WorkingType(futures.WorkingTypeContractPrice). - ClosePosition(true). - ClientAlgoId(getBrOrderID()). - Do(context.Background()) - - if err != nil { - return fmt.Errorf("failed to set stop-loss: %w", err) - } - - logger.Infof(" Stop-loss price set (Algo Order): %.4f", stopPrice) - return nil -} - -// SetTakeProfit sets take-profit order using new Algo Order API -// Binance has migrated stop orders to Algo Order system (error -4120 STOP_ORDER_SWITCH_ALGO) -func (t *FuturesTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error { - var side futures.SideType - var posSide futures.PositionSideType - - if positionSide == "LONG" { - side = futures.SideTypeSell - posSide = futures.PositionSideTypeLong - } else { - side = futures.SideTypeBuy - posSide = futures.PositionSideTypeShort - } - - // Use new Algo Order API - _, err := t.client.NewCreateAlgoOrderService(). - Symbol(symbol). - Side(side). - PositionSide(posSide). - Type(futures.AlgoOrderTypeTakeProfitMarket). - TriggerPrice(fmt.Sprintf("%.8f", takeProfitPrice)). - WorkingType(futures.WorkingTypeContractPrice). - ClosePosition(true). - ClientAlgoId(getBrOrderID()). - Do(context.Background()) - - if err != nil { - return fmt.Errorf("failed to set take-profit: %w", err) - } - - logger.Infof(" Take-profit price set (Algo Order): %.4f", takeProfitPrice) - return nil -} - -// GetMinNotional gets minimum notional value (Binance requirement) -func (t *FuturesTrader) GetMinNotional(symbol string) float64 { - // Use conservative default value of 10 USDT to ensure order passes exchange validation - return 10.0 -} - -// CheckMinNotional checks if order meets minimum notional value requirement -func (t *FuturesTrader) CheckMinNotional(symbol string, quantity float64) error { - price, err := t.GetMarketPrice(symbol) - if err != nil { - return fmt.Errorf("failed to get market price: %w", err) - } - - notionalValue := quantity * price - minNotional := t.GetMinNotional(symbol) - - if notionalValue < minNotional { - return fmt.Errorf( - "order amount %.2f USDT is below minimum requirement %.2f USDT (quantity: %.4f, price: %.4f)", - notionalValue, minNotional, quantity, price, - ) - } - - return nil -} - -// GetSymbolPrecision gets the quantity precision for a trading pair -func (t *FuturesTrader) GetSymbolPrecision(symbol string) (int, error) { - exchangeInfo, err := t.client.NewExchangeInfoService().Do(context.Background()) - if err != nil { - return 0, fmt.Errorf("failed to get trading rules: %w", err) - } - - for _, s := range exchangeInfo.Symbols { - if s.Symbol == symbol { - // Get precision from LOT_SIZE filter - for _, filter := range s.Filters { - if filter["filterType"] == "LOT_SIZE" { - stepSize := filter["stepSize"].(string) - precision := calculatePrecision(stepSize) - logger.Infof(" %s quantity precision: %d (stepSize: %s)", symbol, precision, stepSize) - return precision, nil - } - } - } - } - - logger.Infof(" ⚠ %s precision information not found, using default precision 3", symbol) - return 3, nil // Default precision is 3 + return false } // calculatePrecision calculates precision from stepSize @@ -1142,346 +178,3 @@ func trimTrailingZeros(s string) string { return s } - -// FormatQuantity formats quantity to correct precision -func (t *FuturesTrader) FormatQuantity(symbol string, quantity float64) (string, error) { - precision, err := t.GetSymbolPrecision(symbol) - if err != nil { - // If retrieval fails, use default format - return fmt.Sprintf("%.3f", quantity), nil - } - - format := fmt.Sprintf("%%.%df", precision) - return fmt.Sprintf(format, quantity), nil -} - -// GetSymbolPricePrecision gets the price precision for a trading pair -func (t *FuturesTrader) GetSymbolPricePrecision(symbol string) (int, error) { - exchangeInfo, err := t.client.NewExchangeInfoService().Do(context.Background()) - if err != nil { - return 0, fmt.Errorf("failed to get trading rules: %w", err) - } - - for _, s := range exchangeInfo.Symbols { - if s.Symbol == symbol { - // Get precision from PRICE_FILTER filter - for _, filter := range s.Filters { - if filter["filterType"] == "PRICE_FILTER" { - tickSize := filter["tickSize"].(string) - precision := calculatePrecision(tickSize) - return precision, nil - } - } - } - } - - // Default to 2 decimal places for price - return 2, nil -} - -// FormatPrice formats price to correct precision -func (t *FuturesTrader) FormatPrice(symbol string, price float64) (string, error) { - precision, err := t.GetSymbolPricePrecision(symbol) - if err != nil { - // If retrieval fails, use default format - return fmt.Sprintf("%.2f", price), nil - } - - format := fmt.Sprintf("%%.%df", precision) - return fmt.Sprintf(format, price), nil -} - -// Helper functions -func contains(s, substr string) bool { - return len(s) >= len(substr) && stringContains(s, substr) -} - -func stringContains(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false -} - -// GetOrderStatus gets order status -func (t *FuturesTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { - // Convert orderID to int64 - orderIDInt, err := strconv.ParseInt(orderID, 10, 64) - if err != nil { - return nil, fmt.Errorf("invalid order ID: %s", orderID) - } - - order, err := t.client.NewGetOrderService(). - Symbol(symbol). - OrderID(orderIDInt). - Do(context.Background()) - if err != nil { - return nil, fmt.Errorf("failed to get order status: %w", err) - } - - // Parse execution price - avgPrice, _ := strconv.ParseFloat(order.AvgPrice, 64) - executedQty, _ := strconv.ParseFloat(order.ExecutedQuantity, 64) - - result := map[string]interface{}{ - "orderId": order.OrderID, - "symbol": order.Symbol, - "status": string(order.Status), - "avgPrice": avgPrice, - "executedQty": executedQty, - "side": string(order.Side), - "type": string(order.Type), - "time": order.Time, - "updateTime": order.UpdateTime, - } - - // Binance futures commission fee needs to be obtained through GetUserTrades, not retrieved here for now - // Can be obtained later through WebSocket or separate query - result["commission"] = 0.0 - - return result, nil -} - -// GetClosedPnL retrieves recent closing trades from Binance Futures -// Note: Binance does NOT have a position history API, only trade history. -// This returns individual closing trades (realizedPnl != 0) for real-time position closure detection. -// NOT suitable for historical position reconstruction - use only for matching recent closures. -func (t *FuturesTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) { - trades, err := t.GetTrades(startTime, limit) - if err != nil { - return nil, err - } - - // Filter only closing trades (realizedPnl != 0) and convert to ClosedPnLRecord - var records []types.ClosedPnLRecord - for _, trade := range trades { - if trade.RealizedPnL == 0 { - continue // Skip opening trades - } - - // Determine side from trade - side := "long" - if trade.PositionSide == "SHORT" || trade.PositionSide == "short" { - side = "short" - } else if trade.PositionSide == "BOTH" || trade.PositionSide == "" { - // One-way mode: selling closes long, buying closes short - if trade.Side == "SELL" || trade.Side == "Sell" { - side = "long" - } else { - side = "short" - } - } - - // Calculate entry price from PnL (mathematically accurate for this trade) - var entryPrice float64 - if trade.Quantity > 0 { - if side == "long" { - entryPrice = trade.Price - trade.RealizedPnL/trade.Quantity - } else { - entryPrice = trade.Price + trade.RealizedPnL/trade.Quantity - } - } - - records = append(records, types.ClosedPnLRecord{ - Symbol: trade.Symbol, - Side: side, - EntryPrice: entryPrice, - ExitPrice: trade.Price, - Quantity: trade.Quantity, - RealizedPnL: trade.RealizedPnL, - Fee: trade.Fee, - ExitTime: trade.Time, - EntryTime: trade.Time, // Approximate - OrderID: trade.TradeID, - ExchangeID: trade.TradeID, - CloseType: "unknown", - }) - } - - return records, nil -} - -// GetTrades retrieves trade history from Binance Futures using Income API -// Note: Income API has delays (~minutes), for real-time use GetTradesForSymbol instead -func (t *FuturesTrader) GetTrades(startTime time.Time, limit int) ([]types.TradeRecord, error) { - if limit <= 0 { - limit = 100 - } - if limit > 1000 { - limit = 1000 - } - - // Use Income API to get REALIZED_PNL records (all symbols) - incomes, err := t.client.NewGetIncomeHistoryService(). - IncomeType("REALIZED_PNL"). - StartTime(startTime.UnixMilli()). - Limit(int64(limit)). - Do(context.Background()) - if err != nil { - return nil, fmt.Errorf("failed to get income history: %w", err) - } - - var trades []types.TradeRecord - for _, income := range incomes { - pnl, _ := strconv.ParseFloat(income.Income, 64) - if pnl == 0 { - continue // Skip zero PnL records - } - - // Income API doesn't provide full trade details, create a minimal record - // This is mainly used for detecting recent closures, not historical reconstruction - trade := types.TradeRecord{ - TradeID: strconv.FormatInt(income.TranID, 10), - Symbol: income.Symbol, - RealizedPnL: pnl, - Time: time.UnixMilli(income.Time).UTC(), - // Note: Income API doesn't provide price, quantity, side, fee - // For accurate data, use GetTradesForSymbol with specific symbol - } - trades = append(trades, trade) - } - - return trades, nil -} - -// GetTradesForSymbol retrieves trade history for a specific symbol -// This is more reliable than using Income API which may have delays -func (t *FuturesTrader) GetTradesForSymbol(symbol string, startTime time.Time, limit int) ([]types.TradeRecord, error) { - if limit <= 0 { - limit = 100 - } - if limit > 1000 { - limit = 1000 - } - - accountTrades, err := t.client.NewListAccountTradeService(). - Symbol(symbol). - StartTime(startTime.UnixMilli()). - Limit(limit). - Do(context.Background()) - if err != nil { - return nil, fmt.Errorf("failed to get trade history for %s: %w", symbol, err) - } - - var trades []types.TradeRecord - for _, at := range accountTrades { - price, _ := strconv.ParseFloat(at.Price, 64) - qty, _ := strconv.ParseFloat(at.Quantity, 64) - fee, _ := strconv.ParseFloat(at.Commission, 64) - pnl, _ := strconv.ParseFloat(at.RealizedPnl, 64) - - trade := types.TradeRecord{ - TradeID: strconv.FormatInt(at.ID, 10), - Symbol: at.Symbol, - Side: string(at.Side), - PositionSide: string(at.PositionSide), - Price: price, - Quantity: qty, - RealizedPnL: pnl, - Fee: fee, - Time: time.UnixMilli(at.Time).UTC(), - } - trades = append(trades, trade) - } - - return trades, nil -} - -// GetTradesForSymbolFromID retrieves trade history for a specific symbol starting from a given trade ID -// This is used for incremental sync - only fetch new trades since last sync -func (t *FuturesTrader) GetTradesForSymbolFromID(symbol string, fromID int64, limit int) ([]types.TradeRecord, error) { - if limit <= 0 { - limit = 100 - } - if limit > 1000 { - limit = 1000 - } - - accountTrades, err := t.client.NewListAccountTradeService(). - Symbol(symbol). - FromID(fromID). - Limit(limit). - Do(context.Background()) - if err != nil { - return nil, fmt.Errorf("failed to get trade history for %s from ID %d: %w", symbol, fromID, err) - } - - var trades []types.TradeRecord - for _, at := range accountTrades { - price, _ := strconv.ParseFloat(at.Price, 64) - qty, _ := strconv.ParseFloat(at.Quantity, 64) - fee, _ := strconv.ParseFloat(at.Commission, 64) - pnl, _ := strconv.ParseFloat(at.RealizedPnl, 64) - - trade := types.TradeRecord{ - TradeID: strconv.FormatInt(at.ID, 10), - Symbol: at.Symbol, - Side: string(at.Side), - PositionSide: string(at.PositionSide), - Price: price, - Quantity: qty, - RealizedPnL: pnl, - Fee: fee, - Time: time.UnixMilli(at.Time).UTC(), - } - trades = append(trades, trade) - } - - return trades, nil -} - -// GetCommissionSymbols returns symbols that have new commission records since lastSyncTime -// COMMISSION income is generated for every trade, so this is more reliable than REALIZED_PNL -func (t *FuturesTrader) GetCommissionSymbols(lastSyncTime time.Time) ([]string, error) { - incomes, err := t.client.NewGetIncomeHistoryService(). - IncomeType("COMMISSION"). - StartTime(lastSyncTime.UnixMilli()). - Limit(1000). - Do(context.Background()) - if err != nil { - return nil, fmt.Errorf("failed to get commission history: %w", err) - } - - symbolMap := make(map[string]bool) - for _, income := range incomes { - if income.Symbol != "" { - symbolMap[income.Symbol] = true - } - } - - var symbols []string - for symbol := range symbolMap { - symbols = append(symbols, symbol) - } - - return symbols, nil -} - -// GetPnLSymbols returns symbols that have REALIZED_PNL records since lastSyncTime -// This is a fallback when COMMISSION detection fails (VIP users, BNB fee discount) -func (t *FuturesTrader) GetPnLSymbols(lastSyncTime time.Time) ([]string, error) { - incomes, err := t.client.NewGetIncomeHistoryService(). - IncomeType("REALIZED_PNL"). - StartTime(lastSyncTime.UnixMilli()). - Limit(1000). - Do(context.Background()) - if err != nil { - return nil, fmt.Errorf("failed to get PnL history: %w", err) - } - - symbolMap := make(map[string]bool) - for _, income := range incomes { - if income.Symbol != "" { - symbolMap[income.Symbol] = true - } - } - - var symbols []string - for symbol := range symbolMap { - symbols = append(symbols, symbol) - } - - return symbols, nil -} diff --git a/trader/binance/futures_account.go b/trader/binance/futures_account.go new file mode 100644 index 00000000..549e3fd7 --- /dev/null +++ b/trader/binance/futures_account.go @@ -0,0 +1,291 @@ +package binance + +import ( + "context" + "fmt" + "nofx/logger" + "nofx/trader/types" + "strconv" + "time" +) + +// GetBalance gets account balance (with cache) +func (t *FuturesTrader) GetBalance() (map[string]interface{}, error) { + // First check if cache is valid + t.balanceCacheMutex.RLock() + if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration { + cacheAge := time.Since(t.balanceCacheTime) + t.balanceCacheMutex.RUnlock() + logger.Infof("✓ Using cached account balance (cache age: %.1f seconds ago)", cacheAge.Seconds()) + return t.cachedBalance, nil + } + t.balanceCacheMutex.RUnlock() + + // Cache expired or doesn't exist, call API + logger.Infof("🔄 Cache expired, calling Binance API to get account balance...") + account, err := t.client.NewGetAccountService().Do(context.Background()) + if err != nil { + logger.Infof("❌ Binance API call failed: %v", err) + return nil, fmt.Errorf("failed to get account info: %w", err) + } + + result := make(map[string]interface{}) + result["totalWalletBalance"], _ = strconv.ParseFloat(account.TotalWalletBalance, 64) + result["availableBalance"], _ = strconv.ParseFloat(account.AvailableBalance, 64) + result["totalUnrealizedProfit"], _ = strconv.ParseFloat(account.TotalUnrealizedProfit, 64) + + logger.Infof("✓ Binance API returned: total balance=%s, available=%s, unrealized PnL=%s", + account.TotalWalletBalance, + account.AvailableBalance, + account.TotalUnrealizedProfit) + + // Update cache + t.balanceCacheMutex.Lock() + t.cachedBalance = result + t.balanceCacheTime = time.Now() + t.balanceCacheMutex.Unlock() + + return result, nil +} + +// GetClosedPnL retrieves recent closing trades from Binance Futures +// Note: Binance does NOT have a position history API, only trade history. +// This returns individual closing trades (realizedPnl != 0) for real-time position closure detection. +// NOT suitable for historical position reconstruction - use only for matching recent closures. +func (t *FuturesTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) { + trades, err := t.GetTrades(startTime, limit) + if err != nil { + return nil, err + } + + // Filter only closing trades (realizedPnl != 0) and convert to ClosedPnLRecord + var records []types.ClosedPnLRecord + for _, trade := range trades { + if trade.RealizedPnL == 0 { + continue // Skip opening trades + } + + // Determine side from trade + side := "long" + if trade.PositionSide == "SHORT" || trade.PositionSide == "short" { + side = "short" + } else if trade.PositionSide == "BOTH" || trade.PositionSide == "" { + // One-way mode: selling closes long, buying closes short + if trade.Side == "SELL" || trade.Side == "Sell" { + side = "long" + } else { + side = "short" + } + } + + // Calculate entry price from PnL (mathematically accurate for this trade) + var entryPrice float64 + if trade.Quantity > 0 { + if side == "long" { + entryPrice = trade.Price - trade.RealizedPnL/trade.Quantity + } else { + entryPrice = trade.Price + trade.RealizedPnL/trade.Quantity + } + } + + records = append(records, types.ClosedPnLRecord{ + Symbol: trade.Symbol, + Side: side, + EntryPrice: entryPrice, + ExitPrice: trade.Price, + Quantity: trade.Quantity, + RealizedPnL: trade.RealizedPnL, + Fee: trade.Fee, + ExitTime: trade.Time, + EntryTime: trade.Time, // Approximate + OrderID: trade.TradeID, + ExchangeID: trade.TradeID, + CloseType: "unknown", + }) + } + + return records, nil +} + +// GetTrades retrieves trade history from Binance Futures using Income API +// Note: Income API has delays (~minutes), for real-time use GetTradesForSymbol instead +func (t *FuturesTrader) GetTrades(startTime time.Time, limit int) ([]types.TradeRecord, error) { + if limit <= 0 { + limit = 100 + } + if limit > 1000 { + limit = 1000 + } + + // Use Income API to get REALIZED_PNL records (all symbols) + incomes, err := t.client.NewGetIncomeHistoryService(). + IncomeType("REALIZED_PNL"). + StartTime(startTime.UnixMilli()). + Limit(int64(limit)). + Do(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get income history: %w", err) + } + + var trades []types.TradeRecord + for _, income := range incomes { + pnl, _ := strconv.ParseFloat(income.Income, 64) + if pnl == 0 { + continue // Skip zero PnL records + } + + // Income API doesn't provide full trade details, create a minimal record + // This is mainly used for detecting recent closures, not historical reconstruction + trade := types.TradeRecord{ + TradeID: strconv.FormatInt(income.TranID, 10), + Symbol: income.Symbol, + RealizedPnL: pnl, + Time: time.UnixMilli(income.Time).UTC(), + // Note: Income API doesn't provide price, quantity, side, fee + // For accurate data, use GetTradesForSymbol with specific symbol + } + trades = append(trades, trade) + } + + return trades, nil +} + +// GetTradesForSymbol retrieves trade history for a specific symbol +// This is more reliable than using Income API which may have delays +func (t *FuturesTrader) GetTradesForSymbol(symbol string, startTime time.Time, limit int) ([]types.TradeRecord, error) { + if limit <= 0 { + limit = 100 + } + if limit > 1000 { + limit = 1000 + } + + accountTrades, err := t.client.NewListAccountTradeService(). + Symbol(symbol). + StartTime(startTime.UnixMilli()). + Limit(limit). + Do(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get trade history for %s: %w", symbol, err) + } + + var trades []types.TradeRecord + for _, at := range accountTrades { + price, _ := strconv.ParseFloat(at.Price, 64) + qty, _ := strconv.ParseFloat(at.Quantity, 64) + fee, _ := strconv.ParseFloat(at.Commission, 64) + pnl, _ := strconv.ParseFloat(at.RealizedPnl, 64) + + trade := types.TradeRecord{ + TradeID: strconv.FormatInt(at.ID, 10), + Symbol: at.Symbol, + Side: string(at.Side), + PositionSide: string(at.PositionSide), + Price: price, + Quantity: qty, + RealizedPnL: pnl, + Fee: fee, + Time: time.UnixMilli(at.Time).UTC(), + } + trades = append(trades, trade) + } + + return trades, nil +} + +// GetTradesForSymbolFromID retrieves trade history for a specific symbol starting from a given trade ID +// This is used for incremental sync - only fetch new trades since last sync +func (t *FuturesTrader) GetTradesForSymbolFromID(symbol string, fromID int64, limit int) ([]types.TradeRecord, error) { + if limit <= 0 { + limit = 100 + } + if limit > 1000 { + limit = 1000 + } + + accountTrades, err := t.client.NewListAccountTradeService(). + Symbol(symbol). + FromID(fromID). + Limit(limit). + Do(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get trade history for %s from ID %d: %w", symbol, fromID, err) + } + + var trades []types.TradeRecord + for _, at := range accountTrades { + price, _ := strconv.ParseFloat(at.Price, 64) + qty, _ := strconv.ParseFloat(at.Quantity, 64) + fee, _ := strconv.ParseFloat(at.Commission, 64) + pnl, _ := strconv.ParseFloat(at.RealizedPnl, 64) + + trade := types.TradeRecord{ + TradeID: strconv.FormatInt(at.ID, 10), + Symbol: at.Symbol, + Side: string(at.Side), + PositionSide: string(at.PositionSide), + Price: price, + Quantity: qty, + RealizedPnL: pnl, + Fee: fee, + Time: time.UnixMilli(at.Time).UTC(), + } + trades = append(trades, trade) + } + + return trades, nil +} + +// GetCommissionSymbols returns symbols that have new commission records since lastSyncTime +// COMMISSION income is generated for every trade, so this is more reliable than REALIZED_PNL +func (t *FuturesTrader) GetCommissionSymbols(lastSyncTime time.Time) ([]string, error) { + incomes, err := t.client.NewGetIncomeHistoryService(). + IncomeType("COMMISSION"). + StartTime(lastSyncTime.UnixMilli()). + Limit(1000). + Do(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get commission history: %w", err) + } + + symbolMap := make(map[string]bool) + for _, income := range incomes { + if income.Symbol != "" { + symbolMap[income.Symbol] = true + } + } + + var symbols []string + for symbol := range symbolMap { + symbols = append(symbols, symbol) + } + + return symbols, nil +} + +// GetPnLSymbols returns symbols that have REALIZED_PNL records since lastSyncTime +// This is a fallback when COMMISSION detection fails (VIP users, BNB fee discount) +func (t *FuturesTrader) GetPnLSymbols(lastSyncTime time.Time) ([]string, error) { + incomes, err := t.client.NewGetIncomeHistoryService(). + IncomeType("REALIZED_PNL"). + StartTime(lastSyncTime.UnixMilli()). + Limit(1000). + Do(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get PnL history: %w", err) + } + + symbolMap := make(map[string]bool) + for _, income := range incomes { + if income.Symbol != "" { + symbolMap[income.Symbol] = true + } + } + + var symbols []string + for symbol := range symbolMap { + symbols = append(symbols, symbol) + } + + return symbols, nil +} diff --git a/trader/binance/futures_orders.go b/trader/binance/futures_orders.go new file mode 100644 index 00000000..3e3f9644 --- /dev/null +++ b/trader/binance/futures_orders.go @@ -0,0 +1,758 @@ +package binance + +import ( + "context" + "fmt" + "nofx/logger" + "nofx/trader/types" + "strconv" + + "github.com/adshao/go-binance/v2/futures" +) + +// OpenLong opens a long position +func (t *FuturesTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { + // First cancel all pending orders for this symbol (clean up old stop-loss and take-profit orders) + if err := t.CancelAllOrders(symbol); err != nil { + logger.Infof(" ⚠ Failed to cancel old pending orders (may not have any): %v", err) + } + + // Set leverage + if err := t.SetLeverage(symbol, leverage); err != nil { + return nil, err + } + + // Note: Margin mode should be set by the caller (AutoTrader) before opening position via SetMarginMode + + // Format quantity to correct precision + quantityStr, err := t.FormatQuantity(symbol, quantity) + if err != nil { + return nil, err + } + + // Check if formatted quantity is 0 (prevent rounding errors) + quantityFloat, parseErr := strconv.ParseFloat(quantityStr, 64) + if parseErr != nil || quantityFloat <= 0 { + return nil, fmt.Errorf("position size too small, rounded to 0 (original: %.8f → formatted: %s). Suggest increasing position amount or selecting a lower-priced coin", quantity, quantityStr) + } + + // Check minimum notional value (Binance requires at least 10 USDT) + if err := t.CheckMinNotional(symbol, quantityFloat); err != nil { + return nil, err + } + + // Create market buy order (using br ID) + order, err := t.client.NewCreateOrderService(). + Symbol(symbol). + Side(futures.SideTypeBuy). + PositionSide(futures.PositionSideTypeLong). + Type(futures.OrderTypeMarket). + Quantity(quantityStr). + NewClientOrderID(getBrOrderID()). + Do(context.Background()) + + if err != nil { + return nil, fmt.Errorf("failed to open long position: %w", err) + } + + logger.Infof("✓ Opened long position successfully: %s quantity: %s", symbol, quantityStr) + logger.Infof(" Order ID: %d", order.OrderID) + + result := make(map[string]interface{}) + result["orderId"] = order.OrderID + result["symbol"] = order.Symbol + result["status"] = order.Status + return result, nil +} + +// OpenShort opens a short position +func (t *FuturesTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { + // First cancel all pending orders for this symbol (clean up old stop-loss and take-profit orders) + if err := t.CancelAllOrders(symbol); err != nil { + logger.Infof(" ⚠ Failed to cancel old pending orders (may not have any): %v", err) + } + + // Set leverage + if err := t.SetLeverage(symbol, leverage); err != nil { + return nil, err + } + + // Note: Margin mode should be set by the caller (AutoTrader) before opening position via SetMarginMode + + // Format quantity to correct precision + quantityStr, err := t.FormatQuantity(symbol, quantity) + if err != nil { + return nil, err + } + + // Check if formatted quantity is 0 (prevent rounding errors) + quantityFloat, parseErr := strconv.ParseFloat(quantityStr, 64) + if parseErr != nil || quantityFloat <= 0 { + return nil, fmt.Errorf("position size too small, rounded to 0 (original: %.8f → formatted: %s). Suggest increasing position amount or selecting a lower-priced coin", quantity, quantityStr) + } + + // Check minimum notional value (Binance requires at least 10 USDT) + if err := t.CheckMinNotional(symbol, quantityFloat); err != nil { + return nil, err + } + + // Create market sell order (using br ID) + order, err := t.client.NewCreateOrderService(). + Symbol(symbol). + Side(futures.SideTypeSell). + PositionSide(futures.PositionSideTypeShort). + Type(futures.OrderTypeMarket). + Quantity(quantityStr). + NewClientOrderID(getBrOrderID()). + Do(context.Background()) + + if err != nil { + return nil, fmt.Errorf("failed to open short position: %w", err) + } + + logger.Infof("✓ Opened short position successfully: %s quantity: %s", symbol, quantityStr) + logger.Infof(" Order ID: %d", order.OrderID) + + result := make(map[string]interface{}) + result["orderId"] = order.OrderID + result["symbol"] = order.Symbol + result["status"] = order.Status + return result, nil +} + +// CloseLong closes a long position +func (t *FuturesTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) { + // If quantity is 0, get current position quantity + if quantity == 0 { + positions, err := t.GetPositions() + if err != nil { + return nil, err + } + + for _, pos := range positions { + if pos["symbol"] == symbol && pos["side"] == "long" { + quantity = pos["positionAmt"].(float64) + break + } + } + + if quantity == 0 { + return nil, fmt.Errorf("no long position found for %s", symbol) + } + } + + // Format quantity + quantityStr, err := t.FormatQuantity(symbol, quantity) + if err != nil { + return nil, err + } + + // Create market sell order (close long, using br ID) + order, err := t.client.NewCreateOrderService(). + Symbol(symbol). + Side(futures.SideTypeSell). + PositionSide(futures.PositionSideTypeLong). + Type(futures.OrderTypeMarket). + Quantity(quantityStr). + NewClientOrderID(getBrOrderID()). + Do(context.Background()) + + if err != nil { + return nil, fmt.Errorf("failed to close long position: %w", err) + } + + logger.Infof("✓ Closed long position successfully: %s quantity: %s", symbol, quantityStr) + + // After closing position, cancel all pending orders for this symbol (stop-loss and take-profit orders) + if err := t.CancelAllOrders(symbol); err != nil { + logger.Infof(" ⚠ Failed to cancel pending orders: %v", err) + } + + result := make(map[string]interface{}) + result["orderId"] = order.OrderID + result["symbol"] = order.Symbol + result["status"] = order.Status + return result, nil +} + +// CloseShort closes a short position +func (t *FuturesTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) { + // If quantity is 0, get current position quantity + if quantity == 0 { + positions, err := t.GetPositions() + if err != nil { + return nil, err + } + + for _, pos := range positions { + if pos["symbol"] == symbol && pos["side"] == "short" { + quantity = -pos["positionAmt"].(float64) // Short position quantity is negative, take absolute value + break + } + } + + if quantity == 0 { + return nil, fmt.Errorf("no short position found for %s", symbol) + } + } + + // Format quantity + quantityStr, err := t.FormatQuantity(symbol, quantity) + if err != nil { + return nil, err + } + + // Create market buy order (close short, using br ID) + order, err := t.client.NewCreateOrderService(). + Symbol(symbol). + Side(futures.SideTypeBuy). + PositionSide(futures.PositionSideTypeShort). + Type(futures.OrderTypeMarket). + Quantity(quantityStr). + NewClientOrderID(getBrOrderID()). + Do(context.Background()) + + if err != nil { + return nil, fmt.Errorf("failed to close short position: %w", err) + } + + logger.Infof("✓ Closed short position successfully: %s quantity: %s", symbol, quantityStr) + + // After closing position, cancel all pending orders for this symbol (stop-loss and take-profit orders) + if err := t.CancelAllOrders(symbol); err != nil { + logger.Infof(" ⚠ Failed to cancel pending orders: %v", err) + } + + result := make(map[string]interface{}) + result["orderId"] = order.OrderID + result["symbol"] = order.Symbol + result["status"] = order.Status + return result, nil +} + +// CancelStopLossOrders cancels only stop-loss orders (doesn't affect take-profit orders) +// Now uses both legacy API and new Algo Order API +func (t *FuturesTrader) CancelStopLossOrders(symbol string) error { + canceledCount := 0 + var cancelErrors []error + + // 1. Cancel legacy stop-loss orders + orders, err := t.client.NewListOpenOrdersService(). + Symbol(symbol). + Do(context.Background()) + + if err == nil { + for _, order := range orders { + orderType := string(order.Type) + + // Only cancel stop-loss orders (don't cancel take-profit orders) + // Use string comparison since OrderType constants were removed in v2.8.9 + if orderType == "STOP_MARKET" || orderType == "STOP" { + _, err := t.client.NewCancelOrderService(). + Symbol(symbol). + OrderID(order.OrderID). + Do(context.Background()) + + if err != nil { + errMsg := fmt.Sprintf("Order ID %d: %v", order.OrderID, err) + cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg)) + logger.Infof(" ⚠ Failed to cancel legacy stop-loss order: %s", errMsg) + continue + } + + canceledCount++ + logger.Infof(" ✓ Canceled legacy stop-loss order (Order ID: %d, Type: %s, Side: %s)", order.OrderID, orderType, order.PositionSide) + } + } + } + + // 2. Cancel Algo stop-loss orders + algoOrders, err := t.client.NewListOpenAlgoOrdersService(). + Symbol(symbol). + Do(context.Background()) + + if err == nil { + for _, algoOrder := range algoOrders { + // Only cancel stop-loss orders + if algoOrder.OrderType == futures.AlgoOrderTypeStopMarket || algoOrder.OrderType == futures.AlgoOrderTypeStop { + _, err := t.client.NewCancelAlgoOrderService(). + AlgoID(algoOrder.AlgoId). + Do(context.Background()) + + if err != nil { + errMsg := fmt.Sprintf("Algo ID %d: %v", algoOrder.AlgoId, err) + cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg)) + logger.Infof(" ⚠ Failed to cancel Algo stop-loss order: %s", errMsg) + continue + } + + canceledCount++ + logger.Infof(" ✓ Canceled Algo stop-loss order (Algo ID: %d, Type: %s)", algoOrder.AlgoId, algoOrder.OrderType) + } + } + } + + if canceledCount == 0 && len(cancelErrors) == 0 { + logger.Infof(" ℹ %s has no stop-loss orders to cancel", symbol) + } else if canceledCount > 0 { + logger.Infof(" ✓ Canceled %d stop-loss order(s) for %s", canceledCount, symbol) + } + + // If all cancellations failed, return error + if len(cancelErrors) > 0 && canceledCount == 0 { + return fmt.Errorf("failed to cancel stop-loss orders: %v", cancelErrors) + } + + return nil +} + +// CancelTakeProfitOrders cancels only take-profit orders (doesn't affect stop-loss orders) +// Now uses both legacy API and new Algo Order API +func (t *FuturesTrader) CancelTakeProfitOrders(symbol string) error { + canceledCount := 0 + var cancelErrors []error + + // 1. Cancel legacy take-profit orders + orders, err := t.client.NewListOpenOrdersService(). + Symbol(symbol). + Do(context.Background()) + + if err == nil { + for _, order := range orders { + orderType := string(order.Type) + + // Only cancel take-profit orders (don't cancel stop-loss orders) + // Use string comparison since OrderType constants were removed in v2.8.9 + if orderType == "TAKE_PROFIT_MARKET" || orderType == "TAKE_PROFIT" { + _, err := t.client.NewCancelOrderService(). + Symbol(symbol). + OrderID(order.OrderID). + Do(context.Background()) + + if err != nil { + errMsg := fmt.Sprintf("Order ID %d: %v", order.OrderID, err) + cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg)) + logger.Infof(" ⚠ Failed to cancel legacy take-profit order: %s", errMsg) + continue + } + + canceledCount++ + logger.Infof(" ✓ Canceled legacy take-profit order (Order ID: %d, Type: %s, Side: %s)", order.OrderID, orderType, order.PositionSide) + } + } + } + + // 2. Cancel Algo take-profit orders + algoOrders, err := t.client.NewListOpenAlgoOrdersService(). + Symbol(symbol). + Do(context.Background()) + + if err == nil { + for _, algoOrder := range algoOrders { + // Only cancel take-profit orders + if algoOrder.OrderType == futures.AlgoOrderTypeTakeProfitMarket || algoOrder.OrderType == futures.AlgoOrderTypeTakeProfit { + _, err := t.client.NewCancelAlgoOrderService(). + AlgoID(algoOrder.AlgoId). + Do(context.Background()) + + if err != nil { + errMsg := fmt.Sprintf("Algo ID %d: %v", algoOrder.AlgoId, err) + cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg)) + logger.Infof(" ⚠ Failed to cancel Algo take-profit order: %s", errMsg) + continue + } + + canceledCount++ + logger.Infof(" ✓ Canceled Algo take-profit order (Algo ID: %d, Type: %s)", algoOrder.AlgoId, algoOrder.OrderType) + } + } + } + + if canceledCount == 0 && len(cancelErrors) == 0 { + logger.Infof(" ℹ %s has no take-profit orders to cancel", symbol) + } else if canceledCount > 0 { + logger.Infof(" ✓ Canceled %d take-profit order(s) for %s", canceledCount, symbol) + } + + // If all cancellations failed, return error + if len(cancelErrors) > 0 && canceledCount == 0 { + return fmt.Errorf("failed to cancel take-profit orders: %v", cancelErrors) + } + + return nil +} + +// CancelAllOrders cancels all pending orders for this symbol +// Now uses both legacy API and new Algo Order API +func (t *FuturesTrader) CancelAllOrders(symbol string) error { + // 1. Cancel all legacy orders + err := t.client.NewCancelAllOpenOrdersService(). + Symbol(symbol). + Do(context.Background()) + + if err != nil { + logger.Infof(" ⚠ Failed to cancel legacy orders: %v", err) + } else { + logger.Infof(" ✓ Canceled all legacy pending orders for %s", symbol) + } + + // 2. Cancel all Algo orders + err = t.client.NewCancelAllAlgoOpenOrdersService(). + Symbol(symbol). + Do(context.Background()) + + if err != nil { + // Ignore "no algo orders" error + if !contains(err.Error(), "no algo") && !contains(err.Error(), "No algo") { + logger.Infof(" ⚠ Failed to cancel Algo orders: %v", err) + } + } else { + logger.Infof(" ✓ Canceled all Algo orders for %s", symbol) + } + + return nil +} + +// PlaceLimitOrder places a limit order for grid trading +// This implements the GridTrader interface for FuturesTrader +func (t *FuturesTrader) PlaceLimitOrder(req *types.LimitOrderRequest) (*types.LimitOrderResult, error) { + // Format quantity to correct precision + quantityStr, err := t.FormatQuantity(req.Symbol, req.Quantity) + if err != nil { + return nil, fmt.Errorf("failed to format quantity: %w", err) + } + + // Format price to correct precision + priceStr, err := t.FormatPrice(req.Symbol, req.Price) + if err != nil { + return nil, fmt.Errorf("failed to format price: %w", err) + } + + // Set leverage if specified + if req.Leverage > 0 { + if err := t.SetLeverage(req.Symbol, req.Leverage); err != nil { + logger.Warnf("Failed to set leverage: %v", err) + } + } + + // Determine side and position side + var side futures.SideType + var positionSide futures.PositionSideType + + if req.Side == "BUY" { + side = futures.SideTypeBuy + positionSide = futures.PositionSideTypeLong + } else { + side = futures.SideTypeSell + positionSide = futures.PositionSideTypeShort + } + + // Build order service with broker ID + orderService := t.client.NewCreateOrderService(). + Symbol(req.Symbol). + Side(side). + PositionSide(positionSide). + Type(futures.OrderTypeLimit). + TimeInForce(futures.TimeInForceTypeGTC). + Quantity(quantityStr). + Price(priceStr). + NewClientOrderID(getBrOrderID()) + + // Execute order + order, err := orderService.Do(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to place limit order: %w", err) + } + + logger.Infof("✓ [Grid] Placed limit order: %s %s %s @ %s, qty=%s, orderID=%d", + req.Symbol, req.Side, positionSide, priceStr, quantityStr, order.OrderID) + + return &types.LimitOrderResult{ + OrderID: fmt.Sprintf("%d", order.OrderID), + ClientID: order.ClientOrderID, + Symbol: order.Symbol, + Side: string(order.Side), + PositionSide: string(order.PositionSide), + Price: req.Price, + Quantity: req.Quantity, + Status: string(order.Status), + }, nil +} + +// CancelOrder cancels a specific order by ID +// This implements the GridTrader interface for FuturesTrader +func (t *FuturesTrader) CancelOrder(symbol, orderID string) error { + // Parse order ID to int64 + orderIDInt, err := strconv.ParseInt(orderID, 10, 64) + if err != nil { + return fmt.Errorf("invalid order ID: %w", err) + } + + _, err = t.client.NewCancelOrderService(). + Symbol(symbol). + OrderID(orderIDInt). + Do(context.Background()) + + if err != nil { + return fmt.Errorf("failed to cancel order: %w", err) + } + + logger.Infof("✓ [Grid] Cancelled order: %s/%s", symbol, orderID) + return nil +} + +// GetOrderBook gets the order book for a symbol +// This implements the GridTrader interface for FuturesTrader +func (t *FuturesTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { + book, err := t.client.NewDepthService(). + Symbol(symbol). + Limit(depth). + Do(context.Background()) + + if err != nil { + return nil, nil, fmt.Errorf("failed to get order book: %w", err) + } + + // Convert bids + bids = make([][]float64, len(book.Bids)) + for i, bid := range book.Bids { + price, _ := strconv.ParseFloat(bid.Price, 64) + qty, _ := strconv.ParseFloat(bid.Quantity, 64) + bids[i] = []float64{price, qty} + } + + // Convert asks + asks = make([][]float64, len(book.Asks)) + for i, ask := range book.Asks { + price, _ := strconv.ParseFloat(ask.Price, 64) + qty, _ := strconv.ParseFloat(ask.Quantity, 64) + asks[i] = []float64{price, qty} + } + + return bids, asks, nil +} + +// CancelStopOrders cancels take-profit/stop-loss orders for this symbol (used to adjust TP/SL positions) +// Now uses both legacy API and new Algo Order API (Binance migrated stop orders to Algo system) +func (t *FuturesTrader) CancelStopOrders(symbol string) error { + canceledCount := 0 + + // 1. Cancel legacy stop orders (for backward compatibility) + orders, err := t.client.NewListOpenOrdersService(). + Symbol(symbol). + Do(context.Background()) + + if err == nil { + for _, order := range orders { + orderType := string(order.Type) + + // Only cancel stop-loss and take-profit orders + // Use string comparison since OrderType constants were removed in v2.8.9 + if orderType == "STOP_MARKET" || + orderType == "TAKE_PROFIT_MARKET" || + orderType == "STOP" || + orderType == "TAKE_PROFIT" { + + _, err := t.client.NewCancelOrderService(). + Symbol(symbol). + OrderID(order.OrderID). + Do(context.Background()) + + if err != nil { + logger.Infof(" ⚠ Failed to cancel legacy order %d: %v", order.OrderID, err) + continue + } + + canceledCount++ + logger.Infof(" ✓ Canceled legacy stop order for %s (Order ID: %d, Type: %s)", + symbol, order.OrderID, orderType) + } + } + } + + // 2. Cancel Algo orders (new API) + err = t.client.NewCancelAllAlgoOpenOrdersService(). + Symbol(symbol). + Do(context.Background()) + + if err != nil { + // Ignore "no algo orders" error + if !contains(err.Error(), "no algo") && !contains(err.Error(), "No algo") { + logger.Infof(" ⚠ Failed to cancel Algo orders: %v", err) + } + } else { + logger.Infof(" ✓ Canceled all Algo orders for %s", symbol) + canceledCount++ + } + + if canceledCount == 0 { + logger.Infof(" ℹ %s has no take-profit/stop-loss orders to cancel", symbol) + } + + return nil +} + +// GetOpenOrders gets all open/pending orders for a symbol +func (t *FuturesTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) { + var result []types.OpenOrder + + // 1. Get legacy open orders + orders, err := t.client.NewListOpenOrdersService(). + Symbol(symbol). + Do(context.Background()) + + if err != nil { + return nil, fmt.Errorf("failed to get open orders: %w", err) + } + + for _, order := range orders { + price, _ := strconv.ParseFloat(order.Price, 64) + stopPrice, _ := strconv.ParseFloat(order.StopPrice, 64) + quantity, _ := strconv.ParseFloat(order.OrigQuantity, 64) + + result = append(result, types.OpenOrder{ + OrderID: fmt.Sprintf("%d", order.OrderID), + Symbol: order.Symbol, + Side: string(order.Side), + PositionSide: string(order.PositionSide), + Type: string(order.Type), + Price: price, + StopPrice: stopPrice, + Quantity: quantity, + Status: string(order.Status), + }) + } + + // 2. Get Algo orders (new API for stop-loss/take-profit) + algoOrders, err := t.client.NewListOpenAlgoOrdersService(). + Symbol(symbol). + Do(context.Background()) + + if err == nil { + for _, algoOrder := range algoOrders { + triggerPrice, _ := strconv.ParseFloat(algoOrder.TriggerPrice, 64) + quantity, _ := strconv.ParseFloat(algoOrder.Quantity, 64) + + result = append(result, types.OpenOrder{ + OrderID: fmt.Sprintf("%d", algoOrder.AlgoId), + Symbol: algoOrder.Symbol, + Side: string(algoOrder.Side), + PositionSide: string(algoOrder.PositionSide), + Type: string(algoOrder.OrderType), + Price: 0, // Algo orders use stop price + StopPrice: triggerPrice, + Quantity: quantity, + Status: "NEW", + }) + } + } + + return result, nil +} + +// SetStopLoss sets stop-loss order using new Algo Order API +// Binance has migrated stop orders to Algo Order system (error -4120 STOP_ORDER_SWITCH_ALGO) +func (t *FuturesTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error { + var side futures.SideType + var posSide futures.PositionSideType + + if positionSide == "LONG" { + side = futures.SideTypeSell + posSide = futures.PositionSideTypeLong + } else { + side = futures.SideTypeBuy + posSide = futures.PositionSideTypeShort + } + + // Use new Algo Order API + _, err := t.client.NewCreateAlgoOrderService(). + Symbol(symbol). + Side(side). + PositionSide(posSide). + Type(futures.AlgoOrderTypeStopMarket). + TriggerPrice(fmt.Sprintf("%.8f", stopPrice)). + WorkingType(futures.WorkingTypeContractPrice). + ClosePosition(true). + ClientAlgoId(getBrOrderID()). + Do(context.Background()) + + if err != nil { + return fmt.Errorf("failed to set stop-loss: %w", err) + } + + logger.Infof(" Stop-loss price set (Algo Order): %.4f", stopPrice) + return nil +} + +// SetTakeProfit sets take-profit order using new Algo Order API +// Binance has migrated stop orders to Algo Order system (error -4120 STOP_ORDER_SWITCH_ALGO) +func (t *FuturesTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error { + var side futures.SideType + var posSide futures.PositionSideType + + if positionSide == "LONG" { + side = futures.SideTypeSell + posSide = futures.PositionSideTypeLong + } else { + side = futures.SideTypeBuy + posSide = futures.PositionSideTypeShort + } + + // Use new Algo Order API + _, err := t.client.NewCreateAlgoOrderService(). + Symbol(symbol). + Side(side). + PositionSide(posSide). + Type(futures.AlgoOrderTypeTakeProfitMarket). + TriggerPrice(fmt.Sprintf("%.8f", takeProfitPrice)). + WorkingType(futures.WorkingTypeContractPrice). + ClosePosition(true). + ClientAlgoId(getBrOrderID()). + Do(context.Background()) + + if err != nil { + return fmt.Errorf("failed to set take-profit: %w", err) + } + + logger.Infof(" Take-profit price set (Algo Order): %.4f", takeProfitPrice) + return nil +} + +// GetOrderStatus gets order status +func (t *FuturesTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { + // Convert orderID to int64 + orderIDInt, err := strconv.ParseInt(orderID, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid order ID: %s", orderID) + } + + order, err := t.client.NewGetOrderService(). + Symbol(symbol). + OrderID(orderIDInt). + Do(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get order status: %w", err) + } + + // Parse execution price + avgPrice, _ := strconv.ParseFloat(order.AvgPrice, 64) + executedQty, _ := strconv.ParseFloat(order.ExecutedQuantity, 64) + + result := map[string]interface{}{ + "orderId": order.OrderID, + "symbol": order.Symbol, + "status": string(order.Status), + "avgPrice": avgPrice, + "executedQty": executedQty, + "side": string(order.Side), + "type": string(order.Type), + "time": order.Time, + "updateTime": order.UpdateTime, + } + + // Binance futures commission fee needs to be obtained through GetUserTrades, not retrieved here for now + // Can be obtained later through WebSocket or separate query + result["commission"] = 0.0 + + return result, nil +} diff --git a/trader/binance/futures_positions.go b/trader/binance/futures_positions.go new file mode 100644 index 00000000..d49e960a --- /dev/null +++ b/trader/binance/futures_positions.go @@ -0,0 +1,290 @@ +package binance + +import ( + "context" + "fmt" + "nofx/logger" + "strconv" + "time" + + "github.com/adshao/go-binance/v2/futures" +) + +// GetPositions gets all positions (with cache) +func (t *FuturesTrader) GetPositions() ([]map[string]interface{}, error) { + // First check if cache is valid + t.positionsCacheMutex.RLock() + if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration { + cacheAge := time.Since(t.positionsCacheTime) + t.positionsCacheMutex.RUnlock() + logger.Infof("✓ Using cached position information (cache age: %.1f seconds ago)", cacheAge.Seconds()) + return t.cachedPositions, nil + } + t.positionsCacheMutex.RUnlock() + + // Cache expired or doesn't exist, call API + logger.Infof("🔄 Cache expired, calling Binance API to get position information...") + positions, err := t.client.NewGetPositionRiskService().Do(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get positions: %w", err) + } + + var result []map[string]interface{} + for _, pos := range positions { + posAmt, _ := strconv.ParseFloat(pos.PositionAmt, 64) + if posAmt == 0 { + continue // Skip positions with zero amount + } + + posMap := make(map[string]interface{}) + posMap["symbol"] = pos.Symbol + posMap["positionAmt"], _ = strconv.ParseFloat(pos.PositionAmt, 64) + posMap["entryPrice"], _ = strconv.ParseFloat(pos.EntryPrice, 64) + posMap["markPrice"], _ = strconv.ParseFloat(pos.MarkPrice, 64) + posMap["unRealizedProfit"], _ = strconv.ParseFloat(pos.UnRealizedProfit, 64) + posMap["leverage"], _ = strconv.ParseFloat(pos.Leverage, 64) + posMap["liquidationPrice"], _ = strconv.ParseFloat(pos.LiquidationPrice, 64) + // Note: Binance SDK doesn't expose updateTime field, will fallback to local tracking + + // Determine direction + if posAmt > 0 { + posMap["side"] = "long" + } else { + posMap["side"] = "short" + } + + result = append(result, posMap) + } + + // Update cache + t.positionsCacheMutex.Lock() + t.cachedPositions = result + t.positionsCacheTime = time.Now() + t.positionsCacheMutex.Unlock() + + return result, nil +} + +// SetMarginMode sets margin mode +func (t *FuturesTrader) SetMarginMode(symbol string, isCrossMargin bool) error { + var marginType futures.MarginType + if isCrossMargin { + marginType = futures.MarginTypeCrossed + } else { + marginType = futures.MarginTypeIsolated + } + + // Try to set margin mode + err := t.client.NewChangeMarginTypeService(). + Symbol(symbol). + MarginType(marginType). + Do(context.Background()) + + marginModeStr := "Cross Margin" + if !isCrossMargin { + marginModeStr = "Isolated Margin" + } + + if err != nil { + // If error message contains "No need to change", margin mode is already set to target value + if contains(err.Error(), "No need to change margin type") { + logger.Infof(" ✓ %s margin mode is already %s", symbol, marginModeStr) + return nil + } + // If there is an open position, margin mode cannot be changed, but this doesn't affect trading + if contains(err.Error(), "Margin type cannot be changed if there exists position") { + logger.Infof(" ⚠️ %s has open positions, cannot change margin mode, continuing with current mode", symbol) + return nil + } + // Detect Multi-Assets mode (error code -4168) + if contains(err.Error(), "Multi-Assets mode") || contains(err.Error(), "-4168") || contains(err.Error(), "4168") { + logger.Infof(" ⚠️ %s detected Multi-Assets mode, forcing Cross Margin mode", symbol) + logger.Infof(" 💡 Tip: To use Isolated Margin mode, please disable Multi-Assets mode in Binance") + return nil + } + // Detect Unified Account API (Portfolio Margin) + if contains(err.Error(), "unified") || contains(err.Error(), "portfolio") || contains(err.Error(), "Portfolio") { + logger.Infof(" ❌ %s detected Unified Account API, unable to trade futures", symbol) + return fmt.Errorf("please use 'Spot & Futures Trading' API permission, do not use 'Unified Account API'") + } + logger.Infof(" ⚠️ Failed to set margin mode: %v", err) + // Don't return error, let trading continue + return nil + } + + logger.Infof(" ✓ %s margin mode set to %s", symbol, marginModeStr) + return nil +} + +// SetLeverage sets leverage (with smart detection and cooldown period) +func (t *FuturesTrader) SetLeverage(symbol string, leverage int) error { + // First try to get current leverage (from position information) + currentLeverage := 0 + positions, err := t.GetPositions() + if err == nil { + for _, pos := range positions { + if pos["symbol"] == symbol { + if lev, ok := pos["leverage"].(float64); ok { + currentLeverage = int(lev) + break + } + } + } + } + + // If current leverage is already the target leverage, skip + if currentLeverage == leverage && currentLeverage > 0 { + logger.Infof(" ✓ %s leverage is already %dx, no need to change", symbol, leverage) + return nil + } + + // Change leverage + _, err = t.client.NewChangeLeverageService(). + Symbol(symbol). + Leverage(leverage). + Do(context.Background()) + + if err != nil { + // If error message contains "No need to change", leverage is already the target value + if contains(err.Error(), "No need to change") { + logger.Infof(" ✓ %s leverage is already %dx", symbol, leverage) + return nil + } + return fmt.Errorf("failed to set leverage: %w", err) + } + + logger.Infof(" ✓ %s leverage changed to %dx", symbol, leverage) + + // Wait 5 seconds after changing leverage (to avoid cooldown period errors) + logger.Infof(" ⏱ Waiting 5 seconds for cooldown period...") + time.Sleep(5 * time.Second) + + return nil +} + +// GetMarketPrice gets market price +func (t *FuturesTrader) GetMarketPrice(symbol string) (float64, error) { + prices, err := t.client.NewListPricesService().Symbol(symbol).Do(context.Background()) + if err != nil { + return 0, fmt.Errorf("failed to get price: %w", err) + } + + if len(prices) == 0 { + return 0, fmt.Errorf("price not found") + } + + price, err := strconv.ParseFloat(prices[0].Price, 64) + if err != nil { + return 0, err + } + + return price, nil +} + +// CalculatePositionSize calculates position size +func (t *FuturesTrader) CalculatePositionSize(balance, riskPercent, price float64, leverage int) float64 { + riskAmount := balance * (riskPercent / 100.0) + positionValue := riskAmount * float64(leverage) + quantity := positionValue / price + return quantity +} + +// GetMinNotional gets minimum notional value (Binance requirement) +func (t *FuturesTrader) GetMinNotional(symbol string) float64 { + // Use conservative default value of 10 USDT to ensure order passes exchange validation + return 10.0 +} + +// CheckMinNotional checks if order meets minimum notional value requirement +func (t *FuturesTrader) CheckMinNotional(symbol string, quantity float64) error { + price, err := t.GetMarketPrice(symbol) + if err != nil { + return fmt.Errorf("failed to get market price: %w", err) + } + + notionalValue := quantity * price + minNotional := t.GetMinNotional(symbol) + + if notionalValue < minNotional { + return fmt.Errorf( + "order amount %.2f USDT is below minimum requirement %.2f USDT (quantity: %.4f, price: %.4f)", + notionalValue, minNotional, quantity, price, + ) + } + + return nil +} + +// GetSymbolPrecision gets the quantity precision for a trading pair +func (t *FuturesTrader) GetSymbolPrecision(symbol string) (int, error) { + exchangeInfo, err := t.client.NewExchangeInfoService().Do(context.Background()) + if err != nil { + return 0, fmt.Errorf("failed to get trading rules: %w", err) + } + + for _, s := range exchangeInfo.Symbols { + if s.Symbol == symbol { + // Get precision from LOT_SIZE filter + for _, filter := range s.Filters { + if filter["filterType"] == "LOT_SIZE" { + stepSize := filter["stepSize"].(string) + precision := calculatePrecision(stepSize) + logger.Infof(" %s quantity precision: %d (stepSize: %s)", symbol, precision, stepSize) + return precision, nil + } + } + } + } + + logger.Infof(" ⚠ %s precision information not found, using default precision 3", symbol) + return 3, nil // Default precision is 3 +} + +// FormatQuantity formats quantity to correct precision +func (t *FuturesTrader) FormatQuantity(symbol string, quantity float64) (string, error) { + precision, err := t.GetSymbolPrecision(symbol) + if err != nil { + // If retrieval fails, use default format + return fmt.Sprintf("%.3f", quantity), nil + } + + format := fmt.Sprintf("%%.%df", precision) + return fmt.Sprintf(format, quantity), nil +} + +// GetSymbolPricePrecision gets the price precision for a trading pair +func (t *FuturesTrader) GetSymbolPricePrecision(symbol string) (int, error) { + exchangeInfo, err := t.client.NewExchangeInfoService().Do(context.Background()) + if err != nil { + return 0, fmt.Errorf("failed to get trading rules: %w", err) + } + + for _, s := range exchangeInfo.Symbols { + if s.Symbol == symbol { + // Get precision from PRICE_FILTER filter + for _, filter := range s.Filters { + if filter["filterType"] == "PRICE_FILTER" { + tickSize := filter["tickSize"].(string) + precision := calculatePrecision(tickSize) + return precision, nil + } + } + } + } + + // Default to 2 decimal places for price + return 2, nil +} + +// FormatPrice formats price to correct precision +func (t *FuturesTrader) FormatPrice(symbol string, price float64) (string, error) { + precision, err := t.GetSymbolPricePrecision(symbol) + if err != nil { + // If retrieval fails, use default format + return fmt.Sprintf("%.2f", price), nil + } + + format := fmt.Sprintf("%%.%df", precision) + return fmt.Sprintf(format, price), nil +} + diff --git a/trader/bitget/trader.go b/trader/bitget/trader.go index d44f7e01..449d451e 100644 --- a/trader/bitget/trader.go +++ b/trader/bitget/trader.go @@ -14,22 +14,21 @@ import ( "strings" "sync" "time" - "nofx/trader/types" ) // Bitget API endpoints (V2) const ( - bitgetBaseURL = "https://api.bitget.com" - bitgetAccountPath = "/api/v2/mix/account/accounts" - bitgetPositionPath = "/api/v2/mix/position/all-position" - bitgetOrderPath = "/api/v2/mix/order/place-order" - bitgetLeveragePath = "/api/v2/mix/account/set-leverage" - bitgetTickerPath = "/api/v2/mix/market/ticker" - bitgetContractsPath = "/api/v2/mix/market/contracts" - bitgetCancelOrderPath = "/api/v2/mix/order/cancel-order" - bitgetPendingPath = "/api/v2/mix/order/orders-pending" - bitgetHistoryPath = "/api/v2/mix/order/orders-history" - bitgetMarginModePath = "/api/v2/mix/account/set-margin-mode" + bitgetBaseURL = "https://api.bitget.com" + bitgetAccountPath = "/api/v2/mix/account/accounts" + bitgetPositionPath = "/api/v2/mix/position/all-position" + bitgetOrderPath = "/api/v2/mix/order/place-order" + bitgetLeveragePath = "/api/v2/mix/account/set-leverage" + bitgetTickerPath = "/api/v2/mix/market/ticker" + bitgetContractsPath = "/api/v2/mix/market/contracts" + bitgetCancelOrderPath = "/api/v2/mix/order/cancel-order" + bitgetPendingPath = "/api/v2/mix/order/orders-pending" + bitgetHistoryPath = "/api/v2/mix/order/orders-history" + bitgetMarginModePath = "/api/v2/mix/account/set-margin-mode" bitgetPositionModePath = "/api/v2/mix/account/set-position-mode" ) @@ -63,22 +62,22 @@ type BitgetTrader struct { // BitgetContract Bitget contract info type BitgetContract struct { - Symbol string // Symbol name - BaseCoin string // Base coin - QuoteCoin string // Quote coin - MinTradeNum float64 // Minimum trade amount - MaxTradeNum float64 // Maximum trade amount + Symbol string // Symbol name + BaseCoin string // Base coin + QuoteCoin string // Quote coin + MinTradeNum float64 // Minimum trade amount + MaxTradeNum float64 // Maximum trade amount SizeMultiplier float64 // Contract size multiplier - PricePlace int // Price decimal places - VolumePlace int // Volume decimal places + PricePlace int // Price decimal places + VolumePlace int // Volume decimal places } // BitgetResponse Bitget API response type BitgetResponse struct { - Code string `json:"code"` - Msg string `json:"msg"` - Data json.RawMessage `json:"data"` - RequestTime int64 `json:"requestTime"` + Code string `json:"code"` + Msg string `json:"msg"` + Data json.RawMessage `json:"data"` + RequestTime int64 `json:"requestTime"` } // NewBitgetTrader creates a Bitget trader @@ -218,148 +217,6 @@ func (t *BitgetTrader) convertSymbol(symbol string) string { return strings.ToUpper(symbol) } -// GetBalance gets account balance -func (t *BitgetTrader) GetBalance() (map[string]interface{}, error) { - // Check cache - t.balanceCacheMutex.RLock() - if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration { - t.balanceCacheMutex.RUnlock() - return t.cachedBalance, nil - } - t.balanceCacheMutex.RUnlock() - - params := map[string]interface{}{ - "productType": "USDT-FUTURES", - } - - data, err := t.doRequest("GET", bitgetAccountPath, params) - if err != nil { - return nil, fmt.Errorf("failed to get account balance: %w", err) - } - - var accounts []struct { - MarginCoin string `json:"marginCoin"` - Available string `json:"available"` // Available balance - AccountEquity string `json:"accountEquity"` // Total equity - UsdtEquity string `json:"usdtEquity"` // USDT equity - UnrealizedPL string `json:"unrealizedPL"` // Unrealized P&L - } - - if err := json.Unmarshal(data, &accounts); err != nil { - return nil, fmt.Errorf("failed to parse balance data: %w, raw: %s", err, string(data)) - } - - var totalEquity, availableBalance, unrealizedPnL float64 - for _, acc := range accounts { - if acc.MarginCoin == "USDT" { - totalEquity, _ = strconv.ParseFloat(acc.AccountEquity, 64) - availableBalance, _ = strconv.ParseFloat(acc.Available, 64) - unrealizedPnL, _ = strconv.ParseFloat(acc.UnrealizedPL, 64) - logger.Infof("✓ [Bitget] Balance: equity=%.2f, available=%.2f", totalEquity, availableBalance) - break - } - } - - result := map[string]interface{}{ - "totalWalletBalance": totalEquity - unrealizedPnL, - "availableBalance": availableBalance, - "totalUnrealizedProfit": unrealizedPnL, - "total_equity": totalEquity, - } - - // Update cache - t.balanceCacheMutex.Lock() - t.cachedBalance = result - t.balanceCacheTime = time.Now() - t.balanceCacheMutex.Unlock() - - return result, nil -} - -// GetPositions gets all positions -func (t *BitgetTrader) GetPositions() ([]map[string]interface{}, error) { - // Check cache - t.positionsCacheMutex.RLock() - if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration { - t.positionsCacheMutex.RUnlock() - return t.cachedPositions, nil - } - t.positionsCacheMutex.RUnlock() - - params := map[string]interface{}{ - "productType": "USDT-FUTURES", - "marginCoin": "USDT", - } - - data, err := t.doRequest("GET", bitgetPositionPath, params) - if err != nil { - return nil, fmt.Errorf("failed to get positions: %w", err) - } - - var positions []struct { - Symbol string `json:"symbol"` - HoldSide string `json:"holdSide"` // long, short - OpenPriceAvg string `json:"openPriceAvg"` // Average entry price - MarkPrice string `json:"markPrice"` // Mark price - Total string `json:"total"` // Total position size - Available string `json:"available"` // Available to close - UnrealizedPL string `json:"unrealizedPL"` // Unrealized P&L - Leverage string `json:"leverage"` // Leverage - LiquidationPrice string `json:"liquidationPrice"` // Liquidation price - MarginSize string `json:"marginSize"` // Position margin - CTime string `json:"cTime"` // Create time - UTime string `json:"uTime"` // Update time - } - - if err := json.Unmarshal(data, &positions); err != nil { - return nil, fmt.Errorf("failed to parse position data: %w", err) - } - - var result []map[string]interface{} - for _, pos := range positions { - total, _ := strconv.ParseFloat(pos.Total, 64) - if total == 0 { - continue - } - - entryPrice, _ := strconv.ParseFloat(pos.OpenPriceAvg, 64) - markPrice, _ := strconv.ParseFloat(pos.MarkPrice, 64) - unrealizedPnL, _ := strconv.ParseFloat(pos.UnrealizedPL, 64) - leverage, _ := strconv.ParseFloat(pos.Leverage, 64) - liqPrice, _ := strconv.ParseFloat(pos.LiquidationPrice, 64) - cTime, _ := strconv.ParseInt(pos.CTime, 10, 64) - uTime, _ := strconv.ParseInt(pos.UTime, 10, 64) - - // Normalize side - side := "long" - if pos.HoldSide == "short" { - side = "short" - } - - posMap := map[string]interface{}{ - "symbol": pos.Symbol, - "positionAmt": total, - "entryPrice": entryPrice, - "markPrice": markPrice, - "unRealizedProfit": unrealizedPnL, - "leverage": leverage, - "liquidationPrice": liqPrice, - "side": side, - "createdTime": cTime, - "updatedTime": uTime, - } - result = append(result, posMap) - } - - // Update cache - t.positionsCacheMutex.Lock() - t.cachedPositions = result - t.positionsCacheTime = time.Now() - t.positionsCacheMutex.Unlock() - - return result, nil -} - // getContract gets contract info func (t *BitgetTrader) getContract(symbol string) (*BitgetContract, error) { symbol = t.convertSymbol(symbol) @@ -430,513 +287,6 @@ func (t *BitgetTrader) getContract(symbol string) (*BitgetContract, error) { return nil, fmt.Errorf("contract info not found: %s", symbol) } -// SetMarginMode sets margin mode -func (t *BitgetTrader) SetMarginMode(symbol string, isCrossMargin bool) error { - symbol = t.convertSymbol(symbol) - - marginMode := "isolated" - if isCrossMargin { - marginMode = "crossed" - } - - body := map[string]interface{}{ - "symbol": symbol, - "productType": "USDT-FUTURES", - "marginCoin": "USDT", - "marginMode": marginMode, - } - - _, err := t.doRequest("POST", bitgetMarginModePath, body) - if err != nil { - if strings.Contains(err.Error(), "same") || strings.Contains(err.Error(), "already") { - return nil - } - if strings.Contains(err.Error(), "position") { - logger.Infof(" ⚠️ %s has positions, cannot change margin mode", symbol) - return nil - } - return err - } - - logger.Infof(" ✓ %s margin mode set to %s", symbol, marginMode) - return nil -} - -// SetLeverage sets leverage -func (t *BitgetTrader) SetLeverage(symbol string, leverage int) error { - symbol = t.convertSymbol(symbol) - - body := map[string]interface{}{ - "symbol": symbol, - "productType": "USDT-FUTURES", - "marginCoin": "USDT", - "leverage": fmt.Sprintf("%d", leverage), - } - - _, err := t.doRequest("POST", bitgetLeveragePath, body) - if err != nil { - if strings.Contains(err.Error(), "same") { - return nil - } - logger.Infof(" ⚠️ Failed to set %s leverage: %v", symbol, err) - return err - } - - logger.Infof(" ✓ %s leverage set to %dx", symbol, leverage) - return nil -} - -// OpenLong opens long position -func (t *BitgetTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { - symbol = t.convertSymbol(symbol) - - // Cancel old orders first - t.CancelAllOrders(symbol) - - // Set leverage - if err := t.SetLeverage(symbol, leverage); err != nil { - logger.Infof(" ⚠️ Failed to set leverage: %v", err) - } - - // Format quantity - qtyStr, _ := t.FormatQuantity(symbol, quantity) - - body := map[string]interface{}{ - "symbol": symbol, - "productType": "USDT-FUTURES", - "marginMode": "crossed", - "marginCoin": "USDT", - "side": "buy", - "orderType": "market", - "size": qtyStr, - "clientOid": genBitgetClientOid(), - } - - logger.Infof(" 📊 Bitget OpenLong: symbol=%s, qty=%s, leverage=%d", symbol, qtyStr, leverage) - - data, err := t.doRequest("POST", bitgetOrderPath, body) - if err != nil { - return nil, fmt.Errorf("failed to open long position: %w", err) - } - - var order struct { - OrderId string `json:"orderId"` - ClientOid string `json:"clientOid"` - } - - if err := json.Unmarshal(data, &order); err != nil { - return nil, fmt.Errorf("failed to parse order response: %w", err) - } - - // Clear cache - t.clearCache() - - logger.Infof("✓ Bitget opened long position successfully: %s", symbol) - - return map[string]interface{}{ - "orderId": order.OrderId, - "symbol": symbol, - "status": "FILLED", - }, nil -} - -// OpenShort opens short position -func (t *BitgetTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { - symbol = t.convertSymbol(symbol) - - // Cancel old orders first - t.CancelAllOrders(symbol) - - // Set leverage - if err := t.SetLeverage(symbol, leverage); err != nil { - logger.Infof(" ⚠️ Failed to set leverage: %v", err) - } - - // Format quantity - qtyStr, _ := t.FormatQuantity(symbol, quantity) - - body := map[string]interface{}{ - "symbol": symbol, - "productType": "USDT-FUTURES", - "marginMode": "crossed", - "marginCoin": "USDT", - "side": "sell", - "orderType": "market", - "size": qtyStr, - "clientOid": genBitgetClientOid(), - } - - logger.Infof(" 📊 Bitget OpenShort: symbol=%s, qty=%s, leverage=%d", symbol, qtyStr, leverage) - - data, err := t.doRequest("POST", bitgetOrderPath, body) - if err != nil { - return nil, fmt.Errorf("failed to open short position: %w", err) - } - - var order struct { - OrderId string `json:"orderId"` - ClientOid string `json:"clientOid"` - } - - if err := json.Unmarshal(data, &order); err != nil { - return nil, fmt.Errorf("failed to parse order response: %w", err) - } - - // Clear cache - t.clearCache() - - logger.Infof("✓ Bitget opened short position successfully: %s", symbol) - - return map[string]interface{}{ - "orderId": order.OrderId, - "symbol": symbol, - "status": "FILLED", - }, nil -} - -// CloseLong closes long position -func (t *BitgetTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) { - symbol = t.convertSymbol(symbol) - - // If quantity is 0, get current position - if quantity == 0 { - positions, err := t.GetPositions() - if err != nil { - return nil, err - } - for _, pos := range positions { - if pos["symbol"] == symbol && pos["side"] == "long" { - quantity = pos["positionAmt"].(float64) - break - } - } - if quantity == 0 { - return nil, fmt.Errorf("long position not found for %s", symbol) - } - } - - // Format quantity - qtyStr, _ := t.FormatQuantity(symbol, quantity) - - body := map[string]interface{}{ - "symbol": symbol, - "productType": "USDT-FUTURES", - "marginMode": "crossed", - "marginCoin": "USDT", - "side": "sell", - "orderType": "market", - "size": qtyStr, - "reduceOnly": "YES", - "clientOid": genBitgetClientOid(), - } - - logger.Infof(" 📊 Bitget CloseLong: symbol=%s, qty=%s", symbol, qtyStr) - - data, err := t.doRequest("POST", bitgetOrderPath, body) - if err != nil { - return nil, fmt.Errorf("failed to close long position: %w", err) - } - - var order struct { - OrderId string `json:"orderId"` - } - - if err := json.Unmarshal(data, &order); err != nil { - return nil, err - } - - // Clear cache - t.clearCache() - - logger.Infof("✓ Bitget closed long position successfully: %s", symbol) - - return map[string]interface{}{ - "orderId": order.OrderId, - "symbol": symbol, - "status": "FILLED", - }, nil -} - -// CloseShort closes short position -func (t *BitgetTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) { - symbol = t.convertSymbol(symbol) - - // If quantity is 0, get current position - if quantity == 0 { - positions, err := t.GetPositions() - if err != nil { - return nil, err - } - for _, pos := range positions { - if pos["symbol"] == symbol && pos["side"] == "short" { - quantity = pos["positionAmt"].(float64) - break - } - } - if quantity == 0 { - return nil, fmt.Errorf("short position not found for %s", symbol) - } - } - - // Ensure quantity is positive - if quantity < 0 { - quantity = -quantity - } - - // Format quantity - qtyStr, _ := t.FormatQuantity(symbol, quantity) - - body := map[string]interface{}{ - "symbol": symbol, - "productType": "USDT-FUTURES", - "marginMode": "crossed", - "marginCoin": "USDT", - "side": "buy", - "orderType": "market", - "size": qtyStr, - "reduceOnly": "YES", - "clientOid": genBitgetClientOid(), - } - - logger.Infof(" 📊 Bitget CloseShort: symbol=%s, qty=%s", symbol, qtyStr) - - data, err := t.doRequest("POST", bitgetOrderPath, body) - if err != nil { - return nil, fmt.Errorf("failed to close short position: %w", err) - } - - var order struct { - OrderId string `json:"orderId"` - } - - if err := json.Unmarshal(data, &order); err != nil { - return nil, err - } - - // Clear cache - t.clearCache() - - logger.Infof("✓ Bitget closed short position successfully: %s", symbol) - - return map[string]interface{}{ - "orderId": order.OrderId, - "symbol": symbol, - "status": "FILLED", - }, nil -} - -// GetMarketPrice gets market price -func (t *BitgetTrader) GetMarketPrice(symbol string) (float64, error) { - symbol = t.convertSymbol(symbol) - - params := map[string]interface{}{ - "symbol": symbol, - "productType": "USDT-FUTURES", - } - - data, err := t.doRequest("GET", bitgetTickerPath, params) - if err != nil { - return 0, fmt.Errorf("failed to get price: %w", err) - } - - var tickers []struct { - LastPr string `json:"lastPr"` - } - - if err := json.Unmarshal(data, &tickers); err != nil { - return 0, err - } - - if len(tickers) == 0 { - return 0, fmt.Errorf("no price data received") - } - - price, err := strconv.ParseFloat(tickers[0].LastPr, 64) - if err != nil { - return 0, err - } - - return price, nil -} - -// SetStopLoss sets stop loss order -func (t *BitgetTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error { - // Bitget V2 uses plan order for stop loss - symbol = t.convertSymbol(symbol) - - side := "sell" - holdSide := "long" - if strings.ToUpper(positionSide) == "SHORT" { - side = "buy" - holdSide = "short" - } - - qtyStr, _ := t.FormatQuantity(symbol, quantity) - - body := map[string]interface{}{ - "planType": "loss_plan", - "symbol": symbol, - "productType": "USDT-FUTURES", - "marginMode": "crossed", - "marginCoin": "USDT", - "triggerPrice": fmt.Sprintf("%.8f", stopPrice), - "triggerType": "mark_price", - "side": side, - "tradeSide": "close", - "orderType": "market", - "size": qtyStr, - "holdSide": holdSide, - "clientOid": genBitgetClientOid(), - } - - _, err := t.doRequest("POST", "/api/v2/mix/order/place-plan-order", body) - if err != nil { - return fmt.Errorf("failed to set stop loss: %w", err) - } - - logger.Infof(" ✓ [Bitget] Stop loss set: %s @ %.4f", symbol, stopPrice) - return nil -} - -// SetTakeProfit sets take profit order -func (t *BitgetTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error { - // Bitget V2 uses plan order for take profit - symbol = t.convertSymbol(symbol) - - side := "sell" - holdSide := "long" - if strings.ToUpper(positionSide) == "SHORT" { - side = "buy" - holdSide = "short" - } - - qtyStr, _ := t.FormatQuantity(symbol, quantity) - - body := map[string]interface{}{ - "planType": "profit_plan", - "symbol": symbol, - "productType": "USDT-FUTURES", - "marginMode": "crossed", - "marginCoin": "USDT", - "triggerPrice": fmt.Sprintf("%.8f", takeProfitPrice), - "triggerType": "mark_price", - "side": side, - "tradeSide": "close", - "orderType": "market", - "size": qtyStr, - "holdSide": holdSide, - "clientOid": genBitgetClientOid(), - } - - _, err := t.doRequest("POST", "/api/v2/mix/order/place-plan-order", body) - if err != nil { - return fmt.Errorf("failed to set take profit: %w", err) - } - - logger.Infof(" ✓ [Bitget] Take profit set: %s @ %.4f", symbol, takeProfitPrice) - return nil -} - -// CancelStopLossOrders cancels stop loss orders -func (t *BitgetTrader) CancelStopLossOrders(symbol string) error { - return t.cancelPlanOrders(symbol, "loss_plan") -} - -// CancelTakeProfitOrders cancels take profit orders -func (t *BitgetTrader) CancelTakeProfitOrders(symbol string) error { - return t.cancelPlanOrders(symbol, "profit_plan") -} - -// cancelPlanOrders cancels plan orders -func (t *BitgetTrader) cancelPlanOrders(symbol string, planType string) error { - symbol = t.convertSymbol(symbol) - - // Get pending plan orders - params := map[string]interface{}{ - "symbol": symbol, - "productType": "USDT-FUTURES", - "planType": planType, - } - - data, err := t.doRequest("GET", "/api/v2/mix/order/orders-plan-pending", params) - if err != nil { - return err - } - - var orders struct { - EntrustedList []struct { - OrderId string `json:"orderId"` - } `json:"entrustedList"` - } - - if err := json.Unmarshal(data, &orders); err != nil { - return err - } - - // Cancel each order - for _, order := range orders.EntrustedList { - body := map[string]interface{}{ - "symbol": symbol, - "productType": "USDT-FUTURES", - "marginCoin": "USDT", - "orderId": order.OrderId, - } - t.doRequest("POST", "/api/v2/mix/order/cancel-plan-order", body) - } - - return nil -} - -// CancelAllOrders cancels all pending orders -func (t *BitgetTrader) CancelAllOrders(symbol string) error { - symbol = t.convertSymbol(symbol) - - // Get pending orders - params := map[string]interface{}{ - "symbol": symbol, - "productType": "USDT-FUTURES", - } - - data, err := t.doRequest("GET", bitgetPendingPath, params) - if err != nil { - return err - } - - var orders struct { - EntrustedList []struct { - OrderId string `json:"orderId"` - } `json:"entrustedList"` - } - - if err := json.Unmarshal(data, &orders); err != nil { - return err - } - - // Cancel each order - for _, order := range orders.EntrustedList { - body := map[string]interface{}{ - "symbol": symbol, - "productType": "USDT-FUTURES", - "marginCoin": "USDT", - "orderId": order.OrderId, - } - t.doRequest("POST", bitgetCancelOrderPath, body) - } - - // Also cancel plan orders - t.cancelPlanOrders(symbol, "loss_plan") - t.cancelPlanOrders(symbol, "profit_plan") - - return nil -} - -// CancelStopOrders cancels stop loss and take profit orders -func (t *BitgetTrader) CancelStopOrders(symbol string) error { - t.CancelStopLossOrders(symbol) - t.CancelTakeProfitOrders(symbol) - return nil -} - // FormatQuantity formats quantity func (t *BitgetTrader) FormatQuantity(symbol string, quantity float64) (string, error) { contract, err := t.getContract(symbol) @@ -949,137 +299,6 @@ func (t *BitgetTrader) FormatQuantity(symbol string, quantity float64) (string, return fmt.Sprintf(format, quantity), nil } -// GetOrderStatus gets order status -func (t *BitgetTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { - symbol = t.convertSymbol(symbol) - - params := map[string]interface{}{ - "symbol": symbol, - "productType": "USDT-FUTURES", - "orderId": orderID, - } - - data, err := t.doRequest("GET", "/api/v2/mix/order/detail", params) - if err != nil { - return nil, fmt.Errorf("failed to get order status: %w", err) - } - - var order struct { - OrderId string `json:"orderId"` - State string `json:"state"` // filled, canceled, partially_filled, new - PriceAvg string `json:"priceAvg"` // Average fill price - BaseVolume string `json:"baseVolume"` // Filled quantity - Fee string `json:"fee"` // Fee - Side string `json:"side"` - OrderType string `json:"orderType"` - CTime string `json:"cTime"` - UTime string `json:"uTime"` - } - - if err := json.Unmarshal(data, &order); err != nil { - return nil, err - } - - avgPrice, _ := strconv.ParseFloat(order.PriceAvg, 64) - fillQty, _ := strconv.ParseFloat(order.BaseVolume, 64) - fee, _ := strconv.ParseFloat(order.Fee, 64) - cTime, _ := strconv.ParseInt(order.CTime, 10, 64) - uTime, _ := strconv.ParseInt(order.UTime, 10, 64) - - // Status mapping - statusMap := map[string]string{ - "filled": "FILLED", - "new": "NEW", - "partially_filled": "PARTIALLY_FILLED", - "canceled": "CANCELED", - } - - status := statusMap[order.State] - if status == "" { - status = order.State - } - - return map[string]interface{}{ - "orderId": order.OrderId, - "symbol": symbol, - "status": status, - "avgPrice": avgPrice, - "executedQty": fillQty, - "side": order.Side, - "type": order.OrderType, - "time": cTime, - "updateTime": uTime, - "commission": -fee, - }, nil -} - -// GetClosedPnL retrieves closed position PnL records -func (t *BitgetTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) { - if limit <= 0 { - limit = 100 - } - if limit > 100 { - limit = 100 - } - - params := map[string]interface{}{ - "productType": "USDT-FUTURES", - "startTime": fmt.Sprintf("%d", startTime.UnixMilli()), - "limit": fmt.Sprintf("%d", limit), - } - - data, err := t.doRequest("GET", "/api/v2/mix/position/history-position", params) - if err != nil { - return nil, fmt.Errorf("failed to get positions history: %w", err) - } - - var resp struct { - List []struct { - Symbol string `json:"symbol"` - HoldSide string `json:"holdSide"` - OpenPriceAvg string `json:"openPriceAvg"` - ClosePriceAvg string `json:"closePriceAvg"` - CloseVol string `json:"closeVol"` - AchievedProfits string `json:"achievedProfits"` - TotalFee string `json:"totalFee"` - Leverage string `json:"leverage"` - CTime string `json:"cTime"` - UTime string `json:"uTime"` - } `json:"list"` - } - - if err := json.Unmarshal(data, &resp); err != nil { - return nil, fmt.Errorf("failed to parse response: %w", err) - } - - records := make([]types.ClosedPnLRecord, 0, len(resp.List)) - for _, pos := range resp.List { - record := types.ClosedPnLRecord{ - Symbol: pos.Symbol, - Side: pos.HoldSide, - } - - record.EntryPrice, _ = strconv.ParseFloat(pos.OpenPriceAvg, 64) - record.ExitPrice, _ = strconv.ParseFloat(pos.ClosePriceAvg, 64) - record.Quantity, _ = strconv.ParseFloat(pos.CloseVol, 64) - record.RealizedPnL, _ = strconv.ParseFloat(pos.AchievedProfits, 64) - fee, _ := strconv.ParseFloat(pos.TotalFee, 64) - record.Fee = -fee - lev, _ := strconv.ParseFloat(pos.Leverage, 64) - record.Leverage = int(lev) - - cTime, _ := strconv.ParseInt(pos.CTime, 10, 64) - uTime, _ := strconv.ParseInt(pos.UTime, 10, 64) - record.EntryTime = time.UnixMilli(cTime).UTC() - record.ExitTime = time.UnixMilli(uTime).UTC() - - record.CloseType = "unknown" - records = append(records, record) - } - - return records, nil -} - // clearCache clears all caches func (t *BitgetTrader) clearCache() { t.balanceCacheMutex.Lock() @@ -1097,264 +316,3 @@ func genBitgetClientOid() string { rand := time.Now().Nanosecond() % 100000 return fmt.Sprintf("nofx%d%05d", timestamp, rand) } - -// GetOpenOrders gets all open/pending orders for a symbol -func (t *BitgetTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) { - symbol = t.convertSymbol(symbol) - var result []types.OpenOrder - - // 1. Get pending limit orders - params := map[string]interface{}{ - "symbol": symbol, - "productType": "USDT-FUTURES", - } - - data, err := t.doRequest("GET", bitgetPendingPath, params) - if err != nil { - logger.Warnf("[Bitget] Failed to get pending orders: %v", err) - } - if err == nil && data != nil { - var orders struct { - EntrustedList []struct { - OrderId string `json:"orderId"` - Symbol string `json:"symbol"` - Side string `json:"side"` // buy/sell - TradeSide string `json:"tradeSide"` // open/close - PosSide string `json:"posSide"` // long/short - OrderType string `json:"orderType"` // limit/market - Price string `json:"price"` - Size string `json:"size"` - State string `json:"state"` - } `json:"entrustedList"` - } - if err := json.Unmarshal(data, &orders); err == nil { - for _, order := range orders.EntrustedList { - price, _ := strconv.ParseFloat(order.Price, 64) - quantity, _ := strconv.ParseFloat(order.Size, 64) - - // Convert side to standard format - side := strings.ToUpper(order.Side) - positionSide := strings.ToUpper(order.PosSide) - - result = append(result, types.OpenOrder{ - OrderID: order.OrderId, - Symbol: symbol, - Side: side, - PositionSide: positionSide, - Type: strings.ToUpper(order.OrderType), - Price: price, - StopPrice: 0, - Quantity: quantity, - Status: "NEW", - }) - } - } - } - - // 2. Get pending plan orders (stop-loss/take-profit) - // Bitget V2 API requires planType parameter: profit_loss for SL/TP orders - planParams := map[string]interface{}{ - "productType": "USDT-FUTURES", - "planType": "profit_loss", - } - - planData, err := t.doRequest("GET", "/api/v2/mix/order/orders-plan-pending", planParams) - if err != nil { - logger.Warnf("[Bitget] Failed to get plan orders: %v", err) - } - if err == nil && planData != nil { - var planOrders struct { - EntrustedList []struct { - OrderId string `json:"orderId"` - Symbol string `json:"symbol"` - Side string `json:"side"` - PosSide string `json:"posSide"` - PlanType string `json:"planType"` // pos_loss, pos_profit - TriggerPrice string `json:"triggerPrice"` - StopLossTriggerPrice string `json:"stopLossTriggerPrice"` - StopSurplusTriggerPrice string `json:"stopSurplusTriggerPrice"` - Size string `json:"size"` - PlanStatus string `json:"planStatus"` - } `json:"entrustedList"` - } - if err := json.Unmarshal(planData, &planOrders); err == nil { - for _, order := range planOrders.EntrustedList { - // Filter by symbol if specified - if symbol != "" && order.Symbol != symbol { - continue - } - - // Determine trigger price based on plan type - var triggerPrice float64 - orderType := "STOP_MARKET" - - if order.PlanType == "pos_profit" { - // Take profit order - orderType = "TAKE_PROFIT_MARKET" - if order.StopSurplusTriggerPrice != "" { - triggerPrice, _ = strconv.ParseFloat(order.StopSurplusTriggerPrice, 64) - } else { - triggerPrice, _ = strconv.ParseFloat(order.TriggerPrice, 64) - } - } else { - // Stop loss order (pos_loss) - if order.StopLossTriggerPrice != "" { - triggerPrice, _ = strconv.ParseFloat(order.StopLossTriggerPrice, 64) - } else { - triggerPrice, _ = strconv.ParseFloat(order.TriggerPrice, 64) - } - } - - quantity, _ := strconv.ParseFloat(order.Size, 64) - side := strings.ToUpper(order.Side) - positionSide := strings.ToUpper(order.PosSide) - - result = append(result, types.OpenOrder{ - OrderID: order.OrderId, - Symbol: order.Symbol, - Side: side, - PositionSide: positionSide, - Type: orderType, - Price: 0, - StopPrice: triggerPrice, - Quantity: quantity, - Status: "NEW", - }) - } - } - } - - logger.Infof("✓ BITGET GetOpenOrders: found %d open orders for %s", len(result), symbol) - return result, nil -} - -// PlaceLimitOrder places a limit order for grid trading -// Implements GridTrader interface -func (t *BitgetTrader) PlaceLimitOrder(req *types.LimitOrderRequest) (*types.LimitOrderResult, error) { - symbol := t.convertSymbol(req.Symbol) - - // Set leverage if specified - if req.Leverage > 0 { - if err := t.SetLeverage(symbol, req.Leverage); err != nil { - logger.Warnf("[Bitget] Failed to set leverage: %v", err) - } - } - - // Format quantity - qtyStr, _ := t.FormatQuantity(symbol, req.Quantity) - - // Determine side - side := "buy" - if req.Side == "SELL" { - side = "sell" - } - - body := map[string]interface{}{ - "symbol": symbol, - "productType": "USDT-FUTURES", - "marginMode": "crossed", - "marginCoin": "USDT", - "side": side, - "orderType": "limit", - "size": qtyStr, - "price": fmt.Sprintf("%.8f", req.Price), - "force": "GTC", // Good Till Cancel - "clientOid": genBitgetClientOid(), - } - - // Add reduce only if specified - if req.ReduceOnly { - body["reduceOnly"] = "YES" - } - - logger.Infof("[Bitget] PlaceLimitOrder: %s %s @ %.4f, qty=%s", symbol, side, req.Price, qtyStr) - - data, err := t.doRequest("POST", bitgetOrderPath, body) - if err != nil { - return nil, fmt.Errorf("failed to place limit order: %w", err) - } - - var order struct { - OrderId string `json:"orderId"` - ClientOid string `json:"clientOid"` - } - - if err := json.Unmarshal(data, &order); err != nil { - return nil, fmt.Errorf("failed to parse order response: %w", err) - } - - logger.Infof("✓ [Bitget] Limit order placed: %s %s @ %.4f, orderID=%s", - symbol, side, req.Price, order.OrderId) - - return &types.LimitOrderResult{ - OrderID: order.OrderId, - ClientID: order.ClientOid, - Symbol: req.Symbol, - Side: req.Side, - PositionSide: req.PositionSide, - Price: req.Price, - Quantity: req.Quantity, - Status: "NEW", - }, nil -} - -// CancelOrder cancels a specific order by ID -// Implements GridTrader interface -func (t *BitgetTrader) CancelOrder(symbol, orderID string) error { - symbol = t.convertSymbol(symbol) - - body := map[string]interface{}{ - "symbol": symbol, - "productType": "USDT-FUTURES", - "orderId": orderID, - } - - _, err := t.doRequest("POST", "/api/v2/mix/order/cancel-order", body) - if err != nil { - return fmt.Errorf("failed to cancel order: %w", err) - } - - logger.Infof("✓ [Bitget] Order cancelled: %s %s", symbol, orderID) - return nil -} - -// GetOrderBook gets the order book for a symbol -// Implements GridTrader interface -func (t *BitgetTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { - symbol = t.convertSymbol(symbol) - path := fmt.Sprintf("/api/v2/mix/market/depth?symbol=%s&productType=USDT-FUTURES&limit=%d", symbol, depth) - - data, err := t.doRequest("GET", path, nil) - if err != nil { - return nil, nil, fmt.Errorf("failed to get order book: %w", err) - } - - var result struct { - Bids [][]string `json:"bids"` - Asks [][]string `json:"asks"` - } - - if err := json.Unmarshal(data, &result); err != nil { - return nil, nil, fmt.Errorf("failed to parse order book: %w", err) - } - - // Parse bids - for _, b := range result.Bids { - if len(b) >= 2 { - price, _ := strconv.ParseFloat(b[0], 64) - qty, _ := strconv.ParseFloat(b[1], 64) - bids = append(bids, []float64{price, qty}) - } - } - - // Parse asks - for _, a := range result.Asks { - if len(a) >= 2 { - price, _ := strconv.ParseFloat(a[0], 64) - qty, _ := strconv.ParseFloat(a[1], 64) - asks = append(asks, []float64{price, qty}) - } - } - - return bids, asks, nil -} diff --git a/trader/bitget/trader_account.go b/trader/bitget/trader_account.go new file mode 100644 index 00000000..5449a33c --- /dev/null +++ b/trader/bitget/trader_account.go @@ -0,0 +1,199 @@ +package bitget + +import ( + "encoding/json" + "fmt" + "nofx/logger" + "strconv" + "strings" + "time" +) + +// GetBalance gets account balance +func (t *BitgetTrader) GetBalance() (map[string]interface{}, error) { + // Check cache + t.balanceCacheMutex.RLock() + if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration { + t.balanceCacheMutex.RUnlock() + return t.cachedBalance, nil + } + t.balanceCacheMutex.RUnlock() + + params := map[string]interface{}{ + "productType": "USDT-FUTURES", + } + + data, err := t.doRequest("GET", bitgetAccountPath, params) + if err != nil { + return nil, fmt.Errorf("failed to get account balance: %w", err) + } + + var accounts []struct { + MarginCoin string `json:"marginCoin"` + Available string `json:"available"` // Available balance + AccountEquity string `json:"accountEquity"` // Total equity + UsdtEquity string `json:"usdtEquity"` // USDT equity + UnrealizedPL string `json:"unrealizedPL"` // Unrealized P&L + } + + if err := json.Unmarshal(data, &accounts); err != nil { + return nil, fmt.Errorf("failed to parse balance data: %w, raw: %s", err, string(data)) + } + + var totalEquity, availableBalance, unrealizedPnL float64 + for _, acc := range accounts { + if acc.MarginCoin == "USDT" { + totalEquity, _ = strconv.ParseFloat(acc.AccountEquity, 64) + availableBalance, _ = strconv.ParseFloat(acc.Available, 64) + unrealizedPnL, _ = strconv.ParseFloat(acc.UnrealizedPL, 64) + logger.Infof("✓ [Bitget] Balance: equity=%.2f, available=%.2f", totalEquity, availableBalance) + break + } + } + + result := map[string]interface{}{ + "totalWalletBalance": totalEquity - unrealizedPnL, + "availableBalance": availableBalance, + "totalUnrealizedProfit": unrealizedPnL, + "total_equity": totalEquity, + } + + // Update cache + t.balanceCacheMutex.Lock() + t.cachedBalance = result + t.balanceCacheTime = time.Now() + t.balanceCacheMutex.Unlock() + + return result, nil +} + +// SetMarginMode sets margin mode +func (t *BitgetTrader) SetMarginMode(symbol string, isCrossMargin bool) error { + symbol = t.convertSymbol(symbol) + + marginMode := "isolated" + if isCrossMargin { + marginMode = "crossed" + } + + body := map[string]interface{}{ + "symbol": symbol, + "productType": "USDT-FUTURES", + "marginCoin": "USDT", + "marginMode": marginMode, + } + + _, err := t.doRequest("POST", bitgetMarginModePath, body) + if err != nil { + if strings.Contains(err.Error(), "same") || strings.Contains(err.Error(), "already") { + return nil + } + if strings.Contains(err.Error(), "position") { + logger.Infof(" ⚠️ %s has positions, cannot change margin mode", symbol) + return nil + } + return err + } + + logger.Infof(" ✓ %s margin mode set to %s", symbol, marginMode) + return nil +} + +// SetLeverage sets leverage +func (t *BitgetTrader) SetLeverage(symbol string, leverage int) error { + symbol = t.convertSymbol(symbol) + + body := map[string]interface{}{ + "symbol": symbol, + "productType": "USDT-FUTURES", + "marginCoin": "USDT", + "leverage": fmt.Sprintf("%d", leverage), + } + + _, err := t.doRequest("POST", bitgetLeveragePath, body) + if err != nil { + if strings.Contains(err.Error(), "same") { + return nil + } + logger.Infof(" ⚠️ Failed to set %s leverage: %v", symbol, err) + return err + } + + logger.Infof(" ✓ %s leverage set to %dx", symbol, leverage) + return nil +} + +// GetMarketPrice gets market price +func (t *BitgetTrader) GetMarketPrice(symbol string) (float64, error) { + symbol = t.convertSymbol(symbol) + + params := map[string]interface{}{ + "symbol": symbol, + "productType": "USDT-FUTURES", + } + + data, err := t.doRequest("GET", bitgetTickerPath, params) + if err != nil { + return 0, fmt.Errorf("failed to get price: %w", err) + } + + var tickers []struct { + LastPr string `json:"lastPr"` + } + + if err := json.Unmarshal(data, &tickers); err != nil { + return 0, err + } + + if len(tickers) == 0 { + return 0, fmt.Errorf("no price data received") + } + + price, err := strconv.ParseFloat(tickers[0].LastPr, 64) + if err != nil { + return 0, err + } + + return price, nil +} + +// GetOrderBook gets the order book for a symbol +// Implements GridTrader interface +func (t *BitgetTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { + symbol = t.convertSymbol(symbol) + path := fmt.Sprintf("/api/v2/mix/market/depth?symbol=%s&productType=USDT-FUTURES&limit=%d", symbol, depth) + + data, err := t.doRequest("GET", path, nil) + if err != nil { + return nil, nil, fmt.Errorf("failed to get order book: %w", err) + } + + var result struct { + Bids [][]string `json:"bids"` + Asks [][]string `json:"asks"` + } + + if err := json.Unmarshal(data, &result); err != nil { + return nil, nil, fmt.Errorf("failed to parse order book: %w", err) + } + + // Parse bids + for _, b := range result.Bids { + if len(b) >= 2 { + price, _ := strconv.ParseFloat(b[0], 64) + qty, _ := strconv.ParseFloat(b[1], 64) + bids = append(bids, []float64{price, qty}) + } + } + + // Parse asks + for _, a := range result.Asks { + if len(a) >= 2 { + price, _ := strconv.ParseFloat(a[0], 64) + qty, _ := strconv.ParseFloat(a[1], 64) + asks = append(asks, []float64{price, qty}) + } + } + + return bids, asks, nil +} diff --git a/trader/bitget/trader_orders.go b/trader/bitget/trader_orders.go new file mode 100644 index 00000000..f08a7a38 --- /dev/null +++ b/trader/bitget/trader_orders.go @@ -0,0 +1,711 @@ +package bitget + +import ( + "encoding/json" + "fmt" + "nofx/logger" + "nofx/trader/types" + "strconv" + "strings" +) + +// OpenLong opens long position +func (t *BitgetTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { + symbol = t.convertSymbol(symbol) + + // Cancel old orders first + t.CancelAllOrders(symbol) + + // Set leverage + if err := t.SetLeverage(symbol, leverage); err != nil { + logger.Infof(" ⚠️ Failed to set leverage: %v", err) + } + + // Format quantity + qtyStr, _ := t.FormatQuantity(symbol, quantity) + + body := map[string]interface{}{ + "symbol": symbol, + "productType": "USDT-FUTURES", + "marginMode": "crossed", + "marginCoin": "USDT", + "side": "buy", + "orderType": "market", + "size": qtyStr, + "clientOid": genBitgetClientOid(), + } + + logger.Infof(" 📊 Bitget OpenLong: symbol=%s, qty=%s, leverage=%d", symbol, qtyStr, leverage) + + data, err := t.doRequest("POST", bitgetOrderPath, body) + if err != nil { + return nil, fmt.Errorf("failed to open long position: %w", err) + } + + var order struct { + OrderId string `json:"orderId"` + ClientOid string `json:"clientOid"` + } + + if err := json.Unmarshal(data, &order); err != nil { + return nil, fmt.Errorf("failed to parse order response: %w", err) + } + + // Clear cache + t.clearCache() + + logger.Infof("✓ Bitget opened long position successfully: %s", symbol) + + return map[string]interface{}{ + "orderId": order.OrderId, + "symbol": symbol, + "status": "FILLED", + }, nil +} + +// OpenShort opens short position +func (t *BitgetTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { + symbol = t.convertSymbol(symbol) + + // Cancel old orders first + t.CancelAllOrders(symbol) + + // Set leverage + if err := t.SetLeverage(symbol, leverage); err != nil { + logger.Infof(" ⚠️ Failed to set leverage: %v", err) + } + + // Format quantity + qtyStr, _ := t.FormatQuantity(symbol, quantity) + + body := map[string]interface{}{ + "symbol": symbol, + "productType": "USDT-FUTURES", + "marginMode": "crossed", + "marginCoin": "USDT", + "side": "sell", + "orderType": "market", + "size": qtyStr, + "clientOid": genBitgetClientOid(), + } + + logger.Infof(" 📊 Bitget OpenShort: symbol=%s, qty=%s, leverage=%d", symbol, qtyStr, leverage) + + data, err := t.doRequest("POST", bitgetOrderPath, body) + if err != nil { + return nil, fmt.Errorf("failed to open short position: %w", err) + } + + var order struct { + OrderId string `json:"orderId"` + ClientOid string `json:"clientOid"` + } + + if err := json.Unmarshal(data, &order); err != nil { + return nil, fmt.Errorf("failed to parse order response: %w", err) + } + + // Clear cache + t.clearCache() + + logger.Infof("✓ Bitget opened short position successfully: %s", symbol) + + return map[string]interface{}{ + "orderId": order.OrderId, + "symbol": symbol, + "status": "FILLED", + }, nil +} + +// CloseLong closes long position +func (t *BitgetTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) { + symbol = t.convertSymbol(symbol) + + // If quantity is 0, get current position + if quantity == 0 { + positions, err := t.GetPositions() + if err != nil { + return nil, err + } + for _, pos := range positions { + if pos["symbol"] == symbol && pos["side"] == "long" { + quantity = pos["positionAmt"].(float64) + break + } + } + if quantity == 0 { + return nil, fmt.Errorf("long position not found for %s", symbol) + } + } + + // Format quantity + qtyStr, _ := t.FormatQuantity(symbol, quantity) + + body := map[string]interface{}{ + "symbol": symbol, + "productType": "USDT-FUTURES", + "marginMode": "crossed", + "marginCoin": "USDT", + "side": "sell", + "orderType": "market", + "size": qtyStr, + "reduceOnly": "YES", + "clientOid": genBitgetClientOid(), + } + + logger.Infof(" 📊 Bitget CloseLong: symbol=%s, qty=%s", symbol, qtyStr) + + data, err := t.doRequest("POST", bitgetOrderPath, body) + if err != nil { + return nil, fmt.Errorf("failed to close long position: %w", err) + } + + var order struct { + OrderId string `json:"orderId"` + } + + if err := json.Unmarshal(data, &order); err != nil { + return nil, err + } + + // Clear cache + t.clearCache() + + logger.Infof("✓ Bitget closed long position successfully: %s", symbol) + + return map[string]interface{}{ + "orderId": order.OrderId, + "symbol": symbol, + "status": "FILLED", + }, nil +} + +// CloseShort closes short position +func (t *BitgetTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) { + symbol = t.convertSymbol(symbol) + + // If quantity is 0, get current position + if quantity == 0 { + positions, err := t.GetPositions() + if err != nil { + return nil, err + } + for _, pos := range positions { + if pos["symbol"] == symbol && pos["side"] == "short" { + quantity = pos["positionAmt"].(float64) + break + } + } + if quantity == 0 { + return nil, fmt.Errorf("short position not found for %s", symbol) + } + } + + // Ensure quantity is positive + if quantity < 0 { + quantity = -quantity + } + + // Format quantity + qtyStr, _ := t.FormatQuantity(symbol, quantity) + + body := map[string]interface{}{ + "symbol": symbol, + "productType": "USDT-FUTURES", + "marginMode": "crossed", + "marginCoin": "USDT", + "side": "buy", + "orderType": "market", + "size": qtyStr, + "reduceOnly": "YES", + "clientOid": genBitgetClientOid(), + } + + logger.Infof(" 📊 Bitget CloseShort: symbol=%s, qty=%s", symbol, qtyStr) + + data, err := t.doRequest("POST", bitgetOrderPath, body) + if err != nil { + return nil, fmt.Errorf("failed to close short position: %w", err) + } + + var order struct { + OrderId string `json:"orderId"` + } + + if err := json.Unmarshal(data, &order); err != nil { + return nil, err + } + + // Clear cache + t.clearCache() + + logger.Infof("✓ Bitget closed short position successfully: %s", symbol) + + return map[string]interface{}{ + "orderId": order.OrderId, + "symbol": symbol, + "status": "FILLED", + }, nil +} + +// SetStopLoss sets stop loss order +func (t *BitgetTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error { + // Bitget V2 uses plan order for stop loss + symbol = t.convertSymbol(symbol) + + side := "sell" + holdSide := "long" + if strings.ToUpper(positionSide) == "SHORT" { + side = "buy" + holdSide = "short" + } + + qtyStr, _ := t.FormatQuantity(symbol, quantity) + + body := map[string]interface{}{ + "planType": "loss_plan", + "symbol": symbol, + "productType": "USDT-FUTURES", + "marginMode": "crossed", + "marginCoin": "USDT", + "triggerPrice": fmt.Sprintf("%.8f", stopPrice), + "triggerType": "mark_price", + "side": side, + "tradeSide": "close", + "orderType": "market", + "size": qtyStr, + "holdSide": holdSide, + "clientOid": genBitgetClientOid(), + } + + _, err := t.doRequest("POST", "/api/v2/mix/order/place-plan-order", body) + if err != nil { + return fmt.Errorf("failed to set stop loss: %w", err) + } + + logger.Infof(" ✓ [Bitget] Stop loss set: %s @ %.4f", symbol, stopPrice) + return nil +} + +// SetTakeProfit sets take profit order +func (t *BitgetTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error { + // Bitget V2 uses plan order for take profit + symbol = t.convertSymbol(symbol) + + side := "sell" + holdSide := "long" + if strings.ToUpper(positionSide) == "SHORT" { + side = "buy" + holdSide = "short" + } + + qtyStr, _ := t.FormatQuantity(symbol, quantity) + + body := map[string]interface{}{ + "planType": "profit_plan", + "symbol": symbol, + "productType": "USDT-FUTURES", + "marginMode": "crossed", + "marginCoin": "USDT", + "triggerPrice": fmt.Sprintf("%.8f", takeProfitPrice), + "triggerType": "mark_price", + "side": side, + "tradeSide": "close", + "orderType": "market", + "size": qtyStr, + "holdSide": holdSide, + "clientOid": genBitgetClientOid(), + } + + _, err := t.doRequest("POST", "/api/v2/mix/order/place-plan-order", body) + if err != nil { + return fmt.Errorf("failed to set take profit: %w", err) + } + + logger.Infof(" ✓ [Bitget] Take profit set: %s @ %.4f", symbol, takeProfitPrice) + return nil +} + +// CancelStopLossOrders cancels stop loss orders +func (t *BitgetTrader) CancelStopLossOrders(symbol string) error { + return t.cancelPlanOrders(symbol, "loss_plan") +} + +// CancelTakeProfitOrders cancels take profit orders +func (t *BitgetTrader) CancelTakeProfitOrders(symbol string) error { + return t.cancelPlanOrders(symbol, "profit_plan") +} + +// cancelPlanOrders cancels plan orders +func (t *BitgetTrader) cancelPlanOrders(symbol string, planType string) error { + symbol = t.convertSymbol(symbol) + + // Get pending plan orders + params := map[string]interface{}{ + "symbol": symbol, + "productType": "USDT-FUTURES", + "planType": planType, + } + + data, err := t.doRequest("GET", "/api/v2/mix/order/orders-plan-pending", params) + if err != nil { + return err + } + + var orders struct { + EntrustedList []struct { + OrderId string `json:"orderId"` + } `json:"entrustedList"` + } + + if err := json.Unmarshal(data, &orders); err != nil { + return err + } + + // Cancel each order + for _, order := range orders.EntrustedList { + body := map[string]interface{}{ + "symbol": symbol, + "productType": "USDT-FUTURES", + "marginCoin": "USDT", + "orderId": order.OrderId, + } + t.doRequest("POST", "/api/v2/mix/order/cancel-plan-order", body) + } + + return nil +} + +// CancelAllOrders cancels all pending orders +func (t *BitgetTrader) CancelAllOrders(symbol string) error { + symbol = t.convertSymbol(symbol) + + // Get pending orders + params := map[string]interface{}{ + "symbol": symbol, + "productType": "USDT-FUTURES", + } + + data, err := t.doRequest("GET", bitgetPendingPath, params) + if err != nil { + return err + } + + var orders struct { + EntrustedList []struct { + OrderId string `json:"orderId"` + } `json:"entrustedList"` + } + + if err := json.Unmarshal(data, &orders); err != nil { + return err + } + + // Cancel each order + for _, order := range orders.EntrustedList { + body := map[string]interface{}{ + "symbol": symbol, + "productType": "USDT-FUTURES", + "marginCoin": "USDT", + "orderId": order.OrderId, + } + t.doRequest("POST", bitgetCancelOrderPath, body) + } + + // Also cancel plan orders + t.cancelPlanOrders(symbol, "loss_plan") + t.cancelPlanOrders(symbol, "profit_plan") + + return nil +} + +// CancelStopOrders cancels stop loss and take profit orders +func (t *BitgetTrader) CancelStopOrders(symbol string) error { + t.CancelStopLossOrders(symbol) + t.CancelTakeProfitOrders(symbol) + return nil +} + +// GetOrderStatus gets order status +func (t *BitgetTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { + symbol = t.convertSymbol(symbol) + + params := map[string]interface{}{ + "symbol": symbol, + "productType": "USDT-FUTURES", + "orderId": orderID, + } + + data, err := t.doRequest("GET", "/api/v2/mix/order/detail", params) + if err != nil { + return nil, fmt.Errorf("failed to get order status: %w", err) + } + + var order struct { + OrderId string `json:"orderId"` + State string `json:"state"` // filled, canceled, partially_filled, new + PriceAvg string `json:"priceAvg"` // Average fill price + BaseVolume string `json:"baseVolume"` // Filled quantity + Fee string `json:"fee"` // Fee + Side string `json:"side"` + OrderType string `json:"orderType"` + CTime string `json:"cTime"` + UTime string `json:"uTime"` + } + + if err := json.Unmarshal(data, &order); err != nil { + return nil, err + } + + avgPrice, _ := strconv.ParseFloat(order.PriceAvg, 64) + fillQty, _ := strconv.ParseFloat(order.BaseVolume, 64) + fee, _ := strconv.ParseFloat(order.Fee, 64) + cTime, _ := strconv.ParseInt(order.CTime, 10, 64) + uTime, _ := strconv.ParseInt(order.UTime, 10, 64) + + // Status mapping + statusMap := map[string]string{ + "filled": "FILLED", + "new": "NEW", + "partially_filled": "PARTIALLY_FILLED", + "canceled": "CANCELED", + } + + status := statusMap[order.State] + if status == "" { + status = order.State + } + + return map[string]interface{}{ + "orderId": order.OrderId, + "symbol": symbol, + "status": status, + "avgPrice": avgPrice, + "executedQty": fillQty, + "side": order.Side, + "type": order.OrderType, + "time": cTime, + "updateTime": uTime, + "commission": -fee, + }, nil +} + +// GetOpenOrders gets all open/pending orders for a symbol +func (t *BitgetTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) { + symbol = t.convertSymbol(symbol) + var result []types.OpenOrder + + // 1. Get pending limit orders + params := map[string]interface{}{ + "symbol": symbol, + "productType": "USDT-FUTURES", + } + + data, err := t.doRequest("GET", bitgetPendingPath, params) + if err != nil { + logger.Warnf("[Bitget] Failed to get pending orders: %v", err) + } + if err == nil && data != nil { + var orders struct { + EntrustedList []struct { + OrderId string `json:"orderId"` + Symbol string `json:"symbol"` + Side string `json:"side"` // buy/sell + TradeSide string `json:"tradeSide"` // open/close + PosSide string `json:"posSide"` // long/short + OrderType string `json:"orderType"` // limit/market + Price string `json:"price"` + Size string `json:"size"` + State string `json:"state"` + } `json:"entrustedList"` + } + if err := json.Unmarshal(data, &orders); err == nil { + for _, order := range orders.EntrustedList { + price, _ := strconv.ParseFloat(order.Price, 64) + quantity, _ := strconv.ParseFloat(order.Size, 64) + + // Convert side to standard format + side := strings.ToUpper(order.Side) + positionSide := strings.ToUpper(order.PosSide) + + result = append(result, types.OpenOrder{ + OrderID: order.OrderId, + Symbol: symbol, + Side: side, + PositionSide: positionSide, + Type: strings.ToUpper(order.OrderType), + Price: price, + StopPrice: 0, + Quantity: quantity, + Status: "NEW", + }) + } + } + } + + // 2. Get pending plan orders (stop-loss/take-profit) + // Bitget V2 API requires planType parameter: profit_loss for SL/TP orders + planParams := map[string]interface{}{ + "productType": "USDT-FUTURES", + "planType": "profit_loss", + } + + planData, err := t.doRequest("GET", "/api/v2/mix/order/orders-plan-pending", planParams) + if err != nil { + logger.Warnf("[Bitget] Failed to get plan orders: %v", err) + } + if err == nil && planData != nil { + var planOrders struct { + EntrustedList []struct { + OrderId string `json:"orderId"` + Symbol string `json:"symbol"` + Side string `json:"side"` + PosSide string `json:"posSide"` + PlanType string `json:"planType"` // pos_loss, pos_profit + TriggerPrice string `json:"triggerPrice"` + StopLossTriggerPrice string `json:"stopLossTriggerPrice"` + StopSurplusTriggerPrice string `json:"stopSurplusTriggerPrice"` + Size string `json:"size"` + PlanStatus string `json:"planStatus"` + } `json:"entrustedList"` + } + if err := json.Unmarshal(planData, &planOrders); err == nil { + for _, order := range planOrders.EntrustedList { + // Filter by symbol if specified + if symbol != "" && order.Symbol != symbol { + continue + } + + // Determine trigger price based on plan type + var triggerPrice float64 + orderType := "STOP_MARKET" + + if order.PlanType == "pos_profit" { + // Take profit order + orderType = "TAKE_PROFIT_MARKET" + if order.StopSurplusTriggerPrice != "" { + triggerPrice, _ = strconv.ParseFloat(order.StopSurplusTriggerPrice, 64) + } else { + triggerPrice, _ = strconv.ParseFloat(order.TriggerPrice, 64) + } + } else { + // Stop loss order (pos_loss) + if order.StopLossTriggerPrice != "" { + triggerPrice, _ = strconv.ParseFloat(order.StopLossTriggerPrice, 64) + } else { + triggerPrice, _ = strconv.ParseFloat(order.TriggerPrice, 64) + } + } + + quantity, _ := strconv.ParseFloat(order.Size, 64) + side := strings.ToUpper(order.Side) + positionSide := strings.ToUpper(order.PosSide) + + result = append(result, types.OpenOrder{ + OrderID: order.OrderId, + Symbol: order.Symbol, + Side: side, + PositionSide: positionSide, + Type: orderType, + Price: 0, + StopPrice: triggerPrice, + Quantity: quantity, + Status: "NEW", + }) + } + } + } + + logger.Infof("✓ BITGET GetOpenOrders: found %d open orders for %s", len(result), symbol) + return result, nil +} + +// PlaceLimitOrder places a limit order for grid trading +// Implements GridTrader interface +func (t *BitgetTrader) PlaceLimitOrder(req *types.LimitOrderRequest) (*types.LimitOrderResult, error) { + symbol := t.convertSymbol(req.Symbol) + + // Set leverage if specified + if req.Leverage > 0 { + if err := t.SetLeverage(symbol, req.Leverage); err != nil { + logger.Warnf("[Bitget] Failed to set leverage: %v", err) + } + } + + // Format quantity + qtyStr, _ := t.FormatQuantity(symbol, req.Quantity) + + // Determine side + side := "buy" + if req.Side == "SELL" { + side = "sell" + } + + body := map[string]interface{}{ + "symbol": symbol, + "productType": "USDT-FUTURES", + "marginMode": "crossed", + "marginCoin": "USDT", + "side": side, + "orderType": "limit", + "size": qtyStr, + "price": fmt.Sprintf("%.8f", req.Price), + "force": "GTC", // Good Till Cancel + "clientOid": genBitgetClientOid(), + } + + // Add reduce only if specified + if req.ReduceOnly { + body["reduceOnly"] = "YES" + } + + logger.Infof("[Bitget] PlaceLimitOrder: %s %s @ %.4f, qty=%s", symbol, side, req.Price, qtyStr) + + data, err := t.doRequest("POST", bitgetOrderPath, body) + if err != nil { + return nil, fmt.Errorf("failed to place limit order: %w", err) + } + + var order struct { + OrderId string `json:"orderId"` + ClientOid string `json:"clientOid"` + } + + if err := json.Unmarshal(data, &order); err != nil { + return nil, fmt.Errorf("failed to parse order response: %w", err) + } + + logger.Infof("✓ [Bitget] Limit order placed: %s %s @ %.4f, orderID=%s", + symbol, side, req.Price, order.OrderId) + + return &types.LimitOrderResult{ + OrderID: order.OrderId, + ClientID: order.ClientOid, + Symbol: req.Symbol, + Side: req.Side, + PositionSide: req.PositionSide, + Price: req.Price, + Quantity: req.Quantity, + Status: "NEW", + }, nil +} + +// CancelOrder cancels a specific order by ID +// Implements GridTrader interface +func (t *BitgetTrader) CancelOrder(symbol, orderID string) error { + symbol = t.convertSymbol(symbol) + + body := map[string]interface{}{ + "symbol": symbol, + "productType": "USDT-FUTURES", + "orderId": orderID, + } + + _, err := t.doRequest("POST", "/api/v2/mix/order/cancel-order", body) + if err != nil { + return fmt.Errorf("failed to cancel order: %w", err) + } + + logger.Infof("✓ [Bitget] Order cancelled: %s %s", symbol, orderID) + return nil +} diff --git a/trader/bitget/trader_positions.go b/trader/bitget/trader_positions.go new file mode 100644 index 00000000..2a87a792 --- /dev/null +++ b/trader/bitget/trader_positions.go @@ -0,0 +1,160 @@ +package bitget + +import ( + "encoding/json" + "fmt" + "nofx/trader/types" + "strconv" + "time" +) + +// GetPositions gets all positions +func (t *BitgetTrader) GetPositions() ([]map[string]interface{}, error) { + // Check cache + t.positionsCacheMutex.RLock() + if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration { + t.positionsCacheMutex.RUnlock() + return t.cachedPositions, nil + } + t.positionsCacheMutex.RUnlock() + + params := map[string]interface{}{ + "productType": "USDT-FUTURES", + "marginCoin": "USDT", + } + + data, err := t.doRequest("GET", bitgetPositionPath, params) + if err != nil { + return nil, fmt.Errorf("failed to get positions: %w", err) + } + + var positions []struct { + Symbol string `json:"symbol"` + HoldSide string `json:"holdSide"` // long, short + OpenPriceAvg string `json:"openPriceAvg"` // Average entry price + MarkPrice string `json:"markPrice"` // Mark price + Total string `json:"total"` // Total position size + Available string `json:"available"` // Available to close + UnrealizedPL string `json:"unrealizedPL"` // Unrealized P&L + Leverage string `json:"leverage"` // Leverage + LiquidationPrice string `json:"liquidationPrice"` // Liquidation price + MarginSize string `json:"marginSize"` // Position margin + CTime string `json:"cTime"` // Create time + UTime string `json:"uTime"` // Update time + } + + if err := json.Unmarshal(data, &positions); err != nil { + return nil, fmt.Errorf("failed to parse position data: %w", err) + } + + var result []map[string]interface{} + for _, pos := range positions { + total, _ := strconv.ParseFloat(pos.Total, 64) + if total == 0 { + continue + } + + entryPrice, _ := strconv.ParseFloat(pos.OpenPriceAvg, 64) + markPrice, _ := strconv.ParseFloat(pos.MarkPrice, 64) + unrealizedPnL, _ := strconv.ParseFloat(pos.UnrealizedPL, 64) + leverage, _ := strconv.ParseFloat(pos.Leverage, 64) + liqPrice, _ := strconv.ParseFloat(pos.LiquidationPrice, 64) + cTime, _ := strconv.ParseInt(pos.CTime, 10, 64) + uTime, _ := strconv.ParseInt(pos.UTime, 10, 64) + + // Normalize side + side := "long" + if pos.HoldSide == "short" { + side = "short" + } + + posMap := map[string]interface{}{ + "symbol": pos.Symbol, + "positionAmt": total, + "entryPrice": entryPrice, + "markPrice": markPrice, + "unRealizedProfit": unrealizedPnL, + "leverage": leverage, + "liquidationPrice": liqPrice, + "side": side, + "createdTime": cTime, + "updatedTime": uTime, + } + result = append(result, posMap) + } + + // Update cache + t.positionsCacheMutex.Lock() + t.cachedPositions = result + t.positionsCacheTime = time.Now() + t.positionsCacheMutex.Unlock() + + return result, nil +} + +// GetClosedPnL retrieves closed position PnL records +func (t *BitgetTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) { + if limit <= 0 { + limit = 100 + } + if limit > 100 { + limit = 100 + } + + params := map[string]interface{}{ + "productType": "USDT-FUTURES", + "startTime": fmt.Sprintf("%d", startTime.UnixMilli()), + "limit": fmt.Sprintf("%d", limit), + } + + data, err := t.doRequest("GET", "/api/v2/mix/position/history-position", params) + if err != nil { + return nil, fmt.Errorf("failed to get positions history: %w", err) + } + + var resp struct { + List []struct { + Symbol string `json:"symbol"` + HoldSide string `json:"holdSide"` + OpenPriceAvg string `json:"openPriceAvg"` + ClosePriceAvg string `json:"closePriceAvg"` + CloseVol string `json:"closeVol"` + AchievedProfits string `json:"achievedProfits"` + TotalFee string `json:"totalFee"` + Leverage string `json:"leverage"` + CTime string `json:"cTime"` + UTime string `json:"uTime"` + } `json:"list"` + } + + if err := json.Unmarshal(data, &resp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + records := make([]types.ClosedPnLRecord, 0, len(resp.List)) + for _, pos := range resp.List { + record := types.ClosedPnLRecord{ + Symbol: pos.Symbol, + Side: pos.HoldSide, + } + + record.EntryPrice, _ = strconv.ParseFloat(pos.OpenPriceAvg, 64) + record.ExitPrice, _ = strconv.ParseFloat(pos.ClosePriceAvg, 64) + record.Quantity, _ = strconv.ParseFloat(pos.CloseVol, 64) + record.RealizedPnL, _ = strconv.ParseFloat(pos.AchievedProfits, 64) + fee, _ := strconv.ParseFloat(pos.TotalFee, 64) + record.Fee = -fee + lev, _ := strconv.ParseFloat(pos.Leverage, 64) + record.Leverage = int(lev) + + cTime, _ := strconv.ParseInt(pos.CTime, 10, 64) + uTime, _ := strconv.ParseInt(pos.UTime, 10, 64) + record.EntryTime = time.UnixMilli(cTime).UTC() + record.ExitTime = time.UnixMilli(uTime).UTC() + + record.CloseType = "unknown" + records = append(records, record) + } + + return records, nil +} diff --git a/trader/bybit/trader.go b/trader/bybit/trader.go index 1d4e87c7..5f8e9bfb 100644 --- a/trader/bybit/trader.go +++ b/trader/bybit/trader.go @@ -1,10 +1,6 @@ package bybit import ( - "context" - "crypto/hmac" - "crypto/sha256" - "encoding/hex" "encoding/json" "fmt" "io" @@ -17,7 +13,6 @@ import ( "time" bybit "github.com/bybit-exchange/bybit.go.api" - "nofx/trader/types" ) // BybitTrader Bybit USDT Perpetual Futures Trader @@ -87,590 +82,6 @@ func (h *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error return h.base.RoundTrip(req) } -// GetBalance retrieves account balance -func (t *BybitTrader) GetBalance() (map[string]interface{}, error) { - // Check cache - t.balanceCacheMutex.RLock() - if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration { - balance := t.cachedBalance - t.balanceCacheMutex.RUnlock() - return balance, nil - } - t.balanceCacheMutex.RUnlock() - - // Call API - params := map[string]interface{}{ - "accountType": "UNIFIED", - } - - result, err := t.client.NewUtaBybitServiceWithParams(params).GetAccountWallet(context.Background()) - if err != nil { - return nil, fmt.Errorf("failed to get Bybit balance: %w", err) - } - - if result.RetCode != 0 { - return nil, fmt.Errorf("Bybit API error: %s", result.RetMsg) - } - - // Extract balance information - resultData, ok := result.Result.(map[string]interface{}) - if !ok { - return nil, fmt.Errorf("Bybit balance return format error") - } - - list, _ := resultData["list"].([]interface{}) - - var totalEquity, availableBalance, totalWalletBalance, totalPerpUPL float64 = 0, 0, 0, 0 - - if len(list) > 0 { - account, _ := list[0].(map[string]interface{}) - if equityStr, ok := account["totalEquity"].(string); ok { - totalEquity, _ = strconv.ParseFloat(equityStr, 64) - } - if availStr, ok := account["totalAvailableBalance"].(string); ok { - availableBalance, _ = strconv.ParseFloat(availStr, 64) - } - // Bybit UNIFIED account wallet balance field - if walletStr, ok := account["totalWalletBalance"].(string); ok { - totalWalletBalance, _ = strconv.ParseFloat(walletStr, 64) - } - // Bybit perpetual contract unrealized PnL - if uplStr, ok := account["totalPerpUPL"].(string); ok { - totalPerpUPL, _ = strconv.ParseFloat(uplStr, 64) - } - } - - // If no totalWalletBalance, use totalEquity - if totalWalletBalance == 0 { - totalWalletBalance = totalEquity - } - - balance := map[string]interface{}{ - "totalEquity": totalEquity, - "totalWalletBalance": totalWalletBalance, - "availableBalance": availableBalance, - "totalUnrealizedProfit": totalPerpUPL, - "balance": totalEquity, // Compatible with other exchange formats - } - - // Update cache - t.balanceCacheMutex.Lock() - t.cachedBalance = balance - t.balanceCacheTime = time.Now() - t.balanceCacheMutex.Unlock() - - return balance, nil -} - -// GetPositions retrieves all positions -func (t *BybitTrader) GetPositions() ([]map[string]interface{}, error) { - // Check cache - t.positionsCacheMutex.RLock() - if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration { - positions := t.cachedPositions - t.positionsCacheMutex.RUnlock() - return positions, nil - } - t.positionsCacheMutex.RUnlock() - - // Call API - params := map[string]interface{}{ - "category": "linear", - "settleCoin": "USDT", - } - - result, err := t.client.NewUtaBybitServiceWithParams(params).GetPositionList(context.Background()) - if err != nil { - return nil, fmt.Errorf("failed to get Bybit positions: %w", err) - } - - if result.RetCode != 0 { - return nil, fmt.Errorf("Bybit API error: %s", result.RetMsg) - } - - resultData, ok := result.Result.(map[string]interface{}) - if !ok { - return nil, fmt.Errorf("Bybit positions return format error") - } - - list, _ := resultData["list"].([]interface{}) - - var positions []map[string]interface{} - - for _, item := range list { - pos, ok := item.(map[string]interface{}) - if !ok { - continue - } - - sizeStr, _ := pos["size"].(string) - size, _ := strconv.ParseFloat(sizeStr, 64) - - // Skip empty positions - if size == 0 { - continue - } - - entryPriceStr, _ := pos["avgPrice"].(string) - entryPrice, _ := strconv.ParseFloat(entryPriceStr, 64) - - unrealisedPnlStr, _ := pos["unrealisedPnl"].(string) - unrealisedPnl, _ := strconv.ParseFloat(unrealisedPnlStr, 64) - - leverageStr, _ := pos["leverage"].(string) - leverage, _ := strconv.ParseFloat(leverageStr, 64) - - // Mark price - markPriceStr, _ := pos["markPrice"].(string) - markPrice, _ := strconv.ParseFloat(markPriceStr, 64) - - // Liquidation price - liqPriceStr, _ := pos["liqPrice"].(string) - liqPrice, _ := strconv.ParseFloat(liqPriceStr, 64) - - // Position created/updated time (milliseconds timestamp) - createdTimeStr, _ := pos["createdTime"].(string) - createdTime, _ := strconv.ParseInt(createdTimeStr, 10, 64) - updatedTimeStr, _ := pos["updatedTime"].(string) - updatedTime, _ := strconv.ParseInt(updatedTimeStr, 10, 64) - - positionSide, _ := pos["side"].(string) // Buy = long, Sell = short - - // Log raw position data for debugging - logger.Infof("[Bybit] GetPositions raw: symbol=%v, side=%s, size=%v", pos["symbol"], positionSide, sizeStr) - - // Convert to unified format (use lowercase for consistency with other exchanges) - // Bybit returns "Buy" for long, "Sell" for short - side := "long" - positionAmt := size - positionSideLower := strings.ToLower(positionSide) - if positionSideLower == "sell" { - side = "short" - positionAmt = -size - } - - logger.Infof("[Bybit] GetPositions converted: symbol=%v, rawSide=%s -> side=%s", pos["symbol"], positionSide, side) - - position := map[string]interface{}{ - "symbol": pos["symbol"], - "side": side, - "positionAmt": positionAmt, - "entryPrice": entryPrice, - "markPrice": markPrice, - "unRealizedProfit": unrealisedPnl, - "unrealizedPnL": unrealisedPnl, - "liquidationPrice": liqPrice, - "leverage": leverage, - "createdTime": createdTime, // Position open time (ms) - "updatedTime": updatedTime, // Position last update time (ms) - } - - positions = append(positions, position) - } - - // Update cache - t.positionsCacheMutex.Lock() - t.cachedPositions = positions - t.positionsCacheTime = time.Now() - t.positionsCacheMutex.Unlock() - - return positions, nil -} - -// OpenLong opens a long position -func (t *BybitTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { - logger.Infof("[Bybit] ===== OpenLong called: symbol=%s, qty=%.6f, leverage=%d =====", symbol, quantity, leverage) - - // First cancel all pending orders for this symbol (clean up old orders) - if err := t.CancelAllOrders(symbol); err != nil { - logger.Infof("⚠️ [Bybit] Failed to cancel old pending orders: %v", err) - } - // Also cancel conditional orders (stop-loss/take-profit) - Bybit keeps them separate - if err := t.CancelStopOrders(symbol); err != nil { - logger.Infof("⚠️ [Bybit] Failed to cancel old stop orders: %v", err) - } - - // Set leverage first - if err := t.SetLeverage(symbol, leverage); err != nil { - logger.Infof("⚠️ [Bybit] Failed to set leverage: %v", err) - } - - // Use FormatQuantity to format quantity - qtyStr, _ := t.FormatQuantity(symbol, quantity) - - params := map[string]interface{}{ - "category": "linear", - "symbol": symbol, - "side": "Buy", - "orderType": "Market", - "qty": qtyStr, - "positionIdx": 0, // One-way position mode - } - - logger.Infof("[Bybit] OpenLong placing order: %+v", params) - - result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background()) - if err != nil { - return nil, fmt.Errorf("Bybit open long failed: %w", err) - } - - // Clear cache - t.clearCache() - - return t.parseOrderResult(result) -} - -// OpenShort opens a short position -func (t *BybitTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { - logger.Infof("[Bybit] ===== OpenShort called: symbol=%s, qty=%.6f, leverage=%d =====", symbol, quantity, leverage) - - // First cancel all pending orders for this symbol (clean up old orders) - if err := t.CancelAllOrders(symbol); err != nil { - logger.Infof("⚠️ [Bybit] Failed to cancel old pending orders: %v", err) - } - // Also cancel conditional orders (stop-loss/take-profit) - Bybit keeps them separate - if err := t.CancelStopOrders(symbol); err != nil { - logger.Infof("⚠️ [Bybit] Failed to cancel old stop orders: %v", err) - } - - // Set leverage first - if err := t.SetLeverage(symbol, leverage); err != nil { - logger.Infof("⚠️ [Bybit] Failed to set leverage: %v", err) - } - - // Use FormatQuantity to format quantity - qtyStr, _ := t.FormatQuantity(symbol, quantity) - - params := map[string]interface{}{ - "category": "linear", - "symbol": symbol, - "side": "Sell", - "orderType": "Market", - "qty": qtyStr, - "positionIdx": 0, // One-way position mode - } - - logger.Infof("[Bybit] OpenShort placing order: %+v", params) - - result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background()) - if err != nil { - return nil, fmt.Errorf("Bybit open short failed: %w", err) - } - - // Clear cache - t.clearCache() - - return t.parseOrderResult(result) -} - -// CloseLong closes a long position -func (t *BybitTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) { - // If quantity = 0, get current position quantity - if quantity == 0 { - positions, err := t.GetPositions() - if err != nil { - return nil, err - } - for _, pos := range positions { - side, _ := pos["side"].(string) - if pos["symbol"] == symbol && strings.ToLower(side) == "long" { - quantity = pos["positionAmt"].(float64) - break - } - } - } - - if quantity <= 0 { - return nil, fmt.Errorf("no long position to close") - } - - // Use FormatQuantity to format quantity - qtyStr, _ := t.FormatQuantity(symbol, quantity) - - params := map[string]interface{}{ - "category": "linear", - "symbol": symbol, - "side": "Sell", // Close long with Sell - "orderType": "Market", - "qty": qtyStr, - "positionIdx": 0, - "reduceOnly": true, - } - - result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background()) - if err != nil { - return nil, fmt.Errorf("Bybit close long failed: %w", err) - } - - // Clear cache - t.clearCache() - - return t.parseOrderResult(result) -} - -// CloseShort closes a short position -func (t *BybitTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) { - // If quantity = 0, get current position quantity - if quantity == 0 { - positions, err := t.GetPositions() - if err != nil { - return nil, err - } - for _, pos := range positions { - side, _ := pos["side"].(string) - if pos["symbol"] == symbol && strings.ToLower(side) == "short" { - quantity = -pos["positionAmt"].(float64) // Short position is negative - break - } - } - } - - if quantity <= 0 { - return nil, fmt.Errorf("no short position to close") - } - - // Use FormatQuantity to format quantity - qtyStr, _ := t.FormatQuantity(symbol, quantity) - - params := map[string]interface{}{ - "category": "linear", - "symbol": symbol, - "side": "Buy", // Close short with Buy - "orderType": "Market", - "qty": qtyStr, - "positionIdx": 0, - "reduceOnly": true, - } - - result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background()) - if err != nil { - return nil, fmt.Errorf("Bybit close short failed: %w", err) - } - - // Clear cache - t.clearCache() - - return t.parseOrderResult(result) -} - -// SetLeverage sets leverage -func (t *BybitTrader) SetLeverage(symbol string, leverage int) error { - params := map[string]interface{}{ - "category": "linear", - "symbol": symbol, - "buyLeverage": fmt.Sprintf("%d", leverage), - "sellLeverage": fmt.Sprintf("%d", leverage), - } - - result, err := t.client.NewUtaBybitServiceWithParams(params).SetPositionLeverage(context.Background()) - if err != nil { - // If leverage is already at target value, Bybit will return an error, ignore this case - if strings.Contains(err.Error(), "leverage not modified") { - return nil - } - return fmt.Errorf("failed to set leverage: %w", err) - } - - if result.RetCode != 0 && result.RetCode != 110043 { // 110043 = leverage not modified - return fmt.Errorf("failed to set leverage: %s", result.RetMsg) - } - - return nil -} - -// SetMarginMode sets position margin mode -func (t *BybitTrader) SetMarginMode(symbol string, isCrossMargin bool) error { - tradeMode := 1 // Isolated margin - if isCrossMargin { - tradeMode = 0 // Cross margin - } - - params := map[string]interface{}{ - "category": "linear", - "symbol": symbol, - "tradeMode": tradeMode, - } - - result, err := t.client.NewUtaBybitServiceWithParams(params).SwitchPositionMargin(context.Background()) - if err != nil { - if strings.Contains(err.Error(), "Cross/isolated margin mode is not modified") { - return nil - } - return fmt.Errorf("failed to set margin mode: %w", err) - } - - if result.RetCode != 0 && result.RetCode != 110026 { // already in target mode - return fmt.Errorf("failed to set margin mode: %s", result.RetMsg) - } - - return nil -} - -// GetMarketPrice retrieves market price -func (t *BybitTrader) GetMarketPrice(symbol string) (float64, error) { - params := map[string]interface{}{ - "category": "linear", - "symbol": symbol, - } - - result, err := t.client.NewUtaBybitServiceWithParams(params).GetMarketTickers(context.Background()) - if err != nil { - return 0, fmt.Errorf("failed to get market price: %w", err) - } - - if result.RetCode != 0 { - return 0, fmt.Errorf("API error: %s", result.RetMsg) - } - - resultData, ok := result.Result.(map[string]interface{}) - if !ok { - return 0, fmt.Errorf("return format error") - } - - list, _ := resultData["list"].([]interface{}) - - if len(list) == 0 { - return 0, fmt.Errorf("price data not found for %s", symbol) - } - - ticker, _ := list[0].(map[string]interface{}) - lastPriceStr, _ := ticker["lastPrice"].(string) - lastPrice, err := strconv.ParseFloat(lastPriceStr, 64) - if err != nil { - return 0, fmt.Errorf("failed to parse price: %w", err) - } - - return lastPrice, nil -} - -// SetStopLoss sets stop loss order -func (t *BybitTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error { - side := "Sell" // LONG stop loss uses Sell - if positionSide == "SHORT" { - side = "Buy" // SHORT stop loss uses Buy - } - - // Get current price to determine triggerDirection - currentPrice, err := t.GetMarketPrice(symbol) - if err != nil { - return err - } - - triggerDirection := 2 // Price fall trigger (default long stop loss) - if stopPrice > currentPrice { - triggerDirection = 1 // Price rise trigger (short stop loss) - } - - // Use FormatQuantity to format quantity - qtyStr, _ := t.FormatQuantity(symbol, quantity) - - params := map[string]interface{}{ - "category": "linear", - "symbol": symbol, - "side": side, - "orderType": "Market", - "qty": qtyStr, - "triggerPrice": fmt.Sprintf("%v", stopPrice), - "triggerDirection": triggerDirection, - "triggerBy": "LastPrice", - "reduceOnly": true, - } - - result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background()) - if err != nil { - return fmt.Errorf("failed to set stop loss: %w", err) - } - - if result.RetCode != 0 { - return fmt.Errorf("failed to set stop loss: %s", result.RetMsg) - } - - logger.Infof(" ✓ [Bybit] Stop loss order set: %s @ %.2f", symbol, stopPrice) - return nil -} - -// SetTakeProfit sets take profit order -func (t *BybitTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error { - side := "Sell" // LONG take profit uses Sell - if positionSide == "SHORT" { - side = "Buy" // SHORT take profit uses Buy - } - - // Get current price to determine triggerDirection - currentPrice, err := t.GetMarketPrice(symbol) - if err != nil { - return err - } - - triggerDirection := 1 // Price rise trigger (default long take profit) - if takeProfitPrice < currentPrice { - triggerDirection = 2 // Price fall trigger (short take profit) - } - - // Use FormatQuantity to format quantity - qtyStr, _ := t.FormatQuantity(symbol, quantity) - - params := map[string]interface{}{ - "category": "linear", - "symbol": symbol, - "side": side, - "orderType": "Market", - "qty": qtyStr, - "triggerPrice": fmt.Sprintf("%v", takeProfitPrice), - "triggerDirection": triggerDirection, - "triggerBy": "LastPrice", - "reduceOnly": true, - } - - result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background()) - if err != nil { - return fmt.Errorf("failed to set take profit: %w", err) - } - - if result.RetCode != 0 { - return fmt.Errorf("failed to set take profit: %s", result.RetMsg) - } - - logger.Infof(" ✓ [Bybit] Take profit order set: %s @ %.2f", symbol, takeProfitPrice) - return nil -} - -// CancelStopLossOrders cancels stop loss orders -func (t *BybitTrader) CancelStopLossOrders(symbol string) error { - return t.cancelConditionalOrders(symbol, "StopLoss") -} - -// CancelTakeProfitOrders cancels take profit orders -func (t *BybitTrader) CancelTakeProfitOrders(symbol string) error { - return t.cancelConditionalOrders(symbol, "TakeProfit") -} - -// CancelAllOrders cancels all pending orders -func (t *BybitTrader) CancelAllOrders(symbol string) error { - params := map[string]interface{}{ - "category": "linear", - "symbol": symbol, - } - - _, err := t.client.NewUtaBybitServiceWithParams(params).CancelAllOrders(context.Background()) - if err != nil { - return fmt.Errorf("failed to cancel all orders: %w", err) - } - - return nil -} - -// CancelStopOrders cancels all stop loss and take profit orders -func (t *BybitTrader) CancelStopOrders(symbol string) error { - if err := t.CancelStopLossOrders(symbol); err != nil { - logger.Infof("⚠️ [Bybit] Failed to cancel stop loss orders: %v", err) - } - if err := t.CancelTakeProfitOrders(symbol); err != nil { - logger.Infof("⚠️ [Bybit] Failed to cancel take profit orders: %v", err) - } - return nil -} - // getQtyStep retrieves the quantity step for a trading pair func (t *BybitTrader) getQtyStep(symbol string) float64 { // Check cache first @@ -782,483 +193,3 @@ func (t *BybitTrader) parseOrderResult(result *bybit.ServerResponse) (map[string "status": "NEW", }, nil } - -// GetOrderStatus retrieves order status -func (t *BybitTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { - params := map[string]interface{}{ - "category": "linear", - "symbol": symbol, - "orderId": orderID, - } - - result, err := t.client.NewUtaBybitServiceWithParams(params).GetOrderHistory(context.Background()) - if err != nil { - return nil, fmt.Errorf("failed to get order status: %w", err) - } - - if result.RetCode != 0 { - return nil, fmt.Errorf("API error: %s", result.RetMsg) - } - - resultData, ok := result.Result.(map[string]interface{}) - if !ok { - return nil, fmt.Errorf("return format error") - } - - list, _ := resultData["list"].([]interface{}) - if len(list) == 0 { - return nil, fmt.Errorf("order %s not found", orderID) - } - - order, _ := list[0].(map[string]interface{}) - - // Parse order data - status, _ := order["orderStatus"].(string) - avgPriceStr, _ := order["avgPrice"].(string) - cumExecQtyStr, _ := order["cumExecQty"].(string) - cumExecFeeStr, _ := order["cumExecFee"].(string) - - avgPrice, _ := strconv.ParseFloat(avgPriceStr, 64) - executedQty, _ := strconv.ParseFloat(cumExecQtyStr, 64) - commission, _ := strconv.ParseFloat(cumExecFeeStr, 64) - - // Convert status to unified format - unifiedStatus := status - switch status { - case "Filled": - unifiedStatus = "FILLED" - case "New", "Created": - unifiedStatus = "NEW" - case "Cancelled", "Rejected": - unifiedStatus = "CANCELED" - case "PartiallyFilled": - unifiedStatus = "PARTIALLY_FILLED" - } - - return map[string]interface{}{ - "orderId": orderID, - "status": unifiedStatus, - "avgPrice": avgPrice, - "executedQty": executedQty, - "commission": commission, - }, nil -} - -func (t *BybitTrader) cancelConditionalOrders(symbol string, orderType string) error { - // First get all conditional orders - params := map[string]interface{}{ - "category": "linear", - "symbol": symbol, - "orderFilter": "StopOrder", // Conditional orders - } - - result, err := t.client.NewUtaBybitServiceWithParams(params).GetOpenOrders(context.Background()) - if err != nil { - return fmt.Errorf("failed to get conditional orders: %w", err) - } - - if result.RetCode != 0 { - return nil // No orders - } - - resultData, ok := result.Result.(map[string]interface{}) - if !ok { - return nil - } - - list, _ := resultData["list"].([]interface{}) - - // Cancel matching orders - for _, item := range list { - order, ok := item.(map[string]interface{}) - if !ok { - continue - } - - orderId, _ := order["orderId"].(string) - stopOrderType, _ := order["stopOrderType"].(string) - - // Filter by type - shouldCancel := false - if orderType == "StopLoss" && (stopOrderType == "StopLoss" || stopOrderType == "Stop") { - shouldCancel = true - } - if orderType == "TakeProfit" && (stopOrderType == "TakeProfit" || stopOrderType == "PartialTakeProfit") { - shouldCancel = true - } - - if shouldCancel && orderId != "" { - cancelParams := map[string]interface{}{ - "category": "linear", - "symbol": symbol, - "orderId": orderId, - } - t.client.NewUtaBybitServiceWithParams(cancelParams).CancelOrder(context.Background()) - } - } - - return nil -} - -// GetClosedPnL retrieves closed position PnL records from Bybit via direct HTTP API -func (t *BybitTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) { - // The Bybit SDK doesn't expose the closed-pnl endpoint, use direct HTTP call - return t.getClosedPnLViaHTTP(startTime, limit) -} - -// getClosedPnLViaHTTP makes direct HTTP call to Bybit API for closed PnL with proper signing -func (t *BybitTrader) getClosedPnLViaHTTP(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) { - // Build query string - queryParams := fmt.Sprintf("category=linear&startTime=%d&limit=%d", startTime.UnixMilli(), limit) - url := "https://api.bybit.com/v5/position/closed-pnl?" + queryParams - - // Generate timestamp - timestamp := fmt.Sprintf("%d", time.Now().UnixMilli()) - recvWindow := "5000" - - // Build signature payload: timestamp + api_key + recv_window + queryString - signPayload := timestamp + t.apiKey + recvWindow + queryParams - - // Generate HMAC-SHA256 signature - h := hmac.New(sha256.New, []byte(t.secretKey)) - h.Write([]byte(signPayload)) - signature := hex.EncodeToString(h.Sum(nil)) - - // Create request - req, err := http.NewRequest("GET", url, nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - // Add Bybit V5 API headers - req.Header.Set("X-BAPI-API-KEY", t.apiKey) - req.Header.Set("X-BAPI-SIGN", signature) - req.Header.Set("X-BAPI-SIGN-TYPE", "2") - req.Header.Set("X-BAPI-TIMESTAMP", timestamp) - req.Header.Set("X-BAPI-RECV-WINDOW", recvWindow) - req.Header.Set("Content-Type", "application/json") - - // Use http.DefaultClient for the request - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to call Bybit API: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response: %w", err) - } - - var result struct { - RetCode int `json:"retCode"` - RetMsg string `json:"retMsg"` - Result map[string]interface{} `json:"result"` - } - - if err := json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("failed to parse response: %w", err) - } - - if result.RetCode != 0 { - return nil, fmt.Errorf("Bybit API error: %s", result.RetMsg) - } - - return t.parseClosedPnLResult(result.Result) -} - -// parseClosedPnLResult parses the closed PnL result from Bybit API -func (t *BybitTrader) parseClosedPnLResult(resultData interface{}) ([]types.ClosedPnLRecord, error) { - data, ok := resultData.(map[string]interface{}) - if !ok { - return nil, fmt.Errorf("invalid result format") - } - - list, _ := data["list"].([]interface{}) - var records []types.ClosedPnLRecord - - for _, item := range list { - pnl, ok := item.(map[string]interface{}) - if !ok { - continue - } - - // Parse fields - symbol, _ := pnl["symbol"].(string) - side, _ := pnl["side"].(string) - orderId, _ := pnl["orderId"].(string) - - avgEntryPriceStr, _ := pnl["avgEntryPrice"].(string) - avgExitPriceStr, _ := pnl["avgExitPrice"].(string) - qtyStr, _ := pnl["qty"].(string) - closedPnLStr, _ := pnl["closedPnl"].(string) - cumEntryValueStr, _ := pnl["cumEntryValue"].(string) - cumExitValueStr, _ := pnl["cumExitValue"].(string) - leverageStr, _ := pnl["leverage"].(string) - createdTimeStr, _ := pnl["createdTime"].(string) - updatedTimeStr, _ := pnl["updatedTime"].(string) - - avgEntryPrice, _ := strconv.ParseFloat(avgEntryPriceStr, 64) - avgExitPrice, _ := strconv.ParseFloat(avgExitPriceStr, 64) - qty, _ := strconv.ParseFloat(qtyStr, 64) - closedPnL, _ := strconv.ParseFloat(closedPnLStr, 64) - leverage, _ := strconv.ParseInt(leverageStr, 10, 64) - createdTime, _ := strconv.ParseInt(createdTimeStr, 10, 64) - updatedTime, _ := strconv.ParseInt(updatedTimeStr, 10, 64) - - // Calculate approximate fee from value difference - cumEntryValue, _ := strconv.ParseFloat(cumEntryValueStr, 64) - cumExitValue, _ := strconv.ParseFloat(cumExitValueStr, 64) - expectedPnL := cumExitValue - cumEntryValue - if side == "Sell" { - expectedPnL = cumEntryValue - cumExitValue - } - fee := expectedPnL - closedPnL - if fee < 0 { - fee = 0 - } - - // Normalize side - normalizedSide := "long" - if side == "Sell" { - normalizedSide = "short" - } - - record := types.ClosedPnLRecord{ - Symbol: symbol, - Side: normalizedSide, - EntryPrice: avgEntryPrice, - ExitPrice: avgExitPrice, - Quantity: qty, - RealizedPnL: closedPnL, - Fee: fee, - Leverage: int(leverage), - EntryTime: time.UnixMilli(createdTime).UTC(), - ExitTime: time.UnixMilli(updatedTime).UTC(), - OrderID: orderId, - CloseType: "unknown", // Bybit doesn't provide close type directly - ExchangeID: orderId, // Use orderId as exchange ID - } - - records = append(records, record) - } - - return records, nil -} - -// GetOpenOrders gets all open/pending orders for a symbol -func (t *BybitTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) { - var result []types.OpenOrder - - // Get conditional orders (stop-loss, take-profit) - params := map[string]interface{}{ - "category": "linear", - "symbol": symbol, - "orderFilter": "StopOrder", - } - - resp, err := t.client.NewUtaBybitServiceWithParams(params).GetOpenOrders(context.Background()) - if err != nil { - return nil, fmt.Errorf("failed to get open orders: %w", err) - } - - if resp.RetCode == 0 { - resultData, ok := resp.Result.(map[string]interface{}) - if ok { - list, _ := resultData["list"].([]interface{}) - for _, item := range list { - order, ok := item.(map[string]interface{}) - if !ok { - continue - } - - orderId, _ := order["orderId"].(string) - sym, _ := order["symbol"].(string) - side, _ := order["side"].(string) - orderType, _ := order["orderType"].(string) - stopOrderType, _ := order["stopOrderType"].(string) - triggerPrice, _ := order["triggerPrice"].(string) - qty, _ := order["qty"].(string) - - price, _ := strconv.ParseFloat(triggerPrice, 64) - quantity, _ := strconv.ParseFloat(qty, 64) - - // Determine type based on stopOrderType - displayType := orderType - if stopOrderType != "" { - displayType = stopOrderType - } - - result = append(result, types.OpenOrder{ - OrderID: orderId, - Symbol: sym, - Side: side, - PositionSide: "", // Bybit doesn't use positionSide for UTA - Type: displayType, - Price: 0, - StopPrice: price, - Quantity: quantity, - Status: "NEW", - }) - } - } - } - - return result, nil -} - -// PlaceLimitOrder places a limit order for grid trading -// Implements GridTrader interface -func (t *BybitTrader) PlaceLimitOrder(req *types.LimitOrderRequest) (*types.LimitOrderResult, error) { - // Format quantity - qtyStr, err := t.FormatQuantity(req.Symbol, req.Quantity) - if err != nil { - return nil, fmt.Errorf("failed to format quantity: %w", err) - } - - // Format price - priceStr := fmt.Sprintf("%.8f", req.Price) - - // Set leverage if specified - if req.Leverage > 0 { - if err := t.SetLeverage(req.Symbol, req.Leverage); err != nil { - logger.Warnf("[Bybit] Failed to set leverage: %v", err) - } - } - - // Determine side - side := "Buy" - if req.Side == "SELL" { - side = "Sell" - } - - params := map[string]interface{}{ - "category": "linear", - "symbol": req.Symbol, - "side": side, - "orderType": "Limit", - "qty": qtyStr, - "price": priceStr, - "timeInForce": "GTC", // Good Till Cancel - "positionIdx": 0, // One-way position mode - } - - // Add reduce only if specified - if req.ReduceOnly { - params["reduceOnly"] = true - } - - logger.Infof("[Bybit] PlaceLimitOrder: %s %s @ %s, qty=%s", req.Symbol, side, priceStr, qtyStr) - - result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background()) - if err != nil { - return nil, fmt.Errorf("failed to place limit order: %w", err) - } - - // Parse result - orderID := "" - if result.RetCode == 0 { - if resultData, ok := result.Result.(map[string]interface{}); ok { - if id, ok := resultData["orderId"].(string); ok { - orderID = id - } - } - } else { - return nil, fmt.Errorf("Bybit order failed: %s", result.RetMsg) - } - - logger.Infof("✓ [Bybit] Limit order placed: %s %s @ %s, qty=%s, orderID=%s", - req.Symbol, side, priceStr, qtyStr, orderID) - - return &types.LimitOrderResult{ - OrderID: orderID, - ClientID: req.ClientID, - Symbol: req.Symbol, - Side: req.Side, - PositionSide: req.PositionSide, - Price: req.Price, - Quantity: req.Quantity, - Status: "NEW", - }, nil -} - -// CancelOrder cancels a specific order by ID -// Implements GridTrader interface -func (t *BybitTrader) CancelOrder(symbol, orderID string) error { - params := map[string]interface{}{ - "category": "linear", - "symbol": symbol, - "orderId": orderID, - } - - result, err := t.client.NewUtaBybitServiceWithParams(params).CancelOrder(context.Background()) - if err != nil { - return fmt.Errorf("failed to cancel order: %w", err) - } - - if result.RetCode != 0 { - return fmt.Errorf("Bybit cancel order failed: %s", result.RetMsg) - } - - logger.Infof("✓ [Bybit] Order cancelled: %s %s", symbol, orderID) - return nil -} - -// GetOrderBook gets the order book for a symbol -// Implements GridTrader interface -func (t *BybitTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { - if depth <= 0 { - depth = 25 - } - - // Use HTTP request directly since the SDK doesn't expose GetOrderbook - url := fmt.Sprintf("https://api.bybit.com/v5/market/orderbook?category=linear&symbol=%s&limit=%d", symbol, depth) - resp, err := http.Get(url) - if err != nil { - return nil, nil, fmt.Errorf("failed to get order book: %w", err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusOK { - return nil, nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) - } - - var result struct { - RetCode int `json:"retCode"` - RetMsg string `json:"retMsg"` - Result struct { - S string `json:"s"` // symbol - B [][]string `json:"b"` // bids [[price, size], ...] - A [][]string `json:"a"` // asks [[price, size], ...] - } `json:"result"` - } - - if err := json.Unmarshal(body, &result); err != nil { - return nil, nil, fmt.Errorf("failed to parse order book: %w", err) - } - - if result.RetCode != 0 { - return nil, nil, fmt.Errorf("Bybit get orderbook failed: %s", result.RetMsg) - } - - // Parse bids - for _, b := range result.Result.B { - if len(b) >= 2 { - price, _ := strconv.ParseFloat(b[0], 64) - qty, _ := strconv.ParseFloat(b[1], 64) - bids = append(bids, []float64{price, qty}) - } - } - - // Parse asks - for _, a := range result.Result.A { - if len(a) >= 2 { - price, _ := strconv.ParseFloat(a[0], 64) - qty, _ := strconv.ParseFloat(a[1], 64) - asks = append(asks, []float64{price, qty}) - } - } - - return bids, asks, nil -} diff --git a/trader/bybit/trader_account.go b/trader/bybit/trader_account.go new file mode 100644 index 00000000..88616a04 --- /dev/null +++ b/trader/bybit/trader_account.go @@ -0,0 +1,236 @@ +package bybit + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "nofx/trader/types" + "strconv" + "time" +) + +// GetBalance retrieves account balance +func (t *BybitTrader) GetBalance() (map[string]interface{}, error) { + // Check cache + t.balanceCacheMutex.RLock() + if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration { + balance := t.cachedBalance + t.balanceCacheMutex.RUnlock() + return balance, nil + } + t.balanceCacheMutex.RUnlock() + + // Call API + params := map[string]interface{}{ + "accountType": "UNIFIED", + } + + result, err := t.client.NewUtaBybitServiceWithParams(params).GetAccountWallet(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get Bybit balance: %w", err) + } + + if result.RetCode != 0 { + return nil, fmt.Errorf("Bybit API error: %s", result.RetMsg) + } + + // Extract balance information + resultData, ok := result.Result.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("Bybit balance return format error") + } + + list, _ := resultData["list"].([]interface{}) + + var totalEquity, availableBalance, totalWalletBalance, totalPerpUPL float64 = 0, 0, 0, 0 + + if len(list) > 0 { + account, _ := list[0].(map[string]interface{}) + if equityStr, ok := account["totalEquity"].(string); ok { + totalEquity, _ = strconv.ParseFloat(equityStr, 64) + } + if availStr, ok := account["totalAvailableBalance"].(string); ok { + availableBalance, _ = strconv.ParseFloat(availStr, 64) + } + // Bybit UNIFIED account wallet balance field + if walletStr, ok := account["totalWalletBalance"].(string); ok { + totalWalletBalance, _ = strconv.ParseFloat(walletStr, 64) + } + // Bybit perpetual contract unrealized PnL + if uplStr, ok := account["totalPerpUPL"].(string); ok { + totalPerpUPL, _ = strconv.ParseFloat(uplStr, 64) + } + } + + // If no totalWalletBalance, use totalEquity + if totalWalletBalance == 0 { + totalWalletBalance = totalEquity + } + + balance := map[string]interface{}{ + "totalEquity": totalEquity, + "totalWalletBalance": totalWalletBalance, + "availableBalance": availableBalance, + "totalUnrealizedProfit": totalPerpUPL, + "balance": totalEquity, // Compatible with other exchange formats + } + + // Update cache + t.balanceCacheMutex.Lock() + t.cachedBalance = balance + t.balanceCacheTime = time.Now() + t.balanceCacheMutex.Unlock() + + return balance, nil +} + +// GetClosedPnL retrieves closed position PnL records from Bybit via direct HTTP API +func (t *BybitTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) { + // The Bybit SDK doesn't expose the closed-pnl endpoint, use direct HTTP call + return t.getClosedPnLViaHTTP(startTime, limit) +} + +// getClosedPnLViaHTTP makes direct HTTP call to Bybit API for closed PnL with proper signing +func (t *BybitTrader) getClosedPnLViaHTTP(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) { + // Build query string + queryParams := fmt.Sprintf("category=linear&startTime=%d&limit=%d", startTime.UnixMilli(), limit) + url := "https://api.bybit.com/v5/position/closed-pnl?" + queryParams + + // Generate timestamp + timestamp := fmt.Sprintf("%d", time.Now().UnixMilli()) + recvWindow := "5000" + + // Build signature payload: timestamp + api_key + recv_window + queryString + signPayload := timestamp + t.apiKey + recvWindow + queryParams + + // Generate HMAC-SHA256 signature + h := hmac.New(sha256.New, []byte(t.secretKey)) + h.Write([]byte(signPayload)) + signature := hex.EncodeToString(h.Sum(nil)) + + // Create request + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Add Bybit V5 API headers + req.Header.Set("X-BAPI-API-KEY", t.apiKey) + req.Header.Set("X-BAPI-SIGN", signature) + req.Header.Set("X-BAPI-SIGN-TYPE", "2") + req.Header.Set("X-BAPI-TIMESTAMP", timestamp) + req.Header.Set("X-BAPI-RECV-WINDOW", recvWindow) + req.Header.Set("Content-Type", "application/json") + + // Use http.DefaultClient for the request + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to call Bybit API: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + var result struct { + RetCode int `json:"retCode"` + RetMsg string `json:"retMsg"` + Result map[string]interface{} `json:"result"` + } + + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + if result.RetCode != 0 { + return nil, fmt.Errorf("Bybit API error: %s", result.RetMsg) + } + + return t.parseClosedPnLResult(result.Result) +} + +// parseClosedPnLResult parses the closed PnL result from Bybit API +func (t *BybitTrader) parseClosedPnLResult(resultData interface{}) ([]types.ClosedPnLRecord, error) { + data, ok := resultData.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid result format") + } + + list, _ := data["list"].([]interface{}) + var records []types.ClosedPnLRecord + + for _, item := range list { + pnl, ok := item.(map[string]interface{}) + if !ok { + continue + } + + // Parse fields + symbol, _ := pnl["symbol"].(string) + side, _ := pnl["side"].(string) + orderId, _ := pnl["orderId"].(string) + + avgEntryPriceStr, _ := pnl["avgEntryPrice"].(string) + avgExitPriceStr, _ := pnl["avgExitPrice"].(string) + qtyStr, _ := pnl["qty"].(string) + closedPnLStr, _ := pnl["closedPnl"].(string) + cumEntryValueStr, _ := pnl["cumEntryValue"].(string) + cumExitValueStr, _ := pnl["cumExitValue"].(string) + leverageStr, _ := pnl["leverage"].(string) + createdTimeStr, _ := pnl["createdTime"].(string) + updatedTimeStr, _ := pnl["updatedTime"].(string) + + avgEntryPrice, _ := strconv.ParseFloat(avgEntryPriceStr, 64) + avgExitPrice, _ := strconv.ParseFloat(avgExitPriceStr, 64) + qty, _ := strconv.ParseFloat(qtyStr, 64) + closedPnL, _ := strconv.ParseFloat(closedPnLStr, 64) + leverage, _ := strconv.ParseInt(leverageStr, 10, 64) + createdTime, _ := strconv.ParseInt(createdTimeStr, 10, 64) + updatedTime, _ := strconv.ParseInt(updatedTimeStr, 10, 64) + + // Calculate approximate fee from value difference + cumEntryValue, _ := strconv.ParseFloat(cumEntryValueStr, 64) + cumExitValue, _ := strconv.ParseFloat(cumExitValueStr, 64) + expectedPnL := cumExitValue - cumEntryValue + if side == "Sell" { + expectedPnL = cumEntryValue - cumExitValue + } + fee := expectedPnL - closedPnL + if fee < 0 { + fee = 0 + } + + // Normalize side + normalizedSide := "long" + if side == "Sell" { + normalizedSide = "short" + } + + record := types.ClosedPnLRecord{ + Symbol: symbol, + Side: normalizedSide, + EntryPrice: avgEntryPrice, + ExitPrice: avgExitPrice, + Quantity: qty, + RealizedPnL: closedPnL, + Fee: fee, + Leverage: int(leverage), + EntryTime: time.UnixMilli(createdTime).UTC(), + ExitTime: time.UnixMilli(updatedTime).UTC(), + OrderID: orderId, + CloseType: "unknown", // Bybit doesn't provide close type directly + ExchangeID: orderId, // Use orderId as exchange ID + } + + records = append(records, record) + } + + return records, nil +} diff --git a/trader/bybit/trader_orders.go b/trader/bybit/trader_orders.go new file mode 100644 index 00000000..fbf5f5d9 --- /dev/null +++ b/trader/bybit/trader_orders.go @@ -0,0 +1,741 @@ +package bybit + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "nofx/logger" + "nofx/trader/types" + "strconv" + "strings" +) + +// OpenLong opens a long position +func (t *BybitTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { + logger.Infof("[Bybit] ===== OpenLong called: symbol=%s, qty=%.6f, leverage=%d =====", symbol, quantity, leverage) + + // First cancel all pending orders for this symbol (clean up old orders) + if err := t.CancelAllOrders(symbol); err != nil { + logger.Infof("⚠️ [Bybit] Failed to cancel old pending orders: %v", err) + } + // Also cancel conditional orders (stop-loss/take-profit) - Bybit keeps them separate + if err := t.CancelStopOrders(symbol); err != nil { + logger.Infof("⚠️ [Bybit] Failed to cancel old stop orders: %v", err) + } + + // Set leverage first + if err := t.SetLeverage(symbol, leverage); err != nil { + logger.Infof("⚠️ [Bybit] Failed to set leverage: %v", err) + } + + // Use FormatQuantity to format quantity + qtyStr, _ := t.FormatQuantity(symbol, quantity) + + params := map[string]interface{}{ + "category": "linear", + "symbol": symbol, + "side": "Buy", + "orderType": "Market", + "qty": qtyStr, + "positionIdx": 0, // One-way position mode + } + + logger.Infof("[Bybit] OpenLong placing order: %+v", params) + + result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background()) + if err != nil { + return nil, fmt.Errorf("Bybit open long failed: %w", err) + } + + // Clear cache + t.clearCache() + + return t.parseOrderResult(result) +} + +// OpenShort opens a short position +func (t *BybitTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { + logger.Infof("[Bybit] ===== OpenShort called: symbol=%s, qty=%.6f, leverage=%d =====", symbol, quantity, leverage) + + // First cancel all pending orders for this symbol (clean up old orders) + if err := t.CancelAllOrders(symbol); err != nil { + logger.Infof("⚠️ [Bybit] Failed to cancel old pending orders: %v", err) + } + // Also cancel conditional orders (stop-loss/take-profit) - Bybit keeps them separate + if err := t.CancelStopOrders(symbol); err != nil { + logger.Infof("⚠️ [Bybit] Failed to cancel old stop orders: %v", err) + } + + // Set leverage first + if err := t.SetLeverage(symbol, leverage); err != nil { + logger.Infof("⚠️ [Bybit] Failed to set leverage: %v", err) + } + + // Use FormatQuantity to format quantity + qtyStr, _ := t.FormatQuantity(symbol, quantity) + + params := map[string]interface{}{ + "category": "linear", + "symbol": symbol, + "side": "Sell", + "orderType": "Market", + "qty": qtyStr, + "positionIdx": 0, // One-way position mode + } + + logger.Infof("[Bybit] OpenShort placing order: %+v", params) + + result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background()) + if err != nil { + return nil, fmt.Errorf("Bybit open short failed: %w", err) + } + + // Clear cache + t.clearCache() + + return t.parseOrderResult(result) +} + +// CloseLong closes a long position +func (t *BybitTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) { + // If quantity = 0, get current position quantity + if quantity == 0 { + positions, err := t.GetPositions() + if err != nil { + return nil, err + } + for _, pos := range positions { + side, _ := pos["side"].(string) + if pos["symbol"] == symbol && strings.ToLower(side) == "long" { + quantity = pos["positionAmt"].(float64) + break + } + } + } + + if quantity <= 0 { + return nil, fmt.Errorf("no long position to close") + } + + // Use FormatQuantity to format quantity + qtyStr, _ := t.FormatQuantity(symbol, quantity) + + params := map[string]interface{}{ + "category": "linear", + "symbol": symbol, + "side": "Sell", // Close long with Sell + "orderType": "Market", + "qty": qtyStr, + "positionIdx": 0, + "reduceOnly": true, + } + + result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background()) + if err != nil { + return nil, fmt.Errorf("Bybit close long failed: %w", err) + } + + // Clear cache + t.clearCache() + + return t.parseOrderResult(result) +} + +// CloseShort closes a short position +func (t *BybitTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) { + // If quantity = 0, get current position quantity + if quantity == 0 { + positions, err := t.GetPositions() + if err != nil { + return nil, err + } + for _, pos := range positions { + side, _ := pos["side"].(string) + if pos["symbol"] == symbol && strings.ToLower(side) == "short" { + quantity = -pos["positionAmt"].(float64) // Short position is negative + break + } + } + } + + if quantity <= 0 { + return nil, fmt.Errorf("no short position to close") + } + + // Use FormatQuantity to format quantity + qtyStr, _ := t.FormatQuantity(symbol, quantity) + + params := map[string]interface{}{ + "category": "linear", + "symbol": symbol, + "side": "Buy", // Close short with Buy + "orderType": "Market", + "qty": qtyStr, + "positionIdx": 0, + "reduceOnly": true, + } + + result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background()) + if err != nil { + return nil, fmt.Errorf("Bybit close short failed: %w", err) + } + + // Clear cache + t.clearCache() + + return t.parseOrderResult(result) +} + +// SetLeverage sets leverage +func (t *BybitTrader) SetLeverage(symbol string, leverage int) error { + params := map[string]interface{}{ + "category": "linear", + "symbol": symbol, + "buyLeverage": fmt.Sprintf("%d", leverage), + "sellLeverage": fmt.Sprintf("%d", leverage), + } + + result, err := t.client.NewUtaBybitServiceWithParams(params).SetPositionLeverage(context.Background()) + if err != nil { + // If leverage is already at target value, Bybit will return an error, ignore this case + if strings.Contains(err.Error(), "leverage not modified") { + return nil + } + return fmt.Errorf("failed to set leverage: %w", err) + } + + if result.RetCode != 0 && result.RetCode != 110043 { // 110043 = leverage not modified + return fmt.Errorf("failed to set leverage: %s", result.RetMsg) + } + + return nil +} + +// SetMarginMode sets position margin mode +func (t *BybitTrader) SetMarginMode(symbol string, isCrossMargin bool) error { + tradeMode := 1 // Isolated margin + if isCrossMargin { + tradeMode = 0 // Cross margin + } + + params := map[string]interface{}{ + "category": "linear", + "symbol": symbol, + "tradeMode": tradeMode, + } + + result, err := t.client.NewUtaBybitServiceWithParams(params).SwitchPositionMargin(context.Background()) + if err != nil { + if strings.Contains(err.Error(), "Cross/isolated margin mode is not modified") { + return nil + } + return fmt.Errorf("failed to set margin mode: %w", err) + } + + if result.RetCode != 0 && result.RetCode != 110026 { // already in target mode + return fmt.Errorf("failed to set margin mode: %s", result.RetMsg) + } + + return nil +} + +// GetMarketPrice retrieves market price +func (t *BybitTrader) GetMarketPrice(symbol string) (float64, error) { + params := map[string]interface{}{ + "category": "linear", + "symbol": symbol, + } + + result, err := t.client.NewUtaBybitServiceWithParams(params).GetMarketTickers(context.Background()) + if err != nil { + return 0, fmt.Errorf("failed to get market price: %w", err) + } + + if result.RetCode != 0 { + return 0, fmt.Errorf("API error: %s", result.RetMsg) + } + + resultData, ok := result.Result.(map[string]interface{}) + if !ok { + return 0, fmt.Errorf("return format error") + } + + list, _ := resultData["list"].([]interface{}) + + if len(list) == 0 { + return 0, fmt.Errorf("price data not found for %s", symbol) + } + + ticker, _ := list[0].(map[string]interface{}) + lastPriceStr, _ := ticker["lastPrice"].(string) + lastPrice, err := strconv.ParseFloat(lastPriceStr, 64) + if err != nil { + return 0, fmt.Errorf("failed to parse price: %w", err) + } + + return lastPrice, nil +} + +// SetStopLoss sets stop loss order +func (t *BybitTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error { + side := "Sell" // LONG stop loss uses Sell + if positionSide == "SHORT" { + side = "Buy" // SHORT stop loss uses Buy + } + + // Get current price to determine triggerDirection + currentPrice, err := t.GetMarketPrice(symbol) + if err != nil { + return err + } + + triggerDirection := 2 // Price fall trigger (default long stop loss) + if stopPrice > currentPrice { + triggerDirection = 1 // Price rise trigger (short stop loss) + } + + // Use FormatQuantity to format quantity + qtyStr, _ := t.FormatQuantity(symbol, quantity) + + params := map[string]interface{}{ + "category": "linear", + "symbol": symbol, + "side": side, + "orderType": "Market", + "qty": qtyStr, + "triggerPrice": fmt.Sprintf("%v", stopPrice), + "triggerDirection": triggerDirection, + "triggerBy": "LastPrice", + "reduceOnly": true, + } + + result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background()) + if err != nil { + return fmt.Errorf("failed to set stop loss: %w", err) + } + + if result.RetCode != 0 { + return fmt.Errorf("failed to set stop loss: %s", result.RetMsg) + } + + logger.Infof(" ✓ [Bybit] Stop loss order set: %s @ %.2f", symbol, stopPrice) + return nil +} + +// SetTakeProfit sets take profit order +func (t *BybitTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error { + side := "Sell" // LONG take profit uses Sell + if positionSide == "SHORT" { + side = "Buy" // SHORT take profit uses Buy + } + + // Get current price to determine triggerDirection + currentPrice, err := t.GetMarketPrice(symbol) + if err != nil { + return err + } + + triggerDirection := 1 // Price rise trigger (default long take profit) + if takeProfitPrice < currentPrice { + triggerDirection = 2 // Price fall trigger (short take profit) + } + + // Use FormatQuantity to format quantity + qtyStr, _ := t.FormatQuantity(symbol, quantity) + + params := map[string]interface{}{ + "category": "linear", + "symbol": symbol, + "side": side, + "orderType": "Market", + "qty": qtyStr, + "triggerPrice": fmt.Sprintf("%v", takeProfitPrice), + "triggerDirection": triggerDirection, + "triggerBy": "LastPrice", + "reduceOnly": true, + } + + result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background()) + if err != nil { + return fmt.Errorf("failed to set take profit: %w", err) + } + + if result.RetCode != 0 { + return fmt.Errorf("failed to set take profit: %s", result.RetMsg) + } + + logger.Infof(" ✓ [Bybit] Take profit order set: %s @ %.2f", symbol, takeProfitPrice) + return nil +} + +// CancelStopLossOrders cancels stop loss orders +func (t *BybitTrader) CancelStopLossOrders(symbol string) error { + return t.cancelConditionalOrders(symbol, "StopLoss") +} + +// CancelTakeProfitOrders cancels take profit orders +func (t *BybitTrader) CancelTakeProfitOrders(symbol string) error { + return t.cancelConditionalOrders(symbol, "TakeProfit") +} + +// CancelAllOrders cancels all pending orders +func (t *BybitTrader) CancelAllOrders(symbol string) error { + params := map[string]interface{}{ + "category": "linear", + "symbol": symbol, + } + + _, err := t.client.NewUtaBybitServiceWithParams(params).CancelAllOrders(context.Background()) + if err != nil { + return fmt.Errorf("failed to cancel all orders: %w", err) + } + + return nil +} + +// CancelStopOrders cancels all stop loss and take profit orders +func (t *BybitTrader) CancelStopOrders(symbol string) error { + if err := t.CancelStopLossOrders(symbol); err != nil { + logger.Infof("⚠️ [Bybit] Failed to cancel stop loss orders: %v", err) + } + if err := t.CancelTakeProfitOrders(symbol); err != nil { + logger.Infof("⚠️ [Bybit] Failed to cancel take profit orders: %v", err) + } + return nil +} + +func (t *BybitTrader) cancelConditionalOrders(symbol string, orderType string) error { + // First get all conditional orders + params := map[string]interface{}{ + "category": "linear", + "symbol": symbol, + "orderFilter": "StopOrder", // Conditional orders + } + + result, err := t.client.NewUtaBybitServiceWithParams(params).GetOpenOrders(context.Background()) + if err != nil { + return fmt.Errorf("failed to get conditional orders: %w", err) + } + + if result.RetCode != 0 { + return nil // No orders + } + + resultData, ok := result.Result.(map[string]interface{}) + if !ok { + return nil + } + + list, _ := resultData["list"].([]interface{}) + + // Cancel matching orders + for _, item := range list { + order, ok := item.(map[string]interface{}) + if !ok { + continue + } + + orderId, _ := order["orderId"].(string) + stopOrderType, _ := order["stopOrderType"].(string) + + // Filter by type + shouldCancel := false + if orderType == "StopLoss" && (stopOrderType == "StopLoss" || stopOrderType == "Stop") { + shouldCancel = true + } + if orderType == "TakeProfit" && (stopOrderType == "TakeProfit" || stopOrderType == "PartialTakeProfit") { + shouldCancel = true + } + + if shouldCancel && orderId != "" { + cancelParams := map[string]interface{}{ + "category": "linear", + "symbol": symbol, + "orderId": orderId, + } + t.client.NewUtaBybitServiceWithParams(cancelParams).CancelOrder(context.Background()) + } + } + + return nil +} + +// GetOrderStatus retrieves order status +func (t *BybitTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { + params := map[string]interface{}{ + "category": "linear", + "symbol": symbol, + "orderId": orderID, + } + + result, err := t.client.NewUtaBybitServiceWithParams(params).GetOrderHistory(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get order status: %w", err) + } + + if result.RetCode != 0 { + return nil, fmt.Errorf("API error: %s", result.RetMsg) + } + + resultData, ok := result.Result.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("return format error") + } + + list, _ := resultData["list"].([]interface{}) + if len(list) == 0 { + return nil, fmt.Errorf("order %s not found", orderID) + } + + order, _ := list[0].(map[string]interface{}) + + // Parse order data + status, _ := order["orderStatus"].(string) + avgPriceStr, _ := order["avgPrice"].(string) + cumExecQtyStr, _ := order["cumExecQty"].(string) + cumExecFeeStr, _ := order["cumExecFee"].(string) + + avgPrice, _ := strconv.ParseFloat(avgPriceStr, 64) + executedQty, _ := strconv.ParseFloat(cumExecQtyStr, 64) + commission, _ := strconv.ParseFloat(cumExecFeeStr, 64) + + // Convert status to unified format + unifiedStatus := status + switch status { + case "Filled": + unifiedStatus = "FILLED" + case "New", "Created": + unifiedStatus = "NEW" + case "Cancelled", "Rejected": + unifiedStatus = "CANCELED" + case "PartiallyFilled": + unifiedStatus = "PARTIALLY_FILLED" + } + + return map[string]interface{}{ + "orderId": orderID, + "status": unifiedStatus, + "avgPrice": avgPrice, + "executedQty": executedQty, + "commission": commission, + }, nil +} + +// GetOpenOrders gets all open/pending orders for a symbol +func (t *BybitTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) { + var result []types.OpenOrder + + // Get conditional orders (stop-loss, take-profit) + params := map[string]interface{}{ + "category": "linear", + "symbol": symbol, + "orderFilter": "StopOrder", + } + + resp, err := t.client.NewUtaBybitServiceWithParams(params).GetOpenOrders(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get open orders: %w", err) + } + + if resp.RetCode == 0 { + resultData, ok := resp.Result.(map[string]interface{}) + if ok { + list, _ := resultData["list"].([]interface{}) + for _, item := range list { + order, ok := item.(map[string]interface{}) + if !ok { + continue + } + + orderId, _ := order["orderId"].(string) + sym, _ := order["symbol"].(string) + side, _ := order["side"].(string) + orderType, _ := order["orderType"].(string) + stopOrderType, _ := order["stopOrderType"].(string) + triggerPrice, _ := order["triggerPrice"].(string) + qty, _ := order["qty"].(string) + + price, _ := strconv.ParseFloat(triggerPrice, 64) + quantity, _ := strconv.ParseFloat(qty, 64) + + // Determine type based on stopOrderType + displayType := orderType + if stopOrderType != "" { + displayType = stopOrderType + } + + result = append(result, types.OpenOrder{ + OrderID: orderId, + Symbol: sym, + Side: side, + PositionSide: "", // Bybit doesn't use positionSide for UTA + Type: displayType, + Price: 0, + StopPrice: price, + Quantity: quantity, + Status: "NEW", + }) + } + } + } + + return result, nil +} + +// PlaceLimitOrder places a limit order for grid trading +// Implements GridTrader interface +func (t *BybitTrader) PlaceLimitOrder(req *types.LimitOrderRequest) (*types.LimitOrderResult, error) { + // Format quantity + qtyStr, err := t.FormatQuantity(req.Symbol, req.Quantity) + if err != nil { + return nil, fmt.Errorf("failed to format quantity: %w", err) + } + + // Format price + priceStr := fmt.Sprintf("%.8f", req.Price) + + // Set leverage if specified + if req.Leverage > 0 { + if err := t.SetLeverage(req.Symbol, req.Leverage); err != nil { + logger.Warnf("[Bybit] Failed to set leverage: %v", err) + } + } + + // Determine side + side := "Buy" + if req.Side == "SELL" { + side = "Sell" + } + + params := map[string]interface{}{ + "category": "linear", + "symbol": req.Symbol, + "side": side, + "orderType": "Limit", + "qty": qtyStr, + "price": priceStr, + "timeInForce": "GTC", // Good Till Cancel + "positionIdx": 0, // One-way position mode + } + + // Add reduce only if specified + if req.ReduceOnly { + params["reduceOnly"] = true + } + + logger.Infof("[Bybit] PlaceLimitOrder: %s %s @ %s, qty=%s", req.Symbol, side, priceStr, qtyStr) + + result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to place limit order: %w", err) + } + + // Parse result + orderID := "" + if result.RetCode == 0 { + if resultData, ok := result.Result.(map[string]interface{}); ok { + if id, ok := resultData["orderId"].(string); ok { + orderID = id + } + } + } else { + return nil, fmt.Errorf("Bybit order failed: %s", result.RetMsg) + } + + logger.Infof("✓ [Bybit] Limit order placed: %s %s @ %s, qty=%s, orderID=%s", + req.Symbol, side, priceStr, qtyStr, orderID) + + return &types.LimitOrderResult{ + OrderID: orderID, + ClientID: req.ClientID, + Symbol: req.Symbol, + Side: req.Side, + PositionSide: req.PositionSide, + Price: req.Price, + Quantity: req.Quantity, + Status: "NEW", + }, nil +} + +// CancelOrder cancels a specific order by ID +// Implements GridTrader interface +func (t *BybitTrader) CancelOrder(symbol, orderID string) error { + params := map[string]interface{}{ + "category": "linear", + "symbol": symbol, + "orderId": orderID, + } + + result, err := t.client.NewUtaBybitServiceWithParams(params).CancelOrder(context.Background()) + if err != nil { + return fmt.Errorf("failed to cancel order: %w", err) + } + + if result.RetCode != 0 { + return fmt.Errorf("Bybit cancel order failed: %s", result.RetMsg) + } + + logger.Infof("✓ [Bybit] Order cancelled: %s %s", symbol, orderID) + return nil +} + +// GetOrderBook gets the order book for a symbol +// Implements GridTrader interface +func (t *BybitTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { + if depth <= 0 { + depth = 25 + } + + // Use HTTP request directly since the SDK doesn't expose GetOrderbook + url := fmt.Sprintf("https://api.bybit.com/v5/market/orderbook?category=linear&symbol=%s&limit=%d", symbol, depth) + resp, err := http.Get(url) + if err != nil { + return nil, nil, fmt.Errorf("failed to get order book: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + var result struct { + RetCode int `json:"retCode"` + RetMsg string `json:"retMsg"` + Result struct { + S string `json:"s"` // symbol + B [][]string `json:"b"` // bids [[price, size], ...] + A [][]string `json:"a"` // asks [[price, size], ...] + } `json:"result"` + } + + if err := json.Unmarshal(body, &result); err != nil { + return nil, nil, fmt.Errorf("failed to parse order book: %w", err) + } + + if result.RetCode != 0 { + return nil, nil, fmt.Errorf("Bybit get orderbook failed: %s", result.RetMsg) + } + + // Parse bids + for _, b := range result.Result.B { + if len(b) >= 2 { + price, _ := strconv.ParseFloat(b[0], 64) + qty, _ := strconv.ParseFloat(b[1], 64) + bids = append(bids, []float64{price, qty}) + } + } + + // Parse asks + for _, a := range result.Result.A { + if len(a) >= 2 { + price, _ := strconv.ParseFloat(a[0], 64) + qty, _ := strconv.ParseFloat(a[1], 64) + asks = append(asks, []float64{price, qty}) + } + } + + return bids, asks, nil +} diff --git a/trader/bybit/trader_positions.go b/trader/bybit/trader_positions.go new file mode 100644 index 00000000..4d4d9429 --- /dev/null +++ b/trader/bybit/trader_positions.go @@ -0,0 +1,125 @@ +package bybit + +import ( + "context" + "fmt" + "nofx/logger" + "strconv" + "strings" + "time" +) + +// GetPositions retrieves all positions +func (t *BybitTrader) GetPositions() ([]map[string]interface{}, error) { + // Check cache + t.positionsCacheMutex.RLock() + if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration { + positions := t.cachedPositions + t.positionsCacheMutex.RUnlock() + return positions, nil + } + t.positionsCacheMutex.RUnlock() + + // Call API + params := map[string]interface{}{ + "category": "linear", + "settleCoin": "USDT", + } + + result, err := t.client.NewUtaBybitServiceWithParams(params).GetPositionList(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get Bybit positions: %w", err) + } + + if result.RetCode != 0 { + return nil, fmt.Errorf("Bybit API error: %s", result.RetMsg) + } + + resultData, ok := result.Result.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("Bybit positions return format error") + } + + list, _ := resultData["list"].([]interface{}) + + var positions []map[string]interface{} + + for _, item := range list { + pos, ok := item.(map[string]interface{}) + if !ok { + continue + } + + sizeStr, _ := pos["size"].(string) + size, _ := strconv.ParseFloat(sizeStr, 64) + + // Skip empty positions + if size == 0 { + continue + } + + entryPriceStr, _ := pos["avgPrice"].(string) + entryPrice, _ := strconv.ParseFloat(entryPriceStr, 64) + + unrealisedPnlStr, _ := pos["unrealisedPnl"].(string) + unrealisedPnl, _ := strconv.ParseFloat(unrealisedPnlStr, 64) + + leverageStr, _ := pos["leverage"].(string) + leverage, _ := strconv.ParseFloat(leverageStr, 64) + + // Mark price + markPriceStr, _ := pos["markPrice"].(string) + markPrice, _ := strconv.ParseFloat(markPriceStr, 64) + + // Liquidation price + liqPriceStr, _ := pos["liqPrice"].(string) + liqPrice, _ := strconv.ParseFloat(liqPriceStr, 64) + + // Position created/updated time (milliseconds timestamp) + createdTimeStr, _ := pos["createdTime"].(string) + createdTime, _ := strconv.ParseInt(createdTimeStr, 10, 64) + updatedTimeStr, _ := pos["updatedTime"].(string) + updatedTime, _ := strconv.ParseInt(updatedTimeStr, 10, 64) + + positionSide, _ := pos["side"].(string) // Buy = long, Sell = short + + // Log raw position data for debugging + logger.Infof("[Bybit] GetPositions raw: symbol=%v, side=%s, size=%v", pos["symbol"], positionSide, sizeStr) + + // Convert to unified format (use lowercase for consistency with other exchanges) + // Bybit returns "Buy" for long, "Sell" for short + side := "long" + positionAmt := size + positionSideLower := strings.ToLower(positionSide) + if positionSideLower == "sell" { + side = "short" + positionAmt = -size + } + + logger.Infof("[Bybit] GetPositions converted: symbol=%v, rawSide=%s -> side=%s", pos["symbol"], positionSide, side) + + position := map[string]interface{}{ + "symbol": pos["symbol"], + "side": side, + "positionAmt": positionAmt, + "entryPrice": entryPrice, + "markPrice": markPrice, + "unRealizedProfit": unrealisedPnl, + "unrealizedPnL": unrealisedPnl, + "liquidationPrice": liqPrice, + "leverage": leverage, + "createdTime": createdTime, // Position open time (ms) + "updatedTime": updatedTime, // Position last update time (ms) + } + + positions = append(positions, position) + } + + // Update cache + t.positionsCacheMutex.Lock() + t.cachedPositions = positions + t.positionsCacheTime = time.Now() + t.positionsCacheMutex.Unlock() + + return positions, nil +} diff --git a/trader/bybit/trader_test.go b/trader/bybit/trader_test.go deleted file mode 100644 index 07011b01..00000000 --- a/trader/bybit/trader_test.go +++ /dev/null @@ -1,471 +0,0 @@ -package bybit - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "nofx/trader/testutil" - "nofx/trader/types" -) - -// ============================================================ -// Part 1: BybitTraderTestSuite - Inherits base test suite -// ============================================================ - -// BybitTraderTestSuite Bybit trader test suite -// Inherits TraderTestSuite and adds Bybit-specific mock logic -type BybitTraderTestSuite struct { - *testutil.TraderTestSuite // Embeds base test suite - mockServer *httptest.Server -} - -// NewBybitTraderTestSuite Create Bybit test suite -// Note: Due to Bybit SDK encapsulation design, cannot easily inject mock HTTP client -// Therefore this test suite is mainly used for interface compliance verification, not API call testing -func NewBybitTraderTestSuite(t *testing.T) *BybitTraderTestSuite { - // Create mock HTTP server (for response format verification) - mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - var respBody interface{} - - switch { - case path == "/v5/account/wallet-balance": - respBody = map[string]interface{}{ - "retCode": 0, - "retMsg": "OK", - "result": map[string]interface{}{ - "list": []map[string]interface{}{ - { - "accountType": "UNIFIED", - "totalEquity": "10100.50", - "coin": []map[string]interface{}{ - { - "coin": "USDT", - "walletBalance": "10000.00", - "unrealisedPnl": "100.50", - "availableToWithdraw": "8000.00", - }, - }, - }, - }, - }, - } - default: - respBody = map[string]interface{}{ - "retCode": 0, - "retMsg": "OK", - "result": map[string]interface{}{}, - } - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(respBody) - })) - - // Create real Bybit trader (for interface compliance testing) - traderInstance := NewBybitTrader("test_api_key", "test_secret_key") - - // Create base suite - baseSuite := testutil.NewTraderTestSuite(t, traderInstance) - - return &BybitTraderTestSuite{ - TraderTestSuite: baseSuite, - mockServer: mockServer, - } -} - -// Cleanup Clean up resources -func (s *BybitTraderTestSuite) Cleanup() { - if s.mockServer != nil { - s.mockServer.Close() - } - s.TraderTestSuite.Cleanup() -} - -// ============================================================ -// Part 2: Interface compliance tests -// ============================================================ - -// TestBybitTrader_InterfaceCompliance Test interface compliance -func TestBybitTrader_InterfaceCompliance(t *testing.T) { - var _ types.Trader = (*BybitTrader)(nil) -} - -// ============================================================ -// Part 3: Bybit-specific feature unit tests -// ============================================================ - -// TestNewBybitTrader Test creating Bybit trader -func TestNewBybitTrader(t *testing.T) { - tests := []struct { - name string - apiKey string - secretKey string - wantNil bool - }{ - { - name: "Successfully create", - apiKey: "test_api_key", - secretKey: "test_secret_key", - wantNil: false, - }, - { - name: "Empty API Key can still create", - apiKey: "", - secretKey: "test_secret_key", - wantNil: false, - }, - { - name: "Empty Secret Key can still create", - apiKey: "test_api_key", - secretKey: "", - wantNil: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - bt := NewBybitTrader(tt.apiKey, tt.secretKey) - - if tt.wantNil { - assert.Nil(t, bt) - } else { - assert.NotNil(t, bt) - assert.NotNil(t, bt.client) - } - }) - } -} - -// TestBybitTrader_SymbolFormat Test symbol format -func TestBybitTrader_SymbolFormat(t *testing.T) { - // Bybit uses uppercase symbol format (e.g. BTCUSDT) - tests := []struct { - name string - symbol string - isValid bool - }{ - { - name: "Standard USDT contract", - symbol: "BTCUSDT", - isValid: true, - }, - { - name: "ETH contract", - symbol: "ETHUSDT", - isValid: true, - }, - { - name: "SOL contract", - symbol: "SOLUSDT", - isValid: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Verify symbol format is correct (all uppercase, ends with USDT) - assert.True(t, tt.symbol == strings.ToUpper(tt.symbol)) - assert.True(t, strings.HasSuffix(tt.symbol, "USDT")) - }) - } -} - -// TestBybitTrader_FormatQuantity Test quantity formatting -func TestBybitTrader_FormatQuantity(t *testing.T) { - bt := NewBybitTrader("test", "test") - - tests := []struct { - name string - symbol string - quantity float64 - expected string - hasError bool - }{ - { - name: "BTC quantity formatting", - symbol: "BTCUSDT", - quantity: 0.12345, - expected: "0.123", // Bybit defaults to 3 decimal places - hasError: false, - }, - { - name: "ETH quantity formatting", - symbol: "ETHUSDT", - quantity: 1.2345, - expected: "1.234", - hasError: false, - }, - { - name: "Integer quantity", - symbol: "SOLUSDT", - quantity: 10.0, - expected: "10.000", - hasError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := bt.FormatQuantity(tt.symbol, tt.quantity) - if tt.hasError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tt.expected, result) - } - }) - } -} - -// TestBybitTrader_ParseResponse Test response parsing -func TestBybitTrader_ParseResponse(t *testing.T) { - tests := []struct { - name string - retCode int - retMsg string - expectErr bool - errContain string - }{ - { - name: "Success response", - retCode: 0, - retMsg: "OK", - expectErr: false, - }, - { - name: "API error", - retCode: 10001, - retMsg: "Invalid symbol", - expectErr: true, - errContain: "Invalid symbol", - }, - { - name: "Permission error", - retCode: 10003, - retMsg: "Invalid API key", - expectErr: true, - errContain: "Invalid API key", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := checkBybitResponse(tt.retCode, tt.retMsg) - if tt.expectErr { - assert.Error(t, err) - if tt.errContain != "" { - assert.Contains(t, err.Error(), tt.errContain) - } - } else { - assert.NoError(t, err) - } - }) - } -} - -// checkBybitResponse Check if Bybit API response has errors -func checkBybitResponse(retCode int, retMsg string) error { - if retCode != 0 { - return &BybitAPIError{ - Code: retCode, - Message: retMsg, - } - } - return nil -} - -// BybitAPIError Bybit API error type -type BybitAPIError struct { - Code int - Message string -} - -func (e *BybitAPIError) Error() string { - return e.Message -} - -// TestBybitTrader_PositionSideConversion Test position side conversion -func TestBybitTrader_PositionSideConversion(t *testing.T) { - tests := []struct { - name string - side string - expected string - }{ - { - name: "Buy to Long", - side: "Buy", - expected: "long", - }, - { - name: "Sell to Short", - side: "Sell", - expected: "short", - }, - { - name: "Other values remain unchanged", - side: "Unknown", - expected: "unknown", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := convertBybitSide(tt.side) - assert.Equal(t, tt.expected, result) - }) - } -} - -// convertBybitSide Convert Bybit position side -func convertBybitSide(side string) string { - switch side { - case "Buy": - return "long" - case "Sell": - return "short" - default: - return "unknown" - } -} - -// TestBybitTrader_CategoryLinear Test using only linear category -func TestBybitTrader_CategoryLinear(t *testing.T) { - // Bybit trader should only use linear category (USDT perpetual contracts) - bt := NewBybitTrader("test", "test") - assert.NotNil(t, bt) - - // Verify default configuration - assert.NotNil(t, bt.client) -} - -// TestBybitTrader_CacheDuration Test cache duration -func TestBybitTrader_CacheDuration(t *testing.T) { - bt := NewBybitTrader("test", "test") - - // Verify default cache time is 15 seconds - assert.Equal(t, 15*time.Second, bt.cacheDuration) -} - -// ============================================================ -// Part 4: Mock server integration tests -// ============================================================ - -// TestBybitTrader_MockServerGetBalance Test getting balance through Mock server -func TestBybitTrader_MockServerGetBalance(t *testing.T) { - mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/v5/account/wallet-balance" { - respBody := map[string]interface{}{ - "retCode": 0, - "retMsg": "OK", - "result": map[string]interface{}{ - "list": []map[string]interface{}{ - { - "accountType": "UNIFIED", - "totalEquity": "10100.50", - "coin": []map[string]interface{}{ - { - "coin": "USDT", - "walletBalance": "10000.00", - "unrealisedPnl": "100.50", - "availableToWithdraw": "8000.00", - }, - }, - }, - }, - }, - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(respBody) - return - } - http.NotFound(w, r) - })) - defer mockServer.Close() - - // Due to Bybit SDK encapsulation, cannot directly inject mock URL - // This test verifies mock server response format is correct - assert.NotNil(t, mockServer) -} - -// TestBybitTrader_MockServerGetPositions Test getting positions through Mock server -func TestBybitTrader_MockServerGetPositions(t *testing.T) { - mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/v5/position/list" { - respBody := map[string]interface{}{ - "retCode": 0, - "retMsg": "OK", - "result": map[string]interface{}{ - "list": []map[string]interface{}{ - { - "symbol": "BTCUSDT", - "side": "Buy", - "size": "0.5", - "avgPrice": "50000.00", - "markPrice": "50500.00", - "unrealisedPnl": "250.00", - "liqPrice": "45000.00", - "leverage": "10", - "positionIdx": 0, - }, - }, - }, - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(respBody) - return - } - http.NotFound(w, r) - })) - defer mockServer.Close() - - assert.NotNil(t, mockServer) -} - -// TestBybitTrader_MockServerPlaceOrder Test placing order through Mock server -func TestBybitTrader_MockServerPlaceOrder(t *testing.T) { - mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/v5/order/create" && r.Method == "POST" { - respBody := map[string]interface{}{ - "retCode": 0, - "retMsg": "OK", - "result": map[string]interface{}{ - "orderId": "1234567890", - "orderLinkId": "test-order-id", - }, - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(respBody) - return - } - http.NotFound(w, r) - })) - defer mockServer.Close() - - assert.NotNil(t, mockServer) -} - -// TestBybitTrader_MockServerSetLeverage Test setting leverage through Mock server -func TestBybitTrader_MockServerSetLeverage(t *testing.T) { - mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/v5/position/set-leverage" && r.Method == "POST" { - respBody := map[string]interface{}{ - "retCode": 0, - "retMsg": "OK", - "result": map[string]interface{}{}, - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(respBody) - return - } - http.NotFound(w, r) - })) - defer mockServer.Close() - - assert.NotNil(t, mockServer) -} diff --git a/trader/gate/trader.go b/trader/gate/trader.go index 5b9e7706..27a1d1e1 100644 --- a/trader/gate/trader.go +++ b/trader/gate/trader.go @@ -3,16 +3,12 @@ package gate import ( "context" "fmt" - "math" - "strconv" + "nofx/trader/types" "strings" "sync" "time" - "github.com/antihax/optional" "github.com/gateio/gateapi-go/v6" - "nofx/logger" - "nofx/trader/types" ) // GateTrader implements types.Trader interface for Gate.io Futures @@ -58,118 +54,6 @@ func NewGateTrader(apiKey, secretKey string) *GateTrader { } } -// GetBalance retrieves account balance -func (t *GateTrader) GetBalance() (map[string]interface{}, error) { - // Check cache - t.balanceCacheMutex.RLock() - if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration { - cached := t.cachedBalance - t.balanceCacheMutex.RUnlock() - return cached, nil - } - t.balanceCacheMutex.RUnlock() - - // Fetch from API - accounts, _, err := t.client.FuturesApi.ListFuturesAccounts(t.ctx, "usdt") - if err != nil { - return nil, fmt.Errorf("failed to get balance: %w", err) - } - - total, _ := strconv.ParseFloat(accounts.Total, 64) - available, _ := strconv.ParseFloat(accounts.Available, 64) - unrealizedPnl, _ := strconv.ParseFloat(accounts.UnrealisedPnl, 64) - - result := map[string]interface{}{ - "totalWalletBalance": total, - "availableBalance": available, - "totalUnrealizedProfit": unrealizedPnl, - } - - // Update cache - t.balanceCacheMutex.Lock() - t.cachedBalance = result - t.balanceCacheTime = time.Now() - t.balanceCacheMutex.Unlock() - - return result, nil -} - -// GetPositions retrieves all open positions -func (t *GateTrader) GetPositions() ([]map[string]interface{}, error) { - // Check cache - t.positionsCacheMutex.RLock() - if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration { - cached := t.cachedPositions - t.positionsCacheMutex.RUnlock() - return cached, nil - } - t.positionsCacheMutex.RUnlock() - - // Fetch from API - positions, _, err := t.client.FuturesApi.ListPositions(t.ctx, "usdt", nil) - if err != nil { - return nil, fmt.Errorf("failed to get positions: %w", err) - } - - var result []map[string]interface{} - for _, pos := range positions { - if pos.Size == 0 { - continue // Skip empty positions - } - - entryPrice, _ := strconv.ParseFloat(pos.EntryPrice, 64) - markPrice, _ := strconv.ParseFloat(pos.MarkPrice, 64) - liqPrice, _ := strconv.ParseFloat(pos.LiqPrice, 64) - unrealizedPnl, _ := strconv.ParseFloat(pos.UnrealisedPnl, 64) - leverage, _ := strconv.ParseFloat(pos.Leverage, 64) - - // Gate returns position size in contracts, need to convert to base currency - // Each contract = quanto_multiplier base currency - contractSize := float64(pos.Size) - if pos.Size < 0 { - contractSize = float64(-pos.Size) - } - - // Get quanto_multiplier from contract info to convert contracts to actual quantity - quantoMultiplier := 1.0 - contract, err := t.getContract(pos.Contract) - if err == nil && contract != nil { - qm, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64) - if qm > 0 { - quantoMultiplier = qm - } - } - - // Convert contract count to actual token quantity - positionAmt := contractSize * quantoMultiplier - - // Determine side based on position size - side := "long" - if pos.Size < 0 { - side = "short" - } - - result = append(result, map[string]interface{}{ - "symbol": pos.Contract, - "positionAmt": positionAmt, - "entryPrice": entryPrice, - "markPrice": markPrice, - "unRealizedProfit": unrealizedPnl, - "leverage": int(leverage), - "liquidationPrice": liqPrice, - "side": side, - }) - } - - // Update cache - t.positionsCacheMutex.Lock() - t.cachedPositions = result - t.positionsCacheTime = time.Now() - t.positionsCacheMutex.Unlock() - - return result, nil -} - // convertSymbol converts symbol format (e.g., BTCUSDT -> BTC_USDT) func (t *GateTrader) convertSymbol(symbol string) string { // If already in correct format @@ -215,674 +99,6 @@ func (t *GateTrader) getContract(symbol string) (*gateapi.Contract, error) { return &contract, nil } -// SetLeverage sets the leverage for a symbol -func (t *GateTrader) SetLeverage(symbol string, leverage int) error { - symbol = t.convertSymbol(symbol) - - _, _, err := t.client.FuturesApi.UpdatePositionLeverage(t.ctx, "usdt", symbol, fmt.Sprintf("%d", leverage), nil) - if err != nil { - // Gate.io may return error if leverage is already set - if strings.Contains(err.Error(), "RISK_LIMIT_EXCEEDED") { - logger.Warnf(" [Gate] Leverage %d exceeds limit for %s", leverage, symbol) - return nil - } - return fmt.Errorf("failed to set leverage: %w", err) - } - - logger.Infof(" [Gate] Leverage set to %dx for %s", leverage, symbol) - return nil -} - -// SetMarginMode sets margin mode (cross or isolated) -func (t *GateTrader) SetMarginMode(symbol string, isCrossMargin bool) error { - // Gate.io uses leverage=0 for cross margin, positive number for isolated - // This is handled through UpdatePositionLeverage with cross_leverage_limit - // For now, we'll skip explicit margin mode setting as it's tied to leverage - logger.Infof(" [Gate] Margin mode is set through leverage (0=cross)") - return nil -} - -// OpenLong opens a long position -func (t *GateTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { - symbol = t.convertSymbol(symbol) - - // Cancel old orders first - t.CancelAllOrders(symbol) - - // Set leverage - if err := t.SetLeverage(symbol, leverage); err != nil { - logger.Warnf(" [Gate] Failed to set leverage: %v", err) - } - - // Get contract info for size calculation - contract, err := t.getContract(symbol) - if err != nil { - return nil, err - } - - // Gate uses contract size units (each contract = quanto_multiplier base currency) - // size = quantity / quanto_multiplier - quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64) - size := int64(quantity / quantoMultiplier) - if size <= 0 { - size = 1 - } - - order := gateapi.FuturesOrder{ - Contract: symbol, - Size: size, // Positive for long - Price: "0", // Market order - Tif: "ioc", - Text: "t-nofx", - } - - logger.Infof(" [Gate] OpenLong: symbol=%s, size=%d, leverage=%d", symbol, size, leverage) - - result, _, err := t.client.FuturesApi.CreateFuturesOrder(t.ctx, "usdt", order, nil) - if err != nil { - return nil, fmt.Errorf("failed to open long position: %w", err) - } - - // Clear cache - t.clearCache() - - // Parse fill price from result - fillPrice, _ := strconv.ParseFloat(result.FillPrice, 64) - - logger.Infof(" [Gate] Opened long position: orderId=%d, fillPrice=%.4f", result.Id, fillPrice) - - return map[string]interface{}{ - "orderId": fmt.Sprintf("%d", result.Id), - "symbol": t.revertSymbol(symbol), - "status": "FILLED", - "fillPrice": fillPrice, - "avgPrice": fillPrice, - }, nil -} - -// OpenShort opens a short position -func (t *GateTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { - symbol = t.convertSymbol(symbol) - - // Cancel old orders first - t.CancelAllOrders(symbol) - - // Set leverage - if err := t.SetLeverage(symbol, leverage); err != nil { - logger.Warnf(" [Gate] Failed to set leverage: %v", err) - } - - // Get contract info for size calculation - contract, err := t.getContract(symbol) - if err != nil { - return nil, err - } - - // Gate uses contract size units - quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64) - size := int64(quantity / quantoMultiplier) - if size <= 0 { - size = 1 - } - - order := gateapi.FuturesOrder{ - Contract: symbol, - Size: -size, // Negative for short - Price: "0", // Market order - Tif: "ioc", - Text: "t-nofx", - } - - logger.Infof(" [Gate] OpenShort: symbol=%s, size=%d, leverage=%d", symbol, -size, leverage) - - result, _, err := t.client.FuturesApi.CreateFuturesOrder(t.ctx, "usdt", order, nil) - if err != nil { - return nil, fmt.Errorf("failed to open short position: %w", err) - } - - // Clear cache - t.clearCache() - - // Parse fill price from result - fillPrice, _ := strconv.ParseFloat(result.FillPrice, 64) - - logger.Infof(" [Gate] Opened short position: orderId=%d, fillPrice=%.4f", result.Id, fillPrice) - - return map[string]interface{}{ - "orderId": fmt.Sprintf("%d", result.Id), - "symbol": t.revertSymbol(symbol), - "status": "FILLED", - "fillPrice": fillPrice, - "avgPrice": fillPrice, - }, nil -} - -// CloseLong closes a long position -func (t *GateTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) { - symbol = t.convertSymbol(symbol) - - // If quantity is 0, get current position - if quantity == 0 { - positions, err := t.GetPositions() - if err != nil { - return nil, err - } - for _, pos := range positions { - posSymbol := t.convertSymbol(pos["symbol"].(string)) - if posSymbol == symbol && pos["side"] == "long" { - quantity = pos["positionAmt"].(float64) - break - } - } - if quantity == 0 { - return nil, fmt.Errorf("long position not found for %s", symbol) - } - } - - // Get contract info for size calculation - contract, err := t.getContract(symbol) - if err != nil { - return nil, err - } - - quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64) - size := int64(quantity / quantoMultiplier) - if size <= 0 { - size = 1 - } - - // Close long = sell (use ReduceOnly, not Close which requires Size=0) - order := gateapi.FuturesOrder{ - Contract: symbol, - Size: -size, // Negative to close long - Price: "0", - Tif: "ioc", - ReduceOnly: true, - Text: "t-nofx-close", - } - - logger.Infof(" [Gate] CloseLong: symbol=%s, size=%d", symbol, -size) - - result, _, err := t.client.FuturesApi.CreateFuturesOrder(t.ctx, "usdt", order, nil) - if err != nil { - return nil, fmt.Errorf("failed to close long position: %w", err) - } - - // Clear cache - t.clearCache() - - // Parse fill price from result - fillPrice, _ := strconv.ParseFloat(result.FillPrice, 64) - - logger.Infof(" [Gate] Closed long position: orderId=%d, fillPrice=%.4f", result.Id, fillPrice) - - return map[string]interface{}{ - "orderId": fmt.Sprintf("%d", result.Id), - "symbol": t.revertSymbol(symbol), - "status": "FILLED", - "fillPrice": fillPrice, - "avgPrice": fillPrice, - }, nil -} - -// CloseShort closes a short position -func (t *GateTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) { - symbol = t.convertSymbol(symbol) - - // If quantity is 0, get current position - if quantity == 0 { - positions, err := t.GetPositions() - if err != nil { - return nil, err - } - for _, pos := range positions { - posSymbol := t.convertSymbol(pos["symbol"].(string)) - if posSymbol == symbol && pos["side"] == "short" { - quantity = pos["positionAmt"].(float64) - break - } - } - if quantity == 0 { - return nil, fmt.Errorf("short position not found for %s", symbol) - } - } - - // Ensure quantity is positive - if quantity < 0 { - quantity = -quantity - } - - // Get contract info for size calculation - contract, err := t.getContract(symbol) - if err != nil { - return nil, err - } - - quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64) - size := int64(quantity / quantoMultiplier) - if size <= 0 { - size = 1 - } - - // Close short = buy (use ReduceOnly, not Close which requires Size=0) - order := gateapi.FuturesOrder{ - Contract: symbol, - Size: size, // Positive to close short - Price: "0", - Tif: "ioc", - ReduceOnly: true, - Text: "t-nofx-close", - } - - logger.Infof(" [Gate] CloseShort: symbol=%s, size=%d", symbol, size) - - result, _, err := t.client.FuturesApi.CreateFuturesOrder(t.ctx, "usdt", order, nil) - if err != nil { - return nil, fmt.Errorf("failed to close short position: %w", err) - } - - // Clear cache - t.clearCache() - - // Parse fill price from result - fillPrice, _ := strconv.ParseFloat(result.FillPrice, 64) - - logger.Infof(" [Gate] Closed short position: orderId=%d, fillPrice=%.4f", result.Id, fillPrice) - - return map[string]interface{}{ - "orderId": fmt.Sprintf("%d", result.Id), - "symbol": t.revertSymbol(symbol), - "status": "FILLED", - "fillPrice": fillPrice, - "avgPrice": fillPrice, - }, nil -} - -// GetMarketPrice gets the current market price -func (t *GateTrader) GetMarketPrice(symbol string) (float64, error) { - symbol = t.convertSymbol(symbol) - - opts := &gateapi.ListFuturesTickersOpts{ - Contract: optional.NewString(symbol), - } - - tickers, _, err := t.client.FuturesApi.ListFuturesTickers(t.ctx, "usdt", opts) - if err != nil { - return 0, fmt.Errorf("failed to get market price: %w", err) - } - - if len(tickers) == 0 { - return 0, fmt.Errorf("no ticker data for %s", symbol) - } - - price, _ := strconv.ParseFloat(tickers[0].Last, 64) - return price, nil -} - -// SetStopLoss sets a stop loss order -func (t *GateTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error { - symbol = t.convertSymbol(symbol) - - contract, err := t.getContract(symbol) - if err != nil { - return err - } - - quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64) - size := int64(quantity / quantoMultiplier) - if size <= 0 { - size = 1 - } - - // For long position, stop loss means sell when price drops - // For short position, stop loss means buy when price rises - if strings.ToUpper(positionSide) == "LONG" { - size = -size - } - - // Use price trigger order - trigger := gateapi.FuturesPriceTriggeredOrder{ - Initial: gateapi.FuturesInitialOrder{ - Contract: symbol, - Size: size, - Price: "0", // Market order - Tif: "ioc", - ReduceOnly: true, - Close: true, - }, - Trigger: gateapi.FuturesPriceTrigger{ - StrategyType: 0, // Close position - PriceType: 0, // Latest price - Price: fmt.Sprintf("%.8f", stopPrice), - Rule: 1, // Price <= trigger price - }, - } - - if strings.ToUpper(positionSide) == "SHORT" { - trigger.Trigger.Rule = 2 // Price >= trigger price for short stop loss - } - - _, _, err = t.client.FuturesApi.CreatePriceTriggeredOrder(t.ctx, "usdt", trigger) - if err != nil { - return fmt.Errorf("failed to set stop loss: %w", err) - } - - logger.Infof(" [Gate] Stop loss set: %s @ %.4f", symbol, stopPrice) - return nil -} - -// SetTakeProfit sets a take profit order -func (t *GateTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error { - symbol = t.convertSymbol(symbol) - - contract, err := t.getContract(symbol) - if err != nil { - return err - } - - quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64) - size := int64(quantity / quantoMultiplier) - if size <= 0 { - size = 1 - } - - // For long position, take profit means sell when price rises - // For short position, take profit means buy when price drops - if strings.ToUpper(positionSide) == "LONG" { - size = -size - } - - trigger := gateapi.FuturesPriceTriggeredOrder{ - Initial: gateapi.FuturesInitialOrder{ - Contract: symbol, - Size: size, - Price: "0", // Market order - Tif: "ioc", - ReduceOnly: true, - Close: true, - }, - Trigger: gateapi.FuturesPriceTrigger{ - StrategyType: 0, // Close position - PriceType: 0, // Latest price - Price: fmt.Sprintf("%.8f", takeProfitPrice), - Rule: 2, // Price >= trigger price for long take profit - }, - } - - if strings.ToUpper(positionSide) == "SHORT" { - trigger.Trigger.Rule = 1 // Price <= trigger price for short take profit - } - - _, _, err = t.client.FuturesApi.CreatePriceTriggeredOrder(t.ctx, "usdt", trigger) - if err != nil { - return fmt.Errorf("failed to set take profit: %w", err) - } - - logger.Infof(" [Gate] Take profit set: %s @ %.4f", symbol, takeProfitPrice) - return nil -} - -// CancelStopLossOrders cancels stop loss orders -func (t *GateTrader) CancelStopLossOrders(symbol string) error { - return t.cancelTriggerOrders(symbol, "stop_loss") -} - -// CancelTakeProfitOrders cancels take profit orders -func (t *GateTrader) CancelTakeProfitOrders(symbol string) error { - return t.cancelTriggerOrders(symbol, "take_profit") -} - -// cancelTriggerOrders cancels trigger orders of a specific type -func (t *GateTrader) cancelTriggerOrders(symbol string, orderType string) error { - symbol = t.convertSymbol(symbol) - - opts := &gateapi.ListPriceTriggeredOrdersOpts{ - Contract: optional.NewString(symbol), - } - - orders, _, err := t.client.FuturesApi.ListPriceTriggeredOrders(t.ctx, "usdt", "open", opts) - if err != nil { - return err - } - - for _, order := range orders { - // Determine if it's stop loss or take profit based on trigger rule and position - // For simplicity, cancel all matching symbol orders - _, _, err := t.client.FuturesApi.CancelPriceTriggeredOrder(t.ctx, "usdt", fmt.Sprintf("%d", order.Id)) - if err != nil { - logger.Warnf(" [Gate] Failed to cancel trigger order %d: %v", order.Id, err) - } - } - - return nil -} - -// CancelAllOrders cancels all pending orders for a symbol -func (t *GateTrader) CancelAllOrders(symbol string) error { - symbol = t.convertSymbol(symbol) - - // Cancel regular orders - _, _, err := t.client.FuturesApi.CancelFuturesOrders(t.ctx, "usdt", symbol, nil) - if err != nil { - // Ignore if no orders to cancel - if !strings.Contains(err.Error(), "ORDER_NOT_FOUND") { - logger.Warnf(" [Gate] Error canceling orders: %v", err) - } - } - - // Cancel trigger orders - t.cancelTriggerOrders(symbol, "") - - return nil -} - -// CancelStopOrders cancels all stop orders (stop loss and take profit) -func (t *GateTrader) CancelStopOrders(symbol string) error { - t.CancelStopLossOrders(symbol) - t.CancelTakeProfitOrders(symbol) - return nil -} - -// FormatQuantity formats quantity to correct precision -func (t *GateTrader) FormatQuantity(symbol string, quantity float64) (string, error) { - contract, err := t.getContract(symbol) - if err != nil { - return fmt.Sprintf("%.4f", quantity), nil - } - - // Gate uses quanto_multiplier for contract size - quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64) - if quantoMultiplier > 0 { - // Calculate number of contracts - numContracts := quantity / quantoMultiplier - return fmt.Sprintf("%.0f", math.Floor(numContracts)), nil - } - - return fmt.Sprintf("%.4f", quantity), nil -} - -// GetOrderStatus gets the status of an order -func (t *GateTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { - symbol = t.convertSymbol(symbol) - - order, _, err := t.client.FuturesApi.GetFuturesOrder(t.ctx, "usdt", orderID) - if err != nil { - return nil, fmt.Errorf("failed to get order status: %w", err) - } - - fillPrice, _ := strconv.ParseFloat(order.FillPrice, 64) - tkFee, _ := strconv.ParseFloat(order.Tkfr, 64) - mkFee, _ := strconv.ParseFloat(order.Mkfr, 64) - totalFee := tkFee + mkFee - - // Get quanto_multiplier to convert contracts to actual quantity - quantoMultiplier := 1.0 - contract, contractErr := t.getContract(symbol) - if contractErr == nil && contract != nil { - qm, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64) - if qm > 0 { - quantoMultiplier = qm - } - } - - // Map status - status := "NEW" - switch order.Status { - case "finished": - if order.FinishAs == "filled" { - status = "FILLED" - } else if order.FinishAs == "cancelled" { - status = "CANCELED" - } else { - status = "CLOSED" - } - case "open": - status = "NEW" - } - - side := "BUY" - if order.Size < 0 { - side = "SELL" - } - - // Convert contract count to actual token quantity - executedQty := math.Abs(float64(order.Size-order.Left)) * quantoMultiplier - - return map[string]interface{}{ - "orderId": orderID, - "symbol": t.revertSymbol(symbol), - "status": status, - "avgPrice": fillPrice, - "executedQty": executedQty, - "side": side, - "type": order.Tif, - "time": int64(order.CreateTime * 1000), - "updateTime": int64(order.FinishTime * 1000), - "commission": totalFee, - }, nil -} - -// GetClosedPnL retrieves closed position PnL records -func (t *GateTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) { - if limit <= 0 { - limit = 100 - } - if limit > 100 { - limit = 100 - } - - opts := &gateapi.ListPositionCloseOpts{ - Limit: optional.NewInt32(int32(limit)), - From: optional.NewInt64(startTime.Unix()), - } - - closedPositions, _, err := t.client.FuturesApi.ListPositionClose(t.ctx, "usdt", opts) - if err != nil { - return nil, fmt.Errorf("failed to get closed positions: %w", err) - } - - records := make([]types.ClosedPnLRecord, 0, len(closedPositions)) - for _, pos := range closedPositions { - pnl, _ := strconv.ParseFloat(pos.Pnl, 64) - - record := types.ClosedPnLRecord{ - Symbol: t.revertSymbol(pos.Contract), - Side: pos.Side, - RealizedPnL: pnl, - ExitTime: time.Unix(int64(pos.Time), 0).UTC(), - CloseType: "unknown", - } - - records = append(records, record) - } - - return records, nil -} - -// GetOpenOrders gets open/pending orders -func (t *GateTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) { - symbol = t.convertSymbol(symbol) - - opts := &gateapi.ListFuturesOrdersOpts{ - Contract: optional.NewString(symbol), - } - - orders, _, err := t.client.FuturesApi.ListFuturesOrders(t.ctx, "usdt", "open", opts) - if err != nil { - return nil, fmt.Errorf("failed to get open orders: %w", err) - } - - // Get quanto_multiplier to convert contracts to actual quantity - quantoMultiplier := 1.0 - contract, err := t.getContract(symbol) - if err == nil && contract != nil { - qm, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64) - if qm > 0 { - quantoMultiplier = qm - } - } - - var result []types.OpenOrder - for _, order := range orders { - price, _ := strconv.ParseFloat(order.Price, 64) - - side := "BUY" - if order.Size < 0 { - side = "SELL" - } - - // Convert contract count to actual token quantity - quantity := math.Abs(float64(order.Size)) * quantoMultiplier - - result = append(result, types.OpenOrder{ - OrderID: fmt.Sprintf("%d", order.Id), - Symbol: t.revertSymbol(order.Contract), - Side: side, - Type: "LIMIT", - Price: price, - Quantity: quantity, - Status: "NEW", - }) - } - - // Also get trigger orders - triggerOpts := &gateapi.ListPriceTriggeredOrdersOpts{ - Contract: optional.NewString(symbol), - } - - triggerOrders, _, err := t.client.FuturesApi.ListPriceTriggeredOrders(t.ctx, "usdt", "open", triggerOpts) - if err == nil { - for _, order := range triggerOrders { - triggerPrice, _ := strconv.ParseFloat(order.Trigger.Price, 64) - - side := "BUY" - if order.Initial.Size < 0 { - side = "SELL" - } - - orderType := "STOP_MARKET" - if order.Trigger.Rule == 2 { - orderType = "TAKE_PROFIT_MARKET" - } - - // Convert contract count to actual token quantity - quantity := math.Abs(float64(order.Initial.Size)) * quantoMultiplier - - result = append(result, types.OpenOrder{ - OrderID: fmt.Sprintf("%d", order.Id), - Symbol: t.revertSymbol(order.Initial.Contract), - Side: side, - Type: orderType, - StopPrice: triggerPrice, - Quantity: quantity, - Status: "NEW", - }) - } - } - - return result, nil -} - // clearCache clears all caches func (t *GateTrader) clearCache() { t.balanceCacheMutex.Lock() diff --git a/trader/gate/trader_account.go b/trader/gate/trader_account.go new file mode 100644 index 00000000..1c912315 --- /dev/null +++ b/trader/gate/trader_account.go @@ -0,0 +1,160 @@ +package gate + +import ( + "fmt" + "nofx/trader/types" + "strconv" + "time" + + "github.com/antihax/optional" + "github.com/gateio/gateapi-go/v6" +) + +// GetBalance retrieves account balance +func (t *GateTrader) GetBalance() (map[string]interface{}, error) { + // Check cache + t.balanceCacheMutex.RLock() + if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration { + cached := t.cachedBalance + t.balanceCacheMutex.RUnlock() + return cached, nil + } + t.balanceCacheMutex.RUnlock() + + // Fetch from API + accounts, _, err := t.client.FuturesApi.ListFuturesAccounts(t.ctx, "usdt") + if err != nil { + return nil, fmt.Errorf("failed to get balance: %w", err) + } + + total, _ := strconv.ParseFloat(accounts.Total, 64) + available, _ := strconv.ParseFloat(accounts.Available, 64) + unrealizedPnl, _ := strconv.ParseFloat(accounts.UnrealisedPnl, 64) + + result := map[string]interface{}{ + "totalWalletBalance": total, + "availableBalance": available, + "totalUnrealizedProfit": unrealizedPnl, + } + + // Update cache + t.balanceCacheMutex.Lock() + t.cachedBalance = result + t.balanceCacheTime = time.Now() + t.balanceCacheMutex.Unlock() + + return result, nil +} + +// GetPositions retrieves all open positions +func (t *GateTrader) GetPositions() ([]map[string]interface{}, error) { + // Check cache + t.positionsCacheMutex.RLock() + if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration { + cached := t.cachedPositions + t.positionsCacheMutex.RUnlock() + return cached, nil + } + t.positionsCacheMutex.RUnlock() + + // Fetch from API + positions, _, err := t.client.FuturesApi.ListPositions(t.ctx, "usdt", nil) + if err != nil { + return nil, fmt.Errorf("failed to get positions: %w", err) + } + + var result []map[string]interface{} + for _, pos := range positions { + if pos.Size == 0 { + continue // Skip empty positions + } + + entryPrice, _ := strconv.ParseFloat(pos.EntryPrice, 64) + markPrice, _ := strconv.ParseFloat(pos.MarkPrice, 64) + liqPrice, _ := strconv.ParseFloat(pos.LiqPrice, 64) + unrealizedPnl, _ := strconv.ParseFloat(pos.UnrealisedPnl, 64) + leverage, _ := strconv.ParseFloat(pos.Leverage, 64) + + // Gate returns position size in contracts, need to convert to base currency + // Each contract = quanto_multiplier base currency + contractSize := float64(pos.Size) + if pos.Size < 0 { + contractSize = float64(-pos.Size) + } + + // Get quanto_multiplier from contract info to convert contracts to actual quantity + quantoMultiplier := 1.0 + contract, err := t.getContract(pos.Contract) + if err == nil && contract != nil { + qm, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64) + if qm > 0 { + quantoMultiplier = qm + } + } + + // Convert contract count to actual token quantity + positionAmt := contractSize * quantoMultiplier + + // Determine side based on position size + side := "long" + if pos.Size < 0 { + side = "short" + } + + result = append(result, map[string]interface{}{ + "symbol": pos.Contract, + "positionAmt": positionAmt, + "entryPrice": entryPrice, + "markPrice": markPrice, + "unRealizedProfit": unrealizedPnl, + "leverage": int(leverage), + "liquidationPrice": liqPrice, + "side": side, + }) + } + + // Update cache + t.positionsCacheMutex.Lock() + t.cachedPositions = result + t.positionsCacheTime = time.Now() + t.positionsCacheMutex.Unlock() + + return result, nil +} + +// GetClosedPnL retrieves closed position PnL records +func (t *GateTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) { + if limit <= 0 { + limit = 100 + } + if limit > 100 { + limit = 100 + } + + opts := &gateapi.ListPositionCloseOpts{ + Limit: optional.NewInt32(int32(limit)), + From: optional.NewInt64(startTime.Unix()), + } + + closedPositions, _, err := t.client.FuturesApi.ListPositionClose(t.ctx, "usdt", opts) + if err != nil { + return nil, fmt.Errorf("failed to get closed positions: %w", err) + } + + records := make([]types.ClosedPnLRecord, 0, len(closedPositions)) + for _, pos := range closedPositions { + pnl, _ := strconv.ParseFloat(pos.Pnl, 64) + + record := types.ClosedPnLRecord{ + Symbol: t.revertSymbol(pos.Contract), + Side: pos.Side, + RealizedPnL: pnl, + ExitTime: time.Unix(int64(pos.Time), 0).UTC(), + CloseType: "unknown", + } + + records = append(records, record) + } + + return records, nil +} diff --git a/trader/gate/trader_orders.go b/trader/gate/trader_orders.go new file mode 100644 index 00000000..198a6acb --- /dev/null +++ b/trader/gate/trader_orders.go @@ -0,0 +1,644 @@ +package gate + +import ( + "fmt" + "math" + "nofx/logger" + "nofx/trader/types" + "strconv" + "strings" + + "github.com/antihax/optional" + "github.com/gateio/gateapi-go/v6" +) + +// SetLeverage sets the leverage for a symbol +func (t *GateTrader) SetLeverage(symbol string, leverage int) error { + symbol = t.convertSymbol(symbol) + + _, _, err := t.client.FuturesApi.UpdatePositionLeverage(t.ctx, "usdt", symbol, fmt.Sprintf("%d", leverage), nil) + if err != nil { + // Gate.io may return error if leverage is already set + if strings.Contains(err.Error(), "RISK_LIMIT_EXCEEDED") { + logger.Warnf(" [Gate] Leverage %d exceeds limit for %s", leverage, symbol) + return nil + } + return fmt.Errorf("failed to set leverage: %w", err) + } + + logger.Infof(" [Gate] Leverage set to %dx for %s", leverage, symbol) + return nil +} + +// SetMarginMode sets margin mode (cross or isolated) +func (t *GateTrader) SetMarginMode(symbol string, isCrossMargin bool) error { + // Gate.io uses leverage=0 for cross margin, positive number for isolated + // This is handled through UpdatePositionLeverage with cross_leverage_limit + // For now, we'll skip explicit margin mode setting as it's tied to leverage + logger.Infof(" [Gate] Margin mode is set through leverage (0=cross)") + return nil +} + +// OpenLong opens a long position +func (t *GateTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { + symbol = t.convertSymbol(symbol) + + // Cancel old orders first + t.CancelAllOrders(symbol) + + // Set leverage + if err := t.SetLeverage(symbol, leverage); err != nil { + logger.Warnf(" [Gate] Failed to set leverage: %v", err) + } + + // Get contract info for size calculation + contract, err := t.getContract(symbol) + if err != nil { + return nil, err + } + + // Gate uses contract size units (each contract = quanto_multiplier base currency) + // size = quantity / quanto_multiplier + quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64) + size := int64(quantity / quantoMultiplier) + if size <= 0 { + size = 1 + } + + order := gateapi.FuturesOrder{ + Contract: symbol, + Size: size, // Positive for long + Price: "0", // Market order + Tif: "ioc", + Text: "t-nofx", + } + + logger.Infof(" [Gate] OpenLong: symbol=%s, size=%d, leverage=%d", symbol, size, leverage) + + result, _, err := t.client.FuturesApi.CreateFuturesOrder(t.ctx, "usdt", order, nil) + if err != nil { + return nil, fmt.Errorf("failed to open long position: %w", err) + } + + // Clear cache + t.clearCache() + + // Parse fill price from result + fillPrice, _ := strconv.ParseFloat(result.FillPrice, 64) + + logger.Infof(" [Gate] Opened long position: orderId=%d, fillPrice=%.4f", result.Id, fillPrice) + + return map[string]interface{}{ + "orderId": fmt.Sprintf("%d", result.Id), + "symbol": t.revertSymbol(symbol), + "status": "FILLED", + "fillPrice": fillPrice, + "avgPrice": fillPrice, + }, nil +} + +// OpenShort opens a short position +func (t *GateTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { + symbol = t.convertSymbol(symbol) + + // Cancel old orders first + t.CancelAllOrders(symbol) + + // Set leverage + if err := t.SetLeverage(symbol, leverage); err != nil { + logger.Warnf(" [Gate] Failed to set leverage: %v", err) + } + + // Get contract info for size calculation + contract, err := t.getContract(symbol) + if err != nil { + return nil, err + } + + // Gate uses contract size units + quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64) + size := int64(quantity / quantoMultiplier) + if size <= 0 { + size = 1 + } + + order := gateapi.FuturesOrder{ + Contract: symbol, + Size: -size, // Negative for short + Price: "0", // Market order + Tif: "ioc", + Text: "t-nofx", + } + + logger.Infof(" [Gate] OpenShort: symbol=%s, size=%d, leverage=%d", symbol, -size, leverage) + + result, _, err := t.client.FuturesApi.CreateFuturesOrder(t.ctx, "usdt", order, nil) + if err != nil { + return nil, fmt.Errorf("failed to open short position: %w", err) + } + + // Clear cache + t.clearCache() + + // Parse fill price from result + fillPrice, _ := strconv.ParseFloat(result.FillPrice, 64) + + logger.Infof(" [Gate] Opened short position: orderId=%d, fillPrice=%.4f", result.Id, fillPrice) + + return map[string]interface{}{ + "orderId": fmt.Sprintf("%d", result.Id), + "symbol": t.revertSymbol(symbol), + "status": "FILLED", + "fillPrice": fillPrice, + "avgPrice": fillPrice, + }, nil +} + +// CloseLong closes a long position +func (t *GateTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) { + symbol = t.convertSymbol(symbol) + + // If quantity is 0, get current position + if quantity == 0 { + positions, err := t.GetPositions() + if err != nil { + return nil, err + } + for _, pos := range positions { + posSymbol := t.convertSymbol(pos["symbol"].(string)) + if posSymbol == symbol && pos["side"] == "long" { + quantity = pos["positionAmt"].(float64) + break + } + } + if quantity == 0 { + return nil, fmt.Errorf("long position not found for %s", symbol) + } + } + + // Get contract info for size calculation + contract, err := t.getContract(symbol) + if err != nil { + return nil, err + } + + quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64) + size := int64(quantity / quantoMultiplier) + if size <= 0 { + size = 1 + } + + // Close long = sell (use ReduceOnly, not Close which requires Size=0) + order := gateapi.FuturesOrder{ + Contract: symbol, + Size: -size, // Negative to close long + Price: "0", + Tif: "ioc", + ReduceOnly: true, + Text: "t-nofx-close", + } + + logger.Infof(" [Gate] CloseLong: symbol=%s, size=%d", symbol, -size) + + result, _, err := t.client.FuturesApi.CreateFuturesOrder(t.ctx, "usdt", order, nil) + if err != nil { + return nil, fmt.Errorf("failed to close long position: %w", err) + } + + // Clear cache + t.clearCache() + + // Parse fill price from result + fillPrice, _ := strconv.ParseFloat(result.FillPrice, 64) + + logger.Infof(" [Gate] Closed long position: orderId=%d, fillPrice=%.4f", result.Id, fillPrice) + + return map[string]interface{}{ + "orderId": fmt.Sprintf("%d", result.Id), + "symbol": t.revertSymbol(symbol), + "status": "FILLED", + "fillPrice": fillPrice, + "avgPrice": fillPrice, + }, nil +} + +// CloseShort closes a short position +func (t *GateTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) { + symbol = t.convertSymbol(symbol) + + // If quantity is 0, get current position + if quantity == 0 { + positions, err := t.GetPositions() + if err != nil { + return nil, err + } + for _, pos := range positions { + posSymbol := t.convertSymbol(pos["symbol"].(string)) + if posSymbol == symbol && pos["side"] == "short" { + quantity = pos["positionAmt"].(float64) + break + } + } + if quantity == 0 { + return nil, fmt.Errorf("short position not found for %s", symbol) + } + } + + // Ensure quantity is positive + if quantity < 0 { + quantity = -quantity + } + + // Get contract info for size calculation + contract, err := t.getContract(symbol) + if err != nil { + return nil, err + } + + quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64) + size := int64(quantity / quantoMultiplier) + if size <= 0 { + size = 1 + } + + // Close short = buy (use ReduceOnly, not Close which requires Size=0) + order := gateapi.FuturesOrder{ + Contract: symbol, + Size: size, // Positive to close short + Price: "0", + Tif: "ioc", + ReduceOnly: true, + Text: "t-nofx-close", + } + + logger.Infof(" [Gate] CloseShort: symbol=%s, size=%d", symbol, size) + + result, _, err := t.client.FuturesApi.CreateFuturesOrder(t.ctx, "usdt", order, nil) + if err != nil { + return nil, fmt.Errorf("failed to close short position: %w", err) + } + + // Clear cache + t.clearCache() + + // Parse fill price from result + fillPrice, _ := strconv.ParseFloat(result.FillPrice, 64) + + logger.Infof(" [Gate] Closed short position: orderId=%d, fillPrice=%.4f", result.Id, fillPrice) + + return map[string]interface{}{ + "orderId": fmt.Sprintf("%d", result.Id), + "symbol": t.revertSymbol(symbol), + "status": "FILLED", + "fillPrice": fillPrice, + "avgPrice": fillPrice, + }, nil +} + +// GetMarketPrice gets the current market price +func (t *GateTrader) GetMarketPrice(symbol string) (float64, error) { + symbol = t.convertSymbol(symbol) + + opts := &gateapi.ListFuturesTickersOpts{ + Contract: optional.NewString(symbol), + } + + tickers, _, err := t.client.FuturesApi.ListFuturesTickers(t.ctx, "usdt", opts) + if err != nil { + return 0, fmt.Errorf("failed to get market price: %w", err) + } + + if len(tickers) == 0 { + return 0, fmt.Errorf("no ticker data for %s", symbol) + } + + price, _ := strconv.ParseFloat(tickers[0].Last, 64) + return price, nil +} + +// SetStopLoss sets a stop loss order +func (t *GateTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error { + symbol = t.convertSymbol(symbol) + + contract, err := t.getContract(symbol) + if err != nil { + return err + } + + quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64) + size := int64(quantity / quantoMultiplier) + if size <= 0 { + size = 1 + } + + // For long position, stop loss means sell when price drops + // For short position, stop loss means buy when price rises + if strings.ToUpper(positionSide) == "LONG" { + size = -size + } + + // Use price trigger order + trigger := gateapi.FuturesPriceTriggeredOrder{ + Initial: gateapi.FuturesInitialOrder{ + Contract: symbol, + Size: size, + Price: "0", // Market order + Tif: "ioc", + ReduceOnly: true, + Close: true, + }, + Trigger: gateapi.FuturesPriceTrigger{ + StrategyType: 0, // Close position + PriceType: 0, // Latest price + Price: fmt.Sprintf("%.8f", stopPrice), + Rule: 1, // Price <= trigger price + }, + } + + if strings.ToUpper(positionSide) == "SHORT" { + trigger.Trigger.Rule = 2 // Price >= trigger price for short stop loss + } + + _, _, err = t.client.FuturesApi.CreatePriceTriggeredOrder(t.ctx, "usdt", trigger) + if err != nil { + return fmt.Errorf("failed to set stop loss: %w", err) + } + + logger.Infof(" [Gate] Stop loss set: %s @ %.4f", symbol, stopPrice) + return nil +} + +// SetTakeProfit sets a take profit order +func (t *GateTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error { + symbol = t.convertSymbol(symbol) + + contract, err := t.getContract(symbol) + if err != nil { + return err + } + + quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64) + size := int64(quantity / quantoMultiplier) + if size <= 0 { + size = 1 + } + + // For long position, take profit means sell when price rises + // For short position, take profit means buy when price drops + if strings.ToUpper(positionSide) == "LONG" { + size = -size + } + + trigger := gateapi.FuturesPriceTriggeredOrder{ + Initial: gateapi.FuturesInitialOrder{ + Contract: symbol, + Size: size, + Price: "0", // Market order + Tif: "ioc", + ReduceOnly: true, + Close: true, + }, + Trigger: gateapi.FuturesPriceTrigger{ + StrategyType: 0, // Close position + PriceType: 0, // Latest price + Price: fmt.Sprintf("%.8f", takeProfitPrice), + Rule: 2, // Price >= trigger price for long take profit + }, + } + + if strings.ToUpper(positionSide) == "SHORT" { + trigger.Trigger.Rule = 1 // Price <= trigger price for short take profit + } + + _, _, err = t.client.FuturesApi.CreatePriceTriggeredOrder(t.ctx, "usdt", trigger) + if err != nil { + return fmt.Errorf("failed to set take profit: %w", err) + } + + logger.Infof(" [Gate] Take profit set: %s @ %.4f", symbol, takeProfitPrice) + return nil +} + +// CancelStopLossOrders cancels stop loss orders +func (t *GateTrader) CancelStopLossOrders(symbol string) error { + return t.cancelTriggerOrders(symbol, "stop_loss") +} + +// CancelTakeProfitOrders cancels take profit orders +func (t *GateTrader) CancelTakeProfitOrders(symbol string) error { + return t.cancelTriggerOrders(symbol, "take_profit") +} + +// cancelTriggerOrders cancels trigger orders of a specific type +func (t *GateTrader) cancelTriggerOrders(symbol string, orderType string) error { + symbol = t.convertSymbol(symbol) + + opts := &gateapi.ListPriceTriggeredOrdersOpts{ + Contract: optional.NewString(symbol), + } + + orders, _, err := t.client.FuturesApi.ListPriceTriggeredOrders(t.ctx, "usdt", "open", opts) + if err != nil { + return err + } + + for _, order := range orders { + // Determine if it's stop loss or take profit based on trigger rule and position + // For simplicity, cancel all matching symbol orders + _, _, err := t.client.FuturesApi.CancelPriceTriggeredOrder(t.ctx, "usdt", fmt.Sprintf("%d", order.Id)) + if err != nil { + logger.Warnf(" [Gate] Failed to cancel trigger order %d: %v", order.Id, err) + } + } + + return nil +} + +// CancelAllOrders cancels all pending orders for a symbol +func (t *GateTrader) CancelAllOrders(symbol string) error { + symbol = t.convertSymbol(symbol) + + // Cancel regular orders + _, _, err := t.client.FuturesApi.CancelFuturesOrders(t.ctx, "usdt", symbol, nil) + if err != nil { + // Ignore if no orders to cancel + if !strings.Contains(err.Error(), "ORDER_NOT_FOUND") { + logger.Warnf(" [Gate] Error canceling orders: %v", err) + } + } + + // Cancel trigger orders + t.cancelTriggerOrders(symbol, "") + + return nil +} + +// CancelStopOrders cancels all stop orders (stop loss and take profit) +func (t *GateTrader) CancelStopOrders(symbol string) error { + t.CancelStopLossOrders(symbol) + t.CancelTakeProfitOrders(symbol) + return nil +} + +// FormatQuantity formats quantity to correct precision +func (t *GateTrader) FormatQuantity(symbol string, quantity float64) (string, error) { + contract, err := t.getContract(symbol) + if err != nil { + return fmt.Sprintf("%.4f", quantity), nil + } + + // Gate uses quanto_multiplier for contract size + quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64) + if quantoMultiplier > 0 { + // Calculate number of contracts + numContracts := quantity / quantoMultiplier + return fmt.Sprintf("%.0f", math.Floor(numContracts)), nil + } + + return fmt.Sprintf("%.4f", quantity), nil +} + +// GetOrderStatus gets the status of an order +func (t *GateTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { + symbol = t.convertSymbol(symbol) + + order, _, err := t.client.FuturesApi.GetFuturesOrder(t.ctx, "usdt", orderID) + if err != nil { + return nil, fmt.Errorf("failed to get order status: %w", err) + } + + fillPrice, _ := strconv.ParseFloat(order.FillPrice, 64) + tkFee, _ := strconv.ParseFloat(order.Tkfr, 64) + mkFee, _ := strconv.ParseFloat(order.Mkfr, 64) + totalFee := tkFee + mkFee + + // Get quanto_multiplier to convert contracts to actual quantity + quantoMultiplier := 1.0 + contract, contractErr := t.getContract(symbol) + if contractErr == nil && contract != nil { + qm, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64) + if qm > 0 { + quantoMultiplier = qm + } + } + + // Map status + status := "NEW" + switch order.Status { + case "finished": + if order.FinishAs == "filled" { + status = "FILLED" + } else if order.FinishAs == "cancelled" { + status = "CANCELED" + } else { + status = "CLOSED" + } + case "open": + status = "NEW" + } + + side := "BUY" + if order.Size < 0 { + side = "SELL" + } + + // Convert contract count to actual token quantity + executedQty := math.Abs(float64(order.Size-order.Left)) * quantoMultiplier + + return map[string]interface{}{ + "orderId": orderID, + "symbol": t.revertSymbol(symbol), + "status": status, + "avgPrice": fillPrice, + "executedQty": executedQty, + "side": side, + "type": order.Tif, + "time": int64(order.CreateTime * 1000), + "updateTime": int64(order.FinishTime * 1000), + "commission": totalFee, + }, nil +} + +// GetOpenOrders gets open/pending orders +func (t *GateTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) { + symbol = t.convertSymbol(symbol) + + opts := &gateapi.ListFuturesOrdersOpts{ + Contract: optional.NewString(symbol), + } + + orders, _, err := t.client.FuturesApi.ListFuturesOrders(t.ctx, "usdt", "open", opts) + if err != nil { + return nil, fmt.Errorf("failed to get open orders: %w", err) + } + + // Get quanto_multiplier to convert contracts to actual quantity + quantoMultiplier := 1.0 + contract, err := t.getContract(symbol) + if err == nil && contract != nil { + qm, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64) + if qm > 0 { + quantoMultiplier = qm + } + } + + var result []types.OpenOrder + for _, order := range orders { + price, _ := strconv.ParseFloat(order.Price, 64) + + side := "BUY" + if order.Size < 0 { + side = "SELL" + } + + // Convert contract count to actual token quantity + quantity := math.Abs(float64(order.Size)) * quantoMultiplier + + result = append(result, types.OpenOrder{ + OrderID: fmt.Sprintf("%d", order.Id), + Symbol: t.revertSymbol(order.Contract), + Side: side, + Type: "LIMIT", + Price: price, + Quantity: quantity, + Status: "NEW", + }) + } + + // Also get trigger orders + triggerOpts := &gateapi.ListPriceTriggeredOrdersOpts{ + Contract: optional.NewString(symbol), + } + + triggerOrders, _, err := t.client.FuturesApi.ListPriceTriggeredOrders(t.ctx, "usdt", "open", triggerOpts) + if err == nil { + for _, order := range triggerOrders { + triggerPrice, _ := strconv.ParseFloat(order.Trigger.Price, 64) + + side := "BUY" + if order.Initial.Size < 0 { + side = "SELL" + } + + orderType := "STOP_MARKET" + if order.Trigger.Rule == 2 { + orderType = "TAKE_PROFIT_MARKET" + } + + // Convert contract count to actual token quantity + quantity := math.Abs(float64(order.Initial.Size)) * quantoMultiplier + + result = append(result, types.OpenOrder{ + OrderID: fmt.Sprintf("%d", order.Id), + Symbol: t.revertSymbol(order.Initial.Contract), + Side: side, + Type: orderType, + StopPrice: triggerPrice, + Quantity: quantity, + Status: "NEW", + }) + } + } + + return result, nil +} diff --git a/trader/hyperliquid/balance_test.go b/trader/hyperliquid/balance_test.go deleted file mode 100644 index 491a8fd6..00000000 --- a/trader/hyperliquid/balance_test.go +++ /dev/null @@ -1,295 +0,0 @@ -package hyperliquid - -import ( - "os" - "testing" - "time" -) - -// TestHyperliquidBalanceCalculation tests the balance calculation for Hyperliquid -// including perp, spot, and xyz dex (stocks, forex, metals) accounts -// Run with: TEST_PRIVATE_KEY=xxx TEST_WALLET_ADDR=xxx go test -v -run TestHyperliquidBalanceCalculation ./trader/ -func TestHyperliquidBalanceCalculation(t *testing.T) { - // Get credentials from environment - privateKeyHex := os.Getenv("TEST_PRIVATE_KEY") - walletAddr := os.Getenv("TEST_WALLET_ADDR") - - if privateKeyHex == "" || walletAddr == "" { - t.Skip("TEST_PRIVATE_KEY and TEST_WALLET_ADDR env vars required") - } - - t.Logf("=== Testing Hyperliquid Balance Calculation ===") - t.Logf("Wallet: %s", walletAddr) - - // Create trader instance - trader, err := NewHyperliquidTrader(privateKeyHex, walletAddr, false) - if err != nil { - t.Fatalf("Failed to create trader: %v", err) - } - - // Test GetBalance - t.Log("\n--- Testing GetBalance ---") - balance, err := trader.GetBalance() - if err != nil { - t.Fatalf("GetBalance failed: %v", err) - } - - // Extract values - totalWalletBalance, _ := balance["totalWalletBalance"].(float64) - totalEquity, _ := balance["totalEquity"].(float64) - totalUnrealizedProfit, _ := balance["totalUnrealizedProfit"].(float64) - availableBalance, _ := balance["availableBalance"].(float64) - spotBalance, _ := balance["spotBalance"].(float64) - xyzDexBalance, _ := balance["xyzDexBalance"].(float64) - xyzDexUnrealizedPnl, _ := balance["xyzDexUnrealizedPnl"].(float64) - perpAccountValue, _ := balance["perpAccountValue"].(float64) - - t.Logf("\n📊 Balance Results:") - t.Logf(" Perp Account Value: %.4f USDC", perpAccountValue) - t.Logf(" Spot Balance: %.4f USDC", spotBalance) - t.Logf(" xyz Dex Balance: %.4f USDC", xyzDexBalance) - t.Logf(" xyz Dex Unrealized PnL: %.4f USDC", xyzDexUnrealizedPnl) - t.Logf(" ---") - t.Logf(" Total Wallet Balance: %.4f USDC", totalWalletBalance) - t.Logf(" Total Unrealized PnL: %.4f USDC", totalUnrealizedProfit) - t.Logf(" Total Equity: %.4f USDC", totalEquity) - t.Logf(" Available Balance: %.4f USDC", availableBalance) - - // Verify calculation: totalEquity should equal perpAccountValue + spotBalance + xyzDexBalance - expectedEquity := perpAccountValue + spotBalance + xyzDexBalance - t.Logf("\n🔍 Verification:") - t.Logf(" Expected Equity (Perp + Spot + xyz): %.4f", expectedEquity) - t.Logf(" Actual Total Equity: %.4f", totalEquity) - - if abs(totalEquity-expectedEquity) > 0.01 { - t.Errorf("❌ Equity mismatch! Expected %.4f, got %.4f", expectedEquity, totalEquity) - } else { - t.Logf("✅ Equity calculation correct!") - } - - // Verify: totalWalletBalance + totalUnrealizedProfit should equal totalEquity - calculatedEquity := totalWalletBalance + totalUnrealizedProfit - t.Logf("\n🔍 Secondary Verification:") - t.Logf(" Wallet + Unrealized = %.4f + %.4f = %.4f", totalWalletBalance, totalUnrealizedProfit, calculatedEquity) - t.Logf(" Total Equity: %.4f", totalEquity) - - if abs(calculatedEquity-totalEquity) > 0.01 { - t.Errorf("❌ Secondary check failed! Wallet+Unrealized=%.4f != Equity=%.4f", calculatedEquity, totalEquity) - } else { - t.Logf("✅ Secondary verification passed!") - } - - // Test GetPositions - t.Log("\n--- Testing GetPositions ---") - positions, err := trader.GetPositions() - if err != nil { - t.Fatalf("GetPositions failed: %v", err) - } - - t.Logf("Found %d positions:", len(positions)) - totalPositionValue := 0.0 - totalPositionPnL := 0.0 - - for i, pos := range positions { - symbol, _ := pos["symbol"].(string) - side, _ := pos["side"].(string) - positionAmt, _ := pos["positionAmt"].(float64) - entryPrice, _ := pos["entryPrice"].(float64) - markPrice, _ := pos["markPrice"].(float64) - unrealizedPnL, _ := pos["unRealizedProfit"].(float64) - leverage, _ := pos["leverage"].(float64) - isXyzDex, _ := pos["isXyzDex"].(bool) - - posValue := positionAmt * markPrice - totalPositionValue += posValue - totalPositionPnL += unrealizedPnL - - assetType := "Crypto" - if isXyzDex { - assetType = "xyz Dex" - } - - t.Logf(" [%d] %s (%s)", i+1, symbol, assetType) - t.Logf(" Side: %s, Qty: %.4f, Leverage: %.0fx", side, positionAmt, leverage) - t.Logf(" Entry: %.4f, Mark: %.4f", entryPrice, markPrice) - t.Logf(" Value: %.4f, PnL: %.4f", posValue, unrealizedPnL) - - // Verify xyz dex position has valid entry/mark prices - if isXyzDex { - if entryPrice == 0 { - t.Errorf("❌ xyz dex position %s has zero entry price!", symbol) - } - if markPrice == 0 { - t.Errorf("❌ xyz dex position %s has zero mark price!", symbol) - } - } - } - - t.Logf("\n📊 Position Summary:") - t.Logf(" Total Position Value: %.4f USDC", totalPositionValue) - t.Logf(" Total Position PnL: %.4f USDC", totalPositionPnL) - - // Compare position PnL with balance unrealized PnL - t.Logf("\n🔍 PnL Comparison:") - t.Logf(" Balance Unrealized PnL: %.4f", totalUnrealizedProfit) - t.Logf(" Position Sum PnL: %.4f", totalPositionPnL) - - if abs(totalUnrealizedProfit-totalPositionPnL) > 0.1 { - t.Logf("⚠️ PnL mismatch (may be due to funding fees or timing)") - } else { - t.Logf("✅ PnL values match!") - } -} - -// TestXyzDexBalanceDirectQuery directly queries xyz dex balance for debugging -func TestXyzDexBalanceDirectQuery(t *testing.T) { - privateKeyHex := os.Getenv("TEST_PRIVATE_KEY") - walletAddr := os.Getenv("TEST_WALLET_ADDR") - - if privateKeyHex == "" || walletAddr == "" { - t.Skip("TEST_PRIVATE_KEY and TEST_WALLET_ADDR env vars required") - } - - trader, err := NewHyperliquidTrader(privateKeyHex, walletAddr, false) - if err != nil { - t.Fatalf("Failed to create trader: %v", err) - } - - t.Log("=== Direct xyz Dex Balance Query ===") - - accountValue, unrealizedPnl, positions, err := trader.getXYZDexBalance() - if err != nil { - t.Fatalf("getXYZDexBalance failed: %v", err) - } - - t.Logf("xyz Dex Account Value: %.4f", accountValue) - t.Logf("xyz Dex Unrealized PnL: %.4f", unrealizedPnl) - t.Logf("xyz Dex Wallet Balance: %.4f", accountValue-unrealizedPnl) - t.Logf("xyz Dex Positions: %d", len(positions)) - - for i, pos := range positions { - entryPx := "nil" - if pos.Position.EntryPx != nil { - entryPx = *pos.Position.EntryPx - } - liqPx := "nil" - if pos.Position.LiquidationPx != nil { - liqPx = *pos.Position.LiquidationPx - } - - t.Logf(" [%d] %s:", i+1, pos.Position.Coin) - t.Logf(" Size: %s", pos.Position.Szi) - t.Logf(" Entry Price: %s", entryPx) - t.Logf(" Position Value: %s", pos.Position.PositionValue) - t.Logf(" Unrealized PnL: %s", pos.Position.UnrealizedPnl) - t.Logf(" Liquidation Price: %s", liqPx) - t.Logf(" Leverage: %d (%s)", pos.Position.Leverage.Value, pos.Position.Leverage.Type) - } -} - -// TestEquityAfterOpeningPosition simulates opening a position and verifies equity -func TestEquityAfterOpeningPosition(t *testing.T) { - privateKeyHex := os.Getenv("TEST_PRIVATE_KEY") - walletAddr := os.Getenv("TEST_WALLET_ADDR") - - if privateKeyHex == "" || walletAddr == "" { - t.Skip("TEST_PRIVATE_KEY and TEST_WALLET_ADDR env vars required") - } - - if os.Getenv("XYZ_DEX_LIVE_TEST") != "1" { - t.Skip("Set XYZ_DEX_LIVE_TEST=1 to run live position test") - } - - trader, err := NewHyperliquidTrader(privateKeyHex, walletAddr, false) - if err != nil { - t.Fatalf("Failed to create trader: %v", err) - } - - // Step 1: Record initial balance - t.Log("=== Step 1: Record Initial Balance ===") - initialBalance, _ := trader.GetBalance() - initialEquity, _ := initialBalance["totalEquity"].(float64) - t.Logf("Initial Equity: %.4f", initialEquity) - - // Step 2: Fetch xyz meta - if err := trader.fetchXyzMeta(); err != nil { - t.Fatalf("Failed to fetch xyz meta: %v", err) - } - - // Step 3: Get current price and place a small order - price, err := trader.getXyzMarketPrice("xyz:SILVER") - if err != nil { - t.Fatalf("Failed to get price: %v", err) - } - t.Logf("Current xyz:SILVER price: %.4f", price) - - // Place a small buy order (minimum ~$10) - testSize := 0.14 - testPrice := price * 1.05 // 5% above for IOC - - t.Log("\n=== Step 2: Place Test Order ===") - t.Logf("Opening position: xyz:SILVER BUY %.4f @ %.4f", testSize, testPrice) - - err = trader.placeXyzOrder("xyz:SILVER", true, testSize, testPrice, false) - if err != nil { - t.Logf("Order result: %v", err) - // Even if IOC doesn't fill, continue to check balance - } - - // Wait a moment for the order to process - time.Sleep(2 * time.Second) - - // Step 3: Check balance after order - t.Log("\n=== Step 3: Check Balance After Order ===") - afterBalance, _ := trader.GetBalance() - afterEquity, _ := afterBalance["totalEquity"].(float64) - afterPerpAV, _ := afterBalance["perpAccountValue"].(float64) - afterXyzAV, _ := afterBalance["xyzDexBalance"].(float64) - - t.Logf("After Order:") - t.Logf(" Perp Account Value: %.4f", afterPerpAV) - t.Logf(" xyz Dex Balance: %.4f", afterXyzAV) - t.Logf(" Total Equity: %.4f", afterEquity) - - equityChange := afterEquity - initialEquity - t.Logf("\nEquity Change: %.4f (%.2f%%)", equityChange, (equityChange/initialEquity)*100) - - // Equity should not change significantly (only by trading fees/slippage) - if abs(equityChange) > initialEquity*0.05 { // More than 5% change is suspicious - t.Errorf("❌ Equity changed too much! Initial=%.4f, After=%.4f, Change=%.4f", - initialEquity, afterEquity, equityChange) - } else { - t.Logf("✅ Equity change is within acceptable range") - } - - // Step 4: Close position if opened - t.Log("\n=== Step 4: Close Position ===") - positions, _ := trader.GetPositions() - for _, pos := range positions { - symbol, _ := pos["symbol"].(string) - if symbol == "xyz:SILVER" { - posAmt, _ := pos["positionAmt"].(float64) - if posAmt > 0 { - closePrice := price * 0.95 // 5% below for IOC sell - t.Logf("Closing position: SELL %.4f @ %.4f", posAmt, closePrice) - trader.placeXyzOrder("xyz:SILVER", false, posAmt, closePrice, true) - } - } - } - - time.Sleep(2 * time.Second) - - // Final balance check - t.Log("\n=== Step 5: Final Balance ===") - finalBalance, _ := trader.GetBalance() - finalEquity, _ := finalBalance["totalEquity"].(float64) - t.Logf("Final Equity: %.4f", finalEquity) - t.Logf("Net Change: %.4f", finalEquity-initialEquity) -} - -func abs(x float64) float64 { - if x < 0 { - return -x - } - return x -} diff --git a/trader/hyperliquid/trader.go b/trader/hyperliquid/trader.go index f5572c78..2339a229 100644 --- a/trader/hyperliquid/trader.go +++ b/trader/hyperliquid/trader.go @@ -1,22 +1,16 @@ package hyperliquid import ( - "bytes" "context" "crypto/ecdsa" - "encoding/json" "fmt" - "io" - "net/http" "nofx/logger" "strconv" "strings" "sync" - "time" "github.com/ethereum/go-ethereum/crypto" "github.com/sonirico/go-hyperliquid" - "nofx/trader/types" ) // HyperliquidTrader Hyperliquid trader @@ -64,6 +58,15 @@ var xyzDexAssets = map[string]bool{ "XYZ100": true, } +// defaultBuilder is the builder info for order routing +// Set to nil to avoid requiring builder fee approval +// +// var defaultBuilder = &hyperliquid.BuilderInfo{ +// Builder: "0x891dc6f05ad47a3c1a05da55e7a7517971faaf0d", +// Fee: 10, +// } +var defaultBuilder *hyperliquid.BuilderInfo = nil + // isXyzDexAsset checks if a symbol is an xyz dex asset func isXyzDexAsset(symbol string) bool { // Remove common suffixes to get base symbol @@ -80,6 +83,39 @@ func isXyzDexAsset(symbol string) bool { return xyzDexAssets[base] } +// convertSymbolToHyperliquid converts standard symbol to Hyperliquid format +// Example: "BTCUSDT" -> "BTC", "TSLA" -> "xyz:TSLA", "silver" -> "xyz:SILVER" +func convertSymbolToHyperliquid(symbol string) string { + // Convert to uppercase for consistent handling + base := strings.ToUpper(symbol) + + // Remove common suffixes to get base symbol + for _, suffix := range []string{"USDT", "USD", "-USDC", "-USD"} { + if strings.HasSuffix(base, suffix) { + base = strings.TrimSuffix(base, suffix) + break + } + } + // Remove xyz: prefix if present (case-insensitive, will be re-added if needed) + if strings.HasPrefix(strings.ToLower(base), "xyz:") { + base = base[4:] // Remove first 4 characters + } + + // Check if this is an xyz dex asset (stocks, forex, commodities) + if isXyzDexAsset(base) { + return "xyz:" + base + } + return base +} + +// absFloat returns absolute value of float +func absFloat(x float64) float64 { + if x < 0 { + return -x + } + return x +} + // NewHyperliquidTrader creates a Hyperliquid trader // unifiedAccount: when true, Spot USDC balance is used as collateral for Perp trading func NewHyperliquidTrader(privateKeyHex string, walletAddr string, testnet bool, unifiedAccount bool) (*HyperliquidTrader, error) { @@ -144,7 +180,7 @@ func NewHyperliquidTrader(privateKeyHex string, walletAddr string, testnet bool, return nil, fmt.Errorf("failed to get meta information: %w", err) } - // 🔍 Security check: Validate Agent wallet balance (should be close to 0) + // Security check: Validate Agent wallet balance (should be close to 0) // Only check if using separate Agent wallet (not when main wallet is used as agent) if !strings.EqualFold(walletAddr, agentAddr) { agentState, err := exchange.Info().UserState(ctx, agentAddr) @@ -193,1613 +229,6 @@ func NewHyperliquidTrader(privateKeyHex string, walletAddr string, testnet bool, }, nil } -// GetBalance gets account balance -func (t *HyperliquidTrader) GetBalance() (map[string]interface{}, error) { - logger.Infof("🔄 Calling Hyperliquid API to get account balance...") - - // ✅ Step 1: Query Spot account balance - spotState, err := t.exchange.Info().SpotUserState(t.ctx, t.walletAddr) - var spotUSDCBalance float64 = 0.0 - if err != nil { - logger.Infof("⚠️ Failed to query Spot balance (may have no spot assets): %v", err) - } else if spotState != nil && len(spotState.Balances) > 0 { - for _, balance := range spotState.Balances { - if balance.Coin == "USDC" { - spotUSDCBalance, _ = strconv.ParseFloat(balance.Total, 64) - logger.Infof("✓ Found Spot balance: %.2f USDC", spotUSDCBalance) - break - } - } - } - - // ✅ Step 2: Query Perpetuals contract account status - accountState, err := t.exchange.Info().UserState(t.ctx, t.walletAddr) - if err != nil { - logger.Infof("❌ Hyperliquid Perpetuals API call failed: %v", err) - return nil, fmt.Errorf("failed to get account information: %w", err) - } - - // Parse balance information (MarginSummary fields are all strings) - result := make(map[string]interface{}) - - // ✅ Step 3: Dynamically select correct summary based on margin mode (CrossMarginSummary or MarginSummary) - var accountValue, totalMarginUsed float64 - var summaryType string - var summary interface{} - - if t.isCrossMargin { - // Cross margin mode: use CrossMarginSummary - accountValue, _ = strconv.ParseFloat(accountState.CrossMarginSummary.AccountValue, 64) - totalMarginUsed, _ = strconv.ParseFloat(accountState.CrossMarginSummary.TotalMarginUsed, 64) - summaryType = "CrossMarginSummary (cross margin)" - summary = accountState.CrossMarginSummary - } else { - // Isolated margin mode: use MarginSummary - accountValue, _ = strconv.ParseFloat(accountState.MarginSummary.AccountValue, 64) - totalMarginUsed, _ = strconv.ParseFloat(accountState.MarginSummary.TotalMarginUsed, 64) - summaryType = "MarginSummary (isolated margin)" - summary = accountState.MarginSummary - } - - // 🔍 Debug: Print complete summary structure returned by API - summaryJSON, _ := json.MarshalIndent(summary, " ", " ") - logger.Infof("🔍 [DEBUG] Hyperliquid API %s complete data:", summaryType) - logger.Infof("%s", string(summaryJSON)) - - // ⚠️ Critical fix: Accumulate actual unrealized PnL from all positions - totalUnrealizedPnl := 0.0 - for _, assetPos := range accountState.AssetPositions { - unrealizedPnl, _ := strconv.ParseFloat(assetPos.Position.UnrealizedPnl, 64) - totalUnrealizedPnl += unrealizedPnl - } - - // ✅ Correctly understand Hyperliquid fields: - // AccountValue = Total account equity (includes idle funds + position value + unrealized PnL) - // TotalMarginUsed = Margin used by positions (included in AccountValue, for display only) - // - // To be compatible with auto_types.go calculation logic (totalEquity = totalWalletBalance + totalUnrealizedProfit) - // Need to return "wallet balance without unrealized PnL" - walletBalanceWithoutUnrealized := accountValue - totalUnrealizedPnl - - // ✅ Step 4: Use Withdrawable field (PR #443) - // Withdrawable is the official real withdrawable balance, more reliable than simple calculation - availableBalance := 0.0 - if accountState.Withdrawable != "" { - withdrawable, err := strconv.ParseFloat(accountState.Withdrawable, 64) - if err == nil && withdrawable > 0 { - availableBalance = withdrawable - logger.Infof("✓ Using Withdrawable as available balance: %.2f", availableBalance) - } - } - - // Fallback: If no Withdrawable, use simple calculation - if availableBalance == 0 && accountState.Withdrawable == "" { - availableBalance = accountValue - totalMarginUsed - if availableBalance < 0 { - logger.Infof("⚠️ Calculated available balance is negative (%.2f), reset to 0", availableBalance) - availableBalance = 0 - } - } - - // ✅ Step 5: Query xyz dex balance (stock perps, forex, commodities) - var xyzAccountValue, xyzUnrealizedPnl float64 - var xyzPositions []xyzAssetPosition - xyzAccountValue, xyzUnrealizedPnl, xyzPositions, err = t.getXYZDexBalance() - if err != nil { - // xyz dex query failed - log warning but don't fail the entire balance query - logger.Infof("⚠️ Failed to query xyz dex balance: %v", err) - } - // Always log xyz dex state for debugging - logger.Infof("🔍 xyz dex state: accountValue=%.4f, unrealizedPnl=%.4f, positions=%d", - xyzAccountValue, xyzUnrealizedPnl, len(xyzPositions)) - for _, pos := range xyzPositions { - entryPx := "nil" - if pos.Position.EntryPx != nil { - entryPx = *pos.Position.EntryPx - } - logger.Infof(" └─ %s: size=%s, entryPx=%s, posValue=%s, pnl=%s", - pos.Position.Coin, pos.Position.Szi, entryPx, pos.Position.PositionValue, pos.Position.UnrealizedPnl) - } - xyzWalletBalance := xyzAccountValue - xyzUnrealizedPnl - - // ✅ Step 6: Correctly handle Spot + Perpetuals + xyz dex balance - // Important: Each account is independent, manual transfers required - totalWalletBalance := walletBalanceWithoutUnrealized + spotUSDCBalance + xyzWalletBalance - totalUnrealizedPnlAll := totalUnrealizedPnl + xyzUnrealizedPnl - - // Calculate total equity properly: perpAccountValue + spotUSDCBalance + xyzAccountValue - // Note: totalWalletBalance + totalUnrealizedPnlAll should equal this - totalEquityCalculated := accountValue + spotUSDCBalance + xyzAccountValue - - // ✅ Step 7: Unified Account mode - Spot USDC is used as collateral for Perps - // In this mode, available balance includes Spot USDC since it can be used for Perp margin - if t.isUnifiedAccount && spotUSDCBalance > 0 { - // Add Spot balance to available balance for trading - availableBalance = availableBalance + spotUSDCBalance - logger.Infof("✓ Unified Account: Spot %.2f USDC added to available balance (total: %.2f)", - spotUSDCBalance, availableBalance) - } - - result["totalWalletBalance"] = totalWalletBalance // Total assets (Perp + Spot + xyz) - unrealized - result["totalEquity"] = totalEquityCalculated // Total equity = Perp AV + Spot + xyz AV - result["availableBalance"] = availableBalance // Available balance (Perp + Spot if unified) - result["totalUnrealizedProfit"] = totalUnrealizedPnlAll // Unrealized PnL (Perpetuals + xyz) - result["spotBalance"] = spotUSDCBalance // Spot balance - result["xyzDexBalance"] = xyzAccountValue // xyz dex equity (stock perps, forex, commodities) - result["xyzDexUnrealizedPnl"] = xyzUnrealizedPnl // xyz dex unrealized PnL - result["perpAccountValue"] = accountValue // Perp account value for debugging - - logger.Infof("✓ Hyperliquid complete account:") - logger.Infof(" • Spot balance: %.2f USDC", spotUSDCBalance) - logger.Infof(" • Perpetuals equity: %.2f USDC (wallet %.2f + unrealized %.2f)", - accountValue, - walletBalanceWithoutUnrealized, - totalUnrealizedPnl) - logger.Infof(" • Perpetuals available balance: %.2f USDC", availableBalance) - logger.Infof(" • Margin used: %.2f USDC", totalMarginUsed) - logger.Infof(" • xyz dex equity: %.2f USDC (wallet %.2f + unrealized %.2f)", - xyzAccountValue, - xyzWalletBalance, - xyzUnrealizedPnl) - logger.Infof(" • Total assets (Perp+Spot+xyz): %.2f USDC", totalWalletBalance) - logger.Infof(" ⭐ Total: %.2f USDC | Perp: %.2f | Spot: %.2f | xyz: %.2f", - totalWalletBalance, availableBalance, spotUSDCBalance, xyzAccountValue) - - return result, nil -} - -// xyzDexState represents the clearinghouse state for xyz dex -type xyzDexState struct { - MarginSummary *xyzMarginSummary `json:"marginSummary,omitempty"` - CrossMarginSummary *xyzMarginSummary `json:"crossMarginSummary,omitempty"` - Withdrawable string `json:"withdrawable,omitempty"` - AssetPositions []xyzAssetPosition `json:"assetPositions,omitempty"` -} - -type xyzMarginSummary struct { - AccountValue string `json:"accountValue"` - TotalMarginUsed string `json:"totalMarginUsed"` -} - -type xyzAssetPosition struct { - Position struct { - Coin string `json:"coin"` - Szi string `json:"szi"` - EntryPx *string `json:"entryPx"` - PositionValue string `json:"positionValue"` - UnrealizedPnl string `json:"unrealizedPnl"` - LiquidationPx *string `json:"liquidationPx"` - Leverage struct { - Type string `json:"type"` - Value int `json:"value"` - } `json:"leverage"` - } `json:"position"` -} - -// getXYZDexBalance queries the xyz dex balance (stock perps, forex, commodities) -func (t *HyperliquidTrader) getXYZDexBalance() (accountValue float64, unrealizedPnl float64, positions []xyzAssetPosition, err error) { - // Build request for xyz dex clearinghouse state - reqBody := map[string]interface{}{ - "type": "clearinghouseState", - "user": t.walletAddr, - "dex": "xyz", - } - - jsonBody, err := json.Marshal(reqBody) - if err != nil { - return 0, 0, nil, fmt.Errorf("failed to marshal request: %w", err) - } - - // Determine API URL - apiURL := "https://api.hyperliquid.xyz/info" - // Note: xyz dex may not be available on testnet - - req, err := http.NewRequestWithContext(t.ctx, "POST", apiURL, bytes.NewBuffer(jsonBody)) - if err != nil { - return 0, 0, nil, fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - return 0, 0, nil, fmt.Errorf("failed to execute request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return 0, 0, nil, fmt.Errorf("failed to read response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return 0, 0, nil, fmt.Errorf("xyz dex API error (status %d): %s", resp.StatusCode, string(body)) - } - - var state xyzDexState - if err := json.Unmarshal(body, &state); err != nil { - return 0, 0, nil, fmt.Errorf("failed to parse response: %w", err) - } - - // Parse account value - xyz dex uses MarginSummary for isolated margin mode - // CrossMarginSummary may exist but with 0 values, so check MarginSummary first - if state.MarginSummary != nil && state.MarginSummary.AccountValue != "" { - av, _ := strconv.ParseFloat(state.MarginSummary.AccountValue, 64) - if av > 0 { - accountValue = av - } - } - // Fallback to CrossMarginSummary if MarginSummary is 0 - if accountValue == 0 && state.CrossMarginSummary != nil && state.CrossMarginSummary.AccountValue != "" { - accountValue, _ = strconv.ParseFloat(state.CrossMarginSummary.AccountValue, 64) - } - - // Calculate total unrealized PnL from positions - for _, pos := range state.AssetPositions { - pnl, _ := strconv.ParseFloat(pos.Position.UnrealizedPnl, 64) - unrealizedPnl += pnl - } - - return accountValue, unrealizedPnl, state.AssetPositions, nil -} - -// fetchXyzMeta fetches metadata for xyz dex assets (stocks, forex, commodities) -func (t *HyperliquidTrader) fetchXyzMeta() error { - // Build request for xyz dex meta - reqBody := map[string]string{ - "type": "meta", - "dex": "xyz", - } - - jsonBody, err := json.Marshal(reqBody) - if err != nil { - return fmt.Errorf("failed to marshal request: %w", err) - } - - apiURL := "https://api.hyperliquid.xyz/info" - - req, err := http.NewRequestWithContext(t.ctx, "POST", apiURL, bytes.NewBuffer(jsonBody)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("failed to execute request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("xyz dex meta API error (status %d): %s", resp.StatusCode, string(body)) - } - - var meta xyzDexMeta - if err := json.Unmarshal(body, &meta); err != nil { - return fmt.Errorf("failed to parse response: %w", err) - } - - t.xyzMetaMutex.Lock() - t.xyzMeta = &meta - t.xyzMetaMutex.Unlock() - - logger.Infof("✅ xyz dex meta fetched, contains %d assets", len(meta.Universe)) - return nil -} - -// getXyzSzDecimals gets quantity precision for xyz dex asset -func (t *HyperliquidTrader) getXyzSzDecimals(coin string) int { - t.xyzMetaMutex.RLock() - defer t.xyzMetaMutex.RUnlock() - - if t.xyzMeta == nil { - logger.Infof("⚠️ xyz meta information is empty, using default precision 2") - return 2 // Default precision for stocks/forex - } - - // The meta API returns names with xyz: prefix, so ensure we match correctly - lookupName := coin - if !strings.HasPrefix(lookupName, "xyz:") { - lookupName = "xyz:" + lookupName - } - - // Find corresponding asset in xyzMeta.Universe - for _, asset := range t.xyzMeta.Universe { - if asset.Name == lookupName { - return asset.SzDecimals - } - } - - logger.Infof("⚠️ Precision information not found for %s, using default precision 2", lookupName) - return 2 // Default precision for stocks/forex -} - -// GetPositions gets all positions (including xyz dex positions) -func (t *HyperliquidTrader) GetPositions() ([]map[string]interface{}, error) { - // Get account status - accountState, err := t.exchange.Info().UserState(t.ctx, t.walletAddr) - if err != nil { - return nil, fmt.Errorf("failed to get positions: %w", err) - } - - var result []map[string]interface{} - - // Iterate through all perp positions - for _, assetPos := range accountState.AssetPositions { - position := assetPos.Position - - // Position amount (string type) - posAmt, _ := strconv.ParseFloat(position.Szi, 64) - - if posAmt == 0 { - continue // Skip positions with zero amount - } - - posMap := make(map[string]interface{}) - - // Normalize symbol format (Hyperliquid uses "BTC", we convert to "BTCUSDT") - symbol := position.Coin + "USDT" - posMap["symbol"] = symbol - - // Position amount and direction - if posAmt > 0 { - posMap["side"] = "long" - posMap["positionAmt"] = posAmt - } else { - posMap["side"] = "short" - posMap["positionAmt"] = -posAmt // Convert to positive number - } - - // Price information (EntryPx and LiquidationPx are pointer types) - var entryPrice, liquidationPx float64 - if position.EntryPx != nil { - entryPrice, _ = strconv.ParseFloat(*position.EntryPx, 64) - } - if position.LiquidationPx != nil { - liquidationPx, _ = strconv.ParseFloat(*position.LiquidationPx, 64) - } - - positionValue, _ := strconv.ParseFloat(position.PositionValue, 64) - unrealizedPnl, _ := strconv.ParseFloat(position.UnrealizedPnl, 64) - - // Calculate mark price (positionValue / abs(posAmt)) - var markPrice float64 - if posAmt != 0 { - markPrice = positionValue / absFloat(posAmt) - } - - posMap["entryPrice"] = entryPrice - posMap["markPrice"] = markPrice - posMap["unRealizedProfit"] = unrealizedPnl - posMap["leverage"] = float64(position.Leverage.Value) - posMap["liquidationPrice"] = liquidationPx - - result = append(result, posMap) - } - - // Also get xyz dex positions (stocks, forex, commodities) - _, _, xyzPositions, err := t.getXYZDexBalance() - if err != nil { - // xyz dex query failed - log warning but don't fail - logger.Infof("⚠️ Failed to get xyz dex positions: %v", err) - } else { - for _, pos := range xyzPositions { - posAmt, _ := strconv.ParseFloat(pos.Position.Szi, 64) - if posAmt == 0 { - continue - } - - posMap := make(map[string]interface{}) - - // xyz dex positions - the API returns coin names with xyz: prefix (e.g., "xyz:SILVER") - // Only add prefix if not already present - symbol := pos.Position.Coin - if !strings.HasPrefix(symbol, "xyz:") { - symbol = "xyz:" + symbol - } - posMap["symbol"] = symbol - - if posAmt > 0 { - posMap["side"] = "long" - posMap["positionAmt"] = posAmt - } else { - posMap["side"] = "short" - posMap["positionAmt"] = -posAmt - } - - // Parse price information - var entryPrice, liquidationPx float64 - if pos.Position.EntryPx != nil { - entryPrice, _ = strconv.ParseFloat(*pos.Position.EntryPx, 64) - } - if pos.Position.LiquidationPx != nil { - liquidationPx, _ = strconv.ParseFloat(*pos.Position.LiquidationPx, 64) - } - - positionValue, _ := strconv.ParseFloat(pos.Position.PositionValue, 64) - unrealizedPnl, _ := strconv.ParseFloat(pos.Position.UnrealizedPnl, 64) - - // Calculate mark price from position value - var markPrice float64 - if posAmt != 0 { - markPrice = positionValue / absFloat(posAmt) - } - - // Get leverage (default to 1 if not available) - leverage := float64(pos.Position.Leverage.Value) - if leverage == 0 { - leverage = 1.0 - } - - posMap["entryPrice"] = entryPrice - posMap["markPrice"] = markPrice - posMap["unRealizedProfit"] = unrealizedPnl - posMap["leverage"] = leverage - posMap["liquidationPrice"] = liquidationPx - posMap["isXyzDex"] = true // Mark as xyz dex position - - result = append(result, posMap) - } - } - - return result, nil -} - -// SetMarginMode sets margin mode (set together with SetLeverage) -func (t *HyperliquidTrader) SetMarginMode(symbol string, isCrossMargin bool) error { - // Hyperliquid's margin mode is set in SetLeverage, only record here - t.isCrossMargin = isCrossMargin - marginModeStr := "cross margin" - if !isCrossMargin { - marginModeStr = "isolated margin" - } - logger.Infof(" ✓ %s will use %s mode", symbol, marginModeStr) - return nil -} - -// SetLeverage sets leverage -func (t *HyperliquidTrader) SetLeverage(symbol string, leverage int) error { - // Hyperliquid symbol format (remove USDT suffix) - coin := convertSymbolToHyperliquid(symbol) - - // Call UpdateLeverage (leverage int, name string, isCross bool) - // Third parameter: true=cross margin mode, false=isolated margin mode - _, err := t.exchange.UpdateLeverage(t.ctx, leverage, coin, t.isCrossMargin) - if err != nil { - return fmt.Errorf("failed to set leverage: %w", err) - } - - logger.Infof(" ✓ %s leverage switched to %dx", symbol, leverage) - return nil -} - -// refreshMetaIfNeeded refreshes meta information when invalid (triggered when Asset ID is 0) -func (t *HyperliquidTrader) refreshMetaIfNeeded(coin string) error { - assetID := t.exchange.Info().NameToAsset(coin) - if assetID != 0 { - return nil // Meta is normal, no refresh needed - } - - logger.Infof("⚠️ Asset ID for %s is 0, attempting to refresh Meta information...", coin) - - // Refresh Meta information - meta, err := t.exchange.Info().Meta(t.ctx) - if err != nil { - return fmt.Errorf("failed to refresh Meta information: %w", err) - } - - // ✅ Concurrency safe: Use write lock to protect meta field update - t.metaMutex.Lock() - t.meta = meta - t.metaMutex.Unlock() - - logger.Infof("✅ Meta information refreshed, contains %d assets", len(meta.Universe)) - - // Verify Asset ID after refresh - assetID = t.exchange.Info().NameToAsset(coin) - if assetID == 0 { - return fmt.Errorf("❌ Even after refreshing Meta, Asset ID for %s is still 0. Possible reasons:\n"+ - " 1. This coin is not listed on Hyperliquid\n"+ - " 2. Coin name is incorrect (should be BTC not BTCUSDT)\n"+ - " 3. API connection issue", coin) - } - - logger.Infof("✅ Asset ID check passed after refresh: %s -> %d", coin, assetID) - return nil -} - -// OpenLong opens a long position (supports both crypto and xyz dex) -func (t *HyperliquidTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { - // First cancel all pending orders for this coin - if err := t.CancelAllOrders(symbol); err != nil { - logger.Infof(" ⚠ Failed to cancel old pending orders: %v", err) - } - - // Hyperliquid symbol format - coin := convertSymbolToHyperliquid(symbol) - - // Check if this is an xyz dex asset - isXyz := strings.HasPrefix(coin, "xyz:") - - // Set leverage (skip for xyz dex as it may not support leverage adjustment) - if !isXyz { - if err := t.SetLeverage(symbol, leverage); err != nil { - return nil, err - } - } else { - logger.Infof(" ℹ xyz dex asset %s - using default leverage", coin) - } - - // Get current price (for market order) - price, err := t.GetMarketPrice(symbol) - if err != nil { - return nil, err - } - - // ⚠️ Critical: Price needs to be processed to 5 significant figures - aggressivePrice := t.roundPriceToSigfigs(price * 1.01) - logger.Infof(" 💰 Price precision handling: %.8f -> %.8f (5 significant figures)", price*1.01, aggressivePrice) - - // Handle xyz dex assets differently - if isXyz { - // xyz dex order - if err := t.placeXyzOrder(coin, true, quantity, aggressivePrice, false); err != nil { - return nil, fmt.Errorf("failed to open long position on xyz dex: %w", err) - } - } else { - // Standard crypto order - roundedQuantity := t.roundToSzDecimals(coin, quantity) - logger.Infof(" 📏 Quantity precision handling: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin)) - - order := hyperliquid.CreateOrderRequest{ - Coin: coin, - IsBuy: true, - Size: roundedQuantity, - Price: aggressivePrice, - OrderType: hyperliquid.OrderType{ - Limit: &hyperliquid.LimitOrderType{ - Tif: hyperliquid.TifIoc, - }, - }, - ReduceOnly: false, - } - - _, err = t.exchange.Order(t.ctx, order, defaultBuilder) - if err != nil { - return nil, fmt.Errorf("failed to open long position: %w", err) - } - } - - logger.Infof("✓ Long position opened successfully: %s quantity: %.4f", symbol, quantity) - - result := make(map[string]interface{}) - result["orderId"] = 0 - result["symbol"] = symbol - result["status"] = "FILLED" - - return result, nil -} - -// OpenShort opens a short position (supports both crypto and xyz dex) -func (t *HyperliquidTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { - // First cancel all pending orders for this coin - if err := t.CancelAllOrders(symbol); err != nil { - logger.Infof(" ⚠ Failed to cancel old pending orders: %v", err) - } - - // Hyperliquid symbol format - coin := convertSymbolToHyperliquid(symbol) - - // Check if this is an xyz dex asset - isXyz := strings.HasPrefix(coin, "xyz:") - - // Set leverage (skip for xyz dex) - if !isXyz { - if err := t.SetLeverage(symbol, leverage); err != nil { - return nil, err - } - } else { - logger.Infof(" ℹ xyz dex asset %s - using default leverage", coin) - } - - // Get current price - price, err := t.GetMarketPrice(symbol) - if err != nil { - return nil, err - } - - // ⚠️ Critical: Price needs to be processed to 5 significant figures - aggressivePrice := t.roundPriceToSigfigs(price * 0.99) - logger.Infof(" 💰 Price precision handling: %.8f -> %.8f (5 significant figures)", price*0.99, aggressivePrice) - - // Handle xyz dex assets differently - if isXyz { - // xyz dex order - if err := t.placeXyzOrder(coin, false, quantity, aggressivePrice, false); err != nil { - return nil, fmt.Errorf("failed to open short position on xyz dex: %w", err) - } - } else { - // Standard crypto order - roundedQuantity := t.roundToSzDecimals(coin, quantity) - logger.Infof(" 📏 Quantity precision handling: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin)) - - order := hyperliquid.CreateOrderRequest{ - Coin: coin, - IsBuy: false, - Size: roundedQuantity, - Price: aggressivePrice, - OrderType: hyperliquid.OrderType{ - Limit: &hyperliquid.LimitOrderType{ - Tif: hyperliquid.TifIoc, - }, - }, - ReduceOnly: false, - } - - _, err = t.exchange.Order(t.ctx, order, defaultBuilder) - if err != nil { - return nil, fmt.Errorf("failed to open short position: %w", err) - } - } - - logger.Infof("✓ Short position opened successfully: %s quantity: %.4f", symbol, quantity) - - result := make(map[string]interface{}) - result["orderId"] = 0 - result["symbol"] = symbol - result["status"] = "FILLED" - - return result, nil -} - -// CloseLong closes a long position (supports both crypto and xyz dex) -func (t *HyperliquidTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) { - // Hyperliquid symbol format - coin := convertSymbolToHyperliquid(symbol) - isXyz := strings.HasPrefix(coin, "xyz:") - - // If quantity is 0, get current position quantity - if quantity == 0 { - positions, err := t.GetPositions() - if err != nil { - return nil, err - } - - // For xyz dex, also check xyz: prefixed symbols - searchSymbol := symbol - if isXyz { - searchSymbol = coin // Use xyz:SYMBOL format for comparison - } - - for _, pos := range positions { - posSymbol := pos["symbol"].(string) - if (posSymbol == symbol || posSymbol == searchSymbol) && pos["side"] == "long" { - quantity = pos["positionAmt"].(float64) - break - } - } - - if quantity == 0 { - return nil, fmt.Errorf("no long position found for %s", symbol) - } - } - - // Get current price - price, err := t.GetMarketPrice(symbol) - if err != nil { - return nil, err - } - - // ⚠️ Critical: Price needs to be processed to 5 significant figures - aggressivePrice := t.roundPriceToSigfigs(price * 0.99) - logger.Infof(" 💰 Price precision handling: %.8f -> %.8f (5 significant figures)", price*0.99, aggressivePrice) - - // Handle xyz dex assets differently - if isXyz { - // xyz dex close order - if err := t.placeXyzOrder(coin, false, quantity, aggressivePrice, true); err != nil { - return nil, fmt.Errorf("failed to close long position on xyz dex: %w", err) - } - } else { - // Standard crypto close order - roundedQuantity := t.roundToSzDecimals(coin, quantity) - logger.Infof(" 📏 Quantity precision handling: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin)) - - order := hyperliquid.CreateOrderRequest{ - Coin: coin, - IsBuy: false, - Size: roundedQuantity, - Price: aggressivePrice, - OrderType: hyperliquid.OrderType{ - Limit: &hyperliquid.LimitOrderType{ - Tif: hyperliquid.TifIoc, - }, - }, - ReduceOnly: true, - } - - _, err = t.exchange.Order(t.ctx, order, defaultBuilder) - if err != nil { - return nil, fmt.Errorf("failed to close long position: %w", err) - } - } - - logger.Infof("✓ Long position closed successfully: %s quantity: %.4f", symbol, quantity) - - // Cancel all pending orders for this coin after closing position - if err := t.CancelAllOrders(symbol); err != nil { - logger.Infof(" ⚠ Failed to cancel pending orders: %v", err) - } - - result := make(map[string]interface{}) - result["orderId"] = 0 - result["symbol"] = symbol - result["status"] = "FILLED" - - return result, nil -} - -// CloseShort closes a short position (supports both crypto and xyz dex) -func (t *HyperliquidTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) { - // Hyperliquid symbol format - coin := convertSymbolToHyperliquid(symbol) - isXyz := strings.HasPrefix(coin, "xyz:") - - // If quantity is 0, get current position quantity - if quantity == 0 { - positions, err := t.GetPositions() - if err != nil { - return nil, err - } - - // For xyz dex, also check xyz: prefixed symbols - searchSymbol := symbol - if isXyz { - searchSymbol = coin - } - - for _, pos := range positions { - posSymbol := pos["symbol"].(string) - if (posSymbol == symbol || posSymbol == searchSymbol) && pos["side"] == "short" { - quantity = pos["positionAmt"].(float64) - break - } - } - - if quantity == 0 { - return nil, fmt.Errorf("no short position found for %s", symbol) - } - } - - // Get current price - price, err := t.GetMarketPrice(symbol) - if err != nil { - return nil, err - } - - // ⚠️ Critical: Price needs to be processed to 5 significant figures - aggressivePrice := t.roundPriceToSigfigs(price * 1.01) - logger.Infof(" 💰 Price precision handling: %.8f -> %.8f (5 significant figures)", price*1.01, aggressivePrice) - - // Handle xyz dex assets differently - if isXyz { - // xyz dex close order - if err := t.placeXyzOrder(coin, true, quantity, aggressivePrice, true); err != nil { - return nil, fmt.Errorf("failed to close short position on xyz dex: %w", err) - } - } else { - // Standard crypto close order - roundedQuantity := t.roundToSzDecimals(coin, quantity) - logger.Infof(" 📏 Quantity precision handling: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin)) - - order := hyperliquid.CreateOrderRequest{ - Coin: coin, - IsBuy: true, - Size: roundedQuantity, - Price: aggressivePrice, - OrderType: hyperliquid.OrderType{ - Limit: &hyperliquid.LimitOrderType{ - Tif: hyperliquid.TifIoc, - }, - }, - ReduceOnly: true, - } - - _, err = t.exchange.Order(t.ctx, order, defaultBuilder) - if err != nil { - return nil, fmt.Errorf("failed to close short position: %w", err) - } - } - - logger.Infof("✓ Short position closed successfully: %s quantity: %.4f", symbol, quantity) - - // Cancel all pending orders for this coin after closing position - if err := t.CancelAllOrders(symbol); err != nil { - logger.Infof(" ⚠ Failed to cancel pending orders: %v", err) - } - - result := make(map[string]interface{}) - result["orderId"] = 0 - result["symbol"] = symbol - result["status"] = "FILLED" - - return result, nil -} - -// CancelStopLossOrders only cancels stop loss orders (Hyperliquid cannot distinguish stop loss and take profit, cancel all) -func (t *HyperliquidTrader) CancelStopLossOrders(symbol string) error { - // Hyperliquid SDK's OpenOrder structure does not expose trigger field - // Cannot distinguish stop loss and take profit orders, so cancel all pending orders for this coin - logger.Infof(" ⚠️ Hyperliquid cannot distinguish stop loss/take profit orders, will cancel all pending orders") - return t.CancelStopOrders(symbol) -} - -// CancelTakeProfitOrders only cancels take profit orders (Hyperliquid cannot distinguish stop loss and take profit, cancel all) -func (t *HyperliquidTrader) CancelTakeProfitOrders(symbol string) error { - // Hyperliquid SDK's OpenOrder structure does not expose trigger field - // Cannot distinguish stop loss and take profit orders, so cancel all pending orders for this coin - logger.Infof(" ⚠️ Hyperliquid cannot distinguish stop loss/take profit orders, will cancel all pending orders") - return t.CancelStopOrders(symbol) -} - -// CancelAllOrders cancels all pending orders for this coin -func (t *HyperliquidTrader) CancelAllOrders(symbol string) error { - coin := convertSymbolToHyperliquid(symbol) - - // Check if this is an xyz dex asset - isXyz := strings.HasPrefix(coin, "xyz:") - - if isXyz { - // xyz dex orders - use direct API call - return t.cancelXyzOrders(coin) - } - - // Standard crypto orders - openOrders, err := t.exchange.Info().OpenOrders(t.ctx, t.walletAddr) - if err != nil { - return fmt.Errorf("failed to get pending orders: %w", err) - } - - // Cancel all pending orders for this coin - for _, order := range openOrders { - if order.Coin == coin { - _, err := t.exchange.Cancel(t.ctx, coin, order.Oid) - if err != nil { - logger.Infof(" ⚠ Failed to cancel order (oid=%d): %v", order.Oid, err) - } - } - } - - logger.Infof(" ✓ Cancelled all pending orders for %s", symbol) - return nil -} - -// CancelStopOrders cancels take profit/stop loss orders for this coin (used to adjust TP/SL positions) -func (t *HyperliquidTrader) CancelStopOrders(symbol string) error { - coin := convertSymbolToHyperliquid(symbol) - - // Check if this is an xyz dex asset - isXyz := strings.HasPrefix(coin, "xyz:") - - if isXyz { - // xyz dex orders - use direct API call - return t.cancelXyzOrders(coin) - } - - // Get all pending orders for standard crypto - openOrders, err := t.exchange.Info().OpenOrders(t.ctx, t.walletAddr) - if err != nil { - return fmt.Errorf("failed to get pending orders: %w", err) - } - - // Note: Hyperliquid SDK's OpenOrder structure does not expose trigger field - // Therefore temporarily cancel all pending orders for this coin (including TP/SL orders) - // This is safe because all old orders should be cleaned up before setting new TP/SL - canceledCount := 0 - for _, order := range openOrders { - if order.Coin == coin { - _, err := t.exchange.Cancel(t.ctx, coin, order.Oid) - if err != nil { - logger.Infof(" ⚠ Failed to cancel order (oid=%d): %v", order.Oid, err) - continue - } - canceledCount++ - } - } - - if canceledCount == 0 { - logger.Infof(" ℹ No pending orders to cancel for %s", symbol) - } else { - logger.Infof(" ✓ Cancelled %d pending orders for %s (including TP/SL orders)", canceledCount, symbol) - } - - return nil -} - -// cancelXyzOrders cancels all pending orders for xyz dex assets (stocks, forex, commodities) -func (t *HyperliquidTrader) cancelXyzOrders(coin string) error { - // Query xyz dex open orders - reqBody := map[string]interface{}{ - "type": "openOrders", - "user": t.walletAddr, - "dex": "xyz", - } - - jsonBody, err := json.Marshal(reqBody) - if err != nil { - return fmt.Errorf("failed to marshal request: %w", err) - } - - apiURL := "https://api.hyperliquid.xyz/info" - - req, err := http.NewRequestWithContext(t.ctx, "POST", apiURL, bytes.NewBuffer(jsonBody)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("failed to execute request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("xyz dex openOrders API error (status %d): %s", resp.StatusCode, string(body)) - } - - // Parse open orders - var openOrders []struct { - Coin string `json:"coin"` - Oid int64 `json:"oid"` - } - if err := json.Unmarshal(body, &openOrders); err != nil { - return fmt.Errorf("failed to parse open orders: %w", err) - } - - // Filter orders for this coin and cancel them - canceledCount := 0 - for _, order := range openOrders { - if order.Coin == coin { - if err := t.cancelXyzOrder(order.Oid); err != nil { - logger.Infof(" ⚠ Failed to cancel xyz dex order (oid=%d): %v", order.Oid, err) - continue - } - canceledCount++ - } - } - - if canceledCount == 0 { - logger.Infof(" ℹ No pending xyz dex orders to cancel for %s", coin) - } else { - logger.Infof(" ✓ Cancelled %d xyz dex orders for %s", canceledCount, coin) - } - - return nil -} - -// cancelXyzOrder cancels a single xyz dex order by oid -func (t *HyperliquidTrader) cancelXyzOrder(oid int64) error { - // Get asset index for this order (we need it for cancel action) - // For cancel, we construct a cancel action with the oid - - action := map[string]interface{}{ - "type": "cancel", - "cancels": []map[string]interface{}{ - { - "a": oid, // asset index not needed for cancel by oid in xyz dex - "o": oid, - }, - }, - } - - // Sign the action - nonce := time.Now().UnixMilli() - isMainnet := !t.isTestnet - vaultAddress := "" - - sig, err := hyperliquid.SignL1Action(t.privateKey, action, vaultAddress, nonce, nil, isMainnet) - if err != nil { - return fmt.Errorf("failed to sign cancel action: %w", err) - } - - payload := map[string]any{ - "action": action, - "nonce": nonce, - "signature": sig, - } - - apiURL := hyperliquid.MainnetAPIURL - if t.isTestnet { - apiURL = hyperliquid.TestnetAPIURL - } - - jsonData, err := json.Marshal(payload) - if err != nil { - return fmt.Errorf("failed to marshal payload: %w", err) - } - - req, err := http.NewRequestWithContext(t.ctx, http.MethodPost, apiURL+"/exchange", bytes.NewBuffer(jsonData)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("request failed: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read response: %w", err) - } - - // Check response - var result struct { - Status string `json:"status"` - } - if err := json.Unmarshal(body, &result); err != nil { - return fmt.Errorf("failed to parse response: %w", err) - } - - if result.Status != "ok" { - return fmt.Errorf("cancel failed: %s", string(body)) - } - - return nil -} - -// GetMarketPrice gets market price (supports both crypto and xyz dex assets) -func (t *HyperliquidTrader) GetMarketPrice(symbol string) (float64, error) { - coin := convertSymbolToHyperliquid(symbol) - - // Check if this is an xyz dex asset - if strings.HasPrefix(coin, "xyz:") { - return t.getXyzMarketPrice(coin) - } - - // Get all market prices for crypto - allMids, err := t.exchange.Info().AllMids(t.ctx) - if err != nil { - return 0, fmt.Errorf("failed to get price: %w", err) - } - - // Find price for corresponding coin (allMids is map[string]string) - if priceStr, ok := allMids[coin]; ok { - priceFloat, err := strconv.ParseFloat(priceStr, 64) - if err == nil { - return priceFloat, nil - } - return 0, fmt.Errorf("price format error: %v", err) - } - - return 0, fmt.Errorf("price not found for %s", symbol) -} - -// getXyzMarketPrice gets market price for xyz dex assets -func (t *HyperliquidTrader) getXyzMarketPrice(coin string) (float64, error) { - // Build request for xyz dex allMids - reqBody := map[string]string{ - "type": "allMids", - "dex": "xyz", - } - - jsonBody, err := json.Marshal(reqBody) - if err != nil { - return 0, fmt.Errorf("failed to marshal request: %w", err) - } - - apiURL := "https://api.hyperliquid.xyz/info" - - req, err := http.NewRequestWithContext(t.ctx, "POST", apiURL, bytes.NewBuffer(jsonBody)) - if err != nil { - return 0, fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - return 0, fmt.Errorf("failed to execute request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return 0, fmt.Errorf("failed to read response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return 0, fmt.Errorf("xyz dex allMids API error (status %d): %s", resp.StatusCode, string(body)) - } - - var mids map[string]string - if err := json.Unmarshal(body, &mids); err != nil { - return 0, fmt.Errorf("failed to parse response: %w", err) - } - - // The API returns keys with xyz: prefix, so ensure the coin has it - lookupKey := coin - if !strings.HasPrefix(lookupKey, "xyz:") { - lookupKey = "xyz:" + lookupKey - } - - if priceStr, ok := mids[lookupKey]; ok { - priceFloat, err := strconv.ParseFloat(priceStr, 64) - if err == nil { - return priceFloat, nil - } - return 0, fmt.Errorf("price format error: %v", err) - } - - return 0, fmt.Errorf("xyz dex price not found for %s (lookup key: %s)", coin, lookupKey) -} - -// floatToWireStr converts a float to wire format string (8 decimal places, trimmed zeros) -// This matches the SDK's floatToWire function -func floatToWireStr(x float64) string { - // Format to 8 decimal places - result := fmt.Sprintf("%.8f", x) - // Remove trailing zeros - result = strings.TrimRight(result, "0") - // Remove trailing decimal point if no decimals left - result = strings.TrimRight(result, ".") - return result -} - -// placeXyzOrder places an order on the xyz dex (stocks, forex, commodities) -// Note: xyz dex orders use builder-deployed perpetuals and require different handling -// xyz dex asset indices start from 10000 (10000 + meta_index) -// This implementation bypasses the SDK's NameToAsset lookup and directly constructs the order -func (t *HyperliquidTrader) placeXyzOrder(coin string, isBuy bool, size float64, price float64, reduceOnly bool) error { - // Fetch xyz meta if not cached - t.xyzMetaMutex.RLock() - hasMeta := t.xyzMeta != nil - t.xyzMetaMutex.RUnlock() - - if !hasMeta { - if err := t.fetchXyzMeta(); err != nil { - return fmt.Errorf("failed to fetch xyz meta: %w", err) - } - } - - // Get asset index from xyz meta (returns 0-based index) - metaIndex := t.getXyzAssetIndex(coin) - if metaIndex < 0 { - return fmt.Errorf("xyz asset %s not found in meta", coin) - } - - // HIP-3 perp dex asset index formula: 100000 + perp_dex_index * 10000 + index_in_meta - // xyz dex is at perp_dex_index = 1 (verified from perpDexs API: [null, {name:"xyz",...}]) - // So xyz asset index = 100000 + 1 * 10000 + metaIndex = 110000 + metaIndex - const xyzPerpDexIndex = 1 - assetIndex := 100000 + xyzPerpDexIndex*10000 + metaIndex - - // Round size to correct precision - szDecimals := t.getXyzSzDecimals(coin) - multiplier := 1.0 - for i := 0; i < szDecimals; i++ { - multiplier *= 10.0 - } - roundedSize := float64(int(size*multiplier+0.5)) / multiplier - - // Round price to 5 significant figures - roundedPrice := t.roundPriceToSigfigs(price) - - logger.Infof("📝 Placing xyz dex order (direct): %s %s size=%.4f price=%.4f metaIndex=%d assetIndex=%d (formula: 100000 + 1*10000 + %d) reduceOnly=%v", - map[bool]string{true: "BUY", false: "SELL"}[isBuy], - coin, roundedSize, roundedPrice, metaIndex, assetIndex, metaIndex, reduceOnly) - - // Construct OrderWire directly with correct asset index (bypassing SDK's NameToAsset) - orderWire := hyperliquid.OrderWire{ - Asset: assetIndex, - IsBuy: isBuy, - LimitPx: floatToWireStr(roundedPrice), - Size: floatToWireStr(roundedSize), - ReduceOnly: reduceOnly, - OrderType: hyperliquid.OrderWireType{ - Limit: &hyperliquid.OrderWireTypeLimit{ - Tif: hyperliquid.TifIoc, - }, - }, - } - - // Create OrderAction (no builder to avoid requiring builder fee approval) - action := hyperliquid.OrderAction{ - Type: "order", - Orders: []hyperliquid.OrderWire{orderWire}, - Grouping: "na", - Builder: nil, - } - - // Sign the action - nonce := time.Now().UnixMilli() - isMainnet := !t.isTestnet - vaultAddress := "" // No vault for personal account - - sig, err := hyperliquid.SignL1Action(t.privateKey, action, vaultAddress, nonce, nil, isMainnet) - if err != nil { - return fmt.Errorf("failed to sign xyz dex order: %w", err) - } - - // Construct payload for /exchange endpoint - payload := map[string]any{ - "action": action, - "nonce": nonce, - "signature": sig, - } - - // Determine API URL - apiURL := hyperliquid.MainnetAPIURL - if t.isTestnet { - apiURL = hyperliquid.TestnetAPIURL - } - - // POST to /exchange - jsonData, err := json.Marshal(payload) - if err != nil { - return fmt.Errorf("failed to marshal payload: %w", err) - } - - logger.Infof("📤 Sending xyz dex order to %s/exchange", apiURL) - - req, err := http.NewRequestWithContext(t.ctx, http.MethodPost, apiURL+"/exchange", bytes.NewBuffer(jsonData)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("request failed: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read response body: %w", err) - } - - // Parse response - var result struct { - Status string `json:"status"` - Response struct { - Type string `json:"type"` - Data struct { - Statuses []struct { - Resting *struct { - Oid int64 `json:"oid"` - } `json:"resting,omitempty"` - Filled *struct { - TotalSz string `json:"totalSz"` - AvgPx string `json:"avgPx"` - Oid int `json:"oid"` - } `json:"filled,omitempty"` - Error *string `json:"error,omitempty"` - } `json:"statuses"` - } `json:"data"` - } `json:"response"` - } - - if err := json.Unmarshal(body, &result); err != nil { - // Try to parse as error response - logger.Infof("⚠️ Failed to parse response as success, raw body: %s", string(body)) - return fmt.Errorf("xyz dex order failed, status=%d, body=%s", resp.StatusCode, string(body)) - } - - // Check for errors in response - if result.Status != "ok" { - return fmt.Errorf("xyz dex order failed: status=%s, body=%s", result.Status, string(body)) - } - - // Check order statuses - if len(result.Response.Data.Statuses) > 0 { - status := result.Response.Data.Statuses[0] - if status.Error != nil { - return fmt.Errorf("xyz dex order error (coin=%s, assetIndex=%d, size=%.4f, price=%.4f): %s", coin, assetIndex, roundedSize, roundedPrice, *status.Error) - } - if status.Filled != nil { - logger.Infof("✅ xyz dex order filled: totalSz=%s avgPx=%s oid=%d", - status.Filled.TotalSz, status.Filled.AvgPx, status.Filled.Oid) - } else if status.Resting != nil { - logger.Infof("✅ xyz dex order resting: oid=%d", status.Resting.Oid) - } - } - - logger.Infof("✅ xyz dex order placed successfully: %s (response: %s)", coin, string(body)) - return nil -} - -// getXyzAssetIndex gets the asset index for an xyz dex asset -func (t *HyperliquidTrader) getXyzAssetIndex(baseCoin string) int { - t.xyzMetaMutex.RLock() - defer t.xyzMetaMutex.RUnlock() - - if t.xyzMeta == nil { - return -1 - } - - // The meta API returns names with xyz: prefix, so ensure we match correctly - lookupName := baseCoin - if !strings.HasPrefix(lookupName, "xyz:") { - lookupName = "xyz:" + lookupName - } - - for i, asset := range t.xyzMeta.Universe { - if asset.Name == lookupName { - return i - } - } - return -1 -} - -// placeXyzTriggerOrder places a trigger order (stop loss / take profit) on the xyz dex -// tpsl: "sl" for stop loss, "tp" for take profit -func (t *HyperliquidTrader) placeXyzTriggerOrder(coin string, isBuy bool, size float64, triggerPrice float64, tpsl string) error { - // Fetch xyz meta if not cached - t.xyzMetaMutex.RLock() - hasMeta := t.xyzMeta != nil - t.xyzMetaMutex.RUnlock() - - if !hasMeta { - if err := t.fetchXyzMeta(); err != nil { - return fmt.Errorf("failed to fetch xyz meta: %w", err) - } - } - - // Get asset index from xyz meta (returns 0-based index) - metaIndex := t.getXyzAssetIndex(coin) - if metaIndex < 0 { - return fmt.Errorf("xyz asset %s not found in meta", coin) - } - - // HIP-3 perp dex asset index formula: 100000 + perp_dex_index * 10000 + index_in_meta - // xyz dex is at perp_dex_index = 1 - const xyzPerpDexIndex = 1 - assetIndex := 100000 + xyzPerpDexIndex*10000 + metaIndex - - // Round size to correct precision - szDecimals := t.getXyzSzDecimals(coin) - multiplier := 1.0 - for i := 0; i < szDecimals; i++ { - multiplier *= 10.0 - } - roundedSize := float64(int(size*multiplier+0.5)) / multiplier - - // Round price to 5 significant figures - roundedPrice := t.roundPriceToSigfigs(triggerPrice) - - logger.Infof("📝 Placing xyz dex %s order: %s %s size=%.4f triggerPrice=%.4f assetIndex=%d", - tpsl, - map[bool]string{true: "BUY", false: "SELL"}[isBuy], - coin, roundedSize, roundedPrice, assetIndex) - - // Construct OrderWire with trigger type for stop loss / take profit - orderWire := hyperliquid.OrderWire{ - Asset: assetIndex, - IsBuy: isBuy, - LimitPx: floatToWireStr(roundedPrice), - Size: floatToWireStr(roundedSize), - ReduceOnly: true, // TP/SL orders are always reduce-only - OrderType: hyperliquid.OrderWireType{ - Trigger: &hyperliquid.OrderWireTypeTrigger{ - TriggerPx: floatToWireStr(roundedPrice), - IsMarket: true, - Tpsl: hyperliquid.Tpsl(tpsl), // "sl" or "tp" - convert string to Tpsl type - }, - }, - } - - // Create OrderAction (no builder to avoid requiring builder fee approval) - action := hyperliquid.OrderAction{ - Type: "order", - Orders: []hyperliquid.OrderWire{orderWire}, - Grouping: "na", - Builder: nil, - } - - // Sign the action - nonce := time.Now().UnixMilli() - isMainnet := !t.isTestnet - vaultAddress := "" - - sig, err := hyperliquid.SignL1Action(t.privateKey, action, vaultAddress, nonce, nil, isMainnet) - if err != nil { - return fmt.Errorf("failed to sign xyz dex trigger order: %w", err) - } - - // Construct payload for /exchange endpoint - payload := map[string]any{ - "action": action, - "nonce": nonce, - "signature": sig, - } - - // Determine API URL - apiURL := hyperliquid.MainnetAPIURL - if t.isTestnet { - apiURL = hyperliquid.TestnetAPIURL - } - - // POST to /exchange - jsonData, err := json.Marshal(payload) - if err != nil { - return fmt.Errorf("failed to marshal payload: %w", err) - } - - logger.Infof("📤 Sending xyz dex %s order to %s/exchange", tpsl, apiURL) - - req, err := http.NewRequestWithContext(t.ctx, http.MethodPost, apiURL+"/exchange", bytes.NewBuffer(jsonData)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("request failed: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read response body: %w", err) - } - - // Parse response - var result struct { - Status string `json:"status"` - Response struct { - Type string `json:"type"` - Data struct { - Statuses []struct { - Resting *struct { - Oid int64 `json:"oid"` - } `json:"resting,omitempty"` - Error *string `json:"error,omitempty"` - } `json:"statuses"` - } `json:"data"` - } `json:"response"` - } - - if err := json.Unmarshal(body, &result); err != nil { - logger.Infof("⚠️ Failed to parse response, raw body: %s", string(body)) - return fmt.Errorf("xyz dex %s order failed, status=%d, body=%s", tpsl, resp.StatusCode, string(body)) - } - - // Check for errors in response - if result.Status != "ok" { - return fmt.Errorf("xyz dex %s order failed: status=%s, body=%s", tpsl, result.Status, string(body)) - } - - // Check order statuses - if len(result.Response.Data.Statuses) > 0 { - status := result.Response.Data.Statuses[0] - if status.Error != nil { - return fmt.Errorf("xyz dex %s order error: %s", tpsl, *status.Error) - } - if status.Resting != nil { - logger.Infof("✅ xyz dex %s order placed: oid=%d", tpsl, status.Resting.Oid) - } - } - - logger.Infof("✅ xyz dex %s order placed successfully: %s", tpsl, coin) - return nil -} - -// SetStopLoss sets stop loss order -func (t *HyperliquidTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error { - coin := convertSymbolToHyperliquid(symbol) - - isBuy := positionSide == "SHORT" // Short position stop loss = buy, long position stop loss = sell - - // ⚠️ Critical: Price needs to be processed to 5 significant figures - roundedStopPrice := t.roundPriceToSigfigs(stopPrice) - - // Check if this is an xyz dex asset (stocks, forex, commodities) - isXyz := strings.HasPrefix(coin, "xyz:") - - if isXyz { - // xyz dex stop loss order - use direct API call similar to placeXyzOrder - if err := t.placeXyzTriggerOrder(coin, isBuy, quantity, roundedStopPrice, "sl"); err != nil { - return fmt.Errorf("failed to set xyz dex stop loss: %w", err) - } - } else { - // Standard crypto stop loss order - // ⚠️ Critical: Round quantity according to coin precision requirements - roundedQuantity := t.roundToSzDecimals(coin, quantity) - - // Create stop loss order (Trigger Order) - order := hyperliquid.CreateOrderRequest{ - Coin: coin, - IsBuy: isBuy, - Size: roundedQuantity, // Use rounded quantity - Price: roundedStopPrice, // Use processed price - OrderType: hyperliquid.OrderType{ - Trigger: &hyperliquid.TriggerOrderType{ - TriggerPx: roundedStopPrice, - IsMarket: true, - Tpsl: "sl", // stop loss - }, - }, - ReduceOnly: true, - } - - _, err := t.exchange.Order(t.ctx, order, defaultBuilder) - if err != nil { - return fmt.Errorf("failed to set stop loss: %w", err) - } - } - - logger.Infof(" Stop loss price set: %.4f", roundedStopPrice) - return nil -} - -// SetTakeProfit sets take profit order -func (t *HyperliquidTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error { - coin := convertSymbolToHyperliquid(symbol) - - isBuy := positionSide == "SHORT" // Short position take profit = buy, long position take profit = sell - - // ⚠️ Critical: Price needs to be processed to 5 significant figures - roundedTakeProfitPrice := t.roundPriceToSigfigs(takeProfitPrice) - - // Check if this is an xyz dex asset (stocks, forex, commodities) - isXyz := strings.HasPrefix(coin, "xyz:") - - if isXyz { - // xyz dex take profit order - use direct API call similar to placeXyzOrder - if err := t.placeXyzTriggerOrder(coin, isBuy, quantity, roundedTakeProfitPrice, "tp"); err != nil { - return fmt.Errorf("failed to set xyz dex take profit: %w", err) - } - } else { - // Standard crypto take profit order - // ⚠️ Critical: Round quantity according to coin precision requirements - roundedQuantity := t.roundToSzDecimals(coin, quantity) - - // Create take profit order (Trigger Order) - order := hyperliquid.CreateOrderRequest{ - Coin: coin, - IsBuy: isBuy, - Size: roundedQuantity, // Use rounded quantity - Price: roundedTakeProfitPrice, // Use processed price - OrderType: hyperliquid.OrderType{ - Trigger: &hyperliquid.TriggerOrderType{ - TriggerPx: roundedTakeProfitPrice, - IsMarket: true, - Tpsl: "tp", // take profit - }, - }, - ReduceOnly: true, - } - - _, err := t.exchange.Order(t.ctx, order, defaultBuilder) - if err != nil { - return fmt.Errorf("failed to set take profit: %w", err) - } - } - - logger.Infof(" Take profit price set: %.4f", roundedTakeProfitPrice) - return nil -} - // FormatQuantity formats quantity to correct precision func (t *HyperliquidTrader) FormatQuantity(symbol string, quantity float64) (string, error) { coin := convertSymbolToHyperliquid(symbol) @@ -1812,7 +241,7 @@ func (t *HyperliquidTrader) FormatQuantity(symbol string, quantity float64) (str // getSzDecimals gets quantity precision for coin func (t *HyperliquidTrader) getSzDecimals(coin string) int { - // ✅ Concurrency safe: Use read lock to protect meta field access + // Concurrency safe: Use read lock to protect meta field access t.metaMutex.RLock() defer t.metaMutex.RUnlock() @@ -1883,366 +312,3 @@ func (t *HyperliquidTrader) roundPriceToSigfigs(price float64) float64 { rounded := float64(int(price*multiplier+0.5)) / multiplier return rounded } - -// convertSymbolToHyperliquid converts standard symbol to Hyperliquid format -// Example: "BTCUSDT" -> "BTC", "TSLA" -> "xyz:TSLA", "silver" -> "xyz:SILVER" -func convertSymbolToHyperliquid(symbol string) string { - // Convert to uppercase for consistent handling - base := strings.ToUpper(symbol) - - // Remove common suffixes to get base symbol - for _, suffix := range []string{"USDT", "USD", "-USDC", "-USD"} { - if strings.HasSuffix(base, suffix) { - base = strings.TrimSuffix(base, suffix) - break - } - } - // Remove xyz: prefix if present (case-insensitive, will be re-added if needed) - if strings.HasPrefix(strings.ToLower(base), "xyz:") { - base = base[4:] // Remove first 4 characters - } - - // Check if this is an xyz dex asset (stocks, forex, commodities) - if isXyzDexAsset(base) { - return "xyz:" + base - } - return base -} - -// GetOrderStatus gets order status -// Hyperliquid uses IOC orders, usually filled or cancelled immediately -// For completed orders, need to query historical records -func (t *HyperliquidTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { - // Hyperliquid's IOC orders are completed almost immediately - // If order was placed through this system, returned status will be FILLED - // Try to query open orders to determine if still pending - coin := convertSymbolToHyperliquid(symbol) - - // First check if in open orders - openOrders, err := t.exchange.Info().OpenOrders(t.ctx, t.walletAddr) - if err != nil { - // If query fails, assume order is completed - return map[string]interface{}{ - "orderId": orderID, - "status": "FILLED", - "avgPrice": 0.0, - "executedQty": 0.0, - "commission": 0.0, - }, nil - } - - // Check if order is in open orders list - for _, order := range openOrders { - if order.Coin == coin && fmt.Sprintf("%d", order.Oid) == orderID { - // Order is still pending - return map[string]interface{}{ - "orderId": orderID, - "status": "NEW", - "avgPrice": 0.0, - "executedQty": 0.0, - "commission": 0.0, - }, nil - } - } - - // Order not in open list, meaning completed or cancelled - // Hyperliquid IOC orders not in open list are usually filled - return map[string]interface{}{ - "orderId": orderID, - "status": "FILLED", - "avgPrice": 0.0, // Hyperliquid does not directly return execution price, need to get from position info - "executedQty": 0.0, - "commission": 0.0, - }, nil -} - -// absFloat returns absolute value of float -func absFloat(x float64) float64 { - if x < 0 { - return -x - } - return x -} - -// GetClosedPnL gets recent closing trades from Hyperliquid -// Note: Hyperliquid does NOT have a position history API, only fill history. -// This returns individual closing trades for real-time position closure detection. -func (t *HyperliquidTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) { - trades, err := t.GetTrades(startTime, limit) - if err != nil { - return nil, err - } - - // Filter only closing trades (realizedPnl != 0) - var records []types.ClosedPnLRecord - for _, trade := range trades { - if trade.RealizedPnL == 0 { - continue - } - - // Determine side (Hyperliquid uses one-way mode) - side := "long" - if trade.Side == "SELL" || trade.Side == "Sell" { - side = "long" // Selling closes long - } else { - side = "short" // Buying closes short - } - - // Calculate entry price from PnL - var entryPrice float64 - if trade.Quantity > 0 { - if side == "long" { - entryPrice = trade.Price - trade.RealizedPnL/trade.Quantity - } else { - entryPrice = trade.Price + trade.RealizedPnL/trade.Quantity - } - } - - records = append(records, types.ClosedPnLRecord{ - Symbol: trade.Symbol, - Side: side, - EntryPrice: entryPrice, - ExitPrice: trade.Price, - Quantity: trade.Quantity, - RealizedPnL: trade.RealizedPnL, - Fee: trade.Fee, - ExitTime: trade.Time, - EntryTime: trade.Time, - OrderID: trade.TradeID, - ExchangeID: trade.TradeID, - CloseType: "unknown", - }) - } - - return records, nil -} - -// GetTrades retrieves trade history from Hyperliquid -func (t *HyperliquidTrader) GetTrades(startTime time.Time, limit int) ([]types.TradeRecord, error) { - // Use UserFillsByTime API - startTimeMs := startTime.UnixMilli() - fills, err := t.exchange.Info().UserFillsByTime(t.ctx, t.walletAddr, startTimeMs, nil, nil) - if err != nil { - return nil, fmt.Errorf("failed to get user fills: %w", err) - } - - var trades []types.TradeRecord - for _, fill := range fills { - price, _ := strconv.ParseFloat(fill.Price, 64) - qty, _ := strconv.ParseFloat(fill.Size, 64) - fee, _ := strconv.ParseFloat(fill.Fee, 64) - pnl, _ := strconv.ParseFloat(fill.ClosedPnl, 64) - - // Determine side: "B" = Buy, "S" = Sell (or "A" = Ask, "B" = Bid) - var side string - if fill.Side == "B" || fill.Side == "Buy" || fill.Side == "bid" { - side = "BUY" - } else { - side = "SELL" - } - - // Parse Dir field to get order action - // Hyperliquid Dir values: "Open Long", "Open Short", "Close Long", "Close Short" - var orderAction string - switch strings.ToLower(fill.Dir) { - case "open long": - orderAction = "open_long" - case "open short": - orderAction = "open_short" - case "close long": - orderAction = "close_long" - case "close short": - orderAction = "close_short" - default: - // Fallback: use RealizedPnL if Dir is missing/unknown - if pnl != 0 { - if side == "BUY" { - orderAction = "close_short" - } else { - orderAction = "close_long" - } - } else { - if side == "BUY" { - orderAction = "open_long" - } else { - orderAction = "open_short" - } - } - } - - // Hyperliquid uses one-way mode, so PositionSide is "BOTH" - trade := types.TradeRecord{ - TradeID: strconv.FormatInt(fill.Tid, 10), - Symbol: fill.Coin, - Side: side, - PositionSide: "BOTH", // Hyperliquid doesn't have hedge mode - OrderAction: orderAction, - Price: price, - Quantity: qty, - RealizedPnL: pnl, - Fee: fee, - Time: time.UnixMilli(fill.Time).UTC(), - } - trades = append(trades, trade) - } - - return trades, nil -} - -// defaultBuilder is the builder info for order routing -// Set to nil to avoid requiring builder fee approval -// -// var defaultBuilder = &hyperliquid.BuilderInfo{ -// Builder: "0x891dc6f05ad47a3c1a05da55e7a7517971faaf0d", -// Fee: 10, -// } -var defaultBuilder *hyperliquid.BuilderInfo = nil - -// GetOpenOrders gets all open/pending orders for a symbol -func (t *HyperliquidTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) { - openOrders, err := t.exchange.Info().OpenOrders(t.ctx, t.walletAddr) - if err != nil { - return nil, fmt.Errorf("failed to get open orders: %w", err) - } - - var result []types.OpenOrder - for _, order := range openOrders { - if order.Coin != symbol { - continue - } - - side := "BUY" - if order.Side == "A" { - side = "SELL" - } - - result = append(result, types.OpenOrder{ - OrderID: fmt.Sprintf("%d", order.Oid), - Symbol: order.Coin, - Side: side, - PositionSide: "", - Type: "LIMIT", - Price: order.LimitPx, - StopPrice: 0, - Quantity: order.Size, - Status: "NEW", - }) - } - - return result, nil -} - -// PlaceLimitOrder places a limit order for grid trading -// Implements GridTrader interface -func (t *HyperliquidTrader) PlaceLimitOrder(req *types.LimitOrderRequest) (*types.LimitOrderResult, error) { - coin := convertSymbolToHyperliquid(req.Symbol) - - // Set leverage if specified and not xyz dex - isXyz := strings.HasPrefix(coin, "xyz:") - if req.Leverage > 0 && !isXyz { - if err := t.SetLeverage(req.Symbol, req.Leverage); err != nil { - logger.Warnf("[Hyperliquid] Failed to set leverage: %v", err) - } - } - - // Round quantity to allowed decimals - roundedQuantity := t.roundToSzDecimals(coin, req.Quantity) - - // Round price to 5 significant figures - roundedPrice := t.roundPriceToSigfigs(req.Price) - - // Determine if buy or sell - isBuy := req.Side == "BUY" - - logger.Infof("[Hyperliquid] PlaceLimitOrder: %s %s @ %.4f, qty=%.4f", coin, req.Side, roundedPrice, roundedQuantity) - - order := hyperliquid.CreateOrderRequest{ - Coin: coin, - IsBuy: isBuy, - Size: roundedQuantity, - Price: roundedPrice, - OrderType: hyperliquid.OrderType{ - Limit: &hyperliquid.LimitOrderType{ - Tif: hyperliquid.TifGtc, // Good Till Cancel for grid orders - }, - }, - ReduceOnly: req.ReduceOnly, - } - - _, err := t.exchange.Order(t.ctx, order, defaultBuilder) - if err != nil { - return nil, fmt.Errorf("failed to place limit order: %w", err) - } - - // Note: Hyperliquid's Order response doesn't return the order ID directly - // We would need to query open orders to get it, but for grid trading - // we can track orders by price level instead - orderID := fmt.Sprintf("%d", time.Now().UnixNano()) - - logger.Infof("✓ [Hyperliquid] Limit order placed: %s %s @ %.4f", - coin, req.Side, roundedPrice) - - return &types.LimitOrderResult{ - OrderID: orderID, - ClientID: req.ClientID, - Symbol: req.Symbol, - Side: req.Side, - PositionSide: req.PositionSide, - Price: roundedPrice, - Quantity: roundedQuantity, - Status: "NEW", - }, nil -} - -// CancelOrder cancels a specific order by ID -// Implements GridTrader interface -func (t *HyperliquidTrader) CancelOrder(symbol, orderID string) error { - coin := convertSymbolToHyperliquid(symbol) - - // Parse order ID - oid, err := strconv.ParseInt(orderID, 10, 64) - if err != nil { - return fmt.Errorf("invalid order ID: %w", err) - } - - _, err = t.exchange.Cancel(t.ctx, coin, oid) - if err != nil { - return fmt.Errorf("failed to cancel order: %w", err) - } - - logger.Infof("✓ [Hyperliquid] Order cancelled: %s %s", symbol, orderID) - return nil -} - -// GetOrderBook gets the order book for a symbol -// Implements GridTrader interface -func (t *HyperliquidTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { - coin := convertSymbolToHyperliquid(symbol) - - l2Book, err := t.exchange.Info().L2Snapshot(t.ctx, coin) - if err != nil { - return nil, nil, fmt.Errorf("failed to get order book: %w", err) - } - - if l2Book == nil || len(l2Book.Levels) < 2 { - return nil, nil, fmt.Errorf("invalid order book data") - } - - // Parse bids (first level array) - for i, level := range l2Book.Levels[0] { - if i >= depth { - break - } - bids = append(bids, []float64{level.Px, level.Sz}) - } - - // Parse asks (second level array) - for i, level := range l2Book.Levels[1] { - if i >= depth { - break - } - asks = append(asks, []float64{level.Px, level.Sz}) - } - - return bids, asks, nil -} diff --git a/trader/hyperliquid/trader_account.go b/trader/hyperliquid/trader_account.go new file mode 100644 index 00000000..f556455a --- /dev/null +++ b/trader/hyperliquid/trader_account.go @@ -0,0 +1,592 @@ +package hyperliquid + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "nofx/logger" + "nofx/trader/types" + "strconv" + "strings" + "time" +) + +// GetBalance gets account balance +func (t *HyperliquidTrader) GetBalance() (map[string]interface{}, error) { + logger.Infof("🔄 Calling Hyperliquid API to get account balance...") + + // Step 1: Query Spot account balance + spotState, err := t.exchange.Info().SpotUserState(t.ctx, t.walletAddr) + var spotUSDCBalance float64 = 0.0 + if err != nil { + logger.Infof("⚠️ Failed to query Spot balance (may have no spot assets): %v", err) + } else if spotState != nil && len(spotState.Balances) > 0 { + for _, balance := range spotState.Balances { + if balance.Coin == "USDC" { + spotUSDCBalance, _ = strconv.ParseFloat(balance.Total, 64) + logger.Infof("✓ Found Spot balance: %.2f USDC", spotUSDCBalance) + break + } + } + } + + // Step 2: Query Perpetuals contract account status + accountState, err := t.exchange.Info().UserState(t.ctx, t.walletAddr) + if err != nil { + logger.Infof("❌ Hyperliquid Perpetuals API call failed: %v", err) + return nil, fmt.Errorf("failed to get account information: %w", err) + } + + // Parse balance information (MarginSummary fields are all strings) + result := make(map[string]interface{}) + + // Step 3: Dynamically select correct summary based on margin mode (CrossMarginSummary or MarginSummary) + var accountValue, totalMarginUsed float64 + var summaryType string + var summary interface{} + + if t.isCrossMargin { + // Cross margin mode: use CrossMarginSummary + accountValue, _ = strconv.ParseFloat(accountState.CrossMarginSummary.AccountValue, 64) + totalMarginUsed, _ = strconv.ParseFloat(accountState.CrossMarginSummary.TotalMarginUsed, 64) + summaryType = "CrossMarginSummary (cross margin)" + summary = accountState.CrossMarginSummary + } else { + // Isolated margin mode: use MarginSummary + accountValue, _ = strconv.ParseFloat(accountState.MarginSummary.AccountValue, 64) + totalMarginUsed, _ = strconv.ParseFloat(accountState.MarginSummary.TotalMarginUsed, 64) + summaryType = "MarginSummary (isolated margin)" + summary = accountState.MarginSummary + } + + // Debug: Print complete summary structure returned by API + summaryJSON, _ := json.MarshalIndent(summary, " ", " ") + logger.Infof("🔍 [DEBUG] Hyperliquid API %s complete data:", summaryType) + logger.Infof("%s", string(summaryJSON)) + + // Critical fix: Accumulate actual unrealized PnL from all positions + totalUnrealizedPnl := 0.0 + for _, assetPos := range accountState.AssetPositions { + unrealizedPnl, _ := strconv.ParseFloat(assetPos.Position.UnrealizedPnl, 64) + totalUnrealizedPnl += unrealizedPnl + } + + // Correctly understand Hyperliquid fields: + // AccountValue = Total account equity (includes idle funds + position value + unrealized PnL) + // TotalMarginUsed = Margin used by positions (included in AccountValue, for display only) + // + // To be compatible with auto_types.go calculation logic (totalEquity = totalWalletBalance + totalUnrealizedProfit) + // Need to return "wallet balance without unrealized PnL" + walletBalanceWithoutUnrealized := accountValue - totalUnrealizedPnl + + // Step 4: Use Withdrawable field (PR #443) + // Withdrawable is the official real withdrawable balance, more reliable than simple calculation + availableBalance := 0.0 + if accountState.Withdrawable != "" { + withdrawable, err := strconv.ParseFloat(accountState.Withdrawable, 64) + if err == nil && withdrawable > 0 { + availableBalance = withdrawable + logger.Infof("✓ Using Withdrawable as available balance: %.2f", availableBalance) + } + } + + // Fallback: If no Withdrawable, use simple calculation + if availableBalance == 0 && accountState.Withdrawable == "" { + availableBalance = accountValue - totalMarginUsed + if availableBalance < 0 { + logger.Infof("⚠️ Calculated available balance is negative (%.2f), reset to 0", availableBalance) + availableBalance = 0 + } + } + + // Step 5: Query xyz dex balance (stock perps, forex, commodities) + var xyzAccountValue, xyzUnrealizedPnl float64 + var xyzPositions []xyzAssetPosition + xyzAccountValue, xyzUnrealizedPnl, xyzPositions, err = t.getXYZDexBalance() + if err != nil { + // xyz dex query failed - log warning but don't fail the entire balance query + logger.Infof("⚠️ Failed to query xyz dex balance: %v", err) + } + // Always log xyz dex state for debugging + logger.Infof("🔍 xyz dex state: accountValue=%.4f, unrealizedPnl=%.4f, positions=%d", + xyzAccountValue, xyzUnrealizedPnl, len(xyzPositions)) + for _, pos := range xyzPositions { + entryPx := "nil" + if pos.Position.EntryPx != nil { + entryPx = *pos.Position.EntryPx + } + logger.Infof(" └─ %s: size=%s, entryPx=%s, posValue=%s, pnl=%s", + pos.Position.Coin, pos.Position.Szi, entryPx, pos.Position.PositionValue, pos.Position.UnrealizedPnl) + } + xyzWalletBalance := xyzAccountValue - xyzUnrealizedPnl + + // Step 6: Correctly handle Spot + Perpetuals + xyz dex balance + // Important: Each account is independent, manual transfers required + totalWalletBalance := walletBalanceWithoutUnrealized + spotUSDCBalance + xyzWalletBalance + totalUnrealizedPnlAll := totalUnrealizedPnl + xyzUnrealizedPnl + + // Calculate total equity properly: perpAccountValue + spotUSDCBalance + xyzAccountValue + // Note: totalWalletBalance + totalUnrealizedPnlAll should equal this + totalEquityCalculated := accountValue + spotUSDCBalance + xyzAccountValue + + // Step 7: Unified Account mode - Spot USDC is used as collateral for Perps + // In this mode, available balance includes Spot USDC since it can be used for Perp margin + if t.isUnifiedAccount && spotUSDCBalance > 0 { + // Add Spot balance to available balance for trading + availableBalance = availableBalance + spotUSDCBalance + logger.Infof("✓ Unified Account: Spot %.2f USDC added to available balance (total: %.2f)", + spotUSDCBalance, availableBalance) + } + + // Suppress unused variable warning + _ = totalUnrealizedPnlAll + + result["totalWalletBalance"] = totalWalletBalance // Total assets (Perp + Spot + xyz) - unrealized + result["totalEquity"] = totalEquityCalculated // Total equity = Perp AV + Spot + xyz AV + result["availableBalance"] = availableBalance // Available balance (Perp + Spot if unified) + result["totalUnrealizedProfit"] = totalUnrealizedPnlAll // Unrealized PnL (Perpetuals + xyz) + result["spotBalance"] = spotUSDCBalance // Spot balance + result["xyzDexBalance"] = xyzAccountValue // xyz dex equity (stock perps, forex, commodities) + result["xyzDexUnrealizedPnl"] = xyzUnrealizedPnl // xyz dex unrealized PnL + result["perpAccountValue"] = accountValue // Perp account value for debugging + + logger.Infof("✓ Hyperliquid complete account:") + logger.Infof(" • Spot balance: %.2f USDC", spotUSDCBalance) + logger.Infof(" • Perpetuals equity: %.2f USDC (wallet %.2f + unrealized %.2f)", + accountValue, + walletBalanceWithoutUnrealized, + totalUnrealizedPnl) + logger.Infof(" • Perpetuals available balance: %.2f USDC", availableBalance) + logger.Infof(" • Margin used: %.2f USDC", totalMarginUsed) + logger.Infof(" • xyz dex equity: %.2f USDC (wallet %.2f + unrealized %.2f)", + xyzAccountValue, + xyzWalletBalance, + xyzUnrealizedPnl) + logger.Infof(" • Total assets (Perp+Spot+xyz): %.2f USDC", totalWalletBalance) + logger.Infof(" ⭐ Total: %.2f USDC | Perp: %.2f | Spot: %.2f | xyz: %.2f", + totalWalletBalance, availableBalance, spotUSDCBalance, xyzAccountValue) + + return result, nil +} + +// xyzDexState represents the clearinghouse state for xyz dex +type xyzDexState struct { + MarginSummary *xyzMarginSummary `json:"marginSummary,omitempty"` + CrossMarginSummary *xyzMarginSummary `json:"crossMarginSummary,omitempty"` + Withdrawable string `json:"withdrawable,omitempty"` + AssetPositions []xyzAssetPosition `json:"assetPositions,omitempty"` +} + +type xyzMarginSummary struct { + AccountValue string `json:"accountValue"` + TotalMarginUsed string `json:"totalMarginUsed"` +} + +type xyzAssetPosition struct { + Position struct { + Coin string `json:"coin"` + Szi string `json:"szi"` + EntryPx *string `json:"entryPx"` + PositionValue string `json:"positionValue"` + UnrealizedPnl string `json:"unrealizedPnl"` + LiquidationPx *string `json:"liquidationPx"` + Leverage struct { + Type string `json:"type"` + Value int `json:"value"` + } `json:"leverage"` + } `json:"position"` +} + +// getXYZDexBalance queries the xyz dex balance (stock perps, forex, commodities) +func (t *HyperliquidTrader) getXYZDexBalance() (accountValue float64, unrealizedPnl float64, positions []xyzAssetPosition, err error) { + // Build request for xyz dex clearinghouse state + reqBody := map[string]interface{}{ + "type": "clearinghouseState", + "user": t.walletAddr, + "dex": "xyz", + } + + jsonBody, err := json.Marshal(reqBody) + if err != nil { + return 0, 0, nil, fmt.Errorf("failed to marshal request: %w", err) + } + + // Determine API URL + apiURL := "https://api.hyperliquid.xyz/info" + // Note: xyz dex may not be available on testnet + + req, err := http.NewRequestWithContext(t.ctx, "POST", apiURL, bytes.NewBuffer(jsonBody)) + if err != nil { + return 0, 0, nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return 0, 0, nil, fmt.Errorf("failed to execute request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return 0, 0, nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return 0, 0, nil, fmt.Errorf("xyz dex API error (status %d): %s", resp.StatusCode, string(body)) + } + + var state xyzDexState + if err := json.Unmarshal(body, &state); err != nil { + return 0, 0, nil, fmt.Errorf("failed to parse response: %w", err) + } + + // Parse account value - xyz dex uses MarginSummary for isolated margin mode + // CrossMarginSummary may exist but with 0 values, so check MarginSummary first + if state.MarginSummary != nil && state.MarginSummary.AccountValue != "" { + av, _ := strconv.ParseFloat(state.MarginSummary.AccountValue, 64) + if av > 0 { + accountValue = av + } + } + // Fallback to CrossMarginSummary if MarginSummary is 0 + if accountValue == 0 && state.CrossMarginSummary != nil && state.CrossMarginSummary.AccountValue != "" { + accountValue, _ = strconv.ParseFloat(state.CrossMarginSummary.AccountValue, 64) + } + + // Calculate total unrealized PnL from positions + for _, pos := range state.AssetPositions { + pnl, _ := strconv.ParseFloat(pos.Position.UnrealizedPnl, 64) + unrealizedPnl += pnl + } + + return accountValue, unrealizedPnl, state.AssetPositions, nil +} + +// GetMarketPrice gets market price (supports both crypto and xyz dex assets) +func (t *HyperliquidTrader) GetMarketPrice(symbol string) (float64, error) { + coin := convertSymbolToHyperliquid(symbol) + + // Check if this is an xyz dex asset + if strings.HasPrefix(coin, "xyz:") { + return t.getXyzMarketPrice(coin) + } + + // Get all market prices for crypto + allMids, err := t.exchange.Info().AllMids(t.ctx) + if err != nil { + return 0, fmt.Errorf("failed to get price: %w", err) + } + + // Find price for corresponding coin (allMids is map[string]string) + if priceStr, ok := allMids[coin]; ok { + priceFloat, err := strconv.ParseFloat(priceStr, 64) + if err == nil { + return priceFloat, nil + } + return 0, fmt.Errorf("price format error: %v", err) + } + + return 0, fmt.Errorf("price not found for %s", symbol) +} + +// getXyzMarketPrice gets market price for xyz dex assets +func (t *HyperliquidTrader) getXyzMarketPrice(coin string) (float64, error) { + // Build request for xyz dex allMids + reqBody := map[string]string{ + "type": "allMids", + "dex": "xyz", + } + + jsonBody, err := json.Marshal(reqBody) + if err != nil { + return 0, fmt.Errorf("failed to marshal request: %w", err) + } + + apiURL := "https://api.hyperliquid.xyz/info" + + req, err := http.NewRequestWithContext(t.ctx, "POST", apiURL, bytes.NewBuffer(jsonBody)) + if err != nil { + return 0, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return 0, fmt.Errorf("failed to execute request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return 0, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return 0, fmt.Errorf("xyz dex allMids API error (status %d): %s", resp.StatusCode, string(body)) + } + + var mids map[string]string + if err := json.Unmarshal(body, &mids); err != nil { + return 0, fmt.Errorf("failed to parse response: %w", err) + } + + // The API returns keys with xyz: prefix, so ensure the coin has it + lookupKey := coin + if !strings.HasPrefix(lookupKey, "xyz:") { + lookupKey = "xyz:" + lookupKey + } + + if priceStr, ok := mids[lookupKey]; ok { + priceFloat, err := strconv.ParseFloat(priceStr, 64) + if err == nil { + return priceFloat, nil + } + return 0, fmt.Errorf("price format error: %v", err) + } + + return 0, fmt.Errorf("xyz dex price not found for %s (lookup key: %s)", coin, lookupKey) +} + +// GetOrderStatus gets order status +// Hyperliquid uses IOC orders, usually filled or cancelled immediately +// For completed orders, need to query historical records +func (t *HyperliquidTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { + // Hyperliquid's IOC orders are completed almost immediately + // If order was placed through this system, returned status will be FILLED + // Try to query open orders to determine if still pending + coin := convertSymbolToHyperliquid(symbol) + + // First check if in open orders + openOrders, err := t.exchange.Info().OpenOrders(t.ctx, t.walletAddr) + if err != nil { + // If query fails, assume order is completed + return map[string]interface{}{ + "orderId": orderID, + "status": "FILLED", + "avgPrice": 0.0, + "executedQty": 0.0, + "commission": 0.0, + }, nil + } + + // Check if order is in open orders list + for _, order := range openOrders { + if order.Coin == coin && fmt.Sprintf("%d", order.Oid) == orderID { + // Order is still pending + return map[string]interface{}{ + "orderId": orderID, + "status": "NEW", + "avgPrice": 0.0, + "executedQty": 0.0, + "commission": 0.0, + }, nil + } + } + + // Order not in open list, meaning completed or cancelled + // Hyperliquid IOC orders not in open list are usually filled + return map[string]interface{}{ + "orderId": orderID, + "status": "FILLED", + "avgPrice": 0.0, // Hyperliquid does not directly return execution price, need to get from position info + "executedQty": 0.0, + "commission": 0.0, + }, nil +} + +// GetClosedPnL gets recent closing trades from Hyperliquid +// Note: Hyperliquid does NOT have a position history API, only fill history. +// This returns individual closing trades for real-time position closure detection. +func (t *HyperliquidTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) { + trades, err := t.GetTrades(startTime, limit) + if err != nil { + return nil, err + } + + // Filter only closing trades (realizedPnl != 0) + var records []types.ClosedPnLRecord + for _, trade := range trades { + if trade.RealizedPnL == 0 { + continue + } + + // Determine side (Hyperliquid uses one-way mode) + side := "long" + if trade.Side == "SELL" || trade.Side == "Sell" { + side = "long" // Selling closes long + } else { + side = "short" // Buying closes short + } + + // Calculate entry price from PnL + var entryPrice float64 + if trade.Quantity > 0 { + if side == "long" { + entryPrice = trade.Price - trade.RealizedPnL/trade.Quantity + } else { + entryPrice = trade.Price + trade.RealizedPnL/trade.Quantity + } + } + + records = append(records, types.ClosedPnLRecord{ + Symbol: trade.Symbol, + Side: side, + EntryPrice: entryPrice, + ExitPrice: trade.Price, + Quantity: trade.Quantity, + RealizedPnL: trade.RealizedPnL, + Fee: trade.Fee, + ExitTime: trade.Time, + EntryTime: trade.Time, + OrderID: trade.TradeID, + ExchangeID: trade.TradeID, + CloseType: "unknown", + }) + } + + return records, nil +} + +// GetTrades retrieves trade history from Hyperliquid +func (t *HyperliquidTrader) GetTrades(startTime time.Time, limit int) ([]types.TradeRecord, error) { + // Use UserFillsByTime API + startTimeMs := startTime.UnixMilli() + fills, err := t.exchange.Info().UserFillsByTime(t.ctx, t.walletAddr, startTimeMs, nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to get user fills: %w", err) + } + + var trades []types.TradeRecord + for _, fill := range fills { + price, _ := strconv.ParseFloat(fill.Price, 64) + qty, _ := strconv.ParseFloat(fill.Size, 64) + fee, _ := strconv.ParseFloat(fill.Fee, 64) + pnl, _ := strconv.ParseFloat(fill.ClosedPnl, 64) + + // Determine side: "B" = Buy, "S" = Sell (or "A" = Ask, "B" = Bid) + var side string + if fill.Side == "B" || fill.Side == "Buy" || fill.Side == "bid" { + side = "BUY" + } else { + side = "SELL" + } + + // Parse Dir field to get order action + // Hyperliquid Dir values: "Open Long", "Open Short", "Close Long", "Close Short" + var orderAction string + switch strings.ToLower(fill.Dir) { + case "open long": + orderAction = "open_long" + case "open short": + orderAction = "open_short" + case "close long": + orderAction = "close_long" + case "close short": + orderAction = "close_short" + default: + // Fallback: use RealizedPnL if Dir is missing/unknown + if pnl != 0 { + if side == "BUY" { + orderAction = "close_short" + } else { + orderAction = "close_long" + } + } else { + if side == "BUY" { + orderAction = "open_long" + } else { + orderAction = "open_short" + } + } + } + + // Hyperliquid uses one-way mode, so PositionSide is "BOTH" + trade := types.TradeRecord{ + TradeID: strconv.FormatInt(fill.Tid, 10), + Symbol: fill.Coin, + Side: side, + PositionSide: "BOTH", // Hyperliquid doesn't have hedge mode + OrderAction: orderAction, + Price: price, + Quantity: qty, + RealizedPnL: pnl, + Fee: fee, + Time: time.UnixMilli(fill.Time).UTC(), + } + trades = append(trades, trade) + } + + return trades, nil +} + +// GetOpenOrders gets all open/pending orders for a symbol +func (t *HyperliquidTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) { + openOrders, err := t.exchange.Info().OpenOrders(t.ctx, t.walletAddr) + if err != nil { + return nil, fmt.Errorf("failed to get open orders: %w", err) + } + + var result []types.OpenOrder + for _, order := range openOrders { + if order.Coin != symbol { + continue + } + + side := "BUY" + if order.Side == "A" { + side = "SELL" + } + + result = append(result, types.OpenOrder{ + OrderID: fmt.Sprintf("%d", order.Oid), + Symbol: order.Coin, + Side: side, + PositionSide: "", + Type: "LIMIT", + Price: order.LimitPx, + StopPrice: 0, + Quantity: order.Size, + Status: "NEW", + }) + } + + return result, nil +} + +// GetOrderBook gets the order book for a symbol +// Implements GridTrader interface +func (t *HyperliquidTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { + coin := convertSymbolToHyperliquid(symbol) + + l2Book, err := t.exchange.Info().L2Snapshot(t.ctx, coin) + if err != nil { + return nil, nil, fmt.Errorf("failed to get order book: %w", err) + } + + if l2Book == nil || len(l2Book.Levels) < 2 { + return nil, nil, fmt.Errorf("invalid order book data") + } + + // Parse bids (first level array) + for i, level := range l2Book.Levels[0] { + if i >= depth { + break + } + bids = append(bids, []float64{level.Px, level.Sz}) + } + + // Parse asks (second level array) + for i, level := range l2Book.Levels[1] { + if i >= depth { + break + } + asks = append(asks, []float64{level.Px, level.Sz}) + } + + return bids, asks, nil +} diff --git a/trader/hyperliquid/trader_orders.go b/trader/hyperliquid/trader_orders.go new file mode 100644 index 00000000..3543d655 --- /dev/null +++ b/trader/hyperliquid/trader_orders.go @@ -0,0 +1,1075 @@ +package hyperliquid + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "nofx/logger" + "nofx/trader/types" + "strconv" + "strings" + "time" + + "github.com/sonirico/go-hyperliquid" +) + +// OpenLong opens a long position (supports both crypto and xyz dex) +func (t *HyperliquidTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { + // First cancel all pending orders for this coin + if err := t.CancelAllOrders(symbol); err != nil { + logger.Infof(" ⚠ Failed to cancel old pending orders: %v", err) + } + + // Hyperliquid symbol format + coin := convertSymbolToHyperliquid(symbol) + + // Check if this is an xyz dex asset + isXyz := strings.HasPrefix(coin, "xyz:") + + // Set leverage (skip for xyz dex as it may not support leverage adjustment) + if !isXyz { + if err := t.SetLeverage(symbol, leverage); err != nil { + return nil, err + } + } else { + logger.Infof(" ℹ xyz dex asset %s - using default leverage", coin) + } + + // Get current price (for market order) + price, err := t.GetMarketPrice(symbol) + if err != nil { + return nil, err + } + + // Price needs to be processed to 5 significant figures + aggressivePrice := t.roundPriceToSigfigs(price * 1.01) + logger.Infof(" 💰 Price precision handling: %.8f -> %.8f (5 significant figures)", price*1.01, aggressivePrice) + + // Handle xyz dex assets differently + if isXyz { + // xyz dex order + if err := t.placeXyzOrder(coin, true, quantity, aggressivePrice, false); err != nil { + return nil, fmt.Errorf("failed to open long position on xyz dex: %w", err) + } + } else { + // Standard crypto order + roundedQuantity := t.roundToSzDecimals(coin, quantity) + logger.Infof(" 📏 Quantity precision handling: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin)) + + order := hyperliquid.CreateOrderRequest{ + Coin: coin, + IsBuy: true, + Size: roundedQuantity, + Price: aggressivePrice, + OrderType: hyperliquid.OrderType{ + Limit: &hyperliquid.LimitOrderType{ + Tif: hyperliquid.TifIoc, + }, + }, + ReduceOnly: false, + } + + _, err = t.exchange.Order(t.ctx, order, defaultBuilder) + if err != nil { + return nil, fmt.Errorf("failed to open long position: %w", err) + } + } + + logger.Infof("✓ Long position opened successfully: %s quantity: %.4f", symbol, quantity) + + result := make(map[string]interface{}) + result["orderId"] = 0 + result["symbol"] = symbol + result["status"] = "FILLED" + + return result, nil +} + +// OpenShort opens a short position (supports both crypto and xyz dex) +func (t *HyperliquidTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { + // First cancel all pending orders for this coin + if err := t.CancelAllOrders(symbol); err != nil { + logger.Infof(" ⚠ Failed to cancel old pending orders: %v", err) + } + + // Hyperliquid symbol format + coin := convertSymbolToHyperliquid(symbol) + + // Check if this is an xyz dex asset + isXyz := strings.HasPrefix(coin, "xyz:") + + // Set leverage (skip for xyz dex) + if !isXyz { + if err := t.SetLeverage(symbol, leverage); err != nil { + return nil, err + } + } else { + logger.Infof(" ℹ xyz dex asset %s - using default leverage", coin) + } + + // Get current price + price, err := t.GetMarketPrice(symbol) + if err != nil { + return nil, err + } + + // Price needs to be processed to 5 significant figures + aggressivePrice := t.roundPriceToSigfigs(price * 0.99) + logger.Infof(" 💰 Price precision handling: %.8f -> %.8f (5 significant figures)", price*0.99, aggressivePrice) + + // Handle xyz dex assets differently + if isXyz { + // xyz dex order + if err := t.placeXyzOrder(coin, false, quantity, aggressivePrice, false); err != nil { + return nil, fmt.Errorf("failed to open short position on xyz dex: %w", err) + } + } else { + // Standard crypto order + roundedQuantity := t.roundToSzDecimals(coin, quantity) + logger.Infof(" 📏 Quantity precision handling: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin)) + + order := hyperliquid.CreateOrderRequest{ + Coin: coin, + IsBuy: false, + Size: roundedQuantity, + Price: aggressivePrice, + OrderType: hyperliquid.OrderType{ + Limit: &hyperliquid.LimitOrderType{ + Tif: hyperliquid.TifIoc, + }, + }, + ReduceOnly: false, + } + + _, err = t.exchange.Order(t.ctx, order, defaultBuilder) + if err != nil { + return nil, fmt.Errorf("failed to open short position: %w", err) + } + } + + logger.Infof("✓ Short position opened successfully: %s quantity: %.4f", symbol, quantity) + + result := make(map[string]interface{}) + result["orderId"] = 0 + result["symbol"] = symbol + result["status"] = "FILLED" + + return result, nil +} + +// CloseLong closes a long position (supports both crypto and xyz dex) +func (t *HyperliquidTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) { + // Hyperliquid symbol format + coin := convertSymbolToHyperliquid(symbol) + isXyz := strings.HasPrefix(coin, "xyz:") + + // If quantity is 0, get current position quantity + if quantity == 0 { + positions, err := t.GetPositions() + if err != nil { + return nil, err + } + + // For xyz dex, also check xyz: prefixed symbols + searchSymbol := symbol + if isXyz { + searchSymbol = coin // Use xyz:SYMBOL format for comparison + } + + for _, pos := range positions { + posSymbol := pos["symbol"].(string) + if (posSymbol == symbol || posSymbol == searchSymbol) && pos["side"] == "long" { + quantity = pos["positionAmt"].(float64) + break + } + } + + if quantity == 0 { + return nil, fmt.Errorf("no long position found for %s", symbol) + } + } + + // Get current price + price, err := t.GetMarketPrice(symbol) + if err != nil { + return nil, err + } + + // Price needs to be processed to 5 significant figures + aggressivePrice := t.roundPriceToSigfigs(price * 0.99) + logger.Infof(" 💰 Price precision handling: %.8f -> %.8f (5 significant figures)", price*0.99, aggressivePrice) + + // Handle xyz dex assets differently + if isXyz { + // xyz dex close order + if err := t.placeXyzOrder(coin, false, quantity, aggressivePrice, true); err != nil { + return nil, fmt.Errorf("failed to close long position on xyz dex: %w", err) + } + } else { + // Standard crypto close order + roundedQuantity := t.roundToSzDecimals(coin, quantity) + logger.Infof(" 📏 Quantity precision handling: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin)) + + order := hyperliquid.CreateOrderRequest{ + Coin: coin, + IsBuy: false, + Size: roundedQuantity, + Price: aggressivePrice, + OrderType: hyperliquid.OrderType{ + Limit: &hyperliquid.LimitOrderType{ + Tif: hyperliquid.TifIoc, + }, + }, + ReduceOnly: true, + } + + _, err = t.exchange.Order(t.ctx, order, defaultBuilder) + if err != nil { + return nil, fmt.Errorf("failed to close long position: %w", err) + } + } + + logger.Infof("✓ Long position closed successfully: %s quantity: %.4f", symbol, quantity) + + // Cancel all pending orders for this coin after closing position + if err := t.CancelAllOrders(symbol); err != nil { + logger.Infof(" ⚠ Failed to cancel pending orders: %v", err) + } + + result := make(map[string]interface{}) + result["orderId"] = 0 + result["symbol"] = symbol + result["status"] = "FILLED" + + return result, nil +} + +// CloseShort closes a short position (supports both crypto and xyz dex) +func (t *HyperliquidTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) { + // Hyperliquid symbol format + coin := convertSymbolToHyperliquid(symbol) + isXyz := strings.HasPrefix(coin, "xyz:") + + // If quantity is 0, get current position quantity + if quantity == 0 { + positions, err := t.GetPositions() + if err != nil { + return nil, err + } + + // For xyz dex, also check xyz: prefixed symbols + searchSymbol := symbol + if isXyz { + searchSymbol = coin + } + + for _, pos := range positions { + posSymbol := pos["symbol"].(string) + if (posSymbol == symbol || posSymbol == searchSymbol) && pos["side"] == "short" { + quantity = pos["positionAmt"].(float64) + break + } + } + + if quantity == 0 { + return nil, fmt.Errorf("no short position found for %s", symbol) + } + } + + // Get current price + price, err := t.GetMarketPrice(symbol) + if err != nil { + return nil, err + } + + // Price needs to be processed to 5 significant figures + aggressivePrice := t.roundPriceToSigfigs(price * 1.01) + logger.Infof(" 💰 Price precision handling: %.8f -> %.8f (5 significant figures)", price*1.01, aggressivePrice) + + // Handle xyz dex assets differently + if isXyz { + // xyz dex close order + if err := t.placeXyzOrder(coin, true, quantity, aggressivePrice, true); err != nil { + return nil, fmt.Errorf("failed to close short position on xyz dex: %w", err) + } + } else { + // Standard crypto close order + roundedQuantity := t.roundToSzDecimals(coin, quantity) + logger.Infof(" 📏 Quantity precision handling: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin)) + + order := hyperliquid.CreateOrderRequest{ + Coin: coin, + IsBuy: true, + Size: roundedQuantity, + Price: aggressivePrice, + OrderType: hyperliquid.OrderType{ + Limit: &hyperliquid.LimitOrderType{ + Tif: hyperliquid.TifIoc, + }, + }, + ReduceOnly: true, + } + + _, err = t.exchange.Order(t.ctx, order, defaultBuilder) + if err != nil { + return nil, fmt.Errorf("failed to close short position: %w", err) + } + } + + logger.Infof("✓ Short position closed successfully: %s quantity: %.4f", symbol, quantity) + + // Cancel all pending orders for this coin after closing position + if err := t.CancelAllOrders(symbol); err != nil { + logger.Infof(" ⚠ Failed to cancel pending orders: %v", err) + } + + result := make(map[string]interface{}) + result["orderId"] = 0 + result["symbol"] = symbol + result["status"] = "FILLED" + + return result, nil +} + +// CancelStopLossOrders only cancels stop loss orders (Hyperliquid cannot distinguish stop loss and take profit, cancel all) +func (t *HyperliquidTrader) CancelStopLossOrders(symbol string) error { + // Hyperliquid SDK's OpenOrder structure does not expose trigger field + // Cannot distinguish stop loss and take profit orders, so cancel all pending orders for this coin + logger.Infof(" ⚠️ Hyperliquid cannot distinguish stop loss/take profit orders, will cancel all pending orders") + return t.CancelStopOrders(symbol) +} + +// CancelTakeProfitOrders only cancels take profit orders (Hyperliquid cannot distinguish stop loss and take profit, cancel all) +func (t *HyperliquidTrader) CancelTakeProfitOrders(symbol string) error { + // Hyperliquid SDK's OpenOrder structure does not expose trigger field + // Cannot distinguish stop loss and take profit orders, so cancel all pending orders for this coin + logger.Infof(" ⚠️ Hyperliquid cannot distinguish stop loss/take profit orders, will cancel all pending orders") + return t.CancelStopOrders(symbol) +} + +// CancelAllOrders cancels all pending orders for this coin +func (t *HyperliquidTrader) CancelAllOrders(symbol string) error { + coin := convertSymbolToHyperliquid(symbol) + + // Check if this is an xyz dex asset + isXyz := strings.HasPrefix(coin, "xyz:") + + if isXyz { + // xyz dex orders - use direct API call + return t.cancelXyzOrders(coin) + } + + // Standard crypto orders + openOrders, err := t.exchange.Info().OpenOrders(t.ctx, t.walletAddr) + if err != nil { + return fmt.Errorf("failed to get pending orders: %w", err) + } + + // Cancel all pending orders for this coin + for _, order := range openOrders { + if order.Coin == coin { + _, err := t.exchange.Cancel(t.ctx, coin, order.Oid) + if err != nil { + logger.Infof(" ⚠ Failed to cancel order (oid=%d): %v", order.Oid, err) + } + } + } + + logger.Infof(" ✓ Cancelled all pending orders for %s", symbol) + return nil +} + +// CancelStopOrders cancels take profit/stop loss orders for this coin (used to adjust TP/SL positions) +func (t *HyperliquidTrader) CancelStopOrders(symbol string) error { + coin := convertSymbolToHyperliquid(symbol) + + // Check if this is an xyz dex asset + isXyz := strings.HasPrefix(coin, "xyz:") + + if isXyz { + // xyz dex orders - use direct API call + return t.cancelXyzOrders(coin) + } + + // Get all pending orders for standard crypto + openOrders, err := t.exchange.Info().OpenOrders(t.ctx, t.walletAddr) + if err != nil { + return fmt.Errorf("failed to get pending orders: %w", err) + } + + // Note: Hyperliquid SDK's OpenOrder structure does not expose trigger field + // Therefore temporarily cancel all pending orders for this coin (including TP/SL orders) + // This is safe because all old orders should be cleaned up before setting new TP/SL + canceledCount := 0 + for _, order := range openOrders { + if order.Coin == coin { + _, err := t.exchange.Cancel(t.ctx, coin, order.Oid) + if err != nil { + logger.Infof(" ⚠ Failed to cancel order (oid=%d): %v", order.Oid, err) + continue + } + canceledCount++ + } + } + + if canceledCount == 0 { + logger.Infof(" ℹ No pending orders to cancel for %s", symbol) + } else { + logger.Infof(" ✓ Cancelled %d pending orders for %s (including TP/SL orders)", canceledCount, symbol) + } + + return nil +} + +// cancelXyzOrders cancels all pending orders for xyz dex assets (stocks, forex, commodities) +func (t *HyperliquidTrader) cancelXyzOrders(coin string) error { + // Query xyz dex open orders + reqBody := map[string]interface{}{ + "type": "openOrders", + "user": t.walletAddr, + "dex": "xyz", + } + + jsonBody, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + apiURL := "https://api.hyperliquid.xyz/info" + + req, err := http.NewRequestWithContext(t.ctx, "POST", apiURL, bytes.NewBuffer(jsonBody)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to execute request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("xyz dex openOrders API error (status %d): %s", resp.StatusCode, string(body)) + } + + // Parse open orders + var openOrders []struct { + Coin string `json:"coin"` + Oid int64 `json:"oid"` + } + if err := json.Unmarshal(body, &openOrders); err != nil { + return fmt.Errorf("failed to parse open orders: %w", err) + } + + // Filter orders for this coin and cancel them + canceledCount := 0 + for _, order := range openOrders { + if order.Coin == coin { + if err := t.cancelXyzOrder(order.Oid); err != nil { + logger.Infof(" ⚠ Failed to cancel xyz dex order (oid=%d): %v", order.Oid, err) + continue + } + canceledCount++ + } + } + + if canceledCount == 0 { + logger.Infof(" ℹ No pending xyz dex orders to cancel for %s", coin) + } else { + logger.Infof(" ✓ Cancelled %d xyz dex orders for %s", canceledCount, coin) + } + + return nil +} + +// cancelXyzOrder cancels a single xyz dex order by oid +func (t *HyperliquidTrader) cancelXyzOrder(oid int64) error { + // Get asset index for this order (we need it for cancel action) + // For cancel, we construct a cancel action with the oid + + action := map[string]interface{}{ + "type": "cancel", + "cancels": []map[string]interface{}{ + { + "a": oid, // asset index not needed for cancel by oid in xyz dex + "o": oid, + }, + }, + } + + // Sign the action + nonce := time.Now().UnixMilli() + isMainnet := !t.isTestnet + vaultAddress := "" + + sig, err := hyperliquid.SignL1Action(t.privateKey, action, vaultAddress, nonce, nil, isMainnet) + if err != nil { + return fmt.Errorf("failed to sign cancel action: %w", err) + } + + payload := map[string]any{ + "action": action, + "nonce": nonce, + "signature": sig, + } + + apiURL := hyperliquid.MainnetAPIURL + if t.isTestnet { + apiURL = hyperliquid.TestnetAPIURL + } + + jsonData, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to marshal payload: %w", err) + } + + req, err := http.NewRequestWithContext(t.ctx, http.MethodPost, apiURL+"/exchange", bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + // Check response + var result struct { + Status string `json:"status"` + } + if err := json.Unmarshal(body, &result); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + if result.Status != "ok" { + return fmt.Errorf("cancel failed: %s", string(body)) + } + + return nil +} + +// floatToWireStr converts a float to wire format string (8 decimal places, trimmed zeros) +// This matches the SDK's floatToWire function +func floatToWireStr(x float64) string { + // Format to 8 decimal places + result := fmt.Sprintf("%.8f", x) + // Remove trailing zeros + result = strings.TrimRight(result, "0") + // Remove trailing decimal point if no decimals left + result = strings.TrimRight(result, ".") + return result +} + +// placeXyzOrder places an order on the xyz dex (stocks, forex, commodities) +// Note: xyz dex orders use builder-deployed perpetuals and require different handling +// xyz dex asset indices start from 10000 (10000 + meta_index) +// This implementation bypasses the SDK's NameToAsset lookup and directly constructs the order +func (t *HyperliquidTrader) placeXyzOrder(coin string, isBuy bool, size float64, price float64, reduceOnly bool) error { + // Fetch xyz meta if not cached + t.xyzMetaMutex.RLock() + hasMeta := t.xyzMeta != nil + t.xyzMetaMutex.RUnlock() + + if !hasMeta { + if err := t.fetchXyzMeta(); err != nil { + return fmt.Errorf("failed to fetch xyz meta: %w", err) + } + } + + // Get asset index from xyz meta (returns 0-based index) + metaIndex := t.getXyzAssetIndex(coin) + if metaIndex < 0 { + return fmt.Errorf("xyz asset %s not found in meta", coin) + } + + // HIP-3 perp dex asset index formula: 100000 + perp_dex_index * 10000 + index_in_meta + // xyz dex is at perp_dex_index = 1 (verified from perpDexs API: [null, {name:"xyz",...}]) + // So xyz asset index = 100000 + 1 * 10000 + metaIndex = 110000 + metaIndex + const xyzPerpDexIndex = 1 + assetIndex := 100000 + xyzPerpDexIndex*10000 + metaIndex + + // Round size to correct precision + szDecimals := t.getXyzSzDecimals(coin) + multiplier := 1.0 + for i := 0; i < szDecimals; i++ { + multiplier *= 10.0 + } + roundedSize := float64(int(size*multiplier+0.5)) / multiplier + + // Round price to 5 significant figures + roundedPrice := t.roundPriceToSigfigs(price) + + logger.Infof("📝 Placing xyz dex order (direct): %s %s size=%.4f price=%.4f metaIndex=%d assetIndex=%d (formula: 100000 + 1*10000 + %d) reduceOnly=%v", + map[bool]string{true: "BUY", false: "SELL"}[isBuy], + coin, roundedSize, roundedPrice, metaIndex, assetIndex, metaIndex, reduceOnly) + + // Construct OrderWire directly with correct asset index (bypassing SDK's NameToAsset) + orderWire := hyperliquid.OrderWire{ + Asset: assetIndex, + IsBuy: isBuy, + LimitPx: floatToWireStr(roundedPrice), + Size: floatToWireStr(roundedSize), + ReduceOnly: reduceOnly, + OrderType: hyperliquid.OrderWireType{ + Limit: &hyperliquid.OrderWireTypeLimit{ + Tif: hyperliquid.TifIoc, + }, + }, + } + + // Create OrderAction (no builder to avoid requiring builder fee approval) + action := hyperliquid.OrderAction{ + Type: "order", + Orders: []hyperliquid.OrderWire{orderWire}, + Grouping: "na", + Builder: nil, + } + + // Sign the action + nonce := time.Now().UnixMilli() + isMainnet := !t.isTestnet + vaultAddress := "" // No vault for personal account + + sig, err := hyperliquid.SignL1Action(t.privateKey, action, vaultAddress, nonce, nil, isMainnet) + if err != nil { + return fmt.Errorf("failed to sign xyz dex order: %w", err) + } + + // Construct payload for /exchange endpoint + payload := map[string]any{ + "action": action, + "nonce": nonce, + "signature": sig, + } + + // Determine API URL + apiURL := hyperliquid.MainnetAPIURL + if t.isTestnet { + apiURL = hyperliquid.TestnetAPIURL + } + + // POST to /exchange + jsonData, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to marshal payload: %w", err) + } + + logger.Infof("📤 Sending xyz dex order to %s/exchange", apiURL) + + req, err := http.NewRequestWithContext(t.ctx, http.MethodPost, apiURL+"/exchange", bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + + // Parse response + var result struct { + Status string `json:"status"` + Response struct { + Type string `json:"type"` + Data struct { + Statuses []struct { + Resting *struct { + Oid int64 `json:"oid"` + } `json:"resting,omitempty"` + Filled *struct { + TotalSz string `json:"totalSz"` + AvgPx string `json:"avgPx"` + Oid int `json:"oid"` + } `json:"filled,omitempty"` + Error *string `json:"error,omitempty"` + } `json:"statuses"` + } `json:"data"` + } `json:"response"` + } + + if err := json.Unmarshal(body, &result); err != nil { + // Try to parse as error response + logger.Infof("⚠️ Failed to parse response as success, raw body: %s", string(body)) + return fmt.Errorf("xyz dex order failed, status=%d, body=%s", resp.StatusCode, string(body)) + } + + // Check for errors in response + if result.Status != "ok" { + return fmt.Errorf("xyz dex order failed: status=%s, body=%s", result.Status, string(body)) + } + + // Check order statuses + if len(result.Response.Data.Statuses) > 0 { + status := result.Response.Data.Statuses[0] + if status.Error != nil { + return fmt.Errorf("xyz dex order error (coin=%s, assetIndex=%d, size=%.4f, price=%.4f): %s", coin, assetIndex, roundedSize, roundedPrice, *status.Error) + } + if status.Filled != nil { + logger.Infof("✅ xyz dex order filled: totalSz=%s avgPx=%s oid=%d", + status.Filled.TotalSz, status.Filled.AvgPx, status.Filled.Oid) + } else if status.Resting != nil { + logger.Infof("✅ xyz dex order resting: oid=%d", status.Resting.Oid) + } + } + + logger.Infof("✅ xyz dex order placed successfully: %s (response: %s)", coin, string(body)) + return nil +} + +// placeXyzTriggerOrder places a trigger order (stop loss / take profit) on the xyz dex +// tpsl: "sl" for stop loss, "tp" for take profit +func (t *HyperliquidTrader) placeXyzTriggerOrder(coin string, isBuy bool, size float64, triggerPrice float64, tpsl string) error { + // Fetch xyz meta if not cached + t.xyzMetaMutex.RLock() + hasMeta := t.xyzMeta != nil + t.xyzMetaMutex.RUnlock() + + if !hasMeta { + if err := t.fetchXyzMeta(); err != nil { + return fmt.Errorf("failed to fetch xyz meta: %w", err) + } + } + + // Get asset index from xyz meta (returns 0-based index) + metaIndex := t.getXyzAssetIndex(coin) + if metaIndex < 0 { + return fmt.Errorf("xyz asset %s not found in meta", coin) + } + + // HIP-3 perp dex asset index formula: 100000 + perp_dex_index * 10000 + index_in_meta + // xyz dex is at perp_dex_index = 1 + const xyzPerpDexIndex = 1 + assetIndex := 100000 + xyzPerpDexIndex*10000 + metaIndex + + // Round size to correct precision + szDecimals := t.getXyzSzDecimals(coin) + multiplier := 1.0 + for i := 0; i < szDecimals; i++ { + multiplier *= 10.0 + } + roundedSize := float64(int(size*multiplier+0.5)) / multiplier + + // Round price to 5 significant figures + roundedPrice := t.roundPriceToSigfigs(triggerPrice) + + logger.Infof("📝 Placing xyz dex %s order: %s %s size=%.4f triggerPrice=%.4f assetIndex=%d", + tpsl, + map[bool]string{true: "BUY", false: "SELL"}[isBuy], + coin, roundedSize, roundedPrice, assetIndex) + + // Construct OrderWire with trigger type for stop loss / take profit + orderWire := hyperliquid.OrderWire{ + Asset: assetIndex, + IsBuy: isBuy, + LimitPx: floatToWireStr(roundedPrice), + Size: floatToWireStr(roundedSize), + ReduceOnly: true, // TP/SL orders are always reduce-only + OrderType: hyperliquid.OrderWireType{ + Trigger: &hyperliquid.OrderWireTypeTrigger{ + TriggerPx: floatToWireStr(roundedPrice), + IsMarket: true, + Tpsl: hyperliquid.Tpsl(tpsl), // "sl" or "tp" - convert string to Tpsl type + }, + }, + } + + // Create OrderAction (no builder to avoid requiring builder fee approval) + action := hyperliquid.OrderAction{ + Type: "order", + Orders: []hyperliquid.OrderWire{orderWire}, + Grouping: "na", + Builder: nil, + } + + // Sign the action + nonce := time.Now().UnixMilli() + isMainnet := !t.isTestnet + vaultAddress := "" + + sig, err := hyperliquid.SignL1Action(t.privateKey, action, vaultAddress, nonce, nil, isMainnet) + if err != nil { + return fmt.Errorf("failed to sign xyz dex trigger order: %w", err) + } + + // Construct payload for /exchange endpoint + payload := map[string]any{ + "action": action, + "nonce": nonce, + "signature": sig, + } + + // Determine API URL + apiURL := hyperliquid.MainnetAPIURL + if t.isTestnet { + apiURL = hyperliquid.TestnetAPIURL + } + + // POST to /exchange + jsonData, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to marshal payload: %w", err) + } + + logger.Infof("📤 Sending xyz dex %s order to %s/exchange", tpsl, apiURL) + + req, err := http.NewRequestWithContext(t.ctx, http.MethodPost, apiURL+"/exchange", bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + + // Parse response + var result struct { + Status string `json:"status"` + Response struct { + Type string `json:"type"` + Data struct { + Statuses []struct { + Resting *struct { + Oid int64 `json:"oid"` + } `json:"resting,omitempty"` + Error *string `json:"error,omitempty"` + } `json:"statuses"` + } `json:"data"` + } `json:"response"` + } + + if err := json.Unmarshal(body, &result); err != nil { + logger.Infof("⚠️ Failed to parse response, raw body: %s", string(body)) + return fmt.Errorf("xyz dex %s order failed, status=%d, body=%s", tpsl, resp.StatusCode, string(body)) + } + + // Check for errors in response + if result.Status != "ok" { + return fmt.Errorf("xyz dex %s order failed: status=%s, body=%s", tpsl, result.Status, string(body)) + } + + // Check order statuses + if len(result.Response.Data.Statuses) > 0 { + status := result.Response.Data.Statuses[0] + if status.Error != nil { + return fmt.Errorf("xyz dex %s order error: %s", tpsl, *status.Error) + } + if status.Resting != nil { + logger.Infof("✅ xyz dex %s order placed: oid=%d", tpsl, status.Resting.Oid) + } + } + + logger.Infof("✅ xyz dex %s order placed successfully: %s", tpsl, coin) + return nil +} + +// SetStopLoss sets stop loss order +func (t *HyperliquidTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error { + coin := convertSymbolToHyperliquid(symbol) + + isBuy := positionSide == "SHORT" // Short position stop loss = buy, long position stop loss = sell + + // Price needs to be processed to 5 significant figures + roundedStopPrice := t.roundPriceToSigfigs(stopPrice) + + // Check if this is an xyz dex asset (stocks, forex, commodities) + isXyz := strings.HasPrefix(coin, "xyz:") + + if isXyz { + // xyz dex stop loss order - use direct API call similar to placeXyzOrder + if err := t.placeXyzTriggerOrder(coin, isBuy, quantity, roundedStopPrice, "sl"); err != nil { + return fmt.Errorf("failed to set xyz dex stop loss: %w", err) + } + } else { + // Standard crypto stop loss order + // Round quantity according to coin precision requirements + roundedQuantity := t.roundToSzDecimals(coin, quantity) + + // Create stop loss order (Trigger Order) + order := hyperliquid.CreateOrderRequest{ + Coin: coin, + IsBuy: isBuy, + Size: roundedQuantity, // Use rounded quantity + Price: roundedStopPrice, // Use processed price + OrderType: hyperliquid.OrderType{ + Trigger: &hyperliquid.TriggerOrderType{ + TriggerPx: roundedStopPrice, + IsMarket: true, + Tpsl: "sl", // stop loss + }, + }, + ReduceOnly: true, + } + + _, err := t.exchange.Order(t.ctx, order, defaultBuilder) + if err != nil { + return fmt.Errorf("failed to set stop loss: %w", err) + } + } + + logger.Infof(" Stop loss price set: %.4f", roundedStopPrice) + return nil +} + +// SetTakeProfit sets take profit order +func (t *HyperliquidTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error { + coin := convertSymbolToHyperliquid(symbol) + + isBuy := positionSide == "SHORT" // Short position take profit = buy, long position take profit = sell + + // Price needs to be processed to 5 significant figures + roundedTakeProfitPrice := t.roundPriceToSigfigs(takeProfitPrice) + + // Check if this is an xyz dex asset (stocks, forex, commodities) + isXyz := strings.HasPrefix(coin, "xyz:") + + if isXyz { + // xyz dex take profit order - use direct API call similar to placeXyzOrder + if err := t.placeXyzTriggerOrder(coin, isBuy, quantity, roundedTakeProfitPrice, "tp"); err != nil { + return fmt.Errorf("failed to set xyz dex take profit: %w", err) + } + } else { + // Standard crypto take profit order + // Round quantity according to coin precision requirements + roundedQuantity := t.roundToSzDecimals(coin, quantity) + + // Create take profit order (Trigger Order) + order := hyperliquid.CreateOrderRequest{ + Coin: coin, + IsBuy: isBuy, + Size: roundedQuantity, // Use rounded quantity + Price: roundedTakeProfitPrice, // Use processed price + OrderType: hyperliquid.OrderType{ + Trigger: &hyperliquid.TriggerOrderType{ + TriggerPx: roundedTakeProfitPrice, + IsMarket: true, + Tpsl: "tp", // take profit + }, + }, + ReduceOnly: true, + } + + _, err := t.exchange.Order(t.ctx, order, defaultBuilder) + if err != nil { + return fmt.Errorf("failed to set take profit: %w", err) + } + } + + logger.Infof(" Take profit price set: %.4f", roundedTakeProfitPrice) + return nil +} + +// PlaceLimitOrder places a limit order for grid trading +// Implements GridTrader interface +func (t *HyperliquidTrader) PlaceLimitOrder(req *types.LimitOrderRequest) (*types.LimitOrderResult, error) { + coin := convertSymbolToHyperliquid(req.Symbol) + + // Set leverage if specified and not xyz dex + isXyz := strings.HasPrefix(coin, "xyz:") + if req.Leverage > 0 && !isXyz { + if err := t.SetLeverage(req.Symbol, req.Leverage); err != nil { + logger.Warnf("[Hyperliquid] Failed to set leverage: %v", err) + } + } + + // Round quantity to allowed decimals + roundedQuantity := t.roundToSzDecimals(coin, req.Quantity) + + // Round price to 5 significant figures + roundedPrice := t.roundPriceToSigfigs(req.Price) + + // Determine if buy or sell + isBuy := req.Side == "BUY" + + logger.Infof("[Hyperliquid] PlaceLimitOrder: %s %s @ %.4f, qty=%.4f", coin, req.Side, roundedPrice, roundedQuantity) + + order := hyperliquid.CreateOrderRequest{ + Coin: coin, + IsBuy: isBuy, + Size: roundedQuantity, + Price: roundedPrice, + OrderType: hyperliquid.OrderType{ + Limit: &hyperliquid.LimitOrderType{ + Tif: hyperliquid.TifGtc, // Good Till Cancel for grid orders + }, + }, + ReduceOnly: req.ReduceOnly, + } + + _, err := t.exchange.Order(t.ctx, order, defaultBuilder) + if err != nil { + return nil, fmt.Errorf("failed to place limit order: %w", err) + } + + // Note: Hyperliquid's Order response doesn't return the order ID directly + // We would need to query open orders to get it, but for grid trading + // we can track orders by price level instead + orderID := fmt.Sprintf("%d", time.Now().UnixNano()) + + logger.Infof("✓ [Hyperliquid] Limit order placed: %s %s @ %.4f", + coin, req.Side, roundedPrice) + + return &types.LimitOrderResult{ + OrderID: orderID, + ClientID: req.ClientID, + Symbol: req.Symbol, + Side: req.Side, + PositionSide: req.PositionSide, + Price: roundedPrice, + Quantity: roundedQuantity, + Status: "NEW", + }, nil +} + +// CancelOrder cancels a specific order by ID +// Implements GridTrader interface +func (t *HyperliquidTrader) CancelOrder(symbol, orderID string) error { + coin := convertSymbolToHyperliquid(symbol) + + // Parse order ID + oid, err := strconv.ParseInt(orderID, 10, 64) + if err != nil { + return fmt.Errorf("invalid order ID: %w", err) + } + + _, err = t.exchange.Cancel(t.ctx, coin, oid) + if err != nil { + return fmt.Errorf("failed to cancel order: %w", err) + } + + logger.Infof("✓ [Hyperliquid] Order cancelled: %s %s", symbol, orderID) + return nil +} diff --git a/trader/hyperliquid/trader_positions.go b/trader/hyperliquid/trader_positions.go new file mode 100644 index 00000000..c7a977eb --- /dev/null +++ b/trader/hyperliquid/trader_positions.go @@ -0,0 +1,167 @@ +package hyperliquid + +import ( + "fmt" + "nofx/logger" + "strconv" + "strings" +) + +// GetPositions gets all positions (including xyz dex positions) +func (t *HyperliquidTrader) GetPositions() ([]map[string]interface{}, error) { + // Get account status + accountState, err := t.exchange.Info().UserState(t.ctx, t.walletAddr) + if err != nil { + return nil, fmt.Errorf("failed to get positions: %w", err) + } + + var result []map[string]interface{} + + // Iterate through all perp positions + for _, assetPos := range accountState.AssetPositions { + position := assetPos.Position + + // Position amount (string type) + posAmt, _ := strconv.ParseFloat(position.Szi, 64) + + if posAmt == 0 { + continue // Skip positions with zero amount + } + + posMap := make(map[string]interface{}) + + // Normalize symbol format (Hyperliquid uses "BTC", we convert to "BTCUSDT") + symbol := position.Coin + "USDT" + posMap["symbol"] = symbol + + // Position amount and direction + if posAmt > 0 { + posMap["side"] = "long" + posMap["positionAmt"] = posAmt + } else { + posMap["side"] = "short" + posMap["positionAmt"] = -posAmt // Convert to positive number + } + + // Price information (EntryPx and LiquidationPx are pointer types) + var entryPrice, liquidationPx float64 + if position.EntryPx != nil { + entryPrice, _ = strconv.ParseFloat(*position.EntryPx, 64) + } + if position.LiquidationPx != nil { + liquidationPx, _ = strconv.ParseFloat(*position.LiquidationPx, 64) + } + + positionValue, _ := strconv.ParseFloat(position.PositionValue, 64) + unrealizedPnl, _ := strconv.ParseFloat(position.UnrealizedPnl, 64) + + // Calculate mark price (positionValue / abs(posAmt)) + var markPrice float64 + if posAmt != 0 { + markPrice = positionValue / absFloat(posAmt) + } + + posMap["entryPrice"] = entryPrice + posMap["markPrice"] = markPrice + posMap["unRealizedProfit"] = unrealizedPnl + posMap["leverage"] = float64(position.Leverage.Value) + posMap["liquidationPrice"] = liquidationPx + + result = append(result, posMap) + } + + // Also get xyz dex positions (stocks, forex, commodities) + _, _, xyzPositions, err := t.getXYZDexBalance() + if err != nil { + // xyz dex query failed - log warning but don't fail + logger.Infof("⚠️ Failed to get xyz dex positions: %v", err) + } else { + for _, pos := range xyzPositions { + posAmt, _ := strconv.ParseFloat(pos.Position.Szi, 64) + if posAmt == 0 { + continue + } + + posMap := make(map[string]interface{}) + + // xyz dex positions - the API returns coin names with xyz: prefix (e.g., "xyz:SILVER") + // Only add prefix if not already present + symbol := pos.Position.Coin + if !strings.HasPrefix(symbol, "xyz:") { + symbol = "xyz:" + symbol + } + posMap["symbol"] = symbol + + if posAmt > 0 { + posMap["side"] = "long" + posMap["positionAmt"] = posAmt + } else { + posMap["side"] = "short" + posMap["positionAmt"] = -posAmt + } + + // Parse price information + var entryPrice, liquidationPx float64 + if pos.Position.EntryPx != nil { + entryPrice, _ = strconv.ParseFloat(*pos.Position.EntryPx, 64) + } + if pos.Position.LiquidationPx != nil { + liquidationPx, _ = strconv.ParseFloat(*pos.Position.LiquidationPx, 64) + } + + positionValue, _ := strconv.ParseFloat(pos.Position.PositionValue, 64) + unrealizedPnl, _ := strconv.ParseFloat(pos.Position.UnrealizedPnl, 64) + + // Calculate mark price from position value + var markPrice float64 + if posAmt != 0 { + markPrice = positionValue / absFloat(posAmt) + } + + // Get leverage (default to 1 if not available) + leverage := float64(pos.Position.Leverage.Value) + if leverage == 0 { + leverage = 1.0 + } + + posMap["entryPrice"] = entryPrice + posMap["markPrice"] = markPrice + posMap["unRealizedProfit"] = unrealizedPnl + posMap["leverage"] = leverage + posMap["liquidationPrice"] = liquidationPx + posMap["isXyzDex"] = true // Mark as xyz dex position + + result = append(result, posMap) + } + } + + return result, nil +} + +// SetMarginMode sets margin mode (set together with SetLeverage) +func (t *HyperliquidTrader) SetMarginMode(symbol string, isCrossMargin bool) error { + // Hyperliquid's margin mode is set in SetLeverage, only record here + t.isCrossMargin = isCrossMargin + marginModeStr := "cross margin" + if !isCrossMargin { + marginModeStr = "isolated margin" + } + logger.Infof(" ✓ %s will use %s mode", symbol, marginModeStr) + return nil +} + +// SetLeverage sets leverage +func (t *HyperliquidTrader) SetLeverage(symbol string, leverage int) error { + // Hyperliquid symbol format (remove USDT suffix) + coin := convertSymbolToHyperliquid(symbol) + + // Call UpdateLeverage (leverage int, name string, isCross bool) + // Third parameter: true=cross margin mode, false=isolated margin mode + _, err := t.exchange.UpdateLeverage(t.ctx, leverage, coin, t.isCrossMargin) + if err != nil { + return fmt.Errorf("failed to set leverage: %w", err) + } + + logger.Infof(" ✓ %s leverage switched to %dx", symbol, leverage) + return nil +} diff --git a/trader/hyperliquid/trader_sync.go b/trader/hyperliquid/trader_sync.go new file mode 100644 index 00000000..5a7a21ec --- /dev/null +++ b/trader/hyperliquid/trader_sync.go @@ -0,0 +1,147 @@ +package hyperliquid + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "nofx/logger" + "strings" + "time" +) + +// refreshMetaIfNeeded refreshes meta information when invalid (triggered when Asset ID is 0) +func (t *HyperliquidTrader) refreshMetaIfNeeded(coin string) error { + assetID := t.exchange.Info().NameToAsset(coin) + if assetID != 0 { + return nil // Meta is normal, no refresh needed + } + + logger.Infof("⚠️ Asset ID for %s is 0, attempting to refresh Meta information...", coin) + + // Refresh Meta information + meta, err := t.exchange.Info().Meta(t.ctx) + if err != nil { + return fmt.Errorf("failed to refresh Meta information: %w", err) + } + + // Concurrency safe: Use write lock to protect meta field update + t.metaMutex.Lock() + t.meta = meta + t.metaMutex.Unlock() + + logger.Infof("✅ Meta information refreshed, contains %d assets", len(meta.Universe)) + + // Verify Asset ID after refresh + assetID = t.exchange.Info().NameToAsset(coin) + if assetID == 0 { + return fmt.Errorf("❌ Even after refreshing Meta, Asset ID for %s is still 0. Possible reasons:\n"+ + " 1. This coin is not listed on Hyperliquid\n"+ + " 2. Coin name is incorrect (should be BTC not BTCUSDT)\n"+ + " 3. API connection issue", coin) + } + + logger.Infof("✅ Asset ID check passed after refresh: %s -> %d", coin, assetID) + return nil +} + +// fetchXyzMeta fetches metadata for xyz dex assets (stocks, forex, commodities) +func (t *HyperliquidTrader) fetchXyzMeta() error { + // Build request for xyz dex meta + reqBody := map[string]string{ + "type": "meta", + "dex": "xyz", + } + + jsonBody, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + apiURL := "https://api.hyperliquid.xyz/info" + + req, err := http.NewRequestWithContext(t.ctx, "POST", apiURL, bytes.NewBuffer(jsonBody)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to execute request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("xyz dex meta API error (status %d): %s", resp.StatusCode, string(body)) + } + + var meta xyzDexMeta + if err := json.Unmarshal(body, &meta); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + t.xyzMetaMutex.Lock() + t.xyzMeta = &meta + t.xyzMetaMutex.Unlock() + + logger.Infof("✅ xyz dex meta fetched, contains %d assets", len(meta.Universe)) + return nil +} + +// getXyzSzDecimals gets quantity precision for xyz dex asset +func (t *HyperliquidTrader) getXyzSzDecimals(coin string) int { + t.xyzMetaMutex.RLock() + defer t.xyzMetaMutex.RUnlock() + + if t.xyzMeta == nil { + logger.Infof("⚠️ xyz meta information is empty, using default precision 2") + return 2 // Default precision for stocks/forex + } + + // The meta API returns names with xyz: prefix, so ensure we match correctly + lookupName := coin + if !strings.HasPrefix(lookupName, "xyz:") { + lookupName = "xyz:" + lookupName + } + + // Find corresponding asset in xyzMeta.Universe + for _, asset := range t.xyzMeta.Universe { + if asset.Name == lookupName { + return asset.SzDecimals + } + } + + logger.Infof("⚠️ Precision information not found for %s, using default precision 2", lookupName) + return 2 // Default precision for stocks/forex +} + +// getXyzAssetIndex gets the asset index for an xyz dex asset +func (t *HyperliquidTrader) getXyzAssetIndex(baseCoin string) int { + t.xyzMetaMutex.RLock() + defer t.xyzMetaMutex.RUnlock() + + if t.xyzMeta == nil { + return -1 + } + + // The meta API returns names with xyz: prefix, so ensure we match correctly + lookupName := baseCoin + if !strings.HasPrefix(lookupName, "xyz:") { + lookupName = "xyz:" + lookupName + } + + for i, asset := range t.xyzMeta.Universe { + if asset.Name == lookupName { + return i + } + } + return -1 +} diff --git a/trader/hyperliquid/trader_test.go b/trader/hyperliquid/trader_test.go deleted file mode 100644 index 668cd4ca..00000000 --- a/trader/hyperliquid/trader_test.go +++ /dev/null @@ -1,648 +0,0 @@ -package hyperliquid - -import ( - "context" - "crypto/ecdsa" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/ethereum/go-ethereum/crypto" - "github.com/sonirico/go-hyperliquid" - "github.com/stretchr/testify/assert" - "nofx/trader/testutil" - "nofx/trader/types" -) - -// ============================================================ -// Part 1: HyperliquidTestSuite - Inherits base test suite -// ============================================================ - -// HyperliquidTestSuite Hyperliquid trader test suite -// Inherits TraderTestSuite and adds Hyperliquid-specific mock logic -type HyperliquidTestSuite struct { - *testutil.TraderTestSuite // Embeds base test suite - mockServer *httptest.Server - privateKey *ecdsa.PrivateKey -} - -// NewHyperliquidTestSuite Create Hyperliquid test suite -func NewHyperliquidTestSuite(t *testing.T) *HyperliquidTestSuite { - // Create test private key - privateKey, err := crypto.HexToECDSA("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef") - if err != nil { - t.Fatalf("Failed to create test private key: %v", err) - } - - // Create mock HTTP server - mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Return different mock responses based on request path - var respBody interface{} - - // Hyperliquid API uses POST requests with JSON body - // We need to distinguish different requests by the "type" field in request body - var reqBody map[string]interface{} - if r.Method == "POST" { - json.NewDecoder(r.Body).Decode(&reqBody) - } - - // Try to get type from top level first, then from action object - reqType, _ := reqBody["type"].(string) - if reqType == "" && reqBody["action"] != nil { - if action, ok := reqBody["action"].(map[string]interface{}); ok { - reqType, _ = action["type"].(string) - } - } - - switch reqType { - // Mock Meta - Get market metadata - case "meta": - respBody = map[string]interface{}{ - "universe": []map[string]interface{}{ - { - "name": "BTC", - "szDecimals": 4, - "maxLeverage": 50, - "onlyIsolated": false, - "isDelisted": false, - "marginTableId": 0, - }, - { - "name": "ETH", - "szDecimals": 3, - "maxLeverage": 50, - "onlyIsolated": false, - "isDelisted": false, - "marginTableId": 0, - }, - }, - "marginTables": []interface{}{}, - } - - // Mock UserState - Get user account state (for GetBalance and GetPositions) - case "clearinghouseState": - user, _ := reqBody["user"].(string) - - // Check if querying Agent wallet balance (for security check) - agentAddr := crypto.PubkeyToAddress(privateKey.PublicKey).Hex() - if user == agentAddr { - // Agent wallet balance should be low - respBody = map[string]interface{}{ - "crossMarginSummary": map[string]interface{}{ - "accountValue": "5.00", - "totalMarginUsed": "0.00", - }, - "withdrawable": "5.00", - "assetPositions": []interface{}{}, - } - } else { - // Main wallet account state - respBody = map[string]interface{}{ - "crossMarginSummary": map[string]interface{}{ - "accountValue": "10000.00", - "totalMarginUsed": "2000.00", - }, - "withdrawable": "8000.00", - "assetPositions": []map[string]interface{}{ - { - "position": map[string]interface{}{ - "coin": "BTC", - "szi": "0.5", - "entryPx": "50000.00", - "liquidationPx": "45000.00", - "positionValue": "25000.00", - "unrealizedPnl": "100.50", - "leverage": map[string]interface{}{ - "type": "cross", - "value": 10, - }, - }, - }, - }, - } - } - - // Mock SpotUserState - Get spot account state - case "spotClearinghouseState": - respBody = map[string]interface{}{ - "balances": []map[string]interface{}{ - { - "coin": "USDC", - "total": "500.00", - }, - }, - } - - // Mock SpotMeta - Get spot market metadata - case "spotMeta": - respBody = map[string]interface{}{ - "universe": []map[string]interface{}{}, - "tokens": []map[string]interface{}{}, - } - - // Mock AllMids - Get all market prices - case "allMids": - respBody = map[string]string{ - "BTC": "50000.00", - "ETH": "3000.00", - } - - // Mock OpenOrders - Get open orders list - case "openOrders": - respBody = []interface{}{} - - // Mock Order - Create order (open, close, stop-loss, take-profit) - case "order": - respBody = map[string]interface{}{ - "status": "ok", - "response": map[string]interface{}{ - "type": "order", - "data": map[string]interface{}{ - "statuses": []map[string]interface{}{ - { - "filled": map[string]interface{}{ - "totalSz": "0.01", - "avgPx": "50000.00", - }, - }, - }, - }, - }, - } - - // Mock UpdateLeverage - Set leverage - case "updateLeverage": - respBody = map[string]interface{}{ - "status": "ok", - } - - // Mock Cancel - Cancel order - case "cancel": - respBody = map[string]interface{}{ - "status": "ok", - } - - default: - // Default return success response - respBody = map[string]interface{}{ - "status": "ok", - } - } - - // Serialize response - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(respBody) - })) - - // Create HyperliquidTrader, using mock server URL - walletAddr := "0x9999999999999999999999999999999999999999" - ctx := context.Background() - - // Create Exchange client, pointing to mock server - exchange := hyperliquid.NewExchange( - ctx, - privateKey, - mockServer.URL, // Use mock server URL - nil, - "", - walletAddr, - nil, - ) - - // Create meta (simulate successful fetch) - meta := &hyperliquid.Meta{ - Universe: []hyperliquid.AssetInfo{ - {Name: "BTC", SzDecimals: 4}, - {Name: "ETH", SzDecimals: 3}, - }, - } - - traderInstance := &HyperliquidTrader{ - exchange: exchange, - ctx: ctx, - walletAddr: walletAddr, - meta: meta, - isCrossMargin: true, - } - - // Create base suite - baseSuite := testutil.NewTraderTestSuite(t, traderInstance) - - return &HyperliquidTestSuite{ - TraderTestSuite: baseSuite, - mockServer: mockServer, - privateKey: privateKey, - } -} - -// Cleanup Clean up resources -func (s *HyperliquidTestSuite) Cleanup() { - if s.mockServer != nil { - s.mockServer.Close() - } - s.TraderTestSuite.Cleanup() -} - -// ============================================================ -// Part 2: Run common tests using HyperliquidTestSuite -// ============================================================ - -// TestHyperliquidTrader_InterfaceCompliance Test interface compliance -func TestHyperliquidTrader_InterfaceCompliance(t *testing.T) { - var _ types.Trader = (*HyperliquidTrader)(nil) -} - -// TestHyperliquidTrader_CommonInterface Run all common interface tests using test suite -func TestHyperliquidTrader_CommonInterface(t *testing.T) { - // Create test suite - suite := NewHyperliquidTestSuite(t) - defer suite.Cleanup() - - // Run all common interface tests - suite.RunAllTests() -} - -// ============================================================ -// Part 3: Hyperliquid-specific feature unit tests -// ============================================================ - -// TestNewHyperliquidTrader Test creating Hyperliquid trader -func TestNewHyperliquidTrader(t *testing.T) { - tests := []struct { - name string - privateKeyHex string - walletAddr string - testnet bool - wantError bool - errorContains string - }{ - { - name: "Invalid private key format", - privateKeyHex: "invalid_key", - walletAddr: "0x1234567890123456789012345678901234567890", - testnet: true, - wantError: true, - errorContains: "Failed to parse private key", - }, - { - name: "Empty wallet address", - privateKeyHex: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", - walletAddr: "", - testnet: true, - wantError: true, - errorContains: "Configuration error", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - trader, err := NewHyperliquidTrader(tt.privateKeyHex, tt.walletAddr, tt.testnet) - - if tt.wantError { - assert.Error(t, err) - if tt.errorContains != "" { - assert.Contains(t, err.Error(), tt.errorContains) - } - assert.Nil(t, trader) - } else { - assert.NoError(t, err) - assert.NotNil(t, trader) - if trader != nil { - assert.Equal(t, tt.walletAddr, trader.walletAddr) - assert.NotNil(t, trader.exchange) - } - } - }) - } -} - -// TestNewHyperliquidTrader_Success Test successfully creating trader (requires mock HTTP) -func TestNewHyperliquidTrader_Success(t *testing.T) { - // Create test private key - privateKey, _ := crypto.HexToECDSA("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef") - agentAddr := crypto.PubkeyToAddress(privateKey.PublicKey).Hex() - - // Create mock HTTP server - mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var reqBody map[string]interface{} - json.NewDecoder(r.Body).Decode(&reqBody) - reqType, _ := reqBody["type"].(string) - - var respBody interface{} - switch reqType { - case "meta": - respBody = map[string]interface{}{ - "universe": []map[string]interface{}{ - { - "name": "BTC", - "szDecimals": 4, - "maxLeverage": 50, - "onlyIsolated": false, - "isDelisted": false, - "marginTableId": 0, - }, - }, - "marginTables": []interface{}{}, - } - case "clearinghouseState": - user, _ := reqBody["user"].(string) - if user == agentAddr { - // Agent wallet low balance - respBody = map[string]interface{}{ - "crossMarginSummary": map[string]interface{}{ - "accountValue": "5.00", - }, - "assetPositions": []interface{}{}, - } - } else { - // Main wallet - respBody = map[string]interface{}{ - "crossMarginSummary": map[string]interface{}{ - "accountValue": "10000.00", - }, - "assetPositions": []interface{}{}, - } - } - default: - respBody = map[string]interface{}{"status": "ok"} - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(respBody) - })) - defer mockServer.Close() - - // Note: This test would actually call NewHyperliquidTrader, but will fail - // Because hyperliquid SDK doesn't allow us to inject custom URL in constructor - // So this test is only for verifying parameter handling logic - t.Skip("Skip this test: hyperliquid SDK calls real API during construction, cannot inject mock URL") -} - -// ============================================================ -// Part 4: Utility function unit tests (Hyperliquid-specific) -// ============================================================ - -// TestConvertSymbolToHyperliquid Test symbol conversion function -func TestConvertSymbolToHyperliquid(t *testing.T) { - tests := []struct { - name string - symbol string - expected string - }{ - { - name: "BTCUSDT conversion", - symbol: "BTCUSDT", - expected: "BTC", - }, - { - name: "ETHUSDT conversion", - symbol: "ETHUSDT", - expected: "ETH", - }, - { - name: "No USDT suffix", - symbol: "BTC", - expected: "BTC", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := convertSymbolToHyperliquid(tt.symbol) - assert.Equal(t, tt.expected, result) - }) - } -} - -// TestAbsFloat Test absolute value function -func TestAbsFloat(t *testing.T) { - tests := []struct { - name string - input float64 - expected float64 - }{ - { - name: "Positive number", - input: 10.5, - expected: 10.5, - }, - { - name: "Negative number", - input: -10.5, - expected: 10.5, - }, - { - name: "Zero", - input: 0, - expected: 0, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := absFloat(tt.input) - assert.Equal(t, tt.expected, result) - }) - } -} - -// TestHyperliquidTrader_RoundToSzDecimals Test quantity precision handling -func TestHyperliquidTrader_RoundToSzDecimals(t *testing.T) { - trader := &HyperliquidTrader{ - meta: &hyperliquid.Meta{ - Universe: []hyperliquid.AssetInfo{ - {Name: "BTC", SzDecimals: 4}, - {Name: "ETH", SzDecimals: 3}, - }, - }, - } - - tests := []struct { - name string - coin string - quantity float64 - expected float64 - }{ - { - name: "BTC - round to 4 decimals", - coin: "BTC", - quantity: 1.23456789, - expected: 1.2346, - }, - { - name: "ETH - round to 3 decimals", - coin: "ETH", - quantity: 10.12345, - expected: 10.123, - }, - { - name: "Unknown coin - use default 4 decimals", - coin: "UNKNOWN", - quantity: 1.23456789, - expected: 1.2346, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := trader.roundToSzDecimals(tt.coin, tt.quantity) - assert.InDelta(t, tt.expected, result, 0.0001) - }) - } -} - -// TestHyperliquidTrader_RoundPriceToSigfigs Test price significant figures handling -func TestHyperliquidTrader_RoundPriceToSigfigs(t *testing.T) { - trader := &HyperliquidTrader{} - - tests := []struct { - name string - price float64 - expected float64 - }{ - { - name: "BTC price - 5 significant figures", - price: 50123.456789, - expected: 50123.0, - }, - { - name: "Decimal price - 5 significant figures", - price: 0.0012345678, - expected: 0.0012346, - }, - { - name: "Zero price", - price: 0, - expected: 0, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := trader.roundPriceToSigfigs(tt.price) - assert.InDelta(t, tt.expected, result, tt.expected*0.001) - }) - } -} - -// TestHyperliquidTrader_GetSzDecimals Test getting precision -func TestHyperliquidTrader_GetSzDecimals(t *testing.T) { - tests := []struct { - name string - meta *hyperliquid.Meta - coin string - expected int - }{ - { - name: "meta is nil - return default precision", - meta: nil, - coin: "BTC", - expected: 4, - }, - { - name: "Found BTC - return correct precision", - meta: &hyperliquid.Meta{ - Universe: []hyperliquid.AssetInfo{ - {Name: "BTC", SzDecimals: 5}, - }, - }, - coin: "BTC", - expected: 5, - }, - { - name: "Coin not found - return default precision", - meta: &hyperliquid.Meta{ - Universe: []hyperliquid.AssetInfo{ - {Name: "ETH", SzDecimals: 3}, - }, - }, - coin: "BTC", - expected: 4, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ht := &HyperliquidTrader{meta: tt.meta} - result := ht.getSzDecimals(tt.coin) - assert.Equal(t, tt.expected, result) - }) - } -} - -// TestHyperliquidTrader_SetMarginMode Test setting margin mode -func TestHyperliquidTrader_SetMarginMode(t *testing.T) { - trader := &HyperliquidTrader{ - ctx: context.Background(), - isCrossMargin: true, - } - - tests := []struct { - name string - symbol string - isCrossMargin bool - wantError bool - }{ - { - name: "Set to cross margin mode", - symbol: "BTCUSDT", - isCrossMargin: true, - wantError: false, - }, - { - name: "Set to isolated margin mode", - symbol: "ETHUSDT", - isCrossMargin: false, - wantError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := trader.SetMarginMode(tt.symbol, tt.isCrossMargin) - - if tt.wantError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tt.isCrossMargin, trader.isCrossMargin) - } - }) - } -} - -// TestNewHyperliquidTrader_PrivateKeyProcessing Test private key processing -func TestNewHyperliquidTrader_PrivateKeyProcessing(t *testing.T) { - tests := []struct { - name string - privateKeyHex string - shouldStripOx bool - expectedLength int - }{ - { - name: "Private key with 0x prefix", - privateKeyHex: "0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", - shouldStripOx: true, - expectedLength: 64, - }, - { - name: "Private key without prefix", - privateKeyHex: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", - shouldStripOx: false, - expectedLength: 64, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Test private key prefix handling logic (without actually creating trader) - processed := tt.privateKeyHex - if len(processed) > 2 && (processed[:2] == "0x" || processed[:2] == "0X") { - processed = processed[2:] - } - - assert.Equal(t, tt.expectedLength, len(processed)) - }) - } -} diff --git a/trader/hyperliquid/xyz_dex_test.go b/trader/hyperliquid/xyz_dex_test.go deleted file mode 100644 index e04f9622..00000000 --- a/trader/hyperliquid/xyz_dex_test.go +++ /dev/null @@ -1,669 +0,0 @@ -package hyperliquid - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "os" - "strings" - "testing" - "time" -) - -// testXyzDexAsset is a local copy of testXyzDexAsset for testing -type testXyzDexAsset struct { - Name string `json:"name"` - SzDecimals int `json:"szDecimals"` - MaxLeverage int `json:"maxLeverage"` -} - -// testXyzDexMeta is a local copy of xyzDexMeta for testing -type testXyzDexMeta struct { - Universe []testXyzDexAsset `json:"universe"` -} - -// TestXyzDexMetaFetch tests fetching xyz dex meta from Hyperliquid API -func TestXyzDexMetaFetch(t *testing.T) { - reqBody := map[string]string{ - "type": "meta", - "dex": "xyz", - } - jsonBody, err := json.Marshal(reqBody) - if err != nil { - t.Fatalf("Failed to marshal request: %v", err) - } - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - req, err := http.NewRequestWithContext(ctx, "POST", "https://api.hyperliquid.xyz/info", bytes.NewBuffer(jsonBody)) - if err != nil { - t.Fatalf("Failed to create request: %v", err) - } - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - t.Fatalf("Failed to execute request: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - t.Fatalf("API returned status %d", resp.StatusCode) - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("Failed to read response: %v", err) - } - - var meta testXyzDexMeta - if err := json.Unmarshal(body, &meta); err != nil { - t.Fatalf("Failed to parse response: %v", err) - } - - if len(meta.Universe) == 0 { - t.Fatal("xyz meta universe is empty") - } - - t.Logf("✅ xyz dex meta contains %d assets", len(meta.Universe)) - - // Check that SILVER exists - // HIP-3 perp dex asset index formula: 100000 + perp_dex_index * 10000 + index_in_meta - // xyz dex is at perp_dex_index = 1 - found := false - for i, asset := range meta.Universe { - if asset.Name == "xyz:SILVER" { - found = true - assetIndex := 100000 + 1*10000 + i // xyz dex index = 1 - t.Logf("✅ Found xyz:SILVER at index %d (asset ID: %d)", i, assetIndex) - t.Logf(" SzDecimals: %d, MaxLeverage: %d", asset.SzDecimals, asset.MaxLeverage) - break - } - } - if !found { - t.Fatal("xyz:SILVER not found in meta") - } -} - -// TestXyzDexPriceFetch tests fetching xyz dex prices from Hyperliquid API -func TestXyzDexPriceFetch(t *testing.T) { - reqBody := map[string]string{ - "type": "allMids", - "dex": "xyz", - } - jsonBody, err := json.Marshal(reqBody) - if err != nil { - t.Fatalf("Failed to marshal request: %v", err) - } - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - req, err := http.NewRequestWithContext(ctx, "POST", "https://api.hyperliquid.xyz/info", bytes.NewBuffer(jsonBody)) - if err != nil { - t.Fatalf("Failed to create request: %v", err) - } - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - t.Fatalf("Failed to execute request: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - t.Fatalf("API returned status %d", resp.StatusCode) - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("Failed to read response: %v", err) - } - - var mids map[string]string - if err := json.Unmarshal(body, &mids); err != nil { - t.Fatalf("Failed to parse response: %v", err) - } - - // Check that prices have xyz: prefix - silverPrice, ok := mids["xyz:SILVER"] - if !ok { - t.Fatal("xyz:SILVER price not found (key should include xyz: prefix)") - } - t.Logf("✅ xyz:SILVER price: %s", silverPrice) - - // Verify a few more assets - testAssets := []string{"xyz:GOLD", "xyz:TSLA", "xyz:NVDA"} - for _, asset := range testAssets { - if price, ok := mids[asset]; ok { - t.Logf("✅ %s price: %s", asset, price) - } else { - t.Logf("⚠️ %s not found in prices", asset) - } - } -} - -// TestXyzAssetIndexLookup tests the asset index lookup for xyz dex assets -func TestXyzAssetIndexLookup(t *testing.T) { - // Fetch xyz meta - reqBody := map[string]string{ - "type": "meta", - "dex": "xyz", - } - jsonBody, _ := json.Marshal(reqBody) - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - req, _ := http.NewRequestWithContext(ctx, "POST", "https://api.hyperliquid.xyz/info", bytes.NewBuffer(jsonBody)) - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - t.Fatalf("Failed to fetch meta: %v", err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - var meta testXyzDexMeta - json.Unmarshal(body, &meta) - - // Test lookup with different formats - testCases := []struct { - input string - expected string // expected match in meta - }{ - {"SILVER", "xyz:SILVER"}, - {"xyz:SILVER", "xyz:SILVER"}, - {"GOLD", "xyz:GOLD"}, - {"xyz:TSLA", "xyz:TSLA"}, - } - - for _, tc := range testCases { - lookupName := tc.input - if !strings.HasPrefix(lookupName, "xyz:") { - lookupName = "xyz:" + lookupName - } - - found := false - for i, asset := range meta.Universe { - if asset.Name == lookupName { - found = true - assetIndex := 100000 + 1*10000 + i // HIP-3 formula: 100000 + xyz_dex_index(1) * 10000 + meta_index - t.Logf("✅ Lookup '%s' -> found at index %d (asset ID: %d)", tc.input, i, assetIndex) - break - } - } - if !found { - t.Errorf("❌ Lookup '%s' -> NOT FOUND (expected to match %s)", tc.input, tc.expected) - } - } -} - -// TestXyzSzDecimalsLookup tests the szDecimals lookup for different xyz assets -func TestXyzSzDecimalsLookup(t *testing.T) { - reqBody := map[string]string{ - "type": "meta", - "dex": "xyz", - } - jsonBody, _ := json.Marshal(reqBody) - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - req, _ := http.NewRequestWithContext(ctx, "POST", "https://api.hyperliquid.xyz/info", bytes.NewBuffer(jsonBody)) - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - t.Fatalf("Failed to fetch meta: %v", err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - var meta testXyzDexMeta - json.Unmarshal(body, &meta) - - // Check szDecimals for various assets - expectedDecimals := map[string]int{ - "xyz:SILVER": 2, - "xyz:GOLD": 4, - "xyz:TSLA": 3, - } - - for name, expected := range expectedDecimals { - for _, asset := range meta.Universe { - if asset.Name == name { - if asset.SzDecimals == expected { - t.Logf("✅ %s szDecimals: %d (expected %d)", name, asset.SzDecimals, expected) - } else { - t.Logf("⚠️ %s szDecimals: %d (expected %d, may have changed)", name, asset.SzDecimals, expected) - } - break - } - } - } -} - -// TestXyzOrderParameters tests order parameter calculation -func TestXyzOrderParameters(t *testing.T) { - // Simulate order parameter calculation - testCases := []struct { - price float64 - size float64 - szDecimals int - expectedSz float64 - }{ - {75.33, 1.0, 2, 1.00}, - {75.33, 1.234, 2, 1.23}, - {75.33, 5.567, 2, 5.57}, - {188.15, 0.5, 3, 0.500}, - {188.15, 0.1234, 3, 0.123}, - } - - for _, tc := range testCases { - multiplier := 1.0 - for i := 0; i < tc.szDecimals; i++ { - multiplier *= 10.0 - } - roundedSize := float64(int(tc.size*multiplier+0.5)) / multiplier - - if roundedSize != tc.expectedSz { - t.Errorf("Size rounding failed: input=%v, decimals=%d, got=%v, expected=%v", - tc.size, tc.szDecimals, roundedSize, tc.expectedSz) - } else { - t.Logf("✅ Size rounding: %v (decimals=%d) -> %v", tc.size, tc.szDecimals, roundedSize) - } - } -} - -// TestXyzAssetIndexCalculation tests the HIP-3 asset index calculation -// Formula: 100000 + perp_dex_index * 10000 + meta_index -// For xyz dex: perp_dex_index = 1, so asset_index = 110000 + meta_index -func TestXyzAssetIndexCalculation(t *testing.T) { - reqBody := map[string]string{ - "type": "meta", - "dex": "xyz", - } - jsonBody, _ := json.Marshal(reqBody) - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - req, _ := http.NewRequestWithContext(ctx, "POST", "https://api.hyperliquid.xyz/info", bytes.NewBuffer(jsonBody)) - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - t.Fatalf("Failed to fetch meta: %v", err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - var meta testXyzDexMeta - json.Unmarshal(body, &meta) - - // Test asset index calculation for SILVER - // HIP-3 perp dex asset index formula: 100000 + perp_dex_index * 10000 + index_in_meta - // xyz dex is at perp_dex_index = 1 - const xyzPerpDexIndex = 1 - for i, asset := range meta.Universe { - if asset.Name == "xyz:SILVER" { - assetIndex := 100000 + xyzPerpDexIndex*10000 + i - t.Logf("✅ xyz:SILVER: meta_index=%d, asset_index=%d", i, assetIndex) - - if assetIndex < 110000 { - t.Errorf("Asset index should be >= 110000, got %d", assetIndex) - } - break - } - } - - // Log first few assets for reference - t.Log("\nFirst 5 xyz assets:") - for i := 0; i < 5 && i < len(meta.Universe); i++ { - asset := meta.Universe[i] - assetIndex := 100000 + xyzPerpDexIndex*10000 + i - t.Logf(" [%d] %s -> asset_index=%d, szDecimals=%d", i, asset.Name, assetIndex, asset.SzDecimals) - } -} - -// TestIsXyzDexAsset tests the isXyzDexAsset function -func TestIsXyzDexAsset(t *testing.T) { - testCases := []struct { - symbol string - expected bool - }{ - {"xyz:SILVER", true}, - {"SILVER", true}, - {"silver", true}, - {"xyz:GOLD", true}, - {"GOLD", true}, - {"xyz:TSLA", true}, - {"TSLA", true}, - {"BTCUSDT", false}, - {"BTC", false}, - {"ETHUSDT", false}, - {"SOLUSDT", false}, - {"xyz:BTC", false}, // BTC is not an xyz asset - } - - for _, tc := range testCases { - result := isXyzDexAsset(tc.symbol) - if result != tc.expected { - t.Errorf("isXyzDexAsset(%q) = %v, expected %v", tc.symbol, result, tc.expected) - } else { - t.Logf("✅ isXyzDexAsset(%q) = %v", tc.symbol, result) - } - } -} - -// TestConvertSymbolToHyperliquidXyz tests symbol conversion for xyz assets -func TestConvertSymbolToHyperliquidXyz(t *testing.T) { - testCases := []struct { - input string - expected string - }{ - {"SILVER", "xyz:SILVER"}, - {"silver", "xyz:SILVER"}, - {"xyz:SILVER", "xyz:SILVER"}, - {"GOLD", "xyz:GOLD"}, - {"TSLA", "xyz:TSLA"}, - {"BTC", "BTC"}, - {"BTCUSDT", "BTC"}, - {"ETH", "ETH"}, - {"ETHUSDT", "ETH"}, - } - - for _, tc := range testCases { - result := convertSymbolToHyperliquid(tc.input) - if result != tc.expected { - t.Errorf("convertSymbolToHyperliquid(%q) = %q, expected %q", tc.input, result, tc.expected) - } else { - t.Logf("✅ convertSymbolToHyperliquid(%q) = %q", tc.input, result) - } - } -} - -// TestXyzDexOrderFlow tests the complete order flow (without actually placing an order) -func TestXyzDexOrderFlow(t *testing.T) { - t.Log("=== Testing xyz Dex Order Flow ===") - - // Step 1: Fetch meta - t.Log("\nStep 1: Fetching xyz meta...") - reqBody := map[string]string{"type": "meta", "dex": "xyz"} - jsonBody, _ := json.Marshal(reqBody) - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - req, _ := http.NewRequestWithContext(ctx, "POST", "https://api.hyperliquid.xyz/info", bytes.NewBuffer(jsonBody)) - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - t.Fatalf("Failed to fetch meta: %v", err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - var meta testXyzDexMeta - json.Unmarshal(body, &meta) - t.Logf("✅ Fetched %d xyz assets", len(meta.Universe)) - - // Step 2: Find SILVER - t.Log("\nStep 2: Looking up xyz:SILVER...") - var silverIndex int = -1 - var silverAsset *testXyzDexAsset - for i, asset := range meta.Universe { - if asset.Name == "xyz:SILVER" { - silverIndex = i - silverAsset = &meta.Universe[i] - break - } - } - if silverIndex < 0 { - t.Fatal("SILVER not found in xyz meta") - } - t.Logf("✅ Found at index %d", silverIndex) - - // Step 3: Fetch price - t.Log("\nStep 3: Fetching price...") - priceReq := map[string]string{"type": "allMids", "dex": "xyz"} - priceBody, _ := json.Marshal(priceReq) - req2, _ := http.NewRequestWithContext(ctx, "POST", "https://api.hyperliquid.xyz/info", bytes.NewBuffer(priceBody)) - req2.Header.Set("Content-Type", "application/json") - resp2, _ := client.Do(req2) - body2, _ := io.ReadAll(resp2.Body) - resp2.Body.Close() - - var mids map[string]string - json.Unmarshal(body2, &mids) - priceStr := mids["xyz:SILVER"] - var price float64 - fmt.Sscanf(priceStr, "%f", &price) - t.Logf("✅ Price: %s", priceStr) - - // Step 4: Calculate order parameters - t.Log("\nStep 4: Calculating order parameters...") - orderSize := 1.0 - multiplier := 1.0 - for i := 0; i < silverAsset.SzDecimals; i++ { - multiplier *= 10.0 - } - roundedSize := float64(int(orderSize*multiplier+0.5)) / multiplier - roundedPrice := price * 1.001 // 0.1% slippage - // HIP-3 perp dex asset index formula: 100000 + perp_dex_index * 10000 + index_in_meta - // xyz dex is at perp_dex_index = 1 - assetIndex := 100000 + 1*10000 + silverIndex - - t.Logf(" Asset Index: %d (110000 + %d)", assetIndex, silverIndex) - t.Logf(" Size: %.4f (szDecimals=%d)", roundedSize, silverAsset.SzDecimals) - t.Logf(" Price: %.4f (with slippage)", roundedPrice) - - // Step 5: Summary - t.Log("\n=== Order Flow Test Summary ===") - t.Log("✅ Meta fetch: OK") - t.Log("✅ Asset lookup: OK") - t.Log("✅ Price fetch: OK") - t.Log("✅ Parameter calculation: OK") - t.Logf("\n📋 Order would be placed with:") - t.Logf(" coin: xyz:SILVER") - t.Logf(" assetIndex: %d", assetIndex) - t.Logf(" isBuy: true") - t.Logf(" size: %.4f", roundedSize) - t.Logf(" price: %.4f", roundedPrice) -} - -// TestXyzDexLiveOrder tests placing a real order on xyz dex -// This test requires: -// - XYZ_DEX_LIVE_TEST=1 to enable -// - TEST_PRIVATE_KEY - the private key for signing -// - TEST_WALLET_ADDR - the wallet address with funds -func TestXyzDexLiveOrder(t *testing.T) { - // Skip unless explicitly enabled - if os.Getenv("XYZ_DEX_LIVE_TEST") != "1" { - t.Skip("Skipping live order test. Set XYZ_DEX_LIVE_TEST=1 to run") - } - - // Get credentials from environment variables - privateKeyHex := os.Getenv("TEST_PRIVATE_KEY") - walletAddr := os.Getenv("TEST_WALLET_ADDR") - - if privateKeyHex == "" || walletAddr == "" { - t.Skip("TEST_PRIVATE_KEY and TEST_WALLET_ADDR env vars required") - } - - t.Logf("=== Live xyz Dex Order Test ===") - t.Logf("Wallet: %s", walletAddr) - - // Create trader instance - trader, err := NewHyperliquidTrader(privateKeyHex, walletAddr, false) - if err != nil { - t.Fatalf("Failed to create trader: %v", err) - } - - // Check xyz dex balance first - xyzState, _ := trader.exchange.Info().UserState(trader.ctx, walletAddr, "xyz") - if xyzState != nil && xyzState.CrossMarginSummary.AccountValue == "0.0" { - t.Logf("⚠️ xyz dex account has no funds (balance: %s)", xyzState.CrossMarginSummary.AccountValue) - t.Logf(" To trade xyz dex, you need to transfer funds using perpDexClassTransfer") - t.Logf(" The test will still verify order signing and submission...") - } - - // Fetch xyz meta first - if err := trader.fetchXyzMeta(); err != nil { - t.Fatalf("Failed to fetch xyz meta: %v", err) - } - - // Get current price for xyz:SILVER - price, err := trader.getXyzMarketPrice("xyz:SILVER") - if err != nil { - t.Fatalf("Failed to get price: %v", err) - } - t.Logf("Current xyz:SILVER price: %.4f", price) - - // Place a test order (minimum $10 value = 0.14 SILVER at ~$75) - // With 5% slippage for IOC (market order) - testSize := 0.14 // ~$10.5 at current price - testPrice := price * 1.05 // 5% above market for IOC buy (market order) - - t.Logf("Attempting to place order:") - t.Logf(" Symbol: xyz:SILVER") - t.Logf(" Side: BUY") - t.Logf(" Size: %.4f", testSize) - t.Logf(" Price: %.4f", testPrice) - - // Place the order using the new direct method - err = trader.placeXyzOrder("xyz:SILVER", true, testSize, testPrice, false) - if err != nil { - t.Logf("⚠️ Order result: %v", err) - // Check if this is an expected error (e.g., insufficient margin, no matching orders for IOC) - if strings.Contains(err.Error(), "insufficient") || strings.Contains(err.Error(), "margin") || strings.Contains(err.Error(), "minimum value") { - t.Logf("This may be expected if the test wallet has no margin in xyz dex") - t.Logf("✅ Order was properly signed and submitted (API validated format/signature)") - } else if strings.Contains(err.Error(), "could not immediately match") { - // IOC order didn't fill - this is actually SUCCESS! - // It means the order was properly signed, submitted, and processed - t.Logf("✅ Order was properly submitted but didn't fill (IOC with no matching orders)") - t.Logf(" This confirms the asset index (%d) and signing are correct!", 110026) - } else if strings.Contains(err.Error(), "Order has invalid price") || strings.Contains(err.Error(), "95% away") { - t.Errorf("FAILED: Order has invalid price - asset index issue") - } else { - t.Errorf("FAILED: Unexpected error: %v", err) - } - } else { - t.Logf("✅ Order placed and filled successfully!") - } -} - -// TestXyzDexClosePosition tests closing a position on xyz dex -// This test requires the XYZ_DEX_LIVE_TEST environment variable to be set -func TestXyzDexClosePosition(t *testing.T) { - // Skip unless explicitly enabled - if os.Getenv("XYZ_DEX_LIVE_TEST") != "1" { - t.Skip("Skipping live close position test. Set XYZ_DEX_LIVE_TEST=1 to run") - } - - // Get credentials from environment variables - privateKeyHex := os.Getenv("TEST_PRIVATE_KEY") - walletAddr := os.Getenv("TEST_WALLET_ADDR") - - if privateKeyHex == "" || walletAddr == "" { - t.Skip("TEST_PRIVATE_KEY and TEST_WALLET_ADDR env vars required") - } - - t.Logf("=== Live xyz Dex Close Position Test ===") - t.Logf("Wallet: %s", walletAddr) - - // Create trader instance - trader, err := NewHyperliquidTrader(privateKeyHex, walletAddr, false) - if err != nil { - t.Fatalf("Failed to create trader: %v", err) - } - - // Check current xyz dex position - xyzState, err := trader.exchange.Info().UserState(trader.ctx, walletAddr, "xyz") - if err != nil { - t.Fatalf("Failed to get xyz state: %v", err) - } - - if len(xyzState.AssetPositions) == 0 { - t.Logf("No xyz dex positions to close") - return - } - - // Get the position details - pos := xyzState.AssetPositions[0].Position - entryPx := "" - if pos.EntryPx != nil { - entryPx = *pos.EntryPx - } - t.Logf("Current position: %s size=%s entryPx=%s", pos.Coin, pos.Szi, entryPx) - - // Fetch xyz meta - if err := trader.fetchXyzMeta(); err != nil { - t.Fatalf("Failed to fetch xyz meta: %v", err) - } - - // Get current price - price, err := trader.getXyzMarketPrice(pos.Coin) - if err != nil { - t.Fatalf("Failed to get price: %v", err) - } - t.Logf("Current %s price: %.4f", pos.Coin, price) - - // Parse position size - var posSize float64 - fmt.Sscanf(pos.Szi, "%f", &posSize) - - // Close position: if long (szi > 0), sell; if short (szi < 0), buy - isBuy := posSize < 0 - closeSize := posSize - if closeSize < 0 { - closeSize = -closeSize - } - - // Use aggressive slippage for close - closePrice := price * 0.95 // 5% below for sell - if isBuy { - closePrice = price * 1.05 // 5% above for buy - } - - t.Logf("Closing position:") - t.Logf(" Side: %s", map[bool]string{true: "BUY", false: "SELL"}[isBuy]) - t.Logf(" Size: %.4f", closeSize) - t.Logf(" Price: %.4f", closePrice) - - // Place close order with reduceOnly=true - err = trader.placeXyzOrder(pos.Coin, isBuy, closeSize, closePrice, true) - if err != nil { - t.Logf("⚠️ Close order result: %v", err) - if strings.Contains(err.Error(), "could not immediately match") { - t.Logf("✅ Close order submitted but didn't fill (IOC)") - } else { - t.Errorf("FAILED: %v", err) - } - } else { - t.Logf("✅ Position closed successfully!") - } - - // Verify position is closed - xyzState2, _ := trader.exchange.Info().UserState(trader.ctx, walletAddr, "xyz") - if len(xyzState2.AssetPositions) == 0 { - t.Logf("✅ Position confirmed closed (no positions remaining)") - } else { - newPos := xyzState2.AssetPositions[0].Position - t.Logf("Position after close: %s size=%s", newPos.Coin, newPos.Szi) - } -} diff --git a/trader/indodax/trader.go b/trader/indodax/trader.go index ac49ac07..b1762ba9 100644 --- a/trader/indodax/trader.go +++ b/trader/indodax/trader.go @@ -7,11 +7,9 @@ import ( "encoding/json" "fmt" "io" - "math" "net/http" "net/url" "nofx/logger" - "nofx/trader/types" "strconv" "strings" "sync" @@ -299,561 +297,6 @@ func (t *IndodaxTrader) clearCache() { t.cachedPositions = nil } -// ============================================================ -// types.Trader interface implementation -// ============================================================ - -// GetBalance gets account balance from Indodax -func (t *IndodaxTrader) GetBalance() (map[string]interface{}, error) { - // Check cache - t.cacheMutex.RLock() - if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration { - cached := t.cachedBalance - t.cacheMutex.RUnlock() - return cached, nil - } - t.cacheMutex.RUnlock() - - params := url.Values{} - params.Set("method", "getInfo") - - data, err := t.doPrivateRequest(params) - if err != nil { - return nil, fmt.Errorf("failed to get account info: %w", err) - } - - var result struct { - ServerTime int64 `json:"server_time"` - Balance map[string]interface{} `json:"balance"` - BalanceHold map[string]interface{} `json:"balance_hold"` - UserID string `json:"user_id"` - Name string `json:"name"` - Email string `json:"email"` - } - - if err := json.Unmarshal(data, &result); err != nil { - return nil, fmt.Errorf("failed to parse balance: %w", err) - } - - // Calculate total balance in IDR - idrBalance := parseFloat(result.Balance["idr"]) - idrHold := parseFloat(result.BalanceHold["idr"]) - totalIDR := idrBalance + idrHold - - balance := map[string]interface{}{ - "totalWalletBalance": totalIDR, - "availableBalance": idrBalance, - "totalUnrealizedProfit": 0.0, - "totalEquity": totalIDR, - "balance": totalIDR, - "idr_balance": idrBalance, - "idr_hold": idrHold, - "currency": "IDR", - "user_id": result.UserID, - "server_time": result.ServerTime, - } - - // Add individual crypto balances - for currency, amount := range result.Balance { - if currency != "idr" { - balance["balance_"+currency] = parseFloat(amount) - } - } - for currency, amount := range result.BalanceHold { - if currency != "idr" { - balance["hold_"+currency] = parseFloat(amount) - } - } - - // Update cache - t.cacheMutex.Lock() - t.cachedBalance = balance - t.balanceCacheTime = time.Now() - t.cacheMutex.Unlock() - - return balance, nil -} - -// GetPositions returns currently held crypto balances as "positions" -// Since Indodax is spot-only, each non-zero crypto balance is treated as a position -func (t *IndodaxTrader) GetPositions() ([]map[string]interface{}, error) { - // Check cache - t.cacheMutex.RLock() - if t.cachedPositions != nil && time.Since(t.positionCacheTime) < t.cacheDuration { - cached := t.cachedPositions - t.cacheMutex.RUnlock() - return cached, nil - } - t.cacheMutex.RUnlock() - - params := url.Values{} - params.Set("method", "getInfo") - - data, err := t.doPrivateRequest(params) - if err != nil { - return nil, fmt.Errorf("failed to get positions: %w", err) - } - - var result struct { - Balance map[string]interface{} `json:"balance"` - BalanceHold map[string]interface{} `json:"balance_hold"` - } - - if err := json.Unmarshal(data, &result); err != nil { - return nil, fmt.Errorf("failed to parse positions: %w", err) - } - - var positions []map[string]interface{} - - for currency, amountRaw := range result.Balance { - if currency == "idr" { - continue - } - - amount := parseFloat(amountRaw) - holdAmount := parseFloat(result.BalanceHold[currency]) - totalAmount := amount + holdAmount - - if totalAmount <= 0 { - continue - } - - // Get market price for this coin - markPrice, _ := t.GetMarketPrice(strings.ToUpper(currency) + "IDR") - - // Calculate position value in IDR - notionalValue := totalAmount * markPrice - - position := map[string]interface{}{ - "symbol": strings.ToUpper(currency) + "IDR", - "side": "LONG", - "positionAmt": totalAmount, - "entryPrice": markPrice, // Spot doesn't track entry price - "markPrice": markPrice, - "unRealizedProfit": 0.0, // Spot doesn't track unrealized PnL - "leverage": 1.0, - "mgnMode": "spot", - "notionalValue": notionalValue, - "currency": currency, - "available": amount, - "hold": holdAmount, - } - - positions = append(positions, position) - } - - // Update cache - t.cacheMutex.Lock() - t.cachedPositions = positions - t.positionCacheTime = time.Now() - t.cacheMutex.Unlock() - - return positions, nil -} - -// OpenLong opens a spot buy order -func (t *IndodaxTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { - t.clearCache() - - pair := t.convertSymbol(symbol) - coin := t.getCoinFromSymbol(symbol) - - // Get market price to calculate IDR amount - price, err := t.GetMarketPrice(symbol) - if err != nil { - return nil, fmt.Errorf("failed to get market price: %w", err) - } - - params := url.Values{} - params.Set("method", "trade") - params.Set("pair", pair) - params.Set("type", "buy") - params.Set("price", strconv.FormatFloat(price, 'f', 0, 64)) - params.Set(coin, strconv.FormatFloat(quantity, 'f', 8, 64)) - params.Set("order_type", "limit") - - data, err := t.doPrivateRequest(params) - if err != nil { - return nil, fmt.Errorf("failed to place buy order: %w", err) - } - - var result map[string]interface{} - if err := json.Unmarshal(data, &result); err != nil { - return nil, fmt.Errorf("failed to parse trade response: %w", err) - } - - logger.Infof("[Indodax] Buy order placed: %s qty=%.8f price=%.0f", symbol, quantity, price) - - return map[string]interface{}{ - "orderId": result["order_id"], - "symbol": symbol, - "side": "BUY", - "price": price, - "qty": quantity, - "status": "NEW", - }, nil -} - -// OpenShort is not supported on Indodax (spot-only exchange) -func (t *IndodaxTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { - return nil, fmt.Errorf("short selling is not supported on Indodax (spot-only exchange)") -} - -// CloseLong closes a spot position by selling -func (t *IndodaxTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) { - t.clearCache() - - pair := t.convertSymbol(symbol) - coin := t.getCoinFromSymbol(symbol) - - // If quantity is 0, sell all available balance - if quantity <= 0 { - balance, err := t.GetBalance() - if err != nil { - return nil, fmt.Errorf("failed to get balance for close all: %w", err) - } - available := parseFloat(balance["balance_"+coin]) - if available <= 0 { - return nil, fmt.Errorf("no %s balance to sell", coin) - } - quantity = available - } - - // Get market price - price, err := t.GetMarketPrice(symbol) - if err != nil { - return nil, fmt.Errorf("failed to get market price: %w", err) - } - - params := url.Values{} - params.Set("method", "trade") - params.Set("pair", pair) - params.Set("type", "sell") - params.Set("price", strconv.FormatFloat(price, 'f', 0, 64)) - params.Set(coin, strconv.FormatFloat(quantity, 'f', 8, 64)) - params.Set("order_type", "limit") - - data, err := t.doPrivateRequest(params) - if err != nil { - return nil, fmt.Errorf("failed to place sell order: %w", err) - } - - var result map[string]interface{} - if err := json.Unmarshal(data, &result); err != nil { - return nil, fmt.Errorf("failed to parse trade response: %w", err) - } - - logger.Infof("[Indodax] Sell order placed: %s qty=%.8f price=%.0f", symbol, quantity, price) - - return map[string]interface{}{ - "orderId": result["order_id"], - "symbol": symbol, - "side": "SELL", - "price": price, - "qty": quantity, - "status": "NEW", - }, nil -} - -// CloseShort is not supported on Indodax (spot-only exchange) -func (t *IndodaxTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) { - return nil, fmt.Errorf("short selling is not supported on Indodax (spot-only exchange)") -} - -// SetLeverage is a no-op for Indodax (spot-only, no leverage) -func (t *IndodaxTrader) SetLeverage(symbol string, leverage int) error { - logger.Infof("[Indodax] SetLeverage ignored (spot-only exchange, no leverage support)") - return nil -} - -// SetMarginMode is a no-op for Indodax (spot-only, no margin) -func (t *IndodaxTrader) SetMarginMode(symbol string, isCrossMargin bool) error { - logger.Infof("[Indodax] SetMarginMode ignored (spot-only exchange, no margin support)") - return nil -} - -// GetMarketPrice gets the current market price for a symbol -func (t *IndodaxTrader) GetMarketPrice(symbol string) (float64, error) { - pairID := strings.ToLower(strings.ReplaceAll(t.convertSymbol(symbol), "_", "")) - - data, err := t.doPublicRequest("/ticker/" + pairID) - if err != nil { - return 0, fmt.Errorf("failed to get ticker: %w", err) - } - - var tickerResp IndodaxTickerResponse - if err := json.Unmarshal(data, &tickerResp); err != nil { - return 0, fmt.Errorf("failed to parse ticker: %w", err) - } - - price, err := strconv.ParseFloat(tickerResp.Ticker.Last, 64) - if err != nil { - return 0, fmt.Errorf("failed to parse price '%s': %w", tickerResp.Ticker.Last, err) - } - - return price, nil -} - -// SetStopLoss is not supported on Indodax (spot-only exchange) -func (t *IndodaxTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error { - return fmt.Errorf("stop-loss orders are not supported on Indodax (spot-only exchange)") -} - -// SetTakeProfit is not supported on Indodax (spot-only exchange) -func (t *IndodaxTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error { - return fmt.Errorf("take-profit orders are not supported on Indodax (spot-only exchange)") -} - -// CancelStopLossOrders is a no-op for Indodax -func (t *IndodaxTrader) CancelStopLossOrders(symbol string) error { - return nil -} - -// CancelTakeProfitOrders is a no-op for Indodax -func (t *IndodaxTrader) CancelTakeProfitOrders(symbol string) error { - return nil -} - -// CancelAllOrders cancels all open orders for a given symbol -func (t *IndodaxTrader) CancelAllOrders(symbol string) error { - t.clearCache() - - pair := t.convertSymbol(symbol) - - // First get open orders - params := url.Values{} - params.Set("method", "openOrders") - params.Set("pair", pair) - - data, err := t.doPrivateRequest(params) - if err != nil { - return fmt.Errorf("failed to get open orders: %w", err) - } - - var result struct { - Orders []struct { - OrderID json.Number `json:"order_id"` - Type string `json:"type"` - OrderType string `json:"order_type"` - } `json:"orders"` - } - - if err := json.Unmarshal(data, &result); err != nil { - return fmt.Errorf("failed to parse open orders: %w", err) - } - - // Cancel each order - for _, order := range result.Orders { - cancelParams := url.Values{} - cancelParams.Set("method", "cancelOrder") - cancelParams.Set("pair", pair) - cancelParams.Set("order_id", order.OrderID.String()) - cancelParams.Set("type", order.Type) - - if _, err := t.doPrivateRequest(cancelParams); err != nil { - logger.Warnf("[Indodax] Failed to cancel order %s: %v", order.OrderID, err) - } else { - logger.Infof("[Indodax] Cancelled order: %s", order.OrderID) - } - } - - return nil -} - -// CancelStopOrders is a no-op for Indodax (no stop orders) -func (t *IndodaxTrader) CancelStopOrders(symbol string) error { - return nil -} - -// FormatQuantity formats quantity to correct precision for Indodax -func (t *IndodaxTrader) FormatQuantity(symbol string, quantity float64) (string, error) { - pair, err := t.getPair(symbol) - if err != nil { - // Default: 8 decimal places - return strconv.FormatFloat(quantity, 'f', 8, 64), nil - } - - precision := pair.PriceRound - if precision <= 0 { - precision = 8 - } - - // Round down to avoid exceeding balance - factor := math.Pow(10, float64(precision)) - rounded := math.Floor(quantity*factor) / factor - - return strconv.FormatFloat(rounded, 'f', precision, 64), nil -} - -// GetOrderStatus gets the status of a specific order -func (t *IndodaxTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { - pair := t.convertSymbol(symbol) - - params := url.Values{} - params.Set("method", "getOrder") - params.Set("pair", pair) - params.Set("order_id", orderID) - - data, err := t.doPrivateRequest(params) - if err != nil { - return nil, fmt.Errorf("failed to get order status: %w", err) - } - - var result struct { - Order struct { - OrderID string `json:"order_id"` - Price string `json:"price"` - Type string `json:"type"` - Status string `json:"status"` - SubmitTime string `json:"submit_time"` - FinishTime string `json:"finish_time"` - ClientOrderID string `json:"client_order_id"` - } `json:"order"` - } - - if err := json.Unmarshal(data, &result); err != nil { - return nil, fmt.Errorf("failed to parse order: %w", err) - } - - // Map Indodax status to standard status - status := "NEW" - switch result.Order.Status { - case "filled": - status = "FILLED" - case "cancelled": - status = "CANCELED" - case "open": - status = "NEW" - } - - price, _ := strconv.ParseFloat(result.Order.Price, 64) - - return map[string]interface{}{ - "status": status, - "avgPrice": price, - "executedQty": 0.0, // Indodax doesn't return executed qty in getOrder - "commission": 0.0, - "orderId": result.Order.OrderID, - }, nil -} - -// GetClosedPnL gets closed position PnL records (trade history) -func (t *IndodaxTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) { - // Indodax trade history is limited to 7 days range - params := url.Values{} - params.Set("method", "tradeHistory") - params.Set("pair", "btc_idr") // Default pair; Indodax requires a pair - if limit > 0 { - params.Set("count", strconv.Itoa(limit)) - } - if !startTime.IsZero() { - params.Set("since", strconv.FormatInt(startTime.Unix(), 10)) - } - - data, err := t.doPrivateRequest(params) - if err != nil { - return nil, fmt.Errorf("failed to get trade history: %w", err) - } - - var result struct { - Trades []struct { - TradeID string `json:"trade_id"` - OrderID string `json:"order_id"` - Type string `json:"type"` - Price string `json:"price"` - Fee string `json:"fee"` - TradeTime string `json:"trade_time"` - ClientOrderID string `json:"client_order_id"` - } `json:"trades"` - } - - if err := json.Unmarshal(data, &result); err != nil { - // Trade history might return empty, that's fine - return nil, nil - } - - var records []types.ClosedPnLRecord - for _, trade := range result.Trades { - price, _ := strconv.ParseFloat(trade.Price, 64) - fee, _ := strconv.ParseFloat(trade.Fee, 64) - tradeTime, _ := strconv.ParseInt(trade.TradeTime, 10, 64) - - side := "long" - if trade.Type == "sell" { - side = "long" // Selling from a spot position is closing long - } - - records = append(records, types.ClosedPnLRecord{ - Symbol: "BTCIDR", - Side: side, - ExitPrice: price, - Fee: fee, - ExitTime: time.Unix(tradeTime, 0), - OrderID: trade.OrderID, - CloseType: "manual", - }) - } - - return records, nil -} - -// GetOpenOrders gets open/pending orders -func (t *IndodaxTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) { - pair := t.convertSymbol(symbol) - - params := url.Values{} - params.Set("method", "openOrders") - if pair != "" { - params.Set("pair", pair) - } - - data, err := t.doPrivateRequest(params) - if err != nil { - return nil, fmt.Errorf("failed to get open orders: %w", err) - } - - var result struct { - Orders []struct { - OrderID json.Number `json:"order_id"` - ClientOrderID string `json:"client_order_id"` - SubmitTime string `json:"submit_time"` - Price string `json:"price"` - Type string `json:"type"` - OrderType string `json:"order_type"` - } `json:"orders"` - } - - if err := json.Unmarshal(data, &result); err != nil { - return nil, fmt.Errorf("failed to parse open orders: %w", err) - } - - var orders []types.OpenOrder - for _, order := range result.Orders { - price, _ := strconv.ParseFloat(order.Price, 64) - - side := "BUY" - if order.Type == "sell" { - side = "SELL" - } - - orders = append(orders, types.OpenOrder{ - OrderID: order.OrderID.String(), - Symbol: t.convertSymbolBack(pair), - Side: side, - PositionSide: "LONG", - Type: "LIMIT", - Price: price, - Status: "NEW", - }) - } - - return orders, nil -} - -// ============================================================ -// Helper functions -// ============================================================ - // parseFloat safely parses a float from interface{} func parseFloat(v interface{}) float64 { if v == nil { diff --git a/trader/indodax/trader_account.go b/trader/indodax/trader_account.go new file mode 100644 index 00000000..cd40fa88 --- /dev/null +++ b/trader/indodax/trader_account.go @@ -0,0 +1,221 @@ +package indodax + +import ( + "encoding/json" + "fmt" + "net/url" + "nofx/logger" + "nofx/trader/types" + "strconv" + "strings" + "time" +) + +// GetBalance gets account balance from Indodax +func (t *IndodaxTrader) GetBalance() (map[string]interface{}, error) { + // Check cache + t.cacheMutex.RLock() + if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration { + cached := t.cachedBalance + t.cacheMutex.RUnlock() + return cached, nil + } + t.cacheMutex.RUnlock() + + params := url.Values{} + params.Set("method", "getInfo") + + data, err := t.doPrivateRequest(params) + if err != nil { + return nil, fmt.Errorf("failed to get account info: %w", err) + } + + var result struct { + ServerTime int64 `json:"server_time"` + Balance map[string]interface{} `json:"balance"` + BalanceHold map[string]interface{} `json:"balance_hold"` + UserID string `json:"user_id"` + Name string `json:"name"` + Email string `json:"email"` + } + + if err := json.Unmarshal(data, &result); err != nil { + return nil, fmt.Errorf("failed to parse balance: %w", err) + } + + // Calculate total balance in IDR + idrBalance := parseFloat(result.Balance["idr"]) + idrHold := parseFloat(result.BalanceHold["idr"]) + totalIDR := idrBalance + idrHold + + balance := map[string]interface{}{ + "totalWalletBalance": totalIDR, + "availableBalance": idrBalance, + "totalUnrealizedProfit": 0.0, + "totalEquity": totalIDR, + "balance": totalIDR, + "idr_balance": idrBalance, + "idr_hold": idrHold, + "currency": "IDR", + "user_id": result.UserID, + "server_time": result.ServerTime, + } + + // Add individual crypto balances + for currency, amount := range result.Balance { + if currency != "idr" { + balance["balance_"+currency] = parseFloat(amount) + } + } + for currency, amount := range result.BalanceHold { + if currency != "idr" { + balance["hold_"+currency] = parseFloat(amount) + } + } + + // Update cache + t.cacheMutex.Lock() + t.cachedBalance = balance + t.balanceCacheTime = time.Now() + t.cacheMutex.Unlock() + + return balance, nil +} + +// GetPositions returns currently held crypto balances as "positions" +// Since Indodax is spot-only, each non-zero crypto balance is treated as a position +func (t *IndodaxTrader) GetPositions() ([]map[string]interface{}, error) { + // Check cache + t.cacheMutex.RLock() + if t.cachedPositions != nil && time.Since(t.positionCacheTime) < t.cacheDuration { + cached := t.cachedPositions + t.cacheMutex.RUnlock() + return cached, nil + } + t.cacheMutex.RUnlock() + + params := url.Values{} + params.Set("method", "getInfo") + + data, err := t.doPrivateRequest(params) + if err != nil { + return nil, fmt.Errorf("failed to get positions: %w", err) + } + + var result struct { + Balance map[string]interface{} `json:"balance"` + BalanceHold map[string]interface{} `json:"balance_hold"` + } + + if err := json.Unmarshal(data, &result); err != nil { + return nil, fmt.Errorf("failed to parse positions: %w", err) + } + + var positions []map[string]interface{} + + for currency, amountRaw := range result.Balance { + if currency == "idr" { + continue + } + + amount := parseFloat(amountRaw) + holdAmount := parseFloat(result.BalanceHold[currency]) + totalAmount := amount + holdAmount + + if totalAmount <= 0 { + continue + } + + // Get market price for this coin + markPrice, _ := t.GetMarketPrice(strings.ToUpper(currency) + "IDR") + + // Calculate position value in IDR + notionalValue := totalAmount * markPrice + + position := map[string]interface{}{ + "symbol": strings.ToUpper(currency) + "IDR", + "side": "LONG", + "positionAmt": totalAmount, + "entryPrice": markPrice, // Spot doesn't track entry price + "markPrice": markPrice, + "unRealizedProfit": 0.0, // Spot doesn't track unrealized PnL + "leverage": 1.0, + "mgnMode": "spot", + "notionalValue": notionalValue, + "currency": currency, + "available": amount, + "hold": holdAmount, + } + + positions = append(positions, position) + } + + // Update cache + t.cacheMutex.Lock() + t.cachedPositions = positions + t.positionCacheTime = time.Now() + t.cacheMutex.Unlock() + + return positions, nil +} + +// GetClosedPnL gets closed position PnL records (trade history) +func (t *IndodaxTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) { + // Indodax trade history is limited to 7 days range + params := url.Values{} + params.Set("method", "tradeHistory") + params.Set("pair", "btc_idr") // Default pair; Indodax requires a pair + if limit > 0 { + params.Set("count", strconv.Itoa(limit)) + } + if !startTime.IsZero() { + params.Set("since", strconv.FormatInt(startTime.Unix(), 10)) + } + + data, err := t.doPrivateRequest(params) + if err != nil { + return nil, fmt.Errorf("failed to get trade history: %w", err) + } + + var result struct { + Trades []struct { + TradeID string `json:"trade_id"` + OrderID string `json:"order_id"` + Type string `json:"type"` + Price string `json:"price"` + Fee string `json:"fee"` + TradeTime string `json:"trade_time"` + ClientOrderID string `json:"client_order_id"` + } `json:"trades"` + } + + if err := json.Unmarshal(data, &result); err != nil { + // Trade history might return empty, that's fine + logger.Infof("[Indodax] Trade history parse note: %v", err) + return nil, nil + } + + var records []types.ClosedPnLRecord + for _, trade := range result.Trades { + price, _ := strconv.ParseFloat(trade.Price, 64) + fee, _ := strconv.ParseFloat(trade.Fee, 64) + tradeTime, _ := strconv.ParseInt(trade.TradeTime, 10, 64) + + side := "long" + if trade.Type == "sell" { + side = "long" // Selling from a spot position is closing long + } + + records = append(records, types.ClosedPnLRecord{ + Symbol: "BTCIDR", + Side: side, + ExitPrice: price, + Fee: fee, + ExitTime: time.Unix(tradeTime, 0), + OrderID: trade.OrderID, + CloseType: "manual", + }) + } + + return records, nil +} diff --git a/trader/indodax/trader_orders.go b/trader/indodax/trader_orders.go new file mode 100644 index 00000000..3a8fe513 --- /dev/null +++ b/trader/indodax/trader_orders.go @@ -0,0 +1,351 @@ +package indodax + +import ( + "encoding/json" + "fmt" + "math" + "net/url" + "nofx/logger" + "nofx/trader/types" + "strconv" + "strings" +) + +// OpenLong opens a spot buy order +func (t *IndodaxTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { + t.clearCache() + + pair := t.convertSymbol(symbol) + coin := t.getCoinFromSymbol(symbol) + + // Get market price to calculate IDR amount + price, err := t.GetMarketPrice(symbol) + if err != nil { + return nil, fmt.Errorf("failed to get market price: %w", err) + } + + params := url.Values{} + params.Set("method", "trade") + params.Set("pair", pair) + params.Set("type", "buy") + params.Set("price", strconv.FormatFloat(price, 'f', 0, 64)) + params.Set(coin, strconv.FormatFloat(quantity, 'f', 8, 64)) + params.Set("order_type", "limit") + + data, err := t.doPrivateRequest(params) + if err != nil { + return nil, fmt.Errorf("failed to place buy order: %w", err) + } + + var result map[string]interface{} + if err := json.Unmarshal(data, &result); err != nil { + return nil, fmt.Errorf("failed to parse trade response: %w", err) + } + + logger.Infof("[Indodax] Buy order placed: %s qty=%.8f price=%.0f", symbol, quantity, price) + + return map[string]interface{}{ + "orderId": result["order_id"], + "symbol": symbol, + "side": "BUY", + "price": price, + "qty": quantity, + "status": "NEW", + }, nil +} + +// OpenShort is not supported on Indodax (spot-only exchange) +func (t *IndodaxTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { + return nil, fmt.Errorf("short selling is not supported on Indodax (spot-only exchange)") +} + +// CloseLong closes a spot position by selling +func (t *IndodaxTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) { + t.clearCache() + + pair := t.convertSymbol(symbol) + coin := t.getCoinFromSymbol(symbol) + + // If quantity is 0, sell all available balance + if quantity <= 0 { + balance, err := t.GetBalance() + if err != nil { + return nil, fmt.Errorf("failed to get balance for close all: %w", err) + } + available := parseFloat(balance["balance_"+coin]) + if available <= 0 { + return nil, fmt.Errorf("no %s balance to sell", coin) + } + quantity = available + } + + // Get market price + price, err := t.GetMarketPrice(symbol) + if err != nil { + return nil, fmt.Errorf("failed to get market price: %w", err) + } + + params := url.Values{} + params.Set("method", "trade") + params.Set("pair", pair) + params.Set("type", "sell") + params.Set("price", strconv.FormatFloat(price, 'f', 0, 64)) + params.Set(coin, strconv.FormatFloat(quantity, 'f', 8, 64)) + params.Set("order_type", "limit") + + data, err := t.doPrivateRequest(params) + if err != nil { + return nil, fmt.Errorf("failed to place sell order: %w", err) + } + + var result map[string]interface{} + if err := json.Unmarshal(data, &result); err != nil { + return nil, fmt.Errorf("failed to parse trade response: %w", err) + } + + logger.Infof("[Indodax] Sell order placed: %s qty=%.8f price=%.0f", symbol, quantity, price) + + return map[string]interface{}{ + "orderId": result["order_id"], + "symbol": symbol, + "side": "SELL", + "price": price, + "qty": quantity, + "status": "NEW", + }, nil +} + +// CloseShort is not supported on Indodax (spot-only exchange) +func (t *IndodaxTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) { + return nil, fmt.Errorf("short selling is not supported on Indodax (spot-only exchange)") +} + +// SetLeverage is a no-op for Indodax (spot-only, no leverage) +func (t *IndodaxTrader) SetLeverage(symbol string, leverage int) error { + logger.Infof("[Indodax] SetLeverage ignored (spot-only exchange, no leverage support)") + return nil +} + +// SetMarginMode is a no-op for Indodax (spot-only, no margin) +func (t *IndodaxTrader) SetMarginMode(symbol string, isCrossMargin bool) error { + logger.Infof("[Indodax] SetMarginMode ignored (spot-only exchange, no margin support)") + return nil +} + +// GetMarketPrice gets the current market price for a symbol +func (t *IndodaxTrader) GetMarketPrice(symbol string) (float64, error) { + pairID := strings.ToLower(strings.ReplaceAll(t.convertSymbol(symbol), "_", "")) + + data, err := t.doPublicRequest("/ticker/" + pairID) + if err != nil { + return 0, fmt.Errorf("failed to get ticker: %w", err) + } + + var tickerResp IndodaxTickerResponse + if err := json.Unmarshal(data, &tickerResp); err != nil { + return 0, fmt.Errorf("failed to parse ticker: %w", err) + } + + price, err := strconv.ParseFloat(tickerResp.Ticker.Last, 64) + if err != nil { + return 0, fmt.Errorf("failed to parse price '%s': %w", tickerResp.Ticker.Last, err) + } + + return price, nil +} + +// SetStopLoss is not supported on Indodax (spot-only exchange) +func (t *IndodaxTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error { + return fmt.Errorf("stop-loss orders are not supported on Indodax (spot-only exchange)") +} + +// SetTakeProfit is not supported on Indodax (spot-only exchange) +func (t *IndodaxTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error { + return fmt.Errorf("take-profit orders are not supported on Indodax (spot-only exchange)") +} + +// CancelStopLossOrders is a no-op for Indodax +func (t *IndodaxTrader) CancelStopLossOrders(symbol string) error { + return nil +} + +// CancelTakeProfitOrders is a no-op for Indodax +func (t *IndodaxTrader) CancelTakeProfitOrders(symbol string) error { + return nil +} + +// CancelAllOrders cancels all open orders for a given symbol +func (t *IndodaxTrader) CancelAllOrders(symbol string) error { + t.clearCache() + + pair := t.convertSymbol(symbol) + + // First get open orders + params := url.Values{} + params.Set("method", "openOrders") + params.Set("pair", pair) + + data, err := t.doPrivateRequest(params) + if err != nil { + return fmt.Errorf("failed to get open orders: %w", err) + } + + var result struct { + Orders []struct { + OrderID json.Number `json:"order_id"` + Type string `json:"type"` + OrderType string `json:"order_type"` + } `json:"orders"` + } + + if err := json.Unmarshal(data, &result); err != nil { + return fmt.Errorf("failed to parse open orders: %w", err) + } + + // Cancel each order + for _, order := range result.Orders { + cancelParams := url.Values{} + cancelParams.Set("method", "cancelOrder") + cancelParams.Set("pair", pair) + cancelParams.Set("order_id", order.OrderID.String()) + cancelParams.Set("type", order.Type) + + if _, err := t.doPrivateRequest(cancelParams); err != nil { + logger.Warnf("[Indodax] Failed to cancel order %s: %v", order.OrderID, err) + } else { + logger.Infof("[Indodax] Cancelled order: %s", order.OrderID) + } + } + + return nil +} + +// CancelStopOrders is a no-op for Indodax (no stop orders) +func (t *IndodaxTrader) CancelStopOrders(symbol string) error { + return nil +} + +// FormatQuantity formats quantity to correct precision for Indodax +func (t *IndodaxTrader) FormatQuantity(symbol string, quantity float64) (string, error) { + pair, err := t.getPair(symbol) + if err != nil { + // Default: 8 decimal places + return strconv.FormatFloat(quantity, 'f', 8, 64), nil + } + + precision := pair.PriceRound + if precision <= 0 { + precision = 8 + } + + // Round down to avoid exceeding balance + factor := math.Pow(10, float64(precision)) + rounded := math.Floor(quantity*factor) / factor + + return strconv.FormatFloat(rounded, 'f', precision, 64), nil +} + +// GetOrderStatus gets the status of a specific order +func (t *IndodaxTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { + pair := t.convertSymbol(symbol) + + params := url.Values{} + params.Set("method", "getOrder") + params.Set("pair", pair) + params.Set("order_id", orderID) + + data, err := t.doPrivateRequest(params) + if err != nil { + return nil, fmt.Errorf("failed to get order status: %w", err) + } + + var result struct { + Order struct { + OrderID string `json:"order_id"` + Price string `json:"price"` + Type string `json:"type"` + Status string `json:"status"` + SubmitTime string `json:"submit_time"` + FinishTime string `json:"finish_time"` + ClientOrderID string `json:"client_order_id"` + } `json:"order"` + } + + if err := json.Unmarshal(data, &result); err != nil { + return nil, fmt.Errorf("failed to parse order: %w", err) + } + + // Map Indodax status to standard status + status := "NEW" + switch result.Order.Status { + case "filled": + status = "FILLED" + case "cancelled": + status = "CANCELED" + case "open": + status = "NEW" + } + + price, _ := strconv.ParseFloat(result.Order.Price, 64) + + return map[string]interface{}{ + "status": status, + "avgPrice": price, + "executedQty": 0.0, // Indodax doesn't return executed qty in getOrder + "commission": 0.0, + "orderId": result.Order.OrderID, + }, nil +} + +// GetOpenOrders gets open/pending orders +func (t *IndodaxTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) { + pair := t.convertSymbol(symbol) + + params := url.Values{} + params.Set("method", "openOrders") + if pair != "" { + params.Set("pair", pair) + } + + data, err := t.doPrivateRequest(params) + if err != nil { + return nil, fmt.Errorf("failed to get open orders: %w", err) + } + + var result struct { + Orders []struct { + OrderID json.Number `json:"order_id"` + ClientOrderID string `json:"client_order_id"` + SubmitTime string `json:"submit_time"` + Price string `json:"price"` + Type string `json:"type"` + OrderType string `json:"order_type"` + } `json:"orders"` + } + + if err := json.Unmarshal(data, &result); err != nil { + return nil, fmt.Errorf("failed to parse open orders: %w", err) + } + + var orders []types.OpenOrder + for _, order := range result.Orders { + price, _ := strconv.ParseFloat(order.Price, 64) + + side := "BUY" + if order.Type == "sell" { + side = "SELL" + } + + orders = append(orders, types.OpenOrder{ + OrderID: order.OrderID.String(), + Symbol: t.convertSymbolBack(pair), + Side: side, + PositionSide: "LONG", + Type: "LIMIT", + Price: price, + Status: "NEW", + }) + } + + return orders, nil +} diff --git a/trader/kucoin/trader.go b/trader/kucoin/trader.go index d012526a..746ff634 100644 --- a/trader/kucoin/trader.go +++ b/trader/kucoin/trader.go @@ -11,7 +11,6 @@ import ( "math" "net/http" "nofx/logger" - "nofx/trader/types" "strconv" "strings" "sync" @@ -281,164 +280,6 @@ func (t *KuCoinTrader) convertSymbolBack(kcSymbol string) string { return sym } -// GetBalance gets account balance -func (t *KuCoinTrader) GetBalance() (map[string]interface{}, error) { - // Check cache - t.balanceCacheMutex.RLock() - if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration { - t.balanceCacheMutex.RUnlock() - return t.cachedBalance, nil - } - t.balanceCacheMutex.RUnlock() - - data, err := t.doRequest("GET", kucoinAccountPath+"?currency=USDT", nil) - if err != nil { - return nil, fmt.Errorf("failed to get account balance: %w", err) - } - - var account struct { - AccountEquity float64 `json:"accountEquity"` - UnrealisedPNL float64 `json:"unrealisedPNL"` - MarginBalance float64 `json:"marginBalance"` - PositionMargin float64 `json:"positionMargin"` - OrderMargin float64 `json:"orderMargin"` - FrozenFunds float64 `json:"frozenFunds"` - AvailableBalance float64 `json:"availableBalance"` - Currency string `json:"currency"` - } - - if err := json.Unmarshal(data, &account); err != nil { - return nil, fmt.Errorf("failed to parse balance data: %w", err) - } - - result := map[string]interface{}{ - "totalWalletBalance": account.MarginBalance, // Wallet balance (without unrealized PnL) - "availableBalance": account.AvailableBalance, - "totalUnrealizedProfit": account.UnrealisedPNL, - "total_equity": account.AccountEquity, - "totalEquity": account.AccountEquity, // For GetAccountInfo compatibility - } - - logger.Infof("✓ KuCoin balance: Total equity=%.2f, Available=%.2f, Unrealized PnL=%.2f", - account.AccountEquity, account.AvailableBalance, account.UnrealisedPNL) - - // Update cache - t.balanceCacheMutex.Lock() - t.cachedBalance = result - t.balanceCacheTime = time.Now() - t.balanceCacheMutex.Unlock() - - return result, nil -} - -// GetPositions gets all positions -func (t *KuCoinTrader) GetPositions() ([]map[string]interface{}, error) { - // Check cache - t.positionsCacheMutex.RLock() - if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration { - t.positionsCacheMutex.RUnlock() - return t.cachedPositions, nil - } - t.positionsCacheMutex.RUnlock() - - data, err := t.doRequest("GET", kucoinPositionPath, nil) - if err != nil { - return nil, fmt.Errorf("failed to get positions: %w", err) - } - - var positions []struct { - Symbol string `json:"symbol"` - CurrentQty int64 `json:"currentQty"` // Position quantity (in lots, integer) - AvgEntryPrice float64 `json:"avgEntryPrice"` // Average entry price (string in API) - MarkPrice float64 `json:"markPrice"` // Mark price - UnrealisedPnl float64 `json:"unrealisedPnl"` // Unrealized PnL - Leverage float64 `json:"leverage"` // Leverage setting - RealLeverage float64 `json:"realLeverage"` // Effective leverage (may be nil in cross mode) - LiquidationPrice float64 `json:"liquidationPrice"`// Liquidation price - Multiplier float64 `json:"multiplier"` // Contract multiplier - IsOpen bool `json:"isOpen"` - CrossMode bool `json:"crossMode"` - OpeningTimestamp int64 `json:"openingTimestamp"` - SettleCurrency string `json:"settleCurrency"` - } - - if err := json.Unmarshal(data, &positions); err != nil { - return nil, fmt.Errorf("failed to parse position data: %w", err) - } - - var result []map[string]interface{} - for _, pos := range positions { - if !pos.IsOpen || pos.CurrentQty == 0 { - continue - } - - // Convert symbol format - symbol := t.convertSymbolBack(pos.Symbol) - - // Determine side based on position quantity - // KuCoin: positive qty = long, negative qty = short - side := "long" - qty := pos.CurrentQty - if qty < 0 { - side = "short" - qty = -qty - } - - // Convert lots to actual quantity using multiplier - // Position quantity = lots * multiplier - multiplier := pos.Multiplier - if multiplier == 0 { - multiplier = 0.001 // Default for BTC - } - positionAmt := float64(qty) * multiplier - - // Determine margin mode - mgnMode := "isolated" - if pos.CrossMode { - mgnMode = "cross" - } - - // Use Leverage field (setting), fallback to RealLeverage (effective), default to 10 - leverage := pos.Leverage - if leverage == 0 { - leverage = pos.RealLeverage - } - if leverage == 0 { - leverage = 10 // Default leverage - } - - posMap := map[string]interface{}{ - "symbol": symbol, - "positionAmt": positionAmt, - "entryPrice": pos.AvgEntryPrice, - "markPrice": pos.MarkPrice, - "unRealizedProfit": pos.UnrealisedPnl, - "leverage": leverage, - "liquidationPrice": pos.LiquidationPrice, - "side": side, - "mgnMode": mgnMode, - "createdTime": pos.OpeningTimestamp, - } - result = append(result, posMap) - } - - // Update cache - t.positionsCacheMutex.Lock() - t.cachedPositions = result - t.positionsCacheTime = time.Now() - t.positionsCacheMutex.Unlock() - - return result, nil -} - -// InvalidatePositionCache clears the position cache -func (t *KuCoinTrader) InvalidatePositionCache() { - t.positionsCacheMutex.Lock() - t.cachedPositions = nil - t.positionsCacheTime = time.Time{} - t.positionsCacheMutex.Unlock() -} - // getContract gets contract info func (t *KuCoinTrader) getContract(symbol string) (*KuCoinContract, error) { kcSymbol := t.convertSymbol(symbol) @@ -526,768 +367,3 @@ func (t *KuCoinTrader) quantityToLots(symbol string, quantity float64) (int64, e return lotsInt, nil } - -// SetMarginMode sets margin mode -func (t *KuCoinTrader) SetMarginMode(symbol string, isCrossMargin bool) error { - // KuCoin sets margin mode per position, handled automatically - logger.Infof("✓ KuCoin margin mode: %v (handled per position)", isCrossMargin) - return nil -} - -// SetLeverage sets leverage for a symbol -func (t *KuCoinTrader) SetLeverage(symbol string, leverage int) error { - kcSymbol := t.convertSymbol(symbol) - - body := map[string]interface{}{ - "symbol": kcSymbol, - "leverage": fmt.Sprintf("%d", leverage), - } - - _, err := t.doRequest("POST", kucoinLeveragePath, body) - if err != nil { - // Ignore if already at target leverage - if strings.Contains(err.Error(), "same") || strings.Contains(err.Error(), "already") { - logger.Infof("✓ %s leverage is already %dx", symbol, leverage) - return nil - } - return fmt.Errorf("failed to set leverage: %w", err) - } - - logger.Infof("✓ %s leverage set to %dx", symbol, leverage) - return nil -} - -// OpenLong opens long position -func (t *KuCoinTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { - // Cancel old orders - t.CancelAllOrders(symbol) - - // Set leverage - if err := t.SetLeverage(symbol, leverage); err != nil { - logger.Infof("⚠️ Failed to set leverage: %v", err) - } - - kcSymbol := t.convertSymbol(symbol) - - // Convert quantity to lots - lots, err := t.quantityToLots(symbol, quantity) - if err != nil { - return nil, fmt.Errorf("failed to calculate lots: %w", err) - } - - body := map[string]interface{}{ - "clientOid": fmt.Sprintf("nfx%d", time.Now().UnixNano()), - "symbol": kcSymbol, - "side": "buy", - "type": "market", - "size": lots, - "leverage": fmt.Sprintf("%d", leverage), - "reduceOnly": false, - "marginMode": "CROSS", // Use cross margin mode - } - - data, err := t.doRequest("POST", kucoinOrderPath, body) - if err != nil { - return nil, fmt.Errorf("failed to open long position: %w", err) - } - - var result struct { - OrderId string `json:"orderId"` - } - - if err := json.Unmarshal(data, &result); err != nil { - return nil, fmt.Errorf("failed to parse order response: %w", err) - } - - logger.Infof("✓ KuCoin opened long position: %s, lots=%d, orderId=%s", symbol, lots, result.OrderId) - - // Query order to get fill price - fillPrice := t.queryOrderFillPrice(result.OrderId) - - return map[string]interface{}{ - "orderId": result.OrderId, - "symbol": symbol, - "status": "FILLED", - "fillPrice": fillPrice, - }, nil -} - -// OpenShort opens short position -func (t *KuCoinTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { - // Cancel old orders - t.CancelAllOrders(symbol) - - // Set leverage - if err := t.SetLeverage(symbol, leverage); err != nil { - logger.Infof("⚠️ Failed to set leverage: %v", err) - } - - kcSymbol := t.convertSymbol(symbol) - - // Convert quantity to lots - lots, err := t.quantityToLots(symbol, quantity) - if err != nil { - return nil, fmt.Errorf("failed to calculate lots: %w", err) - } - - body := map[string]interface{}{ - "clientOid": fmt.Sprintf("nfx%d", time.Now().UnixNano()), - "symbol": kcSymbol, - "side": "sell", - "type": "market", - "size": lots, - "leverage": fmt.Sprintf("%d", leverage), - "reduceOnly": false, - "marginMode": "CROSS", // Use cross margin mode - } - - data, err := t.doRequest("POST", kucoinOrderPath, body) - if err != nil { - return nil, fmt.Errorf("failed to open short position: %w", err) - } - - var result struct { - OrderId string `json:"orderId"` - } - - if err := json.Unmarshal(data, &result); err != nil { - return nil, fmt.Errorf("failed to parse order response: %w", err) - } - - logger.Infof("✓ KuCoin opened short position: %s, lots=%d, orderId=%s", symbol, lots, result.OrderId) - - // Query order to get fill price - fillPrice := t.queryOrderFillPrice(result.OrderId) - - return map[string]interface{}{ - "orderId": result.OrderId, - "symbol": symbol, - "status": "FILLED", - "fillPrice": fillPrice, - }, nil -} - -// queryOrderFillPrice queries order status and returns fill price -func (t *KuCoinTrader) queryOrderFillPrice(orderId string) float64 { - // Wait a bit for order to fill - time.Sleep(500 * time.Millisecond) - - path := fmt.Sprintf("%s/%s", kucoinOrderPath, orderId) - data, err := t.doRequest("GET", path, nil) - if err != nil { - logger.Warnf("Failed to query order %s: %v", orderId, err) - return 0 - } - - var order struct { - DealAvgPrice float64 `json:"dealAvgPrice"` - Status string `json:"status"` - DealSize int64 `json:"dealSize"` - } - - if err := json.Unmarshal(data, &order); err != nil { - return 0 - } - - return order.DealAvgPrice -} - -// CloseLong closes long position -func (t *KuCoinTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) { - // Invalidate position cache and get fresh positions - t.InvalidatePositionCache() - positions, err := t.GetPositions() - if err != nil { - return nil, fmt.Errorf("failed to get positions: %w", err) - } - - // Find actual position and get margin mode - var actualQty float64 - var posFound bool - var marginMode string = "CROSS" // Default to CROSS - for _, pos := range positions { - if pos["symbol"] == symbol && pos["side"] == "long" { - actualQty = pos["positionAmt"].(float64) - posFound = true - // Get margin mode from position - if mgnMode, ok := pos["mgnMode"].(string); ok { - marginMode = strings.ToUpper(mgnMode) - } - break - } - } - - if !posFound || actualQty == 0 { - return map[string]interface{}{ - "status": "NO_POSITION", - "message": fmt.Sprintf("No long position found for %s on KuCoin", symbol), - }, nil - } - - // Use actual quantity from exchange - if quantity == 0 || quantity > actualQty { - quantity = actualQty - } - - kcSymbol := t.convertSymbol(symbol) - - // Convert quantity to lots - lots, err := t.quantityToLots(symbol, quantity) - if err != nil { - return nil, fmt.Errorf("failed to calculate lots: %w", err) - } - - body := map[string]interface{}{ - "clientOid": fmt.Sprintf("nfx%d", time.Now().UnixNano()), - "symbol": kcSymbol, - "side": "sell", - "type": "market", - "size": lots, - "reduceOnly": true, - "closeOrder": true, - "marginMode": marginMode, // Use position's margin mode - } - - data, err := t.doRequest("POST", kucoinOrderPath, body) - if err != nil { - return nil, fmt.Errorf("failed to close long position: %w", err) - } - - var result struct { - OrderId string `json:"orderId"` - } - - if err := json.Unmarshal(data, &result); err != nil { - return nil, fmt.Errorf("failed to parse order response: %w", err) - } - - logger.Infof("✓ KuCoin closed long position: %s", symbol) - - // Cancel pending orders - t.CancelAllOrders(symbol) - - return map[string]interface{}{ - "orderId": result.OrderId, - "symbol": symbol, - "status": "FILLED", - }, nil -} - -// CloseShort closes short position -func (t *KuCoinTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) { - // Invalidate position cache and get fresh positions - t.InvalidatePositionCache() - positions, err := t.GetPositions() - if err != nil { - return nil, fmt.Errorf("failed to get positions: %w", err) - } - - // Find actual position and get margin mode - var actualQty float64 - var posFound bool - var marginMode string = "CROSS" // Default to CROSS - for _, pos := range positions { - if pos["symbol"] == symbol && pos["side"] == "short" { - actualQty = pos["positionAmt"].(float64) - posFound = true - // Get margin mode from position - if mgnMode, ok := pos["mgnMode"].(string); ok { - marginMode = strings.ToUpper(mgnMode) - } - break - } - } - - if !posFound || actualQty == 0 { - return map[string]interface{}{ - "status": "NO_POSITION", - "message": fmt.Sprintf("No short position found for %s on KuCoin", symbol), - }, nil - } - - // Use actual quantity from exchange - if quantity == 0 || quantity > actualQty { - quantity = actualQty - } - - kcSymbol := t.convertSymbol(symbol) - - // Convert quantity to lots - lots, err := t.quantityToLots(symbol, quantity) - if err != nil { - return nil, fmt.Errorf("failed to calculate lots: %w", err) - } - - body := map[string]interface{}{ - "clientOid": fmt.Sprintf("nfx%d", time.Now().UnixNano()), - "symbol": kcSymbol, - "side": "buy", - "type": "market", - "size": lots, - "reduceOnly": true, - "closeOrder": true, - "marginMode": marginMode, // Use position's margin mode - } - - data, err := t.doRequest("POST", kucoinOrderPath, body) - if err != nil { - return nil, fmt.Errorf("failed to close short position: %w", err) - } - - var result struct { - OrderId string `json:"orderId"` - } - - if err := json.Unmarshal(data, &result); err != nil { - return nil, fmt.Errorf("failed to parse order response: %w", err) - } - - logger.Infof("✓ KuCoin closed short position: %s", symbol) - - // Cancel pending orders - t.CancelAllOrders(symbol) - - return map[string]interface{}{ - "orderId": result.OrderId, - "symbol": symbol, - "status": "FILLED", - }, nil -} - -// GetMarketPrice gets market price -func (t *KuCoinTrader) GetMarketPrice(symbol string) (float64, error) { - kcSymbol := t.convertSymbol(symbol) - path := fmt.Sprintf("%s?symbol=%s", kucoinTickerPath, kcSymbol) - - data, err := t.doRequest("GET", path, nil) - if err != nil { - return 0, fmt.Errorf("failed to get price: %w", err) - } - - var ticker struct { - Price string `json:"price"` - } - - if err := json.Unmarshal(data, &ticker); err != nil { - return 0, err - } - - price, _ := strconv.ParseFloat(ticker.Price, 64) - return price, nil -} - -// SetStopLoss sets stop loss order -func (t *KuCoinTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error { - kcSymbol := t.convertSymbol(symbol) - - // Convert quantity to lots - lots, err := t.quantityToLots(symbol, quantity) - if err != nil { - return fmt.Errorf("failed to calculate lots: %w", err) - } - - // Determine side: close long = sell, close short = buy - side := "sell" - stop := "down" // Long position: stop loss triggers when price goes down - if strings.ToUpper(positionSide) == "SHORT" { - side = "buy" - stop = "up" // Short position: stop loss triggers when price goes up - } - - body := map[string]interface{}{ - "clientOid": fmt.Sprintf("nfxsl%d", time.Now().UnixNano()), - "symbol": kcSymbol, - "side": side, - "type": "market", - "size": lots, - "stop": stop, - "stopPriceType": "MP", // Mark Price - "stopPrice": fmt.Sprintf("%.8f", stopPrice), - "reduceOnly": true, - "closeOrder": true, - } - - _, err = t.doRequest("POST", kucoinStopOrderPath, body) - if err != nil { - return fmt.Errorf("failed to set stop loss: %w", err) - } - - logger.Infof("✓ Stop loss set: %.4f", stopPrice) - return nil -} - -// SetTakeProfit sets take profit order -func (t *KuCoinTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error { - kcSymbol := t.convertSymbol(symbol) - - // Convert quantity to lots - lots, err := t.quantityToLots(symbol, quantity) - if err != nil { - return fmt.Errorf("failed to calculate lots: %w", err) - } - - // Determine side: close long = sell, close short = buy - side := "sell" - stop := "up" // Long position: take profit triggers when price goes up - if strings.ToUpper(positionSide) == "SHORT" { - side = "buy" - stop = "down" // Short position: take profit triggers when price goes down - } - - body := map[string]interface{}{ - "clientOid": fmt.Sprintf("nfxtp%d", time.Now().UnixNano()), - "symbol": kcSymbol, - "side": side, - "type": "market", - "size": lots, - "stop": stop, - "stopPriceType": "MP", // Mark Price - "stopPrice": fmt.Sprintf("%.8f", takeProfitPrice), - "reduceOnly": true, - "closeOrder": true, - } - - _, err = t.doRequest("POST", kucoinStopOrderPath, body) - if err != nil { - return fmt.Errorf("failed to set take profit: %w", err) - } - - logger.Infof("✓ Take profit set: %.4f", takeProfitPrice) - return nil -} - -// CancelStopLossOrders cancels stop loss orders -func (t *KuCoinTrader) CancelStopLossOrders(symbol string) error { - return t.cancelStopOrdersByType(symbol, "sl") -} - -// CancelTakeProfitOrders cancels take profit orders -func (t *KuCoinTrader) CancelTakeProfitOrders(symbol string) error { - return t.cancelStopOrdersByType(symbol, "tp") -} - -// cancelStopOrdersByType cancels stop orders by type -func (t *KuCoinTrader) cancelStopOrdersByType(symbol string, orderType string) error { - kcSymbol := t.convertSymbol(symbol) - - // Get pending stop orders - path := fmt.Sprintf("%s?symbol=%s", kucoinStopOrderPath, kcSymbol) - data, err := t.doRequest("GET", path, nil) - if err != nil { - return err - } - - var response struct { - Items []struct { - Id string `json:"id"` - ClientOid string `json:"clientOid"` - Stop string `json:"stop"` - } `json:"items"` - } - - if err := json.Unmarshal(data, &response); err != nil { - // Try alternate format (direct array) - var items []struct { - Id string `json:"id"` - ClientOid string `json:"clientOid"` - Stop string `json:"stop"` - } - if err := json.Unmarshal(data, &items); err != nil { - return err - } - response.Items = items - } - - // Cancel matching orders - for _, order := range response.Items { - // Check if order matches type based on clientOid prefix - if orderType == "sl" && !strings.Contains(order.ClientOid, "sl") { - continue - } - if orderType == "tp" && !strings.Contains(order.ClientOid, "tp") { - continue - } - - cancelPath := fmt.Sprintf("%s/%s", kucoinCancelStopPath, order.Id) - _, err := t.doRequest("DELETE", cancelPath, nil) - if err != nil { - logger.Warnf("Failed to cancel stop order %s: %v", order.Id, err) - } - } - - return nil -} - -// CancelStopOrders cancels all stop orders for symbol -func (t *KuCoinTrader) CancelStopOrders(symbol string) error { - kcSymbol := t.convertSymbol(symbol) - - path := fmt.Sprintf("%s?symbol=%s", kucoinCancelStopPath, kcSymbol) - _, err := t.doRequest("DELETE", path, nil) - if err != nil { - // Ignore if no orders to cancel - if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "400100") { - return nil - } - return err - } - - logger.Infof("✓ Cancelled stop orders for %s", symbol) - return nil -} - -// CancelAllOrders cancels all pending orders for symbol -func (t *KuCoinTrader) CancelAllOrders(symbol string) error { - kcSymbol := t.convertSymbol(symbol) - - // Cancel regular orders - path := fmt.Sprintf("%s?symbol=%s", kucoinCancelOrderPath, kcSymbol) - _, err := t.doRequest("DELETE", path, nil) - if err != nil && !strings.Contains(err.Error(), "not found") { - logger.Warnf("Failed to cancel regular orders: %v", err) - } - - // Cancel stop orders - t.CancelStopOrders(symbol) - - return nil -} - -// FormatQuantity formats quantity to correct precision -func (t *KuCoinTrader) FormatQuantity(symbol string, quantity float64) (string, error) { - contract, err := t.getContract(symbol) - if err != nil { - return "", err - } - - // Calculate lots - lots := quantity / contract.Multiplier - - // Round to integer - lotsInt := int64(math.Round(lots)) - - return strconv.FormatInt(lotsInt, 10), nil -} - -// GetOrderStatus gets order status -func (t *KuCoinTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { - path := fmt.Sprintf("%s/%s", kucoinOrderPath, orderID) - data, err := t.doRequest("GET", path, nil) - if err != nil { - return nil, fmt.Errorf("failed to get order status: %w", err) - } - - var order struct { - Id string `json:"id"` - Symbol string `json:"symbol"` - Status string `json:"status"` - DealAvgPrice float64 `json:"dealAvgPrice"` - DealSize int64 `json:"dealSize"` - Fee float64 `json:"fee"` - Side string `json:"side"` - } - - if err := json.Unmarshal(data, &order); err != nil { - return nil, err - } - - // Convert status - status := "NEW" - if order.Status == "done" { - status = "FILLED" - } else if order.Status == "cancelled" || order.Status == "canceled" { - status = "CANCELED" - } - - return map[string]interface{}{ - "orderId": order.Id, - "symbol": t.convertSymbolBack(order.Symbol), - "status": status, - "avgPrice": order.DealAvgPrice, - "executedQty": order.DealSize, - "commission": order.Fee, - }, nil -} - -// GetClosedPnL gets closed position PnL records -func (t *KuCoinTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) { - if limit <= 0 { - limit = 100 - } - if limit > 100 { - limit = 100 - } - - // KuCoin closed positions API - path := fmt.Sprintf("/api/v1/history-positions?status=CLOSE&limit=%d", limit) - if !startTime.IsZero() { - path += fmt.Sprintf("&from=%d", startTime.UnixMilli()) - } - - data, err := t.doRequest("GET", path, nil) - if err != nil { - return nil, fmt.Errorf("failed to get closed PnL: %w", err) - } - - var response struct { - HasMore bool `json:"hasMore"` - DataList []struct { - Symbol string `json:"symbol"` - OpenPrice float64 `json:"avgEntryPrice"` - ClosePrice float64 `json:"avgClosePrice"` - Qty int64 `json:"qty"` - RealisedPnl float64 `json:"realisedGrossCost"` - CloseTime int64 `json:"closeTime"` - OpenTime int64 `json:"openTime"` - PositionId string `json:"id"` - CloseType string `json:"type"` - Leverage int `json:"leverage"` - SettleCurrency string `json:"settleCurrency"` - } `json:"dataList"` - } - - if err := json.Unmarshal(data, &response); err != nil { - return nil, fmt.Errorf("failed to parse closed PnL: %w", err) - } - - var records []types.ClosedPnLRecord - for _, item := range response.DataList { - side := "long" - qty := item.Qty - if qty < 0 { - side = "short" - qty = -qty - } - - // Map close type - closeType := "unknown" - switch strings.ToUpper(item.CloseType) { - case "CLOSE", "MANUAL": - closeType = "manual" - case "STOP", "STOPLOSS": - closeType = "stop_loss" - case "TAKEPROFIT", "TP": - closeType = "take_profit" - case "LIQUIDATION", "LIQ", "ADL": - closeType = "liquidation" - } - - records = append(records, types.ClosedPnLRecord{ - Symbol: t.convertSymbolBack(item.Symbol), - Side: side, - EntryPrice: item.OpenPrice, - ExitPrice: item.ClosePrice, - Quantity: float64(qty), - RealizedPnL: item.RealisedPnl, - Leverage: item.Leverage, - EntryTime: time.UnixMilli(item.OpenTime), - ExitTime: time.UnixMilli(item.CloseTime), - ExchangeID: item.PositionId, - CloseType: closeType, - }) - } - - return records, nil -} - -// GetOpenOrders gets open/pending orders -func (t *KuCoinTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) { - kcSymbol := t.convertSymbol(symbol) - - // Get regular orders - path := fmt.Sprintf("%s?symbol=%s&status=active", kucoinOrderPath, kcSymbol) - data, err := t.doRequest("GET", path, nil) - if err != nil { - return nil, fmt.Errorf("failed to get open orders: %w", err) - } - - var response struct { - Items []struct { - Id string `json:"id"` - Symbol string `json:"symbol"` - Side string `json:"side"` - Type string `json:"type"` - Price string `json:"price"` - Size int64 `json:"size"` - StopType string `json:"stopType"` - } `json:"items"` - } - - if err := json.Unmarshal(data, &response); err != nil { - // Try alternate format - var items []struct { - Id string `json:"id"` - Symbol string `json:"symbol"` - Side string `json:"side"` - Type string `json:"type"` - Price string `json:"price"` - Size int64 `json:"size"` - StopType string `json:"stopType"` - } - if err := json.Unmarshal(data, &items); err != nil { - return nil, err - } - response.Items = items - } - - var orders []types.OpenOrder - for _, item := range response.Items { - // Determine position side based on order side - positionSide := "LONG" - if item.Side == "sell" { - positionSide = "SHORT" - } - - price, _ := strconv.ParseFloat(item.Price, 64) - - orders = append(orders, types.OpenOrder{ - OrderID: item.Id, - Symbol: t.convertSymbolBack(item.Symbol), - Side: strings.ToUpper(item.Side), - PositionSide: positionSide, - Type: strings.ToUpper(item.Type), - Price: price, - Quantity: float64(item.Size), - Status: "NEW", - }) - } - - // Get stop orders - stopPath := fmt.Sprintf("%s?symbol=%s", kucoinStopOrderPath, kcSymbol) - stopData, err := t.doRequest("GET", stopPath, nil) - if err == nil { - var stopResponse struct { - Items []struct { - Id string `json:"id"` - Symbol string `json:"symbol"` - Side string `json:"side"` - StopPrice string `json:"stopPrice"` - Size int64 `json:"size"` - } `json:"items"` - } - - if json.Unmarshal(stopData, &stopResponse) == nil { - for _, item := range stopResponse.Items { - positionSide := "LONG" - if item.Side == "sell" { - positionSide = "SHORT" - } - - stopPrice, _ := strconv.ParseFloat(item.StopPrice, 64) - - orders = append(orders, types.OpenOrder{ - OrderID: item.Id, - Symbol: t.convertSymbolBack(item.Symbol), - Side: strings.ToUpper(item.Side), - PositionSide: positionSide, - Type: "STOP_MARKET", - StopPrice: stopPrice, - Quantity: float64(item.Size), - Status: "NEW", - }) - } - } - } - - return orders, nil -} diff --git a/trader/kucoin/trader_account.go b/trader/kucoin/trader_account.go new file mode 100644 index 00000000..bb9ccdff --- /dev/null +++ b/trader/kucoin/trader_account.go @@ -0,0 +1,58 @@ +package kucoin + +import ( + "encoding/json" + "fmt" + "nofx/logger" + "time" +) + +// GetBalance gets account balance +func (t *KuCoinTrader) GetBalance() (map[string]interface{}, error) { + // Check cache + t.balanceCacheMutex.RLock() + if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration { + t.balanceCacheMutex.RUnlock() + return t.cachedBalance, nil + } + t.balanceCacheMutex.RUnlock() + + data, err := t.doRequest("GET", kucoinAccountPath+"?currency=USDT", nil) + if err != nil { + return nil, fmt.Errorf("failed to get account balance: %w", err) + } + + var account struct { + AccountEquity float64 `json:"accountEquity"` + UnrealisedPNL float64 `json:"unrealisedPNL"` + MarginBalance float64 `json:"marginBalance"` + PositionMargin float64 `json:"positionMargin"` + OrderMargin float64 `json:"orderMargin"` + FrozenFunds float64 `json:"frozenFunds"` + AvailableBalance float64 `json:"availableBalance"` + Currency string `json:"currency"` + } + + if err := json.Unmarshal(data, &account); err != nil { + return nil, fmt.Errorf("failed to parse balance data: %w", err) + } + + result := map[string]interface{}{ + "totalWalletBalance": account.MarginBalance, // Wallet balance (without unrealized PnL) + "availableBalance": account.AvailableBalance, + "totalUnrealizedProfit": account.UnrealisedPNL, + "total_equity": account.AccountEquity, + "totalEquity": account.AccountEquity, // For GetAccountInfo compatibility + } + + logger.Infof("✓ KuCoin balance: Total equity=%.2f, Available=%.2f, Unrealized PnL=%.2f", + account.AccountEquity, account.AvailableBalance, account.UnrealisedPNL) + + // Update cache + t.balanceCacheMutex.Lock() + t.cachedBalance = result + t.balanceCacheTime = time.Now() + t.balanceCacheMutex.Unlock() + + return result, nil +} diff --git a/trader/kucoin/trader_orders.go b/trader/kucoin/trader_orders.go new file mode 100644 index 00000000..41819202 --- /dev/null +++ b/trader/kucoin/trader_orders.go @@ -0,0 +1,777 @@ +package kucoin + +import ( + "encoding/json" + "fmt" + "math" + "nofx/logger" + "nofx/trader/types" + "strconv" + "strings" + "time" +) + +// OpenLong opens long position +func (t *KuCoinTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { + // Cancel old orders + t.CancelAllOrders(symbol) + + // Set leverage + if err := t.SetLeverage(symbol, leverage); err != nil { + logger.Infof("⚠️ Failed to set leverage: %v", err) + } + + kcSymbol := t.convertSymbol(symbol) + + // Convert quantity to lots + lots, err := t.quantityToLots(symbol, quantity) + if err != nil { + return nil, fmt.Errorf("failed to calculate lots: %w", err) + } + + body := map[string]interface{}{ + "clientOid": fmt.Sprintf("nfx%d", time.Now().UnixNano()), + "symbol": kcSymbol, + "side": "buy", + "type": "market", + "size": lots, + "leverage": fmt.Sprintf("%d", leverage), + "reduceOnly": false, + "marginMode": "CROSS", // Use cross margin mode + } + + data, err := t.doRequest("POST", kucoinOrderPath, body) + if err != nil { + return nil, fmt.Errorf("failed to open long position: %w", err) + } + + var result struct { + OrderId string `json:"orderId"` + } + + if err := json.Unmarshal(data, &result); err != nil { + return nil, fmt.Errorf("failed to parse order response: %w", err) + } + + logger.Infof("✓ KuCoin opened long position: %s, lots=%d, orderId=%s", symbol, lots, result.OrderId) + + // Query order to get fill price + fillPrice := t.queryOrderFillPrice(result.OrderId) + + return map[string]interface{}{ + "orderId": result.OrderId, + "symbol": symbol, + "status": "FILLED", + "fillPrice": fillPrice, + }, nil +} + +// OpenShort opens short position +func (t *KuCoinTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { + // Cancel old orders + t.CancelAllOrders(symbol) + + // Set leverage + if err := t.SetLeverage(symbol, leverage); err != nil { + logger.Infof("⚠️ Failed to set leverage: %v", err) + } + + kcSymbol := t.convertSymbol(symbol) + + // Convert quantity to lots + lots, err := t.quantityToLots(symbol, quantity) + if err != nil { + return nil, fmt.Errorf("failed to calculate lots: %w", err) + } + + body := map[string]interface{}{ + "clientOid": fmt.Sprintf("nfx%d", time.Now().UnixNano()), + "symbol": kcSymbol, + "side": "sell", + "type": "market", + "size": lots, + "leverage": fmt.Sprintf("%d", leverage), + "reduceOnly": false, + "marginMode": "CROSS", // Use cross margin mode + } + + data, err := t.doRequest("POST", kucoinOrderPath, body) + if err != nil { + return nil, fmt.Errorf("failed to open short position: %w", err) + } + + var result struct { + OrderId string `json:"orderId"` + } + + if err := json.Unmarshal(data, &result); err != nil { + return nil, fmt.Errorf("failed to parse order response: %w", err) + } + + logger.Infof("✓ KuCoin opened short position: %s, lots=%d, orderId=%s", symbol, lots, result.OrderId) + + // Query order to get fill price + fillPrice := t.queryOrderFillPrice(result.OrderId) + + return map[string]interface{}{ + "orderId": result.OrderId, + "symbol": symbol, + "status": "FILLED", + "fillPrice": fillPrice, + }, nil +} + +// queryOrderFillPrice queries order status and returns fill price +func (t *KuCoinTrader) queryOrderFillPrice(orderId string) float64 { + // Wait a bit for order to fill + time.Sleep(500 * time.Millisecond) + + path := fmt.Sprintf("%s/%s", kucoinOrderPath, orderId) + data, err := t.doRequest("GET", path, nil) + if err != nil { + logger.Warnf("Failed to query order %s: %v", orderId, err) + return 0 + } + + var order struct { + DealAvgPrice float64 `json:"dealAvgPrice"` + Status string `json:"status"` + DealSize int64 `json:"dealSize"` + } + + if err := json.Unmarshal(data, &order); err != nil { + return 0 + } + + return order.DealAvgPrice +} + +// CloseLong closes long position +func (t *KuCoinTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) { + // Invalidate position cache and get fresh positions + t.InvalidatePositionCache() + positions, err := t.GetPositions() + if err != nil { + return nil, fmt.Errorf("failed to get positions: %w", err) + } + + // Find actual position and get margin mode + var actualQty float64 + var posFound bool + var marginMode string = "CROSS" // Default to CROSS + for _, pos := range positions { + if pos["symbol"] == symbol && pos["side"] == "long" { + actualQty = pos["positionAmt"].(float64) + posFound = true + // Get margin mode from position + if mgnMode, ok := pos["mgnMode"].(string); ok { + marginMode = strings.ToUpper(mgnMode) + } + break + } + } + + if !posFound || actualQty == 0 { + return map[string]interface{}{ + "status": "NO_POSITION", + "message": fmt.Sprintf("No long position found for %s on KuCoin", symbol), + }, nil + } + + // Use actual quantity from exchange + if quantity == 0 || quantity > actualQty { + quantity = actualQty + } + + kcSymbol := t.convertSymbol(symbol) + + // Convert quantity to lots + lots, err := t.quantityToLots(symbol, quantity) + if err != nil { + return nil, fmt.Errorf("failed to calculate lots: %w", err) + } + + body := map[string]interface{}{ + "clientOid": fmt.Sprintf("nfx%d", time.Now().UnixNano()), + "symbol": kcSymbol, + "side": "sell", + "type": "market", + "size": lots, + "reduceOnly": true, + "closeOrder": true, + "marginMode": marginMode, // Use position's margin mode + } + + data, err := t.doRequest("POST", kucoinOrderPath, body) + if err != nil { + return nil, fmt.Errorf("failed to close long position: %w", err) + } + + var result struct { + OrderId string `json:"orderId"` + } + + if err := json.Unmarshal(data, &result); err != nil { + return nil, fmt.Errorf("failed to parse order response: %w", err) + } + + logger.Infof("✓ KuCoin closed long position: %s", symbol) + + // Cancel pending orders + t.CancelAllOrders(symbol) + + return map[string]interface{}{ + "orderId": result.OrderId, + "symbol": symbol, + "status": "FILLED", + }, nil +} + +// CloseShort closes short position +func (t *KuCoinTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) { + // Invalidate position cache and get fresh positions + t.InvalidatePositionCache() + positions, err := t.GetPositions() + if err != nil { + return nil, fmt.Errorf("failed to get positions: %w", err) + } + + // Find actual position and get margin mode + var actualQty float64 + var posFound bool + var marginMode string = "CROSS" // Default to CROSS + for _, pos := range positions { + if pos["symbol"] == symbol && pos["side"] == "short" { + actualQty = pos["positionAmt"].(float64) + posFound = true + // Get margin mode from position + if mgnMode, ok := pos["mgnMode"].(string); ok { + marginMode = strings.ToUpper(mgnMode) + } + break + } + } + + if !posFound || actualQty == 0 { + return map[string]interface{}{ + "status": "NO_POSITION", + "message": fmt.Sprintf("No short position found for %s on KuCoin", symbol), + }, nil + } + + // Use actual quantity from exchange + if quantity == 0 || quantity > actualQty { + quantity = actualQty + } + + kcSymbol := t.convertSymbol(symbol) + + // Convert quantity to lots + lots, err := t.quantityToLots(symbol, quantity) + if err != nil { + return nil, fmt.Errorf("failed to calculate lots: %w", err) + } + + body := map[string]interface{}{ + "clientOid": fmt.Sprintf("nfx%d", time.Now().UnixNano()), + "symbol": kcSymbol, + "side": "buy", + "type": "market", + "size": lots, + "reduceOnly": true, + "closeOrder": true, + "marginMode": marginMode, // Use position's margin mode + } + + data, err := t.doRequest("POST", kucoinOrderPath, body) + if err != nil { + return nil, fmt.Errorf("failed to close short position: %w", err) + } + + var result struct { + OrderId string `json:"orderId"` + } + + if err := json.Unmarshal(data, &result); err != nil { + return nil, fmt.Errorf("failed to parse order response: %w", err) + } + + logger.Infof("✓ KuCoin closed short position: %s", symbol) + + // Cancel pending orders + t.CancelAllOrders(symbol) + + return map[string]interface{}{ + "orderId": result.OrderId, + "symbol": symbol, + "status": "FILLED", + }, nil +} + +// GetMarketPrice gets market price +func (t *KuCoinTrader) GetMarketPrice(symbol string) (float64, error) { + kcSymbol := t.convertSymbol(symbol) + path := fmt.Sprintf("%s?symbol=%s", kucoinTickerPath, kcSymbol) + + data, err := t.doRequest("GET", path, nil) + if err != nil { + return 0, fmt.Errorf("failed to get price: %w", err) + } + + var ticker struct { + Price string `json:"price"` + } + + if err := json.Unmarshal(data, &ticker); err != nil { + return 0, err + } + + price, _ := strconv.ParseFloat(ticker.Price, 64) + return price, nil +} + +// SetStopLoss sets stop loss order +func (t *KuCoinTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error { + kcSymbol := t.convertSymbol(symbol) + + // Convert quantity to lots + lots, err := t.quantityToLots(symbol, quantity) + if err != nil { + return fmt.Errorf("failed to calculate lots: %w", err) + } + + // Determine side: close long = sell, close short = buy + side := "sell" + stop := "down" // Long position: stop loss triggers when price goes down + if strings.ToUpper(positionSide) == "SHORT" { + side = "buy" + stop = "up" // Short position: stop loss triggers when price goes up + } + + body := map[string]interface{}{ + "clientOid": fmt.Sprintf("nfxsl%d", time.Now().UnixNano()), + "symbol": kcSymbol, + "side": side, + "type": "market", + "size": lots, + "stop": stop, + "stopPriceType": "MP", // Mark Price + "stopPrice": fmt.Sprintf("%.8f", stopPrice), + "reduceOnly": true, + "closeOrder": true, + } + + _, err = t.doRequest("POST", kucoinStopOrderPath, body) + if err != nil { + return fmt.Errorf("failed to set stop loss: %w", err) + } + + logger.Infof("✓ Stop loss set: %.4f", stopPrice) + return nil +} + +// SetTakeProfit sets take profit order +func (t *KuCoinTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error { + kcSymbol := t.convertSymbol(symbol) + + // Convert quantity to lots + lots, err := t.quantityToLots(symbol, quantity) + if err != nil { + return fmt.Errorf("failed to calculate lots: %w", err) + } + + // Determine side: close long = sell, close short = buy + side := "sell" + stop := "up" // Long position: take profit triggers when price goes up + if strings.ToUpper(positionSide) == "SHORT" { + side = "buy" + stop = "down" // Short position: take profit triggers when price goes down + } + + body := map[string]interface{}{ + "clientOid": fmt.Sprintf("nfxtp%d", time.Now().UnixNano()), + "symbol": kcSymbol, + "side": side, + "type": "market", + "size": lots, + "stop": stop, + "stopPriceType": "MP", // Mark Price + "stopPrice": fmt.Sprintf("%.8f", takeProfitPrice), + "reduceOnly": true, + "closeOrder": true, + } + + _, err = t.doRequest("POST", kucoinStopOrderPath, body) + if err != nil { + return fmt.Errorf("failed to set take profit: %w", err) + } + + logger.Infof("✓ Take profit set: %.4f", takeProfitPrice) + return nil +} + +// CancelStopLossOrders cancels stop loss orders +func (t *KuCoinTrader) CancelStopLossOrders(symbol string) error { + return t.cancelStopOrdersByType(symbol, "sl") +} + +// CancelTakeProfitOrders cancels take profit orders +func (t *KuCoinTrader) CancelTakeProfitOrders(symbol string) error { + return t.cancelStopOrdersByType(symbol, "tp") +} + +// cancelStopOrdersByType cancels stop orders by type +func (t *KuCoinTrader) cancelStopOrdersByType(symbol string, orderType string) error { + kcSymbol := t.convertSymbol(symbol) + + // Get pending stop orders + path := fmt.Sprintf("%s?symbol=%s", kucoinStopOrderPath, kcSymbol) + data, err := t.doRequest("GET", path, nil) + if err != nil { + return err + } + + var response struct { + Items []struct { + Id string `json:"id"` + ClientOid string `json:"clientOid"` + Stop string `json:"stop"` + } `json:"items"` + } + + if err := json.Unmarshal(data, &response); err != nil { + // Try alternate format (direct array) + var items []struct { + Id string `json:"id"` + ClientOid string `json:"clientOid"` + Stop string `json:"stop"` + } + if err := json.Unmarshal(data, &items); err != nil { + return err + } + response.Items = items + } + + // Cancel matching orders + for _, order := range response.Items { + // Check if order matches type based on clientOid prefix + if orderType == "sl" && !strings.Contains(order.ClientOid, "sl") { + continue + } + if orderType == "tp" && !strings.Contains(order.ClientOid, "tp") { + continue + } + + cancelPath := fmt.Sprintf("%s/%s", kucoinCancelStopPath, order.Id) + _, err := t.doRequest("DELETE", cancelPath, nil) + if err != nil { + logger.Warnf("Failed to cancel stop order %s: %v", order.Id, err) + } + } + + return nil +} + +// CancelStopOrders cancels all stop orders for symbol +func (t *KuCoinTrader) CancelStopOrders(symbol string) error { + kcSymbol := t.convertSymbol(symbol) + + path := fmt.Sprintf("%s?symbol=%s", kucoinCancelStopPath, kcSymbol) + _, err := t.doRequest("DELETE", path, nil) + if err != nil { + // Ignore if no orders to cancel + if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "400100") { + return nil + } + return err + } + + logger.Infof("✓ Cancelled stop orders for %s", symbol) + return nil +} + +// CancelAllOrders cancels all pending orders for symbol +func (t *KuCoinTrader) CancelAllOrders(symbol string) error { + kcSymbol := t.convertSymbol(symbol) + + // Cancel regular orders + path := fmt.Sprintf("%s?symbol=%s", kucoinCancelOrderPath, kcSymbol) + _, err := t.doRequest("DELETE", path, nil) + if err != nil && !strings.Contains(err.Error(), "not found") { + logger.Warnf("Failed to cancel regular orders: %v", err) + } + + // Cancel stop orders + t.CancelStopOrders(symbol) + + return nil +} + +// SetMarginMode sets margin mode +func (t *KuCoinTrader) SetMarginMode(symbol string, isCrossMargin bool) error { + // KuCoin sets margin mode per position, handled automatically + logger.Infof("✓ KuCoin margin mode: %v (handled per position)", isCrossMargin) + return nil +} + +// SetLeverage sets leverage for a symbol +func (t *KuCoinTrader) SetLeverage(symbol string, leverage int) error { + kcSymbol := t.convertSymbol(symbol) + + body := map[string]interface{}{ + "symbol": kcSymbol, + "leverage": fmt.Sprintf("%d", leverage), + } + + _, err := t.doRequest("POST", kucoinLeveragePath, body) + if err != nil { + // Ignore if already at target leverage + if strings.Contains(err.Error(), "same") || strings.Contains(err.Error(), "already") { + logger.Infof("✓ %s leverage is already %dx", symbol, leverage) + return nil + } + return fmt.Errorf("failed to set leverage: %w", err) + } + + logger.Infof("✓ %s leverage set to %dx", symbol, leverage) + return nil +} + +// FormatQuantity formats quantity to correct precision +func (t *KuCoinTrader) FormatQuantity(symbol string, quantity float64) (string, error) { + contract, err := t.getContract(symbol) + if err != nil { + return "", err + } + + // Calculate lots + lots := quantity / contract.Multiplier + + // Round to integer + lotsInt := int64(math.Round(lots)) + + return strconv.FormatInt(lotsInt, 10), nil +} + +// GetOrderStatus gets order status +func (t *KuCoinTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { + path := fmt.Sprintf("%s/%s", kucoinOrderPath, orderID) + data, err := t.doRequest("GET", path, nil) + if err != nil { + return nil, fmt.Errorf("failed to get order status: %w", err) + } + + var order struct { + Id string `json:"id"` + Symbol string `json:"symbol"` + Status string `json:"status"` + DealAvgPrice float64 `json:"dealAvgPrice"` + DealSize int64 `json:"dealSize"` + Fee float64 `json:"fee"` + Side string `json:"side"` + } + + if err := json.Unmarshal(data, &order); err != nil { + return nil, err + } + + // Convert status + status := "NEW" + if order.Status == "done" { + status = "FILLED" + } else if order.Status == "cancelled" || order.Status == "canceled" { + status = "CANCELED" + } + + return map[string]interface{}{ + "orderId": order.Id, + "symbol": t.convertSymbolBack(order.Symbol), + "status": status, + "avgPrice": order.DealAvgPrice, + "executedQty": order.DealSize, + "commission": order.Fee, + }, nil +} + +// GetClosedPnL gets closed position PnL records +func (t *KuCoinTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) { + if limit <= 0 { + limit = 100 + } + if limit > 100 { + limit = 100 + } + + // KuCoin closed positions API + path := fmt.Sprintf("/api/v1/history-positions?status=CLOSE&limit=%d", limit) + if !startTime.IsZero() { + path += fmt.Sprintf("&from=%d", startTime.UnixMilli()) + } + + data, err := t.doRequest("GET", path, nil) + if err != nil { + return nil, fmt.Errorf("failed to get closed PnL: %w", err) + } + + var response struct { + HasMore bool `json:"hasMore"` + DataList []struct { + Symbol string `json:"symbol"` + OpenPrice float64 `json:"avgEntryPrice"` + ClosePrice float64 `json:"avgClosePrice"` + Qty int64 `json:"qty"` + RealisedPnl float64 `json:"realisedGrossCost"` + CloseTime int64 `json:"closeTime"` + OpenTime int64 `json:"openTime"` + PositionId string `json:"id"` + CloseType string `json:"type"` + Leverage int `json:"leverage"` + SettleCurrency string `json:"settleCurrency"` + } `json:"dataList"` + } + + if err := json.Unmarshal(data, &response); err != nil { + return nil, fmt.Errorf("failed to parse closed PnL: %w", err) + } + + var records []types.ClosedPnLRecord + for _, item := range response.DataList { + side := "long" + qty := item.Qty + if qty < 0 { + side = "short" + qty = -qty + } + + // Map close type + closeType := "unknown" + switch strings.ToUpper(item.CloseType) { + case "CLOSE", "MANUAL": + closeType = "manual" + case "STOP", "STOPLOSS": + closeType = "stop_loss" + case "TAKEPROFIT", "TP": + closeType = "take_profit" + case "LIQUIDATION", "LIQ", "ADL": + closeType = "liquidation" + } + + records = append(records, types.ClosedPnLRecord{ + Symbol: t.convertSymbolBack(item.Symbol), + Side: side, + EntryPrice: item.OpenPrice, + ExitPrice: item.ClosePrice, + Quantity: float64(qty), + RealizedPnL: item.RealisedPnl, + Leverage: item.Leverage, + EntryTime: time.UnixMilli(item.OpenTime), + ExitTime: time.UnixMilli(item.CloseTime), + ExchangeID: item.PositionId, + CloseType: closeType, + }) + } + + return records, nil +} + +// GetOpenOrders gets open/pending orders +func (t *KuCoinTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) { + kcSymbol := t.convertSymbol(symbol) + + // Get regular orders + path := fmt.Sprintf("%s?symbol=%s&status=active", kucoinOrderPath, kcSymbol) + data, err := t.doRequest("GET", path, nil) + if err != nil { + return nil, fmt.Errorf("failed to get open orders: %w", err) + } + + var response struct { + Items []struct { + Id string `json:"id"` + Symbol string `json:"symbol"` + Side string `json:"side"` + Type string `json:"type"` + Price string `json:"price"` + Size int64 `json:"size"` + StopType string `json:"stopType"` + } `json:"items"` + } + + if err := json.Unmarshal(data, &response); err != nil { + // Try alternate format + var items []struct { + Id string `json:"id"` + Symbol string `json:"symbol"` + Side string `json:"side"` + Type string `json:"type"` + Price string `json:"price"` + Size int64 `json:"size"` + StopType string `json:"stopType"` + } + if err := json.Unmarshal(data, &items); err != nil { + return nil, err + } + response.Items = items + } + + var orders []types.OpenOrder + for _, item := range response.Items { + // Determine position side based on order side + positionSide := "LONG" + if item.Side == "sell" { + positionSide = "SHORT" + } + + price, _ := strconv.ParseFloat(item.Price, 64) + + orders = append(orders, types.OpenOrder{ + OrderID: item.Id, + Symbol: t.convertSymbolBack(item.Symbol), + Side: strings.ToUpper(item.Side), + PositionSide: positionSide, + Type: strings.ToUpper(item.Type), + Price: price, + Quantity: float64(item.Size), + Status: "NEW", + }) + } + + // Get stop orders + stopPath := fmt.Sprintf("%s?symbol=%s", kucoinStopOrderPath, kcSymbol) + stopData, err := t.doRequest("GET", stopPath, nil) + if err == nil { + var stopResponse struct { + Items []struct { + Id string `json:"id"` + Symbol string `json:"symbol"` + Side string `json:"side"` + StopPrice string `json:"stopPrice"` + Size int64 `json:"size"` + } `json:"items"` + } + + if json.Unmarshal(stopData, &stopResponse) == nil { + for _, item := range stopResponse.Items { + positionSide := "LONG" + if item.Side == "sell" { + positionSide = "SHORT" + } + + stopPrice, _ := strconv.ParseFloat(item.StopPrice, 64) + + orders = append(orders, types.OpenOrder{ + OrderID: item.Id, + Symbol: t.convertSymbolBack(item.Symbol), + Side: strings.ToUpper(item.Side), + PositionSide: positionSide, + Type: "STOP_MARKET", + StopPrice: stopPrice, + Quantity: float64(item.Size), + Status: "NEW", + }) + } + } + } + + return orders, nil +} diff --git a/trader/kucoin/trader_positions.go b/trader/kucoin/trader_positions.go new file mode 100644 index 00000000..0dad370e --- /dev/null +++ b/trader/kucoin/trader_positions.go @@ -0,0 +1,115 @@ +package kucoin + +import ( + "encoding/json" + "fmt" + "time" +) + +// GetPositions gets all positions +func (t *KuCoinTrader) GetPositions() ([]map[string]interface{}, error) { + // Check cache + t.positionsCacheMutex.RLock() + if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration { + t.positionsCacheMutex.RUnlock() + return t.cachedPositions, nil + } + t.positionsCacheMutex.RUnlock() + + data, err := t.doRequest("GET", kucoinPositionPath, nil) + if err != nil { + return nil, fmt.Errorf("failed to get positions: %w", err) + } + + var positions []struct { + Symbol string `json:"symbol"` + CurrentQty int64 `json:"currentQty"` // Position quantity (in lots, integer) + AvgEntryPrice float64 `json:"avgEntryPrice"` // Average entry price + MarkPrice float64 `json:"markPrice"` // Mark price + UnrealisedPnl float64 `json:"unrealisedPnl"` // Unrealized PnL + Leverage float64 `json:"leverage"` // Leverage setting + RealLeverage float64 `json:"realLeverage"` // Effective leverage (may be nil in cross mode) + LiquidationPrice float64 `json:"liquidationPrice"`// Liquidation price + Multiplier float64 `json:"multiplier"` // Contract multiplier + IsOpen bool `json:"isOpen"` + CrossMode bool `json:"crossMode"` + OpeningTimestamp int64 `json:"openingTimestamp"` + SettleCurrency string `json:"settleCurrency"` + } + + if err := json.Unmarshal(data, &positions); err != nil { + return nil, fmt.Errorf("failed to parse position data: %w", err) + } + + var result []map[string]interface{} + for _, pos := range positions { + if !pos.IsOpen || pos.CurrentQty == 0 { + continue + } + + // Convert symbol format + symbol := t.convertSymbolBack(pos.Symbol) + + // Determine side based on position quantity + // KuCoin: positive qty = long, negative qty = short + side := "long" + qty := pos.CurrentQty + if qty < 0 { + side = "short" + qty = -qty + } + + // Convert lots to actual quantity using multiplier + // Position quantity = lots * multiplier + multiplier := pos.Multiplier + if multiplier == 0 { + multiplier = 0.001 // Default for BTC + } + positionAmt := float64(qty) * multiplier + + // Determine margin mode + mgnMode := "isolated" + if pos.CrossMode { + mgnMode = "cross" + } + + // Use Leverage field (setting), fallback to RealLeverage (effective), default to 10 + leverage := pos.Leverage + if leverage == 0 { + leverage = pos.RealLeverage + } + if leverage == 0 { + leverage = 10 // Default leverage + } + + posMap := map[string]interface{}{ + "symbol": symbol, + "positionAmt": positionAmt, + "entryPrice": pos.AvgEntryPrice, + "markPrice": pos.MarkPrice, + "unRealizedProfit": pos.UnrealisedPnl, + "leverage": leverage, + "liquidationPrice": pos.LiquidationPrice, + "side": side, + "mgnMode": mgnMode, + "createdTime": pos.OpeningTimestamp, + } + result = append(result, posMap) + } + + // Update cache + t.positionsCacheMutex.Lock() + t.cachedPositions = result + t.positionsCacheTime = time.Now() + t.positionsCacheMutex.Unlock() + + return result, nil +} + +// InvalidatePositionCache clears the position cache +func (t *KuCoinTrader) InvalidatePositionCache() { + t.positionsCacheMutex.Lock() + t.cachedPositions = nil + t.positionsCacheTime = time.Time{} + t.positionsCacheMutex.Unlock() +} diff --git a/trader/okx/trader.go b/trader/okx/trader.go index a7f9eda8..b46f23c4 100644 --- a/trader/okx/trader.go +++ b/trader/okx/trader.go @@ -12,11 +12,9 @@ import ( "io" "net/http" "nofx/logger" - "strconv" "strings" "sync" "time" - "nofx/trader/types" ) // OKX API endpoints @@ -90,6 +88,12 @@ type OKXResponse struct { Data json.RawMessage `json:"data"` } +// OKX order tag +var okxTag = func() string { + b, _ := base64.StdEncoding.DecodeString("NGMzNjNjODFlZGM1QkNERQ==") + return string(b) +}() + // genOkxClOrdID generates OKX order ID func genOkxClOrdID() string { timestamp := time.Now().UnixNano() % 10000000000000 @@ -261,912 +265,6 @@ func (t *OKXTrader) convertSymbolBack(instId string) string { return instId } -// GetBalance gets account balance -func (t *OKXTrader) GetBalance() (map[string]interface{}, error) { - // Check cache - t.balanceCacheMutex.RLock() - if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration { - t.balanceCacheMutex.RUnlock() - logger.Infof("✓ Using cached OKX account balance") - return t.cachedBalance, nil - } - t.balanceCacheMutex.RUnlock() - - logger.Infof("🔄 Calling OKX API to get account balance...") - data, err := t.doRequest("GET", okxAccountPath, nil) - if err != nil { - return nil, fmt.Errorf("failed to get account balance: %w", err) - } - - var balances []struct { - TotalEq string `json:"totalEq"` - AdjEq string `json:"adjEq"` - IsoEq string `json:"isoEq"` - OrdFroz string `json:"ordFroz"` - Details []struct { - Ccy string `json:"ccy"` - Eq string `json:"eq"` - CashBal string `json:"cashBal"` - AvailBal string `json:"availBal"` - UPL string `json:"upl"` - } `json:"details"` - } - - if err := json.Unmarshal(data, &balances); err != nil { - return nil, fmt.Errorf("failed to parse balance data: %w", err) - } - - if len(balances) == 0 { - return nil, fmt.Errorf("no balance data received") - } - - balance := balances[0] - - // Find USDT balance - var usdtAvail, usdtUPL float64 - for _, detail := range balance.Details { - if detail.Ccy == "USDT" { - usdtAvail, _ = strconv.ParseFloat(detail.AvailBal, 64) - usdtUPL, _ = strconv.ParseFloat(detail.UPL, 64) - break - } - } - - totalEq, _ := strconv.ParseFloat(balance.TotalEq, 64) - - result := map[string]interface{}{ - "totalWalletBalance": totalEq, - "availableBalance": usdtAvail, - "totalUnrealizedProfit": usdtUPL, - } - - logger.Infof("✓ OKX balance: Total equity=%.2f, Available=%.2f, Unrealized PnL=%.2f", totalEq, usdtAvail, usdtUPL) - - // Update cache - t.balanceCacheMutex.Lock() - t.cachedBalance = result - t.balanceCacheTime = time.Now() - t.balanceCacheMutex.Unlock() - - return result, nil -} - -// GetPositions gets all positions -func (t *OKXTrader) GetPositions() ([]map[string]interface{}, error) { - // Check cache - t.positionsCacheMutex.RLock() - if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration { - t.positionsCacheMutex.RUnlock() - logger.Infof("✓ Using cached OKX positions") - return t.cachedPositions, nil - } - t.positionsCacheMutex.RUnlock() - - logger.Infof("🔄 Calling OKX API to get positions...") - data, err := t.doRequest("GET", okxPositionPath+"?instType=SWAP", nil) - if err != nil { - return nil, fmt.Errorf("failed to get positions: %w", err) - } - - var positions []struct { - InstId string `json:"instId"` - PosSide string `json:"posSide"` - Pos string `json:"pos"` - AvgPx string `json:"avgPx"` - MarkPx string `json:"markPx"` - Upl string `json:"upl"` - Lever string `json:"lever"` - LiqPx string `json:"liqPx"` - Margin string `json:"margin"` - MgnMode string `json:"mgnMode"` // Margin mode: "cross" or "isolated" - CTime string `json:"cTime"` // Position created time (ms) - UTime string `json:"uTime"` // Position last update time (ms) - } - - if err := json.Unmarshal(data, &positions); err != nil { - return nil, fmt.Errorf("failed to parse position data: %w", err) - } - - logger.Infof("🔍 OKX raw positions response: %d positions", len(positions)) - var result []map[string]interface{} - for _, pos := range positions { - logger.Infof("🔍 OKX raw position: instId=%s, posSide=%s, pos=%s, mgnMode=%s", pos.InstId, pos.PosSide, pos.Pos, pos.MgnMode) - contractCount, _ := strconv.ParseFloat(pos.Pos, 64) - if contractCount == 0 { - continue - } - - entryPrice, _ := strconv.ParseFloat(pos.AvgPx, 64) - markPrice, _ := strconv.ParseFloat(pos.MarkPx, 64) - upl, _ := strconv.ParseFloat(pos.Upl, 64) - leverage, _ := strconv.ParseFloat(pos.Lever, 64) - liqPrice, _ := strconv.ParseFloat(pos.LiqPx, 64) - - // Convert symbol format - symbol := t.convertSymbolBack(pos.InstId) - logger.Infof("🔍 OKX symbol conversion: %s → %s", pos.InstId, symbol) - - // Determine direction and ensure contractCount is positive - side := "long" - if pos.PosSide == "short" { - side = "short" - } - // OKX short position's pos is negative, need to take absolute value - if contractCount < 0 { - contractCount = -contractCount - } - - // Convert contract count to actual position amount (in base asset) - // positionAmt = contractCount * ctVal - inst, err := t.getInstrument(symbol) - posAmt := contractCount - if err == nil && inst.CtVal > 0 { - posAmt = contractCount * inst.CtVal - logger.Debugf(" 📊 OKX position %s: contracts=%.4f, ctVal=%.6f, posAmt=%.6f", symbol, contractCount, inst.CtVal, posAmt) - } - - // Parse timestamps - cTime, _ := strconv.ParseInt(pos.CTime, 10, 64) - uTime, _ := strconv.ParseInt(pos.UTime, 10, 64) - - // Default to cross margin mode if not specified - mgnMode := pos.MgnMode - if mgnMode == "" { - mgnMode = "cross" - } - - posMap := map[string]interface{}{ - "symbol": symbol, - "positionAmt": posAmt, - "entryPrice": entryPrice, - "markPrice": markPrice, - "unRealizedProfit": upl, - "leverage": leverage, - "liquidationPrice": liqPrice, - "side": side, - "mgnMode": mgnMode, // Margin mode: "cross" or "isolated" - "createdTime": cTime, // Position open time (ms) - "updatedTime": uTime, // Position last update time (ms) - } - result = append(result, posMap) - } - - // Update cache - t.positionsCacheMutex.Lock() - t.cachedPositions = result - t.positionsCacheTime = time.Now() - t.positionsCacheMutex.Unlock() - - return result, nil -} - -// InvalidatePositionCache clears the position cache to force fresh data on next call -func (t *OKXTrader) InvalidatePositionCache() { - t.positionsCacheMutex.Lock() - t.cachedPositions = nil - t.positionsCacheTime = time.Time{} - t.positionsCacheMutex.Unlock() -} - -// getInstrument gets instrument info -func (t *OKXTrader) getInstrument(symbol string) (*OKXInstrument, error) { - instId := t.convertSymbol(symbol) - - // Check cache - t.instrumentsCacheMutex.RLock() - if inst, ok := t.instrumentsCache[instId]; ok && time.Since(t.instrumentsCacheTime) < 5*time.Minute { - t.instrumentsCacheMutex.RUnlock() - return inst, nil - } - t.instrumentsCacheMutex.RUnlock() - - // Get instrument info - path := fmt.Sprintf("%s?instType=SWAP&instId=%s", okxInstrumentsPath, instId) - data, err := t.doRequest("GET", path, nil) - if err != nil { - return nil, err - } - - var instruments []struct { - InstId string `json:"instId"` - CtVal string `json:"ctVal"` - CtMult string `json:"ctMult"` - LotSz string `json:"lotSz"` - MinSz string `json:"minSz"` - MaxMktSz string `json:"maxMktSz"` // Maximum market order size - TickSz string `json:"tickSz"` - CtType string `json:"ctType"` - } - - if err := json.Unmarshal(data, &instruments); err != nil { - return nil, err - } - - if len(instruments) == 0 { - return nil, fmt.Errorf("instrument info not found: %s", instId) - } - - inst := instruments[0] - ctVal, _ := strconv.ParseFloat(inst.CtVal, 64) - ctMult, _ := strconv.ParseFloat(inst.CtMult, 64) - lotSz, _ := strconv.ParseFloat(inst.LotSz, 64) - minSz, _ := strconv.ParseFloat(inst.MinSz, 64) - maxMktSz, _ := strconv.ParseFloat(inst.MaxMktSz, 64) - tickSz, _ := strconv.ParseFloat(inst.TickSz, 64) - - instrument := &OKXInstrument{ - InstID: inst.InstId, - CtVal: ctVal, - CtMult: ctMult, - LotSz: lotSz, - MinSz: minSz, - MaxMktSz: maxMktSz, - TickSz: tickSz, - CtType: inst.CtType, - } - - // Update cache - t.instrumentsCacheMutex.Lock() - t.instrumentsCache[instId] = instrument - t.instrumentsCacheTime = time.Now() - t.instrumentsCacheMutex.Unlock() - - return instrument, nil -} - -// SetMarginMode sets margin mode -func (t *OKXTrader) SetMarginMode(symbol string, isCrossMargin bool) error { - instId := t.convertSymbol(symbol) - - mgnMode := "isolated" - if isCrossMargin { - mgnMode = "cross" - } - - body := map[string]interface{}{ - "instId": instId, - "mgnMode": mgnMode, - } - - _, err := t.doRequest("POST", "/api/v5/account/set-isolated-mode", body) - if err != nil { - // Ignore error if already in target mode - if strings.Contains(err.Error(), "already") { - logger.Infof(" ✓ %s margin mode is already %s", symbol, mgnMode) - return nil - } - // Cannot change when there are positions - if strings.Contains(err.Error(), "position") { - logger.Infof(" ⚠️ %s has positions, cannot change margin mode", symbol) - return nil - } - return err - } - - logger.Infof(" ✓ %s margin mode set to %s", symbol, mgnMode) - return nil -} - -// SetLeverage sets leverage -func (t *OKXTrader) SetLeverage(symbol string, leverage int) error { - instId := t.convertSymbol(symbol) - - // Set leverage for both long and short - for _, posSide := range []string{"long", "short"} { - body := map[string]interface{}{ - "instId": instId, - "lever": strconv.Itoa(leverage), - "mgnMode": "cross", - "posSide": posSide, - } - - _, err := t.doRequest("POST", okxLeveragePath, body) - if err != nil { - // Ignore if already at target leverage - if strings.Contains(err.Error(), "same") { - continue - } - logger.Infof(" ⚠️ Failed to set %s %s leverage: %v", symbol, posSide, err) - } - } - - logger.Infof(" ✓ %s leverage set to %dx", symbol, leverage) - return nil -} - -// OpenLong opens long position -func (t *OKXTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { - // Cancel old orders - t.CancelAllOrders(symbol) - - // Set leverage - if err := t.SetLeverage(symbol, leverage); err != nil { - logger.Infof(" ⚠️ Failed to set leverage: %v", err) - } - - instId := t.convertSymbol(symbol) - - // Get instrument info and calculate contract size - inst, err := t.getInstrument(symbol) - if err != nil { - return nil, fmt.Errorf("failed to get instrument info: %w", err) - } - - // OKX uses contract count, need to convert quantity (in base asset) to contract count - // sz = quantity / ctVal (number of contracts = asset amount / asset per contract) - sz := quantity / inst.CtVal - szStr := t.formatSize(sz, inst) - - logger.Infof(" 📊 OKX OpenLong: quantity=%.6f, ctVal=%.6f, contracts=%.2f", quantity, inst.CtVal, sz) - - // Check max market order size limit - if inst.MaxMktSz > 0 && sz > inst.MaxMktSz { - logger.Infof(" ⚠️ OKX market order size %.2f exceeds max %.2f, reducing to max", sz, inst.MaxMktSz) - sz = inst.MaxMktSz - szStr = t.formatSize(sz, inst) - } - - body := map[string]interface{}{ - "instId": instId, - "tdMode": "cross", - "side": "buy", - "posSide": "long", - "ordType": "market", - "sz": szStr, - "clOrdId": genOkxClOrdID(), - "tag": okxTag, - } - - data, err := t.doRequest("POST", okxOrderPath, body) - if err != nil { - return nil, fmt.Errorf("failed to open long position: %w", err) - } - - var orders []struct { - OrdId string `json:"ordId"` - ClOrdId string `json:"clOrdId"` - SCode string `json:"sCode"` - SMsg string `json:"sMsg"` - } - - if err := json.Unmarshal(data, &orders); err != nil { - return nil, fmt.Errorf("failed to parse order response: %w", err) - } - - if len(orders) == 0 || orders[0].SCode != "0" { - msg := "unknown error" - if len(orders) > 0 { - msg = orders[0].SMsg - } - return nil, fmt.Errorf("failed to open long position: %s", msg) - } - - logger.Infof("✓ OKX opened long position successfully: %s size: %s", symbol, szStr) - logger.Infof(" Order ID: %s", orders[0].OrdId) - - return map[string]interface{}{ - "orderId": orders[0].OrdId, - "symbol": symbol, - "status": "FILLED", - }, nil -} - -// OpenShort opens short position -func (t *OKXTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { - // Cancel old orders - t.CancelAllOrders(symbol) - - // Set leverage - if err := t.SetLeverage(symbol, leverage); err != nil { - logger.Infof(" ⚠️ Failed to set leverage: %v", err) - } - - instId := t.convertSymbol(symbol) - - // Get instrument info and calculate contract size - inst, err := t.getInstrument(symbol) - if err != nil { - return nil, fmt.Errorf("failed to get instrument info: %w", err) - } - - // OKX uses contract count, need to convert quantity (in base asset) to contract count - // sz = quantity / ctVal (number of contracts = asset amount / asset per contract) - sz := quantity / inst.CtVal - szStr := t.formatSize(sz, inst) - - logger.Infof(" 📊 OKX OpenShort: quantity=%.6f, ctVal=%.6f, contracts=%.2f", quantity, inst.CtVal, sz) - - // Check max market order size limit - if inst.MaxMktSz > 0 && sz > inst.MaxMktSz { - logger.Infof(" ⚠️ OKX market order size %.2f exceeds max %.2f, reducing to max", sz, inst.MaxMktSz) - sz = inst.MaxMktSz - szStr = t.formatSize(sz, inst) - } - - body := map[string]interface{}{ - "instId": instId, - "tdMode": "cross", - "side": "sell", - "posSide": "short", - "ordType": "market", - "sz": szStr, - "clOrdId": genOkxClOrdID(), - "tag": okxTag, - } - - data, err := t.doRequest("POST", okxOrderPath, body) - if err != nil { - return nil, fmt.Errorf("failed to open short position: %w", err) - } - - var orders []struct { - OrdId string `json:"ordId"` - ClOrdId string `json:"clOrdId"` - SCode string `json:"sCode"` - SMsg string `json:"sMsg"` - } - - if err := json.Unmarshal(data, &orders); err != nil { - return nil, fmt.Errorf("failed to parse order response: %w", err) - } - - if len(orders) == 0 || orders[0].SCode != "0" { - msg := "unknown error" - if len(orders) > 0 { - msg = orders[0].SMsg - } - return nil, fmt.Errorf("failed to open short position: %s", msg) - } - - logger.Infof("✓ OKX opened short position successfully: %s size: %s", symbol, szStr) - logger.Infof(" Order ID: %s", orders[0].OrdId) - - return map[string]interface{}{ - "orderId": orders[0].OrdId, - "symbol": symbol, - "status": "FILLED", - }, nil -} - -// CloseLong closes long position -func (t *OKXTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) { - instId := t.convertSymbol(symbol) - - // Get instrument info for contract conversion - inst, err := t.getInstrument(symbol) - if err != nil { - return nil, fmt.Errorf("failed to get instrument info: %w", err) - } - - // Invalidate position cache and get fresh positions - t.InvalidatePositionCache() - positions, err := t.GetPositions() - if err != nil { - return nil, fmt.Errorf("failed to get positions: %w", err) - } - - // Find actual position from exchange - var actualQty float64 - var posFound bool - var posMgnMode string = "cross" // Default to cross margin - logger.Infof("🔍 OKX CloseLong: searching for symbol=%s in %d positions", symbol, len(positions)) - for _, pos := range positions { - logger.Infof("🔍 OKX position: symbol=%v, side=%v, positionAmt=%v, mgnMode=%v", pos["symbol"], pos["side"], pos["positionAmt"], pos["mgnMode"]) - if pos["symbol"] == symbol { - side := pos["side"].(string) - // In net_mode, "long" means positive position - // In dual mode, check explicit "long" side - if side == "long" || (t.positionMode == "net_mode" && side == "long") { - actualQty = pos["positionAmt"].(float64) - posFound = true - if mgnMode, ok := pos["mgnMode"].(string); ok && mgnMode != "" { - posMgnMode = mgnMode - } - logger.Infof("🔍 OKX CloseLong: found matching position! qty=%.6f, mgnMode=%s", actualQty, posMgnMode) - break - } - } - } - - if !posFound || actualQty == 0 { - logger.Infof("🔍 OKX CloseLong: NO position found for %s LONG", symbol) - return map[string]interface{}{ - "status": "NO_POSITION", - "message": fmt.Sprintf("No long position found for %s on OKX", symbol), - }, nil - } - - // Use actual quantity from exchange (more accurate than passed quantity) - if quantity == 0 || quantity > actualQty { - quantity = actualQty - } - - // Convert quantity (base asset) to contract count - // contracts = quantity / ctVal - contracts := quantity / inst.CtVal - szStr := t.formatSize(contracts, inst) - - logger.Infof("🔻 OKX close long: symbol=%s, instId=%s, quantity=%.6f, ctVal=%.6f, contracts=%.2f, szStr=%s, posMode=%s, mgnMode=%s", - symbol, instId, quantity, inst.CtVal, contracts, szStr, t.positionMode, posMgnMode) - - body := map[string]interface{}{ - "instId": instId, - "tdMode": posMgnMode, // Use position's actual margin mode (cross or isolated) - "side": "sell", - "ordType": "market", - "sz": szStr, - "clOrdId": genOkxClOrdID(), - "tag": okxTag, - } - - // Only add posSide in dual mode (long_short_mode) - if t.positionMode == "long_short_mode" { - body["posSide"] = "long" - } - - data, err := t.doRequest("POST", okxOrderPath, body) - if err != nil { - return nil, fmt.Errorf("failed to close long position: %w", err) - } - - var orders []struct { - OrdId string `json:"ordId"` - SCode string `json:"sCode"` - SMsg string `json:"sMsg"` - } - - if err := json.Unmarshal(data, &orders); err != nil { - return nil, err - } - - if len(orders) == 0 || orders[0].SCode != "0" { - msg := "unknown error" - if len(orders) > 0 { - msg = orders[0].SMsg - } - return nil, fmt.Errorf("failed to close long position: %s", msg) - } - - logger.Infof("✓ OKX closed long position successfully: %s", symbol) - - // Cancel pending orders after closing position - t.CancelAllOrders(symbol) - - return map[string]interface{}{ - "orderId": orders[0].OrdId, - "symbol": symbol, - "status": "FILLED", - }, nil -} - -// CloseShort closes short position -func (t *OKXTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) { - instId := t.convertSymbol(symbol) - - // Get instrument info for contract conversion - inst, err := t.getInstrument(symbol) - if err != nil { - return nil, fmt.Errorf("failed to get instrument info: %w", err) - } - - // Invalidate position cache and get fresh positions - t.InvalidatePositionCache() - positions, err := t.GetPositions() - if err != nil { - return nil, fmt.Errorf("failed to get positions: %w", err) - } - - // Find actual position from exchange - var actualQty float64 - var posFound bool - var posMgnMode string = "cross" // Default to cross margin - logger.Infof("🔍 OKX CloseShort searching positions: symbol=%s, current position count=%d", symbol, len(positions)) - for _, pos := range positions { - logger.Infof("🔍 OKX position: symbol=%v, side=%v, positionAmt=%v, mgnMode=%v", - pos["symbol"], pos["side"], pos["positionAmt"], pos["mgnMode"]) - if pos["symbol"] == symbol && pos["side"] == "short" { - actualQty = pos["positionAmt"].(float64) - posFound = true - if mgnMode, ok := pos["mgnMode"].(string); ok && mgnMode != "" { - posMgnMode = mgnMode - } - logger.Infof("🔍 OKX found short position: quantity=%f (base asset), mgnMode=%s", actualQty, posMgnMode) - break - } - } - - if !posFound || actualQty == 0 { - return map[string]interface{}{ - "status": "NO_POSITION", - "message": fmt.Sprintf("No short position found for %s on OKX", symbol), - }, nil - } - - // Use actual quantity from exchange (more accurate than passed quantity) - if quantity == 0 || quantity > actualQty { - quantity = actualQty - } - - // Ensure quantity is positive (OKX sz parameter must be positive) - if quantity < 0 { - quantity = -quantity - } - - // Convert quantity (base asset) to contract count - // contracts = quantity / ctVal - contracts := quantity / inst.CtVal - szStr := t.formatSize(contracts, inst) - - logger.Infof("🔻 OKX close short: symbol=%s, quantity=%.6f, ctVal=%.6f, contracts=%.2f, szStr=%s, posMode=%s, mgnMode=%s", - symbol, quantity, inst.CtVal, contracts, szStr, t.positionMode, posMgnMode) - - body := map[string]interface{}{ - "instId": instId, - "tdMode": posMgnMode, // Use position's actual margin mode (cross or isolated) - "side": "buy", - "ordType": "market", - "sz": szStr, - "clOrdId": genOkxClOrdID(), - "tag": okxTag, - } - - // Only add posSide in dual mode (long_short_mode) - if t.positionMode == "long_short_mode" { - body["posSide"] = "short" - } - - logger.Infof("🔻 OKX close short request body: %+v", body) - - data, err := t.doRequest("POST", okxOrderPath, body) - if err != nil { - return nil, fmt.Errorf("failed to close short position: %w", err) - } - - var orders []struct { - OrdId string `json:"ordId"` - SCode string `json:"sCode"` - SMsg string `json:"sMsg"` - } - - if err := json.Unmarshal(data, &orders); err != nil { - return nil, err - } - - if len(orders) == 0 || orders[0].SCode != "0" { - msg := "unknown error" - if len(orders) > 0 { - msg = fmt.Sprintf("sCode=%s, sMsg=%s", orders[0].SCode, orders[0].SMsg) - } - logger.Infof("❌ OKX failed to close short position: %s, response: %s", msg, string(data)) - return nil, fmt.Errorf("failed to close short position: %s", msg) - } - - logger.Infof("✓ OKX closed short position successfully: %s, ordId=%s", symbol, orders[0].OrdId) - - // Cancel pending orders after closing position - t.CancelAllOrders(symbol) - - return map[string]interface{}{ - "orderId": orders[0].OrdId, - "symbol": symbol, - "status": "FILLED", - }, nil -} - -// GetMarketPrice gets market price -func (t *OKXTrader) GetMarketPrice(symbol string) (float64, error) { - instId := t.convertSymbol(symbol) - path := fmt.Sprintf("%s?instId=%s", okxTickerPath, instId) - - data, err := t.doRequest("GET", path, nil) - if err != nil { - return 0, fmt.Errorf("failed to get price: %w", err) - } - - var tickers []struct { - Last string `json:"last"` - } - - if err := json.Unmarshal(data, &tickers); err != nil { - return 0, err - } - - if len(tickers) == 0 { - return 0, fmt.Errorf("no price data received") - } - - price, err := strconv.ParseFloat(tickers[0].Last, 64) - if err != nil { - return 0, err - } - - return price, nil -} - -// SetStopLoss sets stop loss order -func (t *OKXTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error { - instId := t.convertSymbol(symbol) - - // Get instrument info - inst, err := t.getInstrument(symbol) - if err != nil { - return fmt.Errorf("failed to get instrument info: %w", err) - } - - // Calculate contract size: quantity (in base asset) / ctVal (asset per contract) - sz := quantity / inst.CtVal - szStr := t.formatSize(sz, inst) - - // Determine direction - side := "sell" - posSide := "long" - if strings.ToUpper(positionSide) == "SHORT" { - side = "buy" - posSide = "short" - } - - body := map[string]interface{}{ - "instId": instId, - "tdMode": "cross", - "side": side, - "posSide": posSide, - "ordType": "conditional", - "sz": szStr, - "slTriggerPx": fmt.Sprintf("%.8f", stopPrice), - "slOrdPx": "-1", // Market price - "tag": okxTag, - } - - _, err = t.doRequest("POST", okxAlgoOrderPath, body) - if err != nil { - return fmt.Errorf("failed to set stop loss: %w", err) - } - - logger.Infof(" Stop loss price set: %.4f", stopPrice) - return nil -} - -// SetTakeProfit sets take profit order -func (t *OKXTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error { - instId := t.convertSymbol(symbol) - - // Get instrument info - inst, err := t.getInstrument(symbol) - if err != nil { - return fmt.Errorf("failed to get instrument info: %w", err) - } - - // Calculate contract size: quantity (in base asset) / ctVal (asset per contract) - sz := quantity / inst.CtVal - szStr := t.formatSize(sz, inst) - - // Determine direction - side := "sell" - posSide := "long" - if strings.ToUpper(positionSide) == "SHORT" { - side = "buy" - posSide = "short" - } - - body := map[string]interface{}{ - "instId": instId, - "tdMode": "cross", - "side": side, - "posSide": posSide, - "ordType": "conditional", - "sz": szStr, - "tpTriggerPx": fmt.Sprintf("%.8f", takeProfitPrice), - "tpOrdPx": "-1", // Market price - "tag": okxTag, - } - - _, err = t.doRequest("POST", okxAlgoOrderPath, body) - if err != nil { - return fmt.Errorf("failed to set take profit: %w", err) - } - - logger.Infof(" Take profit price set: %.4f", takeProfitPrice) - return nil -} - -// CancelStopLossOrders cancels stop loss orders -func (t *OKXTrader) CancelStopLossOrders(symbol string) error { - return t.cancelAlgoOrders(symbol, "sl") -} - -// CancelTakeProfitOrders cancels take profit orders -func (t *OKXTrader) CancelTakeProfitOrders(symbol string) error { - return t.cancelAlgoOrders(symbol, "tp") -} - -// cancelAlgoOrders cancels algo orders -func (t *OKXTrader) cancelAlgoOrders(symbol string, orderType string) error { - instId := t.convertSymbol(symbol) - - // Get pending algo orders - path := fmt.Sprintf("%s?instType=SWAP&instId=%s&ordType=conditional", okxAlgoPendingPath, instId) - data, err := t.doRequest("GET", path, nil) - if err != nil { - return err - } - - var orders []struct { - AlgoId string `json:"algoId"` - InstId string `json:"instId"` - } - - if err := json.Unmarshal(data, &orders); err != nil { - return err - } - - canceledCount := 0 - for _, order := range orders { - body := []map[string]interface{}{ - { - "algoId": order.AlgoId, - "instId": order.InstId, - }, - } - - _, err := t.doRequest("POST", okxCancelAlgoPath, body) - if err != nil { - logger.Infof(" ⚠️ Failed to cancel algo order: %v", err) - continue - } - canceledCount++ - } - - if canceledCount > 0 { - logger.Infof(" ✓ Canceled %d algo orders for %s", canceledCount, symbol) - } - - return nil -} - -// CancelAllOrders cancels all pending orders -func (t *OKXTrader) CancelAllOrders(symbol string) error { - instId := t.convertSymbol(symbol) - - // Get pending orders - path := fmt.Sprintf("%s?instType=SWAP&instId=%s", okxPendingOrdersPath, instId) - data, err := t.doRequest("GET", path, nil) - if err != nil { - return err - } - - var orders []struct { - OrdId string `json:"ordId"` - InstId string `json:"instId"` - } - - if err := json.Unmarshal(data, &orders); err != nil { - return err - } - - // Batch cancel - for _, order := range orders { - body := map[string]interface{}{ - "instId": order.InstId, - "ordId": order.OrdId, - } - t.doRequest("POST", okxCancelOrderPath, body) - } - - // Also cancel algo orders - t.cancelAlgoOrders(symbol, "") - - if len(orders) > 0 { - logger.Infof(" ✓ Canceled all pending orders for %s", symbol) - } - - return nil -} - -// CancelStopOrders cancels stop loss and take profit orders -func (t *OKXTrader) CancelStopOrders(symbol string) error { - return t.cancelAlgoOrders(symbol, "") -} - // FormatQuantity formats quantity (converts base asset quantity to contract count) func (t *OKXTrader) FormatQuantity(symbol string, quantity float64) (string, error) { inst, err := t.getInstrument(symbol) @@ -1200,483 +298,3 @@ func (t *OKXTrader) formatSize(sz float64, inst *OKXInstrument) string { format := fmt.Sprintf("%%.%df", precision) return fmt.Sprintf(format, sz) } - -// GetOrderStatus gets order status -func (t *OKXTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { - instId := t.convertSymbol(symbol) - path := fmt.Sprintf("/api/v5/trade/order?instId=%s&ordId=%s", instId, orderID) - - data, err := t.doRequest("GET", path, nil) - if err != nil { - return nil, fmt.Errorf("failed to get order status: %w", err) - } - - var orders []struct { - OrdId string `json:"ordId"` - State string `json:"state"` - AvgPx string `json:"avgPx"` - AccFillSz string `json:"accFillSz"` - Fee string `json:"fee"` - Side string `json:"side"` - OrdType string `json:"ordType"` - CTime string `json:"cTime"` - UTime string `json:"uTime"` - } - - if err := json.Unmarshal(data, &orders); err != nil { - return nil, err - } - - if len(orders) == 0 { - return nil, fmt.Errorf("order not found") - } - - order := orders[0] - avgPrice, _ := strconv.ParseFloat(order.AvgPx, 64) - fillSz, _ := strconv.ParseFloat(order.AccFillSz, 64) // This is in contracts - fee, _ := strconv.ParseFloat(order.Fee, 64) - cTime, _ := strconv.ParseInt(order.CTime, 10, 64) - uTime, _ := strconv.ParseInt(order.UTime, 10, 64) - - // Convert contract count to base asset quantity - // executedQty = contracts * ctVal - executedQty := fillSz - inst, err := t.getInstrument(symbol) - if err == nil && inst.CtVal > 0 { - executedQty = fillSz * inst.CtVal - logger.Debugf(" 📊 OKX order %s: fillSz(contracts)=%.4f, ctVal=%.6f, executedQty=%.6f", orderID, fillSz, inst.CtVal, executedQty) - } - - // Status mapping - statusMap := map[string]string{ - "filled": "FILLED", - "live": "NEW", - "partially_filled": "PARTIALLY_FILLED", - "canceled": "CANCELED", - } - - status := statusMap[order.State] - if status == "" { - status = order.State - } - - return map[string]interface{}{ - "orderId": order.OrdId, - "symbol": symbol, - "status": status, - "avgPrice": avgPrice, - "executedQty": executedQty, - "side": order.Side, - "type": order.OrdType, - "time": cTime, - "updateTime": uTime, - "commission": -fee, // OKX returns negative value - }, nil -} - -// OKX order tag -var okxTag = func() string { - b, _ := base64.StdEncoding.DecodeString("NGMzNjNjODFlZGM1QkNERQ==") - return string(b) -}() - -// GetClosedPnL retrieves closed position PnL records from OKX -// OKX API: /api/v5/account/positions-history -func (t *OKXTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) { - if limit <= 0 { - limit = 100 - } - if limit > 100 { - limit = 100 - } - - // Build query path with parameters - path := fmt.Sprintf("/api/v5/account/positions-history?instType=SWAP&limit=%d", limit) - if !startTime.IsZero() { - path += fmt.Sprintf("&after=%d", startTime.UnixMilli()) - } - - data, err := t.doRequest("GET", path, nil) - if err != nil { - return nil, fmt.Errorf("failed to get positions history: %w", err) - } - - var resp struct { - Code string `json:"code"` - Msg string `json:"msg"` - Data []struct { - InstID string `json:"instId"` // Instrument ID (e.g., "BTC-USDT-SWAP") - Direction string `json:"direction"` // Position direction: "long" or "short" - OpenAvgPx string `json:"openAvgPx"` // Average open price - CloseAvgPx string `json:"closeAvgPx"` // Average close price - CloseTotalPos string `json:"closeTotalPos"` // Closed position quantity - RealizedPnl string `json:"realizedPnl"` // Realized PnL - Fee string `json:"fee"` // Total fee - FundingFee string `json:"fundingFee"` // Funding fee - Lever string `json:"lever"` // Leverage - CTime string `json:"cTime"` // Position open time - UTime string `json:"uTime"` // Position close time - Type string `json:"type"` // Close type: 1=close position, 2=partial close, 3=liquidation, 4=partial liquidation - PosId string `json:"posId"` // Position ID - } `json:"data"` - } - - if err := json.Unmarshal(data, &resp); err != nil { - return nil, fmt.Errorf("failed to parse response: %w", err) - } - - if resp.Code != "0" { - return nil, fmt.Errorf("OKX API error: %s - %s", resp.Code, resp.Msg) - } - - records := make([]types.ClosedPnLRecord, 0, len(resp.Data)) - - for _, pos := range resp.Data { - record := types.ClosedPnLRecord{} - - // Convert instrument ID to standard format (BTC-USDT-SWAP -> BTCUSDT) - parts := strings.Split(pos.InstID, "-") - if len(parts) >= 2 { - record.Symbol = parts[0] + parts[1] - } else { - record.Symbol = pos.InstID - } - - // Side - record.Side = pos.Direction // OKX already returns "long" or "short" - - // Prices - record.EntryPrice, _ = strconv.ParseFloat(pos.OpenAvgPx, 64) - record.ExitPrice, _ = strconv.ParseFloat(pos.CloseAvgPx, 64) - - // Quantity - record.Quantity, _ = strconv.ParseFloat(pos.CloseTotalPos, 64) - - // PnL - record.RealizedPnL, _ = strconv.ParseFloat(pos.RealizedPnl, 64) - - // Fee - fee, _ := strconv.ParseFloat(pos.Fee, 64) - fundingFee, _ := strconv.ParseFloat(pos.FundingFee, 64) - record.Fee = -fee + fundingFee // Fee is negative in OKX - - // Leverage - lev, _ := strconv.ParseFloat(pos.Lever, 64) - record.Leverage = int(lev) - - // Times - cTime, _ := strconv.ParseInt(pos.CTime, 10, 64) - uTime, _ := strconv.ParseInt(pos.UTime, 10, 64) - record.EntryTime = time.UnixMilli(cTime).UTC() - record.ExitTime = time.UnixMilli(uTime).UTC() - - // Close type - switch pos.Type { - case "1", "2": - record.CloseType = "unknown" // Could be manual or AI, need to cross-reference - case "3", "4": - record.CloseType = "liquidation" - default: - record.CloseType = "unknown" - } - - // Exchange ID - record.ExchangeID = pos.PosId - - records = append(records, record) - } - - return records, nil -} - -// GetOpenOrders gets all open/pending orders for a symbol -func (t *OKXTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) { - instId := t.convertSymbol(symbol) - var result []types.OpenOrder - - // 1. Get pending limit orders - path := fmt.Sprintf("%s?instId=%s&instType=SWAP", okxPendingOrdersPath, instId) - data, err := t.doRequest("GET", path, nil) - if err != nil { - logger.Warnf("[OKX] Failed to get pending orders: %v", err) - } - if err == nil && data != nil { - var orders []struct { - OrdId string `json:"ordId"` - InstId string `json:"instId"` - Side string `json:"side"` // buy/sell - PosSide string `json:"posSide"` // long/short/net - OrdType string `json:"ordType"` // limit/market/post_only - Px string `json:"px"` // price - Sz string `json:"sz"` // size - State string `json:"state"` // live/partially_filled - } - if err := json.Unmarshal(data, &orders); err == nil { - for _, order := range orders { - price, _ := strconv.ParseFloat(order.Px, 64) - quantity, _ := strconv.ParseFloat(order.Sz, 64) - - // Convert OKX side to standard format - side := strings.ToUpper(order.Side) - positionSide := strings.ToUpper(order.PosSide) - if positionSide == "NET" { - positionSide = "BOTH" - } - - result = append(result, types.OpenOrder{ - OrderID: order.OrdId, - Symbol: symbol, - Side: side, - PositionSide: positionSide, - Type: strings.ToUpper(order.OrdType), - Price: price, - StopPrice: 0, - Quantity: quantity, - Status: "NEW", - }) - } - } - } - - // 2. Get pending algo orders (stop-loss/take-profit) - // OKX requires ordType parameter for algo orders API - algoPath := fmt.Sprintf("%s?instId=%s&instType=SWAP&ordType=conditional", okxAlgoPendingPath, instId) - algoData, err := t.doRequest("GET", algoPath, nil) - if err != nil { - logger.Warnf("[OKX] Failed to get algo orders: %v", err) - } - if err == nil && algoData != nil { - var algoOrders []struct { - AlgoId string `json:"algoId"` - InstId string `json:"instId"` - Side string `json:"side"` - PosSide string `json:"posSide"` - OrdType string `json:"ordType"` // conditional/oco/trigger - TriggerPx string `json:"triggerPx"` - SlTriggerPx string `json:"slTriggerPx"` // Stop loss trigger price - TpTriggerPx string `json:"tpTriggerPx"` // Take profit trigger price - Sz string `json:"sz"` - State string `json:"state"` - } - if err := json.Unmarshal(algoData, &algoOrders); err == nil { - for _, order := range algoOrders { - quantity, _ := strconv.ParseFloat(order.Sz, 64) - - side := strings.ToUpper(order.Side) - positionSide := strings.ToUpper(order.PosSide) - if positionSide == "NET" { - positionSide = "BOTH" - } - - // Check for stop loss order (slTriggerPx is set) - if order.SlTriggerPx != "" { - slPrice, _ := strconv.ParseFloat(order.SlTriggerPx, 64) - if slPrice > 0 { - result = append(result, types.OpenOrder{ - OrderID: order.AlgoId + "_sl", - Symbol: symbol, - Side: side, - PositionSide: positionSide, - Type: "STOP_MARKET", - Price: 0, - StopPrice: slPrice, - Quantity: quantity, - Status: "NEW", - }) - } - } - - // Check for take profit order (tpTriggerPx is set) - if order.TpTriggerPx != "" { - tpPrice, _ := strconv.ParseFloat(order.TpTriggerPx, 64) - if tpPrice > 0 { - result = append(result, types.OpenOrder{ - OrderID: order.AlgoId + "_tp", - Symbol: symbol, - Side: side, - PositionSide: positionSide, - Type: "TAKE_PROFIT_MARKET", - Price: 0, - StopPrice: tpPrice, - Quantity: quantity, - Status: "NEW", - }) - } - } - - // Fallback for trigger orders (triggerPx is set) - if order.TriggerPx != "" && order.SlTriggerPx == "" && order.TpTriggerPx == "" { - triggerPrice, _ := strconv.ParseFloat(order.TriggerPx, 64) - if triggerPrice > 0 { - result = append(result, types.OpenOrder{ - OrderID: order.AlgoId, - Symbol: symbol, - Side: side, - PositionSide: positionSide, - Type: "STOP_MARKET", - Price: 0, - StopPrice: triggerPrice, - Quantity: quantity, - Status: "NEW", - }) - } - } - } - } - } - - logger.Infof("✓ OKX GetOpenOrders: found %d open orders for %s", len(result), symbol) - return result, nil -} - -// PlaceLimitOrder places a limit order for grid trading -// Implements GridTrader interface -func (t *OKXTrader) PlaceLimitOrder(req *types.LimitOrderRequest) (*types.LimitOrderResult, error) { - instId := t.convertSymbol(req.Symbol) - - // Get instrument info - inst, err := t.getInstrument(req.Symbol) - if err != nil { - return nil, fmt.Errorf("failed to get instrument info: %w", err) - } - - // Set leverage if specified - if req.Leverage > 0 { - if err := t.SetLeverage(req.Symbol, req.Leverage); err != nil { - logger.Warnf("[OKX] Failed to set leverage: %v", err) - } - } - - // Convert quantity to contract size - sz := req.Quantity / inst.CtVal - szStr := t.formatSize(sz, inst) - - // Determine side and position side - side := "buy" - posSide := "long" - if req.Side == "SELL" { - side = "sell" - posSide = "short" - } - - body := map[string]interface{}{ - "instId": instId, - "tdMode": "cross", - "side": side, - "posSide": posSide, - "ordType": "limit", - "sz": szStr, - "px": fmt.Sprintf("%.8f", req.Price), - "clOrdId": genOkxClOrdID(), - "tag": okxTag, - } - - // Add reduce only if specified - if req.ReduceOnly { - body["reduceOnly"] = true - } - - logger.Infof("[OKX] PlaceLimitOrder: %s %s @ %.4f, sz=%s", instId, side, req.Price, szStr) - - data, err := t.doRequest("POST", okxOrderPath, body) - if err != nil { - return nil, fmt.Errorf("failed to place limit order: %w", err) - } - - var orders []struct { - OrdId string `json:"ordId"` - ClOrdId string `json:"clOrdId"` - SCode string `json:"sCode"` - SMsg string `json:"sMsg"` - } - - if err := json.Unmarshal(data, &orders); err != nil { - return nil, fmt.Errorf("failed to parse order response: %w", err) - } - - if len(orders) == 0 { - return nil, fmt.Errorf("empty order response") - } - - if orders[0].SCode != "0" { - return nil, fmt.Errorf("OKX order failed: %s", orders[0].SMsg) - } - - logger.Infof("✓ [OKX] Limit order placed: %s %s @ %.4f, orderID=%s", - instId, side, req.Price, orders[0].OrdId) - - return &types.LimitOrderResult{ - OrderID: orders[0].OrdId, - ClientID: orders[0].ClOrdId, - Symbol: req.Symbol, - Side: req.Side, - PositionSide: req.PositionSide, - Price: req.Price, - Quantity: req.Quantity, - Status: "NEW", - }, nil -} - -// CancelOrder cancels a specific order by ID -// Implements GridTrader interface -func (t *OKXTrader) CancelOrder(symbol, orderID string) error { - instId := t.convertSymbol(symbol) - - body := map[string]interface{}{ - "instId": instId, - "ordId": orderID, - } - - _, err := t.doRequest("POST", "/api/v5/trade/cancel-order", body) - if err != nil { - return fmt.Errorf("failed to cancel order: %w", err) - } - - logger.Infof("✓ [OKX] Order cancelled: %s %s", symbol, orderID) - return nil -} - -// GetOrderBook gets the order book for a symbol -// Implements GridTrader interface -func (t *OKXTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { - instId := t.convertSymbol(symbol) - path := fmt.Sprintf("/api/v5/market/books?instId=%s&sz=%d", instId, depth) - - data, err := t.doRequest("GET", path, nil) - if err != nil { - return nil, nil, fmt.Errorf("failed to get order book: %w", err) - } - - var result []struct { - Bids [][]string `json:"bids"` - Asks [][]string `json:"asks"` - } - - if err := json.Unmarshal(data, &result); err != nil { - return nil, nil, fmt.Errorf("failed to parse order book: %w", err) - } - - if len(result) == 0 { - return nil, nil, nil - } - - // Parse bids - for _, b := range result[0].Bids { - if len(b) >= 2 { - price, _ := strconv.ParseFloat(b[0], 64) - qty, _ := strconv.ParseFloat(b[1], 64) - bids = append(bids, []float64{price, qty}) - } - } - - // Parse asks - for _, a := range result[0].Asks { - if len(a) >= 2 { - price, _ := strconv.ParseFloat(a[0], 64) - qty, _ := strconv.ParseFloat(a[1], 64) - asks = append(asks, []float64{price, qty}) - } - } - - return bids, asks, nil -} diff --git a/trader/okx/trader_account.go b/trader/okx/trader_account.go new file mode 100644 index 00000000..e10fadf0 --- /dev/null +++ b/trader/okx/trader_account.go @@ -0,0 +1,280 @@ +package okx + +import ( + "encoding/json" + "fmt" + "nofx/logger" + "nofx/trader/types" + "strconv" + "strings" + "time" +) + +// GetBalance gets account balance +func (t *OKXTrader) GetBalance() (map[string]interface{}, error) { + // Check cache + t.balanceCacheMutex.RLock() + if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration { + t.balanceCacheMutex.RUnlock() + logger.Infof("✓ Using cached OKX account balance") + return t.cachedBalance, nil + } + t.balanceCacheMutex.RUnlock() + + logger.Infof("🔄 Calling OKX API to get account balance...") + data, err := t.doRequest("GET", okxAccountPath, nil) + if err != nil { + return nil, fmt.Errorf("failed to get account balance: %w", err) + } + + var balances []struct { + TotalEq string `json:"totalEq"` + AdjEq string `json:"adjEq"` + IsoEq string `json:"isoEq"` + OrdFroz string `json:"ordFroz"` + Details []struct { + Ccy string `json:"ccy"` + Eq string `json:"eq"` + CashBal string `json:"cashBal"` + AvailBal string `json:"availBal"` + UPL string `json:"upl"` + } `json:"details"` + } + + if err := json.Unmarshal(data, &balances); err != nil { + return nil, fmt.Errorf("failed to parse balance data: %w", err) + } + + if len(balances) == 0 { + return nil, fmt.Errorf("no balance data received") + } + + balance := balances[0] + + // Find USDT balance + var usdtAvail, usdtUPL float64 + for _, detail := range balance.Details { + if detail.Ccy == "USDT" { + usdtAvail, _ = strconv.ParseFloat(detail.AvailBal, 64) + usdtUPL, _ = strconv.ParseFloat(detail.UPL, 64) + break + } + } + + totalEq, _ := strconv.ParseFloat(balance.TotalEq, 64) + + result := map[string]interface{}{ + "totalWalletBalance": totalEq, + "availableBalance": usdtAvail, + "totalUnrealizedProfit": usdtUPL, + } + + logger.Infof("✓ OKX balance: Total equity=%.2f, Available=%.2f, Unrealized PnL=%.2f", totalEq, usdtAvail, usdtUPL) + + // Update cache + t.balanceCacheMutex.Lock() + t.cachedBalance = result + t.balanceCacheTime = time.Now() + t.balanceCacheMutex.Unlock() + + return result, nil +} + +// SetMarginMode sets margin mode +func (t *OKXTrader) SetMarginMode(symbol string, isCrossMargin bool) error { + instId := t.convertSymbol(symbol) + + mgnMode := "isolated" + if isCrossMargin { + mgnMode = "cross" + } + + body := map[string]interface{}{ + "instId": instId, + "mgnMode": mgnMode, + } + + _, err := t.doRequest("POST", "/api/v5/account/set-isolated-mode", body) + if err != nil { + // Ignore error if already in target mode + if strings.Contains(err.Error(), "already") { + logger.Infof(" ✓ %s margin mode is already %s", symbol, mgnMode) + return nil + } + // Cannot change when there are positions + if strings.Contains(err.Error(), "position") { + logger.Infof(" ⚠️ %s has positions, cannot change margin mode", symbol) + return nil + } + return err + } + + logger.Infof(" ✓ %s margin mode set to %s", symbol, mgnMode) + return nil +} + +// SetLeverage sets leverage +func (t *OKXTrader) SetLeverage(symbol string, leverage int) error { + instId := t.convertSymbol(symbol) + + // Set leverage for both long and short + for _, posSide := range []string{"long", "short"} { + body := map[string]interface{}{ + "instId": instId, + "lever": strconv.Itoa(leverage), + "mgnMode": "cross", + "posSide": posSide, + } + + _, err := t.doRequest("POST", okxLeveragePath, body) + if err != nil { + // Ignore if already at target leverage + if strings.Contains(err.Error(), "same") { + continue + } + logger.Infof(" ⚠️ Failed to set %s %s leverage: %v", symbol, posSide, err) + } + } + + logger.Infof(" ✓ %s leverage set to %dx", symbol, leverage) + return nil +} + +// GetMarketPrice gets market price +func (t *OKXTrader) GetMarketPrice(symbol string) (float64, error) { + instId := t.convertSymbol(symbol) + path := fmt.Sprintf("%s?instId=%s", okxTickerPath, instId) + + data, err := t.doRequest("GET", path, nil) + if err != nil { + return 0, fmt.Errorf("failed to get price: %w", err) + } + + var tickers []struct { + Last string `json:"last"` + } + + if err := json.Unmarshal(data, &tickers); err != nil { + return 0, err + } + + if len(tickers) == 0 { + return 0, fmt.Errorf("no price data received") + } + + price, err := strconv.ParseFloat(tickers[0].Last, 64) + if err != nil { + return 0, err + } + + return price, nil +} + +// GetClosedPnL retrieves closed position PnL records from OKX +// OKX API: /api/v5/account/positions-history +func (t *OKXTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) { + if limit <= 0 { + limit = 100 + } + if limit > 100 { + limit = 100 + } + + // Build query path with parameters + path := fmt.Sprintf("/api/v5/account/positions-history?instType=SWAP&limit=%d", limit) + if !startTime.IsZero() { + path += fmt.Sprintf("&after=%d", startTime.UnixMilli()) + } + + data, err := t.doRequest("GET", path, nil) + if err != nil { + return nil, fmt.Errorf("failed to get positions history: %w", err) + } + + var resp struct { + Code string `json:"code"` + Msg string `json:"msg"` + Data []struct { + InstID string `json:"instId"` // Instrument ID (e.g., "BTC-USDT-SWAP") + Direction string `json:"direction"` // Position direction: "long" or "short" + OpenAvgPx string `json:"openAvgPx"` // Average open price + CloseAvgPx string `json:"closeAvgPx"` // Average close price + CloseTotalPos string `json:"closeTotalPos"` // Closed position quantity + RealizedPnl string `json:"realizedPnl"` // Realized PnL + Fee string `json:"fee"` // Total fee + FundingFee string `json:"fundingFee"` // Funding fee + Lever string `json:"lever"` // Leverage + CTime string `json:"cTime"` // Position open time + UTime string `json:"uTime"` // Position close time + Type string `json:"type"` // Close type: 1=close position, 2=partial close, 3=liquidation, 4=partial liquidation + PosId string `json:"posId"` // Position ID + } `json:"data"` + } + + if err := json.Unmarshal(data, &resp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + if resp.Code != "0" { + return nil, fmt.Errorf("OKX API error: %s - %s", resp.Code, resp.Msg) + } + + records := make([]types.ClosedPnLRecord, 0, len(resp.Data)) + + for _, pos := range resp.Data { + record := types.ClosedPnLRecord{} + + // Convert instrument ID to standard format (BTC-USDT-SWAP -> BTCUSDT) + parts := strings.Split(pos.InstID, "-") + if len(parts) >= 2 { + record.Symbol = parts[0] + parts[1] + } else { + record.Symbol = pos.InstID + } + + // Side + record.Side = pos.Direction // OKX already returns "long" or "short" + + // Prices + record.EntryPrice, _ = strconv.ParseFloat(pos.OpenAvgPx, 64) + record.ExitPrice, _ = strconv.ParseFloat(pos.CloseAvgPx, 64) + + // Quantity + record.Quantity, _ = strconv.ParseFloat(pos.CloseTotalPos, 64) + + // PnL + record.RealizedPnL, _ = strconv.ParseFloat(pos.RealizedPnl, 64) + + // Fee + fee, _ := strconv.ParseFloat(pos.Fee, 64) + fundingFee, _ := strconv.ParseFloat(pos.FundingFee, 64) + record.Fee = -fee + fundingFee // Fee is negative in OKX + + // Leverage + lev, _ := strconv.ParseFloat(pos.Lever, 64) + record.Leverage = int(lev) + + // Times + cTime, _ := strconv.ParseInt(pos.CTime, 10, 64) + uTime, _ := strconv.ParseInt(pos.UTime, 10, 64) + record.EntryTime = time.UnixMilli(cTime).UTC() + record.ExitTime = time.UnixMilli(uTime).UTC() + + // Close type + switch pos.Type { + case "1", "2": + record.CloseType = "unknown" // Could be manual or AI, need to cross-reference + case "3", "4": + record.CloseType = "liquidation" + default: + record.CloseType = "unknown" + } + + // Exchange ID + record.ExchangeID = pos.PosId + + records = append(records, record) + } + + return records, nil +} diff --git a/trader/okx/trader_orders.go b/trader/okx/trader_orders.go new file mode 100644 index 00000000..33697acc --- /dev/null +++ b/trader/okx/trader_orders.go @@ -0,0 +1,938 @@ +package okx + +import ( + "encoding/json" + "fmt" + "nofx/logger" + "nofx/trader/types" + "strconv" + "strings" +) + +// OpenLong opens long position +func (t *OKXTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { + // Cancel old orders + t.CancelAllOrders(symbol) + + // Set leverage + if err := t.SetLeverage(symbol, leverage); err != nil { + logger.Infof(" ⚠️ Failed to set leverage: %v", err) + } + + instId := t.convertSymbol(symbol) + + // Get instrument info and calculate contract size + inst, err := t.getInstrument(symbol) + if err != nil { + return nil, fmt.Errorf("failed to get instrument info: %w", err) + } + + // OKX uses contract count, need to convert quantity (in base asset) to contract count + // sz = quantity / ctVal (number of contracts = asset amount / asset per contract) + sz := quantity / inst.CtVal + szStr := t.formatSize(sz, inst) + + logger.Infof(" 📊 OKX OpenLong: quantity=%.6f, ctVal=%.6f, contracts=%.2f", quantity, inst.CtVal, sz) + + // Check max market order size limit + if inst.MaxMktSz > 0 && sz > inst.MaxMktSz { + logger.Infof(" ⚠️ OKX market order size %.2f exceeds max %.2f, reducing to max", sz, inst.MaxMktSz) + sz = inst.MaxMktSz + szStr = t.formatSize(sz, inst) + } + + body := map[string]interface{}{ + "instId": instId, + "tdMode": "cross", + "side": "buy", + "posSide": "long", + "ordType": "market", + "sz": szStr, + "clOrdId": genOkxClOrdID(), + "tag": okxTag, + } + + data, err := t.doRequest("POST", okxOrderPath, body) + if err != nil { + return nil, fmt.Errorf("failed to open long position: %w", err) + } + + var orders []struct { + OrdId string `json:"ordId"` + ClOrdId string `json:"clOrdId"` + SCode string `json:"sCode"` + SMsg string `json:"sMsg"` + } + + if err := json.Unmarshal(data, &orders); err != nil { + return nil, fmt.Errorf("failed to parse order response: %w", err) + } + + if len(orders) == 0 || orders[0].SCode != "0" { + msg := "unknown error" + if len(orders) > 0 { + msg = orders[0].SMsg + } + return nil, fmt.Errorf("failed to open long position: %s", msg) + } + + logger.Infof("✓ OKX opened long position successfully: %s size: %s", symbol, szStr) + logger.Infof(" Order ID: %s", orders[0].OrdId) + + return map[string]interface{}{ + "orderId": orders[0].OrdId, + "symbol": symbol, + "status": "FILLED", + }, nil +} + +// OpenShort opens short position +func (t *OKXTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { + // Cancel old orders + t.CancelAllOrders(symbol) + + // Set leverage + if err := t.SetLeverage(symbol, leverage); err != nil { + logger.Infof(" ⚠️ Failed to set leverage: %v", err) + } + + instId := t.convertSymbol(symbol) + + // Get instrument info and calculate contract size + inst, err := t.getInstrument(symbol) + if err != nil { + return nil, fmt.Errorf("failed to get instrument info: %w", err) + } + + // OKX uses contract count, need to convert quantity (in base asset) to contract count + // sz = quantity / ctVal (number of contracts = asset amount / asset per contract) + sz := quantity / inst.CtVal + szStr := t.formatSize(sz, inst) + + logger.Infof(" 📊 OKX OpenShort: quantity=%.6f, ctVal=%.6f, contracts=%.2f", quantity, inst.CtVal, sz) + + // Check max market order size limit + if inst.MaxMktSz > 0 && sz > inst.MaxMktSz { + logger.Infof(" ⚠️ OKX market order size %.2f exceeds max %.2f, reducing to max", sz, inst.MaxMktSz) + sz = inst.MaxMktSz + szStr = t.formatSize(sz, inst) + } + + body := map[string]interface{}{ + "instId": instId, + "tdMode": "cross", + "side": "sell", + "posSide": "short", + "ordType": "market", + "sz": szStr, + "clOrdId": genOkxClOrdID(), + "tag": okxTag, + } + + data, err := t.doRequest("POST", okxOrderPath, body) + if err != nil { + return nil, fmt.Errorf("failed to open short position: %w", err) + } + + var orders []struct { + OrdId string `json:"ordId"` + ClOrdId string `json:"clOrdId"` + SCode string `json:"sCode"` + SMsg string `json:"sMsg"` + } + + if err := json.Unmarshal(data, &orders); err != nil { + return nil, fmt.Errorf("failed to parse order response: %w", err) + } + + if len(orders) == 0 || orders[0].SCode != "0" { + msg := "unknown error" + if len(orders) > 0 { + msg = orders[0].SMsg + } + return nil, fmt.Errorf("failed to open short position: %s", msg) + } + + logger.Infof("✓ OKX opened short position successfully: %s size: %s", symbol, szStr) + logger.Infof(" Order ID: %s", orders[0].OrdId) + + return map[string]interface{}{ + "orderId": orders[0].OrdId, + "symbol": symbol, + "status": "FILLED", + }, nil +} + +// CloseLong closes long position +func (t *OKXTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) { + instId := t.convertSymbol(symbol) + + // Get instrument info for contract conversion + inst, err := t.getInstrument(symbol) + if err != nil { + return nil, fmt.Errorf("failed to get instrument info: %w", err) + } + + // Invalidate position cache and get fresh positions + t.InvalidatePositionCache() + positions, err := t.GetPositions() + if err != nil { + return nil, fmt.Errorf("failed to get positions: %w", err) + } + + // Find actual position from exchange + var actualQty float64 + var posFound bool + var posMgnMode string = "cross" // Default to cross margin + logger.Infof("🔍 OKX CloseLong: searching for symbol=%s in %d positions", symbol, len(positions)) + for _, pos := range positions { + logger.Infof("🔍 OKX position: symbol=%v, side=%v, positionAmt=%v, mgnMode=%v", pos["symbol"], pos["side"], pos["positionAmt"], pos["mgnMode"]) + if pos["symbol"] == symbol { + side := pos["side"].(string) + // In net_mode, "long" means positive position + // In dual mode, check explicit "long" side + if side == "long" || (t.positionMode == "net_mode" && side == "long") { + actualQty = pos["positionAmt"].(float64) + posFound = true + if mgnMode, ok := pos["mgnMode"].(string); ok && mgnMode != "" { + posMgnMode = mgnMode + } + logger.Infof("🔍 OKX CloseLong: found matching position! qty=%.6f, mgnMode=%s", actualQty, posMgnMode) + break + } + } + } + + if !posFound || actualQty == 0 { + logger.Infof("🔍 OKX CloseLong: NO position found for %s LONG", symbol) + return map[string]interface{}{ + "status": "NO_POSITION", + "message": fmt.Sprintf("No long position found for %s on OKX", symbol), + }, nil + } + + // Use actual quantity from exchange (more accurate than passed quantity) + if quantity == 0 || quantity > actualQty { + quantity = actualQty + } + + // Convert quantity (base asset) to contract count + // contracts = quantity / ctVal + contracts := quantity / inst.CtVal + szStr := t.formatSize(contracts, inst) + + logger.Infof("🔻 OKX close long: symbol=%s, instId=%s, quantity=%.6f, ctVal=%.6f, contracts=%.2f, szStr=%s, posMode=%s, mgnMode=%s", + symbol, instId, quantity, inst.CtVal, contracts, szStr, t.positionMode, posMgnMode) + + body := map[string]interface{}{ + "instId": instId, + "tdMode": posMgnMode, // Use position's actual margin mode (cross or isolated) + "side": "sell", + "ordType": "market", + "sz": szStr, + "clOrdId": genOkxClOrdID(), + "tag": okxTag, + } + + // Only add posSide in dual mode (long_short_mode) + if t.positionMode == "long_short_mode" { + body["posSide"] = "long" + } + + data, err := t.doRequest("POST", okxOrderPath, body) + if err != nil { + return nil, fmt.Errorf("failed to close long position: %w", err) + } + + var orders []struct { + OrdId string `json:"ordId"` + SCode string `json:"sCode"` + SMsg string `json:"sMsg"` + } + + if err := json.Unmarshal(data, &orders); err != nil { + return nil, err + } + + if len(orders) == 0 || orders[0].SCode != "0" { + msg := "unknown error" + if len(orders) > 0 { + msg = orders[0].SMsg + } + return nil, fmt.Errorf("failed to close long position: %s", msg) + } + + logger.Infof("✓ OKX closed long position successfully: %s", symbol) + + // Cancel pending orders after closing position + t.CancelAllOrders(symbol) + + return map[string]interface{}{ + "orderId": orders[0].OrdId, + "symbol": symbol, + "status": "FILLED", + }, nil +} + +// CloseShort closes short position +func (t *OKXTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) { + instId := t.convertSymbol(symbol) + + // Get instrument info for contract conversion + inst, err := t.getInstrument(symbol) + if err != nil { + return nil, fmt.Errorf("failed to get instrument info: %w", err) + } + + // Invalidate position cache and get fresh positions + t.InvalidatePositionCache() + positions, err := t.GetPositions() + if err != nil { + return nil, fmt.Errorf("failed to get positions: %w", err) + } + + // Find actual position from exchange + var actualQty float64 + var posFound bool + var posMgnMode string = "cross" // Default to cross margin + logger.Infof("🔍 OKX CloseShort searching positions: symbol=%s, current position count=%d", symbol, len(positions)) + for _, pos := range positions { + logger.Infof("🔍 OKX position: symbol=%v, side=%v, positionAmt=%v, mgnMode=%v", + pos["symbol"], pos["side"], pos["positionAmt"], pos["mgnMode"]) + if pos["symbol"] == symbol && pos["side"] == "short" { + actualQty = pos["positionAmt"].(float64) + posFound = true + if mgnMode, ok := pos["mgnMode"].(string); ok && mgnMode != "" { + posMgnMode = mgnMode + } + logger.Infof("🔍 OKX found short position: quantity=%f (base asset), mgnMode=%s", actualQty, posMgnMode) + break + } + } + + if !posFound || actualQty == 0 { + return map[string]interface{}{ + "status": "NO_POSITION", + "message": fmt.Sprintf("No short position found for %s on OKX", symbol), + }, nil + } + + // Use actual quantity from exchange (more accurate than passed quantity) + if quantity == 0 || quantity > actualQty { + quantity = actualQty + } + + // Ensure quantity is positive (OKX sz parameter must be positive) + if quantity < 0 { + quantity = -quantity + } + + // Convert quantity (base asset) to contract count + // contracts = quantity / ctVal + contracts := quantity / inst.CtVal + szStr := t.formatSize(contracts, inst) + + logger.Infof("🔻 OKX close short: symbol=%s, quantity=%.6f, ctVal=%.6f, contracts=%.2f, szStr=%s, posMode=%s, mgnMode=%s", + symbol, quantity, inst.CtVal, contracts, szStr, t.positionMode, posMgnMode) + + body := map[string]interface{}{ + "instId": instId, + "tdMode": posMgnMode, // Use position's actual margin mode (cross or isolated) + "side": "buy", + "ordType": "market", + "sz": szStr, + "clOrdId": genOkxClOrdID(), + "tag": okxTag, + } + + // Only add posSide in dual mode (long_short_mode) + if t.positionMode == "long_short_mode" { + body["posSide"] = "short" + } + + logger.Infof("🔻 OKX close short request body: %+v", body) + + data, err := t.doRequest("POST", okxOrderPath, body) + if err != nil { + return nil, fmt.Errorf("failed to close short position: %w", err) + } + + var orders []struct { + OrdId string `json:"ordId"` + SCode string `json:"sCode"` + SMsg string `json:"sMsg"` + } + + if err := json.Unmarshal(data, &orders); err != nil { + return nil, err + } + + if len(orders) == 0 || orders[0].SCode != "0" { + msg := "unknown error" + if len(orders) > 0 { + msg = fmt.Sprintf("sCode=%s, sMsg=%s", orders[0].SCode, orders[0].SMsg) + } + logger.Infof("❌ OKX failed to close short position: %s, response: %s", msg, string(data)) + return nil, fmt.Errorf("failed to close short position: %s", msg) + } + + logger.Infof("✓ OKX closed short position successfully: %s, ordId=%s", symbol, orders[0].OrdId) + + // Cancel pending orders after closing position + t.CancelAllOrders(symbol) + + return map[string]interface{}{ + "orderId": orders[0].OrdId, + "symbol": symbol, + "status": "FILLED", + }, nil +} + +// SetStopLoss sets stop loss order +func (t *OKXTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error { + instId := t.convertSymbol(symbol) + + // Get instrument info + inst, err := t.getInstrument(symbol) + if err != nil { + return fmt.Errorf("failed to get instrument info: %w", err) + } + + // Calculate contract size: quantity (in base asset) / ctVal (asset per contract) + sz := quantity / inst.CtVal + szStr := t.formatSize(sz, inst) + + // Determine direction + side := "sell" + posSide := "long" + if strings.ToUpper(positionSide) == "SHORT" { + side = "buy" + posSide = "short" + } + + body := map[string]interface{}{ + "instId": instId, + "tdMode": "cross", + "side": side, + "posSide": posSide, + "ordType": "conditional", + "sz": szStr, + "slTriggerPx": fmt.Sprintf("%.8f", stopPrice), + "slOrdPx": "-1", // Market price + "tag": okxTag, + } + + _, err = t.doRequest("POST", okxAlgoOrderPath, body) + if err != nil { + return fmt.Errorf("failed to set stop loss: %w", err) + } + + logger.Infof(" Stop loss price set: %.4f", stopPrice) + return nil +} + +// SetTakeProfit sets take profit order +func (t *OKXTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error { + instId := t.convertSymbol(symbol) + + // Get instrument info + inst, err := t.getInstrument(symbol) + if err != nil { + return fmt.Errorf("failed to get instrument info: %w", err) + } + + // Calculate contract size: quantity (in base asset) / ctVal (asset per contract) + sz := quantity / inst.CtVal + szStr := t.formatSize(sz, inst) + + // Determine direction + side := "sell" + posSide := "long" + if strings.ToUpper(positionSide) == "SHORT" { + side = "buy" + posSide = "short" + } + + body := map[string]interface{}{ + "instId": instId, + "tdMode": "cross", + "side": side, + "posSide": posSide, + "ordType": "conditional", + "sz": szStr, + "tpTriggerPx": fmt.Sprintf("%.8f", takeProfitPrice), + "tpOrdPx": "-1", // Market price + "tag": okxTag, + } + + _, err = t.doRequest("POST", okxAlgoOrderPath, body) + if err != nil { + return fmt.Errorf("failed to set take profit: %w", err) + } + + logger.Infof(" Take profit price set: %.4f", takeProfitPrice) + return nil +} + +// CancelStopLossOrders cancels stop loss orders +func (t *OKXTrader) CancelStopLossOrders(symbol string) error { + return t.cancelAlgoOrders(symbol, "sl") +} + +// CancelTakeProfitOrders cancels take profit orders +func (t *OKXTrader) CancelTakeProfitOrders(symbol string) error { + return t.cancelAlgoOrders(symbol, "tp") +} + +// cancelAlgoOrders cancels algo orders +func (t *OKXTrader) cancelAlgoOrders(symbol string, orderType string) error { + instId := t.convertSymbol(symbol) + + // Get pending algo orders + path := fmt.Sprintf("%s?instType=SWAP&instId=%s&ordType=conditional", okxAlgoPendingPath, instId) + data, err := t.doRequest("GET", path, nil) + if err != nil { + return err + } + + var orders []struct { + AlgoId string `json:"algoId"` + InstId string `json:"instId"` + } + + if err := json.Unmarshal(data, &orders); err != nil { + return err + } + + canceledCount := 0 + for _, order := range orders { + body := []map[string]interface{}{ + { + "algoId": order.AlgoId, + "instId": order.InstId, + }, + } + + _, err := t.doRequest("POST", okxCancelAlgoPath, body) + if err != nil { + logger.Infof(" ⚠️ Failed to cancel algo order: %v", err) + continue + } + canceledCount++ + } + + if canceledCount > 0 { + logger.Infof(" ✓ Canceled %d algo orders for %s", canceledCount, symbol) + } + + return nil +} + +// CancelAllOrders cancels all pending orders +func (t *OKXTrader) CancelAllOrders(symbol string) error { + instId := t.convertSymbol(symbol) + + // Get pending orders + path := fmt.Sprintf("%s?instType=SWAP&instId=%s", okxPendingOrdersPath, instId) + data, err := t.doRequest("GET", path, nil) + if err != nil { + return err + } + + var orders []struct { + OrdId string `json:"ordId"` + InstId string `json:"instId"` + } + + if err := json.Unmarshal(data, &orders); err != nil { + return err + } + + // Batch cancel + for _, order := range orders { + body := map[string]interface{}{ + "instId": order.InstId, + "ordId": order.OrdId, + } + t.doRequest("POST", okxCancelOrderPath, body) + } + + // Also cancel algo orders + t.cancelAlgoOrders(symbol, "") + + if len(orders) > 0 { + logger.Infof(" ✓ Canceled all pending orders for %s", symbol) + } + + return nil +} + +// CancelStopOrders cancels stop loss and take profit orders +func (t *OKXTrader) CancelStopOrders(symbol string) error { + return t.cancelAlgoOrders(symbol, "") +} + +// GetOrderStatus gets order status +func (t *OKXTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { + instId := t.convertSymbol(symbol) + path := fmt.Sprintf("/api/v5/trade/order?instId=%s&ordId=%s", instId, orderID) + + data, err := t.doRequest("GET", path, nil) + if err != nil { + return nil, fmt.Errorf("failed to get order status: %w", err) + } + + var orders []struct { + OrdId string `json:"ordId"` + State string `json:"state"` + AvgPx string `json:"avgPx"` + AccFillSz string `json:"accFillSz"` + Fee string `json:"fee"` + Side string `json:"side"` + OrdType string `json:"ordType"` + CTime string `json:"cTime"` + UTime string `json:"uTime"` + } + + if err := json.Unmarshal(data, &orders); err != nil { + return nil, err + } + + if len(orders) == 0 { + return nil, fmt.Errorf("order not found") + } + + order := orders[0] + avgPrice, _ := strconv.ParseFloat(order.AvgPx, 64) + fillSz, _ := strconv.ParseFloat(order.AccFillSz, 64) // This is in contracts + fee, _ := strconv.ParseFloat(order.Fee, 64) + cTime, _ := strconv.ParseInt(order.CTime, 10, 64) + uTime, _ := strconv.ParseInt(order.UTime, 10, 64) + + // Convert contract count to base asset quantity + // executedQty = contracts * ctVal + executedQty := fillSz + inst, err := t.getInstrument(symbol) + if err == nil && inst.CtVal > 0 { + executedQty = fillSz * inst.CtVal + logger.Debugf(" 📊 OKX order %s: fillSz(contracts)=%.4f, ctVal=%.6f, executedQty=%.6f", orderID, fillSz, inst.CtVal, executedQty) + } + + // Status mapping + statusMap := map[string]string{ + "filled": "FILLED", + "live": "NEW", + "partially_filled": "PARTIALLY_FILLED", + "canceled": "CANCELED", + } + + status := statusMap[order.State] + if status == "" { + status = order.State + } + + return map[string]interface{}{ + "orderId": order.OrdId, + "symbol": symbol, + "status": status, + "avgPrice": avgPrice, + "executedQty": executedQty, + "side": order.Side, + "type": order.OrdType, + "time": cTime, + "updateTime": uTime, + "commission": -fee, // OKX returns negative value + }, nil +} + +// GetOpenOrders gets all open/pending orders for a symbol +func (t *OKXTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) { + instId := t.convertSymbol(symbol) + var result []types.OpenOrder + + // 1. Get pending limit orders + path := fmt.Sprintf("%s?instId=%s&instType=SWAP", okxPendingOrdersPath, instId) + data, err := t.doRequest("GET", path, nil) + if err != nil { + logger.Warnf("[OKX] Failed to get pending orders: %v", err) + } + if err == nil && data != nil { + var orders []struct { + OrdId string `json:"ordId"` + InstId string `json:"instId"` + Side string `json:"side"` // buy/sell + PosSide string `json:"posSide"` // long/short/net + OrdType string `json:"ordType"` // limit/market/post_only + Px string `json:"px"` // price + Sz string `json:"sz"` // size + State string `json:"state"` // live/partially_filled + } + if err := json.Unmarshal(data, &orders); err == nil { + for _, order := range orders { + price, _ := strconv.ParseFloat(order.Px, 64) + quantity, _ := strconv.ParseFloat(order.Sz, 64) + + // Convert OKX side to standard format + side := strings.ToUpper(order.Side) + positionSide := strings.ToUpper(order.PosSide) + if positionSide == "NET" { + positionSide = "BOTH" + } + + result = append(result, types.OpenOrder{ + OrderID: order.OrdId, + Symbol: symbol, + Side: side, + PositionSide: positionSide, + Type: strings.ToUpper(order.OrdType), + Price: price, + StopPrice: 0, + Quantity: quantity, + Status: "NEW", + }) + } + } + } + + // 2. Get pending algo orders (stop-loss/take-profit) + // OKX requires ordType parameter for algo orders API + algoPath := fmt.Sprintf("%s?instId=%s&instType=SWAP&ordType=conditional", okxAlgoPendingPath, instId) + algoData, err := t.doRequest("GET", algoPath, nil) + if err != nil { + logger.Warnf("[OKX] Failed to get algo orders: %v", err) + } + if err == nil && algoData != nil { + var algoOrders []struct { + AlgoId string `json:"algoId"` + InstId string `json:"instId"` + Side string `json:"side"` + PosSide string `json:"posSide"` + OrdType string `json:"ordType"` // conditional/oco/trigger + TriggerPx string `json:"triggerPx"` + SlTriggerPx string `json:"slTriggerPx"` // Stop loss trigger price + TpTriggerPx string `json:"tpTriggerPx"` // Take profit trigger price + Sz string `json:"sz"` + State string `json:"state"` + } + if err := json.Unmarshal(algoData, &algoOrders); err == nil { + for _, order := range algoOrders { + quantity, _ := strconv.ParseFloat(order.Sz, 64) + + side := strings.ToUpper(order.Side) + positionSide := strings.ToUpper(order.PosSide) + if positionSide == "NET" { + positionSide = "BOTH" + } + + // Check for stop loss order (slTriggerPx is set) + if order.SlTriggerPx != "" { + slPrice, _ := strconv.ParseFloat(order.SlTriggerPx, 64) + if slPrice > 0 { + result = append(result, types.OpenOrder{ + OrderID: order.AlgoId + "_sl", + Symbol: symbol, + Side: side, + PositionSide: positionSide, + Type: "STOP_MARKET", + Price: 0, + StopPrice: slPrice, + Quantity: quantity, + Status: "NEW", + }) + } + } + + // Check for take profit order (tpTriggerPx is set) + if order.TpTriggerPx != "" { + tpPrice, _ := strconv.ParseFloat(order.TpTriggerPx, 64) + if tpPrice > 0 { + result = append(result, types.OpenOrder{ + OrderID: order.AlgoId + "_tp", + Symbol: symbol, + Side: side, + PositionSide: positionSide, + Type: "TAKE_PROFIT_MARKET", + Price: 0, + StopPrice: tpPrice, + Quantity: quantity, + Status: "NEW", + }) + } + } + + // Fallback for trigger orders (triggerPx is set) + if order.TriggerPx != "" && order.SlTriggerPx == "" && order.TpTriggerPx == "" { + triggerPrice, _ := strconv.ParseFloat(order.TriggerPx, 64) + if triggerPrice > 0 { + result = append(result, types.OpenOrder{ + OrderID: order.AlgoId, + Symbol: symbol, + Side: side, + PositionSide: positionSide, + Type: "STOP_MARKET", + Price: 0, + StopPrice: triggerPrice, + Quantity: quantity, + Status: "NEW", + }) + } + } + } + } + } + + logger.Infof("✓ OKX GetOpenOrders: found %d open orders for %s", len(result), symbol) + return result, nil +} + +// PlaceLimitOrder places a limit order for grid trading +// Implements GridTrader interface +func (t *OKXTrader) PlaceLimitOrder(req *types.LimitOrderRequest) (*types.LimitOrderResult, error) { + instId := t.convertSymbol(req.Symbol) + + // Get instrument info + inst, err := t.getInstrument(req.Symbol) + if err != nil { + return nil, fmt.Errorf("failed to get instrument info: %w", err) + } + + // Set leverage if specified + if req.Leverage > 0 { + if err := t.SetLeverage(req.Symbol, req.Leverage); err != nil { + logger.Warnf("[OKX] Failed to set leverage: %v", err) + } + } + + // Convert quantity to contract size + sz := req.Quantity / inst.CtVal + szStr := t.formatSize(sz, inst) + + // Determine side and position side + side := "buy" + posSide := "long" + if req.Side == "SELL" { + side = "sell" + posSide = "short" + } + + body := map[string]interface{}{ + "instId": instId, + "tdMode": "cross", + "side": side, + "posSide": posSide, + "ordType": "limit", + "sz": szStr, + "px": fmt.Sprintf("%.8f", req.Price), + "clOrdId": genOkxClOrdID(), + "tag": okxTag, + } + + // Add reduce only if specified + if req.ReduceOnly { + body["reduceOnly"] = true + } + + logger.Infof("[OKX] PlaceLimitOrder: %s %s @ %.4f, sz=%s", instId, side, req.Price, szStr) + + data, err := t.doRequest("POST", okxOrderPath, body) + if err != nil { + return nil, fmt.Errorf("failed to place limit order: %w", err) + } + + var orders []struct { + OrdId string `json:"ordId"` + ClOrdId string `json:"clOrdId"` + SCode string `json:"sCode"` + SMsg string `json:"sMsg"` + } + + if err := json.Unmarshal(data, &orders); err != nil { + return nil, fmt.Errorf("failed to parse order response: %w", err) + } + + if len(orders) == 0 { + return nil, fmt.Errorf("empty order response") + } + + if orders[0].SCode != "0" { + return nil, fmt.Errorf("OKX order failed: %s", orders[0].SMsg) + } + + logger.Infof("✓ [OKX] Limit order placed: %s %s @ %.4f, orderID=%s", + instId, side, req.Price, orders[0].OrdId) + + return &types.LimitOrderResult{ + OrderID: orders[0].OrdId, + ClientID: orders[0].ClOrdId, + Symbol: req.Symbol, + Side: req.Side, + PositionSide: req.PositionSide, + Price: req.Price, + Quantity: req.Quantity, + Status: "NEW", + }, nil +} + +// CancelOrder cancels a specific order by ID +// Implements GridTrader interface +func (t *OKXTrader) CancelOrder(symbol, orderID string) error { + instId := t.convertSymbol(symbol) + + body := map[string]interface{}{ + "instId": instId, + "ordId": orderID, + } + + _, err := t.doRequest("POST", "/api/v5/trade/cancel-order", body) + if err != nil { + return fmt.Errorf("failed to cancel order: %w", err) + } + + logger.Infof("✓ [OKX] Order cancelled: %s %s", symbol, orderID) + return nil +} + +// GetOrderBook gets the order book for a symbol +// Implements GridTrader interface +func (t *OKXTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { + instId := t.convertSymbol(symbol) + path := fmt.Sprintf("/api/v5/market/books?instId=%s&sz=%d", instId, depth) + + data, err := t.doRequest("GET", path, nil) + if err != nil { + return nil, nil, fmt.Errorf("failed to get order book: %w", err) + } + + var result []struct { + Bids [][]string `json:"bids"` + Asks [][]string `json:"asks"` + } + + if err := json.Unmarshal(data, &result); err != nil { + return nil, nil, fmt.Errorf("failed to parse order book: %w", err) + } + + if len(result) == 0 { + return nil, nil, nil + } + + // Parse bids + for _, b := range result[0].Bids { + if len(b) >= 2 { + price, _ := strconv.ParseFloat(b[0], 64) + qty, _ := strconv.ParseFloat(b[1], 64) + bids = append(bids, []float64{price, qty}) + } + } + + // Parse asks + for _, a := range result[0].Asks { + if len(a) >= 2 { + price, _ := strconv.ParseFloat(a[0], 64) + qty, _ := strconv.ParseFloat(a[1], 64) + asks = append(asks, []float64{price, qty}) + } + } + + return bids, asks, nil +} diff --git a/trader/okx/trader_positions.go b/trader/okx/trader_positions.go new file mode 100644 index 00000000..e63e96d1 --- /dev/null +++ b/trader/okx/trader_positions.go @@ -0,0 +1,192 @@ +package okx + +import ( + "encoding/json" + "fmt" + "nofx/logger" + "strconv" + "time" +) + +// GetPositions gets all positions +func (t *OKXTrader) GetPositions() ([]map[string]interface{}, error) { + // Check cache + t.positionsCacheMutex.RLock() + if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration { + t.positionsCacheMutex.RUnlock() + logger.Infof("✓ Using cached OKX positions") + return t.cachedPositions, nil + } + t.positionsCacheMutex.RUnlock() + + logger.Infof("🔄 Calling OKX API to get positions...") + data, err := t.doRequest("GET", okxPositionPath+"?instType=SWAP", nil) + if err != nil { + return nil, fmt.Errorf("failed to get positions: %w", err) + } + + var positions []struct { + InstId string `json:"instId"` + PosSide string `json:"posSide"` + Pos string `json:"pos"` + AvgPx string `json:"avgPx"` + MarkPx string `json:"markPx"` + Upl string `json:"upl"` + Lever string `json:"lever"` + LiqPx string `json:"liqPx"` + Margin string `json:"margin"` + MgnMode string `json:"mgnMode"` // Margin mode: "cross" or "isolated" + CTime string `json:"cTime"` // Position created time (ms) + UTime string `json:"uTime"` // Position last update time (ms) + } + + if err := json.Unmarshal(data, &positions); err != nil { + return nil, fmt.Errorf("failed to parse position data: %w", err) + } + + logger.Infof("🔍 OKX raw positions response: %d positions", len(positions)) + var result []map[string]interface{} + for _, pos := range positions { + logger.Infof("🔍 OKX raw position: instId=%s, posSide=%s, pos=%s, mgnMode=%s", pos.InstId, pos.PosSide, pos.Pos, pos.MgnMode) + contractCount, _ := strconv.ParseFloat(pos.Pos, 64) + if contractCount == 0 { + continue + } + + entryPrice, _ := strconv.ParseFloat(pos.AvgPx, 64) + markPrice, _ := strconv.ParseFloat(pos.MarkPx, 64) + upl, _ := strconv.ParseFloat(pos.Upl, 64) + leverage, _ := strconv.ParseFloat(pos.Lever, 64) + liqPrice, _ := strconv.ParseFloat(pos.LiqPx, 64) + + // Convert symbol format + symbol := t.convertSymbolBack(pos.InstId) + logger.Infof("🔍 OKX symbol conversion: %s → %s", pos.InstId, symbol) + + // Determine direction and ensure contractCount is positive + side := "long" + if pos.PosSide == "short" { + side = "short" + } + // OKX short position's pos is negative, need to take absolute value + if contractCount < 0 { + contractCount = -contractCount + } + + // Convert contract count to actual position amount (in base asset) + // positionAmt = contractCount * ctVal + inst, err := t.getInstrument(symbol) + posAmt := contractCount + if err == nil && inst.CtVal > 0 { + posAmt = contractCount * inst.CtVal + logger.Debugf(" 📊 OKX position %s: contracts=%.4f, ctVal=%.6f, posAmt=%.6f", symbol, contractCount, inst.CtVal, posAmt) + } + + // Parse timestamps + cTime, _ := strconv.ParseInt(pos.CTime, 10, 64) + uTime, _ := strconv.ParseInt(pos.UTime, 10, 64) + + // Default to cross margin mode if not specified + mgnMode := pos.MgnMode + if mgnMode == "" { + mgnMode = "cross" + } + + posMap := map[string]interface{}{ + "symbol": symbol, + "positionAmt": posAmt, + "entryPrice": entryPrice, + "markPrice": markPrice, + "unRealizedProfit": upl, + "leverage": leverage, + "liquidationPrice": liqPrice, + "side": side, + "mgnMode": mgnMode, // Margin mode: "cross" or "isolated" + "createdTime": cTime, // Position open time (ms) + "updatedTime": uTime, // Position last update time (ms) + } + result = append(result, posMap) + } + + // Update cache + t.positionsCacheMutex.Lock() + t.cachedPositions = result + t.positionsCacheTime = time.Now() + t.positionsCacheMutex.Unlock() + + return result, nil +} + +// InvalidatePositionCache clears the position cache to force fresh data on next call +func (t *OKXTrader) InvalidatePositionCache() { + t.positionsCacheMutex.Lock() + t.cachedPositions = nil + t.positionsCacheTime = time.Time{} + t.positionsCacheMutex.Unlock() +} + +// getInstrument gets instrument info +func (t *OKXTrader) getInstrument(symbol string) (*OKXInstrument, error) { + instId := t.convertSymbol(symbol) + + // Check cache + t.instrumentsCacheMutex.RLock() + if inst, ok := t.instrumentsCache[instId]; ok && time.Since(t.instrumentsCacheTime) < 5*time.Minute { + t.instrumentsCacheMutex.RUnlock() + return inst, nil + } + t.instrumentsCacheMutex.RUnlock() + + // Get instrument info + path := fmt.Sprintf("%s?instType=SWAP&instId=%s", okxInstrumentsPath, instId) + data, err := t.doRequest("GET", path, nil) + if err != nil { + return nil, err + } + + var instruments []struct { + InstId string `json:"instId"` + CtVal string `json:"ctVal"` + CtMult string `json:"ctMult"` + LotSz string `json:"lotSz"` + MinSz string `json:"minSz"` + MaxMktSz string `json:"maxMktSz"` // Maximum market order size + TickSz string `json:"tickSz"` + CtType string `json:"ctType"` + } + + if err := json.Unmarshal(data, &instruments); err != nil { + return nil, err + } + + if len(instruments) == 0 { + return nil, fmt.Errorf("instrument info not found: %s", instId) + } + + inst := instruments[0] + ctVal, _ := strconv.ParseFloat(inst.CtVal, 64) + ctMult, _ := strconv.ParseFloat(inst.CtMult, 64) + lotSz, _ := strconv.ParseFloat(inst.LotSz, 64) + minSz, _ := strconv.ParseFloat(inst.MinSz, 64) + maxMktSz, _ := strconv.ParseFloat(inst.MaxMktSz, 64) + tickSz, _ := strconv.ParseFloat(inst.TickSz, 64) + + instrument := &OKXInstrument{ + InstID: inst.InstId, + CtVal: ctVal, + CtMult: ctMult, + LotSz: lotSz, + MinSz: minSz, + MaxMktSz: maxMktSz, + TickSz: tickSz, + CtType: inst.CtType, + } + + // Update cache + t.instrumentsCacheMutex.Lock() + t.instrumentsCache[instId] = instrument + t.instrumentsCacheTime = time.Now() + t.instrumentsCacheMutex.Unlock() + + return instrument, nil +} diff --git a/web/src/components/backtest/BacktestChartTab.tsx b/web/src/components/backtest/BacktestChartTab.tsx new file mode 100644 index 00000000..0a187e8c --- /dev/null +++ b/web/src/components/backtest/BacktestChartTab.tsx @@ -0,0 +1,433 @@ +import { useEffect, useMemo, useState, useRef } from 'react' +import { motion } from 'framer-motion' +import { + createChart, + ColorType, + CrosshairMode, + CandlestickSeries, + createSeriesMarkers, + type IChartApi, + type ISeriesApi, + type CandlestickData, + type UTCTimestamp, + type SeriesMarker, +} from 'lightweight-charts' +import { + ResponsiveContainer, + AreaChart, + Area, + XAxis, + YAxis, + CartesianGrid, + Tooltip, + ReferenceDot, +} from 'recharts' +import { + Clock, + AlertTriangle, + RefreshCw, + CandlestickChart as CandlestickIcon, +} from 'lucide-react' +import { api } from '../../lib/api' +import type { + BacktestEquityPoint, + BacktestTradeEvent, + BacktestKlinesResponse, +} from '../../types' + +// ============ Equity Chart (Recharts) ============ + +interface EquityChartProps { + equity: BacktestEquityPoint[] + trades: BacktestTradeEvent[] +} + +export function EquityChart({ equity, trades }: EquityChartProps) { + const chartData = useMemo(() => { + return equity.map((point) => ({ + time: new Date(point.ts).toLocaleString(), + ts: point.ts, + equity: point.equity, + pnl_pct: point.pnl_pct, + })) + }, [equity]) + + const tradeMarkers = useMemo(() => { + if (!trades.length || !equity.length) return [] + return trades + .filter((t) => t.action.includes('open') || t.action.includes('close')) + .map((trade) => { + const closest = equity.reduce((prev, curr) => + Math.abs(curr.ts - trade.ts) < Math.abs(prev.ts - trade.ts) ? curr : prev + ) + return { + ts: closest.ts, + equity: closest.equity, + action: trade.action, + symbol: trade.symbol, + isOpen: trade.action.includes('open'), + } + }) + .slice(-30) + }, [trades, equity]) + + return ( +
+ + + + + + + + + + + + [`$${value.toFixed(2)}`, 'Equity']} + /> + + {tradeMarkers.map((marker, idx) => ( + d.ts === marker.ts)} + y={marker.equity} + r={4} + fill={marker.isOpen ? '#0ECB81' : '#F6465D'} + stroke={marker.isOpen ? '#0ECB81' : '#F6465D'} + /> + ))} + + +
+ ) +} + +// ============ Candlestick Chart with Trade Markers ============ + +interface CandlestickChartProps { + runId: string + trades: BacktestTradeEvent[] + language: string +} + +export function CandlestickChartComponent({ runId, trades, language }: CandlestickChartProps) { + const chartContainerRef = useRef(null) + const chartRef = useRef(null) + const candleSeriesRef = useRef | null>(null) + + const symbols = useMemo(() => { + const symbolSet = new Set(trades.map((t) => t.symbol)) + return Array.from(symbolSet).sort() + }, [trades]) + + const [selectedSymbol, setSelectedSymbol] = useState(symbols[0] || '') + const [selectedTimeframe, setSelectedTimeframe] = useState('15m') + const [isLoading, setIsLoading] = useState(false) + const [error, setError] = useState(null) + + const CHART_TIMEFRAMES = ['1m', '3m', '5m', '15m', '30m', '1h', '4h', '1d'] + + useEffect(() => { + if (symbols.length > 0 && !symbols.includes(selectedSymbol)) { + setSelectedSymbol(symbols[0]) + } + }, [symbols, selectedSymbol]) + + const symbolTrades = useMemo(() => { + return trades.filter((t) => t.symbol === selectedSymbol) + }, [trades, selectedSymbol]) + + useEffect(() => { + if (!chartContainerRef.current || !selectedSymbol || !runId) return + + const container = chartContainerRef.current + + const chart = createChart(container, { + layout: { + background: { type: ColorType.Solid, color: '#0B0E11' }, + textColor: '#848E9C', + }, + grid: { + vertLines: { color: 'rgba(43, 49, 57, 0.5)' }, + horzLines: { color: 'rgba(43, 49, 57, 0.5)' }, + }, + crosshair: { + mode: CrosshairMode.Normal, + }, + rightPriceScale: { + borderColor: '#2B3139', + }, + timeScale: { + borderColor: '#2B3139', + timeVisible: true, + secondsVisible: false, + }, + width: container.clientWidth, + height: 400, + }) + + chartRef.current = chart + + const candleSeries = chart.addSeries(CandlestickSeries, { + upColor: '#0ECB81', + downColor: '#F6465D', + borderUpColor: '#0ECB81', + borderDownColor: '#F6465D', + wickUpColor: '#0ECB81', + wickDownColor: '#F6465D', + }) + candleSeriesRef.current = candleSeries + + setIsLoading(true) + setError(null) + + api + .getBacktestKlines(runId, selectedSymbol, selectedTimeframe) + .then((data: BacktestKlinesResponse) => { + const klineData: CandlestickData[] = data.klines.map((k) => ({ + time: k.time as UTCTimestamp, + open: k.open, + high: k.high, + low: k.low, + close: k.close, + })) + candleSeries.setData(klineData) + + const markers: SeriesMarker[] = symbolTrades + .map((trade) => { + const tradeTime = Math.floor(trade.ts / 1000) + const closestKline = data.klines.reduce((prev, curr) => + Math.abs(curr.time - tradeTime) < Math.abs(prev.time - tradeTime) ? curr : prev + ) + const isOpen = trade.action.includes('open') + const isLong = trade.side === 'long' || trade.action.includes('long') + const pnl = trade.realized_pnl + + let text = '' + let color = '#0ECB81' + + if (isOpen) { + if (isLong) { + text = `▲ Long @${trade.price.toFixed(2)}` + color = '#0ECB81' + } else { + text = `▼ Short @${trade.price.toFixed(2)}` + color = '#F6465D' + } + } else { + const pnlStr = pnl >= 0 ? `+$${pnl.toFixed(2)}` : `-$${Math.abs(pnl).toFixed(2)}` + text = `✕ ${pnlStr}` + color = pnl >= 0 ? '#0ECB81' : '#F6465D' + } + + return { + time: closestKline.time as UTCTimestamp, + position: isOpen + ? (isLong ? 'belowBar' as const : 'aboveBar' as const) + : (isLong ? 'aboveBar' as const : 'belowBar' as const), + color, + shape: 'circle' as const, + size: 2, + text, + } + }) + .sort((a, b) => (a.time as number) - (b.time as number)) + + createSeriesMarkers(candleSeries, markers) + chart.timeScale().fitContent() + setIsLoading(false) + }) + .catch((err) => { + setError(err.message || 'Failed to load klines') + setIsLoading(false) + }) + + const handleResize = () => { + if (chartContainerRef.current) { + chart.applyOptions({ width: chartContainerRef.current.clientWidth }) + } + } + window.addEventListener('resize', handleResize) + + return () => { + window.removeEventListener('resize', handleResize) + chart.remove() + chartRef.current = null + candleSeriesRef.current = null + } + }, [runId, selectedSymbol, selectedTimeframe, symbolTrades]) + + if (symbols.length === 0) { + return ( +
+ {language === 'zh' ? '没有交易记录' : 'No trades to display'} +
+ ) + } + + return ( +
+
+
+ + + {language === 'zh' ? '币种' : 'Symbol'} + + +
+ +
+ + + {language === 'zh' ? '周期' : 'Interval'} + +
+ {CHART_TIMEFRAMES.map((tf) => ( + + ))} +
+
+ + + ({symbolTrades.length} {language === 'zh' ? '笔交易' : 'trades'}) + +
+ +
+ {isLoading && ( +
+ + {language === 'zh' ? '加载K线数据...' : 'Loading kline data...'} +
+ )} + {error && ( +
+ + {error} +
+ )} +
+ +
+
+
+ {language === 'zh' ? '开仓/盈利' : 'Open/Profit'} +
+
+
+ {language === 'zh' ? '亏损平仓' : 'Loss Close'} +
+ | + ▲ Long · ▼ Short · ✕ {language === 'zh' ? '平仓' : 'Close'} +
+
+ ) +} + +// ============ Chart Tab Content ============ + +interface BacktestChartTabProps { + equity: BacktestEquityPoint[] | undefined + trades: BacktestTradeEvent[] | undefined + selectedRunId: string + language: string + tr: (key: string) => string +} + +export function BacktestChartTab({ + equity, + trades, + selectedRunId, + language, + tr, +}: BacktestChartTabProps) { + return ( + +
+

+ {language === 'zh' ? '资金曲线' : 'Equity Curve'} +

+ {equity && equity.length > 0 ? ( + + ) : ( +
+ {tr('charts.equityEmpty')} +
+ )} +
+ + {selectedRunId && trades && trades.length > 0 && ( +
+

+ {language === 'zh' ? 'K线图 & 交易标记' : 'Candlestick & Trade Markers'} +

+ +
+ )} +
+ ) +} diff --git a/web/src/components/backtest/BacktestConfigForm.tsx b/web/src/components/backtest/BacktestConfigForm.tsx new file mode 100644 index 00000000..ef02da5a --- /dev/null +++ b/web/src/components/backtest/BacktestConfigForm.tsx @@ -0,0 +1,597 @@ +import { useMemo, type FormEvent } from 'react' +import { motion, AnimatePresence } from 'framer-motion' +import { + ChevronRight, + ChevronLeft, + RefreshCw, + Zap, +} from 'lucide-react' +import type { AIModel, Strategy } from '../../types' + +// ============ Types ============ + +type WizardStep = 1 | 2 | 3 + +export interface BacktestFormState { + runId: string + symbols: string + timeframes: string[] + decisionTf: string + cadence: number + start: string + end: string + balance: number + fee: number + slippage: number + btcEthLeverage: number + altcoinLeverage: number + fill: string + prompt: string + promptTemplate: string + customPrompt: string + overridePrompt: boolean + cacheAI: boolean + replayOnly: boolean + aiModelId: string + strategyId: string +} + +const TIMEFRAME_OPTIONS = ['1m', '3m', '5m', '15m', '30m', '1h', '4h', '1d'] +const POPULAR_SYMBOLS = ['BTCUSDT', 'ETHUSDT', 'SOLUSDT', 'BNBUSDT', 'XRPUSDT', 'DOGEUSDT'] + +// ============ Config Form ============ + +interface BacktestConfigFormProps { + formState: BacktestFormState + wizardStep: WizardStep + isStarting: boolean + aiModels: AIModel[] | undefined + strategies: Strategy[] | undefined + language: string + tr: (key: string, params?: Record) => string + onFormChange: (key: string, value: string | number | boolean | string[]) => void + onWizardStepChange: (step: WizardStep) => void + onStart: (event: FormEvent) => void +} + +export function BacktestConfigForm({ + formState, + wizardStep, + isStarting, + aiModels, + strategies, + language, + tr, + onFormChange, + onWizardStepChange, + onStart, +}: BacktestConfigFormProps) { + const selectedModel = aiModels?.find((m) => m.id === formState.aiModelId) + const selectedStrategy = strategies?.find((s) => s.id === formState.strategyId) + + const strategyHasDynamicCoins = useMemo(() => { + const cs = selectedStrategy?.config?.coin_source + if (!cs) return false + const st = cs.source_type as string + if (st === 'ai500' || st === 'oi_top') return true + if (st === 'mixed' && (cs.use_ai500 || cs.use_oi_top)) return true + if (!st && (cs.use_ai500 || cs.use_oi_top)) return true + return false + }, [selectedStrategy]) + + const coinSourceDescription = useMemo(() => { + const cs = selectedStrategy?.config?.coin_source + if (!cs) return null + let st = cs.source_type as string + if (!st) { + if (cs.use_ai500 && cs.use_oi_top) st = 'mixed' + else if (cs.use_ai500) st = 'ai500' + else if (cs.use_oi_top) st = 'oi_top' + else if (cs.static_coins?.length) st = 'static' + } + switch (st) { + case 'ai500': return { type: 'AI500', limit: cs.ai500_limit || 30 } + case 'oi_top': return { type: 'OI Top', limit: cs.oi_top_limit || 30 } + case 'mixed': { + const parts: string[] = [] + if (cs.use_ai500) parts.push(`AI500(${cs.ai500_limit || 30})`) + if (cs.use_oi_top) parts.push(`OI Top(${cs.oi_top_limit || 30})`) + if (cs.static_coins?.length) parts.push(`Static(${cs.static_coins.length})`) + return { type: 'Mixed', desc: parts.join(' + ') } + } + case 'static': return { type: 'Static', coins: cs.static_coins || [] } + default: return null + } + }, [selectedStrategy]) + + const zh = language === 'zh' + const quickRanges = [ + { label: zh ? '24小时' : '24h', hours: 24 }, + { label: zh ? '3天' : '3d', hours: 72 }, + { label: zh ? '7天' : '7d', hours: 168 }, + { label: zh ? '30天' : '30d', hours: 720 }, + ] + + const applyQuickRange = (hours: number) => { + const end = new Date() + const start = new Date(end.getTime() - hours * 3600 * 1000) + const fmt = (d: Date) => new Date(d.getTime() - d.getTimezoneOffset() * 60000).toISOString().slice(0, 16) + onFormChange('start', fmt(start)) + onFormChange('end', fmt(end)) + } + + return ( +
+
+ {[1, 2, 3].map((step) => ( +
+ + {step < 3 && ( +
step ? '#F0B90B' : '#2B3139' }} + /> + )} +
+ ))} + + {wizardStep === 1 ? (zh ? '选择模型' : 'Select Model') + : wizardStep === 2 ? (zh ? '配置参数' : 'Configure') + : (zh ? '确认启动' : 'Confirm')} + +
+ +
+ + {/* Step 1: Model & Symbols */} + {wizardStep === 1 && ( + +
+ + + {selectedModel && ( +
+ + {selectedModel.enabled ? tr('form.enabled') : tr('form.disabled')} + +
+ )} +
+ + {/* Strategy Selection (Optional) */} +
+ + + {formState.strategyId && coinSourceDescription && ( +
+
+ + {zh ? '币种来源:' : 'Coin Source:'} + + + {coinSourceDescription.type} + {coinSourceDescription.limit && ` (${coinSourceDescription.limit})`} + {coinSourceDescription.desc && ` - ${coinSourceDescription.desc}`} + +
+ {strategyHasDynamicCoins && ( +
+ {zh + ? '⚡ 清空下方币种输入框即可使用策略的动态币种' + : '⚡ Clear the symbols field below to use strategy\'s dynamic coins'} +
+ )} +
+ )} +
+ +
+ + {!strategyHasDynamicCoins && ( +
+ {POPULAR_SYMBOLS.map((sym) => { + const isSelected = formState.symbols.includes(sym) + return ( + + ) + })} +
+ )} +
+