diff --git a/core/src/main/java/io/questdb/griffin/engine/functions/groupby/AvgDoubleGroupByFunction.java b/core/src/main/java/io/questdb/griffin/engine/functions/groupby/AvgDoubleGroupByFunction.java index 0073d79ba59a6afe9047ea3c96f7b47bf239a992..a873da3e837a9dfcfb62a8d422e8c1aeb175b242 100644 --- a/core/src/main/java/io/questdb/griffin/engine/functions/groupby/AvgDoubleGroupByFunction.java +++ b/core/src/main/java/io/questdb/griffin/engine/functions/groupby/AvgDoubleGroupByFunction.java @@ -32,6 +32,7 @@ import io.questdb.cairo.sql.Record; import io.questdb.griffin.engine.functions.DoubleFunction; import io.questdb.griffin.engine.functions.GroupByFunction; import io.questdb.griffin.engine.functions.UnaryFunction; +import io.questdb.std.Numbers; import org.jetbrains.annotations.NotNull; public class AvgDoubleGroupByFunction extends DoubleFunction implements GroupByFunction, UnaryFunction { @@ -46,7 +47,7 @@ public class AvgDoubleGroupByFunction extends DoubleFunction implements GroupByF @Override public void computeFirst(MapValue mapValue, Record record) { final double d = arg.getDouble(record); - if (d == d) { + if (Numbers.isFinite(d)) { mapValue.putDouble(valueIndex, d); mapValue.putLong(valueIndex + 1, 1L); } else { @@ -58,7 +59,7 @@ public class AvgDoubleGroupByFunction extends DoubleFunction implements GroupByF @Override public void computeNext(MapValue mapValue, Record record) { final double d = arg.getDouble(record); - if (d == d) { + if (Numbers.isFinite(d)) { mapValue.addDouble(valueIndex, d); mapValue.addLong(valueIndex + 1, 1L); } diff --git a/core/src/main/java/io/questdb/griffin/engine/functions/groupby/KSumDoubleGroupByFunction.java b/core/src/main/java/io/questdb/griffin/engine/functions/groupby/KSumDoubleGroupByFunction.java index 10af59dd98c309299c2131828a71066e149b3ead..900184eed10b006bfd6fdfd1834a42cb457c885c 100644 --- a/core/src/main/java/io/questdb/griffin/engine/functions/groupby/KSumDoubleGroupByFunction.java +++ b/core/src/main/java/io/questdb/griffin/engine/functions/groupby/KSumDoubleGroupByFunction.java @@ -32,6 +32,7 @@ import io.questdb.cairo.sql.Record; import io.questdb.griffin.engine.functions.DoubleFunction; import io.questdb.griffin.engine.functions.GroupByFunction; import io.questdb.griffin.engine.functions.UnaryFunction; +import io.questdb.std.Numbers; import org.jetbrains.annotations.NotNull; public class KSumDoubleGroupByFunction extends DoubleFunction implements GroupByFunction, UnaryFunction { @@ -46,7 +47,7 @@ public class KSumDoubleGroupByFunction extends DoubleFunction implements GroupBy @Override public void computeFirst(MapValue mapValue, Record record) { final double value = arg.getDouble(record); - if (value == value) { + if (Numbers.isFinite(value)) { mapValue.putDouble(valueIndex, value); mapValue.putDouble(valueIndex + 1, value - value); mapValue.putLong(valueIndex + 2, 1); @@ -60,7 +61,7 @@ public class KSumDoubleGroupByFunction extends DoubleFunction implements GroupBy @Override public void computeNext(MapValue mapValue, Record record) { final double value = arg.getDouble(record); - if (value == value) { + if (Numbers.isFinite(value)) { double sum = mapValue.getDouble(valueIndex); double c = mapValue.getDouble(valueIndex + 1); double y = value - c; diff --git a/core/src/main/java/io/questdb/griffin/engine/functions/groupby/NSumDoubleGroupByFunction.java b/core/src/main/java/io/questdb/griffin/engine/functions/groupby/NSumDoubleGroupByFunction.java index f2fefd30f6a6b34aa6051f08ae9644748f79846d..a8eef75e6d55114a670ae63e351286bfc7f3c776 100644 --- a/core/src/main/java/io/questdb/griffin/engine/functions/groupby/NSumDoubleGroupByFunction.java +++ b/core/src/main/java/io/questdb/griffin/engine/functions/groupby/NSumDoubleGroupByFunction.java @@ -32,6 +32,7 @@ import io.questdb.cairo.sql.Record; import io.questdb.griffin.engine.functions.DoubleFunction; import io.questdb.griffin.engine.functions.GroupByFunction; import io.questdb.griffin.engine.functions.UnaryFunction; +import io.questdb.std.Numbers; import org.jetbrains.annotations.NotNull; public class NSumDoubleGroupByFunction extends DoubleFunction implements GroupByFunction, UnaryFunction { @@ -46,7 +47,7 @@ public class NSumDoubleGroupByFunction extends DoubleFunction implements GroupBy @Override public void computeFirst(MapValue mapValue, Record record) { final double value = arg.getDouble(record); - if (value == value) { + if (Numbers.isFinite(value)) { sum(mapValue, value, 0, 0); mapValue.putLong(valueIndex + 2, 1); } else { @@ -59,7 +60,7 @@ public class NSumDoubleGroupByFunction extends DoubleFunction implements GroupBy @Override public void computeNext(MapValue mapValue, Record record) { final double value = arg.getDouble(record); - if (value == value) { + if (Numbers.isFinite(value)) { sum(mapValue, value, mapValue.getDouble(valueIndex), mapValue.getDouble(valueIndex + 1)); mapValue.addLong(valueIndex + 2, 1); } diff --git a/core/src/main/java/io/questdb/griffin/engine/functions/groupby/SumDoubleGroupByFunction.java b/core/src/main/java/io/questdb/griffin/engine/functions/groupby/SumDoubleGroupByFunction.java index 04c4221dd8724d733ea0226ac09cbcdb1c421eab..6818a979928c56275eaa329dcec55f02ca368463 100644 --- a/core/src/main/java/io/questdb/griffin/engine/functions/groupby/SumDoubleGroupByFunction.java +++ b/core/src/main/java/io/questdb/griffin/engine/functions/groupby/SumDoubleGroupByFunction.java @@ -32,6 +32,7 @@ import io.questdb.cairo.sql.Record; import io.questdb.griffin.engine.functions.DoubleFunction; import io.questdb.griffin.engine.functions.GroupByFunction; import io.questdb.griffin.engine.functions.UnaryFunction; +import io.questdb.std.Numbers; import org.jetbrains.annotations.NotNull; public class SumDoubleGroupByFunction extends DoubleFunction implements GroupByFunction, UnaryFunction { @@ -46,7 +47,7 @@ public class SumDoubleGroupByFunction extends DoubleFunction implements GroupByF @Override public void computeFirst(MapValue mapValue, Record record) { final double value = arg.getDouble(record); - if (value == value) { + if (Numbers.isFinite(value)) { mapValue.putDouble(valueIndex, value); mapValue.putLong(valueIndex + 1, 1); } else { @@ -58,7 +59,7 @@ public class SumDoubleGroupByFunction extends DoubleFunction implements GroupByF @Override public void computeNext(MapValue mapValue, Record record) { final double value = arg.getDouble(record); - if (value == value) { + if (Numbers.isFinite(value)) { mapValue.addDouble(valueIndex, value); mapValue.addLong(valueIndex + 1, 1); } diff --git a/core/src/main/java/io/questdb/std/Numbers.java b/core/src/main/java/io/questdb/std/Numbers.java index 96b2f98ca3652d1ccec7b933f7e6eedd1d936c7c..80a5f3342a7b7ce8bf3fc882ea9309cc902a5a54 100644 --- a/core/src/main/java/io/questdb/std/Numbers.java +++ b/core/src/main/java/io/questdb/std/Numbers.java @@ -1073,6 +1073,10 @@ public final class Numbers { return Double.longBitsToDouble(Double.doubleToRawLongBits(roundUp00PosScale(absValue, scale)) | signMask); } + public static boolean isFinite(double d) { + return ((Double.doubleToRawLongBits(d) & EXP_BIT_MASK) != EXP_BIT_MASK); + } + private static void appendLongHex4(CharSink sink, long value) { appendLongHexPad(sink, hexDigits[(int) ((value) & 0xf)]); } diff --git a/core/src/test/java/io/questdb/griffin/engine/functions/groupby/AvgDoubleGroupByFunctionFactoryTest.java b/core/src/test/java/io/questdb/griffin/engine/functions/groupby/AvgDoubleGroupByFunctionFactoryTest.java index 270882452b65402178d88aea8c0669e45a3bb47c..9f8eb476b165379e30851d2e4fee83320fdd47c5 100644 --- a/core/src/test/java/io/questdb/griffin/engine/functions/groupby/AvgDoubleGroupByFunctionFactoryTest.java +++ b/core/src/test/java/io/questdb/griffin/engine/functions/groupby/AvgDoubleGroupByFunctionFactoryTest.java @@ -28,11 +28,19 @@ import io.questdb.cairo.sql.RecordCursor; import io.questdb.cairo.sql.RecordCursorFactory; import io.questdb.griffin.AbstractGriffinTest; import io.questdb.griffin.CompiledQuery; +import io.questdb.griffin.engine.functions.rnd.SharedRandom; +import io.questdb.std.Rnd; import io.questdb.test.tools.TestUtils; +import org.junit.Before; import org.junit.Test; public class AvgDoubleGroupByFunctionFactoryTest extends AbstractGriffinTest { + @Before + public void setUp3() { + SharedRandom.RANDOM.set(new Rnd()); + } + @Test public void testAll() throws Exception { assertMemoryLeak(() -> { @@ -51,4 +59,37 @@ public class AvgDoubleGroupByFunctionFactoryTest extends AbstractGriffinTest { }); } + @Test + public void testAvgWithInfinity() throws Exception { + assertMemoryLeak(() -> { + compiler.compile("create table test2 as(select case when rnd_double() > 0.6 then 1.0 else 0.0 end val from long_sequence(100));", sqlExecutionContext); + CompiledQuery cq = compiler.compile("select avg(1/val) from test2", sqlExecutionContext); + + try (RecordCursorFactory factory = cq.getRecordCursorFactory()) { + sink.clear(); + try (RecordCursor cursor = factory.getCursor(sqlExecutionContext)) { + printer.print(cursor, factory.getMetadata(), true); + } + } + TestUtils.assertEquals("avg\n1.0\n", sink); + }); + } + + @Test + public void testAllWithInfinity() throws Exception { + assertMemoryLeak(() -> { + compiler.compile("create table test2 as(select case when rnd_double() > 0.6 then 1.0 else 0.0 end val from long_sequence(100));", sqlExecutionContext); + CompiledQuery cq = compiler.compile("select sum(1/val) , avg(1/val), max(1/val), min(1/val), ksum(1/val), nsum(1/val) from test2", sqlExecutionContext); + + try (RecordCursorFactory factory = cq.getRecordCursorFactory()) { + sink.clear(); + try (RecordCursor cursor = factory.getCursor(sqlExecutionContext)) { + printer.print(cursor, factory.getMetadata(), true); + } + } + TestUtils.assertEquals("sum\tavg\tmax\tmin\tksum\tnsum\n" + + "44.0\t1.0\tInfinity\t1.0\t44.0\t44.0\n", + sink); + }); + } } \ No newline at end of file