diff --git a/mcp/client.go b/mcp/client.go index b70ac1cf..6916de3c 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -725,21 +725,24 @@ func (client *Client) CallWithRequestStream(req *Request, onChunk func(string)) return "", fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) } - return ParseSSEStream(resp.Body, onChunk, func() { + text, usage, err := ParseSSEStream(resp.Body, onChunk, func() { select { case resetCh <- struct{}{}: default: } }) + ReportStreamUsage(usage, client.Provider, client.Model) + return text, err } // ParseSSEStream reads an SSE response body, accumulates text deltas, // and calls onChunk with the full accumulated text after each chunk. // If onLine is non-nil, it is called after each raw SSE line is scanned // (useful for resetting idle-timeout watchdogs). -// Returns the complete accumulated text. -func ParseSSEStream(body io.Reader, onChunk func(string), onLine func()) (string, error) { +// Returns the complete accumulated text and any parsed token usage (nil if absent). +func ParseSSEStream(body io.Reader, onChunk func(string), onLine func()) (string, *TokenUsage, error) { var accumulated strings.Builder + var usage *TokenUsage scanner := bufio.NewScanner(body) for scanner.Scan() { @@ -774,8 +777,11 @@ func ParseSSEStream(body io.Reader, onChunk func(string), onLine func()) (string } if chunk.Usage != nil && chunk.Usage.TotalTokens > 0 { - fmt.Printf("📊 [TokenUsage] prompt=%d, completion=%d, total=%d\n", - chunk.Usage.PromptTokens, chunk.Usage.CompletionTokens, chunk.Usage.TotalTokens) + usage = &TokenUsage{ + PromptTokens: chunk.Usage.PromptTokens, + CompletionTokens: chunk.Usage.CompletionTokens, + TotalTokens: chunk.Usage.TotalTokens, + } } if len(chunk.Choices) == 0 { @@ -794,8 +800,23 @@ func ParseSSEStream(body io.Reader, onChunk func(string), onLine func()) (string } if err := scanner.Err(); err != nil { - return accumulated.String(), fmt.Errorf("stream interrupted: %w", err) + return accumulated.String(), usage, fmt.Errorf("stream interrupted: %w", err) } - return accumulated.String(), nil + return accumulated.String(), usage, nil +} + +// ReportStreamUsage fires TokenUsageCallback with the given usage, provider, and model. +// No-op if usage is nil or callback is unset. +func ReportStreamUsage(usage *TokenUsage, provider, model string) { + if usage == nil || TokenUsageCallback == nil || usage.TotalTokens <= 0 { + return + } + TokenUsageCallback(TokenUsage{ + Provider: provider, + Model: model, + PromptTokens: usage.PromptTokens, + CompletionTokens: usage.CompletionTokens, + TotalTokens: usage.TotalTokens, + }) } diff --git a/mcp/payment/x402.go b/mcp/payment/x402.go index 7ce70fdc..577da51f 100644 --- a/mcp/payment/x402.go +++ b/mcp/payment/x402.go @@ -452,7 +452,8 @@ func X402CallStream(c *mcp.Client, signFn X402SignFunc, tag string, systemPrompt var bodyBuf bytes.Buffer tee := io.TeeReader(resp.Body, &bodyBuf) - text, sseErr := mcp.ParseSSEStream(tee, onChunk, onLine) + text, usage, sseErr := mcp.ParseSSEStream(tee, onChunk, onLine) + mcp.ReportStreamUsage(usage, c.Provider, c.Model) if text != "" { c.Log.Infof("📡 [%s] SSE stream complete, got %d chars", tag, len(text))