Files
nofx/trader/okx/trader.go
T

310 lines
8.6 KiB
Go

package okx
import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"nofx/logger"
"strings"
"sync"
"time"
)
// OKX API endpoints
const (
okxBaseURL = "https://www.okx.com"
okxAccountPath = "/api/v5/account/balance"
okxPositionPath = "/api/v5/account/positions"
okxOrderPath = "/api/v5/trade/order"
okxLeveragePath = "/api/v5/account/set-leverage"
okxTickerPath = "/api/v5/market/ticker"
okxInstrumentsPath = "/api/v5/public/instruments"
okxCancelOrderPath = "/api/v5/trade/cancel-order"
okxPendingOrdersPath = "/api/v5/trade/orders-pending"
okxAlgoOrderPath = "/api/v5/trade/order-algo"
okxCancelAlgoPath = "/api/v5/trade/cancel-algos"
okxAlgoPendingPath = "/api/v5/trade/orders-algo-pending"
okxPositionModePath = "/api/v5/account/set-position-mode"
okxAccountConfigPath = "/api/v5/account/config"
)
// OKXTrader OKX futures trader
type OKXTrader struct {
apiKey string
secretKey string
passphrase string
// Margin mode setting used for new orders and leverage changes.
isCrossMargin bool
// Position mode: "long_short_mode" (hedge) or "net_mode" (one-way)
positionMode string
// HTTP client (proxy disabled)
httpClient *http.Client
// Balance cache
cachedBalance map[string]interface{}
balanceCacheTime time.Time
balanceCacheMutex sync.RWMutex
// Positions cache
cachedPositions []map[string]interface{}
positionsCacheTime time.Time
positionsCacheMutex sync.RWMutex
// Instrument info cache
instrumentsCache map[string]*OKXInstrument
instrumentsCacheTime time.Time
instrumentsCacheMutex sync.RWMutex
// Cache duration
cacheDuration time.Duration
}
// OKXInstrument OKX instrument info
type OKXInstrument struct {
InstID string // Instrument ID
CtVal float64 // Contract value
CtMult float64 // Contract multiplier
LotSz float64 // Minimum order size
MinSz float64 // Minimum order size
MaxMktSz float64 // Maximum market order size
TickSz float64 // Minimum price increment
CtType string // Contract type
}
// OKXResponse OKX API response
type OKXResponse struct {
Code string `json:"code"`
Msg string `json:"msg"`
Data json.RawMessage `json:"data"`
}
// OKX order tag
var okxTag = func() string {
b, _ := base64.StdEncoding.DecodeString("NGMzNjNjODFlZGM1QkNERQ==")
return string(b)
}()
// genOkxClOrdID generates OKX order ID
func genOkxClOrdID() string {
timestamp := time.Now().UnixNano() % 10000000000000
randomBytes := make([]byte, 4)
rand.Read(randomBytes)
randomHex := hex.EncodeToString(randomBytes)
// OKX clOrdId max 32 characters
orderID := fmt.Sprintf("%s%d%s", okxTag, timestamp, randomHex)
if len(orderID) > 32 {
orderID = orderID[:32]
}
return orderID
}
// NewOKXTrader creates OKX trader
func NewOKXTrader(apiKey, secretKey, passphrase string) *OKXTrader {
// Use default transport which respects system proxy settings
// OKX requires proxy in China due to DNS pollution
httpClient := &http.Client{
Timeout: 30 * time.Second,
Transport: http.DefaultTransport,
}
trader := &OKXTrader{
apiKey: apiKey,
secretKey: secretKey,
passphrase: passphrase,
isCrossMargin: true,
httpClient: httpClient,
cacheDuration: 15 * time.Second,
instrumentsCache: make(map[string]*OKXInstrument),
}
// Get current position mode first
if err := trader.detectPositionMode(); err != nil {
logger.Infof("⚠️ Failed to detect OKX position mode: %v, assuming dual mode", err)
trader.positionMode = "long_short_mode"
}
// Try to set dual position mode (only if not already)
if trader.positionMode != "long_short_mode" {
if err := trader.setPositionMode(); err != nil {
logger.Infof("⚠️ Failed to set OKX position mode: %v (current mode: %s)", err, trader.positionMode)
}
}
logger.Infof("✓ OKX trader initialized with position mode: %s, default margin mode: %s",
trader.positionMode, trader.marginMode())
return trader
}
func (t *OKXTrader) marginMode() string {
if t.isCrossMargin {
return "cross"
}
return "isolated"
}
// detectPositionMode gets current position mode from account config
func (t *OKXTrader) detectPositionMode() error {
data, err := t.doRequest("GET", okxAccountConfigPath, nil)
if err != nil {
return fmt.Errorf("failed to get account config: %w", err)
}
var configs []struct {
PosMode string `json:"posMode"`
}
if err := json.Unmarshal(data, &configs); err != nil {
return fmt.Errorf("failed to parse account config: %w", err)
}
if len(configs) > 0 {
t.positionMode = configs[0].PosMode
logger.Infof("✓ Detected OKX position mode: %s", t.positionMode)
}
return nil
}
// setPositionMode sets dual position mode
func (t *OKXTrader) setPositionMode() error {
body := map[string]string{
"posMode": "long_short_mode", // Dual position mode
}
_, err := t.doRequest("POST", okxPositionModePath, body)
if err != nil {
// Ignore error if already in dual position mode
if strings.Contains(err.Error(), "already") || strings.Contains(err.Error(), "Position mode is not modified") {
logger.Infof(" ✓ OKX account is already in dual position mode")
return nil
}
return err
}
logger.Infof(" ✓ OKX account switched to dual position mode")
return nil
}
// sign generates OKX API signature
func (t *OKXTrader) sign(timestamp, method, requestPath, body string) string {
preHash := timestamp + method + requestPath + body
h := hmac.New(sha256.New, []byte(t.secretKey))
h.Write([]byte(preHash))
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}
// doRequest executes HTTP request
func (t *OKXTrader) doRequest(method, path string, body interface{}) ([]byte, error) {
var bodyBytes []byte
var err error
if body != nil {
bodyBytes, err = json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("failed to serialize request body: %w", err)
}
}
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05.000Z")
signature := t.sign(timestamp, method, path, string(bodyBytes))
req, err := http.NewRequest(method, okxBaseURL+path, bytes.NewReader(bodyBytes))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("OK-ACCESS-KEY", t.apiKey)
req.Header.Set("OK-ACCESS-SIGN", signature)
req.Header.Set("OK-ACCESS-TIMESTAMP", timestamp)
req.Header.Set("OK-ACCESS-PASSPHRASE", t.passphrase)
req.Header.Set("Content-Type", "application/json")
// Set request header
req.Header.Set("x-simulated-trading", "0")
resp, err := t.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
var okxResp OKXResponse
if err := json.Unmarshal(respBody, &okxResp); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
// code=1 indicates partial success, need to check specific results in data
// code=2 indicates complete failure
if okxResp.Code != "0" && okxResp.Code != "1" {
return nil, fmt.Errorf("OKX API error: code=%s, msg=%s", okxResp.Code, okxResp.Msg)
}
return okxResp.Data, nil
}
// convertSymbol converts generic symbol to OKX format
// e.g. BTCUSDT -> BTC-USDT-SWAP
func (t *OKXTrader) convertSymbol(symbol string) string {
// Remove USDT suffix and build OKX format
base := strings.TrimSuffix(symbol, "USDT")
return fmt.Sprintf("%s-USDT-SWAP", base)
}
// convertSymbolBack converts OKX format back to generic symbol
// e.g. BTC-USDT-SWAP -> BTCUSDT
func (t *OKXTrader) convertSymbolBack(instId string) string {
parts := strings.Split(instId, "-")
if len(parts) >= 2 {
return parts[0] + parts[1]
}
return instId
}
// FormatQuantity formats quantity (converts base asset quantity to contract count)
func (t *OKXTrader) FormatQuantity(symbol string, quantity float64) (string, error) {
inst, err := t.getInstrument(symbol)
if err != nil {
return fmt.Sprintf("%.3f", quantity), nil
}
// OKX uses contract count: quantity (in base asset) / ctVal (asset per contract)
sz := quantity / inst.CtVal
return t.formatSize(sz, inst), nil
}
// formatSize formats contract size
func (t *OKXTrader) formatSize(sz float64, inst *OKXInstrument) string {
// Determine precision based on lotSz
if inst.LotSz >= 1 {
return fmt.Sprintf("%.0f", sz)
}
// Calculate decimal places
lotSzStr := fmt.Sprintf("%f", inst.LotSz)
dotIndex := strings.Index(lotSzStr, ".")
if dotIndex == -1 {
return fmt.Sprintf("%.0f", sz)
}
// Remove trailing zeros
lotSzStr = strings.TrimRight(lotSzStr, "0")
precision := len(lotSzStr) - dotIndex - 1
format := fmt.Sprintf("%%.%df", precision)
return fmt.Sprintf(format, sz)
}