ExecExprVisitor.cpp 37.3 KB
Newer Older
F
FluorineDog 已提交
1 2 3 4 5 6 7 8 9 10 11
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License

12
#include <deque>
N
neza2017 已提交
13
#include <optional>
14
#include <unordered_set>
15
#include <utility>
16
#include <boost/variant.hpp>
17

G
GuoRentong 已提交
18
#include "query/ExprImpl.h"
N
neza2017 已提交
19
#include "query/generated/ExecExprVisitor.h"
20
#include "segcore/SegmentGrowingImpl.h"
21 22
#include "query/Utils.h"
#include "query/Relational.h"
N
neza2017 已提交
23 24 25 26 27 28 29

namespace milvus::query {
// THIS CONTAINS EXTRA BODY FOR VISITOR
// WILL BE USED BY GENERATOR
namespace impl {
class ExecExprVisitor : ExprVisitor {
 public:
30 31
    ExecExprVisitor(const segcore::SegmentInternalInterface& segment, int64_t row_count, Timestamp timestamp)
        : segment_(segment), row_count_(row_count), timestamp_(timestamp) {
N
neza2017 已提交
32
    }
33 34

    BitsetType
N
neza2017 已提交
35
    call_child(Expr& expr) {
36
        AssertInfo(!bitset_opt_.has_value(), "[ExecExprVisitor]Bitset already has value before accept");
N
neza2017 已提交
37
        expr.accept(*this);
38 39 40
        AssertInfo(bitset_opt_.has_value(), "[ExecExprVisitor]Bitset doesn't have value after accept");
        auto res = std::move(bitset_opt_);
        bitset_opt_ = std::nullopt;
41
        return std::move(res.value());
N
neza2017 已提交
42 43
    }

G
GuoRentong 已提交
44
 public:
F
FluorineDog 已提交
45
    template <typename T, typename IndexFunc, typename ElementFunc>
G
GuoRentong 已提交
46
    auto
47
    ExecRangeVisitorImpl(FieldId field_id, IndexFunc func, ElementFunc element_func) -> BitsetType;
G
GuoRentong 已提交
48 49 50

    template <typename T>
    auto
51
    ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw) -> BitsetType;
52

53 54 55 56
    template <typename T>
    auto
    ExecBinaryArithOpEvalRangeVisitorDispatcher(BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType;

57 58
    template <typename T>
    auto
59
    ExecBinaryRangeVisitorDispatcher(BinaryRangeExpr& expr_raw) -> BitsetType;
G
GuoRentong 已提交
60

S
sunby 已提交
61 62
    template <typename T>
    auto
63
    ExecTermVisitorImpl(TermExpr& expr_raw) -> BitsetType;
S
sunby 已提交
64

65 66 67 68
    template <typename T>
    auto
    ExecTermVisitorImplTemplate(TermExpr& expr_raw) -> BitsetType;

69 70
    template <typename CmpFunc>
    auto
71
    ExecCompareExprDispatcher(CompareExpr& expr, CmpFunc cmp_func) -> BitsetType;
72

N
neza2017 已提交
73
 private:
74 75
    const segcore::SegmentInternalInterface& segment_;
    int64_t row_count_;
76
    Timestamp timestamp_;
77
    BitsetTypeOpt bitset_opt_;
N
neza2017 已提交
78 79 80 81
};
}  // namespace impl

void
F
FluorineDog 已提交
82 83
ExecExprVisitor::visit(LogicalUnaryExpr& expr) {
    using OpType = LogicalUnaryExpr::OpType;
84
    auto child_res = call_child(*expr.child_);
85
    BitsetType res = std::move(child_res);
86 87 88 89 90 91 92
    switch (expr.op_type_) {
        case OpType::LogicalNot: {
            res.flip();
            break;
        }
        default: {
            PanicInfo("Invalid Unary Op");
F
FluorineDog 已提交
93 94
        }
    }
95
    AssertInfo(res.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
96
    bitset_opt_ = std::move(res);
N
neza2017 已提交
97 98 99
}

void
F
FluorineDog 已提交
100 101
ExecExprVisitor::visit(LogicalBinaryExpr& expr) {
    using OpType = LogicalBinaryExpr::OpType;
F
FluorineDog 已提交
102 103
    auto left = call_child(*expr.left_);
    auto right = call_child(*expr.right_);
104
    AssertInfo(left.size() == right.size(), "[ExecExprVisitor]Left size not equal to right size");
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
    auto res = std::move(left);
    switch (expr.op_type_) {
        case OpType::LogicalAnd: {
            res &= right;
            break;
        }
        case OpType::LogicalOr: {
            res |= right;
            break;
        }
        case OpType::LogicalXor: {
            res ^= right;
            break;
        }
        case OpType::LogicalMinus: {
            res -= right;
            break;
        }
        default: {
            PanicInfo("Invalid Binary Op");
        }
    }
127
    AssertInfo(res.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
128
    bitset_opt_ = std::move(res);
129
}
F
FluorineDog 已提交
130

131
static auto
132 133
Assemble(const std::deque<BitsetType>& srcs) -> BitsetType {
    BitsetType res;
134 135 136 137 138 139 140 141 142 143 144

    int64_t total_size = 0;
    for (auto& chunk : srcs) {
        total_size += chunk.size();
    }
    res.resize(total_size);

    int64_t counter = 0;
    for (auto& chunk : srcs) {
        for (int64_t i = 0; i < chunk.size(); ++i) {
            res[counter + i] = chunk[i];
F
FluorineDog 已提交
145
        }
146
        counter += chunk.size();
F
FluorineDog 已提交
147
    }
148
    return res;
N
neza2017 已提交
149 150
}

F
FluorineDog 已提交
151
template <typename T, typename IndexFunc, typename ElementFunc>
G
GuoRentong 已提交
152
auto
153
ExecExprVisitor::ExecRangeVisitorImpl(FieldId field_id, IndexFunc index_func, ElementFunc element_func) -> BitsetType {
G
GuoRentong 已提交
154
    auto& schema = segment_.get_schema();
155 156
    auto& field_meta = schema[field_id];
    auto indexing_barrier = segment_.num_chunk_index(field_id);
B
BossZou 已提交
157 158
    auto size_per_chunk = segment_.size_per_chunk();
    auto num_chunk = upper_div(row_count_, size_per_chunk);
159
    std::deque<BitsetType> results;
160

161
    using Index = index::ScalarIndex<T>;
F
FluorineDog 已提交
162
    for (auto chunk_id = 0; chunk_id < indexing_barrier; ++chunk_id) {
163
        const Index& indexing = segment_.chunk_scalar_index<T>(field_id, chunk_id);
164 165 166
        // NOTE: knowhere is not const-ready
        // This is a dirty workaround
        auto data = index_func(const_cast<Index*>(&indexing));
167
        AssertInfo(data->size() == size_per_chunk, "[ExecExprVisitor]Data size not equal to size_per_chunk");
168
        results.emplace_back(std::move(*data));
F
FluorineDog 已提交
169
    }
170
    for (auto chunk_id = indexing_barrier; chunk_id < num_chunk; ++chunk_id) {
171
        auto this_size = chunk_id == num_chunk - 1 ? row_count_ - chunk_id * size_per_chunk : size_per_chunk;
172
        BitsetType result(this_size);
173
        auto chunk = segment_.chunk_data<T>(field_id, chunk_id);
G
GuoRentong 已提交
174
        const T* data = chunk.data();
175
        for (int index = 0; index < this_size; ++index) {
F
FluorineDog 已提交
176
            result[index] = element_func(data[index]);
G
GuoRentong 已提交
177
        }
178
        AssertInfo(result.size() == this_size, "");
179
        results.emplace_back(std::move(result));
G
GuoRentong 已提交
180
    }
181
    auto final_result = Assemble(results);
182
    AssertInfo(final_result.size() == row_count_, "[ExecExprVisitor]Final result size not equal to row count");
183
    return final_result;
G
GuoRentong 已提交
184
}
185

186
template <typename T, typename IndexFunc, typename ElementFunc>
187
auto
188 189
ExecExprVisitor::ExecDataRangeVisitorImpl(FieldId field_id, IndexFunc index_func, ElementFunc element_func)
    -> BitsetType {
190
    auto& schema = segment_.get_schema();
191
    auto& field_meta = schema[field_id];
192 193
    auto size_per_chunk = segment_.size_per_chunk();
    auto num_chunk = upper_div(row_count_, size_per_chunk);
194 195 196 197
    auto indexing_barrier = segment_.num_chunk_index(field_id);
    auto data_barrier = segment_.num_chunk_data(field_id);
    AssertInfo(std::max(data_barrier, indexing_barrier) == num_chunk,
               "max(data_barrier, index_barrier) not equal to num_chunk");
198 199
    std::deque<BitsetType> results;

200 201 202 203 204
    // for growing segment, indexing_barrier will always less than data_barrier
    // so growing segment will always execute expr plan using raw data
    // if sealed segment has loaded raw data on this field, then index_barrier = 0 and data_barrier = 1
    // in this case, sealed segment execute expr plan using raw data
    for (auto chunk_id = 0; chunk_id < data_barrier; ++chunk_id) {
205 206
        auto this_size = chunk_id == num_chunk - 1 ? row_count_ - chunk_id * size_per_chunk : size_per_chunk;
        BitsetType result(this_size);
207
        auto chunk = segment_.chunk_data<T>(field_id, chunk_id);
208 209 210 211 212 213 214
        const T* data = chunk.data();
        for (int index = 0; index < this_size; ++index) {
            result[index] = element_func(data[index]);
        }
        AssertInfo(result.size() == this_size, "[ExecExprVisitor]Chunk result size not equal to expected size");
        results.emplace_back(std::move(result));
    }
215 216 217

    // if sealed segment has loaded scalar index for this field, then index_barrier = 1 and data_barrier = 0
    // in this case, sealed segment execute expr plan using scalar index
218
    using Index = index::ScalarIndex<T>;
219 220 221 222 223 224 225 226 227 228
    for (auto chunk_id = data_barrier; chunk_id < indexing_barrier; ++chunk_id) {
        auto& indexing = segment_.chunk_scalar_index<T>(field_id, chunk_id);
        auto this_size = const_cast<Index*>(&indexing)->Count();
        BitsetType result(this_size);
        for (int offset = 0; offset < this_size; ++offset) {
            result[offset] = index_func(const_cast<Index*>(&indexing), offset);
        }
        results.emplace_back(std::move(result));
    }

229 230 231 232 233
    auto final_result = Assemble(results);
    AssertInfo(final_result.size() == row_count_, "[ExecExprVisitor]Final result size not equal to row count");
    return final_result;
}

G
GuoRentong 已提交
234 235 236 237
#pragma clang diagnostic push
#pragma ide diagnostic ignored "Simplify"
template <typename T>
auto
238
ExecExprVisitor::ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw) -> BitsetType {
239
    auto& expr = static_cast<UnaryRangeExprImpl<T>&>(expr_raw);
240
    using Index = index::ScalarIndex<T>;
241 242 243 244 245 246
    auto op = expr.op_type_;
    auto val = expr.value_;
    switch (op) {
        case OpType::Equal: {
            auto index_func = [val](Index* index) { return index->In(1, &val); };
            auto elem_func = [val](T x) { return (x == val); };
247
            return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
248 249 250 251
        }
        case OpType::NotEqual: {
            auto index_func = [val](Index* index) { return index->NotIn(1, &val); };
            auto elem_func = [val](T x) { return (x != val); };
252
            return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
253 254
        }
        case OpType::GreaterEqual: {
255
            auto index_func = [val](Index* index) { return index->Range(val, OpType::GreaterEqual); };
256
            auto elem_func = [val](T x) { return (x >= val); };
257
            return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
G
GuoRentong 已提交
258
        }
259
        case OpType::GreaterThan: {
260
            auto index_func = [val](Index* index) { return index->Range(val, OpType::GreaterThan); };
261
            auto elem_func = [val](T x) { return (x > val); };
262
            return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
263 264
        }
        case OpType::LessEqual: {
265
            auto index_func = [val](Index* index) { return index->Range(val, OpType::LessEqual); };
266
            auto elem_func = [val](T x) { return (x <= val); };
267
            return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
268 269
        }
        case OpType::LessThan: {
270
            auto index_func = [val](Index* index) { return index->Range(val, OpType::LessThan); };
271
            auto elem_func = [val](T x) { return (x < val); };
272 273 274 275 276
            return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
        }
        case OpType::PrefixMatch: {
            auto index_func = [val](Index* index) {
                auto dataset = std::make_unique<knowhere::Dataset>();
277 278
                dataset->Set(milvus::index::OPERATOR_TYPE, OpType::PrefixMatch);
                dataset->Set(milvus::index::PREFIX_VALUE, val);
279 280 281 282 283 284
                return index->Query(std::move(dataset));
            };
            auto elem_func = [val, op](T x) { return Match(x, val, op); };
            return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
        }
        // TODO: PostfixMatch
285
        default: {
G
GuoRentong 已提交
286 287
            PanicInfo("unsupported range node");
        }
288 289 290 291
    }
}
#pragma clang diagnostic pop

292 293 294 295 296 297
#pragma clang diagnostic push
#pragma ide diagnostic ignored "Simplify"
template <typename T>
auto
ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcher(BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType {
    auto& expr = static_cast<BinaryArithOpEvalRangeExprImpl<T>&>(expr_raw);
298
    using Index = index::ScalarIndex<T>;
299 300 301 302 303 304 305 306 307
    auto arith_op = expr.arith_op_;
    auto right_operand = expr.right_operand_;
    auto op = expr.op_type_;
    auto val = expr.value_;

    switch (op) {
        case OpType::Equal: {
            switch (arith_op) {
                case ArithOpType::Add: {
308 309 310 311
                    auto index_func = [val, right_operand](Index* index, size_t offset) {
                        auto x = index->Reverse_Lookup(offset);
                        return (x + right_operand) == val;
                    };
312
                    auto elem_func = [val, right_operand](T x) { return ((x + right_operand) == val); };
313
                    return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
314 315
                }
                case ArithOpType::Sub: {
316 317 318 319
                    auto index_func = [val, right_operand](Index* index, size_t offset) {
                        auto x = index->Reverse_Lookup(offset);
                        return (x - right_operand) == val;
                    };
320
                    auto elem_func = [val, right_operand](T x) { return ((x - right_operand) == val); };
321
                    return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
322 323
                }
                case ArithOpType::Mul: {
324 325 326 327
                    auto index_func = [val, right_operand](Index* index, size_t offset) {
                        auto x = index->Reverse_Lookup(offset);
                        return (x * right_operand) == val;
                    };
328
                    auto elem_func = [val, right_operand](T x) { return ((x * right_operand) == val); };
329
                    return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
330 331
                }
                case ArithOpType::Div: {
332 333 334 335
                    auto index_func = [val, right_operand](Index* index, size_t offset) {
                        auto x = index->Reverse_Lookup(offset);
                        return (x / right_operand) == val;
                    };
336
                    auto elem_func = [val, right_operand](T x) { return ((x / right_operand) == val); };
337
                    return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
338 339
                }
                case ArithOpType::Mod: {
340 341 342 343
                    auto index_func = [val, right_operand](Index* index, size_t offset) {
                        auto x = index->Reverse_Lookup(offset);
                        return static_cast<T>(fmod(x, right_operand)) == val;
                    };
344 345 346
                    auto elem_func = [val, right_operand](T x) {
                        return (static_cast<T>(fmod(x, right_operand)) == val);
                    };
347
                    return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
348 349 350 351 352 353 354 355 356
                }
                default: {
                    PanicInfo("unsupported arithmetic operation");
                }
            }
        }
        case OpType::NotEqual: {
            switch (arith_op) {
                case ArithOpType::Add: {
357 358 359 360
                    auto index_func = [val, right_operand](Index* index, size_t offset) {
                        auto x = index->Reverse_Lookup(offset);
                        return (x + right_operand) != val;
                    };
361
                    auto elem_func = [val, right_operand](T x) { return ((x + right_operand) != val); };
362
                    return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
363 364
                }
                case ArithOpType::Sub: {
365 366 367 368
                    auto index_func = [val, right_operand](Index* index, size_t offset) {
                        auto x = index->Reverse_Lookup(offset);
                        return (x - right_operand) != val;
                    };
369
                    auto elem_func = [val, right_operand](T x) { return ((x - right_operand) != val); };
370
                    return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
371 372
                }
                case ArithOpType::Mul: {
373 374 375 376
                    auto index_func = [val, right_operand](Index* index, size_t offset) {
                        auto x = index->Reverse_Lookup(offset);
                        return (x * right_operand) != val;
                    };
377
                    auto elem_func = [val, right_operand](T x) { return ((x * right_operand) != val); };
378
                    return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
379 380
                }
                case ArithOpType::Div: {
381 382 383 384
                    auto index_func = [val, right_operand](Index* index, size_t offset) {
                        auto x = index->Reverse_Lookup(offset);
                        return (x / right_operand) != val;
                    };
385
                    auto elem_func = [val, right_operand](T x) { return ((x / right_operand) != val); };
386
                    return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
387 388
                }
                case ArithOpType::Mod: {
389 390 391 392
                    auto index_func = [val, right_operand](Index* index, size_t offset) {
                        auto x = index->Reverse_Lookup(offset);
                        return static_cast<T>(fmod(x, right_operand)) != val;
                    };
393 394 395
                    auto elem_func = [val, right_operand](T x) {
                        return (static_cast<T>(fmod(x, right_operand)) != val);
                    };
396
                    return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
397 398 399 400 401 402 403 404 405 406 407 408 409
                }
                default: {
                    PanicInfo("unsupported arithmetic operation");
                }
            }
        }
        default: {
            PanicInfo("unsupported range node with arithmetic operation");
        }
    }
}
#pragma clang diagnostic pop

410 411 412 413
#pragma clang diagnostic push
#pragma ide diagnostic ignored "Simplify"
template <typename T>
auto
414
ExecExprVisitor::ExecBinaryRangeVisitorDispatcher(BinaryRangeExpr& expr_raw) -> BitsetType {
415
    auto& expr = static_cast<BinaryRangeExprImpl<T>&>(expr_raw);
416
    using Index = index::ScalarIndex<T>;
417 418 419 420
    bool lower_inclusive = expr.lower_inclusive_;
    bool upper_inclusive = expr.upper_inclusive_;
    T val1 = expr.lower_value_;
    T val2 = expr.upper_value_;
421

422 423 424
    auto index_func = [=](Index* index) { return index->Range(val1, lower_inclusive, val2, upper_inclusive); };
    if (lower_inclusive && upper_inclusive) {
        auto elem_func = [val1, val2](T x) { return (val1 <= x && x <= val2); };
425
        return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
426 427
    } else if (lower_inclusive && !upper_inclusive) {
        auto elem_func = [val1, val2](T x) { return (val1 <= x && x < val2); };
428
        return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
429 430
    } else if (!lower_inclusive && upper_inclusive) {
        auto elem_func = [val1, val2](T x) { return (val1 < x && x <= val2); };
431
        return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
G
GuoRentong 已提交
432
    } else {
433
        auto elem_func = [val1, val2](T x) { return (val1 < x && x < val2); };
434
        return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
G
GuoRentong 已提交
435 436 437 438
    }
}
#pragma clang diagnostic pop

N
neza2017 已提交
439
void
440
ExecExprVisitor::visit(UnaryRangeExpr& expr) {
441
    auto& field_meta = segment_.get_schema()[expr.field_id_];
442 443
    AssertInfo(expr.data_type_ == field_meta.get_data_type(),
               "[ExecExprVisitor]DataType of expr isn't field_meta data type");
444
    BitsetType res;
G
GuoRentong 已提交
445
    switch (expr.data_type_) {
N
neza2017 已提交
446
        case DataType::BOOL: {
447
            res = ExecUnaryRangeVisitorDispatcher<bool>(expr);
N
neza2017 已提交
448 449
            break;
        }
G
GuoRentong 已提交
450
        case DataType::INT8: {
451
            res = ExecUnaryRangeVisitorDispatcher<int8_t>(expr);
G
GuoRentong 已提交
452 453 454
            break;
        }
        case DataType::INT16: {
455
            res = ExecUnaryRangeVisitorDispatcher<int16_t>(expr);
G
GuoRentong 已提交
456 457 458
            break;
        }
        case DataType::INT32: {
459
            res = ExecUnaryRangeVisitorDispatcher<int32_t>(expr);
G
GuoRentong 已提交
460 461 462
            break;
        }
        case DataType::INT64: {
463
            res = ExecUnaryRangeVisitorDispatcher<int64_t>(expr);
G
GuoRentong 已提交
464 465 466
            break;
        }
        case DataType::FLOAT: {
467
            res = ExecUnaryRangeVisitorDispatcher<float>(expr);
G
GuoRentong 已提交
468 469 470
            break;
        }
        case DataType::DOUBLE: {
471 472 473
            res = ExecUnaryRangeVisitorDispatcher<double>(expr);
            break;
        }
474 475 476 477
        case DataType::VARCHAR: {
            res = ExecUnaryRangeVisitorDispatcher<std::string>(expr);
            break;
        }
478 479 480
        default:
            PanicInfo("unsupported");
    }
481
    AssertInfo(res.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
482
    bitset_opt_ = std::move(res);
483 484
}

485 486
void
ExecExprVisitor::visit(BinaryArithOpEvalRangeExpr& expr) {
487
    auto& field_meta = segment_.get_schema()[expr.field_id_];
488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522
    AssertInfo(expr.data_type_ == field_meta.get_data_type(),
               "[ExecExprVisitor]DataType of expr isn't field_meta data type");
    BitsetType res;
    switch (expr.data_type_) {
        case DataType::INT8: {
            res = ExecBinaryArithOpEvalRangeVisitorDispatcher<int8_t>(expr);
            break;
        }
        case DataType::INT16: {
            res = ExecBinaryArithOpEvalRangeVisitorDispatcher<int16_t>(expr);
            break;
        }
        case DataType::INT32: {
            res = ExecBinaryArithOpEvalRangeVisitorDispatcher<int32_t>(expr);
            break;
        }
        case DataType::INT64: {
            res = ExecBinaryArithOpEvalRangeVisitorDispatcher<int64_t>(expr);
            break;
        }
        case DataType::FLOAT: {
            res = ExecBinaryArithOpEvalRangeVisitorDispatcher<float>(expr);
            break;
        }
        case DataType::DOUBLE: {
            res = ExecBinaryArithOpEvalRangeVisitorDispatcher<double>(expr);
            break;
        }
        default:
            PanicInfo("unsupported");
    }
    AssertInfo(res.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
    bitset_opt_ = std::move(res);
}

523 524
void
ExecExprVisitor::visit(BinaryRangeExpr& expr) {
525
    auto& field_meta = segment_.get_schema()[expr.field_id_];
526 527
    AssertInfo(expr.data_type_ == field_meta.get_data_type(),
               "[ExecExprVisitor]DataType of expr isn't field_meta data type");
528
    BitsetType res;
529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555
    switch (expr.data_type_) {
        case DataType::BOOL: {
            res = ExecBinaryRangeVisitorDispatcher<bool>(expr);
            break;
        }
        case DataType::INT8: {
            res = ExecBinaryRangeVisitorDispatcher<int8_t>(expr);
            break;
        }
        case DataType::INT16: {
            res = ExecBinaryRangeVisitorDispatcher<int16_t>(expr);
            break;
        }
        case DataType::INT32: {
            res = ExecBinaryRangeVisitorDispatcher<int32_t>(expr);
            break;
        }
        case DataType::INT64: {
            res = ExecBinaryRangeVisitorDispatcher<int64_t>(expr);
            break;
        }
        case DataType::FLOAT: {
            res = ExecBinaryRangeVisitorDispatcher<float>(expr);
            break;
        }
        case DataType::DOUBLE: {
            res = ExecBinaryRangeVisitorDispatcher<double>(expr);
G
GuoRentong 已提交
556 557
            break;
        }
558 559 560 561
        case DataType::VARCHAR: {
            res = ExecBinaryRangeVisitorDispatcher<std::string>(expr);
            break;
        }
G
GuoRentong 已提交
562 563 564
        default:
            PanicInfo("unsupported");
    }
565
    AssertInfo(res.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
566
    bitset_opt_ = std::move(res);
N
neza2017 已提交
567 568
}

569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584
template <typename Op>
struct relational {
    template <typename T, typename U>
    bool
    operator()(T const& a, U const& b) const {
        return Op{}(a, b);
    }
    template <typename... T>
    bool
    operator()(T const&...) const {
        PanicInfo("incompatible operands");
    }
};

template <typename Op>
auto
585
ExecExprVisitor::ExecCompareExprDispatcher(CompareExpr& expr, Op op) -> BitsetType {
586
    using number = boost::variant<bool, int8_t, int16_t, int32_t, int64_t, float, double, std::string>;
587 588
    auto size_per_chunk = segment_.size_per_chunk();
    auto num_chunk = upper_div(row_count_, size_per_chunk);
589
    std::deque<BitsetType> bitsets;
590 591 592 593 594 595 596 597 598 599 600 601

    // check for sealed segment, load either raw field data or index
    auto left_indexing_barrier = segment_.num_chunk_index(expr.left_field_id_);
    auto left_data_barrier = segment_.num_chunk_data(expr.left_field_id_);
    AssertInfo(std::max(left_data_barrier, left_indexing_barrier) == num_chunk,
               "max(left_data_barrier, left_indexing_barrier) not equal to num_chunk");

    auto right_indexing_barrier = segment_.num_chunk_index(expr.right_field_id_);
    auto right_data_barrier = segment_.num_chunk_data(expr.right_field_id_);
    AssertInfo(std::max(right_data_barrier, right_indexing_barrier) == num_chunk,
               "max(right_data_barrier, right_indexing_barrier) not equal to num_chunk");

602 603
    for (int64_t chunk_id = 0; chunk_id < num_chunk; ++chunk_id) {
        auto size = chunk_id == num_chunk - 1 ? row_count_ - chunk_id * size_per_chunk : size_per_chunk;
604 605
        auto getChunkData = [&, chunk_id](DataType type, FieldId field_id,
                                          int64_t data_barrier) -> std::function<const number(int)> {
606 607
            switch (type) {
                case DataType::BOOL: {
608 609 610 611 612 613 614 615
                    if (chunk_id < data_barrier) {
                        auto chunk_data = segment_.chunk_data<bool>(field_id, chunk_id).data();
                        return [chunk_data](int i) -> const number { return chunk_data[i]; };
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
                        auto& indexing = segment_.chunk_scalar_index<bool>(field_id, chunk_id);
                        return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
                    }
616 617
                }
                case DataType::INT8: {
618 619 620 621 622 623 624 625
                    if (chunk_id < data_barrier) {
                        auto chunk_data = segment_.chunk_data<int8_t>(field_id, chunk_id).data();
                        return [chunk_data](int i) -> const number { return chunk_data[i]; };
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
                        auto& indexing = segment_.chunk_scalar_index<int8_t>(field_id, chunk_id);
                        return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
                    }
626 627
                }
                case DataType::INT16: {
628 629 630 631 632 633 634 635
                    if (chunk_id < data_barrier) {
                        auto chunk_data = segment_.chunk_data<int16_t>(field_id, chunk_id).data();
                        return [chunk_data](int i) -> const number { return chunk_data[i]; };
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
                        auto& indexing = segment_.chunk_scalar_index<int16_t>(field_id, chunk_id);
                        return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
                    }
636 637
                }
                case DataType::INT32: {
638 639 640 641 642 643 644 645
                    if (chunk_id < data_barrier) {
                        auto chunk_data = segment_.chunk_data<int32_t>(field_id, chunk_id).data();
                        return [chunk_data](int i) -> const number { return chunk_data[i]; };
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
                        auto& indexing = segment_.chunk_scalar_index<int32_t>(field_id, chunk_id);
                        return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
                    }
646 647
                }
                case DataType::INT64: {
648 649 650 651 652 653 654 655
                    if (chunk_id < data_barrier) {
                        auto chunk_data = segment_.chunk_data<int64_t>(field_id, chunk_id).data();
                        return [chunk_data](int i) -> const number { return chunk_data[i]; };
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
                        auto& indexing = segment_.chunk_scalar_index<int64_t>(field_id, chunk_id);
                        return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
                    }
656 657
                }
                case DataType::FLOAT: {
658 659 660 661 662 663 664 665
                    if (chunk_id < data_barrier) {
                        auto chunk_data = segment_.chunk_data<float>(field_id, chunk_id).data();
                        return [chunk_data](int i) -> const number { return chunk_data[i]; };
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
                        auto& indexing = segment_.chunk_scalar_index<float>(field_id, chunk_id);
                        return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
                    }
666 667
                }
                case DataType::DOUBLE: {
668 669 670 671 672 673 674 675
                    if (chunk_id < data_barrier) {
                        auto chunk_data = segment_.chunk_data<double>(field_id, chunk_id).data();
                        return [chunk_data](int i) -> const number { return chunk_data[i]; };
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
                        auto& indexing = segment_.chunk_scalar_index<double>(field_id, chunk_id);
                        return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
                    }
676 677
                }
                case DataType::VARCHAR: {
678 679 680 681 682 683 684 685
                    if (chunk_id < data_barrier) {
                        auto chunk_data = segment_.chunk_data<std::string>(field_id, chunk_id).data();
                        return [chunk_data](int i) -> const number { return chunk_data[i]; };
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
                        auto& indexing = segment_.chunk_scalar_index<std::string>(field_id, chunk_id);
                        return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
                    }
686 687 688 689 690
                }
                default:
                    PanicInfo("unsupported datatype");
            }
        };
691 692
        auto left = getChunkData(expr.left_data_type_, expr.left_field_id_, left_data_barrier);
        auto right = getChunkData(expr.right_data_type_, expr.right_field_id_, right_data_barrier);
693

694
        BitsetType bitset(size);
695
        for (int i = 0; i < size; ++i) {
696
            bool is_in = boost::apply_visitor(Relational<decltype(op)>{}, left(i), right(i));
697 698 699 700
            bitset[i] = is_in;
        }
        bitsets.emplace_back(std::move(bitset));
    }
701
    auto final_result = Assemble(bitsets);
702
    AssertInfo(final_result.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
703
    return final_result;
704 705 706 707 708
}

void
ExecExprVisitor::visit(CompareExpr& expr) {
    auto& schema = segment_.get_schema();
709 710
    auto& left_field_meta = schema[expr.left_field_id_];
    auto& right_field_meta = schema[expr.right_field_id_];
711 712 713 714
    AssertInfo(expr.left_data_type_ == left_field_meta.get_data_type(),
               "[ExecExprVisitor]Left data type not equal to left field mata type");
    AssertInfo(expr.right_data_type_ == right_field_meta.get_data_type(),
               "[ExecExprVisitor]right data type not equal to right field mata type");
715

716
    BitsetType res;
717
    switch (expr.op_type_) {
718
        case OpType::Equal: {
719
            res = ExecCompareExprDispatcher(expr, std::equal_to<>{});
720 721 722
            break;
        }
        case OpType::NotEqual: {
723
            res = ExecCompareExprDispatcher(expr, std::not_equal_to<>{});
724 725 726
            break;
        }
        case OpType::GreaterEqual: {
727
            res = ExecCompareExprDispatcher(expr, std::greater_equal<>{});
728 729 730
            break;
        }
        case OpType::GreaterThan: {
731
            res = ExecCompareExprDispatcher(expr, std::greater<>{});
732 733 734
            break;
        }
        case OpType::LessEqual: {
735
            res = ExecCompareExprDispatcher(expr, std::less_equal<>{});
736 737 738
            break;
        }
        case OpType::LessThan: {
739
            res = ExecCompareExprDispatcher(expr, std::less<>{});
740 741
            break;
        }
742 743 744 745 746 747
        case OpType::PrefixMatch: {
            res = ExecCompareExprDispatcher(expr, MatchOp<OpType::PrefixMatch>{});
            break;
        }
            // case OpType::PostfixMatch: {
            // }
748 749 750 751
        default: {
            PanicInfo("unsupported optype");
        }
    }
752
    AssertInfo(res.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
753
    bitset_opt_ = std::move(res);
754 755
}

S
sunby 已提交
756 757
template <typename T>
auto
758
ExecExprVisitor::ExecTermVisitorImpl(TermExpr& expr_raw) -> BitsetType {
S
sunby 已提交
759 760
    auto& expr = static_cast<TermExprImpl<T>&>(expr_raw);
    auto& schema = segment_.get_schema();
761 762 763
    auto primary_filed_id = schema.get_primary_field_id();
    auto field_id = expr_raw.field_id_;
    auto& field_meta = schema[field_id];
764 765

    bool use_pk_index = false;
766 767
    if (primary_filed_id.has_value()) {
        use_pk_index = primary_filed_id.value() == field_id && IsPrimaryKeyDataType(field_meta.get_data_type());
768 769 770 771
    }

    if (use_pk_index) {
        auto id_array = std::make_unique<IdArray>();
772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789
        switch (field_meta.get_data_type()) {
            case DataType::INT64: {
                auto dst_ids = id_array->mutable_int_id();
                for (const auto& id : expr.terms_) {
                    dst_ids->add_data((int64_t&)id);
                }
                break;
            }
            case DataType::VARCHAR: {
                auto dst_ids = id_array->mutable_str_id();
                for (const auto& id : expr.terms_) {
                    dst_ids->add_data((std::string&)id);
                }
                break;
            }
            default: {
                PanicInfo("unsupported type");
            }
790
        }
791

792 793 794 795 796 797 798 799 800 801
        auto [uids, seg_offsets] = segment_.search_ids(*id_array, timestamp_);
        BitsetType bitset(row_count_);
        for (const auto& offset : seg_offsets) {
            auto _offset = (int64_t)offset.get();
            bitset[_offset] = true;
        }
        AssertInfo(bitset.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
        return bitset;
    }

802
    return ExecTermVisitorImplTemplate<T>(expr_raw);
S
sunby 已提交
803 804
}

805 806 807
template <>
auto
ExecExprVisitor::ExecTermVisitorImpl<std::string>(TermExpr& expr_raw) -> BitsetType {
808 809 810 811 812 813
    return ExecTermVisitorImplTemplate<std::string>(expr_raw);
}

template <typename T>
auto
ExecExprVisitor::ExecTermVisitorImplTemplate(TermExpr& expr_raw) -> BitsetType {
814
    auto& expr = static_cast<TermExprImpl<T>&>(expr_raw);
815
    using Index = index::ScalarIndex<T>;
816 817 818 819 820 821 822 823 824 825 826 827 828 829
    const auto& terms = expr.terms_;
    auto n = terms.size();
    std::unordered_set<T> term_set(expr.terms_.begin(), expr.terms_.end());

    auto index_func = [&terms, n](Index* index) { return index->In(n, terms.data()); };
    auto elem_func = [&terms, &term_set](T x) {
        //// terms has already been sorted.
        // return std::binary_search(terms.begin(), terms.end(), x);
        return term_set.find(x) != term_set.end();
    };

    return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
}

830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860
// TODO: bool is so ugly here.
template <>
auto
ExecExprVisitor::ExecTermVisitorImplTemplate<bool>(TermExpr& expr_raw) -> BitsetType {
    using T = bool;
    auto& expr = static_cast<TermExprImpl<T>&>(expr_raw);
    using Index = index::ScalarIndex<T>;
    const auto& terms = expr.terms_;
    auto n = terms.size();
    std::unordered_set<T> term_set(expr.terms_.begin(), expr.terms_.end());

    auto index_func = [&terms, n](Index* index) {
        auto bool_arr_copy = new bool[terms.size()];
        int it = 0;
        for (auto elem : terms) {
            bool_arr_copy[it++] = elem;
        }
        auto bitset = index->In(n, bool_arr_copy);
        delete[] bool_arr_copy;
        return bitset;
    };

    auto elem_func = [&terms, &term_set](T x) {
        //// terms has already been sorted.
        // return std::binary_search(terms.begin(), terms.end(), x);
        return term_set.find(x) != term_set.end();
    };

    return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
}

S
sunby 已提交
861 862
void
ExecExprVisitor::visit(TermExpr& expr) {
863
    auto& field_meta = segment_.get_schema()[expr.field_id_];
864 865
    AssertInfo(expr.data_type_ == field_meta.get_data_type(),
               "[ExecExprVisitor]DataType of expr isn't field_meta data type ");
866
    BitsetType res;
S
sunby 已提交
867 868
    switch (expr.data_type_) {
        case DataType::BOOL: {
869
            res = ExecTermVisitorImpl<bool>(expr);
S
sunby 已提交
870 871 872
            break;
        }
        case DataType::INT8: {
873
            res = ExecTermVisitorImpl<int8_t>(expr);
S
sunby 已提交
874 875 876
            break;
        }
        case DataType::INT16: {
877
            res = ExecTermVisitorImpl<int16_t>(expr);
S
sunby 已提交
878 879 880
            break;
        }
        case DataType::INT32: {
881
            res = ExecTermVisitorImpl<int32_t>(expr);
S
sunby 已提交
882 883 884
            break;
        }
        case DataType::INT64: {
885
            res = ExecTermVisitorImpl<int64_t>(expr);
S
sunby 已提交
886 887 888
            break;
        }
        case DataType::FLOAT: {
889
            res = ExecTermVisitorImpl<float>(expr);
S
sunby 已提交
890 891 892
            break;
        }
        case DataType::DOUBLE: {
893
            res = ExecTermVisitorImpl<double>(expr);
S
sunby 已提交
894 895
            break;
        }
896 897 898 899
        case DataType::VARCHAR: {
            res = ExecTermVisitorImpl<std::string>(expr);
            break;
        }
S
sunby 已提交
900 901 902
        default:
            PanicInfo("unsupported");
    }
903
    AssertInfo(res.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
904
    bitset_opt_ = std::move(res);
S
sunby 已提交
905
}
N
neza2017 已提交
906
}  // namespace milvus::query