Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/EFCore.Relational/Query/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,7 @@ protected virtual void GenerateLike(LikeExpression likeExpression, bool negated)
}

_relationalCommandBuilder.Append(" LIKE ");

Visit(likeExpression.Pattern);

if (likeExpression.EscapeChar != null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ protected override Expression VisitNewArray(NewArrayExpression newArrayExpressio
/// <inheritdoc />
protected override Expression VisitParameter(ParameterExpression parameterExpression)
=> parameterExpression.Name?.StartsWith(QueryCompilationContext.QueryParameterPrefix, StringComparison.Ordinal) == true
? new SqlParameterExpression(parameterExpression, null)
? new SqlParameterExpression(parameterExpression.Name, parameterExpression.Type, null)
: throw new InvalidOperationException(CoreStrings.TranslationFailed(parameterExpression.Print()));

/// <inheritdoc />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,30 @@ namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions;
/// <summary>
/// An expression that represents a parameter in a SQL tree.
/// </summary>
/// <remarks>
/// This is a simple wrapper around a <see cref="ParameterExpression" /> in the SQL tree.
/// Instances of this type cannot be constructed by application or database provider code. If this is a problem for your
/// application or provider, then please file an issue at
/// <see href="https://github.com/dotnet/efcore">github.com/dotnet/efcore</see>.
/// </remarks>
public sealed class SqlParameterExpression : SqlExpression
{
private readonly ParameterExpression _parameterExpression;
private readonly string _name;

internal SqlParameterExpression(ParameterExpression parameterExpression, RelationalTypeMapping? typeMapping)
: base(parameterExpression.Type.UnwrapNullableType(), typeMapping)
/// <summary>
/// Creates a new instance of the <see cref="SqlParameterExpression" /> class.
/// </summary>
/// <param name="name">The parameter name.</param>
/// <param name="type">The <see cref="Type" /> of the expression.</param>
/// <param name="typeMapping">The <see cref="RelationalTypeMapping" /> associated with the expression.</param>
public SqlParameterExpression(string name, Type type, RelationalTypeMapping? typeMapping)
: this(name, type.UnwrapNullableType(), type.IsNullableType(), typeMapping)
{
Check.DebugAssert(parameterExpression.Name != null, "Parameter must have name.");
}

_parameterExpression = parameterExpression;
_name = parameterExpression.Name;
IsNullable = parameterExpression.Type.IsNullableType();
private SqlParameterExpression(string name, Type type, bool nullable, RelationalTypeMapping? typeMapping)
: base(type, typeMapping)
{
Name = name;
IsNullable = nullable;
}

/// <summary>
/// The name of the parameter.
/// </summary>
public string Name
=> _name;
public string Name { get; }

/// <summary>
/// The bool value indicating if this parameter can have null values.
Expand All @@ -44,15 +42,15 @@ public string Name
/// <param name="typeMapping">A relational type mapping to apply.</param>
/// <returns>A new expression which has supplied type mapping.</returns>
public SqlExpression ApplyTypeMapping(RelationalTypeMapping? typeMapping)
=> new SqlParameterExpression(_parameterExpression, typeMapping);
=> new SqlParameterExpression(Name, Type, IsNullable, typeMapping);

/// <inheritdoc />
protected override Expression VisitChildren(ExpressionVisitor visitor)
=> this;

/// <inheritdoc />
protected override void Print(ExpressionPrinter expressionPrinter)
=> expressionPrinter.Append("@" + _parameterExpression.Name);
=> expressionPrinter.Append("@" + Name);

/// <inheritdoc />
public override bool Equals(object? obj)
Expand Down
124 changes: 96 additions & 28 deletions src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ protected virtual TableExpressionBase Visit(TableExpressionBase tableExpressionB
var newTable = Visit(innerJoinExpression.Table);
var newJoinPredicate = ProcessJoinPredicate(innerJoinExpression.JoinPredicate);

return TryGetBoolConstantValue(newJoinPredicate) == true
return IsTrue(newJoinPredicate)
? new CrossJoinExpression(newTable)
: innerJoinExpression.Update(newTable, newJoinPredicate);
}
Expand Down Expand Up @@ -301,7 +301,7 @@ protected virtual SelectExpression Visit(SelectExpression selectExpression)
var predicate = Visit(selectExpression.Predicate, allowOptimizedExpansion: true, out _);
changed |= predicate != selectExpression.Predicate;

if (TryGetBoolConstantValue(predicate) == true)
if (IsTrue(predicate))
{
predicate = null;
changed = true;
Expand Down Expand Up @@ -333,7 +333,7 @@ protected virtual SelectExpression Visit(SelectExpression selectExpression)
var having = Visit(selectExpression.Having, allowOptimizedExpansion: true, out _);
changed |= having != selectExpression.Having;

if (TryGetBoolConstantValue(having) == true)
if (IsTrue(having))
{
having = null;
changed = true;
Expand Down Expand Up @@ -519,20 +519,17 @@ protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool al
var test = Visit(
whenClause.Test, allowOptimizedExpansion: testIsCondition, preserveColumnNullabilityInformation: true, out _);

if (TryGetBoolConstantValue(test) is bool testConstantBool)
if (IsTrue(test))
{
if (testConstantBool)
{
testEvaluatesToTrue = true;
}
else
{
// if test evaluates to 'false' we can remove the WhenClause
RestoreNonNullableColumnsList(currentNonNullableColumnsCount);
RestoreNullValueColumnsList(currentNullValueColumnsCount);
testEvaluatesToTrue = true;
}
else if (IsFalse(test))
{
// if test evaluates to 'false' we can remove the WhenClause
RestoreNonNullableColumnsList(currentNonNullableColumnsCount);
RestoreNullValueColumnsList(currentNullValueColumnsCount);

continue;
}
continue;
}

var newResult = Visit(whenClause.Result, out var resultNullable);
Expand Down Expand Up @@ -570,7 +567,7 @@ protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool al
// if there is only one When clause and it's test evaluates to 'true' AND there is no else block, simply return the result
return elseResult == null
&& whenClauses.Count == 1
&& TryGetBoolConstantValue(whenClauses[0].Test) == true
&& IsTrue(whenClauses[0].Test)
? whenClauses[0].Result
: caseExpression.Update(operand, whenClauses, elseResult);
}
Expand Down Expand Up @@ -635,7 +632,7 @@ protected virtual SqlExpression VisitExists(

// if subquery has predicate which evaluates to false, we can simply return false
// if the exists is negated we need to return true instead
return TryGetBoolConstantValue(subquery.Predicate) == false
return IsFalse(subquery.Predicate)
? _sqlExpressionFactory.Constant(false, existsExpression.TypeMapping)
: existsExpression.Update(subquery);
}
Expand All @@ -658,7 +655,7 @@ protected virtual SqlExpression VisitIn(InExpression inExpression, bool allowOpt
var subquery = Visit(inExpression.Subquery);

// a IN (SELECT * FROM table WHERE false) => false
if (TryGetBoolConstantValue(subquery.Predicate) == false)
if (IsFalse(subquery.Predicate))
{
nullable = false;

Expand Down Expand Up @@ -967,9 +964,64 @@ protected virtual SqlExpression VisitLike(LikeExpression likeExpression, bool al
var pattern = Visit(likeExpression.Pattern, out var patternNullable);
var escapeChar = Visit(likeExpression.EscapeChar, out var escapeCharNullable);

nullable = matchNullable || patternNullable || escapeCharNullable;
SqlExpression result = likeExpression.Update(match, pattern, escapeChar);

if (UseRelationalNulls)
{
nullable = matchNullable || patternNullable || escapeCharNullable;

return result;
}

nullable = false;

// The null semantics behavior we implement for LIKE is that it only returns true when both sides are non-null and match; any other
// input returns false:
// foo LIKE f% -> true
// foo LIKE null -> false
// null LIKE f% -> false
// null LIKE null -> false

if (IsNull(match) || IsNull(pattern) || IsNull(escapeChar))
{
return _sqlExpressionFactory.Constant(false, likeExpression.TypeMapping);
}

// A constant match-all pattern (%) returns true for all cases, except where the match string is null:
// nullable_foo LIKE % -> foo IS NOT NULL
// non_nullable_foo LIKE % -> true
if (pattern is SqlConstantExpression { Value: "%" })
{
return matchNullable
? _sqlExpressionFactory.IsNotNull(match)
: _sqlExpressionFactory.Constant(true, likeExpression.TypeMapping);
}

return likeExpression.Update(match, pattern, escapeChar);
if (!allowOptimizedExpansion)
{
if (matchNullable)
{
result = _sqlExpressionFactory.AndAlso(result, GenerateNotNullCheck(match));
}

if (patternNullable)
{
result = _sqlExpressionFactory.AndAlso(result, GenerateNotNullCheck(pattern));
}

if (escapeChar is not null && escapeCharNullable)
{
result = _sqlExpressionFactory.AndAlso(result, GenerateNotNullCheck(escapeChar));
}
}

return result;

SqlExpression GenerateNotNullCheck(SqlExpression operand)
=> OptimizeNonNullableNotExpression(
_sqlExpressionFactory.Not(
ProcessNullNotNull(
_sqlExpressionFactory.IsNull(operand), operandNullable: true)));
}

/// <summary>
Expand Down Expand Up @@ -1395,8 +1447,28 @@ protected virtual SqlExpression VisitJsonScalar(
/// </summary>
protected virtual bool PreferExistsToComplexIn => false;

private static bool? TryGetBoolConstantValue(SqlExpression? expression)
=> expression is SqlConstantExpression { Value: bool boolValue } ? boolValue : null;
// Note that we can check parameter values for null since we cache by the parameter nullability; but we cannot do the same for bool.
private bool IsNull(SqlExpression? expression)
=> expression is SqlConstantExpression { Value: null }
|| expression is SqlParameterExpression { Name: string parameterName } && ParameterValues[parameterName] is null;

private bool IsTrue(SqlExpression? expression)
=> expression is SqlConstantExpression { Value: true };

private bool IsFalse(SqlExpression? expression)
=> expression is SqlConstantExpression { Value: false };

private bool TryGetBool(SqlExpression? expression, out bool value)
{
if (expression is SqlConstantExpression { Value: bool b })
{
value = b;
return true;
}

value = false;
return false;
}

private void RestoreNonNullableColumnsList(int counter)
{
Expand Down Expand Up @@ -1486,7 +1558,7 @@ private SqlExpression OptimizeComparison(
return result;
}

if (TryGetBoolConstantValue(right) is bool rightBoolValue
if (TryGetBool(right, out var rightBoolValue)
&& !leftNullable
&& left.TypeMapping!.Converter == null)
{
Expand All @@ -1502,7 +1574,7 @@ private SqlExpression OptimizeComparison(
: left;
}

if (TryGetBoolConstantValue(left) is bool leftBoolValue
if (TryGetBool(left, out var leftBoolValue)
&& !rightNullable
&& right.TypeMapping!.Converter == null)
{
Expand Down Expand Up @@ -2069,10 +2141,6 @@ private SqlExpression ProcessNullNotNull(SqlUnaryExpression sqlUnaryExpression,
private static bool IsLogicalNot(SqlUnaryExpression? sqlUnaryExpression)
=> sqlUnaryExpression is { OperatorType: ExpressionType.Not } && sqlUnaryExpression.Type == typeof(bool);

private bool IsNull(SqlExpression expression)
=> expression is SqlConstantExpression { Value: null }
|| expression is SqlParameterExpression { Name: string parameterName } && ParameterValues[parameterName] is null;

// ?a == ?b -> [(a == b) && (a != null && b != null)] || (a == null && b == null))
//
// a | b | F1 = a == b | F2 = (a != null && b != null) | F3 = F1 && F2 |
Expand Down
Loading