refactor: split large files and clean up project structure

- Rename experience/ to telemetry/ for clarity
- Split 15+ large Go files (800-2200 lines) into focused modules:
  kernel/engine.go, backtest/runner.go, market/data.go, store/position.go,
  api/handler_trader.go, trader/auto_trader_grid.go, and 9 exchange traders
- Split frontend monoliths: types.ts, api.ts, AITradersPage.tsx, BacktestPage.tsx
  into domain-specific modules with barrel re-exports
- Remove stale files: screenshots, .yml.old, pyproject.toml
- Remove unused scripts/ and cmd/ directories
- Remove broken/outdated test files (network-dependent, stale expectations)
This commit is contained in:
tinkle-community
2026-03-12 12:53:57 +08:00
parent 8e294a5eed
commit cb31782be4
113 changed files with 20423 additions and 25733 deletions
File diff suppressed because it is too large Load Diff
+299
View File
@@ -0,0 +1,299 @@
package aster
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"nofx/logger"
"nofx/trader/types"
"strconv"
"time"
)
// GetBalance Get account balance
func (t *AsterTrader) GetBalance() (map[string]interface{}, error) {
params := make(map[string]interface{})
body, err := t.request("GET", "/fapi/v3/balance", params)
if err != nil {
return nil, err
}
var balances []map[string]interface{}
if err := json.Unmarshal(body, &balances); err != nil {
return nil, err
}
// Find USDT balance
availableBalance := 0.0
crossUnPnl := 0.0
crossWalletBalance := 0.0
foundUSDT := false
for _, bal := range balances {
if asset, ok := bal["asset"].(string); ok && asset == "USDT" {
foundUSDT = true
// Parse Aster fields (reference: https://github.com/asterdex/api-docs)
if avail, ok := bal["availableBalance"].(string); ok {
availableBalance, _ = strconv.ParseFloat(avail, 64)
}
if unpnl, ok := bal["crossUnPnl"].(string); ok {
crossUnPnl, _ = strconv.ParseFloat(unpnl, 64)
}
if cwb, ok := bal["crossWalletBalance"].(string); ok {
crossWalletBalance, _ = strconv.ParseFloat(cwb, 64)
}
break
}
}
if !foundUSDT {
logger.Infof("⚠️ USDT asset record not found!")
}
// Get positions to calculate margin used and real unrealized PnL
positions, err := t.GetPositions()
if err != nil {
logger.Infof("⚠️ Failed to get position information: %v", err)
// fallback: use simple calculation when unable to get positions
return map[string]interface{}{
"totalWalletBalance": crossWalletBalance,
"availableBalance": availableBalance,
"totalUnrealizedProfit": crossUnPnl,
}, nil
}
// Critical fix: accumulate real unrealized PnL from positions
// Aster's crossUnPnl field is inaccurate, need to recalculate from position data
totalMarginUsed := 0.0
realUnrealizedPnl := 0.0
for _, pos := range positions {
markPrice := pos["markPrice"].(float64)
quantity := pos["positionAmt"].(float64)
if quantity < 0 {
quantity = -quantity
}
unrealizedPnl := pos["unRealizedProfit"].(float64)
realUnrealizedPnl += unrealizedPnl
leverage := 10
if lev, ok := pos["leverage"].(float64); ok {
leverage = int(lev)
}
marginUsed := (quantity * markPrice) / float64(leverage)
totalMarginUsed += marginUsed
}
// Aster correct calculation method:
// Total equity = available balance + margin used
// Wallet balance = total equity - unrealized PnL
// Unrealized PnL = calculated from accumulated positions (don't use API's crossUnPnl)
totalEquity := availableBalance + totalMarginUsed
totalWalletBalance := totalEquity - realUnrealizedPnl
return map[string]interface{}{
"totalWalletBalance": totalWalletBalance, // Wallet balance (excluding unrealized PnL)
"availableBalance": availableBalance, // Available balance
"totalUnrealizedProfit": realUnrealizedPnl, // Unrealized PnL (accumulated from positions)
}, nil
}
// GetMarketPrice Get market price
func (t *AsterTrader) GetMarketPrice(symbol string) (float64, error) {
// Use ticker interface to get current price
resp, err := t.client.Get(fmt.Sprintf("%s/fapi/v3/ticker/price?symbol=%s", t.baseURL, symbol))
if err != nil {
return 0, err
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return 0, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
return 0, err
}
priceStr, ok := result["price"].(string)
if !ok {
return 0, errors.New("unable to get price")
}
return strconv.ParseFloat(priceStr, 64)
}
// GetClosedPnL gets recent closing trades from Aster
// Note: Aster does NOT have a position history API, only trade history.
// This returns individual closing trades for real-time position closure detection.
func (t *AsterTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) {
trades, err := t.GetTrades(startTime, limit)
if err != nil {
return nil, err
}
// Filter only closing trades (realizedPnl != 0)
var records []types.ClosedPnLRecord
for _, trade := range trades {
if trade.RealizedPnL == 0 {
continue
}
// Determine side from PositionSide or trade direction
side := "long"
if trade.PositionSide == "SHORT" || trade.PositionSide == "short" {
side = "short"
} else if trade.PositionSide == "BOTH" || trade.PositionSide == "" {
if trade.Side == "SELL" || trade.Side == "Sell" {
side = "long"
} else {
side = "short"
}
}
// Calculate entry price from PnL
var entryPrice float64
if trade.Quantity > 0 {
if side == "long" {
entryPrice = trade.Price - trade.RealizedPnL/trade.Quantity
} else {
entryPrice = trade.Price + trade.RealizedPnL/trade.Quantity
}
}
records = append(records, types.ClosedPnLRecord{
Symbol: trade.Symbol,
Side: side,
EntryPrice: entryPrice,
ExitPrice: trade.Price,
Quantity: trade.Quantity,
RealizedPnL: trade.RealizedPnL,
Fee: trade.Fee,
ExitTime: trade.Time,
EntryTime: trade.Time,
OrderID: trade.TradeID,
ExchangeID: trade.TradeID,
CloseType: "unknown",
})
}
return records, nil
}
// AsterTradeRecord represents a trade from Aster API
type AsterTradeRecord struct {
ID int64 `json:"id"`
Symbol string `json:"symbol"`
OrderID int64 `json:"orderId"`
Side string `json:"side"` // BUY or SELL
PositionSide string `json:"positionSide"` // LONG or SHORT
Price string `json:"price"`
Qty string `json:"qty"`
RealizedPnl string `json:"realizedPnl"`
Commission string `json:"commission"`
Time int64 `json:"time"`
Buyer bool `json:"buyer"`
Maker bool `json:"maker"`
}
// GetTrades retrieves trade history from Aster
func (t *AsterTrader) GetTrades(startTime time.Time, limit int) ([]types.TradeRecord, error) {
if limit <= 0 {
limit = 500
}
// Build request params
params := map[string]interface{}{
"startTime": startTime.UnixMilli(),
"limit": limit,
}
// Use existing request method with signing
body, err := t.request("GET", "/fapi/v3/userTrades", params)
if err != nil {
logger.Infof("⚠️ Aster userTrades API error: %v", err)
return []types.TradeRecord{}, nil
}
var asterTrades []AsterTradeRecord
if err := json.Unmarshal(body, &asterTrades); err != nil {
logger.Infof("⚠️ Failed to parse Aster trades response: %v", err)
return []types.TradeRecord{}, nil
}
// Convert to unified TradeRecord format
var result []types.TradeRecord
for _, at := range asterTrades {
price, _ := strconv.ParseFloat(at.Price, 64)
qty, _ := strconv.ParseFloat(at.Qty, 64)
fee, _ := strconv.ParseFloat(at.Commission, 64)
pnl, _ := strconv.ParseFloat(at.RealizedPnl, 64)
trade := types.TradeRecord{
TradeID: strconv.FormatInt(at.ID, 10),
Symbol: at.Symbol,
Side: at.Side,
PositionSide: at.PositionSide,
Price: price,
Quantity: qty,
RealizedPnL: pnl,
Fee: fee,
Time: time.UnixMilli(at.Time).UTC(),
}
result = append(result, trade)
}
return result, nil
}
// GetOrderBook gets the order book for a symbol
func (t *AsterTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) {
if depth <= 0 {
depth = 20
}
// Aster uses public endpoint (no signature required)
resp, err := t.client.Get(fmt.Sprintf("%s/fapi/v3/depth?symbol=%s&limit=%d", t.baseURL, symbol, depth))
if err != nil {
return nil, nil, fmt.Errorf("failed to fetch order book: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
var result struct {
Bids [][]string `json:"bids"` // [[price, qty], ...]
Asks [][]string `json:"asks"` // [[price, qty], ...]
}
if err := json.Unmarshal(body, &result); err != nil {
return nil, nil, fmt.Errorf("failed to parse order book: %w", err)
}
// Convert string arrays to float64 arrays
bids = make([][]float64, len(result.Bids))
for i, bid := range result.Bids {
if len(bid) >= 2 {
price, _ := strconv.ParseFloat(bid[0], 64)
qty, _ := strconv.ParseFloat(bid[1], 64)
bids[i] = []float64{price, qty}
}
}
asks = make([][]float64, len(result.Asks))
for i, ask := range result.Asks {
if len(ask) >= 2 {
price, _ := strconv.ParseFloat(ask[0], 64)
qty, _ := strconv.ParseFloat(ask[1], 64)
asks[i] = []float64{price, qty}
}
}
return bids, asks, nil
}
+787
View File
@@ -0,0 +1,787 @@
package aster
import (
"encoding/json"
"fmt"
"nofx/logger"
"nofx/trader/types"
"strconv"
"strings"
)
// OpenLong Open long position
func (t *AsterTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
// Cancel all pending orders before opening position to prevent position stacking from residual orders
if err := t.CancelAllOrders(symbol); err != nil {
logger.Infof(" ⚠ Failed to cancel pending orders (continuing to open position): %v", err)
}
// Set leverage first (non-fatal if position already exists)
if err := t.SetLeverage(symbol, leverage); err != nil {
// Error -2030: Cannot adjust leverage when position exists
// This is expected when adding to an existing position, continue with current leverage
if strings.Contains(err.Error(), "-2030") {
logger.Infof(" ⚠ Cannot change leverage (position exists), using current leverage: %v", err)
} else {
return nil, fmt.Errorf("failed to set leverage: %w", err)
}
}
// Get current price
price, err := t.GetMarketPrice(symbol)
if err != nil {
return nil, err
}
// Use limit order to simulate market order (price set slightly higher to ensure execution)
limitPrice := price * 1.01
// Format price and quantity to correct precision
formattedPrice, err := t.formatPrice(symbol, limitPrice)
if err != nil {
return nil, err
}
formattedQty, err := t.formatQuantity(symbol, quantity)
if err != nil {
return nil, err
}
// Get precision information
prec, err := t.getPrecision(symbol)
if err != nil {
return nil, err
}
// Convert to string with correct precision format
priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision)
qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision)
logger.Infof(" 📏 Precision handling: price %.8f -> %s (precision=%d), quantity %.8f -> %s (precision=%d)",
limitPrice, priceStr, prec.PricePrecision, quantity, qtyStr, prec.QuantityPrecision)
params := map[string]interface{}{
"symbol": symbol,
"positionSide": "BOTH",
"type": "LIMIT",
"side": "BUY",
"timeInForce": "GTC",
"quantity": qtyStr,
"price": priceStr,
}
body, err := t.request("POST", "/fapi/v3/order", params)
if err != nil {
return nil, err
}
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
return nil, err
}
return result, nil
}
// OpenShort Open short position
func (t *AsterTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
// Cancel all pending orders before opening position to prevent position stacking from residual orders
if err := t.CancelAllOrders(symbol); err != nil {
logger.Infof(" ⚠ Failed to cancel pending orders (continuing to open position): %v", err)
}
// Set leverage first (non-fatal if position already exists)
if err := t.SetLeverage(symbol, leverage); err != nil {
// Error -2030: Cannot adjust leverage when position exists
// This is expected when adding to an existing position, continue with current leverage
if strings.Contains(err.Error(), "-2030") {
logger.Infof(" ⚠ Cannot change leverage (position exists), using current leverage: %v", err)
} else {
return nil, fmt.Errorf("failed to set leverage: %w", err)
}
}
// Get current price
price, err := t.GetMarketPrice(symbol)
if err != nil {
return nil, err
}
// Use limit order to simulate market order (price set slightly lower to ensure execution)
limitPrice := price * 0.99
// Format price and quantity to correct precision
formattedPrice, err := t.formatPrice(symbol, limitPrice)
if err != nil {
return nil, err
}
formattedQty, err := t.formatQuantity(symbol, quantity)
if err != nil {
return nil, err
}
// Get precision information
prec, err := t.getPrecision(symbol)
if err != nil {
return nil, err
}
// Convert to string with correct precision format
priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision)
qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision)
logger.Infof(" 📏 Precision handling: price %.8f -> %s (precision=%d), quantity %.8f -> %s (precision=%d)",
limitPrice, priceStr, prec.PricePrecision, quantity, qtyStr, prec.QuantityPrecision)
params := map[string]interface{}{
"symbol": symbol,
"positionSide": "BOTH",
"type": "LIMIT",
"side": "SELL",
"timeInForce": "GTC",
"quantity": qtyStr,
"price": priceStr,
}
body, err := t.request("POST", "/fapi/v3/order", params)
if err != nil {
return nil, err
}
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
return nil, err
}
return result, nil
}
// CloseLong Close long position
func (t *AsterTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) {
// If quantity is 0, get current position quantity
if quantity == 0 {
positions, err := t.GetPositions()
if err != nil {
return nil, err
}
for _, pos := range positions {
if pos["symbol"] == symbol && pos["side"] == "long" {
quantity = pos["positionAmt"].(float64)
break
}
}
if quantity == 0 {
return nil, fmt.Errorf("no long position found for %s", symbol)
}
logger.Infof(" 📊 Retrieved long position quantity: %.8f", quantity)
}
price, err := t.GetMarketPrice(symbol)
if err != nil {
return nil, err
}
limitPrice := price * 0.99
// Format price and quantity to correct precision
formattedPrice, err := t.formatPrice(symbol, limitPrice)
if err != nil {
return nil, err
}
formattedQty, err := t.formatQuantity(symbol, quantity)
if err != nil {
return nil, err
}
// Get precision information
prec, err := t.getPrecision(symbol)
if err != nil {
return nil, err
}
// Convert to string with correct precision format
priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision)
qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision)
logger.Infof(" 📏 Precision handling: price %.8f -> %s (precision=%d), quantity %.8f -> %s (precision=%d)",
limitPrice, priceStr, prec.PricePrecision, quantity, qtyStr, prec.QuantityPrecision)
params := map[string]interface{}{
"symbol": symbol,
"positionSide": "BOTH",
"type": "LIMIT",
"side": "SELL",
"timeInForce": "GTC",
"quantity": qtyStr,
"price": priceStr,
}
body, err := t.request("POST", "/fapi/v3/order", params)
if err != nil {
return nil, err
}
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
return nil, err
}
logger.Infof("✓ Successfully closed long position: %s quantity: %s", symbol, qtyStr)
// Cancel all pending orders for this symbol after closing position (stop-loss/take-profit orders)
if err := t.CancelAllOrders(symbol); err != nil {
logger.Infof(" ⚠ Failed to cancel pending orders: %v", err)
}
return result, nil
}
// CloseShort Close short position
func (t *AsterTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) {
// If quantity is 0, get current position quantity
if quantity == 0 {
positions, err := t.GetPositions()
if err != nil {
return nil, err
}
for _, pos := range positions {
if pos["symbol"] == symbol && pos["side"] == "short" {
// Aster's GetPositions has already converted short position quantity to positive, use directly
quantity = pos["positionAmt"].(float64)
break
}
}
if quantity == 0 {
return nil, fmt.Errorf("no short position found for %s", symbol)
}
logger.Infof(" 📊 Retrieved short position quantity: %.8f", quantity)
}
price, err := t.GetMarketPrice(symbol)
if err != nil {
return nil, err
}
limitPrice := price * 1.01
// Format price and quantity to correct precision
formattedPrice, err := t.formatPrice(symbol, limitPrice)
if err != nil {
return nil, err
}
formattedQty, err := t.formatQuantity(symbol, quantity)
if err != nil {
return nil, err
}
// Get precision information
prec, err := t.getPrecision(symbol)
if err != nil {
return nil, err
}
// Convert to string with correct precision format
priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision)
qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision)
logger.Infof(" 📏 Precision handling: price %.8f -> %s (precision=%d), quantity %.8f -> %s (precision=%d)",
limitPrice, priceStr, prec.PricePrecision, quantity, qtyStr, prec.QuantityPrecision)
params := map[string]interface{}{
"symbol": symbol,
"positionSide": "BOTH",
"type": "LIMIT",
"side": "BUY",
"timeInForce": "GTC",
"quantity": qtyStr,
"price": priceStr,
}
body, err := t.request("POST", "/fapi/v3/order", params)
if err != nil {
return nil, err
}
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
return nil, err
}
logger.Infof("✓ Successfully closed short position: %s quantity: %s", symbol, qtyStr)
// Cancel all pending orders for this symbol after closing position (stop-loss/take-profit orders)
if err := t.CancelAllOrders(symbol); err != nil {
logger.Infof(" ⚠ Failed to cancel pending orders: %v", err)
}
return result, nil
}
// SetStopLoss Set stop loss
func (t *AsterTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error {
side := "SELL"
if positionSide == "SHORT" {
side = "BUY"
}
// Format price and quantity to correct precision
formattedPrice, err := t.formatPrice(symbol, stopPrice)
if err != nil {
return err
}
formattedQty, err := t.formatQuantity(symbol, quantity)
if err != nil {
return err
}
// Get precision information
prec, err := t.getPrecision(symbol)
if err != nil {
return err
}
// Convert to string with correct precision format
priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision)
qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision)
params := map[string]interface{}{
"symbol": symbol,
"positionSide": "BOTH",
"type": "STOP_MARKET",
"side": side,
"stopPrice": priceStr,
"quantity": qtyStr,
"timeInForce": "GTC",
}
_, err = t.request("POST", "/fapi/v3/order", params)
return err
}
// SetTakeProfit Set take profit
func (t *AsterTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error {
side := "SELL"
if positionSide == "SHORT" {
side = "BUY"
}
// Format price and quantity to correct precision
formattedPrice, err := t.formatPrice(symbol, takeProfitPrice)
if err != nil {
return err
}
formattedQty, err := t.formatQuantity(symbol, quantity)
if err != nil {
return err
}
// Get precision information
prec, err := t.getPrecision(symbol)
if err != nil {
return err
}
// Convert to string with correct precision format
priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision)
qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision)
params := map[string]interface{}{
"symbol": symbol,
"positionSide": "BOTH",
"type": "TAKE_PROFIT_MARKET",
"side": side,
"stopPrice": priceStr,
"quantity": qtyStr,
"timeInForce": "GTC",
}
_, err = t.request("POST", "/fapi/v3/order", params)
return err
}
// CancelStopLossOrders Cancel stop-loss orders only (does not affect take-profit orders)
func (t *AsterTrader) CancelStopLossOrders(symbol string) error {
// Get all open orders for this symbol
params := map[string]interface{}{
"symbol": symbol,
}
body, err := t.request("GET", "/fapi/v3/openOrders", params)
if err != nil {
return fmt.Errorf("failed to get open orders: %w", err)
}
var orders []map[string]interface{}
if err := json.Unmarshal(body, &orders); err != nil {
return fmt.Errorf("failed to parse order data: %w", err)
}
// Filter and cancel stop-loss orders (cancel all directions including LONG and SHORT)
canceledCount := 0
var cancelErrors []error
for _, order := range orders {
orderType, _ := order["type"].(string)
// Only cancel stop-loss orders (don't cancel take-profit orders)
if orderType == "STOP_MARKET" || orderType == "STOP" {
orderID, _ := order["orderId"].(float64)
positionSide, _ := order["positionSide"].(string)
cancelParams := map[string]interface{}{
"symbol": symbol,
"orderId": int64(orderID),
}
_, err := t.request("DELETE", "/fapi/v1/order", cancelParams)
if err != nil {
errMsg := fmt.Sprintf("order ID %d: %v", int64(orderID), err)
cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg))
logger.Infof(" ⚠ Failed to cancel stop-loss order: %s", errMsg)
continue
}
canceledCount++
logger.Infof(" ✓ Canceled stop-loss order (order ID: %d, type: %s, direction: %s)", int64(orderID), orderType, positionSide)
}
}
if canceledCount == 0 && len(cancelErrors) == 0 {
logger.Infof(" %s no stop-loss orders to cancel", symbol)
} else if canceledCount > 0 {
logger.Infof(" ✓ Canceled %d stop-loss order(s) for %s", canceledCount, symbol)
}
// Return error if all cancellations failed
if len(cancelErrors) > 0 && canceledCount == 0 {
return fmt.Errorf("failed to cancel stop-loss orders: %v", cancelErrors)
}
return nil
}
// CancelTakeProfitOrders Cancel take-profit orders only (does not affect stop-loss orders)
func (t *AsterTrader) CancelTakeProfitOrders(symbol string) error {
// Get all open orders for this symbol
params := map[string]interface{}{
"symbol": symbol,
}
body, err := t.request("GET", "/fapi/v3/openOrders", params)
if err != nil {
return fmt.Errorf("failed to get open orders: %w", err)
}
var orders []map[string]interface{}
if err := json.Unmarshal(body, &orders); err != nil {
return fmt.Errorf("failed to parse order data: %w", err)
}
// Filter and cancel take-profit orders (cancel all directions including LONG and SHORT)
canceledCount := 0
var cancelErrors []error
for _, order := range orders {
orderType, _ := order["type"].(string)
// Only cancel take-profit orders (don't cancel stop-loss orders)
if orderType == "TAKE_PROFIT_MARKET" || orderType == "TAKE_PROFIT" {
orderID, _ := order["orderId"].(float64)
positionSide, _ := order["positionSide"].(string)
cancelParams := map[string]interface{}{
"symbol": symbol,
"orderId": int64(orderID),
}
_, err := t.request("DELETE", "/fapi/v1/order", cancelParams)
if err != nil {
errMsg := fmt.Sprintf("order ID %d: %v", int64(orderID), err)
cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg))
logger.Infof(" ⚠ Failed to cancel take-profit order: %s", errMsg)
continue
}
canceledCount++
logger.Infof(" ✓ Canceled take-profit order (order ID: %d, type: %s, direction: %s)", int64(orderID), orderType, positionSide)
}
}
if canceledCount == 0 && len(cancelErrors) == 0 {
logger.Infof(" %s no take-profit orders to cancel", symbol)
} else if canceledCount > 0 {
logger.Infof(" ✓ Canceled %d take-profit order(s) for %s", canceledCount, symbol)
}
// Return error if all cancellations failed
if len(cancelErrors) > 0 && canceledCount == 0 {
return fmt.Errorf("failed to cancel take-profit orders: %v", cancelErrors)
}
return nil
}
// CancelAllOrders Cancel all orders
func (t *AsterTrader) CancelAllOrders(symbol string) error {
params := map[string]interface{}{
"symbol": symbol,
}
_, err := t.request("DELETE", "/fapi/v3/allOpenOrders", params)
return err
}
// CancelStopOrders Cancel take-profit/stop-loss orders for this symbol (used to adjust TP/SL positions)
func (t *AsterTrader) CancelStopOrders(symbol string) error {
// Get all open orders for this symbol
params := map[string]interface{}{
"symbol": symbol,
}
body, err := t.request("GET", "/fapi/v3/openOrders", params)
if err != nil {
return fmt.Errorf("failed to get open orders: %w", err)
}
var orders []map[string]interface{}
if err := json.Unmarshal(body, &orders); err != nil {
return fmt.Errorf("failed to parse order data: %w", err)
}
// Filter and cancel take-profit/stop-loss orders
canceledCount := 0
for _, order := range orders {
orderType, _ := order["type"].(string)
// Only cancel stop-loss and take-profit orders
if orderType == "STOP_MARKET" ||
orderType == "TAKE_PROFIT_MARKET" ||
orderType == "STOP" ||
orderType == "TAKE_PROFIT" {
orderID, _ := order["orderId"].(float64)
cancelParams := map[string]interface{}{
"symbol": symbol,
"orderId": int64(orderID),
}
_, err := t.request("DELETE", "/fapi/v3/order", cancelParams)
if err != nil {
logger.Infof(" ⚠ Failed to cancel order %d: %v", int64(orderID), err)
continue
}
canceledCount++
logger.Infof(" ✓ Canceled take-profit/stop-loss order for %s (order ID: %d, type: %s)",
symbol, int64(orderID), orderType)
}
}
if canceledCount == 0 {
logger.Infof(" %s no take-profit/stop-loss orders to cancel", symbol)
} else {
logger.Infof(" ✓ Canceled %d take-profit/stop-loss order(s) for %s", canceledCount, symbol)
}
return nil
}
// FormatQuantity Format quantity (implements Trader interface)
func (t *AsterTrader) FormatQuantity(symbol string, quantity float64) (string, error) {
formatted, err := t.formatQuantity(symbol, quantity)
if err != nil {
return "", err
}
return fmt.Sprintf("%v", formatted), nil
}
// GetOrderStatus Get order status
func (t *AsterTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) {
params := map[string]interface{}{
"symbol": symbol,
"orderId": orderID,
}
body, err := t.request("GET", "/fapi/v3/order", params)
if err != nil {
return nil, fmt.Errorf("failed to get order status: %w", err)
}
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse order response: %w", err)
}
// Standardize return fields
response := map[string]interface{}{
"orderId": result["orderId"],
"symbol": result["symbol"],
"status": result["status"],
"side": result["side"],
"type": result["type"],
"time": result["time"],
"updateTime": result["updateTime"],
"commission": 0.0, // Aster may require separate query
}
// Parse numeric fields
if avgPrice, ok := result["avgPrice"].(string); ok {
if v, err := strconv.ParseFloat(avgPrice, 64); err == nil {
response["avgPrice"] = v
}
} else if avgPrice, ok := result["avgPrice"].(float64); ok {
response["avgPrice"] = avgPrice
}
if executedQty, ok := result["executedQty"].(string); ok {
if v, err := strconv.ParseFloat(executedQty, 64); err == nil {
response["executedQty"] = v
}
} else if executedQty, ok := result["executedQty"].(float64); ok {
response["executedQty"] = executedQty
}
return response, nil
}
// GetOpenOrders gets all open/pending orders for a symbol
func (t *AsterTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) {
params := map[string]interface{}{
"symbol": symbol,
}
body, err := t.request("GET", "/fapi/v3/openOrders", params)
if err != nil {
return nil, fmt.Errorf("failed to get open orders: %w", err)
}
var orders []struct {
OrderID int64 `json:"orderId"`
Symbol string `json:"symbol"`
Side string `json:"side"`
PositionSide string `json:"positionSide"`
Type string `json:"type"`
Price string `json:"price"`
StopPrice string `json:"stopPrice"`
OrigQty string `json:"origQty"`
Status string `json:"status"`
}
if err := json.Unmarshal(body, &orders); err != nil {
return nil, fmt.Errorf("failed to parse open orders: %w", err)
}
var result []types.OpenOrder
for _, order := range orders {
price, _ := strconv.ParseFloat(order.Price, 64)
stopPrice, _ := strconv.ParseFloat(order.StopPrice, 64)
quantity, _ := strconv.ParseFloat(order.OrigQty, 64)
result = append(result, types.OpenOrder{
OrderID: fmt.Sprintf("%d", order.OrderID),
Symbol: order.Symbol,
Side: order.Side,
PositionSide: order.PositionSide,
Type: order.Type,
Price: price,
StopPrice: stopPrice,
Quantity: quantity,
Status: order.Status,
})
}
logger.Infof("✓ ASTER GetOpenOrders: found %d open orders for %s", len(result), symbol)
return result, nil
}
// PlaceLimitOrder places a limit order for grid trading
func (t *AsterTrader) PlaceLimitOrder(req *types.LimitOrderRequest) (*types.LimitOrderResult, error) {
// Format price and quantity to correct precision
formattedPrice, err := t.formatPrice(req.Symbol, req.Price)
if err != nil {
return nil, fmt.Errorf("failed to format price: %w", err)
}
formattedQty, err := t.formatQuantity(req.Symbol, req.Quantity)
if err != nil {
return nil, fmt.Errorf("failed to format quantity: %w", err)
}
// Get precision information
prec, err := t.getPrecision(req.Symbol)
if err != nil {
return nil, fmt.Errorf("failed to get precision: %w", err)
}
// Convert to string with correct precision format
priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision)
qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision)
// Determine side
side := "BUY"
if req.Side == "SELL" || req.Side == "Sell" || req.Side == "sell" {
side = "SELL"
}
params := map[string]interface{}{
"symbol": req.Symbol,
"positionSide": "BOTH",
"type": "LIMIT",
"side": side,
"timeInForce": "GTC",
"quantity": qtyStr,
"price": priceStr,
}
// Add reduceOnly if specified
if req.ReduceOnly {
params["reduceOnly"] = "true"
}
body, err := t.request("POST", "/fapi/v3/order", params)
if err != nil {
return nil, fmt.Errorf("failed to place limit order: %w", err)
}
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse order response: %w", err)
}
// Extract order ID
orderID := ""
if id, ok := result["orderId"].(float64); ok {
orderID = fmt.Sprintf("%.0f", id)
} else if id, ok := result["orderId"].(string); ok {
orderID = id
}
// Extract client order ID
clientOrderID := ""
if cid, ok := result["clientOrderId"].(string); ok {
clientOrderID = cid
}
return &types.LimitOrderResult{
OrderID: orderID,
ClientID: clientOrderID,
Symbol: req.Symbol,
Side: side,
Price: formattedPrice,
Quantity: formattedQty,
Status: "NEW",
}, nil
}
// CancelOrder cancels a specific order by order ID
func (t *AsterTrader) CancelOrder(symbol, orderID string) error {
params := map[string]interface{}{
"symbol": symbol,
"orderId": orderID,
}
_, err := t.request("DELETE", "/fapi/v3/order", params)
if err != nil {
return fmt.Errorf("failed to cancel order %s: %w", orderID, err)
}
return nil
}
+121
View File
@@ -0,0 +1,121 @@
package aster
import (
"encoding/json"
"fmt"
"nofx/logger"
"strconv"
"strings"
)
// GetPositions Get position information
func (t *AsterTrader) GetPositions() ([]map[string]interface{}, error) {
params := make(map[string]interface{})
body, err := t.request("GET", "/fapi/v3/positionRisk", params)
if err != nil {
return nil, err
}
var positions []map[string]interface{}
if err := json.Unmarshal(body, &positions); err != nil {
return nil, err
}
result := []map[string]interface{}{}
for _, pos := range positions {
posAmtStr, ok := pos["positionAmt"].(string)
if !ok {
continue
}
posAmt, _ := strconv.ParseFloat(posAmtStr, 64)
if posAmt == 0 {
continue // Skip empty positions
}
entryPrice, _ := strconv.ParseFloat(pos["entryPrice"].(string), 64)
markPrice, _ := strconv.ParseFloat(pos["markPrice"].(string), 64)
unRealizedProfit, _ := strconv.ParseFloat(pos["unRealizedProfit"].(string), 64)
leverageVal, _ := strconv.ParseFloat(pos["leverage"].(string), 64)
liquidationPrice, _ := strconv.ParseFloat(pos["liquidationPrice"].(string), 64)
// Determine direction (consistent with Binance)
side := "long"
if posAmt < 0 {
side = "short"
posAmt = -posAmt
}
// Return same field names as Binance
result = append(result, map[string]interface{}{
"symbol": pos["symbol"],
"side": side,
"positionAmt": posAmt,
"entryPrice": entryPrice,
"markPrice": markPrice,
"unRealizedProfit": unRealizedProfit,
"leverage": leverageVal,
"liquidationPrice": liquidationPrice,
})
}
return result, nil
}
// SetMarginMode Set margin mode
func (t *AsterTrader) SetMarginMode(symbol string, isCrossMargin bool) error {
// Aster supports margin mode settings
// API format similar to Binance: CROSSED (cross margin) / ISOLATED (isolated margin)
marginType := "CROSSED"
if !isCrossMargin {
marginType = "ISOLATED"
}
params := map[string]interface{}{
"symbol": symbol,
"marginType": marginType,
}
// Use request method to call API
_, err := t.request("POST", "/fapi/v3/marginType", params)
if err != nil {
// Ignore error if it indicates no need to change
if strings.Contains(err.Error(), "No need to change") ||
strings.Contains(err.Error(), "Margin type cannot be changed") {
logger.Infof(" ✓ %s margin mode is already %s or cannot be changed due to existing positions", symbol, marginType)
return nil
}
// Detect multi-assets mode (error code -4168)
if strings.Contains(err.Error(), "Multi-Assets mode") ||
strings.Contains(err.Error(), "-4168") ||
strings.Contains(err.Error(), "4168") {
logger.Infof(" ⚠️ %s detected multi-assets mode, forcing cross margin mode", symbol)
logger.Infof(" 💡 Tip: To use isolated margin mode, please disable multi-assets mode on the exchange")
return nil
}
// Detect unified account API
if strings.Contains(err.Error(), "unified") ||
strings.Contains(err.Error(), "portfolio") ||
strings.Contains(err.Error(), "Portfolio") {
logger.Infof(" ❌ %s detected unified account API, cannot perform futures trading", symbol)
return fmt.Errorf("please use 'Spot & Futures Trading' API permission, not 'Unified Account API'")
}
logger.Infof(" ⚠️ Failed to set margin mode: %v", err)
// Don't return error, let trading continue
return nil
}
logger.Infof(" ✓ %s margin mode has been set to %s", symbol, marginType)
return nil
}
// SetLeverage Set leverage multiplier
func (t *AsterTrader) SetLeverage(symbol string, leverage int) error {
params := map[string]interface{}{
"symbol": symbol,
"leverage": leverage,
}
_, err := t.request("POST", "/fapi/v3/leverage", params)
return err
}
+2 -2
View File
@@ -3,7 +3,7 @@ package trader
import (
"fmt"
"math"
"nofx/experience"
"nofx/telemetry"
"nofx/kernel"
"nofx/logger"
"nofx/market"
@@ -345,7 +345,7 @@ func (at *AutoTrader) recordAndConfirmOrder(orderResult map[string]interface{},
// Send anonymous trade statistics for experience improvement (async, non-blocking)
// This helps us understand overall product usage across all deployments
experience.TrackTrade(experience.TradeEvent{
telemetry.TrackTrade(telemetry.TradeEvent{
Exchange: at.exchange,
TradeType: action,
Symbol: symbol,
+21 -1232
View File
File diff suppressed because it is too large Load Diff
+485
View File
@@ -0,0 +1,485 @@
package trader
import (
"math"
"nofx/kernel"
"nofx/logger"
"nofx/market"
"nofx/store"
)
// ============================================================================
// Grid Level Calculation and Rebalancing
// ============================================================================
// calculateDefaultBounds calculates default bounds based on price
func (at *AutoTrader) calculateDefaultBounds(price float64, config *store.GridStrategyConfig) {
// Default: +/-3% from current price
multiplier := 0.03 * float64(config.GridCount) / 10
at.gridState.UpperPrice = price * (1 + multiplier)
at.gridState.LowerPrice = price * (1 - multiplier)
}
// calculateATRBounds calculates bounds using ATR
func (at *AutoTrader) calculateATRBounds(price float64, mktData *market.Data, config *store.GridStrategyConfig) {
atr := 0.0
if mktData.LongerTermContext != nil {
atr = mktData.LongerTermContext.ATR14
}
if atr <= 0 {
at.calculateDefaultBounds(price, config)
return
}
multiplier := config.ATRMultiplier
if multiplier <= 0 {
multiplier = 2.0
}
halfRange := atr * multiplier
at.gridState.UpperPrice = price + halfRange
at.gridState.LowerPrice = price - halfRange
}
// initializeGridLevels creates the grid level structure
func (at *AutoTrader) initializeGridLevels(currentPrice float64, config *store.GridStrategyConfig) {
levels := make([]kernel.GridLevelInfo, config.GridCount)
totalWeight := 0.0
weights := make([]float64, config.GridCount)
// Calculate weights based on distribution
for i := 0; i < config.GridCount; i++ {
switch config.Distribution {
case "gaussian":
// Gaussian distribution - more weight in the middle
center := float64(config.GridCount-1) / 2
sigma := float64(config.GridCount) / 4
weights[i] = math.Exp(-math.Pow(float64(i)-center, 2) / (2 * sigma * sigma))
case "pyramid":
// Pyramid - more weight at bottom
weights[i] = float64(config.GridCount - i)
default: // uniform
weights[i] = 1.0
}
totalWeight += weights[i]
}
// Create levels
for i := 0; i < config.GridCount; i++ {
price := at.gridState.LowerPrice + float64(i)*at.gridState.GridSpacing
allocatedUSD := config.TotalInvestment * weights[i] / totalWeight
// Determine initial side (below current price = buy, above = sell)
side := "buy"
if price > currentPrice {
side = "sell"
}
levels[i] = kernel.GridLevelInfo{
Index: i,
Price: price,
State: "empty",
Side: side,
AllocatedUSD: allocatedUSD,
}
}
at.gridState.Levels = levels
// Apply direction-based side assignment if enabled
if config.EnableDirectionAdjust {
at.applyGridDirection(currentPrice)
}
}
// applyGridDirection adjusts grid level sides based on the current direction
// This redistributes buy/sell levels according to the direction bias ratio
func (at *AutoTrader) applyGridDirection(currentPrice float64) {
config := at.gridState.Config
direction := at.gridState.CurrentDirection
// Get bias ratio from config, default to 0.7 (70%/30%)
biasRatio := config.DirectionBiasRatio
if biasRatio <= 0 || biasRatio > 1 {
biasRatio = 0.7
}
buyRatio, _ := direction.GetBuySellRatio(biasRatio)
// Calculate how many levels should be buy vs sell based on direction
totalLevels := len(at.gridState.Levels)
targetBuyLevels := int(float64(totalLevels) * buyRatio)
// For neutral: use price-based assignment (buy below, sell above)
if direction == market.GridDirectionNeutral {
for i := range at.gridState.Levels {
if at.gridState.Levels[i].Price <= currentPrice {
at.gridState.Levels[i].Side = "buy"
} else {
at.gridState.Levels[i].Side = "sell"
}
}
return
}
// For long/long_bias: more buy levels
// For short/short_bias: more sell levels
switch direction {
case market.GridDirectionLong:
// 100% buy - all levels are buy
for i := range at.gridState.Levels {
at.gridState.Levels[i].Side = "buy"
}
case market.GridDirectionShort:
// 100% sell - all levels are sell
for i := range at.gridState.Levels {
at.gridState.Levels[i].Side = "sell"
}
case market.GridDirectionLongBias, market.GridDirectionShortBias:
// Assign sides based on position relative to current price
// For long_bias: keep all below as buy, convert some above to buy
// For short_bias: keep all above as sell, convert some below to sell
buyCount := 0
sellCount := 0
for i := range at.gridState.Levels {
needMoreBuys := buyCount < targetBuyLevels
needMoreSells := sellCount < (totalLevels - targetBuyLevels)
if at.gridState.Levels[i].Price <= currentPrice {
// Level below or at current price
if needMoreBuys {
at.gridState.Levels[i].Side = "buy"
buyCount++
} else {
at.gridState.Levels[i].Side = "sell"
sellCount++
}
} else {
// Level above current price
if needMoreSells && direction == market.GridDirectionShortBias {
at.gridState.Levels[i].Side = "sell"
sellCount++
} else if needMoreBuys && direction == market.GridDirectionLongBias {
at.gridState.Levels[i].Side = "buy"
buyCount++
} else if needMoreSells {
at.gridState.Levels[i].Side = "sell"
sellCount++
} else {
at.gridState.Levels[i].Side = "buy"
buyCount++
}
}
}
}
logger.Infof("[Grid] Applied direction %s: buy_ratio=%.0f%%, levels reconfigured",
direction, buyRatio*100)
}
// checkGridSkew checks if grid is heavily skewed (too many fills on one side)
// Returns: (skewed bool, buyFilledCount int, sellFilledCount int)
func (at *AutoTrader) checkGridSkew() (bool, int, int) {
at.gridState.mu.RLock()
defer at.gridState.mu.RUnlock()
buyFilled := 0
sellFilled := 0
buyEmpty := 0
sellEmpty := 0
for _, level := range at.gridState.Levels {
if level.Side == "buy" {
if level.State == "filled" {
buyFilled++
} else if level.State == "empty" {
buyEmpty++
}
} else {
if level.State == "filled" {
sellFilled++
} else if level.State == "empty" {
sellEmpty++
}
}
}
// Grid is skewed if one side has 3x more fills than the other
// or if one side is completely empty
skewed := false
if buyFilled > 0 && sellFilled == 0 && sellEmpty > 5 {
skewed = true // All buys filled, no sells
} else if sellFilled > 0 && buyFilled == 0 && buyEmpty > 5 {
skewed = true // All sells filled, no buys
} else if buyFilled >= 3*sellFilled && buyFilled > 5 {
skewed = true
} else if sellFilled >= 3*buyFilled && sellFilled > 5 {
skewed = true
}
return skewed, buyFilled, sellFilled
}
// autoAdjustGrid automatically adjusts grid when heavily skewed
func (at *AutoTrader) autoAdjustGrid() {
skewed, buyFilled, sellFilled := at.checkGridSkew()
if !skewed {
return
}
logger.Warnf("[Grid] Grid heavily skewed: buy_filled=%d, sell_filled=%d. Auto-adjusting...",
buyFilled, sellFilled)
gridConfig := at.config.StrategyConfig.GridConfig
// Get current price
currentPrice, err := at.trader.GetMarketPrice(gridConfig.Symbol)
if err != nil {
logger.Errorf("[Grid] Failed to get price for auto-adjust: %v", err)
return
}
// Check if price is near grid boundary
at.gridState.mu.RLock()
upper := at.gridState.UpperPrice
lower := at.gridState.LowerPrice
at.gridState.mu.RUnlock()
// Only adjust if price has moved significantly (>30% of grid range)
gridRange := upper - lower
midPrice := (upper + lower) / 2
priceDeviation := math.Abs(currentPrice - midPrice)
if priceDeviation < gridRange*0.3 {
return // Price still near center, don't adjust
}
logger.Infof("[Grid] Adjusting grid around new price $%.2f", currentPrice)
// Cancel existing orders first (before taking the lock for state modification)
if err := at.cancelAllGridOrders(); err != nil {
logger.Errorf("[Grid] Failed to cancel orders during auto-adjust: %v", err)
// Continue with adjustment anyway
}
// CRITICAL FIX: Hold lock for the entire adjustment operation to ensure atomicity
at.gridState.mu.Lock()
defer at.gridState.mu.Unlock()
// Preserve filled positions before reinitializing
filledPositions := make(map[int]kernel.GridLevelInfo)
for i, level := range at.gridState.Levels {
if level.State == "filled" {
filledPositions[i] = level
}
}
// CRITICAL FIX: Recalculate grid bounds centered on current price
// Use the same logic as InitializeGrid() - either ATR-based or default percentage
if gridConfig.UseATRBounds {
// Try to get ATR for bound calculation
mktData, err := market.GetWithTimeframes(gridConfig.Symbol, []string{"4h"}, "4h", 20)
if err != nil {
logger.Warnf("[Grid] Failed to get market data for ATR during adjust: %v, using default bounds", err)
at.calculateDefaultBoundsLocked(currentPrice, gridConfig)
} else {
at.calculateATRBoundsLocked(currentPrice, mktData, gridConfig)
}
} else {
// Use default bounds calculation (scaled by grid count)
at.calculateDefaultBoundsLocked(currentPrice, gridConfig)
}
// Recalculate grid spacing based on new bounds
at.gridState.GridSpacing = (at.gridState.UpperPrice - at.gridState.LowerPrice) / float64(gridConfig.GridCount-1)
logger.Infof("[Grid] New bounds: $%.2f - $%.2f, spacing: $%.2f",
at.gridState.LowerPrice, at.gridState.UpperPrice, at.gridState.GridSpacing)
// Initialize new grid levels (without lock since we already hold it)
at.initializeGridLevelsLocked(currentPrice, gridConfig)
// CRITICAL FIX: Restore filled positions - find closest new level for each filled position
for _, filledLevel := range filledPositions {
closestIdx := -1
closestDist := math.MaxFloat64
for i, newLevel := range at.gridState.Levels {
dist := math.Abs(newLevel.Price - filledLevel.PositionEntry)
if dist < closestDist {
closestDist = dist
closestIdx = i
}
}
if closestIdx >= 0 {
// Restore the filled state to the closest level
at.gridState.Levels[closestIdx].State = "filled"
at.gridState.Levels[closestIdx].PositionEntry = filledLevel.PositionEntry
at.gridState.Levels[closestIdx].PositionSize = filledLevel.PositionSize
at.gridState.Levels[closestIdx].UnrealizedPnL = filledLevel.UnrealizedPnL
at.gridState.Levels[closestIdx].OrderID = filledLevel.OrderID
at.gridState.Levels[closestIdx].OrderQuantity = filledLevel.OrderQuantity
logger.Infof("[Grid] Restored filled position at level %d (entry $%.2f)", closestIdx, filledLevel.PositionEntry)
}
}
}
// calculateDefaultBoundsLocked calculates default bounds (caller must hold lock)
func (at *AutoTrader) calculateDefaultBoundsLocked(price float64, config *store.GridStrategyConfig) {
// Default: +/-3% from current price, scaled by grid count
multiplier := 0.03 * float64(config.GridCount) / 10
at.gridState.UpperPrice = price * (1 + multiplier)
at.gridState.LowerPrice = price * (1 - multiplier)
}
// calculateATRBoundsLocked calculates bounds using ATR (caller must hold lock)
func (at *AutoTrader) calculateATRBoundsLocked(price float64, mktData *market.Data, config *store.GridStrategyConfig) {
atr := 0.0
if mktData.LongerTermContext != nil {
atr = mktData.LongerTermContext.ATR14
}
if atr <= 0 {
at.calculateDefaultBoundsLocked(price, config)
return
}
multiplier := config.ATRMultiplier
if multiplier <= 0 {
multiplier = 2.0
}
halfRange := atr * multiplier
at.gridState.UpperPrice = price + halfRange
at.gridState.LowerPrice = price - halfRange
}
// initializeGridLevelsLocked creates the grid level structure (caller must hold lock)
func (at *AutoTrader) initializeGridLevelsLocked(currentPrice float64, config *store.GridStrategyConfig) {
levels := make([]kernel.GridLevelInfo, config.GridCount)
totalWeight := 0.0
weights := make([]float64, config.GridCount)
// Calculate weights based on distribution
for i := 0; i < config.GridCount; i++ {
switch config.Distribution {
case "gaussian":
// Gaussian distribution - more weight in the middle
center := float64(config.GridCount-1) / 2
sigma := float64(config.GridCount) / 4
weights[i] = math.Exp(-math.Pow(float64(i)-center, 2) / (2 * sigma * sigma))
case "pyramid":
// Pyramid - more weight at bottom
weights[i] = float64(config.GridCount - i)
default: // uniform
weights[i] = 1.0
}
totalWeight += weights[i]
}
// Create levels
for i := 0; i < config.GridCount; i++ {
price := at.gridState.LowerPrice + float64(i)*at.gridState.GridSpacing
allocatedUSD := config.TotalInvestment * weights[i] / totalWeight
// Determine initial side (below current price = buy, above = sell)
side := "buy"
if price > currentPrice {
side = "sell"
}
levels[i] = kernel.GridLevelInfo{
Index: i,
Price: price,
State: "empty",
Side: side,
AllocatedUSD: allocatedUSD,
}
}
at.gridState.Levels = levels
// Apply direction-based side assignment if enabled (note: caller holds lock)
if config.EnableDirectionAdjust {
at.applyGridDirectionLocked(currentPrice)
}
}
// applyGridDirectionLocked adjusts grid level sides based on the current direction (caller must hold lock)
func (at *AutoTrader) applyGridDirectionLocked(currentPrice float64) {
config := at.gridState.Config
direction := at.gridState.CurrentDirection
// Get bias ratio from config, default to 0.7 (70%/30%)
biasRatio := config.DirectionBiasRatio
if biasRatio <= 0 || biasRatio > 1 {
biasRatio = 0.7
}
buyRatio, _ := direction.GetBuySellRatio(biasRatio)
// For neutral: use price-based assignment (buy below, sell above)
if direction == market.GridDirectionNeutral {
for i := range at.gridState.Levels {
if at.gridState.Levels[i].Price <= currentPrice {
at.gridState.Levels[i].Side = "buy"
} else {
at.gridState.Levels[i].Side = "sell"
}
}
return
}
totalLevels := len(at.gridState.Levels)
targetBuyLevels := int(float64(totalLevels) * buyRatio)
switch direction {
case market.GridDirectionLong:
for i := range at.gridState.Levels {
at.gridState.Levels[i].Side = "buy"
}
case market.GridDirectionShort:
for i := range at.gridState.Levels {
at.gridState.Levels[i].Side = "sell"
}
case market.GridDirectionLongBias, market.GridDirectionShortBias:
buyCount := 0
sellCount := 0
for i := range at.gridState.Levels {
needMoreBuys := buyCount < targetBuyLevels
needMoreSells := sellCount < (totalLevels - targetBuyLevels)
if at.gridState.Levels[i].Price <= currentPrice {
if needMoreBuys {
at.gridState.Levels[i].Side = "buy"
buyCount++
} else {
at.gridState.Levels[i].Side = "sell"
sellCount++
}
} else {
if needMoreSells && direction == market.GridDirectionShortBias {
at.gridState.Levels[i].Side = "sell"
sellCount++
} else if needMoreBuys && direction == market.GridDirectionLongBias {
at.gridState.Levels[i].Side = "buy"
buyCount++
} else if needMoreSells {
at.gridState.Levels[i].Side = "sell"
sellCount++
} else {
at.gridState.Levels[i].Side = "buy"
buyCount++
}
}
}
}
}
+419
View File
@@ -0,0 +1,419 @@
package trader
import (
"fmt"
"math"
"nofx/kernel"
"nofx/logger"
"time"
)
// ============================================================================
// Grid Order Placement and Management
// ============================================================================
// checkTotalPositionLimit checks if adding a new position would exceed total limits
// Returns: (allowed bool, currentPositionValue float64, maxAllowed float64)
func (at *AutoTrader) checkTotalPositionLimit(symbol string, additionalValue float64) (bool, float64, float64) {
gridConfig := at.config.StrategyConfig.GridConfig
// Calculate max allowed total position value
// Total position should not exceed: TotalInvestment * Leverage
maxTotalPositionValue := gridConfig.TotalInvestment * float64(gridConfig.Leverage)
// Get current position value from exchange
currentPositionValue := 0.0
positions, err := at.trader.GetPositions()
if err == nil {
for _, pos := range positions {
if sym, ok := pos["symbol"].(string); ok && sym == symbol {
if size, ok := pos["positionAmt"].(float64); ok {
if price, ok := pos["markPrice"].(float64); ok {
currentPositionValue = math.Abs(size) * price
} else if entryPrice, ok := pos["entryPrice"].(float64); ok {
currentPositionValue = math.Abs(size) * entryPrice
}
}
}
}
}
// Also count pending orders as potential position
at.gridState.mu.RLock()
pendingValue := 0.0
for _, level := range at.gridState.Levels {
if level.State == "pending" {
pendingValue += level.OrderQuantity * level.Price
}
}
at.gridState.mu.RUnlock()
totalAfterOrder := currentPositionValue + pendingValue + additionalValue
allowed := totalAfterOrder <= maxTotalPositionValue
return allowed, currentPositionValue + pendingValue, maxTotalPositionValue
}
// placeGridLimitOrder places a limit order for grid trading
func (at *AutoTrader) placeGridLimitOrder(d *kernel.Decision, side string) error {
// Check if trader supports GridTrader interface
gridTrader, ok := at.trader.(GridTrader)
if !ok {
// Fallback to adapter
gridTrader = NewGridTraderAdapter(at.trader)
}
gridConfig := at.config.StrategyConfig.GridConfig
// CRITICAL: Validate and cap quantity to prevent excessive position sizes
// This protects against AI miscalculations or leverage misconfigurations
quantity := d.Quantity
if d.Price > 0 && gridConfig.TotalInvestment > 0 {
// Calculate max allowed position value per grid level
// Each level gets proportional share of total investment
maxMarginPerLevel := gridConfig.TotalInvestment / float64(gridConfig.GridCount)
maxPositionValuePerLevel := maxMarginPerLevel * float64(gridConfig.Leverage)
maxQuantityPerLevel := maxPositionValuePerLevel / d.Price
// Also get the level's allocated USD for additional validation
at.gridState.mu.RLock()
var levelAllocatedUSD float64
if d.LevelIndex >= 0 && d.LevelIndex < len(at.gridState.Levels) {
levelAllocatedUSD = at.gridState.Levels[d.LevelIndex].AllocatedUSD
}
at.gridState.mu.RUnlock()
// Use level-specific allocation if available
if levelAllocatedUSD > 0 {
levelMaxPositionValue := levelAllocatedUSD * float64(gridConfig.Leverage)
levelMaxQuantity := levelMaxPositionValue / d.Price
if levelMaxQuantity < maxQuantityPerLevel {
maxQuantityPerLevel = levelMaxQuantity
}
}
// Cap quantity if it exceeds the maximum allowed
if quantity > maxQuantityPerLevel {
logger.Warnf("[Grid] Quantity %.4f exceeds max allowed %.4f (position_value $%.2f > max $%.2f), capping",
quantity, maxQuantityPerLevel, quantity*d.Price, maxPositionValuePerLevel)
quantity = maxQuantityPerLevel
}
// Safety check: ensure position value is reasonable (within 2x of intended max as absolute limit)
positionValue := quantity * d.Price
absoluteMaxValue := gridConfig.TotalInvestment * float64(gridConfig.Leverage) * 2 // 2x safety margin
if positionValue > absoluteMaxValue {
logger.Errorf("[Grid] CRITICAL: Position value $%.2f exceeds absolute max $%.2f! Rejecting order.",
positionValue, absoluteMaxValue)
return fmt.Errorf("position value $%.2f exceeds safety limit $%.2f", positionValue, absoluteMaxValue)
}
}
// CRITICAL: Check total position limit before placing order
orderValue := quantity * d.Price
allowed, currentValue, maxValue := at.checkTotalPositionLimit(d.Symbol, orderValue)
if !allowed {
logger.Errorf("[Grid] TOTAL POSITION LIMIT EXCEEDED: current=$%.2f + order=$%.2f > max=$%.2f. Rejecting order.",
currentValue, orderValue, maxValue)
return fmt.Errorf("total position value $%.2f would exceed limit $%.2f", currentValue+orderValue, maxValue)
}
req := &LimitOrderRequest{
Symbol: d.Symbol,
Side: side,
Price: d.Price,
Quantity: quantity, // Use validated/capped quantity
Leverage: gridConfig.Leverage,
PostOnly: gridConfig.UseMakerOnly,
ReduceOnly: false,
ClientID: fmt.Sprintf("grid-%d-%d", d.LevelIndex, time.Now().UnixNano()%1000000),
}
result, err := gridTrader.PlaceLimitOrder(req)
if err != nil {
return fmt.Errorf("failed to place limit order: %w", err)
}
// Update grid level state
at.gridState.mu.Lock()
if d.LevelIndex >= 0 && d.LevelIndex < len(at.gridState.Levels) {
at.gridState.Levels[d.LevelIndex].State = "pending"
at.gridState.Levels[d.LevelIndex].OrderID = result.OrderID
at.gridState.Levels[d.LevelIndex].OrderQuantity = d.Quantity
at.gridState.OrderBook[result.OrderID] = d.LevelIndex
}
at.gridState.mu.Unlock()
logger.Infof("[Grid] Placed %s limit order at $%.2f, qty=%.4f, level=%d, orderID=%s",
side, d.Price, d.Quantity, d.LevelIndex, result.OrderID)
return nil
}
// cancelGridOrder cancels a specific grid order
func (at *AutoTrader) cancelGridOrder(d *kernel.Decision) error {
gridTrader, ok := at.trader.(GridTrader)
if !ok {
gridTrader = NewGridTraderAdapter(at.trader)
}
if err := gridTrader.CancelOrder(d.Symbol, d.OrderID); err != nil {
return fmt.Errorf("failed to cancel order: %w", err)
}
// Update state
at.gridState.mu.Lock()
if levelIdx, ok := at.gridState.OrderBook[d.OrderID]; ok {
if levelIdx >= 0 && levelIdx < len(at.gridState.Levels) {
at.gridState.Levels[levelIdx].State = "empty"
at.gridState.Levels[levelIdx].OrderID = ""
at.gridState.Levels[levelIdx].OrderQuantity = 0
}
delete(at.gridState.OrderBook, d.OrderID)
}
at.gridState.mu.Unlock()
logger.Infof("[Grid] Cancelled order: %s", d.OrderID)
return nil
}
// cancelAllGridOrders cancels all grid orders
func (at *AutoTrader) cancelAllGridOrders() error {
gridConfig := at.config.StrategyConfig.GridConfig
if err := at.trader.CancelAllOrders(gridConfig.Symbol); err != nil {
return fmt.Errorf("failed to cancel all orders: %w", err)
}
// Reset all pending levels
at.gridState.mu.Lock()
for i := range at.gridState.Levels {
if at.gridState.Levels[i].State == "pending" {
at.gridState.Levels[i].State = "empty"
at.gridState.Levels[i].OrderID = ""
at.gridState.Levels[i].OrderQuantity = 0
}
}
at.gridState.OrderBook = make(map[string]int)
at.gridState.mu.Unlock()
logger.Infof("[Grid] Cancelled all orders")
return nil
}
// pauseGrid pauses grid trading
func (at *AutoTrader) pauseGrid(reason string) error {
at.cancelAllGridOrders()
at.gridState.mu.Lock()
at.gridState.IsPaused = true
at.gridState.mu.Unlock()
logger.Infof("[Grid] Paused: %s", reason)
return nil
}
// resumeGrid resumes grid trading
func (at *AutoTrader) resumeGrid() error {
at.gridState.mu.Lock()
at.gridState.IsPaused = false
at.gridState.mu.Unlock()
logger.Infof("[Grid] Resumed")
return nil
}
// adjustGrid adjusts grid parameters
func (at *AutoTrader) adjustGrid(d *kernel.Decision) error {
// Cancel existing orders first
at.cancelAllGridOrders()
gridConfig := at.config.StrategyConfig.GridConfig
// Get current price
price, err := at.trader.GetMarketPrice(gridConfig.Symbol)
if err != nil {
return fmt.Errorf("failed to get market price: %w", err)
}
// Reinitialize grid levels
at.initializeGridLevels(price, gridConfig)
logger.Infof("[Grid] Adjusted grid bounds around price $%.2f", price)
return nil
}
// syncGridState syncs grid state with exchange
func (at *AutoTrader) syncGridState() {
gridConfig := at.config.StrategyConfig.GridConfig
// Get open orders from exchange
openOrders, err := at.trader.GetOpenOrders(gridConfig.Symbol)
if err != nil {
logger.Warnf("[Grid] Failed to get open orders: %v", err)
return
}
// Build set of active order IDs
activeOrderIDs := make(map[string]bool)
for _, order := range openOrders {
activeOrderIDs[order.OrderID] = true
}
// Get current positions to verify fills
positions, err := at.trader.GetPositions()
currentPositionSize := 0.0
if err != nil {
logger.Warnf("[Grid] Failed to get positions for state sync: %v", err)
} else {
for _, pos := range positions {
if sym, ok := pos["symbol"].(string); ok && sym == gridConfig.Symbol {
if size, ok := pos["positionAmt"].(float64); ok {
currentPositionSize = size
}
}
}
}
// Update levels based on order status
at.gridState.mu.Lock()
expectedPositionSize := 0.0
for _, level := range at.gridState.Levels {
if level.State == "filled" {
expectedPositionSize += level.PositionSize
}
}
for i := range at.gridState.Levels {
level := &at.gridState.Levels[i]
if level.State == "pending" && level.OrderID != "" {
if !activeOrderIDs[level.OrderID] {
// Order no longer exists - check if position changed to determine fill vs cancel
// This is a heuristic - ideally we'd query order history
// If current position is larger than expected filled positions, this order was likely filled
if math.Abs(currentPositionSize) > math.Abs(expectedPositionSize) {
// Position increased, likely filled
level.State = "filled"
level.PositionEntry = level.Price
level.PositionSize = level.OrderQuantity
at.gridState.TotalTrades++
logger.Infof("[Grid] Level %d order filled at $%.2f", i, level.Price)
} else {
// Position didn't increase as expected, likely cancelled
level.State = "empty"
level.OrderID = ""
level.OrderQuantity = 0
logger.Infof("[Grid] Level %d order cancelled/expired", i)
}
delete(at.gridState.OrderBook, level.OrderID)
}
}
}
at.gridState.mu.Unlock()
logger.Debugf("[Grid] Synced state: position=%.4f, orders=%d", currentPositionSize, len(openOrders))
// Check stop loss
at.checkAndExecuteStopLoss()
// Check grid skew
at.autoAdjustGrid()
}
// closeAllPositions closes all open positions for the grid symbol
func (at *AutoTrader) closeAllPositions() error {
gridConfig := at.config.StrategyConfig.GridConfig
if gridConfig == nil {
return nil
}
positions, err := at.trader.GetPositions()
if err != nil {
return fmt.Errorf("failed to get positions: %w", err)
}
for _, pos := range positions {
symbol, _ := pos["symbol"].(string)
if symbol != gridConfig.Symbol {
continue
}
size, _ := pos["positionAmt"].(float64)
if size == 0 {
continue
}
if size > 0 {
_, err = at.trader.CloseLong(symbol, size)
} else {
_, err = at.trader.CloseShort(symbol, -size)
}
if err != nil {
logger.Infof("Failed to close position: %v", err)
}
}
return nil
}
// checkAndExecuteStopLoss checks if any filled level has exceeded stop loss and closes it
func (at *AutoTrader) checkAndExecuteStopLoss() {
gridConfig := at.config.StrategyConfig.GridConfig
if gridConfig.StopLossPct <= 0 {
return // Stop loss not configured
}
currentPrice, err := at.trader.GetMarketPrice(gridConfig.Symbol)
if err != nil {
logger.Warnf("[Grid] Failed to get market price for stop loss check: %v", err)
return
}
at.gridState.mu.Lock()
defer at.gridState.mu.Unlock()
for i := range at.gridState.Levels {
level := &at.gridState.Levels[i]
if level.State != "filled" || level.PositionEntry <= 0 {
continue
}
// Calculate loss percentage
var lossPct float64
if level.Side == "buy" {
// Long position: loss when price drops
lossPct = (level.PositionEntry - currentPrice) / level.PositionEntry * 100
} else {
// Short position: loss when price rises
lossPct = (currentPrice - level.PositionEntry) / level.PositionEntry * 100
}
// Check if stop loss triggered
if lossPct >= gridConfig.StopLossPct {
logger.Warnf("[Grid] STOP LOSS TRIGGERED: Level %d, entry=$%.2f, current=$%.2f, loss=%.2f%%",
i, level.PositionEntry, currentPrice, lossPct)
// Close the position
var closeErr error
if level.Side == "buy" {
_, closeErr = at.trader.CloseLong(gridConfig.Symbol, level.PositionSize)
} else {
_, closeErr = at.trader.CloseShort(gridConfig.Symbol, level.PositionSize)
}
if closeErr != nil {
logger.Errorf("[Grid] Failed to execute stop loss for level %d: %v", i, closeErr)
} else {
level.State = "stopped"
realizedLoss := -lossPct * level.AllocatedUSD / 100
level.UnrealizedPnL = realizedLoss
at.gridState.TotalTrades++
// Update daily PnL tracking (lock already held, update directly)
at.gridState.DailyPnL += realizedLoss
at.gridState.TotalProfit += realizedLoss
logger.Infof("[Grid] Stop loss executed: Level %d closed at $%.2f (loss %.2f%%)",
i, currentPrice, lossPct)
}
}
}
}
+345
View File
@@ -0,0 +1,345 @@
package trader
import (
"fmt"
"math"
"nofx/logger"
"nofx/market"
"time"
)
// ============================================================================
// Regime Detection and Strategy Switching
// ============================================================================
// checkBoxBreakout checks for multi-period box breakouts and takes appropriate action
func (at *AutoTrader) checkBoxBreakout() error {
gridConfig := at.config.StrategyConfig.GridConfig
if gridConfig == nil {
return nil
}
// Get box data
box, err := market.GetBoxData(gridConfig.Symbol)
if err != nil {
logger.Infof("Failed to get box data: %v", err)
return nil // Non-fatal, continue with other checks
}
// Update grid state with box values
at.gridState.mu.Lock()
at.gridState.ShortBoxUpper = box.ShortUpper
at.gridState.ShortBoxLower = box.ShortLower
at.gridState.MidBoxUpper = box.MidUpper
at.gridState.MidBoxLower = box.MidLower
at.gridState.LongBoxUpper = box.LongUpper
at.gridState.LongBoxLower = box.LongLower
at.gridState.mu.Unlock()
// Detect breakout
breakoutLevel, direction := detectBoxBreakout(box)
// Get current breakout state
state := &BreakoutState{
Level: market.BreakoutLevel(at.gridState.BreakoutLevel),
Direction: at.gridState.BreakoutDirection,
ConfirmCount: at.gridState.BreakoutConfirmCount,
}
// Check if breakout is confirmed (3 candles)
confirmed := confirmBreakout(state, breakoutLevel, direction)
// Update grid state
at.gridState.mu.Lock()
at.gridState.BreakoutLevel = string(state.Level)
at.gridState.BreakoutDirection = state.Direction
at.gridState.BreakoutConfirmCount = state.ConfirmCount
at.gridState.mu.Unlock()
if !confirmed {
return nil
}
// Take action based on breakout level
// Use direction-aware action if enabled
enableDirectionAdjust := gridConfig.EnableDirectionAdjust
action := getBreakoutActionWithDirection(breakoutLevel, enableDirectionAdjust)
// If direction adjustment action, determine the new direction
if action == BreakoutActionAdjustDirection {
box, _ := market.GetBoxData(gridConfig.Symbol)
newDirection := determineGridDirection(box, at.gridState.CurrentDirection, breakoutLevel, direction)
return at.executeDirectionAdjustment(newDirection)
}
return at.executeBreakoutAction(action)
}
// executeBreakoutAction executes the appropriate action for a breakout
func (at *AutoTrader) executeBreakoutAction(action BreakoutAction) error {
switch action {
case BreakoutActionReducePosition:
// Short box breakout: reduce position to 50%
logger.Infof("Short box breakout confirmed, reducing position to 50%%")
at.gridState.mu.Lock()
at.gridState.PositionReductionPct = 50
at.gridState.mu.Unlock()
return nil
case BreakoutActionPauseGrid:
// Mid box breakout: pause grid + cancel orders
logger.Infof("Mid box breakout confirmed, pausing grid and canceling orders")
at.gridState.mu.Lock()
at.gridState.IsPaused = true
at.gridState.mu.Unlock()
return at.cancelAllGridOrders()
case BreakoutActionCloseAll:
// Long box breakout: pause + cancel + close all
logger.Infof("Long box breakout confirmed, closing all positions")
at.gridState.mu.Lock()
at.gridState.IsPaused = true
at.gridState.mu.Unlock()
if err := at.cancelAllGridOrders(); err != nil {
logger.Infof("Failed to cancel orders: %v", err)
}
return at.closeAllPositions()
case BreakoutActionAdjustDirection:
// Direction adjustment is handled separately via executeDirectionAdjustment
// This case should not be reached, but handle gracefully
logger.Infof("Direction adjustment action received via executeBreakoutAction")
return nil
}
return nil
}
// executeDirectionAdjustment handles grid direction changes based on box breakout
func (at *AutoTrader) executeDirectionAdjustment(newDirection market.GridDirection) error {
at.gridState.mu.RLock()
oldDirection := at.gridState.CurrentDirection
at.gridState.mu.RUnlock()
if oldDirection == newDirection {
return nil // No change needed
}
logger.Infof("[Grid] Direction adjustment: %s -> %s", oldDirection, newDirection)
// Cancel existing orders before adjusting
if err := at.cancelAllGridOrders(); err != nil {
logger.Warnf("[Grid] Failed to cancel orders during direction adjustment: %v", err)
}
// Apply the new direction
return at.adjustGridDirection(newDirection)
}
// adjustGridDirection handles runtime direction adjustment when breakout is detected
func (at *AutoTrader) adjustGridDirection(newDirection market.GridDirection) error {
at.gridState.mu.Lock()
defer at.gridState.mu.Unlock()
oldDirection := at.gridState.CurrentDirection
if oldDirection == newDirection {
return nil // No change needed
}
at.gridState.CurrentDirection = newDirection
at.gridState.DirectionChangedAt = time.Now()
at.gridState.DirectionChangeCount++
logger.Infof("[Grid] Direction changed: %s -> %s (change count: %d)",
oldDirection, newDirection, at.gridState.DirectionChangeCount)
// Get current price for recalculation
currentPrice, err := at.trader.GetMarketPrice(at.gridState.Config.Symbol)
if err != nil {
return fmt.Errorf("failed to get market price: %w", err)
}
// Reapply direction to grid levels
at.applyGridDirection(currentPrice)
return nil
}
// checkFalseBreakoutRecovery checks if price has returned to box after breakout
func (at *AutoTrader) checkFalseBreakoutRecovery() error {
gridConfig := at.config.StrategyConfig.GridConfig
if gridConfig == nil {
return nil
}
at.gridState.mu.RLock()
breakoutLevel := at.gridState.BreakoutLevel
isPaused := at.gridState.IsPaused
positionReduction := at.gridState.PositionReductionPct
currentDirection := at.gridState.CurrentDirection
at.gridState.mu.RUnlock()
// Only check if we had a breakout or non-neutral direction
needsRecoveryCheck := breakoutLevel != string(market.BreakoutNone) ||
positionReduction != 0 ||
isPaused ||
(gridConfig.EnableDirectionAdjust && currentDirection != market.GridDirectionNeutral)
if !needsRecoveryCheck {
return nil
}
// Get current box data
box, err := market.GetBoxData(gridConfig.Symbol)
if err != nil {
return nil
}
// Check if price is back inside the long box
if box.CurrentPrice >= box.LongLower && box.CurrentPrice <= box.LongUpper {
logger.Infof("Price returned to box, recovering with 50%% position")
at.gridState.mu.Lock()
at.gridState.BreakoutLevel = string(market.BreakoutNone)
at.gridState.BreakoutDirection = ""
at.gridState.BreakoutConfirmCount = 0
at.gridState.PositionReductionPct = 50 // Recover at 50%
at.gridState.IsPaused = false
at.gridState.mu.Unlock()
}
// Check for direction recovery toward neutral (if direction adjustment is enabled)
if gridConfig.EnableDirectionAdjust && currentDirection != market.GridDirectionNeutral {
if shouldRecoverDirection(box, currentDirection) {
newDirection := determineRecoveryDirection(box.CurrentPrice, box, currentDirection)
if newDirection != currentDirection {
logger.Infof("[Grid] Direction recovery: %s -> %s (price back in short box)",
currentDirection, newDirection)
at.adjustGridDirection(newDirection)
}
}
}
return nil
}
// GetGridRiskInfo returns current risk information for frontend display
func (at *AutoTrader) GetGridRiskInfo() *GridRiskInfo {
gridConfig := at.config.StrategyConfig.GridConfig
if gridConfig == nil {
return &GridRiskInfo{}
}
at.gridState.mu.RLock()
defer at.gridState.mu.RUnlock()
// Get current price
currentPrice, _ := at.trader.GetMarketPrice(gridConfig.Symbol)
// Calculate effective leverage
totalInvestment := gridConfig.TotalInvestment
leverage := gridConfig.Leverage
// Get current position value
positions, _ := at.trader.GetPositions()
var currentPositionValue float64
var currentPositionSize float64
for _, pos := range positions {
if sym, _ := pos["symbol"].(string); sym == gridConfig.Symbol {
size, _ := pos["positionAmt"].(float64)
entry, _ := pos["entryPrice"].(float64)
currentPositionValue = math.Abs(size * entry)
currentPositionSize = size
break
}
}
effectiveLeverage := 0.0
if totalInvestment > 0 {
effectiveLeverage = currentPositionValue / totalInvestment
}
// Calculate max position based on regime
regimeLevel := market.RegimeLevel(at.gridState.CurrentRegimeLevel)
if regimeLevel == "" {
regimeLevel = market.RegimeLevelStandard
}
// Use default position limit since GridStrategyConfig doesn't have regime-specific limits
// Default is 70% for standard regime
maxPositionPct := 70.0
switch regimeLevel {
case market.RegimeLevelNarrow:
maxPositionPct = 40.0
case market.RegimeLevelStandard:
maxPositionPct = 70.0
case market.RegimeLevelWide:
maxPositionPct = 60.0
case market.RegimeLevelVolatile:
maxPositionPct = 40.0
}
maxPosition := totalInvestment * maxPositionPct / 100 * float64(leverage)
// Use default leverage limits since GridStrategyConfig doesn't have regime-specific limits
recommendedLeverage := leverage
switch regimeLevel {
case market.RegimeLevelNarrow:
recommendedLeverage = min(leverage, 2)
case market.RegimeLevelStandard:
recommendedLeverage = min(leverage, 4)
case market.RegimeLevelWide:
recommendedLeverage = min(leverage, 3)
case market.RegimeLevelVolatile:
recommendedLeverage = min(leverage, 2)
}
// Calculate liquidation distance and price only when there's a position
var liquidationDistance float64
var liquidationPrice float64
if currentPositionSize != 0 && currentPrice > 0 {
liquidationDistance = 100.0 / float64(leverage) * 0.9 // ~90% of theoretical max
if currentPositionSize > 0 {
// Long position: liquidation below entry
liquidationPrice = currentPrice * (1 - liquidationDistance/100)
} else {
// Short position: liquidation above entry
liquidationPrice = currentPrice * (1 + liquidationDistance/100)
}
}
positionPercent := 0.0
if maxPosition > 0 {
positionPercent = currentPositionValue / maxPosition * 100
}
return &GridRiskInfo{
CurrentLeverage: leverage,
EffectiveLeverage: effectiveLeverage,
RecommendedLeverage: recommendedLeverage,
CurrentPosition: currentPositionValue,
MaxPosition: maxPosition,
PositionPercent: positionPercent,
LiquidationPrice: liquidationPrice,
LiquidationDistance: liquidationDistance,
RegimeLevel: string(regimeLevel),
ShortBoxUpper: at.gridState.ShortBoxUpper,
ShortBoxLower: at.gridState.ShortBoxLower,
MidBoxUpper: at.gridState.MidBoxUpper,
MidBoxLower: at.gridState.MidBoxLower,
LongBoxUpper: at.gridState.LongBoxUpper,
LongBoxLower: at.gridState.LongBoxLower,
CurrentPrice: currentPrice,
BreakoutLevel: at.gridState.BreakoutLevel,
BreakoutDirection: at.gridState.BreakoutDirection,
CurrentGridDirection: string(at.gridState.CurrentDirection),
DirectionChangeCount: at.gridState.DirectionChangeCount,
EnableDirectionAdjust: gridConfig.EnableDirectionAdjust,
}
}
+8 -1315
View File
File diff suppressed because it is too large Load Diff
+291
View File
@@ -0,0 +1,291 @@
package binance
import (
"context"
"fmt"
"nofx/logger"
"nofx/trader/types"
"strconv"
"time"
)
// GetBalance gets account balance (with cache)
func (t *FuturesTrader) GetBalance() (map[string]interface{}, error) {
// First check if cache is valid
t.balanceCacheMutex.RLock()
if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration {
cacheAge := time.Since(t.balanceCacheTime)
t.balanceCacheMutex.RUnlock()
logger.Infof("✓ Using cached account balance (cache age: %.1f seconds ago)", cacheAge.Seconds())
return t.cachedBalance, nil
}
t.balanceCacheMutex.RUnlock()
// Cache expired or doesn't exist, call API
logger.Infof("🔄 Cache expired, calling Binance API to get account balance...")
account, err := t.client.NewGetAccountService().Do(context.Background())
if err != nil {
logger.Infof("❌ Binance API call failed: %v", err)
return nil, fmt.Errorf("failed to get account info: %w", err)
}
result := make(map[string]interface{})
result["totalWalletBalance"], _ = strconv.ParseFloat(account.TotalWalletBalance, 64)
result["availableBalance"], _ = strconv.ParseFloat(account.AvailableBalance, 64)
result["totalUnrealizedProfit"], _ = strconv.ParseFloat(account.TotalUnrealizedProfit, 64)
logger.Infof("✓ Binance API returned: total balance=%s, available=%s, unrealized PnL=%s",
account.TotalWalletBalance,
account.AvailableBalance,
account.TotalUnrealizedProfit)
// Update cache
t.balanceCacheMutex.Lock()
t.cachedBalance = result
t.balanceCacheTime = time.Now()
t.balanceCacheMutex.Unlock()
return result, nil
}
// GetClosedPnL retrieves recent closing trades from Binance Futures
// Note: Binance does NOT have a position history API, only trade history.
// This returns individual closing trades (realizedPnl != 0) for real-time position closure detection.
// NOT suitable for historical position reconstruction - use only for matching recent closures.
func (t *FuturesTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) {
trades, err := t.GetTrades(startTime, limit)
if err != nil {
return nil, err
}
// Filter only closing trades (realizedPnl != 0) and convert to ClosedPnLRecord
var records []types.ClosedPnLRecord
for _, trade := range trades {
if trade.RealizedPnL == 0 {
continue // Skip opening trades
}
// Determine side from trade
side := "long"
if trade.PositionSide == "SHORT" || trade.PositionSide == "short" {
side = "short"
} else if trade.PositionSide == "BOTH" || trade.PositionSide == "" {
// One-way mode: selling closes long, buying closes short
if trade.Side == "SELL" || trade.Side == "Sell" {
side = "long"
} else {
side = "short"
}
}
// Calculate entry price from PnL (mathematically accurate for this trade)
var entryPrice float64
if trade.Quantity > 0 {
if side == "long" {
entryPrice = trade.Price - trade.RealizedPnL/trade.Quantity
} else {
entryPrice = trade.Price + trade.RealizedPnL/trade.Quantity
}
}
records = append(records, types.ClosedPnLRecord{
Symbol: trade.Symbol,
Side: side,
EntryPrice: entryPrice,
ExitPrice: trade.Price,
Quantity: trade.Quantity,
RealizedPnL: trade.RealizedPnL,
Fee: trade.Fee,
ExitTime: trade.Time,
EntryTime: trade.Time, // Approximate
OrderID: trade.TradeID,
ExchangeID: trade.TradeID,
CloseType: "unknown",
})
}
return records, nil
}
// GetTrades retrieves trade history from Binance Futures using Income API
// Note: Income API has delays (~minutes), for real-time use GetTradesForSymbol instead
func (t *FuturesTrader) GetTrades(startTime time.Time, limit int) ([]types.TradeRecord, error) {
if limit <= 0 {
limit = 100
}
if limit > 1000 {
limit = 1000
}
// Use Income API to get REALIZED_PNL records (all symbols)
incomes, err := t.client.NewGetIncomeHistoryService().
IncomeType("REALIZED_PNL").
StartTime(startTime.UnixMilli()).
Limit(int64(limit)).
Do(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get income history: %w", err)
}
var trades []types.TradeRecord
for _, income := range incomes {
pnl, _ := strconv.ParseFloat(income.Income, 64)
if pnl == 0 {
continue // Skip zero PnL records
}
// Income API doesn't provide full trade details, create a minimal record
// This is mainly used for detecting recent closures, not historical reconstruction
trade := types.TradeRecord{
TradeID: strconv.FormatInt(income.TranID, 10),
Symbol: income.Symbol,
RealizedPnL: pnl,
Time: time.UnixMilli(income.Time).UTC(),
// Note: Income API doesn't provide price, quantity, side, fee
// For accurate data, use GetTradesForSymbol with specific symbol
}
trades = append(trades, trade)
}
return trades, nil
}
// GetTradesForSymbol retrieves trade history for a specific symbol
// This is more reliable than using Income API which may have delays
func (t *FuturesTrader) GetTradesForSymbol(symbol string, startTime time.Time, limit int) ([]types.TradeRecord, error) {
if limit <= 0 {
limit = 100
}
if limit > 1000 {
limit = 1000
}
accountTrades, err := t.client.NewListAccountTradeService().
Symbol(symbol).
StartTime(startTime.UnixMilli()).
Limit(limit).
Do(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get trade history for %s: %w", symbol, err)
}
var trades []types.TradeRecord
for _, at := range accountTrades {
price, _ := strconv.ParseFloat(at.Price, 64)
qty, _ := strconv.ParseFloat(at.Quantity, 64)
fee, _ := strconv.ParseFloat(at.Commission, 64)
pnl, _ := strconv.ParseFloat(at.RealizedPnl, 64)
trade := types.TradeRecord{
TradeID: strconv.FormatInt(at.ID, 10),
Symbol: at.Symbol,
Side: string(at.Side),
PositionSide: string(at.PositionSide),
Price: price,
Quantity: qty,
RealizedPnL: pnl,
Fee: fee,
Time: time.UnixMilli(at.Time).UTC(),
}
trades = append(trades, trade)
}
return trades, nil
}
// GetTradesForSymbolFromID retrieves trade history for a specific symbol starting from a given trade ID
// This is used for incremental sync - only fetch new trades since last sync
func (t *FuturesTrader) GetTradesForSymbolFromID(symbol string, fromID int64, limit int) ([]types.TradeRecord, error) {
if limit <= 0 {
limit = 100
}
if limit > 1000 {
limit = 1000
}
accountTrades, err := t.client.NewListAccountTradeService().
Symbol(symbol).
FromID(fromID).
Limit(limit).
Do(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get trade history for %s from ID %d: %w", symbol, fromID, err)
}
var trades []types.TradeRecord
for _, at := range accountTrades {
price, _ := strconv.ParseFloat(at.Price, 64)
qty, _ := strconv.ParseFloat(at.Quantity, 64)
fee, _ := strconv.ParseFloat(at.Commission, 64)
pnl, _ := strconv.ParseFloat(at.RealizedPnl, 64)
trade := types.TradeRecord{
TradeID: strconv.FormatInt(at.ID, 10),
Symbol: at.Symbol,
Side: string(at.Side),
PositionSide: string(at.PositionSide),
Price: price,
Quantity: qty,
RealizedPnL: pnl,
Fee: fee,
Time: time.UnixMilli(at.Time).UTC(),
}
trades = append(trades, trade)
}
return trades, nil
}
// GetCommissionSymbols returns symbols that have new commission records since lastSyncTime
// COMMISSION income is generated for every trade, so this is more reliable than REALIZED_PNL
func (t *FuturesTrader) GetCommissionSymbols(lastSyncTime time.Time) ([]string, error) {
incomes, err := t.client.NewGetIncomeHistoryService().
IncomeType("COMMISSION").
StartTime(lastSyncTime.UnixMilli()).
Limit(1000).
Do(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get commission history: %w", err)
}
symbolMap := make(map[string]bool)
for _, income := range incomes {
if income.Symbol != "" {
symbolMap[income.Symbol] = true
}
}
var symbols []string
for symbol := range symbolMap {
symbols = append(symbols, symbol)
}
return symbols, nil
}
// GetPnLSymbols returns symbols that have REALIZED_PNL records since lastSyncTime
// This is a fallback when COMMISSION detection fails (VIP users, BNB fee discount)
func (t *FuturesTrader) GetPnLSymbols(lastSyncTime time.Time) ([]string, error) {
incomes, err := t.client.NewGetIncomeHistoryService().
IncomeType("REALIZED_PNL").
StartTime(lastSyncTime.UnixMilli()).
Limit(1000).
Do(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get PnL history: %w", err)
}
symbolMap := make(map[string]bool)
for _, income := range incomes {
if income.Symbol != "" {
symbolMap[income.Symbol] = true
}
}
var symbols []string
for symbol := range symbolMap {
symbols = append(symbols, symbol)
}
return symbols, nil
}
+758
View File
@@ -0,0 +1,758 @@
package binance
import (
"context"
"fmt"
"nofx/logger"
"nofx/trader/types"
"strconv"
"github.com/adshao/go-binance/v2/futures"
)
// OpenLong opens a long position
func (t *FuturesTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
// First cancel all pending orders for this symbol (clean up old stop-loss and take-profit orders)
if err := t.CancelAllOrders(symbol); err != nil {
logger.Infof(" ⚠ Failed to cancel old pending orders (may not have any): %v", err)
}
// Set leverage
if err := t.SetLeverage(symbol, leverage); err != nil {
return nil, err
}
// Note: Margin mode should be set by the caller (AutoTrader) before opening position via SetMarginMode
// Format quantity to correct precision
quantityStr, err := t.FormatQuantity(symbol, quantity)
if err != nil {
return nil, err
}
// Check if formatted quantity is 0 (prevent rounding errors)
quantityFloat, parseErr := strconv.ParseFloat(quantityStr, 64)
if parseErr != nil || quantityFloat <= 0 {
return nil, fmt.Errorf("position size too small, rounded to 0 (original: %.8f → formatted: %s). Suggest increasing position amount or selecting a lower-priced coin", quantity, quantityStr)
}
// Check minimum notional value (Binance requires at least 10 USDT)
if err := t.CheckMinNotional(symbol, quantityFloat); err != nil {
return nil, err
}
// Create market buy order (using br ID)
order, err := t.client.NewCreateOrderService().
Symbol(symbol).
Side(futures.SideTypeBuy).
PositionSide(futures.PositionSideTypeLong).
Type(futures.OrderTypeMarket).
Quantity(quantityStr).
NewClientOrderID(getBrOrderID()).
Do(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to open long position: %w", err)
}
logger.Infof("✓ Opened long position successfully: %s quantity: %s", symbol, quantityStr)
logger.Infof(" Order ID: %d", order.OrderID)
result := make(map[string]interface{})
result["orderId"] = order.OrderID
result["symbol"] = order.Symbol
result["status"] = order.Status
return result, nil
}
// OpenShort opens a short position
func (t *FuturesTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
// First cancel all pending orders for this symbol (clean up old stop-loss and take-profit orders)
if err := t.CancelAllOrders(symbol); err != nil {
logger.Infof(" ⚠ Failed to cancel old pending orders (may not have any): %v", err)
}
// Set leverage
if err := t.SetLeverage(symbol, leverage); err != nil {
return nil, err
}
// Note: Margin mode should be set by the caller (AutoTrader) before opening position via SetMarginMode
// Format quantity to correct precision
quantityStr, err := t.FormatQuantity(symbol, quantity)
if err != nil {
return nil, err
}
// Check if formatted quantity is 0 (prevent rounding errors)
quantityFloat, parseErr := strconv.ParseFloat(quantityStr, 64)
if parseErr != nil || quantityFloat <= 0 {
return nil, fmt.Errorf("position size too small, rounded to 0 (original: %.8f → formatted: %s). Suggest increasing position amount or selecting a lower-priced coin", quantity, quantityStr)
}
// Check minimum notional value (Binance requires at least 10 USDT)
if err := t.CheckMinNotional(symbol, quantityFloat); err != nil {
return nil, err
}
// Create market sell order (using br ID)
order, err := t.client.NewCreateOrderService().
Symbol(symbol).
Side(futures.SideTypeSell).
PositionSide(futures.PositionSideTypeShort).
Type(futures.OrderTypeMarket).
Quantity(quantityStr).
NewClientOrderID(getBrOrderID()).
Do(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to open short position: %w", err)
}
logger.Infof("✓ Opened short position successfully: %s quantity: %s", symbol, quantityStr)
logger.Infof(" Order ID: %d", order.OrderID)
result := make(map[string]interface{})
result["orderId"] = order.OrderID
result["symbol"] = order.Symbol
result["status"] = order.Status
return result, nil
}
// CloseLong closes a long position
func (t *FuturesTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) {
// If quantity is 0, get current position quantity
if quantity == 0 {
positions, err := t.GetPositions()
if err != nil {
return nil, err
}
for _, pos := range positions {
if pos["symbol"] == symbol && pos["side"] == "long" {
quantity = pos["positionAmt"].(float64)
break
}
}
if quantity == 0 {
return nil, fmt.Errorf("no long position found for %s", symbol)
}
}
// Format quantity
quantityStr, err := t.FormatQuantity(symbol, quantity)
if err != nil {
return nil, err
}
// Create market sell order (close long, using br ID)
order, err := t.client.NewCreateOrderService().
Symbol(symbol).
Side(futures.SideTypeSell).
PositionSide(futures.PositionSideTypeLong).
Type(futures.OrderTypeMarket).
Quantity(quantityStr).
NewClientOrderID(getBrOrderID()).
Do(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to close long position: %w", err)
}
logger.Infof("✓ Closed long position successfully: %s quantity: %s", symbol, quantityStr)
// After closing position, cancel all pending orders for this symbol (stop-loss and take-profit orders)
if err := t.CancelAllOrders(symbol); err != nil {
logger.Infof(" ⚠ Failed to cancel pending orders: %v", err)
}
result := make(map[string]interface{})
result["orderId"] = order.OrderID
result["symbol"] = order.Symbol
result["status"] = order.Status
return result, nil
}
// CloseShort closes a short position
func (t *FuturesTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) {
// If quantity is 0, get current position quantity
if quantity == 0 {
positions, err := t.GetPositions()
if err != nil {
return nil, err
}
for _, pos := range positions {
if pos["symbol"] == symbol && pos["side"] == "short" {
quantity = -pos["positionAmt"].(float64) // Short position quantity is negative, take absolute value
break
}
}
if quantity == 0 {
return nil, fmt.Errorf("no short position found for %s", symbol)
}
}
// Format quantity
quantityStr, err := t.FormatQuantity(symbol, quantity)
if err != nil {
return nil, err
}
// Create market buy order (close short, using br ID)
order, err := t.client.NewCreateOrderService().
Symbol(symbol).
Side(futures.SideTypeBuy).
PositionSide(futures.PositionSideTypeShort).
Type(futures.OrderTypeMarket).
Quantity(quantityStr).
NewClientOrderID(getBrOrderID()).
Do(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to close short position: %w", err)
}
logger.Infof("✓ Closed short position successfully: %s quantity: %s", symbol, quantityStr)
// After closing position, cancel all pending orders for this symbol (stop-loss and take-profit orders)
if err := t.CancelAllOrders(symbol); err != nil {
logger.Infof(" ⚠ Failed to cancel pending orders: %v", err)
}
result := make(map[string]interface{})
result["orderId"] = order.OrderID
result["symbol"] = order.Symbol
result["status"] = order.Status
return result, nil
}
// CancelStopLossOrders cancels only stop-loss orders (doesn't affect take-profit orders)
// Now uses both legacy API and new Algo Order API
func (t *FuturesTrader) CancelStopLossOrders(symbol string) error {
canceledCount := 0
var cancelErrors []error
// 1. Cancel legacy stop-loss orders
orders, err := t.client.NewListOpenOrdersService().
Symbol(symbol).
Do(context.Background())
if err == nil {
for _, order := range orders {
orderType := string(order.Type)
// Only cancel stop-loss orders (don't cancel take-profit orders)
// Use string comparison since OrderType constants were removed in v2.8.9
if orderType == "STOP_MARKET" || orderType == "STOP" {
_, err := t.client.NewCancelOrderService().
Symbol(symbol).
OrderID(order.OrderID).
Do(context.Background())
if err != nil {
errMsg := fmt.Sprintf("Order ID %d: %v", order.OrderID, err)
cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg))
logger.Infof(" ⚠ Failed to cancel legacy stop-loss order: %s", errMsg)
continue
}
canceledCount++
logger.Infof(" ✓ Canceled legacy stop-loss order (Order ID: %d, Type: %s, Side: %s)", order.OrderID, orderType, order.PositionSide)
}
}
}
// 2. Cancel Algo stop-loss orders
algoOrders, err := t.client.NewListOpenAlgoOrdersService().
Symbol(symbol).
Do(context.Background())
if err == nil {
for _, algoOrder := range algoOrders {
// Only cancel stop-loss orders
if algoOrder.OrderType == futures.AlgoOrderTypeStopMarket || algoOrder.OrderType == futures.AlgoOrderTypeStop {
_, err := t.client.NewCancelAlgoOrderService().
AlgoID(algoOrder.AlgoId).
Do(context.Background())
if err != nil {
errMsg := fmt.Sprintf("Algo ID %d: %v", algoOrder.AlgoId, err)
cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg))
logger.Infof(" ⚠ Failed to cancel Algo stop-loss order: %s", errMsg)
continue
}
canceledCount++
logger.Infof(" ✓ Canceled Algo stop-loss order (Algo ID: %d, Type: %s)", algoOrder.AlgoId, algoOrder.OrderType)
}
}
}
if canceledCount == 0 && len(cancelErrors) == 0 {
logger.Infof(" %s has no stop-loss orders to cancel", symbol)
} else if canceledCount > 0 {
logger.Infof(" ✓ Canceled %d stop-loss order(s) for %s", canceledCount, symbol)
}
// If all cancellations failed, return error
if len(cancelErrors) > 0 && canceledCount == 0 {
return fmt.Errorf("failed to cancel stop-loss orders: %v", cancelErrors)
}
return nil
}
// CancelTakeProfitOrders cancels only take-profit orders (doesn't affect stop-loss orders)
// Now uses both legacy API and new Algo Order API
func (t *FuturesTrader) CancelTakeProfitOrders(symbol string) error {
canceledCount := 0
var cancelErrors []error
// 1. Cancel legacy take-profit orders
orders, err := t.client.NewListOpenOrdersService().
Symbol(symbol).
Do(context.Background())
if err == nil {
for _, order := range orders {
orderType := string(order.Type)
// Only cancel take-profit orders (don't cancel stop-loss orders)
// Use string comparison since OrderType constants were removed in v2.8.9
if orderType == "TAKE_PROFIT_MARKET" || orderType == "TAKE_PROFIT" {
_, err := t.client.NewCancelOrderService().
Symbol(symbol).
OrderID(order.OrderID).
Do(context.Background())
if err != nil {
errMsg := fmt.Sprintf("Order ID %d: %v", order.OrderID, err)
cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg))
logger.Infof(" ⚠ Failed to cancel legacy take-profit order: %s", errMsg)
continue
}
canceledCount++
logger.Infof(" ✓ Canceled legacy take-profit order (Order ID: %d, Type: %s, Side: %s)", order.OrderID, orderType, order.PositionSide)
}
}
}
// 2. Cancel Algo take-profit orders
algoOrders, err := t.client.NewListOpenAlgoOrdersService().
Symbol(symbol).
Do(context.Background())
if err == nil {
for _, algoOrder := range algoOrders {
// Only cancel take-profit orders
if algoOrder.OrderType == futures.AlgoOrderTypeTakeProfitMarket || algoOrder.OrderType == futures.AlgoOrderTypeTakeProfit {
_, err := t.client.NewCancelAlgoOrderService().
AlgoID(algoOrder.AlgoId).
Do(context.Background())
if err != nil {
errMsg := fmt.Sprintf("Algo ID %d: %v", algoOrder.AlgoId, err)
cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg))
logger.Infof(" ⚠ Failed to cancel Algo take-profit order: %s", errMsg)
continue
}
canceledCount++
logger.Infof(" ✓ Canceled Algo take-profit order (Algo ID: %d, Type: %s)", algoOrder.AlgoId, algoOrder.OrderType)
}
}
}
if canceledCount == 0 && len(cancelErrors) == 0 {
logger.Infof(" %s has no take-profit orders to cancel", symbol)
} else if canceledCount > 0 {
logger.Infof(" ✓ Canceled %d take-profit order(s) for %s", canceledCount, symbol)
}
// If all cancellations failed, return error
if len(cancelErrors) > 0 && canceledCount == 0 {
return fmt.Errorf("failed to cancel take-profit orders: %v", cancelErrors)
}
return nil
}
// CancelAllOrders cancels all pending orders for this symbol
// Now uses both legacy API and new Algo Order API
func (t *FuturesTrader) CancelAllOrders(symbol string) error {
// 1. Cancel all legacy orders
err := t.client.NewCancelAllOpenOrdersService().
Symbol(symbol).
Do(context.Background())
if err != nil {
logger.Infof(" ⚠ Failed to cancel legacy orders: %v", err)
} else {
logger.Infof(" ✓ Canceled all legacy pending orders for %s", symbol)
}
// 2. Cancel all Algo orders
err = t.client.NewCancelAllAlgoOpenOrdersService().
Symbol(symbol).
Do(context.Background())
if err != nil {
// Ignore "no algo orders" error
if !contains(err.Error(), "no algo") && !contains(err.Error(), "No algo") {
logger.Infof(" ⚠ Failed to cancel Algo orders: %v", err)
}
} else {
logger.Infof(" ✓ Canceled all Algo orders for %s", symbol)
}
return nil
}
// PlaceLimitOrder places a limit order for grid trading
// This implements the GridTrader interface for FuturesTrader
func (t *FuturesTrader) PlaceLimitOrder(req *types.LimitOrderRequest) (*types.LimitOrderResult, error) {
// Format quantity to correct precision
quantityStr, err := t.FormatQuantity(req.Symbol, req.Quantity)
if err != nil {
return nil, fmt.Errorf("failed to format quantity: %w", err)
}
// Format price to correct precision
priceStr, err := t.FormatPrice(req.Symbol, req.Price)
if err != nil {
return nil, fmt.Errorf("failed to format price: %w", err)
}
// Set leverage if specified
if req.Leverage > 0 {
if err := t.SetLeverage(req.Symbol, req.Leverage); err != nil {
logger.Warnf("Failed to set leverage: %v", err)
}
}
// Determine side and position side
var side futures.SideType
var positionSide futures.PositionSideType
if req.Side == "BUY" {
side = futures.SideTypeBuy
positionSide = futures.PositionSideTypeLong
} else {
side = futures.SideTypeSell
positionSide = futures.PositionSideTypeShort
}
// Build order service with broker ID
orderService := t.client.NewCreateOrderService().
Symbol(req.Symbol).
Side(side).
PositionSide(positionSide).
Type(futures.OrderTypeLimit).
TimeInForce(futures.TimeInForceTypeGTC).
Quantity(quantityStr).
Price(priceStr).
NewClientOrderID(getBrOrderID())
// Execute order
order, err := orderService.Do(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to place limit order: %w", err)
}
logger.Infof("✓ [Grid] Placed limit order: %s %s %s @ %s, qty=%s, orderID=%d",
req.Symbol, req.Side, positionSide, priceStr, quantityStr, order.OrderID)
return &types.LimitOrderResult{
OrderID: fmt.Sprintf("%d", order.OrderID),
ClientID: order.ClientOrderID,
Symbol: order.Symbol,
Side: string(order.Side),
PositionSide: string(order.PositionSide),
Price: req.Price,
Quantity: req.Quantity,
Status: string(order.Status),
}, nil
}
// CancelOrder cancels a specific order by ID
// This implements the GridTrader interface for FuturesTrader
func (t *FuturesTrader) CancelOrder(symbol, orderID string) error {
// Parse order ID to int64
orderIDInt, err := strconv.ParseInt(orderID, 10, 64)
if err != nil {
return fmt.Errorf("invalid order ID: %w", err)
}
_, err = t.client.NewCancelOrderService().
Symbol(symbol).
OrderID(orderIDInt).
Do(context.Background())
if err != nil {
return fmt.Errorf("failed to cancel order: %w", err)
}
logger.Infof("✓ [Grid] Cancelled order: %s/%s", symbol, orderID)
return nil
}
// GetOrderBook gets the order book for a symbol
// This implements the GridTrader interface for FuturesTrader
func (t *FuturesTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) {
book, err := t.client.NewDepthService().
Symbol(symbol).
Limit(depth).
Do(context.Background())
if err != nil {
return nil, nil, fmt.Errorf("failed to get order book: %w", err)
}
// Convert bids
bids = make([][]float64, len(book.Bids))
for i, bid := range book.Bids {
price, _ := strconv.ParseFloat(bid.Price, 64)
qty, _ := strconv.ParseFloat(bid.Quantity, 64)
bids[i] = []float64{price, qty}
}
// Convert asks
asks = make([][]float64, len(book.Asks))
for i, ask := range book.Asks {
price, _ := strconv.ParseFloat(ask.Price, 64)
qty, _ := strconv.ParseFloat(ask.Quantity, 64)
asks[i] = []float64{price, qty}
}
return bids, asks, nil
}
// CancelStopOrders cancels take-profit/stop-loss orders for this symbol (used to adjust TP/SL positions)
// Now uses both legacy API and new Algo Order API (Binance migrated stop orders to Algo system)
func (t *FuturesTrader) CancelStopOrders(symbol string) error {
canceledCount := 0
// 1. Cancel legacy stop orders (for backward compatibility)
orders, err := t.client.NewListOpenOrdersService().
Symbol(symbol).
Do(context.Background())
if err == nil {
for _, order := range orders {
orderType := string(order.Type)
// Only cancel stop-loss and take-profit orders
// Use string comparison since OrderType constants were removed in v2.8.9
if orderType == "STOP_MARKET" ||
orderType == "TAKE_PROFIT_MARKET" ||
orderType == "STOP" ||
orderType == "TAKE_PROFIT" {
_, err := t.client.NewCancelOrderService().
Symbol(symbol).
OrderID(order.OrderID).
Do(context.Background())
if err != nil {
logger.Infof(" ⚠ Failed to cancel legacy order %d: %v", order.OrderID, err)
continue
}
canceledCount++
logger.Infof(" ✓ Canceled legacy stop order for %s (Order ID: %d, Type: %s)",
symbol, order.OrderID, orderType)
}
}
}
// 2. Cancel Algo orders (new API)
err = t.client.NewCancelAllAlgoOpenOrdersService().
Symbol(symbol).
Do(context.Background())
if err != nil {
// Ignore "no algo orders" error
if !contains(err.Error(), "no algo") && !contains(err.Error(), "No algo") {
logger.Infof(" ⚠ Failed to cancel Algo orders: %v", err)
}
} else {
logger.Infof(" ✓ Canceled all Algo orders for %s", symbol)
canceledCount++
}
if canceledCount == 0 {
logger.Infof(" %s has no take-profit/stop-loss orders to cancel", symbol)
}
return nil
}
// GetOpenOrders gets all open/pending orders for a symbol
func (t *FuturesTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) {
var result []types.OpenOrder
// 1. Get legacy open orders
orders, err := t.client.NewListOpenOrdersService().
Symbol(symbol).
Do(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get open orders: %w", err)
}
for _, order := range orders {
price, _ := strconv.ParseFloat(order.Price, 64)
stopPrice, _ := strconv.ParseFloat(order.StopPrice, 64)
quantity, _ := strconv.ParseFloat(order.OrigQuantity, 64)
result = append(result, types.OpenOrder{
OrderID: fmt.Sprintf("%d", order.OrderID),
Symbol: order.Symbol,
Side: string(order.Side),
PositionSide: string(order.PositionSide),
Type: string(order.Type),
Price: price,
StopPrice: stopPrice,
Quantity: quantity,
Status: string(order.Status),
})
}
// 2. Get Algo orders (new API for stop-loss/take-profit)
algoOrders, err := t.client.NewListOpenAlgoOrdersService().
Symbol(symbol).
Do(context.Background())
if err == nil {
for _, algoOrder := range algoOrders {
triggerPrice, _ := strconv.ParseFloat(algoOrder.TriggerPrice, 64)
quantity, _ := strconv.ParseFloat(algoOrder.Quantity, 64)
result = append(result, types.OpenOrder{
OrderID: fmt.Sprintf("%d", algoOrder.AlgoId),
Symbol: algoOrder.Symbol,
Side: string(algoOrder.Side),
PositionSide: string(algoOrder.PositionSide),
Type: string(algoOrder.OrderType),
Price: 0, // Algo orders use stop price
StopPrice: triggerPrice,
Quantity: quantity,
Status: "NEW",
})
}
}
return result, nil
}
// SetStopLoss sets stop-loss order using new Algo Order API
// Binance has migrated stop orders to Algo Order system (error -4120 STOP_ORDER_SWITCH_ALGO)
func (t *FuturesTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error {
var side futures.SideType
var posSide futures.PositionSideType
if positionSide == "LONG" {
side = futures.SideTypeSell
posSide = futures.PositionSideTypeLong
} else {
side = futures.SideTypeBuy
posSide = futures.PositionSideTypeShort
}
// Use new Algo Order API
_, err := t.client.NewCreateAlgoOrderService().
Symbol(symbol).
Side(side).
PositionSide(posSide).
Type(futures.AlgoOrderTypeStopMarket).
TriggerPrice(fmt.Sprintf("%.8f", stopPrice)).
WorkingType(futures.WorkingTypeContractPrice).
ClosePosition(true).
ClientAlgoId(getBrOrderID()).
Do(context.Background())
if err != nil {
return fmt.Errorf("failed to set stop-loss: %w", err)
}
logger.Infof(" Stop-loss price set (Algo Order): %.4f", stopPrice)
return nil
}
// SetTakeProfit sets take-profit order using new Algo Order API
// Binance has migrated stop orders to Algo Order system (error -4120 STOP_ORDER_SWITCH_ALGO)
func (t *FuturesTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error {
var side futures.SideType
var posSide futures.PositionSideType
if positionSide == "LONG" {
side = futures.SideTypeSell
posSide = futures.PositionSideTypeLong
} else {
side = futures.SideTypeBuy
posSide = futures.PositionSideTypeShort
}
// Use new Algo Order API
_, err := t.client.NewCreateAlgoOrderService().
Symbol(symbol).
Side(side).
PositionSide(posSide).
Type(futures.AlgoOrderTypeTakeProfitMarket).
TriggerPrice(fmt.Sprintf("%.8f", takeProfitPrice)).
WorkingType(futures.WorkingTypeContractPrice).
ClosePosition(true).
ClientAlgoId(getBrOrderID()).
Do(context.Background())
if err != nil {
return fmt.Errorf("failed to set take-profit: %w", err)
}
logger.Infof(" Take-profit price set (Algo Order): %.4f", takeProfitPrice)
return nil
}
// GetOrderStatus gets order status
func (t *FuturesTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) {
// Convert orderID to int64
orderIDInt, err := strconv.ParseInt(orderID, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid order ID: %s", orderID)
}
order, err := t.client.NewGetOrderService().
Symbol(symbol).
OrderID(orderIDInt).
Do(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get order status: %w", err)
}
// Parse execution price
avgPrice, _ := strconv.ParseFloat(order.AvgPrice, 64)
executedQty, _ := strconv.ParseFloat(order.ExecutedQuantity, 64)
result := map[string]interface{}{
"orderId": order.OrderID,
"symbol": order.Symbol,
"status": string(order.Status),
"avgPrice": avgPrice,
"executedQty": executedQty,
"side": string(order.Side),
"type": string(order.Type),
"time": order.Time,
"updateTime": order.UpdateTime,
}
// Binance futures commission fee needs to be obtained through GetUserTrades, not retrieved here for now
// Can be obtained later through WebSocket or separate query
result["commission"] = 0.0
return result, nil
}
+290
View File
@@ -0,0 +1,290 @@
package binance
import (
"context"
"fmt"
"nofx/logger"
"strconv"
"time"
"github.com/adshao/go-binance/v2/futures"
)
// GetPositions gets all positions (with cache)
func (t *FuturesTrader) GetPositions() ([]map[string]interface{}, error) {
// First check if cache is valid
t.positionsCacheMutex.RLock()
if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration {
cacheAge := time.Since(t.positionsCacheTime)
t.positionsCacheMutex.RUnlock()
logger.Infof("✓ Using cached position information (cache age: %.1f seconds ago)", cacheAge.Seconds())
return t.cachedPositions, nil
}
t.positionsCacheMutex.RUnlock()
// Cache expired or doesn't exist, call API
logger.Infof("🔄 Cache expired, calling Binance API to get position information...")
positions, err := t.client.NewGetPositionRiskService().Do(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get positions: %w", err)
}
var result []map[string]interface{}
for _, pos := range positions {
posAmt, _ := strconv.ParseFloat(pos.PositionAmt, 64)
if posAmt == 0 {
continue // Skip positions with zero amount
}
posMap := make(map[string]interface{})
posMap["symbol"] = pos.Symbol
posMap["positionAmt"], _ = strconv.ParseFloat(pos.PositionAmt, 64)
posMap["entryPrice"], _ = strconv.ParseFloat(pos.EntryPrice, 64)
posMap["markPrice"], _ = strconv.ParseFloat(pos.MarkPrice, 64)
posMap["unRealizedProfit"], _ = strconv.ParseFloat(pos.UnRealizedProfit, 64)
posMap["leverage"], _ = strconv.ParseFloat(pos.Leverage, 64)
posMap["liquidationPrice"], _ = strconv.ParseFloat(pos.LiquidationPrice, 64)
// Note: Binance SDK doesn't expose updateTime field, will fallback to local tracking
// Determine direction
if posAmt > 0 {
posMap["side"] = "long"
} else {
posMap["side"] = "short"
}
result = append(result, posMap)
}
// Update cache
t.positionsCacheMutex.Lock()
t.cachedPositions = result
t.positionsCacheTime = time.Now()
t.positionsCacheMutex.Unlock()
return result, nil
}
// SetMarginMode sets margin mode
func (t *FuturesTrader) SetMarginMode(symbol string, isCrossMargin bool) error {
var marginType futures.MarginType
if isCrossMargin {
marginType = futures.MarginTypeCrossed
} else {
marginType = futures.MarginTypeIsolated
}
// Try to set margin mode
err := t.client.NewChangeMarginTypeService().
Symbol(symbol).
MarginType(marginType).
Do(context.Background())
marginModeStr := "Cross Margin"
if !isCrossMargin {
marginModeStr = "Isolated Margin"
}
if err != nil {
// If error message contains "No need to change", margin mode is already set to target value
if contains(err.Error(), "No need to change margin type") {
logger.Infof(" ✓ %s margin mode is already %s", symbol, marginModeStr)
return nil
}
// If there is an open position, margin mode cannot be changed, but this doesn't affect trading
if contains(err.Error(), "Margin type cannot be changed if there exists position") {
logger.Infof(" ⚠️ %s has open positions, cannot change margin mode, continuing with current mode", symbol)
return nil
}
// Detect Multi-Assets mode (error code -4168)
if contains(err.Error(), "Multi-Assets mode") || contains(err.Error(), "-4168") || contains(err.Error(), "4168") {
logger.Infof(" ⚠️ %s detected Multi-Assets mode, forcing Cross Margin mode", symbol)
logger.Infof(" 💡 Tip: To use Isolated Margin mode, please disable Multi-Assets mode in Binance")
return nil
}
// Detect Unified Account API (Portfolio Margin)
if contains(err.Error(), "unified") || contains(err.Error(), "portfolio") || contains(err.Error(), "Portfolio") {
logger.Infof(" ❌ %s detected Unified Account API, unable to trade futures", symbol)
return fmt.Errorf("please use 'Spot & Futures Trading' API permission, do not use 'Unified Account API'")
}
logger.Infof(" ⚠️ Failed to set margin mode: %v", err)
// Don't return error, let trading continue
return nil
}
logger.Infof(" ✓ %s margin mode set to %s", symbol, marginModeStr)
return nil
}
// SetLeverage sets leverage (with smart detection and cooldown period)
func (t *FuturesTrader) SetLeverage(symbol string, leverage int) error {
// First try to get current leverage (from position information)
currentLeverage := 0
positions, err := t.GetPositions()
if err == nil {
for _, pos := range positions {
if pos["symbol"] == symbol {
if lev, ok := pos["leverage"].(float64); ok {
currentLeverage = int(lev)
break
}
}
}
}
// If current leverage is already the target leverage, skip
if currentLeverage == leverage && currentLeverage > 0 {
logger.Infof(" ✓ %s leverage is already %dx, no need to change", symbol, leverage)
return nil
}
// Change leverage
_, err = t.client.NewChangeLeverageService().
Symbol(symbol).
Leverage(leverage).
Do(context.Background())
if err != nil {
// If error message contains "No need to change", leverage is already the target value
if contains(err.Error(), "No need to change") {
logger.Infof(" ✓ %s leverage is already %dx", symbol, leverage)
return nil
}
return fmt.Errorf("failed to set leverage: %w", err)
}
logger.Infof(" ✓ %s leverage changed to %dx", symbol, leverage)
// Wait 5 seconds after changing leverage (to avoid cooldown period errors)
logger.Infof(" ⏱ Waiting 5 seconds for cooldown period...")
time.Sleep(5 * time.Second)
return nil
}
// GetMarketPrice gets market price
func (t *FuturesTrader) GetMarketPrice(symbol string) (float64, error) {
prices, err := t.client.NewListPricesService().Symbol(symbol).Do(context.Background())
if err != nil {
return 0, fmt.Errorf("failed to get price: %w", err)
}
if len(prices) == 0 {
return 0, fmt.Errorf("price not found")
}
price, err := strconv.ParseFloat(prices[0].Price, 64)
if err != nil {
return 0, err
}
return price, nil
}
// CalculatePositionSize calculates position size
func (t *FuturesTrader) CalculatePositionSize(balance, riskPercent, price float64, leverage int) float64 {
riskAmount := balance * (riskPercent / 100.0)
positionValue := riskAmount * float64(leverage)
quantity := positionValue / price
return quantity
}
// GetMinNotional gets minimum notional value (Binance requirement)
func (t *FuturesTrader) GetMinNotional(symbol string) float64 {
// Use conservative default value of 10 USDT to ensure order passes exchange validation
return 10.0
}
// CheckMinNotional checks if order meets minimum notional value requirement
func (t *FuturesTrader) CheckMinNotional(symbol string, quantity float64) error {
price, err := t.GetMarketPrice(symbol)
if err != nil {
return fmt.Errorf("failed to get market price: %w", err)
}
notionalValue := quantity * price
minNotional := t.GetMinNotional(symbol)
if notionalValue < minNotional {
return fmt.Errorf(
"order amount %.2f USDT is below minimum requirement %.2f USDT (quantity: %.4f, price: %.4f)",
notionalValue, minNotional, quantity, price,
)
}
return nil
}
// GetSymbolPrecision gets the quantity precision for a trading pair
func (t *FuturesTrader) GetSymbolPrecision(symbol string) (int, error) {
exchangeInfo, err := t.client.NewExchangeInfoService().Do(context.Background())
if err != nil {
return 0, fmt.Errorf("failed to get trading rules: %w", err)
}
for _, s := range exchangeInfo.Symbols {
if s.Symbol == symbol {
// Get precision from LOT_SIZE filter
for _, filter := range s.Filters {
if filter["filterType"] == "LOT_SIZE" {
stepSize := filter["stepSize"].(string)
precision := calculatePrecision(stepSize)
logger.Infof(" %s quantity precision: %d (stepSize: %s)", symbol, precision, stepSize)
return precision, nil
}
}
}
}
logger.Infof(" ⚠ %s precision information not found, using default precision 3", symbol)
return 3, nil // Default precision is 3
}
// FormatQuantity formats quantity to correct precision
func (t *FuturesTrader) FormatQuantity(symbol string, quantity float64) (string, error) {
precision, err := t.GetSymbolPrecision(symbol)
if err != nil {
// If retrieval fails, use default format
return fmt.Sprintf("%.3f", quantity), nil
}
format := fmt.Sprintf("%%.%df", precision)
return fmt.Sprintf(format, quantity), nil
}
// GetSymbolPricePrecision gets the price precision for a trading pair
func (t *FuturesTrader) GetSymbolPricePrecision(symbol string) (int, error) {
exchangeInfo, err := t.client.NewExchangeInfoService().Do(context.Background())
if err != nil {
return 0, fmt.Errorf("failed to get trading rules: %w", err)
}
for _, s := range exchangeInfo.Symbols {
if s.Symbol == symbol {
// Get precision from PRICE_FILTER filter
for _, filter := range s.Filters {
if filter["filterType"] == "PRICE_FILTER" {
tickSize := filter["tickSize"].(string)
precision := calculatePrecision(tickSize)
return precision, nil
}
}
}
}
// Default to 2 decimal places for price
return 2, nil
}
// FormatPrice formats price to correct precision
func (t *FuturesTrader) FormatPrice(symbol string, price float64) (string, error) {
precision, err := t.GetSymbolPricePrecision(symbol)
if err != nil {
// If retrieval fails, use default format
return fmt.Sprintf("%.2f", price), nil
}
format := fmt.Sprintf("%%.%df", precision)
return fmt.Sprintf(format, price), nil
}
+22 -1064
View File
File diff suppressed because it is too large Load Diff
+199
View File
@@ -0,0 +1,199 @@
package bitget
import (
"encoding/json"
"fmt"
"nofx/logger"
"strconv"
"strings"
"time"
)
// GetBalance gets account balance
func (t *BitgetTrader) GetBalance() (map[string]interface{}, error) {
// Check cache
t.balanceCacheMutex.RLock()
if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration {
t.balanceCacheMutex.RUnlock()
return t.cachedBalance, nil
}
t.balanceCacheMutex.RUnlock()
params := map[string]interface{}{
"productType": "USDT-FUTURES",
}
data, err := t.doRequest("GET", bitgetAccountPath, params)
if err != nil {
return nil, fmt.Errorf("failed to get account balance: %w", err)
}
var accounts []struct {
MarginCoin string `json:"marginCoin"`
Available string `json:"available"` // Available balance
AccountEquity string `json:"accountEquity"` // Total equity
UsdtEquity string `json:"usdtEquity"` // USDT equity
UnrealizedPL string `json:"unrealizedPL"` // Unrealized P&L
}
if err := json.Unmarshal(data, &accounts); err != nil {
return nil, fmt.Errorf("failed to parse balance data: %w, raw: %s", err, string(data))
}
var totalEquity, availableBalance, unrealizedPnL float64
for _, acc := range accounts {
if acc.MarginCoin == "USDT" {
totalEquity, _ = strconv.ParseFloat(acc.AccountEquity, 64)
availableBalance, _ = strconv.ParseFloat(acc.Available, 64)
unrealizedPnL, _ = strconv.ParseFloat(acc.UnrealizedPL, 64)
logger.Infof("✓ [Bitget] Balance: equity=%.2f, available=%.2f", totalEquity, availableBalance)
break
}
}
result := map[string]interface{}{
"totalWalletBalance": totalEquity - unrealizedPnL,
"availableBalance": availableBalance,
"totalUnrealizedProfit": unrealizedPnL,
"total_equity": totalEquity,
}
// Update cache
t.balanceCacheMutex.Lock()
t.cachedBalance = result
t.balanceCacheTime = time.Now()
t.balanceCacheMutex.Unlock()
return result, nil
}
// SetMarginMode sets margin mode
func (t *BitgetTrader) SetMarginMode(symbol string, isCrossMargin bool) error {
symbol = t.convertSymbol(symbol)
marginMode := "isolated"
if isCrossMargin {
marginMode = "crossed"
}
body := map[string]interface{}{
"symbol": symbol,
"productType": "USDT-FUTURES",
"marginCoin": "USDT",
"marginMode": marginMode,
}
_, err := t.doRequest("POST", bitgetMarginModePath, body)
if err != nil {
if strings.Contains(err.Error(), "same") || strings.Contains(err.Error(), "already") {
return nil
}
if strings.Contains(err.Error(), "position") {
logger.Infof(" ⚠️ %s has positions, cannot change margin mode", symbol)
return nil
}
return err
}
logger.Infof(" ✓ %s margin mode set to %s", symbol, marginMode)
return nil
}
// SetLeverage sets leverage
func (t *BitgetTrader) SetLeverage(symbol string, leverage int) error {
symbol = t.convertSymbol(symbol)
body := map[string]interface{}{
"symbol": symbol,
"productType": "USDT-FUTURES",
"marginCoin": "USDT",
"leverage": fmt.Sprintf("%d", leverage),
}
_, err := t.doRequest("POST", bitgetLeveragePath, body)
if err != nil {
if strings.Contains(err.Error(), "same") {
return nil
}
logger.Infof(" ⚠️ Failed to set %s leverage: %v", symbol, err)
return err
}
logger.Infof(" ✓ %s leverage set to %dx", symbol, leverage)
return nil
}
// GetMarketPrice gets market price
func (t *BitgetTrader) GetMarketPrice(symbol string) (float64, error) {
symbol = t.convertSymbol(symbol)
params := map[string]interface{}{
"symbol": symbol,
"productType": "USDT-FUTURES",
}
data, err := t.doRequest("GET", bitgetTickerPath, params)
if err != nil {
return 0, fmt.Errorf("failed to get price: %w", err)
}
var tickers []struct {
LastPr string `json:"lastPr"`
}
if err := json.Unmarshal(data, &tickers); err != nil {
return 0, err
}
if len(tickers) == 0 {
return 0, fmt.Errorf("no price data received")
}
price, err := strconv.ParseFloat(tickers[0].LastPr, 64)
if err != nil {
return 0, err
}
return price, nil
}
// GetOrderBook gets the order book for a symbol
// Implements GridTrader interface
func (t *BitgetTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) {
symbol = t.convertSymbol(symbol)
path := fmt.Sprintf("/api/v2/mix/market/depth?symbol=%s&productType=USDT-FUTURES&limit=%d", symbol, depth)
data, err := t.doRequest("GET", path, nil)
if err != nil {
return nil, nil, fmt.Errorf("failed to get order book: %w", err)
}
var result struct {
Bids [][]string `json:"bids"`
Asks [][]string `json:"asks"`
}
if err := json.Unmarshal(data, &result); err != nil {
return nil, nil, fmt.Errorf("failed to parse order book: %w", err)
}
// Parse bids
for _, b := range result.Bids {
if len(b) >= 2 {
price, _ := strconv.ParseFloat(b[0], 64)
qty, _ := strconv.ParseFloat(b[1], 64)
bids = append(bids, []float64{price, qty})
}
}
// Parse asks
for _, a := range result.Asks {
if len(a) >= 2 {
price, _ := strconv.ParseFloat(a[0], 64)
qty, _ := strconv.ParseFloat(a[1], 64)
asks = append(asks, []float64{price, qty})
}
}
return bids, asks, nil
}
+711
View File
@@ -0,0 +1,711 @@
package bitget
import (
"encoding/json"
"fmt"
"nofx/logger"
"nofx/trader/types"
"strconv"
"strings"
)
// OpenLong opens long position
func (t *BitgetTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
symbol = t.convertSymbol(symbol)
// Cancel old orders first
t.CancelAllOrders(symbol)
// Set leverage
if err := t.SetLeverage(symbol, leverage); err != nil {
logger.Infof(" ⚠️ Failed to set leverage: %v", err)
}
// Format quantity
qtyStr, _ := t.FormatQuantity(symbol, quantity)
body := map[string]interface{}{
"symbol": symbol,
"productType": "USDT-FUTURES",
"marginMode": "crossed",
"marginCoin": "USDT",
"side": "buy",
"orderType": "market",
"size": qtyStr,
"clientOid": genBitgetClientOid(),
}
logger.Infof(" 📊 Bitget OpenLong: symbol=%s, qty=%s, leverage=%d", symbol, qtyStr, leverage)
data, err := t.doRequest("POST", bitgetOrderPath, body)
if err != nil {
return nil, fmt.Errorf("failed to open long position: %w", err)
}
var order struct {
OrderId string `json:"orderId"`
ClientOid string `json:"clientOid"`
}
if err := json.Unmarshal(data, &order); err != nil {
return nil, fmt.Errorf("failed to parse order response: %w", err)
}
// Clear cache
t.clearCache()
logger.Infof("✓ Bitget opened long position successfully: %s", symbol)
return map[string]interface{}{
"orderId": order.OrderId,
"symbol": symbol,
"status": "FILLED",
}, nil
}
// OpenShort opens short position
func (t *BitgetTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
symbol = t.convertSymbol(symbol)
// Cancel old orders first
t.CancelAllOrders(symbol)
// Set leverage
if err := t.SetLeverage(symbol, leverage); err != nil {
logger.Infof(" ⚠️ Failed to set leverage: %v", err)
}
// Format quantity
qtyStr, _ := t.FormatQuantity(symbol, quantity)
body := map[string]interface{}{
"symbol": symbol,
"productType": "USDT-FUTURES",
"marginMode": "crossed",
"marginCoin": "USDT",
"side": "sell",
"orderType": "market",
"size": qtyStr,
"clientOid": genBitgetClientOid(),
}
logger.Infof(" 📊 Bitget OpenShort: symbol=%s, qty=%s, leverage=%d", symbol, qtyStr, leverage)
data, err := t.doRequest("POST", bitgetOrderPath, body)
if err != nil {
return nil, fmt.Errorf("failed to open short position: %w", err)
}
var order struct {
OrderId string `json:"orderId"`
ClientOid string `json:"clientOid"`
}
if err := json.Unmarshal(data, &order); err != nil {
return nil, fmt.Errorf("failed to parse order response: %w", err)
}
// Clear cache
t.clearCache()
logger.Infof("✓ Bitget opened short position successfully: %s", symbol)
return map[string]interface{}{
"orderId": order.OrderId,
"symbol": symbol,
"status": "FILLED",
}, nil
}
// CloseLong closes long position
func (t *BitgetTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) {
symbol = t.convertSymbol(symbol)
// If quantity is 0, get current position
if quantity == 0 {
positions, err := t.GetPositions()
if err != nil {
return nil, err
}
for _, pos := range positions {
if pos["symbol"] == symbol && pos["side"] == "long" {
quantity = pos["positionAmt"].(float64)
break
}
}
if quantity == 0 {
return nil, fmt.Errorf("long position not found for %s", symbol)
}
}
// Format quantity
qtyStr, _ := t.FormatQuantity(symbol, quantity)
body := map[string]interface{}{
"symbol": symbol,
"productType": "USDT-FUTURES",
"marginMode": "crossed",
"marginCoin": "USDT",
"side": "sell",
"orderType": "market",
"size": qtyStr,
"reduceOnly": "YES",
"clientOid": genBitgetClientOid(),
}
logger.Infof(" 📊 Bitget CloseLong: symbol=%s, qty=%s", symbol, qtyStr)
data, err := t.doRequest("POST", bitgetOrderPath, body)
if err != nil {
return nil, fmt.Errorf("failed to close long position: %w", err)
}
var order struct {
OrderId string `json:"orderId"`
}
if err := json.Unmarshal(data, &order); err != nil {
return nil, err
}
// Clear cache
t.clearCache()
logger.Infof("✓ Bitget closed long position successfully: %s", symbol)
return map[string]interface{}{
"orderId": order.OrderId,
"symbol": symbol,
"status": "FILLED",
}, nil
}
// CloseShort closes short position
func (t *BitgetTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) {
symbol = t.convertSymbol(symbol)
// If quantity is 0, get current position
if quantity == 0 {
positions, err := t.GetPositions()
if err != nil {
return nil, err
}
for _, pos := range positions {
if pos["symbol"] == symbol && pos["side"] == "short" {
quantity = pos["positionAmt"].(float64)
break
}
}
if quantity == 0 {
return nil, fmt.Errorf("short position not found for %s", symbol)
}
}
// Ensure quantity is positive
if quantity < 0 {
quantity = -quantity
}
// Format quantity
qtyStr, _ := t.FormatQuantity(symbol, quantity)
body := map[string]interface{}{
"symbol": symbol,
"productType": "USDT-FUTURES",
"marginMode": "crossed",
"marginCoin": "USDT",
"side": "buy",
"orderType": "market",
"size": qtyStr,
"reduceOnly": "YES",
"clientOid": genBitgetClientOid(),
}
logger.Infof(" 📊 Bitget CloseShort: symbol=%s, qty=%s", symbol, qtyStr)
data, err := t.doRequest("POST", bitgetOrderPath, body)
if err != nil {
return nil, fmt.Errorf("failed to close short position: %w", err)
}
var order struct {
OrderId string `json:"orderId"`
}
if err := json.Unmarshal(data, &order); err != nil {
return nil, err
}
// Clear cache
t.clearCache()
logger.Infof("✓ Bitget closed short position successfully: %s", symbol)
return map[string]interface{}{
"orderId": order.OrderId,
"symbol": symbol,
"status": "FILLED",
}, nil
}
// SetStopLoss sets stop loss order
func (t *BitgetTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error {
// Bitget V2 uses plan order for stop loss
symbol = t.convertSymbol(symbol)
side := "sell"
holdSide := "long"
if strings.ToUpper(positionSide) == "SHORT" {
side = "buy"
holdSide = "short"
}
qtyStr, _ := t.FormatQuantity(symbol, quantity)
body := map[string]interface{}{
"planType": "loss_plan",
"symbol": symbol,
"productType": "USDT-FUTURES",
"marginMode": "crossed",
"marginCoin": "USDT",
"triggerPrice": fmt.Sprintf("%.8f", stopPrice),
"triggerType": "mark_price",
"side": side,
"tradeSide": "close",
"orderType": "market",
"size": qtyStr,
"holdSide": holdSide,
"clientOid": genBitgetClientOid(),
}
_, err := t.doRequest("POST", "/api/v2/mix/order/place-plan-order", body)
if err != nil {
return fmt.Errorf("failed to set stop loss: %w", err)
}
logger.Infof(" ✓ [Bitget] Stop loss set: %s @ %.4f", symbol, stopPrice)
return nil
}
// SetTakeProfit sets take profit order
func (t *BitgetTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error {
// Bitget V2 uses plan order for take profit
symbol = t.convertSymbol(symbol)
side := "sell"
holdSide := "long"
if strings.ToUpper(positionSide) == "SHORT" {
side = "buy"
holdSide = "short"
}
qtyStr, _ := t.FormatQuantity(symbol, quantity)
body := map[string]interface{}{
"planType": "profit_plan",
"symbol": symbol,
"productType": "USDT-FUTURES",
"marginMode": "crossed",
"marginCoin": "USDT",
"triggerPrice": fmt.Sprintf("%.8f", takeProfitPrice),
"triggerType": "mark_price",
"side": side,
"tradeSide": "close",
"orderType": "market",
"size": qtyStr,
"holdSide": holdSide,
"clientOid": genBitgetClientOid(),
}
_, err := t.doRequest("POST", "/api/v2/mix/order/place-plan-order", body)
if err != nil {
return fmt.Errorf("failed to set take profit: %w", err)
}
logger.Infof(" ✓ [Bitget] Take profit set: %s @ %.4f", symbol, takeProfitPrice)
return nil
}
// CancelStopLossOrders cancels stop loss orders
func (t *BitgetTrader) CancelStopLossOrders(symbol string) error {
return t.cancelPlanOrders(symbol, "loss_plan")
}
// CancelTakeProfitOrders cancels take profit orders
func (t *BitgetTrader) CancelTakeProfitOrders(symbol string) error {
return t.cancelPlanOrders(symbol, "profit_plan")
}
// cancelPlanOrders cancels plan orders
func (t *BitgetTrader) cancelPlanOrders(symbol string, planType string) error {
symbol = t.convertSymbol(symbol)
// Get pending plan orders
params := map[string]interface{}{
"symbol": symbol,
"productType": "USDT-FUTURES",
"planType": planType,
}
data, err := t.doRequest("GET", "/api/v2/mix/order/orders-plan-pending", params)
if err != nil {
return err
}
var orders struct {
EntrustedList []struct {
OrderId string `json:"orderId"`
} `json:"entrustedList"`
}
if err := json.Unmarshal(data, &orders); err != nil {
return err
}
// Cancel each order
for _, order := range orders.EntrustedList {
body := map[string]interface{}{
"symbol": symbol,
"productType": "USDT-FUTURES",
"marginCoin": "USDT",
"orderId": order.OrderId,
}
t.doRequest("POST", "/api/v2/mix/order/cancel-plan-order", body)
}
return nil
}
// CancelAllOrders cancels all pending orders
func (t *BitgetTrader) CancelAllOrders(symbol string) error {
symbol = t.convertSymbol(symbol)
// Get pending orders
params := map[string]interface{}{
"symbol": symbol,
"productType": "USDT-FUTURES",
}
data, err := t.doRequest("GET", bitgetPendingPath, params)
if err != nil {
return err
}
var orders struct {
EntrustedList []struct {
OrderId string `json:"orderId"`
} `json:"entrustedList"`
}
if err := json.Unmarshal(data, &orders); err != nil {
return err
}
// Cancel each order
for _, order := range orders.EntrustedList {
body := map[string]interface{}{
"symbol": symbol,
"productType": "USDT-FUTURES",
"marginCoin": "USDT",
"orderId": order.OrderId,
}
t.doRequest("POST", bitgetCancelOrderPath, body)
}
// Also cancel plan orders
t.cancelPlanOrders(symbol, "loss_plan")
t.cancelPlanOrders(symbol, "profit_plan")
return nil
}
// CancelStopOrders cancels stop loss and take profit orders
func (t *BitgetTrader) CancelStopOrders(symbol string) error {
t.CancelStopLossOrders(symbol)
t.CancelTakeProfitOrders(symbol)
return nil
}
// GetOrderStatus gets order status
func (t *BitgetTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) {
symbol = t.convertSymbol(symbol)
params := map[string]interface{}{
"symbol": symbol,
"productType": "USDT-FUTURES",
"orderId": orderID,
}
data, err := t.doRequest("GET", "/api/v2/mix/order/detail", params)
if err != nil {
return nil, fmt.Errorf("failed to get order status: %w", err)
}
var order struct {
OrderId string `json:"orderId"`
State string `json:"state"` // filled, canceled, partially_filled, new
PriceAvg string `json:"priceAvg"` // Average fill price
BaseVolume string `json:"baseVolume"` // Filled quantity
Fee string `json:"fee"` // Fee
Side string `json:"side"`
OrderType string `json:"orderType"`
CTime string `json:"cTime"`
UTime string `json:"uTime"`
}
if err := json.Unmarshal(data, &order); err != nil {
return nil, err
}
avgPrice, _ := strconv.ParseFloat(order.PriceAvg, 64)
fillQty, _ := strconv.ParseFloat(order.BaseVolume, 64)
fee, _ := strconv.ParseFloat(order.Fee, 64)
cTime, _ := strconv.ParseInt(order.CTime, 10, 64)
uTime, _ := strconv.ParseInt(order.UTime, 10, 64)
// Status mapping
statusMap := map[string]string{
"filled": "FILLED",
"new": "NEW",
"partially_filled": "PARTIALLY_FILLED",
"canceled": "CANCELED",
}
status := statusMap[order.State]
if status == "" {
status = order.State
}
return map[string]interface{}{
"orderId": order.OrderId,
"symbol": symbol,
"status": status,
"avgPrice": avgPrice,
"executedQty": fillQty,
"side": order.Side,
"type": order.OrderType,
"time": cTime,
"updateTime": uTime,
"commission": -fee,
}, nil
}
// GetOpenOrders gets all open/pending orders for a symbol
func (t *BitgetTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) {
symbol = t.convertSymbol(symbol)
var result []types.OpenOrder
// 1. Get pending limit orders
params := map[string]interface{}{
"symbol": symbol,
"productType": "USDT-FUTURES",
}
data, err := t.doRequest("GET", bitgetPendingPath, params)
if err != nil {
logger.Warnf("[Bitget] Failed to get pending orders: %v", err)
}
if err == nil && data != nil {
var orders struct {
EntrustedList []struct {
OrderId string `json:"orderId"`
Symbol string `json:"symbol"`
Side string `json:"side"` // buy/sell
TradeSide string `json:"tradeSide"` // open/close
PosSide string `json:"posSide"` // long/short
OrderType string `json:"orderType"` // limit/market
Price string `json:"price"`
Size string `json:"size"`
State string `json:"state"`
} `json:"entrustedList"`
}
if err := json.Unmarshal(data, &orders); err == nil {
for _, order := range orders.EntrustedList {
price, _ := strconv.ParseFloat(order.Price, 64)
quantity, _ := strconv.ParseFloat(order.Size, 64)
// Convert side to standard format
side := strings.ToUpper(order.Side)
positionSide := strings.ToUpper(order.PosSide)
result = append(result, types.OpenOrder{
OrderID: order.OrderId,
Symbol: symbol,
Side: side,
PositionSide: positionSide,
Type: strings.ToUpper(order.OrderType),
Price: price,
StopPrice: 0,
Quantity: quantity,
Status: "NEW",
})
}
}
}
// 2. Get pending plan orders (stop-loss/take-profit)
// Bitget V2 API requires planType parameter: profit_loss for SL/TP orders
planParams := map[string]interface{}{
"productType": "USDT-FUTURES",
"planType": "profit_loss",
}
planData, err := t.doRequest("GET", "/api/v2/mix/order/orders-plan-pending", planParams)
if err != nil {
logger.Warnf("[Bitget] Failed to get plan orders: %v", err)
}
if err == nil && planData != nil {
var planOrders struct {
EntrustedList []struct {
OrderId string `json:"orderId"`
Symbol string `json:"symbol"`
Side string `json:"side"`
PosSide string `json:"posSide"`
PlanType string `json:"planType"` // pos_loss, pos_profit
TriggerPrice string `json:"triggerPrice"`
StopLossTriggerPrice string `json:"stopLossTriggerPrice"`
StopSurplusTriggerPrice string `json:"stopSurplusTriggerPrice"`
Size string `json:"size"`
PlanStatus string `json:"planStatus"`
} `json:"entrustedList"`
}
if err := json.Unmarshal(planData, &planOrders); err == nil {
for _, order := range planOrders.EntrustedList {
// Filter by symbol if specified
if symbol != "" && order.Symbol != symbol {
continue
}
// Determine trigger price based on plan type
var triggerPrice float64
orderType := "STOP_MARKET"
if order.PlanType == "pos_profit" {
// Take profit order
orderType = "TAKE_PROFIT_MARKET"
if order.StopSurplusTriggerPrice != "" {
triggerPrice, _ = strconv.ParseFloat(order.StopSurplusTriggerPrice, 64)
} else {
triggerPrice, _ = strconv.ParseFloat(order.TriggerPrice, 64)
}
} else {
// Stop loss order (pos_loss)
if order.StopLossTriggerPrice != "" {
triggerPrice, _ = strconv.ParseFloat(order.StopLossTriggerPrice, 64)
} else {
triggerPrice, _ = strconv.ParseFloat(order.TriggerPrice, 64)
}
}
quantity, _ := strconv.ParseFloat(order.Size, 64)
side := strings.ToUpper(order.Side)
positionSide := strings.ToUpper(order.PosSide)
result = append(result, types.OpenOrder{
OrderID: order.OrderId,
Symbol: order.Symbol,
Side: side,
PositionSide: positionSide,
Type: orderType,
Price: 0,
StopPrice: triggerPrice,
Quantity: quantity,
Status: "NEW",
})
}
}
}
logger.Infof("✓ BITGET GetOpenOrders: found %d open orders for %s", len(result), symbol)
return result, nil
}
// PlaceLimitOrder places a limit order for grid trading
// Implements GridTrader interface
func (t *BitgetTrader) PlaceLimitOrder(req *types.LimitOrderRequest) (*types.LimitOrderResult, error) {
symbol := t.convertSymbol(req.Symbol)
// Set leverage if specified
if req.Leverage > 0 {
if err := t.SetLeverage(symbol, req.Leverage); err != nil {
logger.Warnf("[Bitget] Failed to set leverage: %v", err)
}
}
// Format quantity
qtyStr, _ := t.FormatQuantity(symbol, req.Quantity)
// Determine side
side := "buy"
if req.Side == "SELL" {
side = "sell"
}
body := map[string]interface{}{
"symbol": symbol,
"productType": "USDT-FUTURES",
"marginMode": "crossed",
"marginCoin": "USDT",
"side": side,
"orderType": "limit",
"size": qtyStr,
"price": fmt.Sprintf("%.8f", req.Price),
"force": "GTC", // Good Till Cancel
"clientOid": genBitgetClientOid(),
}
// Add reduce only if specified
if req.ReduceOnly {
body["reduceOnly"] = "YES"
}
logger.Infof("[Bitget] PlaceLimitOrder: %s %s @ %.4f, qty=%s", symbol, side, req.Price, qtyStr)
data, err := t.doRequest("POST", bitgetOrderPath, body)
if err != nil {
return nil, fmt.Errorf("failed to place limit order: %w", err)
}
var order struct {
OrderId string `json:"orderId"`
ClientOid string `json:"clientOid"`
}
if err := json.Unmarshal(data, &order); err != nil {
return nil, fmt.Errorf("failed to parse order response: %w", err)
}
logger.Infof("✓ [Bitget] Limit order placed: %s %s @ %.4f, orderID=%s",
symbol, side, req.Price, order.OrderId)
return &types.LimitOrderResult{
OrderID: order.OrderId,
ClientID: order.ClientOid,
Symbol: req.Symbol,
Side: req.Side,
PositionSide: req.PositionSide,
Price: req.Price,
Quantity: req.Quantity,
Status: "NEW",
}, nil
}
// CancelOrder cancels a specific order by ID
// Implements GridTrader interface
func (t *BitgetTrader) CancelOrder(symbol, orderID string) error {
symbol = t.convertSymbol(symbol)
body := map[string]interface{}{
"symbol": symbol,
"productType": "USDT-FUTURES",
"orderId": orderID,
}
_, err := t.doRequest("POST", "/api/v2/mix/order/cancel-order", body)
if err != nil {
return fmt.Errorf("failed to cancel order: %w", err)
}
logger.Infof("✓ [Bitget] Order cancelled: %s %s", symbol, orderID)
return nil
}
+160
View File
@@ -0,0 +1,160 @@
package bitget
import (
"encoding/json"
"fmt"
"nofx/trader/types"
"strconv"
"time"
)
// GetPositions gets all positions
func (t *BitgetTrader) GetPositions() ([]map[string]interface{}, error) {
// Check cache
t.positionsCacheMutex.RLock()
if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration {
t.positionsCacheMutex.RUnlock()
return t.cachedPositions, nil
}
t.positionsCacheMutex.RUnlock()
params := map[string]interface{}{
"productType": "USDT-FUTURES",
"marginCoin": "USDT",
}
data, err := t.doRequest("GET", bitgetPositionPath, params)
if err != nil {
return nil, fmt.Errorf("failed to get positions: %w", err)
}
var positions []struct {
Symbol string `json:"symbol"`
HoldSide string `json:"holdSide"` // long, short
OpenPriceAvg string `json:"openPriceAvg"` // Average entry price
MarkPrice string `json:"markPrice"` // Mark price
Total string `json:"total"` // Total position size
Available string `json:"available"` // Available to close
UnrealizedPL string `json:"unrealizedPL"` // Unrealized P&L
Leverage string `json:"leverage"` // Leverage
LiquidationPrice string `json:"liquidationPrice"` // Liquidation price
MarginSize string `json:"marginSize"` // Position margin
CTime string `json:"cTime"` // Create time
UTime string `json:"uTime"` // Update time
}
if err := json.Unmarshal(data, &positions); err != nil {
return nil, fmt.Errorf("failed to parse position data: %w", err)
}
var result []map[string]interface{}
for _, pos := range positions {
total, _ := strconv.ParseFloat(pos.Total, 64)
if total == 0 {
continue
}
entryPrice, _ := strconv.ParseFloat(pos.OpenPriceAvg, 64)
markPrice, _ := strconv.ParseFloat(pos.MarkPrice, 64)
unrealizedPnL, _ := strconv.ParseFloat(pos.UnrealizedPL, 64)
leverage, _ := strconv.ParseFloat(pos.Leverage, 64)
liqPrice, _ := strconv.ParseFloat(pos.LiquidationPrice, 64)
cTime, _ := strconv.ParseInt(pos.CTime, 10, 64)
uTime, _ := strconv.ParseInt(pos.UTime, 10, 64)
// Normalize side
side := "long"
if pos.HoldSide == "short" {
side = "short"
}
posMap := map[string]interface{}{
"symbol": pos.Symbol,
"positionAmt": total,
"entryPrice": entryPrice,
"markPrice": markPrice,
"unRealizedProfit": unrealizedPnL,
"leverage": leverage,
"liquidationPrice": liqPrice,
"side": side,
"createdTime": cTime,
"updatedTime": uTime,
}
result = append(result, posMap)
}
// Update cache
t.positionsCacheMutex.Lock()
t.cachedPositions = result
t.positionsCacheTime = time.Now()
t.positionsCacheMutex.Unlock()
return result, nil
}
// GetClosedPnL retrieves closed position PnL records
func (t *BitgetTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) {
if limit <= 0 {
limit = 100
}
if limit > 100 {
limit = 100
}
params := map[string]interface{}{
"productType": "USDT-FUTURES",
"startTime": fmt.Sprintf("%d", startTime.UnixMilli()),
"limit": fmt.Sprintf("%d", limit),
}
data, err := t.doRequest("GET", "/api/v2/mix/position/history-position", params)
if err != nil {
return nil, fmt.Errorf("failed to get positions history: %w", err)
}
var resp struct {
List []struct {
Symbol string `json:"symbol"`
HoldSide string `json:"holdSide"`
OpenPriceAvg string `json:"openPriceAvg"`
ClosePriceAvg string `json:"closePriceAvg"`
CloseVol string `json:"closeVol"`
AchievedProfits string `json:"achievedProfits"`
TotalFee string `json:"totalFee"`
Leverage string `json:"leverage"`
CTime string `json:"cTime"`
UTime string `json:"uTime"`
} `json:"list"`
}
if err := json.Unmarshal(data, &resp); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
records := make([]types.ClosedPnLRecord, 0, len(resp.List))
for _, pos := range resp.List {
record := types.ClosedPnLRecord{
Symbol: pos.Symbol,
Side: pos.HoldSide,
}
record.EntryPrice, _ = strconv.ParseFloat(pos.OpenPriceAvg, 64)
record.ExitPrice, _ = strconv.ParseFloat(pos.ClosePriceAvg, 64)
record.Quantity, _ = strconv.ParseFloat(pos.CloseVol, 64)
record.RealizedPnL, _ = strconv.ParseFloat(pos.AchievedProfits, 64)
fee, _ := strconv.ParseFloat(pos.TotalFee, 64)
record.Fee = -fee
lev, _ := strconv.ParseFloat(pos.Leverage, 64)
record.Leverage = int(lev)
cTime, _ := strconv.ParseInt(pos.CTime, 10, 64)
uTime, _ := strconv.ParseInt(pos.UTime, 10, 64)
record.EntryTime = time.UnixMilli(cTime).UTC()
record.ExitTime = time.UnixMilli(uTime).UTC()
record.CloseType = "unknown"
records = append(records, record)
}
return records, nil
}
File diff suppressed because it is too large Load Diff
+236
View File
@@ -0,0 +1,236 @@
package bybit
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"nofx/trader/types"
"strconv"
"time"
)
// GetBalance retrieves account balance
func (t *BybitTrader) GetBalance() (map[string]interface{}, error) {
// Check cache
t.balanceCacheMutex.RLock()
if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration {
balance := t.cachedBalance
t.balanceCacheMutex.RUnlock()
return balance, nil
}
t.balanceCacheMutex.RUnlock()
// Call API
params := map[string]interface{}{
"accountType": "UNIFIED",
}
result, err := t.client.NewUtaBybitServiceWithParams(params).GetAccountWallet(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get Bybit balance: %w", err)
}
if result.RetCode != 0 {
return nil, fmt.Errorf("Bybit API error: %s", result.RetMsg)
}
// Extract balance information
resultData, ok := result.Result.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("Bybit balance return format error")
}
list, _ := resultData["list"].([]interface{})
var totalEquity, availableBalance, totalWalletBalance, totalPerpUPL float64 = 0, 0, 0, 0
if len(list) > 0 {
account, _ := list[0].(map[string]interface{})
if equityStr, ok := account["totalEquity"].(string); ok {
totalEquity, _ = strconv.ParseFloat(equityStr, 64)
}
if availStr, ok := account["totalAvailableBalance"].(string); ok {
availableBalance, _ = strconv.ParseFloat(availStr, 64)
}
// Bybit UNIFIED account wallet balance field
if walletStr, ok := account["totalWalletBalance"].(string); ok {
totalWalletBalance, _ = strconv.ParseFloat(walletStr, 64)
}
// Bybit perpetual contract unrealized PnL
if uplStr, ok := account["totalPerpUPL"].(string); ok {
totalPerpUPL, _ = strconv.ParseFloat(uplStr, 64)
}
}
// If no totalWalletBalance, use totalEquity
if totalWalletBalance == 0 {
totalWalletBalance = totalEquity
}
balance := map[string]interface{}{
"totalEquity": totalEquity,
"totalWalletBalance": totalWalletBalance,
"availableBalance": availableBalance,
"totalUnrealizedProfit": totalPerpUPL,
"balance": totalEquity, // Compatible with other exchange formats
}
// Update cache
t.balanceCacheMutex.Lock()
t.cachedBalance = balance
t.balanceCacheTime = time.Now()
t.balanceCacheMutex.Unlock()
return balance, nil
}
// GetClosedPnL retrieves closed position PnL records from Bybit via direct HTTP API
func (t *BybitTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) {
// The Bybit SDK doesn't expose the closed-pnl endpoint, use direct HTTP call
return t.getClosedPnLViaHTTP(startTime, limit)
}
// getClosedPnLViaHTTP makes direct HTTP call to Bybit API for closed PnL with proper signing
func (t *BybitTrader) getClosedPnLViaHTTP(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) {
// Build query string
queryParams := fmt.Sprintf("category=linear&startTime=%d&limit=%d", startTime.UnixMilli(), limit)
url := "https://api.bybit.com/v5/position/closed-pnl?" + queryParams
// Generate timestamp
timestamp := fmt.Sprintf("%d", time.Now().UnixMilli())
recvWindow := "5000"
// Build signature payload: timestamp + api_key + recv_window + queryString
signPayload := timestamp + t.apiKey + recvWindow + queryParams
// Generate HMAC-SHA256 signature
h := hmac.New(sha256.New, []byte(t.secretKey))
h.Write([]byte(signPayload))
signature := hex.EncodeToString(h.Sum(nil))
// Create request
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
// Add Bybit V5 API headers
req.Header.Set("X-BAPI-API-KEY", t.apiKey)
req.Header.Set("X-BAPI-SIGN", signature)
req.Header.Set("X-BAPI-SIGN-TYPE", "2")
req.Header.Set("X-BAPI-TIMESTAMP", timestamp)
req.Header.Set("X-BAPI-RECV-WINDOW", recvWindow)
req.Header.Set("Content-Type", "application/json")
// Use http.DefaultClient for the request
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to call Bybit API: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
var result struct {
RetCode int `json:"retCode"`
RetMsg string `json:"retMsg"`
Result map[string]interface{} `json:"result"`
}
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
if result.RetCode != 0 {
return nil, fmt.Errorf("Bybit API error: %s", result.RetMsg)
}
return t.parseClosedPnLResult(result.Result)
}
// parseClosedPnLResult parses the closed PnL result from Bybit API
func (t *BybitTrader) parseClosedPnLResult(resultData interface{}) ([]types.ClosedPnLRecord, error) {
data, ok := resultData.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid result format")
}
list, _ := data["list"].([]interface{})
var records []types.ClosedPnLRecord
for _, item := range list {
pnl, ok := item.(map[string]interface{})
if !ok {
continue
}
// Parse fields
symbol, _ := pnl["symbol"].(string)
side, _ := pnl["side"].(string)
orderId, _ := pnl["orderId"].(string)
avgEntryPriceStr, _ := pnl["avgEntryPrice"].(string)
avgExitPriceStr, _ := pnl["avgExitPrice"].(string)
qtyStr, _ := pnl["qty"].(string)
closedPnLStr, _ := pnl["closedPnl"].(string)
cumEntryValueStr, _ := pnl["cumEntryValue"].(string)
cumExitValueStr, _ := pnl["cumExitValue"].(string)
leverageStr, _ := pnl["leverage"].(string)
createdTimeStr, _ := pnl["createdTime"].(string)
updatedTimeStr, _ := pnl["updatedTime"].(string)
avgEntryPrice, _ := strconv.ParseFloat(avgEntryPriceStr, 64)
avgExitPrice, _ := strconv.ParseFloat(avgExitPriceStr, 64)
qty, _ := strconv.ParseFloat(qtyStr, 64)
closedPnL, _ := strconv.ParseFloat(closedPnLStr, 64)
leverage, _ := strconv.ParseInt(leverageStr, 10, 64)
createdTime, _ := strconv.ParseInt(createdTimeStr, 10, 64)
updatedTime, _ := strconv.ParseInt(updatedTimeStr, 10, 64)
// Calculate approximate fee from value difference
cumEntryValue, _ := strconv.ParseFloat(cumEntryValueStr, 64)
cumExitValue, _ := strconv.ParseFloat(cumExitValueStr, 64)
expectedPnL := cumExitValue - cumEntryValue
if side == "Sell" {
expectedPnL = cumEntryValue - cumExitValue
}
fee := expectedPnL - closedPnL
if fee < 0 {
fee = 0
}
// Normalize side
normalizedSide := "long"
if side == "Sell" {
normalizedSide = "short"
}
record := types.ClosedPnLRecord{
Symbol: symbol,
Side: normalizedSide,
EntryPrice: avgEntryPrice,
ExitPrice: avgExitPrice,
Quantity: qty,
RealizedPnL: closedPnL,
Fee: fee,
Leverage: int(leverage),
EntryTime: time.UnixMilli(createdTime).UTC(),
ExitTime: time.UnixMilli(updatedTime).UTC(),
OrderID: orderId,
CloseType: "unknown", // Bybit doesn't provide close type directly
ExchangeID: orderId, // Use orderId as exchange ID
}
records = append(records, record)
}
return records, nil
}
+741
View File
@@ -0,0 +1,741 @@
package bybit
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"nofx/logger"
"nofx/trader/types"
"strconv"
"strings"
)
// OpenLong opens a long position
func (t *BybitTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
logger.Infof("[Bybit] ===== OpenLong called: symbol=%s, qty=%.6f, leverage=%d =====", symbol, quantity, leverage)
// First cancel all pending orders for this symbol (clean up old orders)
if err := t.CancelAllOrders(symbol); err != nil {
logger.Infof("⚠️ [Bybit] Failed to cancel old pending orders: %v", err)
}
// Also cancel conditional orders (stop-loss/take-profit) - Bybit keeps them separate
if err := t.CancelStopOrders(symbol); err != nil {
logger.Infof("⚠️ [Bybit] Failed to cancel old stop orders: %v", err)
}
// Set leverage first
if err := t.SetLeverage(symbol, leverage); err != nil {
logger.Infof("⚠️ [Bybit] Failed to set leverage: %v", err)
}
// Use FormatQuantity to format quantity
qtyStr, _ := t.FormatQuantity(symbol, quantity)
params := map[string]interface{}{
"category": "linear",
"symbol": symbol,
"side": "Buy",
"orderType": "Market",
"qty": qtyStr,
"positionIdx": 0, // One-way position mode
}
logger.Infof("[Bybit] OpenLong placing order: %+v", params)
result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background())
if err != nil {
return nil, fmt.Errorf("Bybit open long failed: %w", err)
}
// Clear cache
t.clearCache()
return t.parseOrderResult(result)
}
// OpenShort opens a short position
func (t *BybitTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
logger.Infof("[Bybit] ===== OpenShort called: symbol=%s, qty=%.6f, leverage=%d =====", symbol, quantity, leverage)
// First cancel all pending orders for this symbol (clean up old orders)
if err := t.CancelAllOrders(symbol); err != nil {
logger.Infof("⚠️ [Bybit] Failed to cancel old pending orders: %v", err)
}
// Also cancel conditional orders (stop-loss/take-profit) - Bybit keeps them separate
if err := t.CancelStopOrders(symbol); err != nil {
logger.Infof("⚠️ [Bybit] Failed to cancel old stop orders: %v", err)
}
// Set leverage first
if err := t.SetLeverage(symbol, leverage); err != nil {
logger.Infof("⚠️ [Bybit] Failed to set leverage: %v", err)
}
// Use FormatQuantity to format quantity
qtyStr, _ := t.FormatQuantity(symbol, quantity)
params := map[string]interface{}{
"category": "linear",
"symbol": symbol,
"side": "Sell",
"orderType": "Market",
"qty": qtyStr,
"positionIdx": 0, // One-way position mode
}
logger.Infof("[Bybit] OpenShort placing order: %+v", params)
result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background())
if err != nil {
return nil, fmt.Errorf("Bybit open short failed: %w", err)
}
// Clear cache
t.clearCache()
return t.parseOrderResult(result)
}
// CloseLong closes a long position
func (t *BybitTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) {
// If quantity = 0, get current position quantity
if quantity == 0 {
positions, err := t.GetPositions()
if err != nil {
return nil, err
}
for _, pos := range positions {
side, _ := pos["side"].(string)
if pos["symbol"] == symbol && strings.ToLower(side) == "long" {
quantity = pos["positionAmt"].(float64)
break
}
}
}
if quantity <= 0 {
return nil, fmt.Errorf("no long position to close")
}
// Use FormatQuantity to format quantity
qtyStr, _ := t.FormatQuantity(symbol, quantity)
params := map[string]interface{}{
"category": "linear",
"symbol": symbol,
"side": "Sell", // Close long with Sell
"orderType": "Market",
"qty": qtyStr,
"positionIdx": 0,
"reduceOnly": true,
}
result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background())
if err != nil {
return nil, fmt.Errorf("Bybit close long failed: %w", err)
}
// Clear cache
t.clearCache()
return t.parseOrderResult(result)
}
// CloseShort closes a short position
func (t *BybitTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) {
// If quantity = 0, get current position quantity
if quantity == 0 {
positions, err := t.GetPositions()
if err != nil {
return nil, err
}
for _, pos := range positions {
side, _ := pos["side"].(string)
if pos["symbol"] == symbol && strings.ToLower(side) == "short" {
quantity = -pos["positionAmt"].(float64) // Short position is negative
break
}
}
}
if quantity <= 0 {
return nil, fmt.Errorf("no short position to close")
}
// Use FormatQuantity to format quantity
qtyStr, _ := t.FormatQuantity(symbol, quantity)
params := map[string]interface{}{
"category": "linear",
"symbol": symbol,
"side": "Buy", // Close short with Buy
"orderType": "Market",
"qty": qtyStr,
"positionIdx": 0,
"reduceOnly": true,
}
result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background())
if err != nil {
return nil, fmt.Errorf("Bybit close short failed: %w", err)
}
// Clear cache
t.clearCache()
return t.parseOrderResult(result)
}
// SetLeverage sets leverage
func (t *BybitTrader) SetLeverage(symbol string, leverage int) error {
params := map[string]interface{}{
"category": "linear",
"symbol": symbol,
"buyLeverage": fmt.Sprintf("%d", leverage),
"sellLeverage": fmt.Sprintf("%d", leverage),
}
result, err := t.client.NewUtaBybitServiceWithParams(params).SetPositionLeverage(context.Background())
if err != nil {
// If leverage is already at target value, Bybit will return an error, ignore this case
if strings.Contains(err.Error(), "leverage not modified") {
return nil
}
return fmt.Errorf("failed to set leverage: %w", err)
}
if result.RetCode != 0 && result.RetCode != 110043 { // 110043 = leverage not modified
return fmt.Errorf("failed to set leverage: %s", result.RetMsg)
}
return nil
}
// SetMarginMode sets position margin mode
func (t *BybitTrader) SetMarginMode(symbol string, isCrossMargin bool) error {
tradeMode := 1 // Isolated margin
if isCrossMargin {
tradeMode = 0 // Cross margin
}
params := map[string]interface{}{
"category": "linear",
"symbol": symbol,
"tradeMode": tradeMode,
}
result, err := t.client.NewUtaBybitServiceWithParams(params).SwitchPositionMargin(context.Background())
if err != nil {
if strings.Contains(err.Error(), "Cross/isolated margin mode is not modified") {
return nil
}
return fmt.Errorf("failed to set margin mode: %w", err)
}
if result.RetCode != 0 && result.RetCode != 110026 { // already in target mode
return fmt.Errorf("failed to set margin mode: %s", result.RetMsg)
}
return nil
}
// GetMarketPrice retrieves market price
func (t *BybitTrader) GetMarketPrice(symbol string) (float64, error) {
params := map[string]interface{}{
"category": "linear",
"symbol": symbol,
}
result, err := t.client.NewUtaBybitServiceWithParams(params).GetMarketTickers(context.Background())
if err != nil {
return 0, fmt.Errorf("failed to get market price: %w", err)
}
if result.RetCode != 0 {
return 0, fmt.Errorf("API error: %s", result.RetMsg)
}
resultData, ok := result.Result.(map[string]interface{})
if !ok {
return 0, fmt.Errorf("return format error")
}
list, _ := resultData["list"].([]interface{})
if len(list) == 0 {
return 0, fmt.Errorf("price data not found for %s", symbol)
}
ticker, _ := list[0].(map[string]interface{})
lastPriceStr, _ := ticker["lastPrice"].(string)
lastPrice, err := strconv.ParseFloat(lastPriceStr, 64)
if err != nil {
return 0, fmt.Errorf("failed to parse price: %w", err)
}
return lastPrice, nil
}
// SetStopLoss sets stop loss order
func (t *BybitTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error {
side := "Sell" // LONG stop loss uses Sell
if positionSide == "SHORT" {
side = "Buy" // SHORT stop loss uses Buy
}
// Get current price to determine triggerDirection
currentPrice, err := t.GetMarketPrice(symbol)
if err != nil {
return err
}
triggerDirection := 2 // Price fall trigger (default long stop loss)
if stopPrice > currentPrice {
triggerDirection = 1 // Price rise trigger (short stop loss)
}
// Use FormatQuantity to format quantity
qtyStr, _ := t.FormatQuantity(symbol, quantity)
params := map[string]interface{}{
"category": "linear",
"symbol": symbol,
"side": side,
"orderType": "Market",
"qty": qtyStr,
"triggerPrice": fmt.Sprintf("%v", stopPrice),
"triggerDirection": triggerDirection,
"triggerBy": "LastPrice",
"reduceOnly": true,
}
result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background())
if err != nil {
return fmt.Errorf("failed to set stop loss: %w", err)
}
if result.RetCode != 0 {
return fmt.Errorf("failed to set stop loss: %s", result.RetMsg)
}
logger.Infof(" ✓ [Bybit] Stop loss order set: %s @ %.2f", symbol, stopPrice)
return nil
}
// SetTakeProfit sets take profit order
func (t *BybitTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error {
side := "Sell" // LONG take profit uses Sell
if positionSide == "SHORT" {
side = "Buy" // SHORT take profit uses Buy
}
// Get current price to determine triggerDirection
currentPrice, err := t.GetMarketPrice(symbol)
if err != nil {
return err
}
triggerDirection := 1 // Price rise trigger (default long take profit)
if takeProfitPrice < currentPrice {
triggerDirection = 2 // Price fall trigger (short take profit)
}
// Use FormatQuantity to format quantity
qtyStr, _ := t.FormatQuantity(symbol, quantity)
params := map[string]interface{}{
"category": "linear",
"symbol": symbol,
"side": side,
"orderType": "Market",
"qty": qtyStr,
"triggerPrice": fmt.Sprintf("%v", takeProfitPrice),
"triggerDirection": triggerDirection,
"triggerBy": "LastPrice",
"reduceOnly": true,
}
result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background())
if err != nil {
return fmt.Errorf("failed to set take profit: %w", err)
}
if result.RetCode != 0 {
return fmt.Errorf("failed to set take profit: %s", result.RetMsg)
}
logger.Infof(" ✓ [Bybit] Take profit order set: %s @ %.2f", symbol, takeProfitPrice)
return nil
}
// CancelStopLossOrders cancels stop loss orders
func (t *BybitTrader) CancelStopLossOrders(symbol string) error {
return t.cancelConditionalOrders(symbol, "StopLoss")
}
// CancelTakeProfitOrders cancels take profit orders
func (t *BybitTrader) CancelTakeProfitOrders(symbol string) error {
return t.cancelConditionalOrders(symbol, "TakeProfit")
}
// CancelAllOrders cancels all pending orders
func (t *BybitTrader) CancelAllOrders(symbol string) error {
params := map[string]interface{}{
"category": "linear",
"symbol": symbol,
}
_, err := t.client.NewUtaBybitServiceWithParams(params).CancelAllOrders(context.Background())
if err != nil {
return fmt.Errorf("failed to cancel all orders: %w", err)
}
return nil
}
// CancelStopOrders cancels all stop loss and take profit orders
func (t *BybitTrader) CancelStopOrders(symbol string) error {
if err := t.CancelStopLossOrders(symbol); err != nil {
logger.Infof("⚠️ [Bybit] Failed to cancel stop loss orders: %v", err)
}
if err := t.CancelTakeProfitOrders(symbol); err != nil {
logger.Infof("⚠️ [Bybit] Failed to cancel take profit orders: %v", err)
}
return nil
}
func (t *BybitTrader) cancelConditionalOrders(symbol string, orderType string) error {
// First get all conditional orders
params := map[string]interface{}{
"category": "linear",
"symbol": symbol,
"orderFilter": "StopOrder", // Conditional orders
}
result, err := t.client.NewUtaBybitServiceWithParams(params).GetOpenOrders(context.Background())
if err != nil {
return fmt.Errorf("failed to get conditional orders: %w", err)
}
if result.RetCode != 0 {
return nil // No orders
}
resultData, ok := result.Result.(map[string]interface{})
if !ok {
return nil
}
list, _ := resultData["list"].([]interface{})
// Cancel matching orders
for _, item := range list {
order, ok := item.(map[string]interface{})
if !ok {
continue
}
orderId, _ := order["orderId"].(string)
stopOrderType, _ := order["stopOrderType"].(string)
// Filter by type
shouldCancel := false
if orderType == "StopLoss" && (stopOrderType == "StopLoss" || stopOrderType == "Stop") {
shouldCancel = true
}
if orderType == "TakeProfit" && (stopOrderType == "TakeProfit" || stopOrderType == "PartialTakeProfit") {
shouldCancel = true
}
if shouldCancel && orderId != "" {
cancelParams := map[string]interface{}{
"category": "linear",
"symbol": symbol,
"orderId": orderId,
}
t.client.NewUtaBybitServiceWithParams(cancelParams).CancelOrder(context.Background())
}
}
return nil
}
// GetOrderStatus retrieves order status
func (t *BybitTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) {
params := map[string]interface{}{
"category": "linear",
"symbol": symbol,
"orderId": orderID,
}
result, err := t.client.NewUtaBybitServiceWithParams(params).GetOrderHistory(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get order status: %w", err)
}
if result.RetCode != 0 {
return nil, fmt.Errorf("API error: %s", result.RetMsg)
}
resultData, ok := result.Result.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("return format error")
}
list, _ := resultData["list"].([]interface{})
if len(list) == 0 {
return nil, fmt.Errorf("order %s not found", orderID)
}
order, _ := list[0].(map[string]interface{})
// Parse order data
status, _ := order["orderStatus"].(string)
avgPriceStr, _ := order["avgPrice"].(string)
cumExecQtyStr, _ := order["cumExecQty"].(string)
cumExecFeeStr, _ := order["cumExecFee"].(string)
avgPrice, _ := strconv.ParseFloat(avgPriceStr, 64)
executedQty, _ := strconv.ParseFloat(cumExecQtyStr, 64)
commission, _ := strconv.ParseFloat(cumExecFeeStr, 64)
// Convert status to unified format
unifiedStatus := status
switch status {
case "Filled":
unifiedStatus = "FILLED"
case "New", "Created":
unifiedStatus = "NEW"
case "Cancelled", "Rejected":
unifiedStatus = "CANCELED"
case "PartiallyFilled":
unifiedStatus = "PARTIALLY_FILLED"
}
return map[string]interface{}{
"orderId": orderID,
"status": unifiedStatus,
"avgPrice": avgPrice,
"executedQty": executedQty,
"commission": commission,
}, nil
}
// GetOpenOrders gets all open/pending orders for a symbol
func (t *BybitTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) {
var result []types.OpenOrder
// Get conditional orders (stop-loss, take-profit)
params := map[string]interface{}{
"category": "linear",
"symbol": symbol,
"orderFilter": "StopOrder",
}
resp, err := t.client.NewUtaBybitServiceWithParams(params).GetOpenOrders(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get open orders: %w", err)
}
if resp.RetCode == 0 {
resultData, ok := resp.Result.(map[string]interface{})
if ok {
list, _ := resultData["list"].([]interface{})
for _, item := range list {
order, ok := item.(map[string]interface{})
if !ok {
continue
}
orderId, _ := order["orderId"].(string)
sym, _ := order["symbol"].(string)
side, _ := order["side"].(string)
orderType, _ := order["orderType"].(string)
stopOrderType, _ := order["stopOrderType"].(string)
triggerPrice, _ := order["triggerPrice"].(string)
qty, _ := order["qty"].(string)
price, _ := strconv.ParseFloat(triggerPrice, 64)
quantity, _ := strconv.ParseFloat(qty, 64)
// Determine type based on stopOrderType
displayType := orderType
if stopOrderType != "" {
displayType = stopOrderType
}
result = append(result, types.OpenOrder{
OrderID: orderId,
Symbol: sym,
Side: side,
PositionSide: "", // Bybit doesn't use positionSide for UTA
Type: displayType,
Price: 0,
StopPrice: price,
Quantity: quantity,
Status: "NEW",
})
}
}
}
return result, nil
}
// PlaceLimitOrder places a limit order for grid trading
// Implements GridTrader interface
func (t *BybitTrader) PlaceLimitOrder(req *types.LimitOrderRequest) (*types.LimitOrderResult, error) {
// Format quantity
qtyStr, err := t.FormatQuantity(req.Symbol, req.Quantity)
if err != nil {
return nil, fmt.Errorf("failed to format quantity: %w", err)
}
// Format price
priceStr := fmt.Sprintf("%.8f", req.Price)
// Set leverage if specified
if req.Leverage > 0 {
if err := t.SetLeverage(req.Symbol, req.Leverage); err != nil {
logger.Warnf("[Bybit] Failed to set leverage: %v", err)
}
}
// Determine side
side := "Buy"
if req.Side == "SELL" {
side = "Sell"
}
params := map[string]interface{}{
"category": "linear",
"symbol": req.Symbol,
"side": side,
"orderType": "Limit",
"qty": qtyStr,
"price": priceStr,
"timeInForce": "GTC", // Good Till Cancel
"positionIdx": 0, // One-way position mode
}
// Add reduce only if specified
if req.ReduceOnly {
params["reduceOnly"] = true
}
logger.Infof("[Bybit] PlaceLimitOrder: %s %s @ %s, qty=%s", req.Symbol, side, priceStr, qtyStr)
result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to place limit order: %w", err)
}
// Parse result
orderID := ""
if result.RetCode == 0 {
if resultData, ok := result.Result.(map[string]interface{}); ok {
if id, ok := resultData["orderId"].(string); ok {
orderID = id
}
}
} else {
return nil, fmt.Errorf("Bybit order failed: %s", result.RetMsg)
}
logger.Infof("✓ [Bybit] Limit order placed: %s %s @ %s, qty=%s, orderID=%s",
req.Symbol, side, priceStr, qtyStr, orderID)
return &types.LimitOrderResult{
OrderID: orderID,
ClientID: req.ClientID,
Symbol: req.Symbol,
Side: req.Side,
PositionSide: req.PositionSide,
Price: req.Price,
Quantity: req.Quantity,
Status: "NEW",
}, nil
}
// CancelOrder cancels a specific order by ID
// Implements GridTrader interface
func (t *BybitTrader) CancelOrder(symbol, orderID string) error {
params := map[string]interface{}{
"category": "linear",
"symbol": symbol,
"orderId": orderID,
}
result, err := t.client.NewUtaBybitServiceWithParams(params).CancelOrder(context.Background())
if err != nil {
return fmt.Errorf("failed to cancel order: %w", err)
}
if result.RetCode != 0 {
return fmt.Errorf("Bybit cancel order failed: %s", result.RetMsg)
}
logger.Infof("✓ [Bybit] Order cancelled: %s %s", symbol, orderID)
return nil
}
// GetOrderBook gets the order book for a symbol
// Implements GridTrader interface
func (t *BybitTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) {
if depth <= 0 {
depth = 25
}
// Use HTTP request directly since the SDK doesn't expose GetOrderbook
url := fmt.Sprintf("https://api.bybit.com/v5/market/orderbook?category=linear&symbol=%s&limit=%d", symbol, depth)
resp, err := http.Get(url)
if err != nil {
return nil, nil, fmt.Errorf("failed to get order book: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
var result struct {
RetCode int `json:"retCode"`
RetMsg string `json:"retMsg"`
Result struct {
S string `json:"s"` // symbol
B [][]string `json:"b"` // bids [[price, size], ...]
A [][]string `json:"a"` // asks [[price, size], ...]
} `json:"result"`
}
if err := json.Unmarshal(body, &result); err != nil {
return nil, nil, fmt.Errorf("failed to parse order book: %w", err)
}
if result.RetCode != 0 {
return nil, nil, fmt.Errorf("Bybit get orderbook failed: %s", result.RetMsg)
}
// Parse bids
for _, b := range result.Result.B {
if len(b) >= 2 {
price, _ := strconv.ParseFloat(b[0], 64)
qty, _ := strconv.ParseFloat(b[1], 64)
bids = append(bids, []float64{price, qty})
}
}
// Parse asks
for _, a := range result.Result.A {
if len(a) >= 2 {
price, _ := strconv.ParseFloat(a[0], 64)
qty, _ := strconv.ParseFloat(a[1], 64)
asks = append(asks, []float64{price, qty})
}
}
return bids, asks, nil
}
+125
View File
@@ -0,0 +1,125 @@
package bybit
import (
"context"
"fmt"
"nofx/logger"
"strconv"
"strings"
"time"
)
// GetPositions retrieves all positions
func (t *BybitTrader) GetPositions() ([]map[string]interface{}, error) {
// Check cache
t.positionsCacheMutex.RLock()
if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration {
positions := t.cachedPositions
t.positionsCacheMutex.RUnlock()
return positions, nil
}
t.positionsCacheMutex.RUnlock()
// Call API
params := map[string]interface{}{
"category": "linear",
"settleCoin": "USDT",
}
result, err := t.client.NewUtaBybitServiceWithParams(params).GetPositionList(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get Bybit positions: %w", err)
}
if result.RetCode != 0 {
return nil, fmt.Errorf("Bybit API error: %s", result.RetMsg)
}
resultData, ok := result.Result.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("Bybit positions return format error")
}
list, _ := resultData["list"].([]interface{})
var positions []map[string]interface{}
for _, item := range list {
pos, ok := item.(map[string]interface{})
if !ok {
continue
}
sizeStr, _ := pos["size"].(string)
size, _ := strconv.ParseFloat(sizeStr, 64)
// Skip empty positions
if size == 0 {
continue
}
entryPriceStr, _ := pos["avgPrice"].(string)
entryPrice, _ := strconv.ParseFloat(entryPriceStr, 64)
unrealisedPnlStr, _ := pos["unrealisedPnl"].(string)
unrealisedPnl, _ := strconv.ParseFloat(unrealisedPnlStr, 64)
leverageStr, _ := pos["leverage"].(string)
leverage, _ := strconv.ParseFloat(leverageStr, 64)
// Mark price
markPriceStr, _ := pos["markPrice"].(string)
markPrice, _ := strconv.ParseFloat(markPriceStr, 64)
// Liquidation price
liqPriceStr, _ := pos["liqPrice"].(string)
liqPrice, _ := strconv.ParseFloat(liqPriceStr, 64)
// Position created/updated time (milliseconds timestamp)
createdTimeStr, _ := pos["createdTime"].(string)
createdTime, _ := strconv.ParseInt(createdTimeStr, 10, 64)
updatedTimeStr, _ := pos["updatedTime"].(string)
updatedTime, _ := strconv.ParseInt(updatedTimeStr, 10, 64)
positionSide, _ := pos["side"].(string) // Buy = long, Sell = short
// Log raw position data for debugging
logger.Infof("[Bybit] GetPositions raw: symbol=%v, side=%s, size=%v", pos["symbol"], positionSide, sizeStr)
// Convert to unified format (use lowercase for consistency with other exchanges)
// Bybit returns "Buy" for long, "Sell" for short
side := "long"
positionAmt := size
positionSideLower := strings.ToLower(positionSide)
if positionSideLower == "sell" {
side = "short"
positionAmt = -size
}
logger.Infof("[Bybit] GetPositions converted: symbol=%v, rawSide=%s -> side=%s", pos["symbol"], positionSide, side)
position := map[string]interface{}{
"symbol": pos["symbol"],
"side": side,
"positionAmt": positionAmt,
"entryPrice": entryPrice,
"markPrice": markPrice,
"unRealizedProfit": unrealisedPnl,
"unrealizedPnL": unrealisedPnl,
"liquidationPrice": liqPrice,
"leverage": leverage,
"createdTime": createdTime, // Position open time (ms)
"updatedTime": updatedTime, // Position last update time (ms)
}
positions = append(positions, position)
}
// Update cache
t.positionsCacheMutex.Lock()
t.cachedPositions = positions
t.positionsCacheTime = time.Now()
t.positionsCacheMutex.Unlock()
return positions, nil
}
-471
View File
@@ -1,471 +0,0 @@
package bybit
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"nofx/trader/testutil"
"nofx/trader/types"
)
// ============================================================
// Part 1: BybitTraderTestSuite - Inherits base test suite
// ============================================================
// BybitTraderTestSuite Bybit trader test suite
// Inherits TraderTestSuite and adds Bybit-specific mock logic
type BybitTraderTestSuite struct {
*testutil.TraderTestSuite // Embeds base test suite
mockServer *httptest.Server
}
// NewBybitTraderTestSuite Create Bybit test suite
// Note: Due to Bybit SDK encapsulation design, cannot easily inject mock HTTP client
// Therefore this test suite is mainly used for interface compliance verification, not API call testing
func NewBybitTraderTestSuite(t *testing.T) *BybitTraderTestSuite {
// Create mock HTTP server (for response format verification)
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
var respBody interface{}
switch {
case path == "/v5/account/wallet-balance":
respBody = map[string]interface{}{
"retCode": 0,
"retMsg": "OK",
"result": map[string]interface{}{
"list": []map[string]interface{}{
{
"accountType": "UNIFIED",
"totalEquity": "10100.50",
"coin": []map[string]interface{}{
{
"coin": "USDT",
"walletBalance": "10000.00",
"unrealisedPnl": "100.50",
"availableToWithdraw": "8000.00",
},
},
},
},
},
}
default:
respBody = map[string]interface{}{
"retCode": 0,
"retMsg": "OK",
"result": map[string]interface{}{},
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(respBody)
}))
// Create real Bybit trader (for interface compliance testing)
traderInstance := NewBybitTrader("test_api_key", "test_secret_key")
// Create base suite
baseSuite := testutil.NewTraderTestSuite(t, traderInstance)
return &BybitTraderTestSuite{
TraderTestSuite: baseSuite,
mockServer: mockServer,
}
}
// Cleanup Clean up resources
func (s *BybitTraderTestSuite) Cleanup() {
if s.mockServer != nil {
s.mockServer.Close()
}
s.TraderTestSuite.Cleanup()
}
// ============================================================
// Part 2: Interface compliance tests
// ============================================================
// TestBybitTrader_InterfaceCompliance Test interface compliance
func TestBybitTrader_InterfaceCompliance(t *testing.T) {
var _ types.Trader = (*BybitTrader)(nil)
}
// ============================================================
// Part 3: Bybit-specific feature unit tests
// ============================================================
// TestNewBybitTrader Test creating Bybit trader
func TestNewBybitTrader(t *testing.T) {
tests := []struct {
name string
apiKey string
secretKey string
wantNil bool
}{
{
name: "Successfully create",
apiKey: "test_api_key",
secretKey: "test_secret_key",
wantNil: false,
},
{
name: "Empty API Key can still create",
apiKey: "",
secretKey: "test_secret_key",
wantNil: false,
},
{
name: "Empty Secret Key can still create",
apiKey: "test_api_key",
secretKey: "",
wantNil: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bt := NewBybitTrader(tt.apiKey, tt.secretKey)
if tt.wantNil {
assert.Nil(t, bt)
} else {
assert.NotNil(t, bt)
assert.NotNil(t, bt.client)
}
})
}
}
// TestBybitTrader_SymbolFormat Test symbol format
func TestBybitTrader_SymbolFormat(t *testing.T) {
// Bybit uses uppercase symbol format (e.g. BTCUSDT)
tests := []struct {
name string
symbol string
isValid bool
}{
{
name: "Standard USDT contract",
symbol: "BTCUSDT",
isValid: true,
},
{
name: "ETH contract",
symbol: "ETHUSDT",
isValid: true,
},
{
name: "SOL contract",
symbol: "SOLUSDT",
isValid: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Verify symbol format is correct (all uppercase, ends with USDT)
assert.True(t, tt.symbol == strings.ToUpper(tt.symbol))
assert.True(t, strings.HasSuffix(tt.symbol, "USDT"))
})
}
}
// TestBybitTrader_FormatQuantity Test quantity formatting
func TestBybitTrader_FormatQuantity(t *testing.T) {
bt := NewBybitTrader("test", "test")
tests := []struct {
name string
symbol string
quantity float64
expected string
hasError bool
}{
{
name: "BTC quantity formatting",
symbol: "BTCUSDT",
quantity: 0.12345,
expected: "0.123", // Bybit defaults to 3 decimal places
hasError: false,
},
{
name: "ETH quantity formatting",
symbol: "ETHUSDT",
quantity: 1.2345,
expected: "1.234",
hasError: false,
},
{
name: "Integer quantity",
symbol: "SOLUSDT",
quantity: 10.0,
expected: "10.000",
hasError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := bt.FormatQuantity(tt.symbol, tt.quantity)
if tt.hasError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
}
// TestBybitTrader_ParseResponse Test response parsing
func TestBybitTrader_ParseResponse(t *testing.T) {
tests := []struct {
name string
retCode int
retMsg string
expectErr bool
errContain string
}{
{
name: "Success response",
retCode: 0,
retMsg: "OK",
expectErr: false,
},
{
name: "API error",
retCode: 10001,
retMsg: "Invalid symbol",
expectErr: true,
errContain: "Invalid symbol",
},
{
name: "Permission error",
retCode: 10003,
retMsg: "Invalid API key",
expectErr: true,
errContain: "Invalid API key",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := checkBybitResponse(tt.retCode, tt.retMsg)
if tt.expectErr {
assert.Error(t, err)
if tt.errContain != "" {
assert.Contains(t, err.Error(), tt.errContain)
}
} else {
assert.NoError(t, err)
}
})
}
}
// checkBybitResponse Check if Bybit API response has errors
func checkBybitResponse(retCode int, retMsg string) error {
if retCode != 0 {
return &BybitAPIError{
Code: retCode,
Message: retMsg,
}
}
return nil
}
// BybitAPIError Bybit API error type
type BybitAPIError struct {
Code int
Message string
}
func (e *BybitAPIError) Error() string {
return e.Message
}
// TestBybitTrader_PositionSideConversion Test position side conversion
func TestBybitTrader_PositionSideConversion(t *testing.T) {
tests := []struct {
name string
side string
expected string
}{
{
name: "Buy to Long",
side: "Buy",
expected: "long",
},
{
name: "Sell to Short",
side: "Sell",
expected: "short",
},
{
name: "Other values remain unchanged",
side: "Unknown",
expected: "unknown",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertBybitSide(tt.side)
assert.Equal(t, tt.expected, result)
})
}
}
// convertBybitSide Convert Bybit position side
func convertBybitSide(side string) string {
switch side {
case "Buy":
return "long"
case "Sell":
return "short"
default:
return "unknown"
}
}
// TestBybitTrader_CategoryLinear Test using only linear category
func TestBybitTrader_CategoryLinear(t *testing.T) {
// Bybit trader should only use linear category (USDT perpetual contracts)
bt := NewBybitTrader("test", "test")
assert.NotNil(t, bt)
// Verify default configuration
assert.NotNil(t, bt.client)
}
// TestBybitTrader_CacheDuration Test cache duration
func TestBybitTrader_CacheDuration(t *testing.T) {
bt := NewBybitTrader("test", "test")
// Verify default cache time is 15 seconds
assert.Equal(t, 15*time.Second, bt.cacheDuration)
}
// ============================================================
// Part 4: Mock server integration tests
// ============================================================
// TestBybitTrader_MockServerGetBalance Test getting balance through Mock server
func TestBybitTrader_MockServerGetBalance(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v5/account/wallet-balance" {
respBody := map[string]interface{}{
"retCode": 0,
"retMsg": "OK",
"result": map[string]interface{}{
"list": []map[string]interface{}{
{
"accountType": "UNIFIED",
"totalEquity": "10100.50",
"coin": []map[string]interface{}{
{
"coin": "USDT",
"walletBalance": "10000.00",
"unrealisedPnl": "100.50",
"availableToWithdraw": "8000.00",
},
},
},
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(respBody)
return
}
http.NotFound(w, r)
}))
defer mockServer.Close()
// Due to Bybit SDK encapsulation, cannot directly inject mock URL
// This test verifies mock server response format is correct
assert.NotNil(t, mockServer)
}
// TestBybitTrader_MockServerGetPositions Test getting positions through Mock server
func TestBybitTrader_MockServerGetPositions(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v5/position/list" {
respBody := map[string]interface{}{
"retCode": 0,
"retMsg": "OK",
"result": map[string]interface{}{
"list": []map[string]interface{}{
{
"symbol": "BTCUSDT",
"side": "Buy",
"size": "0.5",
"avgPrice": "50000.00",
"markPrice": "50500.00",
"unrealisedPnl": "250.00",
"liqPrice": "45000.00",
"leverage": "10",
"positionIdx": 0,
},
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(respBody)
return
}
http.NotFound(w, r)
}))
defer mockServer.Close()
assert.NotNil(t, mockServer)
}
// TestBybitTrader_MockServerPlaceOrder Test placing order through Mock server
func TestBybitTrader_MockServerPlaceOrder(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v5/order/create" && r.Method == "POST" {
respBody := map[string]interface{}{
"retCode": 0,
"retMsg": "OK",
"result": map[string]interface{}{
"orderId": "1234567890",
"orderLinkId": "test-order-id",
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(respBody)
return
}
http.NotFound(w, r)
}))
defer mockServer.Close()
assert.NotNil(t, mockServer)
}
// TestBybitTrader_MockServerSetLeverage Test setting leverage through Mock server
func TestBybitTrader_MockServerSetLeverage(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v5/position/set-leverage" && r.Method == "POST" {
respBody := map[string]interface{}{
"retCode": 0,
"retMsg": "OK",
"result": map[string]interface{}{},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(respBody)
return
}
http.NotFound(w, r)
}))
defer mockServer.Close()
assert.NotNil(t, mockServer)
}
+1 -785
View File
@@ -3,16 +3,12 @@ package gate
import (
"context"
"fmt"
"math"
"strconv"
"nofx/trader/types"
"strings"
"sync"
"time"
"github.com/antihax/optional"
"github.com/gateio/gateapi-go/v6"
"nofx/logger"
"nofx/trader/types"
)
// GateTrader implements types.Trader interface for Gate.io Futures
@@ -58,118 +54,6 @@ func NewGateTrader(apiKey, secretKey string) *GateTrader {
}
}
// GetBalance retrieves account balance
func (t *GateTrader) GetBalance() (map[string]interface{}, error) {
// Check cache
t.balanceCacheMutex.RLock()
if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration {
cached := t.cachedBalance
t.balanceCacheMutex.RUnlock()
return cached, nil
}
t.balanceCacheMutex.RUnlock()
// Fetch from API
accounts, _, err := t.client.FuturesApi.ListFuturesAccounts(t.ctx, "usdt")
if err != nil {
return nil, fmt.Errorf("failed to get balance: %w", err)
}
total, _ := strconv.ParseFloat(accounts.Total, 64)
available, _ := strconv.ParseFloat(accounts.Available, 64)
unrealizedPnl, _ := strconv.ParseFloat(accounts.UnrealisedPnl, 64)
result := map[string]interface{}{
"totalWalletBalance": total,
"availableBalance": available,
"totalUnrealizedProfit": unrealizedPnl,
}
// Update cache
t.balanceCacheMutex.Lock()
t.cachedBalance = result
t.balanceCacheTime = time.Now()
t.balanceCacheMutex.Unlock()
return result, nil
}
// GetPositions retrieves all open positions
func (t *GateTrader) GetPositions() ([]map[string]interface{}, error) {
// Check cache
t.positionsCacheMutex.RLock()
if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration {
cached := t.cachedPositions
t.positionsCacheMutex.RUnlock()
return cached, nil
}
t.positionsCacheMutex.RUnlock()
// Fetch from API
positions, _, err := t.client.FuturesApi.ListPositions(t.ctx, "usdt", nil)
if err != nil {
return nil, fmt.Errorf("failed to get positions: %w", err)
}
var result []map[string]interface{}
for _, pos := range positions {
if pos.Size == 0 {
continue // Skip empty positions
}
entryPrice, _ := strconv.ParseFloat(pos.EntryPrice, 64)
markPrice, _ := strconv.ParseFloat(pos.MarkPrice, 64)
liqPrice, _ := strconv.ParseFloat(pos.LiqPrice, 64)
unrealizedPnl, _ := strconv.ParseFloat(pos.UnrealisedPnl, 64)
leverage, _ := strconv.ParseFloat(pos.Leverage, 64)
// Gate returns position size in contracts, need to convert to base currency
// Each contract = quanto_multiplier base currency
contractSize := float64(pos.Size)
if pos.Size < 0 {
contractSize = float64(-pos.Size)
}
// Get quanto_multiplier from contract info to convert contracts to actual quantity
quantoMultiplier := 1.0
contract, err := t.getContract(pos.Contract)
if err == nil && contract != nil {
qm, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64)
if qm > 0 {
quantoMultiplier = qm
}
}
// Convert contract count to actual token quantity
positionAmt := contractSize * quantoMultiplier
// Determine side based on position size
side := "long"
if pos.Size < 0 {
side = "short"
}
result = append(result, map[string]interface{}{
"symbol": pos.Contract,
"positionAmt": positionAmt,
"entryPrice": entryPrice,
"markPrice": markPrice,
"unRealizedProfit": unrealizedPnl,
"leverage": int(leverage),
"liquidationPrice": liqPrice,
"side": side,
})
}
// Update cache
t.positionsCacheMutex.Lock()
t.cachedPositions = result
t.positionsCacheTime = time.Now()
t.positionsCacheMutex.Unlock()
return result, nil
}
// convertSymbol converts symbol format (e.g., BTCUSDT -> BTC_USDT)
func (t *GateTrader) convertSymbol(symbol string) string {
// If already in correct format
@@ -215,674 +99,6 @@ func (t *GateTrader) getContract(symbol string) (*gateapi.Contract, error) {
return &contract, nil
}
// SetLeverage sets the leverage for a symbol
func (t *GateTrader) SetLeverage(symbol string, leverage int) error {
symbol = t.convertSymbol(symbol)
_, _, err := t.client.FuturesApi.UpdatePositionLeverage(t.ctx, "usdt", symbol, fmt.Sprintf("%d", leverage), nil)
if err != nil {
// Gate.io may return error if leverage is already set
if strings.Contains(err.Error(), "RISK_LIMIT_EXCEEDED") {
logger.Warnf(" [Gate] Leverage %d exceeds limit for %s", leverage, symbol)
return nil
}
return fmt.Errorf("failed to set leverage: %w", err)
}
logger.Infof(" [Gate] Leverage set to %dx for %s", leverage, symbol)
return nil
}
// SetMarginMode sets margin mode (cross or isolated)
func (t *GateTrader) SetMarginMode(symbol string, isCrossMargin bool) error {
// Gate.io uses leverage=0 for cross margin, positive number for isolated
// This is handled through UpdatePositionLeverage with cross_leverage_limit
// For now, we'll skip explicit margin mode setting as it's tied to leverage
logger.Infof(" [Gate] Margin mode is set through leverage (0=cross)")
return nil
}
// OpenLong opens a long position
func (t *GateTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
symbol = t.convertSymbol(symbol)
// Cancel old orders first
t.CancelAllOrders(symbol)
// Set leverage
if err := t.SetLeverage(symbol, leverage); err != nil {
logger.Warnf(" [Gate] Failed to set leverage: %v", err)
}
// Get contract info for size calculation
contract, err := t.getContract(symbol)
if err != nil {
return nil, err
}
// Gate uses contract size units (each contract = quanto_multiplier base currency)
// size = quantity / quanto_multiplier
quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64)
size := int64(quantity / quantoMultiplier)
if size <= 0 {
size = 1
}
order := gateapi.FuturesOrder{
Contract: symbol,
Size: size, // Positive for long
Price: "0", // Market order
Tif: "ioc",
Text: "t-nofx",
}
logger.Infof(" [Gate] OpenLong: symbol=%s, size=%d, leverage=%d", symbol, size, leverage)
result, _, err := t.client.FuturesApi.CreateFuturesOrder(t.ctx, "usdt", order, nil)
if err != nil {
return nil, fmt.Errorf("failed to open long position: %w", err)
}
// Clear cache
t.clearCache()
// Parse fill price from result
fillPrice, _ := strconv.ParseFloat(result.FillPrice, 64)
logger.Infof(" [Gate] Opened long position: orderId=%d, fillPrice=%.4f", result.Id, fillPrice)
return map[string]interface{}{
"orderId": fmt.Sprintf("%d", result.Id),
"symbol": t.revertSymbol(symbol),
"status": "FILLED",
"fillPrice": fillPrice,
"avgPrice": fillPrice,
}, nil
}
// OpenShort opens a short position
func (t *GateTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
symbol = t.convertSymbol(symbol)
// Cancel old orders first
t.CancelAllOrders(symbol)
// Set leverage
if err := t.SetLeverage(symbol, leverage); err != nil {
logger.Warnf(" [Gate] Failed to set leverage: %v", err)
}
// Get contract info for size calculation
contract, err := t.getContract(symbol)
if err != nil {
return nil, err
}
// Gate uses contract size units
quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64)
size := int64(quantity / quantoMultiplier)
if size <= 0 {
size = 1
}
order := gateapi.FuturesOrder{
Contract: symbol,
Size: -size, // Negative for short
Price: "0", // Market order
Tif: "ioc",
Text: "t-nofx",
}
logger.Infof(" [Gate] OpenShort: symbol=%s, size=%d, leverage=%d", symbol, -size, leverage)
result, _, err := t.client.FuturesApi.CreateFuturesOrder(t.ctx, "usdt", order, nil)
if err != nil {
return nil, fmt.Errorf("failed to open short position: %w", err)
}
// Clear cache
t.clearCache()
// Parse fill price from result
fillPrice, _ := strconv.ParseFloat(result.FillPrice, 64)
logger.Infof(" [Gate] Opened short position: orderId=%d, fillPrice=%.4f", result.Id, fillPrice)
return map[string]interface{}{
"orderId": fmt.Sprintf("%d", result.Id),
"symbol": t.revertSymbol(symbol),
"status": "FILLED",
"fillPrice": fillPrice,
"avgPrice": fillPrice,
}, nil
}
// CloseLong closes a long position
func (t *GateTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) {
symbol = t.convertSymbol(symbol)
// If quantity is 0, get current position
if quantity == 0 {
positions, err := t.GetPositions()
if err != nil {
return nil, err
}
for _, pos := range positions {
posSymbol := t.convertSymbol(pos["symbol"].(string))
if posSymbol == symbol && pos["side"] == "long" {
quantity = pos["positionAmt"].(float64)
break
}
}
if quantity == 0 {
return nil, fmt.Errorf("long position not found for %s", symbol)
}
}
// Get contract info for size calculation
contract, err := t.getContract(symbol)
if err != nil {
return nil, err
}
quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64)
size := int64(quantity / quantoMultiplier)
if size <= 0 {
size = 1
}
// Close long = sell (use ReduceOnly, not Close which requires Size=0)
order := gateapi.FuturesOrder{
Contract: symbol,
Size: -size, // Negative to close long
Price: "0",
Tif: "ioc",
ReduceOnly: true,
Text: "t-nofx-close",
}
logger.Infof(" [Gate] CloseLong: symbol=%s, size=%d", symbol, -size)
result, _, err := t.client.FuturesApi.CreateFuturesOrder(t.ctx, "usdt", order, nil)
if err != nil {
return nil, fmt.Errorf("failed to close long position: %w", err)
}
// Clear cache
t.clearCache()
// Parse fill price from result
fillPrice, _ := strconv.ParseFloat(result.FillPrice, 64)
logger.Infof(" [Gate] Closed long position: orderId=%d, fillPrice=%.4f", result.Id, fillPrice)
return map[string]interface{}{
"orderId": fmt.Sprintf("%d", result.Id),
"symbol": t.revertSymbol(symbol),
"status": "FILLED",
"fillPrice": fillPrice,
"avgPrice": fillPrice,
}, nil
}
// CloseShort closes a short position
func (t *GateTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) {
symbol = t.convertSymbol(symbol)
// If quantity is 0, get current position
if quantity == 0 {
positions, err := t.GetPositions()
if err != nil {
return nil, err
}
for _, pos := range positions {
posSymbol := t.convertSymbol(pos["symbol"].(string))
if posSymbol == symbol && pos["side"] == "short" {
quantity = pos["positionAmt"].(float64)
break
}
}
if quantity == 0 {
return nil, fmt.Errorf("short position not found for %s", symbol)
}
}
// Ensure quantity is positive
if quantity < 0 {
quantity = -quantity
}
// Get contract info for size calculation
contract, err := t.getContract(symbol)
if err != nil {
return nil, err
}
quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64)
size := int64(quantity / quantoMultiplier)
if size <= 0 {
size = 1
}
// Close short = buy (use ReduceOnly, not Close which requires Size=0)
order := gateapi.FuturesOrder{
Contract: symbol,
Size: size, // Positive to close short
Price: "0",
Tif: "ioc",
ReduceOnly: true,
Text: "t-nofx-close",
}
logger.Infof(" [Gate] CloseShort: symbol=%s, size=%d", symbol, size)
result, _, err := t.client.FuturesApi.CreateFuturesOrder(t.ctx, "usdt", order, nil)
if err != nil {
return nil, fmt.Errorf("failed to close short position: %w", err)
}
// Clear cache
t.clearCache()
// Parse fill price from result
fillPrice, _ := strconv.ParseFloat(result.FillPrice, 64)
logger.Infof(" [Gate] Closed short position: orderId=%d, fillPrice=%.4f", result.Id, fillPrice)
return map[string]interface{}{
"orderId": fmt.Sprintf("%d", result.Id),
"symbol": t.revertSymbol(symbol),
"status": "FILLED",
"fillPrice": fillPrice,
"avgPrice": fillPrice,
}, nil
}
// GetMarketPrice gets the current market price
func (t *GateTrader) GetMarketPrice(symbol string) (float64, error) {
symbol = t.convertSymbol(symbol)
opts := &gateapi.ListFuturesTickersOpts{
Contract: optional.NewString(symbol),
}
tickers, _, err := t.client.FuturesApi.ListFuturesTickers(t.ctx, "usdt", opts)
if err != nil {
return 0, fmt.Errorf("failed to get market price: %w", err)
}
if len(tickers) == 0 {
return 0, fmt.Errorf("no ticker data for %s", symbol)
}
price, _ := strconv.ParseFloat(tickers[0].Last, 64)
return price, nil
}
// SetStopLoss sets a stop loss order
func (t *GateTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error {
symbol = t.convertSymbol(symbol)
contract, err := t.getContract(symbol)
if err != nil {
return err
}
quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64)
size := int64(quantity / quantoMultiplier)
if size <= 0 {
size = 1
}
// For long position, stop loss means sell when price drops
// For short position, stop loss means buy when price rises
if strings.ToUpper(positionSide) == "LONG" {
size = -size
}
// Use price trigger order
trigger := gateapi.FuturesPriceTriggeredOrder{
Initial: gateapi.FuturesInitialOrder{
Contract: symbol,
Size: size,
Price: "0", // Market order
Tif: "ioc",
ReduceOnly: true,
Close: true,
},
Trigger: gateapi.FuturesPriceTrigger{
StrategyType: 0, // Close position
PriceType: 0, // Latest price
Price: fmt.Sprintf("%.8f", stopPrice),
Rule: 1, // Price <= trigger price
},
}
if strings.ToUpper(positionSide) == "SHORT" {
trigger.Trigger.Rule = 2 // Price >= trigger price for short stop loss
}
_, _, err = t.client.FuturesApi.CreatePriceTriggeredOrder(t.ctx, "usdt", trigger)
if err != nil {
return fmt.Errorf("failed to set stop loss: %w", err)
}
logger.Infof(" [Gate] Stop loss set: %s @ %.4f", symbol, stopPrice)
return nil
}
// SetTakeProfit sets a take profit order
func (t *GateTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error {
symbol = t.convertSymbol(symbol)
contract, err := t.getContract(symbol)
if err != nil {
return err
}
quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64)
size := int64(quantity / quantoMultiplier)
if size <= 0 {
size = 1
}
// For long position, take profit means sell when price rises
// For short position, take profit means buy when price drops
if strings.ToUpper(positionSide) == "LONG" {
size = -size
}
trigger := gateapi.FuturesPriceTriggeredOrder{
Initial: gateapi.FuturesInitialOrder{
Contract: symbol,
Size: size,
Price: "0", // Market order
Tif: "ioc",
ReduceOnly: true,
Close: true,
},
Trigger: gateapi.FuturesPriceTrigger{
StrategyType: 0, // Close position
PriceType: 0, // Latest price
Price: fmt.Sprintf("%.8f", takeProfitPrice),
Rule: 2, // Price >= trigger price for long take profit
},
}
if strings.ToUpper(positionSide) == "SHORT" {
trigger.Trigger.Rule = 1 // Price <= trigger price for short take profit
}
_, _, err = t.client.FuturesApi.CreatePriceTriggeredOrder(t.ctx, "usdt", trigger)
if err != nil {
return fmt.Errorf("failed to set take profit: %w", err)
}
logger.Infof(" [Gate] Take profit set: %s @ %.4f", symbol, takeProfitPrice)
return nil
}
// CancelStopLossOrders cancels stop loss orders
func (t *GateTrader) CancelStopLossOrders(symbol string) error {
return t.cancelTriggerOrders(symbol, "stop_loss")
}
// CancelTakeProfitOrders cancels take profit orders
func (t *GateTrader) CancelTakeProfitOrders(symbol string) error {
return t.cancelTriggerOrders(symbol, "take_profit")
}
// cancelTriggerOrders cancels trigger orders of a specific type
func (t *GateTrader) cancelTriggerOrders(symbol string, orderType string) error {
symbol = t.convertSymbol(symbol)
opts := &gateapi.ListPriceTriggeredOrdersOpts{
Contract: optional.NewString(symbol),
}
orders, _, err := t.client.FuturesApi.ListPriceTriggeredOrders(t.ctx, "usdt", "open", opts)
if err != nil {
return err
}
for _, order := range orders {
// Determine if it's stop loss or take profit based on trigger rule and position
// For simplicity, cancel all matching symbol orders
_, _, err := t.client.FuturesApi.CancelPriceTriggeredOrder(t.ctx, "usdt", fmt.Sprintf("%d", order.Id))
if err != nil {
logger.Warnf(" [Gate] Failed to cancel trigger order %d: %v", order.Id, err)
}
}
return nil
}
// CancelAllOrders cancels all pending orders for a symbol
func (t *GateTrader) CancelAllOrders(symbol string) error {
symbol = t.convertSymbol(symbol)
// Cancel regular orders
_, _, err := t.client.FuturesApi.CancelFuturesOrders(t.ctx, "usdt", symbol, nil)
if err != nil {
// Ignore if no orders to cancel
if !strings.Contains(err.Error(), "ORDER_NOT_FOUND") {
logger.Warnf(" [Gate] Error canceling orders: %v", err)
}
}
// Cancel trigger orders
t.cancelTriggerOrders(symbol, "")
return nil
}
// CancelStopOrders cancels all stop orders (stop loss and take profit)
func (t *GateTrader) CancelStopOrders(symbol string) error {
t.CancelStopLossOrders(symbol)
t.CancelTakeProfitOrders(symbol)
return nil
}
// FormatQuantity formats quantity to correct precision
func (t *GateTrader) FormatQuantity(symbol string, quantity float64) (string, error) {
contract, err := t.getContract(symbol)
if err != nil {
return fmt.Sprintf("%.4f", quantity), nil
}
// Gate uses quanto_multiplier for contract size
quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64)
if quantoMultiplier > 0 {
// Calculate number of contracts
numContracts := quantity / quantoMultiplier
return fmt.Sprintf("%.0f", math.Floor(numContracts)), nil
}
return fmt.Sprintf("%.4f", quantity), nil
}
// GetOrderStatus gets the status of an order
func (t *GateTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) {
symbol = t.convertSymbol(symbol)
order, _, err := t.client.FuturesApi.GetFuturesOrder(t.ctx, "usdt", orderID)
if err != nil {
return nil, fmt.Errorf("failed to get order status: %w", err)
}
fillPrice, _ := strconv.ParseFloat(order.FillPrice, 64)
tkFee, _ := strconv.ParseFloat(order.Tkfr, 64)
mkFee, _ := strconv.ParseFloat(order.Mkfr, 64)
totalFee := tkFee + mkFee
// Get quanto_multiplier to convert contracts to actual quantity
quantoMultiplier := 1.0
contract, contractErr := t.getContract(symbol)
if contractErr == nil && contract != nil {
qm, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64)
if qm > 0 {
quantoMultiplier = qm
}
}
// Map status
status := "NEW"
switch order.Status {
case "finished":
if order.FinishAs == "filled" {
status = "FILLED"
} else if order.FinishAs == "cancelled" {
status = "CANCELED"
} else {
status = "CLOSED"
}
case "open":
status = "NEW"
}
side := "BUY"
if order.Size < 0 {
side = "SELL"
}
// Convert contract count to actual token quantity
executedQty := math.Abs(float64(order.Size-order.Left)) * quantoMultiplier
return map[string]interface{}{
"orderId": orderID,
"symbol": t.revertSymbol(symbol),
"status": status,
"avgPrice": fillPrice,
"executedQty": executedQty,
"side": side,
"type": order.Tif,
"time": int64(order.CreateTime * 1000),
"updateTime": int64(order.FinishTime * 1000),
"commission": totalFee,
}, nil
}
// GetClosedPnL retrieves closed position PnL records
func (t *GateTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) {
if limit <= 0 {
limit = 100
}
if limit > 100 {
limit = 100
}
opts := &gateapi.ListPositionCloseOpts{
Limit: optional.NewInt32(int32(limit)),
From: optional.NewInt64(startTime.Unix()),
}
closedPositions, _, err := t.client.FuturesApi.ListPositionClose(t.ctx, "usdt", opts)
if err != nil {
return nil, fmt.Errorf("failed to get closed positions: %w", err)
}
records := make([]types.ClosedPnLRecord, 0, len(closedPositions))
for _, pos := range closedPositions {
pnl, _ := strconv.ParseFloat(pos.Pnl, 64)
record := types.ClosedPnLRecord{
Symbol: t.revertSymbol(pos.Contract),
Side: pos.Side,
RealizedPnL: pnl,
ExitTime: time.Unix(int64(pos.Time), 0).UTC(),
CloseType: "unknown",
}
records = append(records, record)
}
return records, nil
}
// GetOpenOrders gets open/pending orders
func (t *GateTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) {
symbol = t.convertSymbol(symbol)
opts := &gateapi.ListFuturesOrdersOpts{
Contract: optional.NewString(symbol),
}
orders, _, err := t.client.FuturesApi.ListFuturesOrders(t.ctx, "usdt", "open", opts)
if err != nil {
return nil, fmt.Errorf("failed to get open orders: %w", err)
}
// Get quanto_multiplier to convert contracts to actual quantity
quantoMultiplier := 1.0
contract, err := t.getContract(symbol)
if err == nil && contract != nil {
qm, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64)
if qm > 0 {
quantoMultiplier = qm
}
}
var result []types.OpenOrder
for _, order := range orders {
price, _ := strconv.ParseFloat(order.Price, 64)
side := "BUY"
if order.Size < 0 {
side = "SELL"
}
// Convert contract count to actual token quantity
quantity := math.Abs(float64(order.Size)) * quantoMultiplier
result = append(result, types.OpenOrder{
OrderID: fmt.Sprintf("%d", order.Id),
Symbol: t.revertSymbol(order.Contract),
Side: side,
Type: "LIMIT",
Price: price,
Quantity: quantity,
Status: "NEW",
})
}
// Also get trigger orders
triggerOpts := &gateapi.ListPriceTriggeredOrdersOpts{
Contract: optional.NewString(symbol),
}
triggerOrders, _, err := t.client.FuturesApi.ListPriceTriggeredOrders(t.ctx, "usdt", "open", triggerOpts)
if err == nil {
for _, order := range triggerOrders {
triggerPrice, _ := strconv.ParseFloat(order.Trigger.Price, 64)
side := "BUY"
if order.Initial.Size < 0 {
side = "SELL"
}
orderType := "STOP_MARKET"
if order.Trigger.Rule == 2 {
orderType = "TAKE_PROFIT_MARKET"
}
// Convert contract count to actual token quantity
quantity := math.Abs(float64(order.Initial.Size)) * quantoMultiplier
result = append(result, types.OpenOrder{
OrderID: fmt.Sprintf("%d", order.Id),
Symbol: t.revertSymbol(order.Initial.Contract),
Side: side,
Type: orderType,
StopPrice: triggerPrice,
Quantity: quantity,
Status: "NEW",
})
}
}
return result, nil
}
// clearCache clears all caches
func (t *GateTrader) clearCache() {
t.balanceCacheMutex.Lock()
+160
View File
@@ -0,0 +1,160 @@
package gate
import (
"fmt"
"nofx/trader/types"
"strconv"
"time"
"github.com/antihax/optional"
"github.com/gateio/gateapi-go/v6"
)
// GetBalance retrieves account balance
func (t *GateTrader) GetBalance() (map[string]interface{}, error) {
// Check cache
t.balanceCacheMutex.RLock()
if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration {
cached := t.cachedBalance
t.balanceCacheMutex.RUnlock()
return cached, nil
}
t.balanceCacheMutex.RUnlock()
// Fetch from API
accounts, _, err := t.client.FuturesApi.ListFuturesAccounts(t.ctx, "usdt")
if err != nil {
return nil, fmt.Errorf("failed to get balance: %w", err)
}
total, _ := strconv.ParseFloat(accounts.Total, 64)
available, _ := strconv.ParseFloat(accounts.Available, 64)
unrealizedPnl, _ := strconv.ParseFloat(accounts.UnrealisedPnl, 64)
result := map[string]interface{}{
"totalWalletBalance": total,
"availableBalance": available,
"totalUnrealizedProfit": unrealizedPnl,
}
// Update cache
t.balanceCacheMutex.Lock()
t.cachedBalance = result
t.balanceCacheTime = time.Now()
t.balanceCacheMutex.Unlock()
return result, nil
}
// GetPositions retrieves all open positions
func (t *GateTrader) GetPositions() ([]map[string]interface{}, error) {
// Check cache
t.positionsCacheMutex.RLock()
if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration {
cached := t.cachedPositions
t.positionsCacheMutex.RUnlock()
return cached, nil
}
t.positionsCacheMutex.RUnlock()
// Fetch from API
positions, _, err := t.client.FuturesApi.ListPositions(t.ctx, "usdt", nil)
if err != nil {
return nil, fmt.Errorf("failed to get positions: %w", err)
}
var result []map[string]interface{}
for _, pos := range positions {
if pos.Size == 0 {
continue // Skip empty positions
}
entryPrice, _ := strconv.ParseFloat(pos.EntryPrice, 64)
markPrice, _ := strconv.ParseFloat(pos.MarkPrice, 64)
liqPrice, _ := strconv.ParseFloat(pos.LiqPrice, 64)
unrealizedPnl, _ := strconv.ParseFloat(pos.UnrealisedPnl, 64)
leverage, _ := strconv.ParseFloat(pos.Leverage, 64)
// Gate returns position size in contracts, need to convert to base currency
// Each contract = quanto_multiplier base currency
contractSize := float64(pos.Size)
if pos.Size < 0 {
contractSize = float64(-pos.Size)
}
// Get quanto_multiplier from contract info to convert contracts to actual quantity
quantoMultiplier := 1.0
contract, err := t.getContract(pos.Contract)
if err == nil && contract != nil {
qm, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64)
if qm > 0 {
quantoMultiplier = qm
}
}
// Convert contract count to actual token quantity
positionAmt := contractSize * quantoMultiplier
// Determine side based on position size
side := "long"
if pos.Size < 0 {
side = "short"
}
result = append(result, map[string]interface{}{
"symbol": pos.Contract,
"positionAmt": positionAmt,
"entryPrice": entryPrice,
"markPrice": markPrice,
"unRealizedProfit": unrealizedPnl,
"leverage": int(leverage),
"liquidationPrice": liqPrice,
"side": side,
})
}
// Update cache
t.positionsCacheMutex.Lock()
t.cachedPositions = result
t.positionsCacheTime = time.Now()
t.positionsCacheMutex.Unlock()
return result, nil
}
// GetClosedPnL retrieves closed position PnL records
func (t *GateTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) {
if limit <= 0 {
limit = 100
}
if limit > 100 {
limit = 100
}
opts := &gateapi.ListPositionCloseOpts{
Limit: optional.NewInt32(int32(limit)),
From: optional.NewInt64(startTime.Unix()),
}
closedPositions, _, err := t.client.FuturesApi.ListPositionClose(t.ctx, "usdt", opts)
if err != nil {
return nil, fmt.Errorf("failed to get closed positions: %w", err)
}
records := make([]types.ClosedPnLRecord, 0, len(closedPositions))
for _, pos := range closedPositions {
pnl, _ := strconv.ParseFloat(pos.Pnl, 64)
record := types.ClosedPnLRecord{
Symbol: t.revertSymbol(pos.Contract),
Side: pos.Side,
RealizedPnL: pnl,
ExitTime: time.Unix(int64(pos.Time), 0).UTC(),
CloseType: "unknown",
}
records = append(records, record)
}
return records, nil
}
+644
View File
@@ -0,0 +1,644 @@
package gate
import (
"fmt"
"math"
"nofx/logger"
"nofx/trader/types"
"strconv"
"strings"
"github.com/antihax/optional"
"github.com/gateio/gateapi-go/v6"
)
// SetLeverage sets the leverage for a symbol
func (t *GateTrader) SetLeverage(symbol string, leverage int) error {
symbol = t.convertSymbol(symbol)
_, _, err := t.client.FuturesApi.UpdatePositionLeverage(t.ctx, "usdt", symbol, fmt.Sprintf("%d", leverage), nil)
if err != nil {
// Gate.io may return error if leverage is already set
if strings.Contains(err.Error(), "RISK_LIMIT_EXCEEDED") {
logger.Warnf(" [Gate] Leverage %d exceeds limit for %s", leverage, symbol)
return nil
}
return fmt.Errorf("failed to set leverage: %w", err)
}
logger.Infof(" [Gate] Leverage set to %dx for %s", leverage, symbol)
return nil
}
// SetMarginMode sets margin mode (cross or isolated)
func (t *GateTrader) SetMarginMode(symbol string, isCrossMargin bool) error {
// Gate.io uses leverage=0 for cross margin, positive number for isolated
// This is handled through UpdatePositionLeverage with cross_leverage_limit
// For now, we'll skip explicit margin mode setting as it's tied to leverage
logger.Infof(" [Gate] Margin mode is set through leverage (0=cross)")
return nil
}
// OpenLong opens a long position
func (t *GateTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
symbol = t.convertSymbol(symbol)
// Cancel old orders first
t.CancelAllOrders(symbol)
// Set leverage
if err := t.SetLeverage(symbol, leverage); err != nil {
logger.Warnf(" [Gate] Failed to set leverage: %v", err)
}
// Get contract info for size calculation
contract, err := t.getContract(symbol)
if err != nil {
return nil, err
}
// Gate uses contract size units (each contract = quanto_multiplier base currency)
// size = quantity / quanto_multiplier
quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64)
size := int64(quantity / quantoMultiplier)
if size <= 0 {
size = 1
}
order := gateapi.FuturesOrder{
Contract: symbol,
Size: size, // Positive for long
Price: "0", // Market order
Tif: "ioc",
Text: "t-nofx",
}
logger.Infof(" [Gate] OpenLong: symbol=%s, size=%d, leverage=%d", symbol, size, leverage)
result, _, err := t.client.FuturesApi.CreateFuturesOrder(t.ctx, "usdt", order, nil)
if err != nil {
return nil, fmt.Errorf("failed to open long position: %w", err)
}
// Clear cache
t.clearCache()
// Parse fill price from result
fillPrice, _ := strconv.ParseFloat(result.FillPrice, 64)
logger.Infof(" [Gate] Opened long position: orderId=%d, fillPrice=%.4f", result.Id, fillPrice)
return map[string]interface{}{
"orderId": fmt.Sprintf("%d", result.Id),
"symbol": t.revertSymbol(symbol),
"status": "FILLED",
"fillPrice": fillPrice,
"avgPrice": fillPrice,
}, nil
}
// OpenShort opens a short position
func (t *GateTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
symbol = t.convertSymbol(symbol)
// Cancel old orders first
t.CancelAllOrders(symbol)
// Set leverage
if err := t.SetLeverage(symbol, leverage); err != nil {
logger.Warnf(" [Gate] Failed to set leverage: %v", err)
}
// Get contract info for size calculation
contract, err := t.getContract(symbol)
if err != nil {
return nil, err
}
// Gate uses contract size units
quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64)
size := int64(quantity / quantoMultiplier)
if size <= 0 {
size = 1
}
order := gateapi.FuturesOrder{
Contract: symbol,
Size: -size, // Negative for short
Price: "0", // Market order
Tif: "ioc",
Text: "t-nofx",
}
logger.Infof(" [Gate] OpenShort: symbol=%s, size=%d, leverage=%d", symbol, -size, leverage)
result, _, err := t.client.FuturesApi.CreateFuturesOrder(t.ctx, "usdt", order, nil)
if err != nil {
return nil, fmt.Errorf("failed to open short position: %w", err)
}
// Clear cache
t.clearCache()
// Parse fill price from result
fillPrice, _ := strconv.ParseFloat(result.FillPrice, 64)
logger.Infof(" [Gate] Opened short position: orderId=%d, fillPrice=%.4f", result.Id, fillPrice)
return map[string]interface{}{
"orderId": fmt.Sprintf("%d", result.Id),
"symbol": t.revertSymbol(symbol),
"status": "FILLED",
"fillPrice": fillPrice,
"avgPrice": fillPrice,
}, nil
}
// CloseLong closes a long position
func (t *GateTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) {
symbol = t.convertSymbol(symbol)
// If quantity is 0, get current position
if quantity == 0 {
positions, err := t.GetPositions()
if err != nil {
return nil, err
}
for _, pos := range positions {
posSymbol := t.convertSymbol(pos["symbol"].(string))
if posSymbol == symbol && pos["side"] == "long" {
quantity = pos["positionAmt"].(float64)
break
}
}
if quantity == 0 {
return nil, fmt.Errorf("long position not found for %s", symbol)
}
}
// Get contract info for size calculation
contract, err := t.getContract(symbol)
if err != nil {
return nil, err
}
quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64)
size := int64(quantity / quantoMultiplier)
if size <= 0 {
size = 1
}
// Close long = sell (use ReduceOnly, not Close which requires Size=0)
order := gateapi.FuturesOrder{
Contract: symbol,
Size: -size, // Negative to close long
Price: "0",
Tif: "ioc",
ReduceOnly: true,
Text: "t-nofx-close",
}
logger.Infof(" [Gate] CloseLong: symbol=%s, size=%d", symbol, -size)
result, _, err := t.client.FuturesApi.CreateFuturesOrder(t.ctx, "usdt", order, nil)
if err != nil {
return nil, fmt.Errorf("failed to close long position: %w", err)
}
// Clear cache
t.clearCache()
// Parse fill price from result
fillPrice, _ := strconv.ParseFloat(result.FillPrice, 64)
logger.Infof(" [Gate] Closed long position: orderId=%d, fillPrice=%.4f", result.Id, fillPrice)
return map[string]interface{}{
"orderId": fmt.Sprintf("%d", result.Id),
"symbol": t.revertSymbol(symbol),
"status": "FILLED",
"fillPrice": fillPrice,
"avgPrice": fillPrice,
}, nil
}
// CloseShort closes a short position
func (t *GateTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) {
symbol = t.convertSymbol(symbol)
// If quantity is 0, get current position
if quantity == 0 {
positions, err := t.GetPositions()
if err != nil {
return nil, err
}
for _, pos := range positions {
posSymbol := t.convertSymbol(pos["symbol"].(string))
if posSymbol == symbol && pos["side"] == "short" {
quantity = pos["positionAmt"].(float64)
break
}
}
if quantity == 0 {
return nil, fmt.Errorf("short position not found for %s", symbol)
}
}
// Ensure quantity is positive
if quantity < 0 {
quantity = -quantity
}
// Get contract info for size calculation
contract, err := t.getContract(symbol)
if err != nil {
return nil, err
}
quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64)
size := int64(quantity / quantoMultiplier)
if size <= 0 {
size = 1
}
// Close short = buy (use ReduceOnly, not Close which requires Size=0)
order := gateapi.FuturesOrder{
Contract: symbol,
Size: size, // Positive to close short
Price: "0",
Tif: "ioc",
ReduceOnly: true,
Text: "t-nofx-close",
}
logger.Infof(" [Gate] CloseShort: symbol=%s, size=%d", symbol, size)
result, _, err := t.client.FuturesApi.CreateFuturesOrder(t.ctx, "usdt", order, nil)
if err != nil {
return nil, fmt.Errorf("failed to close short position: %w", err)
}
// Clear cache
t.clearCache()
// Parse fill price from result
fillPrice, _ := strconv.ParseFloat(result.FillPrice, 64)
logger.Infof(" [Gate] Closed short position: orderId=%d, fillPrice=%.4f", result.Id, fillPrice)
return map[string]interface{}{
"orderId": fmt.Sprintf("%d", result.Id),
"symbol": t.revertSymbol(symbol),
"status": "FILLED",
"fillPrice": fillPrice,
"avgPrice": fillPrice,
}, nil
}
// GetMarketPrice gets the current market price
func (t *GateTrader) GetMarketPrice(symbol string) (float64, error) {
symbol = t.convertSymbol(symbol)
opts := &gateapi.ListFuturesTickersOpts{
Contract: optional.NewString(symbol),
}
tickers, _, err := t.client.FuturesApi.ListFuturesTickers(t.ctx, "usdt", opts)
if err != nil {
return 0, fmt.Errorf("failed to get market price: %w", err)
}
if len(tickers) == 0 {
return 0, fmt.Errorf("no ticker data for %s", symbol)
}
price, _ := strconv.ParseFloat(tickers[0].Last, 64)
return price, nil
}
// SetStopLoss sets a stop loss order
func (t *GateTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error {
symbol = t.convertSymbol(symbol)
contract, err := t.getContract(symbol)
if err != nil {
return err
}
quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64)
size := int64(quantity / quantoMultiplier)
if size <= 0 {
size = 1
}
// For long position, stop loss means sell when price drops
// For short position, stop loss means buy when price rises
if strings.ToUpper(positionSide) == "LONG" {
size = -size
}
// Use price trigger order
trigger := gateapi.FuturesPriceTriggeredOrder{
Initial: gateapi.FuturesInitialOrder{
Contract: symbol,
Size: size,
Price: "0", // Market order
Tif: "ioc",
ReduceOnly: true,
Close: true,
},
Trigger: gateapi.FuturesPriceTrigger{
StrategyType: 0, // Close position
PriceType: 0, // Latest price
Price: fmt.Sprintf("%.8f", stopPrice),
Rule: 1, // Price <= trigger price
},
}
if strings.ToUpper(positionSide) == "SHORT" {
trigger.Trigger.Rule = 2 // Price >= trigger price for short stop loss
}
_, _, err = t.client.FuturesApi.CreatePriceTriggeredOrder(t.ctx, "usdt", trigger)
if err != nil {
return fmt.Errorf("failed to set stop loss: %w", err)
}
logger.Infof(" [Gate] Stop loss set: %s @ %.4f", symbol, stopPrice)
return nil
}
// SetTakeProfit sets a take profit order
func (t *GateTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error {
symbol = t.convertSymbol(symbol)
contract, err := t.getContract(symbol)
if err != nil {
return err
}
quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64)
size := int64(quantity / quantoMultiplier)
if size <= 0 {
size = 1
}
// For long position, take profit means sell when price rises
// For short position, take profit means buy when price drops
if strings.ToUpper(positionSide) == "LONG" {
size = -size
}
trigger := gateapi.FuturesPriceTriggeredOrder{
Initial: gateapi.FuturesInitialOrder{
Contract: symbol,
Size: size,
Price: "0", // Market order
Tif: "ioc",
ReduceOnly: true,
Close: true,
},
Trigger: gateapi.FuturesPriceTrigger{
StrategyType: 0, // Close position
PriceType: 0, // Latest price
Price: fmt.Sprintf("%.8f", takeProfitPrice),
Rule: 2, // Price >= trigger price for long take profit
},
}
if strings.ToUpper(positionSide) == "SHORT" {
trigger.Trigger.Rule = 1 // Price <= trigger price for short take profit
}
_, _, err = t.client.FuturesApi.CreatePriceTriggeredOrder(t.ctx, "usdt", trigger)
if err != nil {
return fmt.Errorf("failed to set take profit: %w", err)
}
logger.Infof(" [Gate] Take profit set: %s @ %.4f", symbol, takeProfitPrice)
return nil
}
// CancelStopLossOrders cancels stop loss orders
func (t *GateTrader) CancelStopLossOrders(symbol string) error {
return t.cancelTriggerOrders(symbol, "stop_loss")
}
// CancelTakeProfitOrders cancels take profit orders
func (t *GateTrader) CancelTakeProfitOrders(symbol string) error {
return t.cancelTriggerOrders(symbol, "take_profit")
}
// cancelTriggerOrders cancels trigger orders of a specific type
func (t *GateTrader) cancelTriggerOrders(symbol string, orderType string) error {
symbol = t.convertSymbol(symbol)
opts := &gateapi.ListPriceTriggeredOrdersOpts{
Contract: optional.NewString(symbol),
}
orders, _, err := t.client.FuturesApi.ListPriceTriggeredOrders(t.ctx, "usdt", "open", opts)
if err != nil {
return err
}
for _, order := range orders {
// Determine if it's stop loss or take profit based on trigger rule and position
// For simplicity, cancel all matching symbol orders
_, _, err := t.client.FuturesApi.CancelPriceTriggeredOrder(t.ctx, "usdt", fmt.Sprintf("%d", order.Id))
if err != nil {
logger.Warnf(" [Gate] Failed to cancel trigger order %d: %v", order.Id, err)
}
}
return nil
}
// CancelAllOrders cancels all pending orders for a symbol
func (t *GateTrader) CancelAllOrders(symbol string) error {
symbol = t.convertSymbol(symbol)
// Cancel regular orders
_, _, err := t.client.FuturesApi.CancelFuturesOrders(t.ctx, "usdt", symbol, nil)
if err != nil {
// Ignore if no orders to cancel
if !strings.Contains(err.Error(), "ORDER_NOT_FOUND") {
logger.Warnf(" [Gate] Error canceling orders: %v", err)
}
}
// Cancel trigger orders
t.cancelTriggerOrders(symbol, "")
return nil
}
// CancelStopOrders cancels all stop orders (stop loss and take profit)
func (t *GateTrader) CancelStopOrders(symbol string) error {
t.CancelStopLossOrders(symbol)
t.CancelTakeProfitOrders(symbol)
return nil
}
// FormatQuantity formats quantity to correct precision
func (t *GateTrader) FormatQuantity(symbol string, quantity float64) (string, error) {
contract, err := t.getContract(symbol)
if err != nil {
return fmt.Sprintf("%.4f", quantity), nil
}
// Gate uses quanto_multiplier for contract size
quantoMultiplier, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64)
if quantoMultiplier > 0 {
// Calculate number of contracts
numContracts := quantity / quantoMultiplier
return fmt.Sprintf("%.0f", math.Floor(numContracts)), nil
}
return fmt.Sprintf("%.4f", quantity), nil
}
// GetOrderStatus gets the status of an order
func (t *GateTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) {
symbol = t.convertSymbol(symbol)
order, _, err := t.client.FuturesApi.GetFuturesOrder(t.ctx, "usdt", orderID)
if err != nil {
return nil, fmt.Errorf("failed to get order status: %w", err)
}
fillPrice, _ := strconv.ParseFloat(order.FillPrice, 64)
tkFee, _ := strconv.ParseFloat(order.Tkfr, 64)
mkFee, _ := strconv.ParseFloat(order.Mkfr, 64)
totalFee := tkFee + mkFee
// Get quanto_multiplier to convert contracts to actual quantity
quantoMultiplier := 1.0
contract, contractErr := t.getContract(symbol)
if contractErr == nil && contract != nil {
qm, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64)
if qm > 0 {
quantoMultiplier = qm
}
}
// Map status
status := "NEW"
switch order.Status {
case "finished":
if order.FinishAs == "filled" {
status = "FILLED"
} else if order.FinishAs == "cancelled" {
status = "CANCELED"
} else {
status = "CLOSED"
}
case "open":
status = "NEW"
}
side := "BUY"
if order.Size < 0 {
side = "SELL"
}
// Convert contract count to actual token quantity
executedQty := math.Abs(float64(order.Size-order.Left)) * quantoMultiplier
return map[string]interface{}{
"orderId": orderID,
"symbol": t.revertSymbol(symbol),
"status": status,
"avgPrice": fillPrice,
"executedQty": executedQty,
"side": side,
"type": order.Tif,
"time": int64(order.CreateTime * 1000),
"updateTime": int64(order.FinishTime * 1000),
"commission": totalFee,
}, nil
}
// GetOpenOrders gets open/pending orders
func (t *GateTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) {
symbol = t.convertSymbol(symbol)
opts := &gateapi.ListFuturesOrdersOpts{
Contract: optional.NewString(symbol),
}
orders, _, err := t.client.FuturesApi.ListFuturesOrders(t.ctx, "usdt", "open", opts)
if err != nil {
return nil, fmt.Errorf("failed to get open orders: %w", err)
}
// Get quanto_multiplier to convert contracts to actual quantity
quantoMultiplier := 1.0
contract, err := t.getContract(symbol)
if err == nil && contract != nil {
qm, _ := strconv.ParseFloat(contract.QuantoMultiplier, 64)
if qm > 0 {
quantoMultiplier = qm
}
}
var result []types.OpenOrder
for _, order := range orders {
price, _ := strconv.ParseFloat(order.Price, 64)
side := "BUY"
if order.Size < 0 {
side = "SELL"
}
// Convert contract count to actual token quantity
quantity := math.Abs(float64(order.Size)) * quantoMultiplier
result = append(result, types.OpenOrder{
OrderID: fmt.Sprintf("%d", order.Id),
Symbol: t.revertSymbol(order.Contract),
Side: side,
Type: "LIMIT",
Price: price,
Quantity: quantity,
Status: "NEW",
})
}
// Also get trigger orders
triggerOpts := &gateapi.ListPriceTriggeredOrdersOpts{
Contract: optional.NewString(symbol),
}
triggerOrders, _, err := t.client.FuturesApi.ListPriceTriggeredOrders(t.ctx, "usdt", "open", triggerOpts)
if err == nil {
for _, order := range triggerOrders {
triggerPrice, _ := strconv.ParseFloat(order.Trigger.Price, 64)
side := "BUY"
if order.Initial.Size < 0 {
side = "SELL"
}
orderType := "STOP_MARKET"
if order.Trigger.Rule == 2 {
orderType = "TAKE_PROFIT_MARKET"
}
// Convert contract count to actual token quantity
quantity := math.Abs(float64(order.Initial.Size)) * quantoMultiplier
result = append(result, types.OpenOrder{
OrderID: fmt.Sprintf("%d", order.Id),
Symbol: t.revertSymbol(order.Initial.Contract),
Side: side,
Type: orderType,
StopPrice: triggerPrice,
Quantity: quantity,
Status: "NEW",
})
}
}
return result, nil
}
-295
View File
@@ -1,295 +0,0 @@
package hyperliquid
import (
"os"
"testing"
"time"
)
// TestHyperliquidBalanceCalculation tests the balance calculation for Hyperliquid
// including perp, spot, and xyz dex (stocks, forex, metals) accounts
// Run with: TEST_PRIVATE_KEY=xxx TEST_WALLET_ADDR=xxx go test -v -run TestHyperliquidBalanceCalculation ./trader/
func TestHyperliquidBalanceCalculation(t *testing.T) {
// Get credentials from environment
privateKeyHex := os.Getenv("TEST_PRIVATE_KEY")
walletAddr := os.Getenv("TEST_WALLET_ADDR")
if privateKeyHex == "" || walletAddr == "" {
t.Skip("TEST_PRIVATE_KEY and TEST_WALLET_ADDR env vars required")
}
t.Logf("=== Testing Hyperliquid Balance Calculation ===")
t.Logf("Wallet: %s", walletAddr)
// Create trader instance
trader, err := NewHyperliquidTrader(privateKeyHex, walletAddr, false)
if err != nil {
t.Fatalf("Failed to create trader: %v", err)
}
// Test GetBalance
t.Log("\n--- Testing GetBalance ---")
balance, err := trader.GetBalance()
if err != nil {
t.Fatalf("GetBalance failed: %v", err)
}
// Extract values
totalWalletBalance, _ := balance["totalWalletBalance"].(float64)
totalEquity, _ := balance["totalEquity"].(float64)
totalUnrealizedProfit, _ := balance["totalUnrealizedProfit"].(float64)
availableBalance, _ := balance["availableBalance"].(float64)
spotBalance, _ := balance["spotBalance"].(float64)
xyzDexBalance, _ := balance["xyzDexBalance"].(float64)
xyzDexUnrealizedPnl, _ := balance["xyzDexUnrealizedPnl"].(float64)
perpAccountValue, _ := balance["perpAccountValue"].(float64)
t.Logf("\n📊 Balance Results:")
t.Logf(" Perp Account Value: %.4f USDC", perpAccountValue)
t.Logf(" Spot Balance: %.4f USDC", spotBalance)
t.Logf(" xyz Dex Balance: %.4f USDC", xyzDexBalance)
t.Logf(" xyz Dex Unrealized PnL: %.4f USDC", xyzDexUnrealizedPnl)
t.Logf(" ---")
t.Logf(" Total Wallet Balance: %.4f USDC", totalWalletBalance)
t.Logf(" Total Unrealized PnL: %.4f USDC", totalUnrealizedProfit)
t.Logf(" Total Equity: %.4f USDC", totalEquity)
t.Logf(" Available Balance: %.4f USDC", availableBalance)
// Verify calculation: totalEquity should equal perpAccountValue + spotBalance + xyzDexBalance
expectedEquity := perpAccountValue + spotBalance + xyzDexBalance
t.Logf("\n🔍 Verification:")
t.Logf(" Expected Equity (Perp + Spot + xyz): %.4f", expectedEquity)
t.Logf(" Actual Total Equity: %.4f", totalEquity)
if abs(totalEquity-expectedEquity) > 0.01 {
t.Errorf("❌ Equity mismatch! Expected %.4f, got %.4f", expectedEquity, totalEquity)
} else {
t.Logf("✅ Equity calculation correct!")
}
// Verify: totalWalletBalance + totalUnrealizedProfit should equal totalEquity
calculatedEquity := totalWalletBalance + totalUnrealizedProfit
t.Logf("\n🔍 Secondary Verification:")
t.Logf(" Wallet + Unrealized = %.4f + %.4f = %.4f", totalWalletBalance, totalUnrealizedProfit, calculatedEquity)
t.Logf(" Total Equity: %.4f", totalEquity)
if abs(calculatedEquity-totalEquity) > 0.01 {
t.Errorf("❌ Secondary check failed! Wallet+Unrealized=%.4f != Equity=%.4f", calculatedEquity, totalEquity)
} else {
t.Logf("✅ Secondary verification passed!")
}
// Test GetPositions
t.Log("\n--- Testing GetPositions ---")
positions, err := trader.GetPositions()
if err != nil {
t.Fatalf("GetPositions failed: %v", err)
}
t.Logf("Found %d positions:", len(positions))
totalPositionValue := 0.0
totalPositionPnL := 0.0
for i, pos := range positions {
symbol, _ := pos["symbol"].(string)
side, _ := pos["side"].(string)
positionAmt, _ := pos["positionAmt"].(float64)
entryPrice, _ := pos["entryPrice"].(float64)
markPrice, _ := pos["markPrice"].(float64)
unrealizedPnL, _ := pos["unRealizedProfit"].(float64)
leverage, _ := pos["leverage"].(float64)
isXyzDex, _ := pos["isXyzDex"].(bool)
posValue := positionAmt * markPrice
totalPositionValue += posValue
totalPositionPnL += unrealizedPnL
assetType := "Crypto"
if isXyzDex {
assetType = "xyz Dex"
}
t.Logf(" [%d] %s (%s)", i+1, symbol, assetType)
t.Logf(" Side: %s, Qty: %.4f, Leverage: %.0fx", side, positionAmt, leverage)
t.Logf(" Entry: %.4f, Mark: %.4f", entryPrice, markPrice)
t.Logf(" Value: %.4f, PnL: %.4f", posValue, unrealizedPnL)
// Verify xyz dex position has valid entry/mark prices
if isXyzDex {
if entryPrice == 0 {
t.Errorf("❌ xyz dex position %s has zero entry price!", symbol)
}
if markPrice == 0 {
t.Errorf("❌ xyz dex position %s has zero mark price!", symbol)
}
}
}
t.Logf("\n📊 Position Summary:")
t.Logf(" Total Position Value: %.4f USDC", totalPositionValue)
t.Logf(" Total Position PnL: %.4f USDC", totalPositionPnL)
// Compare position PnL with balance unrealized PnL
t.Logf("\n🔍 PnL Comparison:")
t.Logf(" Balance Unrealized PnL: %.4f", totalUnrealizedProfit)
t.Logf(" Position Sum PnL: %.4f", totalPositionPnL)
if abs(totalUnrealizedProfit-totalPositionPnL) > 0.1 {
t.Logf("⚠️ PnL mismatch (may be due to funding fees or timing)")
} else {
t.Logf("✅ PnL values match!")
}
}
// TestXyzDexBalanceDirectQuery directly queries xyz dex balance for debugging
func TestXyzDexBalanceDirectQuery(t *testing.T) {
privateKeyHex := os.Getenv("TEST_PRIVATE_KEY")
walletAddr := os.Getenv("TEST_WALLET_ADDR")
if privateKeyHex == "" || walletAddr == "" {
t.Skip("TEST_PRIVATE_KEY and TEST_WALLET_ADDR env vars required")
}
trader, err := NewHyperliquidTrader(privateKeyHex, walletAddr, false)
if err != nil {
t.Fatalf("Failed to create trader: %v", err)
}
t.Log("=== Direct xyz Dex Balance Query ===")
accountValue, unrealizedPnl, positions, err := trader.getXYZDexBalance()
if err != nil {
t.Fatalf("getXYZDexBalance failed: %v", err)
}
t.Logf("xyz Dex Account Value: %.4f", accountValue)
t.Logf("xyz Dex Unrealized PnL: %.4f", unrealizedPnl)
t.Logf("xyz Dex Wallet Balance: %.4f", accountValue-unrealizedPnl)
t.Logf("xyz Dex Positions: %d", len(positions))
for i, pos := range positions {
entryPx := "nil"
if pos.Position.EntryPx != nil {
entryPx = *pos.Position.EntryPx
}
liqPx := "nil"
if pos.Position.LiquidationPx != nil {
liqPx = *pos.Position.LiquidationPx
}
t.Logf(" [%d] %s:", i+1, pos.Position.Coin)
t.Logf(" Size: %s", pos.Position.Szi)
t.Logf(" Entry Price: %s", entryPx)
t.Logf(" Position Value: %s", pos.Position.PositionValue)
t.Logf(" Unrealized PnL: %s", pos.Position.UnrealizedPnl)
t.Logf(" Liquidation Price: %s", liqPx)
t.Logf(" Leverage: %d (%s)", pos.Position.Leverage.Value, pos.Position.Leverage.Type)
}
}
// TestEquityAfterOpeningPosition simulates opening a position and verifies equity
func TestEquityAfterOpeningPosition(t *testing.T) {
privateKeyHex := os.Getenv("TEST_PRIVATE_KEY")
walletAddr := os.Getenv("TEST_WALLET_ADDR")
if privateKeyHex == "" || walletAddr == "" {
t.Skip("TEST_PRIVATE_KEY and TEST_WALLET_ADDR env vars required")
}
if os.Getenv("XYZ_DEX_LIVE_TEST") != "1" {
t.Skip("Set XYZ_DEX_LIVE_TEST=1 to run live position test")
}
trader, err := NewHyperliquidTrader(privateKeyHex, walletAddr, false)
if err != nil {
t.Fatalf("Failed to create trader: %v", err)
}
// Step 1: Record initial balance
t.Log("=== Step 1: Record Initial Balance ===")
initialBalance, _ := trader.GetBalance()
initialEquity, _ := initialBalance["totalEquity"].(float64)
t.Logf("Initial Equity: %.4f", initialEquity)
// Step 2: Fetch xyz meta
if err := trader.fetchXyzMeta(); err != nil {
t.Fatalf("Failed to fetch xyz meta: %v", err)
}
// Step 3: Get current price and place a small order
price, err := trader.getXyzMarketPrice("xyz:SILVER")
if err != nil {
t.Fatalf("Failed to get price: %v", err)
}
t.Logf("Current xyz:SILVER price: %.4f", price)
// Place a small buy order (minimum ~$10)
testSize := 0.14
testPrice := price * 1.05 // 5% above for IOC
t.Log("\n=== Step 2: Place Test Order ===")
t.Logf("Opening position: xyz:SILVER BUY %.4f @ %.4f", testSize, testPrice)
err = trader.placeXyzOrder("xyz:SILVER", true, testSize, testPrice, false)
if err != nil {
t.Logf("Order result: %v", err)
// Even if IOC doesn't fill, continue to check balance
}
// Wait a moment for the order to process
time.Sleep(2 * time.Second)
// Step 3: Check balance after order
t.Log("\n=== Step 3: Check Balance After Order ===")
afterBalance, _ := trader.GetBalance()
afterEquity, _ := afterBalance["totalEquity"].(float64)
afterPerpAV, _ := afterBalance["perpAccountValue"].(float64)
afterXyzAV, _ := afterBalance["xyzDexBalance"].(float64)
t.Logf("After Order:")
t.Logf(" Perp Account Value: %.4f", afterPerpAV)
t.Logf(" xyz Dex Balance: %.4f", afterXyzAV)
t.Logf(" Total Equity: %.4f", afterEquity)
equityChange := afterEquity - initialEquity
t.Logf("\nEquity Change: %.4f (%.2f%%)", equityChange, (equityChange/initialEquity)*100)
// Equity should not change significantly (only by trading fees/slippage)
if abs(equityChange) > initialEquity*0.05 { // More than 5% change is suspicious
t.Errorf("❌ Equity changed too much! Initial=%.4f, After=%.4f, Change=%.4f",
initialEquity, afterEquity, equityChange)
} else {
t.Logf("✅ Equity change is within acceptable range")
}
// Step 4: Close position if opened
t.Log("\n=== Step 4: Close Position ===")
positions, _ := trader.GetPositions()
for _, pos := range positions {
symbol, _ := pos["symbol"].(string)
if symbol == "xyz:SILVER" {
posAmt, _ := pos["positionAmt"].(float64)
if posAmt > 0 {
closePrice := price * 0.95 // 5% below for IOC sell
t.Logf("Closing position: SELL %.4f @ %.4f", posAmt, closePrice)
trader.placeXyzOrder("xyz:SILVER", false, posAmt, closePrice, true)
}
}
}
time.Sleep(2 * time.Second)
// Final balance check
t.Log("\n=== Step 5: Final Balance ===")
finalBalance, _ := trader.GetBalance()
finalEquity, _ := finalBalance["totalEquity"].(float64)
t.Logf("Final Equity: %.4f", finalEquity)
t.Logf("Net Change: %.4f", finalEquity-initialEquity)
}
func abs(x float64) float64 {
if x < 0 {
return -x
}
return x
}
File diff suppressed because it is too large Load Diff
+592
View File
@@ -0,0 +1,592 @@
package hyperliquid
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"nofx/logger"
"nofx/trader/types"
"strconv"
"strings"
"time"
)
// GetBalance gets account balance
func (t *HyperliquidTrader) GetBalance() (map[string]interface{}, error) {
logger.Infof("🔄 Calling Hyperliquid API to get account balance...")
// Step 1: Query Spot account balance
spotState, err := t.exchange.Info().SpotUserState(t.ctx, t.walletAddr)
var spotUSDCBalance float64 = 0.0
if err != nil {
logger.Infof("⚠️ Failed to query Spot balance (may have no spot assets): %v", err)
} else if spotState != nil && len(spotState.Balances) > 0 {
for _, balance := range spotState.Balances {
if balance.Coin == "USDC" {
spotUSDCBalance, _ = strconv.ParseFloat(balance.Total, 64)
logger.Infof("✓ Found Spot balance: %.2f USDC", spotUSDCBalance)
break
}
}
}
// Step 2: Query Perpetuals contract account status
accountState, err := t.exchange.Info().UserState(t.ctx, t.walletAddr)
if err != nil {
logger.Infof("❌ Hyperliquid Perpetuals API call failed: %v", err)
return nil, fmt.Errorf("failed to get account information: %w", err)
}
// Parse balance information (MarginSummary fields are all strings)
result := make(map[string]interface{})
// Step 3: Dynamically select correct summary based on margin mode (CrossMarginSummary or MarginSummary)
var accountValue, totalMarginUsed float64
var summaryType string
var summary interface{}
if t.isCrossMargin {
// Cross margin mode: use CrossMarginSummary
accountValue, _ = strconv.ParseFloat(accountState.CrossMarginSummary.AccountValue, 64)
totalMarginUsed, _ = strconv.ParseFloat(accountState.CrossMarginSummary.TotalMarginUsed, 64)
summaryType = "CrossMarginSummary (cross margin)"
summary = accountState.CrossMarginSummary
} else {
// Isolated margin mode: use MarginSummary
accountValue, _ = strconv.ParseFloat(accountState.MarginSummary.AccountValue, 64)
totalMarginUsed, _ = strconv.ParseFloat(accountState.MarginSummary.TotalMarginUsed, 64)
summaryType = "MarginSummary (isolated margin)"
summary = accountState.MarginSummary
}
// Debug: Print complete summary structure returned by API
summaryJSON, _ := json.MarshalIndent(summary, " ", " ")
logger.Infof("🔍 [DEBUG] Hyperliquid API %s complete data:", summaryType)
logger.Infof("%s", string(summaryJSON))
// Critical fix: Accumulate actual unrealized PnL from all positions
totalUnrealizedPnl := 0.0
for _, assetPos := range accountState.AssetPositions {
unrealizedPnl, _ := strconv.ParseFloat(assetPos.Position.UnrealizedPnl, 64)
totalUnrealizedPnl += unrealizedPnl
}
// Correctly understand Hyperliquid fields:
// AccountValue = Total account equity (includes idle funds + position value + unrealized PnL)
// TotalMarginUsed = Margin used by positions (included in AccountValue, for display only)
//
// To be compatible with auto_types.go calculation logic (totalEquity = totalWalletBalance + totalUnrealizedProfit)
// Need to return "wallet balance without unrealized PnL"
walletBalanceWithoutUnrealized := accountValue - totalUnrealizedPnl
// Step 4: Use Withdrawable field (PR #443)
// Withdrawable is the official real withdrawable balance, more reliable than simple calculation
availableBalance := 0.0
if accountState.Withdrawable != "" {
withdrawable, err := strconv.ParseFloat(accountState.Withdrawable, 64)
if err == nil && withdrawable > 0 {
availableBalance = withdrawable
logger.Infof("✓ Using Withdrawable as available balance: %.2f", availableBalance)
}
}
// Fallback: If no Withdrawable, use simple calculation
if availableBalance == 0 && accountState.Withdrawable == "" {
availableBalance = accountValue - totalMarginUsed
if availableBalance < 0 {
logger.Infof("⚠️ Calculated available balance is negative (%.2f), reset to 0", availableBalance)
availableBalance = 0
}
}
// Step 5: Query xyz dex balance (stock perps, forex, commodities)
var xyzAccountValue, xyzUnrealizedPnl float64
var xyzPositions []xyzAssetPosition
xyzAccountValue, xyzUnrealizedPnl, xyzPositions, err = t.getXYZDexBalance()
if err != nil {
// xyz dex query failed - log warning but don't fail the entire balance query
logger.Infof("⚠️ Failed to query xyz dex balance: %v", err)
}
// Always log xyz dex state for debugging
logger.Infof("🔍 xyz dex state: accountValue=%.4f, unrealizedPnl=%.4f, positions=%d",
xyzAccountValue, xyzUnrealizedPnl, len(xyzPositions))
for _, pos := range xyzPositions {
entryPx := "nil"
if pos.Position.EntryPx != nil {
entryPx = *pos.Position.EntryPx
}
logger.Infof(" └─ %s: size=%s, entryPx=%s, posValue=%s, pnl=%s",
pos.Position.Coin, pos.Position.Szi, entryPx, pos.Position.PositionValue, pos.Position.UnrealizedPnl)
}
xyzWalletBalance := xyzAccountValue - xyzUnrealizedPnl
// Step 6: Correctly handle Spot + Perpetuals + xyz dex balance
// Important: Each account is independent, manual transfers required
totalWalletBalance := walletBalanceWithoutUnrealized + spotUSDCBalance + xyzWalletBalance
totalUnrealizedPnlAll := totalUnrealizedPnl + xyzUnrealizedPnl
// Calculate total equity properly: perpAccountValue + spotUSDCBalance + xyzAccountValue
// Note: totalWalletBalance + totalUnrealizedPnlAll should equal this
totalEquityCalculated := accountValue + spotUSDCBalance + xyzAccountValue
// Step 7: Unified Account mode - Spot USDC is used as collateral for Perps
// In this mode, available balance includes Spot USDC since it can be used for Perp margin
if t.isUnifiedAccount && spotUSDCBalance > 0 {
// Add Spot balance to available balance for trading
availableBalance = availableBalance + spotUSDCBalance
logger.Infof("✓ Unified Account: Spot %.2f USDC added to available balance (total: %.2f)",
spotUSDCBalance, availableBalance)
}
// Suppress unused variable warning
_ = totalUnrealizedPnlAll
result["totalWalletBalance"] = totalWalletBalance // Total assets (Perp + Spot + xyz) - unrealized
result["totalEquity"] = totalEquityCalculated // Total equity = Perp AV + Spot + xyz AV
result["availableBalance"] = availableBalance // Available balance (Perp + Spot if unified)
result["totalUnrealizedProfit"] = totalUnrealizedPnlAll // Unrealized PnL (Perpetuals + xyz)
result["spotBalance"] = spotUSDCBalance // Spot balance
result["xyzDexBalance"] = xyzAccountValue // xyz dex equity (stock perps, forex, commodities)
result["xyzDexUnrealizedPnl"] = xyzUnrealizedPnl // xyz dex unrealized PnL
result["perpAccountValue"] = accountValue // Perp account value for debugging
logger.Infof("✓ Hyperliquid complete account:")
logger.Infof(" • Spot balance: %.2f USDC", spotUSDCBalance)
logger.Infof(" • Perpetuals equity: %.2f USDC (wallet %.2f + unrealized %.2f)",
accountValue,
walletBalanceWithoutUnrealized,
totalUnrealizedPnl)
logger.Infof(" • Perpetuals available balance: %.2f USDC", availableBalance)
logger.Infof(" • Margin used: %.2f USDC", totalMarginUsed)
logger.Infof(" • xyz dex equity: %.2f USDC (wallet %.2f + unrealized %.2f)",
xyzAccountValue,
xyzWalletBalance,
xyzUnrealizedPnl)
logger.Infof(" • Total assets (Perp+Spot+xyz): %.2f USDC", totalWalletBalance)
logger.Infof(" ⭐ Total: %.2f USDC | Perp: %.2f | Spot: %.2f | xyz: %.2f",
totalWalletBalance, availableBalance, spotUSDCBalance, xyzAccountValue)
return result, nil
}
// xyzDexState represents the clearinghouse state for xyz dex
type xyzDexState struct {
MarginSummary *xyzMarginSummary `json:"marginSummary,omitempty"`
CrossMarginSummary *xyzMarginSummary `json:"crossMarginSummary,omitempty"`
Withdrawable string `json:"withdrawable,omitempty"`
AssetPositions []xyzAssetPosition `json:"assetPositions,omitempty"`
}
type xyzMarginSummary struct {
AccountValue string `json:"accountValue"`
TotalMarginUsed string `json:"totalMarginUsed"`
}
type xyzAssetPosition struct {
Position struct {
Coin string `json:"coin"`
Szi string `json:"szi"`
EntryPx *string `json:"entryPx"`
PositionValue string `json:"positionValue"`
UnrealizedPnl string `json:"unrealizedPnl"`
LiquidationPx *string `json:"liquidationPx"`
Leverage struct {
Type string `json:"type"`
Value int `json:"value"`
} `json:"leverage"`
} `json:"position"`
}
// getXYZDexBalance queries the xyz dex balance (stock perps, forex, commodities)
func (t *HyperliquidTrader) getXYZDexBalance() (accountValue float64, unrealizedPnl float64, positions []xyzAssetPosition, err error) {
// Build request for xyz dex clearinghouse state
reqBody := map[string]interface{}{
"type": "clearinghouseState",
"user": t.walletAddr,
"dex": "xyz",
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return 0, 0, nil, fmt.Errorf("failed to marshal request: %w", err)
}
// Determine API URL
apiURL := "https://api.hyperliquid.xyz/info"
// Note: xyz dex may not be available on testnet
req, err := http.NewRequestWithContext(t.ctx, "POST", apiURL, bytes.NewBuffer(jsonBody))
if err != nil {
return 0, 0, nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return 0, 0, nil, fmt.Errorf("failed to execute request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return 0, 0, nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return 0, 0, nil, fmt.Errorf("xyz dex API error (status %d): %s", resp.StatusCode, string(body))
}
var state xyzDexState
if err := json.Unmarshal(body, &state); err != nil {
return 0, 0, nil, fmt.Errorf("failed to parse response: %w", err)
}
// Parse account value - xyz dex uses MarginSummary for isolated margin mode
// CrossMarginSummary may exist but with 0 values, so check MarginSummary first
if state.MarginSummary != nil && state.MarginSummary.AccountValue != "" {
av, _ := strconv.ParseFloat(state.MarginSummary.AccountValue, 64)
if av > 0 {
accountValue = av
}
}
// Fallback to CrossMarginSummary if MarginSummary is 0
if accountValue == 0 && state.CrossMarginSummary != nil && state.CrossMarginSummary.AccountValue != "" {
accountValue, _ = strconv.ParseFloat(state.CrossMarginSummary.AccountValue, 64)
}
// Calculate total unrealized PnL from positions
for _, pos := range state.AssetPositions {
pnl, _ := strconv.ParseFloat(pos.Position.UnrealizedPnl, 64)
unrealizedPnl += pnl
}
return accountValue, unrealizedPnl, state.AssetPositions, nil
}
// GetMarketPrice gets market price (supports both crypto and xyz dex assets)
func (t *HyperliquidTrader) GetMarketPrice(symbol string) (float64, error) {
coin := convertSymbolToHyperliquid(symbol)
// Check if this is an xyz dex asset
if strings.HasPrefix(coin, "xyz:") {
return t.getXyzMarketPrice(coin)
}
// Get all market prices for crypto
allMids, err := t.exchange.Info().AllMids(t.ctx)
if err != nil {
return 0, fmt.Errorf("failed to get price: %w", err)
}
// Find price for corresponding coin (allMids is map[string]string)
if priceStr, ok := allMids[coin]; ok {
priceFloat, err := strconv.ParseFloat(priceStr, 64)
if err == nil {
return priceFloat, nil
}
return 0, fmt.Errorf("price format error: %v", err)
}
return 0, fmt.Errorf("price not found for %s", symbol)
}
// getXyzMarketPrice gets market price for xyz dex assets
func (t *HyperliquidTrader) getXyzMarketPrice(coin string) (float64, error) {
// Build request for xyz dex allMids
reqBody := map[string]string{
"type": "allMids",
"dex": "xyz",
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return 0, fmt.Errorf("failed to marshal request: %w", err)
}
apiURL := "https://api.hyperliquid.xyz/info"
req, err := http.NewRequestWithContext(t.ctx, "POST", apiURL, bytes.NewBuffer(jsonBody))
if err != nil {
return 0, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return 0, fmt.Errorf("failed to execute request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return 0, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return 0, fmt.Errorf("xyz dex allMids API error (status %d): %s", resp.StatusCode, string(body))
}
var mids map[string]string
if err := json.Unmarshal(body, &mids); err != nil {
return 0, fmt.Errorf("failed to parse response: %w", err)
}
// The API returns keys with xyz: prefix, so ensure the coin has it
lookupKey := coin
if !strings.HasPrefix(lookupKey, "xyz:") {
lookupKey = "xyz:" + lookupKey
}
if priceStr, ok := mids[lookupKey]; ok {
priceFloat, err := strconv.ParseFloat(priceStr, 64)
if err == nil {
return priceFloat, nil
}
return 0, fmt.Errorf("price format error: %v", err)
}
return 0, fmt.Errorf("xyz dex price not found for %s (lookup key: %s)", coin, lookupKey)
}
// GetOrderStatus gets order status
// Hyperliquid uses IOC orders, usually filled or cancelled immediately
// For completed orders, need to query historical records
func (t *HyperliquidTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) {
// Hyperliquid's IOC orders are completed almost immediately
// If order was placed through this system, returned status will be FILLED
// Try to query open orders to determine if still pending
coin := convertSymbolToHyperliquid(symbol)
// First check if in open orders
openOrders, err := t.exchange.Info().OpenOrders(t.ctx, t.walletAddr)
if err != nil {
// If query fails, assume order is completed
return map[string]interface{}{
"orderId": orderID,
"status": "FILLED",
"avgPrice": 0.0,
"executedQty": 0.0,
"commission": 0.0,
}, nil
}
// Check if order is in open orders list
for _, order := range openOrders {
if order.Coin == coin && fmt.Sprintf("%d", order.Oid) == orderID {
// Order is still pending
return map[string]interface{}{
"orderId": orderID,
"status": "NEW",
"avgPrice": 0.0,
"executedQty": 0.0,
"commission": 0.0,
}, nil
}
}
// Order not in open list, meaning completed or cancelled
// Hyperliquid IOC orders not in open list are usually filled
return map[string]interface{}{
"orderId": orderID,
"status": "FILLED",
"avgPrice": 0.0, // Hyperliquid does not directly return execution price, need to get from position info
"executedQty": 0.0,
"commission": 0.0,
}, nil
}
// GetClosedPnL gets recent closing trades from Hyperliquid
// Note: Hyperliquid does NOT have a position history API, only fill history.
// This returns individual closing trades for real-time position closure detection.
func (t *HyperliquidTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) {
trades, err := t.GetTrades(startTime, limit)
if err != nil {
return nil, err
}
// Filter only closing trades (realizedPnl != 0)
var records []types.ClosedPnLRecord
for _, trade := range trades {
if trade.RealizedPnL == 0 {
continue
}
// Determine side (Hyperliquid uses one-way mode)
side := "long"
if trade.Side == "SELL" || trade.Side == "Sell" {
side = "long" // Selling closes long
} else {
side = "short" // Buying closes short
}
// Calculate entry price from PnL
var entryPrice float64
if trade.Quantity > 0 {
if side == "long" {
entryPrice = trade.Price - trade.RealizedPnL/trade.Quantity
} else {
entryPrice = trade.Price + trade.RealizedPnL/trade.Quantity
}
}
records = append(records, types.ClosedPnLRecord{
Symbol: trade.Symbol,
Side: side,
EntryPrice: entryPrice,
ExitPrice: trade.Price,
Quantity: trade.Quantity,
RealizedPnL: trade.RealizedPnL,
Fee: trade.Fee,
ExitTime: trade.Time,
EntryTime: trade.Time,
OrderID: trade.TradeID,
ExchangeID: trade.TradeID,
CloseType: "unknown",
})
}
return records, nil
}
// GetTrades retrieves trade history from Hyperliquid
func (t *HyperliquidTrader) GetTrades(startTime time.Time, limit int) ([]types.TradeRecord, error) {
// Use UserFillsByTime API
startTimeMs := startTime.UnixMilli()
fills, err := t.exchange.Info().UserFillsByTime(t.ctx, t.walletAddr, startTimeMs, nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to get user fills: %w", err)
}
var trades []types.TradeRecord
for _, fill := range fills {
price, _ := strconv.ParseFloat(fill.Price, 64)
qty, _ := strconv.ParseFloat(fill.Size, 64)
fee, _ := strconv.ParseFloat(fill.Fee, 64)
pnl, _ := strconv.ParseFloat(fill.ClosedPnl, 64)
// Determine side: "B" = Buy, "S" = Sell (or "A" = Ask, "B" = Bid)
var side string
if fill.Side == "B" || fill.Side == "Buy" || fill.Side == "bid" {
side = "BUY"
} else {
side = "SELL"
}
// Parse Dir field to get order action
// Hyperliquid Dir values: "Open Long", "Open Short", "Close Long", "Close Short"
var orderAction string
switch strings.ToLower(fill.Dir) {
case "open long":
orderAction = "open_long"
case "open short":
orderAction = "open_short"
case "close long":
orderAction = "close_long"
case "close short":
orderAction = "close_short"
default:
// Fallback: use RealizedPnL if Dir is missing/unknown
if pnl != 0 {
if side == "BUY" {
orderAction = "close_short"
} else {
orderAction = "close_long"
}
} else {
if side == "BUY" {
orderAction = "open_long"
} else {
orderAction = "open_short"
}
}
}
// Hyperliquid uses one-way mode, so PositionSide is "BOTH"
trade := types.TradeRecord{
TradeID: strconv.FormatInt(fill.Tid, 10),
Symbol: fill.Coin,
Side: side,
PositionSide: "BOTH", // Hyperliquid doesn't have hedge mode
OrderAction: orderAction,
Price: price,
Quantity: qty,
RealizedPnL: pnl,
Fee: fee,
Time: time.UnixMilli(fill.Time).UTC(),
}
trades = append(trades, trade)
}
return trades, nil
}
// GetOpenOrders gets all open/pending orders for a symbol
func (t *HyperliquidTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) {
openOrders, err := t.exchange.Info().OpenOrders(t.ctx, t.walletAddr)
if err != nil {
return nil, fmt.Errorf("failed to get open orders: %w", err)
}
var result []types.OpenOrder
for _, order := range openOrders {
if order.Coin != symbol {
continue
}
side := "BUY"
if order.Side == "A" {
side = "SELL"
}
result = append(result, types.OpenOrder{
OrderID: fmt.Sprintf("%d", order.Oid),
Symbol: order.Coin,
Side: side,
PositionSide: "",
Type: "LIMIT",
Price: order.LimitPx,
StopPrice: 0,
Quantity: order.Size,
Status: "NEW",
})
}
return result, nil
}
// GetOrderBook gets the order book for a symbol
// Implements GridTrader interface
func (t *HyperliquidTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) {
coin := convertSymbolToHyperliquid(symbol)
l2Book, err := t.exchange.Info().L2Snapshot(t.ctx, coin)
if err != nil {
return nil, nil, fmt.Errorf("failed to get order book: %w", err)
}
if l2Book == nil || len(l2Book.Levels) < 2 {
return nil, nil, fmt.Errorf("invalid order book data")
}
// Parse bids (first level array)
for i, level := range l2Book.Levels[0] {
if i >= depth {
break
}
bids = append(bids, []float64{level.Px, level.Sz})
}
// Parse asks (second level array)
for i, level := range l2Book.Levels[1] {
if i >= depth {
break
}
asks = append(asks, []float64{level.Px, level.Sz})
}
return bids, asks, nil
}
File diff suppressed because it is too large Load Diff
+167
View File
@@ -0,0 +1,167 @@
package hyperliquid
import (
"fmt"
"nofx/logger"
"strconv"
"strings"
)
// GetPositions gets all positions (including xyz dex positions)
func (t *HyperliquidTrader) GetPositions() ([]map[string]interface{}, error) {
// Get account status
accountState, err := t.exchange.Info().UserState(t.ctx, t.walletAddr)
if err != nil {
return nil, fmt.Errorf("failed to get positions: %w", err)
}
var result []map[string]interface{}
// Iterate through all perp positions
for _, assetPos := range accountState.AssetPositions {
position := assetPos.Position
// Position amount (string type)
posAmt, _ := strconv.ParseFloat(position.Szi, 64)
if posAmt == 0 {
continue // Skip positions with zero amount
}
posMap := make(map[string]interface{})
// Normalize symbol format (Hyperliquid uses "BTC", we convert to "BTCUSDT")
symbol := position.Coin + "USDT"
posMap["symbol"] = symbol
// Position amount and direction
if posAmt > 0 {
posMap["side"] = "long"
posMap["positionAmt"] = posAmt
} else {
posMap["side"] = "short"
posMap["positionAmt"] = -posAmt // Convert to positive number
}
// Price information (EntryPx and LiquidationPx are pointer types)
var entryPrice, liquidationPx float64
if position.EntryPx != nil {
entryPrice, _ = strconv.ParseFloat(*position.EntryPx, 64)
}
if position.LiquidationPx != nil {
liquidationPx, _ = strconv.ParseFloat(*position.LiquidationPx, 64)
}
positionValue, _ := strconv.ParseFloat(position.PositionValue, 64)
unrealizedPnl, _ := strconv.ParseFloat(position.UnrealizedPnl, 64)
// Calculate mark price (positionValue / abs(posAmt))
var markPrice float64
if posAmt != 0 {
markPrice = positionValue / absFloat(posAmt)
}
posMap["entryPrice"] = entryPrice
posMap["markPrice"] = markPrice
posMap["unRealizedProfit"] = unrealizedPnl
posMap["leverage"] = float64(position.Leverage.Value)
posMap["liquidationPrice"] = liquidationPx
result = append(result, posMap)
}
// Also get xyz dex positions (stocks, forex, commodities)
_, _, xyzPositions, err := t.getXYZDexBalance()
if err != nil {
// xyz dex query failed - log warning but don't fail
logger.Infof("⚠️ Failed to get xyz dex positions: %v", err)
} else {
for _, pos := range xyzPositions {
posAmt, _ := strconv.ParseFloat(pos.Position.Szi, 64)
if posAmt == 0 {
continue
}
posMap := make(map[string]interface{})
// xyz dex positions - the API returns coin names with xyz: prefix (e.g., "xyz:SILVER")
// Only add prefix if not already present
symbol := pos.Position.Coin
if !strings.HasPrefix(symbol, "xyz:") {
symbol = "xyz:" + symbol
}
posMap["symbol"] = symbol
if posAmt > 0 {
posMap["side"] = "long"
posMap["positionAmt"] = posAmt
} else {
posMap["side"] = "short"
posMap["positionAmt"] = -posAmt
}
// Parse price information
var entryPrice, liquidationPx float64
if pos.Position.EntryPx != nil {
entryPrice, _ = strconv.ParseFloat(*pos.Position.EntryPx, 64)
}
if pos.Position.LiquidationPx != nil {
liquidationPx, _ = strconv.ParseFloat(*pos.Position.LiquidationPx, 64)
}
positionValue, _ := strconv.ParseFloat(pos.Position.PositionValue, 64)
unrealizedPnl, _ := strconv.ParseFloat(pos.Position.UnrealizedPnl, 64)
// Calculate mark price from position value
var markPrice float64
if posAmt != 0 {
markPrice = positionValue / absFloat(posAmt)
}
// Get leverage (default to 1 if not available)
leverage := float64(pos.Position.Leverage.Value)
if leverage == 0 {
leverage = 1.0
}
posMap["entryPrice"] = entryPrice
posMap["markPrice"] = markPrice
posMap["unRealizedProfit"] = unrealizedPnl
posMap["leverage"] = leverage
posMap["liquidationPrice"] = liquidationPx
posMap["isXyzDex"] = true // Mark as xyz dex position
result = append(result, posMap)
}
}
return result, nil
}
// SetMarginMode sets margin mode (set together with SetLeverage)
func (t *HyperliquidTrader) SetMarginMode(symbol string, isCrossMargin bool) error {
// Hyperliquid's margin mode is set in SetLeverage, only record here
t.isCrossMargin = isCrossMargin
marginModeStr := "cross margin"
if !isCrossMargin {
marginModeStr = "isolated margin"
}
logger.Infof(" ✓ %s will use %s mode", symbol, marginModeStr)
return nil
}
// SetLeverage sets leverage
func (t *HyperliquidTrader) SetLeverage(symbol string, leverage int) error {
// Hyperliquid symbol format (remove USDT suffix)
coin := convertSymbolToHyperliquid(symbol)
// Call UpdateLeverage (leverage int, name string, isCross bool)
// Third parameter: true=cross margin mode, false=isolated margin mode
_, err := t.exchange.UpdateLeverage(t.ctx, leverage, coin, t.isCrossMargin)
if err != nil {
return fmt.Errorf("failed to set leverage: %w", err)
}
logger.Infof(" ✓ %s leverage switched to %dx", symbol, leverage)
return nil
}
+147
View File
@@ -0,0 +1,147 @@
package hyperliquid
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"nofx/logger"
"strings"
"time"
)
// refreshMetaIfNeeded refreshes meta information when invalid (triggered when Asset ID is 0)
func (t *HyperliquidTrader) refreshMetaIfNeeded(coin string) error {
assetID := t.exchange.Info().NameToAsset(coin)
if assetID != 0 {
return nil // Meta is normal, no refresh needed
}
logger.Infof("⚠️ Asset ID for %s is 0, attempting to refresh Meta information...", coin)
// Refresh Meta information
meta, err := t.exchange.Info().Meta(t.ctx)
if err != nil {
return fmt.Errorf("failed to refresh Meta information: %w", err)
}
// Concurrency safe: Use write lock to protect meta field update
t.metaMutex.Lock()
t.meta = meta
t.metaMutex.Unlock()
logger.Infof("✅ Meta information refreshed, contains %d assets", len(meta.Universe))
// Verify Asset ID after refresh
assetID = t.exchange.Info().NameToAsset(coin)
if assetID == 0 {
return fmt.Errorf("❌ Even after refreshing Meta, Asset ID for %s is still 0. Possible reasons:\n"+
" 1. This coin is not listed on Hyperliquid\n"+
" 2. Coin name is incorrect (should be BTC not BTCUSDT)\n"+
" 3. API connection issue", coin)
}
logger.Infof("✅ Asset ID check passed after refresh: %s -> %d", coin, assetID)
return nil
}
// fetchXyzMeta fetches metadata for xyz dex assets (stocks, forex, commodities)
func (t *HyperliquidTrader) fetchXyzMeta() error {
// Build request for xyz dex meta
reqBody := map[string]string{
"type": "meta",
"dex": "xyz",
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
apiURL := "https://api.hyperliquid.xyz/info"
req, err := http.NewRequestWithContext(t.ctx, "POST", apiURL, bytes.NewBuffer(jsonBody))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to execute request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("xyz dex meta API error (status %d): %s", resp.StatusCode, string(body))
}
var meta xyzDexMeta
if err := json.Unmarshal(body, &meta); err != nil {
return fmt.Errorf("failed to parse response: %w", err)
}
t.xyzMetaMutex.Lock()
t.xyzMeta = &meta
t.xyzMetaMutex.Unlock()
logger.Infof("✅ xyz dex meta fetched, contains %d assets", len(meta.Universe))
return nil
}
// getXyzSzDecimals gets quantity precision for xyz dex asset
func (t *HyperliquidTrader) getXyzSzDecimals(coin string) int {
t.xyzMetaMutex.RLock()
defer t.xyzMetaMutex.RUnlock()
if t.xyzMeta == nil {
logger.Infof("⚠️ xyz meta information is empty, using default precision 2")
return 2 // Default precision for stocks/forex
}
// The meta API returns names with xyz: prefix, so ensure we match correctly
lookupName := coin
if !strings.HasPrefix(lookupName, "xyz:") {
lookupName = "xyz:" + lookupName
}
// Find corresponding asset in xyzMeta.Universe
for _, asset := range t.xyzMeta.Universe {
if asset.Name == lookupName {
return asset.SzDecimals
}
}
logger.Infof("⚠️ Precision information not found for %s, using default precision 2", lookupName)
return 2 // Default precision for stocks/forex
}
// getXyzAssetIndex gets the asset index for an xyz dex asset
func (t *HyperliquidTrader) getXyzAssetIndex(baseCoin string) int {
t.xyzMetaMutex.RLock()
defer t.xyzMetaMutex.RUnlock()
if t.xyzMeta == nil {
return -1
}
// The meta API returns names with xyz: prefix, so ensure we match correctly
lookupName := baseCoin
if !strings.HasPrefix(lookupName, "xyz:") {
lookupName = "xyz:" + lookupName
}
for i, asset := range t.xyzMeta.Universe {
if asset.Name == lookupName {
return i
}
}
return -1
}
-648
View File
@@ -1,648 +0,0 @@
package hyperliquid
import (
"context"
"crypto/ecdsa"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/ethereum/go-ethereum/crypto"
"github.com/sonirico/go-hyperliquid"
"github.com/stretchr/testify/assert"
"nofx/trader/testutil"
"nofx/trader/types"
)
// ============================================================
// Part 1: HyperliquidTestSuite - Inherits base test suite
// ============================================================
// HyperliquidTestSuite Hyperliquid trader test suite
// Inherits TraderTestSuite and adds Hyperliquid-specific mock logic
type HyperliquidTestSuite struct {
*testutil.TraderTestSuite // Embeds base test suite
mockServer *httptest.Server
privateKey *ecdsa.PrivateKey
}
// NewHyperliquidTestSuite Create Hyperliquid test suite
func NewHyperliquidTestSuite(t *testing.T) *HyperliquidTestSuite {
// Create test private key
privateKey, err := crypto.HexToECDSA("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef")
if err != nil {
t.Fatalf("Failed to create test private key: %v", err)
}
// Create mock HTTP server
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Return different mock responses based on request path
var respBody interface{}
// Hyperliquid API uses POST requests with JSON body
// We need to distinguish different requests by the "type" field in request body
var reqBody map[string]interface{}
if r.Method == "POST" {
json.NewDecoder(r.Body).Decode(&reqBody)
}
// Try to get type from top level first, then from action object
reqType, _ := reqBody["type"].(string)
if reqType == "" && reqBody["action"] != nil {
if action, ok := reqBody["action"].(map[string]interface{}); ok {
reqType, _ = action["type"].(string)
}
}
switch reqType {
// Mock Meta - Get market metadata
case "meta":
respBody = map[string]interface{}{
"universe": []map[string]interface{}{
{
"name": "BTC",
"szDecimals": 4,
"maxLeverage": 50,
"onlyIsolated": false,
"isDelisted": false,
"marginTableId": 0,
},
{
"name": "ETH",
"szDecimals": 3,
"maxLeverage": 50,
"onlyIsolated": false,
"isDelisted": false,
"marginTableId": 0,
},
},
"marginTables": []interface{}{},
}
// Mock UserState - Get user account state (for GetBalance and GetPositions)
case "clearinghouseState":
user, _ := reqBody["user"].(string)
// Check if querying Agent wallet balance (for security check)
agentAddr := crypto.PubkeyToAddress(privateKey.PublicKey).Hex()
if user == agentAddr {
// Agent wallet balance should be low
respBody = map[string]interface{}{
"crossMarginSummary": map[string]interface{}{
"accountValue": "5.00",
"totalMarginUsed": "0.00",
},
"withdrawable": "5.00",
"assetPositions": []interface{}{},
}
} else {
// Main wallet account state
respBody = map[string]interface{}{
"crossMarginSummary": map[string]interface{}{
"accountValue": "10000.00",
"totalMarginUsed": "2000.00",
},
"withdrawable": "8000.00",
"assetPositions": []map[string]interface{}{
{
"position": map[string]interface{}{
"coin": "BTC",
"szi": "0.5",
"entryPx": "50000.00",
"liquidationPx": "45000.00",
"positionValue": "25000.00",
"unrealizedPnl": "100.50",
"leverage": map[string]interface{}{
"type": "cross",
"value": 10,
},
},
},
},
}
}
// Mock SpotUserState - Get spot account state
case "spotClearinghouseState":
respBody = map[string]interface{}{
"balances": []map[string]interface{}{
{
"coin": "USDC",
"total": "500.00",
},
},
}
// Mock SpotMeta - Get spot market metadata
case "spotMeta":
respBody = map[string]interface{}{
"universe": []map[string]interface{}{},
"tokens": []map[string]interface{}{},
}
// Mock AllMids - Get all market prices
case "allMids":
respBody = map[string]string{
"BTC": "50000.00",
"ETH": "3000.00",
}
// Mock OpenOrders - Get open orders list
case "openOrders":
respBody = []interface{}{}
// Mock Order - Create order (open, close, stop-loss, take-profit)
case "order":
respBody = map[string]interface{}{
"status": "ok",
"response": map[string]interface{}{
"type": "order",
"data": map[string]interface{}{
"statuses": []map[string]interface{}{
{
"filled": map[string]interface{}{
"totalSz": "0.01",
"avgPx": "50000.00",
},
},
},
},
},
}
// Mock UpdateLeverage - Set leverage
case "updateLeverage":
respBody = map[string]interface{}{
"status": "ok",
}
// Mock Cancel - Cancel order
case "cancel":
respBody = map[string]interface{}{
"status": "ok",
}
default:
// Default return success response
respBody = map[string]interface{}{
"status": "ok",
}
}
// Serialize response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(respBody)
}))
// Create HyperliquidTrader, using mock server URL
walletAddr := "0x9999999999999999999999999999999999999999"
ctx := context.Background()
// Create Exchange client, pointing to mock server
exchange := hyperliquid.NewExchange(
ctx,
privateKey,
mockServer.URL, // Use mock server URL
nil,
"",
walletAddr,
nil,
)
// Create meta (simulate successful fetch)
meta := &hyperliquid.Meta{
Universe: []hyperliquid.AssetInfo{
{Name: "BTC", SzDecimals: 4},
{Name: "ETH", SzDecimals: 3},
},
}
traderInstance := &HyperliquidTrader{
exchange: exchange,
ctx: ctx,
walletAddr: walletAddr,
meta: meta,
isCrossMargin: true,
}
// Create base suite
baseSuite := testutil.NewTraderTestSuite(t, traderInstance)
return &HyperliquidTestSuite{
TraderTestSuite: baseSuite,
mockServer: mockServer,
privateKey: privateKey,
}
}
// Cleanup Clean up resources
func (s *HyperliquidTestSuite) Cleanup() {
if s.mockServer != nil {
s.mockServer.Close()
}
s.TraderTestSuite.Cleanup()
}
// ============================================================
// Part 2: Run common tests using HyperliquidTestSuite
// ============================================================
// TestHyperliquidTrader_InterfaceCompliance Test interface compliance
func TestHyperliquidTrader_InterfaceCompliance(t *testing.T) {
var _ types.Trader = (*HyperliquidTrader)(nil)
}
// TestHyperliquidTrader_CommonInterface Run all common interface tests using test suite
func TestHyperliquidTrader_CommonInterface(t *testing.T) {
// Create test suite
suite := NewHyperliquidTestSuite(t)
defer suite.Cleanup()
// Run all common interface tests
suite.RunAllTests()
}
// ============================================================
// Part 3: Hyperliquid-specific feature unit tests
// ============================================================
// TestNewHyperliquidTrader Test creating Hyperliquid trader
func TestNewHyperliquidTrader(t *testing.T) {
tests := []struct {
name string
privateKeyHex string
walletAddr string
testnet bool
wantError bool
errorContains string
}{
{
name: "Invalid private key format",
privateKeyHex: "invalid_key",
walletAddr: "0x1234567890123456789012345678901234567890",
testnet: true,
wantError: true,
errorContains: "Failed to parse private key",
},
{
name: "Empty wallet address",
privateKeyHex: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
walletAddr: "",
testnet: true,
wantError: true,
errorContains: "Configuration error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
trader, err := NewHyperliquidTrader(tt.privateKeyHex, tt.walletAddr, tt.testnet)
if tt.wantError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
assert.Nil(t, trader)
} else {
assert.NoError(t, err)
assert.NotNil(t, trader)
if trader != nil {
assert.Equal(t, tt.walletAddr, trader.walletAddr)
assert.NotNil(t, trader.exchange)
}
}
})
}
}
// TestNewHyperliquidTrader_Success Test successfully creating trader (requires mock HTTP)
func TestNewHyperliquidTrader_Success(t *testing.T) {
// Create test private key
privateKey, _ := crypto.HexToECDSA("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef")
agentAddr := crypto.PubkeyToAddress(privateKey.PublicKey).Hex()
// Create mock HTTP server
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var reqBody map[string]interface{}
json.NewDecoder(r.Body).Decode(&reqBody)
reqType, _ := reqBody["type"].(string)
var respBody interface{}
switch reqType {
case "meta":
respBody = map[string]interface{}{
"universe": []map[string]interface{}{
{
"name": "BTC",
"szDecimals": 4,
"maxLeverage": 50,
"onlyIsolated": false,
"isDelisted": false,
"marginTableId": 0,
},
},
"marginTables": []interface{}{},
}
case "clearinghouseState":
user, _ := reqBody["user"].(string)
if user == agentAddr {
// Agent wallet low balance
respBody = map[string]interface{}{
"crossMarginSummary": map[string]interface{}{
"accountValue": "5.00",
},
"assetPositions": []interface{}{},
}
} else {
// Main wallet
respBody = map[string]interface{}{
"crossMarginSummary": map[string]interface{}{
"accountValue": "10000.00",
},
"assetPositions": []interface{}{},
}
}
default:
respBody = map[string]interface{}{"status": "ok"}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(respBody)
}))
defer mockServer.Close()
// Note: This test would actually call NewHyperliquidTrader, but will fail
// Because hyperliquid SDK doesn't allow us to inject custom URL in constructor
// So this test is only for verifying parameter handling logic
t.Skip("Skip this test: hyperliquid SDK calls real API during construction, cannot inject mock URL")
}
// ============================================================
// Part 4: Utility function unit tests (Hyperliquid-specific)
// ============================================================
// TestConvertSymbolToHyperliquid Test symbol conversion function
func TestConvertSymbolToHyperliquid(t *testing.T) {
tests := []struct {
name string
symbol string
expected string
}{
{
name: "BTCUSDT conversion",
symbol: "BTCUSDT",
expected: "BTC",
},
{
name: "ETHUSDT conversion",
symbol: "ETHUSDT",
expected: "ETH",
},
{
name: "No USDT suffix",
symbol: "BTC",
expected: "BTC",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertSymbolToHyperliquid(tt.symbol)
assert.Equal(t, tt.expected, result)
})
}
}
// TestAbsFloat Test absolute value function
func TestAbsFloat(t *testing.T) {
tests := []struct {
name string
input float64
expected float64
}{
{
name: "Positive number",
input: 10.5,
expected: 10.5,
},
{
name: "Negative number",
input: -10.5,
expected: 10.5,
},
{
name: "Zero",
input: 0,
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := absFloat(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
// TestHyperliquidTrader_RoundToSzDecimals Test quantity precision handling
func TestHyperliquidTrader_RoundToSzDecimals(t *testing.T) {
trader := &HyperliquidTrader{
meta: &hyperliquid.Meta{
Universe: []hyperliquid.AssetInfo{
{Name: "BTC", SzDecimals: 4},
{Name: "ETH", SzDecimals: 3},
},
},
}
tests := []struct {
name string
coin string
quantity float64
expected float64
}{
{
name: "BTC - round to 4 decimals",
coin: "BTC",
quantity: 1.23456789,
expected: 1.2346,
},
{
name: "ETH - round to 3 decimals",
coin: "ETH",
quantity: 10.12345,
expected: 10.123,
},
{
name: "Unknown coin - use default 4 decimals",
coin: "UNKNOWN",
quantity: 1.23456789,
expected: 1.2346,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := trader.roundToSzDecimals(tt.coin, tt.quantity)
assert.InDelta(t, tt.expected, result, 0.0001)
})
}
}
// TestHyperliquidTrader_RoundPriceToSigfigs Test price significant figures handling
func TestHyperliquidTrader_RoundPriceToSigfigs(t *testing.T) {
trader := &HyperliquidTrader{}
tests := []struct {
name string
price float64
expected float64
}{
{
name: "BTC price - 5 significant figures",
price: 50123.456789,
expected: 50123.0,
},
{
name: "Decimal price - 5 significant figures",
price: 0.0012345678,
expected: 0.0012346,
},
{
name: "Zero price",
price: 0,
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := trader.roundPriceToSigfigs(tt.price)
assert.InDelta(t, tt.expected, result, tt.expected*0.001)
})
}
}
// TestHyperliquidTrader_GetSzDecimals Test getting precision
func TestHyperliquidTrader_GetSzDecimals(t *testing.T) {
tests := []struct {
name string
meta *hyperliquid.Meta
coin string
expected int
}{
{
name: "meta is nil - return default precision",
meta: nil,
coin: "BTC",
expected: 4,
},
{
name: "Found BTC - return correct precision",
meta: &hyperliquid.Meta{
Universe: []hyperliquid.AssetInfo{
{Name: "BTC", SzDecimals: 5},
},
},
coin: "BTC",
expected: 5,
},
{
name: "Coin not found - return default precision",
meta: &hyperliquid.Meta{
Universe: []hyperliquid.AssetInfo{
{Name: "ETH", SzDecimals: 3},
},
},
coin: "BTC",
expected: 4,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ht := &HyperliquidTrader{meta: tt.meta}
result := ht.getSzDecimals(tt.coin)
assert.Equal(t, tt.expected, result)
})
}
}
// TestHyperliquidTrader_SetMarginMode Test setting margin mode
func TestHyperliquidTrader_SetMarginMode(t *testing.T) {
trader := &HyperliquidTrader{
ctx: context.Background(),
isCrossMargin: true,
}
tests := []struct {
name string
symbol string
isCrossMargin bool
wantError bool
}{
{
name: "Set to cross margin mode",
symbol: "BTCUSDT",
isCrossMargin: true,
wantError: false,
},
{
name: "Set to isolated margin mode",
symbol: "ETHUSDT",
isCrossMargin: false,
wantError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := trader.SetMarginMode(tt.symbol, tt.isCrossMargin)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.isCrossMargin, trader.isCrossMargin)
}
})
}
}
// TestNewHyperliquidTrader_PrivateKeyProcessing Test private key processing
func TestNewHyperliquidTrader_PrivateKeyProcessing(t *testing.T) {
tests := []struct {
name string
privateKeyHex string
shouldStripOx bool
expectedLength int
}{
{
name: "Private key with 0x prefix",
privateKeyHex: "0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
shouldStripOx: true,
expectedLength: 64,
},
{
name: "Private key without prefix",
privateKeyHex: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
shouldStripOx: false,
expectedLength: 64,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test private key prefix handling logic (without actually creating trader)
processed := tt.privateKeyHex
if len(processed) > 2 && (processed[:2] == "0x" || processed[:2] == "0X") {
processed = processed[2:]
}
assert.Equal(t, tt.expectedLength, len(processed))
})
}
}
-669
View File
@@ -1,669 +0,0 @@
package hyperliquid
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"testing"
"time"
)
// testXyzDexAsset is a local copy of testXyzDexAsset for testing
type testXyzDexAsset struct {
Name string `json:"name"`
SzDecimals int `json:"szDecimals"`
MaxLeverage int `json:"maxLeverage"`
}
// testXyzDexMeta is a local copy of xyzDexMeta for testing
type testXyzDexMeta struct {
Universe []testXyzDexAsset `json:"universe"`
}
// TestXyzDexMetaFetch tests fetching xyz dex meta from Hyperliquid API
func TestXyzDexMetaFetch(t *testing.T) {
reqBody := map[string]string{
"type": "meta",
"dex": "xyz",
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
t.Fatalf("Failed to marshal request: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "POST", "https://api.hyperliquid.xyz/info", bytes.NewBuffer(jsonBody))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("API returned status %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response: %v", err)
}
var meta testXyzDexMeta
if err := json.Unmarshal(body, &meta); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if len(meta.Universe) == 0 {
t.Fatal("xyz meta universe is empty")
}
t.Logf("✅ xyz dex meta contains %d assets", len(meta.Universe))
// Check that SILVER exists
// HIP-3 perp dex asset index formula: 100000 + perp_dex_index * 10000 + index_in_meta
// xyz dex is at perp_dex_index = 1
found := false
for i, asset := range meta.Universe {
if asset.Name == "xyz:SILVER" {
found = true
assetIndex := 100000 + 1*10000 + i // xyz dex index = 1
t.Logf("✅ Found xyz:SILVER at index %d (asset ID: %d)", i, assetIndex)
t.Logf(" SzDecimals: %d, MaxLeverage: %d", asset.SzDecimals, asset.MaxLeverage)
break
}
}
if !found {
t.Fatal("xyz:SILVER not found in meta")
}
}
// TestXyzDexPriceFetch tests fetching xyz dex prices from Hyperliquid API
func TestXyzDexPriceFetch(t *testing.T) {
reqBody := map[string]string{
"type": "allMids",
"dex": "xyz",
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
t.Fatalf("Failed to marshal request: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "POST", "https://api.hyperliquid.xyz/info", bytes.NewBuffer(jsonBody))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("API returned status %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response: %v", err)
}
var mids map[string]string
if err := json.Unmarshal(body, &mids); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
// Check that prices have xyz: prefix
silverPrice, ok := mids["xyz:SILVER"]
if !ok {
t.Fatal("xyz:SILVER price not found (key should include xyz: prefix)")
}
t.Logf("✅ xyz:SILVER price: %s", silverPrice)
// Verify a few more assets
testAssets := []string{"xyz:GOLD", "xyz:TSLA", "xyz:NVDA"}
for _, asset := range testAssets {
if price, ok := mids[asset]; ok {
t.Logf("✅ %s price: %s", asset, price)
} else {
t.Logf("⚠️ %s not found in prices", asset)
}
}
}
// TestXyzAssetIndexLookup tests the asset index lookup for xyz dex assets
func TestXyzAssetIndexLookup(t *testing.T) {
// Fetch xyz meta
reqBody := map[string]string{
"type": "meta",
"dex": "xyz",
}
jsonBody, _ := json.Marshal(reqBody)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, _ := http.NewRequestWithContext(ctx, "POST", "https://api.hyperliquid.xyz/info", bytes.NewBuffer(jsonBody))
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Failed to fetch meta: %v", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
var meta testXyzDexMeta
json.Unmarshal(body, &meta)
// Test lookup with different formats
testCases := []struct {
input string
expected string // expected match in meta
}{
{"SILVER", "xyz:SILVER"},
{"xyz:SILVER", "xyz:SILVER"},
{"GOLD", "xyz:GOLD"},
{"xyz:TSLA", "xyz:TSLA"},
}
for _, tc := range testCases {
lookupName := tc.input
if !strings.HasPrefix(lookupName, "xyz:") {
lookupName = "xyz:" + lookupName
}
found := false
for i, asset := range meta.Universe {
if asset.Name == lookupName {
found = true
assetIndex := 100000 + 1*10000 + i // HIP-3 formula: 100000 + xyz_dex_index(1) * 10000 + meta_index
t.Logf("✅ Lookup '%s' -> found at index %d (asset ID: %d)", tc.input, i, assetIndex)
break
}
}
if !found {
t.Errorf("❌ Lookup '%s' -> NOT FOUND (expected to match %s)", tc.input, tc.expected)
}
}
}
// TestXyzSzDecimalsLookup tests the szDecimals lookup for different xyz assets
func TestXyzSzDecimalsLookup(t *testing.T) {
reqBody := map[string]string{
"type": "meta",
"dex": "xyz",
}
jsonBody, _ := json.Marshal(reqBody)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, _ := http.NewRequestWithContext(ctx, "POST", "https://api.hyperliquid.xyz/info", bytes.NewBuffer(jsonBody))
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Failed to fetch meta: %v", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
var meta testXyzDexMeta
json.Unmarshal(body, &meta)
// Check szDecimals for various assets
expectedDecimals := map[string]int{
"xyz:SILVER": 2,
"xyz:GOLD": 4,
"xyz:TSLA": 3,
}
for name, expected := range expectedDecimals {
for _, asset := range meta.Universe {
if asset.Name == name {
if asset.SzDecimals == expected {
t.Logf("✅ %s szDecimals: %d (expected %d)", name, asset.SzDecimals, expected)
} else {
t.Logf("⚠️ %s szDecimals: %d (expected %d, may have changed)", name, asset.SzDecimals, expected)
}
break
}
}
}
}
// TestXyzOrderParameters tests order parameter calculation
func TestXyzOrderParameters(t *testing.T) {
// Simulate order parameter calculation
testCases := []struct {
price float64
size float64
szDecimals int
expectedSz float64
}{
{75.33, 1.0, 2, 1.00},
{75.33, 1.234, 2, 1.23},
{75.33, 5.567, 2, 5.57},
{188.15, 0.5, 3, 0.500},
{188.15, 0.1234, 3, 0.123},
}
for _, tc := range testCases {
multiplier := 1.0
for i := 0; i < tc.szDecimals; i++ {
multiplier *= 10.0
}
roundedSize := float64(int(tc.size*multiplier+0.5)) / multiplier
if roundedSize != tc.expectedSz {
t.Errorf("Size rounding failed: input=%v, decimals=%d, got=%v, expected=%v",
tc.size, tc.szDecimals, roundedSize, tc.expectedSz)
} else {
t.Logf("✅ Size rounding: %v (decimals=%d) -> %v", tc.size, tc.szDecimals, roundedSize)
}
}
}
// TestXyzAssetIndexCalculation tests the HIP-3 asset index calculation
// Formula: 100000 + perp_dex_index * 10000 + meta_index
// For xyz dex: perp_dex_index = 1, so asset_index = 110000 + meta_index
func TestXyzAssetIndexCalculation(t *testing.T) {
reqBody := map[string]string{
"type": "meta",
"dex": "xyz",
}
jsonBody, _ := json.Marshal(reqBody)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, _ := http.NewRequestWithContext(ctx, "POST", "https://api.hyperliquid.xyz/info", bytes.NewBuffer(jsonBody))
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Failed to fetch meta: %v", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
var meta testXyzDexMeta
json.Unmarshal(body, &meta)
// Test asset index calculation for SILVER
// HIP-3 perp dex asset index formula: 100000 + perp_dex_index * 10000 + index_in_meta
// xyz dex is at perp_dex_index = 1
const xyzPerpDexIndex = 1
for i, asset := range meta.Universe {
if asset.Name == "xyz:SILVER" {
assetIndex := 100000 + xyzPerpDexIndex*10000 + i
t.Logf("✅ xyz:SILVER: meta_index=%d, asset_index=%d", i, assetIndex)
if assetIndex < 110000 {
t.Errorf("Asset index should be >= 110000, got %d", assetIndex)
}
break
}
}
// Log first few assets for reference
t.Log("\nFirst 5 xyz assets:")
for i := 0; i < 5 && i < len(meta.Universe); i++ {
asset := meta.Universe[i]
assetIndex := 100000 + xyzPerpDexIndex*10000 + i
t.Logf(" [%d] %s -> asset_index=%d, szDecimals=%d", i, asset.Name, assetIndex, asset.SzDecimals)
}
}
// TestIsXyzDexAsset tests the isXyzDexAsset function
func TestIsXyzDexAsset(t *testing.T) {
testCases := []struct {
symbol string
expected bool
}{
{"xyz:SILVER", true},
{"SILVER", true},
{"silver", true},
{"xyz:GOLD", true},
{"GOLD", true},
{"xyz:TSLA", true},
{"TSLA", true},
{"BTCUSDT", false},
{"BTC", false},
{"ETHUSDT", false},
{"SOLUSDT", false},
{"xyz:BTC", false}, // BTC is not an xyz asset
}
for _, tc := range testCases {
result := isXyzDexAsset(tc.symbol)
if result != tc.expected {
t.Errorf("isXyzDexAsset(%q) = %v, expected %v", tc.symbol, result, tc.expected)
} else {
t.Logf("✅ isXyzDexAsset(%q) = %v", tc.symbol, result)
}
}
}
// TestConvertSymbolToHyperliquidXyz tests symbol conversion for xyz assets
func TestConvertSymbolToHyperliquidXyz(t *testing.T) {
testCases := []struct {
input string
expected string
}{
{"SILVER", "xyz:SILVER"},
{"silver", "xyz:SILVER"},
{"xyz:SILVER", "xyz:SILVER"},
{"GOLD", "xyz:GOLD"},
{"TSLA", "xyz:TSLA"},
{"BTC", "BTC"},
{"BTCUSDT", "BTC"},
{"ETH", "ETH"},
{"ETHUSDT", "ETH"},
}
for _, tc := range testCases {
result := convertSymbolToHyperliquid(tc.input)
if result != tc.expected {
t.Errorf("convertSymbolToHyperliquid(%q) = %q, expected %q", tc.input, result, tc.expected)
} else {
t.Logf("✅ convertSymbolToHyperliquid(%q) = %q", tc.input, result)
}
}
}
// TestXyzDexOrderFlow tests the complete order flow (without actually placing an order)
func TestXyzDexOrderFlow(t *testing.T) {
t.Log("=== Testing xyz Dex Order Flow ===")
// Step 1: Fetch meta
t.Log("\nStep 1: Fetching xyz meta...")
reqBody := map[string]string{"type": "meta", "dex": "xyz"}
jsonBody, _ := json.Marshal(reqBody)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, _ := http.NewRequestWithContext(ctx, "POST", "https://api.hyperliquid.xyz/info", bytes.NewBuffer(jsonBody))
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Failed to fetch meta: %v", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
var meta testXyzDexMeta
json.Unmarshal(body, &meta)
t.Logf("✅ Fetched %d xyz assets", len(meta.Universe))
// Step 2: Find SILVER
t.Log("\nStep 2: Looking up xyz:SILVER...")
var silverIndex int = -1
var silverAsset *testXyzDexAsset
for i, asset := range meta.Universe {
if asset.Name == "xyz:SILVER" {
silverIndex = i
silverAsset = &meta.Universe[i]
break
}
}
if silverIndex < 0 {
t.Fatal("SILVER not found in xyz meta")
}
t.Logf("✅ Found at index %d", silverIndex)
// Step 3: Fetch price
t.Log("\nStep 3: Fetching price...")
priceReq := map[string]string{"type": "allMids", "dex": "xyz"}
priceBody, _ := json.Marshal(priceReq)
req2, _ := http.NewRequestWithContext(ctx, "POST", "https://api.hyperliquid.xyz/info", bytes.NewBuffer(priceBody))
req2.Header.Set("Content-Type", "application/json")
resp2, _ := client.Do(req2)
body2, _ := io.ReadAll(resp2.Body)
resp2.Body.Close()
var mids map[string]string
json.Unmarshal(body2, &mids)
priceStr := mids["xyz:SILVER"]
var price float64
fmt.Sscanf(priceStr, "%f", &price)
t.Logf("✅ Price: %s", priceStr)
// Step 4: Calculate order parameters
t.Log("\nStep 4: Calculating order parameters...")
orderSize := 1.0
multiplier := 1.0
for i := 0; i < silverAsset.SzDecimals; i++ {
multiplier *= 10.0
}
roundedSize := float64(int(orderSize*multiplier+0.5)) / multiplier
roundedPrice := price * 1.001 // 0.1% slippage
// HIP-3 perp dex asset index formula: 100000 + perp_dex_index * 10000 + index_in_meta
// xyz dex is at perp_dex_index = 1
assetIndex := 100000 + 1*10000 + silverIndex
t.Logf(" Asset Index: %d (110000 + %d)", assetIndex, silverIndex)
t.Logf(" Size: %.4f (szDecimals=%d)", roundedSize, silverAsset.SzDecimals)
t.Logf(" Price: %.4f (with slippage)", roundedPrice)
// Step 5: Summary
t.Log("\n=== Order Flow Test Summary ===")
t.Log("✅ Meta fetch: OK")
t.Log("✅ Asset lookup: OK")
t.Log("✅ Price fetch: OK")
t.Log("✅ Parameter calculation: OK")
t.Logf("\n📋 Order would be placed with:")
t.Logf(" coin: xyz:SILVER")
t.Logf(" assetIndex: %d", assetIndex)
t.Logf(" isBuy: true")
t.Logf(" size: %.4f", roundedSize)
t.Logf(" price: %.4f", roundedPrice)
}
// TestXyzDexLiveOrder tests placing a real order on xyz dex
// This test requires:
// - XYZ_DEX_LIVE_TEST=1 to enable
// - TEST_PRIVATE_KEY - the private key for signing
// - TEST_WALLET_ADDR - the wallet address with funds
func TestXyzDexLiveOrder(t *testing.T) {
// Skip unless explicitly enabled
if os.Getenv("XYZ_DEX_LIVE_TEST") != "1" {
t.Skip("Skipping live order test. Set XYZ_DEX_LIVE_TEST=1 to run")
}
// Get credentials from environment variables
privateKeyHex := os.Getenv("TEST_PRIVATE_KEY")
walletAddr := os.Getenv("TEST_WALLET_ADDR")
if privateKeyHex == "" || walletAddr == "" {
t.Skip("TEST_PRIVATE_KEY and TEST_WALLET_ADDR env vars required")
}
t.Logf("=== Live xyz Dex Order Test ===")
t.Logf("Wallet: %s", walletAddr)
// Create trader instance
trader, err := NewHyperliquidTrader(privateKeyHex, walletAddr, false)
if err != nil {
t.Fatalf("Failed to create trader: %v", err)
}
// Check xyz dex balance first
xyzState, _ := trader.exchange.Info().UserState(trader.ctx, walletAddr, "xyz")
if xyzState != nil && xyzState.CrossMarginSummary.AccountValue == "0.0" {
t.Logf("⚠️ xyz dex account has no funds (balance: %s)", xyzState.CrossMarginSummary.AccountValue)
t.Logf(" To trade xyz dex, you need to transfer funds using perpDexClassTransfer")
t.Logf(" The test will still verify order signing and submission...")
}
// Fetch xyz meta first
if err := trader.fetchXyzMeta(); err != nil {
t.Fatalf("Failed to fetch xyz meta: %v", err)
}
// Get current price for xyz:SILVER
price, err := trader.getXyzMarketPrice("xyz:SILVER")
if err != nil {
t.Fatalf("Failed to get price: %v", err)
}
t.Logf("Current xyz:SILVER price: %.4f", price)
// Place a test order (minimum $10 value = 0.14 SILVER at ~$75)
// With 5% slippage for IOC (market order)
testSize := 0.14 // ~$10.5 at current price
testPrice := price * 1.05 // 5% above market for IOC buy (market order)
t.Logf("Attempting to place order:")
t.Logf(" Symbol: xyz:SILVER")
t.Logf(" Side: BUY")
t.Logf(" Size: %.4f", testSize)
t.Logf(" Price: %.4f", testPrice)
// Place the order using the new direct method
err = trader.placeXyzOrder("xyz:SILVER", true, testSize, testPrice, false)
if err != nil {
t.Logf("⚠️ Order result: %v", err)
// Check if this is an expected error (e.g., insufficient margin, no matching orders for IOC)
if strings.Contains(err.Error(), "insufficient") || strings.Contains(err.Error(), "margin") || strings.Contains(err.Error(), "minimum value") {
t.Logf("This may be expected if the test wallet has no margin in xyz dex")
t.Logf("✅ Order was properly signed and submitted (API validated format/signature)")
} else if strings.Contains(err.Error(), "could not immediately match") {
// IOC order didn't fill - this is actually SUCCESS!
// It means the order was properly signed, submitted, and processed
t.Logf("✅ Order was properly submitted but didn't fill (IOC with no matching orders)")
t.Logf(" This confirms the asset index (%d) and signing are correct!", 110026)
} else if strings.Contains(err.Error(), "Order has invalid price") || strings.Contains(err.Error(), "95% away") {
t.Errorf("FAILED: Order has invalid price - asset index issue")
} else {
t.Errorf("FAILED: Unexpected error: %v", err)
}
} else {
t.Logf("✅ Order placed and filled successfully!")
}
}
// TestXyzDexClosePosition tests closing a position on xyz dex
// This test requires the XYZ_DEX_LIVE_TEST environment variable to be set
func TestXyzDexClosePosition(t *testing.T) {
// Skip unless explicitly enabled
if os.Getenv("XYZ_DEX_LIVE_TEST") != "1" {
t.Skip("Skipping live close position test. Set XYZ_DEX_LIVE_TEST=1 to run")
}
// Get credentials from environment variables
privateKeyHex := os.Getenv("TEST_PRIVATE_KEY")
walletAddr := os.Getenv("TEST_WALLET_ADDR")
if privateKeyHex == "" || walletAddr == "" {
t.Skip("TEST_PRIVATE_KEY and TEST_WALLET_ADDR env vars required")
}
t.Logf("=== Live xyz Dex Close Position Test ===")
t.Logf("Wallet: %s", walletAddr)
// Create trader instance
trader, err := NewHyperliquidTrader(privateKeyHex, walletAddr, false)
if err != nil {
t.Fatalf("Failed to create trader: %v", err)
}
// Check current xyz dex position
xyzState, err := trader.exchange.Info().UserState(trader.ctx, walletAddr, "xyz")
if err != nil {
t.Fatalf("Failed to get xyz state: %v", err)
}
if len(xyzState.AssetPositions) == 0 {
t.Logf("No xyz dex positions to close")
return
}
// Get the position details
pos := xyzState.AssetPositions[0].Position
entryPx := ""
if pos.EntryPx != nil {
entryPx = *pos.EntryPx
}
t.Logf("Current position: %s size=%s entryPx=%s", pos.Coin, pos.Szi, entryPx)
// Fetch xyz meta
if err := trader.fetchXyzMeta(); err != nil {
t.Fatalf("Failed to fetch xyz meta: %v", err)
}
// Get current price
price, err := trader.getXyzMarketPrice(pos.Coin)
if err != nil {
t.Fatalf("Failed to get price: %v", err)
}
t.Logf("Current %s price: %.4f", pos.Coin, price)
// Parse position size
var posSize float64
fmt.Sscanf(pos.Szi, "%f", &posSize)
// Close position: if long (szi > 0), sell; if short (szi < 0), buy
isBuy := posSize < 0
closeSize := posSize
if closeSize < 0 {
closeSize = -closeSize
}
// Use aggressive slippage for close
closePrice := price * 0.95 // 5% below for sell
if isBuy {
closePrice = price * 1.05 // 5% above for buy
}
t.Logf("Closing position:")
t.Logf(" Side: %s", map[bool]string{true: "BUY", false: "SELL"}[isBuy])
t.Logf(" Size: %.4f", closeSize)
t.Logf(" Price: %.4f", closePrice)
// Place close order with reduceOnly=true
err = trader.placeXyzOrder(pos.Coin, isBuy, closeSize, closePrice, true)
if err != nil {
t.Logf("⚠️ Close order result: %v", err)
if strings.Contains(err.Error(), "could not immediately match") {
t.Logf("✅ Close order submitted but didn't fill (IOC)")
} else {
t.Errorf("FAILED: %v", err)
}
} else {
t.Logf("✅ Position closed successfully!")
}
// Verify position is closed
xyzState2, _ := trader.exchange.Info().UserState(trader.ctx, walletAddr, "xyz")
if len(xyzState2.AssetPositions) == 0 {
t.Logf("✅ Position confirmed closed (no positions remaining)")
} else {
newPos := xyzState2.AssetPositions[0].Position
t.Logf("Position after close: %s size=%s", newPos.Coin, newPos.Szi)
}
}
-557
View File
@@ -7,11 +7,9 @@ import (
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"net/url"
"nofx/logger"
"nofx/trader/types"
"strconv"
"strings"
"sync"
@@ -299,561 +297,6 @@ func (t *IndodaxTrader) clearCache() {
t.cachedPositions = nil
}
// ============================================================
// types.Trader interface implementation
// ============================================================
// GetBalance gets account balance from Indodax
func (t *IndodaxTrader) GetBalance() (map[string]interface{}, error) {
// Check cache
t.cacheMutex.RLock()
if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration {
cached := t.cachedBalance
t.cacheMutex.RUnlock()
return cached, nil
}
t.cacheMutex.RUnlock()
params := url.Values{}
params.Set("method", "getInfo")
data, err := t.doPrivateRequest(params)
if err != nil {
return nil, fmt.Errorf("failed to get account info: %w", err)
}
var result struct {
ServerTime int64 `json:"server_time"`
Balance map[string]interface{} `json:"balance"`
BalanceHold map[string]interface{} `json:"balance_hold"`
UserID string `json:"user_id"`
Name string `json:"name"`
Email string `json:"email"`
}
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse balance: %w", err)
}
// Calculate total balance in IDR
idrBalance := parseFloat(result.Balance["idr"])
idrHold := parseFloat(result.BalanceHold["idr"])
totalIDR := idrBalance + idrHold
balance := map[string]interface{}{
"totalWalletBalance": totalIDR,
"availableBalance": idrBalance,
"totalUnrealizedProfit": 0.0,
"totalEquity": totalIDR,
"balance": totalIDR,
"idr_balance": idrBalance,
"idr_hold": idrHold,
"currency": "IDR",
"user_id": result.UserID,
"server_time": result.ServerTime,
}
// Add individual crypto balances
for currency, amount := range result.Balance {
if currency != "idr" {
balance["balance_"+currency] = parseFloat(amount)
}
}
for currency, amount := range result.BalanceHold {
if currency != "idr" {
balance["hold_"+currency] = parseFloat(amount)
}
}
// Update cache
t.cacheMutex.Lock()
t.cachedBalance = balance
t.balanceCacheTime = time.Now()
t.cacheMutex.Unlock()
return balance, nil
}
// GetPositions returns currently held crypto balances as "positions"
// Since Indodax is spot-only, each non-zero crypto balance is treated as a position
func (t *IndodaxTrader) GetPositions() ([]map[string]interface{}, error) {
// Check cache
t.cacheMutex.RLock()
if t.cachedPositions != nil && time.Since(t.positionCacheTime) < t.cacheDuration {
cached := t.cachedPositions
t.cacheMutex.RUnlock()
return cached, nil
}
t.cacheMutex.RUnlock()
params := url.Values{}
params.Set("method", "getInfo")
data, err := t.doPrivateRequest(params)
if err != nil {
return nil, fmt.Errorf("failed to get positions: %w", err)
}
var result struct {
Balance map[string]interface{} `json:"balance"`
BalanceHold map[string]interface{} `json:"balance_hold"`
}
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse positions: %w", err)
}
var positions []map[string]interface{}
for currency, amountRaw := range result.Balance {
if currency == "idr" {
continue
}
amount := parseFloat(amountRaw)
holdAmount := parseFloat(result.BalanceHold[currency])
totalAmount := amount + holdAmount
if totalAmount <= 0 {
continue
}
// Get market price for this coin
markPrice, _ := t.GetMarketPrice(strings.ToUpper(currency) + "IDR")
// Calculate position value in IDR
notionalValue := totalAmount * markPrice
position := map[string]interface{}{
"symbol": strings.ToUpper(currency) + "IDR",
"side": "LONG",
"positionAmt": totalAmount,
"entryPrice": markPrice, // Spot doesn't track entry price
"markPrice": markPrice,
"unRealizedProfit": 0.0, // Spot doesn't track unrealized PnL
"leverage": 1.0,
"mgnMode": "spot",
"notionalValue": notionalValue,
"currency": currency,
"available": amount,
"hold": holdAmount,
}
positions = append(positions, position)
}
// Update cache
t.cacheMutex.Lock()
t.cachedPositions = positions
t.positionCacheTime = time.Now()
t.cacheMutex.Unlock()
return positions, nil
}
// OpenLong opens a spot buy order
func (t *IndodaxTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
t.clearCache()
pair := t.convertSymbol(symbol)
coin := t.getCoinFromSymbol(symbol)
// Get market price to calculate IDR amount
price, err := t.GetMarketPrice(symbol)
if err != nil {
return nil, fmt.Errorf("failed to get market price: %w", err)
}
params := url.Values{}
params.Set("method", "trade")
params.Set("pair", pair)
params.Set("type", "buy")
params.Set("price", strconv.FormatFloat(price, 'f', 0, 64))
params.Set(coin, strconv.FormatFloat(quantity, 'f', 8, 64))
params.Set("order_type", "limit")
data, err := t.doPrivateRequest(params)
if err != nil {
return nil, fmt.Errorf("failed to place buy order: %w", err)
}
var result map[string]interface{}
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse trade response: %w", err)
}
logger.Infof("[Indodax] Buy order placed: %s qty=%.8f price=%.0f", symbol, quantity, price)
return map[string]interface{}{
"orderId": result["order_id"],
"symbol": symbol,
"side": "BUY",
"price": price,
"qty": quantity,
"status": "NEW",
}, nil
}
// OpenShort is not supported on Indodax (spot-only exchange)
func (t *IndodaxTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
return nil, fmt.Errorf("short selling is not supported on Indodax (spot-only exchange)")
}
// CloseLong closes a spot position by selling
func (t *IndodaxTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) {
t.clearCache()
pair := t.convertSymbol(symbol)
coin := t.getCoinFromSymbol(symbol)
// If quantity is 0, sell all available balance
if quantity <= 0 {
balance, err := t.GetBalance()
if err != nil {
return nil, fmt.Errorf("failed to get balance for close all: %w", err)
}
available := parseFloat(balance["balance_"+coin])
if available <= 0 {
return nil, fmt.Errorf("no %s balance to sell", coin)
}
quantity = available
}
// Get market price
price, err := t.GetMarketPrice(symbol)
if err != nil {
return nil, fmt.Errorf("failed to get market price: %w", err)
}
params := url.Values{}
params.Set("method", "trade")
params.Set("pair", pair)
params.Set("type", "sell")
params.Set("price", strconv.FormatFloat(price, 'f', 0, 64))
params.Set(coin, strconv.FormatFloat(quantity, 'f', 8, 64))
params.Set("order_type", "limit")
data, err := t.doPrivateRequest(params)
if err != nil {
return nil, fmt.Errorf("failed to place sell order: %w", err)
}
var result map[string]interface{}
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse trade response: %w", err)
}
logger.Infof("[Indodax] Sell order placed: %s qty=%.8f price=%.0f", symbol, quantity, price)
return map[string]interface{}{
"orderId": result["order_id"],
"symbol": symbol,
"side": "SELL",
"price": price,
"qty": quantity,
"status": "NEW",
}, nil
}
// CloseShort is not supported on Indodax (spot-only exchange)
func (t *IndodaxTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) {
return nil, fmt.Errorf("short selling is not supported on Indodax (spot-only exchange)")
}
// SetLeverage is a no-op for Indodax (spot-only, no leverage)
func (t *IndodaxTrader) SetLeverage(symbol string, leverage int) error {
logger.Infof("[Indodax] SetLeverage ignored (spot-only exchange, no leverage support)")
return nil
}
// SetMarginMode is a no-op for Indodax (spot-only, no margin)
func (t *IndodaxTrader) SetMarginMode(symbol string, isCrossMargin bool) error {
logger.Infof("[Indodax] SetMarginMode ignored (spot-only exchange, no margin support)")
return nil
}
// GetMarketPrice gets the current market price for a symbol
func (t *IndodaxTrader) GetMarketPrice(symbol string) (float64, error) {
pairID := strings.ToLower(strings.ReplaceAll(t.convertSymbol(symbol), "_", ""))
data, err := t.doPublicRequest("/ticker/" + pairID)
if err != nil {
return 0, fmt.Errorf("failed to get ticker: %w", err)
}
var tickerResp IndodaxTickerResponse
if err := json.Unmarshal(data, &tickerResp); err != nil {
return 0, fmt.Errorf("failed to parse ticker: %w", err)
}
price, err := strconv.ParseFloat(tickerResp.Ticker.Last, 64)
if err != nil {
return 0, fmt.Errorf("failed to parse price '%s': %w", tickerResp.Ticker.Last, err)
}
return price, nil
}
// SetStopLoss is not supported on Indodax (spot-only exchange)
func (t *IndodaxTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error {
return fmt.Errorf("stop-loss orders are not supported on Indodax (spot-only exchange)")
}
// SetTakeProfit is not supported on Indodax (spot-only exchange)
func (t *IndodaxTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error {
return fmt.Errorf("take-profit orders are not supported on Indodax (spot-only exchange)")
}
// CancelStopLossOrders is a no-op for Indodax
func (t *IndodaxTrader) CancelStopLossOrders(symbol string) error {
return nil
}
// CancelTakeProfitOrders is a no-op for Indodax
func (t *IndodaxTrader) CancelTakeProfitOrders(symbol string) error {
return nil
}
// CancelAllOrders cancels all open orders for a given symbol
func (t *IndodaxTrader) CancelAllOrders(symbol string) error {
t.clearCache()
pair := t.convertSymbol(symbol)
// First get open orders
params := url.Values{}
params.Set("method", "openOrders")
params.Set("pair", pair)
data, err := t.doPrivateRequest(params)
if err != nil {
return fmt.Errorf("failed to get open orders: %w", err)
}
var result struct {
Orders []struct {
OrderID json.Number `json:"order_id"`
Type string `json:"type"`
OrderType string `json:"order_type"`
} `json:"orders"`
}
if err := json.Unmarshal(data, &result); err != nil {
return fmt.Errorf("failed to parse open orders: %w", err)
}
// Cancel each order
for _, order := range result.Orders {
cancelParams := url.Values{}
cancelParams.Set("method", "cancelOrder")
cancelParams.Set("pair", pair)
cancelParams.Set("order_id", order.OrderID.String())
cancelParams.Set("type", order.Type)
if _, err := t.doPrivateRequest(cancelParams); err != nil {
logger.Warnf("[Indodax] Failed to cancel order %s: %v", order.OrderID, err)
} else {
logger.Infof("[Indodax] Cancelled order: %s", order.OrderID)
}
}
return nil
}
// CancelStopOrders is a no-op for Indodax (no stop orders)
func (t *IndodaxTrader) CancelStopOrders(symbol string) error {
return nil
}
// FormatQuantity formats quantity to correct precision for Indodax
func (t *IndodaxTrader) FormatQuantity(symbol string, quantity float64) (string, error) {
pair, err := t.getPair(symbol)
if err != nil {
// Default: 8 decimal places
return strconv.FormatFloat(quantity, 'f', 8, 64), nil
}
precision := pair.PriceRound
if precision <= 0 {
precision = 8
}
// Round down to avoid exceeding balance
factor := math.Pow(10, float64(precision))
rounded := math.Floor(quantity*factor) / factor
return strconv.FormatFloat(rounded, 'f', precision, 64), nil
}
// GetOrderStatus gets the status of a specific order
func (t *IndodaxTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) {
pair := t.convertSymbol(symbol)
params := url.Values{}
params.Set("method", "getOrder")
params.Set("pair", pair)
params.Set("order_id", orderID)
data, err := t.doPrivateRequest(params)
if err != nil {
return nil, fmt.Errorf("failed to get order status: %w", err)
}
var result struct {
Order struct {
OrderID string `json:"order_id"`
Price string `json:"price"`
Type string `json:"type"`
Status string `json:"status"`
SubmitTime string `json:"submit_time"`
FinishTime string `json:"finish_time"`
ClientOrderID string `json:"client_order_id"`
} `json:"order"`
}
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse order: %w", err)
}
// Map Indodax status to standard status
status := "NEW"
switch result.Order.Status {
case "filled":
status = "FILLED"
case "cancelled":
status = "CANCELED"
case "open":
status = "NEW"
}
price, _ := strconv.ParseFloat(result.Order.Price, 64)
return map[string]interface{}{
"status": status,
"avgPrice": price,
"executedQty": 0.0, // Indodax doesn't return executed qty in getOrder
"commission": 0.0,
"orderId": result.Order.OrderID,
}, nil
}
// GetClosedPnL gets closed position PnL records (trade history)
func (t *IndodaxTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) {
// Indodax trade history is limited to 7 days range
params := url.Values{}
params.Set("method", "tradeHistory")
params.Set("pair", "btc_idr") // Default pair; Indodax requires a pair
if limit > 0 {
params.Set("count", strconv.Itoa(limit))
}
if !startTime.IsZero() {
params.Set("since", strconv.FormatInt(startTime.Unix(), 10))
}
data, err := t.doPrivateRequest(params)
if err != nil {
return nil, fmt.Errorf("failed to get trade history: %w", err)
}
var result struct {
Trades []struct {
TradeID string `json:"trade_id"`
OrderID string `json:"order_id"`
Type string `json:"type"`
Price string `json:"price"`
Fee string `json:"fee"`
TradeTime string `json:"trade_time"`
ClientOrderID string `json:"client_order_id"`
} `json:"trades"`
}
if err := json.Unmarshal(data, &result); err != nil {
// Trade history might return empty, that's fine
return nil, nil
}
var records []types.ClosedPnLRecord
for _, trade := range result.Trades {
price, _ := strconv.ParseFloat(trade.Price, 64)
fee, _ := strconv.ParseFloat(trade.Fee, 64)
tradeTime, _ := strconv.ParseInt(trade.TradeTime, 10, 64)
side := "long"
if trade.Type == "sell" {
side = "long" // Selling from a spot position is closing long
}
records = append(records, types.ClosedPnLRecord{
Symbol: "BTCIDR",
Side: side,
ExitPrice: price,
Fee: fee,
ExitTime: time.Unix(tradeTime, 0),
OrderID: trade.OrderID,
CloseType: "manual",
})
}
return records, nil
}
// GetOpenOrders gets open/pending orders
func (t *IndodaxTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) {
pair := t.convertSymbol(symbol)
params := url.Values{}
params.Set("method", "openOrders")
if pair != "" {
params.Set("pair", pair)
}
data, err := t.doPrivateRequest(params)
if err != nil {
return nil, fmt.Errorf("failed to get open orders: %w", err)
}
var result struct {
Orders []struct {
OrderID json.Number `json:"order_id"`
ClientOrderID string `json:"client_order_id"`
SubmitTime string `json:"submit_time"`
Price string `json:"price"`
Type string `json:"type"`
OrderType string `json:"order_type"`
} `json:"orders"`
}
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse open orders: %w", err)
}
var orders []types.OpenOrder
for _, order := range result.Orders {
price, _ := strconv.ParseFloat(order.Price, 64)
side := "BUY"
if order.Type == "sell" {
side = "SELL"
}
orders = append(orders, types.OpenOrder{
OrderID: order.OrderID.String(),
Symbol: t.convertSymbolBack(pair),
Side: side,
PositionSide: "LONG",
Type: "LIMIT",
Price: price,
Status: "NEW",
})
}
return orders, nil
}
// ============================================================
// Helper functions
// ============================================================
// parseFloat safely parses a float from interface{}
func parseFloat(v interface{}) float64 {
if v == nil {
+221
View File
@@ -0,0 +1,221 @@
package indodax
import (
"encoding/json"
"fmt"
"net/url"
"nofx/logger"
"nofx/trader/types"
"strconv"
"strings"
"time"
)
// GetBalance gets account balance from Indodax
func (t *IndodaxTrader) GetBalance() (map[string]interface{}, error) {
// Check cache
t.cacheMutex.RLock()
if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration {
cached := t.cachedBalance
t.cacheMutex.RUnlock()
return cached, nil
}
t.cacheMutex.RUnlock()
params := url.Values{}
params.Set("method", "getInfo")
data, err := t.doPrivateRequest(params)
if err != nil {
return nil, fmt.Errorf("failed to get account info: %w", err)
}
var result struct {
ServerTime int64 `json:"server_time"`
Balance map[string]interface{} `json:"balance"`
BalanceHold map[string]interface{} `json:"balance_hold"`
UserID string `json:"user_id"`
Name string `json:"name"`
Email string `json:"email"`
}
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse balance: %w", err)
}
// Calculate total balance in IDR
idrBalance := parseFloat(result.Balance["idr"])
idrHold := parseFloat(result.BalanceHold["idr"])
totalIDR := idrBalance + idrHold
balance := map[string]interface{}{
"totalWalletBalance": totalIDR,
"availableBalance": idrBalance,
"totalUnrealizedProfit": 0.0,
"totalEquity": totalIDR,
"balance": totalIDR,
"idr_balance": idrBalance,
"idr_hold": idrHold,
"currency": "IDR",
"user_id": result.UserID,
"server_time": result.ServerTime,
}
// Add individual crypto balances
for currency, amount := range result.Balance {
if currency != "idr" {
balance["balance_"+currency] = parseFloat(amount)
}
}
for currency, amount := range result.BalanceHold {
if currency != "idr" {
balance["hold_"+currency] = parseFloat(amount)
}
}
// Update cache
t.cacheMutex.Lock()
t.cachedBalance = balance
t.balanceCacheTime = time.Now()
t.cacheMutex.Unlock()
return balance, nil
}
// GetPositions returns currently held crypto balances as "positions"
// Since Indodax is spot-only, each non-zero crypto balance is treated as a position
func (t *IndodaxTrader) GetPositions() ([]map[string]interface{}, error) {
// Check cache
t.cacheMutex.RLock()
if t.cachedPositions != nil && time.Since(t.positionCacheTime) < t.cacheDuration {
cached := t.cachedPositions
t.cacheMutex.RUnlock()
return cached, nil
}
t.cacheMutex.RUnlock()
params := url.Values{}
params.Set("method", "getInfo")
data, err := t.doPrivateRequest(params)
if err != nil {
return nil, fmt.Errorf("failed to get positions: %w", err)
}
var result struct {
Balance map[string]interface{} `json:"balance"`
BalanceHold map[string]interface{} `json:"balance_hold"`
}
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse positions: %w", err)
}
var positions []map[string]interface{}
for currency, amountRaw := range result.Balance {
if currency == "idr" {
continue
}
amount := parseFloat(amountRaw)
holdAmount := parseFloat(result.BalanceHold[currency])
totalAmount := amount + holdAmount
if totalAmount <= 0 {
continue
}
// Get market price for this coin
markPrice, _ := t.GetMarketPrice(strings.ToUpper(currency) + "IDR")
// Calculate position value in IDR
notionalValue := totalAmount * markPrice
position := map[string]interface{}{
"symbol": strings.ToUpper(currency) + "IDR",
"side": "LONG",
"positionAmt": totalAmount,
"entryPrice": markPrice, // Spot doesn't track entry price
"markPrice": markPrice,
"unRealizedProfit": 0.0, // Spot doesn't track unrealized PnL
"leverage": 1.0,
"mgnMode": "spot",
"notionalValue": notionalValue,
"currency": currency,
"available": amount,
"hold": holdAmount,
}
positions = append(positions, position)
}
// Update cache
t.cacheMutex.Lock()
t.cachedPositions = positions
t.positionCacheTime = time.Now()
t.cacheMutex.Unlock()
return positions, nil
}
// GetClosedPnL gets closed position PnL records (trade history)
func (t *IndodaxTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) {
// Indodax trade history is limited to 7 days range
params := url.Values{}
params.Set("method", "tradeHistory")
params.Set("pair", "btc_idr") // Default pair; Indodax requires a pair
if limit > 0 {
params.Set("count", strconv.Itoa(limit))
}
if !startTime.IsZero() {
params.Set("since", strconv.FormatInt(startTime.Unix(), 10))
}
data, err := t.doPrivateRequest(params)
if err != nil {
return nil, fmt.Errorf("failed to get trade history: %w", err)
}
var result struct {
Trades []struct {
TradeID string `json:"trade_id"`
OrderID string `json:"order_id"`
Type string `json:"type"`
Price string `json:"price"`
Fee string `json:"fee"`
TradeTime string `json:"trade_time"`
ClientOrderID string `json:"client_order_id"`
} `json:"trades"`
}
if err := json.Unmarshal(data, &result); err != nil {
// Trade history might return empty, that's fine
logger.Infof("[Indodax] Trade history parse note: %v", err)
return nil, nil
}
var records []types.ClosedPnLRecord
for _, trade := range result.Trades {
price, _ := strconv.ParseFloat(trade.Price, 64)
fee, _ := strconv.ParseFloat(trade.Fee, 64)
tradeTime, _ := strconv.ParseInt(trade.TradeTime, 10, 64)
side := "long"
if trade.Type == "sell" {
side = "long" // Selling from a spot position is closing long
}
records = append(records, types.ClosedPnLRecord{
Symbol: "BTCIDR",
Side: side,
ExitPrice: price,
Fee: fee,
ExitTime: time.Unix(tradeTime, 0),
OrderID: trade.OrderID,
CloseType: "manual",
})
}
return records, nil
}
+351
View File
@@ -0,0 +1,351 @@
package indodax
import (
"encoding/json"
"fmt"
"math"
"net/url"
"nofx/logger"
"nofx/trader/types"
"strconv"
"strings"
)
// OpenLong opens a spot buy order
func (t *IndodaxTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
t.clearCache()
pair := t.convertSymbol(symbol)
coin := t.getCoinFromSymbol(symbol)
// Get market price to calculate IDR amount
price, err := t.GetMarketPrice(symbol)
if err != nil {
return nil, fmt.Errorf("failed to get market price: %w", err)
}
params := url.Values{}
params.Set("method", "trade")
params.Set("pair", pair)
params.Set("type", "buy")
params.Set("price", strconv.FormatFloat(price, 'f', 0, 64))
params.Set(coin, strconv.FormatFloat(quantity, 'f', 8, 64))
params.Set("order_type", "limit")
data, err := t.doPrivateRequest(params)
if err != nil {
return nil, fmt.Errorf("failed to place buy order: %w", err)
}
var result map[string]interface{}
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse trade response: %w", err)
}
logger.Infof("[Indodax] Buy order placed: %s qty=%.8f price=%.0f", symbol, quantity, price)
return map[string]interface{}{
"orderId": result["order_id"],
"symbol": symbol,
"side": "BUY",
"price": price,
"qty": quantity,
"status": "NEW",
}, nil
}
// OpenShort is not supported on Indodax (spot-only exchange)
func (t *IndodaxTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
return nil, fmt.Errorf("short selling is not supported on Indodax (spot-only exchange)")
}
// CloseLong closes a spot position by selling
func (t *IndodaxTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) {
t.clearCache()
pair := t.convertSymbol(symbol)
coin := t.getCoinFromSymbol(symbol)
// If quantity is 0, sell all available balance
if quantity <= 0 {
balance, err := t.GetBalance()
if err != nil {
return nil, fmt.Errorf("failed to get balance for close all: %w", err)
}
available := parseFloat(balance["balance_"+coin])
if available <= 0 {
return nil, fmt.Errorf("no %s balance to sell", coin)
}
quantity = available
}
// Get market price
price, err := t.GetMarketPrice(symbol)
if err != nil {
return nil, fmt.Errorf("failed to get market price: %w", err)
}
params := url.Values{}
params.Set("method", "trade")
params.Set("pair", pair)
params.Set("type", "sell")
params.Set("price", strconv.FormatFloat(price, 'f', 0, 64))
params.Set(coin, strconv.FormatFloat(quantity, 'f', 8, 64))
params.Set("order_type", "limit")
data, err := t.doPrivateRequest(params)
if err != nil {
return nil, fmt.Errorf("failed to place sell order: %w", err)
}
var result map[string]interface{}
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse trade response: %w", err)
}
logger.Infof("[Indodax] Sell order placed: %s qty=%.8f price=%.0f", symbol, quantity, price)
return map[string]interface{}{
"orderId": result["order_id"],
"symbol": symbol,
"side": "SELL",
"price": price,
"qty": quantity,
"status": "NEW",
}, nil
}
// CloseShort is not supported on Indodax (spot-only exchange)
func (t *IndodaxTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) {
return nil, fmt.Errorf("short selling is not supported on Indodax (spot-only exchange)")
}
// SetLeverage is a no-op for Indodax (spot-only, no leverage)
func (t *IndodaxTrader) SetLeverage(symbol string, leverage int) error {
logger.Infof("[Indodax] SetLeverage ignored (spot-only exchange, no leverage support)")
return nil
}
// SetMarginMode is a no-op for Indodax (spot-only, no margin)
func (t *IndodaxTrader) SetMarginMode(symbol string, isCrossMargin bool) error {
logger.Infof("[Indodax] SetMarginMode ignored (spot-only exchange, no margin support)")
return nil
}
// GetMarketPrice gets the current market price for a symbol
func (t *IndodaxTrader) GetMarketPrice(symbol string) (float64, error) {
pairID := strings.ToLower(strings.ReplaceAll(t.convertSymbol(symbol), "_", ""))
data, err := t.doPublicRequest("/ticker/" + pairID)
if err != nil {
return 0, fmt.Errorf("failed to get ticker: %w", err)
}
var tickerResp IndodaxTickerResponse
if err := json.Unmarshal(data, &tickerResp); err != nil {
return 0, fmt.Errorf("failed to parse ticker: %w", err)
}
price, err := strconv.ParseFloat(tickerResp.Ticker.Last, 64)
if err != nil {
return 0, fmt.Errorf("failed to parse price '%s': %w", tickerResp.Ticker.Last, err)
}
return price, nil
}
// SetStopLoss is not supported on Indodax (spot-only exchange)
func (t *IndodaxTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error {
return fmt.Errorf("stop-loss orders are not supported on Indodax (spot-only exchange)")
}
// SetTakeProfit is not supported on Indodax (spot-only exchange)
func (t *IndodaxTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error {
return fmt.Errorf("take-profit orders are not supported on Indodax (spot-only exchange)")
}
// CancelStopLossOrders is a no-op for Indodax
func (t *IndodaxTrader) CancelStopLossOrders(symbol string) error {
return nil
}
// CancelTakeProfitOrders is a no-op for Indodax
func (t *IndodaxTrader) CancelTakeProfitOrders(symbol string) error {
return nil
}
// CancelAllOrders cancels all open orders for a given symbol
func (t *IndodaxTrader) CancelAllOrders(symbol string) error {
t.clearCache()
pair := t.convertSymbol(symbol)
// First get open orders
params := url.Values{}
params.Set("method", "openOrders")
params.Set("pair", pair)
data, err := t.doPrivateRequest(params)
if err != nil {
return fmt.Errorf("failed to get open orders: %w", err)
}
var result struct {
Orders []struct {
OrderID json.Number `json:"order_id"`
Type string `json:"type"`
OrderType string `json:"order_type"`
} `json:"orders"`
}
if err := json.Unmarshal(data, &result); err != nil {
return fmt.Errorf("failed to parse open orders: %w", err)
}
// Cancel each order
for _, order := range result.Orders {
cancelParams := url.Values{}
cancelParams.Set("method", "cancelOrder")
cancelParams.Set("pair", pair)
cancelParams.Set("order_id", order.OrderID.String())
cancelParams.Set("type", order.Type)
if _, err := t.doPrivateRequest(cancelParams); err != nil {
logger.Warnf("[Indodax] Failed to cancel order %s: %v", order.OrderID, err)
} else {
logger.Infof("[Indodax] Cancelled order: %s", order.OrderID)
}
}
return nil
}
// CancelStopOrders is a no-op for Indodax (no stop orders)
func (t *IndodaxTrader) CancelStopOrders(symbol string) error {
return nil
}
// FormatQuantity formats quantity to correct precision for Indodax
func (t *IndodaxTrader) FormatQuantity(symbol string, quantity float64) (string, error) {
pair, err := t.getPair(symbol)
if err != nil {
// Default: 8 decimal places
return strconv.FormatFloat(quantity, 'f', 8, 64), nil
}
precision := pair.PriceRound
if precision <= 0 {
precision = 8
}
// Round down to avoid exceeding balance
factor := math.Pow(10, float64(precision))
rounded := math.Floor(quantity*factor) / factor
return strconv.FormatFloat(rounded, 'f', precision, 64), nil
}
// GetOrderStatus gets the status of a specific order
func (t *IndodaxTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) {
pair := t.convertSymbol(symbol)
params := url.Values{}
params.Set("method", "getOrder")
params.Set("pair", pair)
params.Set("order_id", orderID)
data, err := t.doPrivateRequest(params)
if err != nil {
return nil, fmt.Errorf("failed to get order status: %w", err)
}
var result struct {
Order struct {
OrderID string `json:"order_id"`
Price string `json:"price"`
Type string `json:"type"`
Status string `json:"status"`
SubmitTime string `json:"submit_time"`
FinishTime string `json:"finish_time"`
ClientOrderID string `json:"client_order_id"`
} `json:"order"`
}
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse order: %w", err)
}
// Map Indodax status to standard status
status := "NEW"
switch result.Order.Status {
case "filled":
status = "FILLED"
case "cancelled":
status = "CANCELED"
case "open":
status = "NEW"
}
price, _ := strconv.ParseFloat(result.Order.Price, 64)
return map[string]interface{}{
"status": status,
"avgPrice": price,
"executedQty": 0.0, // Indodax doesn't return executed qty in getOrder
"commission": 0.0,
"orderId": result.Order.OrderID,
}, nil
}
// GetOpenOrders gets open/pending orders
func (t *IndodaxTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) {
pair := t.convertSymbol(symbol)
params := url.Values{}
params.Set("method", "openOrders")
if pair != "" {
params.Set("pair", pair)
}
data, err := t.doPrivateRequest(params)
if err != nil {
return nil, fmt.Errorf("failed to get open orders: %w", err)
}
var result struct {
Orders []struct {
OrderID json.Number `json:"order_id"`
ClientOrderID string `json:"client_order_id"`
SubmitTime string `json:"submit_time"`
Price string `json:"price"`
Type string `json:"type"`
OrderType string `json:"order_type"`
} `json:"orders"`
}
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse open orders: %w", err)
}
var orders []types.OpenOrder
for _, order := range result.Orders {
price, _ := strconv.ParseFloat(order.Price, 64)
side := "BUY"
if order.Type == "sell" {
side = "SELL"
}
orders = append(orders, types.OpenOrder{
OrderID: order.OrderID.String(),
Symbol: t.convertSymbolBack(pair),
Side: side,
PositionSide: "LONG",
Type: "LIMIT",
Price: price,
Status: "NEW",
})
}
return orders, nil
}
-924
View File
@@ -11,7 +11,6 @@ import (
"math"
"net/http"
"nofx/logger"
"nofx/trader/types"
"strconv"
"strings"
"sync"
@@ -281,164 +280,6 @@ func (t *KuCoinTrader) convertSymbolBack(kcSymbol string) string {
return sym
}
// GetBalance gets account balance
func (t *KuCoinTrader) GetBalance() (map[string]interface{}, error) {
// Check cache
t.balanceCacheMutex.RLock()
if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration {
t.balanceCacheMutex.RUnlock()
return t.cachedBalance, nil
}
t.balanceCacheMutex.RUnlock()
data, err := t.doRequest("GET", kucoinAccountPath+"?currency=USDT", nil)
if err != nil {
return nil, fmt.Errorf("failed to get account balance: %w", err)
}
var account struct {
AccountEquity float64 `json:"accountEquity"`
UnrealisedPNL float64 `json:"unrealisedPNL"`
MarginBalance float64 `json:"marginBalance"`
PositionMargin float64 `json:"positionMargin"`
OrderMargin float64 `json:"orderMargin"`
FrozenFunds float64 `json:"frozenFunds"`
AvailableBalance float64 `json:"availableBalance"`
Currency string `json:"currency"`
}
if err := json.Unmarshal(data, &account); err != nil {
return nil, fmt.Errorf("failed to parse balance data: %w", err)
}
result := map[string]interface{}{
"totalWalletBalance": account.MarginBalance, // Wallet balance (without unrealized PnL)
"availableBalance": account.AvailableBalance,
"totalUnrealizedProfit": account.UnrealisedPNL,
"total_equity": account.AccountEquity,
"totalEquity": account.AccountEquity, // For GetAccountInfo compatibility
}
logger.Infof("✓ KuCoin balance: Total equity=%.2f, Available=%.2f, Unrealized PnL=%.2f",
account.AccountEquity, account.AvailableBalance, account.UnrealisedPNL)
// Update cache
t.balanceCacheMutex.Lock()
t.cachedBalance = result
t.balanceCacheTime = time.Now()
t.balanceCacheMutex.Unlock()
return result, nil
}
// GetPositions gets all positions
func (t *KuCoinTrader) GetPositions() ([]map[string]interface{}, error) {
// Check cache
t.positionsCacheMutex.RLock()
if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration {
t.positionsCacheMutex.RUnlock()
return t.cachedPositions, nil
}
t.positionsCacheMutex.RUnlock()
data, err := t.doRequest("GET", kucoinPositionPath, nil)
if err != nil {
return nil, fmt.Errorf("failed to get positions: %w", err)
}
var positions []struct {
Symbol string `json:"symbol"`
CurrentQty int64 `json:"currentQty"` // Position quantity (in lots, integer)
AvgEntryPrice float64 `json:"avgEntryPrice"` // Average entry price (string in API)
MarkPrice float64 `json:"markPrice"` // Mark price
UnrealisedPnl float64 `json:"unrealisedPnl"` // Unrealized PnL
Leverage float64 `json:"leverage"` // Leverage setting
RealLeverage float64 `json:"realLeverage"` // Effective leverage (may be nil in cross mode)
LiquidationPrice float64 `json:"liquidationPrice"`// Liquidation price
Multiplier float64 `json:"multiplier"` // Contract multiplier
IsOpen bool `json:"isOpen"`
CrossMode bool `json:"crossMode"`
OpeningTimestamp int64 `json:"openingTimestamp"`
SettleCurrency string `json:"settleCurrency"`
}
if err := json.Unmarshal(data, &positions); err != nil {
return nil, fmt.Errorf("failed to parse position data: %w", err)
}
var result []map[string]interface{}
for _, pos := range positions {
if !pos.IsOpen || pos.CurrentQty == 0 {
continue
}
// Convert symbol format
symbol := t.convertSymbolBack(pos.Symbol)
// Determine side based on position quantity
// KuCoin: positive qty = long, negative qty = short
side := "long"
qty := pos.CurrentQty
if qty < 0 {
side = "short"
qty = -qty
}
// Convert lots to actual quantity using multiplier
// Position quantity = lots * multiplier
multiplier := pos.Multiplier
if multiplier == 0 {
multiplier = 0.001 // Default for BTC
}
positionAmt := float64(qty) * multiplier
// Determine margin mode
mgnMode := "isolated"
if pos.CrossMode {
mgnMode = "cross"
}
// Use Leverage field (setting), fallback to RealLeverage (effective), default to 10
leverage := pos.Leverage
if leverage == 0 {
leverage = pos.RealLeverage
}
if leverage == 0 {
leverage = 10 // Default leverage
}
posMap := map[string]interface{}{
"symbol": symbol,
"positionAmt": positionAmt,
"entryPrice": pos.AvgEntryPrice,
"markPrice": pos.MarkPrice,
"unRealizedProfit": pos.UnrealisedPnl,
"leverage": leverage,
"liquidationPrice": pos.LiquidationPrice,
"side": side,
"mgnMode": mgnMode,
"createdTime": pos.OpeningTimestamp,
}
result = append(result, posMap)
}
// Update cache
t.positionsCacheMutex.Lock()
t.cachedPositions = result
t.positionsCacheTime = time.Now()
t.positionsCacheMutex.Unlock()
return result, nil
}
// InvalidatePositionCache clears the position cache
func (t *KuCoinTrader) InvalidatePositionCache() {
t.positionsCacheMutex.Lock()
t.cachedPositions = nil
t.positionsCacheTime = time.Time{}
t.positionsCacheMutex.Unlock()
}
// getContract gets contract info
func (t *KuCoinTrader) getContract(symbol string) (*KuCoinContract, error) {
kcSymbol := t.convertSymbol(symbol)
@@ -526,768 +367,3 @@ func (t *KuCoinTrader) quantityToLots(symbol string, quantity float64) (int64, e
return lotsInt, nil
}
// SetMarginMode sets margin mode
func (t *KuCoinTrader) SetMarginMode(symbol string, isCrossMargin bool) error {
// KuCoin sets margin mode per position, handled automatically
logger.Infof("✓ KuCoin margin mode: %v (handled per position)", isCrossMargin)
return nil
}
// SetLeverage sets leverage for a symbol
func (t *KuCoinTrader) SetLeverage(symbol string, leverage int) error {
kcSymbol := t.convertSymbol(symbol)
body := map[string]interface{}{
"symbol": kcSymbol,
"leverage": fmt.Sprintf("%d", leverage),
}
_, err := t.doRequest("POST", kucoinLeveragePath, body)
if err != nil {
// Ignore if already at target leverage
if strings.Contains(err.Error(), "same") || strings.Contains(err.Error(), "already") {
logger.Infof("✓ %s leverage is already %dx", symbol, leverage)
return nil
}
return fmt.Errorf("failed to set leverage: %w", err)
}
logger.Infof("✓ %s leverage set to %dx", symbol, leverage)
return nil
}
// OpenLong opens long position
func (t *KuCoinTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
// Cancel old orders
t.CancelAllOrders(symbol)
// Set leverage
if err := t.SetLeverage(symbol, leverage); err != nil {
logger.Infof("⚠️ Failed to set leverage: %v", err)
}
kcSymbol := t.convertSymbol(symbol)
// Convert quantity to lots
lots, err := t.quantityToLots(symbol, quantity)
if err != nil {
return nil, fmt.Errorf("failed to calculate lots: %w", err)
}
body := map[string]interface{}{
"clientOid": fmt.Sprintf("nfx%d", time.Now().UnixNano()),
"symbol": kcSymbol,
"side": "buy",
"type": "market",
"size": lots,
"leverage": fmt.Sprintf("%d", leverage),
"reduceOnly": false,
"marginMode": "CROSS", // Use cross margin mode
}
data, err := t.doRequest("POST", kucoinOrderPath, body)
if err != nil {
return nil, fmt.Errorf("failed to open long position: %w", err)
}
var result struct {
OrderId string `json:"orderId"`
}
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse order response: %w", err)
}
logger.Infof("✓ KuCoin opened long position: %s, lots=%d, orderId=%s", symbol, lots, result.OrderId)
// Query order to get fill price
fillPrice := t.queryOrderFillPrice(result.OrderId)
return map[string]interface{}{
"orderId": result.OrderId,
"symbol": symbol,
"status": "FILLED",
"fillPrice": fillPrice,
}, nil
}
// OpenShort opens short position
func (t *KuCoinTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
// Cancel old orders
t.CancelAllOrders(symbol)
// Set leverage
if err := t.SetLeverage(symbol, leverage); err != nil {
logger.Infof("⚠️ Failed to set leverage: %v", err)
}
kcSymbol := t.convertSymbol(symbol)
// Convert quantity to lots
lots, err := t.quantityToLots(symbol, quantity)
if err != nil {
return nil, fmt.Errorf("failed to calculate lots: %w", err)
}
body := map[string]interface{}{
"clientOid": fmt.Sprintf("nfx%d", time.Now().UnixNano()),
"symbol": kcSymbol,
"side": "sell",
"type": "market",
"size": lots,
"leverage": fmt.Sprintf("%d", leverage),
"reduceOnly": false,
"marginMode": "CROSS", // Use cross margin mode
}
data, err := t.doRequest("POST", kucoinOrderPath, body)
if err != nil {
return nil, fmt.Errorf("failed to open short position: %w", err)
}
var result struct {
OrderId string `json:"orderId"`
}
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse order response: %w", err)
}
logger.Infof("✓ KuCoin opened short position: %s, lots=%d, orderId=%s", symbol, lots, result.OrderId)
// Query order to get fill price
fillPrice := t.queryOrderFillPrice(result.OrderId)
return map[string]interface{}{
"orderId": result.OrderId,
"symbol": symbol,
"status": "FILLED",
"fillPrice": fillPrice,
}, nil
}
// queryOrderFillPrice queries order status and returns fill price
func (t *KuCoinTrader) queryOrderFillPrice(orderId string) float64 {
// Wait a bit for order to fill
time.Sleep(500 * time.Millisecond)
path := fmt.Sprintf("%s/%s", kucoinOrderPath, orderId)
data, err := t.doRequest("GET", path, nil)
if err != nil {
logger.Warnf("Failed to query order %s: %v", orderId, err)
return 0
}
var order struct {
DealAvgPrice float64 `json:"dealAvgPrice"`
Status string `json:"status"`
DealSize int64 `json:"dealSize"`
}
if err := json.Unmarshal(data, &order); err != nil {
return 0
}
return order.DealAvgPrice
}
// CloseLong closes long position
func (t *KuCoinTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) {
// Invalidate position cache and get fresh positions
t.InvalidatePositionCache()
positions, err := t.GetPositions()
if err != nil {
return nil, fmt.Errorf("failed to get positions: %w", err)
}
// Find actual position and get margin mode
var actualQty float64
var posFound bool
var marginMode string = "CROSS" // Default to CROSS
for _, pos := range positions {
if pos["symbol"] == symbol && pos["side"] == "long" {
actualQty = pos["positionAmt"].(float64)
posFound = true
// Get margin mode from position
if mgnMode, ok := pos["mgnMode"].(string); ok {
marginMode = strings.ToUpper(mgnMode)
}
break
}
}
if !posFound || actualQty == 0 {
return map[string]interface{}{
"status": "NO_POSITION",
"message": fmt.Sprintf("No long position found for %s on KuCoin", symbol),
}, nil
}
// Use actual quantity from exchange
if quantity == 0 || quantity > actualQty {
quantity = actualQty
}
kcSymbol := t.convertSymbol(symbol)
// Convert quantity to lots
lots, err := t.quantityToLots(symbol, quantity)
if err != nil {
return nil, fmt.Errorf("failed to calculate lots: %w", err)
}
body := map[string]interface{}{
"clientOid": fmt.Sprintf("nfx%d", time.Now().UnixNano()),
"symbol": kcSymbol,
"side": "sell",
"type": "market",
"size": lots,
"reduceOnly": true,
"closeOrder": true,
"marginMode": marginMode, // Use position's margin mode
}
data, err := t.doRequest("POST", kucoinOrderPath, body)
if err != nil {
return nil, fmt.Errorf("failed to close long position: %w", err)
}
var result struct {
OrderId string `json:"orderId"`
}
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse order response: %w", err)
}
logger.Infof("✓ KuCoin closed long position: %s", symbol)
// Cancel pending orders
t.CancelAllOrders(symbol)
return map[string]interface{}{
"orderId": result.OrderId,
"symbol": symbol,
"status": "FILLED",
}, nil
}
// CloseShort closes short position
func (t *KuCoinTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) {
// Invalidate position cache and get fresh positions
t.InvalidatePositionCache()
positions, err := t.GetPositions()
if err != nil {
return nil, fmt.Errorf("failed to get positions: %w", err)
}
// Find actual position and get margin mode
var actualQty float64
var posFound bool
var marginMode string = "CROSS" // Default to CROSS
for _, pos := range positions {
if pos["symbol"] == symbol && pos["side"] == "short" {
actualQty = pos["positionAmt"].(float64)
posFound = true
// Get margin mode from position
if mgnMode, ok := pos["mgnMode"].(string); ok {
marginMode = strings.ToUpper(mgnMode)
}
break
}
}
if !posFound || actualQty == 0 {
return map[string]interface{}{
"status": "NO_POSITION",
"message": fmt.Sprintf("No short position found for %s on KuCoin", symbol),
}, nil
}
// Use actual quantity from exchange
if quantity == 0 || quantity > actualQty {
quantity = actualQty
}
kcSymbol := t.convertSymbol(symbol)
// Convert quantity to lots
lots, err := t.quantityToLots(symbol, quantity)
if err != nil {
return nil, fmt.Errorf("failed to calculate lots: %w", err)
}
body := map[string]interface{}{
"clientOid": fmt.Sprintf("nfx%d", time.Now().UnixNano()),
"symbol": kcSymbol,
"side": "buy",
"type": "market",
"size": lots,
"reduceOnly": true,
"closeOrder": true,
"marginMode": marginMode, // Use position's margin mode
}
data, err := t.doRequest("POST", kucoinOrderPath, body)
if err != nil {
return nil, fmt.Errorf("failed to close short position: %w", err)
}
var result struct {
OrderId string `json:"orderId"`
}
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse order response: %w", err)
}
logger.Infof("✓ KuCoin closed short position: %s", symbol)
// Cancel pending orders
t.CancelAllOrders(symbol)
return map[string]interface{}{
"orderId": result.OrderId,
"symbol": symbol,
"status": "FILLED",
}, nil
}
// GetMarketPrice gets market price
func (t *KuCoinTrader) GetMarketPrice(symbol string) (float64, error) {
kcSymbol := t.convertSymbol(symbol)
path := fmt.Sprintf("%s?symbol=%s", kucoinTickerPath, kcSymbol)
data, err := t.doRequest("GET", path, nil)
if err != nil {
return 0, fmt.Errorf("failed to get price: %w", err)
}
var ticker struct {
Price string `json:"price"`
}
if err := json.Unmarshal(data, &ticker); err != nil {
return 0, err
}
price, _ := strconv.ParseFloat(ticker.Price, 64)
return price, nil
}
// SetStopLoss sets stop loss order
func (t *KuCoinTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error {
kcSymbol := t.convertSymbol(symbol)
// Convert quantity to lots
lots, err := t.quantityToLots(symbol, quantity)
if err != nil {
return fmt.Errorf("failed to calculate lots: %w", err)
}
// Determine side: close long = sell, close short = buy
side := "sell"
stop := "down" // Long position: stop loss triggers when price goes down
if strings.ToUpper(positionSide) == "SHORT" {
side = "buy"
stop = "up" // Short position: stop loss triggers when price goes up
}
body := map[string]interface{}{
"clientOid": fmt.Sprintf("nfxsl%d", time.Now().UnixNano()),
"symbol": kcSymbol,
"side": side,
"type": "market",
"size": lots,
"stop": stop,
"stopPriceType": "MP", // Mark Price
"stopPrice": fmt.Sprintf("%.8f", stopPrice),
"reduceOnly": true,
"closeOrder": true,
}
_, err = t.doRequest("POST", kucoinStopOrderPath, body)
if err != nil {
return fmt.Errorf("failed to set stop loss: %w", err)
}
logger.Infof("✓ Stop loss set: %.4f", stopPrice)
return nil
}
// SetTakeProfit sets take profit order
func (t *KuCoinTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error {
kcSymbol := t.convertSymbol(symbol)
// Convert quantity to lots
lots, err := t.quantityToLots(symbol, quantity)
if err != nil {
return fmt.Errorf("failed to calculate lots: %w", err)
}
// Determine side: close long = sell, close short = buy
side := "sell"
stop := "up" // Long position: take profit triggers when price goes up
if strings.ToUpper(positionSide) == "SHORT" {
side = "buy"
stop = "down" // Short position: take profit triggers when price goes down
}
body := map[string]interface{}{
"clientOid": fmt.Sprintf("nfxtp%d", time.Now().UnixNano()),
"symbol": kcSymbol,
"side": side,
"type": "market",
"size": lots,
"stop": stop,
"stopPriceType": "MP", // Mark Price
"stopPrice": fmt.Sprintf("%.8f", takeProfitPrice),
"reduceOnly": true,
"closeOrder": true,
}
_, err = t.doRequest("POST", kucoinStopOrderPath, body)
if err != nil {
return fmt.Errorf("failed to set take profit: %w", err)
}
logger.Infof("✓ Take profit set: %.4f", takeProfitPrice)
return nil
}
// CancelStopLossOrders cancels stop loss orders
func (t *KuCoinTrader) CancelStopLossOrders(symbol string) error {
return t.cancelStopOrdersByType(symbol, "sl")
}
// CancelTakeProfitOrders cancels take profit orders
func (t *KuCoinTrader) CancelTakeProfitOrders(symbol string) error {
return t.cancelStopOrdersByType(symbol, "tp")
}
// cancelStopOrdersByType cancels stop orders by type
func (t *KuCoinTrader) cancelStopOrdersByType(symbol string, orderType string) error {
kcSymbol := t.convertSymbol(symbol)
// Get pending stop orders
path := fmt.Sprintf("%s?symbol=%s", kucoinStopOrderPath, kcSymbol)
data, err := t.doRequest("GET", path, nil)
if err != nil {
return err
}
var response struct {
Items []struct {
Id string `json:"id"`
ClientOid string `json:"clientOid"`
Stop string `json:"stop"`
} `json:"items"`
}
if err := json.Unmarshal(data, &response); err != nil {
// Try alternate format (direct array)
var items []struct {
Id string `json:"id"`
ClientOid string `json:"clientOid"`
Stop string `json:"stop"`
}
if err := json.Unmarshal(data, &items); err != nil {
return err
}
response.Items = items
}
// Cancel matching orders
for _, order := range response.Items {
// Check if order matches type based on clientOid prefix
if orderType == "sl" && !strings.Contains(order.ClientOid, "sl") {
continue
}
if orderType == "tp" && !strings.Contains(order.ClientOid, "tp") {
continue
}
cancelPath := fmt.Sprintf("%s/%s", kucoinCancelStopPath, order.Id)
_, err := t.doRequest("DELETE", cancelPath, nil)
if err != nil {
logger.Warnf("Failed to cancel stop order %s: %v", order.Id, err)
}
}
return nil
}
// CancelStopOrders cancels all stop orders for symbol
func (t *KuCoinTrader) CancelStopOrders(symbol string) error {
kcSymbol := t.convertSymbol(symbol)
path := fmt.Sprintf("%s?symbol=%s", kucoinCancelStopPath, kcSymbol)
_, err := t.doRequest("DELETE", path, nil)
if err != nil {
// Ignore if no orders to cancel
if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "400100") {
return nil
}
return err
}
logger.Infof("✓ Cancelled stop orders for %s", symbol)
return nil
}
// CancelAllOrders cancels all pending orders for symbol
func (t *KuCoinTrader) CancelAllOrders(symbol string) error {
kcSymbol := t.convertSymbol(symbol)
// Cancel regular orders
path := fmt.Sprintf("%s?symbol=%s", kucoinCancelOrderPath, kcSymbol)
_, err := t.doRequest("DELETE", path, nil)
if err != nil && !strings.Contains(err.Error(), "not found") {
logger.Warnf("Failed to cancel regular orders: %v", err)
}
// Cancel stop orders
t.CancelStopOrders(symbol)
return nil
}
// FormatQuantity formats quantity to correct precision
func (t *KuCoinTrader) FormatQuantity(symbol string, quantity float64) (string, error) {
contract, err := t.getContract(symbol)
if err != nil {
return "", err
}
// Calculate lots
lots := quantity / contract.Multiplier
// Round to integer
lotsInt := int64(math.Round(lots))
return strconv.FormatInt(lotsInt, 10), nil
}
// GetOrderStatus gets order status
func (t *KuCoinTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) {
path := fmt.Sprintf("%s/%s", kucoinOrderPath, orderID)
data, err := t.doRequest("GET", path, nil)
if err != nil {
return nil, fmt.Errorf("failed to get order status: %w", err)
}
var order struct {
Id string `json:"id"`
Symbol string `json:"symbol"`
Status string `json:"status"`
DealAvgPrice float64 `json:"dealAvgPrice"`
DealSize int64 `json:"dealSize"`
Fee float64 `json:"fee"`
Side string `json:"side"`
}
if err := json.Unmarshal(data, &order); err != nil {
return nil, err
}
// Convert status
status := "NEW"
if order.Status == "done" {
status = "FILLED"
} else if order.Status == "cancelled" || order.Status == "canceled" {
status = "CANCELED"
}
return map[string]interface{}{
"orderId": order.Id,
"symbol": t.convertSymbolBack(order.Symbol),
"status": status,
"avgPrice": order.DealAvgPrice,
"executedQty": order.DealSize,
"commission": order.Fee,
}, nil
}
// GetClosedPnL gets closed position PnL records
func (t *KuCoinTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) {
if limit <= 0 {
limit = 100
}
if limit > 100 {
limit = 100
}
// KuCoin closed positions API
path := fmt.Sprintf("/api/v1/history-positions?status=CLOSE&limit=%d", limit)
if !startTime.IsZero() {
path += fmt.Sprintf("&from=%d", startTime.UnixMilli())
}
data, err := t.doRequest("GET", path, nil)
if err != nil {
return nil, fmt.Errorf("failed to get closed PnL: %w", err)
}
var response struct {
HasMore bool `json:"hasMore"`
DataList []struct {
Symbol string `json:"symbol"`
OpenPrice float64 `json:"avgEntryPrice"`
ClosePrice float64 `json:"avgClosePrice"`
Qty int64 `json:"qty"`
RealisedPnl float64 `json:"realisedGrossCost"`
CloseTime int64 `json:"closeTime"`
OpenTime int64 `json:"openTime"`
PositionId string `json:"id"`
CloseType string `json:"type"`
Leverage int `json:"leverage"`
SettleCurrency string `json:"settleCurrency"`
} `json:"dataList"`
}
if err := json.Unmarshal(data, &response); err != nil {
return nil, fmt.Errorf("failed to parse closed PnL: %w", err)
}
var records []types.ClosedPnLRecord
for _, item := range response.DataList {
side := "long"
qty := item.Qty
if qty < 0 {
side = "short"
qty = -qty
}
// Map close type
closeType := "unknown"
switch strings.ToUpper(item.CloseType) {
case "CLOSE", "MANUAL":
closeType = "manual"
case "STOP", "STOPLOSS":
closeType = "stop_loss"
case "TAKEPROFIT", "TP":
closeType = "take_profit"
case "LIQUIDATION", "LIQ", "ADL":
closeType = "liquidation"
}
records = append(records, types.ClosedPnLRecord{
Symbol: t.convertSymbolBack(item.Symbol),
Side: side,
EntryPrice: item.OpenPrice,
ExitPrice: item.ClosePrice,
Quantity: float64(qty),
RealizedPnL: item.RealisedPnl,
Leverage: item.Leverage,
EntryTime: time.UnixMilli(item.OpenTime),
ExitTime: time.UnixMilli(item.CloseTime),
ExchangeID: item.PositionId,
CloseType: closeType,
})
}
return records, nil
}
// GetOpenOrders gets open/pending orders
func (t *KuCoinTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) {
kcSymbol := t.convertSymbol(symbol)
// Get regular orders
path := fmt.Sprintf("%s?symbol=%s&status=active", kucoinOrderPath, kcSymbol)
data, err := t.doRequest("GET", path, nil)
if err != nil {
return nil, fmt.Errorf("failed to get open orders: %w", err)
}
var response struct {
Items []struct {
Id string `json:"id"`
Symbol string `json:"symbol"`
Side string `json:"side"`
Type string `json:"type"`
Price string `json:"price"`
Size int64 `json:"size"`
StopType string `json:"stopType"`
} `json:"items"`
}
if err := json.Unmarshal(data, &response); err != nil {
// Try alternate format
var items []struct {
Id string `json:"id"`
Symbol string `json:"symbol"`
Side string `json:"side"`
Type string `json:"type"`
Price string `json:"price"`
Size int64 `json:"size"`
StopType string `json:"stopType"`
}
if err := json.Unmarshal(data, &items); err != nil {
return nil, err
}
response.Items = items
}
var orders []types.OpenOrder
for _, item := range response.Items {
// Determine position side based on order side
positionSide := "LONG"
if item.Side == "sell" {
positionSide = "SHORT"
}
price, _ := strconv.ParseFloat(item.Price, 64)
orders = append(orders, types.OpenOrder{
OrderID: item.Id,
Symbol: t.convertSymbolBack(item.Symbol),
Side: strings.ToUpper(item.Side),
PositionSide: positionSide,
Type: strings.ToUpper(item.Type),
Price: price,
Quantity: float64(item.Size),
Status: "NEW",
})
}
// Get stop orders
stopPath := fmt.Sprintf("%s?symbol=%s", kucoinStopOrderPath, kcSymbol)
stopData, err := t.doRequest("GET", stopPath, nil)
if err == nil {
var stopResponse struct {
Items []struct {
Id string `json:"id"`
Symbol string `json:"symbol"`
Side string `json:"side"`
StopPrice string `json:"stopPrice"`
Size int64 `json:"size"`
} `json:"items"`
}
if json.Unmarshal(stopData, &stopResponse) == nil {
for _, item := range stopResponse.Items {
positionSide := "LONG"
if item.Side == "sell" {
positionSide = "SHORT"
}
stopPrice, _ := strconv.ParseFloat(item.StopPrice, 64)
orders = append(orders, types.OpenOrder{
OrderID: item.Id,
Symbol: t.convertSymbolBack(item.Symbol),
Side: strings.ToUpper(item.Side),
PositionSide: positionSide,
Type: "STOP_MARKET",
StopPrice: stopPrice,
Quantity: float64(item.Size),
Status: "NEW",
})
}
}
}
return orders, nil
}
+58
View File
@@ -0,0 +1,58 @@
package kucoin
import (
"encoding/json"
"fmt"
"nofx/logger"
"time"
)
// GetBalance gets account balance
func (t *KuCoinTrader) GetBalance() (map[string]interface{}, error) {
// Check cache
t.balanceCacheMutex.RLock()
if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration {
t.balanceCacheMutex.RUnlock()
return t.cachedBalance, nil
}
t.balanceCacheMutex.RUnlock()
data, err := t.doRequest("GET", kucoinAccountPath+"?currency=USDT", nil)
if err != nil {
return nil, fmt.Errorf("failed to get account balance: %w", err)
}
var account struct {
AccountEquity float64 `json:"accountEquity"`
UnrealisedPNL float64 `json:"unrealisedPNL"`
MarginBalance float64 `json:"marginBalance"`
PositionMargin float64 `json:"positionMargin"`
OrderMargin float64 `json:"orderMargin"`
FrozenFunds float64 `json:"frozenFunds"`
AvailableBalance float64 `json:"availableBalance"`
Currency string `json:"currency"`
}
if err := json.Unmarshal(data, &account); err != nil {
return nil, fmt.Errorf("failed to parse balance data: %w", err)
}
result := map[string]interface{}{
"totalWalletBalance": account.MarginBalance, // Wallet balance (without unrealized PnL)
"availableBalance": account.AvailableBalance,
"totalUnrealizedProfit": account.UnrealisedPNL,
"total_equity": account.AccountEquity,
"totalEquity": account.AccountEquity, // For GetAccountInfo compatibility
}
logger.Infof("✓ KuCoin balance: Total equity=%.2f, Available=%.2f, Unrealized PnL=%.2f",
account.AccountEquity, account.AvailableBalance, account.UnrealisedPNL)
// Update cache
t.balanceCacheMutex.Lock()
t.cachedBalance = result
t.balanceCacheTime = time.Now()
t.balanceCacheMutex.Unlock()
return result, nil
}
+777
View File
@@ -0,0 +1,777 @@
package kucoin
import (
"encoding/json"
"fmt"
"math"
"nofx/logger"
"nofx/trader/types"
"strconv"
"strings"
"time"
)
// OpenLong opens long position
func (t *KuCoinTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
// Cancel old orders
t.CancelAllOrders(symbol)
// Set leverage
if err := t.SetLeverage(symbol, leverage); err != nil {
logger.Infof("⚠️ Failed to set leverage: %v", err)
}
kcSymbol := t.convertSymbol(symbol)
// Convert quantity to lots
lots, err := t.quantityToLots(symbol, quantity)
if err != nil {
return nil, fmt.Errorf("failed to calculate lots: %w", err)
}
body := map[string]interface{}{
"clientOid": fmt.Sprintf("nfx%d", time.Now().UnixNano()),
"symbol": kcSymbol,
"side": "buy",
"type": "market",
"size": lots,
"leverage": fmt.Sprintf("%d", leverage),
"reduceOnly": false,
"marginMode": "CROSS", // Use cross margin mode
}
data, err := t.doRequest("POST", kucoinOrderPath, body)
if err != nil {
return nil, fmt.Errorf("failed to open long position: %w", err)
}
var result struct {
OrderId string `json:"orderId"`
}
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse order response: %w", err)
}
logger.Infof("✓ KuCoin opened long position: %s, lots=%d, orderId=%s", symbol, lots, result.OrderId)
// Query order to get fill price
fillPrice := t.queryOrderFillPrice(result.OrderId)
return map[string]interface{}{
"orderId": result.OrderId,
"symbol": symbol,
"status": "FILLED",
"fillPrice": fillPrice,
}, nil
}
// OpenShort opens short position
func (t *KuCoinTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
// Cancel old orders
t.CancelAllOrders(symbol)
// Set leverage
if err := t.SetLeverage(symbol, leverage); err != nil {
logger.Infof("⚠️ Failed to set leverage: %v", err)
}
kcSymbol := t.convertSymbol(symbol)
// Convert quantity to lots
lots, err := t.quantityToLots(symbol, quantity)
if err != nil {
return nil, fmt.Errorf("failed to calculate lots: %w", err)
}
body := map[string]interface{}{
"clientOid": fmt.Sprintf("nfx%d", time.Now().UnixNano()),
"symbol": kcSymbol,
"side": "sell",
"type": "market",
"size": lots,
"leverage": fmt.Sprintf("%d", leverage),
"reduceOnly": false,
"marginMode": "CROSS", // Use cross margin mode
}
data, err := t.doRequest("POST", kucoinOrderPath, body)
if err != nil {
return nil, fmt.Errorf("failed to open short position: %w", err)
}
var result struct {
OrderId string `json:"orderId"`
}
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse order response: %w", err)
}
logger.Infof("✓ KuCoin opened short position: %s, lots=%d, orderId=%s", symbol, lots, result.OrderId)
// Query order to get fill price
fillPrice := t.queryOrderFillPrice(result.OrderId)
return map[string]interface{}{
"orderId": result.OrderId,
"symbol": symbol,
"status": "FILLED",
"fillPrice": fillPrice,
}, nil
}
// queryOrderFillPrice queries order status and returns fill price
func (t *KuCoinTrader) queryOrderFillPrice(orderId string) float64 {
// Wait a bit for order to fill
time.Sleep(500 * time.Millisecond)
path := fmt.Sprintf("%s/%s", kucoinOrderPath, orderId)
data, err := t.doRequest("GET", path, nil)
if err != nil {
logger.Warnf("Failed to query order %s: %v", orderId, err)
return 0
}
var order struct {
DealAvgPrice float64 `json:"dealAvgPrice"`
Status string `json:"status"`
DealSize int64 `json:"dealSize"`
}
if err := json.Unmarshal(data, &order); err != nil {
return 0
}
return order.DealAvgPrice
}
// CloseLong closes long position
func (t *KuCoinTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) {
// Invalidate position cache and get fresh positions
t.InvalidatePositionCache()
positions, err := t.GetPositions()
if err != nil {
return nil, fmt.Errorf("failed to get positions: %w", err)
}
// Find actual position and get margin mode
var actualQty float64
var posFound bool
var marginMode string = "CROSS" // Default to CROSS
for _, pos := range positions {
if pos["symbol"] == symbol && pos["side"] == "long" {
actualQty = pos["positionAmt"].(float64)
posFound = true
// Get margin mode from position
if mgnMode, ok := pos["mgnMode"].(string); ok {
marginMode = strings.ToUpper(mgnMode)
}
break
}
}
if !posFound || actualQty == 0 {
return map[string]interface{}{
"status": "NO_POSITION",
"message": fmt.Sprintf("No long position found for %s on KuCoin", symbol),
}, nil
}
// Use actual quantity from exchange
if quantity == 0 || quantity > actualQty {
quantity = actualQty
}
kcSymbol := t.convertSymbol(symbol)
// Convert quantity to lots
lots, err := t.quantityToLots(symbol, quantity)
if err != nil {
return nil, fmt.Errorf("failed to calculate lots: %w", err)
}
body := map[string]interface{}{
"clientOid": fmt.Sprintf("nfx%d", time.Now().UnixNano()),
"symbol": kcSymbol,
"side": "sell",
"type": "market",
"size": lots,
"reduceOnly": true,
"closeOrder": true,
"marginMode": marginMode, // Use position's margin mode
}
data, err := t.doRequest("POST", kucoinOrderPath, body)
if err != nil {
return nil, fmt.Errorf("failed to close long position: %w", err)
}
var result struct {
OrderId string `json:"orderId"`
}
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse order response: %w", err)
}
logger.Infof("✓ KuCoin closed long position: %s", symbol)
// Cancel pending orders
t.CancelAllOrders(symbol)
return map[string]interface{}{
"orderId": result.OrderId,
"symbol": symbol,
"status": "FILLED",
}, nil
}
// CloseShort closes short position
func (t *KuCoinTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) {
// Invalidate position cache and get fresh positions
t.InvalidatePositionCache()
positions, err := t.GetPositions()
if err != nil {
return nil, fmt.Errorf("failed to get positions: %w", err)
}
// Find actual position and get margin mode
var actualQty float64
var posFound bool
var marginMode string = "CROSS" // Default to CROSS
for _, pos := range positions {
if pos["symbol"] == symbol && pos["side"] == "short" {
actualQty = pos["positionAmt"].(float64)
posFound = true
// Get margin mode from position
if mgnMode, ok := pos["mgnMode"].(string); ok {
marginMode = strings.ToUpper(mgnMode)
}
break
}
}
if !posFound || actualQty == 0 {
return map[string]interface{}{
"status": "NO_POSITION",
"message": fmt.Sprintf("No short position found for %s on KuCoin", symbol),
}, nil
}
// Use actual quantity from exchange
if quantity == 0 || quantity > actualQty {
quantity = actualQty
}
kcSymbol := t.convertSymbol(symbol)
// Convert quantity to lots
lots, err := t.quantityToLots(symbol, quantity)
if err != nil {
return nil, fmt.Errorf("failed to calculate lots: %w", err)
}
body := map[string]interface{}{
"clientOid": fmt.Sprintf("nfx%d", time.Now().UnixNano()),
"symbol": kcSymbol,
"side": "buy",
"type": "market",
"size": lots,
"reduceOnly": true,
"closeOrder": true,
"marginMode": marginMode, // Use position's margin mode
}
data, err := t.doRequest("POST", kucoinOrderPath, body)
if err != nil {
return nil, fmt.Errorf("failed to close short position: %w", err)
}
var result struct {
OrderId string `json:"orderId"`
}
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse order response: %w", err)
}
logger.Infof("✓ KuCoin closed short position: %s", symbol)
// Cancel pending orders
t.CancelAllOrders(symbol)
return map[string]interface{}{
"orderId": result.OrderId,
"symbol": symbol,
"status": "FILLED",
}, nil
}
// GetMarketPrice gets market price
func (t *KuCoinTrader) GetMarketPrice(symbol string) (float64, error) {
kcSymbol := t.convertSymbol(symbol)
path := fmt.Sprintf("%s?symbol=%s", kucoinTickerPath, kcSymbol)
data, err := t.doRequest("GET", path, nil)
if err != nil {
return 0, fmt.Errorf("failed to get price: %w", err)
}
var ticker struct {
Price string `json:"price"`
}
if err := json.Unmarshal(data, &ticker); err != nil {
return 0, err
}
price, _ := strconv.ParseFloat(ticker.Price, 64)
return price, nil
}
// SetStopLoss sets stop loss order
func (t *KuCoinTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error {
kcSymbol := t.convertSymbol(symbol)
// Convert quantity to lots
lots, err := t.quantityToLots(symbol, quantity)
if err != nil {
return fmt.Errorf("failed to calculate lots: %w", err)
}
// Determine side: close long = sell, close short = buy
side := "sell"
stop := "down" // Long position: stop loss triggers when price goes down
if strings.ToUpper(positionSide) == "SHORT" {
side = "buy"
stop = "up" // Short position: stop loss triggers when price goes up
}
body := map[string]interface{}{
"clientOid": fmt.Sprintf("nfxsl%d", time.Now().UnixNano()),
"symbol": kcSymbol,
"side": side,
"type": "market",
"size": lots,
"stop": stop,
"stopPriceType": "MP", // Mark Price
"stopPrice": fmt.Sprintf("%.8f", stopPrice),
"reduceOnly": true,
"closeOrder": true,
}
_, err = t.doRequest("POST", kucoinStopOrderPath, body)
if err != nil {
return fmt.Errorf("failed to set stop loss: %w", err)
}
logger.Infof("✓ Stop loss set: %.4f", stopPrice)
return nil
}
// SetTakeProfit sets take profit order
func (t *KuCoinTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error {
kcSymbol := t.convertSymbol(symbol)
// Convert quantity to lots
lots, err := t.quantityToLots(symbol, quantity)
if err != nil {
return fmt.Errorf("failed to calculate lots: %w", err)
}
// Determine side: close long = sell, close short = buy
side := "sell"
stop := "up" // Long position: take profit triggers when price goes up
if strings.ToUpper(positionSide) == "SHORT" {
side = "buy"
stop = "down" // Short position: take profit triggers when price goes down
}
body := map[string]interface{}{
"clientOid": fmt.Sprintf("nfxtp%d", time.Now().UnixNano()),
"symbol": kcSymbol,
"side": side,
"type": "market",
"size": lots,
"stop": stop,
"stopPriceType": "MP", // Mark Price
"stopPrice": fmt.Sprintf("%.8f", takeProfitPrice),
"reduceOnly": true,
"closeOrder": true,
}
_, err = t.doRequest("POST", kucoinStopOrderPath, body)
if err != nil {
return fmt.Errorf("failed to set take profit: %w", err)
}
logger.Infof("✓ Take profit set: %.4f", takeProfitPrice)
return nil
}
// CancelStopLossOrders cancels stop loss orders
func (t *KuCoinTrader) CancelStopLossOrders(symbol string) error {
return t.cancelStopOrdersByType(symbol, "sl")
}
// CancelTakeProfitOrders cancels take profit orders
func (t *KuCoinTrader) CancelTakeProfitOrders(symbol string) error {
return t.cancelStopOrdersByType(symbol, "tp")
}
// cancelStopOrdersByType cancels stop orders by type
func (t *KuCoinTrader) cancelStopOrdersByType(symbol string, orderType string) error {
kcSymbol := t.convertSymbol(symbol)
// Get pending stop orders
path := fmt.Sprintf("%s?symbol=%s", kucoinStopOrderPath, kcSymbol)
data, err := t.doRequest("GET", path, nil)
if err != nil {
return err
}
var response struct {
Items []struct {
Id string `json:"id"`
ClientOid string `json:"clientOid"`
Stop string `json:"stop"`
} `json:"items"`
}
if err := json.Unmarshal(data, &response); err != nil {
// Try alternate format (direct array)
var items []struct {
Id string `json:"id"`
ClientOid string `json:"clientOid"`
Stop string `json:"stop"`
}
if err := json.Unmarshal(data, &items); err != nil {
return err
}
response.Items = items
}
// Cancel matching orders
for _, order := range response.Items {
// Check if order matches type based on clientOid prefix
if orderType == "sl" && !strings.Contains(order.ClientOid, "sl") {
continue
}
if orderType == "tp" && !strings.Contains(order.ClientOid, "tp") {
continue
}
cancelPath := fmt.Sprintf("%s/%s", kucoinCancelStopPath, order.Id)
_, err := t.doRequest("DELETE", cancelPath, nil)
if err != nil {
logger.Warnf("Failed to cancel stop order %s: %v", order.Id, err)
}
}
return nil
}
// CancelStopOrders cancels all stop orders for symbol
func (t *KuCoinTrader) CancelStopOrders(symbol string) error {
kcSymbol := t.convertSymbol(symbol)
path := fmt.Sprintf("%s?symbol=%s", kucoinCancelStopPath, kcSymbol)
_, err := t.doRequest("DELETE", path, nil)
if err != nil {
// Ignore if no orders to cancel
if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "400100") {
return nil
}
return err
}
logger.Infof("✓ Cancelled stop orders for %s", symbol)
return nil
}
// CancelAllOrders cancels all pending orders for symbol
func (t *KuCoinTrader) CancelAllOrders(symbol string) error {
kcSymbol := t.convertSymbol(symbol)
// Cancel regular orders
path := fmt.Sprintf("%s?symbol=%s", kucoinCancelOrderPath, kcSymbol)
_, err := t.doRequest("DELETE", path, nil)
if err != nil && !strings.Contains(err.Error(), "not found") {
logger.Warnf("Failed to cancel regular orders: %v", err)
}
// Cancel stop orders
t.CancelStopOrders(symbol)
return nil
}
// SetMarginMode sets margin mode
func (t *KuCoinTrader) SetMarginMode(symbol string, isCrossMargin bool) error {
// KuCoin sets margin mode per position, handled automatically
logger.Infof("✓ KuCoin margin mode: %v (handled per position)", isCrossMargin)
return nil
}
// SetLeverage sets leverage for a symbol
func (t *KuCoinTrader) SetLeverage(symbol string, leverage int) error {
kcSymbol := t.convertSymbol(symbol)
body := map[string]interface{}{
"symbol": kcSymbol,
"leverage": fmt.Sprintf("%d", leverage),
}
_, err := t.doRequest("POST", kucoinLeveragePath, body)
if err != nil {
// Ignore if already at target leverage
if strings.Contains(err.Error(), "same") || strings.Contains(err.Error(), "already") {
logger.Infof("✓ %s leverage is already %dx", symbol, leverage)
return nil
}
return fmt.Errorf("failed to set leverage: %w", err)
}
logger.Infof("✓ %s leverage set to %dx", symbol, leverage)
return nil
}
// FormatQuantity formats quantity to correct precision
func (t *KuCoinTrader) FormatQuantity(symbol string, quantity float64) (string, error) {
contract, err := t.getContract(symbol)
if err != nil {
return "", err
}
// Calculate lots
lots := quantity / contract.Multiplier
// Round to integer
lotsInt := int64(math.Round(lots))
return strconv.FormatInt(lotsInt, 10), nil
}
// GetOrderStatus gets order status
func (t *KuCoinTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) {
path := fmt.Sprintf("%s/%s", kucoinOrderPath, orderID)
data, err := t.doRequest("GET", path, nil)
if err != nil {
return nil, fmt.Errorf("failed to get order status: %w", err)
}
var order struct {
Id string `json:"id"`
Symbol string `json:"symbol"`
Status string `json:"status"`
DealAvgPrice float64 `json:"dealAvgPrice"`
DealSize int64 `json:"dealSize"`
Fee float64 `json:"fee"`
Side string `json:"side"`
}
if err := json.Unmarshal(data, &order); err != nil {
return nil, err
}
// Convert status
status := "NEW"
if order.Status == "done" {
status = "FILLED"
} else if order.Status == "cancelled" || order.Status == "canceled" {
status = "CANCELED"
}
return map[string]interface{}{
"orderId": order.Id,
"symbol": t.convertSymbolBack(order.Symbol),
"status": status,
"avgPrice": order.DealAvgPrice,
"executedQty": order.DealSize,
"commission": order.Fee,
}, nil
}
// GetClosedPnL gets closed position PnL records
func (t *KuCoinTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) {
if limit <= 0 {
limit = 100
}
if limit > 100 {
limit = 100
}
// KuCoin closed positions API
path := fmt.Sprintf("/api/v1/history-positions?status=CLOSE&limit=%d", limit)
if !startTime.IsZero() {
path += fmt.Sprintf("&from=%d", startTime.UnixMilli())
}
data, err := t.doRequest("GET", path, nil)
if err != nil {
return nil, fmt.Errorf("failed to get closed PnL: %w", err)
}
var response struct {
HasMore bool `json:"hasMore"`
DataList []struct {
Symbol string `json:"symbol"`
OpenPrice float64 `json:"avgEntryPrice"`
ClosePrice float64 `json:"avgClosePrice"`
Qty int64 `json:"qty"`
RealisedPnl float64 `json:"realisedGrossCost"`
CloseTime int64 `json:"closeTime"`
OpenTime int64 `json:"openTime"`
PositionId string `json:"id"`
CloseType string `json:"type"`
Leverage int `json:"leverage"`
SettleCurrency string `json:"settleCurrency"`
} `json:"dataList"`
}
if err := json.Unmarshal(data, &response); err != nil {
return nil, fmt.Errorf("failed to parse closed PnL: %w", err)
}
var records []types.ClosedPnLRecord
for _, item := range response.DataList {
side := "long"
qty := item.Qty
if qty < 0 {
side = "short"
qty = -qty
}
// Map close type
closeType := "unknown"
switch strings.ToUpper(item.CloseType) {
case "CLOSE", "MANUAL":
closeType = "manual"
case "STOP", "STOPLOSS":
closeType = "stop_loss"
case "TAKEPROFIT", "TP":
closeType = "take_profit"
case "LIQUIDATION", "LIQ", "ADL":
closeType = "liquidation"
}
records = append(records, types.ClosedPnLRecord{
Symbol: t.convertSymbolBack(item.Symbol),
Side: side,
EntryPrice: item.OpenPrice,
ExitPrice: item.ClosePrice,
Quantity: float64(qty),
RealizedPnL: item.RealisedPnl,
Leverage: item.Leverage,
EntryTime: time.UnixMilli(item.OpenTime),
ExitTime: time.UnixMilli(item.CloseTime),
ExchangeID: item.PositionId,
CloseType: closeType,
})
}
return records, nil
}
// GetOpenOrders gets open/pending orders
func (t *KuCoinTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) {
kcSymbol := t.convertSymbol(symbol)
// Get regular orders
path := fmt.Sprintf("%s?symbol=%s&status=active", kucoinOrderPath, kcSymbol)
data, err := t.doRequest("GET", path, nil)
if err != nil {
return nil, fmt.Errorf("failed to get open orders: %w", err)
}
var response struct {
Items []struct {
Id string `json:"id"`
Symbol string `json:"symbol"`
Side string `json:"side"`
Type string `json:"type"`
Price string `json:"price"`
Size int64 `json:"size"`
StopType string `json:"stopType"`
} `json:"items"`
}
if err := json.Unmarshal(data, &response); err != nil {
// Try alternate format
var items []struct {
Id string `json:"id"`
Symbol string `json:"symbol"`
Side string `json:"side"`
Type string `json:"type"`
Price string `json:"price"`
Size int64 `json:"size"`
StopType string `json:"stopType"`
}
if err := json.Unmarshal(data, &items); err != nil {
return nil, err
}
response.Items = items
}
var orders []types.OpenOrder
for _, item := range response.Items {
// Determine position side based on order side
positionSide := "LONG"
if item.Side == "sell" {
positionSide = "SHORT"
}
price, _ := strconv.ParseFloat(item.Price, 64)
orders = append(orders, types.OpenOrder{
OrderID: item.Id,
Symbol: t.convertSymbolBack(item.Symbol),
Side: strings.ToUpper(item.Side),
PositionSide: positionSide,
Type: strings.ToUpper(item.Type),
Price: price,
Quantity: float64(item.Size),
Status: "NEW",
})
}
// Get stop orders
stopPath := fmt.Sprintf("%s?symbol=%s", kucoinStopOrderPath, kcSymbol)
stopData, err := t.doRequest("GET", stopPath, nil)
if err == nil {
var stopResponse struct {
Items []struct {
Id string `json:"id"`
Symbol string `json:"symbol"`
Side string `json:"side"`
StopPrice string `json:"stopPrice"`
Size int64 `json:"size"`
} `json:"items"`
}
if json.Unmarshal(stopData, &stopResponse) == nil {
for _, item := range stopResponse.Items {
positionSide := "LONG"
if item.Side == "sell" {
positionSide = "SHORT"
}
stopPrice, _ := strconv.ParseFloat(item.StopPrice, 64)
orders = append(orders, types.OpenOrder{
OrderID: item.Id,
Symbol: t.convertSymbolBack(item.Symbol),
Side: strings.ToUpper(item.Side),
PositionSide: positionSide,
Type: "STOP_MARKET",
StopPrice: stopPrice,
Quantity: float64(item.Size),
Status: "NEW",
})
}
}
}
return orders, nil
}
+115
View File
@@ -0,0 +1,115 @@
package kucoin
import (
"encoding/json"
"fmt"
"time"
)
// GetPositions gets all positions
func (t *KuCoinTrader) GetPositions() ([]map[string]interface{}, error) {
// Check cache
t.positionsCacheMutex.RLock()
if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration {
t.positionsCacheMutex.RUnlock()
return t.cachedPositions, nil
}
t.positionsCacheMutex.RUnlock()
data, err := t.doRequest("GET", kucoinPositionPath, nil)
if err != nil {
return nil, fmt.Errorf("failed to get positions: %w", err)
}
var positions []struct {
Symbol string `json:"symbol"`
CurrentQty int64 `json:"currentQty"` // Position quantity (in lots, integer)
AvgEntryPrice float64 `json:"avgEntryPrice"` // Average entry price
MarkPrice float64 `json:"markPrice"` // Mark price
UnrealisedPnl float64 `json:"unrealisedPnl"` // Unrealized PnL
Leverage float64 `json:"leverage"` // Leverage setting
RealLeverage float64 `json:"realLeverage"` // Effective leverage (may be nil in cross mode)
LiquidationPrice float64 `json:"liquidationPrice"`// Liquidation price
Multiplier float64 `json:"multiplier"` // Contract multiplier
IsOpen bool `json:"isOpen"`
CrossMode bool `json:"crossMode"`
OpeningTimestamp int64 `json:"openingTimestamp"`
SettleCurrency string `json:"settleCurrency"`
}
if err := json.Unmarshal(data, &positions); err != nil {
return nil, fmt.Errorf("failed to parse position data: %w", err)
}
var result []map[string]interface{}
for _, pos := range positions {
if !pos.IsOpen || pos.CurrentQty == 0 {
continue
}
// Convert symbol format
symbol := t.convertSymbolBack(pos.Symbol)
// Determine side based on position quantity
// KuCoin: positive qty = long, negative qty = short
side := "long"
qty := pos.CurrentQty
if qty < 0 {
side = "short"
qty = -qty
}
// Convert lots to actual quantity using multiplier
// Position quantity = lots * multiplier
multiplier := pos.Multiplier
if multiplier == 0 {
multiplier = 0.001 // Default for BTC
}
positionAmt := float64(qty) * multiplier
// Determine margin mode
mgnMode := "isolated"
if pos.CrossMode {
mgnMode = "cross"
}
// Use Leverage field (setting), fallback to RealLeverage (effective), default to 10
leverage := pos.Leverage
if leverage == 0 {
leverage = pos.RealLeverage
}
if leverage == 0 {
leverage = 10 // Default leverage
}
posMap := map[string]interface{}{
"symbol": symbol,
"positionAmt": positionAmt,
"entryPrice": pos.AvgEntryPrice,
"markPrice": pos.MarkPrice,
"unRealizedProfit": pos.UnrealisedPnl,
"leverage": leverage,
"liquidationPrice": pos.LiquidationPrice,
"side": side,
"mgnMode": mgnMode,
"createdTime": pos.OpeningTimestamp,
}
result = append(result, posMap)
}
// Update cache
t.positionsCacheMutex.Lock()
t.cachedPositions = result
t.positionsCacheTime = time.Now()
t.positionsCacheMutex.Unlock()
return result, nil
}
// InvalidatePositionCache clears the position cache
func (t *KuCoinTrader) InvalidatePositionCache() {
t.positionsCacheMutex.Lock()
t.cachedPositions = nil
t.positionsCacheTime = time.Time{}
t.positionsCacheMutex.Unlock()
}
+6 -1388
View File
File diff suppressed because it is too large Load Diff
+280
View File
@@ -0,0 +1,280 @@
package okx
import (
"encoding/json"
"fmt"
"nofx/logger"
"nofx/trader/types"
"strconv"
"strings"
"time"
)
// GetBalance gets account balance
func (t *OKXTrader) GetBalance() (map[string]interface{}, error) {
// Check cache
t.balanceCacheMutex.RLock()
if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration {
t.balanceCacheMutex.RUnlock()
logger.Infof("✓ Using cached OKX account balance")
return t.cachedBalance, nil
}
t.balanceCacheMutex.RUnlock()
logger.Infof("🔄 Calling OKX API to get account balance...")
data, err := t.doRequest("GET", okxAccountPath, nil)
if err != nil {
return nil, fmt.Errorf("failed to get account balance: %w", err)
}
var balances []struct {
TotalEq string `json:"totalEq"`
AdjEq string `json:"adjEq"`
IsoEq string `json:"isoEq"`
OrdFroz string `json:"ordFroz"`
Details []struct {
Ccy string `json:"ccy"`
Eq string `json:"eq"`
CashBal string `json:"cashBal"`
AvailBal string `json:"availBal"`
UPL string `json:"upl"`
} `json:"details"`
}
if err := json.Unmarshal(data, &balances); err != nil {
return nil, fmt.Errorf("failed to parse balance data: %w", err)
}
if len(balances) == 0 {
return nil, fmt.Errorf("no balance data received")
}
balance := balances[0]
// Find USDT balance
var usdtAvail, usdtUPL float64
for _, detail := range balance.Details {
if detail.Ccy == "USDT" {
usdtAvail, _ = strconv.ParseFloat(detail.AvailBal, 64)
usdtUPL, _ = strconv.ParseFloat(detail.UPL, 64)
break
}
}
totalEq, _ := strconv.ParseFloat(balance.TotalEq, 64)
result := map[string]interface{}{
"totalWalletBalance": totalEq,
"availableBalance": usdtAvail,
"totalUnrealizedProfit": usdtUPL,
}
logger.Infof("✓ OKX balance: Total equity=%.2f, Available=%.2f, Unrealized PnL=%.2f", totalEq, usdtAvail, usdtUPL)
// Update cache
t.balanceCacheMutex.Lock()
t.cachedBalance = result
t.balanceCacheTime = time.Now()
t.balanceCacheMutex.Unlock()
return result, nil
}
// SetMarginMode sets margin mode
func (t *OKXTrader) SetMarginMode(symbol string, isCrossMargin bool) error {
instId := t.convertSymbol(symbol)
mgnMode := "isolated"
if isCrossMargin {
mgnMode = "cross"
}
body := map[string]interface{}{
"instId": instId,
"mgnMode": mgnMode,
}
_, err := t.doRequest("POST", "/api/v5/account/set-isolated-mode", body)
if err != nil {
// Ignore error if already in target mode
if strings.Contains(err.Error(), "already") {
logger.Infof(" ✓ %s margin mode is already %s", symbol, mgnMode)
return nil
}
// Cannot change when there are positions
if strings.Contains(err.Error(), "position") {
logger.Infof(" ⚠️ %s has positions, cannot change margin mode", symbol)
return nil
}
return err
}
logger.Infof(" ✓ %s margin mode set to %s", symbol, mgnMode)
return nil
}
// SetLeverage sets leverage
func (t *OKXTrader) SetLeverage(symbol string, leverage int) error {
instId := t.convertSymbol(symbol)
// Set leverage for both long and short
for _, posSide := range []string{"long", "short"} {
body := map[string]interface{}{
"instId": instId,
"lever": strconv.Itoa(leverage),
"mgnMode": "cross",
"posSide": posSide,
}
_, err := t.doRequest("POST", okxLeveragePath, body)
if err != nil {
// Ignore if already at target leverage
if strings.Contains(err.Error(), "same") {
continue
}
logger.Infof(" ⚠️ Failed to set %s %s leverage: %v", symbol, posSide, err)
}
}
logger.Infof(" ✓ %s leverage set to %dx", symbol, leverage)
return nil
}
// GetMarketPrice gets market price
func (t *OKXTrader) GetMarketPrice(symbol string) (float64, error) {
instId := t.convertSymbol(symbol)
path := fmt.Sprintf("%s?instId=%s", okxTickerPath, instId)
data, err := t.doRequest("GET", path, nil)
if err != nil {
return 0, fmt.Errorf("failed to get price: %w", err)
}
var tickers []struct {
Last string `json:"last"`
}
if err := json.Unmarshal(data, &tickers); err != nil {
return 0, err
}
if len(tickers) == 0 {
return 0, fmt.Errorf("no price data received")
}
price, err := strconv.ParseFloat(tickers[0].Last, 64)
if err != nil {
return 0, err
}
return price, nil
}
// GetClosedPnL retrieves closed position PnL records from OKX
// OKX API: /api/v5/account/positions-history
func (t *OKXTrader) GetClosedPnL(startTime time.Time, limit int) ([]types.ClosedPnLRecord, error) {
if limit <= 0 {
limit = 100
}
if limit > 100 {
limit = 100
}
// Build query path with parameters
path := fmt.Sprintf("/api/v5/account/positions-history?instType=SWAP&limit=%d", limit)
if !startTime.IsZero() {
path += fmt.Sprintf("&after=%d", startTime.UnixMilli())
}
data, err := t.doRequest("GET", path, nil)
if err != nil {
return nil, fmt.Errorf("failed to get positions history: %w", err)
}
var resp struct {
Code string `json:"code"`
Msg string `json:"msg"`
Data []struct {
InstID string `json:"instId"` // Instrument ID (e.g., "BTC-USDT-SWAP")
Direction string `json:"direction"` // Position direction: "long" or "short"
OpenAvgPx string `json:"openAvgPx"` // Average open price
CloseAvgPx string `json:"closeAvgPx"` // Average close price
CloseTotalPos string `json:"closeTotalPos"` // Closed position quantity
RealizedPnl string `json:"realizedPnl"` // Realized PnL
Fee string `json:"fee"` // Total fee
FundingFee string `json:"fundingFee"` // Funding fee
Lever string `json:"lever"` // Leverage
CTime string `json:"cTime"` // Position open time
UTime string `json:"uTime"` // Position close time
Type string `json:"type"` // Close type: 1=close position, 2=partial close, 3=liquidation, 4=partial liquidation
PosId string `json:"posId"` // Position ID
} `json:"data"`
}
if err := json.Unmarshal(data, &resp); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
if resp.Code != "0" {
return nil, fmt.Errorf("OKX API error: %s - %s", resp.Code, resp.Msg)
}
records := make([]types.ClosedPnLRecord, 0, len(resp.Data))
for _, pos := range resp.Data {
record := types.ClosedPnLRecord{}
// Convert instrument ID to standard format (BTC-USDT-SWAP -> BTCUSDT)
parts := strings.Split(pos.InstID, "-")
if len(parts) >= 2 {
record.Symbol = parts[0] + parts[1]
} else {
record.Symbol = pos.InstID
}
// Side
record.Side = pos.Direction // OKX already returns "long" or "short"
// Prices
record.EntryPrice, _ = strconv.ParseFloat(pos.OpenAvgPx, 64)
record.ExitPrice, _ = strconv.ParseFloat(pos.CloseAvgPx, 64)
// Quantity
record.Quantity, _ = strconv.ParseFloat(pos.CloseTotalPos, 64)
// PnL
record.RealizedPnL, _ = strconv.ParseFloat(pos.RealizedPnl, 64)
// Fee
fee, _ := strconv.ParseFloat(pos.Fee, 64)
fundingFee, _ := strconv.ParseFloat(pos.FundingFee, 64)
record.Fee = -fee + fundingFee // Fee is negative in OKX
// Leverage
lev, _ := strconv.ParseFloat(pos.Lever, 64)
record.Leverage = int(lev)
// Times
cTime, _ := strconv.ParseInt(pos.CTime, 10, 64)
uTime, _ := strconv.ParseInt(pos.UTime, 10, 64)
record.EntryTime = time.UnixMilli(cTime).UTC()
record.ExitTime = time.UnixMilli(uTime).UTC()
// Close type
switch pos.Type {
case "1", "2":
record.CloseType = "unknown" // Could be manual or AI, need to cross-reference
case "3", "4":
record.CloseType = "liquidation"
default:
record.CloseType = "unknown"
}
// Exchange ID
record.ExchangeID = pos.PosId
records = append(records, record)
}
return records, nil
}
+938
View File
@@ -0,0 +1,938 @@
package okx
import (
"encoding/json"
"fmt"
"nofx/logger"
"nofx/trader/types"
"strconv"
"strings"
)
// OpenLong opens long position
func (t *OKXTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
// Cancel old orders
t.CancelAllOrders(symbol)
// Set leverage
if err := t.SetLeverage(symbol, leverage); err != nil {
logger.Infof(" ⚠️ Failed to set leverage: %v", err)
}
instId := t.convertSymbol(symbol)
// Get instrument info and calculate contract size
inst, err := t.getInstrument(symbol)
if err != nil {
return nil, fmt.Errorf("failed to get instrument info: %w", err)
}
// OKX uses contract count, need to convert quantity (in base asset) to contract count
// sz = quantity / ctVal (number of contracts = asset amount / asset per contract)
sz := quantity / inst.CtVal
szStr := t.formatSize(sz, inst)
logger.Infof(" 📊 OKX OpenLong: quantity=%.6f, ctVal=%.6f, contracts=%.2f", quantity, inst.CtVal, sz)
// Check max market order size limit
if inst.MaxMktSz > 0 && sz > inst.MaxMktSz {
logger.Infof(" ⚠️ OKX market order size %.2f exceeds max %.2f, reducing to max", sz, inst.MaxMktSz)
sz = inst.MaxMktSz
szStr = t.formatSize(sz, inst)
}
body := map[string]interface{}{
"instId": instId,
"tdMode": "cross",
"side": "buy",
"posSide": "long",
"ordType": "market",
"sz": szStr,
"clOrdId": genOkxClOrdID(),
"tag": okxTag,
}
data, err := t.doRequest("POST", okxOrderPath, body)
if err != nil {
return nil, fmt.Errorf("failed to open long position: %w", err)
}
var orders []struct {
OrdId string `json:"ordId"`
ClOrdId string `json:"clOrdId"`
SCode string `json:"sCode"`
SMsg string `json:"sMsg"`
}
if err := json.Unmarshal(data, &orders); err != nil {
return nil, fmt.Errorf("failed to parse order response: %w", err)
}
if len(orders) == 0 || orders[0].SCode != "0" {
msg := "unknown error"
if len(orders) > 0 {
msg = orders[0].SMsg
}
return nil, fmt.Errorf("failed to open long position: %s", msg)
}
logger.Infof("✓ OKX opened long position successfully: %s size: %s", symbol, szStr)
logger.Infof(" Order ID: %s", orders[0].OrdId)
return map[string]interface{}{
"orderId": orders[0].OrdId,
"symbol": symbol,
"status": "FILLED",
}, nil
}
// OpenShort opens short position
func (t *OKXTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
// Cancel old orders
t.CancelAllOrders(symbol)
// Set leverage
if err := t.SetLeverage(symbol, leverage); err != nil {
logger.Infof(" ⚠️ Failed to set leverage: %v", err)
}
instId := t.convertSymbol(symbol)
// Get instrument info and calculate contract size
inst, err := t.getInstrument(symbol)
if err != nil {
return nil, fmt.Errorf("failed to get instrument info: %w", err)
}
// OKX uses contract count, need to convert quantity (in base asset) to contract count
// sz = quantity / ctVal (number of contracts = asset amount / asset per contract)
sz := quantity / inst.CtVal
szStr := t.formatSize(sz, inst)
logger.Infof(" 📊 OKX OpenShort: quantity=%.6f, ctVal=%.6f, contracts=%.2f", quantity, inst.CtVal, sz)
// Check max market order size limit
if inst.MaxMktSz > 0 && sz > inst.MaxMktSz {
logger.Infof(" ⚠️ OKX market order size %.2f exceeds max %.2f, reducing to max", sz, inst.MaxMktSz)
sz = inst.MaxMktSz
szStr = t.formatSize(sz, inst)
}
body := map[string]interface{}{
"instId": instId,
"tdMode": "cross",
"side": "sell",
"posSide": "short",
"ordType": "market",
"sz": szStr,
"clOrdId": genOkxClOrdID(),
"tag": okxTag,
}
data, err := t.doRequest("POST", okxOrderPath, body)
if err != nil {
return nil, fmt.Errorf("failed to open short position: %w", err)
}
var orders []struct {
OrdId string `json:"ordId"`
ClOrdId string `json:"clOrdId"`
SCode string `json:"sCode"`
SMsg string `json:"sMsg"`
}
if err := json.Unmarshal(data, &orders); err != nil {
return nil, fmt.Errorf("failed to parse order response: %w", err)
}
if len(orders) == 0 || orders[0].SCode != "0" {
msg := "unknown error"
if len(orders) > 0 {
msg = orders[0].SMsg
}
return nil, fmt.Errorf("failed to open short position: %s", msg)
}
logger.Infof("✓ OKX opened short position successfully: %s size: %s", symbol, szStr)
logger.Infof(" Order ID: %s", orders[0].OrdId)
return map[string]interface{}{
"orderId": orders[0].OrdId,
"symbol": symbol,
"status": "FILLED",
}, nil
}
// CloseLong closes long position
func (t *OKXTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) {
instId := t.convertSymbol(symbol)
// Get instrument info for contract conversion
inst, err := t.getInstrument(symbol)
if err != nil {
return nil, fmt.Errorf("failed to get instrument info: %w", err)
}
// Invalidate position cache and get fresh positions
t.InvalidatePositionCache()
positions, err := t.GetPositions()
if err != nil {
return nil, fmt.Errorf("failed to get positions: %w", err)
}
// Find actual position from exchange
var actualQty float64
var posFound bool
var posMgnMode string = "cross" // Default to cross margin
logger.Infof("🔍 OKX CloseLong: searching for symbol=%s in %d positions", symbol, len(positions))
for _, pos := range positions {
logger.Infof("🔍 OKX position: symbol=%v, side=%v, positionAmt=%v, mgnMode=%v", pos["symbol"], pos["side"], pos["positionAmt"], pos["mgnMode"])
if pos["symbol"] == symbol {
side := pos["side"].(string)
// In net_mode, "long" means positive position
// In dual mode, check explicit "long" side
if side == "long" || (t.positionMode == "net_mode" && side == "long") {
actualQty = pos["positionAmt"].(float64)
posFound = true
if mgnMode, ok := pos["mgnMode"].(string); ok && mgnMode != "" {
posMgnMode = mgnMode
}
logger.Infof("🔍 OKX CloseLong: found matching position! qty=%.6f, mgnMode=%s", actualQty, posMgnMode)
break
}
}
}
if !posFound || actualQty == 0 {
logger.Infof("🔍 OKX CloseLong: NO position found for %s LONG", symbol)
return map[string]interface{}{
"status": "NO_POSITION",
"message": fmt.Sprintf("No long position found for %s on OKX", symbol),
}, nil
}
// Use actual quantity from exchange (more accurate than passed quantity)
if quantity == 0 || quantity > actualQty {
quantity = actualQty
}
// Convert quantity (base asset) to contract count
// contracts = quantity / ctVal
contracts := quantity / inst.CtVal
szStr := t.formatSize(contracts, inst)
logger.Infof("🔻 OKX close long: symbol=%s, instId=%s, quantity=%.6f, ctVal=%.6f, contracts=%.2f, szStr=%s, posMode=%s, mgnMode=%s",
symbol, instId, quantity, inst.CtVal, contracts, szStr, t.positionMode, posMgnMode)
body := map[string]interface{}{
"instId": instId,
"tdMode": posMgnMode, // Use position's actual margin mode (cross or isolated)
"side": "sell",
"ordType": "market",
"sz": szStr,
"clOrdId": genOkxClOrdID(),
"tag": okxTag,
}
// Only add posSide in dual mode (long_short_mode)
if t.positionMode == "long_short_mode" {
body["posSide"] = "long"
}
data, err := t.doRequest("POST", okxOrderPath, body)
if err != nil {
return nil, fmt.Errorf("failed to close long position: %w", err)
}
var orders []struct {
OrdId string `json:"ordId"`
SCode string `json:"sCode"`
SMsg string `json:"sMsg"`
}
if err := json.Unmarshal(data, &orders); err != nil {
return nil, err
}
if len(orders) == 0 || orders[0].SCode != "0" {
msg := "unknown error"
if len(orders) > 0 {
msg = orders[0].SMsg
}
return nil, fmt.Errorf("failed to close long position: %s", msg)
}
logger.Infof("✓ OKX closed long position successfully: %s", symbol)
// Cancel pending orders after closing position
t.CancelAllOrders(symbol)
return map[string]interface{}{
"orderId": orders[0].OrdId,
"symbol": symbol,
"status": "FILLED",
}, nil
}
// CloseShort closes short position
func (t *OKXTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) {
instId := t.convertSymbol(symbol)
// Get instrument info for contract conversion
inst, err := t.getInstrument(symbol)
if err != nil {
return nil, fmt.Errorf("failed to get instrument info: %w", err)
}
// Invalidate position cache and get fresh positions
t.InvalidatePositionCache()
positions, err := t.GetPositions()
if err != nil {
return nil, fmt.Errorf("failed to get positions: %w", err)
}
// Find actual position from exchange
var actualQty float64
var posFound bool
var posMgnMode string = "cross" // Default to cross margin
logger.Infof("🔍 OKX CloseShort searching positions: symbol=%s, current position count=%d", symbol, len(positions))
for _, pos := range positions {
logger.Infof("🔍 OKX position: symbol=%v, side=%v, positionAmt=%v, mgnMode=%v",
pos["symbol"], pos["side"], pos["positionAmt"], pos["mgnMode"])
if pos["symbol"] == symbol && pos["side"] == "short" {
actualQty = pos["positionAmt"].(float64)
posFound = true
if mgnMode, ok := pos["mgnMode"].(string); ok && mgnMode != "" {
posMgnMode = mgnMode
}
logger.Infof("🔍 OKX found short position: quantity=%f (base asset), mgnMode=%s", actualQty, posMgnMode)
break
}
}
if !posFound || actualQty == 0 {
return map[string]interface{}{
"status": "NO_POSITION",
"message": fmt.Sprintf("No short position found for %s on OKX", symbol),
}, nil
}
// Use actual quantity from exchange (more accurate than passed quantity)
if quantity == 0 || quantity > actualQty {
quantity = actualQty
}
// Ensure quantity is positive (OKX sz parameter must be positive)
if quantity < 0 {
quantity = -quantity
}
// Convert quantity (base asset) to contract count
// contracts = quantity / ctVal
contracts := quantity / inst.CtVal
szStr := t.formatSize(contracts, inst)
logger.Infof("🔻 OKX close short: symbol=%s, quantity=%.6f, ctVal=%.6f, contracts=%.2f, szStr=%s, posMode=%s, mgnMode=%s",
symbol, quantity, inst.CtVal, contracts, szStr, t.positionMode, posMgnMode)
body := map[string]interface{}{
"instId": instId,
"tdMode": posMgnMode, // Use position's actual margin mode (cross or isolated)
"side": "buy",
"ordType": "market",
"sz": szStr,
"clOrdId": genOkxClOrdID(),
"tag": okxTag,
}
// Only add posSide in dual mode (long_short_mode)
if t.positionMode == "long_short_mode" {
body["posSide"] = "short"
}
logger.Infof("🔻 OKX close short request body: %+v", body)
data, err := t.doRequest("POST", okxOrderPath, body)
if err != nil {
return nil, fmt.Errorf("failed to close short position: %w", err)
}
var orders []struct {
OrdId string `json:"ordId"`
SCode string `json:"sCode"`
SMsg string `json:"sMsg"`
}
if err := json.Unmarshal(data, &orders); err != nil {
return nil, err
}
if len(orders) == 0 || orders[0].SCode != "0" {
msg := "unknown error"
if len(orders) > 0 {
msg = fmt.Sprintf("sCode=%s, sMsg=%s", orders[0].SCode, orders[0].SMsg)
}
logger.Infof("❌ OKX failed to close short position: %s, response: %s", msg, string(data))
return nil, fmt.Errorf("failed to close short position: %s", msg)
}
logger.Infof("✓ OKX closed short position successfully: %s, ordId=%s", symbol, orders[0].OrdId)
// Cancel pending orders after closing position
t.CancelAllOrders(symbol)
return map[string]interface{}{
"orderId": orders[0].OrdId,
"symbol": symbol,
"status": "FILLED",
}, nil
}
// SetStopLoss sets stop loss order
func (t *OKXTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error {
instId := t.convertSymbol(symbol)
// Get instrument info
inst, err := t.getInstrument(symbol)
if err != nil {
return fmt.Errorf("failed to get instrument info: %w", err)
}
// Calculate contract size: quantity (in base asset) / ctVal (asset per contract)
sz := quantity / inst.CtVal
szStr := t.formatSize(sz, inst)
// Determine direction
side := "sell"
posSide := "long"
if strings.ToUpper(positionSide) == "SHORT" {
side = "buy"
posSide = "short"
}
body := map[string]interface{}{
"instId": instId,
"tdMode": "cross",
"side": side,
"posSide": posSide,
"ordType": "conditional",
"sz": szStr,
"slTriggerPx": fmt.Sprintf("%.8f", stopPrice),
"slOrdPx": "-1", // Market price
"tag": okxTag,
}
_, err = t.doRequest("POST", okxAlgoOrderPath, body)
if err != nil {
return fmt.Errorf("failed to set stop loss: %w", err)
}
logger.Infof(" Stop loss price set: %.4f", stopPrice)
return nil
}
// SetTakeProfit sets take profit order
func (t *OKXTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error {
instId := t.convertSymbol(symbol)
// Get instrument info
inst, err := t.getInstrument(symbol)
if err != nil {
return fmt.Errorf("failed to get instrument info: %w", err)
}
// Calculate contract size: quantity (in base asset) / ctVal (asset per contract)
sz := quantity / inst.CtVal
szStr := t.formatSize(sz, inst)
// Determine direction
side := "sell"
posSide := "long"
if strings.ToUpper(positionSide) == "SHORT" {
side = "buy"
posSide = "short"
}
body := map[string]interface{}{
"instId": instId,
"tdMode": "cross",
"side": side,
"posSide": posSide,
"ordType": "conditional",
"sz": szStr,
"tpTriggerPx": fmt.Sprintf("%.8f", takeProfitPrice),
"tpOrdPx": "-1", // Market price
"tag": okxTag,
}
_, err = t.doRequest("POST", okxAlgoOrderPath, body)
if err != nil {
return fmt.Errorf("failed to set take profit: %w", err)
}
logger.Infof(" Take profit price set: %.4f", takeProfitPrice)
return nil
}
// CancelStopLossOrders cancels stop loss orders
func (t *OKXTrader) CancelStopLossOrders(symbol string) error {
return t.cancelAlgoOrders(symbol, "sl")
}
// CancelTakeProfitOrders cancels take profit orders
func (t *OKXTrader) CancelTakeProfitOrders(symbol string) error {
return t.cancelAlgoOrders(symbol, "tp")
}
// cancelAlgoOrders cancels algo orders
func (t *OKXTrader) cancelAlgoOrders(symbol string, orderType string) error {
instId := t.convertSymbol(symbol)
// Get pending algo orders
path := fmt.Sprintf("%s?instType=SWAP&instId=%s&ordType=conditional", okxAlgoPendingPath, instId)
data, err := t.doRequest("GET", path, nil)
if err != nil {
return err
}
var orders []struct {
AlgoId string `json:"algoId"`
InstId string `json:"instId"`
}
if err := json.Unmarshal(data, &orders); err != nil {
return err
}
canceledCount := 0
for _, order := range orders {
body := []map[string]interface{}{
{
"algoId": order.AlgoId,
"instId": order.InstId,
},
}
_, err := t.doRequest("POST", okxCancelAlgoPath, body)
if err != nil {
logger.Infof(" ⚠️ Failed to cancel algo order: %v", err)
continue
}
canceledCount++
}
if canceledCount > 0 {
logger.Infof(" ✓ Canceled %d algo orders for %s", canceledCount, symbol)
}
return nil
}
// CancelAllOrders cancels all pending orders
func (t *OKXTrader) CancelAllOrders(symbol string) error {
instId := t.convertSymbol(symbol)
// Get pending orders
path := fmt.Sprintf("%s?instType=SWAP&instId=%s", okxPendingOrdersPath, instId)
data, err := t.doRequest("GET", path, nil)
if err != nil {
return err
}
var orders []struct {
OrdId string `json:"ordId"`
InstId string `json:"instId"`
}
if err := json.Unmarshal(data, &orders); err != nil {
return err
}
// Batch cancel
for _, order := range orders {
body := map[string]interface{}{
"instId": order.InstId,
"ordId": order.OrdId,
}
t.doRequest("POST", okxCancelOrderPath, body)
}
// Also cancel algo orders
t.cancelAlgoOrders(symbol, "")
if len(orders) > 0 {
logger.Infof(" ✓ Canceled all pending orders for %s", symbol)
}
return nil
}
// CancelStopOrders cancels stop loss and take profit orders
func (t *OKXTrader) CancelStopOrders(symbol string) error {
return t.cancelAlgoOrders(symbol, "")
}
// GetOrderStatus gets order status
func (t *OKXTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) {
instId := t.convertSymbol(symbol)
path := fmt.Sprintf("/api/v5/trade/order?instId=%s&ordId=%s", instId, orderID)
data, err := t.doRequest("GET", path, nil)
if err != nil {
return nil, fmt.Errorf("failed to get order status: %w", err)
}
var orders []struct {
OrdId string `json:"ordId"`
State string `json:"state"`
AvgPx string `json:"avgPx"`
AccFillSz string `json:"accFillSz"`
Fee string `json:"fee"`
Side string `json:"side"`
OrdType string `json:"ordType"`
CTime string `json:"cTime"`
UTime string `json:"uTime"`
}
if err := json.Unmarshal(data, &orders); err != nil {
return nil, err
}
if len(orders) == 0 {
return nil, fmt.Errorf("order not found")
}
order := orders[0]
avgPrice, _ := strconv.ParseFloat(order.AvgPx, 64)
fillSz, _ := strconv.ParseFloat(order.AccFillSz, 64) // This is in contracts
fee, _ := strconv.ParseFloat(order.Fee, 64)
cTime, _ := strconv.ParseInt(order.CTime, 10, 64)
uTime, _ := strconv.ParseInt(order.UTime, 10, 64)
// Convert contract count to base asset quantity
// executedQty = contracts * ctVal
executedQty := fillSz
inst, err := t.getInstrument(symbol)
if err == nil && inst.CtVal > 0 {
executedQty = fillSz * inst.CtVal
logger.Debugf(" 📊 OKX order %s: fillSz(contracts)=%.4f, ctVal=%.6f, executedQty=%.6f", orderID, fillSz, inst.CtVal, executedQty)
}
// Status mapping
statusMap := map[string]string{
"filled": "FILLED",
"live": "NEW",
"partially_filled": "PARTIALLY_FILLED",
"canceled": "CANCELED",
}
status := statusMap[order.State]
if status == "" {
status = order.State
}
return map[string]interface{}{
"orderId": order.OrdId,
"symbol": symbol,
"status": status,
"avgPrice": avgPrice,
"executedQty": executedQty,
"side": order.Side,
"type": order.OrdType,
"time": cTime,
"updateTime": uTime,
"commission": -fee, // OKX returns negative value
}, nil
}
// GetOpenOrders gets all open/pending orders for a symbol
func (t *OKXTrader) GetOpenOrders(symbol string) ([]types.OpenOrder, error) {
instId := t.convertSymbol(symbol)
var result []types.OpenOrder
// 1. Get pending limit orders
path := fmt.Sprintf("%s?instId=%s&instType=SWAP", okxPendingOrdersPath, instId)
data, err := t.doRequest("GET", path, nil)
if err != nil {
logger.Warnf("[OKX] Failed to get pending orders: %v", err)
}
if err == nil && data != nil {
var orders []struct {
OrdId string `json:"ordId"`
InstId string `json:"instId"`
Side string `json:"side"` // buy/sell
PosSide string `json:"posSide"` // long/short/net
OrdType string `json:"ordType"` // limit/market/post_only
Px string `json:"px"` // price
Sz string `json:"sz"` // size
State string `json:"state"` // live/partially_filled
}
if err := json.Unmarshal(data, &orders); err == nil {
for _, order := range orders {
price, _ := strconv.ParseFloat(order.Px, 64)
quantity, _ := strconv.ParseFloat(order.Sz, 64)
// Convert OKX side to standard format
side := strings.ToUpper(order.Side)
positionSide := strings.ToUpper(order.PosSide)
if positionSide == "NET" {
positionSide = "BOTH"
}
result = append(result, types.OpenOrder{
OrderID: order.OrdId,
Symbol: symbol,
Side: side,
PositionSide: positionSide,
Type: strings.ToUpper(order.OrdType),
Price: price,
StopPrice: 0,
Quantity: quantity,
Status: "NEW",
})
}
}
}
// 2. Get pending algo orders (stop-loss/take-profit)
// OKX requires ordType parameter for algo orders API
algoPath := fmt.Sprintf("%s?instId=%s&instType=SWAP&ordType=conditional", okxAlgoPendingPath, instId)
algoData, err := t.doRequest("GET", algoPath, nil)
if err != nil {
logger.Warnf("[OKX] Failed to get algo orders: %v", err)
}
if err == nil && algoData != nil {
var algoOrders []struct {
AlgoId string `json:"algoId"`
InstId string `json:"instId"`
Side string `json:"side"`
PosSide string `json:"posSide"`
OrdType string `json:"ordType"` // conditional/oco/trigger
TriggerPx string `json:"triggerPx"`
SlTriggerPx string `json:"slTriggerPx"` // Stop loss trigger price
TpTriggerPx string `json:"tpTriggerPx"` // Take profit trigger price
Sz string `json:"sz"`
State string `json:"state"`
}
if err := json.Unmarshal(algoData, &algoOrders); err == nil {
for _, order := range algoOrders {
quantity, _ := strconv.ParseFloat(order.Sz, 64)
side := strings.ToUpper(order.Side)
positionSide := strings.ToUpper(order.PosSide)
if positionSide == "NET" {
positionSide = "BOTH"
}
// Check for stop loss order (slTriggerPx is set)
if order.SlTriggerPx != "" {
slPrice, _ := strconv.ParseFloat(order.SlTriggerPx, 64)
if slPrice > 0 {
result = append(result, types.OpenOrder{
OrderID: order.AlgoId + "_sl",
Symbol: symbol,
Side: side,
PositionSide: positionSide,
Type: "STOP_MARKET",
Price: 0,
StopPrice: slPrice,
Quantity: quantity,
Status: "NEW",
})
}
}
// Check for take profit order (tpTriggerPx is set)
if order.TpTriggerPx != "" {
tpPrice, _ := strconv.ParseFloat(order.TpTriggerPx, 64)
if tpPrice > 0 {
result = append(result, types.OpenOrder{
OrderID: order.AlgoId + "_tp",
Symbol: symbol,
Side: side,
PositionSide: positionSide,
Type: "TAKE_PROFIT_MARKET",
Price: 0,
StopPrice: tpPrice,
Quantity: quantity,
Status: "NEW",
})
}
}
// Fallback for trigger orders (triggerPx is set)
if order.TriggerPx != "" && order.SlTriggerPx == "" && order.TpTriggerPx == "" {
triggerPrice, _ := strconv.ParseFloat(order.TriggerPx, 64)
if triggerPrice > 0 {
result = append(result, types.OpenOrder{
OrderID: order.AlgoId,
Symbol: symbol,
Side: side,
PositionSide: positionSide,
Type: "STOP_MARKET",
Price: 0,
StopPrice: triggerPrice,
Quantity: quantity,
Status: "NEW",
})
}
}
}
}
}
logger.Infof("✓ OKX GetOpenOrders: found %d open orders for %s", len(result), symbol)
return result, nil
}
// PlaceLimitOrder places a limit order for grid trading
// Implements GridTrader interface
func (t *OKXTrader) PlaceLimitOrder(req *types.LimitOrderRequest) (*types.LimitOrderResult, error) {
instId := t.convertSymbol(req.Symbol)
// Get instrument info
inst, err := t.getInstrument(req.Symbol)
if err != nil {
return nil, fmt.Errorf("failed to get instrument info: %w", err)
}
// Set leverage if specified
if req.Leverage > 0 {
if err := t.SetLeverage(req.Symbol, req.Leverage); err != nil {
logger.Warnf("[OKX] Failed to set leverage: %v", err)
}
}
// Convert quantity to contract size
sz := req.Quantity / inst.CtVal
szStr := t.formatSize(sz, inst)
// Determine side and position side
side := "buy"
posSide := "long"
if req.Side == "SELL" {
side = "sell"
posSide = "short"
}
body := map[string]interface{}{
"instId": instId,
"tdMode": "cross",
"side": side,
"posSide": posSide,
"ordType": "limit",
"sz": szStr,
"px": fmt.Sprintf("%.8f", req.Price),
"clOrdId": genOkxClOrdID(),
"tag": okxTag,
}
// Add reduce only if specified
if req.ReduceOnly {
body["reduceOnly"] = true
}
logger.Infof("[OKX] PlaceLimitOrder: %s %s @ %.4f, sz=%s", instId, side, req.Price, szStr)
data, err := t.doRequest("POST", okxOrderPath, body)
if err != nil {
return nil, fmt.Errorf("failed to place limit order: %w", err)
}
var orders []struct {
OrdId string `json:"ordId"`
ClOrdId string `json:"clOrdId"`
SCode string `json:"sCode"`
SMsg string `json:"sMsg"`
}
if err := json.Unmarshal(data, &orders); err != nil {
return nil, fmt.Errorf("failed to parse order response: %w", err)
}
if len(orders) == 0 {
return nil, fmt.Errorf("empty order response")
}
if orders[0].SCode != "0" {
return nil, fmt.Errorf("OKX order failed: %s", orders[0].SMsg)
}
logger.Infof("✓ [OKX] Limit order placed: %s %s @ %.4f, orderID=%s",
instId, side, req.Price, orders[0].OrdId)
return &types.LimitOrderResult{
OrderID: orders[0].OrdId,
ClientID: orders[0].ClOrdId,
Symbol: req.Symbol,
Side: req.Side,
PositionSide: req.PositionSide,
Price: req.Price,
Quantity: req.Quantity,
Status: "NEW",
}, nil
}
// CancelOrder cancels a specific order by ID
// Implements GridTrader interface
func (t *OKXTrader) CancelOrder(symbol, orderID string) error {
instId := t.convertSymbol(symbol)
body := map[string]interface{}{
"instId": instId,
"ordId": orderID,
}
_, err := t.doRequest("POST", "/api/v5/trade/cancel-order", body)
if err != nil {
return fmt.Errorf("failed to cancel order: %w", err)
}
logger.Infof("✓ [OKX] Order cancelled: %s %s", symbol, orderID)
return nil
}
// GetOrderBook gets the order book for a symbol
// Implements GridTrader interface
func (t *OKXTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) {
instId := t.convertSymbol(symbol)
path := fmt.Sprintf("/api/v5/market/books?instId=%s&sz=%d", instId, depth)
data, err := t.doRequest("GET", path, nil)
if err != nil {
return nil, nil, fmt.Errorf("failed to get order book: %w", err)
}
var result []struct {
Bids [][]string `json:"bids"`
Asks [][]string `json:"asks"`
}
if err := json.Unmarshal(data, &result); err != nil {
return nil, nil, fmt.Errorf("failed to parse order book: %w", err)
}
if len(result) == 0 {
return nil, nil, nil
}
// Parse bids
for _, b := range result[0].Bids {
if len(b) >= 2 {
price, _ := strconv.ParseFloat(b[0], 64)
qty, _ := strconv.ParseFloat(b[1], 64)
bids = append(bids, []float64{price, qty})
}
}
// Parse asks
for _, a := range result[0].Asks {
if len(a) >= 2 {
price, _ := strconv.ParseFloat(a[0], 64)
qty, _ := strconv.ParseFloat(a[1], 64)
asks = append(asks, []float64{price, qty})
}
}
return bids, asks, nil
}
+192
View File
@@ -0,0 +1,192 @@
package okx
import (
"encoding/json"
"fmt"
"nofx/logger"
"strconv"
"time"
)
// GetPositions gets all positions
func (t *OKXTrader) GetPositions() ([]map[string]interface{}, error) {
// Check cache
t.positionsCacheMutex.RLock()
if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration {
t.positionsCacheMutex.RUnlock()
logger.Infof("✓ Using cached OKX positions")
return t.cachedPositions, nil
}
t.positionsCacheMutex.RUnlock()
logger.Infof("🔄 Calling OKX API to get positions...")
data, err := t.doRequest("GET", okxPositionPath+"?instType=SWAP", nil)
if err != nil {
return nil, fmt.Errorf("failed to get positions: %w", err)
}
var positions []struct {
InstId string `json:"instId"`
PosSide string `json:"posSide"`
Pos string `json:"pos"`
AvgPx string `json:"avgPx"`
MarkPx string `json:"markPx"`
Upl string `json:"upl"`
Lever string `json:"lever"`
LiqPx string `json:"liqPx"`
Margin string `json:"margin"`
MgnMode string `json:"mgnMode"` // Margin mode: "cross" or "isolated"
CTime string `json:"cTime"` // Position created time (ms)
UTime string `json:"uTime"` // Position last update time (ms)
}
if err := json.Unmarshal(data, &positions); err != nil {
return nil, fmt.Errorf("failed to parse position data: %w", err)
}
logger.Infof("🔍 OKX raw positions response: %d positions", len(positions))
var result []map[string]interface{}
for _, pos := range positions {
logger.Infof("🔍 OKX raw position: instId=%s, posSide=%s, pos=%s, mgnMode=%s", pos.InstId, pos.PosSide, pos.Pos, pos.MgnMode)
contractCount, _ := strconv.ParseFloat(pos.Pos, 64)
if contractCount == 0 {
continue
}
entryPrice, _ := strconv.ParseFloat(pos.AvgPx, 64)
markPrice, _ := strconv.ParseFloat(pos.MarkPx, 64)
upl, _ := strconv.ParseFloat(pos.Upl, 64)
leverage, _ := strconv.ParseFloat(pos.Lever, 64)
liqPrice, _ := strconv.ParseFloat(pos.LiqPx, 64)
// Convert symbol format
symbol := t.convertSymbolBack(pos.InstId)
logger.Infof("🔍 OKX symbol conversion: %s → %s", pos.InstId, symbol)
// Determine direction and ensure contractCount is positive
side := "long"
if pos.PosSide == "short" {
side = "short"
}
// OKX short position's pos is negative, need to take absolute value
if contractCount < 0 {
contractCount = -contractCount
}
// Convert contract count to actual position amount (in base asset)
// positionAmt = contractCount * ctVal
inst, err := t.getInstrument(symbol)
posAmt := contractCount
if err == nil && inst.CtVal > 0 {
posAmt = contractCount * inst.CtVal
logger.Debugf(" 📊 OKX position %s: contracts=%.4f, ctVal=%.6f, posAmt=%.6f", symbol, contractCount, inst.CtVal, posAmt)
}
// Parse timestamps
cTime, _ := strconv.ParseInt(pos.CTime, 10, 64)
uTime, _ := strconv.ParseInt(pos.UTime, 10, 64)
// Default to cross margin mode if not specified
mgnMode := pos.MgnMode
if mgnMode == "" {
mgnMode = "cross"
}
posMap := map[string]interface{}{
"symbol": symbol,
"positionAmt": posAmt,
"entryPrice": entryPrice,
"markPrice": markPrice,
"unRealizedProfit": upl,
"leverage": leverage,
"liquidationPrice": liqPrice,
"side": side,
"mgnMode": mgnMode, // Margin mode: "cross" or "isolated"
"createdTime": cTime, // Position open time (ms)
"updatedTime": uTime, // Position last update time (ms)
}
result = append(result, posMap)
}
// Update cache
t.positionsCacheMutex.Lock()
t.cachedPositions = result
t.positionsCacheTime = time.Now()
t.positionsCacheMutex.Unlock()
return result, nil
}
// InvalidatePositionCache clears the position cache to force fresh data on next call
func (t *OKXTrader) InvalidatePositionCache() {
t.positionsCacheMutex.Lock()
t.cachedPositions = nil
t.positionsCacheTime = time.Time{}
t.positionsCacheMutex.Unlock()
}
// getInstrument gets instrument info
func (t *OKXTrader) getInstrument(symbol string) (*OKXInstrument, error) {
instId := t.convertSymbol(symbol)
// Check cache
t.instrumentsCacheMutex.RLock()
if inst, ok := t.instrumentsCache[instId]; ok && time.Since(t.instrumentsCacheTime) < 5*time.Minute {
t.instrumentsCacheMutex.RUnlock()
return inst, nil
}
t.instrumentsCacheMutex.RUnlock()
// Get instrument info
path := fmt.Sprintf("%s?instType=SWAP&instId=%s", okxInstrumentsPath, instId)
data, err := t.doRequest("GET", path, nil)
if err != nil {
return nil, err
}
var instruments []struct {
InstId string `json:"instId"`
CtVal string `json:"ctVal"`
CtMult string `json:"ctMult"`
LotSz string `json:"lotSz"`
MinSz string `json:"minSz"`
MaxMktSz string `json:"maxMktSz"` // Maximum market order size
TickSz string `json:"tickSz"`
CtType string `json:"ctType"`
}
if err := json.Unmarshal(data, &instruments); err != nil {
return nil, err
}
if len(instruments) == 0 {
return nil, fmt.Errorf("instrument info not found: %s", instId)
}
inst := instruments[0]
ctVal, _ := strconv.ParseFloat(inst.CtVal, 64)
ctMult, _ := strconv.ParseFloat(inst.CtMult, 64)
lotSz, _ := strconv.ParseFloat(inst.LotSz, 64)
minSz, _ := strconv.ParseFloat(inst.MinSz, 64)
maxMktSz, _ := strconv.ParseFloat(inst.MaxMktSz, 64)
tickSz, _ := strconv.ParseFloat(inst.TickSz, 64)
instrument := &OKXInstrument{
InstID: inst.InstId,
CtVal: ctVal,
CtMult: ctMult,
LotSz: lotSz,
MinSz: minSz,
MaxMktSz: maxMktSz,
TickSz: tickSz,
CtType: inst.CtType,
}
// Update cache
t.instrumentsCacheMutex.Lock()
t.instrumentsCache[instId] = instrument
t.instrumentsCacheTime = time.Now()
t.instrumentsCacheMutex.Unlock()
return instrument, nil
}