Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/Grpc.Net.Client/Internal/GrpcCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,16 @@ internal sealed partial class GrpcCall<TRequest, TResponse> : GrpcCall, IGrpcCal
public HttpContentClientStreamWriter<TRequest, TResponse>? ClientStreamWriter { get; private set; }
public HttpContentClientStreamReader<TRequest, TResponse>? ClientStreamReader { get; private set; }

public GrpcCall(Method<TRequest, TResponse> method, GrpcMethodInfo grpcMethodInfo, CallOptions options, GrpcChannel channel, int attemptCount)
public GrpcCall(Method<TRequest, TResponse> method, GrpcMethodInfo grpcMethodInfo, CallOptions options, GrpcChannel channel, int attemptCount, bool forceAsyncHttpResponse)
: base(options, channel)
{
// Validate deadline before creating any objects that require cleanup
ValidateDeadline(options.Deadline);

_callCts = new CancellationTokenSource();
_httpResponseTcs = new TaskCompletionSource<HttpResponseMessage>(TaskCreationOptions.RunContinuationsAsynchronously);
// Retries and hedging can run multiple calls at the same time and use locking for thread-safety.
// Running HTTP response continuation asynchronously is required for locking to work correctly.
_httpResponseTcs = new TaskCompletionSource<HttpResponseMessage>(forceAsyncHttpResponse ? TaskCreationOptions.RunContinuationsAsynchronously : TaskCreationOptions.None);
// Run the callTcs continuation immediately to keep the same context. Required for Activity.
_callTcs = new TaskCompletionSource<Status>();
Method = method;
Expand Down Expand Up @@ -142,7 +144,10 @@ public void StartDuplexStreaming()

internal void StartUnaryCore(HttpContent content)
{
_responseTcs = new TaskCompletionSource<TResponse>(TaskCreationOptions.RunContinuationsAsynchronously);
// Not created with RunContinuationsAsynchronously to avoid unnecessary dispatch to the thread pool.
// The TCS is set from RunCall but it is the last operation before the method exits so there shouldn't
// be an impact from running the response continutation synchronously.
_responseTcs = new TaskCompletionSource<TResponse>();

var timeout = GetTimeout();
var message = CreateHttpRequestMessage(timeout);
Expand All @@ -161,7 +166,10 @@ internal void StartServerStreamingCore(HttpContent content)

internal void StartClientStreamingCore(HttpContentClientStreamWriter<TRequest, TResponse> clientStreamWriter, HttpContent content)
{
_responseTcs = new TaskCompletionSource<TResponse>(TaskCreationOptions.RunContinuationsAsynchronously);
// Not created with RunContinuationsAsynchronously to avoid unnecessary dispatch to the thread pool.
// The TCS is set from RunCall but it is the last operation before the method exits so there shouldn't
// be an impact from running the response continutation synchronously.
_responseTcs = new TaskCompletionSource<TResponse>();

var timeout = GetTimeout();
var message = CreateHttpRequestMessage(timeout);
Expand Down Expand Up @@ -431,9 +439,6 @@ private void CancelCall(Status status)
// Cancellation will also cause reader/writer to throw if used afterwards.
_callCts.Cancel();

// Ensure any logic that is waiting on the HttpResponse is unstuck.
_httpResponseTcs.TrySetCanceled();

// Cancellation token won't send RST_STREAM if HttpClient.SendAsync is complete.
// Dispose HttpResponseMessage to send RST_STREAM to server for in-progress calls.
HttpResponse?.Dispose();
Expand Down Expand Up @@ -652,6 +657,9 @@ private async Task RunCall(HttpRequestMessage request, TimeSpan? timeout)
// Verify that FinishCall is called in every code path of this method.
// Should create an "Unassigned variable" compiler error if not set.
Debug.Assert(finished);
// Should be completed before exiting.
Debug.Assert(_httpResponseTcs.Task.IsCompleted);
Debug.Assert(_responseTcs == null || _responseTcs.Task.IsCompleted);
}
}

Expand Down
5 changes: 3 additions & 2 deletions src/Grpc.Net.Client/Internal/HttpClientCallInvoker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ private static IGrpcCall<TRequest, TResponse> CreateRootGrpcCall<TRequest, TResp
{
// No retry/hedge policy configured. Fast path!
// Note that callWrapper is null here and will be set later.
return CreateGrpcCall<TRequest, TResponse>(channel, method, options, attempt: 1, callWrapper: null);
return CreateGrpcCall<TRequest, TResponse>(channel, method, options, attempt: 1, forceAsyncHttpResponse: false, callWrapper: null);
}
}

Expand Down Expand Up @@ -210,14 +210,15 @@ public static GrpcCall<TRequest, TResponse> CreateGrpcCall<TRequest, TResponse>(
Method<TRequest, TResponse> method,
CallOptions options,
int attempt,
bool forceAsyncHttpResponse,
object? callWrapper)
where TRequest : class
where TResponse : class
{
ObjectDisposedThrowHelper.ThrowIf(channel.Disposed, typeof(GrpcChannel));

var methodInfo = channel.GetCachedGrpcMethodInfo(method);
var call = new GrpcCall<TRequest, TResponse>(method, methodInfo, options, channel, attempt);
var call = new GrpcCall<TRequest, TResponse>(method, methodInfo, options, channel, attempt, forceAsyncHttpResponse);
call.CallWrapper = callWrapper;

return call;
Expand Down
2 changes: 1 addition & 1 deletion src/Grpc.Net.Client/Internal/Retry/HedgingCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ private async Task StartCall(Action<GrpcCall<TRequest, TResponse>> startCallFunc

OnStartingAttempt();

call = HttpClientCallInvoker.CreateGrpcCall<TRequest, TResponse>(Channel, Method, Options, AttemptCount, CallWrapper);
call = HttpClientCallInvoker.CreateGrpcCall<TRequest, TResponse>(Channel, Method, Options, AttemptCount, forceAsyncHttpResponse: true, CallWrapper);
_activeCalls.Add(call);

startCallFunc(call);
Expand Down
2 changes: 1 addition & 1 deletion src/Grpc.Net.Client/Internal/Retry/RetryCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ private async Task StartRetry(Action<GrpcCall<TRequest, TResponse>> startCallFun
// Start new call.
OnStartingAttempt();

currentCall = _activeCall = HttpClientCallInvoker.CreateGrpcCall<TRequest, TResponse>(Channel, Method, Options, AttemptCount, CallWrapper);
currentCall = _activeCall = HttpClientCallInvoker.CreateGrpcCall<TRequest, TResponse>(Channel, Method, Options, AttemptCount, forceAsyncHttpResponse: true, CallWrapper);
startCallFunc(currentCall);

SetNewActiveCallUnsynchronized(currentCall);
Expand Down
6 changes: 5 additions & 1 deletion src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ internal abstract partial class RetryCallBase<TRequest, TResponse> : IGrpcCall<T
private Task<TResponse>? _responseTask;
private Task<Metadata>? _responseHeadersTask;
private TRequest? _request;
private bool _commitStarted;

// Internal for unit testing.
internal CancellationTokenRegistration? _ctsRegistration;
Expand Down Expand Up @@ -369,8 +370,11 @@ protected void CommitCall(IGrpcCall<TRequest, TResponse> call, CommitReason comm
{
lock (Lock)
{
if (!CommitedCallTask.IsCompletedSuccessfully())
if (!_commitStarted)
{
// Specify that call is commiting. This is to prevent any chance of re-entrancy from logic run in OnCommitCall.
_commitStarted = true;

// The buffer size is verified in unit tests after calls are completed.
// Clear the buffer before commiting call.
ClearRetryBuffer();
Expand Down
2 changes: 1 addition & 1 deletion test/Grpc.Net.Client.Tests/CancellationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ public async Task AsyncClientStreamingCall_CancellationDuringSend_ThrowOperation

cts.Cancel();

var ex = await ExceptionAssert.ThrowsAsync<TaskCanceledException>(() => responseHeadersTask).DefaultTimeout();
var ex = await ExceptionAssert.ThrowsAsync<OperationCanceledException>(() => responseHeadersTask).DefaultTimeout();
Assert.AreEqual(StatusCode.Cancelled, call.GetStatus().StatusCode);
Assert.AreEqual("Call canceled by the client.", call.GetStatus().Detail);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ private static GrpcCall<HelloRequest, HelloReply> CreateGrpcCall(GrpcChannel cha
new GrpcMethodInfo(new GrpcCallScope(ClientTestHelpers.ServiceMethod.Type, uri), uri, methodConfig: null),
new CallOptions(),
channel,
attemptCount: 0);
attemptCount: 0,
forceAsyncHttpResponse: false);
}

private static GrpcChannel CreateChannel(HttpClient httpClient, ILoggerFactory? loggerFactory = null, bool? throwOperationCanceledOnCancellation = null)
Expand Down
5 changes: 4 additions & 1 deletion test/Grpc.Net.Client.Tests/ResponseHeadersAsyncTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@ public async Task AsyncUnaryCall_AuthInterceptorDispose_ResponseHeadersError()
var credentialsSyncPoint = new SyncPoint(runContinuationsAsynchronously: true);
var credentials = CallCredentials.FromInterceptor(async (context, metadata) =>
{
await credentialsSyncPoint.WaitToContinue();
var tcs = new TaskCompletionSource<bool>();
context.CancellationToken.Register(s => ((TaskCompletionSource<bool>)s!).SetResult(true), tcs);

await Task.WhenAny(credentialsSyncPoint.WaitToContinue(), tcs.Task);
metadata.Add("Authorization", $"Bearer TEST");
});

Expand Down
5 changes: 4 additions & 1 deletion test/Grpc.Net.Client.Tests/Retry/RetryTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ public async Task AsyncUnaryCall_AuthInteceptorDispose_Error()
var credentialsSyncPoint = new SyncPoint(runContinuationsAsynchronously: true);
var credentials = CallCredentials.FromInterceptor(async (context, metadata) =>
{
await credentialsSyncPoint.WaitToContinue();
var tcs = new TaskCompletionSource<bool>();
context.CancellationToken.Register(s => ((TaskCompletionSource<bool>)s!).SetResult(true), tcs);

await Task.WhenAny(credentialsSyncPoint.WaitToContinue(), tcs.Task);
metadata.Add("Authorization", $"Bearer TEST");
});
var invoker = HttpClientCallInvokerFactory.Create(httpClient, loggerFactory: provider.GetRequiredService<ILoggerFactory>(), serviceConfig: serviceConfig, configure: options => options.Credentials = ChannelCredentials.Create(new SslCredentials(), credentials));
Expand Down