using System.Numerics;
using static Npgsql.EntityFrameworkCore.PostgreSQL.Utilities.Statics;
using ExpressionExtensions = Microsoft.EntityFrameworkCore.Query.ExpressionExtensions;
namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal;
///
/// Provides translation services for static methods..
///
///
/// See:
/// - https://www.postgresql.org/docs/current/static/functions-math.html
/// - https://www.postgresql.org/docs/current/static/functions-conditional.html#FUNCTIONS-GREATEST-LEAST
///
public class NpgsqlMathTranslator : IMethodCallTranslator
{
private static readonly Dictionary SupportedMethodTranslations = new()
{
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(decimal) })!, "abs" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(double) })!, "abs" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(float) })!, "abs" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(int) })!, "abs" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(long) })!, "abs" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(short) })!, "abs" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Abs), new[] { typeof(float) })!, "abs" },
{ typeof(BigInteger).GetRuntimeMethod(nameof(BigInteger.Abs), new[] { typeof(BigInteger) })!, "abs" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Ceiling), new[] { typeof(decimal) })!, "ceiling" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Ceiling), new[] { typeof(double) })!, "ceiling" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Ceiling), new[] { typeof(float) })!, "ceiling" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Floor), new[] { typeof(decimal) })!, "floor" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Floor), new[] { typeof(double) })!, "floor" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Floor), new[] { typeof(float) })!, "floor" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Pow), new[] { typeof(double), typeof(double) })!, "power" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Pow), new[] { typeof(float), typeof(float) })!, "power" },
{ typeof(BigInteger).GetRuntimeMethod(nameof(BigInteger.Pow), new[] { typeof(BigInteger), typeof(int) })!, "power" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Exp), new[] { typeof(double) })!, "exp" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Exp), new[] { typeof(float) })!, "exp" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Log10), new[] { typeof(double) })!, "log" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Log10), new[] { typeof(float) })!, "log" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Log), new[] { typeof(double) })!, "ln" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Log), new[] { typeof(float) })!, "ln" },
// Note: PostgreSQL has log(x,y) but only for decimal, whereas .NET has it only for double/float
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sqrt), new[] { typeof(double) })!, "sqrt" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Sqrt), new[] { typeof(float) })!, "sqrt" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Acos), new[] { typeof(double) })!, "acos" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Acos), new[] { typeof(float) })!, "acos" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Asin), new[] { typeof(double) })!, "asin" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Asin), new[] { typeof(float) })!, "asin" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Atan), new[] { typeof(double) })!, "atan" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Atan), new[] { typeof(float) })!, "atan" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Atan2), new[] { typeof(double), typeof(double) })!, "atan2" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Atan2), new[] { typeof(float), typeof(float) })!, "atan2" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Cos), new[] { typeof(double) })!, "cos" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Cos), new[] { typeof(float) })!, "cos" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sin), new[] { typeof(double) })!, "sin" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Sin), new[] { typeof(float) })!, "sin" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Tan), new[] { typeof(double) })!, "tan" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Tan), new[] { typeof(float) })!, "tan" },
// https://www.postgresql.org/docs/current/functions-conditional.html#FUNCTIONS-GREATEST-LEAST
{ typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(decimal), typeof(decimal) })!, "greatest" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(double), typeof(double) })!, "greatest" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(float), typeof(float) })!, "greatest" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(int), typeof(int) })!, "greatest" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(long), typeof(long) })!, "greatest" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(short), typeof(short) })!, "greatest" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Max), new[] { typeof(float), typeof(float) })!, "greatest" },
{ typeof(BigInteger).GetRuntimeMethod(nameof(BigInteger.Max), new[] { typeof(BigInteger), typeof(BigInteger) })!, "greatest" },
// https://www.postgresql.org/docs/current/functions-conditional.html#FUNCTIONS-GREATEST-LEAST
{ typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(decimal), typeof(decimal) })!, "least" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(double), typeof(double) })!, "least" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(float), typeof(float) })!, "least" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(int), typeof(int) })!, "least" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(long), typeof(long) })!, "least" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(short), typeof(short) })!, "least" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Min), new[] { typeof(float), typeof(float) })!, "least" },
{ typeof(BigInteger).GetRuntimeMethod(nameof(BigInteger.Min), new[] { typeof(BigInteger), typeof(BigInteger) })!, "least" },
};
private static readonly IEnumerable TruncateMethodInfos = new[]
{
typeof(Math).GetRequiredRuntimeMethod(nameof(Math.Truncate), typeof(decimal)),
typeof(Math).GetRequiredRuntimeMethod(nameof(Math.Truncate), typeof(double)),
typeof(MathF).GetRequiredRuntimeMethod(nameof(MathF.Truncate), typeof(float))
};
private static readonly IEnumerable RoundMethodInfos = new[]
{
typeof(Math).GetRequiredRuntimeMethod(nameof(Math.Round), typeof(decimal)),
typeof(Math).GetRequiredRuntimeMethod(nameof(Math.Round), typeof(double)),
typeof(MathF).GetRequiredRuntimeMethod(nameof(MathF.Round), typeof(float))
};
private static readonly IEnumerable SignMethodInfos = new[]
{
typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(decimal) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(double) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(float) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(int) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(long) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(sbyte) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(short) })!,
typeof(MathF).GetRuntimeMethod(nameof(MathF.Sign), new[] { typeof(float) })!,
};
private static readonly MethodInfo RoundDecimalTwoParams
= typeof(Math).GetRuntimeMethod(nameof(Math.Round), new[] { typeof(decimal), typeof(int) })!;
private static readonly MethodInfo DoubleIsNanMethodInfo
= typeof(double).GetRuntimeMethod(nameof(double.IsNaN), new[] { typeof(double) })!;
private static readonly MethodInfo DoubleIsPositiveInfinityMethodInfo
= typeof(double).GetRuntimeMethod(nameof(double.IsPositiveInfinity), new[] { typeof(double) })!;
private static readonly MethodInfo DoubleIsNegativeInfinityMethodInfo
= typeof(double).GetRuntimeMethod(nameof(double.IsNegativeInfinity), new[] { typeof(double) })!;
private static readonly MethodInfo FloatIsNanMethodInfo
= typeof(float).GetRuntimeMethod(nameof(float.IsNaN), new[] { typeof(float) })!;
private static readonly MethodInfo FloatIsPositiveInfinityMethodInfo
= typeof(float).GetRuntimeMethod(nameof(float.IsPositiveInfinity), new[] { typeof(float) })!;
private static readonly MethodInfo FloatIsNegativeInfinityMethodInfo
= typeof(float).GetRuntimeMethod(nameof(float.IsNegativeInfinity), new[] { typeof(float) })!;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly RelationalTypeMapping _intTypeMapping;
private readonly RelationalTypeMapping _decimalTypeMapping;
public NpgsqlMathTranslator(
IRelationalTypeMappingSource typeMappingSource,
ISqlExpressionFactory sqlExpressionFactory,
IModel model)
{
_sqlExpressionFactory = sqlExpressionFactory;
_intTypeMapping = typeMappingSource.FindMapping(typeof(int), model)!;
_decimalTypeMapping = typeMappingSource.FindMapping(typeof(decimal), model)!;
}
///
public virtual SqlExpression? Translate(
SqlExpression? instance,
MethodInfo method,
IReadOnlyList arguments,
IDiagnosticsLogger logger)
{
if (SupportedMethodTranslations.TryGetValue(method, out var sqlFunctionName))
{
var typeMapping = arguments.Count == 1
? ExpressionExtensions.InferTypeMapping(arguments[0])
: ExpressionExtensions.InferTypeMapping(arguments[0], arguments[1]);
var newArguments = new SqlExpression[arguments.Count];
newArguments[0] = _sqlExpressionFactory.ApplyTypeMapping(arguments[0], typeMapping);
if (arguments.Count == 2)
{
newArguments[1] = _sqlExpressionFactory.ApplyTypeMapping(arguments[1], typeMapping);
}
// Note: GREATER/LEAST only return NULL if *all* arguments are null, but we currently can't
// convey this.
return _sqlExpressionFactory.Function(
sqlFunctionName,
newArguments,
nullable: true,
argumentsPropagateNullability: TrueArrays[newArguments.Length],
method.ReturnType,
typeMapping);
}
if (TruncateMethodInfos.Contains(method))
{
var argument = arguments[0];
// C# has Round over decimal/double/float only so our argument will be one of those types (compiler puts convert node)
// In database result will be same type except for float which returns double which we need to cast back to float.
var result = (SqlExpression)_sqlExpressionFactory.Function(
"trunc",
new[] { argument },
nullable: true,
argumentsPropagateNullability: new[] { true, false, false },
argument.Type == typeof(float) ? typeof(double) : argument.Type);
if (argument.Type == typeof(float))
{
result = _sqlExpressionFactory.Convert(result, typeof(float));
}
return _sqlExpressionFactory.ApplyTypeMapping(result, argument.TypeMapping);
}
if (RoundMethodInfos.Contains(method))
{
var argument = arguments[0];
// C# has Round over decimal/double/float only so our argument will be one of those types (compiler puts convert node)
// In database result will be same type except for float which returns double which we need to cast back to float.
var result = (SqlExpression) _sqlExpressionFactory.Function(
"round",
new[] { argument },
nullable: true,
argumentsPropagateNullability: new[] { true, true },
argument.Type == typeof(float) ? typeof(double) : argument.Type);
if (argument.Type == typeof(float))
{
result = _sqlExpressionFactory.Convert(result, typeof(float));
}
return _sqlExpressionFactory.ApplyTypeMapping(result, argument.TypeMapping);
}
// PostgreSQL sign() returns 1, 0, -1, but in the same type as the argument, so we need to convert
// the return type to int.
if (SignMethodInfos.Contains(method))
{
return
_sqlExpressionFactory.Convert(
_sqlExpressionFactory.Function(
"sign",
arguments,
nullable: true,
argumentsPropagateNullability: TrueArrays[1],
method.ReturnType),
typeof(int),
_intTypeMapping);
}
if (method == RoundDecimalTwoParams)
{
return _sqlExpressionFactory.Function("round", new[]
{
_sqlExpressionFactory.ApplyDefaultTypeMapping(arguments[0]),
_sqlExpressionFactory.ApplyDefaultTypeMapping(arguments[1])
},
nullable: true,
argumentsPropagateNullability: TrueArrays[2],
method.ReturnType,
_decimalTypeMapping);
}
// PostgreSQL treats NaN values as equal, against IEEE754
if (method == DoubleIsNanMethodInfo)
{
return _sqlExpressionFactory.Equal(arguments[0], _sqlExpressionFactory.Constant(double.NaN));
}
if (method == FloatIsNanMethodInfo)
{
return _sqlExpressionFactory.Equal(arguments[0], _sqlExpressionFactory.Constant(float.NaN));
}
if (method == DoubleIsPositiveInfinityMethodInfo)
{
return _sqlExpressionFactory.Equal(arguments[0], _sqlExpressionFactory.Constant(double.PositiveInfinity));
}
if (method == FloatIsPositiveInfinityMethodInfo)
{
return _sqlExpressionFactory.Equal(arguments[0], _sqlExpressionFactory.Constant(float.PositiveInfinity));
}
if (method == DoubleIsNegativeInfinityMethodInfo)
{
return _sqlExpressionFactory.Equal(arguments[0], _sqlExpressionFactory.Constant(double.NegativeInfinity));
}
if (method == FloatIsNegativeInfinityMethodInfo)
{
return _sqlExpressionFactory.Equal(arguments[0], _sqlExpressionFactory.Constant(float.NegativeInfinity));
}
return null;
}
}