diff --git a/mcp/client.go b/mcp/client.go index e914c79e..2d5c864c 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -227,6 +227,16 @@ func (client *Client) BuildMCPRequestBody(systemPrompt, userPrompt string) map[s "content": userPrompt, }) + // Guard: truncate messages if they would exceed the model's context window + if client.Cfg.MaxContext > 0 { + truncated, removed := truncateMessages(messages, client.Cfg.MaxContext, client.MaxTokens) + if removed > 0 { + client.Log.Warnf("⚠️ [%s] Context guard: truncated %d oldest messages to fit within %d token limit", + client.String(), removed, client.Cfg.MaxContext) + messages = truncated + } + } + // Build request body requestBody := map[string]interface{}{ "model": client.Model, @@ -575,6 +585,20 @@ func (client *Client) BuildRequestBodyFromRequest(req *Request) map[string]any { messages = append(messages, m) } + // Guard: truncate messages if they would exceed the model's context window + maxOut := client.MaxTokens + if req.MaxTokens != nil { + maxOut = *req.MaxTokens + } + if client.Cfg.MaxContext > 0 { + truncated, removed := truncateMessagesAny(messages, client.Cfg.MaxContext, maxOut) + if removed > 0 { + client.Log.Warnf("⚠️ [%s] Context guard: truncated %d oldest messages to fit within %d token limit", + client.String(), removed, client.Cfg.MaxContext) + messages = truncated + } + } + // Build basic request body requestBody := map[string]interface{}{ "model": req.Model, diff --git a/mcp/config.go b/mcp/config.go index 208b0180..2020f952 100644 --- a/mcp/config.go +++ b/mcp/config.go @@ -20,6 +20,7 @@ type Config struct { // Behavior configuration MaxTokens int + MaxContext int // Model's max context window in tokens (0 = no limit) Temperature float64 UseFullURL bool diff --git a/mcp/context_guard.go b/mcp/context_guard.go new file mode 100644 index 00000000..795bdac1 --- /dev/null +++ b/mcp/context_guard.go @@ -0,0 +1,147 @@ +package mcp + +import ( + "fmt" + "unicode/utf8" +) + +// estimateMessageTokens estimates the token count for a list of chat messages. +// Uses ~3 chars per token heuristic (conservative for mixed CJK/English text). +// Each message has ~10 tokens overhead for role/formatting. +func estimateMessageTokens(messages []map[string]string) int { + total := 0 + for _, msg := range messages { + content := msg["content"] + charCount := utf8.RuneCountInString(content) + total += charCount/3 + 10 // ~3 chars per token + overhead + } + return total +} + +// estimateMessageTokensAny is like estimateMessageTokens but for map[string]any messages +// (used by BuildRequestBodyFromRequest which needs tool_calls support). +func estimateMessageTokensAny(messages []map[string]any) int { + total := 0 + for _, msg := range messages { + content := fmt.Sprintf("%v", msg["content"]) + charCount := utf8.RuneCountInString(content) + total += charCount/3 + 10 + } + return total +} + +// truncateMessages removes oldest non-system messages until estimated tokens +// fit within the context limit. Returns the truncated messages and the number +// of messages removed. +// +// Rules: +// - Never removes system messages (role="system") +// - Removes from the oldest non-system message first +// - Keeps the most recent messages +// - Returns original messages unchanged if no truncation needed +func truncateMessages(messages []map[string]string, maxContext, maxTokens int) ([]map[string]string, int) { + if maxContext <= 0 { + return messages, 0 + } + + budget := maxContext - maxTokens + if budget <= 0 { + budget = maxContext / 2 // safety: at least half for input + } + + estimated := estimateMessageTokens(messages) + if estimated <= budget { + return messages, 0 + } + + // Separate system messages (keep all) from non-system (truncatable) + var systemMsgs []map[string]string + var otherMsgs []map[string]string + for _, msg := range messages { + if msg["role"] == "system" { + systemMsgs = append(systemMsgs, msg) + } else { + otherMsgs = append(otherMsgs, msg) + } + } + + // Calculate system message tokens (non-removable) + systemTokens := estimateMessageTokens(systemMsgs) + remainingBudget := budget - systemTokens + if remainingBudget <= 0 { + return messages, 0 + } + + // Remove oldest non-system messages until we fit + removed := 0 + for len(otherMsgs) > 1 { + currentTokens := estimateMessageTokens(otherMsgs) + if currentTokens <= remainingBudget { + break + } + otherMsgs = otherMsgs[1:] + removed++ + } + + if removed == 0 { + return messages, 0 + } + + result := make([]map[string]string, 0, len(systemMsgs)+len(otherMsgs)) + result = append(result, systemMsgs...) + result = append(result, otherMsgs...) + return result, removed +} + +// truncateMessagesAny is like truncateMessages but for map[string]any messages. +func truncateMessagesAny(messages []map[string]any, maxContext, maxTokens int) ([]map[string]any, int) { + if maxContext <= 0 { + return messages, 0 + } + + budget := maxContext - maxTokens + if budget <= 0 { + budget = maxContext / 2 + } + + estimated := estimateMessageTokensAny(messages) + if estimated <= budget { + return messages, 0 + } + + var systemMsgs []map[string]any + var otherMsgs []map[string]any + for _, msg := range messages { + role, _ := msg["role"].(string) + if role == "system" { + systemMsgs = append(systemMsgs, msg) + } else { + otherMsgs = append(otherMsgs, msg) + } + } + + systemTokens := estimateMessageTokensAny(systemMsgs) + remainingBudget := budget - systemTokens + if remainingBudget <= 0 { + return messages, 0 + } + + removed := 0 + for len(otherMsgs) > 1 { + currentTokens := estimateMessageTokensAny(otherMsgs) + if currentTokens <= remainingBudget { + break + } + otherMsgs = otherMsgs[1:] + removed++ + } + + if removed == 0 { + return messages, 0 + } + + result := make([]map[string]any, 0, len(systemMsgs)+len(otherMsgs)) + result = append(result, systemMsgs...) + result = append(result, otherMsgs...) + return result, removed +} diff --git a/mcp/context_guard_test.go b/mcp/context_guard_test.go new file mode 100644 index 00000000..d2b5cf73 --- /dev/null +++ b/mcp/context_guard_test.go @@ -0,0 +1,125 @@ +package mcp + +import ( + "strings" + "testing" +) + +func TestEstimateMessageTokens(t *testing.T) { + msgs := []map[string]string{ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"}, + } + tokens := estimateMessageTokens(msgs) + if tokens <= 0 { + t.Errorf("expected positive token count, got %d", tokens) + } + // "You are a helpful assistant." = 28 chars / 3 + 10 = ~19 + // "Hello, how are you?" = 19 chars / 3 + 10 = ~16 + // Total ~35 + if tokens < 20 || tokens > 60 { + t.Errorf("expected ~35 tokens, got %d", tokens) + } +} + +func TestTruncateMessages_NoTruncationNeeded(t *testing.T) { + msgs := []map[string]string{ + {"role": "system", "content": "Be helpful."}, + {"role": "user", "content": "Hi"}, + } + result, removed := truncateMessages(msgs, 131072, 2000) + if removed != 0 { + t.Errorf("expected no truncation, got %d removed", removed) + } + if len(result) != 2 { + t.Errorf("expected 2 messages, got %d", len(result)) + } +} + +func TestTruncateMessages_NoLimit(t *testing.T) { + msgs := []map[string]string{ + {"role": "user", "content": strings.Repeat("x", 1000000)}, + } + result, removed := truncateMessages(msgs, 0, 2000) + if removed != 0 { + t.Errorf("expected no truncation when maxContext=0, got %d removed", removed) + } + if len(result) != 1 { + t.Errorf("expected 1 message, got %d", len(result)) + } +} + +func TestTruncateMessages_TruncatesOldest(t *testing.T) { + // Create messages that definitely exceed a small context limit + msgs := []map[string]string{ + {"role": "system", "content": "System prompt"}, + {"role": "user", "content": strings.Repeat("old message ", 500)}, // ~2000 chars + {"role": "assistant", "content": strings.Repeat("old reply ", 500)}, // ~2000 chars + {"role": "user", "content": strings.Repeat("newer msg ", 500)}, // ~2000 chars + {"role": "assistant", "content": strings.Repeat("newer reply ", 500)}, // ~2000 chars + {"role": "user", "content": "latest question"}, + } + + // Set a small context limit that forces truncation + result, removed := truncateMessages(msgs, 2000, 500) + if removed == 0 { + t.Fatal("expected some messages to be truncated") + } + + // System message should always be preserved + if result[0]["role"] != "system" { + t.Error("system message should be first") + } + + // Last message should be the latest user message + last := result[len(result)-1] + if last["content"] != "latest question" { + t.Errorf("last message should be 'latest question', got '%s'", last["content"]) + } + + // Should have fewer messages than original + if len(result) >= len(msgs) { + t.Errorf("expected fewer messages after truncation, got %d (original %d)", len(result), len(msgs)) + } +} + +func TestTruncateMessages_PreservesSystemMessages(t *testing.T) { + msgs := []map[string]string{ + {"role": "system", "content": "System 1"}, + {"role": "system", "content": "System 2"}, + {"role": "user", "content": strings.Repeat("long msg ", 1000)}, + {"role": "user", "content": "short"}, + } + + result, _ := truncateMessages(msgs, 500, 100) + + // Count system messages - should all be preserved + systemCount := 0 + for _, msg := range result { + if msg["role"] == "system" { + systemCount++ + } + } + if systemCount != 2 { + t.Errorf("expected 2 system messages preserved, got %d", systemCount) + } +} + +func TestTruncateMessages_KeepsAtLeastOneNonSystem(t *testing.T) { + msgs := []map[string]string{ + {"role": "system", "content": "System"}, + {"role": "user", "content": strings.Repeat("very long ", 10000)}, + } + + result, _ := truncateMessages(msgs, 100, 50) + + nonSystem := 0 + for _, msg := range result { + if msg["role"] != "system" { + nonSystem++ + } + } + if nonSystem < 1 { + t.Error("should keep at least 1 non-system message") + } +} diff --git a/mcp/options.go b/mcp/options.go index 12ee9a30..335cec28 100644 --- a/mcp/options.go +++ b/mcp/options.go @@ -86,6 +86,19 @@ func WithMaxTokens(maxTokens int) ClientOption { } } +// WithMaxContext sets the model's max context window in tokens. +// When set (> 0), the client will automatically truncate oldest non-system +// messages if the estimated token count exceeds this limit. +// +// Usage example: +// +// client := mcp.NewClient(mcp.WithMaxContext(131072)) // DeepSeek 128K +func WithMaxContext(maxContext int) ClientOption { + return func(c *Config) { + c.MaxContext = maxContext + } +} + // WithTemperature sets temperature parameter // // Usage example: