feat(trader): implement margin mode handling for order and leverage settings

This commit is contained in:
Dean
2026-04-14 17:42:05 +08:00
parent e1b5a5d833
commit a3d8831b36
5 changed files with 204 additions and 14 deletions
+11
View File
@@ -324,6 +324,17 @@ func (at *AutoTrader) InitializeGrid() error {
at.gridState.IsInitialized = true
// Keep grid orders aligned with the trader's configured cross/isolated mode.
if err := at.trader.SetMarginMode(gridConfig.Symbol, at.config.IsCrossMargin); err != nil {
logger.Warnf("[Grid] Failed to set margin mode for %s: %v", gridConfig.Symbol, err)
} else {
marginMode := "cross"
if !at.config.IsCrossMargin {
marginMode = "isolated"
}
logger.Infof("[Grid] Margin mode set to %s for %s", marginMode, gridConfig.Symbol)
}
// CRITICAL: Set leverage on exchange before trading
if err := at.trader.SetLeverage(gridConfig.Symbol, gridConfig.Leverage); err != nil {
logger.Warnf("[Grid] Failed to set leverage %dx on exchange: %v", gridConfig.Leverage, err)
+11 -2
View File
@@ -41,7 +41,7 @@ type OKXTrader struct {
secretKey string
passphrase string
// Margin mode setting
// Margin mode setting used for new orders and leverage changes.
isCrossMargin bool
// Position mode: "long_short_mode" (hedge) or "net_mode" (one-way)
@@ -121,6 +121,7 @@ func NewOKXTrader(apiKey, secretKey, passphrase string) *OKXTrader {
apiKey: apiKey,
secretKey: secretKey,
passphrase: passphrase,
isCrossMargin: true,
httpClient: httpClient,
cacheDuration: 15 * time.Second,
instrumentsCache: make(map[string]*OKXInstrument),
@@ -139,10 +140,18 @@ func NewOKXTrader(apiKey, secretKey, passphrase string) *OKXTrader {
}
}
logger.Infof("✓ OKX trader initialized with position mode: %s", trader.positionMode)
logger.Infof("✓ OKX trader initialized with position mode: %s, default margin mode: %s",
trader.positionMode, trader.marginMode())
return trader
}
func (t *OKXTrader) marginMode() string {
if t.isCrossMargin {
return "cross"
}
return "isolated"
}
// detectPositionMode gets current position mode from account config
func (t *OKXTrader) detectPositionMode() error {
data, err := t.doRequest("GET", okxAccountConfigPath, nil)
+5 -7
View File
@@ -83,11 +83,8 @@ func (t *OKXTrader) GetBalance() (map[string]interface{}, error) {
// SetMarginMode sets margin mode
func (t *OKXTrader) SetMarginMode(symbol string, isCrossMargin bool) error {
instId := t.convertSymbol(symbol)
mgnMode := "isolated"
if isCrossMargin {
mgnMode = "cross"
}
t.isCrossMargin = isCrossMargin
mgnMode := t.marginMode()
body := map[string]interface{}{
"instId": instId,
@@ -116,13 +113,14 @@ func (t *OKXTrader) SetMarginMode(symbol string, isCrossMargin bool) error {
// SetLeverage sets leverage
func (t *OKXTrader) SetLeverage(symbol string, leverage int) error {
instId := t.convertSymbol(symbol)
marginMode := t.marginMode()
// Set leverage for both long and short
for _, posSide := range []string{"long", "short"} {
body := map[string]interface{}{
"instId": instId,
"lever": strconv.Itoa(leverage),
"mgnMode": "cross",
"mgnMode": marginMode,
"posSide": posSide,
}
@@ -136,7 +134,7 @@ func (t *OKXTrader) SetLeverage(symbol string, leverage int) error {
}
}
logger.Infof(" ✓ %s leverage set to %dx", symbol, leverage)
logger.Infof(" ✓ %s leverage set to %dx (%s)", symbol, leverage, marginMode)
return nil
}
+162
View File
@@ -0,0 +1,162 @@
package okx
import (
"bytes"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"time"
"nofx/trader/types"
)
type capturedRequest struct {
Method string
Path string
Body map[string]interface{}
}
type recordingTransport struct {
requests []capturedRequest
}
func (rt *recordingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
var body map[string]interface{}
if req.Body != nil {
data, _ := io.ReadAll(req.Body)
if len(data) > 0 && strings.HasPrefix(strings.TrimSpace(string(data)), "{") {
_ = json.Unmarshal(data, &body)
}
}
rt.requests = append(rt.requests, capturedRequest{
Method: req.Method,
Path: req.URL.Path,
Body: body,
})
response := `{"code":"0","msg":"","data":[]}`
switch req.URL.Path {
case okxInstrumentsPath:
response = `{"code":"0","msg":"","data":[{"instId":"BTC-USDT-SWAP","ctVal":"0.01","ctMult":"1","lotSz":"1","minSz":"1","maxMktSz":"100000","tickSz":"0.1","ctType":"linear"}]}`
case okxOrderPath:
response = `{"code":"0","msg":"","data":[{"ordId":"123","clOrdId":"abc","sCode":"0","sMsg":""}]}`
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(bytes.NewBufferString(response)),
}, nil
}
func (rt *recordingTransport) requestsForPath(path string) []capturedRequest {
var matches []capturedRequest
for _, req := range rt.requests {
if req.Path == path {
matches = append(matches, req)
}
}
return matches
}
func newTestOKXTrader(rt *recordingTransport, isCrossMargin bool) *OKXTrader {
return &OKXTrader{
apiKey: "key",
secretKey: "secret",
passphrase: "pass",
isCrossMargin: isCrossMargin,
positionMode: "long_short_mode",
httpClient: &http.Client{
Transport: rt,
},
cacheDuration: 15 * time.Second,
instrumentsCache: make(map[string]*OKXInstrument),
instrumentsCacheTime: time.Now(),
}
}
func TestOKXSetLeverageUsesConfiguredMarginMode(t *testing.T) {
rt := &recordingTransport{}
trader := newTestOKXTrader(rt, false)
if err := trader.SetLeverage("BTCUSDT", 5); err != nil {
t.Fatalf("SetLeverage failed: %v", err)
}
leverageRequests := rt.requestsForPath(okxLeveragePath)
if len(leverageRequests) != 2 {
t.Fatalf("expected 2 leverage requests, got %d", len(leverageRequests))
}
for _, req := range leverageRequests {
if req.Body["mgnMode"] != "isolated" {
t.Fatalf("expected isolated leverage mode, got %#v", req.Body["mgnMode"])
}
}
}
func TestOKXOpenLongUsesConfiguredMarginMode(t *testing.T) {
rt := &recordingTransport{}
trader := newTestOKXTrader(rt, false)
if _, err := trader.OpenLong("BTCUSDT", 0.1, 5); err != nil {
t.Fatalf("OpenLong failed: %v", err)
}
orderRequests := rt.requestsForPath(okxOrderPath)
if len(orderRequests) == 0 {
t.Fatal("expected at least one order request")
}
lastOrder := orderRequests[len(orderRequests)-1]
if lastOrder.Body["tdMode"] != "isolated" {
t.Fatalf("expected isolated tdMode, got %#v", lastOrder.Body["tdMode"])
}
}
func TestOKXSetStopLossUsesConfiguredMarginMode(t *testing.T) {
rt := &recordingTransport{}
trader := newTestOKXTrader(rt, false)
if err := trader.SetStopLoss("BTCUSDT", "LONG", 0.1, 90000); err != nil {
t.Fatalf("SetStopLoss failed: %v", err)
}
algoRequests := rt.requestsForPath(okxAlgoOrderPath)
if len(algoRequests) != 1 {
t.Fatalf("expected 1 algo order request, got %d", len(algoRequests))
}
if algoRequests[0].Body["tdMode"] != "isolated" {
t.Fatalf("expected isolated tdMode, got %#v", algoRequests[0].Body["tdMode"])
}
}
func TestOKXPlaceLimitOrderUsesConfiguredMarginMode(t *testing.T) {
rt := &recordingTransport{}
trader := newTestOKXTrader(rt, false)
_, err := trader.PlaceLimitOrder(&types.LimitOrderRequest{
Symbol: "BTCUSDT",
Side: "BUY",
PositionSide: "LONG",
Price: 95000,
Quantity: 0.1,
Leverage: 3,
})
if err != nil {
t.Fatalf("PlaceLimitOrder failed: %v", err)
}
orderRequests := rt.requestsForPath(okxOrderPath)
if len(orderRequests) != 1 {
t.Fatalf("expected 1 limit order request, got %d", len(orderRequests))
}
if orderRequests[0].Body["tdMode"] != "isolated" {
t.Fatalf("expected isolated tdMode, got %#v", orderRequests[0].Body["tdMode"])
}
}
+15 -5
View File
@@ -41,9 +41,11 @@ func (t *OKXTrader) OpenLong(symbol string, quantity float64, leverage int) (map
szStr = t.formatSize(sz, inst)
}
marginMode := t.marginMode()
body := map[string]interface{}{
"instId": instId,
"tdMode": "cross",
"tdMode": marginMode,
"side": "buy",
"posSide": "long",
"ordType": "market",
@@ -118,9 +120,11 @@ func (t *OKXTrader) OpenShort(symbol string, quantity float64, leverage int) (ma
szStr = t.formatSize(sz, inst)
}
marginMode := t.marginMode()
body := map[string]interface{}{
"instId": instId,
"tdMode": "cross",
"tdMode": marginMode,
"side": "sell",
"posSide": "short",
"ordType": "market",
@@ -410,9 +414,11 @@ func (t *OKXTrader) SetStopLoss(symbol string, positionSide string, quantity, st
posSide = "short"
}
marginMode := t.marginMode()
body := map[string]interface{}{
"instId": instId,
"tdMode": "cross",
"tdMode": marginMode,
"side": side,
"posSide": posSide,
"ordType": "conditional",
@@ -453,9 +459,11 @@ func (t *OKXTrader) SetTakeProfit(symbol string, positionSide string, quantity,
posSide = "short"
}
marginMode := t.marginMode()
body := map[string]interface{}{
"instId": instId,
"tdMode": "cross",
"tdMode": marginMode,
"side": side,
"posSide": posSide,
"ordType": "conditional",
@@ -815,9 +823,11 @@ func (t *OKXTrader) PlaceLimitOrder(req *types.LimitOrderRequest) (*types.LimitO
posSide = "short"
}
marginMode := t.marginMode()
body := map[string]interface{}{
"instId": instId,
"tdMode": "cross",
"tdMode": marginMode,
"side": side,
"posSide": posSide,
"ordType": "limit",