Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion global.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"sdk": {
"version": "7.0.201"
"version": "7.0.201",
"rollForward": "latestFeature"
}
}
1 change: 1 addition & 0 deletions src/Grpc.AspNetCore.Server/Grpc.AspNetCore.Server.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
<Compile Include="..\Shared\Server\UnaryServerMethodInvoker.cs" Link="Model\Internal\UnaryServerMethodInvoker.cs" />
<Compile Include="..\Shared\NullableAttributes.cs" Link="Internal\NullableAttributes.cs" />
<Compile Include="..\Shared\CodeAnalysisAttributes.cs" Link="Internal\CodeAnalysisAttributes.cs" />
<Compile Include="..\Shared\NonCapturingTimer.cs" Link="Internal\NonCapturingTimer.cs" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -91,12 +91,12 @@ public ServerCallDeadlineManager(HttpContextServerCallContext serverCallContext,
// Ensures there is no weird situation where the timer triggers
// before the field is set. Shouldn't happen because only long deadlines
// will take this path but better to be safe than sorry.
_longDeadlineTimer = new Timer(DeadlineExceededLongDelegate, (this, maxTimerDueTime), Timeout.Infinite, Timeout.Infinite);
_longDeadlineTimer = NonCapturingTimer.Create(DeadlineExceededLongDelegate, (this, maxTimerDueTime), Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan);
_longDeadlineTimer.Change(timerMilliseconds, Timeout.Infinite);
}
else
{
_longDeadlineTimer = new Timer(DeadlineExceededDelegate, this, timerMilliseconds, Timeout.Infinite);
_longDeadlineTimer = NonCapturingTimer.Create(DeadlineExceededDelegate, this, TimeSpan.FromMilliseconds(timerMilliseconds), Timeout.InfiniteTimeSpan);
}
}

Expand Down
23 changes: 15 additions & 8 deletions src/Grpc.Net.Client/Balancer/BalancerAttributes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#endregion

#if SUPPORT_LOAD_BALANCING
using System.Collections;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics.CodeAnalysis;
Expand All @@ -29,7 +30,7 @@ namespace Grpc.Net.Client.Balancer;
/// Note: Experimental API that can change or be removed without any prior notice.
/// </para>
/// </summary>
public sealed class BalancerAttributes : IDictionary<string, object?>
public sealed class BalancerAttributes : IDictionary<string, object?>, IReadOnlyDictionary<string, object?>
{
/// <summary>
/// Gets a read-only collection of metadata attributes.
Expand Down Expand Up @@ -61,22 +62,28 @@ private BalancerAttributes(IDictionary<string, object?> attributes)
_attributes[key] = value;
}
}

ICollection<string> IDictionary<string, object?>.Keys => _attributes.Keys;
ICollection<object?> IDictionary<string, object?>.Values => _attributes.Values;
int ICollection<KeyValuePair<string, object?>>.Count => _attributes.Count;
bool ICollection<KeyValuePair<string, object?>>.IsReadOnly => ((IDictionary<string, object?>)_attributes).IsReadOnly;
bool ICollection<KeyValuePair<string, object?>>.IsReadOnly => _attributes.IsReadOnly;
IEnumerable<string> IReadOnlyDictionary<string, object?>.Keys => _attributes.Keys;
IEnumerable<object?> IReadOnlyDictionary<string, object?>.Values => _attributes.Values;
int IReadOnlyCollection<KeyValuePair<string, object?>>.Count => _attributes.Count;
object? IReadOnlyDictionary<string, object?>.this[string key] => _attributes[key];
void IDictionary<string, object?>.Add(string key, object? value) => _attributes.Add(key, value);
void ICollection<KeyValuePair<string, object?>>.Add(KeyValuePair<string, object?> item) => ((IDictionary<string, object?>)_attributes).Add(item);
void ICollection<KeyValuePair<string, object?>>.Add(KeyValuePair<string, object?> item) => _attributes.Add(item);
void ICollection<KeyValuePair<string, object?>>.Clear() => _attributes.Clear();
bool ICollection<KeyValuePair<string, object?>>.Contains(KeyValuePair<string, object?> item) => ((IDictionary<string, object?>)_attributes).Contains(item);
bool ICollection<KeyValuePair<string, object?>>.Contains(KeyValuePair<string, object?> item) => _attributes.Contains(item);
bool IDictionary<string, object?>.ContainsKey(string key) => _attributes.ContainsKey(key);
void ICollection<KeyValuePair<string, object?>>.CopyTo(KeyValuePair<string, object?>[] array, int arrayIndex) =>
((IDictionary<string, object?>)_attributes).CopyTo(array, arrayIndex);
void ICollection<KeyValuePair<string, object?>>.CopyTo(KeyValuePair<string, object?>[] array, int arrayIndex) => _attributes.CopyTo(array, arrayIndex);
IEnumerator<KeyValuePair<string, object?>> IEnumerable<KeyValuePair<string, object?>>.GetEnumerator() => _attributes.GetEnumerator();
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() => ((System.Collections.IEnumerable)_attributes).GetEnumerator();
IEnumerator System.Collections.IEnumerable.GetEnumerator() => ((System.Collections.IEnumerable)_attributes).GetEnumerator();
bool IDictionary<string, object?>.Remove(string key) => _attributes.Remove(key);
bool ICollection<KeyValuePair<string, object?>>.Remove(KeyValuePair<string, object?> item) => ((IDictionary<string, object?>)_attributes).Remove(item);
bool ICollection<KeyValuePair<string, object?>>.Remove(KeyValuePair<string, object?> item) => _attributes.Remove(item);
bool IDictionary<string, object?>.TryGetValue(string key, out object? value) => _attributes.TryGetValue(key, out value);
bool IReadOnlyDictionary<string, object?>.ContainsKey(string key) => _attributes.ContainsKey(key);
bool IReadOnlyDictionary<string, object?>.TryGetValue(string key, out object? value) => _attributes.TryGetValue(key, out value);

/// <summary>
/// Gets the value associated with the specified key.
Expand Down
2 changes: 1 addition & 1 deletion src/Grpc.Net.Client/Balancer/DnsResolver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ protected override void OnStarted()

if (_refreshInterval != Timeout.InfiniteTimeSpan)
{
_timer = new Timer(OnTimerCallback, null, Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan);
_timer = NonCapturingTimer.Create(OnTimerCallback, state: null, Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan);
_timer.Change(_refreshInterval, _refreshInterval);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -77,7 +77,7 @@ public SocketConnectivitySubchannelTransport(
ConnectTimeout = connectTimeout;
_socketConnect = socketConnect ?? OnConnect;
_activeStreams = new List<ActiveStream>();
_socketConnectedTimer = new Timer(OnCheckSocketConnection, state: null, Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan);
_socketConnectedTimer = NonCapturingTimer.Create(OnCheckSocketConnection, state: null, Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan);
}

private object Lock => _subchannel.Lock;
Expand Down
4 changes: 2 additions & 2 deletions src/Grpc.Net.Client/Balancer/PollingResolver.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -86,7 +86,7 @@ protected PollingResolver(ILoggerFactory loggerFactory, IBackoffPolicyFactory? b
/// </para>
/// </summary>
/// <param name="listener">The callback used to receive updates on the target.</param>
public override sealed void Start(Action<ResolverResult> listener)
public sealed override void Start(Action<ResolverResult> listener)
{
if (listener == null)
{
Expand Down
18 changes: 16 additions & 2 deletions src/Grpc.Net.Client/Balancer/Subchannel.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -233,7 +233,21 @@ public void RequestConnection()
}
}

_ = ConnectTransportAsync();
// Don't capture the current ExecutionContext and its AsyncLocals onto the connect
bool restoreFlow = false;
if (!ExecutionContext.IsFlowSuppressed())
{
ExecutionContext.SuppressFlow();
restoreFlow = true;
}

_ = Task.Run(ConnectTransportAsync);

// Restore the current ExecutionContext
if (restoreFlow)
{
ExecutionContext.RestoreFlow();
}
}

private void CancelInProgressConnect()
Expand Down
3 changes: 2 additions & 1 deletion src/Grpc.Net.Client/Grpc.Net.Client.csproj
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<Description>.NET client for gRPC</Description>
Expand Down Expand Up @@ -35,6 +35,7 @@
<Compile Include="..\Shared\NullableAttributes.cs" Link="Internal\NullableAttributes.cs" />
<Compile Include="..\Shared\Http2ErrorCode.cs" Link="Internal\Http2ErrorCode.cs" />
<Compile Include="..\Shared\Http3ErrorCode.cs" Link="Internal\Http3ErrorCode.cs" />
<Compile Include="..\Shared\NonCapturingTimer.cs" Link="Internal\NonCapturingTimer.cs" />
<Compile Include="..\Shared\NonDisposableMemoryStream.cs" Link="Internal\NonDisposableMemoryStream.cs" />
</ItemGroup>

Expand Down
22 changes: 15 additions & 7 deletions src/Grpc.Net.Client/GrpcChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ public sealed class GrpcChannel : ChannelBase, IDisposable
private readonly Dictionary<MethodKey, MethodConfig>? _serviceConfigMethods;
private readonly bool _isSecure;
private readonly List<CallCredentials>? _callCredentials;
// Internal for testing
internal readonly HashSet<IDisposable> ActiveCalls;
private readonly HashSet<IDisposable> _activeCalls;

internal Uri Address { get; }
internal HttpMessageInvoker HttpInvoker { get; }
Expand Down Expand Up @@ -165,7 +164,7 @@ internal GrpcChannel(Uri address, GrpcChannelOptions channelOptions) : base(addr
ThrowOperationCanceledOnCancellation = channelOptions.ThrowOperationCanceledOnCancellation;
UnsafeUseInsecureChannelCallCredentials = channelOptions.UnsafeUseInsecureChannelCallCredentials;
_createMethodInfoFunc = CreateMethodInfo;
ActiveCalls = new HashSet<IDisposable>();
_activeCalls = new HashSet<IDisposable>();
if (channelOptions.ServiceConfig is { } serviceConfig)
{
RetryThrottling = serviceConfig.RetryThrottling != null ? CreateChannelRetryThrottling(serviceConfig.RetryThrottling) : null;
Expand Down Expand Up @@ -490,15 +489,15 @@ internal void RegisterActiveCall(IDisposable grpcCall)
throw new ObjectDisposedException(nameof(GrpcChannel));
}

ActiveCalls.Add(grpcCall);
_activeCalls.Add(grpcCall);
}
}

internal void FinishActiveCall(IDisposable grpcCall)
{
lock (_lock)
{
ActiveCalls.Remove(grpcCall);
_activeCalls.Remove(grpcCall);
}
}

Expand Down Expand Up @@ -749,9 +748,9 @@ public void Dispose()
return;
}

if (ActiveCalls.Count > 0)
if (_activeCalls.Count > 0)
{
activeCallsCopy = ActiveCalls.ToArray();
activeCallsCopy = _activeCalls.ToArray();
}

Disposed = true;
Expand Down Expand Up @@ -807,6 +806,15 @@ internal int GetRandomNumber(int minValue, int maxValue)
}
}

// Internal for testing
internal IDisposable[] GetActiveCalls()
{
lock (_lock)
{
return _activeCalls.ToArray();
}
}

#if SUPPORT_LOAD_BALANCING
private sealed class SubChannelTransportFactory : ISubchannelTransportFactory
{
Expand Down
2 changes: 1 addition & 1 deletion src/Grpc.Net.Client/Internal/GrpcCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ public Exception CreateFailureStatusException(Status status)
GrpcCallLog.StartingDeadlineTimeout(Logger, timeout.Value);

var dueTime = CommonGrpcProtocolHelpers.GetTimerDueTime(timeout.Value, Channel.MaxTimerDueTime);
_deadlineTimer = new Timer(DeadlineExceededCallback, null, dueTime, Timeout.Infinite);
_deadlineTimer = NonCapturingTimer.Create(DeadlineExceededCallback, state: null, TimeSpan.FromMilliseconds(dueTime), Timeout.InfiniteTimeSpan);
}
}

Expand Down
39 changes: 39 additions & 0 deletions src/Shared/NonCapturingTimer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace Grpc.Shared;

// A convenience API for interacting with System.Threading.Timer in a way
// that doesn't capture the ExecutionContext. We should be using this (or equivalent)
// everywhere we use timers to avoid rooting any values stored in asynclocals.
internal static class NonCapturingTimer
{
public static Timer Create(TimerCallback callback, object? state, TimeSpan dueTime, TimeSpan period)
{
if (callback is null)
{
throw new ArgumentNullException(nameof(callback));
}

// Don't capture the current ExecutionContext and its AsyncLocals onto the timer
bool restoreFlow = false;
try
{
if (!ExecutionContext.IsFlowSuppressed())
{
ExecutionContext.SuppressFlow();
restoreFlow = true;
}

return new Timer(callback, state, dueTime, period);
}
finally
{
// Restore the current ExecutionContext
if (restoreFlow)
{
ExecutionContext.RestoreFlow();
}
}
}
}
70 changes: 1 addition & 69 deletions test/FunctionalTests/Balancer/BalancerHelpers.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -185,74 +185,6 @@ public static async Task<GrpcChannel> CreateChannel(
return channel;
}

public static Task WaitForChannelStateAsync(ILogger logger, GrpcChannel channel, ConnectivityState state, int channelId = 1)
{
return WaitForChannelStatesAsync(logger, channel, new[] { state }, channelId);
}

public static async Task WaitForChannelStatesAsync(ILogger logger, GrpcChannel channel, ConnectivityState[] states, int channelId = 1)
{
var statesText = string.Join(", ", states.Select(s => $"'{s}'"));
logger.LogInformation($"Channel id {channelId}: Waiting for channel states {statesText}.");

var currentState = channel.State;

while (!states.Contains(currentState))
{
logger.LogInformation($"Channel id {channelId}: Current channel state '{currentState}' doesn't match expected states {statesText}.");

await channel.WaitForStateChangedAsync(currentState).DefaultTimeout();
currentState = channel.State;
}

logger.LogInformation($"Channel id {channelId}: Current channel state '{currentState}' matches expected states {statesText}.");
}

public static async Task<Subchannel> WaitForSubchannelToBeReadyAsync(ILogger logger, GrpcChannel channel, Func<SubchannelPicker?, Subchannel[]>? getPickerSubchannels = null)
{
var subChannel = (await WaitForSubchannelsToBeReadyAsync(logger, channel, 1)).Single();
return subChannel;
}

public static async Task<Subchannel[]> WaitForSubchannelsToBeReadyAsync(ILogger logger, GrpcChannel channel, int expectedCount, Func<SubchannelPicker?, Subchannel[]>? getPickerSubchannels = null)
{
if (getPickerSubchannels == null)
{
getPickerSubchannels = (picker) =>
{
return picker switch
{
RoundRobinPicker roundRobinPicker => roundRobinPicker._subchannels.ToArray(),
PickFirstPicker pickFirstPicker => new[] { pickFirstPicker.Subchannel },
EmptyPicker emptyPicker => Array.Empty<Subchannel>(),
null => Array.Empty<Subchannel>(),
_ => throw new Exception("Unexpected picker type: " + picker.GetType().FullName)
};
};
}

logger.LogInformation($"Waiting for subchannel ready count: {expectedCount}");

Subchannel[]? subChannelsCopy = null;
await TestHelpers.AssertIsTrueRetryAsync(() =>
{
var picker = channel.ConnectionManager._picker;
subChannelsCopy = getPickerSubchannels(picker);
logger.LogInformation($"Current subchannel ready count: {subChannelsCopy.Length}");
for (var i = 0; i < subChannelsCopy.Length; i++)
{
logger.LogInformation($"Ready subchannel: {subChannelsCopy[i]}");
}

return subChannelsCopy.Length == expectedCount;
}, "Wait for all subconnections to be connected.");

logger.LogInformation($"Finished waiting for subchannel ready.");

Debug.Assert(subChannelsCopy != null);
return subChannelsCopy;
}

public static T? GetInnerLoadBalancer<T>(GrpcChannel channel) where T : LoadBalancer
{
var balancer = (ChildHandlerLoadBalancer)channel.ConnectionManager._balancer!;
Expand Down
4 changes: 2 additions & 2 deletions test/FunctionalTests/Balancer/ConnectionTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -352,7 +352,7 @@ Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)

await channel.ConnectAsync().DefaultTimeout();

await BalancerHelpers.WaitForSubchannelsToBeReadyAsync(Logger, channel, 2).DefaultTimeout();
await BalancerWaitHelpers.WaitForSubchannelsToBeReadyAsync(Logger, channel, 2).DefaultTimeout();

var client = TestClientFactory.Create(channel, endpoint1.Method);

Expand Down
4 changes: 2 additions & 2 deletions test/FunctionalTests/Balancer/LeastUsedBalancerTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -67,7 +67,7 @@ async Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext conte

var channel = await BalancerHelpers.CreateChannel(LoggerFactory, new LoadBalancingConfig("least_used"), new[] { endpoint1.Address, endpoint2.Address }, connect: true);

await BalancerHelpers.WaitForSubchannelsToBeReadyAsync(
await BalancerWaitHelpers.WaitForSubchannelsToBeReadyAsync(
Logger,
channel,
expectedCount: 2,
Expand Down
Loading