未验证 提交 1eed21e0 编写于 作者: S Shay Rojansky 提交者: GitHub

Implement sum and average over TimeSpan (#2423)

Closes #2339
上级 82a758cd
// ReSharper disable once CheckNamespace
namespace Microsoft.EntityFrameworkCore;
/// <summary>
/// Provides extension methods supporting NodaTime function translation for PostgreSQL.
/// </summary>
public static class NpgsqlNodaTimeDbFunctionsExtensions
{
/// <summary>
/// Computes the sum of the non-null input intervals. Corresponds to the PostgreSQL <c>sum</c> aggregate function.
/// </summary>
/// <param name="_">The <see cref="DbFunctions" /> instance.</param>
/// <param name="input">The input values to be summed.</param>
/// <seealso href="https://www.postgresql.org/docs/current/functions-aggregate.html">PostgreSQL documentation for aggregate functions.</seealso>
public static Period? Sum(this DbFunctions _, IEnumerable<Period> input)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Sum)));
/// <summary>
/// Computes the sum of the non-null input intervals. Corresponds to the PostgreSQL <c>sum</c> aggregate function.
/// </summary>
/// <param name="_">The <see cref="DbFunctions" /> instance.</param>
/// <param name="input">The input values to be summed.</param>
/// <seealso href="https://www.postgresql.org/docs/current/functions-aggregate.html">PostgreSQL documentation for aggregate functions.</seealso>
public static Duration? Sum(this DbFunctions _, IEnumerable<Duration> input)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Sum)));
/// <summary>
/// Computes the average (arithmetic mean) of the non-null input intervals. Corresponds to the PostgreSQL <c>avg</c> aggregate function.
/// </summary>
/// <param name="_">The <see cref="DbFunctions" /> instance.</param>
/// <param name="input">The input values to be computed into an average.</param>
/// <seealso href="https://www.postgresql.org/docs/current/functions-aggregate.html">PostgreSQL documentation for aggregate functions.</seealso>
public static Period? Average(this DbFunctions _, IEnumerable<Period> input)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Average)));
/// <summary>
/// Computes the average (arithmetic mean) of the non-null input intervals. Corresponds to the PostgreSQL <c>avg</c> aggregate function.
/// </summary>
/// <param name="_">The <see cref="DbFunctions" /> instance.</param>
/// <param name="input">The input values to be computed into an average.</param>
/// <seealso href="https://www.postgresql.org/docs/current/functions-aggregate.html">PostgreSQL documentation for aggregate functions.</seealso>
public static Duration? Average(this DbFunctions _, IEnumerable<Duration> input)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Average)));
}
......@@ -22,6 +22,7 @@ public static class NpgsqlNodaTimeServiceCollectionExtensions
new EntityFrameworkRelationalServicesBuilder(serviceCollection)
.TryAdd<IRelationalTypeMappingSourcePlugin, NpgsqlNodaTimeTypeMappingSourcePlugin>()
.TryAdd<IMethodCallTranslatorPlugin, NpgsqlNodaTimeMethodCallTranslatorPlugin>()
.TryAdd<IAggregateMethodCallTranslatorPlugin, NpgsqlNodaTimeAggregateMethodCallTranslatorPlugin>()
.TryAdd<IMemberTranslatorPlugin, NpgsqlNodaTimeMemberTranslatorPlugin>()
.TryAdd<IEvaluatableExpressionFilterPlugin, NpgsqlNodaTimeEvaluatableExpressionFilterPlugin>();
......
using Npgsql.EntityFrameworkCore.PostgreSQL.Query;
namespace Npgsql.EntityFrameworkCore.PostgreSQL.NodaTime.Query.Internal;
public class NpgsqlNodaTimeAggregateMethodCallTranslatorPlugin : IAggregateMethodCallTranslatorPlugin
{
public NpgsqlNodaTimeAggregateMethodCallTranslatorPlugin(ISqlExpressionFactory sqlExpressionFactory)
{
if (sqlExpressionFactory is not NpgsqlSqlExpressionFactory npgsqlSqlExpressionFactory)
{
throw new ArgumentException($"Must be an {nameof(NpgsqlSqlExpressionFactory)}", nameof(sqlExpressionFactory));
}
Translators = new IAggregateMethodCallTranslator[]
{
new NpgsqlNodaTimeAggregateMethodTranslator(npgsqlSqlExpressionFactory)
};
}
public virtual IEnumerable<IAggregateMethodCallTranslator> Translators { get; }
}
public class NpgsqlNodaTimeAggregateMethodTranslator : IAggregateMethodCallTranslator
{
private static readonly bool[][] FalseArrays = { Array.Empty<bool>(), new[] { false } };
private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory;
public NpgsqlNodaTimeAggregateMethodTranslator(NpgsqlSqlExpressionFactory sqlExpressionFactory)
=> _sqlExpressionFactory = sqlExpressionFactory;
public virtual SqlExpression? Translate(
MethodInfo method,
EnumerableExpression source,
IReadOnlyList<SqlExpression> arguments,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
if (source.Selector is not SqlExpression sqlExpression || method.DeclaringType != typeof(NpgsqlNodaTimeDbFunctionsExtensions))
{
return null;
}
switch (method.Name)
{
case nameof(NpgsqlNodaTimeDbFunctionsExtensions.Sum):
return _sqlExpressionFactory.AggregateFunction(
"sum",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
returnType: sqlExpression.Type,
sqlExpression.TypeMapping);
case nameof(NpgsqlNodaTimeDbFunctionsExtensions.Average):
return _sqlExpressionFactory.AggregateFunction(
"avg",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
returnType: sqlExpression.Type,
sqlExpression.TypeMapping);
default:
return null;
}
}
}
......@@ -36,6 +36,24 @@ public static T[] JsonAgg<T>(this DbFunctions _, IEnumerable<T> input)
public static T[] JsonbAgg<T>(this DbFunctions _, IEnumerable<T> input)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(JsonbAgg)));
/// <summary>
/// Computes the sum of the non-null input intervals. Corresponds to the PostgreSQL <c>sum</c> aggregate function.
/// </summary>
/// <param name="_">The <see cref="DbFunctions" /> instance.</param>
/// <param name="input">The input values to be summed.</param>
/// <seealso href="https://www.postgresql.org/docs/current/functions-aggregate.html">PostgreSQL documentation for aggregate functions.</seealso>
public static TimeSpan? Sum(this DbFunctions _, IEnumerable<TimeSpan> input)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Sum)));
/// <summary>
/// Computes the average (arithmetic mean) of the non-null input intervals. Corresponds to the PostgreSQL <c>avg</c> aggregate function.
/// </summary>
/// <param name="_">The <see cref="DbFunctions" /> instance.</param>
/// <param name="input">The input values to be computed into an average.</param>
/// <seealso href="https://www.postgresql.org/docs/current/functions-aggregate.html">PostgreSQL documentation for aggregate functions.</seealso>
public static TimeSpan? Average(this DbFunctions _, IEnumerable<TimeSpan> input)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Average)));
#region Range
/// <summary>
......
......@@ -128,6 +128,26 @@ public class NpgsqlMiscAggregateMethodTranslator : IAggregateMethodCallTranslato
returnType: sqlExpression.Type,
sqlExpression.TypeMapping);
case nameof(NpgsqlAggregateDbFunctionsExtensions.Sum):
return _sqlExpressionFactory.AggregateFunction(
"sum",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
returnType: sqlExpression.Type,
sqlExpression.TypeMapping);
case nameof(NpgsqlAggregateDbFunctionsExtensions.Average):
return _sqlExpressionFactory.AggregateFunction(
"avg",
new[] { sqlExpression },
source,
nullable: true,
argumentsPropagateNullability: FalseArrays[1],
returnType: sqlExpression.Type,
sqlExpression.TypeMapping);
case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonbObjectAgg):
case nameof(NpgsqlAggregateDbFunctionsExtensions.JsonObjectAgg):
var isJsonb = method.Name == nameof(NpgsqlAggregateDbFunctionsExtensions.JsonbObjectAgg);
......
......@@ -199,6 +199,15 @@ PostgresUnknownBinaryExpression postgresUnknownBinaryExpression
&& visitedBase is SqlFunctionExpression { Name: "COALESCE", Arguments: { } } coalesceExpression
&& coalesceExpression.Arguments[0] is PostgresFunctionExpression wrappedFunctionExpression)
{
// The base logic assumes sum is operating over numbers, which breaks sum over PG interval.
// Detect that case and remove the coalesce entirely (note that we don't need coalescing since sum function is in
// EF.Functions.Sum, and returns nullable. This is a temporary hack until #38158 is fixed.
if (sqlFunctionExpression.Type == typeof(TimeSpan)
|| sqlFunctionExpression.Type.FullName is "NodaTime.Period" or "NodaTime.Duration")
{
return coalesceExpression.Arguments[0];
}
var visitedArguments = coalesceExpression.Arguments!.ToArray();
visitedArguments[0] = VisitPostgresFunctionComponents(wrappedFunctionExpression);
......
......@@ -346,6 +346,44 @@ public async Task Where_TimeSpan_TotalMilliseconds(bool async)
WHERE (date_part('epoch', m.""Duration"") / 0.001) < 3700000.0");
}
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task GroupBy_Property_Select_Sum_over_TimeSpan(bool async)
{
await AssertQueryScalar(
async,
ss => ss.Set<Mission>()
.GroupBy(o => o.Id)
.Select(g => EF.Functions.Sum(g.Select(o => o.Duration))),
ss => ss.Set<Mission>()
.GroupBy(o => o.Id)
.Select(g => (TimeSpan?)new TimeSpan(g.Sum(o => o.Duration.Ticks))));
AssertSql(
@"SELECT sum(m.""Duration"")
FROM ""Missions"" AS m
GROUP BY m.""Id""");
}
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task GroupBy_Property_Select_Average_over_TimeSpan(bool async)
{
await AssertQueryScalar(
async,
ss => ss.Set<Mission>()
.GroupBy(o => o.Id)
.Select(g => EF.Functions.Average(g.Select(o => o.Duration))),
ss => ss.Set<Mission>()
.GroupBy(o => o.Id)
.Select(g => (TimeSpan?)new TimeSpan((long)g.Average(o => o.Duration.Ticks))));
AssertSql(
@"SELECT avg(m.""Duration"")
FROM ""Missions"" AS m
GROUP BY m.""Id""");
}
#endregion TimeSpan
#region DateOnly
......
......@@ -13,7 +13,7 @@ public class NorthwindFunctionsQueryNpgsqlTest : NorthwindFunctionsQueryRelation
: base(fixture)
{
ClearLog();
Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
//Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
}
public override async Task IsNullOrWhiteSpace_in_predicate(bool async)
......
......@@ -3163,6 +3163,8 @@ into g
GROUP BY o.""CustomerID""");
}
// See aggregate tests over TimeSpan in GearsOfWarQueryNpsgqlTest
private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
......
......@@ -755,6 +755,48 @@ public Task Period_FromTicks_is_not_translated()
() => ctx.Set<NodaTimeTypes>().Where(t => Period.FromNanoseconds(t.Id).Seconds == 1).ToListAsync());
}
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public async Task GroupBy_Property_Select_Sum_over_Period(bool async)
{
using var ctx = CreateContext();
// Note: Unlike Duration, Period can't be converted to total ticks (because its absolute time varies).
var query = ctx.Set<NodaTimeTypes>()
.GroupBy(o => o.Id)
.Select(g => EF.Functions.Sum(g.Select(o => o.Period)));
_ = async
? await query.ToListAsync()
: query.ToList();
AssertSql(
@"SELECT sum(n.""Period"")
FROM ""NodaTimeTypes"" AS n
GROUP BY n.""Id""");
}
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public async Task GroupBy_Property_Select_Average_over_Period(bool async)
{
using var ctx = CreateContext();
// Note: Unlike Duration, Period can't be converted to total ticks (because its absolute time varies).
var query = ctx.Set<NodaTimeTypes>()
.GroupBy(o => o.Id)
.Select(g => EF.Functions.Average(g.Select(o => o.Period)));
_ = async
? await query.ToListAsync()
: query.ToList();
AssertSql(
@"SELECT avg(n.""Period"")
FROM ""NodaTimeTypes"" AS n
GROUP BY n.""Id""");
}
#endregion Period
#region Duration
......@@ -894,6 +936,44 @@ public async Task Duration_Seconds(bool async)
WHERE floor(date_part('second', n.""Duration""))::int = 8");
}
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public async Task GroupBy_Property_Select_Sum_over_Duration(bool async)
{
await AssertQueryScalar(
async,
ss => ss.Set<NodaTimeTypes>()
.GroupBy(o => o.Id)
.Select(g => EF.Functions.Sum(g.Select(o => o.Duration))),
expectedQuery: ss => ss.Set<NodaTimeTypes>()
.GroupBy(o => o.Id)
.Select(g => (Duration?)Duration.FromTicks(g.Sum(o => o.Duration.TotalTicks))));
AssertSql(
@"SELECT sum(n.""Duration"")
FROM ""NodaTimeTypes"" AS n
GROUP BY n.""Id""");
}
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public async Task GroupBy_Property_Select_Average_over_Duration(bool async)
{
await AssertQueryScalar(
async,
ss => ss.Set<NodaTimeTypes>()
.GroupBy(o => o.Id)
.Select(g => EF.Functions.Average(g.Select(o => o.Duration))),
expectedQuery: ss => ss.Set<NodaTimeTypes>()
.GroupBy(o => o.Id)
.Select(g => (Duration?)Duration.FromTicks((long)g.Average(o => o.Duration.TotalTicks))));
AssertSql(
@"SELECT avg(n.""Duration"")
FROM ""NodaTimeTypes"" AS n
GROUP BY n.""Id""");
}
#endregion
#region Interval
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册