Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions src/Grpc.Net.Client/Balancer/Subchannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,10 @@ public void RequestConnection()
}
}

if (connectionRequested)
{
UpdateConnectivityState(ConnectivityState.Connecting, "Connection requested.");
}
Debug.Assert(connectionRequested, "Ensure that only expected state made it to this point.");

SubchannelLog.StartingConnectionRequest(_logger, Id);
UpdateConnectivityState(ConnectivityState.Connecting, "Connection requested.");

// Don't capture the current ExecutionContext and its AsyncLocals onto the connect
var restoreFlow = false;
Expand Down Expand Up @@ -327,6 +327,15 @@ private async Task ConnectTransportAsync()
return;
}

// There is already a connect in-progress on this transport.
// Don't cancel and start again as that causes queued requests waiting on the connect to fail.
if (_connectContext != null && !_connectContext.Disposed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to check the file to see if this had a race condition (i.e. can _connectContext become null; it can't) - it isn't, but may I suggest:

if (_connectContext is { Disposed: false })

so that it is obviously not a race (i.e. single-read)?

{
SubchannelLog.ConnectionRequestedInNonIdleState(_logger, Id, _state);
_delayInterruptTcs?.TrySetResult(null);
return;
}

connectContext = GetConnectContextUnsynchronized();

// Use a semaphore to limit one connection attempt at a time. This is done to prevent a race conditional where a canceled connect
Expand Down Expand Up @@ -633,7 +642,11 @@ public static void AddressesUpdated(ILogger logger, string subchannelId, IReadOn
AddressesUpdated(logger, subchannelId, addressesText);
}
}

[LoggerMessage(Level = LogLevel.Debug, EventId = 20, EventName = "QueuingConnect", Message = "Subchannel id '{SubchannelId}' queuing connect because a connect is already in progress.")]
public static partial void QueuingConnect(ILogger logger, string subchannelId);

[LoggerMessage(Level = LogLevel.Trace, EventId = 21, EventName = "StartingConnectionRequest", Message = "Subchannel id '{SubchannelId}' starting connection request.")]
public static partial void StartingConnectionRequest(ILogger logger, string subchannelId);
}
#endif
122 changes: 122 additions & 0 deletions test/Grpc.Net.Client.Tests/Balancer/ConnectionManagerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,128 @@ public async Task PickAsync_UpdateAddressesWhileRequestingConnection_DoesNotDead
}
}

[Test]
public async Task PickAsync_MultipleRequestsRequestConnect_SingleConnectAttempt()
{
var services = new ServiceCollection();
services.AddNUnitLogger();

var testSink = new TestSink();
var testProvider = new TestLoggerProvider(testSink);

services.AddLogging(b =>
{
b.AddProvider(testProvider);
});

await using var serviceProvider = services.BuildServiceProvider();
var loggerFactory = serviceProvider.GetRequiredService<ILoggerFactory>();
var logger = loggerFactory.CreateLogger(nameof(PickAsync_MultipleRequestsRequestConnect_SingleConnectAttempt));

var requestConnectionChannel = Channel.CreateUnbounded<SyncPoint>();
var requestConnectionSyncPoint1 = new SyncPoint(runContinuationsAsynchronously: true);
var requestConnectionSyncPoint2 = new SyncPoint(runContinuationsAsynchronously: true);
requestConnectionChannel.Writer.TryWrite(requestConnectionSyncPoint1);
requestConnectionChannel.Writer.TryWrite(requestConnectionSyncPoint2);

var connectingSyncPoint = new SyncPoint(runContinuationsAsynchronously: true);

var resolver = new TestResolver(loggerFactory);
resolver.UpdateAddresses(new List<BalancerAddress>
{
new BalancerAddress("localhost", 80)
});

var channelOptions = new GrpcChannelOptions();
var acting = false;
var transportFactory = TestSubchannelTransportFactory.Create(async (subChannel, attempt, cancellationToken) =>
{
cancellationToken.Register(() =>
{
logger.LogError("Connect cancellation token canceled.");
});

if (!acting)
{
return new TryConnectResult(ConnectivityState.Ready);
}

await connectingSyncPoint.WaitToContinue().WaitAsync(cancellationToken);

Assert.IsFalse(cancellationToken.IsCancellationRequested, "Cancellation token should not be canceled.");

return new TryConnectResult(ConnectivityState.Ready);
});
var clientChannel = CreateConnectionManager(loggerFactory, resolver, transportFactory, new[] { new PickFirstBalancerFactory() });
// Configure balancer similar to how GrpcChannel constructor does it
clientChannel.ConfigureBalancer(c => new ChildHandlerLoadBalancer(
c,
channelOptions.ServiceConfig,
clientChannel));

await clientChannel.ConnectAsync(waitForReady: true, cancellationToken: CancellationToken.None);

transportFactory.Transports.ForEach(t => t.Disconnect());

testSink.MessageLogged += (w) =>
{
if (w.EventId.Name == "StartingConnectionRequest")
{
if (!requestConnectionChannel.Reader.TryRead(out var syncPoint))
{
throw new InvalidOperationException("Channel should have sync point.");
}
syncPoint.WaitToContinue().Wait();
}
};

acting = true;

logger.LogInformation("Start first pick.");
var pickTask1 = Task.Run(() => clientChannel.PickAsync(
new PickContext { Request = new HttpRequestMessage() },
waitForReady: true,
CancellationToken.None).AsTask());

logger.LogInformation("Wait for first pick to request connection.");
await requestConnectionSyncPoint1.WaitForSyncPoint().DefaultTimeout();

logger.LogInformation("Start second pick.");
var pickTask2 = Task.Run(() => clientChannel.PickAsync(
new PickContext { Request = new HttpRequestMessage() },
waitForReady: true,
CancellationToken.None).AsTask());

logger.LogInformation("Wait for second pick to request connection.");
await requestConnectionSyncPoint2.WaitForSyncPoint().DefaultTimeout();

logger.LogInformation("Allow first pick to start connecting.");
requestConnectionSyncPoint1.Continue();
await connectingSyncPoint.WaitForSyncPoint();

var connectionRequestedInNonIdleStateTcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
testSink.MessageLogged += (w) =>
{
if (w.EventId.Name == "ConnectionRequestedInNonIdleState")
{
connectionRequestedInNonIdleStateTcs.TrySetResult(null);
}
};

logger.LogInformation("Allow second pick to wait for connecting to complete.");
requestConnectionSyncPoint2.Continue();

logger.LogInformation("Wait for second pick to report that there is already a connection requested.");
await connectionRequestedInNonIdleStateTcs.Task.DefaultTimeout();

logger.LogInformation("Allow first pick connecting to complete.");
connectingSyncPoint.Continue();

logger.LogInformation("Wait for both picks to complete successfully.");
await pickTask1.DefaultTimeout();
await pickTask2.DefaultTimeout();
}

[Test]
public async Task PickAsync_ExecutionContext_DoesNotCaptureAsyncLocalsInConnect()
{
Expand Down