mirror of
https://github.com/laoxong/nofx.git
synced 2026-06-04 09:58:22 +08:00
feat(mcp): add context length guard to prevent oversized requests
* feat: add X-Client-ID header for claw402 monitoring * feat(mcp): add context length guard to prevent oversized requests - Add MaxContext field to Config (default 0 = no limit) - Add WithMaxContext() option for setting model context limits - Add context_guard.go: token estimation + message truncation - Integrate guard into both BuildMCPRequestBody and BuildRequestBodyFromRequest - Support both map[string]string and map[string]any message formats - Truncates oldest non-system messages when estimated tokens exceed limit - Always preserves system messages and keeps at least 1 non-system message - Logs warning when truncation occurs for debugging Usage: mcp.NewDeepSeekClient(mcp.WithMaxContext(131072))
This commit is contained in:
@@ -227,6 +227,16 @@ func (client *Client) BuildMCPRequestBody(systemPrompt, userPrompt string) map[s
|
||||
"content": userPrompt,
|
||||
})
|
||||
|
||||
// Guard: truncate messages if they would exceed the model's context window
|
||||
if client.Cfg.MaxContext > 0 {
|
||||
truncated, removed := truncateMessages(messages, client.Cfg.MaxContext, client.MaxTokens)
|
||||
if removed > 0 {
|
||||
client.Log.Warnf("⚠️ [%s] Context guard: truncated %d oldest messages to fit within %d token limit",
|
||||
client.String(), removed, client.Cfg.MaxContext)
|
||||
messages = truncated
|
||||
}
|
||||
}
|
||||
|
||||
// Build request body
|
||||
requestBody := map[string]interface{}{
|
||||
"model": client.Model,
|
||||
@@ -575,6 +585,20 @@ func (client *Client) BuildRequestBodyFromRequest(req *Request) map[string]any {
|
||||
messages = append(messages, m)
|
||||
}
|
||||
|
||||
// Guard: truncate messages if they would exceed the model's context window
|
||||
maxOut := client.MaxTokens
|
||||
if req.MaxTokens != nil {
|
||||
maxOut = *req.MaxTokens
|
||||
}
|
||||
if client.Cfg.MaxContext > 0 {
|
||||
truncated, removed := truncateMessagesAny(messages, client.Cfg.MaxContext, maxOut)
|
||||
if removed > 0 {
|
||||
client.Log.Warnf("⚠️ [%s] Context guard: truncated %d oldest messages to fit within %d token limit",
|
||||
client.String(), removed, client.Cfg.MaxContext)
|
||||
messages = truncated
|
||||
}
|
||||
}
|
||||
|
||||
// Build basic request body
|
||||
requestBody := map[string]interface{}{
|
||||
"model": req.Model,
|
||||
|
||||
@@ -20,6 +20,7 @@ type Config struct {
|
||||
|
||||
// Behavior configuration
|
||||
MaxTokens int
|
||||
MaxContext int // Model's max context window in tokens (0 = no limit)
|
||||
Temperature float64
|
||||
UseFullURL bool
|
||||
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// estimateMessageTokens estimates the token count for a list of chat messages.
|
||||
// Uses ~3 chars per token heuristic (conservative for mixed CJK/English text).
|
||||
// Each message has ~10 tokens overhead for role/formatting.
|
||||
func estimateMessageTokens(messages []map[string]string) int {
|
||||
total := 0
|
||||
for _, msg := range messages {
|
||||
content := msg["content"]
|
||||
charCount := utf8.RuneCountInString(content)
|
||||
total += charCount/3 + 10 // ~3 chars per token + overhead
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// estimateMessageTokensAny is like estimateMessageTokens but for map[string]any messages
|
||||
// (used by BuildRequestBodyFromRequest which needs tool_calls support).
|
||||
func estimateMessageTokensAny(messages []map[string]any) int {
|
||||
total := 0
|
||||
for _, msg := range messages {
|
||||
content := fmt.Sprintf("%v", msg["content"])
|
||||
charCount := utf8.RuneCountInString(content)
|
||||
total += charCount/3 + 10
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// truncateMessages removes oldest non-system messages until estimated tokens
|
||||
// fit within the context limit. Returns the truncated messages and the number
|
||||
// of messages removed.
|
||||
//
|
||||
// Rules:
|
||||
// - Never removes system messages (role="system")
|
||||
// - Removes from the oldest non-system message first
|
||||
// - Keeps the most recent messages
|
||||
// - Returns original messages unchanged if no truncation needed
|
||||
func truncateMessages(messages []map[string]string, maxContext, maxTokens int) ([]map[string]string, int) {
|
||||
if maxContext <= 0 {
|
||||
return messages, 0
|
||||
}
|
||||
|
||||
budget := maxContext - maxTokens
|
||||
if budget <= 0 {
|
||||
budget = maxContext / 2 // safety: at least half for input
|
||||
}
|
||||
|
||||
estimated := estimateMessageTokens(messages)
|
||||
if estimated <= budget {
|
||||
return messages, 0
|
||||
}
|
||||
|
||||
// Separate system messages (keep all) from non-system (truncatable)
|
||||
var systemMsgs []map[string]string
|
||||
var otherMsgs []map[string]string
|
||||
for _, msg := range messages {
|
||||
if msg["role"] == "system" {
|
||||
systemMsgs = append(systemMsgs, msg)
|
||||
} else {
|
||||
otherMsgs = append(otherMsgs, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate system message tokens (non-removable)
|
||||
systemTokens := estimateMessageTokens(systemMsgs)
|
||||
remainingBudget := budget - systemTokens
|
||||
if remainingBudget <= 0 {
|
||||
return messages, 0
|
||||
}
|
||||
|
||||
// Remove oldest non-system messages until we fit
|
||||
removed := 0
|
||||
for len(otherMsgs) > 1 {
|
||||
currentTokens := estimateMessageTokens(otherMsgs)
|
||||
if currentTokens <= remainingBudget {
|
||||
break
|
||||
}
|
||||
otherMsgs = otherMsgs[1:]
|
||||
removed++
|
||||
}
|
||||
|
||||
if removed == 0 {
|
||||
return messages, 0
|
||||
}
|
||||
|
||||
result := make([]map[string]string, 0, len(systemMsgs)+len(otherMsgs))
|
||||
result = append(result, systemMsgs...)
|
||||
result = append(result, otherMsgs...)
|
||||
return result, removed
|
||||
}
|
||||
|
||||
// truncateMessagesAny is like truncateMessages but for map[string]any messages.
|
||||
func truncateMessagesAny(messages []map[string]any, maxContext, maxTokens int) ([]map[string]any, int) {
|
||||
if maxContext <= 0 {
|
||||
return messages, 0
|
||||
}
|
||||
|
||||
budget := maxContext - maxTokens
|
||||
if budget <= 0 {
|
||||
budget = maxContext / 2
|
||||
}
|
||||
|
||||
estimated := estimateMessageTokensAny(messages)
|
||||
if estimated <= budget {
|
||||
return messages, 0
|
||||
}
|
||||
|
||||
var systemMsgs []map[string]any
|
||||
var otherMsgs []map[string]any
|
||||
for _, msg := range messages {
|
||||
role, _ := msg["role"].(string)
|
||||
if role == "system" {
|
||||
systemMsgs = append(systemMsgs, msg)
|
||||
} else {
|
||||
otherMsgs = append(otherMsgs, msg)
|
||||
}
|
||||
}
|
||||
|
||||
systemTokens := estimateMessageTokensAny(systemMsgs)
|
||||
remainingBudget := budget - systemTokens
|
||||
if remainingBudget <= 0 {
|
||||
return messages, 0
|
||||
}
|
||||
|
||||
removed := 0
|
||||
for len(otherMsgs) > 1 {
|
||||
currentTokens := estimateMessageTokensAny(otherMsgs)
|
||||
if currentTokens <= remainingBudget {
|
||||
break
|
||||
}
|
||||
otherMsgs = otherMsgs[1:]
|
||||
removed++
|
||||
}
|
||||
|
||||
if removed == 0 {
|
||||
return messages, 0
|
||||
}
|
||||
|
||||
result := make([]map[string]any, 0, len(systemMsgs)+len(otherMsgs))
|
||||
result = append(result, systemMsgs...)
|
||||
result = append(result, otherMsgs...)
|
||||
return result, removed
|
||||
}
|
||||
@@ -0,0 +1,125 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEstimateMessageTokens(t *testing.T) {
|
||||
msgs := []map[string]string{
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
}
|
||||
tokens := estimateMessageTokens(msgs)
|
||||
if tokens <= 0 {
|
||||
t.Errorf("expected positive token count, got %d", tokens)
|
||||
}
|
||||
// "You are a helpful assistant." = 28 chars / 3 + 10 = ~19
|
||||
// "Hello, how are you?" = 19 chars / 3 + 10 = ~16
|
||||
// Total ~35
|
||||
if tokens < 20 || tokens > 60 {
|
||||
t.Errorf("expected ~35 tokens, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateMessages_NoTruncationNeeded(t *testing.T) {
|
||||
msgs := []map[string]string{
|
||||
{"role": "system", "content": "Be helpful."},
|
||||
{"role": "user", "content": "Hi"},
|
||||
}
|
||||
result, removed := truncateMessages(msgs, 131072, 2000)
|
||||
if removed != 0 {
|
||||
t.Errorf("expected no truncation, got %d removed", removed)
|
||||
}
|
||||
if len(result) != 2 {
|
||||
t.Errorf("expected 2 messages, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateMessages_NoLimit(t *testing.T) {
|
||||
msgs := []map[string]string{
|
||||
{"role": "user", "content": strings.Repeat("x", 1000000)},
|
||||
}
|
||||
result, removed := truncateMessages(msgs, 0, 2000)
|
||||
if removed != 0 {
|
||||
t.Errorf("expected no truncation when maxContext=0, got %d removed", removed)
|
||||
}
|
||||
if len(result) != 1 {
|
||||
t.Errorf("expected 1 message, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateMessages_TruncatesOldest(t *testing.T) {
|
||||
// Create messages that definitely exceed a small context limit
|
||||
msgs := []map[string]string{
|
||||
{"role": "system", "content": "System prompt"},
|
||||
{"role": "user", "content": strings.Repeat("old message ", 500)}, // ~2000 chars
|
||||
{"role": "assistant", "content": strings.Repeat("old reply ", 500)}, // ~2000 chars
|
||||
{"role": "user", "content": strings.Repeat("newer msg ", 500)}, // ~2000 chars
|
||||
{"role": "assistant", "content": strings.Repeat("newer reply ", 500)}, // ~2000 chars
|
||||
{"role": "user", "content": "latest question"},
|
||||
}
|
||||
|
||||
// Set a small context limit that forces truncation
|
||||
result, removed := truncateMessages(msgs, 2000, 500)
|
||||
if removed == 0 {
|
||||
t.Fatal("expected some messages to be truncated")
|
||||
}
|
||||
|
||||
// System message should always be preserved
|
||||
if result[0]["role"] != "system" {
|
||||
t.Error("system message should be first")
|
||||
}
|
||||
|
||||
// Last message should be the latest user message
|
||||
last := result[len(result)-1]
|
||||
if last["content"] != "latest question" {
|
||||
t.Errorf("last message should be 'latest question', got '%s'", last["content"])
|
||||
}
|
||||
|
||||
// Should have fewer messages than original
|
||||
if len(result) >= len(msgs) {
|
||||
t.Errorf("expected fewer messages after truncation, got %d (original %d)", len(result), len(msgs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateMessages_PreservesSystemMessages(t *testing.T) {
|
||||
msgs := []map[string]string{
|
||||
{"role": "system", "content": "System 1"},
|
||||
{"role": "system", "content": "System 2"},
|
||||
{"role": "user", "content": strings.Repeat("long msg ", 1000)},
|
||||
{"role": "user", "content": "short"},
|
||||
}
|
||||
|
||||
result, _ := truncateMessages(msgs, 500, 100)
|
||||
|
||||
// Count system messages - should all be preserved
|
||||
systemCount := 0
|
||||
for _, msg := range result {
|
||||
if msg["role"] == "system" {
|
||||
systemCount++
|
||||
}
|
||||
}
|
||||
if systemCount != 2 {
|
||||
t.Errorf("expected 2 system messages preserved, got %d", systemCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateMessages_KeepsAtLeastOneNonSystem(t *testing.T) {
|
||||
msgs := []map[string]string{
|
||||
{"role": "system", "content": "System"},
|
||||
{"role": "user", "content": strings.Repeat("very long ", 10000)},
|
||||
}
|
||||
|
||||
result, _ := truncateMessages(msgs, 100, 50)
|
||||
|
||||
nonSystem := 0
|
||||
for _, msg := range result {
|
||||
if msg["role"] != "system" {
|
||||
nonSystem++
|
||||
}
|
||||
}
|
||||
if nonSystem < 1 {
|
||||
t.Error("should keep at least 1 non-system message")
|
||||
}
|
||||
}
|
||||
@@ -86,6 +86,19 @@ func WithMaxTokens(maxTokens int) ClientOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithMaxContext sets the model's max context window in tokens.
|
||||
// When set (> 0), the client will automatically truncate oldest non-system
|
||||
// messages if the estimated token count exceeds this limit.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// client := mcp.NewClient(mcp.WithMaxContext(131072)) // DeepSeek 128K
|
||||
func WithMaxContext(maxContext int) ClientOption {
|
||||
return func(c *Config) {
|
||||
c.MaxContext = maxContext
|
||||
}
|
||||
}
|
||||
|
||||
// WithTemperature sets temperature parameter
|
||||
//
|
||||
// Usage example:
|
||||
|
||||
Reference in New Issue
Block a user