NpgsqlMathTranslator.cs 15.2 KB
Newer Older
S
Shay Rojansky 已提交
1
using System.Numerics;
2
using static Npgsql.EntityFrameworkCore.PostgreSQL.Utilities.Statics;
S
Shay Rojansky 已提交
3
using ExpressionExtensions = Microsoft.EntityFrameworkCore.Query.ExpressionExtensions;
4

5 6 7 8 9 10 11 12 13 14 15
namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal;

/// <summary>
/// Provides translation services for static <see cref="Math"/> methods..
/// </summary>
/// <remarks>
/// See:
///   - https://www.postgresql.org/docs/current/static/functions-math.html
///   - https://www.postgresql.org/docs/current/static/functions-conditional.html#FUNCTIONS-GREATEST-LEAST
/// </remarks>
public class NpgsqlMathTranslator : IMethodCallTranslator
16
{
17
    private static readonly Dictionary<MethodInfo, string> SupportedMethodTranslations = new()
18
    {
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
        { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(decimal) })!, "abs" },
        { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(double) })!, "abs" },
        { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(float) })!, "abs" },
        { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(int) })!, "abs" },
        { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(long) })!, "abs" },
        { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(short) })!, "abs" },
        { typeof(MathF).GetRuntimeMethod(nameof(MathF.Abs), new[] { typeof(float) })!, "abs" },
        { typeof(BigInteger).GetRuntimeMethod(nameof(BigInteger.Abs), new[] { typeof(BigInteger) })!, "abs" },

        { typeof(Math).GetRuntimeMethod(nameof(Math.Ceiling), new[] { typeof(decimal) })!, "ceiling" },
        { typeof(Math).GetRuntimeMethod(nameof(Math.Ceiling), new[] { typeof(double) })!, "ceiling" },
        { typeof(MathF).GetRuntimeMethod(nameof(MathF.Ceiling), new[] { typeof(float) })!, "ceiling" },

        { typeof(Math).GetRuntimeMethod(nameof(Math.Floor), new[] { typeof(decimal) })!, "floor" },
        { typeof(Math).GetRuntimeMethod(nameof(Math.Floor), new[] { typeof(double) })!, "floor" },
        { typeof(MathF).GetRuntimeMethod(nameof(MathF.Floor), new[] { typeof(float) })!, "floor" },

        { typeof(Math).GetRuntimeMethod(nameof(Math.Pow), new[] { typeof(double), typeof(double) })!, "power" },
        { typeof(MathF).GetRuntimeMethod(nameof(MathF.Pow), new[] { typeof(float), typeof(float) })!, "power" },
        { typeof(BigInteger).GetRuntimeMethod(nameof(BigInteger.Pow), new[] { typeof(BigInteger), typeof(int) })!, "power" },

        { typeof(Math).GetRuntimeMethod(nameof(Math.Exp), new[] { typeof(double) })!, "exp" },
        { typeof(MathF).GetRuntimeMethod(nameof(MathF.Exp), new[] { typeof(float) })!, "exp" },

        { typeof(Math).GetRuntimeMethod(nameof(Math.Log10), new[] { typeof(double) })!, "log" },
        { typeof(MathF).GetRuntimeMethod(nameof(MathF.Log10), new[] { typeof(float) })!, "log" },

        { typeof(Math).GetRuntimeMethod(nameof(Math.Log), new[] { typeof(double) })!, "ln" },
        { typeof(MathF).GetRuntimeMethod(nameof(MathF.Log), new[] { typeof(float) })!, "ln" },
        // Note: PostgreSQL has log(x,y) but only for decimal, whereas .NET has it only for double/float

        { typeof(Math).GetRuntimeMethod(nameof(Math.Sqrt), new[] { typeof(double) })!, "sqrt" },
        { typeof(MathF).GetRuntimeMethod(nameof(MathF.Sqrt), new[] { typeof(float) })!, "sqrt" },

        { typeof(Math).GetRuntimeMethod(nameof(Math.Acos), new[] { typeof(double) })!, "acos" },
        { typeof(MathF).GetRuntimeMethod(nameof(MathF.Acos), new[] { typeof(float) })!, "acos" },

        { typeof(Math).GetRuntimeMethod(nameof(Math.Asin), new[] { typeof(double) })!, "asin" },
        { typeof(MathF).GetRuntimeMethod(nameof(MathF.Asin), new[] { typeof(float) })!, "asin" },

        { typeof(Math).GetRuntimeMethod(nameof(Math.Atan), new[] { typeof(double) })!, "atan" },
        { typeof(MathF).GetRuntimeMethod(nameof(MathF.Atan), new[] { typeof(float) })!, "atan" },

        { typeof(Math).GetRuntimeMethod(nameof(Math.Atan2), new[] { typeof(double), typeof(double) })!, "atan2" },
        { typeof(MathF).GetRuntimeMethod(nameof(MathF.Atan2), new[] { typeof(float), typeof(float) })!, "atan2" },

        { typeof(Math).GetRuntimeMethod(nameof(Math.Cos), new[] { typeof(double) })!, "cos" },
        { typeof(MathF).GetRuntimeMethod(nameof(MathF.Cos), new[] { typeof(float) })!, "cos" },

        { typeof(Math).GetRuntimeMethod(nameof(Math.Sin), new[] { typeof(double) })!, "sin" },
        { typeof(MathF).GetRuntimeMethod(nameof(MathF.Sin), new[] { typeof(float) })!, "sin" },

        { typeof(Math).GetRuntimeMethod(nameof(Math.Tan), new[] { typeof(double) })!, "tan" },
        { typeof(MathF).GetRuntimeMethod(nameof(MathF.Tan), new[] { typeof(float) })!, "tan" },

        // https://www.postgresql.org/docs/current/functions-conditional.html#FUNCTIONS-GREATEST-LEAST
        { typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(decimal), typeof(decimal) })!, "greatest" },
        { typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(double), typeof(double) })!, "greatest" },
        { typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(float), typeof(float) })!, "greatest" },
        { typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(int), typeof(int) })!, "greatest" },
        { typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(long), typeof(long) })!, "greatest" },
        { typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(short), typeof(short) })!, "greatest" },
        { typeof(MathF).GetRuntimeMethod(nameof(MathF.Max), new[] { typeof(float), typeof(float) })!, "greatest" },
        { typeof(BigInteger).GetRuntimeMethod(nameof(BigInteger.Max), new[] { typeof(BigInteger), typeof(BigInteger) })!, "greatest" },

        // https://www.postgresql.org/docs/current/functions-conditional.html#FUNCTIONS-GREATEST-LEAST
        { typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(decimal), typeof(decimal) })!, "least" },
        { typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(double), typeof(double) })!, "least" },
        { typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(float), typeof(float) })!, "least" },
        { typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(int), typeof(int) })!, "least" },
        { typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(long), typeof(long) })!, "least" },
        { typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(short), typeof(short) })!, "least" },
        { typeof(MathF).GetRuntimeMethod(nameof(MathF.Min), new[] { typeof(float), typeof(float) })!, "least" },
        { typeof(BigInteger).GetRuntimeMethod(nameof(BigInteger.Min), new[] { typeof(BigInteger), typeof(BigInteger) })!, "least" },
    };

    private static readonly IEnumerable<MethodInfo> TruncateMethodInfos = new[]
    {
S
Shay Rojansky 已提交
97 98 99
        typeof(Math).GetRequiredRuntimeMethod(nameof(Math.Truncate), typeof(decimal)),
        typeof(Math).GetRequiredRuntimeMethod(nameof(Math.Truncate), typeof(double)),
        typeof(MathF).GetRequiredRuntimeMethod(nameof(MathF.Truncate), typeof(float))
100
    };
S
Shay Rojansky 已提交
101

102 103
    private static readonly IEnumerable<MethodInfo> RoundMethodInfos = new[]
    {
S
Shay Rojansky 已提交
104 105 106
        typeof(Math).GetRequiredRuntimeMethod(nameof(Math.Round), typeof(decimal)),
        typeof(Math).GetRequiredRuntimeMethod(nameof(Math.Round), typeof(double)),
        typeof(MathF).GetRequiredRuntimeMethod(nameof(MathF.Round), typeof(float))
107
    };
S
Shay Rojansky 已提交
108

109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
    private static readonly IEnumerable<MethodInfo> SignMethodInfos = new[]
    {
        typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(decimal) })!,
        typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(double) })!,
        typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(float) })!,
        typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(int) })!,
        typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(long) })!,
        typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(sbyte) })!,
        typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(short) })!,
        typeof(MathF).GetRuntimeMethod(nameof(MathF.Sign), new[] { typeof(float) })!,
    };

    private static readonly MethodInfo RoundDecimalTwoParams
        = typeof(Math).GetRuntimeMethod(nameof(Math.Round), new[] { typeof(decimal), typeof(int) })!;

    private static readonly MethodInfo DoubleIsNanMethodInfo
        = typeof(double).GetRuntimeMethod(nameof(double.IsNaN), new[] { typeof(double) })!;
    private static readonly MethodInfo DoubleIsPositiveInfinityMethodInfo
        = typeof(double).GetRuntimeMethod(nameof(double.IsPositiveInfinity), new[] { typeof(double) })!;
    private static readonly MethodInfo DoubleIsNegativeInfinityMethodInfo
        = typeof(double).GetRuntimeMethod(nameof(double.IsNegativeInfinity), new[] { typeof(double) })!;

    private static readonly MethodInfo FloatIsNanMethodInfo
        = typeof(float).GetRuntimeMethod(nameof(float.IsNaN), new[] { typeof(float) })!;
    private static readonly MethodInfo FloatIsPositiveInfinityMethodInfo
        = typeof(float).GetRuntimeMethod(nameof(float.IsPositiveInfinity), new[] { typeof(float) })!;
    private static readonly MethodInfo FloatIsNegativeInfinityMethodInfo
        = typeof(float).GetRuntimeMethod(nameof(float.IsNegativeInfinity), new[] { typeof(float) })!;

    private readonly ISqlExpressionFactory _sqlExpressionFactory;
    private readonly RelationalTypeMapping _intTypeMapping;
    private readonly RelationalTypeMapping _decimalTypeMapping;

    public NpgsqlMathTranslator(
        IRelationalTypeMappingSource typeMappingSource,
        ISqlExpressionFactory sqlExpressionFactory,
        IModel model)
    {
        _sqlExpressionFactory = sqlExpressionFactory;
S
Shay Rojansky 已提交
148 149
        _intTypeMapping = typeMappingSource.FindMapping(typeof(int), model)!;
        _decimalTypeMapping = typeMappingSource.FindMapping(typeof(decimal), model)!;
150
    }
151

152 153 154 155 156 157 158 159
    /// <inheritdoc />
    public virtual SqlExpression? Translate(
        SqlExpression? instance,
        MethodInfo method,
        IReadOnlyList<SqlExpression> arguments,
        IDiagnosticsLogger<DbLoggerCategory.Query> logger)
    {
        if (SupportedMethodTranslations.TryGetValue(method, out var sqlFunctionName))
160
        {
161 162 163
            var typeMapping = arguments.Count == 1
                ? ExpressionExtensions.InferTypeMapping(arguments[0])
                : ExpressionExtensions.InferTypeMapping(arguments[0], arguments[1]);
164

165 166
            var newArguments = new SqlExpression[arguments.Count];
            newArguments[0] = _sqlExpressionFactory.ApplyTypeMapping(arguments[0], typeMapping);
167

168
            if (arguments.Count == 2)
S
Shay Rojansky 已提交
169
            {
170
                newArguments[1] = _sqlExpressionFactory.ApplyTypeMapping(arguments[1], typeMapping);
S
Shay Rojansky 已提交
171 172
            }

173 174 175 176 177 178 179 180 181 182
            // Note: GREATER/LEAST only return NULL if *all* arguments are null, but we currently can't
            // convey this.
            return _sqlExpressionFactory.Function(
                sqlFunctionName,
                newArguments,
                nullable: true,
                argumentsPropagateNullability: TrueArrays[newArguments.Length],
                method.ReturnType,
                typeMapping);
        }
S
Shay Rojansky 已提交
183

184 185 186
        if (TruncateMethodInfos.Contains(method))
        {
            var argument = arguments[0];
187 188 189

            // C# has Round over decimal/double/float only so our argument will be one of those types (compiler puts convert node)
            // In database result will be same type except for float which returns double which we need to cast back to float.
190 191 192 193 194
            var result = (SqlExpression)_sqlExpressionFactory.Function(
                "trunc",
                new[] { argument },
                nullable: true,
                argumentsPropagateNullability: new[] { true, false, false },
195
                argument.Type == typeof(float) ? typeof(double) : argument.Type);
196 197

            if (argument.Type == typeof(float))
198
            {
199
                result = _sqlExpressionFactory.Convert(result, typeof(float));
200
            }
S
Shay Rojansky 已提交
201

202 203
            return _sqlExpressionFactory.ApplyTypeMapping(result, argument.TypeMapping);
        }
S
Shay Rojansky 已提交
204

205 206 207
        if (RoundMethodInfos.Contains(method))
        {
            var argument = arguments[0];
208 209 210

            // C# has Round over decimal/double/float only so our argument will be one of those types (compiler puts convert node)
            // In database result will be same type except for float which returns double which we need to cast back to float.
211 212 213 214 215
            var result = (SqlExpression) _sqlExpressionFactory.Function(
                "round",
                new[] { argument },
                nullable: true,
                argumentsPropagateNullability: new[] { true, true },
216
                argument.Type == typeof(float) ? typeof(double) : argument.Type);
217 218

            if (argument.Type == typeof(float))
219
            {
220
                result = _sqlExpressionFactory.Convert(result, typeof(float));
221 222
            }

223 224
            return _sqlExpressionFactory.ApplyTypeMapping(result, argument.TypeMapping);
        }
225

226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
        // PostgreSQL sign() returns 1, 0, -1, but in the same type as the argument, so we need to convert
        // the return type to int.
        if (SignMethodInfos.Contains(method))
        {
            return
                _sqlExpressionFactory.Convert(
                    _sqlExpressionFactory.Function(
                        "sign",
                        arguments,
                        nullable: true,
                        argumentsPropagateNullability: TrueArrays[1],
                        method.ReturnType),
                    typeof(int),
                    _intTypeMapping);
        }
241

242 243 244 245 246 247 248 249 250 251 252 253
        if (method == RoundDecimalTwoParams)
        {
            return _sqlExpressionFactory.Function("round", new[]
                {
                    _sqlExpressionFactory.ApplyDefaultTypeMapping(arguments[0]),
                    _sqlExpressionFactory.ApplyDefaultTypeMapping(arguments[1])
                },
                nullable: true,
                argumentsPropagateNullability: TrueArrays[2],
                method.ReturnType,
                _decimalTypeMapping);
        }
254

255 256 257 258 259
        // PostgreSQL treats NaN values as equal, against IEEE754
        if (method == DoubleIsNanMethodInfo)
        {
            return _sqlExpressionFactory.Equal(arguments[0], _sqlExpressionFactory.Constant(double.NaN));
        }
260

261 262 263 264 265 266 267 268 269 270 271 272 273 274
        if (method == FloatIsNanMethodInfo)
        {
            return _sqlExpressionFactory.Equal(arguments[0], _sqlExpressionFactory.Constant(float.NaN));
        }

        if (method == DoubleIsPositiveInfinityMethodInfo)
        {
            return _sqlExpressionFactory.Equal(arguments[0], _sqlExpressionFactory.Constant(double.PositiveInfinity));
        }

        if (method == FloatIsPositiveInfinityMethodInfo)
        {
            return _sqlExpressionFactory.Equal(arguments[0], _sqlExpressionFactory.Constant(float.PositiveInfinity));
        }
275

276 277 278
        if (method == DoubleIsNegativeInfinityMethodInfo)
        {
            return _sqlExpressionFactory.Equal(arguments[0], _sqlExpressionFactory.Constant(double.NegativeInfinity));
279
        }
280 281 282 283 284 285 286

        if (method == FloatIsNegativeInfinityMethodInfo)
        {
            return _sqlExpressionFactory.Equal(arguments[0], _sqlExpressionFactory.Constant(float.NegativeInfinity));
        }

        return null;
287
    }
288
}