diff --git a/src/libraries/Common/src/Interop/Unix/System.Native/Interop.Socket.cs b/src/libraries/Common/src/Interop/Unix/System.Native/Interop.Socket.cs index 56d00ccb3a98cd..01b5fb9f32f730 100644 --- a/src/libraries/Common/src/Interop/Unix/System.Native/Interop.Socket.cs +++ b/src/libraries/Common/src/Interop/Unix/System.Native/Interop.Socket.cs @@ -3,7 +3,6 @@ using System; using System.Net; -using System.Net.Internals; using System.Net.Sockets; using System.Runtime.InteropServices; @@ -12,6 +11,6 @@ internal static partial class Interop internal static partial class Sys { [LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_Socket")] - internal static unsafe partial Error Socket(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType, IntPtr* socket); + internal static unsafe partial Error Socket(int addressFamily, int socketType, int protocolType, IntPtr* socket); } } diff --git a/src/libraries/Common/src/Interop/Unix/System.Native/Interop.SocketAddress.cs b/src/libraries/Common/src/Interop/Unix/System.Native/Interop.SocketAddress.cs index 4c03a64e1beb3a..7a1dceee2fd868 100644 --- a/src/libraries/Common/src/Interop/Unix/System.Native/Interop.SocketAddress.cs +++ b/src/libraries/Common/src/Interop/Unix/System.Native/Interop.SocketAddress.cs @@ -9,9 +9,9 @@ internal static partial class Interop { internal static partial class Sys { - [LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_GetIPSocketAddressSizes")] + [LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_GetSocketAddressSizes")] [SuppressGCTransition] - internal static unsafe partial Error GetIPSocketAddressSizes(int* ipv4SocketAddressSize, int* ipv6SocketAddressSize); + internal static unsafe partial Error GetSocketAddressSizes(int* ipv4SocketAddressSize, int* ipv6SocketAddressSize, int* udsSocketAddressSize, int* maxSocketAddressSize); [LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_GetAddressFamily")] [SuppressGCTransition] diff --git a/src/libraries/Common/src/Interop/Windows/IpHlpApi/Interop.ICMP.cs b/src/libraries/Common/src/Interop/Windows/IpHlpApi/Interop.ICMP.cs index 090bf72a31a39d..1f71754b486f88 100644 --- a/src/libraries/Common/src/Interop/Windows/IpHlpApi/Interop.ICMP.cs +++ b/src/libraries/Common/src/Interop/Windows/IpHlpApi/Interop.ICMP.cs @@ -104,6 +104,6 @@ internal static partial uint IcmpSendEcho2(SafeCloseIcmpHandle icmpHandle, SafeW [LibraryImport(Interop.Libraries.IpHlpApi, SetLastError = true)] internal static unsafe partial uint Icmp6SendEcho2(SafeCloseIcmpHandle icmpHandle, SafeWaitHandle Event, IntPtr apcRoutine, IntPtr apcContext, - byte* sourceSocketAddress, byte[] destSocketAddress, SafeLocalAllocHandle data, ushort dataSize, ref IP_OPTION_INFORMATION options, SafeLocalAllocHandle replyBuffer, uint replySize, uint timeout); + Span sourceSocketAddress, Span destSocketAddress, SafeLocalAllocHandle data, ushort dataSize, ref IP_OPTION_INFORMATION options, SafeLocalAllocHandle replyBuffer, uint replySize, uint timeout); } } diff --git a/src/libraries/Common/src/Interop/Windows/IpHlpApi/Interop.NetworkInformation.cs b/src/libraries/Common/src/Interop/Windows/IpHlpApi/Interop.NetworkInformation.cs index 71e67116e4ebe2..b01bb8ee9d3c9b 100644 --- a/src/libraries/Common/src/Interop/Windows/IpHlpApi/Interop.NetworkInformation.cs +++ b/src/libraries/Common/src/Interop/Windows/IpHlpApi/Interop.NetworkInformation.cs @@ -8,7 +8,6 @@ using System.Net.NetworkInformation; using System.Net.Sockets; using System.Runtime.InteropServices; -using Internals = System.Net.Internals; internal static partial class Interop { @@ -53,20 +52,14 @@ internal enum GetAdaptersAddressesFlags } [StructLayout(LayoutKind.Sequential)] - internal struct IpSocketAddress + internal unsafe struct IpSocketAddress { internal IntPtr address; internal int addressLength; internal IPAddress MarshalIPAddress() { - // Determine the address family used to create the IPAddress. - AddressFamily family = (addressLength > Internals.SocketAddress.IPv4AddressSize) - ? AddressFamily.InterNetworkV6 : AddressFamily.InterNetwork; - Internals.SocketAddress sockAddress = new Internals.SocketAddress(family, addressLength); - Marshal.Copy(address, sockAddress.Buffer, 0, addressLength); - - return sockAddress.GetIPAddress(); + return IPEndPointExtensions.GetIPAddress(new Span((void*)address, addressLength)); } } @@ -511,7 +504,7 @@ internal static unsafe partial uint GetAdaptersAddresses( uint* outBufLen); [LibraryImport(Interop.Libraries.IpHlpApi)] - internal static unsafe partial uint GetBestInterfaceEx(byte* ipAddress, int* index); + internal static unsafe partial uint GetBestInterfaceEx(Span ipAddress, int* index); [LibraryImport(Interop.Libraries.IpHlpApi)] internal static partial uint GetIfEntry2(ref MibIfRow2 pIfRow); diff --git a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSAConnect.cs b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSAConnect.cs index 77c846af58f8d0..6f3faacef32793 100644 --- a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSAConnect.cs +++ b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSAConnect.cs @@ -10,13 +10,22 @@ internal static partial class Interop internal static partial class Winsock { [LibraryImport(Interop.Libraries.Ws2_32, SetLastError = true)] - internal static partial SocketError WSAConnect( + private static partial SocketError WSAConnect( SafeSocketHandle socketHandle, - byte[] socketAddress, + ReadOnlySpan socketAddress, int socketAddressSize, IntPtr inBuffer, IntPtr outBuffer, IntPtr sQOS, IntPtr gQOS); + + internal static SocketError WSAConnect( + SafeSocketHandle socketHandle, + ReadOnlySpan socketAddress, + IntPtr inBuffer, + IntPtr outBuffer, + IntPtr sQOS, + IntPtr gQOS) => + WSAConnect(socketHandle, socketAddress, socketAddress.Length, inBuffer, outBuffer, sQOS, gQOS); } } diff --git a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.accept.cs b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.accept.cs index 2202c0815ff177..531327f0e5ca9a 100644 --- a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.accept.cs +++ b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.accept.cs @@ -12,7 +12,7 @@ internal static partial class Winsock [LibraryImport(Interop.Libraries.Ws2_32, SetLastError = true)] internal static partial IntPtr accept( SafeSocketHandle socketHandle, - byte[] socketAddress, + Span socketAddress, ref int socketAddressSize); } } diff --git a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.recvfrom.cs b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.recvfrom.cs index 8c470c30268ef3..70d76733825acc 100644 --- a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.recvfrom.cs +++ b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.recvfrom.cs @@ -13,10 +13,10 @@ internal static partial class Winsock [LibraryImport(Interop.Libraries.Ws2_32, SetLastError = true)] internal static unsafe partial int recvfrom( SafeSocketHandle socketHandle, - byte* pinnedBuffer, + Span pinnedBuffer, int len, SocketFlags socketFlags, - byte[] socketAddress, + Span socketAddress, ref int socketAddressSize); } } diff --git a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.sendto.cs b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.sendto.cs index 01e036664049ec..e93d609401513b 100644 --- a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.sendto.cs +++ b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.sendto.cs @@ -15,7 +15,7 @@ internal static unsafe partial int sendto( byte* pinnedBuffer, int len, SocketFlags socketFlags, - byte[] socketAddress, + ReadOnlySpan socketAddress, int socketAddressSize); } } diff --git a/src/libraries/Common/src/System/Net/IPEndPointExtensions.cs b/src/libraries/Common/src/System/Net/IPEndPointExtensions.cs new file mode 100644 index 00000000000000..2308d48c29671f --- /dev/null +++ b/src/libraries/Common/src/System/Net/IPEndPointExtensions.cs @@ -0,0 +1,62 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Net; + +namespace System.Net.Sockets +{ + internal static class IPEndPointExtensions + { + public static IPAddress GetIPAddress(ReadOnlySpan socketAddressBuffer) + { + AddressFamily family = SocketAddressPal.GetAddressFamily(socketAddressBuffer); + + if (family == AddressFamily.InterNetworkV6) + { + Span address = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes]; + uint scope; + SocketAddressPal.GetIPv6Address(socketAddressBuffer, address, out scope); + return new IPAddress(address, (long)scope); + } + else if (family == AddressFamily.InterNetwork) + { + return new IPAddress((long)SocketAddressPal.GetIPv4Address(socketAddressBuffer) & 0x0FFFFFFFF); + } + + throw new SocketException((int)SocketError.AddressFamilyNotSupported); + } + + public static void SetIPAddress(Span socketAddressBuffer, IPAddress address) + { + SocketAddressPal.SetAddressFamily(socketAddressBuffer, address.AddressFamily); + SocketAddressPal.SetPort(socketAddressBuffer, 0); + if (address.AddressFamily == AddressFamily.InterNetwork) + { +#pragma warning disable CS0618 + SocketAddressPal.SetIPv4Address(socketAddressBuffer, (uint)address.Address); +#pragma warning restore CS0618 + } + else + { + Span addressBuffer = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes]; + address.TryWriteBytes(addressBuffer, out int written); + Debug.Assert(written == IPAddressParserStatics.IPv6AddressBytes); + SocketAddressPal.SetIPv6Address(socketAddressBuffer, addressBuffer, (uint)address.ScopeId); + } + } + + public static IPEndPoint CreateIPEndPoint(ReadOnlySpan socketAddressBuffer) + { + return new IPEndPoint(GetIPAddress(socketAddressBuffer), SocketAddressPal.GetPort(socketAddressBuffer)); + } + + // suggestion from https://github.com/dotnet/runtime/issues/78993 + public static void Serialize(this IPEndPoint endPoint, Span destination) + { + SocketAddressPal.SetAddressFamily(destination, endPoint.AddressFamily); + SetIPAddress(destination, endPoint.Address); + SocketAddressPal.SetPort(destination, (ushort)endPoint.Port); + } + } +} diff --git a/src/libraries/Common/src/System/Net/SocketAddress.cs b/src/libraries/Common/src/System/Net/SocketAddress.cs index 62323be5564998..c7ad3785555166 100644 --- a/src/libraries/Common/src/System/Net/SocketAddress.cs +++ b/src/libraries/Common/src/System/Net/SocketAddress.cs @@ -20,27 +20,26 @@ namespace System.Net.Internals #else internal sealed #endif - class SocketAddress + class SocketAddress : System.IEquatable { #pragma warning disable CA1802 // these could be const on Windows but need to be static readonly for Unix internal static readonly int IPv6AddressSize = SocketAddressPal.IPv6AddressSize; internal static readonly int IPv4AddressSize = SocketAddressPal.IPv4AddressSize; + internal static readonly int UdsAddressSize = SocketAddressPal.UdsAddressSize; + internal static readonly int MaxAddressSize = SocketAddressPal.MaxAddressSize; #pragma warning restore CA1802 internal int InternalSize; - internal byte[] Buffer; + internal byte[] InternalBuffer; private const int MinSize = 2; - private const int MaxSize = 32; // IrDA requires 32 bytes private const int DataOffset = 2; - private bool _changed = true; - private int _hash; public AddressFamily Family { get { - return SocketAddressPal.GetAddressFamily(Buffer); + return SocketAddressPal.GetAddressFamily(InternalBuffer); } } @@ -50,6 +49,12 @@ public int Size { return InternalSize; } + set + { + ArgumentOutOfRangeException.ThrowIfGreaterThan(value, InternalBuffer.Length); + ArgumentOutOfRangeException.ThrowIfLessThan(value, MinSize); + InternalSize = value; + } } // Access to unmanaged serialized data. This doesn't @@ -60,27 +65,31 @@ public byte this[int offset] { get { - if (offset < 0 || offset >= Size) + if ((uint)offset >= (uint)Size) { throw new IndexOutOfRangeException(); } - return Buffer[offset]; + return InternalBuffer[offset]; } set { - if (offset < 0 || offset >= Size) + if ((uint)offset >= (uint)Size) { throw new IndexOutOfRangeException(); } - if (Buffer[offset] != value) - { - _changed = true; - } - Buffer[offset] = value; + InternalBuffer[offset] = value; } } - public SocketAddress(AddressFamily family) : this(family, MaxSize) + public static int GetMaximumAddressSize(AddressFamily addressFamily) => addressFamily switch + { + AddressFamily.InterNetwork => IPv4AddressSize, + AddressFamily.InterNetworkV6 => IPv6AddressSize, + AddressFamily.Unix => UdsAddressSize, + _ => MaxAddressSize + }; + + public SocketAddress(AddressFamily family) : this(family, GetMaximumAddressSize(family)) { } @@ -95,17 +104,19 @@ public SocketAddress(AddressFamily family, int size) // The following formula will extend 'size' to the alignment boundary then add IntPtr.Size more bytes. size = (size + IntPtr.Size - 1) / IntPtr.Size * IntPtr.Size + IntPtr.Size; #endif - Buffer = new byte[size]; + InternalBuffer = new byte[size]; + InternalBuffer[0] = (byte)InternalSize; - SocketAddressPal.SetAddressFamily(Buffer, family); + SocketAddressPal.SetAddressFamily(InternalBuffer, family); } internal SocketAddress(IPAddress ipAddress) : this(ipAddress.AddressFamily, ((ipAddress.AddressFamily == AddressFamily.InterNetwork) ? IPv4AddressSize : IPv6AddressSize)) { + // No Port. - SocketAddressPal.SetPort(Buffer, 0); + SocketAddressPal.SetPort(InternalBuffer, 0); if (ipAddress.AddressFamily == AddressFamily.InterNetworkV6) { @@ -113,7 +124,7 @@ internal SocketAddress(IPAddress ipAddress) ipAddress.TryWriteBytes(addressBytes, out int bytesWritten); Debug.Assert(bytesWritten == IPAddressParserStatics.IPv6AddressBytes); - SocketAddressPal.SetIPv6Address(Buffer, addressBytes, (uint)ipAddress.ScopeId); + SocketAddressPal.SetIPv6Address(InternalBuffer, addressBytes, (uint)ipAddress.ScopeId); } else { @@ -122,21 +133,33 @@ internal SocketAddress(IPAddress ipAddress) #pragma warning restore CS0618 Debug.Assert(ipAddress.AddressFamily == AddressFamily.InterNetwork); - SocketAddressPal.SetIPv4Address(Buffer, address); + SocketAddressPal.SetIPv4Address(InternalBuffer, address); } } internal SocketAddress(IPAddress ipaddress, int port) : this(ipaddress) { - SocketAddressPal.SetPort(Buffer, unchecked((ushort)port)); + SocketAddressPal.SetPort(InternalBuffer, unchecked((ushort)port)); } internal SocketAddress(AddressFamily addressFamily, ReadOnlySpan buffer) { - Buffer = buffer.ToArray(); - InternalSize = Buffer.Length; - SocketAddressPal.SetAddressFamily(Buffer, addressFamily); + InternalBuffer = buffer.ToArray(); + InternalSize = InternalBuffer.Length; + SocketAddressPal.SetAddressFamily(InternalBuffer, addressFamily); + } + + /// This represents underlying memory that can be passed to native OS calls. + /// + /// Content of the memory can be invalidated if is changed or if the SocketAddress is used in another receive call. + /// + public Memory Buffer + { + get + { + return new Memory(InternalBuffer, 0, InternalSize); + } } internal IPAddress GetIPAddress() @@ -147,14 +170,14 @@ internal IPAddress GetIPAddress() Span address = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes]; uint scope; - SocketAddressPal.GetIPv6Address(Buffer, address, out scope); + SocketAddressPal.GetIPv6Address(InternalBuffer, address, out scope); return new IPAddress(address, (long)scope); } else if (Family == AddressFamily.InterNetwork) { Debug.Assert(Size >= IPv4AddressSize); - long address = (long)SocketAddressPal.GetIPv4Address(Buffer) & 0x0FFFFFFFF; + long address = (long)SocketAddressPal.GetIPv4Address(InternalBuffer) & 0x0FFFFFFFF; return new IPAddress(address); } else @@ -167,7 +190,7 @@ internal IPAddress GetIPAddress() } } - internal int GetPort() => (int)SocketAddressPal.GetPort(Buffer); + internal int GetPort() => (int)SocketAddressPal.GetPort(InternalBuffer); internal IPEndPoint GetIPEndPoint() { @@ -179,51 +202,29 @@ internal IPEndPoint GetIPEndPoint() internal void CopyAddressSizeIntoBuffer() { int addressSizeOffset = GetAddressSizeOffset(); - Buffer[addressSizeOffset] = unchecked((byte)(InternalSize)); - Buffer[addressSizeOffset + 1] = unchecked((byte)(InternalSize >> 8)); - Buffer[addressSizeOffset + 2] = unchecked((byte)(InternalSize >> 16)); - Buffer[addressSizeOffset + 3] = unchecked((byte)(InternalSize >> 24)); + InternalBuffer[addressSizeOffset] = unchecked((byte)(InternalSize)); + InternalBuffer[addressSizeOffset + 1] = unchecked((byte)(InternalSize >> 8)); + InternalBuffer[addressSizeOffset + 2] = unchecked((byte)(InternalSize >> 16)); + InternalBuffer[addressSizeOffset + 3] = unchecked((byte)(InternalSize >> 24)); } // Can be called after the above method did work. internal int GetAddressSizeOffset() { - return Buffer.Length - IntPtr.Size; + return InternalBuffer.Length - IntPtr.Size; } #endif public override bool Equals(object? comparand) => - comparand is SocketAddress other && - Buffer.AsSpan(0, Size).SequenceEqual(other.Buffer.AsSpan(0, other.Size)); + comparand is SocketAddress other && Equals(other); + + public bool Equals(SocketAddress? comparand) => comparand != null && Buffer.Span.SequenceEqual(comparand.Buffer.Span); public override int GetHashCode() { - if (_changed) - { - _changed = false; - _hash = 0; - - int i; - int size = Size & ~3; - - for (i = 0; i < size; i += 4) - { - _hash ^= BinaryPrimitives.ReadInt32LittleEndian(Buffer.AsSpan(i)); - } - if ((Size & 3) != 0) - { - int remnant = 0; - int shift = 0; - - for (; i < Size; ++i) - { - remnant |= ((int)Buffer[i]) << shift; - shift += 8; - } - _hash ^= remnant; - } - } - return _hash; + HashCode hash = default; + hash.AddBytes(new ReadOnlySpan(InternalBuffer, 0, InternalSize)); + return hash.ToHashCode(); } public override string ToString() @@ -257,7 +258,7 @@ public override string ToString() result[length++] = ':'; result[length++] = '{'; - byte[] buffer = Buffer; + byte[] buffer = InternalBuffer; for (int i = DataOffset; i < Size; i++) { if (i > DataOffset) diff --git a/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs b/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs index 3aa6d95337c54c..107d1d5e659671 100644 --- a/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs +++ b/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs @@ -11,24 +11,30 @@ namespace System.Net { internal static class SocketAddressPal { - public static readonly int IPv6AddressSize = GetIPv6AddressSize(); - public static readonly int IPv4AddressSize = GetIPv4AddressSize(); + public static readonly int IPv4AddressSize; + public static readonly int IPv6AddressSize; + public static readonly int UdsAddressSize; + public static readonly int MaxAddressSize; - private static unsafe int GetIPv6AddressSize() +#pragma warning disable CA1810 + static unsafe SocketAddressPal() { - int ipv6AddressSize, unused; - Interop.Error err = Interop.Sys.GetIPSocketAddressSizes(&unused, &ipv6AddressSize); + int ipv4 = 0; + int ipv6 = 0; + int uds = 0; + int max = 0; + Interop.Error err = Interop.Sys.GetSocketAddressSizes(&ipv4, &ipv6, &uds, &max); Debug.Assert(err == Interop.Error.SUCCESS, $"Unexpected err: {err}"); - return ipv6AddressSize; - } - - private static unsafe int GetIPv4AddressSize() - { - int ipv4AddressSize, unused; - Interop.Error err = Interop.Sys.GetIPSocketAddressSizes(&ipv4AddressSize, &unused); - Debug.Assert(err == Interop.Error.SUCCESS, $"Unexpected err: {err}"); - return ipv4AddressSize; + Debug.Assert(ipv4 > 0); + Debug.Assert(ipv6 > 0); + Debug.Assert(uds > 0); + Debug.Assert(max >= ipv4 && max >= ipv6 && max >= uds); + IPv4AddressSize = ipv4; + IPv6AddressSize = ipv6; + UdsAddressSize = uds; + MaxAddressSize = max; } +#pragma warning restore CA1810 private static void ThrowOnFailure(Interop.Error err) { @@ -64,7 +70,7 @@ public static unsafe AddressFamily GetAddressFamily(ReadOnlySpan buffer) return family; } - public static unsafe void SetAddressFamily(byte[] buffer, AddressFamily family) + public static unsafe void SetAddressFamily(Span buffer, AddressFamily family) { Interop.Error err; @@ -92,7 +98,7 @@ public static unsafe ushort GetPort(ReadOnlySpan buffer) return port; } - public static unsafe void SetPort(byte[] buffer, ushort port) + public static unsafe void SetPort(Span buffer, ushort port) { Interop.Error err; fixed (byte* rawAddress = buffer) @@ -130,7 +136,7 @@ public static unsafe void GetIPv6Address(ReadOnlySpan buffer, Span a scope = localScope; } - public static unsafe void SetIPv4Address(byte[] buffer, uint address) + public static unsafe void SetIPv4Address(Span buffer, uint address) { Interop.Error err; fixed (byte* rawAddress = buffer) @@ -141,21 +147,22 @@ public static unsafe void SetIPv4Address(byte[] buffer, uint address) ThrowOnFailure(err); } - public static unsafe void SetIPv4Address(byte[] buffer, byte* address) + public static unsafe void SetIPv4Address(Span buffer, byte* address) { uint addr = (uint)System.Runtime.InteropServices.Marshal.ReadInt32((IntPtr)address); SetIPv4Address(buffer, addr); } - public static unsafe void SetIPv6Address(byte[] buffer, Span address, uint scope) + public static unsafe void SetIPv6Address(Span buffer, Span address, uint scope) { + fixed (byte* rawInput = &MemoryMarshal.GetReference(address)) { SetIPv6Address(buffer, rawInput, address.Length, scope); } } - public static unsafe void SetIPv6Address(byte[] buffer, byte* address, int addressLength, uint scope) + public static unsafe void SetIPv6Address(Span buffer, byte* address, int addressLength, uint scope) { Interop.Error err; fixed (byte* rawAddress = buffer) @@ -165,5 +172,14 @@ public static unsafe void SetIPv6Address(byte[] buffer, byte* address, int addre ThrowOnFailure(err); } + + public static unsafe void Clear(Span buffer) + { + AddressFamily family = GetAddressFamily(buffer); + buffer.Clear(); + // platforms where this matters (OSXLike & BSD) use uint8 for SA length + buffer[0] = (byte)Math.Min(buffer.Length, 255); + SetAddressFamily(buffer, family); + } } } diff --git a/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs b/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs index d9bdfb12797044..a563675dab62fd 100644 --- a/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs +++ b/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs @@ -10,13 +10,15 @@ internal static class SocketAddressPal { public const int IPv6AddressSize = 28; public const int IPv4AddressSize = 16; + public const int UdsAddressSize = 110; + public const int MaxAddressSize = 128; public static AddressFamily GetAddressFamily(ReadOnlySpan buffer) { return (AddressFamily)BitConverter.ToInt16(buffer); } - public static void SetAddressFamily(byte[] buffer, AddressFamily family) + public static void SetAddressFamily(Span buffer, AddressFamily family) { if ((int)(family) > ushort.MaxValue) { @@ -36,8 +38,8 @@ public static void SetAddressFamily(byte[] buffer, AddressFamily family) public static ushort GetPort(ReadOnlySpan buffer) => BinaryPrimitives.ReadUInt16BigEndian(buffer.Slice(2)); - public static void SetPort(byte[] buffer, ushort port) - => BinaryPrimitives.WriteUInt16BigEndian(buffer.AsSpan(2), port); + public static void SetPort(Span buffer, ushort port) + => BinaryPrimitives.WriteUInt16BigEndian(buffer.Slice(2), port); public static uint GetIPv4Address(ReadOnlySpan buffer) => BinaryPrimitives.ReadUInt32LittleEndian(buffer.Slice(4)); @@ -49,22 +51,29 @@ public static void GetIPv6Address(ReadOnlySpan buffer, Span address, scope = BinaryPrimitives.ReadUInt32LittleEndian(buffer.Slice(24)); } - public static void SetIPv4Address(byte[] buffer, uint address) + public static void SetIPv4Address(Span buffer, uint address) { // IPv4 Address serialization - BinaryPrimitives.WriteUInt32LittleEndian(buffer.AsSpan(4), address); + BinaryPrimitives.WriteUInt32LittleEndian(buffer.Slice(4), address); } - public static void SetIPv6Address(byte[] buffer, Span address, uint scope) + public static void SetIPv6Address(Span buffer, Span address, uint scope) { // No handling for Flow Information - BinaryPrimitives.WriteUInt32LittleEndian(buffer.AsSpan(4), 0); + BinaryPrimitives.WriteUInt32LittleEndian(buffer.Slice(4), 0); // Scope serialization - BinaryPrimitives.WriteUInt32LittleEndian(buffer.AsSpan(24), scope); + BinaryPrimitives.WriteUInt32LittleEndian(buffer.Slice(24), scope); // Address serialization - address.CopyTo(buffer.AsSpan(8)); + address.CopyTo(buffer.Slice(8)); + } + + public static unsafe void Clear(Span buffer) + { + AddressFamily family = GetAddressFamily(buffer); + buffer.Clear(); + SetAddressFamily(buffer, family); } } } diff --git a/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.Unix.cs b/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.Unix.cs index 973254788e5427..0e6e96a4dd20db 100644 --- a/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.Unix.cs +++ b/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.Unix.cs @@ -1,7 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Net.Internals; using System.Net.Sockets; using System.Runtime.InteropServices; @@ -9,6 +8,7 @@ namespace System.Net { internal static partial class SocketProtocolSupportPal { + private const int DgramSocketType = 2; private static unsafe bool IsSupported(AddressFamily af) { // Check for AF_UNIX on iOS/tvOS. The OS claims to support this, but returns EPERM on bind. @@ -21,7 +21,7 @@ private static unsafe bool IsSupported(AddressFamily af) IntPtr socket = invalid; try { - Interop.Error result = Interop.Sys.Socket(af, SocketType.Dgram, 0, &socket); + Interop.Error result = Interop.Sys.Socket((int)af, DgramSocketType, 0, &socket); // we get EAFNOSUPPORT when family is not supported by Kernel, EPROTONOSUPPORT may come from policy enforcement like FreeBSD jail() return result != Interop.Error.EAFNOSUPPORT && result != Interop.Error.EPROTONOSUPPORT; } diff --git a/src/libraries/System.Net.NameResolution/tests/PalTests/System.Net.NameResolution.Pal.Tests.csproj b/src/libraries/System.Net.NameResolution/tests/PalTests/System.Net.NameResolution.Pal.Tests.csproj index e0aa8f994b0c03..5256dd88482328 100644 --- a/src/libraries/System.Net.NameResolution/tests/PalTests/System.Net.NameResolution.Pal.Tests.csproj +++ b/src/libraries/System.Net.NameResolution/tests/PalTests/System.Net.NameResolution.Pal.Tests.csproj @@ -19,9 +19,9 @@ - - diff --git a/src/libraries/System.Net.NetworkInformation/src/System.Net.NetworkInformation.csproj b/src/libraries/System.Net.NetworkInformation/src/System.Net.NetworkInformation.csproj index 2a742d79f71441..3cf6d5f597f673 100644 --- a/src/libraries/System.Net.NetworkInformation/src/System.Net.NetworkInformation.csproj +++ b/src/libraries/System.Net.NetworkInformation/src/System.Net.NetworkInformation.csproj @@ -82,7 +82,7 @@ - + diff --git a/src/libraries/System.Net.NetworkInformation/src/System/Net/NetworkInformation/SystemNetworkInterface.cs b/src/libraries/System.Net.NetworkInformation/src/System/Net/NetworkInformation/SystemNetworkInterface.cs index 4c62c3770d96a2..96256e552a0f1e 100644 --- a/src/libraries/System.Net.NetworkInformation/src/System/Net/NetworkInformation/SystemNetworkInterface.cs +++ b/src/libraries/System.Net.NetworkInformation/src/System/Net/NetworkInformation/SystemNetworkInterface.cs @@ -44,14 +44,13 @@ internal static int InternalIPv6LoopbackInterfaceIndex private static unsafe int GetBestInterfaceForAddress(IPAddress addr) { int index; - Internals.SocketAddress address = new Internals.SocketAddress(addr); - fixed (byte* buffer = address.Buffer) + Span buffer= stackalloc byte[SocketAddressPal.IPv6AddressSize]; + IPEndPointExtensions.SetIPAddress(buffer, addr); + + int error = (int)Interop.IpHlpApi.GetBestInterfaceEx(buffer, &index); + if (error != 0) { - int error = (int)Interop.IpHlpApi.GetBestInterfaceEx(buffer, &index); - if (error != 0) - { - throw new NetworkInformationException(error); - } + throw new NetworkInformationException(error); } return index; diff --git a/src/libraries/System.Net.Ping/src/System.Net.Ping.csproj b/src/libraries/System.Net.Ping/src/System.Net.Ping.csproj index 9c817f3908a98c..aaaf38f5bf51fa 100644 --- a/src/libraries/System.Net.Ping/src/System.Net.Ping.csproj +++ b/src/libraries/System.Net.Ping/src/System.Net.Ping.csproj @@ -24,17 +24,14 @@ - - - - + + @@ -96,9 +93,6 @@ Link="Common\Interop\Windows\WinSock\Interop.WSAStartup.cs" /> - - diff --git a/src/libraries/System.Net.Ping/src/System/Net/NetworkInformation/Ping.Windows.cs b/src/libraries/System.Net.Ping/src/System/Net/NetworkInformation/Ping.Windows.cs index 9a1b3075282e34..b5ca9832bdb3ed 100644 --- a/src/libraries/System.Net.Ping/src/System/Net/NetworkInformation/Ping.Windows.cs +++ b/src/libraries/System.Net.Ping/src/System/Net/NetworkInformation/Ping.Windows.cs @@ -183,10 +183,11 @@ private unsafe int SendEcho(IPAddress address, byte[] buffer, int timeout, PingO (uint)timeout); } - IPEndPoint ep = new IPEndPoint(address, 0); - Internals.SocketAddress remoteAddr = IPEndPointExtensions.Serialize(ep); - byte* sourceAddr = stackalloc byte[28]; - NativeMemory.Clear(sourceAddr, 28); + Span remoteAddr = stackalloc byte[SocketAddressPal.IPv6AddressSize]; + IPEndPointExtensions.SetIPAddress(remoteAddr, address); + + Span sourceAddr = stackalloc byte[SocketAddressPal.IPv6AddressSize]; + sourceAddr.Clear(); return (int)Interop.IpHlpApi.Icmp6SendEcho2( _handlePingV6!, @@ -194,7 +195,7 @@ private unsafe int SendEcho(IPAddress address, byte[] buffer, int timeout, PingO IntPtr.Zero, IntPtr.Zero, sourceAddr, - remoteAddr.Buffer, + remoteAddr, _requestBuffer!, (ushort)buffer.Length, ref ipOptions, diff --git a/src/libraries/System.Net.Primitives/ref/System.Net.Primitives.cs b/src/libraries/System.Net.Primitives/ref/System.Net.Primitives.cs index 2ed166f381b9b0..480548a33f1b7d 100644 --- a/src/libraries/System.Net.Primitives/ref/System.Net.Primitives.cs +++ b/src/libraries/System.Net.Primitives/ref/System.Net.Primitives.cs @@ -349,14 +349,17 @@ public NetworkCredential(string? userName, string? password, string? domain) { } public System.Net.NetworkCredential GetCredential(string? host, int port, string? authenticationType) { throw null; } public System.Net.NetworkCredential GetCredential(System.Uri? uri, string? authenticationType) { throw null; } } - public partial class SocketAddress + public partial class SocketAddress : System.IEquatable { public SocketAddress(System.Net.Sockets.AddressFamily family) { } public SocketAddress(System.Net.Sockets.AddressFamily family, int size) { } public System.Net.Sockets.AddressFamily Family { get { throw null; } } public byte this[int offset] { get { throw null; } set { } } - public int Size { get { throw null; } } + public int Size { get { throw null; } set { } } + public static int GetMaximumAddressSize(System.Net.Sockets.AddressFamily addressFamily) { throw null; } + public System.Memory Buffer { get { throw null; } } public override bool Equals(object? comparand) { throw null; } + public bool Equals(System.Net.SocketAddress? comparand) { throw null; } public override int GetHashCode() { throw null; } public override string ToString() { throw null; } } diff --git a/src/libraries/System.Net.Primitives/tests/FunctionalTests/SocketAddressTest.cs b/src/libraries/System.Net.Primitives/tests/FunctionalTests/SocketAddressTest.cs index 70ef816d1d0af5..0a185b7b42296f 100644 --- a/src/libraries/System.Net.Primitives/tests/FunctionalTests/SocketAddressTest.cs +++ b/src/libraries/System.Net.Primitives/tests/FunctionalTests/SocketAddressTest.cs @@ -14,7 +14,7 @@ public static void Ctor_AddressFamily_Success() { SocketAddress sa = new SocketAddress(AddressFamily.InterNetwork); Assert.Equal(AddressFamily.InterNetwork, sa.Family); - Assert.Equal(32, sa.Size); + Assert.Equal(16, sa.Size); } [Fact] @@ -31,6 +31,28 @@ public static void Ctor_AddressFamilySize_Invalid() Assert.Throws(() => new SocketAddress(AddressFamily.InterNetwork, 1)); //Size < MinSize (32) } + [Theory] + [InlineData(AddressFamily.InterNetwork)] + [InlineData(AddressFamily.InterNetworkV6)] + [InlineData(AddressFamily.Unix)] + public static void Ctor_AddressFamilySize_Correct(AddressFamily addressFamily) + { + SocketAddress sa = new SocketAddress(addressFamily); + Assert.Equal(SocketAddress.GetMaximumAddressSize(addressFamily), sa.Size); + Assert.Equal(SocketAddress.GetMaximumAddressSize(addressFamily), sa.Buffer.Length); + Assert.True(sa.Size <= SocketAddress.GetMaximumAddressSize(AddressFamily.Unknown)); + } + + [Fact] + public static void AddressFamily_Size_Correct() + { + SocketAddress sa = new SocketAddress(AddressFamily.InterNetwork); + Assert.Throws(() => sa.Size = sa.Size + 1); + + sa.Size = 4; + Assert.Equal(4, sa.Buffer.Length); + } + [Fact] public static void Equals_Compare_Success() { diff --git a/src/libraries/System.Net.Quic/src/System.Net.Quic.csproj b/src/libraries/System.Net.Quic/src/System.Net.Quic.csproj index 53b39492c39a4e..3b94c902e62a98 100644 --- a/src/libraries/System.Net.Quic/src/System.Net.Quic.csproj +++ b/src/libraries/System.Net.Quic/src/System.Net.Quic.csproj @@ -28,9 +28,8 @@ - - + diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicExtensions.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicExtensions.cs index fbade293cfda3a..a3d7bc6f3f7d3e 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicExtensions.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicExtensions.cs @@ -8,7 +8,7 @@ namespace Microsoft.Quic; internal unsafe partial struct QUIC_NEW_CONNECTION_INFO { public override string ToString() - => $"{{ {nameof(QuicVersion)} = {QuicVersion}, {nameof(LocalAddress)} = {LocalAddress->ToIPEndPoint()}, {nameof(RemoteAddress)} = {RemoteAddress->ToIPEndPoint()} }}"; + => $"{{ {nameof(QuicVersion)} = {QuicVersion}, {nameof(LocalAddress)} = {MsQuicHelpers.QuicAddrToIPEndPoint(LocalAddress)}, {nameof(RemoteAddress)} = {MsQuicHelpers.QuicAddrToIPEndPoint(RemoteAddress)} }}"; } internal unsafe partial struct QUIC_LISTENER_EVENT @@ -17,7 +17,7 @@ public override string ToString() => Type switch { QUIC_LISTENER_EVENT_TYPE.NEW_CONNECTION - => $"{{ {nameof(NEW_CONNECTION.Info)} = {{ {nameof(QUIC_NEW_CONNECTION_INFO.QuicVersion)} = {NEW_CONNECTION.Info->QuicVersion}, {nameof(QUIC_NEW_CONNECTION_INFO.LocalAddress)} = {NEW_CONNECTION.Info->LocalAddress->ToIPEndPoint()}, {nameof(QUIC_NEW_CONNECTION_INFO.RemoteAddress)} = {NEW_CONNECTION.Info->RemoteAddress->ToIPEndPoint()} }} }}", + => $"{{ {nameof(NEW_CONNECTION.Info)} = {{ {nameof(QUIC_NEW_CONNECTION_INFO.QuicVersion)} = {NEW_CONNECTION.Info->QuicVersion}, {nameof(QUIC_NEW_CONNECTION_INFO.LocalAddress)} = {MsQuicHelpers.QuicAddrToIPEndPoint(NEW_CONNECTION.Info->LocalAddress)}, {nameof(QUIC_NEW_CONNECTION_INFO.RemoteAddress)} = {MsQuicHelpers.QuicAddrToIPEndPoint(NEW_CONNECTION.Info->RemoteAddress)} }} }}", _ => string.Empty }; } @@ -36,9 +36,9 @@ public override string ToString() QUIC_CONNECTION_EVENT_TYPE.SHUTDOWN_COMPLETE => $"{{ {nameof(SHUTDOWN_COMPLETE.HandshakeCompleted)} = {SHUTDOWN_COMPLETE.HandshakeCompleted}, {nameof(SHUTDOWN_COMPLETE.PeerAcknowledgedShutdown)} = {SHUTDOWN_COMPLETE.PeerAcknowledgedShutdown}, {nameof(SHUTDOWN_COMPLETE.AppCloseInProgress)} = {SHUTDOWN_COMPLETE.AppCloseInProgress} }}", QUIC_CONNECTION_EVENT_TYPE.LOCAL_ADDRESS_CHANGED - => $"{{ {nameof(LOCAL_ADDRESS_CHANGED.Address)} = {LOCAL_ADDRESS_CHANGED.Address->ToIPEndPoint()} }}", + => $"{{ {nameof(LOCAL_ADDRESS_CHANGED.Address)} = {MsQuicHelpers.QuicAddrToIPEndPoint(LOCAL_ADDRESS_CHANGED.Address)} }}", QUIC_CONNECTION_EVENT_TYPE.PEER_ADDRESS_CHANGED - => $"{{ {nameof(PEER_ADDRESS_CHANGED.Address)} = {PEER_ADDRESS_CHANGED.Address->ToIPEndPoint()} }}", + => $"{{ {nameof(PEER_ADDRESS_CHANGED.Address)} = {MsQuicHelpers.QuicAddrToIPEndPoint(PEER_ADDRESS_CHANGED.Address)} }}", QUIC_CONNECTION_EVENT_TYPE.PEER_STREAM_STARTED => $"{{ {nameof(PEER_STREAM_STARTED.Flags)} = {PEER_STREAM_STARTED.Flags} }}", QUIC_CONNECTION_EVENT_TYPE.PEER_CERTIFICATE_RECEIVED diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs index bf454c047c3cdd..ad3ed5c134599d 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs @@ -35,23 +35,22 @@ internal static bool TryParse(this EndPoint endPoint, out string? host, out IPAd return false; } - internal static unsafe IPEndPoint ToIPEndPoint(this ref QuicAddr quicAddress, AddressFamily? addressFamilyOverride = null) + internal static unsafe IPEndPoint QuicAddrToIPEndPoint(QuicAddr* quicAddress, AddressFamily? addressFamilyOverride = null) { // MsQuic always uses storage size as if IPv6 was used - Span addressBytes = new Span((byte*)Unsafe.AsPointer(ref quicAddress), Internals.SocketAddress.IPv6AddressSize); - return new Internals.SocketAddress(addressFamilyOverride ?? SocketAddressPal.GetAddressFamily(addressBytes), addressBytes).GetIPEndPoint(); + Span addressBytes = new Span(quicAddress, SocketAddressPal.IPv6AddressSize); + if (addressFamilyOverride != null) + { + SocketAddressPal.SetAddressFamily(addressBytes, (AddressFamily)addressFamilyOverride!); + } + return IPEndPointExtensions.CreateIPEndPoint(addressBytes); } internal static unsafe QuicAddr ToQuicAddr(this IPEndPoint ipEndPoint) { - // TODO: is the layout same for SocketAddress.Buffer and QuicAddr on all platforms? QuicAddr result = default; Span rawAddress = MemoryMarshal.AsBytes(MemoryMarshal.CreateSpan(ref result, 1)); - - Internals.SocketAddress address = IPEndPointExtensions.Serialize(ipEndPoint); - Debug.Assert(address.Size <= rawAddress.Length); - - address.Buffer.AsSpan(0, address.Size).CopyTo(rawAddress); + ipEndPoint.Serialize(rawAddress); return result; } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs index b3ac83567cee41..4c82110c118368 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs @@ -239,8 +239,8 @@ internal unsafe QuicConnection(QUIC_HANDLE* handle, QUIC_NEW_CONNECTION_INFO* in throw; } - _remoteEndPoint = info->RemoteAddress->ToIPEndPoint(); - _localEndPoint = info->LocalAddress->ToIPEndPoint(); + _remoteEndPoint = MsQuicHelpers.QuicAddrToIPEndPoint(info->RemoteAddress); + _localEndPoint = MsQuicHelpers.QuicAddrToIPEndPoint(info->LocalAddress); #if DEBUG _tlsSecret = MsQuicTlsSecret.Create(_handle); #endif @@ -478,10 +478,10 @@ private unsafe int HandleEventConnected(ref CONNECTED_DATA data) _negotiatedApplicationProtocol = new SslApplicationProtocol(new Span(data.NegotiatedAlpn, data.NegotiatedAlpnLength).ToArray()); QuicAddr remoteAddress = MsQuicHelpers.GetMsQuicParameter(_handle, QUIC_PARAM_CONN_REMOTE_ADDRESS); - _remoteEndPoint = remoteAddress.ToIPEndPoint(); + _remoteEndPoint = MsQuicHelpers.QuicAddrToIPEndPoint(&remoteAddress); QuicAddr localAddress = MsQuicHelpers.GetMsQuicParameter(_handle, QUIC_PARAM_CONN_LOCAL_ADDRESS); - _localEndPoint = localAddress.ToIPEndPoint(); + _localEndPoint = MsQuicHelpers.QuicAddrToIPEndPoint(&localAddress); if (NetEventSource.Log.IsEnabled()) { @@ -514,12 +514,12 @@ private unsafe int HandleEventShutdownComplete() } private unsafe int HandleEventLocalAddressChanged(ref LOCAL_ADDRESS_CHANGED_DATA data) { - _localEndPoint = data.Address->ToIPEndPoint(); + _localEndPoint = MsQuicHelpers.QuicAddrToIPEndPoint(data.Address); return QUIC_STATUS_SUCCESS; } private unsafe int HandleEventPeerAddressChanged(ref PEER_ADDRESS_CHANGED_DATA data) { - _remoteEndPoint = data.Address->ToIPEndPoint(); + _remoteEndPoint = MsQuicHelpers.QuicAddrToIPEndPoint(data.Address); return QUIC_STATUS_SUCCESS; } private unsafe int HandleEventPeerStreamStarted(ref PEER_STREAM_STARTED_DATA data) diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs index 7b4c6613afabf8..4116d4a609f795 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs @@ -157,7 +157,7 @@ private unsafe QuicListener(QuicListenerOptions options) // Get the actual listening endpoint. address = GetMsQuicParameter(_handle, QUIC_PARAM_LISTENER_LOCAL_ADDRESS); - LocalEndPoint = address.ToIPEndPoint(options.ListenEndPoint.AddressFamily); + LocalEndPoint = MsQuicHelpers.QuicAddrToIPEndPoint(&address, options.ListenEndPoint.AddressFamily); } /// @@ -281,7 +281,7 @@ private unsafe int HandleEventNewConnection(ref NEW_CONNECTION_DATA data) { if (NetEventSource.Log.IsEnabled()) { - NetEventSource.Info(this, $"{this} Refusing connection from {data.Info->RemoteAddress->ToIPEndPoint()} due to backlog limit"); + NetEventSource.Info(this, $"{this} Refusing connection from {MsQuicHelpers.QuicAddrToIPEndPoint(data.Info->RemoteAddress)} due to backlog limit"); } Interlocked.Increment(ref _pendingConnectionsCapacity); diff --git a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs index c9b67c966f7a0b..57de68d185656b 100644 --- a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs +++ b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs @@ -105,6 +105,8 @@ public partial class LingerOption public LingerOption(bool enable, int seconds) { } public bool Enabled { get { throw null; } set { } } public int LingerTime { get { throw null; } set { } } + public override bool Equals(object? comparand) { throw null; } + public override int GetHashCode() { throw null; } } public partial class MulticastOption { @@ -223,7 +225,7 @@ public enum ProtocolType public sealed partial class SafeSocketHandle : Microsoft.Win32.SafeHandles.SafeHandleMinusOneIsInvalid { public SafeSocketHandle() : base (default(bool)) { } - public SafeSocketHandle(System.IntPtr preexistingHandle, bool ownsHandle) : base (default(bool)) { } + public SafeSocketHandle(nint preexistingHandle, bool ownsHandle) : base (default(bool)) { } public override bool IsInvalid { get { throw null; } } protected override bool ReleaseHandle() { throw null; } } @@ -272,7 +274,7 @@ public Socket(System.Net.Sockets.SocketType socketType, System.Net.Sockets.Proto public bool DualMode { get { throw null; } set { } } public bool EnableBroadcast { get { throw null; } set { } } public bool ExclusiveAddressUse { get { throw null; } set { } } - public System.IntPtr Handle { get { throw null; } } + public nint Handle { get { throw null; } } public bool IsBound { get { throw null; } } [System.Diagnostics.CodeAnalysis.DisallowNullAttribute] public System.Net.Sockets.LingerOption? LingerState { get { throw null; } set { } } @@ -398,6 +400,7 @@ public void Listen(int backlog) { } public int ReceiveFrom(byte[] buffer, System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP) { throw null; } public int ReceiveFrom(System.Span buffer, ref System.Net.EndPoint remoteEP) { throw null; } public int ReceiveFrom(System.Span buffer, System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP) { throw null; } + public int ReceiveFrom(System.Span buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.SocketAddress receivedSocketAddress) { throw null; } public System.Threading.Tasks.Task ReceiveFromAsync(System.ArraySegment buffer, System.Net.EndPoint remoteEndPoint) { throw null; } public System.Threading.Tasks.Task ReceiveFromAsync(System.ArraySegment buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint) { throw null; } public System.Threading.Tasks.ValueTask ReceiveFromAsync(System.Memory buffer, System.Net.EndPoint remoteEndPoint, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } @@ -442,6 +445,7 @@ public void SendFile(string? fileName, System.ReadOnlySpan preBuffer, Syst public int SendTo(byte[] buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEP) { throw null; } public int SendTo(System.ReadOnlySpan buffer, System.Net.EndPoint remoteEP) { throw null; } public int SendTo(System.ReadOnlySpan buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEP) { throw null; } + public int SendTo(System.ReadOnlySpan buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.SocketAddress socketAddress) { throw null; } public System.Threading.Tasks.Task SendToAsync(System.ArraySegment buffer, System.Net.EndPoint remoteEP) { throw null; } public System.Threading.Tasks.Task SendToAsync(System.ArraySegment buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEP) { throw null; } public bool SendToAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; } @@ -811,6 +815,8 @@ public sealed partial class UnixDomainSocketEndPoint : System.Net.EndPoint public UnixDomainSocketEndPoint(string path) { } public override System.Net.Sockets.AddressFamily AddressFamily { get { throw null; } } public override System.Net.EndPoint Create(System.Net.SocketAddress socketAddress) { throw null; } + public override bool Equals([System.Diagnostics.CodeAnalysis.NotNullWhenAttribute(true)] object? obj) { throw null; } + public override int GetHashCode() { throw null; } public override System.Net.SocketAddress Serialize() { throw null; } public override string ToString() { throw null; } } diff --git a/src/libraries/System.Net.Sockets/src/Resources/Strings.resx b/src/libraries/System.Net.Sockets/src/Resources/Strings.resx index 72a5478e850ac2..7a0feca077c032 100644 --- a/src/libraries/System.Net.Sockets/src/Resources/Strings.resx +++ b/src/libraries/System.Net.Sockets/src/Resources/Strings.resx @@ -315,4 +315,7 @@ Handle is already used by another Socket. + + Provided SocketAddress is too small for given AddressFamily. + diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs index c7467ee9d7b416..9b93f56b23cbc0 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs @@ -73,7 +73,7 @@ public Socket(SocketInformation socketInformation) Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(ep); unsafe { - fixed (byte* bufferPtr = socketAddress.Buffer) + fixed (byte* bufferPtr = socketAddress.InternalBuffer) fixed (int* sizePtr = &socketAddress.InternalSize) { errorCode = SocketPal.GetSockName(_handle, bufferPtr, sizePtr); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 6b724c9a1709cb..3f450bb2544ed2 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -304,7 +304,7 @@ public EndPoint? LocalEndPoint unsafe { - fixed (byte* buffer = socketAddress.Buffer) + fixed (byte* buffer = socketAddress.InternalBuffer) fixed (int* bufferSize = &socketAddress.InternalSize) { // This may throw ObjectDisposedException. @@ -346,7 +346,7 @@ public EndPoint? RemoteEndPoint // This may throw ObjectDisposedException. SocketError errorCode = SocketPal.GetPeerName( _handle, - socketAddress.Buffer, + socketAddress.InternalBuffer, ref socketAddress.InternalSize); if (errorCode != SocketError.Success) @@ -765,7 +765,7 @@ private void DoBind(EndPoint endPointSnapshot, Internals.SocketAddress socketAdd SocketError errorCode = SocketPal.Bind( _handle, _protocolType, - socketAddress.Buffer, + socketAddress.InternalBuffer, socketAddress.Size); // Throw an appropriate SocketException if the native call fails. @@ -1014,13 +1014,15 @@ public Socket Accept() // This may throw ObjectDisposedException. SafeSocketHandle acceptedSocketHandle; SocketError errorCode; + int socketAddressLen; try { errorCode = SocketPal.Accept( _handle, socketAddress.Buffer, - ref socketAddress.InternalSize, + out socketAddressLen, out acceptedSocketHandle); + socketAddress.Size = socketAddressLen; } catch (Exception ex) { @@ -1282,7 +1284,7 @@ public int SendTo(byte[] buffer, int offset, int size, SocketFlags socketFlags, Internals.SocketAddress socketAddress = Serialize(ref remoteEP); int bytesTransferred; - SocketError errorCode = SocketPal.SendTo(_handle, buffer, offset, size, socketFlags, socketAddress.Buffer, socketAddress.Size, out bytesTransferred); + SocketError errorCode = SocketPal.SendTo(_handle, buffer, offset, size, socketFlags, socketAddress.Buffer, out bytesTransferred); // Throw an appropriate SocketException if the native call fails. if (errorCode != SocketError.Success) @@ -1354,7 +1356,7 @@ public int SendTo(ReadOnlySpan buffer, SocketFlags socketFlags, EndPoint r Internals.SocketAddress socketAddress = Serialize(ref remoteEP); int bytesTransferred; - SocketError errorCode = SocketPal.SendTo(_handle, buffer, socketFlags, socketAddress.Buffer, socketAddress.Size, out bytesTransferred); + SocketError errorCode = SocketPal.SendTo(_handle, buffer, socketFlags, socketAddress.Buffer, out bytesTransferred); // Throw an appropriate SocketException if the native call fails. if (errorCode != SocketError.Success) @@ -1375,6 +1377,42 @@ public int SendTo(ReadOnlySpan buffer, SocketFlags socketFlags, EndPoint r return bytesTransferred; } + /// + /// Sends data to a specific endpoint using the specified . + /// + /// A span of bytes that contains the data to be sent. + /// A bitwise combination of the values. + /// The that represents the destination for the data. + /// The number of bytes sent. + /// remoteEP is . + /// An error occurred when attempting to access the socket. + /// The has been closed. + public int SendTo(ReadOnlySpan buffer, SocketFlags socketFlags, SocketAddress socketAddress) + { + ThrowIfDisposed(); + ArgumentNullException.ThrowIfNull(socketAddress); + + ValidateBlockingMode(); + + int bytesTransferred; + SocketError errorCode = SocketPal.SendTo(_handle, buffer, socketFlags, socketAddress.Buffer, out bytesTransferred); + + // Throw an appropriate SocketException if the native call fails. + if (errorCode != SocketError.Success) + { + UpdateSendSocketErrorForDisposed(ref errorCode); + + UpdateStatusAfterSocketErrorAndThrowException(errorCode); + } + else if (SocketsTelemetry.Log.IsEnabled()) + { + SocketsTelemetry.Log.BytesSent(bytesTransferred); + if (SocketType == SocketType.Dgram) SocketsTelemetry.Log.DatagramSent(); + } + + return bytesTransferred; + } + // Receives data from a connected socket. public int Receive(byte[] buffer, int size, SocketFlags socketFlags) { @@ -1681,7 +1719,7 @@ public int ReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFl Internals.SocketAddress socketAddressOriginal = IPEndPointExtensions.Serialize(endPointSnapshot); int bytesTransferred; - SocketError errorCode = SocketPal.ReceiveFrom(_handle, buffer, offset, size, socketFlags, socketAddress.Buffer, ref socketAddress.InternalSize, out bytesTransferred); + SocketError errorCode = SocketPal.ReceiveFrom(_handle, buffer, offset, size, socketFlags, socketAddress.Buffer, out int socketAddressLength, out bytesTransferred); UpdateReceiveSocketErrorForDisposed(ref errorCode, bytesTransferred); // If the native call fails we'll throw a SocketException. @@ -1703,6 +1741,8 @@ public int ReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFl if (SocketType == SocketType.Dgram) SocketsTelemetry.Log.DatagramReceived(); } + socketAddress.Size = socketAddressLength; + if (!socketAddressOriginal.Equals(socketAddress)) { try @@ -1788,7 +1828,7 @@ public int ReceiveFrom(Span buffer, SocketFlags socketFlags, ref EndPoint Internals.SocketAddress socketAddressOriginal = IPEndPointExtensions.Serialize(endPointSnapshot); int bytesTransferred; - SocketError errorCode = SocketPal.ReceiveFrom(_handle, buffer, socketFlags, socketAddress.Buffer, ref socketAddress.InternalSize, out bytesTransferred); + SocketError errorCode = SocketPal.ReceiveFrom(_handle, buffer, socketFlags, socketAddress.Buffer, out int socketAddressLength, out bytesTransferred); UpdateReceiveSocketErrorForDisposed(ref errorCode, bytesTransferred); // If the native call fails we'll throw a SocketException. @@ -1810,6 +1850,7 @@ public int ReceiveFrom(Span buffer, SocketFlags socketFlags, ref EndPoint if (SocketType == SocketType.Dgram) SocketsTelemetry.Log.DatagramReceived(); } + socketAddress.Size = socketAddressLength; if (!socketAddressOriginal.Equals(socketAddress)) { try @@ -1840,6 +1881,52 @@ public int ReceiveFrom(Span buffer, SocketFlags socketFlags, ref EndPoint return bytesTransferred; } + /// + /// Receives a datagram into the data buffer, using the specified , and stores the endpoint. + /// + /// A span of bytes that is the storage location for received data. + /// A bitwise combination of the values. + /// An , that will be updated with value of the remote peer. + /// The number of bytes received. + /// remoteEP is . + /// An error occurred when attempting to access the socket. + /// The has been closed. + public int ReceiveFrom(Span buffer, SocketFlags socketFlags, SocketAddress receivedSocketAddress) + { + ThrowIfDisposed(); + + if (receivedSocketAddress.Size < SocketAddress.GetMaximumAddressSize(AddressFamily)) + { + throw new ArgumentOutOfRangeException(nameof(receivedSocketAddress), SR.net_sockets_address_small); + } + + ValidateBlockingMode(); + + int bytesTransferred; + SocketError errorCode = SocketPal.ReceiveFrom(_handle, buffer, socketFlags, receivedSocketAddress.Buffer, out int socketAddressSize, out bytesTransferred); + if (socketAddressSize > 0) + { + receivedSocketAddress.Size = socketAddressSize; + } + UpdateReceiveSocketErrorForDisposed(ref errorCode, bytesTransferred); + // If the native call fails we'll throw a SocketException. + if (errorCode != SocketError.Success) + { + SocketException socketException = new SocketException((int)errorCode); + UpdateStatusAfterSocketError(socketException); + if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, socketException); + + throw socketException; + } + else if (SocketsTelemetry.Log.IsEnabled()) + { + SocketsTelemetry.Log.BytesReceived(bytesTransferred); + if (SocketType == SocketType.Dgram) SocketsTelemetry.Log.DatagramReceived(); + } + + return bytesTransferred; + } + public int IOControl(int ioControlCode, byte[]? optionInValue, byte[]? optionOutValue) { ThrowIfDisposed(); @@ -3082,7 +3169,7 @@ private void DoConnect(EndPoint endPointSnapshot, Internals.SocketAddress socket SocketError errorCode; try { - errorCode = SocketPal.Connect(_handle, socketAddress.Buffer, socketAddress.Size); + errorCode = SocketPal.Connect(_handle, socketAddress.Buffer); } catch (Exception ex) { diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs index 1926097b8e45d6..3f2e27bc784016 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs @@ -48,7 +48,7 @@ private void ReturnOperation(AcceptOperation operation) { operation.Reset(); operation.Callback = null; - operation.SocketAddress = null; + operation.SocketAddress = default; Volatile.Write(ref _cachedAcceptOperation, operation); // benign race condition } @@ -57,7 +57,7 @@ private void ReturnOperation(BufferMemoryReceiveOperation operation) operation.Reset(); operation.Buffer = default; operation.Callback = null; - operation.SocketAddress = null; + operation.SocketAddress = default; Volatile.Write(ref _cachedBufferMemoryReceiveOperation, operation); // benign race condition } @@ -66,7 +66,7 @@ private void ReturnOperation(BufferListReceiveOperation operation) operation.Reset(); operation.Buffers = null; operation.Callback = null; - operation.SocketAddress = null; + operation.SocketAddress = default; Volatile.Write(ref _cachedBufferListReceiveOperation, operation); // benign race condition } @@ -75,7 +75,7 @@ private void ReturnOperation(BufferMemorySendOperation operation) operation.Reset(); operation.Buffer = default; operation.Callback = null; - operation.SocketAddress = null; + operation.SocketAddress = default; Volatile.Write(ref _cachedBufferMemorySendOperation, operation); // benign race condition } @@ -84,7 +84,7 @@ private void ReturnOperation(BufferListSendOperation operation) operation.Reset(); operation.Buffers = null; operation.Callback = null; - operation.SocketAddress = null; + operation.SocketAddress = default; Volatile.Write(ref _cachedBufferListSendOperation, operation); // benign race condition } @@ -128,8 +128,7 @@ private enum State public readonly SocketAsyncContext AssociatedContext; public AsyncOperation Next = null!; // initialized by helper called from ctor public SocketError ErrorCode; - public byte[]? SocketAddress; - public int SocketAddressLen; + public Memory SocketAddress; public CancellationTokenRegistration CancellationRegistration; public ManualResetEventSlim? Event { get; set; } @@ -348,7 +347,7 @@ public WriteOperation(SocketAsyncContext context) : base(context) { } void IThreadPoolWorkItem.Execute() => AssociatedContext.ProcessAsyncWriteOperation(this); } - private abstract class SendOperation : WriteOperation + private abstract unsafe class SendOperation : WriteOperation { public SocketFlags Flags; public int BytesTransferred; @@ -357,10 +356,10 @@ private abstract class SendOperation : WriteOperation public SendOperation(SocketAsyncContext context) : base(context) { } - public Action? Callback { get; set; } + public Action, SocketFlags, SocketError>? Callback { get; set; } - public override void InvokeCallback(bool allowPooling) => - Callback!(BytesTransferred, SocketAddress, SocketAddressLen, SocketFlags.None, ErrorCode); + public override unsafe void InvokeCallback(bool allowPooling) => + Callback!(BytesTransferred, SocketAddress, SocketFlags.None, ErrorCode); } private sealed class BufferMemorySendOperation : SendOperation @@ -372,15 +371,14 @@ public BufferMemorySendOperation(SocketAsyncContext context) : base(context) { } protected override bool DoTryComplete(SocketAsyncContext context) { int bufferIndex = 0; - return SocketPal.TryCompleteSendTo(context._socket, Buffer.Span, null, ref bufferIndex, ref Offset, ref Count, Flags, SocketAddress, SocketAddressLen, ref BytesTransferred, out ErrorCode); + return SocketPal.TryCompleteSendTo(context._socket, Buffer.Span, null, ref bufferIndex, ref Offset, ref Count, Flags, SocketAddress.Span, ref BytesTransferred, out ErrorCode); } - public override void InvokeCallback(bool allowPooling) + public override unsafe void InvokeCallback(bool allowPooling) { var cb = Callback!; int bt = BytesTransferred; - byte[]? sa = SocketAddress; - int sal = SocketAddressLen; + Memory sa = SocketAddress; SocketError ec = ErrorCode; if (allowPooling) @@ -388,7 +386,7 @@ public override void InvokeCallback(bool allowPooling) AssociatedContext.ReturnOperation(this); } - cb(bt, sa, sal, SocketFlags.None, ec); + cb(bt, sa, SocketFlags.None, ec); } } @@ -401,15 +399,14 @@ public BufferListSendOperation(SocketAsyncContext context) : base(context) { } protected override bool DoTryComplete(SocketAsyncContext context) { - return SocketPal.TryCompleteSendTo(context._socket, default(ReadOnlySpan), Buffers, ref BufferIndex, ref Offset, ref Count, Flags, SocketAddress, SocketAddressLen, ref BytesTransferred, out ErrorCode); + return SocketPal.TryCompleteSendTo(context._socket, default(ReadOnlySpan), Buffers, ref BufferIndex, ref Offset, ref Count, Flags, SocketAddress.Span, ref BytesTransferred, out ErrorCode); } public override void InvokeCallback(bool allowPooling) { var cb = Callback!; int bt = BytesTransferred; - byte[]? sa = SocketAddress; - int sal = SocketAddressLen; + Memory sa = SocketAddress; SocketError ec = ErrorCode; if (allowPooling) @@ -417,7 +414,7 @@ public override void InvokeCallback(bool allowPooling) AssociatedContext.ReturnOperation(this); } - cb(bt, sa, sal, SocketFlags.None, ec); + cb(bt, sa, SocketFlags.None, ec); } } @@ -431,7 +428,7 @@ protected override bool DoTryComplete(SocketAsyncContext context) { int bufferIndex = 0; int bufferLength = Offset + Count; // TryCompleteSendTo expects the entire buffer, which it then indexes into with the ref Offset and ref Count arguments - return SocketPal.TryCompleteSendTo(context._socket, new ReadOnlySpan(BufferPtr, bufferLength), null, ref bufferIndex, ref Offset, ref Count, Flags, SocketAddress, SocketAddressLen, ref BytesTransferred, out ErrorCode); + return SocketPal.TryCompleteSendTo(context._socket, new ReadOnlySpan(BufferPtr, bufferLength), null, ref bufferIndex, ref Offset, ref Count, Flags, SocketAddress.Span, ref BytesTransferred, out ErrorCode); } } @@ -443,10 +440,10 @@ private abstract class ReceiveOperation : ReadOperation public ReceiveOperation(SocketAsyncContext context) : base(context) { } - public Action? Callback { get; set; } + public Action, SocketFlags, SocketError>? Callback { get; set; } public override void InvokeCallback(bool allowPooling) => - Callback!(BytesTransferred, SocketAddress, SocketAddressLen, ReceivedFlags, ErrorCode); + Callback!(BytesTransferred, SocketAddress, ReceivedFlags, ErrorCode); } private sealed class BufferMemoryReceiveOperation : ReceiveOperation @@ -460,7 +457,7 @@ protected override bool DoTryComplete(SocketAsyncContext context) { // Zero byte read is performed to know when data is available. // We don't have to call receive, our caller is interested in the event. - if (Buffer.Length == 0 && Flags == SocketFlags.None && SocketAddress == null) + if (Buffer.Length == 0 && Flags == SocketFlags.None && SocketAddress.Length == 0) { BytesTransferred = 0; ReceivedFlags = SocketFlags.None; @@ -471,14 +468,16 @@ protected override bool DoTryComplete(SocketAsyncContext context) { if (!SetReceivedFlags) { - Debug.Assert(SocketAddress == null); + Debug.Assert(SocketAddress.Length == 0); ReceivedFlags = SocketFlags.None; return SocketPal.TryCompleteReceive(context._socket, Buffer.Span, Flags, out BytesTransferred, out ErrorCode); } else { - return SocketPal.TryCompleteReceiveFrom(context._socket, Buffer.Span, null, Flags, SocketAddress, ref SocketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode); + bool result = SocketPal.TryCompleteReceiveFrom(context._socket, Buffer.Span, null, Flags, SocketAddress.Span, out int socketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode); + SocketAddress = SocketAddress.Slice(0, socketAddressLen); + return result; } } } @@ -487,8 +486,7 @@ public override void InvokeCallback(bool allowPooling) { var cb = Callback!; int bt = BytesTransferred; - byte[]? sa = SocketAddress; - int sal = SocketAddressLen; + Memory sa = SocketAddress; SocketFlags rf = ReceivedFlags; SocketError ec = ErrorCode; @@ -497,7 +495,7 @@ public override void InvokeCallback(bool allowPooling) AssociatedContext.ReturnOperation(this); } - cb(bt, sa, sal, rf, ec); + cb(bt, sa, rf, ec); } } @@ -507,15 +505,21 @@ private sealed class BufferListReceiveOperation : ReceiveOperation public BufferListReceiveOperation(SocketAsyncContext context) : base(context) { } - protected override bool DoTryComplete(SocketAsyncContext context) => - SocketPal.TryCompleteReceiveFrom(context._socket, default(Span), Buffers, Flags, SocketAddress, ref SocketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode); + protected override bool DoTryComplete(SocketAsyncContext context) + { + bool completed = SocketPal.TryCompleteReceiveFrom(context._socket, default(Span), Buffers, Flags, SocketAddress.Span, out int socketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode); + if (ErrorCode == SocketError.Success) + { + SocketAddress = SocketAddress.Slice(0, socketAddressLen); + } + return completed; + } public override void InvokeCallback(bool allowPooling) { var cb = Callback!; int bt = BytesTransferred; - byte[]? sa = SocketAddress; - int sal = SocketAddressLen; + Memory sa = SocketAddress; SocketFlags rf = ReceivedFlags; SocketError ec = ErrorCode; @@ -524,7 +528,7 @@ public override void InvokeCallback(bool allowPooling) AssociatedContext.ReturnOperation(this); } - cb(bt, sa, sal, rf, ec); + cb(bt, sa, rf, ec); } } @@ -535,8 +539,15 @@ private sealed unsafe class BufferPtrReceiveOperation : ReceiveOperation public BufferPtrReceiveOperation(SocketAsyncContext context) : base(context) { } - protected override bool DoTryComplete(SocketAsyncContext context) => - SocketPal.TryCompleteReceiveFrom(context._socket, new Span(BufferPtr, Length), null, Flags, SocketAddress, ref SocketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode); + protected override bool DoTryComplete(SocketAsyncContext context) + { + bool completed = SocketPal.TryCompleteReceiveFrom(context._socket, new Span(BufferPtr, Length), null, Flags, SocketAddress.Span, out int socketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode); + if (ErrorCode == SocketError.Success) + { + SocketAddress = SocketAddress.Slice(0, socketAddressLen); + } + return completed; + } } private sealed class ReceiveMessageFromOperation : ReadOperation @@ -553,13 +564,20 @@ private sealed class ReceiveMessageFromOperation : ReadOperation public ReceiveMessageFromOperation(SocketAsyncContext context) : base(context) { } - public Action? Callback { get; set; } + public Action, SocketFlags, IPPacketInformation, SocketError>? Callback { get; set; } - protected override bool DoTryComplete(SocketAsyncContext context) => - SocketPal.TryCompleteReceiveMessageFrom(context._socket, Buffer.Span, Buffers, Flags, SocketAddress!, ref SocketAddressLen, IsIPv4, IsIPv6, out BytesTransferred, out ReceivedFlags, out IPPacketInformation, out ErrorCode); + protected override bool DoTryComplete(SocketAsyncContext context) + { + bool completed = SocketPal.TryCompleteReceiveMessageFrom(context._socket, Buffer.Span, Buffers, Flags, SocketAddress, out int socketAddressLen, IsIPv4, IsIPv6, out BytesTransferred, out ReceivedFlags, out IPPacketInformation, out ErrorCode); + if (ErrorCode == SocketError.Success) + { + SocketAddress = SocketAddress.Slice(0, socketAddressLen); + } + return completed; + } public override void InvokeCallback(bool allowPooling) => - Callback!(BytesTransferred, SocketAddress!, SocketAddressLen, ReceivedFlags, IPPacketInformation, ErrorCode); + Callback!(BytesTransferred, SocketAddress, ReceivedFlags, IPPacketInformation, ErrorCode); } private sealed unsafe class BufferPtrReceiveMessageFromOperation : ReadOperation @@ -576,13 +594,20 @@ private sealed unsafe class BufferPtrReceiveMessageFromOperation : ReadOperation public BufferPtrReceiveMessageFromOperation(SocketAsyncContext context) : base(context) { } - public Action? Callback { get; set; } + public Action, SocketFlags, IPPacketInformation, SocketError>? Callback { get; set; } - protected override bool DoTryComplete(SocketAsyncContext context) => - SocketPal.TryCompleteReceiveMessageFrom(context._socket, new Span(BufferPtr, Length), null, Flags, SocketAddress!, ref SocketAddressLen, IsIPv4, IsIPv6, out BytesTransferred, out ReceivedFlags, out IPPacketInformation, out ErrorCode); + protected override bool DoTryComplete(SocketAsyncContext context) + { + bool completed = SocketPal.TryCompleteReceiveMessageFrom(context._socket, new Span(BufferPtr, Length), null, Flags, SocketAddress!, out int socketAddressLen, IsIPv4, IsIPv6, out BytesTransferred, out ReceivedFlags, out IPPacketInformation, out ErrorCode); + if (ErrorCode == SocketError.Success) + { + SocketAddress = SocketAddress.Slice(0, socketAddressLen); + } + return completed; + } public override void InvokeCallback(bool allowPooling) => - Callback!(BytesTransferred, SocketAddress!, SocketAddressLen, ReceivedFlags, IPPacketInformation, ErrorCode); + Callback!(BytesTransferred, SocketAddress, ReceivedFlags, IPPacketInformation, ErrorCode); } private sealed class AcceptOperation : ReadOperation @@ -591,12 +616,16 @@ private sealed class AcceptOperation : ReadOperation public AcceptOperation(SocketAsyncContext context) : base(context) { } - public Action? Callback { get; set; } + public Action, SocketError>? Callback { get; set; } protected override bool DoTryComplete(SocketAsyncContext context) { - bool completed = SocketPal.TryCompleteAccept(context._socket, SocketAddress!, ref SocketAddressLen, out AcceptedFileDescriptor, out ErrorCode); + bool completed = SocketPal.TryCompleteAccept(context._socket, SocketAddress, out int socketAddressLen, out AcceptedFileDescriptor, out ErrorCode); Debug.Assert(ErrorCode == SocketError.Success || AcceptedFileDescriptor == (IntPtr)(-1), $"Unexpected values: ErrorCode={ErrorCode}, AcceptedFileDescriptor={AcceptedFileDescriptor}"); + if (ErrorCode == SocketError.Success) + { + SocketAddress = SocketAddress.Slice(0, socketAddressLen); + } return completed; } @@ -604,8 +633,7 @@ public override void InvokeCallback(bool allowPooling) { var cb = Callback!; IntPtr fd = AcceptedFileDescriptor; - byte[] sa = SocketAddress!; - int sal = SocketAddressLen; + Memory sa = SocketAddress; SocketError ec = ErrorCode; if (allowPooling) @@ -613,7 +641,7 @@ public override void InvokeCallback(bool allowPooling) AssociatedContext.ReturnOperation(this); } - cb(fd, sa, sal, ec); + cb(fd, sa, ec); } } @@ -1383,15 +1411,14 @@ private bool ShouldRetrySyncOperation(out SocketError errorCode) private void ProcessAsyncWriteOperation(WriteOperation op) => _sendQueue.ProcessAsyncOperation(op); - public SocketError Accept(byte[] socketAddress, ref int socketAddressLen, out IntPtr acceptedFd) + public SocketError Accept(Memory socketAddress, out int socketAddressLen, out IntPtr acceptedFd) { - Debug.Assert(socketAddress != null, "Expected non-null socketAddress"); - Debug.Assert(socketAddressLen > 0, $"Unexpected socketAddressLen: {socketAddressLen}"); + Debug.Assert(socketAddress.Length > 0, $"Unexpected socketAddressLen: {socketAddress.Length}"); SocketError errorCode; int observedSequenceNumber; if (_receiveQueue.IsReady(this, out observedSequenceNumber) && - SocketPal.TryCompleteAccept(_socket, socketAddress, ref socketAddressLen, out acceptedFd, out errorCode)) + SocketPal.TryCompleteAccept(_socket, socketAddress, out socketAddressLen, out acceptedFd, out errorCode)) { Debug.Assert(errorCode == SocketError.Success || acceptedFd == (IntPtr)(-1), $"Unexpected values: errorCode={errorCode}, acceptedFd={acceptedFd}"); return errorCode; @@ -1400,20 +1427,18 @@ public SocketError Accept(byte[] socketAddress, ref int socketAddressLen, out In var operation = new AcceptOperation(this) { SocketAddress = socketAddress, - SocketAddressLen = socketAddressLen, }; PerformSyncOperation(ref _receiveQueue, operation, -1, observedSequenceNumber); - socketAddressLen = operation.SocketAddressLen; + socketAddressLen = operation.SocketAddress.Length; acceptedFd = operation.AcceptedFileDescriptor; return operation.ErrorCode; } - public SocketError AcceptAsync(byte[] socketAddress, ref int socketAddressLen, out IntPtr acceptedFd, Action callback, CancellationToken cancellationToken) + public SocketError AcceptAsync(Memory socketAddress, out int socketAddressLen, out IntPtr acceptedFd, Action, SocketError> callback, CancellationToken cancellationToken) { - Debug.Assert(socketAddress != null, "Expected non-null socketAddress"); - Debug.Assert(socketAddressLen > 0, $"Unexpected socketAddressLen: {socketAddressLen}"); + Debug.Assert(socketAddress.Length > 0, $"Unexpected socketAddressLen: {socketAddress.Length}"); Debug.Assert(callback != null, "Expected non-null callback"); SetHandleNonBlocking(); @@ -1421,7 +1446,7 @@ public SocketError AcceptAsync(byte[] socketAddress, ref int socketAddressLen, o SocketError errorCode; int observedSequenceNumber; if (_receiveQueue.IsReady(this, out observedSequenceNumber) && - SocketPal.TryCompleteAccept(_socket, socketAddress, ref socketAddressLen, out acceptedFd, out errorCode)) + SocketPal.TryCompleteAccept(_socket, socketAddress, out socketAddressLen, out acceptedFd, out errorCode)) { Debug.Assert(errorCode == SocketError.Success || acceptedFd == (IntPtr)(-1), $"Unexpected values: errorCode={errorCode}, acceptedFd={acceptedFd}"); @@ -1431,11 +1456,10 @@ public SocketError AcceptAsync(byte[] socketAddress, ref int socketAddressLen, o AcceptOperation operation = RentAcceptOperation(); operation.Callback = callback; operation.SocketAddress = socketAddress; - operation.SocketAddressLen = socketAddressLen; if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber, cancellationToken)) { - socketAddressLen = operation.SocketAddressLen; + socketAddressLen = operation.SocketAddress.Length; acceptedFd = operation.AcceptedFileDescriptor; errorCode = operation.ErrorCode; @@ -1444,13 +1468,13 @@ public SocketError AcceptAsync(byte[] socketAddress, ref int socketAddressLen, o } acceptedFd = (IntPtr)(-1); + socketAddressLen = 0; return SocketError.IOPending; } - public SocketError Connect(byte[] socketAddress, int socketAddressLen) + public SocketError Connect(Memory socketAddress) { - Debug.Assert(socketAddress != null, "Expected non-null socketAddress"); - Debug.Assert(socketAddressLen > 0, $"Unexpected socketAddressLen: {socketAddressLen}"); + Debug.Assert(socketAddress.Length > 0, $"Unexpected socketAddressLen: {socketAddress.Length}"); // Connect is different than the usual "readiness" pattern of other operations. // We need to call TryStartConnect to initiate the connect with the OS, @@ -1459,7 +1483,7 @@ public SocketError Connect(byte[] socketAddress, int socketAddressLen) SocketError errorCode; int observedSequenceNumber; _sendQueue.IsReady(this, out observedSequenceNumber); - if (SocketPal.TryStartConnect(_socket, socketAddress, socketAddressLen, out errorCode) || + if (SocketPal.TryStartConnect(_socket, socketAddress, out errorCode) || !ShouldRetrySyncOperation(out errorCode)) { _socket.RegisterConnectResult(errorCode); @@ -1469,7 +1493,6 @@ public SocketError Connect(byte[] socketAddress, int socketAddressLen) var operation = new ConnectOperation(this) { SocketAddress = socketAddress, - SocketAddressLen = socketAddressLen }; PerformSyncOperation(ref _sendQueue, operation, -1, observedSequenceNumber); @@ -1477,10 +1500,9 @@ public SocketError Connect(byte[] socketAddress, int socketAddressLen) return operation.ErrorCode; } - public SocketError ConnectAsync(byte[] socketAddress, int socketAddressLen, Action callback) + public SocketError ConnectAsync(Memory socketAddress, Action callback) { - Debug.Assert(socketAddress != null, "Expected non-null socketAddress"); - Debug.Assert(socketAddressLen > 0, $"Unexpected socketAddressLen: {socketAddressLen}"); + Debug.Assert(socketAddress.Length > 0, $"Unexpected socketAddressLen: {socketAddress.Length}"); Debug.Assert(callback != null, "Expected non-null callback"); SetHandleNonBlocking(); @@ -1491,7 +1513,7 @@ public SocketError ConnectAsync(byte[] socketAddress, int socketAddressLen, Acti SocketError errorCode; int observedSequenceNumber; _sendQueue.IsReady(this, out observedSequenceNumber); - if (SocketPal.TryStartConnect(_socket, socketAddress, socketAddressLen, out errorCode)) + if (SocketPal.TryStartConnect(_socket, socketAddress, out errorCode)) { _socket.RegisterConnectResult(errorCode); return errorCode; @@ -1501,7 +1523,6 @@ public SocketError ConnectAsync(byte[] socketAddress, int socketAddressLen, Acti { Callback = callback, SocketAddress = socketAddress, - SocketAddressLen = socketAddressLen }; if (!_sendQueue.StartAsyncOperation(this, operation, observedSequenceNumber)) @@ -1514,23 +1535,20 @@ public SocketError ConnectAsync(byte[] socketAddress, int socketAddressLen, Acti public SocketError Receive(Memory buffer, SocketFlags flags, int timeout, out int bytesReceived) { - int socketAddressLen = 0; - return ReceiveFrom(buffer, ref flags, null, ref socketAddressLen, timeout, out bytesReceived); + return ReceiveFrom(buffer, ref flags, Memory.Empty, out int _, timeout, out bytesReceived); } public SocketError Receive(Span buffer, SocketFlags flags, int timeout, out int bytesReceived) { - int socketAddressLen = 0; - return ReceiveFrom(buffer, ref flags, null, ref socketAddressLen, timeout, out bytesReceived); + return ReceiveFrom(buffer, ref flags, Memory.Empty, out int _, timeout, out bytesReceived); } - public SocketError ReceiveAsync(Memory buffer, SocketFlags flags, out int bytesReceived, out SocketFlags receivedFlags, Action callback, CancellationToken cancellationToken) + public SocketError ReceiveAsync(Memory buffer, SocketFlags flags, out int bytesReceived, out SocketFlags receivedFlags, Action, SocketFlags, SocketError> callback, CancellationToken cancellationToken) { - int socketAddressLen = 0; - return ReceiveFromAsync(buffer, flags, null, ref socketAddressLen, out bytesReceived, out receivedFlags, callback, cancellationToken); + return ReceiveFromAsync(buffer, flags, Memory.Empty, out int _, out bytesReceived, out receivedFlags, callback, cancellationToken); } - public SocketError ReceiveFrom(Memory buffer, ref SocketFlags flags, byte[]? socketAddress, ref int socketAddressLen, int timeout, out int bytesReceived) + public unsafe SocketError ReceiveFrom(Memory buffer, ref SocketFlags flags, Memory socketAddress, out int socketAddressLen, int timeout, out int bytesReceived) { Debug.Assert(timeout == -1 || timeout > 0, $"Unexpected timeout: {timeout}"); @@ -1538,7 +1556,7 @@ public SocketError ReceiveFrom(Memory buffer, ref SocketFlags flags, byte[ SocketError errorCode; int observedSequenceNumber; if (_receiveQueue.IsReady(this, out observedSequenceNumber) && - (SocketPal.TryCompleteReceiveFrom(_socket, buffer.Span, flags, socketAddress, ref socketAddressLen, out bytesReceived, out receivedFlags, out errorCode) || + (SocketPal.TryCompleteReceiveFrom(_socket, buffer.Span, flags, socketAddress.Span, out socketAddressLen, out bytesReceived, out receivedFlags, out errorCode) || !ShouldRetrySyncOperation(out errorCode))) { flags = receivedFlags; @@ -1551,23 +1569,23 @@ public SocketError ReceiveFrom(Memory buffer, ref SocketFlags flags, byte[ Flags = flags, SetReceivedFlags = true, SocketAddress = socketAddress, - SocketAddressLen = socketAddressLen, }; PerformSyncOperation(ref _receiveQueue, operation, timeout, observedSequenceNumber); flags = operation.ReceivedFlags; bytesReceived = operation.BytesTransferred; + socketAddressLen = operation.SocketAddress.Length; return operation.ErrorCode; } - public unsafe SocketError ReceiveFrom(Span buffer, ref SocketFlags flags, byte[]? socketAddress, ref int socketAddressLen, int timeout, out int bytesReceived) + public unsafe SocketError ReceiveFrom(Span buffer, ref SocketFlags flags, Memory socketAddress, out int socketAddressLen, int timeout, out int bytesReceived) { SocketFlags receivedFlags; SocketError errorCode; int observedSequenceNumber; if (_receiveQueue.IsReady(this, out observedSequenceNumber) && - (SocketPal.TryCompleteReceiveFrom(_socket, buffer, flags, socketAddress, ref socketAddressLen, out bytesReceived, out receivedFlags, out errorCode) || + (SocketPal.TryCompleteReceiveFrom(_socket, buffer, flags, socketAddress.Span, out socketAddressLen, out bytesReceived, out receivedFlags, out errorCode) || !ShouldRetrySyncOperation(out errorCode))) { flags = receivedFlags; @@ -1582,18 +1600,18 @@ public unsafe SocketError ReceiveFrom(Span buffer, ref SocketFlags flags, Length = buffer.Length, Flags = flags, SocketAddress = socketAddress, - SocketAddressLen = socketAddressLen, }; PerformSyncOperation(ref _receiveQueue, operation, timeout, observedSequenceNumber); flags = operation.ReceivedFlags; bytesReceived = operation.BytesTransferred; + socketAddressLen = operation.SocketAddress.Length; return operation.ErrorCode; } } - public SocketError ReceiveAsync(Memory buffer, SocketFlags flags, out int bytesReceived, Action callback, CancellationToken cancellationToken = default) + public SocketError ReceiveAsync(Memory buffer, SocketFlags flags, out int bytesReceived, Action, SocketFlags, SocketError> callback, CancellationToken cancellationToken = default) { SetHandleNonBlocking(); @@ -1610,8 +1628,7 @@ public SocketError ReceiveAsync(Memory buffer, SocketFlags flags, out int operation.Callback = callback; operation.Buffer = buffer; operation.Flags = flags; - operation.SocketAddress = null; - operation.SocketAddressLen = 0; + operation.SocketAddress = default; if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber, cancellationToken)) { @@ -1626,14 +1643,14 @@ public SocketError ReceiveAsync(Memory buffer, SocketFlags flags, out int return SocketError.IOPending; } - public SocketError ReceiveFromAsync(Memory buffer, SocketFlags flags, byte[]? socketAddress, ref int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, Action callback, CancellationToken cancellationToken = default) + public SocketError ReceiveFromAsync(Memory buffer, SocketFlags flags, Memory socketAddress, out int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, Action, SocketFlags, SocketError> callback, CancellationToken cancellationToken = default) { SetHandleNonBlocking(); SocketError errorCode; int observedSequenceNumber; if (_receiveQueue.IsReady(this, out observedSequenceNumber) && - SocketPal.TryCompleteReceiveFrom(_socket, buffer.Span, flags, socketAddress, ref socketAddressLen, out bytesReceived, out receivedFlags, out errorCode)) + SocketPal.TryCompleteReceiveFrom(_socket, buffer.Span, flags, socketAddress.Span, out socketAddressLen, out bytesReceived, out receivedFlags, out errorCode)) { return errorCode; } @@ -1644,35 +1661,35 @@ public SocketError ReceiveFromAsync(Memory buffer, SocketFlags flags, byte operation.Buffer = buffer; operation.Flags = flags; operation.SocketAddress = socketAddress; - operation.SocketAddressLen = socketAddressLen; if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber, cancellationToken)) { receivedFlags = operation.ReceivedFlags; bytesReceived = operation.BytesTransferred; errorCode = operation.ErrorCode; + socketAddressLen = operation.SocketAddress.Length; ReturnOperation(operation); return errorCode; } bytesReceived = 0; + socketAddressLen = 0; receivedFlags = SocketFlags.None; return SocketError.IOPending; } public SocketError Receive(IList> buffers, SocketFlags flags, int timeout, out int bytesReceived) { - return ReceiveFrom(buffers, ref flags, null, 0, timeout, out bytesReceived); + return ReceiveFrom(buffers, ref flags, Memory.Empty, out int _, timeout, out bytesReceived); } - public SocketError ReceiveAsync(IList> buffers, SocketFlags flags, out int bytesReceived, out SocketFlags receivedFlags, Action callback) + public SocketError ReceiveAsync(IList> buffers, SocketFlags flags, out int bytesReceived, out SocketFlags receivedFlags, Action, SocketFlags, SocketError> callback) { - int socketAddressLen = 0; - return ReceiveFromAsync(buffers, flags, null, ref socketAddressLen, out bytesReceived, out receivedFlags, callback); + return ReceiveFromAsync(buffers, flags, Memory.Empty, out int _, out bytesReceived, out receivedFlags, callback); } - public SocketError ReceiveFrom(IList> buffers, ref SocketFlags flags, byte[]? socketAddress, int socketAddressLen, int timeout, out int bytesReceived) + public unsafe SocketError ReceiveFrom(IList> buffers, ref SocketFlags flags, Memory socketAddress, out int socketAddressLen, int timeout, out int bytesReceived) { Debug.Assert(timeout == -1 || timeout > 0, $"Unexpected timeout: {timeout}"); @@ -1680,7 +1697,7 @@ public SocketError ReceiveFrom(IList> buffers, ref SocketFlag SocketError errorCode; int observedSequenceNumber; if (_receiveQueue.IsReady(this, out observedSequenceNumber) && - (SocketPal.TryCompleteReceiveFrom(_socket, buffers, flags, socketAddress, ref socketAddressLen, out bytesReceived, out receivedFlags, out errorCode) || + (SocketPal.TryCompleteReceiveFrom(_socket, buffers, flags, socketAddress.Span, out socketAddressLen, out bytesReceived, out receivedFlags, out errorCode) || !ShouldRetrySyncOperation(out errorCode))) { flags = receivedFlags; @@ -1692,25 +1709,24 @@ public SocketError ReceiveFrom(IList> buffers, ref SocketFlag Buffers = buffers, Flags = flags, SocketAddress = socketAddress, - SocketAddressLen = socketAddressLen }; PerformSyncOperation(ref _receiveQueue, operation, timeout, observedSequenceNumber); - socketAddressLen = operation.SocketAddressLen; + socketAddressLen = operation.SocketAddress.Length; flags = operation.ReceivedFlags; bytesReceived = operation.BytesTransferred; return operation.ErrorCode; } - public SocketError ReceiveFromAsync(IList> buffers, SocketFlags flags, byte[]? socketAddress, ref int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, Action callback) + public SocketError ReceiveFromAsync(IList> buffers, SocketFlags flags, Memory socketAddress, out int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, Action, SocketFlags, SocketError> callback) { SetHandleNonBlocking(); SocketError errorCode; int observedSequenceNumber; if (_receiveQueue.IsReady(this, out observedSequenceNumber) && - SocketPal.TryCompleteReceiveFrom(_socket, buffers, flags, socketAddress, ref socketAddressLen, out bytesReceived, out receivedFlags, out errorCode)) + SocketPal.TryCompleteReceiveFrom(_socket, buffers, flags, socketAddress.Span, out socketAddressLen, out bytesReceived, out receivedFlags, out errorCode)) { // Synchronous success or failure return errorCode; @@ -1721,11 +1737,10 @@ public SocketError ReceiveFromAsync(IList> buffers, SocketFla operation.Buffers = buffers; operation.Flags = flags; operation.SocketAddress = socketAddress; - operation.SocketAddressLen = socketAddressLen; if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber)) { - socketAddressLen = operation.SocketAddressLen; + socketAddressLen = operation.SocketAddress.Length; receivedFlags = operation.ReceivedFlags; bytesReceived = operation.BytesTransferred; errorCode = operation.ErrorCode; @@ -1735,12 +1750,13 @@ public SocketError ReceiveFromAsync(IList> buffers, SocketFla } receivedFlags = SocketFlags.None; + socketAddressLen = 0; bytesReceived = 0; return SocketError.IOPending; } public SocketError ReceiveMessageFrom( - Memory buffer, ref SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, int timeout, out IPPacketInformation ipPacketInformation, out int bytesReceived) + Memory buffer, ref SocketFlags flags, Memory socketAddress, out int socketAddressLen, bool isIPv4, bool isIPv6, int timeout, out IPPacketInformation ipPacketInformation, out int bytesReceived) { Debug.Assert(timeout == -1 || timeout > 0, $"Unexpected timeout: {timeout}"); @@ -1748,7 +1764,7 @@ public SocketError ReceiveMessageFrom( SocketError errorCode; int observedSequenceNumber; if (_receiveQueue.IsReady(this, out observedSequenceNumber) && - (SocketPal.TryCompleteReceiveMessageFrom(_socket, buffer.Span, null, flags, socketAddress, ref socketAddressLen, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, out errorCode) || + (SocketPal.TryCompleteReceiveMessageFrom(_socket, buffer.Span, null, flags, socketAddress, out socketAddressLen, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, out errorCode) || !ShouldRetrySyncOperation(out errorCode))) { flags = receivedFlags; @@ -1761,14 +1777,13 @@ public SocketError ReceiveMessageFrom( Buffers = null, Flags = flags, SocketAddress = socketAddress, - SocketAddressLen = socketAddressLen, IsIPv4 = isIPv4, IsIPv6 = isIPv6, }; PerformSyncOperation(ref _receiveQueue, operation, timeout, observedSequenceNumber); - socketAddressLen = operation.SocketAddressLen; + socketAddressLen = operation.SocketAddress.Length; flags = operation.ReceivedFlags; ipPacketInformation = operation.IPPacketInformation; bytesReceived = operation.BytesTransferred; @@ -1776,7 +1791,7 @@ public SocketError ReceiveMessageFrom( } public unsafe SocketError ReceiveMessageFrom( - Span buffer, ref SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, int timeout, out IPPacketInformation ipPacketInformation, out int bytesReceived) + Span buffer, ref SocketFlags flags, Memory socketAddress, out int socketAddressLen, bool isIPv4, bool isIPv6, int timeout, out IPPacketInformation ipPacketInformation, out int bytesReceived) { Debug.Assert(timeout == -1 || timeout > 0, $"Unexpected timeout: {timeout}"); @@ -1784,7 +1799,7 @@ public unsafe SocketError ReceiveMessageFrom( SocketError errorCode; int observedSequenceNumber; if (_receiveQueue.IsReady(this, out observedSequenceNumber) && - (SocketPal.TryCompleteReceiveMessageFrom(_socket, buffer, null, flags, socketAddress, ref socketAddressLen, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, out errorCode) || + (SocketPal.TryCompleteReceiveMessageFrom(_socket, buffer, null, flags, socketAddress, out socketAddressLen, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, out errorCode) || !ShouldRetrySyncOperation(out errorCode))) { flags = receivedFlags; @@ -1799,14 +1814,13 @@ public unsafe SocketError ReceiveMessageFrom( Length = buffer.Length, Flags = flags, SocketAddress = socketAddress, - SocketAddressLen = socketAddressLen, IsIPv4 = isIPv4, IsIPv6 = isIPv6, }; PerformSyncOperation(ref _receiveQueue, operation, timeout, observedSequenceNumber); - socketAddressLen = operation.SocketAddressLen; + socketAddressLen = operation.SocketAddress.Length; flags = operation.ReceivedFlags; ipPacketInformation = operation.IPPacketInformation; bytesReceived = operation.BytesTransferred; @@ -1814,14 +1828,14 @@ public unsafe SocketError ReceiveMessageFrom( } } - public SocketError ReceiveMessageFromAsync(Memory buffer, IList>? buffers, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, out int bytesReceived, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, Action callback, CancellationToken cancellationToken = default) + public SocketError ReceiveMessageFromAsync(Memory buffer, IList>? buffers, SocketFlags flags, Memory socketAddress, out int socketAddressLen, bool isIPv4, bool isIPv6, out int bytesReceived, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, Action, SocketFlags, IPPacketInformation, SocketError> callback, CancellationToken cancellationToken = default) { SetHandleNonBlocking(); SocketError errorCode; int observedSequenceNumber; if (_receiveQueue.IsReady(this, out observedSequenceNumber) && - SocketPal.TryCompleteReceiveMessageFrom(_socket, buffer.Span, buffers, flags, socketAddress, ref socketAddressLen, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, out errorCode)) + SocketPal.TryCompleteReceiveMessageFrom(_socket, buffer.Span, buffers, flags, socketAddress, out socketAddressLen, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, out errorCode)) { return errorCode; } @@ -1833,14 +1847,13 @@ public SocketError ReceiveMessageFromAsync(Memory buffer, IList buffer, IList buffer, SocketFlags flags, int timeout, out int bytesSent) => - SendTo(buffer, flags, null, 0, timeout, out bytesSent); + SendTo(buffer, flags, Memory.Empty, timeout, out bytesSent); public SocketError Send(byte[] buffer, int offset, int count, SocketFlags flags, int timeout, out int bytesSent) { - return SendTo(buffer, offset, count, flags, null, 0, timeout, out bytesSent); + return SendTo(buffer, offset, count, flags, Memory.Empty, timeout, out bytesSent); } - public SocketError SendAsync(Memory buffer, int offset, int count, SocketFlags flags, out int bytesSent, Action callback, CancellationToken cancellationToken) + public SocketError SendAsync(Memory buffer, int offset, int count, SocketFlags flags, out int bytesSent, Action, SocketFlags, SocketError> callback, CancellationToken cancellationToken) { - int socketAddressLen = 0; - return SendToAsync(buffer, offset, count, flags, null, ref socketAddressLen, out bytesSent, callback, cancellationToken); + return SendToAsync(buffer, offset, count, flags, Memory.Empty, out bytesSent, callback, cancellationToken); } - public SocketError SendTo(byte[] buffer, int offset, int count, SocketFlags flags, byte[]? socketAddress, int socketAddressLen, int timeout, out int bytesSent) + public SocketError SendTo(byte[] buffer, int offset, int count, SocketFlags flags, Memory socketAddress, int timeout, out int bytesSent) { Debug.Assert(timeout == -1 || timeout > 0, $"Unexpected timeout: {timeout}"); @@ -1875,7 +1888,7 @@ public SocketError SendTo(byte[] buffer, int offset, int count, SocketFlags flag SocketError errorCode; int observedSequenceNumber; if (_sendQueue.IsReady(this, out observedSequenceNumber) && - (SocketPal.TryCompleteSendTo(_socket, buffer, ref offset, ref count, flags, socketAddress, socketAddressLen, ref bytesSent, out errorCode) || + (SocketPal.TryCompleteSendTo(_socket, buffer, ref offset, ref count, flags, socketAddress.Span, ref bytesSent, out errorCode) || !ShouldRetrySyncOperation(out errorCode))) { return errorCode; @@ -1888,7 +1901,6 @@ public SocketError SendTo(byte[] buffer, int offset, int count, SocketFlags flag Count = count, Flags = flags, SocketAddress = socketAddress, - SocketAddressLen = socketAddressLen, BytesTransferred = bytesSent }; @@ -1898,7 +1910,7 @@ public SocketError SendTo(byte[] buffer, int offset, int count, SocketFlags flag return operation.ErrorCode; } - public unsafe SocketError SendTo(ReadOnlySpan buffer, SocketFlags flags, byte[]? socketAddress, int socketAddressLen, int timeout, out int bytesSent) + public unsafe SocketError SendTo(ReadOnlySpan buffer, SocketFlags flags, Memory socketAddress, int timeout, out int bytesSent) { Debug.Assert(timeout == -1 || timeout > 0, $"Unexpected timeout: {timeout}"); @@ -1907,7 +1919,7 @@ public unsafe SocketError SendTo(ReadOnlySpan buffer, SocketFlags flags, b int bufferIndexIgnored = 0, offset = 0, count = buffer.Length; int observedSequenceNumber; if (_sendQueue.IsReady(this, out observedSequenceNumber) && - (SocketPal.TryCompleteSendTo(_socket, buffer, null, ref bufferIndexIgnored, ref offset, ref count, flags, socketAddress, socketAddressLen, ref bytesSent, out errorCode) || + (SocketPal.TryCompleteSendTo(_socket, buffer, null, ref bufferIndexIgnored, ref offset, ref count, flags, socketAddress.Span, ref bytesSent, out errorCode) || !ShouldRetrySyncOperation(out errorCode))) { return errorCode; @@ -1922,7 +1934,6 @@ public unsafe SocketError SendTo(ReadOnlySpan buffer, SocketFlags flags, b Count = count, Flags = flags, SocketAddress = socketAddress, - SocketAddressLen = socketAddressLen, BytesTransferred = bytesSent }; @@ -1933,7 +1944,7 @@ public unsafe SocketError SendTo(ReadOnlySpan buffer, SocketFlags flags, b } } - public SocketError SendToAsync(Memory buffer, int offset, int count, SocketFlags flags, byte[]? socketAddress, ref int socketAddressLen, out int bytesSent, Action callback, CancellationToken cancellationToken = default) + public SocketError SendToAsync(Memory buffer, int offset, int count, SocketFlags flags, Memory socketAddress, out int bytesSent, Action, SocketFlags, SocketError> callback, CancellationToken cancellationToken = default) { SetHandleNonBlocking(); @@ -1941,7 +1952,7 @@ public SocketError SendToAsync(Memory buffer, int offset, int count, Socke SocketError errorCode; int observedSequenceNumber; if (_sendQueue.IsReady(this, out observedSequenceNumber) && - SocketPal.TryCompleteSendTo(_socket, buffer.Span, ref offset, ref count, flags, socketAddress, socketAddressLen, ref bytesSent, out errorCode)) + SocketPal.TryCompleteSendTo(_socket, buffer.Span, ref offset, ref count, flags, socketAddress.Span, ref bytesSent, out errorCode)) { return errorCode; } @@ -1953,7 +1964,6 @@ public SocketError SendToAsync(Memory buffer, int offset, int count, Socke operation.Count = count; operation.Flags = flags; operation.SocketAddress = socketAddress; - operation.SocketAddressLen = socketAddressLen; operation.BytesTransferred = bytesSent; if (!_sendQueue.StartAsyncOperation(this, operation, observedSequenceNumber, cancellationToken)) @@ -1970,16 +1980,15 @@ public SocketError SendToAsync(Memory buffer, int offset, int count, Socke public SocketError Send(IList> buffers, SocketFlags flags, int timeout, out int bytesSent) { - return SendTo(buffers, flags, null, 0, timeout, out bytesSent); + return SendTo(buffers, flags, Memory.Empty, timeout, out bytesSent); } - public SocketError SendAsync(IList> buffers, SocketFlags flags, out int bytesSent, Action callback) + public SocketError SendAsync(IList> buffers, SocketFlags flags, out int bytesSent, Action, SocketFlags, SocketError> callback) { - int socketAddressLen = 0; - return SendToAsync(buffers, flags, null, ref socketAddressLen, out bytesSent, callback); + return SendToAsync(buffers, flags, Memory.Empty, out bytesSent, callback); } - public SocketError SendTo(IList> buffers, SocketFlags flags, byte[]? socketAddress, int socketAddressLen, int timeout, out int bytesSent) + public SocketError SendTo(IList> buffers, SocketFlags flags, Memory socketAddress, int timeout, out int bytesSent) { Debug.Assert(timeout == -1 || timeout > 0, $"Unexpected timeout: {timeout}"); @@ -1989,7 +1998,7 @@ public SocketError SendTo(IList> buffers, SocketFlags flags, SocketError errorCode; int observedSequenceNumber; if (_sendQueue.IsReady(this, out observedSequenceNumber) && - (SocketPal.TryCompleteSendTo(_socket, buffers, ref bufferIndex, ref offset, flags, socketAddress, socketAddressLen, ref bytesSent, out errorCode) || + (SocketPal.TryCompleteSendTo(_socket, buffers, ref bufferIndex, ref offset, flags, socketAddress.Span, ref bytesSent, out errorCode) || !ShouldRetrySyncOperation(out errorCode))) { return errorCode; @@ -2002,7 +2011,6 @@ public SocketError SendTo(IList> buffers, SocketFlags flags, Offset = offset, Flags = flags, SocketAddress = socketAddress, - SocketAddressLen = socketAddressLen, BytesTransferred = bytesSent }; @@ -2012,7 +2020,7 @@ public SocketError SendTo(IList> buffers, SocketFlags flags, return operation.ErrorCode; } - public SocketError SendToAsync(IList> buffers, SocketFlags flags, byte[]? socketAddress, ref int socketAddressLen, out int bytesSent, Action callback) + public SocketError SendToAsync(IList> buffers, SocketFlags flags, Memory socketAddress, out int bytesSent, Action, SocketFlags, SocketError> callback) { SetHandleNonBlocking(); @@ -2022,7 +2030,7 @@ public SocketError SendToAsync(IList> buffers, SocketFlags fl SocketError errorCode; int observedSequenceNumber; if (_sendQueue.IsReady(this, out observedSequenceNumber) && - SocketPal.TryCompleteSendTo(_socket, buffers, ref bufferIndex, ref offset, flags, socketAddress, socketAddressLen, ref bytesSent, out errorCode)) + SocketPal.TryCompleteSendTo(_socket, buffers, ref bufferIndex, ref offset, flags, socketAddress.Span, ref bytesSent, out errorCode)) { return errorCode; } @@ -2034,7 +2042,6 @@ public SocketError SendToAsync(IList> buffers, SocketFlags fl operation.Offset = offset; operation.Flags = flags; operation.SocketAddress = socketAddress; - operation.SocketAddressLen = socketAddressLen; operation.BytesTransferred = bytesSent; if (!_sendQueue.StartAsyncOperation(this, operation, observedSequenceNumber)) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs index 1a09d751708e6c..a8efcdab763220 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs @@ -13,7 +13,7 @@ public partial class SocketAsyncEventArgs : EventArgs, IDisposable private IntPtr _acceptedFileDescriptor; private int _socketAddressSize; private SocketFlags _receivedFlags; - private Action? _transferCompletionCallback; + private Action, SocketFlags, SocketError>? _transferCompletionCallback; partial void InitializeInternals(); @@ -23,18 +23,18 @@ public partial class SocketAsyncEventArgs : EventArgs, IDisposable partial void CompleteCore(); - private void AcceptCompletionCallback(IntPtr acceptedFileDescriptor, byte[] socketAddress, int socketAddressSize, SocketError socketError) + private void AcceptCompletionCallback(IntPtr acceptedFileDescriptor, Memory socketAddress, SocketError socketError) { - CompleteAcceptOperation(acceptedFileDescriptor, socketAddress, socketAddressSize); + CompleteAcceptOperation(acceptedFileDescriptor, socketAddress); CompletionCallback(0, SocketFlags.None, socketError); } - private void CompleteAcceptOperation(IntPtr acceptedFileDescriptor, byte[] socketAddress, int socketAddressSize) + private void CompleteAcceptOperation(IntPtr acceptedFileDescriptor, Memory socketAddress) { _acceptedFileDescriptor = acceptedFileDescriptor; - Debug.Assert(socketAddress == null || socketAddress == _acceptBuffer, $"Unexpected socketAddress: {socketAddress}"); - _acceptAddressBufferCount = socketAddressSize; + Debug.Assert(socketAddress.Length > 0); + _acceptAddressBufferCount = socketAddress.Length; } internal unsafe SocketError DoOperationAccept(Socket _ /*socket*/, SafeSocketHandle handle, SafeSocketHandle? acceptHandle, CancellationToken cancellationToken) @@ -49,12 +49,11 @@ internal unsafe SocketError DoOperationAccept(Socket _ /*socket*/, SafeSocketHan Debug.Assert(acceptHandle == null, $"Unexpected acceptHandle: {acceptHandle}"); IntPtr acceptedFd; - int socketAddressLen = _acceptAddressBufferCount / 2; - SocketError socketError = handle.AsyncContext.AcceptAsync(_acceptBuffer!, ref socketAddressLen, out acceptedFd, AcceptCompletionCallback, cancellationToken); + SocketError socketError = handle.AsyncContext.AcceptAsync(_acceptBuffer!, out int socketAddressLen, out acceptedFd, AcceptCompletionCallback, cancellationToken); if (socketError != SocketError.IOPending) { - CompleteAcceptOperation(acceptedFd, _acceptBuffer!, socketAddressLen); + CompleteAcceptOperation(acceptedFd, new Memory(_acceptBuffer, 0, socketAddressLen)); FinishOperationSync(socketError, 0, SocketFlags.None); } @@ -71,7 +70,7 @@ internal unsafe SocketError DoOperationConnectEx(Socket _ /*socket*/, SafeSocket internal unsafe SocketError DoOperationConnect(SafeSocketHandle handle) { - SocketError socketError = handle.AsyncContext.ConnectAsync(_socketAddress!.Buffer, _socketAddress.Size, ConnectCompletionCallback); + SocketError socketError = handle.AsyncContext.ConnectAsync(_socketAddress!.Buffer, ConnectCompletionCallback); if (socketError != SocketError.IOPending) { FinishOperationSync(socketError, 0, SocketFlags.None); @@ -86,19 +85,18 @@ internal SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handl return socketError; } - private Action TransferCompletionCallback => + private Action, SocketFlags, SocketError> TransferCompletionCallback => _transferCompletionCallback ??= TransferCompletionCallbackCore; - private void TransferCompletionCallbackCore(int bytesTransferred, byte[]? socketAddress, int socketAddressSize, SocketFlags receivedFlags, SocketError socketError) + private void TransferCompletionCallbackCore(int bytesTransferred, Memory socketAddress, SocketFlags receivedFlags, SocketError socketError) { - CompleteTransferOperation(socketAddress, socketAddressSize, receivedFlags); + CompleteTransferOperation(socketAddress, socketAddress.Length, receivedFlags); CompletionCallback(bytesTransferred, receivedFlags, socketError); } - private void CompleteTransferOperation(byte[]? socketAddress, int socketAddressSize, SocketFlags receivedFlags) + private void CompleteTransferOperation(Memory _, int socketAddressSize, SocketFlags receivedFlags) { - Debug.Assert(socketAddress == null || socketAddress == _socketAddress!.Buffer, $"Unexpected socketAddress: {socketAddress}"); _socketAddressSize = socketAddressSize; _receivedFlags = receivedFlags; } @@ -147,14 +145,14 @@ internal unsafe SocketError DoOperationReceiveFrom(SafeSocketHandle handle, Canc SocketFlags flags; SocketError errorCode; int bytesReceived; - int socketAddressLen = _socketAddress!.Size; + int socketAddressLen; if (_bufferList == null) { - errorCode = handle.AsyncContext.ReceiveFromAsync(_buffer.Slice(_offset, _count), _socketFlags, _socketAddress.Buffer, ref socketAddressLen, out bytesReceived, out flags, TransferCompletionCallback, cancellationToken); + errorCode = handle.AsyncContext.ReceiveFromAsync(_buffer.Slice(_offset, _count), _socketFlags, _socketAddress!.Buffer, out socketAddressLen, out bytesReceived, out flags, TransferCompletionCallback, cancellationToken); } else { - errorCode = handle.AsyncContext.ReceiveFromAsync(_bufferListInternal!, _socketFlags, _socketAddress.Buffer, ref socketAddressLen, out bytesReceived, out flags, TransferCompletionCallback); + errorCode = handle.AsyncContext.ReceiveFromAsync(_bufferListInternal!, _socketFlags, _socketAddress!.Buffer, out socketAddressLen, out bytesReceived, out flags, TransferCompletionCallback); } if (errorCode != SocketError.IOPending) @@ -166,19 +164,18 @@ internal unsafe SocketError DoOperationReceiveFrom(SafeSocketHandle handle, Canc return errorCode; } - private void ReceiveMessageFromCompletionCallback(int bytesTransferred, byte[] socketAddress, int socketAddressSize, SocketFlags receivedFlags, IPPacketInformation ipPacketInformation, SocketError errorCode) + private void ReceiveMessageFromCompletionCallback(int bytesTransferred, Memory socketAddress, SocketFlags receivedFlags, IPPacketInformation ipPacketInformation, SocketError errorCode) { - CompleteReceiveMessageFromOperation(socketAddress, socketAddressSize, receivedFlags, ipPacketInformation); + CompleteReceiveMessageFromOperation(socketAddress, socketAddress.Length, receivedFlags, ipPacketInformation); CompletionCallback(bytesTransferred, receivedFlags, errorCode); } - private void CompleteReceiveMessageFromOperation(byte[] socketAddress, int socketAddressSize, SocketFlags receivedFlags, IPPacketInformation ipPacketInformation) + private void CompleteReceiveMessageFromOperation(Memory socketAddress, int socketAddressSize, SocketFlags receivedFlags, IPPacketInformation ipPacketInformation) { - Debug.Assert(_socketAddress != null, "Expected non-null _socketAddress"); - Debug.Assert(socketAddress == null || _socketAddress.Buffer == socketAddress, $"Unexpected socketAddress: {socketAddress}"); + Debug.Assert(socketAddress.Length == socketAddressSize); - _socketAddressSize = socketAddressSize; + _socketAddressSize = socketAddress.Length; _receivedFlags = receivedFlags; _receiveMessageFromPacketInfo = ipPacketInformation; } @@ -196,9 +193,10 @@ internal unsafe SocketError DoOperationReceiveMessageFrom(Socket socket, SafeSoc int bytesReceived; SocketFlags receivedFlags; IPPacketInformation ipPacketInformation; - SocketError socketError = handle.AsyncContext.ReceiveMessageFromAsync(_buffer.Slice(_offset, _count), _bufferListInternal, _socketFlags, _socketAddress.Buffer, ref socketAddressSize, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, ReceiveMessageFromCompletionCallback, cancellationToken); + SocketError socketError = handle.AsyncContext.ReceiveMessageFromAsync(_buffer.Slice(_offset, _count), _bufferListInternal, _socketFlags, _socketAddress.Buffer, out socketAddressSize, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, ReceiveMessageFromCompletionCallback, cancellationToken); if (socketError != SocketError.IOPending) { + _socketAddress.Size = socketAddressSize; CompleteReceiveMessageFromOperation(_socketAddress.Buffer, socketAddressSize, receivedFlags, ipPacketInformation); FinishOperationSync(socketError, bytesReceived, receivedFlags); } @@ -295,20 +293,19 @@ internal SocketError DoOperationSendTo(SafeSocketHandle handle, CancellationToke _socketAddressSize = 0; int bytesSent; - int socketAddressLen = _socketAddress!.Size; SocketError errorCode; if (_bufferList == null) { - errorCode = handle.AsyncContext.SendToAsync(_buffer, _offset, _count, _socketFlags, _socketAddress.Buffer, ref socketAddressLen, out bytesSent, TransferCompletionCallback, cancellationToken); + errorCode = handle.AsyncContext.SendToAsync(_buffer, _offset, _count, _socketFlags, _socketAddress!.Buffer, out bytesSent, TransferCompletionCallback, cancellationToken); } else { - errorCode = handle.AsyncContext.SendToAsync(_bufferListInternal!, _socketFlags, _socketAddress.Buffer, ref socketAddressLen, out bytesSent, TransferCompletionCallback); + errorCode = handle.AsyncContext.SendToAsync(_bufferListInternal!, _socketFlags, _socketAddress!.Buffer, out bytesSent, TransferCompletionCallback); } if (errorCode != SocketError.IOPending) { - CompleteTransferOperation(_socketAddress.Buffer, socketAddressLen, SocketFlags.None); + CompleteTransferOperation(_socketAddress.Buffer, _socketAddress.Size, SocketFlags.None); FinishOperationSync(errorCode, bytesSent, SocketFlags.None); } @@ -334,7 +331,7 @@ internal void LogBuffer(int size) private SocketError FinishOperationAccept(Internals.SocketAddress remoteSocketAddress) { - System.Buffer.BlockCopy(_acceptBuffer!, 0, remoteSocketAddress.Buffer, 0, _acceptAddressBufferCount); + System.Buffer.BlockCopy(_acceptBuffer!, 0, remoteSocketAddress.InternalBuffer, 0, _acceptAddressBufferCount); Socket acceptedSocket = _currentSocket!.CreateAcceptSocket( SocketPal.CreateSocket(_acceptedFileDescriptor), _currentSocket._rightEndPoint!.Create(remoteSocketAddress)); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs index 004b8dab0342bf..6f5f5e856c66dc 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs @@ -284,7 +284,7 @@ internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle ha internal SocketError DoOperationConnect(SafeSocketHandle handle) { // Called for connectionless protocols. - SocketError socketError = SocketPal.Connect(handle, _socketAddress!.Buffer, _socketAddress.Size); + SocketError socketError = SocketPal.Connect(handle, _socketAddress!.Buffer); FinishOperationSync(socketError, 0, SocketFlags.None); return socketError; } @@ -303,7 +303,7 @@ internal unsafe SocketError DoOperationConnectEx(Socket socket, SafeSocketHandle { bool success = socket.ConnectEx( handle, - _socketAddress!.Buffer.AsSpan(), + _socketAddress!.InternalBuffer.AsSpan(), (IntPtr)(bufferPtr + _offset), _count, out int bytesTransferred, @@ -763,7 +763,7 @@ internal unsafe SocketError DoOperationSendToSingleBuffer(SafeSocketHandle handl 1, out int bytesTransferred, _socketFlags, - _socketAddress!.Buffer.AsSpan(), + _socketAddress!.InternalBuffer.AsSpan(), overlapped, IntPtr.Zero); @@ -790,7 +790,7 @@ internal unsafe SocketError DoOperationSendToMultiBuffer(SafeSocketHandle handle _bufferListInternal!.Count, out int bytesTransferred, _socketFlags, - _socketAddress!.Buffer.AsSpan(), + _socketAddress!.InternalBuffer.AsSpan(), overlapped, IntPtr.Zero); @@ -873,7 +873,7 @@ private void PinSocketAddressBuffer() } // Pin down the new one. - _socketAddressGCHandle = GCHandle.Alloc(_socketAddress!.Buffer, GCHandleType.Pinned); + _socketAddressGCHandle = GCHandle.Alloc(_socketAddress!.InternalBuffer, GCHandleType.Pinned); _socketAddress.CopyAddressSizeIntoBuffer(); _pinnedSocketAddress = _socketAddress; } @@ -883,11 +883,11 @@ private unsafe IntPtr PtrSocketAddressBuffer get { Debug.Assert(_pinnedSocketAddress != null); - Debug.Assert(_pinnedSocketAddress.Buffer != null); - Debug.Assert(_pinnedSocketAddress.Buffer.Length > 0); + Debug.Assert(_pinnedSocketAddress.InternalBuffer != null); + Debug.Assert(_pinnedSocketAddress.InternalBuffer.Length > 0); Debug.Assert(_socketAddressGCHandle.IsAllocated); - Debug.Assert(_socketAddressGCHandle.Target == _pinnedSocketAddress.Buffer); - fixed (void* ptrSocketAddressBuffer = &_pinnedSocketAddress.Buffer[0]) + Debug.Assert(_socketAddressGCHandle.Target == _pinnedSocketAddress.InternalBuffer); + fixed (void* ptrSocketAddressBuffer = &_pinnedSocketAddress.InternalBuffer[0]) { return (IntPtr)ptrSocketAddressBuffer; } @@ -1075,7 +1075,7 @@ private unsafe SocketError FinishOperationAccept(Internals.SocketAddress remoteS out remoteSocketAddress.InternalSize ); - Marshal.Copy(remoteAddr, remoteSocketAddress.Buffer, 0, remoteSocketAddress.Size); + Marshal.Copy(remoteAddr, remoteSocketAddress.InternalBuffer, 0, remoteSocketAddress.Size); } socketError = Interop.Winsock.setsockopt( diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs index 3c81f6e4c81ce7..06ce39380a4b8d 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs @@ -60,7 +60,7 @@ public static unsafe SocketError CreateSocket(AddressFamily addressFamily, Socke IntPtr fd; SocketError errorCode; - Interop.Error error = Interop.Sys.Socket(addressFamily, socketType, protocolType, &fd); + Interop.Error error = Interop.Sys.Socket((int)addressFamily, (int)socketType, (int)protocolType, &fd); if (error == Interop.Error.SUCCESS) { Debug.Assert(fd != (IntPtr)(-1), "fd should not be -1"); @@ -138,16 +138,13 @@ private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags, return received; } - private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags, Span buffer, byte[]? socketAddress, ref int socketAddressLen, out SocketFlags receivedFlags, out Interop.Error errno) + private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags, Span buffer, Span socketAddress, out int socketAddressLen, out SocketFlags receivedFlags, out Interop.Error errno) { Debug.Assert(socket.IsSocket); - Debug.Assert(socketAddress != null || socketAddressLen == 0, $"Unexpected values: socketAddress={socketAddress}, socketAddressLen={socketAddressLen}"); - long received = 0; - int sockAddrLen = socketAddress != null ? socketAddressLen : 0; - fixed (byte* sockAddr = socketAddress) + fixed (byte* sockAddr = &MemoryMarshal.GetReference(socketAddress)) fixed (byte* b = &MemoryMarshal.GetReference(buffer)) { var iov = new Interop.Sys.IOVector { @@ -157,7 +154,7 @@ private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags, var messageHeader = new Interop.Sys.MessageHeader { SocketAddress = sockAddr, - SocketAddressLen = sockAddrLen, + SocketAddressLen = socketAddress.Length, IOVectors = &iov, IOVectorCount = 1 }; @@ -169,7 +166,7 @@ private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags, &received); receivedFlags = messageHeader.Flags; - sockAddrLen = messageHeader.SocketAddressLen; + socketAddressLen = messageHeader.SocketAddressLen; } if (errno != Interop.Error.SUCCESS) @@ -177,7 +174,6 @@ private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags, return -1; } - socketAddressLen = sockAddrLen; return checked((int)received); } @@ -237,7 +233,7 @@ private static unsafe int SysSend(SafeSocketHandle socket, SocketFlags flags, Re return sent; } - private static unsafe int SysSend(SafeSocketHandle socket, SocketFlags flags, ReadOnlySpan buffer, ref int offset, ref int count, byte[] socketAddress, int socketAddressLen, out Interop.Error errno) + private static unsafe int SysSend(SafeSocketHandle socket, SocketFlags flags, ReadOnlySpan buffer, ref int offset, ref int count, ReadOnlySpan socketAddress, out Interop.Error errno) { Debug.Assert(socket.IsSocket); @@ -256,7 +252,7 @@ private static unsafe int SysSend(SafeSocketHandle socket, SocketFlags flags, Re var messageHeader = new Interop.Sys.MessageHeader { SocketAddress = sockAddr, - SocketAddressLen = socketAddress != null ? socketAddressLen : 0, + SocketAddressLen = socketAddress.Length, IOVectors = &iov, IOVectorCount = 1 }; @@ -281,19 +277,13 @@ private static unsafe int SysSend(SafeSocketHandle socket, SocketFlags flags, Re return sent; } - private static unsafe int SysSend(SafeSocketHandle socket, SocketFlags flags, IList> buffers, ref int bufferIndex, ref int offset, byte[]? socketAddress, int socketAddressLen, out Interop.Error errno) + private static unsafe int SysSend(SafeSocketHandle socket, SocketFlags flags, IList> buffers, ref int bufferIndex, ref int offset, ReadOnlySpan socketAddress, out Interop.Error errno) { Debug.Assert(socket.IsSocket); // Pin buffers and set up iovecs. int startIndex = bufferIndex, startOffset = offset; - int sockAddrLen = 0; - if (socketAddress != null) - { - sockAddrLen = socketAddressLen; - } - int maxBuffers = buffers.Count - startIndex; bool allocOnStack = maxBuffers <= IovStackThreshold; Span handles = allocOnStack ? stackalloc GCHandle[IovStackThreshold] : new GCHandle[maxBuffers]; @@ -320,7 +310,7 @@ private static unsafe int SysSend(SafeSocketHandle socket, SocketFlags flags, IL { var messageHeader = new Interop.Sys.MessageHeader { SocketAddress = sockAddr, - SocketAddressLen = sockAddrLen, + SocketAddressLen = socketAddress.Length, IOVectors = iov, IOVectorCount = iovCount }; @@ -377,7 +367,7 @@ private static unsafe long SendFile(SafeSocketHandle socket, SafeFileHandle file return bytesSent; } - private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags, IList> buffers, byte[]? socketAddress, ref int socketAddressLen, out SocketFlags receivedFlags, out Interop.Error errno) + private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags, IList> buffers, Span socketAddress, out int socketAddressLen, out SocketFlags receivedFlags, out Interop.Error errno) { Debug.Assert(socket.IsSocket); @@ -392,6 +382,7 @@ private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags, if (errno != Interop.Error.SUCCESS) { receivedFlags = 0; + socketAddressLen = 0; return -1; } if (available == 0) @@ -405,12 +396,7 @@ private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags, Span handles = allocOnStack ? stackalloc GCHandle[IovStackThreshold] : new GCHandle[maxBuffers]; Span iovecs = allocOnStack ? stackalloc Interop.Sys.IOVector[IovStackThreshold] : new Interop.Sys.IOVector[maxBuffers]; - int sockAddrLen = 0; - if (socketAddress != null) - { - sockAddrLen = socketAddressLen; - } - + int sockAddrLen = socketAddress.Length; long received = 0; int toReceive = 0, iovCount = 0; try @@ -469,16 +455,17 @@ private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags, } } + socketAddressLen = sockAddrLen; + if (errno != Interop.Error.SUCCESS) { return -1; } - socketAddressLen = sockAddrLen; return checked((int)received); } - private static unsafe int SysReceiveMessageFrom(SafeSocketHandle socket, SocketFlags flags, Span buffer, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, out Interop.Error errno) + private static unsafe int SysReceiveMessageFrom(SafeSocketHandle socket, SocketFlags flags, Span buffer, Span socketAddress, out int socketAddressLen, bool isIPv4, bool isIPv6, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, out Interop.Error errno) { Debug.Assert(socket.IsSocket); Debug.Assert(socketAddress != null, "Expected non-null socketAddress"); @@ -486,8 +473,6 @@ private static unsafe int SysReceiveMessageFrom(SafeSocketHandle socket, SocketF int cmsgBufferLen = Interop.Sys.GetControlMessageBufferSize(Convert.ToInt32(isIPv4), Convert.ToInt32(isIPv6)); byte* cmsgBuffer = stackalloc byte[cmsgBufferLen]; - int sockAddrLen = socketAddressLen; - Interop.Sys.MessageHeader messageHeader; long received = 0; @@ -501,7 +486,7 @@ private static unsafe int SysReceiveMessageFrom(SafeSocketHandle socket, SocketF messageHeader = new Interop.Sys.MessageHeader { SocketAddress = rawSocketAddress, - SocketAddressLen = sockAddrLen, + SocketAddressLen = socketAddress.Length, IOVectors = &iov, IOVectorCount = 1, ControlBuffer = cmsgBuffer, @@ -515,7 +500,7 @@ private static unsafe int SysReceiveMessageFrom(SafeSocketHandle socket, SocketF &received); receivedFlags = messageHeader.Flags; - sockAddrLen = messageHeader.SocketAddressLen; + socketAddressLen = messageHeader.SocketAddressLen; } if (errno != Interop.Error.SUCCESS) @@ -524,14 +509,19 @@ private static unsafe int SysReceiveMessageFrom(SafeSocketHandle socket, SocketF return -1; } + if (socketAddressLen == 0) + { + // We can fail to get peer address on TCP + socketAddressLen = socketAddress.Length; + SocketAddressPal.Clear(socketAddress); + } ipPacketInformation = GetIPPacketInformation(&messageHeader, isIPv4, isIPv6); - socketAddressLen = sockAddrLen; return checked((int)received); } private static unsafe int SysReceiveMessageFrom( SafeSocketHandle socket, SocketFlags flags, IList> buffers, - byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, + Span socketAddress, out int socketAddressLen, bool isIPv4, bool isIPv6, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, out Interop.Error errno) { Debug.Assert(socket.IsSocket); @@ -566,7 +556,7 @@ private static unsafe int SysReceiveMessageFrom( var messageHeader = new Interop.Sys.MessageHeader { SocketAddress = sockAddr, - SocketAddressLen = socketAddressLen, + SocketAddressLen = socketAddress.Length, IOVectors = iov, IOVectorCount = iovCount, ControlBuffer = cmsgBuffer, @@ -581,12 +571,18 @@ private static unsafe int SysReceiveMessageFrom( &received); receivedFlags = messageHeader.Flags; - int sockAddrLen = messageHeader.SocketAddressLen; + socketAddressLen = messageHeader.SocketAddressLen; if (errno == Interop.Error.SUCCESS) { ipPacketInformation = GetIPPacketInformation(&messageHeader, isIPv4, isIPv6); - socketAddressLen = sockAddrLen; + if (socketAddressLen == 0) + { + // We can fail to get peer address on TCP + socketAddressLen = socketAddress.Length; + SocketAddressPal.Clear(socketAddress); + } + return checked((int)received); } else @@ -606,22 +602,24 @@ private static unsafe int SysReceiveMessageFrom( } } - public static unsafe bool TryCompleteAccept(SafeSocketHandle socket, byte[] socketAddress, ref int socketAddressLen, out IntPtr acceptedFd, out SocketError errorCode) + public static unsafe bool TryCompleteAccept(SafeSocketHandle socket, Memory socketAddress, out int socketAddressLen, out IntPtr acceptedFd, out SocketError errorCode) { IntPtr fd = IntPtr.Zero; Interop.Error errno; - int sockAddrLen = socketAddressLen; - fixed (byte* rawSocketAddress = socketAddress) + int sockAddrLen = socketAddress.Length; + fixed (byte* rawSocketAddress = socketAddress.Span) { try { errno = Interop.Sys.Accept(socket, rawSocketAddress, &sockAddrLen, &fd); + socketAddressLen = sockAddrLen; } catch (ObjectDisposedException) { // The socket was closed, or is closing. errorCode = SocketError.OperationAborted; acceptedFd = (IntPtr)(-1); + socketAddressLen = 0; return true; } } @@ -630,7 +628,6 @@ public static unsafe bool TryCompleteAccept(SafeSocketHandle socket, byte[] sock { Debug.Assert(fd != (IntPtr)(-1), "Expected fd != -1"); - socketAddressLen = sockAddrLen; errorCode = SocketError.Success; acceptedFd = fd; @@ -648,10 +645,9 @@ public static unsafe bool TryCompleteAccept(SafeSocketHandle socket, byte[] sock return false; } - public static unsafe bool TryStartConnect(SafeSocketHandle socket, byte[] socketAddress, int socketAddressLen, out SocketError errorCode) + public static unsafe bool TryStartConnect(SafeSocketHandle socket, Memory socketAddress, out SocketError errorCode) { - Debug.Assert(socketAddress != null, "Expected non-null socketAddress"); - Debug.Assert(socketAddressLen > 0, $"Unexpected socketAddressLen: {socketAddressLen}"); + Debug.Assert(socketAddress.Length > 0, $"Unexpected socketAddressLen: {socketAddress.Length}"); if (socket.IsDisconnected) { @@ -660,9 +656,9 @@ public static unsafe bool TryStartConnect(SafeSocketHandle socket, byte[] socket } Interop.Error err; - fixed (byte* rawSocketAddress = socketAddress) + fixed (byte* rawSocketAddress = socketAddress.Span) { - err = Interop.Sys.Connect(socket, rawSocketAddress, socketAddressLen); + err = Interop.Sys.Connect(socket, rawSocketAddress, socketAddress.Length); } if (err == Interop.Error.SUCCESS) @@ -735,11 +731,11 @@ public static unsafe bool TryCompleteConnect(SafeSocketHandle socket, out Socket return true; } - public static bool TryCompleteReceiveFrom(SafeSocketHandle socket, Span buffer, SocketFlags flags, byte[]? socketAddress, ref int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, out SocketError errorCode) => - TryCompleteReceiveFrom(socket, buffer, null, flags, socketAddress, ref socketAddressLen, out bytesReceived, out receivedFlags, out errorCode); + public static bool TryCompleteReceiveFrom(SafeSocketHandle socket, Span buffer, SocketFlags flags, Span socketAddress, out int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, out SocketError errorCode) => + TryCompleteReceiveFrom(socket, buffer, null, flags, socketAddress, out socketAddressLen, out bytesReceived, out receivedFlags, out errorCode); - public static bool TryCompleteReceiveFrom(SafeSocketHandle socket, IList> buffers, SocketFlags flags, byte[]? socketAddress, ref int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, out SocketError errorCode) => - TryCompleteReceiveFrom(socket, default(Span), buffers, flags, socketAddress, ref socketAddressLen, out bytesReceived, out receivedFlags, out errorCode); + public static bool TryCompleteReceiveFrom(SafeSocketHandle socket, IList> buffers, SocketFlags flags, Span socketAddress, out int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, out SocketError errorCode) => + TryCompleteReceiveFrom(socket, default(Span), buffers, flags, socketAddress, out socketAddressLen, out bytesReceived, out receivedFlags, out errorCode); public static unsafe bool TryCompleteReceive(SafeSocketHandle socket, Span buffer, SocketFlags flags, out int bytesReceived, out SocketError errorCode) { @@ -800,7 +796,7 @@ public static unsafe bool TryCompleteReceive(SafeSocketHandle socket, Span } } - public static unsafe bool TryCompleteReceiveFrom(SafeSocketHandle socket, Span buffer, IList>? buffers, SocketFlags flags, byte[]? socketAddress, ref int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, out SocketError errorCode) + public static unsafe bool TryCompleteReceiveFrom(SafeSocketHandle socket, Span buffer, IList>? buffers, SocketFlags flags, Span socketAddress, out int receivedSocketAddressLength, out int bytesReceived, out SocketFlags receivedFlags, out SocketError errorCode) { try { @@ -815,11 +811,12 @@ public static unsafe bool TryCompleteReceiveFrom(SafeSocketHandle socket, Span(&oneBytePeekBuffer, 1), socketAddress, ref socketAddressLen, out receivedFlags, out errno); + received = SysReceive(socket, flags | SocketFlags.Peek, new Span(&oneBytePeekBuffer, 1), socketAddress, out receivedSocketAddressLength, out receivedFlags, out errno); if (received > 0) { // Peeked for 1-byte, but the actual request was for 0. @@ -838,17 +835,24 @@ public static unsafe bool TryCompleteReceiveFrom(SafeSocketHandle socket, Span 0 bytes into a single buffer - received = SysReceive(socket, flags, buffer, socketAddress, ref socketAddressLen, out receivedFlags, out errno); + received = SysReceive(socket, flags, buffer, socketAddress, out receivedSocketAddressLength, out receivedFlags, out errno); } if (received != -1) { bytesReceived = received; errorCode = SocketError.Success; + if (socketAddress.Length > 0 && receivedSocketAddressLength == 0) + { + // We can fail to get peer address on TCP + receivedSocketAddressLength = socketAddress.Length; + SocketAddressPal.Clear(socketAddress); + } return true; } bytesReceived = 0; + receivedSocketAddressLength = 0; if (errno != Interop.Error.EAGAIN && errno != Interop.Error.EWOULDBLOCK) { @@ -864,23 +868,30 @@ public static unsafe bool TryCompleteReceiveFrom(SafeSocketHandle socket, Span buffer, IList>? buffers, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, out int bytesReceived, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, out SocketError errorCode) + public static unsafe bool TryCompleteReceiveMessageFrom(SafeSocketHandle socket, Span buffer, IList>? buffers, SocketFlags flags, Memory socketAddress, out int receivedSocketAddressLength, bool isIPv4, bool isIPv6, out int bytesReceived, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, out SocketError errorCode) { try { Interop.Error errno; int received = buffers == null ? - SysReceiveMessageFrom(socket, flags, buffer, socketAddress, ref socketAddressLen, isIPv4, isIPv6, out receivedFlags, out ipPacketInformation, out errno) : - SysReceiveMessageFrom(socket, flags, buffers, socketAddress, ref socketAddressLen, isIPv4, isIPv6, out receivedFlags, out ipPacketInformation, out errno); + SysReceiveMessageFrom(socket, flags, buffer, socketAddress.Span, out receivedSocketAddressLength, isIPv4, isIPv6, out receivedFlags, out ipPacketInformation, out errno) : + SysReceiveMessageFrom(socket, flags, buffers, socketAddress.Span, out receivedSocketAddressLength, isIPv4, isIPv6, out receivedFlags, out ipPacketInformation, out errno); if (received != -1) { + if (socketAddress.Length > 0 && receivedSocketAddressLength == 0) + { + // We can fail to get peer address on TCP + receivedSocketAddressLength = socketAddress.Length; + SocketAddressPal.Clear(socketAddress.Span); + } bytesReceived = received; errorCode = SocketError.Success; return true; @@ -902,31 +913,32 @@ public static unsafe bool TryCompleteReceiveMessageFrom(SafeSocketHandle socket, // The socket was closed, or is closing. bytesReceived = 0; receivedFlags = 0; + receivedSocketAddressLength = 0; ipPacketInformation = default(IPPacketInformation); errorCode = SocketError.OperationAborted; return true; } } - public static bool TryCompleteSendTo(SafeSocketHandle socket, Span buffer, ref int offset, ref int count, SocketFlags flags, byte[]? socketAddress, int socketAddressLen, ref int bytesSent, out SocketError errorCode) + public static bool TryCompleteSendTo(SafeSocketHandle socket, Span buffer, ref int offset, ref int count, SocketFlags flags, ReadOnlySpan socketAddress, ref int bytesSent, out SocketError errorCode) { int bufferIndex = 0; - return TryCompleteSendTo(socket, buffer, null, ref bufferIndex, ref offset, ref count, flags, socketAddress, socketAddressLen, ref bytesSent, out errorCode); + return TryCompleteSendTo(socket, buffer, null, ref bufferIndex, ref offset, ref count, flags, socketAddress, ref bytesSent, out errorCode); } - public static bool TryCompleteSendTo(SafeSocketHandle socket, ReadOnlySpan buffer, SocketFlags flags, byte[]? socketAddress, int socketAddressLen, ref int bytesSent, out SocketError errorCode) + public static bool TryCompleteSendTo(SafeSocketHandle socket, ReadOnlySpan buffer, SocketFlags flags, ReadOnlySpan socketAddress, ref int bytesSent, out SocketError errorCode) { int bufferIndex = 0, offset = 0, count = buffer.Length; - return TryCompleteSendTo(socket, buffer, null, ref bufferIndex, ref offset, ref count, flags, socketAddress, socketAddressLen, ref bytesSent, out errorCode); + return TryCompleteSendTo(socket, buffer, null, ref bufferIndex, ref offset, ref count, flags, socketAddress, ref bytesSent, out errorCode); } - public static bool TryCompleteSendTo(SafeSocketHandle socket, IList> buffers, ref int bufferIndex, ref int offset, SocketFlags flags, byte[]? socketAddress, int socketAddressLen, ref int bytesSent, out SocketError errorCode) + public static bool TryCompleteSendTo(SafeSocketHandle socket, IList> buffers, ref int bufferIndex, ref int offset, SocketFlags flags, ReadOnlySpan socketAddress, ref int bytesSent, out SocketError errorCode) { int count = 0; - return TryCompleteSendTo(socket, default(ReadOnlySpan), buffers, ref bufferIndex, ref offset, ref count, flags, socketAddress, socketAddressLen, ref bytesSent, out errorCode); + return TryCompleteSendTo(socket, default(ReadOnlySpan), buffers, ref bufferIndex, ref offset, ref count, flags, socketAddress, ref bytesSent, out errorCode); } - public static bool TryCompleteSendTo(SafeSocketHandle socket, ReadOnlySpan buffer, IList>? buffers, ref int bufferIndex, ref int offset, ref int count, SocketFlags flags, byte[]? socketAddress, int socketAddressLen, ref int bytesSent, out SocketError errorCode) + public static bool TryCompleteSendTo(SafeSocketHandle socket, ReadOnlySpan buffer, IList>? buffers, ref int bufferIndex, ref int offset, ref int count, SocketFlags flags, ReadOnlySpan socketAddress, ref int bytesSent, out SocketError errorCode) { bool successfulSend = false; long start = socket.IsUnderlyingHandleBlocking && socket.SendTimeout > 0 ? Environment.TickCount64 : 0; // Get ticks only if timeout is set and socket is blocking. @@ -946,9 +958,9 @@ public static bool TryCompleteSendTo(SafeSocketHandle socket, ReadOnlySpan else { sent = buffers != null ? - SysSend(socket, flags, buffers, ref bufferIndex, ref offset, socketAddress, socketAddressLen, out errno) : + SysSend(socket, flags, buffers, ref bufferIndex, ref offset, socketAddress, out errno) : socketAddress == null ? SysSend(socket, flags, buffer, ref offset, ref count, out errno) : - SysSend(socket, flags, buffer, ref offset, ref count, socketAddress, socketAddressLen, out errno); + SysSend(socket, flags, buffer, ref offset, ref count, socketAddress, out errno); } } catch (ObjectDisposedException) @@ -1094,7 +1106,7 @@ public static SocketError Listen(SafeSocketHandle handle, int backlog) return err == Interop.Error.SUCCESS ? SocketError.Success : GetSocketErrorForErrorCode(err); } - public static SocketError Accept(SafeSocketHandle listenSocket, byte[] socketAddress, ref int socketAddressLen, out SafeSocketHandle socket) + public static SocketError Accept(SafeSocketHandle listenSocket, Memory socketAddress, out int socketAddressLen, out SafeSocketHandle socket) { socket = new SafeSocketHandle(); @@ -1102,11 +1114,11 @@ public static SocketError Accept(SafeSocketHandle listenSocket, byte[] socketAdd SocketError errorCode; if (!listenSocket.IsNonBlocking) { - errorCode = listenSocket.AsyncContext.Accept(socketAddress, ref socketAddressLen, out acceptedFd); + errorCode = listenSocket.AsyncContext.Accept(socketAddress, out socketAddressLen, out acceptedFd); } else { - if (!TryCompleteAccept(listenSocket, socketAddress, ref socketAddressLen, out acceptedFd, out errorCode)) + if (!TryCompleteAccept(listenSocket, socketAddress, out socketAddressLen, out acceptedFd, out errorCode)) { errorCode = SocketError.WouldBlock; } @@ -1119,15 +1131,15 @@ public static SocketError Accept(SafeSocketHandle listenSocket, byte[] socketAdd return errorCode; } - public static SocketError Connect(SafeSocketHandle handle, byte[] socketAddress, int socketAddressLen) + public static SocketError Connect(SafeSocketHandle handle, Memory socketAddress) { if (!handle.IsNonBlocking) { - return handle.AsyncContext.Connect(socketAddress, socketAddressLen); + return handle.AsyncContext.Connect(socketAddress); } SocketError errorCode; - bool completed = TryStartConnect(handle, socketAddress, socketAddressLen, out errorCode); + bool completed = TryStartConnect(handle, socketAddress, out errorCode); if (completed) { handle.RegisterConnectResult(errorCode); @@ -1151,7 +1163,7 @@ public static SocketError Send(SafeSocketHandle handle, IList int bufferIndex = 0; int offset = 0; SocketError errorCode; - TryCompleteSendTo(handle, bufferList, ref bufferIndex, ref offset, socketFlags, null, 0, ref bytesTransferred, out errorCode); + TryCompleteSendTo(handle, bufferList, ref bufferIndex, ref offset, socketFlags, ReadOnlySpan.Empty, ref bytesTransferred, out errorCode); return errorCode; } @@ -1164,7 +1176,7 @@ public static SocketError Send(SafeSocketHandle handle, byte[] buffer, int offse bytesTransferred = 0; SocketError errorCode; - TryCompleteSendTo(handle, buffer, ref offset, ref count, socketFlags, null, 0, ref bytesTransferred, out errorCode); + TryCompleteSendTo(handle, buffer, ref offset, ref count, socketFlags, ReadOnlySpan.Empty, ref bytesTransferred, out errorCode); return errorCode; } @@ -1177,7 +1189,7 @@ public static SocketError Send(SafeSocketHandle handle, ReadOnlySpan buffe bytesTransferred = 0; SocketError errorCode; - TryCompleteSendTo(handle, buffer, socketFlags, null, 0, ref bytesTransferred, out errorCode); + TryCompleteSendTo(handle, buffer, socketFlags, ReadOnlySpan.Empty, ref bytesTransferred, out errorCode); return errorCode; } @@ -1197,29 +1209,29 @@ public static SocketError SendFile(SafeSocketHandle handle, SafeFileHandle fileH return completed ? errorCode : SocketError.WouldBlock; } - public static SocketError SendTo(SafeSocketHandle handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, byte[] socketAddress, int socketAddressLen, out int bytesTransferred) + public static SocketError SendTo(SafeSocketHandle handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, Memory socketAddress, out int bytesTransferred) { if (!handle.IsNonBlocking) { - return handle.AsyncContext.SendTo(buffer, offset, count, socketFlags, socketAddress, socketAddressLen, handle.SendTimeout, out bytesTransferred); + return handle.AsyncContext.SendTo(buffer, offset, count, socketFlags, socketAddress, handle.SendTimeout, out bytesTransferred); } bytesTransferred = 0; SocketError errorCode; - TryCompleteSendTo(handle, buffer, ref offset, ref count, socketFlags, socketAddress, socketAddressLen, ref bytesTransferred, out errorCode); + TryCompleteSendTo(handle, buffer, ref offset, ref count, socketFlags, socketAddress.Span, ref bytesTransferred, out errorCode); return errorCode; } - public static SocketError SendTo(SafeSocketHandle handle, ReadOnlySpan buffer, SocketFlags socketFlags, byte[] socketAddress, int socketAddressLen, out int bytesTransferred) + public static SocketError SendTo(SafeSocketHandle handle, ReadOnlySpan buffer, SocketFlags socketFlags, Memory socketAddress, out int bytesTransferred) { if (!handle.IsNonBlocking) { - return handle.AsyncContext.SendTo(buffer, socketFlags, socketAddress, socketAddressLen, handle.SendTimeout, out bytesTransferred); + return handle.AsyncContext.SendTo(buffer, socketFlags, socketAddress, handle.SendTimeout, out bytesTransferred); } bytesTransferred = 0; SocketError errorCode; - TryCompleteSendTo(handle, buffer, socketFlags, socketAddress, socketAddressLen, ref bytesTransferred, out errorCode); + TryCompleteSendTo(handle, buffer, socketFlags, socketAddress.Span, ref bytesTransferred, out errorCode); return errorCode; } @@ -1232,8 +1244,7 @@ public static SocketError Receive(SafeSocketHandle handle, IList buffer, So public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle handle, byte[] buffer, int offset, int count, ref SocketFlags socketFlags, Internals.SocketAddress socketAddress, out Internals.SocketAddress receiveAddress, out IPPacketInformation ipPacketInformation, out int bytesTransferred) { - byte[] socketAddressBuffer = socketAddress.Buffer; - int socketAddressLen = socketAddress.Size; + int socketAddressLen; bool isIPv4, isIPv6; Socket.GetIPProtocolInformation(socket.AddressFamily, socketAddress, out isIPv4, out isIPv6); @@ -1277,11 +1287,11 @@ public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle han SocketError errorCode; if (!handle.IsNonBlocking) { - errorCode = handle.AsyncContext.ReceiveMessageFrom(new Memory(buffer, offset, count), ref socketFlags, socketAddressBuffer, ref socketAddressLen, isIPv4, isIPv6, handle.ReceiveTimeout, out ipPacketInformation, out bytesTransferred); + errorCode = handle.AsyncContext.ReceiveMessageFrom(new Memory(buffer, offset, count), ref socketFlags, socketAddress.Buffer, out socketAddressLen, isIPv4, isIPv6, handle.ReceiveTimeout, out ipPacketInformation, out bytesTransferred); } else { - if (!TryCompleteReceiveMessageFrom(handle, new Span(buffer, offset, count), null, socketFlags, socketAddressBuffer, ref socketAddressLen, isIPv4, isIPv6, out bytesTransferred, out socketFlags, out ipPacketInformation, out errorCode)) + if (!TryCompleteReceiveMessageFrom(handle, new Span(buffer, offset, count), null, socketFlags, socketAddress.Buffer, out socketAddressLen, isIPv4, isIPv6, out bytesTransferred, out socketFlags, out ipPacketInformation, out errorCode)) { errorCode = SocketError.WouldBlock; } @@ -1295,8 +1305,8 @@ public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle han public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle handle, Span buffer, ref SocketFlags socketFlags, Internals.SocketAddress socketAddress, out Internals.SocketAddress receiveAddress, out IPPacketInformation ipPacketInformation, out int bytesTransferred) { - byte[] socketAddressBuffer = socketAddress.Buffer; - int socketAddressLen = socketAddress.Size; + byte[] socketAddressBuffer = socketAddress.InternalBuffer; + int socketAddressLen; bool isIPv4, isIPv6; Socket.GetIPProtocolInformation(socket.AddressFamily, socketAddress, out isIPv4, out isIPv6); @@ -1304,11 +1314,11 @@ public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle han SocketError errorCode; if (!handle.IsNonBlocking) { - errorCode = handle.AsyncContext.ReceiveMessageFrom(buffer, ref socketFlags, socketAddressBuffer, ref socketAddressLen, isIPv4, isIPv6, handle.ReceiveTimeout, out ipPacketInformation, out bytesTransferred); + errorCode = handle.AsyncContext.ReceiveMessageFrom(buffer, ref socketFlags, socketAddressBuffer, out socketAddressLen, isIPv4, isIPv6, handle.ReceiveTimeout, out ipPacketInformation, out bytesTransferred); } else { - if (!TryCompleteReceiveMessageFrom(handle, buffer, null, socketFlags, socketAddressBuffer, ref socketAddressLen, isIPv4, isIPv6, out bytesTransferred, out socketFlags, out ipPacketInformation, out errorCode)) + if (!TryCompleteReceiveMessageFrom(handle, buffer, null, socketFlags, socketAddressBuffer, out socketAddressLen, isIPv4, isIPv6, out bytesTransferred, out socketFlags, out ipPacketInformation, out errorCode)) { errorCode = SocketError.WouldBlock; } @@ -1319,27 +1329,27 @@ public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle han return errorCode; } - public static SocketError ReceiveFrom(SafeSocketHandle handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, byte[] socketAddress, ref int socketAddressLen, out int bytesTransferred) + public static SocketError ReceiveFrom(SafeSocketHandle handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, Memory socketAddress, out int socketAddressLen, out int bytesTransferred) { if (!handle.IsNonBlocking) { - return handle.AsyncContext.ReceiveFrom(new Memory(buffer, offset, count), ref socketFlags, socketAddress, ref socketAddressLen, handle.ReceiveTimeout, out bytesTransferred); + return handle.AsyncContext.ReceiveFrom(new Memory(buffer, offset, count), ref socketFlags, socketAddress, out socketAddressLen, handle.ReceiveTimeout, out bytesTransferred); } SocketError errorCode; - bool completed = TryCompleteReceiveFrom(handle, new Span(buffer, offset, count), socketFlags, socketAddress, ref socketAddressLen, out bytesTransferred, out socketFlags, out errorCode); + bool completed = TryCompleteReceiveFrom(handle, new Span(buffer, offset, count), socketFlags, socketAddress.Span, out socketAddressLen, out bytesTransferred, out socketFlags, out errorCode); return completed ? errorCode : SocketError.WouldBlock; } - public static SocketError ReceiveFrom(SafeSocketHandle handle, Span buffer, SocketFlags socketFlags, byte[] socketAddress, ref int socketAddressLen, out int bytesTransferred) + public static SocketError ReceiveFrom(SafeSocketHandle handle, Span buffer, SocketFlags socketFlags, Memory socketAddress, out int socketAddressLen, out int bytesTransferred) { if (!handle.IsNonBlocking) { - return handle.AsyncContext.ReceiveFrom(buffer, ref socketFlags, socketAddress, ref socketAddressLen, handle.ReceiveTimeout, out bytesTransferred); + return handle.AsyncContext.ReceiveFrom(buffer, ref socketFlags, socketAddress, out socketAddressLen, handle.ReceiveTimeout, out bytesTransferred); } SocketError errorCode; - bool completed = TryCompleteReceiveFrom(handle, buffer, socketFlags, socketAddress, ref socketAddressLen, out bytesTransferred, out socketFlags, out errorCode); + bool completed = TryCompleteReceiveFrom(handle, buffer, socketFlags, socketAddress.Span, out socketAddressLen, out bytesTransferred, out socketFlags, out errorCode); return completed ? errorCode : SocketError.WouldBlock; } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs index 749c26a608ebc9..7ab0a867e64c7b 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs @@ -176,22 +176,22 @@ public static SocketError Listen(SafeSocketHandle handle, int backlog) return errorCode == SocketError.SocketError ? GetLastSocketError() : SocketError.Success; } - public static SocketError Accept(SafeSocketHandle listenSocket, byte[] socketAddress, ref int socketAddressSize, out SafeSocketHandle socket) + public static SocketError Accept(SafeSocketHandle listenSocket, Memory socketAddress, out int socketAddressSize, out SafeSocketHandle socket) { socket = new SafeSocketHandle(); - Marshal.InitHandle(socket, Interop.Winsock.accept(listenSocket, socketAddress, ref socketAddressSize)); + socketAddressSize = socketAddress.Length; + Marshal.InitHandle(socket, Interop.Winsock.accept(listenSocket, socketAddress.Span, ref socketAddressSize)); if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(null, socket); return socket.IsInvalid ? GetLastSocketError() : SocketError.Success; } - public static SocketError Connect(SafeSocketHandle handle, byte[] peerAddress, int peerAddressLen) + public static SocketError Connect(SafeSocketHandle handle, Memory peerAddress) { SocketError errorCode = Interop.Winsock.WSAConnect( handle, - peerAddress, - peerAddressLen, + peerAddress.Span, IntPtr.Zero, IntPtr.Zero, IntPtr.Zero, @@ -300,15 +300,15 @@ public static unsafe SocketError SendFile(SafeSocketHandle handle, SafeFileHandl } } - public static SocketError SendTo(SafeSocketHandle handle, byte[] buffer, int offset, int size, SocketFlags socketFlags, byte[] peerAddress, int peerAddressSize, out int bytesTransferred) => - SendTo(handle, buffer.AsSpan(offset, size), socketFlags, peerAddress, peerAddressSize, out bytesTransferred); + public static SocketError SendTo(SafeSocketHandle handle, byte[] buffer, int offset, int size, SocketFlags socketFlags, ReadOnlyMemory peerAddress, out int bytesTransferred) => + SendTo(handle, buffer.AsSpan(offset, size), socketFlags, peerAddress, out bytesTransferred); - public static unsafe SocketError SendTo(SafeSocketHandle handle, ReadOnlySpan buffer, SocketFlags socketFlags, byte[] peerAddress, int peerAddressSize, out int bytesTransferred) + public static unsafe SocketError SendTo(SafeSocketHandle handle, ReadOnlySpan buffer, SocketFlags socketFlags, ReadOnlyMemory peerAddress, out int bytesTransferred) { int bytesSent; fixed (byte* bufferPtr = &MemoryMarshal.GetReference(buffer)) { - bytesSent = Interop.Winsock.sendto(handle, bufferPtr, buffer.Length, socketFlags, peerAddress, peerAddressSize); + bytesSent = Interop.Winsock.sendto(handle, bufferPtr, buffer.Length, socketFlags, peerAddress.Span, peerAddress.Length); } if (bytesSent == (int)SocketError.SocketError) @@ -447,7 +447,7 @@ public static unsafe SocketError ReceiveMessageFrom(Socket socket, SafeSocketHan receiveAddress = socketAddress; ipPacketInformation = default(IPPacketInformation); fixed (byte* bufferPtr = &MemoryMarshal.GetReference(buffer)) - fixed (byte* ptrSocketAddress = socketAddress.Buffer) + fixed (byte* ptrSocketAddress = &MemoryMarshal.GetReference(socketAddress.Buffer.Span)) { Interop.Winsock.WSAMsg wsaMsg; wsaMsg.socketAddress = (IntPtr)ptrSocketAddress; @@ -512,17 +512,15 @@ public static unsafe SocketError ReceiveMessageFrom(Socket socket, SafeSocketHan return SocketError.Success; } - public static unsafe SocketError ReceiveFrom(SafeSocketHandle handle, byte[] buffer, int offset, int size, SocketFlags _ /*socketFlags*/, byte[] socketAddress, ref int addressLength, out int bytesTransferred) => - ReceiveFrom(handle, buffer.AsSpan(offset, size), SocketFlags.None, socketAddress, ref addressLength, out bytesTransferred); + public static unsafe SocketError ReceiveFrom(SafeSocketHandle handle, byte[] buffer, int offset, int size, SocketFlags _ /*socketFlags*/, Memory socketAddress, out int addressLength, out int bytesTransferred) => + ReceiveFrom(handle, buffer.AsSpan(offset, size), SocketFlags.None, socketAddress, out addressLength, out bytesTransferred); - public static unsafe SocketError ReceiveFrom(SafeSocketHandle handle, Span buffer, SocketFlags socketFlags, byte[] socketAddress, ref int addressLength, out int bytesTransferred) + public static unsafe SocketError ReceiveFrom(SafeSocketHandle handle, Span buffer, SocketFlags socketFlags, Memory socketAddress, out int addressLength, out int bytesTransferred) { int bytesReceived; - fixed (byte* bufferPtr = &MemoryMarshal.GetReference(buffer)) - { - bytesReceived = Interop.Winsock.recvfrom(handle, bufferPtr, buffer.Length, socketFlags, socketAddress, ref addressLength); - } + addressLength = socketAddress.Length; + bytesReceived = Interop.Winsock.recvfrom(handle, buffer, buffer.Length, socketFlags, socketAddress.Span, ref addressLength); if (bytesReceived == (int)SocketError.SocketError) { diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs index 30f899edda0401..1a720df27d250c 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs @@ -154,6 +154,65 @@ public async Task ReceiveSent_UDP_Success(bool ipv4) } } + [Theory] + [InlineData(false)] + [InlineData(true)] + public void ReceiveSent_SocketAddress_Success(bool ipv4) + { + const int DatagramSize = 256; + const int DatagramsToSend = 16; + + IPAddress address = ipv4 ? IPAddress.Loopback : IPAddress.IPv6Loopback; + using Socket server = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + using Socket client = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + + client.BindToAnonymousPort(address); + server.BindToAnonymousPort(address); + + byte[] sendBuffer = new byte[DatagramSize]; + byte[] receiveBuffer = new byte[DatagramSize]; + + SocketAddress serverSA = server.LocalEndPoint.Serialize(); + SocketAddress clientSA = client.LocalEndPoint.Serialize(); + SocketAddress sa = new SocketAddress(address.AddressFamily); + + Random rnd = new Random(0); + + for (int i = 0; i < DatagramsToSend; i++) + { + rnd.NextBytes(sendBuffer); + client.SendTo(sendBuffer.AsSpan(), SocketFlags.None, serverSA); + + int readBytes = server.ReceiveFrom(receiveBuffer, SocketFlags.None, sa); + Assert.Equal(sa, clientSA); + Assert.Equal(client.LocalEndPoint, client.LocalEndPoint.Create(sa)); + Assert.True(new Span(receiveBuffer, 0, readBytes).SequenceEqual(sendBuffer)); + + // and send it back to make sure it works. + rnd.NextBytes(sendBuffer); + server.SendTo(sendBuffer, SocketFlags.None, sa); + readBytes = client.ReceiveFrom(receiveBuffer, SocketFlags.None, sa); + Assert.Equal(sa, serverSA); + Assert.Equal(server.LocalEndPoint, server.LocalEndPoint.Create(sa)); + Assert.True(new Span(receiveBuffer, 0, readBytes).SequenceEqual(sendBuffer)); + + } + } + + [Fact] + public void ReceiveSent_SmallSocketAddress_Throws() + { + using Socket server = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + server.BindToAnonymousPort(IPAddress.Loopback); + + byte[] receiveBuffer = new byte[1]; + + SocketAddress serverSA = server.LocalEndPoint.Serialize(); + + SocketAddress sa = new SocketAddress(AddressFamily.InterNetwork, 2); + Assert.Throws(() => server.ReceiveFrom(receiveBuffer, SocketFlags.None, sa)); + } + [Theory] [InlineData(true)] [InlineData(false)] diff --git a/src/native/libs/System.Native/entrypoints.c b/src/native/libs/System.Native/entrypoints.c index 394a39b0d00365..0b719b1545faeb 100644 --- a/src/native/libs/System.Native/entrypoints.c +++ b/src/native/libs/System.Native/entrypoints.c @@ -137,7 +137,7 @@ static const Entry s_sysNative[] = DllImportEntry(SystemNative_GetNameInfo) DllImportEntry(SystemNative_GetDomainName) DllImportEntry(SystemNative_GetHostName) - DllImportEntry(SystemNative_GetIPSocketAddressSizes) + DllImportEntry(SystemNative_GetSocketAddressSizes) DllImportEntry(SystemNative_GetAddressFamily) DllImportEntry(SystemNative_SetAddressFamily) DllImportEntry(SystemNative_GetPort) diff --git a/src/native/libs/System.Native/pal_networking.c b/src/native/libs/System.Native/pal_networking.c index ffd548835901f3..8dfc133f2ed788 100644 --- a/src/native/libs/System.Native/pal_networking.c +++ b/src/native/libs/System.Native/pal_networking.c @@ -659,15 +659,17 @@ static bool IsInBounds(const void* void_baseAddr, size_t len, const void* void_v return valueAddr >= baseAddr && (valueAddr + valueSize) <= (baseAddr + len); } -int32_t SystemNative_GetIPSocketAddressSizes(int32_t* ipv4SocketAddressSize, int32_t* ipv6SocketAddressSize) +int32_t SystemNative_GetSocketAddressSizes(int32_t* ipv4SocketAddressSize, int32_t* ipv6SocketAddressSize, int32_t* udsSocketAddressSize, int32_t* maxSocketAddressSize) { - if (ipv4SocketAddressSize == NULL || ipv6SocketAddressSize == NULL) + if (ipv4SocketAddressSize == NULL || ipv6SocketAddressSize == NULL || udsSocketAddressSize == NULL || maxSocketAddressSize == NULL) { return Error_EFAULT; } *ipv4SocketAddressSize = sizeof(struct sockaddr_in); *ipv6SocketAddressSize = sizeof(struct sockaddr_in6); + *udsSocketAddressSize = sizeof(struct sockaddr_un); + *maxSocketAddressSize = sizeof(struct sockaddr_storage); return Error_SUCCESS; } diff --git a/src/native/libs/System.Native/pal_networking.h b/src/native/libs/System.Native/pal_networking.h index 65148d67a646bf..0a46f1490aab96 100644 --- a/src/native/libs/System.Native/pal_networking.h +++ b/src/native/libs/System.Native/pal_networking.h @@ -312,7 +312,7 @@ PALEXPORT int32_t SystemNative_GetDomainName(uint8_t* name, int32_t nameLength); PALEXPORT int32_t SystemNative_GetHostName(uint8_t* name, int32_t nameLength); -PALEXPORT int32_t SystemNative_GetIPSocketAddressSizes(int32_t* ipv4SocketAddressSize, int32_t* ipv6SocketAddressSize); +PALEXPORT int32_t SystemNative_GetSocketAddressSizes(int32_t* ipv4SocketAddressSize, int32_t* ipv6SocketAddressSize, int32_t* udsSocketAddressSize, int32_t* maxSocketAddressSize); PALEXPORT int32_t SystemNative_GetAddressFamily(const uint8_t* socketAddress, int32_t socketAddressLen, int32_t* addressFamily); diff --git a/src/native/libs/System.Native/pal_networking_wasi.c b/src/native/libs/System.Native/pal_networking_wasi.c index 1cbd3655194104..baeb6494b31e7a 100644 --- a/src/native/libs/System.Native/pal_networking_wasi.c +++ b/src/native/libs/System.Native/pal_networking_wasi.c @@ -61,7 +61,7 @@ int32_t SystemNative_GetHostName(uint8_t* name, int32_t nameLength) return gethostname((char*)name, unsignedSize); } -int32_t SystemNative_GetIPSocketAddressSizes(int32_t* ipv4SocketAddressSize, int32_t* ipv6SocketAddressSize) +int32_t SystemNative_GetSocketAddressSizes(int32_t* ipv4SocketAddressSize, int32_t* ipv6SocketAddressSize, int32_t*udsSocketAddressSize, int* maxSocketAddressSize) { return Error_EFAULT; }