Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1488,9 +1488,17 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam
if params == nil {
return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams)
}
var duplicate bool
ss.updateState(func(state *ServerSessionState) {
state.InitializeParams = params
duplicate = state.InitializeParams != nil
if !duplicate {
state.InitializeParams = params
}
})
if duplicate {
ss.server.opts.Logger.Error("duplicate initialize request")
return nil, fmt.Errorf("%w: duplicate %q received", jsonrpc2.ErrInvalidRequest, methodInitialize)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// ErrInvalidRequest is used when the JSON sent is not a valid Request object.

This does not seem the right error for this case. We might align it to the 'duplicate initialized notification' case error

}

s := ss.server
return &InitializeResult{
Expand Down
87 changes: 87 additions & 0 deletions mcp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log"
"log/slog"
Expand Down Expand Up @@ -825,6 +826,92 @@ func TestClientRootCapabilities(t *testing.T) {
}
}

func TestServerRejectsDuplicateInitialize(t *testing.T) {
ctx := context.Background()

server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil)
cTransport, sTransport := NewInMemoryTransports()
ss, err := server.Connect(ctx, sTransport, nil)
if err != nil {
t.Fatal(err)
}
defer ss.Close()

cConn, err := cTransport.Connect(ctx)
if err != nil {
t.Fatal(err)
}
defer cConn.Close()

firstParams := json.RawMessage(`{
"protocolVersion": "2025-11-25",
"clientInfo": {"name": "first-client", "version": "1.0.0"}
}`)
firstReq, err := jsonrpc2.NewCall(jsonrpc2.Int64ID(1), methodInitialize, firstParams)
if err != nil {
t.Fatal(err)
}
if err := cConn.Write(ctx, firstReq); err != nil {
t.Fatalf("first initialize write failed: %v", err)
}
msg, err := cConn.Read(ctx)
if err != nil {
t.Fatalf("first initialize read failed: %v", err)
}
resp, ok := msg.(*jsonrpc2.Response)
if !ok {
t.Fatalf("expected Response, got %T", msg)
}
if resp.Error != nil {
t.Fatalf("first initialize failed: %v", resp.Error)
}

initializedReq, err := jsonrpc2.NewNotification(notificationInitialized, &InitializedParams{})
if err != nil {
t.Fatal(err)
}
if err := cConn.Write(ctx, initializedReq); err != nil {
t.Fatalf("initialized notification write failed: %v", err)
}

secondParams := json.RawMessage(`{
"protocolVersion": "2024-11-05",
"clientInfo": {"name": "second-client", "version": "2.0.0"}
}`)
secondReq, err := jsonrpc2.NewCall(jsonrpc2.Int64ID(2), methodInitialize, secondParams)
if err != nil {
t.Fatal(err)
}
if err := cConn.Write(ctx, secondReq); err != nil {
t.Fatalf("second initialize write failed: %v", err)
}
msg, err = cConn.Read(ctx)
if err != nil {
t.Fatalf("second initialize read failed: %v", err)
}
resp, ok = msg.(*jsonrpc2.Response)
if !ok {
t.Fatalf("expected Response, got %T", msg)
}
if resp.Error == nil {
t.Fatal("second initialize unexpectedly succeeded")
}
if !errors.Is(resp.Error, jsonrpc2.ErrInvalidRequest) {
t.Fatalf("second initialize error = %v, want invalid request", resp.Error)
}

got := ss.InitializeParams()
if got == nil {
t.Fatal("InitializeParams is nil")
}
if got.ProtocolVersion != "2025-11-25" {
t.Fatalf("ProtocolVersion = %q, want first initialize value", got.ProtocolVersion)
}
if got.ClientInfo == nil || got.ClientInfo.Name != "first-client" {
t.Fatalf("ClientInfo = %#v, want first initialize value", got.ClientInfo)
}
}

// TODO: move this to tool_test.go
func TestToolForSchemas(t *testing.T) {
// Validate that toolForErr handles schemas properly.
Expand Down