diff --git a/trader/auto_trader_grid.go b/trader/auto_trader_grid.go index 6151dbbb..c16a5d8d 100644 --- a/trader/auto_trader_grid.go +++ b/trader/auto_trader_grid.go @@ -324,6 +324,17 @@ func (at *AutoTrader) InitializeGrid() error { at.gridState.IsInitialized = true + // Keep grid orders aligned with the trader's configured cross/isolated mode. + if err := at.trader.SetMarginMode(gridConfig.Symbol, at.config.IsCrossMargin); err != nil { + logger.Warnf("[Grid] Failed to set margin mode for %s: %v", gridConfig.Symbol, err) + } else { + marginMode := "cross" + if !at.config.IsCrossMargin { + marginMode = "isolated" + } + logger.Infof("[Grid] Margin mode set to %s for %s", marginMode, gridConfig.Symbol) + } + // CRITICAL: Set leverage on exchange before trading if err := at.trader.SetLeverage(gridConfig.Symbol, gridConfig.Leverage); err != nil { logger.Warnf("[Grid] Failed to set leverage %dx on exchange: %v", gridConfig.Leverage, err) diff --git a/trader/okx/trader.go b/trader/okx/trader.go index b46f23c4..41e4a45f 100644 --- a/trader/okx/trader.go +++ b/trader/okx/trader.go @@ -41,7 +41,7 @@ type OKXTrader struct { secretKey string passphrase string - // Margin mode setting + // Margin mode setting used for new orders and leverage changes. isCrossMargin bool // Position mode: "long_short_mode" (hedge) or "net_mode" (one-way) @@ -121,6 +121,7 @@ func NewOKXTrader(apiKey, secretKey, passphrase string) *OKXTrader { apiKey: apiKey, secretKey: secretKey, passphrase: passphrase, + isCrossMargin: true, httpClient: httpClient, cacheDuration: 15 * time.Second, instrumentsCache: make(map[string]*OKXInstrument), @@ -139,10 +140,18 @@ func NewOKXTrader(apiKey, secretKey, passphrase string) *OKXTrader { } } - logger.Infof("✓ OKX trader initialized with position mode: %s", trader.positionMode) + logger.Infof("✓ OKX trader initialized with position mode: %s, default margin mode: %s", + trader.positionMode, trader.marginMode()) return trader } +func (t *OKXTrader) marginMode() string { + if t.isCrossMargin { + return "cross" + } + return "isolated" +} + // detectPositionMode gets current position mode from account config func (t *OKXTrader) detectPositionMode() error { data, err := t.doRequest("GET", okxAccountConfigPath, nil) diff --git a/trader/okx/trader_account.go b/trader/okx/trader_account.go index e10fadf0..53fa0550 100644 --- a/trader/okx/trader_account.go +++ b/trader/okx/trader_account.go @@ -83,11 +83,8 @@ func (t *OKXTrader) GetBalance() (map[string]interface{}, error) { // SetMarginMode sets margin mode func (t *OKXTrader) SetMarginMode(symbol string, isCrossMargin bool) error { instId := t.convertSymbol(symbol) - - mgnMode := "isolated" - if isCrossMargin { - mgnMode = "cross" - } + t.isCrossMargin = isCrossMargin + mgnMode := t.marginMode() body := map[string]interface{}{ "instId": instId, @@ -116,13 +113,14 @@ func (t *OKXTrader) SetMarginMode(symbol string, isCrossMargin bool) error { // SetLeverage sets leverage func (t *OKXTrader) SetLeverage(symbol string, leverage int) error { instId := t.convertSymbol(symbol) + marginMode := t.marginMode() // Set leverage for both long and short for _, posSide := range []string{"long", "short"} { body := map[string]interface{}{ "instId": instId, "lever": strconv.Itoa(leverage), - "mgnMode": "cross", + "mgnMode": marginMode, "posSide": posSide, } @@ -136,7 +134,7 @@ func (t *OKXTrader) SetLeverage(symbol string, leverage int) error { } } - logger.Infof(" ✓ %s leverage set to %dx", symbol, leverage) + logger.Infof(" ✓ %s leverage set to %dx (%s)", symbol, leverage, marginMode) return nil } diff --git a/trader/okx/trader_margin_mode_test.go b/trader/okx/trader_margin_mode_test.go new file mode 100644 index 00000000..b0a78305 --- /dev/null +++ b/trader/okx/trader_margin_mode_test.go @@ -0,0 +1,162 @@ +package okx + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "strings" + "testing" + "time" + + "nofx/trader/types" +) + +type capturedRequest struct { + Method string + Path string + Body map[string]interface{} +} + +type recordingTransport struct { + requests []capturedRequest +} + +func (rt *recordingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + var body map[string]interface{} + if req.Body != nil { + data, _ := io.ReadAll(req.Body) + if len(data) > 0 && strings.HasPrefix(strings.TrimSpace(string(data)), "{") { + _ = json.Unmarshal(data, &body) + } + } + + rt.requests = append(rt.requests, capturedRequest{ + Method: req.Method, + Path: req.URL.Path, + Body: body, + }) + + response := `{"code":"0","msg":"","data":[]}` + switch req.URL.Path { + case okxInstrumentsPath: + response = `{"code":"0","msg":"","data":[{"instId":"BTC-USDT-SWAP","ctVal":"0.01","ctMult":"1","lotSz":"1","minSz":"1","maxMktSz":"100000","tickSz":"0.1","ctType":"linear"}]}` + case okxOrderPath: + response = `{"code":"0","msg":"","data":[{"ordId":"123","clOrdId":"abc","sCode":"0","sMsg":""}]}` + } + + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewBufferString(response)), + }, nil +} + +func (rt *recordingTransport) requestsForPath(path string) []capturedRequest { + var matches []capturedRequest + for _, req := range rt.requests { + if req.Path == path { + matches = append(matches, req) + } + } + return matches +} + +func newTestOKXTrader(rt *recordingTransport, isCrossMargin bool) *OKXTrader { + return &OKXTrader{ + apiKey: "key", + secretKey: "secret", + passphrase: "pass", + isCrossMargin: isCrossMargin, + positionMode: "long_short_mode", + httpClient: &http.Client{ + Transport: rt, + }, + cacheDuration: 15 * time.Second, + instrumentsCache: make(map[string]*OKXInstrument), + instrumentsCacheTime: time.Now(), + } +} + +func TestOKXSetLeverageUsesConfiguredMarginMode(t *testing.T) { + rt := &recordingTransport{} + trader := newTestOKXTrader(rt, false) + + if err := trader.SetLeverage("BTCUSDT", 5); err != nil { + t.Fatalf("SetLeverage failed: %v", err) + } + + leverageRequests := rt.requestsForPath(okxLeveragePath) + if len(leverageRequests) != 2 { + t.Fatalf("expected 2 leverage requests, got %d", len(leverageRequests)) + } + + for _, req := range leverageRequests { + if req.Body["mgnMode"] != "isolated" { + t.Fatalf("expected isolated leverage mode, got %#v", req.Body["mgnMode"]) + } + } +} + +func TestOKXOpenLongUsesConfiguredMarginMode(t *testing.T) { + rt := &recordingTransport{} + trader := newTestOKXTrader(rt, false) + + if _, err := trader.OpenLong("BTCUSDT", 0.1, 5); err != nil { + t.Fatalf("OpenLong failed: %v", err) + } + + orderRequests := rt.requestsForPath(okxOrderPath) + if len(orderRequests) == 0 { + t.Fatal("expected at least one order request") + } + + lastOrder := orderRequests[len(orderRequests)-1] + if lastOrder.Body["tdMode"] != "isolated" { + t.Fatalf("expected isolated tdMode, got %#v", lastOrder.Body["tdMode"]) + } +} + +func TestOKXSetStopLossUsesConfiguredMarginMode(t *testing.T) { + rt := &recordingTransport{} + trader := newTestOKXTrader(rt, false) + + if err := trader.SetStopLoss("BTCUSDT", "LONG", 0.1, 90000); err != nil { + t.Fatalf("SetStopLoss failed: %v", err) + } + + algoRequests := rt.requestsForPath(okxAlgoOrderPath) + if len(algoRequests) != 1 { + t.Fatalf("expected 1 algo order request, got %d", len(algoRequests)) + } + + if algoRequests[0].Body["tdMode"] != "isolated" { + t.Fatalf("expected isolated tdMode, got %#v", algoRequests[0].Body["tdMode"]) + } +} + +func TestOKXPlaceLimitOrderUsesConfiguredMarginMode(t *testing.T) { + rt := &recordingTransport{} + trader := newTestOKXTrader(rt, false) + + _, err := trader.PlaceLimitOrder(&types.LimitOrderRequest{ + Symbol: "BTCUSDT", + Side: "BUY", + PositionSide: "LONG", + Price: 95000, + Quantity: 0.1, + Leverage: 3, + }) + if err != nil { + t.Fatalf("PlaceLimitOrder failed: %v", err) + } + + orderRequests := rt.requestsForPath(okxOrderPath) + if len(orderRequests) != 1 { + t.Fatalf("expected 1 limit order request, got %d", len(orderRequests)) + } + + if orderRequests[0].Body["tdMode"] != "isolated" { + t.Fatalf("expected isolated tdMode, got %#v", orderRequests[0].Body["tdMode"]) + } +} diff --git a/trader/okx/trader_orders.go b/trader/okx/trader_orders.go index 33697acc..2676db3f 100644 --- a/trader/okx/trader_orders.go +++ b/trader/okx/trader_orders.go @@ -41,9 +41,11 @@ func (t *OKXTrader) OpenLong(symbol string, quantity float64, leverage int) (map szStr = t.formatSize(sz, inst) } + marginMode := t.marginMode() + body := map[string]interface{}{ "instId": instId, - "tdMode": "cross", + "tdMode": marginMode, "side": "buy", "posSide": "long", "ordType": "market", @@ -118,9 +120,11 @@ func (t *OKXTrader) OpenShort(symbol string, quantity float64, leverage int) (ma szStr = t.formatSize(sz, inst) } + marginMode := t.marginMode() + body := map[string]interface{}{ "instId": instId, - "tdMode": "cross", + "tdMode": marginMode, "side": "sell", "posSide": "short", "ordType": "market", @@ -410,9 +414,11 @@ func (t *OKXTrader) SetStopLoss(symbol string, positionSide string, quantity, st posSide = "short" } + marginMode := t.marginMode() + body := map[string]interface{}{ "instId": instId, - "tdMode": "cross", + "tdMode": marginMode, "side": side, "posSide": posSide, "ordType": "conditional", @@ -453,9 +459,11 @@ func (t *OKXTrader) SetTakeProfit(symbol string, positionSide string, quantity, posSide = "short" } + marginMode := t.marginMode() + body := map[string]interface{}{ "instId": instId, - "tdMode": "cross", + "tdMode": marginMode, "side": side, "posSide": posSide, "ordType": "conditional", @@ -815,9 +823,11 @@ func (t *OKXTrader) PlaceLimitOrder(req *types.LimitOrderRequest) (*types.LimitO posSide = "short" } + marginMode := t.marginMode() + body := map[string]interface{}{ "instId": instId, - "tdMode": "cross", + "tdMode": marginMode, "side": side, "posSide": posSide, "ordType": "limit",