Skip to content
13 changes: 13 additions & 0 deletions src/Common/EncodingUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ public static byte[] GetUtf8Bytes(ReadOnlySpan<char> utf16)
return bytes;
}

/// <summary>Decodes a UTF-8 <see cref="ReadOnlySequence{T}"/> to a <see cref="string"/>.</summary>
public static string GetUtf8String(in ReadOnlySequence<byte> sequence)
{
if (sequence.IsEmpty)
{
return string.Empty;
}

return sequence.IsSingleSegment
? Encoding.UTF8.GetString(sequence.First.Span)
: Encoding.UTF8.GetString(sequence.ToArray());
}

/// <summary>
/// Encodes binary data to base64-encoded UTF-8 bytes.
/// </summary>
Expand Down
19 changes: 19 additions & 0 deletions src/Common/Polyfills/System/Text/EncodingExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,25 @@ public static int GetBytes(this Encoding encoding, ReadOnlySpan<char> chars, Spa
}
}
}

/// <summary>
/// Decodes all the bytes in the specified span into a string.
/// </summary>
public static string GetString(this Encoding encoding, ReadOnlySpan<byte> bytes)
{
if (bytes.IsEmpty)
{
return string.Empty;
}

unsafe
{
fixed (byte* bytesPtr = bytes)
{
return encoding.GetString(bytesPtr, bytes.Length);
}
}
}
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace ModelContextProtocol.Client;
/// <summary>Provides the client side of a stdio-based session transport.</summary>
internal sealed class StdioClientSessionTransport(
StdioClientTransportOptions options, Process process, string endpointName, Queue<string> stderrRollingLog, ILoggerFactory? loggerFactory) :
StreamClientSessionTransport(process.StandardInput.BaseStream, process.StandardOutput.BaseStream, encoding: null, endpointName, loggerFactory)
StreamClientSessionTransport(process.StandardInput.BaseStream, process.StandardOutput.BaseStream, endpointName, loggerFactory)
{
private readonly StdioClientTransportOptions _options = options;
private readonly Process _process = process;
Expand Down
10 changes: 6 additions & 4 deletions src/ModelContextProtocol.Core/Client/StdioClientTransport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ public sealed partial class StdioClientTransport : IClientTransport
private static readonly object s_consoleEncodingLock = new();
#endif

private static readonly UTF8Encoding s_noBomUtf8Encoding = new(encoderShouldEmitUTF8Identifier: false);

private readonly StdioClientTransportOptions _options;
private readonly ILoggerFactory? _loggerFactory;

Expand Down Expand Up @@ -85,10 +87,10 @@ public async Task<ITransport> ConnectAsync(CancellationToken cancellationToken =
UseShellExecute = false,
CreateNoWindow = true,
WorkingDirectory = _options.WorkingDirectory ?? Environment.CurrentDirectory,
StandardOutputEncoding = StreamClientSessionTransport.NoBomUtf8Encoding,
StandardErrorEncoding = StreamClientSessionTransport.NoBomUtf8Encoding,
StandardOutputEncoding = s_noBomUtf8Encoding,
StandardErrorEncoding = s_noBomUtf8Encoding,
#if NET
StandardInputEncoding = StreamClientSessionTransport.NoBomUtf8Encoding,
StandardInputEncoding = s_noBomUtf8Encoding,
#endif
};

Expand Down Expand Up @@ -173,7 +175,7 @@ public async Task<ITransport> ConnectAsync(CancellationToken cancellationToken =
Encoding originalInputEncoding = Console.InputEncoding;
try
{
Console.InputEncoding = StreamClientSessionTransport.NoBomUtf8Encoding;
Console.InputEncoding = s_noBomUtf8Encoding;
processStarted = process.Start();
}
finally
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Microsoft.Extensions.Logging;
using ModelContextProtocol.Protocol;
using System.Text;
using System.Buffers;
using System.IO.Pipelines;
using System.Text.Json;

namespace ModelContextProtocol.Client;
Expand All @@ -9,10 +10,9 @@ namespace ModelContextProtocol.Client;
internal class StreamClientSessionTransport : TransportBase
{
private static readonly byte[] s_newlineBytes = "\n"u8.ToArray();
private static readonly StreamPipeReaderOptions s_pipeReaderOptions = new(bufferSize: 64 * 1024); // 64KB minimum buffer

internal static UTF8Encoding NoBomUtf8Encoding { get; } = new(encoderShouldEmitUTF8Identifier: false);

private readonly TextReader _serverOutput;
private readonly PipeReader _serverOutputPipe;
private readonly Stream _serverInputStream;
private readonly SemaphoreSlim _sendLock = new(1, 1);
private CancellationTokenSource? _shutdownCts = new();
Expand All @@ -27,9 +27,6 @@ internal class StreamClientSessionTransport : TransportBase
/// <param name="serverOutput">
/// The server's output stream. Messages read from this stream will be received from the server.
/// </param>
/// <param name="encoding">
/// The encoding used for reading and writing messages from the input and output streams. Defaults to UTF-8 without BOM if null.
/// </param>
/// <param name="endpointName">
/// A name that identifies this transport endpoint in logs.
/// </param>
Expand All @@ -40,18 +37,14 @@ internal class StreamClientSessionTransport : TransportBase
/// This constructor starts a background task to read messages from the server output stream.
/// The transport will be marked as connected once initialized.
/// </remarks>
public StreamClientSessionTransport(Stream serverInput, Stream serverOutput, Encoding? encoding, string endpointName, ILoggerFactory? loggerFactory)
public StreamClientSessionTransport(Stream serverInput, Stream serverOutput, string endpointName, ILoggerFactory? loggerFactory)
: base(endpointName, loggerFactory)
{
Throw.IfNull(serverInput);
Throw.IfNull(serverOutput);

_serverInputStream = serverInput;
#if NET
_serverOutput = new StreamReader(serverOutput, encoding ?? NoBomUtf8Encoding);
#else
_serverOutput = new CancellableStreamReader(serverOutput, encoding ?? NoBomUtf8Encoding);
#endif
_serverOutputPipe = PipeReader.Create(serverOutput, s_pipeReaderOptions);

SetConnected();

Expand Down Expand Up @@ -102,24 +95,8 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken)
try
{
LogTransportEnteringReadMessagesLoop(Name);

while (true)
{
if (await _serverOutput.ReadLineAsync(cancellationToken).ConfigureAwait(false) is not string line)
{
LogTransportEndOfStream(Name);
break;
}

if (string.IsNullOrWhiteSpace(line))
{
continue;
}

LogTransportReceivedMessageSensitive(Name, line);

await ProcessMessageAsync(line, cancellationToken).ConfigureAwait(false);
}
await _serverOutputPipe.ReadLinesAsync(ProcessLineAsync, cancellationToken).ConfigureAwait(false);
LogTransportEndOfStream(Name);
}
catch (OperationCanceledException)
{
Expand All @@ -137,25 +114,43 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken)
}
}

private async Task ProcessMessageAsync(string line, CancellationToken cancellationToken)
private async Task ProcessLineAsync(ReadOnlySequence<byte> line, CancellationToken cancellationToken)
{
if (Logger.IsEnabled(LogLevel.Trace))
{
LogTransportReceivedMessageSensitive(Name, EncodingUtilities.GetUtf8String(line));
}

try
{
var message = (JsonRpcMessage?)JsonSerializer.Deserialize(line.AsSpan().Trim(), McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)));
if (message != null)
JsonRpcMessage? message;
if (line.IsSingleSegment)
{
message = JsonSerializer.Deserialize(line.First.Span, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))) as JsonRpcMessage;
}
else
{
var reader = new Utf8JsonReader(line, isFinalBlock: true, state: default);
message = JsonSerializer.Deserialize(ref reader, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))) as JsonRpcMessage;
}

if (message is not null)
{
await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
}
else
{
LogTransportMessageParseUnexpectedTypeSensitive(Name, line);
if (Logger.IsEnabled(LogLevel.Trace))
{
LogTransportMessageParseUnexpectedTypeSensitive(Name, EncodingUtilities.GetUtf8String(line));
}
}
}
catch (JsonException ex)
{
if (Logger.IsEnabled(LogLevel.Trace))
{
LogTransportMessageParseFailedSensitive(Name, line, ex);
LogTransportMessageParseFailedSensitive(Name, EncodingUtilities.GetUtf8String(line), ex);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ public Task<ITransport> ConnectAsync(CancellationToken cancellationToken = defau
return Task.FromResult<ITransport>(new StreamClientSessionTransport(
_serverInput,
_serverOutput,
encoding: null,
"Client (stream)",
_loggerFactory));
}
Expand Down
72 changes: 72 additions & 0 deletions src/ModelContextProtocol.Core/Protocol/PipeReaderExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using System.Buffers;
using System.IO.Pipelines;

namespace ModelContextProtocol.Protocol;

/// <summary>Internal helper for reading newline-delimited UTF-8 lines from a <see cref="PipeReader"/>.</summary>
internal static class PipeReaderExtensions
{
/// <summary>
/// Reads newline-delimited lines from <paramref name="reader"/>, invoking
/// <paramref name="processLine"/> for each non-empty line, until the reader signals completion.
/// </summary>
internal static async Task ReadLinesAsync(
this PipeReader reader,
Func<ReadOnlySequence<byte>, CancellationToken, Task> processLine,
CancellationToken cancellationToken)
{
while (true)
{
ReadResult result = await reader.ReadAsync(cancellationToken).ConfigureAwait(false);
ReadOnlySequence<byte> buffer = result.Buffer;

SequencePosition? position;
while ((position = buffer.PositionOf((byte)'\n')) != null)
{
ReadOnlySequence<byte> line = buffer.Slice(0, position.Value);

// Trim trailing \r for Windows-style CRLF line endings.
if (EndsWithCarriageReturn(line))
{
line = line.Slice(0, line.Length - 1);
}

if (!line.IsEmpty)
{
await processLine(line, cancellationToken).ConfigureAwait(false);
}

// Advance past the '\n'.
buffer = buffer.Slice(buffer.GetPosition(1, position.Value));
}

reader.AdvanceTo(buffer.Start, buffer.End);

if (result.IsCompleted)
{
break;
}
}
}

private static bool EndsWithCarriageReturn(in ReadOnlySequence<byte> sequence)
{
if (sequence.IsSingleSegment)
{
ReadOnlySpan<byte> span = sequence.First.Span;
return span.Length > 0 && span[span.Length - 1] == (byte)'\r';
}

// Multi-segment: find the last non-empty segment to check its last byte.
ReadOnlyMemory<byte> last = default;
foreach (ReadOnlyMemory<byte> segment in sequence)
{
if (!segment.IsEmpty)
{
last = segment;
}
}

return !last.IsEmpty && last.Span[last.Length - 1] == (byte)'\r';
}
}
Loading