Skip to content

Commit 7786079

Browse files
authored
Fix bugs related to channel dispose while there are active calls (#2120)
1 parent 38738af commit 7786079

File tree

11 files changed

+225
-31
lines changed

11 files changed

+225
-31
lines changed

src/Grpc.Net.Client/Configuration/HedgingPolicy.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#region Copyright notice and license
1+
#region Copyright notice and license
22

33
// Copyright 2019 The gRPC Authors
44
//
@@ -37,7 +37,7 @@ public sealed class HedgingPolicy : ConfigObject
3737
internal const string HedgingDelayPropertyName = "hedgingDelay";
3838
internal const string NonFatalStatusCodesPropertyName = "nonFatalStatusCodes";
3939

40-
private ConfigProperty<Values<StatusCode, object>, IList<object>> _nonFatalStatusCodes =
40+
private readonly ConfigProperty<Values<StatusCode, object>, IList<object>> _nonFatalStatusCodes =
4141
new(i => new Values<StatusCode, object>(i ?? new List<object>(), s => ConvertHelpers.ConvertStatusCode(s), s => ConvertHelpers.ConvertStatusCode(s.ToString()!)), NonFatalStatusCodesPropertyName);
4242

4343
/// <summary>

src/Grpc.Net.Client/GrpcChannel.cs

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,13 @@ internal void RegisterActiveCall(IDisposable grpcCall)
483483
{
484484
lock (_lock)
485485
{
486+
// Test the disposed flag inside the lock to ensure there is no chance of a race and adding a call after dispose.
487+
// Note that a GrpcCall has been created but hasn't been started. The error will prevent it from starting.
488+
if (Disposed)
489+
{
490+
throw new ObjectDisposedException(nameof(GrpcChannel));
491+
}
492+
486493
ActiveCalls.Add(grpcCall);
487494
}
488495
}
@@ -733,23 +740,29 @@ public Task WaitForStateChangedAsync(ConnectivityState lastObservedState, Cancel
733740
/// </summary>
734741
public void Dispose()
735742
{
736-
if (Disposed)
737-
{
738-
return;
739-
}
740-
743+
IDisposable[]? activeCallsCopy = null;
741744
lock (_lock)
742745
{
746+
// Check and set disposed flag inside lock.
747+
if (Disposed)
748+
{
749+
return;
750+
}
751+
743752
if (ActiveCalls.Count > 0)
744753
{
745-
// Disposing a call will remove it from ActiveCalls. Need to take a copy
746-
// to avoid enumeration from being modified
747-
var activeCallsCopy = ActiveCalls.ToArray();
754+
activeCallsCopy = ActiveCalls.ToArray();
755+
}
748756

749-
foreach (var activeCall in activeCallsCopy)
750-
{
751-
activeCall.Dispose();
752-
}
757+
Disposed = true;
758+
}
759+
760+
// Dispose calls outside of lock to avoid chance of deadlock.
761+
if (activeCallsCopy is not null)
762+
{
763+
foreach (var activeCall in activeCallsCopy)
764+
{
765+
activeCall.Dispose();
753766
}
754767
}
755768

@@ -760,7 +773,6 @@ public void Dispose()
760773
#if SUPPORT_LOAD_BALANCING
761774
ConnectionManager.Dispose();
762775
#endif
763-
Disposed = true;
764776
}
765777

766778
internal bool TryAddToRetryBuffer(long messageSize)

src/Grpc.Net.Client/Internal/HttpClientCallInvoker.cs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#region Copyright notice and license
1+
#region Copyright notice and license
22

33
// Copyright 2019 The gRPC Authors
44
//
@@ -150,11 +150,6 @@ public static GrpcCall<TRequest, TResponse> CreateGrpcCall<TRequest, TResponse>(
150150
where TRequest : class
151151
where TResponse : class
152152
{
153-
if (channel.Disposed)
154-
{
155-
throw new ObjectDisposedException(nameof(GrpcChannel));
156-
}
157-
158153
var methodInfo = channel.GetCachedGrpcMethodInfo(method);
159154
var call = new GrpcCall<TRequest, TResponse>(method, methodInfo, options, channel, attempt);
160155

src/Grpc.Net.Client/Internal/Retry/HedgingCall.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#region Copyright notice and license
1+
#region Copyright notice and license
22

33
// Copyright 2019 The gRPC Authors
44
//
@@ -54,6 +54,8 @@ public HedgingCall(HedgingPolicyInfo hedgingPolicy, GrpcChannel channel, Method<
5454
_delayInterruptTcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
5555
_hedgingDelayCts = new CancellationTokenSource();
5656
}
57+
58+
Channel.RegisterActiveCall(this);
5759
}
5860

5961
private async Task StartCall(Action<GrpcCall<TRequest, TResponse>> startCallFunc)

src/Grpc.Net.Client/Internal/Retry/RetryCall.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#region Copyright notice and license
1+
#region Copyright notice and license
22

33
// Copyright 2019 The gRPC Authors
44
//
@@ -42,6 +42,8 @@ public RetryCall(RetryPolicyInfo retryPolicy, GrpcChannel channel, Method<TReque
4242
_retryPolicy = retryPolicy;
4343

4444
_nextRetryDelayMilliseconds = Convert.ToInt32(retryPolicy.InitialBackoff.TotalMilliseconds);
45+
46+
Channel.RegisterActiveCall(this);
4547
}
4648

4749
private int CalculateNextRetryDelay()

src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#region Copyright notice and license
1+
#region Copyright notice and license
22

33
// Copyright 2019 The gRPC Authors
44
//
@@ -440,6 +440,8 @@ protected virtual void Dispose(bool disposing)
440440

441441
protected void Cleanup()
442442
{
443+
Channel.FinishActiveCall(this);
444+
443445
_ctsRegistration?.Dispose();
444446
_ctsRegistration = null;
445447
CancellationTokenSource.Cancel();

test/FunctionalTests/Client/RetryTests.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#region Copyright notice and license
1+
#region Copyright notice and license
22

33
// Copyright 2019 The gRPC Authors
44
//
@@ -108,6 +108,8 @@ async Task<DataMessage> ClientStreamingWithReadFailures(IAsyncStreamReader<DataM
108108

109109
// Assert
110110
Assert.IsTrue(result.Data.Span.SequenceEqual(sentData.ToArray()));
111+
112+
Assert.AreEqual(0, channel.ActiveCalls.Count);
111113
}
112114

113115
[Test]
@@ -390,6 +392,8 @@ Task FakeServerStreamCall(DataMessage request, IServerStreamWriter<DataMessage>
390392
await MakeCallsAsync(channel, method, references, cts.Token).DefaultTimeout();
391393

392394
// Assert
395+
Assert.AreEqual(0, channel.ActiveCalls.Count);
396+
393397
// There is a race when cleaning up cancellation token registry.
394398
// Retry a few times to ensure GC is run after unregister.
395399
await TestHelpers.AssertIsTrueRetryAsync(() =>

test/Grpc.Net.Client.Tests/AsyncDuplexStreamingCallTests.cs

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#region Copyright notice and license
1+
#region Copyright notice and license
22

33
// Copyright 2019 The gRPC Authors
44
//
@@ -17,13 +17,16 @@
1717
#endregion
1818

1919
using System.Net;
20+
using System.Threading.Tasks;
2021
using Greet;
2122
using Grpc.Core;
2223
using Grpc.Net.Client.Internal;
2324
using Grpc.Net.Client.Internal.Http;
2425
using Grpc.Net.Client.Tests.Infrastructure;
2526
using Grpc.Shared;
2627
using Grpc.Tests.Shared;
28+
using Microsoft.Extensions.DependencyInjection;
29+
using Microsoft.Extensions.Logging;
2730
using NUnit.Framework;
2831

2932
namespace Grpc.Net.Client.Tests;
@@ -177,4 +180,81 @@ await streamContent.AddDataAndWait(await ClientTestHelpers.GetResponseDataAsync(
177180
Assert.IsTrue(moveNextTask4.IsCompleted);
178181
Assert.IsFalse(await moveNextTask3.DefaultTimeout());
179182
}
183+
184+
[Test]
185+
public async Task AsyncDuplexStreamingCall_CancellationDisposeRace_Success()
186+
{
187+
// Arrange
188+
var services = new ServiceCollection();
189+
services.AddNUnitLogger();
190+
var loggerFactory = services.BuildServiceProvider().GetRequiredService<ILoggerFactory>();
191+
var logger = loggerFactory.CreateLogger(GetType());
192+
193+
for (var i = 0; i < 20; i++)
194+
{
195+
// Let's mimic a real call first to get GrpcCall.RunCall where we need to for reproducing the deadlock.
196+
var streamContent = new SyncPointMemoryStream();
197+
var requestContentTcs = new TaskCompletionSource<Task<Stream>>(TaskCreationOptions.RunContinuationsAsynchronously);
198+
199+
PushStreamContent<HelloRequest, HelloReply>? content = null;
200+
201+
var handler = TestHttpMessageHandler.Create(async request =>
202+
{
203+
content = (PushStreamContent<HelloRequest, HelloReply>)request.Content!;
204+
var streamTask = content.ReadAsStreamAsync();
205+
requestContentTcs.SetResult(streamTask);
206+
// Wait for RequestStream.CompleteAsync()
207+
await streamTask;
208+
return ResponseUtils.CreateResponse(HttpStatusCode.OK, new StreamContent(streamContent));
209+
});
210+
var channel = GrpcChannel.ForAddress("http://localhost", new GrpcChannelOptions
211+
{
212+
HttpHandler = handler,
213+
LoggerFactory = loggerFactory
214+
});
215+
var invoker = channel.CreateCallInvoker();
216+
217+
var cts = new CancellationTokenSource();
218+
219+
var call = invoker.AsyncDuplexStreamingCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(cancellationToken: cts.Token));
220+
await call.RequestStream.WriteAsync(new HelloRequest { Name = "1" }).DefaultTimeout();
221+
await call.RequestStream.CompleteAsync().DefaultTimeout();
222+
223+
// Let's read a response
224+
var deserializationContext = new DefaultDeserializationContext();
225+
var requestContent = await await requestContentTcs.Task.DefaultTimeout();
226+
var requestMessage = await StreamSerializationHelper.ReadMessageAsync(
227+
requestContent,
228+
ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer,
229+
GrpcProtocolConstants.IdentityGrpcEncoding,
230+
maximumMessageSize: null,
231+
GrpcProtocolConstants.DefaultCompressionProviders,
232+
singleMessage: false,
233+
CancellationToken.None).DefaultTimeout();
234+
Assert.AreEqual("1", requestMessage!.Name);
235+
236+
var actTcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
237+
238+
var cancellationTask = Task.Run(async () =>
239+
{
240+
await actTcs.Task;
241+
cts.Cancel();
242+
});
243+
var disposingTask = Task.Run(async () =>
244+
{
245+
await actTcs.Task;
246+
channel.Dispose();
247+
});
248+
249+
// Small pause to make sure we're waiting at the TCS everywhere.
250+
await Task.Delay(50);
251+
252+
// Act
253+
actTcs.SetResult(null);
254+
255+
// Assert
256+
// Cancellation and disposing should both complete quickly. If there is a deadlock then the await will timeout.
257+
await Task.WhenAll(cancellationTask, disposingTask).DefaultTimeout();
258+
}
259+
}
180260
}

test/Grpc.Net.Client.Tests/Retry/HedgingTests.cs

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#region Copyright notice and license
1+
#region Copyright notice and license
22

33
// Copyright 2019 The gRPC Authors
44
//
@@ -72,6 +72,8 @@ public async Task AsyncUnaryCall_OneAttempt_Success(int maxAttempts)
7272
var rs = await call.ResponseAsync.DefaultTimeout();
7373
Assert.AreEqual("Hello world", rs.Message);
7474
Assert.AreEqual(StatusCode.OK, call.GetStatus().StatusCode);
75+
76+
Assert.AreEqual(0, invoker.Channel.ActiveCalls.Count);
7577
}
7678

7779
[Test]
@@ -591,7 +593,6 @@ public async Task AsyncClientStreamingCall_SuccessAfterRetry_RequestContentSent(
591593
var responseTask = call.ResponseAsync;
592594
Assert.IsFalse(responseTask.IsCompleted, "Response not returned until client stream is complete.");
593595

594-
595596
await call.RequestStream.WriteAsync(new HelloRequest { Name = "1" }).DefaultTimeout();
596597
await call.RequestStream.WriteAsync(new HelloRequest { Name = "2" }).DefaultTimeout();
597598

@@ -687,6 +688,54 @@ public async Task AsyncClientStreamingCall_WriteAfterResult_Error()
687688
Assert.AreEqual(StatusCode.OK, ex.StatusCode);
688689
}
689690

691+
[Test]
692+
public void AsyncUnaryCall_DisposedChannel_Error()
693+
{
694+
// Arrange
695+
var httpClient = ClientTestHelpers.CreateTestClient(request =>
696+
{
697+
return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.OK));
698+
});
699+
var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig();
700+
var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig);
701+
702+
// Act & Assert
703+
invoker.Channel.Dispose();
704+
Assert.Throws<ObjectDisposedException>(() => invoker.AsyncUnaryCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }));
705+
}
706+
707+
[Test]
708+
public async Task AsyncUnaryCall_ChannelDisposeDuringBackoff_CanceledStatus()
709+
{
710+
// Arrange
711+
var callCount = 0;
712+
var httpClient = ClientTestHelpers.CreateTestClient(async request =>
713+
{
714+
callCount++;
715+
716+
await request.Content!.CopyToAsync(new MemoryStream());
717+
return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable, retryPushbackHeader: TimeSpan.FromSeconds(10).TotalMilliseconds.ToString(CultureInfo.InvariantCulture));
718+
});
719+
var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(hedgingDelay: TimeSpan.FromSeconds(10));
720+
var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig);
721+
var cts = new CancellationTokenSource();
722+
723+
// Act
724+
var call = invoker.AsyncUnaryCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(cancellationToken: cts.Token), new HelloRequest { Name = "World" });
725+
726+
var delayTask = Task.Delay(100);
727+
var completedTask = await Task.WhenAny(call.ResponseAsync, delayTask);
728+
729+
// Assert
730+
Assert.AreEqual(delayTask, completedTask); // Ensure that we're waiting for retry
731+
732+
invoker.Channel.Dispose();
733+
734+
var ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => call.ResponseAsync).DefaultTimeout();
735+
Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode);
736+
Assert.AreEqual("gRPC call disposed.", ex.Status.Detail);
737+
}
738+
690739
private static Task<HelloRequest?> ReadRequestMessage(Stream requestContent)
691740
{
692741
return StreamSerializationHelper.ReadMessageAsync(

0 commit comments

Comments
 (0)