diff --git a/src/Cli/dotnet/Commands/Test/MTP/IPC/NamedPipeServer.cs b/src/Cli/dotnet/Commands/Test/MTP/IPC/NamedPipeServer.cs index b06261dde858..ba9c04f6fe76 100644 --- a/src/Cli/dotnet/Commands/Test/MTP/IPC/NamedPipeServer.cs +++ b/src/Cli/dotnet/Commands/Test/MTP/IPC/NamedPipeServer.cs @@ -32,7 +32,15 @@ public NamedPipeServer( CancellationToken cancellationToken, bool skipUnknownMessages) { - _namedPipeServerStream = new(pipeName, PipeDirection.InOut, maxNumberOfServerInstances, PipeTransmissionMode.Byte, PipeOptions.Asynchronous | PipeOptions.CurrentUserOnly); + _namedPipeServerStream = new NamedPipeServerStream( + pipeName, + PipeDirection.InOut, + maxNumberOfServerInstances, + PipeTransmissionMode.Byte, + PipeOptions.Asynchronous | PipeOptions.CurrentUserOnly, + inBufferSize: 0, + outBufferSize: 0); + _callback = callback; _cancellationToken = cancellationToken; _skipUnknownMessages = skipUnknownMessages; @@ -67,51 +75,69 @@ public async Task WaitConnectionAsync(CancellationToken cancellationToken) /// private async Task InternalLoopAsync(CancellationToken cancellationToken) { - int currentMessageSize = 0; - int missingBytesToReadOfWholeMessage = 0; + // This is an indicator when reading from the pipe whether we are at the start of a new message (i.e, we should read 4 bytes as message size) + // Note that the implementation assumes no overlapping messages in the pipe. + // The flow goes like: + // 1. MTP sends a request (and acquires lock). + // 2. SDK reads the request. + // 3. SDK sends a response. + // 4. MTP reads the response (and releases lock). + // This means that no two requests can be in the pipe at the same time. + bool isStartOfNewMessage = true; + int remainingBytesToReadOfWholeMessage = 0; while (!cancellationToken.IsCancellationRequested) { - int missingBytesToReadOfCurrentChunk = 0; - int currentReadIndex = 0; - int currentReadBytes = await _namedPipeServerStream.ReadAsync(_readBuffer.AsMemory(currentReadIndex, _readBuffer.Length), cancellationToken); - if (currentReadBytes == 0) + // If we are at the start of a new message, we need to read at least the message size. + int currentReadBytes = isStartOfNewMessage + ? await _namedPipeServerStream.ReadAtLeastAsync(_readBuffer, minimumBytes: sizeof(int), throwOnEndOfStream: false, cancellationToken) + : await _namedPipeServerStream.ReadAsync(_readBuffer, cancellationToken); + + if (currentReadBytes == 0 || (isStartOfNewMessage && currentReadBytes < sizeof(int))) { // The client has disconnected return; } - // Reset the current chunk size - missingBytesToReadOfCurrentChunk = currentReadBytes; + // The local remainingBytesToProcess tracks the remaining bytes of what we have read from the pipe but not yet processed. + // At the beginning here, it contains everything we have read from the pipe. + // As we are processing the data in it, we continue to slice it. + Memory remainingBytesToProcess = _readBuffer.AsMemory(0, currentReadBytes); - // If currentRequestSize is 0, we need to read the message size - if (currentMessageSize == 0) + // If the current read is the start of a new message, we need to read the message size first. + if (isStartOfNewMessage) { // We need to read the message size, first 4 bytes - if (currentReadBytes < sizeof(int)) - { - throw new UnreachableException(CliCommandStrings.DotnetTestPipeIncompleteSize); - } + remainingBytesToReadOfWholeMessage = BitConverter.ToInt32(remainingBytesToProcess.Span); + + // Now that we have read the size, we slice the remainingBytesToProcess. + remainingBytesToProcess = remainingBytesToProcess.Slice(sizeof(int)); - currentMessageSize = BitConverter.ToInt32(_readBuffer, 0); - missingBytesToReadOfCurrentChunk = currentReadBytes - sizeof(int); - missingBytesToReadOfWholeMessage = currentMessageSize; - currentReadIndex = sizeof(int); + // Now that we have read the size, we are no longer at the start of a new message. + // If the current chunk ended up to be the full message, we will set this back to true later. + isStartOfNewMessage = false; } - if (missingBytesToReadOfCurrentChunk > 0) + // We read the rest of the message. + // Note that this assumes that no messages are overlapping in the pipe. + if (remainingBytesToProcess.Length > 0) { // We need to read the rest of the message - await _messageBuffer.WriteAsync(_readBuffer.AsMemory(currentReadIndex, missingBytesToReadOfCurrentChunk), cancellationToken); - missingBytesToReadOfWholeMessage -= missingBytesToReadOfCurrentChunk; + await _messageBuffer.WriteAsync(remainingBytesToProcess, cancellationToken); + remainingBytesToReadOfWholeMessage -= remainingBytesToProcess.Length; + + // At this point, we have read everything in the remainingBytesToProcess. + // Note that while remainingBytesToProcess isn't accessed after this point, we still maintain the + // invariant that it tracks what we have read from the pipe but not yet processed. + remainingBytesToProcess = Memory.Empty; } - if (missingBytesToReadOfWholeMessage < 0) + if (remainingBytesToReadOfWholeMessage < 0) { throw new UnreachableException(CliCommandStrings.DotnetTestPipeOverlapping); } // If we have read all the message, we can deserialize it - if (missingBytesToReadOfWholeMessage == 0) + if (remainingBytesToReadOfWholeMessage == 0) { // Deserialize the message _messageBuffer.Position = 0; @@ -147,12 +173,19 @@ private async Task InternalLoopAsync(CancellationToken cancellationToken) // Write the message size byte[] bytes = _sizeOfIntArray; - BitConverter.TryWriteBytes(bytes, sizeOfTheWholeMessage); + if (!BitConverter.TryWriteBytes(bytes, sizeOfTheWholeMessage)) + { + throw new UnreachableException(); + } + await _messageBuffer.WriteAsync(bytes, cancellationToken); // Write the serializer id bytes = _sizeOfIntArray; - BitConverter.TryWriteBytes(bytes, responseNamedPipeSerializer.Id); + if (!BitConverter.TryWriteBytes(bytes, responseNamedPipeSerializer.Id)) + { + throw new UnreachableException(); + } await _messageBuffer.WriteAsync(bytes.AsMemory(0, sizeof(int)), cancellationToken); @@ -164,10 +197,6 @@ private async Task InternalLoopAsync(CancellationToken cancellationToken) { await _namedPipeServerStream.WriteAsync(_messageBuffer.GetBuffer().AsMemory(0, (int)_messageBuffer.Position), cancellationToken); await _namedPipeServerStream.FlushAsync(cancellationToken); - if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - _namedPipeServerStream.WaitForPipeDrain(); - } } finally { @@ -177,8 +206,8 @@ private async Task InternalLoopAsync(CancellationToken cancellationToken) } // Reset the control variables - currentMessageSize = 0; - missingBytesToReadOfWholeMessage = 0; + isStartOfNewMessage = true; + remainingBytesToReadOfWholeMessage = 0; } } } diff --git a/test/dotnet.Tests/CommandTests/Test/IPCTests.cs b/test/dotnet.Tests/CommandTests/Test/IPCTests.cs new file mode 100644 index 000000000000..124580f0f64f --- /dev/null +++ b/test/dotnet.Tests/CommandTests/Test/IPCTests.cs @@ -0,0 +1,255 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.IO.Pipes; +using Microsoft.DotNet.Cli.Commands.Test.IPC; +using Microsoft.DotNet.Cli.Commands.Test.IPC.Models; +using Microsoft.DotNet.Cli.Commands.Test.IPC.Serializers; + +namespace dotnet.Tests.CommandTests.Test; + +public class IPCTests +{ + [Fact] + public async Task SingleConnectionNamedPipeServer_MultipleConnection_Fails() + { + string pipeName = NamedPipeServer.GetPipeName(Guid.NewGuid().ToString("N")); + + List openedPipes = []; + List exceptions = []; + + ManualResetEventSlim waitException = new(false); + var waitTask = Task.Run( + async () => + { + try + { + while (true) + { + var singleConnectionNamedPipeServer = new NamedPipeServer( + pipeName, + (_, _) => Task.FromResult(VoidResponse.CachedInstance), + maxNumberOfServerInstances: 1, + CancellationToken.None, + skipUnknownMessages: false); + + await singleConnectionNamedPipeServer.WaitConnectionAsync(CancellationToken.None); + openedPipes.Add(singleConnectionNamedPipeServer); + } + } + catch (Exception ex) + { + exceptions.Add(ex); + waitException.Set(); + } + }); + + var namedPipeClient1 = new NamedPipeClient(pipeName); + await namedPipeClient1.ConnectAsync(CancellationToken.None); + waitException.Wait(); + + var openedPipe = Assert.Single(openedPipes); + var exception = Assert.Single(exceptions); + Assert.Equal(typeof(IOException), exception.GetType()); + Assert.Contains("All pipe instances are busy.", exception.Message); + + await waitTask; + namedPipeClient1.Dispose(); + openedPipe.Dispose(); + + // Verify double dispose + namedPipeClient1.Dispose(); + openedPipe.Dispose(); + } + + // CAREFUL: This test produces random test cases. + // So, flakiness in this test might be an indicator to a serious product bug. + [Fact] + public async Task SingleConnectionNamedPipeServer_RequestReplySerialization_Succeeded() + { + Queue receivedMessages = new(); + string pipeName = NamedPipeServer.GetPipeName(Guid.NewGuid().ToString("N")); + NamedPipeClient namedPipeClient = new(pipeName); + namedPipeClient.RegisterSerializer(new VoidResponseSerializer(), typeof(VoidResponse)); + namedPipeClient.RegisterSerializer(new TextMessageSerializer(), typeof(TextMessage)); + namedPipeClient.RegisterSerializer(new IntMessageSerializer(), typeof(IntMessage)); + namedPipeClient.RegisterSerializer(new LongMessageSerializer(), typeof(LongMessage)); + + ManualResetEventSlim manualResetEventSlim = new(false); + var clientConnected = Task.Run( + async () => + { + while (true) + { + try + { + await namedPipeClient.ConnectAsync(CancellationToken.None); + manualResetEventSlim.Set(); + break; + } + catch (OperationCanceledException) + { + throw new OperationCanceledException("SingleConnectionNamedPipeServer_RequestReplySerialization_Succeeded cancellation during connect"); + } + catch (Exception) + { + } + } + }, CancellationToken.None); + NamedPipeServer singleConnectionNamedPipeServer = new( + pipeName, + (_, request) => + { + receivedMessages.Enqueue((BaseMessage)request); + return Task.FromResult(VoidResponse.CachedInstance); + }, + NamedPipeServerStream.MaxAllowedServerInstances, + CancellationToken.None, + skipUnknownMessages: false); + singleConnectionNamedPipeServer.RegisterSerializer(new VoidResponseSerializer(), typeof(VoidResponse)); + singleConnectionNamedPipeServer.RegisterSerializer(new TextMessageSerializer(), typeof(TextMessage)); + singleConnectionNamedPipeServer.RegisterSerializer(new IntMessageSerializer(), typeof(IntMessage)); + singleConnectionNamedPipeServer.RegisterSerializer(new LongMessageSerializer(), typeof(LongMessage)); + await singleConnectionNamedPipeServer.WaitConnectionAsync(CancellationToken.None); + manualResetEventSlim.Wait(); + + await clientConnected; + + await namedPipeClient.RequestReplyAsync(new IntMessage(10), CancellationToken.None); + Assert.Equal(new IntMessage(10), receivedMessages.Dequeue()); + + await namedPipeClient.RequestReplyAsync(new LongMessage(11), CancellationToken.None); + Assert.Equal(new LongMessage(11), receivedMessages.Dequeue()); + + for (int i = 0; i < 100; i++) + { + await AssertWithLengthAsync(Random.Shared.Next(1024, 1024 * 1024 * 2)); + } + + // NOTE: 250000 is the buffer size of NamedPipeServer. + // We explicitly test around this size (and multiple of it) as most potential bugs can be around it. + for (int multiple = 1; multiple <= 3; multiple++) + { + const int namedPipeServerBufferSize = 250000; + int minLength = namedPipeServerBufferSize * multiple - 1000; + int maxLength = namedPipeServerBufferSize * multiple + 1000; + for (int randomLength = minLength; randomLength <= maxLength; randomLength++) + { + await AssertWithLengthAsync(randomLength); + } + } + + namedPipeClient.Dispose(); + singleConnectionNamedPipeServer.Dispose(); + + async Task AssertWithLengthAsync(int length) + { + string currentString = RandomString(length); + await namedPipeClient.RequestReplyAsync(new TextMessage(currentString), CancellationToken.None); + Assert.Single(receivedMessages); + Assert.Equal(new TextMessage(currentString), receivedMessages.Dequeue()); + } + } + + [Fact] + public async Task ConnectionNamedPipeServer_MultipleConnection_Succeeds() + { + string pipeName = NamedPipeServer.GetPipeName(Guid.NewGuid().ToString("N")); + + List pipes = []; + for (int i = 0; i < 3; i++) + { + pipes.Add(new NamedPipeServer( + pipeName, + (_, _) => Task.FromResult(VoidResponse.CachedInstance), + maxNumberOfServerInstances: 3, + CancellationToken.None, + skipUnknownMessages: false)); + } + + IOException exception = Assert.Throws(() => + new NamedPipeServer( + pipeName, + (_, _) => Task.FromResult(VoidResponse.CachedInstance), + maxNumberOfServerInstances: 3, + CancellationToken.None, + skipUnknownMessages: false)); + Assert.Contains("All pipe instances are busy.", exception.Message); + + List waitConnectionTask = []; + int connectionCompleted = 0; + foreach (NamedPipeServer namedPipeServer in pipes) + { + waitConnectionTask.Add(Task.Run( + async () => + { + await namedPipeServer.WaitConnectionAsync(CancellationToken.None); + Interlocked.Increment(ref connectionCompleted); + }, CancellationToken.None)); + } + + List connectedClients = []; + for (int i = 0; i < waitConnectionTask.Count; i++) + { + var namedPipeClient = new NamedPipeClient(pipeName); + connectedClients.Add(namedPipeClient); + await namedPipeClient.ConnectAsync(CancellationToken.None); + } + + await Task.WhenAll([.. waitConnectionTask]); + + Assert.Equal(3, connectionCompleted); + + foreach (NamedPipeClient namedPipeClient in connectedClients) + { + namedPipeClient.Dispose(); + } + + foreach (NamedPipeServer namedPipeServer in pipes) + { + namedPipeServer.Dispose(); + } + } + + private static string RandomString(int length) + { + const string chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + return new string([.. Enumerable.Repeat(chars, length).Select(s => s[Random.Shared.Next(s.Length)])]); + } + + private abstract record BaseMessage : IRequest; + + private sealed record TextMessage(string Text) : BaseMessage; + + private sealed class TextMessageSerializer : BaseSerializer, INamedPipeSerializer + { + public int Id => 2; + + public object Deserialize(Stream stream) => new TextMessage(ReadString(stream)); + + public void Serialize(object objectToSerialize, Stream stream) => WriteString(stream, ((TextMessage)objectToSerialize).Text); + } + + private sealed record IntMessage(int Integer) : BaseMessage; + + private sealed class IntMessageSerializer : BaseSerializer, INamedPipeSerializer + { + public int Id => 3; + + public object Deserialize(Stream stream) => new IntMessage(ReadInt(stream)); + + public void Serialize(object objectToSerialize, Stream stream) => WriteInt(stream, ((IntMessage)objectToSerialize).Integer); + } + + private sealed record LongMessage(long Long) : BaseMessage; + + private sealed class LongMessageSerializer : BaseSerializer, INamedPipeSerializer + { + public int Id => 4; + + public object Deserialize(Stream stream) => new LongMessage(ReadInt(stream)); + + public void Serialize(object objectToSerialize, Stream stream) => WriteLong(stream, ((LongMessage)objectToSerialize).Long); + } + +} diff --git a/test/dotnet.Tests/CommandTests/Test/NamedPipeClient.cs b/test/dotnet.Tests/CommandTests/Test/NamedPipeClient.cs new file mode 100644 index 000000000000..531ba8d08163 --- /dev/null +++ b/test/dotnet.Tests/CommandTests/Test/NamedPipeClient.cs @@ -0,0 +1,190 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace dotnet.Tests.CommandTests.Test; + +using System.Buffers; +using System.IO.Pipes; +using Microsoft.DotNet.Cli.Commands.Test.IPC; + +internal sealed class NamedPipeClient : NamedPipeBase +{ + private readonly NamedPipeClientStream _namedPipeClientStream; + private readonly SemaphoreSlim _lock = new(1, 1); + + private readonly MemoryStream _serializationBuffer = new(); + private readonly MemoryStream _messageBuffer = new(); + private readonly byte[] _readBuffer = new byte[250000]; + + private bool _disposed; + + public NamedPipeClient(string name) + { + _namedPipeClientStream = new(".", name, PipeDirection.InOut, PipeOptions.CurrentUserOnly); + PipeName = name; + } + + public string PipeName { get; } + + public async Task ConnectAsync(CancellationToken cancellationToken) + => await _namedPipeClientStream.ConnectAsync(cancellationToken).ConfigureAwait(false); + + public async Task RequestReplyAsync(TRequest request, CancellationToken cancellationToken) + where TRequest : IRequest + where TResponse : IResponse + { + await _lock.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + INamedPipeSerializer requestNamedPipeSerializer = GetSerializer(typeof(TRequest)); + + // Ask to serialize the body + _serializationBuffer.Position = 0; + requestNamedPipeSerializer.Serialize(request, _serializationBuffer); + + // Write the message size + _messageBuffer.Position = 0; + + // The length of the message is the size of the message plus one byte to store the serializer id + // Space for the message + int sizeOfTheWholeMessage = (int)_serializationBuffer.Position; + + // Space for the serializer id + sizeOfTheWholeMessage += sizeof(int); + + // Write the message size + byte[] bytes = ArrayPool.Shared.Rent(sizeof(int)); + try + { + BitConverter.TryWriteBytes(bytes, sizeOfTheWholeMessage); + await _messageBuffer.WriteAsync(bytes.AsMemory(0, sizeof(int)), cancellationToken).ConfigureAwait(false); + } + finally + { + ArrayPool.Shared.Return(bytes); + } + + // Write the serializer id + bytes = ArrayPool.Shared.Rent(sizeof(int)); + try + { + BitConverter.TryWriteBytes(bytes, requestNamedPipeSerializer.Id); + await _messageBuffer.WriteAsync(bytes.AsMemory(0, sizeof(int)), cancellationToken).ConfigureAwait(false); + } + finally + { + ArrayPool.Shared.Return(bytes); + } + + try + { + // Write the message + await _messageBuffer.WriteAsync(_serializationBuffer.GetBuffer().AsMemory(0, (int)_serializationBuffer.Position), cancellationToken).ConfigureAwait(false); + } + finally + { + // Reset the serialization buffer + _serializationBuffer.Position = 0; + } + + // Send the message + try + { + await _namedPipeClientStream.WriteAsync(_messageBuffer.GetBuffer().AsMemory(0, (int)_messageBuffer.Position), cancellationToken).ConfigureAwait(false); + await _namedPipeClientStream.FlushAsync(cancellationToken).ConfigureAwait(false); + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + _namedPipeClientStream.WaitForPipeDrain(); + } + } + finally + { + // Reset the buffers + _messageBuffer.Position = 0; + _serializationBuffer.Position = 0; + } + + // Read the response + int currentMessageSize = 0; + int missingBytesToReadOfWholeMessage = 0; + while (true) + { + int currentReadIndex = 0; + int currentReadBytes = await _namedPipeClientStream.ReadAsync(_readBuffer.AsMemory(currentReadIndex, _readBuffer.Length), cancellationToken).ConfigureAwait(false); + + if (currentReadBytes == 0) + { + // We are reading a message response. + // If we cannot get a response, there is no way we can recover and continue executing. + // This can happen if the other processes gets killed or crashes while while it's sending the response. + // This is especially important for 'dotnet test', where the user can simply kill the dotnet.exe process themselves. + // In that case, we want the MTP process to also die. + Environment.FailFast("[NamedPipeClient] Connection lost with the other side."); + } + + // Reset the current chunk size + int missingBytesToReadOfCurrentChunk = currentReadBytes; + + // If currentRequestSize is 0, we need to read the message size + if (currentMessageSize == 0) + { + // We need to read the message size, first 4 bytes + currentMessageSize = BitConverter.ToInt32(_readBuffer, 0); + missingBytesToReadOfCurrentChunk = currentReadBytes - sizeof(int); + missingBytesToReadOfWholeMessage = currentMessageSize; + currentReadIndex = sizeof(int); + } + + if (missingBytesToReadOfCurrentChunk > 0) + { + // We need to read the rest of the message + await _messageBuffer.WriteAsync(_readBuffer.AsMemory(currentReadIndex, missingBytesToReadOfCurrentChunk), cancellationToken).ConfigureAwait(false); + missingBytesToReadOfWholeMessage -= missingBytesToReadOfCurrentChunk; + } + + // If we have read all the message, we can deserialize it + if (missingBytesToReadOfWholeMessage == 0) + { + // Deserialize the message + _messageBuffer.Position = 0; + + // Get the serializer id + int serializerId = BitConverter.ToInt32(_messageBuffer.GetBuffer(), 0); + + // Get the serializer + _messageBuffer.Position += sizeof(int); // Skip the serializer id + INamedPipeSerializer responseNamedPipeSerializer = GetSerializer(serializerId); + + // Deserialize the message + try + { + return (TResponse)responseNamedPipeSerializer.Deserialize(_messageBuffer); + } + finally + { + // Reset the message buffer + _messageBuffer.Position = 0; + } + } + } + } + finally + { + _lock.Release(); + } + } + + public void Dispose() + { + if (_disposed) + { + return; + } + + _lock.Dispose(); + _serializationBuffer.Dispose(); + _messageBuffer.Dispose(); + _namedPipeClientStream.Dispose(); + _disposed = true; + } +}