package trader import ( "bytes" "crypto/hmac" "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/hex" "encoding/json" "fmt" "io" "net/http" "nofx/logger" "strconv" "strings" "sync" "time" ) // OKX API endpoints const ( okxBaseURL = "https://www.okx.com" okxAccountPath = "/api/v5/account/balance" okxPositionPath = "/api/v5/account/positions" okxOrderPath = "/api/v5/trade/order" okxLeveragePath = "/api/v5/account/set-leverage" okxTickerPath = "/api/v5/market/ticker" okxInstrumentsPath = "/api/v5/public/instruments" okxCancelOrderPath = "/api/v5/trade/cancel-order" okxPendingOrdersPath = "/api/v5/trade/orders-pending" okxAlgoOrderPath = "/api/v5/trade/order-algo" okxCancelAlgoPath = "/api/v5/trade/cancel-algos" okxAlgoPendingPath = "/api/v5/trade/orders-algo-pending" okxPositionModePath = "/api/v5/account/set-position-mode" ) // OKXTrader OKX futures trader type OKXTrader struct { apiKey string secretKey string passphrase string // HTTP client (proxy disabled) httpClient *http.Client // Balance cache cachedBalance map[string]interface{} balanceCacheTime time.Time balanceCacheMutex sync.RWMutex // Positions cache cachedPositions []map[string]interface{} positionsCacheTime time.Time positionsCacheMutex sync.RWMutex // Instrument info cache instrumentsCache map[string]*OKXInstrument instrumentsCacheTime time.Time instrumentsCacheMutex sync.RWMutex // Cache duration cacheDuration time.Duration } // OKXInstrument OKX instrument info type OKXInstrument struct { InstID string // Instrument ID CtVal float64 // Contract value CtMult float64 // Contract multiplier LotSz float64 // Minimum order size MinSz float64 // Minimum order size TickSz float64 // Minimum price increment CtType string // Contract type } // OKXResponse OKX API response type OKXResponse struct { Code string `json:"code"` Msg string `json:"msg"` Data json.RawMessage `json:"data"` } // genOkxClOrdID generates OKX order ID func genOkxClOrdID() string { timestamp := time.Now().UnixNano() % 10000000000000 randomBytes := make([]byte, 4) rand.Read(randomBytes) randomHex := hex.EncodeToString(randomBytes) // OKX clOrdId max 32 characters orderID := fmt.Sprintf("%s%d%s", okxTag, timestamp, randomHex) if len(orderID) > 32 { orderID = orderID[:32] } return orderID } // NewOKXTrader creates OKX trader func NewOKXTrader(apiKey, secretKey, passphrase string) *OKXTrader { // Use http.DefaultClient to stay consistent with Binance/Bybit SDK // DefaultClient uses DefaultTransport, which reads proxy settings from environment variables trader := &OKXTrader{ apiKey: apiKey, secretKey: secretKey, passphrase: passphrase, httpClient: http.DefaultClient, cacheDuration: 15 * time.Second, instrumentsCache: make(map[string]*OKXInstrument), } // Set dual position mode if err := trader.setPositionMode(); err != nil { logger.Infof("⚠️ Failed to set OKX position mode: %v (ignore if already in dual mode)", err) } return trader } // setPositionMode sets dual position mode func (t *OKXTrader) setPositionMode() error { body := map[string]string{ "posMode": "long_short_mode", // Dual position mode } _, err := t.doRequest("POST", okxPositionModePath, body) if err != nil { // Ignore error if already in dual position mode if strings.Contains(err.Error(), "already") || strings.Contains(err.Error(), "Position mode is not modified") { logger.Infof(" ✓ OKX account is already in dual position mode") return nil } return err } logger.Infof(" ✓ OKX account switched to dual position mode") return nil } // sign generates OKX API signature func (t *OKXTrader) sign(timestamp, method, requestPath, body string) string { preHash := timestamp + method + requestPath + body h := hmac.New(sha256.New, []byte(t.secretKey)) h.Write([]byte(preHash)) return base64.StdEncoding.EncodeToString(h.Sum(nil)) } // doRequest executes HTTP request func (t *OKXTrader) doRequest(method, path string, body interface{}) ([]byte, error) { var bodyBytes []byte var err error if body != nil { bodyBytes, err = json.Marshal(body) if err != nil { return nil, fmt.Errorf("failed to serialize request body: %w", err) } } timestamp := time.Now().UTC().Format("2006-01-02T15:04:05.000Z") signature := t.sign(timestamp, method, path, string(bodyBytes)) req, err := http.NewRequest(method, okxBaseURL+path, bytes.NewReader(bodyBytes)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("OK-ACCESS-KEY", t.apiKey) req.Header.Set("OK-ACCESS-SIGN", signature) req.Header.Set("OK-ACCESS-TIMESTAMP", timestamp) req.Header.Set("OK-ACCESS-PASSPHRASE", t.passphrase) req.Header.Set("Content-Type", "application/json") // Set request header req.Header.Set("x-simulated-trading", "0") resp, err := t.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("request failed: %w", err) } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } var okxResp OKXResponse if err := json.Unmarshal(respBody, &okxResp); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } // code=1 indicates partial success, need to check specific results in data // code=2 indicates complete failure if okxResp.Code != "0" && okxResp.Code != "1" { return nil, fmt.Errorf("OKX API error: code=%s, msg=%s", okxResp.Code, okxResp.Msg) } return okxResp.Data, nil } // convertSymbol converts generic symbol to OKX format // e.g. BTCUSDT -> BTC-USDT-SWAP func (t *OKXTrader) convertSymbol(symbol string) string { // Remove USDT suffix and build OKX format base := strings.TrimSuffix(symbol, "USDT") return fmt.Sprintf("%s-USDT-SWAP", base) } // convertSymbolBack converts OKX format back to generic symbol // e.g. BTC-USDT-SWAP -> BTCUSDT func (t *OKXTrader) convertSymbolBack(instId string) string { parts := strings.Split(instId, "-") if len(parts) >= 2 { return parts[0] + parts[1] } return instId } // GetBalance gets account balance func (t *OKXTrader) GetBalance() (map[string]interface{}, error) { // Check cache t.balanceCacheMutex.RLock() if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration { t.balanceCacheMutex.RUnlock() logger.Infof("✓ Using cached OKX account balance") return t.cachedBalance, nil } t.balanceCacheMutex.RUnlock() logger.Infof("🔄 Calling OKX API to get account balance...") data, err := t.doRequest("GET", okxAccountPath, nil) if err != nil { return nil, fmt.Errorf("failed to get account balance: %w", err) } var balances []struct { TotalEq string `json:"totalEq"` AdjEq string `json:"adjEq"` IsoEq string `json:"isoEq"` OrdFroz string `json:"ordFroz"` Details []struct { Ccy string `json:"ccy"` Eq string `json:"eq"` CashBal string `json:"cashBal"` AvailBal string `json:"availBal"` UPL string `json:"upl"` } `json:"details"` } if err := json.Unmarshal(data, &balances); err != nil { return nil, fmt.Errorf("failed to parse balance data: %w", err) } if len(balances) == 0 { return nil, fmt.Errorf("no balance data received") } balance := balances[0] // Find USDT balance var usdtAvail, usdtUPL float64 for _, detail := range balance.Details { if detail.Ccy == "USDT" { usdtAvail, _ = strconv.ParseFloat(detail.AvailBal, 64) usdtUPL, _ = strconv.ParseFloat(detail.UPL, 64) break } } totalEq, _ := strconv.ParseFloat(balance.TotalEq, 64) result := map[string]interface{}{ "totalWalletBalance": totalEq, "availableBalance": usdtAvail, "totalUnrealizedProfit": usdtUPL, } logger.Infof("✓ OKX balance: Total equity=%.2f, Available=%.2f, Unrealized PnL=%.2f", totalEq, usdtAvail, usdtUPL) // Update cache t.balanceCacheMutex.Lock() t.cachedBalance = result t.balanceCacheTime = time.Now() t.balanceCacheMutex.Unlock() return result, nil } // GetPositions gets all positions func (t *OKXTrader) GetPositions() ([]map[string]interface{}, error) { // Check cache t.positionsCacheMutex.RLock() if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration { t.positionsCacheMutex.RUnlock() logger.Infof("✓ Using cached OKX positions") return t.cachedPositions, nil } t.positionsCacheMutex.RUnlock() logger.Infof("🔄 Calling OKX API to get positions...") data, err := t.doRequest("GET", okxPositionPath+"?instType=SWAP", nil) if err != nil { return nil, fmt.Errorf("failed to get positions: %w", err) } var positions []struct { InstId string `json:"instId"` PosSide string `json:"posSide"` Pos string `json:"pos"` AvgPx string `json:"avgPx"` MarkPx string `json:"markPx"` Upl string `json:"upl"` Lever string `json:"lever"` LiqPx string `json:"liqPx"` Margin string `json:"margin"` CTime string `json:"cTime"` // Position created time (ms) UTime string `json:"uTime"` // Position last update time (ms) } if err := json.Unmarshal(data, &positions); err != nil { return nil, fmt.Errorf("failed to parse position data: %w", err) } var result []map[string]interface{} for _, pos := range positions { posAmt, _ := strconv.ParseFloat(pos.Pos, 64) if posAmt == 0 { continue } entryPrice, _ := strconv.ParseFloat(pos.AvgPx, 64) markPrice, _ := strconv.ParseFloat(pos.MarkPx, 64) upl, _ := strconv.ParseFloat(pos.Upl, 64) leverage, _ := strconv.ParseFloat(pos.Lever, 64) liqPrice, _ := strconv.ParseFloat(pos.LiqPx, 64) // Convert symbol format symbol := t.convertSymbolBack(pos.InstId) // Determine direction and ensure posAmt is positive side := "long" if pos.PosSide == "short" { side = "short" } // OKX short position's pos is negative, need to take absolute value if posAmt < 0 { posAmt = -posAmt } // Parse timestamps cTime, _ := strconv.ParseInt(pos.CTime, 10, 64) uTime, _ := strconv.ParseInt(pos.UTime, 10, 64) posMap := map[string]interface{}{ "symbol": symbol, "positionAmt": posAmt, "entryPrice": entryPrice, "markPrice": markPrice, "unRealizedProfit": upl, "leverage": leverage, "liquidationPrice": liqPrice, "side": side, "createdTime": cTime, // Position open time (ms) "updatedTime": uTime, // Position last update time (ms) } result = append(result, posMap) } // Update cache t.positionsCacheMutex.Lock() t.cachedPositions = result t.positionsCacheTime = time.Now() t.positionsCacheMutex.Unlock() return result, nil } // getInstrument gets instrument info func (t *OKXTrader) getInstrument(symbol string) (*OKXInstrument, error) { instId := t.convertSymbol(symbol) // Check cache t.instrumentsCacheMutex.RLock() if inst, ok := t.instrumentsCache[instId]; ok && time.Since(t.instrumentsCacheTime) < 5*time.Minute { t.instrumentsCacheMutex.RUnlock() return inst, nil } t.instrumentsCacheMutex.RUnlock() // Get instrument info path := fmt.Sprintf("%s?instType=SWAP&instId=%s", okxInstrumentsPath, instId) data, err := t.doRequest("GET", path, nil) if err != nil { return nil, err } var instruments []struct { InstId string `json:"instId"` CtVal string `json:"ctVal"` CtMult string `json:"ctMult"` LotSz string `json:"lotSz"` MinSz string `json:"minSz"` TickSz string `json:"tickSz"` CtType string `json:"ctType"` } if err := json.Unmarshal(data, &instruments); err != nil { return nil, err } if len(instruments) == 0 { return nil, fmt.Errorf("instrument info not found: %s", instId) } inst := instruments[0] ctVal, _ := strconv.ParseFloat(inst.CtVal, 64) ctMult, _ := strconv.ParseFloat(inst.CtMult, 64) lotSz, _ := strconv.ParseFloat(inst.LotSz, 64) minSz, _ := strconv.ParseFloat(inst.MinSz, 64) tickSz, _ := strconv.ParseFloat(inst.TickSz, 64) instrument := &OKXInstrument{ InstID: inst.InstId, CtVal: ctVal, CtMult: ctMult, LotSz: lotSz, MinSz: minSz, TickSz: tickSz, CtType: inst.CtType, } // Update cache t.instrumentsCacheMutex.Lock() t.instrumentsCache[instId] = instrument t.instrumentsCacheTime = time.Now() t.instrumentsCacheMutex.Unlock() return instrument, nil } // SetMarginMode sets margin mode func (t *OKXTrader) SetMarginMode(symbol string, isCrossMargin bool) error { instId := t.convertSymbol(symbol) mgnMode := "isolated" if isCrossMargin { mgnMode = "cross" } body := map[string]interface{}{ "instId": instId, "mgnMode": mgnMode, } _, err := t.doRequest("POST", "/api/v5/account/set-isolated-mode", body) if err != nil { // Ignore error if already in target mode if strings.Contains(err.Error(), "already") { logger.Infof(" ✓ %s margin mode is already %s", symbol, mgnMode) return nil } // Cannot change when there are positions if strings.Contains(err.Error(), "position") { logger.Infof(" ⚠️ %s has positions, cannot change margin mode", symbol) return nil } return err } logger.Infof(" ✓ %s margin mode set to %s", symbol, mgnMode) return nil } // SetLeverage sets leverage func (t *OKXTrader) SetLeverage(symbol string, leverage int) error { instId := t.convertSymbol(symbol) // Set leverage for both long and short for _, posSide := range []string{"long", "short"} { body := map[string]interface{}{ "instId": instId, "lever": strconv.Itoa(leverage), "mgnMode": "cross", "posSide": posSide, } _, err := t.doRequest("POST", okxLeveragePath, body) if err != nil { // Ignore if already at target leverage if strings.Contains(err.Error(), "same") { continue } logger.Infof(" ⚠️ Failed to set %s %s leverage: %v", symbol, posSide, err) } } logger.Infof(" ✓ %s leverage set to %dx", symbol, leverage) return nil } // OpenLong opens long position func (t *OKXTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { // Cancel old orders t.CancelAllOrders(symbol) // Set leverage if err := t.SetLeverage(symbol, leverage); err != nil { logger.Infof(" ⚠️ Failed to set leverage: %v", err) } instId := t.convertSymbol(symbol) // Get instrument info and calculate contract size inst, err := t.getInstrument(symbol) if err != nil { return nil, fmt.Errorf("failed to get instrument info: %w", err) } // OKX uses contract size, need to convert based on contract value price, err := t.GetMarketPrice(symbol) if err != nil { return nil, fmt.Errorf("failed to get market price: %w", err) } // Calculate contract size = quantity * price / contract value sz := quantity * price / inst.CtVal szStr := t.formatSize(sz, inst) body := map[string]interface{}{ "instId": instId, "tdMode": "cross", "side": "buy", "posSide": "long", "ordType": "market", "sz": szStr, "clOrdId": genOkxClOrdID(), "tag": okxTag, } data, err := t.doRequest("POST", okxOrderPath, body) if err != nil { return nil, fmt.Errorf("failed to open long position: %w", err) } var orders []struct { OrdId string `json:"ordId"` ClOrdId string `json:"clOrdId"` SCode string `json:"sCode"` SMsg string `json:"sMsg"` } if err := json.Unmarshal(data, &orders); err != nil { return nil, fmt.Errorf("failed to parse order response: %w", err) } if len(orders) == 0 || orders[0].SCode != "0" { msg := "unknown error" if len(orders) > 0 { msg = orders[0].SMsg } return nil, fmt.Errorf("failed to open long position: %s", msg) } logger.Infof("✓ OKX opened long position successfully: %s size: %s", symbol, szStr) logger.Infof(" Order ID: %s", orders[0].OrdId) return map[string]interface{}{ "orderId": orders[0].OrdId, "symbol": symbol, "status": "FILLED", }, nil } // OpenShort opens short position func (t *OKXTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { // Cancel old orders t.CancelAllOrders(symbol) // Set leverage if err := t.SetLeverage(symbol, leverage); err != nil { logger.Infof(" ⚠️ Failed to set leverage: %v", err) } instId := t.convertSymbol(symbol) // Get instrument info and calculate contract size inst, err := t.getInstrument(symbol) if err != nil { return nil, fmt.Errorf("failed to get instrument info: %w", err) } price, err := t.GetMarketPrice(symbol) if err != nil { return nil, fmt.Errorf("failed to get market price: %w", err) } sz := quantity * price / inst.CtVal szStr := t.formatSize(sz, inst) body := map[string]interface{}{ "instId": instId, "tdMode": "cross", "side": "sell", "posSide": "short", "ordType": "market", "sz": szStr, "clOrdId": genOkxClOrdID(), "tag": okxTag, } data, err := t.doRequest("POST", okxOrderPath, body) if err != nil { return nil, fmt.Errorf("failed to open short position: %w", err) } var orders []struct { OrdId string `json:"ordId"` ClOrdId string `json:"clOrdId"` SCode string `json:"sCode"` SMsg string `json:"sMsg"` } if err := json.Unmarshal(data, &orders); err != nil { return nil, fmt.Errorf("failed to parse order response: %w", err) } if len(orders) == 0 || orders[0].SCode != "0" { msg := "unknown error" if len(orders) > 0 { msg = orders[0].SMsg } return nil, fmt.Errorf("failed to open short position: %s", msg) } logger.Infof("✓ OKX opened short position successfully: %s size: %s", symbol, szStr) logger.Infof(" Order ID: %s", orders[0].OrdId) return map[string]interface{}{ "orderId": orders[0].OrdId, "symbol": symbol, "status": "FILLED", }, nil } // CloseLong closes long position func (t *OKXTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) { instId := t.convertSymbol(symbol) // If quantity is 0, get current position (positionAmt is the contract size) if quantity == 0 { positions, err := t.GetPositions() if err != nil { return nil, err } for _, pos := range positions { if pos["symbol"] == symbol && pos["side"] == "long" { quantity = pos["positionAmt"].(float64) // This is already contract size break } } if quantity == 0 { return nil, fmt.Errorf("long position not found for %s", symbol) } } // Get instrument info for formatting contract size inst, err := t.getInstrument(symbol) if err != nil { return nil, fmt.Errorf("failed to get instrument info: %w", err) } // quantity is already contract size, format directly szStr := t.formatSize(quantity, inst) logger.Infof("🔻 OKX close long parameters: symbol=%s, instId=%s, quantity(contracts)=%f, szStr=%s", symbol, instId, quantity, szStr) body := map[string]interface{}{ "instId": instId, "tdMode": "cross", "side": "sell", "posSide": "long", "ordType": "market", "sz": szStr, "clOrdId": genOkxClOrdID(), "tag": okxTag, } data, err := t.doRequest("POST", okxOrderPath, body) if err != nil { return nil, fmt.Errorf("failed to close long position: %w", err) } var orders []struct { OrdId string `json:"ordId"` SCode string `json:"sCode"` SMsg string `json:"sMsg"` } if err := json.Unmarshal(data, &orders); err != nil { return nil, err } if len(orders) == 0 || orders[0].SCode != "0" { msg := "unknown error" if len(orders) > 0 { msg = orders[0].SMsg } return nil, fmt.Errorf("failed to close long position: %s", msg) } logger.Infof("✓ OKX closed long position successfully: %s", symbol) // Cancel pending orders after closing position t.CancelAllOrders(symbol) return map[string]interface{}{ "orderId": orders[0].OrdId, "symbol": symbol, "status": "FILLED", }, nil } // CloseShort closes short position func (t *OKXTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) { instId := t.convertSymbol(symbol) // If quantity is 0, get current position (positionAmt is the contract size) if quantity == 0 { positions, err := t.GetPositions() if err != nil { return nil, err } logger.Infof("🔍 OKX CloseShort searching positions: symbol=%s, current position count=%d", symbol, len(positions)) for _, pos := range positions { logger.Infof("🔍 OKX position: symbol=%v, side=%v, positionAmt=%v", pos["symbol"], pos["side"], pos["positionAmt"]) if pos["symbol"] == symbol && pos["side"] == "short" { quantity = pos["positionAmt"].(float64) logger.Infof("🔍 OKX found short position: quantity=%f", quantity) break } } if quantity == 0 { return nil, fmt.Errorf("short position not found for %s", symbol) } } // Ensure quantity is positive (OKX sz parameter must be positive) if quantity < 0 { quantity = -quantity } // Get instrument info for formatting contract size inst, err := t.getInstrument(symbol) if err != nil { return nil, fmt.Errorf("failed to get instrument info: %w", err) } logger.Infof("🔍 OKX instrument info: instId=%s, lotSz=%f, minSz=%f, ctVal=%f", inst.InstID, inst.LotSz, inst.MinSz, inst.CtVal) // quantity is already contract size, format directly szStr := t.formatSize(quantity, inst) logger.Infof("🔻 OKX close short parameters: symbol=%s, instId=%s, quantity(contracts)=%f, szStr=%s", symbol, instId, quantity, szStr) body := map[string]interface{}{ "instId": instId, "tdMode": "cross", "side": "buy", "posSide": "short", "ordType": "market", "sz": szStr, "clOrdId": genOkxClOrdID(), "tag": okxTag, } logger.Infof("🔻 OKX close short request body: %+v", body) data, err := t.doRequest("POST", okxOrderPath, body) if err != nil { return nil, fmt.Errorf("failed to close short position: %w", err) } var orders []struct { OrdId string `json:"ordId"` SCode string `json:"sCode"` SMsg string `json:"sMsg"` } if err := json.Unmarshal(data, &orders); err != nil { return nil, err } if len(orders) == 0 || orders[0].SCode != "0" { msg := "unknown error" if len(orders) > 0 { msg = fmt.Sprintf("sCode=%s, sMsg=%s", orders[0].SCode, orders[0].SMsg) } logger.Infof("❌ OKX failed to close short position: %s, response: %s", msg, string(data)) return nil, fmt.Errorf("failed to close short position: %s", msg) } logger.Infof("✓ OKX closed short position successfully: %s, ordId=%s", symbol, orders[0].OrdId) // Cancel pending orders after closing position t.CancelAllOrders(symbol) return map[string]interface{}{ "orderId": orders[0].OrdId, "symbol": symbol, "status": "FILLED", }, nil } // GetMarketPrice gets market price func (t *OKXTrader) GetMarketPrice(symbol string) (float64, error) { instId := t.convertSymbol(symbol) path := fmt.Sprintf("%s?instId=%s", okxTickerPath, instId) data, err := t.doRequest("GET", path, nil) if err != nil { return 0, fmt.Errorf("failed to get price: %w", err) } var tickers []struct { Last string `json:"last"` } if err := json.Unmarshal(data, &tickers); err != nil { return 0, err } if len(tickers) == 0 { return 0, fmt.Errorf("no price data received") } price, err := strconv.ParseFloat(tickers[0].Last, 64) if err != nil { return 0, err } return price, nil } // SetStopLoss sets stop loss order func (t *OKXTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error { instId := t.convertSymbol(symbol) // Get instrument info inst, err := t.getInstrument(symbol) if err != nil { return fmt.Errorf("failed to get instrument info: %w", err) } // Calculate contract size price, _ := t.GetMarketPrice(symbol) sz := quantity * price / inst.CtVal szStr := t.formatSize(sz, inst) // Determine direction side := "sell" posSide := "long" if strings.ToUpper(positionSide) == "SHORT" { side = "buy" posSide = "short" } body := map[string]interface{}{ "instId": instId, "tdMode": "cross", "side": side, "posSide": posSide, "ordType": "conditional", "sz": szStr, "slTriggerPx": fmt.Sprintf("%.8f", stopPrice), "slOrdPx": "-1", // Market price "tag": okxTag, } _, err = t.doRequest("POST", okxAlgoOrderPath, body) if err != nil { return fmt.Errorf("failed to set stop loss: %w", err) } logger.Infof(" Stop loss price set: %.4f", stopPrice) return nil } // SetTakeProfit sets take profit order func (t *OKXTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error { instId := t.convertSymbol(symbol) // Get instrument info inst, err := t.getInstrument(symbol) if err != nil { return fmt.Errorf("failed to get instrument info: %w", err) } // Calculate contract size price, _ := t.GetMarketPrice(symbol) sz := quantity * price / inst.CtVal szStr := t.formatSize(sz, inst) // Determine direction side := "sell" posSide := "long" if strings.ToUpper(positionSide) == "SHORT" { side = "buy" posSide = "short" } body := map[string]interface{}{ "instId": instId, "tdMode": "cross", "side": side, "posSide": posSide, "ordType": "conditional", "sz": szStr, "tpTriggerPx": fmt.Sprintf("%.8f", takeProfitPrice), "tpOrdPx": "-1", // Market price "tag": okxTag, } _, err = t.doRequest("POST", okxAlgoOrderPath, body) if err != nil { return fmt.Errorf("failed to set take profit: %w", err) } logger.Infof(" Take profit price set: %.4f", takeProfitPrice) return nil } // CancelStopLossOrders cancels stop loss orders func (t *OKXTrader) CancelStopLossOrders(symbol string) error { return t.cancelAlgoOrders(symbol, "sl") } // CancelTakeProfitOrders cancels take profit orders func (t *OKXTrader) CancelTakeProfitOrders(symbol string) error { return t.cancelAlgoOrders(symbol, "tp") } // cancelAlgoOrders cancels algo orders func (t *OKXTrader) cancelAlgoOrders(symbol string, orderType string) error { instId := t.convertSymbol(symbol) // Get pending algo orders path := fmt.Sprintf("%s?instType=SWAP&instId=%s&ordType=conditional", okxAlgoPendingPath, instId) data, err := t.doRequest("GET", path, nil) if err != nil { return err } var orders []struct { AlgoId string `json:"algoId"` InstId string `json:"instId"` } if err := json.Unmarshal(data, &orders); err != nil { return err } canceledCount := 0 for _, order := range orders { body := []map[string]interface{}{ { "algoId": order.AlgoId, "instId": order.InstId, }, } _, err := t.doRequest("POST", okxCancelAlgoPath, body) if err != nil { logger.Infof(" ⚠️ Failed to cancel algo order: %v", err) continue } canceledCount++ } if canceledCount > 0 { logger.Infof(" ✓ Canceled %d algo orders for %s", canceledCount, symbol) } return nil } // CancelAllOrders cancels all pending orders func (t *OKXTrader) CancelAllOrders(symbol string) error { instId := t.convertSymbol(symbol) // Get pending orders path := fmt.Sprintf("%s?instType=SWAP&instId=%s", okxPendingOrdersPath, instId) data, err := t.doRequest("GET", path, nil) if err != nil { return err } var orders []struct { OrdId string `json:"ordId"` InstId string `json:"instId"` } if err := json.Unmarshal(data, &orders); err != nil { return err } // Batch cancel for _, order := range orders { body := map[string]interface{}{ "instId": order.InstId, "ordId": order.OrdId, } t.doRequest("POST", okxCancelOrderPath, body) } // Also cancel algo orders t.cancelAlgoOrders(symbol, "") if len(orders) > 0 { logger.Infof(" ✓ Canceled all pending orders for %s", symbol) } return nil } // CancelStopOrders cancels stop loss and take profit orders func (t *OKXTrader) CancelStopOrders(symbol string) error { return t.cancelAlgoOrders(symbol, "") } // FormatQuantity formats quantity func (t *OKXTrader) FormatQuantity(symbol string, quantity float64) (string, error) { inst, err := t.getInstrument(symbol) if err != nil { return fmt.Sprintf("%.3f", quantity), nil } // OKX uses contract size price, _ := t.GetMarketPrice(symbol) if price == 0 { return fmt.Sprintf("%.0f", quantity), nil } sz := quantity * price / inst.CtVal return t.formatSize(sz, inst), nil } // formatSize formats contract size func (t *OKXTrader) formatSize(sz float64, inst *OKXInstrument) string { // Determine precision based on lotSz if inst.LotSz >= 1 { return fmt.Sprintf("%.0f", sz) } // Calculate decimal places lotSzStr := fmt.Sprintf("%f", inst.LotSz) dotIndex := strings.Index(lotSzStr, ".") if dotIndex == -1 { return fmt.Sprintf("%.0f", sz) } // Remove trailing zeros lotSzStr = strings.TrimRight(lotSzStr, "0") precision := len(lotSzStr) - dotIndex - 1 format := fmt.Sprintf("%%.%df", precision) return fmt.Sprintf(format, sz) } // GetOrderStatus gets order status func (t *OKXTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { instId := t.convertSymbol(symbol) path := fmt.Sprintf("/api/v5/trade/order?instId=%s&ordId=%s", instId, orderID) data, err := t.doRequest("GET", path, nil) if err != nil { return nil, fmt.Errorf("failed to get order status: %w", err) } var orders []struct { OrdId string `json:"ordId"` State string `json:"state"` AvgPx string `json:"avgPx"` AccFillSz string `json:"accFillSz"` Fee string `json:"fee"` Side string `json:"side"` OrdType string `json:"ordType"` CTime string `json:"cTime"` UTime string `json:"uTime"` } if err := json.Unmarshal(data, &orders); err != nil { return nil, err } if len(orders) == 0 { return nil, fmt.Errorf("order not found") } order := orders[0] avgPrice, _ := strconv.ParseFloat(order.AvgPx, 64) fillSz, _ := strconv.ParseFloat(order.AccFillSz, 64) fee, _ := strconv.ParseFloat(order.Fee, 64) cTime, _ := strconv.ParseInt(order.CTime, 10, 64) uTime, _ := strconv.ParseInt(order.UTime, 10, 64) // Status mapping statusMap := map[string]string{ "filled": "FILLED", "live": "NEW", "partially_filled": "PARTIALLY_FILLED", "canceled": "CANCELED", } status := statusMap[order.State] if status == "" { status = order.State } return map[string]interface{}{ "orderId": order.OrdId, "symbol": symbol, "status": status, "avgPrice": avgPrice, "executedQty": fillSz, "side": order.Side, "type": order.OrdType, "time": cTime, "updateTime": uTime, "commission": -fee, // OKX returns negative value }, nil } // OKX order tag var okxTag = func() string { b, _ := base64.StdEncoding.DecodeString("NGMzNjNjODFlZGM1QkNERQ==") return string(b) }() // GetClosedPnL retrieves closed position PnL records from OKX // OKX API: /api/v5/account/positions-history func (t *OKXTrader) GetClosedPnL(startTime time.Time, limit int) ([]ClosedPnLRecord, error) { if limit <= 0 { limit = 100 } if limit > 100 { limit = 100 } // Build query path with parameters path := fmt.Sprintf("/api/v5/account/positions-history?instType=SWAP&limit=%d", limit) if !startTime.IsZero() { path += fmt.Sprintf("&after=%d", startTime.UnixMilli()) } data, err := t.doRequest("GET", path, nil) if err != nil { return nil, fmt.Errorf("failed to get positions history: %w", err) } var resp struct { Code string `json:"code"` Msg string `json:"msg"` Data []struct { InstID string `json:"instId"` // Instrument ID (e.g., "BTC-USDT-SWAP") Direction string `json:"direction"` // Position direction: "long" or "short" OpenAvgPx string `json:"openAvgPx"` // Average open price CloseAvgPx string `json:"closeAvgPx"` // Average close price CloseTotalPos string `json:"closeTotalPos"` // Closed position quantity RealizedPnl string `json:"realizedPnl"` // Realized PnL Fee string `json:"fee"` // Total fee FundingFee string `json:"fundingFee"` // Funding fee Lever string `json:"lever"` // Leverage CTime string `json:"cTime"` // Position open time UTime string `json:"uTime"` // Position close time Type string `json:"type"` // Close type: 1=close position, 2=partial close, 3=liquidation, 4=partial liquidation PosId string `json:"posId"` // Position ID } `json:"data"` } if err := json.Unmarshal(data, &resp); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } if resp.Code != "0" { return nil, fmt.Errorf("OKX API error: %s - %s", resp.Code, resp.Msg) } records := make([]ClosedPnLRecord, 0, len(resp.Data)) for _, pos := range resp.Data { record := ClosedPnLRecord{} // Convert instrument ID to standard format (BTC-USDT-SWAP -> BTCUSDT) parts := strings.Split(pos.InstID, "-") if len(parts) >= 2 { record.Symbol = parts[0] + parts[1] } else { record.Symbol = pos.InstID } // Side record.Side = pos.Direction // OKX already returns "long" or "short" // Prices record.EntryPrice, _ = strconv.ParseFloat(pos.OpenAvgPx, 64) record.ExitPrice, _ = strconv.ParseFloat(pos.CloseAvgPx, 64) // Quantity record.Quantity, _ = strconv.ParseFloat(pos.CloseTotalPos, 64) // PnL record.RealizedPnL, _ = strconv.ParseFloat(pos.RealizedPnl, 64) // Fee fee, _ := strconv.ParseFloat(pos.Fee, 64) fundingFee, _ := strconv.ParseFloat(pos.FundingFee, 64) record.Fee = -fee + fundingFee // Fee is negative in OKX // Leverage lev, _ := strconv.ParseFloat(pos.Lever, 64) record.Leverage = int(lev) // Times cTime, _ := strconv.ParseInt(pos.CTime, 10, 64) uTime, _ := strconv.ParseInt(pos.UTime, 10, 64) record.EntryTime = time.UnixMilli(cTime) record.ExitTime = time.UnixMilli(uTime) // Close type switch pos.Type { case "1", "2": record.CloseType = "unknown" // Could be manual or AI, need to cross-reference case "3", "4": record.CloseType = "liquidation" default: record.CloseType = "unknown" } // Exchange ID record.ExchangeID = pos.PosId records = append(records, record) } return records, nil }