feat(mcp): add context length guard to prevent oversized requests

* feat: add X-Client-ID header for claw402 monitoring

* feat(mcp): add context length guard to prevent oversized requests

- Add MaxContext field to Config (default 0 = no limit)
- Add WithMaxContext() option for setting model context limits
- Add context_guard.go: token estimation + message truncation
- Integrate guard into both BuildMCPRequestBody and BuildRequestBodyFromRequest
- Support both map[string]string and map[string]any message formats
- Truncates oldest non-system messages when estimated tokens exceed limit
- Always preserves system messages and keeps at least 1 non-system message
- Logs warning when truncation occurs for debugging

Usage: mcp.NewDeepSeekClient(mcp.WithMaxContext(131072))
This commit is contained in:
shinchan-zhai
2026-03-18 11:10:22 +08:00
committed by GitHub
parent d5fbe445e1
commit 16ebe0a64c
5 changed files with 310 additions and 0 deletions
+24
View File
@@ -227,6 +227,16 @@ func (client *Client) BuildMCPRequestBody(systemPrompt, userPrompt string) map[s
"content": userPrompt, "content": userPrompt,
}) })
// Guard: truncate messages if they would exceed the model's context window
if client.Cfg.MaxContext > 0 {
truncated, removed := truncateMessages(messages, client.Cfg.MaxContext, client.MaxTokens)
if removed > 0 {
client.Log.Warnf("⚠️ [%s] Context guard: truncated %d oldest messages to fit within %d token limit",
client.String(), removed, client.Cfg.MaxContext)
messages = truncated
}
}
// Build request body // Build request body
requestBody := map[string]interface{}{ requestBody := map[string]interface{}{
"model": client.Model, "model": client.Model,
@@ -575,6 +585,20 @@ func (client *Client) BuildRequestBodyFromRequest(req *Request) map[string]any {
messages = append(messages, m) messages = append(messages, m)
} }
// Guard: truncate messages if they would exceed the model's context window
maxOut := client.MaxTokens
if req.MaxTokens != nil {
maxOut = *req.MaxTokens
}
if client.Cfg.MaxContext > 0 {
truncated, removed := truncateMessagesAny(messages, client.Cfg.MaxContext, maxOut)
if removed > 0 {
client.Log.Warnf("⚠️ [%s] Context guard: truncated %d oldest messages to fit within %d token limit",
client.String(), removed, client.Cfg.MaxContext)
messages = truncated
}
}
// Build basic request body // Build basic request body
requestBody := map[string]interface{}{ requestBody := map[string]interface{}{
"model": req.Model, "model": req.Model,
+1
View File
@@ -20,6 +20,7 @@ type Config struct {
// Behavior configuration // Behavior configuration
MaxTokens int MaxTokens int
MaxContext int // Model's max context window in tokens (0 = no limit)
Temperature float64 Temperature float64
UseFullURL bool UseFullURL bool
+147
View File
@@ -0,0 +1,147 @@
package mcp
import (
"fmt"
"unicode/utf8"
)
// estimateMessageTokens estimates the token count for a list of chat messages.
// Uses ~3 chars per token heuristic (conservative for mixed CJK/English text).
// Each message has ~10 tokens overhead for role/formatting.
func estimateMessageTokens(messages []map[string]string) int {
total := 0
for _, msg := range messages {
content := msg["content"]
charCount := utf8.RuneCountInString(content)
total += charCount/3 + 10 // ~3 chars per token + overhead
}
return total
}
// estimateMessageTokensAny is like estimateMessageTokens but for map[string]any messages
// (used by BuildRequestBodyFromRequest which needs tool_calls support).
func estimateMessageTokensAny(messages []map[string]any) int {
total := 0
for _, msg := range messages {
content := fmt.Sprintf("%v", msg["content"])
charCount := utf8.RuneCountInString(content)
total += charCount/3 + 10
}
return total
}
// truncateMessages removes oldest non-system messages until estimated tokens
// fit within the context limit. Returns the truncated messages and the number
// of messages removed.
//
// Rules:
// - Never removes system messages (role="system")
// - Removes from the oldest non-system message first
// - Keeps the most recent messages
// - Returns original messages unchanged if no truncation needed
func truncateMessages(messages []map[string]string, maxContext, maxTokens int) ([]map[string]string, int) {
if maxContext <= 0 {
return messages, 0
}
budget := maxContext - maxTokens
if budget <= 0 {
budget = maxContext / 2 // safety: at least half for input
}
estimated := estimateMessageTokens(messages)
if estimated <= budget {
return messages, 0
}
// Separate system messages (keep all) from non-system (truncatable)
var systemMsgs []map[string]string
var otherMsgs []map[string]string
for _, msg := range messages {
if msg["role"] == "system" {
systemMsgs = append(systemMsgs, msg)
} else {
otherMsgs = append(otherMsgs, msg)
}
}
// Calculate system message tokens (non-removable)
systemTokens := estimateMessageTokens(systemMsgs)
remainingBudget := budget - systemTokens
if remainingBudget <= 0 {
return messages, 0
}
// Remove oldest non-system messages until we fit
removed := 0
for len(otherMsgs) > 1 {
currentTokens := estimateMessageTokens(otherMsgs)
if currentTokens <= remainingBudget {
break
}
otherMsgs = otherMsgs[1:]
removed++
}
if removed == 0 {
return messages, 0
}
result := make([]map[string]string, 0, len(systemMsgs)+len(otherMsgs))
result = append(result, systemMsgs...)
result = append(result, otherMsgs...)
return result, removed
}
// truncateMessagesAny is like truncateMessages but for map[string]any messages.
func truncateMessagesAny(messages []map[string]any, maxContext, maxTokens int) ([]map[string]any, int) {
if maxContext <= 0 {
return messages, 0
}
budget := maxContext - maxTokens
if budget <= 0 {
budget = maxContext / 2
}
estimated := estimateMessageTokensAny(messages)
if estimated <= budget {
return messages, 0
}
var systemMsgs []map[string]any
var otherMsgs []map[string]any
for _, msg := range messages {
role, _ := msg["role"].(string)
if role == "system" {
systemMsgs = append(systemMsgs, msg)
} else {
otherMsgs = append(otherMsgs, msg)
}
}
systemTokens := estimateMessageTokensAny(systemMsgs)
remainingBudget := budget - systemTokens
if remainingBudget <= 0 {
return messages, 0
}
removed := 0
for len(otherMsgs) > 1 {
currentTokens := estimateMessageTokensAny(otherMsgs)
if currentTokens <= remainingBudget {
break
}
otherMsgs = otherMsgs[1:]
removed++
}
if removed == 0 {
return messages, 0
}
result := make([]map[string]any, 0, len(systemMsgs)+len(otherMsgs))
result = append(result, systemMsgs...)
result = append(result, otherMsgs...)
return result, removed
}
+125
View File
@@ -0,0 +1,125 @@
package mcp
import (
"strings"
"testing"
)
func TestEstimateMessageTokens(t *testing.T) {
msgs := []map[string]string{
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
}
tokens := estimateMessageTokens(msgs)
if tokens <= 0 {
t.Errorf("expected positive token count, got %d", tokens)
}
// "You are a helpful assistant." = 28 chars / 3 + 10 = ~19
// "Hello, how are you?" = 19 chars / 3 + 10 = ~16
// Total ~35
if tokens < 20 || tokens > 60 {
t.Errorf("expected ~35 tokens, got %d", tokens)
}
}
func TestTruncateMessages_NoTruncationNeeded(t *testing.T) {
msgs := []map[string]string{
{"role": "system", "content": "Be helpful."},
{"role": "user", "content": "Hi"},
}
result, removed := truncateMessages(msgs, 131072, 2000)
if removed != 0 {
t.Errorf("expected no truncation, got %d removed", removed)
}
if len(result) != 2 {
t.Errorf("expected 2 messages, got %d", len(result))
}
}
func TestTruncateMessages_NoLimit(t *testing.T) {
msgs := []map[string]string{
{"role": "user", "content": strings.Repeat("x", 1000000)},
}
result, removed := truncateMessages(msgs, 0, 2000)
if removed != 0 {
t.Errorf("expected no truncation when maxContext=0, got %d removed", removed)
}
if len(result) != 1 {
t.Errorf("expected 1 message, got %d", len(result))
}
}
func TestTruncateMessages_TruncatesOldest(t *testing.T) {
// Create messages that definitely exceed a small context limit
msgs := []map[string]string{
{"role": "system", "content": "System prompt"},
{"role": "user", "content": strings.Repeat("old message ", 500)}, // ~2000 chars
{"role": "assistant", "content": strings.Repeat("old reply ", 500)}, // ~2000 chars
{"role": "user", "content": strings.Repeat("newer msg ", 500)}, // ~2000 chars
{"role": "assistant", "content": strings.Repeat("newer reply ", 500)}, // ~2000 chars
{"role": "user", "content": "latest question"},
}
// Set a small context limit that forces truncation
result, removed := truncateMessages(msgs, 2000, 500)
if removed == 0 {
t.Fatal("expected some messages to be truncated")
}
// System message should always be preserved
if result[0]["role"] != "system" {
t.Error("system message should be first")
}
// Last message should be the latest user message
last := result[len(result)-1]
if last["content"] != "latest question" {
t.Errorf("last message should be 'latest question', got '%s'", last["content"])
}
// Should have fewer messages than original
if len(result) >= len(msgs) {
t.Errorf("expected fewer messages after truncation, got %d (original %d)", len(result), len(msgs))
}
}
func TestTruncateMessages_PreservesSystemMessages(t *testing.T) {
msgs := []map[string]string{
{"role": "system", "content": "System 1"},
{"role": "system", "content": "System 2"},
{"role": "user", "content": strings.Repeat("long msg ", 1000)},
{"role": "user", "content": "short"},
}
result, _ := truncateMessages(msgs, 500, 100)
// Count system messages - should all be preserved
systemCount := 0
for _, msg := range result {
if msg["role"] == "system" {
systemCount++
}
}
if systemCount != 2 {
t.Errorf("expected 2 system messages preserved, got %d", systemCount)
}
}
func TestTruncateMessages_KeepsAtLeastOneNonSystem(t *testing.T) {
msgs := []map[string]string{
{"role": "system", "content": "System"},
{"role": "user", "content": strings.Repeat("very long ", 10000)},
}
result, _ := truncateMessages(msgs, 100, 50)
nonSystem := 0
for _, msg := range result {
if msg["role"] != "system" {
nonSystem++
}
}
if nonSystem < 1 {
t.Error("should keep at least 1 non-system message")
}
}
+13
View File
@@ -86,6 +86,19 @@ func WithMaxTokens(maxTokens int) ClientOption {
} }
} }
// WithMaxContext sets the model's max context window in tokens.
// When set (> 0), the client will automatically truncate oldest non-system
// messages if the estimated token count exceeds this limit.
//
// Usage example:
//
// client := mcp.NewClient(mcp.WithMaxContext(131072)) // DeepSeek 128K
func WithMaxContext(maxContext int) ClientOption {
return func(c *Config) {
c.MaxContext = maxContext
}
}
// WithTemperature sets temperature parameter // WithTemperature sets temperature parameter
// //
// Usage example: // Usage example: