未验证 提交 0b4cfe66 编写于 作者: V Vlad Ilyushchenko 提交者: GitHub

feat(griffin): vectorized count() execution (#382)

上级 1831d781
...@@ -49,6 +49,9 @@ extern "C" { \ ...@@ -49,6 +49,9 @@ extern "C" { \
JNIEXPORT jdouble JNICALL Java_io_questdb_std_Vect_ ## func(JNIEnv *env, jclass cl, jlong pDouble, jlong size) { \ JNIEXPORT jdouble JNICALL Java_io_questdb_std_Vect_ ## func(JNIEnv *env, jclass cl, jlong pDouble, jlong size) { \
return func((double *) pDouble, size); \ return func((double *) pDouble, size); \
}\ }\
JNIEXPORT jdouble JNICALL JavaCritical_io_questdb_std_Vect_ ## func(jlong pDouble, jlong size) { \
return func((double *) pDouble, size); \
}\
\ \
} }
......
...@@ -194,6 +194,56 @@ Java_io_questdb_std_Rosti_keyedIntDistinct(JNIEnv *env, jclass cl, jlong pRosti, ...@@ -194,6 +194,56 @@ Java_io_questdb_std_Rosti_keyedIntDistinct(JNIEnv *env, jclass cl, jlong pRosti,
} }
} }
JNIEXPORT void JNICALL
Java_io_questdb_std_Rosti_keyedIntCount(JNIEnv *env, jclass cl, jlong pRosti, jlong pKeys, jlong count,
jint valueOffset) {
auto map = reinterpret_cast<rosti_t *>(pRosti);
const auto *pi = reinterpret_cast<int32_t *>(pKeys);
const auto value_offset = map->value_offsets_[valueOffset];
for (int i = 0; i < count; i++) {
_mm_prefetch(pi + 16, _MM_HINT_T0);
const int32_t v = pi[i];
auto res = find(map, v);
auto dest = map->slots_ + res.first;
if (PREDICT_FALSE(res.second)) {
*reinterpret_cast<int32_t *>(dest) = v;
*reinterpret_cast<jlong *>(dest + value_offset) = 1;
} else {
(*reinterpret_cast<jlong *>(dest + value_offset))++;
}
}
}
JNIEXPORT void JNICALL
Java_io_questdb_std_Rosti_keyedIntCountMerge(JNIEnv *env, jclass cl, jlong pRostiA, jlong pRostiB,
jint valueOffset) {
auto map_a = reinterpret_cast<rosti_t *>(pRostiA);
auto map_b = reinterpret_cast<rosti_t *>(pRostiB);
const auto value_offset = map_b->value_offsets_[valueOffset];
const auto capacity = map_b->capacity_;
const auto ctrl = map_b->ctrl_;
const auto shift = map_b->slot_size_shift_;
const auto slots = map_b->slots_;
for (size_t i = 0; i < capacity; i++) {
if (ctrl[i] > -1) {
auto src = slots + (i << shift);
auto key = *reinterpret_cast<int32_t *>(src);
auto count = *reinterpret_cast<jlong *>(src + value_offset);
auto res = find(map_a, key);
// maps must have identical structure to use "shift" from map B on map A
auto dest = map_a->slots_ + res.first;
if (PREDICT_FALSE(res.second)) {
*reinterpret_cast<int32_t *>(dest) = key;
*reinterpret_cast<jlong *>(dest + value_offset) = count;
} else {
(*reinterpret_cast<jlong *>(dest + value_offset)) += count;
}
}
}
}
JNIEXPORT void JNICALL JNIEXPORT void JNICALL
Java_io_questdb_std_Rosti_keyedIntKSumDoubleMerge(JNIEnv *env, jclass cl, jlong pRostiA, jlong pRostiB, Java_io_questdb_std_Rosti_keyedIntKSumDoubleMerge(JNIEnv *env, jclass cl, jlong pRostiA, jlong pRostiB,
jint valueOffset) { jint valueOffset) {
......
...@@ -541,6 +541,9 @@ public class SqlCodeGenerator implements Mutable { ...@@ -541,6 +541,9 @@ public class SqlCodeGenerator implements Mutable {
addVaf(new SumTimestampVectorAggregateFunction(ast.rhs.position, columnIndex), qc.getName()); addVaf(new SumTimestampVectorAggregateFunction(ast.rhs.position, columnIndex), qc.getName());
continue; continue;
} }
} else if (ast.type == FUNCTION && ast.paramCount == 0 && Chars.equals(ast.token, "count")) {
addVaf(new CountVectorAggregateFunction(ast.position), qc.getName());
continue;
} else if (isSingleColumnFunction(ast, "ksum")) { } else if (isSingleColumnFunction(ast, "ksum")) {
final int columnIndex = metadata.getColumnIndex(ast.rhs.token); final int columnIndex = metadata.getColumnIndex(ast.rhs.token);
final int type = metadata.getColumnType(columnIndex); final int type = metadata.getColumnType(columnIndex);
...@@ -1546,7 +1549,6 @@ public class SqlCodeGenerator implements Mutable { ...@@ -1546,7 +1549,6 @@ public class SqlCodeGenerator implements Mutable {
final RecordCursorFactory factory = generateSubQuery(model, executionContext); final RecordCursorFactory factory = generateSubQuery(model, executionContext);
try { try {
// generate special case plan for "select count() from somewhere" // generate special case plan for "select count() from somewhere"
ObjList<QueryColumn> columns = model.getBottomUpColumns(); ObjList<QueryColumn> columns = model.getBottomUpColumns();
if (columns.size() == 1) { if (columns.size() == 1) {
......
/*******************************************************************************
* ___ _ ____ ____
* / _ \ _ _ ___ ___| |_| _ \| __ )
* | | | | | | |/ _ \/ __| __| | | | _ \
* | |_| | |_| | __/\__ \ |_| |_| | |_) |
* \__\_\\__,_|\___||___/\__|____/|____/
*
* Copyright (c) 2014-2019 Appsicle
* Copyright (c) 2019-2020 QuestDB
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
******************************************************************************/
package io.questdb.griffin.engine.groupby.vect;
import io.questdb.cairo.ArrayColumnTypes;
import io.questdb.cairo.ColumnType;
import io.questdb.cairo.sql.Record;
import io.questdb.griffin.engine.functions.LongFunction;
import io.questdb.std.Rosti;
import io.questdb.std.Unsafe;
import java.util.concurrent.atomic.LongAdder;
public class CountVectorAggregateFunction extends LongFunction implements VectorAggregateFunction {
private final LongAdder count = new LongAdder();
private int valueOffset;
public CountVectorAggregateFunction(int position) {
super(position);
}
@Override
public void pushValueTypes(ArrayColumnTypes types) {
this.valueOffset = types.getColumnCount();
types.add(ColumnType.LONG);
}
@Override
public int getValueOffset() {
return valueOffset;
}
@Override
public void initRosti(long pRosti) {
Unsafe.getUnsafe().putLong(Rosti.getInitialValueSlot(pRosti, valueOffset), 0);
}
@Override
public void aggregate(long pRosti, long keyAddress, long valueAddress, long count, int workerId) {
Rosti.keyedIntCount(pRosti, keyAddress, count, valueOffset);
}
@Override
public void merge(long pRostiA, long pRostiB) {
Rosti.keyedIntCountMerge(pRostiA, pRostiB, valueOffset);
}
@Override
public void wrapUp(long pRosti) {
}
@Override
public void aggregate(long address, long count, int workerId) {
this.count.add(count);
}
@Override
public int getColumnIndex() {
return 0;
}
@Override
public void clear() {
count.reset();
}
@Override
public long getLong(Record rec) {
return count.sum();
}
}
...@@ -61,6 +61,10 @@ public final class Rosti { ...@@ -61,6 +61,10 @@ public final class Rosti {
public static native void keyedIntDistinct(long pRosti, long pKeys, long count); public static native void keyedIntDistinct(long pRosti, long pKeys, long count);
public static native void keyedIntCount(long pRosti, long pKeys, long count, int valueOffset);
public static native void keyedIntCountMerge(long pRostiA, long pRostiB, int valueOffset);
// sum double // sum double
public static native void keyedIntSumDouble(long pRosti, long pKeys, long pDouble, long count, int valueOffset); public static native void keyedIntSumDouble(long pRosti, long pKeys, long pDouble, long count, int valueOffset);
......
...@@ -357,6 +357,34 @@ public class KeyedAggregationTest extends AbstractGriffinTest { ...@@ -357,6 +357,34 @@ public class KeyedAggregationTest extends AbstractGriffinTest {
}); });
} }
@Test
public void testIntSymbolAddValueMidTableCount() throws Exception {
assertMemoryLeak(() -> {
compiler.compile("create table tab as (select rnd_symbol('s1','s2','s3', null) s1 from long_sequence(1000000))", sqlExecutionContext);
compiler.compile("alter table tab add column val long", sqlExecutionContext);
compiler.compile("insert into tab select rnd_symbol('a1','a2','a3', null), rnd_long(33, 889992, 2) from long_sequence(1000000)", sqlExecutionContext);
try (
RecordCursorFactory factory = compiler.compile("select s1, count() from tab order by s1", sqlExecutionContext).getRecordCursorFactory();
RecordCursor cursor = factory.getCursor(sqlExecutionContext)
) {
String expected = "s1\tcount\n" +
"\t500194\n" +
"a1\t248976\n" +
"a2\t250638\n" +
"a3\t250099\n" +
"s1\t249898\n" +
"s2\t250010\n" +
"s3\t250185\n";
sink.clear();
printer.print(cursor, factory.getMetadata(), true);
TestUtils.assertEquals(expected, sink);
}
});
}
@Test @Test
public void testIntSymbolAddValueMidTableMaxDate() throws Exception { public void testIntSymbolAddValueMidTableMaxDate() throws Exception {
assertMemoryLeak(() -> { assertMemoryLeak(() -> {
......
package io.questdb.griffin; /*******************************************************************************
* ___ _ ____ ____
* / _ \ _ _ ___ ___| |_| _ \| __ )
* | | | | | | |/ _ \/ __| __| | | | _ \
* | |_| | |_| | __/\__ \ |_| |_| | |_) |
* \__\_\\__,_|\___||___/\__|____/|____/
*
* Copyright (c) 2014-2019 Appsicle
* Copyright (c) 2019-2020 QuestDB
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
******************************************************************************/
import org.junit.Assert; package io.questdb.griffin;
import org.junit.BeforeClass;
import org.junit.Test;
import io.questdb.cairo.security.CairoSecurityContextImpl; import io.questdb.cairo.security.CairoSecurityContextImpl;
import io.questdb.cairo.sql.InsertMethod; import io.questdb.cairo.sql.InsertMethod;
import io.questdb.cairo.sql.InsertStatement; import io.questdb.cairo.sql.InsertStatement;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
public class SecurityTest extends AbstractGriffinTest { public class SecurityTest extends AbstractGriffinTest {
protected static SqlExecutionContext readOnlyExecutionContext; protected static SqlExecutionContext readOnlyExecutionContext;
...@@ -363,6 +386,15 @@ public class SecurityTest extends AbstractGriffinTest { ...@@ -363,6 +386,15 @@ public class SecurityTest extends AbstractGriffinTest {
@Test @Test
public void testMaxInMemoryRowsWithImplicitGroupBy() throws Exception { public void testMaxInMemoryRowsWithImplicitGroupBy() throws Exception {
SqlExecutionContext readOnlyExecutionContext = new SqlExecutionContextImpl(
messageBus,
1,
engine)
.with(
new CairoSecurityContextImpl(false,
3),
bindVariableService,
null);
assertMemoryLeak(() -> { assertMemoryLeak(() -> {
sqlExecutionContext.getRandom().reset(); sqlExecutionContext.getRandom().reset();
compiler.compile("create table tb1 as (select" + compiler.compile("create table tb1 as (select" +
...@@ -373,18 +405,18 @@ public class SecurityTest extends AbstractGriffinTest { ...@@ -373,18 +405,18 @@ public class SecurityTest extends AbstractGriffinTest {
" from long_sequence(1000)) timestamp(ts)", sqlExecutionContext); " from long_sequence(1000)) timestamp(ts)", sqlExecutionContext);
assertQuery( assertQuery(
"sym2\tcount\nGZ\t509\nRX\t491\n", "sym2\tcount\nGZ\t509\nRX\t491\n",
"select sym2, count() from tb1", "select sym2, count() from tb1 order by sym2",
null, null,
true, readOnlyExecutionContext); true, readOnlyExecutionContext);
try { try {
assertQuery( assertQuery(
"sym1\tcount\nPEHN\t265\nCPSW\t231\nHYRX\t262\nVTJW\t242\n", "sym1\tcount\nPEHN\t265\nCPSW\t231\nHYRX\t262\nVTJW\t242\n",
"select sym1, count() from tb1", "select sym1, count() from tb1 order by sym1",
null, null,
true, readOnlyExecutionContext); true, readOnlyExecutionContext);
Assert.fail(); Assert.fail();
} catch (Exception ex) { } catch (Exception ex) {
Assert.assertTrue(ex.toString().contains("limit of 2 exceeded")); Assert.assertTrue(ex.toString().contains("limit of 3 exceeded"));
} }
}); });
} }
......
...@@ -289,11 +289,11 @@ public class SampleByTest extends AbstractGriffinTest { ...@@ -289,11 +289,11 @@ public class SampleByTest extends AbstractGriffinTest {
@Test @Test
public void testGroupByCount() throws Exception { public void testGroupByCount() throws Exception {
assertQuery("c\tcount\n" + assertQuery("c\tcount\n" +
"XY\t6\n" +
"\t5\n" + "\t5\n" +
"ZP\t5\n" + "UU\t4\n" +
"UU\t4\n", "XY\t6\n" +
"select c, count() from x", "ZP\t5\n",
"select c, count() from x order by c",
"create table x as " + "create table x as " +
"(" + "(" +
"select" + "select" +
...@@ -311,12 +311,12 @@ public class SampleByTest extends AbstractGriffinTest { ...@@ -311,12 +311,12 @@ public class SampleByTest extends AbstractGriffinTest {
" long_sequence(5)" + " long_sequence(5)" +
")", ")",
"c\tcount\n" + "c\tcount\n" +
"XY\t6\n" +
"\t5\n" + "\t5\n" +
"ZP\t5\n" + "KK\t1\n" +
"UU\t4\n" +
"PL\t4\n" + "PL\t4\n" +
"KK\t1\n", "UU\t4\n" +
"XY\t6\n" +
"ZP\t5\n",
true); true);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册