diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index 72cc7408..77247ee5 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -639,6 +639,11 @@ func (c *Connection) handleAsync() { ctx := context.WithValue(req.ctx, asyncKey, releaser) go func() { defer releaser.release(true) + defer func() { + if r := recover(); r != nil { + c.processResult(c.handler, req, nil, fmt.Errorf("%w: panic in handler: %v", ErrInternal, r)) + } + }() result, err := c.handler.Handle(ctx, req.Request) c.processResult(c.handler, req, result, err) }() diff --git a/mcp/panic_recovery_test.go b/mcp/panic_recovery_test.go new file mode 100644 index 00000000..877979ac --- /dev/null +++ b/mcp/panic_recovery_test.go @@ -0,0 +1,106 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "testing" + "time" + + "github.com/google/jsonschema-go/jsonschema" +) + +// TestToolHandler_PanicRecovery verifies that a panicking tool handler does +// not crash the server process. The panic should be recovered and returned +// as a JSON-RPC internal error to the client. +func TestToolHandler_PanicRecovery(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + ct, st := NewInMemoryTransports() + + s := NewServer(testImpl, nil) + AddTool(s, &Tool{ + Name: "panic-tool", + Description: "a tool that panics", + InputSchema: &jsonschema.Schema{Type: "object"}, + }, func(_ context.Context, _ *CallToolRequest, _ map[string]any) (*CallToolResult, any, error) { + panic("deliberate panic in tool handler") + }) + + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + c := NewClient(testImpl, nil) + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Call the panicking tool. Without recovery, this crashes the process. + // With recovery, we get an error response. + _, err = cs.CallTool(ctx, &CallToolParams{Name: "panic-tool"}) + + // We expect an error (the panic is caught and returned as internal error). + if err == nil { + t.Fatal("expected error from panicking tool handler, got success") + } + // The important thing is we reached this line: the process didn't crash. +} + +// TestToolHandler_PanicDoesNotAffectSubsequentCalls verifies that after a +// panic in one tool handler, the server continues to serve other requests. +func TestToolHandler_PanicDoesNotAffectSubsequentCalls(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + ct, st := NewInMemoryTransports() + + s := NewServer(testImpl, nil) + AddTool(s, &Tool{ + Name: "panic-tool", + Description: "a tool that panics", + InputSchema: &jsonschema.Schema{Type: "object"}, + }, func(_ context.Context, _ *CallToolRequest, _ map[string]any) (*CallToolResult, any, error) { + panic("deliberate panic") + }) + AddTool(s, &Tool{ + Name: "safe-tool", + Description: "a tool that works", + InputSchema: &jsonschema.Schema{Type: "object"}, + }, func(_ context.Context, _ *CallToolRequest, _ map[string]any) (*CallToolResult, any, error) { + return &CallToolResult{Content: []Content{&TextContent{Text: "ok"}}}, nil, nil + }) + + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + c := NewClient(testImpl, nil) + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // First: call the panicking tool. + _, _ = cs.CallTool(ctx, &CallToolParams{Name: "panic-tool"}) + + // Second: call the safe tool. This should succeed, proving the server + // is still alive after the panic. + result, err := cs.CallTool(ctx, &CallToolParams{Name: "safe-tool"}) + if err != nil { + t.Fatalf("safe tool call failed after panic: %v", err) + } + if result == nil { + t.Fatal("expected result from safe tool, got nil") + } +} diff --git a/mcp/shared.go b/mcp/shared.go index 078b401b..211324ba 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -606,6 +606,11 @@ func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr *cancelPtr = cancel go func() { + defer func() { + if r := recover(); r != nil { + logger.Error("panic in keepalive goroutine", "error", r) + } + }() ticker := time.NewTicker(interval) defer ticker.Stop() diff --git a/mcp/sse.go b/mcp/sse.go index 0e1ad79e..72f3b772 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -10,6 +10,7 @@ import ( "crypto/rand" "fmt" "io" + "log/slog" "mime" "net" "net/http" @@ -420,6 +421,11 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { go func() { defer s.Close() // close the transport when the GET exits + defer func() { + if r := recover(); r != nil { + slog.Default().Error("panic in SSE reader goroutine", "error", r) + } + }() for evt, err := range scanEvents(resp.Body) { if err != nil { diff --git a/mcp/streamable.go b/mcp/streamable.go index b8e36553..d842a514 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -2058,6 +2058,11 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) error { } func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Response) { + defer func() { + if r := recover(); r != nil { + c.fail(fmt.Errorf("%s: panic in handleJSON: %v", requestSummary, r)) + } + }() body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { @@ -2083,6 +2088,11 @@ func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Resp // stream is complete when we receive its response. Otherwise, this is the // standalone stream. func (c *streamableClientConn) handleSSE(ctx context.Context, requestSummary string, resp *http.Response, forCall *jsonrpc2.Request) { + defer func() { + if r := recover(); r != nil { + c.fail(fmt.Errorf("%s: panic in handleSSE: %v", requestSummary, r)) + } + }() // Track the last event ID to detect progress. // The retry counter is only reset when progress is made (lastEventID advances). // This prevents infinite retry loops when a server repeatedly terminates diff --git a/mcp/transport.go b/mcp/transport.go index ea447478..30e3cb97 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -403,6 +403,14 @@ func newIOConn(rwc io.ReadWriteCloser) *ioConn { // but that is unavoidable since AFAIK there is no (easy and portable) way to // guarantee that reads of stdin are unblocked when closed. go func() { + defer func() { + if r := recover(); r != nil { + select { + case incoming <- msgOrErr{err: fmt.Errorf("panic in reader: %v", r)}: + case <-closed: + } + } + }() dec := json.NewDecoder(rwc) for { var raw json.RawMessage