diff --git a/.gitignore b/.gitignore index 8d8db2cb4..a2ea2f790 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ # Cake tools /[Tt]ools/ +# Language server cache +*.lscache + # Build output [Bb]uildArtifacts/ # Build results diff --git a/docs/concepts/tools/tools.md b/docs/concepts/tools/tools.md index 503307e66..1cd412b73 100644 --- a/docs/concepts/tools/tools.md +++ b/docs/concepts/tools/tools.md @@ -315,3 +315,28 @@ public static string Search( // Schema will include descriptions and default value for maxResults } ``` + +### Custom HTTP headers from tool parameters + +When using the Streamable HTTP transport, tool parameters can be mirrored as HTTP headers so that network infrastructure (load balancers, proxies, gateways) can make routing decisions without parsing the JSON-RPC request body. Apply the to a parameter to opt it in: + +```csharp +[McpServerTool, Description("Executes a SQL query in a specific region")] +public static string ExecuteSql( + [McpHeader("Region"), Description("Target datacenter region")] string region, + [Description("The SQL query to execute")] string query) +{ + // Clients will send an additional HTTP header: + // Mcp-Param-Region: +} +``` + +When the tool's schema is generated, the annotated parameter includes an `x-mcp-header` extension property. Clients read this annotation and automatically add the corresponding `Mcp-Param-{Name}` header on outgoing `tools/call` requests. The server validates that the header value matches the value in the JSON-RPC body. + +Rules and constraints: + +- Only primitive parameter types (`string`, numeric types, `bool`) are supported. +- The header name must contain only visible ASCII characters (0x21–0x7E) excluding colon (`:`). +- Values containing non-ASCII characters, control characters, or leading/trailing whitespace are Base64-encoded using the `=?base64?{value}?=` wrapper. +- Header names must be case-insensitively unique within the tool's input schema. +- Header validation is enforced only for protocol versions that support the HTTP Standardization feature (currently `DRAFT-2026-v1` and later). diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index ec28eff84..50fa24db2 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -23,9 +23,9 @@ internal sealed class StreamableHttpHandler( IServiceProvider applicationServices, ILoggerFactory loggerFactory) { - private const string McpSessionIdHeaderName = "Mcp-Session-Id"; - private const string McpProtocolVersionHeaderName = "MCP-Protocol-Version"; - private const string LastEventIdHeaderName = "Last-Event-ID"; + private const string McpSessionIdHeaderName = McpHttpHeaders.SessionId; + private const string McpProtocolVersionHeaderName = McpHttpHeaders.ProtocolVersion; + private const string LastEventIdHeaderName = McpHttpHeaders.LastEventId; /// /// All protocol versions supported by this implementation. @@ -37,6 +37,7 @@ internal sealed class StreamableHttpHandler( "2025-03-26", "2025-06-18", "2025-11-25", + "DRAFT-2026-v1", ]; private static readonly JsonTypeInfo s_messageTypeInfo = GetRequiredJsonTypeInfo(); @@ -79,6 +80,12 @@ await WriteJsonRpcErrorAsync(context, return; } + if (!ValidateMcpHeaders(context, message, mcpServerOptionsSnapshot.Value.ToolCollection, out errorMessage)) + { + await WriteJsonRpcErrorAsync(context, errorMessage!, StatusCodes.Status400BadRequest, (int)McpErrorCode.HeaderMismatch); + return; + } + var session = await GetOrCreateSessionAsync(context, message); if (session is null) { @@ -540,6 +547,269 @@ private static bool ValidateProtocolVersionHeader(HttpContext context, out strin return true; } + /// + /// Validates standard MCP request headers (Mcp-Method, Mcp-Name) and custom parameter headers + /// (Mcp-Param-*) against the JSON-RPC request body. + /// Validation is only performed for protocol versions that include the HTTP Standardization feature. + /// + /// The HTTP context containing the request headers. + /// The JSON-RPC message to validate against. + /// The tool collection to look up tool schemas for parameter header validation. + /// Set to the error message if validation fails; null otherwise. + /// True if validation passes; false otherwise. + internal static bool ValidateMcpHeaders(HttpContext context, JsonRpcMessage message, McpServerPrimitiveCollection? toolCollection, out string? errorMessage) + { + // Only validate for protocol versions that support standard headers. + var protocolVersion = context.Request.Headers[McpProtocolVersionHeaderName].ToString(); + if (!McpHttpHeaders.SupportsStandardHeaders(protocolVersion)) + { + errorMessage = null; + return true; + } + + // Only validate for JSON-RPC requests and notifications, not responses. + if (!(message is JsonRpcRequest || message is JsonRpcNotification)) + { + errorMessage = null; + return true; + } + + // For requests that support standard headers, the Mcp-Method header must be present + // and match the method in the JSON-RPC body. + if (!context.Request.Headers.ContainsKey(McpHttpHeaders.Method)) + { + errorMessage = "Missing required Mcp-Method header."; + return false; + } + + var mcpMethodInHeader = context.Request.Headers[McpHttpHeaders.Method].ToString(); + var mcpMethodInBody = message switch + { + JsonRpcRequest request => request.Method, + JsonRpcNotification notification => notification.Method, + _ => null, // This case is already ruled out by the earlier check, but we need it to satisfy the compiler. + }; + + if (!string.Equals(mcpMethodInHeader, mcpMethodInBody, StringComparison.Ordinal)) + { + errorMessage = $"Header mismatch: Mcp-Method header value '{mcpMethodInHeader}' does not match body value '{mcpMethodInBody}'."; + return false; + } + + // From here on, only validate tools/read, tools/call, and prompts/get requests + if (mcpMethodInBody is not (RequestMethods.ToolsCall or RequestMethods.ResourcesRead or RequestMethods.PromptsGet)) + { + errorMessage = null; + return true; + } + + // For these requests, the Mcp-Name header must be present and match the name or uri in the JSON-RPC body. + if (!context.Request.Headers.ContainsKey(McpHttpHeaders.Name)) + { + errorMessage = "Missing required Mcp-Name header."; + return false; + } + + var mcpNameInHeader = context.Request.Headers[McpHttpHeaders.Name].ToString(); + + // Extract the params and name value from the body based on the method, if present. + var bodyParams = message switch + { + JsonRpcRequest request => request.Params, + JsonRpcNotification notification => notification.Params, + _ => null, + }; + var mcpNameInBody = mcpMethodInBody switch + { + RequestMethods.ToolsCall => GetJsonNodeStringProperty(bodyParams, "name"), + RequestMethods.ResourcesRead => GetJsonNodeStringProperty(bodyParams, "uri"), + RequestMethods.PromptsGet => GetJsonNodeStringProperty(bodyParams, "name"), + _ => null, + }; + + // Check that the header value matches the body value if the body value is present. + if (!string.Equals(mcpNameInHeader, mcpNameInBody, StringComparison.Ordinal)) + { + errorMessage = $"Header mismatch: Mcp-Name header value '{mcpNameInHeader}' does not match body value '{mcpNameInBody}'."; + return false; + } + + // Validate Mcp-Param-* custom headers against tool schema + if (!ValidateCustomParamHeaders(context, message, toolCollection, out errorMessage)) + { + return false; + } + + errorMessage = null; + return true; + } + + /// + /// Validates that all parameters annotated with x-mcp-header in the tool's input schema + /// have corresponding Mcp-Param-* headers present in the request, and that any present + /// Mcp-Param-* headers have valid encoding. + /// + private static bool ValidateCustomParamHeaders( + HttpContext context, + JsonRpcMessage message, + McpServerPrimitiveCollection? toolCollection, + out string? errorMessage) + { + // Custom param headers are only relevant for tools/call requests + if (message is not JsonRpcRequest { Method: RequestMethods.ToolsCall, Params: { } bodyParams }) + { + errorMessage = null; + return true; + } + + // Look up the tool to check for x-mcp-header annotations in the schema + var toolName = GetJsonNodeStringProperty(bodyParams, "name"); + if (toolName is null || toolCollection is null || !toolCollection.TryGetPrimitive(toolName, out var tool)) + { + errorMessage = null; + return true; + } + + var inputSchema = tool.ProtocolTool.InputSchema; + if (inputSchema.ValueKind != System.Text.Json.JsonValueKind.Object || + !inputSchema.TryGetProperty("properties", out var properties) || + properties.ValueKind != System.Text.Json.JsonValueKind.Object) + { + errorMessage = null; + return true; + } + + // Get the arguments from the body for value comparison + System.Text.Json.Nodes.JsonNode? arguments = null; + if (bodyParams is System.Text.Json.Nodes.JsonObject paramsObj) + { + paramsObj.TryGetPropertyValue("arguments", out arguments); + } + + // Check that every x-mcp-header annotated parameter has a corresponding header, + // that the header value is validly encoded, and that it matches the body value. + foreach (var property in properties.EnumerateObject()) + { + if (!property.Value.TryGetProperty("x-mcp-header", out var headerNameElement)) + { + continue; + } + + var headerName = headerNameElement.GetString(); + if (string.IsNullOrEmpty(headerName)) + { + continue; + } + + var fullHeaderName = $"{McpHttpHeaders.ParamPrefix}{headerName}"; + if (!context.Request.Headers.ContainsKey(fullHeaderName)) + { + // Per the SEP: if the parameter value is null or not provided in + // the arguments, the client MUST omit the header and the server + // MUST NOT expect it. Only reject when a non-null value is present + // in the body but the header is missing. + bool hasNonNullBodyValue = arguments is System.Text.Json.Nodes.JsonObject argsForMissing && + argsForMissing.TryGetPropertyValue(property.Name, out var argForMissing) && + argForMissing is not null && + argForMissing.GetValueKind() != System.Text.Json.JsonValueKind.Null; + + if (hasNonNullBodyValue) + { + errorMessage = $"Missing required {fullHeaderName} header for parameter '{property.Name}' annotated with x-mcp-header."; + return false; + } + + continue; + } + + var actualHeaderValue = context.Request.Headers[fullHeaderName].ToString(); + + // Validate the raw header value for invalid characters per SEP. + // Servers MUST reject headers containing characters outside the valid HTTP header value range. + if (!IsValidHeaderValue(actualHeaderValue)) + { + errorMessage = $"Header mismatch: {fullHeaderName} header contains invalid characters."; + return false; + } + + var decodedActual = Client.McpHeaderEncoder.DecodeValue(actualHeaderValue); + if (decodedActual is null) + { + errorMessage = $"Header mismatch: {fullHeaderName} header contains invalid Base64 encoding."; + return false; + } + + // Verify the header value matches the argument value in the body + if (arguments is System.Text.Json.Nodes.JsonObject argsObj && + argsObj.TryGetPropertyValue(property.Name, out var argNode) && + argNode is not null) + { + var expectedHeaderValue = ConvertJsonNodeToHeaderValue(argNode); + if (expectedHeaderValue is not null) + { + var decodedExpected = Client.McpHeaderEncoder.DecodeValue(expectedHeaderValue); + if (!string.Equals(decodedActual, decodedExpected, StringComparison.Ordinal)) + { + errorMessage = $"Header mismatch: {fullHeaderName} header value does not match body argument '{property.Name}'."; + return false; + } + } + } + } + + errorMessage = null; + return true; + } + + private static string? GetJsonNodeStringProperty(System.Text.Json.Nodes.JsonNode? node, string propertyName) + { + if (node is System.Text.Json.Nodes.JsonObject obj && obj.TryGetPropertyValue(propertyName, out var value)) + { + return value?.GetValue(); + } + + return null; + } + + /// + /// Validates that a header value contains only characters allowed in HTTP header field values + /// per RFC 9110: visible ASCII (0x21-0x7E), space (0x20), and horizontal tab (0x09). + /// + private static bool IsValidHeaderValue(string value) + { + foreach (char c in value) + { + if (c < 0x20 || c > 0x7E) + { + if (c != '\t') + { + return false; + } + } + } + + return true; + } + + private static string? ConvertJsonNodeToHeaderValue(System.Text.Json.Nodes.JsonNode node) + { + if (node is not System.Text.Json.Nodes.JsonValue jsonValue) + { + return null; + } + + object? value = jsonValue.GetValueKind() switch + { + System.Text.Json.JsonValueKind.String => jsonValue.GetValue(), + System.Text.Json.JsonValueKind.Number => jsonValue.ToJsonString(), + System.Text.Json.JsonValueKind.True => true, + System.Text.Json.JsonValueKind.False => false, + _ => null + }; + + return Client.McpHeaderEncoder.EncodeValue(value); + } + private static bool MatchesApplicationJsonMediaType(MediaTypeHeaderValue acceptHeaderValue) => acceptHeaderValue.MatchesMediaType("application/json"); diff --git a/src/ModelContextProtocol.Core/Client/McpClient.Methods.cs b/src/ModelContextProtocol.Core/Client/McpClient.Methods.cs index 673f66420..808f5b84b 100644 --- a/src/ModelContextProtocol.Core/Client/McpClient.Methods.cs +++ b/src/ModelContextProtocol.Core/Client/McpClient.Methods.cs @@ -184,6 +184,15 @@ public async ValueTask> ListToolsAsync( tools ??= new(toolResults.Tools.Count); foreach (var tool in toolResults.Tools) { + // Validate x-mcp-header annotations per SEP-2243. + // Clients MUST exclude tools with invalid annotations and SHOULD log a warning. + if (!McpHeaderExtractor.ValidateToolSchema(tool, out var rejectionReason)) + { + OnToolRejected(tool, rejectionReason!); + continue; + } + + OnToolDiscovered(tool); tools.Add(new(this, tool, options?.JsonSerializerOptions)); } @@ -194,6 +203,26 @@ public async ValueTask> ListToolsAsync( return tools; } + /// + /// Called when a tool definition is discovered from a tools/list response. + /// + /// + /// Override this method to cache or process tool definitions for use in + /// subsequent tools/call requests (e.g., for adding custom HTTP headers). + /// + internal virtual void OnToolDiscovered(Tool tool) + { + } + + /// + /// Called when a tool definition is rejected due to invalid x-mcp-header annotations. + /// + /// The tool that was rejected. + /// The reason the tool was rejected. + internal virtual void OnToolRejected(Tool tool, string reason) + { + } + /// /// Retrieves a list of available tools from the server. /// diff --git a/src/ModelContextProtocol.Core/Client/McpClientImpl.cs b/src/ModelContextProtocol.Core/Client/McpClientImpl.cs index 4205c28e1..728a2c6c0 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientImpl.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientImpl.cs @@ -1,6 +1,7 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Protocol; +using System.Collections.Concurrent; using System.Text.Json; namespace ModelContextProtocol.Client; @@ -22,6 +23,7 @@ internal sealed partial class McpClientImpl : McpClient private readonly McpSessionHandler _sessionHandler; private readonly SemaphoreSlim _disposeLock = new(1, 1); private readonly McpTaskCancellationTokenProvider? _taskCancellationTokenProvider; + private readonly ConcurrentDictionary _toolCache = new(StringComparer.Ordinal); private ServerCapabilities? _serverCapabilities; private Implementation? _serverInfo; @@ -633,12 +635,37 @@ internal void ResumeSession(ResumeClientSessionOptions resumeOptions) /// public override Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) - => _sessionHandler.SendRequestAsync(request, cancellationToken); + { + // For tools/call requests, attach the cached tool definition to the message context + // so the transport can add custom Mcp-Param-* headers based on x-mcp-header schema annotations. + if (request.Method == RequestMethods.ToolsCall && + request.Params is System.Text.Json.Nodes.JsonObject paramsObj && + paramsObj.TryGetPropertyValue("name", out var nameNode) && + nameNode?.GetValue() is { } toolName && + _toolCache.TryGetValue(toolName, out var tool)) + { + request.Context ??= new(); + request.Context.Items ??= new Dictionary(); + request.Context.Items[McpHttpHeaders.ToolContextKey] = tool; + } + + return _sessionHandler.SendRequestAsync(request, cancellationToken); + } /// public override Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) => _sessionHandler.SendMessageAsync(message, cancellationToken); + internal override void OnToolDiscovered(Tool tool) + { + _toolCache[tool.Name] = tool; + } + + internal override void OnToolRejected(Tool tool, string reason) + { + LogToolRejected(tool.Name, reason); + } + /// public override IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => _sessionHandler.RegisterNotificationHandler(method, handler); @@ -686,4 +713,7 @@ public override async ValueTask DisposeAsync() [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client resumed existing session.")] private partial void LogClientSessionResumed(string endpointName); + [LoggerMessage(Level = LogLevel.Warning, Message = "Tool '{ToolName}' excluded from tools/list: {Reason}")] + private partial void LogToolRejected(string toolName, string reason); + } diff --git a/src/ModelContextProtocol.Core/Client/McpHeaderEncoder.cs b/src/ModelContextProtocol.Core/Client/McpHeaderEncoder.cs new file mode 100644 index 000000000..0d4aee656 --- /dev/null +++ b/src/ModelContextProtocol.Core/Client/McpHeaderEncoder.cs @@ -0,0 +1,149 @@ +using System.Text; +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Client; + +/// +/// Encodes and decodes parameter values for use in MCP HTTP headers according to the +/// HTTP Standardization SEP. +/// +/// +/// +/// This encoder handles conversion of parameter values to HTTP header-safe strings, +/// including Base64 encoding for values that cannot be safely transmitted as plain text. +/// +/// +/// Encoding rules: +/// +/// Plain ASCII values (0x20-0x7E): sent as-is +/// Values with leading/trailing whitespace: Base64 encoded with =?base64?{value}?= wrapper +/// Non-ASCII characters: Base64 encoded +/// Control characters: Base64 encoded +/// +/// +/// +public static class McpHeaderEncoder +{ + private const string Base64Prefix = "=?base64?"; + private const string Base64Suffix = "?="; + + /// + /// Encodes a parameter value for use in an HTTP header. + /// + /// The value to encode. Can be string, number, or boolean. + /// + /// The encoded header value, or if the value cannot be encoded + /// (e.g., is not a supported type). + /// + public static string? EncodeValue(object? value) + { + if (value is null) + { + return null; + } + + var stringValue = ConvertToString(value); + if (stringValue is null) + { + return null; + } + + if (RequiresBase64Encoding(stringValue)) + { + return EncodeAsBase64(stringValue); + } + + return stringValue; + } + + /// + /// Decodes a header value that may be Base64-encoded according to SEP rules. + /// + /// The header value to decode. + /// + /// The decoded string value, or if decoding fails. + /// If the value is not Base64-encoded, returns the original value. + /// + public static string? DecodeValue(string? headerValue) + { + if (headerValue is null || headerValue.Length == 0) + { + return headerValue; + } + + // Check for Base64 wrapper (case-insensitive prefix check per SEP) + if (headerValue.StartsWith(Base64Prefix, StringComparison.OrdinalIgnoreCase) && + headerValue.EndsWith(Base64Suffix, StringComparison.Ordinal)) + { + var base64Content = headerValue.Substring( + Base64Prefix.Length, + headerValue.Length - Base64Prefix.Length - Base64Suffix.Length); + + try + { + var bytes = Convert.FromBase64String(base64Content); + return Encoding.UTF8.GetString(bytes); + } + catch (FormatException) + { + return null; + } + } + + return headerValue; + } + + private static string? ConvertToString(object value) + { + return value switch + { + string s => s, + bool b => b ? "true" : "false", + byte n => n.ToString(System.Globalization.CultureInfo.InvariantCulture), + sbyte n => n.ToString(System.Globalization.CultureInfo.InvariantCulture), + short n => n.ToString(System.Globalization.CultureInfo.InvariantCulture), + ushort n => n.ToString(System.Globalization.CultureInfo.InvariantCulture), + int n => n.ToString(System.Globalization.CultureInfo.InvariantCulture), + uint n => n.ToString(System.Globalization.CultureInfo.InvariantCulture), + long n => n.ToString(System.Globalization.CultureInfo.InvariantCulture), + ulong n => n.ToString(System.Globalization.CultureInfo.InvariantCulture), + float n => n.ToString(System.Globalization.CultureInfo.InvariantCulture), + double n => n.ToString(System.Globalization.CultureInfo.InvariantCulture), + decimal n => n.ToString(System.Globalization.CultureInfo.InvariantCulture), + _ => null + }; + } + + private static bool RequiresBase64Encoding(string value) + { + if (value.Length == 0) + { + return false; + } + + // Check for leading/trailing whitespace (space or tab) + if (value[0] is ' ' or '\t' || value[^1] is ' ' or '\t') + { + return true; + } + + foreach (char c in value) + { + // Valid HTTP header field value characters per SEP: visible ASCII (0x21-0x7E) and space (0x20). + // All control characters (0x00-0x1F, 0x7F), including tab, must be Base64-encoded. + if (c < 0x20 || c > 0x7E) + { + return true; + } + } + + return false; + } + + private static string EncodeAsBase64(string value) + { + var bytes = Encoding.UTF8.GetBytes(value); + var base64 = Convert.ToBase64String(bytes); + return $"{Base64Prefix}{base64}{Base64Suffix}"; + } +} diff --git a/src/ModelContextProtocol.Core/Client/McpHeaderExtractor.cs b/src/ModelContextProtocol.Core/Client/McpHeaderExtractor.cs new file mode 100644 index 000000000..63b86cafa --- /dev/null +++ b/src/ModelContextProtocol.Core/Client/McpHeaderExtractor.cs @@ -0,0 +1,161 @@ +using System.Net.Http.Headers; +using System.Text.Json; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Client; + +/// +/// Extracts parameter values from tool call arguments and adds them as HTTP headers +/// based on x-mcp-header schema extensions. +/// +internal static class McpHeaderExtractor +{ + private const string XMcpHeaderProperty = "x-mcp-header"; + + /// + /// Adds custom parameter headers to an HTTP request based on a tool's schema extensions. + /// + /// The HTTP request headers to add to. + /// The tool definition containing the input schema with x-mcp-header annotations. + /// The arguments being passed to the tool call. + public static void AddParameterHeaders( + HttpRequestHeaders headers, + Tool tool, + JsonElement? arguments) + { + if (!arguments.HasValue || arguments.Value.ValueKind != JsonValueKind.Object) + { + return; + } + + if (tool.InputSchema.ValueKind != JsonValueKind.Object || + !tool.InputSchema.TryGetProperty("properties", out var properties) || + properties.ValueKind != JsonValueKind.Object) + { + return; + } + + foreach (var property in properties.EnumerateObject()) + { + if (property.Value.ValueKind != JsonValueKind.Object || + !property.Value.TryGetProperty(XMcpHeaderProperty, out var headerNameElement)) + { + continue; + } + + var headerName = headerNameElement.GetString(); + if (string.IsNullOrEmpty(headerName)) + { + continue; + } + + // Look for the corresponding argument value + if (!arguments.Value.TryGetProperty(property.Name, out var argValue)) + { + continue; + } + + // Null values → omit header per SEP + if (argValue.ValueKind == JsonValueKind.Null) + { + continue; + } + + var headerValue = ConvertJsonElementToHeaderValue(argValue); + if (headerValue is not null) + { + headers.Add($"{McpHttpHeaders.ParamPrefix}{headerName}", headerValue); + } + } + } + + private static string? ConvertJsonElementToHeaderValue(JsonElement element) + { + object? value = element.ValueKind switch + { + JsonValueKind.String => element.GetString(), + JsonValueKind.Number => element.GetRawText(), + JsonValueKind.True => true, + JsonValueKind.False => false, + _ => null + }; + + return McpHeaderEncoder.EncodeValue(value); + } + + /// + /// Validates a tool's inputSchema for valid x-mcp-header annotations. + /// Returns if the tool is valid; with a reason if it should be rejected. + /// + internal static bool ValidateToolSchema(Tool tool, out string? rejectionReason) + { + rejectionReason = null; + + if (tool.InputSchema.ValueKind != JsonValueKind.Object || + !tool.InputSchema.TryGetProperty("properties", out var properties) || + properties.ValueKind != JsonValueKind.Object) + { + return true; + } + + var headerNames = new HashSet(StringComparer.OrdinalIgnoreCase); + + foreach (var property in properties.EnumerateObject()) + { + // Skip properties whose schema is not an object (e.g., boolean `true`/`false` schemas) + if (property.Value.ValueKind != JsonValueKind.Object || + !property.Value.TryGetProperty(XMcpHeaderProperty, out var headerNameElement)) + { + continue; + } + + // x-mcp-header value must be a string + if (headerNameElement.ValueKind != JsonValueKind.String) + { + rejectionReason = $"Tool '{tool.Name}': x-mcp-header on property '{property.Name}' is not a string."; + return false; + } + + var headerName = headerNameElement.GetString(); + + // MUST NOT be empty + if (string.IsNullOrEmpty(headerName)) + { + rejectionReason = $"Tool '{tool.Name}': x-mcp-header on property '{property.Name}' is empty."; + return false; + } + + // MUST contain only ASCII characters (0x21-0x7E) excluding space and colon + foreach (char c in headerName!) + { + if (c < 0x21 || c > 0x7E || c == ':') + { + rejectionReason = $"Tool '{tool.Name}': x-mcp-header '{headerName}' contains invalid character '{c}' (0x{(int)c:X2})."; + return false; + } + } + + // MUST be case-insensitively unique + if (!headerNames.Add(headerName)) + { + rejectionReason = $"Tool '{tool.Name}': duplicate x-mcp-header name '{headerName}' (case-insensitive)."; + return false; + } + + // MUST only be applied to primitive types (string, number, boolean) + if (property.Value.TryGetProperty("type", out var typeElement) && + typeElement.ValueKind == JsonValueKind.String) + { + var typeName = typeElement.GetString(); + if (typeName is not ("string" or "number" or "integer" or "boolean")) + { + rejectionReason = $"Tool '{tool.Name}': x-mcp-header on property '{property.Name}' has non-primitive type '{typeName}'."; + return false; + } + } + } + + return true; + } +} diff --git a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs index f51e236b4..2cebccb3b 100644 --- a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs @@ -4,6 +4,7 @@ using System.Net.Http.Headers; using System.Net.ServerSentEvents; using System.Text.Json; +using System.Text.Json.Nodes; using ModelContextProtocol.Protocol; using System.Threading.Channels; using System.Net; @@ -91,6 +92,8 @@ internal async Task SendHttpRequestAsync(JsonRpcMessage mes CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, SessionId, _negotiatedProtocolVersion); + AddMcpRequestHeaders(httpRequestMessage.Headers, message); + var response = await _httpClient.SendAsync(httpRequestMessage, message, cancellationToken).ConfigureAwait(false); // We'll let the caller decide whether to throw or fall back given an unsuccessful response. @@ -431,17 +434,17 @@ internal static void CopyAdditionalHeaders( { if (sessionId is not null) { - headers.Add("Mcp-Session-Id", sessionId); + headers.Add(McpHttpHeaders.SessionId, sessionId); } if (protocolVersion is not null) { - headers.Add("MCP-Protocol-Version", protocolVersion); + headers.Add(McpHttpHeaders.ProtocolVersion, protocolVersion); } if (lastEventId is not null) { - headers.Add("Last-Event-ID", lastEventId); + headers.Add(McpHttpHeaders.LastEventId, lastEventId); } if (additionalHeaders is null) @@ -458,6 +461,78 @@ internal static void CopyAdditionalHeaders( } } + /// + /// Adds standard MCP request headers (Mcp-Method, Mcp-Name) and custom parameter headers + /// (Mcp-Param-{Name}) to an HTTP request based on the JSON-RPC message being sent. + /// + internal static void AddMcpRequestHeaders(HttpRequestHeaders headers, JsonRpcMessage message) + { + string? method = message switch + { + JsonRpcRequest request => request.Method, + JsonRpcNotification notification => notification.Method, + _ => null, + }; + + if (method is null) + { + return; + } + + headers.Add(McpHttpHeaders.Method, method); + + // Add Mcp-Name header for methods that target a specific named resource + string? name = message switch + { + JsonRpcRequest { Method: RequestMethods.ToolsCall or RequestMethods.PromptsGet } request + => GetParamsStringProperty(request.Params, "name"), + JsonRpcRequest { Method: RequestMethods.ResourcesRead } request + => GetParamsStringProperty(request.Params, "uri"), + _ => null, + }; + + if (name is not null) + { + headers.Add(McpHttpHeaders.Name, name); + } + + // Add custom Mcp-Param-{Name} headers for tools/call requests with x-mcp-header annotations + if (method == RequestMethods.ToolsCall && + message is JsonRpcRequest toolsCallRequest && + toolsCallRequest.Context?.Items?.TryGetValue(McpHttpHeaders.ToolContextKey, out var toolObj) == true && + toolObj is Tool tool) + { + var arguments = GetParamsArguments(toolsCallRequest.Params); + McpHeaderExtractor.AddParameterHeaders(headers, tool, arguments); + } + } + + /// + /// Extracts a string property from the JSON-RPC params object. + /// + private static string? GetParamsStringProperty(JsonNode? paramsNode, string propertyName) + { + if (paramsNode is JsonObject obj && obj.TryGetPropertyValue(propertyName, out var value)) + { + return value?.GetValue(); + } + + return null; + } + + /// + /// Extracts the arguments property from a tools/call params object as a JsonElement. + /// + private static JsonElement? GetParamsArguments(JsonNode? paramsNode) + { + if (paramsNode is JsonObject obj && obj.TryGetPropertyValue("arguments", out var argsNode) && argsNode is not null) + { + return JsonSerializer.Deserialize(argsNode, McpJsonUtilities.JsonContext.Default.JsonElement); + } + + return null; + } + /// /// Tracks state across SSE stream connections. /// diff --git a/src/ModelContextProtocol.Core/McpErrorCode.cs b/src/ModelContextProtocol.Core/McpErrorCode.cs index 33cd74a82..38c5f1161 100644 --- a/src/ModelContextProtocol.Core/McpErrorCode.cs +++ b/src/ModelContextProtocol.Core/McpErrorCode.cs @@ -5,6 +5,26 @@ namespace ModelContextProtocol; /// public enum McpErrorCode { + /// + /// Indicates that HTTP headers do not match the corresponding values in the request body, + /// or that required headers are missing or malformed. + /// + /// + /// + /// This error is returned when a Streamable HTTP request fails header validation. Validation failures include: + /// + /// + /// A required standard header (Mcp-Method, Mcp-Name) is missing. + /// A header value does not match the corresponding request body value. + /// A Base64-encoded header value cannot be decoded. + /// A header value contains invalid characters. + /// + /// + /// This error code is in the JSON-RPC implementation-defined server error range (-32000 to -32099). + /// + /// + HeaderMismatch = -32001, + /// /// Indicates that the requested resource could not be found. /// diff --git a/src/ModelContextProtocol.Core/McpSessionHandler.cs b/src/ModelContextProtocol.Core/McpSessionHandler.cs index 24543fd3e..77c18b8be 100644 --- a/src/ModelContextProtocol.Core/McpSessionHandler.cs +++ b/src/ModelContextProtocol.Core/McpSessionHandler.cs @@ -41,6 +41,7 @@ internal sealed partial class McpSessionHandler : IAsyncDisposable "2025-03-26", "2025-06-18", LatestProtocolVersion, + "DRAFT-2026-v1", ]; /// @@ -159,7 +160,7 @@ public McpSessionHandler( /// completes its channel with a , the wrapped /// is unwrapped. Otherwise, a default instance is returned. /// - internal Task CompletionTask => + internal Task CompletionTask => field ??= GetCompletionDetailsAsync(_transport.MessageReader.Completion); /// diff --git a/src/ModelContextProtocol.Core/Protocol/McpHttpHeaders.cs b/src/ModelContextProtocol.Core/Protocol/McpHttpHeaders.cs new file mode 100644 index 000000000..6db1e9821 --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/McpHttpHeaders.cs @@ -0,0 +1,78 @@ +namespace ModelContextProtocol.Protocol; + +/// +/// Constants for MCP-specific HTTP header names used in the Streamable HTTP transport. +/// +/// +/// Per RFC 9110, HTTP header names are case-insensitive. Clients and servers must +/// use case-insensitive comparisons when processing these headers. +/// +public static class McpHttpHeaders +{ + /// + /// The minimum protocol version that requires standard MCP request headers. + /// + /// + /// Servers enforce missing Mcp-Method and Mcp-Name headers as errors only when + /// the client's MCP-Protocol-Version header indicates this version or later. + /// Clients using older versions are not required to send these headers. + /// + public const string MinVersionForStandardHeaders = "DRAFT-2026-v1"; + + /// The session identifier header. + public const string SessionId = "Mcp-Session-Id"; + + /// The negotiated protocol version header. + public const string ProtocolVersion = "MCP-Protocol-Version"; + + /// The last event ID for SSE stream resumption. + public const string LastEventId = "Last-Event-ID"; + + /// + /// The JSON-RPC method being invoked (e.g., "tools/call", "resources/read"). + /// + /// + /// Required on all Streamable HTTP POST requests. The value must match the method + /// field in the JSON-RPC request body. + /// + public const string Method = "Mcp-Method"; + + /// + /// The name or URI of the target resource for the request. + /// + /// + /// Required for tools/call, resources/read, and prompts/get requests. + /// For tools/call and prompts/get, the value is taken from params.name. + /// For resources/read, the value is taken from params.uri. + /// + public const string Name = "Mcp-Name"; + + /// + /// Prefix for custom parameter headers (Mcp-Param-{Name}). + /// + /// + /// When a tool's inputSchema includes properties annotated with x-mcp-header, + /// clients mirror those parameter values into HTTP headers using this prefix. + /// + public const string ParamPrefix = "Mcp-Param-"; + + /// + /// Key used in to store the + /// definition for the current request, enabling the transport to add custom parameter headers. + /// + internal const string ToolContextKey = "Mcp.Tool"; + + /// + /// Protocol versions that require standard MCP request headers (Mcp-Method, Mcp-Name). + /// + private static readonly HashSet s_versionsWithStandardHeaders = new(StringComparer.Ordinal) + { + MinVersionForStandardHeaders, + }; + + /// + /// Returns if the given protocol version requires standard MCP request headers. + /// + public static bool SupportsStandardHeaders(string? protocolVersion) + => !string.IsNullOrEmpty(protocolVersion) && s_versionsWithStandardHeaders.Contains(protocolVersion!); +} diff --git a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs index 700d9d26d..961344c2c 100644 --- a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs @@ -124,6 +124,12 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( Icons = options?.Icons, }; + // Add x-mcp-header extensions to the input schema based on McpHeaderAttribute on parameters + if (function.UnderlyingMethod is { } method) + { + tool.InputSchema = AddMcpHeaderExtensions(tool.InputSchema, method); + } + if (options is not null) { if (options.Title is not null || @@ -600,4 +606,85 @@ private static CallToolResult ConvertAIContentEnumerableToCallToolResult(IEnumer IsError = allErrorContent && hasAny }; } + + /// + /// Post-processes the input schema to add x-mcp-header extensions based on + /// on method parameters. + /// + private static JsonElement AddMcpHeaderExtensions(JsonElement inputSchema, MethodInfo method) + { + // Collect parameters with McpHeaderAttribute + var headerParams = new List<(string ParameterName, string HeaderName, ParameterInfo Parameter)>(); + var headerNamesSet = new HashSet(StringComparer.OrdinalIgnoreCase); + + foreach (var param in method.GetParameters()) + { + var attr = param.GetCustomAttribute(); + if (attr is null) + { + continue; + } + + // Validate primitive type only + var paramType = Nullable.GetUnderlyingType(param.ParameterType) ?? param.ParameterType; + if (!IsPrimitiveHeaderType(paramType)) + { + throw new InvalidOperationException( + $"Parameter '{param.Name}' on method '{method.Name}' has [McpHeader] but is not a primitive type. " + + "Only string, numeric, and boolean types may be annotated with [McpHeader]."); + } + + // Validate case-insensitive uniqueness + if (!headerNamesSet.Add(attr.Name)) + { + throw new InvalidOperationException( + $"Duplicate x-mcp-header name '{attr.Name}' (case-insensitive) found on method '{method.Name}'. " + + "Header names must be case-insensitively unique within a tool's input schema."); + } + + headerParams.Add((param.Name!, attr.Name, param)); + } + + if (headerParams.Count == 0) + { + return inputSchema; + } + + // Parse the schema to a mutable JsonNode, add extensions, and convert back + var schemaNode = JsonNode.Parse(inputSchema.GetRawText()); + if (schemaNode is not JsonObject schemaObj || + !schemaObj.TryGetPropertyValue("properties", out var propertiesNode) || + propertiesNode is not JsonObject propertiesObj) + { + return inputSchema; + } + + foreach (var (parameterName, headerName, _) in headerParams) + { + if (propertiesObj.TryGetPropertyValue(parameterName, out var propNode) && + propNode is JsonObject propObj) + { + propObj["x-mcp-header"] = headerName; + } + } + + return JsonSerializer.Deserialize(schemaNode, McpJsonUtilities.JsonContext.Default.JsonElement); + } + + private static bool IsPrimitiveHeaderType(Type type) + { + return type == typeof(string) || + type == typeof(bool) || + type == typeof(byte) || + type == typeof(sbyte) || + type == typeof(short) || + type == typeof(ushort) || + type == typeof(int) || + type == typeof(uint) || + type == typeof(long) || + type == typeof(ulong) || + type == typeof(float) || + type == typeof(double) || + type == typeof(decimal); + } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/McpHeaderAttribute.cs b/src/ModelContextProtocol.Core/Server/McpHeaderAttribute.cs new file mode 100644 index 000000000..516b73580 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/McpHeaderAttribute.cs @@ -0,0 +1,81 @@ +namespace ModelContextProtocol.Server; + +/// +/// Indicates that a tool parameter should be mirrored as an HTTP header in client requests. +/// +/// +/// +/// When applied to a parameter, the SDK will include an x-mcp-header extension property +/// in the parameter's JSON schema. Clients will then mirror this parameter's value into an +/// HTTP header named Mcp-Param-{Name}. +/// +/// +/// Only parameters with primitive types (string, number, boolean) may use this attribute. +/// The header name must contain only ASCII characters (0x21-0x7E, excluding space and colon) +/// and must be case-insensitively unique within the tool's input schema. +/// +/// +/// This enables network infrastructure such as load balancers, proxies, and gateways to make +/// routing decisions based on tool parameter values without parsing the JSON-RPC request body. +/// +/// +/// +/// +/// [McpServerTool] +/// public static string ExecuteSql( +/// [McpHeader("Region")] string region, +/// string query) +/// { +/// // The client will add header: Mcp-Param-Region: {region value} +/// } +/// +/// +[AttributeUsage(AttributeTargets.Parameter | AttributeTargets.Property)] +public sealed class McpHeaderAttribute : Attribute +{ + /// + /// Initializes a new instance of the class. + /// + /// + /// The name portion of the header. The full header name will be Mcp-Param-{name}. + /// Must contain only ASCII characters (0x21-0x7E, excluding space and colon). + /// + /// + /// The name is null, empty, or contains invalid characters. + /// + public McpHeaderAttribute(string name) + { + Throw.IfNullOrWhiteSpace(name); + ValidateHeaderName(name); + Name = name; + } + + /// + /// Gets the name portion of the header. + /// + /// + /// The full header name sent by clients will be Mcp-Param-{Name}. + /// + public string Name { get; } + + /// + /// Validates that a header name contains only valid characters. + /// + /// The header name to validate. + /// The name contains invalid characters. + internal static void ValidateHeaderName(string name) + { + foreach (char c in name) + { + // Valid token characters per RFC 9110: visible ASCII (0x21-0x7E) excluding delimiters. + // Space (0x20) and colon (':') are explicitly prohibited. + if (c < 0x21 || c > 0x7E || c == ':') + { + throw new ArgumentException( + $"Header name contains invalid character '{c}' (0x{(int)c:X2}). " + + "Only ASCII characters (0x21-0x7E) excluding colon are allowed.", + nameof(name)); + } + } + } +} diff --git a/tests/Common/Utils/NodeHelpers.cs b/tests/Common/Utils/NodeHelpers.cs index 94ae206ab..22cb5c9f2 100644 --- a/tests/Common/Utils/NodeHelpers.cs +++ b/tests/Common/Utils/NodeHelpers.cs @@ -80,6 +80,13 @@ public static ProcessStartInfo ConformanceTestStartInfo(string arguments) { EnsureNpmDependenciesInstalled(); + // If MCP_CONFORMANCE_PROTOCOL_VERSION is set, pass it as --spec-version to the runner. + var protocolVersion = Environment.GetEnvironmentVariable("MCP_CONFORMANCE_PROTOCOL_VERSION"); + if (!string.IsNullOrEmpty(protocolVersion)) + { + arguments += $" --spec-version {protocolVersion}"; + } + var repoRoot = FindRepoRoot(); var binPath = Path.Combine(repoRoot, "node_modules", ".bin", "conformance"); @@ -126,6 +133,82 @@ public static ProcessStartInfo ConformanceTestStartInfo(string arguments) return startInfo; } + /// + /// Gets the installed conformance package version by running 'conformance --version'. + /// Returns null if the version cannot be determined. + /// + public static Version? GetConformanceVersion() + { + if (!IsNodeInstalled()) + { + return null; + } + + try + { + EnsureNpmDependenciesInstalled(); + var repoRoot = FindRepoRoot(); + var binPath = Path.Combine(repoRoot, "node_modules", ".bin", "conformance"); + + ProcessStartInfo startInfo; + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + startInfo = new ProcessStartInfo + { + FileName = $"{binPath}.cmd", + Arguments = "--version", + RedirectStandardOutput = true, + RedirectStandardError = true, + UseShellExecute = false, + CreateNoWindow = true + }; + } + else + { + startInfo = new ProcessStartInfo + { + FileName = binPath, + Arguments = "--version", + RedirectStandardOutput = true, + RedirectStandardError = true, + UseShellExecute = false, + CreateNoWindow = true + }; + } + + using var process = Process.Start(startInfo); + if (process == null) + { + return null; + } + + var output = process.StandardOutput.ReadToEnd().Trim(); + process.WaitForExit(10_000); + + if (process.ExitCode == 0 && Version.TryParse(output, out var version)) + { + return version; + } + + return null; + } + catch + { + return null; + } + } + + /// + /// Checks if the installed conformance package version is at least the specified minimum. + /// + public static bool IsConformanceVersionAtLeast(string minimumVersion) + { + var installed = GetConformanceVersion(); + return installed != null + && Version.TryParse(minimumVersion, out var min) + && installed >= min; + } + /// /// Checks if Node.js is installed and available on the system. /// diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs index 72d075fe7..1ed231ef4 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs @@ -17,6 +17,9 @@ public class ClientConformanceTests // Public static property required for SkipUnless attribute public static bool IsNodeInstalled => NodeHelpers.IsNodeInstalled(); + // SEP-2243 scenarios require conformance package >= 0.1.16 + public static bool HasSep2243Scenarios => NodeHelpers.IsConformanceVersionAtLeast("0.1.16"); + public ClientConformanceTests(ITestOutputHelper output) { _output = output; @@ -61,6 +64,21 @@ public async Task RunConformanceTest(string scenario) $"Conformance test failed.\n\nStdout:\n{result.Output}\n\nStderr:\n{result.Error}"); } + // HTTP Standardization (SEP-2243) — requires conformance package >= 0.1.16 + [Theory(Skip = "Conformance package >= 0.1.16 not available.", SkipUnless = nameof(HasSep2243Scenarios))] + [InlineData("http-standard-headers")] + [InlineData("http-custom-headers")] + [InlineData("http-invalid-tool-headers")] + public async Task RunConformanceTest_Sep2243(string scenario) + { + // Run the conformance test suite + var result = await RunClientConformanceScenario(scenario); + + // Report the results + Assert.True(result.Success, + $"Conformance test failed.\n\nStdout:\n{result.Output}\n\nStderr:\n{result.Error}"); + } + private async Task<(bool Success, string Output, string Error)> RunClientConformanceScenario(string scenario) { // Construct an absolute path to the conformance client executable diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Sep2243HeaderTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Sep2243HeaderTests.cs new file mode 100644 index 000000000..389b9bac4 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Sep2243HeaderTests.cs @@ -0,0 +1,342 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; +using System.Net; +using System.Net.ServerSentEvents; +using System.Text; +using System.Text.Json; + +namespace ModelContextProtocol.AspNetCore.Tests; + +/// +/// Tests for SEP-2243 HTTP header standardization features: +/// - Custom Mcp-Param-* header validation +/// - Tab/control character encoding +/// - Numeric precision in header values +/// - Empty string header validation +/// - Invalid header character rejection +/// +public class Sep2243HeaderTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper), IAsyncDisposable +{ + private WebApplication? _app; + + private async Task StartAsync() + { + Builder.Services.AddMcpServer(options => + { + options.ServerInfo = new Implementation + { + Name = nameof(Sep2243HeaderTests), + Version = "1.0", + }; + }).WithTools(Tools).WithHttpTransport(); + + _app = Builder.Build(); + _app.MapMcp(); + await _app.StartAsync(TestContext.Current.CancellationToken); + + HttpClient.DefaultRequestHeaders.Accept.Add(new("application/json")); + HttpClient.DefaultRequestHeaders.Accept.Add(new("text/event-stream")); + } + + public async ValueTask DisposeAsync() + { + if (_app is not null) + { + await _app.DisposeAsync(); + } + base.Dispose(); + } + + // Create a tool with x-mcp-header annotations in the schema. + // We set InputSchema directly because TransformSchemaNode doesn't provide + // property-level path context for lambda-based tool creation. + private static McpServerTool[] Tools { get; } = [CreateHeaderTestTool()]; + + private static readonly JsonSerializerOptions s_reflectionOptions = new() + { + TypeInfoResolver = new System.Text.Json.Serialization.Metadata.DefaultJsonTypeInfoResolver() + }; + + private static McpServerTool CreateHeaderTestTool() + { + var tool = McpServerTool.Create( + [McpServerTool(Name = "header_test")] + static (string region, int priority, bool verbose, string emptyVal) => + $"region={region},priority={priority},verbose={verbose},empty={emptyVal}", + new McpServerToolCreateOptions { SerializerOptions = s_reflectionOptions }); + + using var doc = JsonDocument.Parse(""" + { + "type": "object", + "properties": { + "region": { "type": "string", "x-mcp-header": "Region" }, + "priority": { "type": "integer", "x-mcp-header": "Priority" }, + "verbose": { "type": "boolean", "x-mcp-header": "Verbose" }, + "emptyVal": { "type": "string", "x-mcp-header": "EmptyVal" } + }, + "required": ["region", "priority", "verbose", "emptyVal"] + } + """); + tool.ProtocolTool.InputSchema = doc.RootElement.Clone(); + + return tool; + } + + #region Server-side validation tests + + [Fact] + public async Task Server_ValidatesEmptyStringHeaderValue_AgainstBodyValue() + { + await StartAsync(); + await InitializeWithDraftVersionAsync(); + + // Send a tools/call with an empty string param that has an x-mcp-header. + // The header should be present with an empty value, matching the body's empty string. + var callJson = CallTool("header_test", """{"region":"us-west1","priority":42,"verbose":false,"emptyVal":""}"""); + + using var request = new HttpRequestMessage(HttpMethod.Post, ""); + request.Content = new StringContent(callJson, Encoding.UTF8, "application/json"); + request.Headers.Add("MCP-Protocol-Version", "DRAFT-2026-v1"); + request.Headers.Add("Mcp-Method", "tools/call"); + request.Headers.Add("Mcp-Name", "header_test"); + request.Headers.Add("Mcp-Param-Region", "us-west1"); + request.Headers.Add("Mcp-Param-Priority", "42"); + request.Headers.Add("Mcp-Param-Verbose", "false"); + request.Headers.Add("Mcp-Param-EmptyVal", ""); + + using var response = await HttpClient.SendAsync(request, TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + [Fact] + public async Task Server_RejectsHeaderMismatch_WhenEmptyHeaderDoesNotMatchBody() + { + await StartAsync(); + await InitializeWithDraftVersionAsync(); + + // Send a tools/call where the body has a non-empty value but the header is empty + var callJson = CallTool("header_test", """{"region":"us-west1","priority":42,"verbose":false,"emptyVal":"some-value"}"""); + + using var request = new HttpRequestMessage(HttpMethod.Post, ""); + request.Content = new StringContent(callJson, Encoding.UTF8, "application/json"); + request.Headers.Add("MCP-Protocol-Version", "DRAFT-2026-v1"); + request.Headers.Add("Mcp-Method", "tools/call"); + request.Headers.Add("Mcp-Name", "header_test"); + request.Headers.Add("Mcp-Param-Region", "us-west1"); + request.Headers.Add("Mcp-Param-Priority", "42"); + request.Headers.Add("Mcp-Param-Verbose", "false"); + request.Headers.Add("Mcp-Param-EmptyVal", ""); + + using var response = await HttpClient.SendAsync(request, TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task Server_AcceptsBase64EncodedHeaderWithControlChars() + { + await StartAsync(); + await InitializeWithDraftVersionAsync(); + + // Encode a value with a newline control character using Base64 + var valueWithNewline = "line1\nline2"; + var encodedValue = McpHeaderEncoder.EncodeValue(valueWithNewline); + + var callJson = CallTool("header_test", $$"""{"region":"{{valueWithNewline.Replace("\n", "\\n")}}","priority":42,"verbose":false,"emptyVal":""}"""); + + using var request = new HttpRequestMessage(HttpMethod.Post, ""); + request.Content = new StringContent(callJson, Encoding.UTF8, "application/json"); + request.Headers.Add("MCP-Protocol-Version", "DRAFT-2026-v1"); + request.Headers.Add("Mcp-Method", "tools/call"); + request.Headers.Add("Mcp-Name", "header_test"); + request.Headers.Add("Mcp-Param-Region", encodedValue!); + request.Headers.Add("Mcp-Param-Priority", "42"); + request.Headers.Add("Mcp-Param-Verbose", "false"); + request.Headers.Add("Mcp-Param-EmptyVal", ""); + + using var response = await HttpClient.SendAsync(request, TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + [Fact] + public async Task Server_AcceptsLargeIntegerWithFullPrecision() + { + await StartAsync(); + await InitializeWithDraftVersionAsync(); + + // Use a large integer that would lose precision if converted through double + // 2^53 + 1 = 9007199254740993 (cannot be represented exactly as double) + const long largeInt = 9007199254740993L; + var callJson = CallTool("header_test", $$"""{"region":"test","priority":{{largeInt}},"verbose":false,"emptyVal":""}"""); + + using var request = new HttpRequestMessage(HttpMethod.Post, ""); + request.Content = new StringContent(callJson, Encoding.UTF8, "application/json"); + request.Headers.Add("MCP-Protocol-Version", "DRAFT-2026-v1"); + request.Headers.Add("Mcp-Method", "tools/call"); + request.Headers.Add("Mcp-Name", "header_test"); + request.Headers.Add("Mcp-Param-Region", "test"); + request.Headers.Add("Mcp-Param-Priority", largeInt.ToString()); + request.Headers.Add("Mcp-Param-Verbose", "false"); + request.Headers.Add("Mcp-Param-EmptyVal", ""); + + using var response = await HttpClient.SendAsync(request, TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + [Fact] + public async Task Server_SkipsHeaderValidation_ForNonDraftVersion() + { + await StartAsync(); + await InitializeWithNonDraftVersionAsync(); + + // With non-draft version, Mcp-Param-* headers are NOT validated even if mismatched + var callJson = CallTool("header_test", """{"region":"us-west1","priority":42,"verbose":false,"emptyVal":""}"""); + + using var request = new HttpRequestMessage(HttpMethod.Post, ""); + request.Content = new StringContent(callJson, Encoding.UTF8, "application/json"); + // Send the WRONG header value — this should still succeed because version is non-draft + request.Headers.Add("MCP-Protocol-Version", "2025-11-25"); + request.Headers.Add("Mcp-Method", "tools/call"); + request.Headers.Add("Mcp-Name", "header_test"); + request.Headers.Add("Mcp-Param-Region", "WRONG-VALUE"); + + using var response = await HttpClient.SendAsync(request, TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + #endregion + + #region Client-side encoding tests (unit tests for McpHeaderEncoder) + + [Theory] + [InlineData("hello\tworld")] + [InlineData("col1\tcol2\tcol3")] + public void Client_TabInValue_IsBase64Encoded(string value) + { + var encoded = McpHeaderEncoder.EncodeValue(value); + Assert.NotNull(encoded); + Assert.StartsWith("=?base64?", encoded); + Assert.EndsWith("?=", encoded); + + // Verify round-trip + var decoded = McpHeaderEncoder.DecodeValue(encoded); + Assert.Equal(value, decoded); + } + + [Theory] + [InlineData("simple-text", false)] + [InlineData("with space", false)] + [InlineData("Hello, 世界", true)] + [InlineData("line1\nline2", true)] + [InlineData("\ttab-start", true)] + [InlineData("mid\ttab", true)] + [InlineData("control\x01char", true)] + public void Client_EncodeValue_Base64OnlyWhenNeeded(string value, bool expectBase64) + { + var encoded = McpHeaderEncoder.EncodeValue(value); + Assert.NotNull(encoded); + + if (expectBase64) + { + Assert.StartsWith("=?base64?", encoded); + } + else + { + Assert.DoesNotContain("=?base64?", encoded); + } + + // All values must round-trip + var decoded = McpHeaderEncoder.DecodeValue(encoded); + Assert.Equal(value, decoded); + } + + [Fact] + public void Client_EncodeValue_LargeInteger_PreservesFullPrecision() + { + // 2^53 + 1 cannot be represented exactly as a double + var encoded = McpHeaderEncoder.EncodeValue(9007199254740993L); + Assert.Equal("9007199254740993", encoded); + } + + [Fact] + public void Client_EncodeValue_Boolean_EncodesCorrectly() + { + Assert.Equal("true", McpHeaderEncoder.EncodeValue(true)); + Assert.Equal("false", McpHeaderEncoder.EncodeValue(false)); + } + + #endregion + + #region Version gating tests + + [Theory] + [InlineData("DRAFT-2026-v1", true)] + [InlineData("2025-11-25", false)] + [InlineData("2025-06-18", false)] + [InlineData("2024-11-05", false)] + [InlineData(null, false)] + [InlineData("", false)] + public void SupportsStandardHeaders_CorrectlyGatesVersions(string? version, bool expected) + { + Assert.Equal(expected, McpHttpHeaders.SupportsStandardHeaders(version)); + } + + #endregion + + #region Helpers + + private async Task InitializeWithDraftVersionAsync() + { + HttpClient.DefaultRequestHeaders.Remove("mcp-session-id"); + + using var request = new HttpRequestMessage(HttpMethod.Post, ""); + request.Content = JsonContent(InitializeRequestDraft); + request.Headers.Add("MCP-Protocol-Version", "DRAFT-2026-v1"); + request.Headers.Add("Mcp-Method", "initialize"); + + using var response = await HttpClient.SendAsync(request, TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + var sessionId = Assert.Single(response.Headers.GetValues("mcp-session-id")); + HttpClient.DefaultRequestHeaders.Remove("mcp-session-id"); + HttpClient.DefaultRequestHeaders.Add("mcp-session-id", sessionId); + } + + private async Task InitializeWithNonDraftVersionAsync() + { + HttpClient.DefaultRequestHeaders.Remove("mcp-session-id"); + + using var response = await HttpClient.PostAsync("", JsonContent(InitializeRequest), TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + var sessionId = Assert.Single(response.Headers.GetValues("mcp-session-id")); + HttpClient.DefaultRequestHeaders.Remove("mcp-session-id"); + HttpClient.DefaultRequestHeaders.Add("mcp-session-id", sessionId); + } + + private static StringContent JsonContent(string json) => new(json, Encoding.UTF8, "application/json"); + + private long _lastRequestId = 1; + + private string CallTool(string toolName, string arguments = "{}") + { + var id = Interlocked.Increment(ref _lastRequestId); + return $$$""" + {"jsonrpc":"2.0","id":{{{id}}},"method":"tools/call","params":{"name":"{{{toolName}}}","arguments":{{{arguments}}}}} + """; + } + + private static string InitializeRequest => """ + {"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{},"clientInfo":{"name":"TestClient","version":"1.0"}}} + """; + + private static string InitializeRequestDraft => """ + {"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"DRAFT-2026-v1","capabilities":{},"clientInfo":{"name":"TestClient","version":"1.0"}}} + """; + + #endregion +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ServerConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/ServerConformanceTests.cs index e538a6f3f..618e8ae5b 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/ServerConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ServerConformanceTests.cs @@ -127,6 +127,30 @@ public async Task RunPendingConformanceTest_ServerSsePolling() $"Conformance test failed.\n\nStdout:\n{result.Output}\n\nStderr:\n{result.Error}"); } + [Fact] + public async Task RunConformanceTest_HttpHeaderValidation() + { + Assert.SkipWhen(!NodeHelpers.IsNodeInstalled(), "Node.js is not installed. Skipping conformance tests."); + Assert.SkipWhen(!NodeHelpers.IsConformanceVersionAtLeast("0.1.16"), "Conformance package >= 0.1.16 not available."); + + var result = await RunConformanceTestsAsync($"server --url {fixture.ServerUrl} --scenario http-header-validation"); + + Assert.True(result.Success, + $"Conformance test failed.\n\nStdout:\n{result.Output}\n\nStderr:\n{result.Error}"); + } + + [Fact] + public async Task RunConformanceTest_HttpCustomHeaderServerValidation() + { + Assert.SkipWhen(!NodeHelpers.IsNodeInstalled(), "Node.js is not installed. Skipping conformance tests."); + Assert.SkipWhen(!NodeHelpers.IsConformanceVersionAtLeast("0.1.16"), "Conformance package >= 0.1.16 not available."); + + var result = await RunConformanceTestsAsync($"server --url {fixture.ServerUrl} --scenario http-custom-header-server-validation"); + + Assert.True(result.Success, + $"Conformance test failed.\n\nStdout:\n{result.Output}\n\nStderr:\n{result.Error}"); + } + private async Task<(bool Success, string Output, string Error)> RunConformanceTestsAsync(string arguments) { var startInfo = NodeHelpers.ConformanceTestStartInfo(arguments); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs index c8e3f8d7b..517d41e02 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs @@ -549,6 +549,229 @@ private static string Echo(string message) return message; } + #region SEP-2243 Client Header Tests + + [Fact] + public async Task ListTools_FiltersToolsWithInvalidHeaderAnnotations() + { + // Start a mock server that returns tools with both valid and invalid x-mcp-header annotations + await StartHeaderToolServer(); + + await using var transport = new HttpClientTransport(new() + { + Endpoint = new("http://localhost:5000/mcp"), + TransportMode = HttpTransportMode.StreamableHttp, + }, HttpClient, LoggerFactory); + + await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // The server returns 3 tools: valid_tool, invalid_space_tool, invalid_duplicate_tool + // The client should filter out tools with invalid x-mcp-header annotations + var toolNames = tools.Select(t => t.Name).ToList(); + Assert.Contains("valid_tool", toolNames); + Assert.DoesNotContain("invalid_space_tool", toolNames); + Assert.DoesNotContain("invalid_duplicate_tool", toolNames); + } + + [Fact] + public async Task Client_SendsCorrectHeaders_EndToEnd() + { + // Start a server that captures request headers for verification + var capturedHeaders = new Dictionary(StringComparer.OrdinalIgnoreCase); + await StartHeaderCapturingServer(capturedHeaders); + + await using var transport = new HttpClientTransport(new() + { + Endpoint = new("http://localhost:5000/mcp"), + TransportMode = HttpTransportMode.StreamableHttp, + }, HttpClient, LoggerFactory); + + await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + var tool = Assert.Single(tools); + Assert.Equal("header_tool", tool.Name); + + // Call the tool — client should send Mcp-Param-* headers automatically + capturedHeaders.Clear(); + await tool.CallAsync(new Dictionary { ["region"] = "us-west-2" }, cancellationToken: TestContext.Current.CancellationToken); + + // Verify the client sent the correct headers + Assert.True(capturedHeaders.ContainsKey("Mcp-Method"), "Expected Mcp-Method header"); + Assert.Equal("tools/call", capturedHeaders["Mcp-Method"]); + Assert.True(capturedHeaders.ContainsKey("Mcp-Name"), "Expected Mcp-Name header"); + Assert.Equal("header_tool", capturedHeaders["Mcp-Name"]); + Assert.True(capturedHeaders.ContainsKey("Mcp-Param-Region"), "Expected Mcp-Param-Region header"); + Assert.Equal("us-west-2", capturedHeaders["Mcp-Param-Region"]); + } + + private async Task StartHeaderToolServer() + { + Builder.Services.Configure(options => + { + options.SerializerOptions.TypeInfoResolverChain.Add(McpJsonUtilities.DefaultOptions.TypeInfoResolver!); + }); + _app = Builder.Build(); + + _app.MapPost("/mcp", (JsonRpcMessage message) => + { + if (message is not JsonRpcRequest request) + { + return Results.Accepted(); + } + + if (request.Method == "initialize") + { + return Results.Json(new JsonRpcResponse + { + Id = request.Id, + Result = JsonSerializer.SerializeToNode(new InitializeResult + { + ProtocolVersion = "DRAFT-2026-v1", + Capabilities = new() { Tools = new() }, + ServerInfo = new Implementation { Name = "header-test-server", Version = "1.0" }, + }, McpJsonUtilities.DefaultOptions) + }); + } + + if (request.Method == "tools/list") + { + // Return tools with various x-mcp-header annotations — some valid, some invalid + var toolsJson = JsonSerializer.SerializeToNode(new ListToolsResult + { + Tools = + [ + CreateToolWithSchema("valid_tool", """ + { + "type": "object", + "properties": { + "region": { "type": "string", "x-mcp-header": "Region" } + } + } + """), + CreateToolWithSchema("invalid_space_tool", """ + { + "type": "object", + "properties": { + "value": { "type": "string", "x-mcp-header": "Invalid Name" } + } + } + """), + CreateToolWithSchema("invalid_duplicate_tool", """ + { + "type": "object", + "properties": { + "a": { "type": "string", "x-mcp-header": "Same" }, + "b": { "type": "string", "x-mcp-header": "Same" } + } + } + """), + ] + }, McpJsonUtilities.DefaultOptions); + + return Results.Json(new JsonRpcResponse + { + Id = request.Id, + Result = toolsJson, + }); + } + + return Results.Accepted(); + }); + + await _app.StartAsync(TestContext.Current.CancellationToken); + } + + private async Task StartHeaderCapturingServer(Dictionary capturedHeaders) + { + Builder.Services.Configure(options => + { + options.SerializerOptions.TypeInfoResolverChain.Add(McpJsonUtilities.DefaultOptions.TypeInfoResolver!); + }); + _app = Builder.Build(); + + _app.MapPost("/mcp", (JsonRpcMessage message, HttpContext context) => + { + if (message is not JsonRpcRequest request) + { + return Results.Accepted(); + } + + if (request.Method == "initialize") + { + return Results.Json(new JsonRpcResponse + { + Id = request.Id, + Result = JsonSerializer.SerializeToNode(new InitializeResult + { + ProtocolVersion = "DRAFT-2026-v1", + Capabilities = new() { Tools = new() }, + ServerInfo = new Implementation { Name = "header-capture", Version = "1.0" }, + }, McpJsonUtilities.DefaultOptions) + }); + } + + if (request.Method == "tools/list") + { + return Results.Json(new JsonRpcResponse + { + Id = request.Id, + Result = JsonSerializer.SerializeToNode(new ListToolsResult + { + Tools = [CreateToolWithSchema("header_tool", """ + { + "type": "object", + "properties": { + "region": { "type": "string", "x-mcp-header": "Region" } + }, + "required": ["region"] + } + """)] + }, McpJsonUtilities.DefaultOptions), + }); + } + + if (request.Method == "tools/call") + { + // Capture all MCP headers for verification + foreach (var header in context.Request.Headers) + { + if (header.Key.StartsWith("Mcp-", StringComparison.OrdinalIgnoreCase)) + { + capturedHeaders[header.Key] = header.Value.ToString(); + } + } + + var parameters = JsonSerializer.Deserialize(request.Params, GetJsonTypeInfo()); + return Results.Json(new JsonRpcResponse + { + Id = request.Id, + Result = JsonSerializer.SerializeToNode(new CallToolResult + { + Content = [new TextContentBlock { Text = "ok" }], + }, McpJsonUtilities.DefaultOptions), + }); + } + + return Results.Accepted(); + }); + + await _app.StartAsync(TestContext.Current.CancellationToken); + } + + private static Tool CreateToolWithSchema(string name, string schemaJson) + { + using var doc = JsonDocument.Parse(schemaJson); + return new Tool + { + Name = name, + InputSchema = doc.RootElement.Clone(), + }; + } + + #endregion + private sealed class ResumeTestServer { private static readonly Tool ResumeTool = new() diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs index bbe642ab6..38b1ca696 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs @@ -730,6 +730,98 @@ public async Task McpServer_UsedOutOfScope_CanSendNotifications() Assert.Equal(NotificationMethods.ResourceUpdatedNotification, notification.Method); } + #region SEP-2243 Header Validation Tests + + [Fact] + public async Task DraftVersion_RejectsMissingMcpMethodHeader() + { + await StartAsync(); + + // Initialize with draft version to enable header validation + await CallInitializeWithDraftVersionAndValidateAsync(); + + // Send a tools/call request without Mcp-Method header — should be rejected + using var request = new HttpRequestMessage(HttpMethod.Post, ""); + request.Content = JsonContent(CallTool("echo", """{"message":"test"}""")); + request.Headers.Add("MCP-Protocol-Version", "DRAFT-2026-v1"); + // Deliberately omit Mcp-Method header + + using var response = await HttpClient.SendAsync(request, TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task DraftVersion_RejectsMismatchedMcpMethodHeader() + { + await StartAsync(); + await CallInitializeWithDraftVersionAndValidateAsync(); + + // Send a tools/call request but set Mcp-Method to wrong value + using var request = new HttpRequestMessage(HttpMethod.Post, ""); + request.Content = JsonContent(CallTool("echo", """{"message":"test"}""")); + request.Headers.Add("MCP-Protocol-Version", "DRAFT-2026-v1"); + request.Headers.Add("Mcp-Method", "resources/read"); // Wrong method + + using var response = await HttpClient.SendAsync(request, TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task DraftVersion_AcceptsCorrectMcpMethodHeader() + { + await StartAsync(); + await CallInitializeWithDraftVersionAndValidateAsync(); + + // Send a tools/call request with correct Mcp-Method and Mcp-Name headers + using var request = new HttpRequestMessage(HttpMethod.Post, ""); + request.Content = JsonContent(CallTool("echo", """{"message":"hello"}""")); + request.Headers.Add("MCP-Protocol-Version", "DRAFT-2026-v1"); + request.Headers.Add("Mcp-Method", "tools/call"); + request.Headers.Add("Mcp-Name", "echo"); + + using var response = await HttpClient.SendAsync(request, TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + [Fact] + public async Task NonDraftVersion_DoesNotRequireMcpMethodHeader() + { + await StartAsync(); + await CallInitializeAndValidateAsync(); + + // With non-draft version, Mcp-Method header is not required + using var request = new HttpRequestMessage(HttpMethod.Post, ""); + request.Content = JsonContent(CallTool("echo", """{"message":"hello"}""")); + request.Headers.Add("MCP-Protocol-Version", "2025-03-26"); + // No Mcp-Method header — should still work + + using var response = await HttpClient.SendAsync(request, TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + private async Task CallInitializeWithDraftVersionAndValidateAsync() + { + HttpClient.DefaultRequestHeaders.Remove("mcp-session-id"); + + using var request = new HttpRequestMessage(HttpMethod.Post, ""); + request.Content = JsonContent(InitializeRequestDraft); + request.Headers.Add("MCP-Protocol-Version", "DRAFT-2026-v1"); + request.Headers.Add("Mcp-Method", "initialize"); + + using var response = await HttpClient.SendAsync(request, TestContext.Current.CancellationToken); + var rpcResponse = await AssertSingleSseResponseAsync(response); + AssertServerInfo(rpcResponse); + + var sessionId = Assert.Single(response.Headers.GetValues("mcp-session-id")); + SetSessionId(sessionId); + } + + private static string InitializeRequestDraft => """ + {"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"DRAFT-2026-v1","capabilities":{},"clientInfo":{"name":"IntegrationTestClient","version":"1.0.0"}}} + """; + + #endregion + private static StringContent JsonContent(string json) => new(json, Encoding.UTF8, "application/json"); private static JsonTypeInfo GetJsonTypeInfo() => (JsonTypeInfo)McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(T)); diff --git a/tests/ModelContextProtocol.ConformanceClient/Program.cs b/tests/ModelContextProtocol.ConformanceClient/Program.cs index b5e048dd0..7ce848907 100644 --- a/tests/ModelContextProtocol.ConformanceClient/Program.cs +++ b/tests/ModelContextProtocol.ConformanceClient/Program.cs @@ -179,6 +179,136 @@ } break; } + case "http-standard-headers": + { + // List and call tools to test Mcp-Method and Mcp-Name headers + var tools = await mcpClient.ListToolsAsync(); + Console.WriteLine($"Available tools: {string.Join(", ", tools.Select(t => t.Name))}"); + + var tool = tools.FirstOrDefault(t => t.Name == "test_headers"); + if (tool is not null) + { + Console.WriteLine("Calling tool: test_headers"); + var result = await mcpClient.CallToolAsync(toolName: "test_headers", arguments: new Dictionary()); + success &= !(result.IsError == true); + } + + // List and get prompts to test Mcp-Method and Mcp-Name headers + var prompts = await mcpClient.ListPromptsAsync(); + Console.WriteLine($"Available prompts: {string.Join(", ", prompts.Select(p => p.Name))}"); + + foreach (var prompt in prompts) + { + Console.WriteLine($"Getting prompt: {prompt.Name}"); + try + { + await mcpClient.GetPromptAsync(prompt.Name); + } + catch (Exception ex) + { + Console.WriteLine($"Prompt get error (expected for test): {ex.Message}"); + } + } + + // List and read resources to test Mcp-Name with params.uri + var resources = await mcpClient.ListResourcesAsync(); + Console.WriteLine($"Available resources: {string.Join(", ", resources.Select(r => r.Uri))}"); + + foreach (var resource in resources) + { + Console.WriteLine($"Reading resource: {resource.Uri}"); + try + { + await mcpClient.ReadResourceAsync(resource.Uri); + } + catch (Exception ex) + { + Console.WriteLine($"Resource read error (expected for test): {ex.Message}"); + } + } + break; + } + case "http-custom-headers": + { + // List tools to discover x-mcp-header annotations (populates tool cache) + var tools = await mcpClient.ListToolsAsync(); + Console.WriteLine($"Available tools: {string.Join(", ", tools.Select(t => t.Name))}"); + + // Parse conformance context for tool calls + if (!string.IsNullOrEmpty(conformanceContext)) + { + using var contextDoc = JsonDocument.Parse(conformanceContext); + + // Support both "toolCalls" (array) and legacy "toolCall" (single object) + var toolCallElements = new List(); + if (contextDoc.RootElement.TryGetProperty("toolCalls", out var toolCallsArray) && + toolCallsArray.ValueKind == JsonValueKind.Array) + { + foreach (var item in toolCallsArray.EnumerateArray()) + { + toolCallElements.Add(item); + } + } + else if (contextDoc.RootElement.TryGetProperty("toolCall", out var toolCallEl)) + { + toolCallElements.Add(toolCallEl); + } + + foreach (var toolCallEl in toolCallElements) + { + var toolName = toolCallEl.TryGetProperty("name", out var nameEl) + ? nameEl.GetString() ?? "test_custom_headers" + : "test_custom_headers"; + + Dictionary toolCallArgs = new(); + if (toolCallEl.TryGetProperty("arguments", out var argsEl)) + { + foreach (var prop in argsEl.EnumerateObject()) + { + object? value = prop.Value.ValueKind switch + { + JsonValueKind.String => prop.Value.GetString(), + JsonValueKind.Number => prop.Value.TryGetInt64(out var l) ? l : prop.Value.GetDouble(), + JsonValueKind.True => true, + JsonValueKind.False => false, + JsonValueKind.Null => null, + _ => prop.Value.GetRawText(), + }; + toolCallArgs[prop.Name] = value; + } + } + + Console.WriteLine($"Calling tool: {toolName} with {toolCallArgs.Count} arguments"); + var result = await mcpClient.CallToolAsync(toolName: toolName, arguments: toolCallArgs); + success &= !(result.IsError == true); + } + } + break; + } + case "http-invalid-tool-headers": + { + // List tools — the client should filter out tools with invalid x-mcp-header annotations + var tools = await mcpClient.ListToolsAsync(); + Console.WriteLine($"Available tools after filtering: {string.Join(", ", tools.Select(t => t.Name))}"); + + // Only call valid_tool — invalid tools should have been excluded + var validTool = tools.FirstOrDefault(t => t.Name == "valid_tool"); + if (validTool is not null) + { + Console.WriteLine("Calling valid_tool"); + var result = await mcpClient.CallToolAsync(toolName: "valid_tool", arguments: new Dictionary + { + { "region", "us-east1" } + }); + success &= !(result.IsError == true); + } + else + { + Console.WriteLine("ERROR: valid_tool was not found in the filtered tool list"); + success = false; + } + break; + } default: // No extra processing for other scenarios break; diff --git a/tests/ModelContextProtocol.ConformanceServer/Tools/ConformanceTools.cs b/tests/ModelContextProtocol.ConformanceServer/Tools/ConformanceTools.cs index d6db6f626..bef403404 100644 --- a/tests/ModelContextProtocol.ConformanceServer/Tools/ConformanceTools.cs +++ b/tests/ModelContextProtocol.ConformanceServer/Tools/ConformanceTools.cs @@ -442,4 +442,13 @@ public static string TestReconnection() // and the client must reconnect to get the result. return "Reconnection test completed successfully"; } + + [McpServerTool(Name = "test_header_tool")] + [Description("A tool with x-mcp-header annotations for conformance testing")] + public static string TestHeaderTool( + [McpHeader("Region"), Description("The deployment region")] string region, + [Description("The query to execute")] string query) + { + return $"Executed in region {region}: {query}"; + } } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Client/McpHeaderEncoderTests.cs b/tests/ModelContextProtocol.Tests/Client/McpHeaderEncoderTests.cs new file mode 100644 index 000000000..9f130db4c --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Client/McpHeaderEncoderTests.cs @@ -0,0 +1,163 @@ +using ModelContextProtocol.Client; + +namespace ModelContextProtocol.Tests.Client; + +public class McpHeaderEncoderTests +{ + [Theory] + [InlineData("us-west1", "us-west1")] + [InlineData("hello-world", "hello-world")] + [InlineData("my_tool_name", "my_tool_name")] + [InlineData("us west 1", "us west 1")] + [InlineData("", "")] + public void EncodeValue_PlainAscii_PassesThrough(string input, string expected) + { + var result = McpHeaderEncoder.EncodeValue(input); + Assert.Equal(expected, result); + } + + [Theory] + [InlineData(" us-west1", "=?base64?IHVzLXdlc3Qx?=")] + [InlineData("us-west1 ", "=?base64?dXMtd2VzdDEg?=")] + [InlineData(" us-west1 ", "=?base64?IHVzLXdlc3QxIA==?=")] + [InlineData("\tindented", "=?base64?CWluZGVudGVk?=")] + public void EncodeValue_LeadingTrailingWhitespace_Base64Encodes(string input, string expected) + { + var result = McpHeaderEncoder.EncodeValue(input); + Assert.Equal(expected, result); + } + + [Fact] + public void EncodeValue_NonAsciiCharacters_Base64Encodes() + { + var result = McpHeaderEncoder.EncodeValue("日本語"); + Assert.Equal("=?base64?5pel5pys6Kqe?=", result); + } + + [Fact] + public void EncodeValue_NewlineCharacter_Base64Encodes() + { + var result = McpHeaderEncoder.EncodeValue("line1\nline2"); + Assert.Equal("=?base64?bGluZTEKbGluZTI=?=", result); + } + + [Fact] + public void EncodeValue_CarriageReturnNewline_Base64Encodes() + { + var result = McpHeaderEncoder.EncodeValue("line1\r\nline2"); + Assert.Equal("=?base64?bGluZTENCmxpbmUy?=", result); + } + + [Theory] + [InlineData(true, "true")] + [InlineData(false, "false")] + public void EncodeValue_Boolean_ConvertsToLowercase(bool input, string expected) + { + var result = McpHeaderEncoder.EncodeValue(input); + Assert.Equal(expected, result); + } + + [Theory] + [InlineData(42, "42")] + [InlineData(3.14, "3.14")] + [InlineData(0, "0")] + [InlineData(-1, "-1")] + public void EncodeValue_Number_ConvertsToString(object input, string expected) + { + var result = McpHeaderEncoder.EncodeValue(input); + Assert.Equal(expected, result); + } + + [Fact] + public void EncodeValue_Null_ReturnsNull() + { + var result = McpHeaderEncoder.EncodeValue(null); + Assert.Null(result); + } + + [Fact] + public void EncodeValue_UnsupportedType_ReturnsNull() + { + var result = McpHeaderEncoder.EncodeValue(new object()); + Assert.Null(result); + } + + [Theory] + [InlineData("us-west1", "us-west1")] + [InlineData("", "")] + public void DecodeValue_PlainAscii_ReturnsAsIs(string input, string expected) + { + var result = McpHeaderEncoder.DecodeValue(input); + Assert.Equal(expected, result); + } + + [Fact] + public void DecodeValue_Null_ReturnsNull() + { + var result = McpHeaderEncoder.DecodeValue(null); + Assert.Null(result); + } + + [Fact] + public void DecodeValue_ValidBase64_Decodes() + { + var result = McpHeaderEncoder.DecodeValue("=?base64?SGVsbG8=?="); + Assert.Equal("Hello", result); + } + + [Fact] + public void DecodeValue_CaseInsensitivePrefix_Decodes() + { + var result = McpHeaderEncoder.DecodeValue("=?BASE64?SGVsbG8=?="); + Assert.Equal("Hello", result); + } + + [Fact] + public void DecodeValue_InvalidBase64_ReturnsNull() + { + var result = McpHeaderEncoder.DecodeValue("=?base64?SGVs!!!bG8=?="); + Assert.Null(result); + } + + [Fact] + public void DecodeValue_MissingPrefix_ReturnsLiteralValue() + { + var result = McpHeaderEncoder.DecodeValue("SGVsbG8="); + Assert.Equal("SGVsbG8=", result); + } + + [Fact] + public void DecodeValue_MissingSuffix_ReturnsLiteralValue() + { + var result = McpHeaderEncoder.DecodeValue("=?base64?SGVsbG8="); + Assert.Equal("=?base64?SGVsbG8=", result); + } + + [Theory] + [InlineData("us-west1")] + [InlineData("Hello, 世界")] + [InlineData(" padded ")] + [InlineData("line1\nline2")] + [InlineData("\tindented")] + [InlineData("a\tb")] + public void RoundTrip_EncodeDecode_PreservesValue(string original) + { + var encoded = McpHeaderEncoder.EncodeValue(original); + Assert.NotNull(encoded); + + var decoded = McpHeaderEncoder.DecodeValue(encoded); + Assert.Equal(original, decoded); + } + + [Fact] + public void EncodeValue_EmbeddedTab_Base64Encodes() + { + var result = McpHeaderEncoder.EncodeValue("col1\tcol2"); + Assert.StartsWith("=?base64?", result); + Assert.EndsWith("?=", result); + + // Verify round-trip + var decoded = McpHeaderEncoder.DecodeValue(result); + Assert.Equal("col1\tcol2", decoded); + } +} diff --git a/tests/ModelContextProtocol.Tests/Client/McpRequestHeadersTests.cs b/tests/ModelContextProtocol.Tests/Client/McpRequestHeadersTests.cs new file mode 100644 index 000000000..83f9e610f --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Client/McpRequestHeadersTests.cs @@ -0,0 +1,35 @@ +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Tests.Client; + +public class McpRequestHeadersTests +{ + [Fact] + public void McpHttpHeaders_HasCorrectValues() + { + Assert.Equal("Mcp-Session-Id", McpHttpHeaders.SessionId); + Assert.Equal("MCP-Protocol-Version", McpHttpHeaders.ProtocolVersion); + Assert.Equal("Last-Event-ID", McpHttpHeaders.LastEventId); + Assert.Equal("Mcp-Method", McpHttpHeaders.Method); + Assert.Equal("Mcp-Name", McpHttpHeaders.Name); + Assert.Equal("Mcp-Param-", McpHttpHeaders.ParamPrefix); + } + + [Fact] + public void McpErrorCode_HeaderMismatch_HasCorrectValue() + { + Assert.Equal(-32001, (int)McpErrorCode.HeaderMismatch); + } + + [Theory] + [InlineData("DRAFT-2026-v1", true)] + [InlineData("2025-11-25", false)] + [InlineData("2025-06-18", false)] + [InlineData("2024-11-05", false)] + [InlineData(null, false)] + [InlineData("", false)] + public void SupportsStandardHeaders_ReturnsExpected(string? version, bool expected) + { + Assert.Equal(expected, McpHttpHeaders.SupportsStandardHeaders(version)); + } +} diff --git a/tests/ModelContextProtocol.Tests/Server/McpHeaderAttributeTests.cs b/tests/ModelContextProtocol.Tests/Server/McpHeaderAttributeTests.cs new file mode 100644 index 000000000..694869af9 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/McpHeaderAttributeTests.cs @@ -0,0 +1,59 @@ +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.Tests.Server; + +public class McpHeaderAttributeTests +{ + [Theory] + [InlineData("Region")] + [InlineData("TenantId")] + [InlineData("Priority")] + [InlineData("X-Custom")] + public void Constructor_ValidHeaderName_Succeeds(string name) + { + var attr = new McpHeaderAttribute(name); + Assert.Equal(name, attr.Name); + } + + [Fact] + public void Constructor_NameWithSpace_Throws() + { + Assert.Throws(() => new McpHeaderAttribute("My Region")); + } + + [Fact] + public void Constructor_NameWithColon_Throws() + { + Assert.Throws(() => new McpHeaderAttribute("Region:Primary")); + } + + [Fact] + public void Constructor_NullName_Throws() + { + Assert.ThrowsAny(() => new McpHeaderAttribute(null!)); + } + + [Fact] + public void Constructor_EmptyName_Throws() + { + Assert.ThrowsAny(() => new McpHeaderAttribute("")); + } + + [Fact] + public void Constructor_WhitespaceName_Throws() + { + Assert.ThrowsAny(() => new McpHeaderAttribute(" ")); + } + + [Fact] + public void Constructor_NameWithControlCharacter_Throws() + { + Assert.Throws(() => new McpHeaderAttribute("Region\t1")); + } + + [Fact] + public void Constructor_NameWithNonAscii_Throws() + { + Assert.Throws(() => new McpHeaderAttribute("Région")); + } +} diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index 808ba7efe..a283bf18c 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -1177,4 +1177,136 @@ private static string SyncTool() [JsonSerializable(typeof(DateTimeOffset?))] [JsonSerializable(typeof(Person))] partial class JsonContext2 : JsonSerializerContext; + + // ===== x-mcp-header tests ===== + + [Fact] + public void Create_WithMcpHeaderAttribute_AddsXMcpHeaderExtension() + { + var tool = McpServerTool.Create(typeof(McpHeaderToolType).GetMethod(nameof(McpHeaderToolType.ToolWithSingleHeader))!); + var schema = tool.ProtocolTool.InputSchema; + var props = schema.GetProperty("properties"); + var regionProp = props.GetProperty("region"); + Assert.True(regionProp.TryGetProperty("x-mcp-header", out var headerValue)); + Assert.Equal("Region", headerValue.GetString()); + } + + [Fact] + public void Create_WithMultipleMcpHeaderAttributes_AddsAllExtensions() + { + var tool = McpServerTool.Create(typeof(McpHeaderToolType).GetMethod(nameof(McpHeaderToolType.ToolWithMultipleHeaders))!); + var schema = tool.ProtocolTool.InputSchema; + var props = schema.GetProperty("properties"); + + var regionProp = props.GetProperty("region"); + Assert.True(regionProp.TryGetProperty("x-mcp-header", out var regionHeader)); + Assert.Equal("Region", regionHeader.GetString()); + + var tenantProp = props.GetProperty("tenantId"); + Assert.True(tenantProp.TryGetProperty("x-mcp-header", out var tenantHeader)); + Assert.Equal("TenantId", tenantHeader.GetString()); + } + + [Fact] + public void Create_WithDuplicateHeaderNames_ThrowsInvalidOperationException() + { + Assert.Throws(() => + McpServerTool.Create(typeof(McpHeaderToolType).GetMethod(nameof(McpHeaderToolType.ToolWithDuplicateHeaders))!)); + } + + [Fact] + public void Create_WithMcpHeaderOnNonPrimitiveType_ThrowsInvalidOperationException() + { + Assert.Throws(() => + McpServerTool.Create(typeof(McpHeaderToolType).GetMethod(nameof(McpHeaderToolType.ToolWithNonPrimitiveHeader))!)); + } + + [Fact] + public void Create_WithMcpHeaderOnNumericType_AddsExtension() + { + var tool = McpServerTool.Create(typeof(McpHeaderToolType).GetMethod(nameof(McpHeaderToolType.ToolWithNumericHeader))!); + var schema = tool.ProtocolTool.InputSchema; + var props = schema.GetProperty("properties"); + var countProp = props.GetProperty("count"); + Assert.True(countProp.TryGetProperty("x-mcp-header", out var headerValue)); + Assert.Equal("Count", headerValue.GetString()); + } + + [Fact] + public void Create_WithMcpHeaderOnBooleanType_AddsExtension() + { + var tool = McpServerTool.Create(typeof(McpHeaderToolType).GetMethod(nameof(McpHeaderToolType.ToolWithBooleanHeader))!); + var schema = tool.ProtocolTool.InputSchema; + var props = schema.GetProperty("properties"); + var flagProp = props.GetProperty("flag"); + Assert.True(flagProp.TryGetProperty("x-mcp-header", out var headerValue)); + Assert.Equal("Flag", headerValue.GetString()); + } + + [Fact] + public void Create_WithMcpHeaderOnNullableType_AddsExtension() + { + var tool = McpServerTool.Create(typeof(McpHeaderToolType).GetMethod(nameof(McpHeaderToolType.ToolWithNullableHeader))!); + var schema = tool.ProtocolTool.InputSchema; + var props = schema.GetProperty("properties"); + var countProp = props.GetProperty("count"); + Assert.True(countProp.TryGetProperty("x-mcp-header", out var headerValue)); + Assert.Equal("Count", headerValue.GetString()); + } + + [Fact] + public void Create_WithoutMcpHeaderAttribute_NoXMcpHeaderExtension() + { + var tool = McpServerTool.Create(typeof(McpHeaderToolType).GetMethod(nameof(McpHeaderToolType.ToolWithoutHeaders))!); + var schema = tool.ProtocolTool.InputSchema; + var props = schema.GetProperty("properties"); + var regionProp = props.GetProperty("region"); + Assert.False(regionProp.TryGetProperty("x-mcp-header", out _)); + } + + private static class McpHeaderToolType + { + [McpServerTool] + public static string ToolWithSingleHeader( + [McpHeader("Region")] string region, + string query) + => "result"; + + [McpServerTool] + public static string ToolWithMultipleHeaders( + [McpHeader("Region")] string region, + [McpHeader("TenantId")] string tenantId, + string query) + => "result"; + + [McpServerTool] + public static string ToolWithDuplicateHeaders( + [McpHeader("Region")] string region1, + [McpHeader("REGION")] string region2) + => "result"; + + [McpServerTool] + public static string ToolWithNonPrimitiveHeader( + [McpHeader("Data")] object data) + => "result"; + + [McpServerTool] + public static string ToolWithNumericHeader( + [McpHeader("Count")] int count) + => "result"; + + [McpServerTool] + public static string ToolWithBooleanHeader( + [McpHeader("Flag")] bool flag) + => "result"; + + [McpServerTool] + public static string ToolWithNullableHeader( + [McpHeader("Count")] int? count) + => "result"; + + [McpServerTool] + public static string ToolWithoutHeaders(string region, string query) + => "result"; + } }