Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions Directory.Build.props
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
<Project>
<PropertyGroup>
<LangVersion>12.0</LangVersion>
<DotNetTargetFrameworks>net7.0;net8.0;net9.0</DotNetTargetFrameworks>
<TargetFrameworks>$(DotNetTargetFrameworks)</TargetFrameworks>
<TargetFrameworks>net8.0</TargetFrameworks>
</PropertyGroup>
<PropertyGroup>
<MobileTargetFrameworks>net8.0-ios;net8.0-android;net8.0-macos;net8.0-tvos;net9.0-ios;net9.0-android;net9.0-macos;net9.0-tvos</MobileTargetFrameworks>
Expand Down
105 changes: 63 additions & 42 deletions src/SuperSocket.MySQL/MySQLConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,53 +42,67 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)

var endPoint = new DnsEndPoint(_host, _port, AddressFamily.InterNetwork);

var connected = await ConnectAsync(endPoint, cancellationToken).ConfigureAwait(false);
try
{
var connected = await ConnectAsync(endPoint, cancellationToken).ConfigureAwait(false);

if (!connected)
throw new InvalidOperationException($"Failed to connect to MySQL server at {_host}:{_port}");
if (!connected)
throw new InvalidOperationException($"MySQL authentication failed: Unable to connect to server at {_host}:{_port}");

// Wait for server's handshake packet
var packet = await ReceiveAsync().ConfigureAwait(false);
if (!(packet is HandshakePacket handshakePacket))
throw new InvalidOperationException("Expected handshake packet from server.");
// Wait for server's handshake packet
var packet = await ReceiveAsync().ConfigureAwait(false);
if (!(packet is HandshakePacket handshakePacket))
throw new InvalidOperationException("MySQL authentication failed: Expected handshake packet from server but received unexpected packet");

// Prepare handshake response
var handshakeResponse = new HandshakeResponsePacket
// Prepare handshake response
var handshakeResponse = new HandshakeResponsePacket
{
CapabilityFlags = (uint)(ClientCapabilities.CLIENT_PROTOCOL_41 |
ClientCapabilities.CLIENT_SECURE_CONNECTION |
ClientCapabilities.CLIENT_PLUGIN_AUTH),
MaxPacketSize = 16777216, // 16MB
CharacterSet = 0x21, // utf8_general_ci
Username = _userName,
Database = string.Empty, // Can be set later if needed
AuthPluginName = "mysql_native_password"
};

// Generate authentication response
handshakeResponse.AuthResponse = GenerateAuthResponse(handshakePacket);
handshakeResponse.SequenceId = packet.SequenceId + 1;

// Send handshake response
await SendAsync(PacketEncoder, handshakeResponse).ConfigureAwait(false);

// Wait for authentication result (OK packet or Error packet)
var authResult = await ReceiveAsync().ConfigureAwait(false);

switch (authResult)
{
case OKPacket okPacket:
// Authentication successful
IsAuthenticated = true;
break;
case ErrorPacket errorPacket:
// Authentication failed
var errorMsg = !string.IsNullOrEmpty(errorPacket.ErrorMessage)
? errorPacket.ErrorMessage
: "authentication failed";
throw new InvalidOperationException($"MySQL authentication failed: {errorMsg} (Error {errorPacket.ErrorCode})");
default:
// Any other response during authentication is also an authentication failure
throw new InvalidOperationException($"MySQL authentication failed: Unexpected packet received during authentication: {authResult?.GetType().Name ?? "null"}");
}
}
catch (InvalidOperationException)
{
CapabilityFlags = (uint)(ClientCapabilities.CLIENT_PROTOCOL_41 |
ClientCapabilities.CLIENT_SECURE_CONNECTION |
ClientCapabilities.CLIENT_PLUGIN_AUTH),
MaxPacketSize = 16777216, // 16MB
CharacterSet = 0x21, // utf8_general_ci
Username = _userName,
Database = string.Empty, // Can be set later if needed
AuthPluginName = "mysql_native_password"
};

// Generate authentication response
handshakeResponse.AuthResponse = GenerateAuthResponse(handshakePacket);
handshakeResponse.SequenceId = packet.SequenceId + 1;

// Send handshake response
await SendAsync(PacketEncoder, handshakeResponse).ConfigureAwait(false);

// Wait for authentication result (OK packet or Error packet)
var authResult = await ReceiveAsync().ConfigureAwait(false);

switch (authResult)
// Re-throw InvalidOperationException as-is (these are our authentication failures)
throw;
}
catch (Exception ex)
{
case OKPacket okPacket:
// Authentication successful
IsAuthenticated = true;
break;
case ErrorPacket errorPacket:
// Authentication failed
var errorMsg = !string.IsNullOrEmpty(errorPacket.ErrorMessage)
? errorPacket.ErrorMessage
: "Authentication failed";
throw new InvalidOperationException($"MySQL authentication failed: {errorMsg} (Error {errorPacket.ErrorCode})");
default:
throw new InvalidOperationException($"Unexpected packet received during authentication: {authResult?.GetType().Name ?? "null"}");
// Convert any other exception during authentication to authentication failure
throw new InvalidOperationException($"MySQL authentication failed: {ex.Message}", ex);
}
}

Expand Down Expand Up @@ -160,8 +174,15 @@ public async Task DisconnectAsync()
{
try
{
// Always attempt to close, but catch any exceptions to prevent
// NullReferenceException when connection was never established
await CloseAsync();
}
catch (Exception)
{
// Ignore any exceptions during cleanup - we're disconnecting anyway
// This ensures DisconnectAsync is always safe to call
}
finally
{
IsAuthenticated = false;
Expand Down
107 changes: 107 additions & 0 deletions tests/SuperSocket.MySQL.Test/AuthenticationErrorHandlingTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
using System;
using System.Threading.Tasks;
using Xunit;
using SuperSocket.MySQL;

namespace SuperSocket.MySQL.Test
{
/// <summary>
/// Tests specifically for authentication error handling scenarios
/// mentioned in the problem statement
/// </summary>
public class AuthenticationErrorHandlingTest
{
[Fact]
public async Task ConnectAsync_WhenServerUnreachable_ShouldContainAuthenticationFailed()
{
// Arrange
var connection = new MySQLConnection("unreachable-host-that-does-not-exist", 3306, "user", "password");

// Act & Assert
var exception = await Assert.ThrowsAsync<InvalidOperationException>(
async () => await connection.ConnectAsync()
);

Assert.Contains("authentication failed", exception.Message.ToLower());
}

[Fact]
public async Task ConnectAsync_WhenPortClosed_ShouldContainAuthenticationFailed()
{
// Arrange - Use a port that's unlikely to be open
var connection = new MySQLConnection("localhost", 9999, "user", "password");

// Act & Assert
var exception = await Assert.ThrowsAsync<InvalidOperationException>(
async () => await connection.ConnectAsync()
);

Assert.Contains("authentication failed", exception.Message.ToLower());
}

[Fact]
public async Task ConnectAsync_WhenNetworkError_ShouldContainAuthenticationFailed()
{
// Arrange - Use an invalid hostname that should cause DNS resolution failure
var connection = new MySQLConnection("invalid.invalid.invalid", 3306, "user", "password");

// Act & Assert
var exception = await Assert.ThrowsAsync<InvalidOperationException>(
async () => await connection.ConnectAsync()
);

Assert.Contains("authentication failed", exception.Message.ToLower());
}

[Fact]
public async Task DisconnectAsync_WhenConnectionNeverEstablished_ShouldNotThrowNullReference()
{
// Arrange
var connection = new MySQLConnection("localhost", 3306, "user", "password");

// Act & Assert - This should not throw NullReferenceException
await connection.DisconnectAsync();

// Verify authentication state is reset
Assert.False(connection.IsAuthenticated);
}

[Fact]
public async Task DisconnectAsync_AfterFailedConnection_ShouldNotThrowNullReference()
{
// Arrange
var connection = new MySQLConnection("unreachable-host", 3306, "user", "password");

try
{
// Try to connect (this will fail)
await connection.ConnectAsync();
}
catch (InvalidOperationException)
{
// Expected failure
}

// Act & Assert - This should not throw NullReferenceException
await connection.DisconnectAsync();

// Verify authentication state is reset
Assert.False(connection.IsAuthenticated);
}

[Fact]
public async Task DisconnectAsync_CalledMultipleTimes_ShouldNotThrow()
{
// Arrange
var connection = new MySQLConnection("localhost", 3306, "user", "password");

// Act & Assert - Multiple calls should not throw
await connection.DisconnectAsync();
await connection.DisconnectAsync();
await connection.DisconnectAsync();

// Verify authentication state is reset
Assert.False(connection.IsAuthenticated);
}
}
}
2 changes: 1 addition & 1 deletion tests/SuperSocket.MySQL.Test/SuperSocket.MySQL.Test.csproj
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFrameworks>net9.0</TargetFrameworks>
<TargetFrameworks>net8.0</TargetFrameworks>
<IsPackable>false</IsPackable>
</PropertyGroup>

Expand Down