Skip to content

Commit 6f03f0e

Browse files
authored
Fix race condition that caused inprogress connect to be canceled (#2618)
1 parent 30553b5 commit 6f03f0e

File tree

2 files changed

+139
-4
lines changed

2 files changed

+139
-4
lines changed

src/Grpc.Net.Client/Balancer/Subchannel.cs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,10 @@ public void RequestConnection()
266266
}
267267
}
268268

269-
if (connectionRequested)
270-
{
271-
UpdateConnectivityState(ConnectivityState.Connecting, "Connection requested.");
272-
}
269+
Debug.Assert(connectionRequested, "Ensure that only expected state made it to this point.");
270+
271+
SubchannelLog.StartingConnectionRequest(_logger, Id);
272+
UpdateConnectivityState(ConnectivityState.Connecting, "Connection requested.");
273273

274274
// Don't capture the current ExecutionContext and its AsyncLocals onto the connect
275275
var restoreFlow = false;
@@ -327,6 +327,15 @@ private async Task ConnectTransportAsync()
327327
return;
328328
}
329329

330+
// There is already a connect in-progress on this transport.
331+
// Don't cancel and start again as that causes queued requests waiting on the connect to fail.
332+
if (_connectContext != null && !_connectContext.Disposed)
333+
{
334+
SubchannelLog.ConnectionRequestedInNonIdleState(_logger, Id, _state);
335+
_delayInterruptTcs?.TrySetResult(null);
336+
return;
337+
}
338+
330339
connectContext = GetConnectContextUnsynchronized();
331340

332341
// Use a semaphore to limit one connection attempt at a time. This is done to prevent a race conditional where a canceled connect
@@ -633,7 +642,11 @@ public static void AddressesUpdated(ILogger logger, string subchannelId, IReadOn
633642
AddressesUpdated(logger, subchannelId, addressesText);
634643
}
635644
}
645+
636646
[LoggerMessage(Level = LogLevel.Debug, EventId = 20, EventName = "QueuingConnect", Message = "Subchannel id '{SubchannelId}' queuing connect because a connect is already in progress.")]
637647
public static partial void QueuingConnect(ILogger logger, string subchannelId);
648+
649+
[LoggerMessage(Level = LogLevel.Trace, EventId = 21, EventName = "StartingConnectionRequest", Message = "Subchannel id '{SubchannelId}' starting connection request.")]
650+
public static partial void StartingConnectionRequest(ILogger logger, string subchannelId);
638651
}
639652
#endif

test/Grpc.Net.Client.Tests/Balancer/ConnectionManagerTests.cs

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,128 @@ public async Task PickAsync_UpdateAddressesWhileRequestingConnection_DoesNotDead
620620
}
621621
}
622622

623+
[Test]
624+
public async Task PickAsync_MultipleRequestsRequestConnect_SingleConnectAttempt()
625+
{
626+
var services = new ServiceCollection();
627+
services.AddNUnitLogger();
628+
629+
var testSink = new TestSink();
630+
var testProvider = new TestLoggerProvider(testSink);
631+
632+
services.AddLogging(b =>
633+
{
634+
b.AddProvider(testProvider);
635+
});
636+
637+
await using var serviceProvider = services.BuildServiceProvider();
638+
var loggerFactory = serviceProvider.GetRequiredService<ILoggerFactory>();
639+
var logger = loggerFactory.CreateLogger(nameof(PickAsync_MultipleRequestsRequestConnect_SingleConnectAttempt));
640+
641+
var requestConnectionChannel = Channel.CreateUnbounded<SyncPoint>();
642+
var requestConnectionSyncPoint1 = new SyncPoint(runContinuationsAsynchronously: true);
643+
var requestConnectionSyncPoint2 = new SyncPoint(runContinuationsAsynchronously: true);
644+
requestConnectionChannel.Writer.TryWrite(requestConnectionSyncPoint1);
645+
requestConnectionChannel.Writer.TryWrite(requestConnectionSyncPoint2);
646+
647+
var connectingSyncPoint = new SyncPoint(runContinuationsAsynchronously: true);
648+
649+
var resolver = new TestResolver(loggerFactory);
650+
resolver.UpdateAddresses(new List<BalancerAddress>
651+
{
652+
new BalancerAddress("localhost", 80)
653+
});
654+
655+
var channelOptions = new GrpcChannelOptions();
656+
var acting = false;
657+
var transportFactory = TestSubchannelTransportFactory.Create(async (subChannel, attempt, cancellationToken) =>
658+
{
659+
cancellationToken.Register(() =>
660+
{
661+
logger.LogError("Connect cancellation token canceled.");
662+
});
663+
664+
if (!acting)
665+
{
666+
return new TryConnectResult(ConnectivityState.Ready);
667+
}
668+
669+
await connectingSyncPoint.WaitToContinue().WaitAsync(cancellationToken);
670+
671+
Assert.IsFalse(cancellationToken.IsCancellationRequested, "Cancellation token should not be canceled.");
672+
673+
return new TryConnectResult(ConnectivityState.Ready);
674+
});
675+
var clientChannel = CreateConnectionManager(loggerFactory, resolver, transportFactory, new[] { new PickFirstBalancerFactory() });
676+
// Configure balancer similar to how GrpcChannel constructor does it
677+
clientChannel.ConfigureBalancer(c => new ChildHandlerLoadBalancer(
678+
c,
679+
channelOptions.ServiceConfig,
680+
clientChannel));
681+
682+
await clientChannel.ConnectAsync(waitForReady: true, cancellationToken: CancellationToken.None);
683+
684+
transportFactory.Transports.ForEach(t => t.Disconnect());
685+
686+
testSink.MessageLogged += (w) =>
687+
{
688+
if (w.EventId.Name == "StartingConnectionRequest")
689+
{
690+
if (!requestConnectionChannel.Reader.TryRead(out var syncPoint))
691+
{
692+
throw new InvalidOperationException("Channel should have sync point.");
693+
}
694+
syncPoint.WaitToContinue().Wait();
695+
}
696+
};
697+
698+
acting = true;
699+
700+
logger.LogInformation("Start first pick.");
701+
var pickTask1 = Task.Run(() => clientChannel.PickAsync(
702+
new PickContext { Request = new HttpRequestMessage() },
703+
waitForReady: true,
704+
CancellationToken.None).AsTask());
705+
706+
logger.LogInformation("Wait for first pick to request connection.");
707+
await requestConnectionSyncPoint1.WaitForSyncPoint().DefaultTimeout();
708+
709+
logger.LogInformation("Start second pick.");
710+
var pickTask2 = Task.Run(() => clientChannel.PickAsync(
711+
new PickContext { Request = new HttpRequestMessage() },
712+
waitForReady: true,
713+
CancellationToken.None).AsTask());
714+
715+
logger.LogInformation("Wait for second pick to request connection.");
716+
await requestConnectionSyncPoint2.WaitForSyncPoint().DefaultTimeout();
717+
718+
logger.LogInformation("Allow first pick to start connecting.");
719+
requestConnectionSyncPoint1.Continue();
720+
await connectingSyncPoint.WaitForSyncPoint();
721+
722+
var connectionRequestedInNonIdleStateTcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
723+
testSink.MessageLogged += (w) =>
724+
{
725+
if (w.EventId.Name == "ConnectionRequestedInNonIdleState")
726+
{
727+
connectionRequestedInNonIdleStateTcs.TrySetResult(null);
728+
}
729+
};
730+
731+
logger.LogInformation("Allow second pick to wait for connecting to complete.");
732+
requestConnectionSyncPoint2.Continue();
733+
734+
logger.LogInformation("Wait for second pick to report that there is already a connection requested.");
735+
await connectionRequestedInNonIdleStateTcs.Task.DefaultTimeout();
736+
737+
logger.LogInformation("Allow first pick connecting to complete.");
738+
connectingSyncPoint.Continue();
739+
740+
logger.LogInformation("Wait for both picks to complete successfully.");
741+
await pickTask1.DefaultTimeout();
742+
await pickTask2.DefaultTimeout();
743+
}
744+
623745
[Test]
624746
public async Task PickAsync_ExecutionContext_DoesNotCaptureAsyncLocalsInConnect()
625747
{

0 commit comments

Comments
 (0)