未验证 提交 97d2f374 编写于 作者: S Shay Rojansky 提交者: GitHub

Fix ExecuteUpdate with invalid reference to main table (#2554)

Fixes #2478
上级 8ee39e34
using System.Diagnostics.CodeAnalysis;
using System.Net;
using System.Net.NetworkInformation;
using System.Text.RegularExpressions;
......@@ -271,6 +272,131 @@ protected virtual Expression VisitPostgresDelete(PostgresDeleteExpression pgDele
return pgDeleteExpression;
}
/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override Expression VisitUpdate(UpdateExpression updateExpression)
{
var selectExpression = updateExpression.SelectExpression;
if (selectExpression.Offset == null
&& selectExpression.Limit == null
&& selectExpression.Having == null
&& selectExpression.Orderings.Count == 0
&& selectExpression.GroupBy.Count == 0
&& selectExpression.Projection.Count == 0
&& (selectExpression.Tables.Count == 1
|| !ReferenceEquals(selectExpression.Tables[0], updateExpression.Table)
|| selectExpression.Tables[1] is InnerJoinExpression
|| selectExpression.Tables[1] is CrossJoinExpression))
{
Sql.Append("UPDATE ");
Visit(updateExpression.Table);
Sql.AppendLine();
Sql.Append("SET ");
Sql.Append(
$"{_sqlGenerationHelper.DelimitIdentifier(updateExpression.ColumnValueSetters[0].Column.Name)} = ");
Visit(updateExpression.ColumnValueSetters[0].Value);
using (Sql.Indent())
{
foreach (var columnValueSetter in updateExpression.ColumnValueSetters.Skip(1))
{
Sql.AppendLine(",");
Sql.Append($"{_sqlGenerationHelper.DelimitIdentifier(columnValueSetter.Column.Name)} = ");
Visit(columnValueSetter.Value);
}
}
var predicate = selectExpression.Predicate;
var firstTable = true;
OuterReferenceFindingExpressionVisitor? visitor = null;
if (selectExpression.Tables.Count > 1)
{
Sql.AppendLine().Append("FROM ");
for (var i = 0; i < selectExpression.Tables.Count; i++)
{
var table = selectExpression.Tables[i];
var joinExpression = table as JoinExpressionBase;
if (ReferenceEquals(updateExpression.Table, joinExpression?.Table ?? table))
{
LiftPredicate(table);
continue;
}
visitor ??= new OuterReferenceFindingExpressionVisitor(updateExpression.Table);
// PostgreSQL doesn't support referencing the main update table from anywhere except for the UPDATE WHERE clause.
// This specifically makes it impossible to have joins which reference the main table in their predicate (ON ...).
// Because of this, we detect all such inner joins and lift their predicates to the main WHERE clause (where a reference to the
// main table is allowed), producing UPDATE ... FROM x, y WHERE y.foreign_key = x.id instead of INNER JOIN ... ON.
if (firstTable)
{
LiftPredicate(table);
table = joinExpression?.Table ?? table;
}
else if (joinExpression is InnerJoinExpression innerJoinExpression
&& visitor.ContainsReferenceToMainTable(innerJoinExpression.JoinPredicate))
{
LiftPredicate(innerJoinExpression);
Sql.AppendLine(",");
using (Sql.Indent())
{
Visit(innerJoinExpression.Table);
}
continue;
}
if (firstTable)
{
firstTable = false;
}
else
{
Sql.AppendLine();
}
Visit(table);
void LiftPredicate(TableExpressionBase joinTable)
{
if (joinTable is PredicateJoinExpressionBase predicateJoinExpression)
{
Check.DebugAssert(joinExpression is not LeftJoinExpression, "Cannot lift predicate for left join");
predicate = predicate == null
? predicateJoinExpression.JoinPredicate
: new SqlBinaryExpression(
ExpressionType.AndAlso,
predicateJoinExpression.JoinPredicate,
predicate,
typeof(bool),
predicate.TypeMapping);
}
}
}
}
if (predicate != null)
{
Sql.AppendLine().Append("WHERE ");
Visit(predicate);
}
return updateExpression;
}
throw new InvalidOperationException(
RelationalStrings.ExecuteOperationWithUnsupportedOperatorInSqlGeneration(nameof(RelationalQueryableExtensions.ExecuteUpdate)));
}
/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
......@@ -891,4 +1017,41 @@ private static bool RequiresBrackets(SqlExpression expression)
generationAction(items[i]);
}
}
private sealed class OuterReferenceFindingExpressionVisitor : ExpressionVisitor
{
private readonly TableExpression _mainTable;
private bool _containsReference;
public OuterReferenceFindingExpressionVisitor(TableExpression mainTable)
=> _mainTable = mainTable;
public bool ContainsReferenceToMainTable(SqlExpression sqlExpression)
{
_containsReference = false;
Visit(sqlExpression);
return _containsReference;
}
[return: NotNullIfNotNull("expression")]
public override Expression? Visit(Expression? expression)
{
if (_containsReference)
{
return expression;
}
if (expression is ColumnExpression columnExpression
&& columnExpression.Table == _mainTable)
{
_containsReference = true;
return expression;
}
return base.Visit(expression);
}
}
}
......@@ -24,6 +24,60 @@ public class NpgsqlQueryableMethodTranslatingExpressionVisitor : RelationalQuery
{
}
/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override bool IsValidSelectExpressionForExecuteUpdate(
SelectExpression selectExpression,
EntityShaperExpression entityShaperExpression,
[NotNullWhen(true)] out TableExpression? tableExpression)
{
if (!base.IsValidSelectExpressionForExecuteUpdate(selectExpression, entityShaperExpression, out tableExpression))
{
return false;
}
// PostgreSQL doesn't support referencing the main update table from anywhere except for the UPDATE WHERE clause.
// This specifically makes it impossible to have joins which reference the main table in their predicate (ON ...).
// Because of this, we detect all such inner joins and lift their predicates to the main WHERE clause (where a reference to the
// main table is allowed) - see NpgsqlQuerySqlGenerator.VisitUpdate.
// For any other type of join which contains a reference to the main table, we return false to trigger a subquery pushdown instead.
OuterReferenceFindingExpressionVisitor? visitor = null;
for (var i = 0; i < selectExpression.Tables.Count; i++)
{
var table = selectExpression.Tables[i];
if (ReferenceEquals(table, tableExpression))
{
continue;
}
visitor ??= new OuterReferenceFindingExpressionVisitor(tableExpression);
// For inner joins, if the predicate contains a reference to the main table, NpgsqlQuerySqlGenerator will lift the predicate
// to the WHERE clause; so we only need to check the inner join's table (i.e. subquery) for such a reference.
// Cross join and cross/outer apply (lateral joins) don't have predicates, so just check the entire join for a reference to
// the main table, and switch to subquery syntax if one is found.
// Left join does have a predicate, but it isn't possible to lift it to the main WHERE clause; so also check the entire
// join.
if (table is InnerJoinExpression innerJoin)
{
table = innerJoin.Table;
}
if (visitor.ContainsReferenceToMainTable(table))
{
return false;
}
}
return true;
}
/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
......@@ -72,4 +126,41 @@ public class NpgsqlQueryableMethodTranslatingExpressionVisitor : RelationalQuery
tableExpression = null;
return false;
}
private sealed class OuterReferenceFindingExpressionVisitor : ExpressionVisitor
{
private readonly TableExpression _mainTable;
private bool _containsReference;
public OuterReferenceFindingExpressionVisitor(TableExpression mainTable)
=> _mainTable = mainTable;
public bool ContainsReferenceToMainTable(TableExpressionBase tableExpression)
{
_containsReference = false;
Visit(tableExpression);
return _containsReference;
}
[return: NotNullIfNotNull("expression")]
public override Expression? Visit(Expression? expression)
{
if (_containsReference)
{
return expression;
}
if (expression is ColumnExpression columnExpression
&& columnExpression.Table == _mainTable)
{
_containsReference = true;
return expression;
}
return base.Visit(expression);
}
}
}
using Microsoft.EntityFrameworkCore.BulkUpdates;
using Microsoft.EntityFrameworkCore.TestModels.Northwind;
namespace Npgsql.EntityFrameworkCore.PostgreSQL.BulkUpdates;
......@@ -1253,70 +1254,88 @@ public override async Task Update_with_outer_apply_set_constant(bool async)
""");
}
[ConditionalTheory(Skip = "https://github.com/npgsql/efcore.pg/issues/2478")]
[ConditionalTheory]
public override async Task Update_with_cross_join_left_join_set_constant(bool async)
{
await base.Update_with_cross_join_left_join_set_constant(async);
AssertExecuteUpdateSql(
@"UPDATE [c]
SET [c].[ContactName] = N'Updated'
FROM [Customers] AS [c]
CROSS JOIN (
SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region]
FROM [Customers] AS [c0]
WHERE [c0].[City] IS NOT NULL AND ([c0].[City] LIKE N'S%')
) AS [t]
LEFT JOIN (
SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM [Orders] AS [o]
WHERE [o].[OrderID] < 10300
) AS [t0] ON [c].[CustomerID] = [t0].[CustomerID]
WHERE [c].[CustomerID] LIKE N'F%'");
}
[ConditionalTheory(Skip = "https://github.com/npgsql/efcore.pg/issues/2478")]
"""
UPDATE "Customers" AS c
SET "ContactName" = 'Updated'
FROM (
SELECT c0."CustomerID", c0."Address", c0."City", c0."CompanyName", c0."ContactName", c0."ContactTitle", c0."Country", c0."Fax", c0."Phone", c0."PostalCode", c0."Region", t."CustomerID" AS "CustomerID0", t."Address" AS "Address0", t."City" AS "City0", t."CompanyName" AS "CompanyName0", t."ContactName" AS "ContactName0", t."ContactTitle" AS "ContactTitle0", t."Country" AS "Country0", t."Fax" AS "Fax0", t."Phone" AS "Phone0", t."PostalCode" AS "PostalCode0", t."Region" AS "Region0", t0."OrderID", t0."CustomerID" AS "CustomerID1", t0."EmployeeID", t0."OrderDate"
FROM "Customers" AS c0
CROSS JOIN (
SELECT c1."CustomerID", c1."Address", c1."City", c1."CompanyName", c1."ContactName", c1."ContactTitle", c1."Country", c1."Fax", c1."Phone", c1."PostalCode", c1."Region"
FROM "Customers" AS c1
WHERE (c1."City" IS NOT NULL) AND (c1."City" LIKE 'S%')
) AS t
LEFT JOIN (
SELECT o."OrderID", o."CustomerID", o."EmployeeID", o."OrderDate"
FROM "Orders" AS o
WHERE o."OrderID" < 10300
) AS t0 ON c0."CustomerID" = t0."CustomerID"
WHERE c0."CustomerID" LIKE 'F%'
) AS t1
WHERE c."CustomerID" = t1."CustomerID"
""");
}
[ConditionalTheory]
public override async Task Update_with_cross_join_cross_apply_set_constant(bool async)
{
await base.Update_with_cross_join_cross_apply_set_constant(async);
AssertExecuteUpdateSql(
@"UPDATE [c]
SET [c].[ContactName] = N'Updated'
FROM [Customers] AS [c]
CROSS JOIN (
SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region]
FROM [Customers] AS [c0]
WHERE [c0].[City] IS NOT NULL AND ([c0].[City] LIKE N'S%')
) AS [t]
CROSS APPLY (
SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM [Orders] AS [o]
WHERE [o].[OrderID] < 10300 AND DATEPART(year, [o].[OrderDate]) < CAST(LEN([c].[ContactName]) AS int)
) AS [t0]
WHERE [c].[CustomerID] LIKE N'F%'");
}
[ConditionalTheory(Skip = "https://github.com/npgsql/efcore.pg/issues/2478")]
"""
UPDATE "Customers" AS c
SET "ContactName" = 'Updated'
FROM (
SELECT c0."CustomerID", c0."Address", c0."City", c0."CompanyName", c0."ContactName", c0."ContactTitle", c0."Country", c0."Fax", c0."Phone", c0."PostalCode", c0."Region", t0."OrderID", t0."CustomerID" AS "CustomerID0", t0."EmployeeID", t0."OrderDate", t."CustomerID" AS "CustomerID1"
FROM "Customers" AS c0
CROSS JOIN (
SELECT c1."CustomerID"
FROM "Customers" AS c1
WHERE (c1."City" IS NOT NULL) AND (c1."City" LIKE 'S%')
) AS t
JOIN LATERAL (
SELECT o."OrderID", o."CustomerID", o."EmployeeID", o."OrderDate"
FROM "Orders" AS o
WHERE o."OrderID" < 10300 AND date_part('year', o."OrderDate")::int < length(c0."ContactName")::int
) AS t0 ON TRUE
WHERE c0."CustomerID" LIKE 'F%'
) AS t1
WHERE c."CustomerID" = t1."CustomerID"
""");
}
[ConditionalTheory]
public override async Task Update_with_cross_join_outer_apply_set_constant(bool async)
{
await base.Update_with_cross_join_outer_apply_set_constant(async);
AssertExecuteUpdateSql(
@"UPDATE [c]
SET [c].[ContactName] = N'Updated'
FROM [Customers] AS [c]
CROSS JOIN (
SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region]
FROM [Customers] AS [c0]
WHERE [c0].[City] IS NOT NULL AND ([c0].[City] LIKE N'S%')
) AS [t]
OUTER APPLY (
SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM [Orders] AS [o]
WHERE [o].[OrderID] < 10300 AND DATEPART(year, [o].[OrderDate]) < CAST(LEN([c].[ContactName]) AS int)
) AS [t0]
WHERE [c].[CustomerID] LIKE N'F%'");
"""
UPDATE "Customers" AS c
SET "ContactName" = 'Updated'
FROM (
SELECT c0."CustomerID", c0."Address", c0."City", c0."CompanyName", c0."ContactName", c0."ContactTitle", c0."Country", c0."Fax", c0."Phone", c0."PostalCode", c0."Region", t."CustomerID" AS "CustomerID0", t."Address" AS "Address0", t."City" AS "City0", t."CompanyName" AS "CompanyName0", t."ContactName" AS "ContactName0", t."ContactTitle" AS "ContactTitle0", t."Country" AS "Country0", t."Fax" AS "Fax0", t."Phone" AS "Phone0", t."PostalCode" AS "PostalCode0", t."Region" AS "Region0", t0."OrderID", t0."CustomerID" AS "CustomerID1", t0."EmployeeID", t0."OrderDate"
FROM "Customers" AS c0
CROSS JOIN (
SELECT c1."CustomerID", c1."Address", c1."City", c1."CompanyName", c1."ContactName", c1."ContactTitle", c1."Country", c1."Fax", c1."Phone", c1."PostalCode", c1."Region"
FROM "Customers" AS c1
WHERE (c1."City" IS NOT NULL) AND (c1."City" LIKE 'S%')
) AS t
LEFT JOIN LATERAL (
SELECT o."OrderID", o."CustomerID", o."EmployeeID", o."OrderDate"
FROM "Orders" AS o
WHERE o."OrderID" < 10300 AND date_part('year', o."OrderDate")::int < length(c0."ContactName")::int
) AS t0 ON TRUE
WHERE c0."CustomerID" LIKE 'F%'
) AS t1
WHERE c."CustomerID" = t1."CustomerID"
""");
}
public override async Task Update_FromSql_set_constant(bool async)
......@@ -1399,6 +1418,30 @@ public override async Task Update_Where_Join_set_property_from_joined_single_res
""");
}
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Update_with_two_inner_joins(bool async)
{
await AssertUpdate(
async,
ss => ss
.Set<OrderDetail>()
.Where(od => od.Product.Discontinued && od.Order.OrderDate > new DateTime(1990, 1, 1)),
e => e,
s => s.SetProperty(od => od.Quantity, 1),
rowsAffectedCount: 228,
(b, a) => Assert.All(a, od => Assert.Equal(1, od.Quantity)));
AssertExecuteUpdateSql(
"""
UPDATE "Order Details" AS o
SET "Quantity" = 1::smallint
FROM "Products" AS p,
"Orders" AS o0
WHERE o."OrderID" = o0."OrderID" AND o."ProductID" = p."ProductID" AND p."Discontinued" AND o0."OrderDate" > TIMESTAMP '1990-01-01 00:00:00'
""");
}
[ConditionalFact]
public virtual void Check_all_tests_overridden()
=> TestHelpers.AssertAllMethodsOverridden(GetType());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册