未验证 提交 8c5576d9 编写于 作者: A Andrey Pechkurov 提交者: GitHub

chore(sql): improve strpos function (#1535)

上级 4c99a2e8
......@@ -26,13 +26,12 @@ package org.questdb;
import io.questdb.cairo.sql.Function;
import io.questdb.cairo.sql.Record;
import io.questdb.griffin.engine.functions.CharFunction;
import io.questdb.griffin.engine.functions.StrFunction;
import io.questdb.griffin.engine.functions.constants.CharConstant;
import io.questdb.griffin.engine.functions.str.StrPosCharFunctionFactory;
import io.questdb.griffin.engine.functions.str.StrPosFunctionFactory;
import io.questdb.std.Rnd;
import org.openjdk.jmh.annotations.*;
import org.openjdk.jmh.profile.GCProfiler;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
......@@ -49,14 +48,15 @@ public class StrPosBenchmark {
private final Record[] records;
private final String[] strings;
private final Function strposStrFunc;
private final Function strposCharFunc;
private final Function strFunc;
private final Function strConstFunc;
private final Function charFunc;
private final Function charConstFunc;
private final Rnd rnd = new Rnd();
public static void main(String[] args) throws RunnerException {
Options opt = new OptionsBuilder()
.include(StrPosBenchmark.class.getSimpleName())
.addProfiler(GCProfiler.class)
.warmupIterations(2)
.measurementIterations(3)
.forks(1)
......@@ -70,7 +70,7 @@ public class StrPosBenchmark {
strings = new String[N];
for (int i = 0; i < N; i++) {
builder.setLength(0);
int startLen = rnd.nextInt(1000);
int startLen = rnd.nextInt(32);
for (int j = 0; j < startLen; j++) {
builder.append('a');
}
......@@ -89,7 +89,7 @@ public class StrPosBenchmark {
};
}
Function strFunc = new StrFunction() {
final Function strInputFunc = new StrFunction() {
@Override
public CharSequence getStr(Record rec) {
return rec.getStr(0);
......@@ -100,9 +100,28 @@ public class StrPosBenchmark {
return rec.getStr(0);
}
};
Function substrFunc = new CharConstant(',');
strposStrFunc = new StrPosFunctionFactory.Func(strFunc, substrFunc);
strposCharFunc = new StrPosCharFunctionFactory.Func(strFunc, substrFunc);
final Function substrStrInputFunc = new StrFunction() {
@Override
public CharSequence getStr(Record rec) {
return ",";
}
@Override
public CharSequence getStrB(Record rec) {
return ",";
}
};
final Function substrCharInputFunc = new CharFunction() {
@Override
public char getChar(Record rec) {
return ',';
}
};
strFunc = new StrPosFunctionFactory.Func(strInputFunc, substrStrInputFunc);
strConstFunc = new StrPosFunctionFactory.ConstFunc(strInputFunc, ",");
charFunc = new StrPosCharFunctionFactory.Func(strInputFunc, substrCharInputFunc);
charConstFunc = new StrPosCharFunctionFactory.ConstFunc(strInputFunc, ',');
}
@Benchmark
......@@ -113,12 +132,24 @@ public class StrPosBenchmark {
@Benchmark
public int testStrOverload() {
int i = rnd.nextInt(N);
return strposStrFunc.getInt(records[i]);
return strFunc.getInt(records[i]);
}
@Benchmark
public int testStrConstOverload() {
int i = rnd.nextInt(N);
return strConstFunc.getInt(records[i]);
}
@Benchmark
public int testCharOverload() {
int i = rnd.nextInt(N);
return strposCharFunc.getInt(records[i]);
return charFunc.getInt(records[i]);
}
@Benchmark
public int testCharConstOverload() {
int i = rnd.nextInt(N);
return charConstFunc.getInt(records[i]);
}
}
......@@ -32,6 +32,9 @@ import io.questdb.griffin.SqlException;
import io.questdb.griffin.SqlExecutionContext;
import io.questdb.griffin.engine.functions.BinaryFunction;
import io.questdb.griffin.engine.functions.IntFunction;
import io.questdb.griffin.engine.functions.UnaryFunction;
import io.questdb.griffin.engine.functions.constants.CharConstant;
import io.questdb.griffin.engine.functions.constants.IntConstant;
import io.questdb.std.IntList;
import io.questdb.std.Numbers;
import io.questdb.std.ObjList;
......@@ -52,7 +55,29 @@ public class StrPosCharFunctionFactory implements FunctionFactory {
CairoConfiguration configuration,
SqlExecutionContext sqlExecutionContext
) throws SqlException {
return new Func(args.getQuick(0), args.getQuick(1));
final Function substrFunc = args.getQuick(1);
if (substrFunc.isConstant()) {
char substr = substrFunc.getChar(null);
if (substr == CharConstant.ZERO.getChar(null)) {
return IntConstant.NULL;
}
return new ConstFunc(args.getQuick(0), substr);
}
return new Func(args.getQuick(0), substrFunc);
}
private static int strpos(@NotNull CharSequence str, char substr) {
final int strLen = str.length();
if (strLen < 1) {
return 0;
}
for (int i = 0; i < strLen; i++) {
if (str.charAt(i) == substr) {
return i + 1;
}
}
return 0;
}
public static class Func extends IntFunction implements BinaryFunction {
......@@ -72,21 +97,10 @@ public class StrPosCharFunctionFactory implements FunctionFactory {
return Numbers.INT_NaN;
}
final char substr = this.substrFunc.getChar(rec);
return strpos(str, substr);
}
private int strpos(@NotNull CharSequence str, char substr) {
final int strLen = str.length();
if (strLen < 1) {
return 0;
}
for (int i = 0; i < strLen; i++) {
if (str.charAt(i) == substr) {
return i + 1;
}
if (substr == CharConstant.ZERO.getChar(null)) {
return Numbers.INT_NaN;
}
return 0;
return strpos(str, substr);
}
@Override
......@@ -99,4 +113,29 @@ public class StrPosCharFunctionFactory implements FunctionFactory {
return substrFunc;
}
}
public static class ConstFunc extends IntFunction implements UnaryFunction {
private final Function strFunc;
private final char substr;
public ConstFunc(Function strFunc, char substr) {
this.strFunc = strFunc;
this.substr = substr;
}
@Override
public int getInt(Record rec) {
final CharSequence str = this.strFunc.getStr(rec);
if (str == null) {
return Numbers.INT_NaN;
}
return strpos(str, substr);
}
@Override
public Function getArg() {
return strFunc;
}
}
}
......@@ -32,6 +32,8 @@ import io.questdb.griffin.SqlException;
import io.questdb.griffin.SqlExecutionContext;
import io.questdb.griffin.engine.functions.BinaryFunction;
import io.questdb.griffin.engine.functions.IntFunction;
import io.questdb.griffin.engine.functions.UnaryFunction;
import io.questdb.griffin.engine.functions.constants.IntConstant;
import io.questdb.std.IntList;
import io.questdb.std.Numbers;
import io.questdb.std.ObjList;
......@@ -52,7 +54,40 @@ public class StrPosFunctionFactory implements FunctionFactory {
CairoConfiguration configuration,
SqlExecutionContext sqlExecutionContext
) throws SqlException {
return new Func(args.getQuick(0), args.getQuick(1));
final Function substrFunc = args.getQuick(1);
if (substrFunc.isConstant()) {
CharSequence substr = substrFunc.getStr(null);
if (substr == null) {
return IntConstant.NULL;
}
return new ConstFunc(args.getQuick(0), substr);
}
return new Func(args.getQuick(0), substrFunc);
}
private static int strpos(@NotNull CharSequence str, @NotNull CharSequence substr) {
final int substrLen = substr.length();
if (substrLen < 1) {
return 1;
}
final int strLen = str.length();
if (strLen < 1) {
return 0;
}
OUTER:
for (int i = 0, n = strLen - substrLen + 1; i < n; i++) {
final char c = str.charAt(i);
if (c == substr.charAt(0)) {
for (int k = 1; k < substrLen; k++) {
if (str.charAt(i + k) != substr.charAt(k)) {
continue OUTER;
}
}
return i + 1;
}
}
return 0;
}
public static class Func extends IntFunction implements BinaryFunction {
......@@ -78,31 +113,6 @@ public class StrPosFunctionFactory implements FunctionFactory {
return strpos(str, substr);
}
private int strpos(@NotNull CharSequence str, @NotNull CharSequence substr) {
final int substrLen = substr.length();
if (substrLen < 1) {
return 1;
}
final int strLen = str.length();
if (strLen < 1) {
return 0;
}
OUTER:
for (int i = 0; i < strLen - substrLen + 1; i++) {
final char c = str.charAt(i);
if (c == substr.charAt(0)) {
for (int k = 1; k < substrLen; k++) {
if (str.charAt(i + k) != substr.charAt(k)) {
continue OUTER;
}
}
return i + 1;
}
}
return 0;
}
@Override
public Function getLeft() {
return strFunc;
......@@ -113,4 +123,29 @@ public class StrPosFunctionFactory implements FunctionFactory {
return substrFunc;
}
}
public static class ConstFunc extends IntFunction implements UnaryFunction {
private final Function strFunc;
private final CharSequence substr;
public ConstFunc(Function strFunc, CharSequence substr) {
this.strFunc = strFunc;
this.substr = substr;
}
@Override
public int getInt(Record rec) {
final CharSequence str = this.strFunc.getStr(rec);
if (str == null) {
return Numbers.INT_NaN;
}
return strpos(str, substr);
}
@Override
public Function getArg() {
return strFunc;
}
}
}
......@@ -30,28 +30,81 @@ import org.junit.Test;
public class StrPosFunctionFactoryTest extends AbstractGriffinTest {
@Test
public void testSimple() throws Exception {
public void testVarStr() throws Exception {
assertQuery(
"substr\tstr\tstrpos\n" +
"XYZ\tABC XYZ XYZ\t5\n" +
"C\tXYZ\t0\n" +
"\tXYZ\tNaN\n" +
"\tXYZ\tNaN\n" +
"C\tXYW\t0\n" +
"C\tXYW\t0\n" +
"XYZ\tABC XYZ XYZ\t5\n" +
"XYZ\tXYZ\t1\n" +
"C\tXYZ\t0\n" +
"\tABC XYZ XYZ\tNaN\n" +
"C\tXYW\t0\n" +
"XYZ\tXYZ\t1\n" +
"C\tABC XYZ XYZ\t3\n" +
"XYZ\tABC XYZ XYZ\t5\n" +
"C\tXYZ\t0\n" +
"C\tABC XYZ XYZ\t3\n" +
"XYZ\tABC XYZ XYZ\t5\n" +
"XYZ\tXYW\t0\n" +
"XYZ\tXYZ\t1\n",
"XYZ\t\tNaN\n" +
"\tABC XYZ XYZ\tNaN\n" +
"C\t\tNaN\n" +
"C\tXYW\t0\n" +
"XYZ\t\tNaN\n",
"select substr,str,strpos(str,substr) from x",
"create table x as (" +
"select rnd_str('ABC XYZ XYZ','XYZ','XYW',NULL) as str\n" +
", rnd_str('XYZ','C',NULL) as substr\n" +
"from long_sequence(15)" +
")",
null,
true,
false,
true
);
}
@Test
public void testVarStrConstSubstr() throws Exception {
assertQuery(
"str\tstrpos\n" +
"ABC XYZ XYZ\t5\n" +
"ABC XYZ XYZ\t5\n" +
"XYZ\t1\n" +
"XYW\t0\n" +
"XYW\t0\n",
"select str,strpos(str,'XYZ') from x",
"create table x as (" +
"select rnd_str('ABC XYZ XYZ','XYZ','XYW') as str\n" +
", rnd_str('XYZ','C') as substr\n" +
"from long_sequence(5)" +
")",
null,
true,
false,
true
);
}
@Test
public void testVarChar() throws Exception {
assertQuery(
"substr\tstr\tstrpos\n" +
"T\tTEST\t1\n" +
"W\tA X X\t0\n" +
"P\tA X X\t0\n" +
"W\tCDE\t0\n" +
"Y\tCDE\t0\n" +
"X\tTEST\t0\n" +
"E\tCDE\t3\n" +
"N\tA X X\t0\n" +
"X\tTEST\t0\n" +
"Z\tA X X\t0\n" +
"X\t\tNaN\n" +
"X\tTEST\t0\n" +
"B\t\tNaN\n" +
"T\tCDE\t0\n" +
"P\t\tNaN\n",
"select substr,str,strpos(str,substr) from x",
"create table x as (" +
"select rnd_str('TEST','A X X','CDE',NULL) as str\n" +
", rnd_char() as substr\n" +
"from long_sequence(15)" +
")",
null,
......@@ -61,12 +114,33 @@ public class StrPosFunctionFactoryTest extends AbstractGriffinTest {
);
}
@Test
public void testVarCharConstSubstr() throws Exception {
assertQuery(
"str\tstrpos\n" +
"ABC XYZ XYZ\t3\n" +
"ABC XYZ XYZ\t3\n" +
"CBA\t1\n" +
"XYZ\t0\n" +
"XYZ\t0\n",
"select str,strpos(str,'C') from x",
"create table x as (" +
"select rnd_str('ABC XYZ XYZ','CBA','XYZ') as str\n" +
"from long_sequence(5)" +
")",
null,
true,
false,
true
);
}
@Test
public void testConstantNull() throws Exception {
assertQuery(
"pos1\tpos2\n" +
"NaN\tNaN\n",
"select strpos(null,'a') pos1, strpos('a',null) pos2",
"pos1\tpos2\tpos3\n" +
"NaN\tNaN\tNaN\n",
"select strpos(null,'a') pos1, strpos(null,'abc') pos2, strpos('a',null) pos3",
null,
null,
true,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册