Files
nofx/mcp/client.go
T
Lance a1f909adbe fix(telemetry): report token usage for SSE streaming paths (#1475)
* fix(telemetry): report token usage for SSE streaming paths

ParseSSEStream already parsed the usage block from SSE chunks but only
printed it, so claw402 streaming calls (and native streaming) never
fired TokenUsageCallback. GA4 therefore undercounted AI usage on the
streaming path.

Return the parsed usage from ParseSSEStream and have both callers fire
the callback with their own Provider/Model.

* chore: drop leftover debug Printf in ParseSSEStream

Telemetry is now wired via TokenUsageCallback, so the Printf is
redundant noise in the stream path.
2026-04-16 21:31:13 +08:00

823 lines
23 KiB
Go

package mcp
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
const (
ProviderCustom = "custom"
MCPClientTemperature = 0.5
)
var (
DefaultTimeout = 120 * time.Second
MaxRetryTimes = 3
retryableErrors = []string{
"EOF",
"timeout",
"connection reset",
"connection refused",
"temporary failure",
"no such host",
"stream error", // HTTP/2 stream error
"INTERNAL_ERROR", // Server internal error
"status 502", // Bad Gateway
"status 503", // Service Unavailable
"status 520", // Cloudflare origin error
"status 524", // Cloudflare timeout
}
// TokenUsageCallback is called after each AI request with token usage info
TokenUsageCallback func(usage TokenUsage)
)
// TokenUsage represents token usage from AI API response
type TokenUsage struct {
Provider string // payment channel: "claw402" or native provider name
Model string
PromptTokens int
CompletionTokens int
TotalTokens int
}
// Channel returns the payment channel category for telemetry.
// Returns "claw402" or "native" based on the provider.
func (u TokenUsage) Channel() string {
switch u.Provider {
case ProviderClaw402:
return "claw402"
default:
return "native"
}
}
// Client AI API configuration
type Client struct {
Provider string
APIKey string
BaseURL string
Model string
UseFullURL bool // Whether to use full URL (without appending /chat/completions)
MaxTokens int // Maximum tokens for AI response
HTTPClient *http.Client // Exported for sub-packages
Log Logger // Exported for sub-packages
Cfg *Config // Exported for sub-packages
// Hooks are used to implement dynamic dispatch (polymorphism)
// When provider.DeepSeekClient embeds Client, Hooks point to DeepSeekClient
// This way methods called in Call() are automatically dispatched to the overridden version
Hooks ClientHooks
}
// New creates default client (backward compatible)
//
// Deprecated: Recommend using NewClient(...opts) for better flexibility
func New() AIClient {
return NewClient()
}
// NewClient creates client (supports options pattern)
//
// Usage examples:
//
// // Basic usage (backward compatible)
// client := mcp.NewClient()
//
// // Custom logger
// client := mcp.NewClient(mcp.WithLogger(customLogger))
//
// // Custom timeout
// client := mcp.NewClient(mcp.WithTimeout(60*time.Second))
//
// // Combine multiple options
// client := mcp.NewClient(
// mcp.WithDeepSeekConfig("sk-xxx"),
// mcp.WithLogger(customLogger),
// mcp.WithTimeout(60*time.Second),
// )
func NewClient(opts ...ClientOption) AIClient {
// 1. Create default config
cfg := DefaultConfig()
// 2. Apply user options
for _, opt := range opts {
opt(cfg)
}
// 3. Create client instance
client := &Client{
Provider: cfg.Provider,
APIKey: cfg.APIKey,
BaseURL: cfg.BaseURL,
Model: cfg.Model,
MaxTokens: cfg.MaxTokens,
UseFullURL: cfg.UseFullURL,
HTTPClient: cfg.HTTPClient,
Log: cfg.Logger,
Cfg: cfg,
}
// 4. Set default Provider (if not set)
if client.Provider == "" {
client.Provider = ProviderDeepSeek
client.BaseURL = DefaultDeepSeekBaseURL
client.Model = DefaultDeepSeekModel
}
// 5. Set hooks to point to self
client.Hooks = client
return client
}
// SetCustomAPI sets custom OpenAI-compatible API
func (client *Client) SetAPIKey(apiKey, apiURL, customModel string) {
client.Provider = ProviderCustom
client.APIKey = apiKey
// Check if URL ends with #, if so use full URL (without appending /chat/completions)
if strings.HasSuffix(apiURL, "#") {
client.BaseURL = strings.TrimSuffix(apiURL, "#")
client.UseFullURL = true
} else {
client.BaseURL = apiURL
client.UseFullURL = false
}
client.Model = customModel
}
func (client *Client) SetTimeout(timeout time.Duration) {
client.HTTPClient.Timeout = timeout
}
// CallWithMessages template method - fixed retry flow (cannot be overridden)
func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string, error) {
if client.APIKey == "" {
return "", fmt.Errorf("AI API key not set, please call SetAPIKey first")
}
// Fixed retry flow
var lastErr error
maxRetries := client.Cfg.MaxRetries
for attempt := 1; attempt <= maxRetries; attempt++ {
if attempt > 1 {
client.Log.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries)
}
// Call the fixed single-call flow
result, err := client.Hooks.Call(systemPrompt, userPrompt)
if err == nil {
if attempt > 1 {
client.Log.Infof("✓ AI API retry succeeded")
}
return result, nil
}
lastErr = err
// Check if error is retryable via hooks (supports custom retry strategy)
if !client.Hooks.IsRetryableError(err) {
return "", err
}
// Wait before retry
if attempt < maxRetries {
waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt)
client.Log.Infof("⏳ Waiting %v before retry...", waitTime)
time.Sleep(waitTime)
}
}
return "", fmt.Errorf("still failed after %d retries: %w", maxRetries, lastErr)
}
func (client *Client) SetAuthHeader(reqHeader http.Header) {
reqHeader.Set("Authorization", fmt.Sprintf("Bearer %s", client.APIKey))
}
func (client *Client) BuildMCPRequestBody(systemPrompt, userPrompt string) map[string]any {
// Build messages array
messages := []map[string]string{}
// If system prompt exists, add system message
if systemPrompt != "" {
messages = append(messages, map[string]string{
"role": "system",
"content": systemPrompt,
})
}
// Add user message
messages = append(messages, map[string]string{
"role": "user",
"content": userPrompt,
})
// Guard: truncate messages if they would exceed the model's context window
if client.Cfg.MaxContext > 0 {
truncated, removed := truncateMessages(messages, client.Cfg.MaxContext, client.MaxTokens)
if removed > 0 {
client.Log.Warnf("⚠️ [%s] Context guard: truncated %d oldest messages to fit within %d token limit",
client.String(), removed, client.Cfg.MaxContext)
messages = truncated
}
}
// Build request body
requestBody := map[string]interface{}{
"model": client.Model,
"messages": messages,
"temperature": client.Cfg.Temperature, // Use configured temperature
}
// OpenAI newer models use max_completion_tokens instead of max_tokens
if client.Provider == ProviderOpenAI {
requestBody["max_completion_tokens"] = client.MaxTokens
} else {
requestBody["max_tokens"] = client.MaxTokens
}
return requestBody
}
// MarshalRequestBody can be used to marshal the request body and can be overridden
func (client *Client) MarshalRequestBody(requestBody map[string]any) ([]byte, error) {
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to serialize request: %w", err)
}
return jsonData, nil
}
func (client *Client) ParseMCPResponse(body []byte) (string, error) {
r, err := client.ParseMCPResponseFull(body)
if err != nil {
return "", err
}
return r.Content, nil
}
// ParseMCPResponseFull parses the OpenAI-format response body and returns both
// the text content and any tool calls.
func (client *Client) ParseMCPResponseFull(body []byte) (*LLMResponse, error) {
var result struct {
Choices []struct {
Message struct {
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls"`
} `json:"message"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
if len(result.Choices) == 0 {
return nil, fmt.Errorf("API returned empty response")
}
// Report token usage if callback is set
if TokenUsageCallback != nil && result.Usage.TotalTokens > 0 {
TokenUsageCallback(TokenUsage{
Provider: client.Provider,
Model: client.Model,
PromptTokens: result.Usage.PromptTokens,
CompletionTokens: result.Usage.CompletionTokens,
TotalTokens: result.Usage.TotalTokens,
})
}
msg := result.Choices[0].Message
return &LLMResponse{
Content: msg.Content,
ToolCalls: msg.ToolCalls,
}, nil
}
func (client *Client) BuildUrl() string {
if client.UseFullURL {
return client.BaseURL
}
return fmt.Sprintf("%s/chat/completions", client.BaseURL)
}
func (client *Client) BuildRequest(url string, jsonData []byte) (*http.Request, error) {
// Create HTTP request
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("fail to build request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
// Set auth header via hooks (supports overriding)
client.Hooks.SetAuthHeader(req.Header)
return req, nil
}
// Call single AI API call (fixed flow, cannot be overridden)
func (client *Client) Call(systemPrompt, userPrompt string) (string, error) {
// Print current AI configuration
client.Log.Infof("📡 [%s] Request AI Server: BaseURL: %s", client.String(), client.BaseURL)
client.Log.Debugf("[%s] UseFullURL: %v", client.String(), client.UseFullURL)
if len(client.APIKey) > 8 {
client.Log.Debugf("[%s] API Key: %s...%s", client.String(), client.APIKey[:4], client.APIKey[len(client.APIKey)-4:])
}
// Step 1: Build request body (via hooks for dynamic dispatch)
requestBody := client.Hooks.BuildMCPRequestBody(systemPrompt, userPrompt)
// Step 2: Serialize request body (via hooks for dynamic dispatch)
jsonData, err := client.Hooks.MarshalRequestBody(requestBody)
if err != nil {
return "", err
}
// Step 3: Build URL (via hooks for dynamic dispatch)
url := client.Hooks.BuildUrl()
client.Log.Infof("📡 [MCP %s] Request URL: %s", client.String(), url)
// Step 4: Create HTTP request (fixed logic)
req, err := client.Hooks.BuildRequest(url, jsonData)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
// Step 5: Send HTTP request (fixed logic)
resp, err := client.HTTPClient.Do(req)
if err != nil {
return "", fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
// Step 6: Read response body (fixed logic)
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
// Step 7: Check HTTP status code (fixed logic)
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("API returned error (status %d): %s", resp.StatusCode, string(body))
}
// Step 8: Parse response (via hooks for dynamic dispatch)
result, err := client.Hooks.ParseMCPResponse(body)
if err != nil {
return "", fmt.Errorf("fail to parse AI server response: %w", err)
}
return result, nil
}
func (client *Client) String() string {
return fmt.Sprintf("[Provider: %s, Model: %s]",
client.Provider, client.Model)
}
// BaseClient returns the underlying *Client (satisfies ClientEmbedder interface).
func (c *Client) BaseClient() *Client { return c }
// IsRetryableError determines if error is retryable (network errors, timeouts, etc.)
func (client *Client) IsRetryableError(err error) bool {
errStr := err.Error()
// Network errors, timeouts, EOF, etc. can be retried
for _, retryable := range client.Cfg.RetryableErrors {
if strings.Contains(errStr, retryable) {
return true
}
}
return false
}
// ============================================================
// Builder Pattern API (Advanced Features)
// ============================================================
// CallWithRequest calls AI API using Request object (supports advanced features)
func (client *Client) CallWithRequest(req *Request) (string, error) {
if client.APIKey == "" {
return "", fmt.Errorf("AI API key not set, please call SetAPIKey first")
}
// If Model is not set in Request, use Client's Model
if req.Model == "" {
req.Model = client.Model
}
// Fixed retry flow
var lastErr error
maxRetries := client.Cfg.MaxRetries
for attempt := 1; attempt <= maxRetries; attempt++ {
if attempt > 1 {
client.Log.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries)
}
// Call single request
result, err := client.callWithRequest(req)
if err == nil {
if attempt > 1 {
client.Log.Infof("✓ AI API retry succeeded")
}
return result, nil
}
lastErr = err
// Check if error is retryable
if !client.Hooks.IsRetryableError(err) {
return "", err
}
// Wait before retry
if attempt < maxRetries {
waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt)
client.Log.Infof("⏳ Waiting %v before retry...", waitTime)
time.Sleep(waitTime)
}
}
return "", fmt.Errorf("still failed after %d retries: %w", maxRetries, lastErr)
}
// CallWithRequestFull calls the AI API and returns both text content and tool calls.
func (client *Client) CallWithRequestFull(req *Request) (*LLMResponse, error) {
if client.APIKey == "" {
return nil, fmt.Errorf("AI API key not set, please call SetAPIKey first")
}
if req.Model == "" {
req.Model = client.Model
}
var lastErr error
maxRetries := client.Cfg.MaxRetries
for attempt := 1; attempt <= maxRetries; attempt++ {
if attempt > 1 {
client.Log.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries)
}
result, err := client.callWithRequestFull(req)
if err == nil {
return result, nil
}
lastErr = err
if !client.Hooks.IsRetryableError(err) {
return nil, err
}
if attempt < maxRetries {
waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt)
time.Sleep(waitTime)
}
}
return nil, fmt.Errorf("still failed after %d retries: %w", maxRetries, lastErr)
}
// callWithRequestFull single call that returns LLMResponse (content + tool calls).
func (client *Client) callWithRequestFull(req *Request) (*LLMResponse, error) {
client.Log.Infof("📡 [%s] Request AI Server (full): BaseURL: %s", client.String(), client.BaseURL)
requestBody := client.Hooks.BuildRequestBodyFromRequest(req)
jsonData, err := client.Hooks.MarshalRequestBody(requestBody)
if err != nil {
return nil, err
}
url := client.Hooks.BuildUrl()
httpReq, err := client.Hooks.BuildRequest(url, jsonData)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
resp, err := client.HTTPClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API returned error (status %d): %s", resp.StatusCode, string(body))
}
return client.Hooks.ParseMCPResponseFull(body)
}
// callWithRequest single AI API call (using Request object)
func (client *Client) callWithRequest(req *Request) (string, error) {
// Print current AI configuration
client.Log.Infof("📡 [%s] Request AI Server with Builder: BaseURL: %s", client.String(), client.BaseURL)
client.Log.Debugf("[%s] Messages count: %d", client.String(), len(req.Messages))
requestBody := client.Hooks.BuildRequestBodyFromRequest(req)
jsonData, err := client.Hooks.MarshalRequestBody(requestBody)
if err != nil {
return "", err
}
url := client.Hooks.BuildUrl()
client.Log.Infof("📡 [MCP %s] Request URL: %s", client.String(), url)
httpReq, err := client.Hooks.BuildRequest(url, jsonData)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
resp, err := client.HTTPClient.Do(httpReq)
if err != nil {
return "", fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("API returned error (status %d): %s", resp.StatusCode, string(body))
}
result, err := client.Hooks.ParseMCPResponse(body)
if err != nil {
return "", fmt.Errorf("fail to parse AI server response: %w", err)
}
return result, nil
}
// BuildRequestBodyFromRequest builds request body from Request object
func (client *Client) BuildRequestBodyFromRequest(req *Request) map[string]any {
// Convert Message to API format — must use map[string]any to support
// tool-call messages (tool_calls, tool_call_id fields).
messages := make([]map[string]any, 0, len(req.Messages))
for _, msg := range req.Messages {
m := map[string]any{"role": msg.Role}
if len(msg.ToolCalls) > 0 {
// Assistant message that contains tool invocations.
// content must be null/omitted for OpenAI compatibility.
m["tool_calls"] = msg.ToolCalls
} else if msg.ToolCallID != "" {
// Tool result message (role="tool").
m["tool_call_id"] = msg.ToolCallID
m["content"] = msg.Content
} else {
m["content"] = msg.Content
}
messages = append(messages, m)
}
// Guard: truncate messages if they would exceed the model's context window
maxOut := client.MaxTokens
if req.MaxTokens != nil {
maxOut = *req.MaxTokens
}
if client.Cfg.MaxContext > 0 {
truncated, removed := truncateMessagesAny(messages, client.Cfg.MaxContext, maxOut)
if removed > 0 {
client.Log.Warnf("⚠️ [%s] Context guard: truncated %d oldest messages to fit within %d token limit",
client.String(), removed, client.Cfg.MaxContext)
messages = truncated
}
}
// Build basic request body
requestBody := map[string]interface{}{
"model": req.Model,
"messages": messages,
}
// Add optional parameters (only add non-nil parameters)
if req.Temperature != nil {
requestBody["temperature"] = *req.Temperature
} else {
// If not set in Request, use Client's configuration
requestBody["temperature"] = client.Cfg.Temperature
}
// OpenAI newer models use max_completion_tokens instead of max_tokens
tokenKey := "max_tokens"
if client.Provider == ProviderOpenAI {
tokenKey = "max_completion_tokens"
}
if req.MaxTokens != nil {
requestBody[tokenKey] = *req.MaxTokens
} else {
// If not set in Request, use Client's MaxTokens
requestBody[tokenKey] = client.MaxTokens
}
if req.TopP != nil {
requestBody["top_p"] = *req.TopP
}
if req.FrequencyPenalty != nil {
requestBody["frequency_penalty"] = *req.FrequencyPenalty
}
if req.PresencePenalty != nil {
requestBody["presence_penalty"] = *req.PresencePenalty
}
if len(req.Stop) > 0 {
requestBody["stop"] = req.Stop
}
if len(req.Tools) > 0 {
requestBody["tools"] = req.Tools
}
if req.ToolChoice != "" {
requestBody["tool_choice"] = req.ToolChoice
}
if req.Stream {
requestBody["stream"] = true
}
return requestBody
}
// CallWithRequestStream streams the LLM response via SSE (Server-Sent Events).
// onChunk is called with the full accumulated text so far after each received chunk.
// Returns the complete final text when the stream ends.
//
// Idle timeout: if no chunk arrives for 30 seconds the stream is cancelled automatically.
// This prevents the scanner from blocking indefinitely on a hung or stalled connection.
func (client *Client) CallWithRequestStream(req *Request, onChunk func(string)) (string, error) {
if client.APIKey == "" {
return "", fmt.Errorf("AI API key not set")
}
if req.Model == "" {
req.Model = client.Model
}
req.Stream = true
requestBody := client.Hooks.BuildRequestBodyFromRequest(req)
jsonData, err := client.Hooks.MarshalRequestBody(requestBody)
if err != nil {
return "", err
}
url := client.Hooks.BuildUrl()
httpReq, err := client.Hooks.BuildRequest(url, jsonData)
if err != nil {
return "", err
}
// Idle-timeout watchdog: cancel the request if no SSE line arrives for 60 seconds.
// This breaks the scanner out of an indefinitely blocking Read on a hung connection.
const idleTimeout = 60 * time.Second
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resetCh := make(chan struct{}, 1)
go func() {
t := time.NewTimer(idleTimeout)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
cancel() // idle timeout: kill the connection
return
case <-resetCh:
// received a line — reset the idle timer
if !t.Stop() {
select {
case <-t.C:
default:
}
}
t.Reset(idleTimeout)
}
}
}()
httpReq = httpReq.WithContext(ctx)
resp, err := client.HTTPClient.Do(httpReq)
if err != nil {
return "", fmt.Errorf("streaming request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
text, usage, err := ParseSSEStream(resp.Body, onChunk, func() {
select {
case resetCh <- struct{}{}:
default:
}
})
ReportStreamUsage(usage, client.Provider, client.Model)
return text, err
}
// ParseSSEStream reads an SSE response body, accumulates text deltas,
// and calls onChunk with the full accumulated text after each chunk.
// If onLine is non-nil, it is called after each raw SSE line is scanned
// (useful for resetting idle-timeout watchdogs).
// Returns the complete accumulated text and any parsed token usage (nil if absent).
func ParseSSEStream(body io.Reader, onChunk func(string), onLine func()) (string, *TokenUsage, error) {
var accumulated strings.Builder
var usage *TokenUsage
scanner := bufio.NewScanner(body)
for scanner.Scan() {
if onLine != nil {
onLine()
}
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
break
}
var chunk struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason *string `json:"finish_reason"`
} `json:"choices"`
Usage *struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage,omitempty"`
}
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue // skip malformed chunks
}
if chunk.Usage != nil && chunk.Usage.TotalTokens > 0 {
usage = &TokenUsage{
PromptTokens: chunk.Usage.PromptTokens,
CompletionTokens: chunk.Usage.CompletionTokens,
TotalTokens: chunk.Usage.TotalTokens,
}
}
if len(chunk.Choices) == 0 {
continue
}
delta := chunk.Choices[0].Delta.Content
if delta == "" {
continue
}
accumulated.WriteString(delta)
if onChunk != nil {
onChunk(accumulated.String())
}
}
if err := scanner.Err(); err != nil {
return accumulated.String(), usage, fmt.Errorf("stream interrupted: %w", err)
}
return accumulated.String(), usage, nil
}
// ReportStreamUsage fires TokenUsageCallback with the given usage, provider, and model.
// No-op if usage is nil or callback is unset.
func ReportStreamUsage(usage *TokenUsage, provider, model string) {
if usage == nil || TokenUsageCallback == nil || usage.TotalTokens <= 0 {
return
}
TokenUsageCallback(TokenUsage{
Provider: provider,
Model: model,
PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens,
TotalTokens: usage.TotalTokens,
})
}