diff --git a/mcp/server.go b/mcp/server.go index 183226d1..732b1ba0 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -1488,9 +1488,17 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam if params == nil { return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams) } + var duplicate bool ss.updateState(func(state *ServerSessionState) { - state.InitializeParams = params + duplicate = state.InitializeParams != nil + if !duplicate { + state.InitializeParams = params + } }) + if duplicate { + ss.server.opts.Logger.Error("duplicate initialize request") + return nil, fmt.Errorf("duplicate %q received", methodInitialize) + } s := ss.server return &InitializeResult{ diff --git a/mcp/server_test.go b/mcp/server_test.go index 2937ea2b..dcc7fb0e 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -825,6 +825,92 @@ func TestClientRootCapabilities(t *testing.T) { } } +func TestServerRejectsDuplicateInitialize(t *testing.T) { + ctx := context.Background() + + server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) + cTransport, sTransport := NewInMemoryTransports() + ss, err := server.Connect(ctx, sTransport, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + cConn, err := cTransport.Connect(ctx) + if err != nil { + t.Fatal(err) + } + defer cConn.Close() + + firstParams := json.RawMessage(`{ + "protocolVersion": "2025-11-25", + "clientInfo": {"name": "first-client", "version": "1.0.0"} + }`) + firstReq, err := jsonrpc2.NewCall(jsonrpc2.Int64ID(1), methodInitialize, firstParams) + if err != nil { + t.Fatal(err) + } + if err := cConn.Write(ctx, firstReq); err != nil { + t.Fatalf("first initialize write failed: %v", err) + } + msg, err := cConn.Read(ctx) + if err != nil { + t.Fatalf("first initialize read failed: %v", err) + } + resp, ok := msg.(*jsonrpc2.Response) + if !ok { + t.Fatalf("expected Response, got %T", msg) + } + if resp.Error != nil { + t.Fatalf("first initialize failed: %v", resp.Error) + } + + initializedReq, err := jsonrpc2.NewNotification(notificationInitialized, &InitializedParams{}) + if err != nil { + t.Fatal(err) + } + if err := cConn.Write(ctx, initializedReq); err != nil { + t.Fatalf("initialized notification write failed: %v", err) + } + + secondParams := json.RawMessage(`{ + "protocolVersion": "2024-11-05", + "clientInfo": {"name": "second-client", "version": "2.0.0"} + }`) + secondReq, err := jsonrpc2.NewCall(jsonrpc2.Int64ID(2), methodInitialize, secondParams) + if err != nil { + t.Fatal(err) + } + if err := cConn.Write(ctx, secondReq); err != nil { + t.Fatalf("second initialize write failed: %v", err) + } + msg, err = cConn.Read(ctx) + if err != nil { + t.Fatalf("second initialize read failed: %v", err) + } + resp, ok = msg.(*jsonrpc2.Response) + if !ok { + t.Fatalf("expected Response, got %T", msg) + } + if resp.Error == nil { + t.Fatal("second initialize unexpectedly succeeded") + } + if !strings.Contains(resp.Error.Error(), `duplicate "initialize" received`) { + t.Fatalf("second initialize error = %v, want duplicate initialize", resp.Error) + } + + got := ss.InitializeParams() + if got == nil { + t.Fatal("InitializeParams is nil") + } + if got.ProtocolVersion != "2025-11-25" { + t.Fatalf("ProtocolVersion = %q, want first initialize value", got.ProtocolVersion) + } + if got.ClientInfo == nil || got.ClientInfo.Name != "first-client" { + t.Fatalf("ClientInfo = %#v, want first initialize value", got.ClientInfo) + } +} + // TODO: move this to tool_test.go func TestToolForSchemas(t *testing.T) { // Validate that toolForErr handles schemas properly.