ExecExprVisitor.cpp 75.9 KB
Newer Older
F
FluorineDog 已提交
1 2 3 4 5 6 7 8 9 10 11
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License

Y
yah01 已提交
12 13 14
#include "query/generated/ExecExprVisitor.h"

#include <boost/variant.hpp>
15
#include <boost/utility/binary.hpp>
16
#include <ctime>
17
#include <deque>
N
neza2017 已提交
18
#include <optional>
19 20 21
#include <string>
#include <string_view>
#include <type_traits>
22
#include <unordered_set>
23 24
#include <utility>

25
#include "arrow/type_fwd.h"
26 27 28 29
#include "common/Json.h"
#include "common/Types.h"
#include "exceptions/EasyAssert.h"
#include "pb/plan.pb.h"
G
GuoRentong 已提交
30
#include "query/ExprImpl.h"
31
#include "query/Relational.h"
Y
yah01 已提交
32 33
#include "query/Utils.h"
#include "segcore/SegmentGrowingImpl.h"
34
#include "simdjson/error.h"
35
#include "query/PlanProto.h"
N
neza2017 已提交
36 37 38 39 40 41
namespace milvus::query {
// THIS CONTAINS EXTRA BODY FOR VISITOR
// WILL BE USED BY GENERATOR
namespace impl {
class ExecExprVisitor : ExprVisitor {
 public:
Y
yah01 已提交
42 43 44
    ExecExprVisitor(const segcore::SegmentInternalInterface& segment,
                    int64_t row_count,
                    Timestamp timestamp)
45
        : segment_(segment), row_count_(row_count), timestamp_(timestamp) {
N
neza2017 已提交
46
    }
47 48

    BitsetType
N
neza2017 已提交
49
    call_child(Expr& expr) {
Y
yah01 已提交
50 51
        AssertInfo(!bitset_opt_.has_value(),
                   "[ExecExprVisitor]Bitset already has value before accept");
N
neza2017 已提交
52
        expr.accept(*this);
Y
yah01 已提交
53 54
        AssertInfo(bitset_opt_.has_value(),
                   "[ExecExprVisitor]Bitset doesn't have value after accept");
55 56
        auto res = std::move(bitset_opt_);
        bitset_opt_ = std::nullopt;
57
        return std::move(res.value());
N
neza2017 已提交
58 59
    }

G
GuoRentong 已提交
60
 public:
F
FluorineDog 已提交
61
    template <typename T, typename IndexFunc, typename ElementFunc>
G
GuoRentong 已提交
62
    auto
Y
yah01 已提交
63 64 65
    ExecRangeVisitorImpl(FieldId field_id,
                         IndexFunc func,
                         ElementFunc element_func) -> BitsetType;
G
GuoRentong 已提交
66 67 68

    template <typename T>
    auto
69
    ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw) -> BitsetType;
70

71 72
    template <typename T>
    auto
Y
yah01 已提交
73 74
    ExecBinaryArithOpEvalRangeVisitorDispatcher(
        BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType;
75

76 77
    template <typename T>
    auto
78
    ExecBinaryRangeVisitorDispatcher(BinaryRangeExpr& expr_raw) -> BitsetType;
G
GuoRentong 已提交
79

S
sunby 已提交
80 81
    template <typename T>
    auto
82
    ExecTermVisitorImpl(TermExpr& expr_raw) -> BitsetType;
S
sunby 已提交
83

84 85 86 87
    template <typename T>
    auto
    ExecTermVisitorImplTemplate(TermExpr& expr_raw) -> BitsetType;

88 89
    template <typename CmpFunc>
    auto
Y
yah01 已提交
90 91
    ExecCompareExprDispatcher(CompareExpr& expr, CmpFunc cmp_func)
        -> BitsetType;
92

N
neza2017 已提交
93
 private:
94 95
    const segcore::SegmentInternalInterface& segment_;
    int64_t row_count_;
96
    Timestamp timestamp_;
97
    BitsetTypeOpt bitset_opt_;
N
neza2017 已提交
98 99 100 101
};
}  // namespace impl

void
F
FluorineDog 已提交
102 103
ExecExprVisitor::visit(LogicalUnaryExpr& expr) {
    using OpType = LogicalUnaryExpr::OpType;
104
    auto child_res = call_child(*expr.child_);
105
    BitsetType res = std::move(child_res);
106 107 108 109 110 111 112
    switch (expr.op_type_) {
        case OpType::LogicalNot: {
            res.flip();
            break;
        }
        default: {
            PanicInfo("Invalid Unary Op");
F
FluorineDog 已提交
113 114
        }
    }
Y
yah01 已提交
115 116
    AssertInfo(res.size() == row_count_,
               "[ExecExprVisitor]Size of results not equal row count");
117
    bitset_opt_ = std::move(res);
N
neza2017 已提交
118 119 120
}

void
F
FluorineDog 已提交
121 122
ExecExprVisitor::visit(LogicalBinaryExpr& expr) {
    using OpType = LogicalBinaryExpr::OpType;
F
FluorineDog 已提交
123 124
    auto left = call_child(*expr.left_);
    auto right = call_child(*expr.right_);
Y
yah01 已提交
125 126
    AssertInfo(left.size() == right.size(),
               "[ExecExprVisitor]Left size not equal to right size");
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
    auto res = std::move(left);
    switch (expr.op_type_) {
        case OpType::LogicalAnd: {
            res &= right;
            break;
        }
        case OpType::LogicalOr: {
            res |= right;
            break;
        }
        case OpType::LogicalXor: {
            res ^= right;
            break;
        }
        case OpType::LogicalMinus: {
            res -= right;
            break;
        }
        default: {
            PanicInfo("Invalid Binary Op");
        }
    }
Y
yah01 已提交
149 150
    AssertInfo(res.size() == row_count_,
               "[ExecExprVisitor]Size of results not equal row count");
151
    bitset_opt_ = std::move(res);
152
}
F
FluorineDog 已提交
153

154
static auto
155 156
Assemble(const std::deque<BitsetType>& srcs) -> BitsetType {
    BitsetType res;
157

158 159 160 161
    if (srcs.size() == 1) {
        return srcs[0];
    }

162 163 164 165 166 167 168 169 170 171
    int64_t total_size = 0;
    for (auto& chunk : srcs) {
        total_size += chunk.size();
    }
    res.resize(total_size);

    int64_t counter = 0;
    for (auto& chunk : srcs) {
        for (int64_t i = 0; i < chunk.size(); ++i) {
            res[counter + i] = chunk[i];
F
FluorineDog 已提交
172
        }
173
        counter += chunk.size();
F
FluorineDog 已提交
174
    }
175
    return res;
N
neza2017 已提交
176 177
}

178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
void
AppendOneChunk(BitsetType& result, const FixedVector<bool>& chunk_res) {
    // Append a value once instead of BITSET_BLOCK_BIT_SIZE times.
    auto AppendBlock = [&result](const bool* ptr, int n) {
        for (int i = 0; i < n; ++i) {
            BitSetBlockType val = 0;
            // This can use CPU SIMD optimzation
            uint8_t vals[BITSET_BLOCK_SIZE] = {0};
            for (size_t j = 0; j < 8; ++j) {
                for (size_t k = 0; k < BITSET_BLOCK_SIZE; ++k) {
                    vals[k] |= uint8_t(*(ptr + k * 8 + j)) << j;
                }
            }
            for (size_t j = 0; j < BITSET_BLOCK_SIZE; ++j) {
                val |= BitSetBlockType(vals[j]) << (8 * j);
            }
            result.append(val);
            ptr += BITSET_BLOCK_SIZE * 8;
        }
    };
    // Append bit for these bits that can not be union as a block
    // Usually n less than BITSET_BLOCK_BIT_SIZE.
    auto AppendBit = [&result](const bool* ptr, int n) {
        for (int i = 0; i < n; ++i) {
            bool bit = *ptr++;
            result.push_back(bit);
        }
    };

    size_t res_len = result.size();
    size_t chunk_len = chunk_res.size();
    const bool* chunk_ptr = chunk_res.data();

    int n_prefix =
        res_len % BITSET_BLOCK_BIT_SIZE == 0
            ? 0
            : std::min(BITSET_BLOCK_BIT_SIZE - res_len % BITSET_BLOCK_BIT_SIZE,
                       chunk_len);

    AppendBit(chunk_ptr, n_prefix);

    if (n_prefix == chunk_len)
        return;

    size_t n_block = (chunk_len - n_prefix) / BITSET_BLOCK_BIT_SIZE;
    size_t n_suffix = (chunk_len - n_prefix) % BITSET_BLOCK_BIT_SIZE;

    AppendBlock(chunk_ptr + n_prefix, n_block);

    AppendBit(chunk_ptr + n_prefix + n_block * BITSET_BLOCK_BIT_SIZE, n_suffix);

    return;
}

BitsetType
AssembleChunk(const std::vector<FixedVector<bool>>& results) {
    BitsetType assemble_result;
    for (auto& result : results) {
        AppendOneChunk(assemble_result, result);
    }
    return assemble_result;
}

F
FluorineDog 已提交
241
template <typename T, typename IndexFunc, typename ElementFunc>
G
GuoRentong 已提交
242
auto
Y
yah01 已提交
243 244 245
ExecExprVisitor::ExecRangeVisitorImpl(FieldId field_id,
                                      IndexFunc index_func,
                                      ElementFunc element_func) -> BitsetType {
G
GuoRentong 已提交
246
    auto& schema = segment_.get_schema();
247 248
    auto& field_meta = schema[field_id];
    auto indexing_barrier = segment_.num_chunk_index(field_id);
B
BossZou 已提交
249 250
    auto size_per_chunk = segment_.size_per_chunk();
    auto num_chunk = upper_div(row_count_, size_per_chunk);
251
    std::vector<FixedVector<bool>> results;
252 253
    results.reserve(num_chunk);

Y
yah01 已提交
254 255 256
    typedef std::
        conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
            IndexInnerType;
Y
yah01 已提交
257
    using Index = index::ScalarIndex<IndexInnerType>;
F
FluorineDog 已提交
258
    for (auto chunk_id = 0; chunk_id < indexing_barrier; ++chunk_id) {
Y
yah01 已提交
259 260
        const Index& indexing =
            segment_.chunk_scalar_index<IndexInnerType>(field_id, chunk_id);
261 262 263
        // NOTE: knowhere is not const-ready
        // This is a dirty workaround
        auto data = index_func(const_cast<Index*>(&indexing));
264
        AssertInfo(data.size() == size_per_chunk,
Y
yah01 已提交
265
                   "[ExecExprVisitor]Data size not equal to size_per_chunk");
266
        results.emplace_back(std::move(data));
F
FluorineDog 已提交
267
    }
268
    for (auto chunk_id = indexing_barrier; chunk_id < num_chunk; ++chunk_id) {
Y
yah01 已提交
269 270 271
        auto this_size = chunk_id == num_chunk - 1
                             ? row_count_ - chunk_id * size_per_chunk
                             : size_per_chunk;
272
        FixedVector<bool> chunk_res(this_size);
273
        auto chunk = segment_.chunk_data<T>(field_id, chunk_id);
G
GuoRentong 已提交
274
        const T* data = chunk.data();
275
        // Can use CPU SIMD optimazation to speed up
276
        for (int index = 0; index < this_size; ++index) {
277
            chunk_res[index] = element_func(data[index]);
G
GuoRentong 已提交
278
        }
279
        results.emplace_back(std::move(chunk_res));
G
GuoRentong 已提交
280
    }
281
    auto final_result = AssembleChunk(results);
Y
yah01 已提交
282 283
    AssertInfo(final_result.size() == row_count_,
               "[ExecExprVisitor]Final result size not equal to row count");
284
    return final_result;
G
GuoRentong 已提交
285
}
286

287
template <typename T, typename IndexFunc, typename ElementFunc>
288
auto
Y
yah01 已提交
289 290 291
ExecExprVisitor::ExecDataRangeVisitorImpl(FieldId field_id,
                                          IndexFunc index_func,
                                          ElementFunc element_func)
292
    -> BitsetType {
293
    auto& schema = segment_.get_schema();
294
    auto& field_meta = schema[field_id];
295 296
    auto size_per_chunk = segment_.size_per_chunk();
    auto num_chunk = upper_div(row_count_, size_per_chunk);
297 298 299 300
    auto indexing_barrier = segment_.num_chunk_index(field_id);
    auto data_barrier = segment_.num_chunk_data(field_id);
    AssertInfo(std::max(data_barrier, indexing_barrier) == num_chunk,
               "max(data_barrier, index_barrier) not equal to num_chunk");
301
    std::vector<FixedVector<bool>> results;
302
    results.reserve(num_chunk);
303

304 305 306 307 308
    // for growing segment, indexing_barrier will always less than data_barrier
    // so growing segment will always execute expr plan using raw data
    // if sealed segment has loaded raw data on this field, then index_barrier = 0 and data_barrier = 1
    // in this case, sealed segment execute expr plan using raw data
    for (auto chunk_id = 0; chunk_id < data_barrier; ++chunk_id) {
Y
yah01 已提交
309 310 311
        auto this_size = chunk_id == num_chunk - 1
                             ? row_count_ - chunk_id * size_per_chunk
                             : size_per_chunk;
312
        FixedVector<bool> result(this_size);
313
        auto chunk = segment_.chunk_data<T>(field_id, chunk_id);
314 315 316 317
        const T* data = chunk.data();
        for (int index = 0; index < this_size; ++index) {
            result[index] = element_func(data[index]);
        }
318 319 320
        AssertInfo(result.size() == this_size,
                   "[ExecExprVisitor]Chunk result size not equal to "
                   "expected size");
321 322
        results.emplace_back(std::move(result));
    }
323 324 325

    // if sealed segment has loaded scalar index for this field, then index_barrier = 1 and data_barrier = 0
    // in this case, sealed segment execute expr plan using scalar index
Y
yah01 已提交
326 327 328
    typedef std::
        conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
            IndexInnerType;
Y
yah01 已提交
329
    using Index = index::ScalarIndex<IndexInnerType>;
Y
yah01 已提交
330 331 332 333
    for (auto chunk_id = data_barrier; chunk_id < indexing_barrier;
         ++chunk_id) {
        auto& indexing =
            segment_.chunk_scalar_index<IndexInnerType>(field_id, chunk_id);
334
        auto this_size = const_cast<Index*>(&indexing)->Count();
335
        FixedVector<bool> result(this_size);
336 337 338 339 340 341
        for (int offset = 0; offset < this_size; ++offset) {
            result[offset] = index_func(const_cast<Index*>(&indexing), offset);
        }
        results.emplace_back(std::move(result));
    }

342
    auto final_result = AssembleChunk(results);
Y
yah01 已提交
343 344
    AssertInfo(final_result.size() == row_count_,
               "[ExecExprVisitor]Final result size not equal to row count");
345 346 347
    return final_result;
}

G
GuoRentong 已提交
348 349 350 351
#pragma clang diagnostic push
#pragma ide diagnostic ignored "Simplify"
template <typename T>
auto
Y
yah01 已提交
352 353 354 355 356
ExecExprVisitor::ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw)
    -> BitsetType {
    typedef std::
        conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
            IndexInnerType;
Y
yah01 已提交
357 358 359
    using Index = index::ScalarIndex<IndexInnerType>;
    auto& expr = static_cast<UnaryRangeExprImpl<IndexInnerType>&>(expr_raw);

360
    auto op = expr.op_type_;
Y
yah01 已提交
361
    auto val = IndexInnerType(expr.value_);
362
    auto field_id = expr.column_.field_id;
363 364
    switch (op) {
        case OpType::Equal: {
365
            auto index_func = [&](Index* index) { return index->In(1, &val); };
366
            auto elem_func = [&](MayConstRef<T> x) { return (x == val); };
367
            return ExecRangeVisitorImpl<T>(field_id, index_func, elem_func);
368 369
        }
        case OpType::NotEqual: {
370
            auto index_func = [&](Index* index) {
Y
yah01 已提交
371 372
                return index->NotIn(1, &val);
            };
373
            auto elem_func = [&](MayConstRef<T> x) { return (x != val); };
374
            return ExecRangeVisitorImpl<T>(field_id, index_func, elem_func);
375 376
        }
        case OpType::GreaterEqual: {
377
            auto index_func = [&](Index* index) {
Y
yah01 已提交
378 379
                return index->Range(val, OpType::GreaterEqual);
            };
380
            auto elem_func = [&](MayConstRef<T> x) { return (x >= val); };
381
            return ExecRangeVisitorImpl<T>(field_id, index_func, elem_func);
G
GuoRentong 已提交
382
        }
383
        case OpType::GreaterThan: {
384
            auto index_func = [&](Index* index) {
Y
yah01 已提交
385 386
                return index->Range(val, OpType::GreaterThan);
            };
387
            auto elem_func = [&](MayConstRef<T> x) { return (x > val); };
388
            return ExecRangeVisitorImpl<T>(field_id, index_func, elem_func);
389 390
        }
        case OpType::LessEqual: {
391
            auto index_func = [&](Index* index) {
Y
yah01 已提交
392 393
                return index->Range(val, OpType::LessEqual);
            };
394
            auto elem_func = [&](MayConstRef<T> x) { return (x <= val); };
395
            return ExecRangeVisitorImpl<T>(field_id, index_func, elem_func);
396 397
        }
        case OpType::LessThan: {
398
            auto index_func = [&](Index* index) {
Y
yah01 已提交
399 400
                return index->Range(val, OpType::LessThan);
            };
401
            auto elem_func = [&](MayConstRef<T> x) { return (x < val); };
402
            return ExecRangeVisitorImpl<T>(field_id, index_func, elem_func);
403 404
        }
        case OpType::PrefixMatch: {
405
            auto index_func = [&](Index* index) {
P
presburger 已提交
406
                auto dataset = std::make_unique<Dataset>();
407 408
                dataset->Set(milvus::index::OPERATOR_TYPE, OpType::PrefixMatch);
                dataset->Set(milvus::index::PREFIX_VALUE, val);
409 410
                return index->Query(std::move(dataset));
            };
411 412 413
            auto elem_func = [&](MayConstRef<T> x) {
                return Match(x, val, op);
            };
414
            return ExecRangeVisitorImpl<T>(field_id, index_func, elem_func);
415 416
        }
        // TODO: PostfixMatch
417
        default: {
G
GuoRentong 已提交
418 419
            PanicInfo("unsupported range node");
        }
420 421 422 423
    }
}
#pragma clang diagnostic pop

424 425 426 427 428 429 430 431 432
template <typename ExprValueType>
auto
ExecExprVisitor::ExecUnaryRangeVisitorDispatcherJson(UnaryRangeExpr& expr_raw)
    -> BitsetType {
    using Index = index::ScalarIndex<milvus::Json>;
    auto& expr = static_cast<UnaryRangeExprImpl<ExprValueType>&>(expr_raw);

    auto op = expr.op_type_;
    auto val = expr.value_;
433
    auto pointer = milvus::Json::pointer(expr.column_.nested_path);
434
    auto field_id = expr.column_.field_id;
435
    auto index_func = [=](Index* index) { return TargetBitmap{}; };
436 437 438 439 440
    using GetType =
        std::conditional_t<std::is_same_v<ExprValueType, std::string>,
                           std::string_view,
                           ExprValueType>;

441 442 443 444 445 446 447 448 449 450 451
#define UnaryRangeJSONCompare(cmp)                            \
    do {                                                      \
        auto x = json.template at<GetType>(pointer);          \
        if (x.error()) {                                      \
            if constexpr (std::is_same_v<GetType, int64_t>) { \
                auto x = json.template at<double>(pointer);   \
                return !x.error() && (cmp);                   \
            }                                                 \
            return false;                                     \
        }                                                     \
        return (cmp);                                         \
452 453
    } while (false)

454 455 456 457 458 459 460 461 462 463 464
#define UnaryRangeJSONCompareNotEqual(cmp)                    \
    do {                                                      \
        auto x = json.template at<GetType>(pointer);          \
        if (x.error()) {                                      \
            if constexpr (std::is_same_v<GetType, int64_t>) { \
                auto x = json.template at<double>(pointer);   \
                return x.error() || (cmp);                    \
            }                                                 \
            return true;                                      \
        }                                                     \
        return (cmp);                                         \
465 466
    } while (false)

467 468
    switch (op) {
        case OpType::Equal: {
469
            auto elem_func = [&](const milvus::Json& json) {
470
                UnaryRangeJSONCompare(x.value() == val);
471 472 473 474 475
            };
            return ExecRangeVisitorImpl<milvus::Json>(
                field_id, index_func, elem_func);
        }
        case OpType::NotEqual: {
476
            auto elem_func = [&](const milvus::Json& json) {
477
                UnaryRangeJSONCompareNotEqual(x.value() != val);
478 479 480 481 482
            };
            return ExecRangeVisitorImpl<milvus::Json>(
                field_id, index_func, elem_func);
        }
        case OpType::GreaterEqual: {
483
            auto elem_func = [&](const milvus::Json& json) {
484
                UnaryRangeJSONCompare(x.value() >= val);
485 486 487 488 489
            };
            return ExecRangeVisitorImpl<milvus::Json>(
                field_id, index_func, elem_func);
        }
        case OpType::GreaterThan: {
490
            auto elem_func = [&](const milvus::Json& json) {
491
                UnaryRangeJSONCompare(x.value() > val);
492 493 494 495 496
            };
            return ExecRangeVisitorImpl<milvus::Json>(
                field_id, index_func, elem_func);
        }
        case OpType::LessEqual: {
497
            auto elem_func = [&](const milvus::Json& json) {
498
                UnaryRangeJSONCompare(x.value() <= val);
499 500 501 502 503
            };
            return ExecRangeVisitorImpl<milvus::Json>(
                field_id, index_func, elem_func);
        }
        case OpType::LessThan: {
504
            auto elem_func = [&](const milvus::Json& json) {
505
                UnaryRangeJSONCompare(x.value() < val);
506 507 508 509 510
            };
            return ExecRangeVisitorImpl<milvus::Json>(
                field_id, index_func, elem_func);
        }
        case OpType::PrefixMatch: {
511
            auto elem_func = [&](const milvus::Json& json) {
512
                UnaryRangeJSONCompare(Match(ExprValueType(x.value()), val, op));
513 514 515 516 517 518 519 520 521 522 523
            };
            return ExecRangeVisitorImpl<milvus::Json>(
                field_id, index_func, elem_func);
        }
        // TODO: PostfixMatch
        default: {
            PanicInfo("unsupported range node");
        }
    }
}

524 525 526 527
#pragma clang diagnostic push
#pragma ide diagnostic ignored "Simplify"
template <typename T>
auto
Y
yah01 已提交
528 529
ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcher(
    BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType {
530
    auto& expr = static_cast<BinaryArithOpEvalRangeExprImpl<T>&>(expr_raw);
531
    using Index = index::ScalarIndex<T>;
532 533 534 535
    auto arith_op = expr.arith_op_;
    auto right_operand = expr.right_operand_;
    auto op = expr.op_type_;
    auto val = expr.value_;
536
    auto& nested_path = expr.column_.nested_path;
537 538 539 540 541

    switch (op) {
        case OpType::Equal: {
            switch (arith_op) {
                case ArithOpType::Add: {
Y
yah01 已提交
542 543
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
544 545 546
                        auto x = index->Reverse_Lookup(offset);
                        return (x + right_operand) == val;
                    };
547 548 549 550 551 552
                    auto elem_func =
                        [val, right_operand, &nested_path](MayConstRef<T> x) {
                            // visit the nested field
                            // now it must be Json
                            return ((x + right_operand) == val);
                        };
Y
yah01 已提交
553
                    return ExecDataRangeVisitorImpl<T>(
554
                        expr.column_.field_id, index_func, elem_func);
555 556
                }
                case ArithOpType::Sub: {
Y
yah01 已提交
557 558
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
559 560 561
                        auto x = index->Reverse_Lookup(offset);
                        return (x - right_operand) == val;
                    };
562
                    auto elem_func = [val, right_operand](MayConstRef<T> x) {
Y
yah01 已提交
563 564 565
                        return ((x - right_operand) == val);
                    };
                    return ExecDataRangeVisitorImpl<T>(
566
                        expr.column_.field_id, index_func, elem_func);
567 568
                }
                case ArithOpType::Mul: {
Y
yah01 已提交
569 570
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
571 572 573
                        auto x = index->Reverse_Lookup(offset);
                        return (x * right_operand) == val;
                    };
574
                    auto elem_func = [val, right_operand](MayConstRef<T> x) {
Y
yah01 已提交
575 576 577
                        return ((x * right_operand) == val);
                    };
                    return ExecDataRangeVisitorImpl<T>(
578
                        expr.column_.field_id, index_func, elem_func);
579 580
                }
                case ArithOpType::Div: {
Y
yah01 已提交
581 582
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
583 584 585
                        auto x = index->Reverse_Lookup(offset);
                        return (x / right_operand) == val;
                    };
586
                    auto elem_func = [val, right_operand](MayConstRef<T> x) {
Y
yah01 已提交
587 588 589
                        return ((x / right_operand) == val);
                    };
                    return ExecDataRangeVisitorImpl<T>(
590
                        expr.column_.field_id, index_func, elem_func);
591 592
                }
                case ArithOpType::Mod: {
Y
yah01 已提交
593 594
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
595 596 597
                        auto x = index->Reverse_Lookup(offset);
                        return static_cast<T>(fmod(x, right_operand)) == val;
                    };
598
                    auto elem_func = [val, right_operand](MayConstRef<T> x) {
599 600
                        return (static_cast<T>(fmod(x, right_operand)) == val);
                    };
Y
yah01 已提交
601
                    return ExecDataRangeVisitorImpl<T>(
602
                        expr.column_.field_id, index_func, elem_func);
603 604 605 606 607 608 609 610 611
                }
                default: {
                    PanicInfo("unsupported arithmetic operation");
                }
            }
        }
        case OpType::NotEqual: {
            switch (arith_op) {
                case ArithOpType::Add: {
Y
yah01 已提交
612 613
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
614 615 616
                        auto x = index->Reverse_Lookup(offset);
                        return (x + right_operand) != val;
                    };
617
                    auto elem_func = [val, right_operand](MayConstRef<T> x) {
Y
yah01 已提交
618 619 620
                        return ((x + right_operand) != val);
                    };
                    return ExecDataRangeVisitorImpl<T>(
621
                        expr.column_.field_id, index_func, elem_func);
622 623
                }
                case ArithOpType::Sub: {
Y
yah01 已提交
624 625
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
626 627 628
                        auto x = index->Reverse_Lookup(offset);
                        return (x - right_operand) != val;
                    };
629
                    auto elem_func = [val, right_operand](MayConstRef<T> x) {
Y
yah01 已提交
630 631 632
                        return ((x - right_operand) != val);
                    };
                    return ExecDataRangeVisitorImpl<T>(
633
                        expr.column_.field_id, index_func, elem_func);
634 635
                }
                case ArithOpType::Mul: {
Y
yah01 已提交
636 637
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
638 639 640
                        auto x = index->Reverse_Lookup(offset);
                        return (x * right_operand) != val;
                    };
641
                    auto elem_func = [val, right_operand](MayConstRef<T> x) {
Y
yah01 已提交
642 643 644
                        return ((x * right_operand) != val);
                    };
                    return ExecDataRangeVisitorImpl<T>(
645
                        expr.column_.field_id, index_func, elem_func);
646 647
                }
                case ArithOpType::Div: {
Y
yah01 已提交
648 649
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
650 651 652
                        auto x = index->Reverse_Lookup(offset);
                        return (x / right_operand) != val;
                    };
653
                    auto elem_func = [val, right_operand](MayConstRef<T> x) {
Y
yah01 已提交
654 655 656
                        return ((x / right_operand) != val);
                    };
                    return ExecDataRangeVisitorImpl<T>(
657
                        expr.column_.field_id, index_func, elem_func);
658 659
                }
                case ArithOpType::Mod: {
Y
yah01 已提交
660 661
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
662 663 664
                        auto x = index->Reverse_Lookup(offset);
                        return static_cast<T>(fmod(x, right_operand)) != val;
                    };
665
                    auto elem_func = [val, right_operand](MayConstRef<T> x) {
666 667
                        return (static_cast<T>(fmod(x, right_operand)) != val);
                    };
Y
yah01 已提交
668
                    return ExecDataRangeVisitorImpl<T>(
669
                        expr.column_.field_id, index_func, elem_func);
670 671 672 673 674 675 676 677 678 679 680 681 682
                }
                default: {
                    PanicInfo("unsupported arithmetic operation");
                }
            }
        }
        default: {
            PanicInfo("unsupported range node with arithmetic operation");
        }
    }
}
#pragma clang diagnostic pop

683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698
template <typename ExprValueType>
auto
ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson(
    BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType {
    auto& expr =
        static_cast<BinaryArithOpEvalRangeExprImpl<ExprValueType>&>(expr_raw);
    using Index = index::ScalarIndex<milvus::Json>;
    using GetType =
        std::conditional_t<std::is_same_v<ExprValueType, std::string>,
                           std::string_view,
                           ExprValueType>;

    auto arith_op = expr.arith_op_;
    auto right_operand = expr.right_operand_;
    auto op = expr.op_type_;
    auto val = expr.value_;
699
    auto pointer = milvus::Json::pointer(expr.column_.nested_path);
700 701 702 703 704 705 706 707 708 709 710 711

#define BinaryArithRangeJSONCompare(cmp)                      \
    do {                                                      \
        auto x = json.template at<GetType>(pointer);          \
        if (x.error()) {                                      \
            if constexpr (std::is_same_v<GetType, int64_t>) { \
                auto x = json.template at<double>(pointer);   \
                return !x.error() && (cmp);                   \
            }                                                 \
            return false;                                     \
        }                                                     \
        return (cmp);                                         \
712 713
    } while (false)

714 715 716 717 718 719 720 721 722 723 724
#define BinaryArithRangeJSONCompareNotEqual(cmp)              \
    do {                                                      \
        auto x = json.template at<GetType>(pointer);          \
        if (x.error()) {                                      \
            if constexpr (std::is_same_v<GetType, int64_t>) { \
                auto x = json.template at<double>(pointer);   \
                return x.error() || (cmp);                    \
            }                                                 \
            return true;                                      \
        }                                                     \
        return (cmp);                                         \
725 726
    } while (false)

727 728 729 730 731 732 733 734 735
    switch (op) {
        case OpType::Equal: {
            switch (arith_op) {
                case ArithOpType::Add: {
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
                        return false;
                    };
                    auto elem_func = [&](const milvus::Json& json) {
736 737
                        BinaryArithRangeJSONCompare(x.value() + right_operand ==
                                                    val);
738 739 740 741 742 743 744 745 746 747
                    };
                    return ExecDataRangeVisitorImpl<milvus::Json>(
                        expr.column_.field_id, index_func, elem_func);
                }
                case ArithOpType::Sub: {
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
                        return false;
                    };
                    auto elem_func = [&](const milvus::Json& json) {
748 749
                        BinaryArithRangeJSONCompare(x.value() - right_operand ==
                                                    val);
750 751 752 753 754 755 756 757 758 759
                    };
                    return ExecDataRangeVisitorImpl<milvus::Json>(
                        expr.column_.field_id, index_func, elem_func);
                }
                case ArithOpType::Mul: {
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
                        return false;
                    };
                    auto elem_func = [&](const milvus::Json& json) {
760 761
                        BinaryArithRangeJSONCompare(x.value() * right_operand ==
                                                    val);
762 763 764 765 766 767 768 769 770 771
                    };
                    return ExecDataRangeVisitorImpl<milvus::Json>(
                        expr.column_.field_id, index_func, elem_func);
                }
                case ArithOpType::Div: {
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
                        return false;
                    };
                    auto elem_func = [&](const milvus::Json& json) {
772 773
                        BinaryArithRangeJSONCompare(x.value() / right_operand ==
                                                    val);
774 775 776 777 778 779 780 781 782 783
                    };
                    return ExecDataRangeVisitorImpl<milvus::Json>(
                        expr.column_.field_id, index_func, elem_func);
                }
                case ArithOpType::Mod: {
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
                        return false;
                    };
                    auto elem_func = [&](const milvus::Json& json) {
784 785 786
                        BinaryArithRangeJSONCompare(
                            static_cast<ExprValueType>(
                                fmod(x.value(), right_operand)) == val);
787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803
                    };
                    return ExecDataRangeVisitorImpl<milvus::Json>(
                        expr.column_.field_id, index_func, elem_func);
                }
                default: {
                    PanicInfo("unsupported arithmetic operation");
                }
            }
        }
        case OpType::NotEqual: {
            switch (arith_op) {
                case ArithOpType::Add: {
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
                        return false;
                    };
                    auto elem_func = [&](const milvus::Json& json) {
804 805
                        BinaryArithRangeJSONCompareNotEqual(
                            x.value() + right_operand != val);
806 807 808 809 810 811 812 813 814 815
                    };
                    return ExecDataRangeVisitorImpl<milvus::Json>(
                        expr.column_.field_id, index_func, elem_func);
                }
                case ArithOpType::Sub: {
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
                        return false;
                    };
                    auto elem_func = [&](const milvus::Json& json) {
816 817
                        BinaryArithRangeJSONCompareNotEqual(
                            x.value() - right_operand != val);
818 819 820 821 822 823 824 825 826 827
                    };
                    return ExecDataRangeVisitorImpl<milvus::Json>(
                        expr.column_.field_id, index_func, elem_func);
                }
                case ArithOpType::Mul: {
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
                        return false;
                    };
                    auto elem_func = [&](const milvus::Json& json) {
828 829
                        BinaryArithRangeJSONCompareNotEqual(
                            x.value() * right_operand != val);
830 831 832 833 834 835 836 837 838 839
                    };
                    return ExecDataRangeVisitorImpl<milvus::Json>(
                        expr.column_.field_id, index_func, elem_func);
                }
                case ArithOpType::Div: {
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
                        return false;
                    };
                    auto elem_func = [&](const milvus::Json& json) {
840 841
                        BinaryArithRangeJSONCompareNotEqual(
                            x.value() / right_operand != val);
842 843 844 845 846 847 848 849 850 851
                    };
                    return ExecDataRangeVisitorImpl<milvus::Json>(
                        expr.column_.field_id, index_func, elem_func);
                }
                case ArithOpType::Mod: {
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
                        return false;
                    };
                    auto elem_func = [&](const milvus::Json& json) {
852 853 854
                        BinaryArithRangeJSONCompareNotEqual(
                            static_cast<ExprValueType>(
                                fmod(x.value(), right_operand)) != val);
855 856 857 858 859 860 861 862 863 864 865 866 867
                    };
                    return ExecDataRangeVisitorImpl<milvus::Json>(
                        expr.column_.field_id, index_func, elem_func);
                }
                default: {
                    PanicInfo("unsupported arithmetic operation");
                }
            }
        }
        default: {
            PanicInfo("unsupported range node with arithmetic operation");
        }
    }
868
}  // namespace milvus::query
869

870 871 872 873
#pragma clang diagnostic push
#pragma ide diagnostic ignored "Simplify"
template <typename T>
auto
Y
yah01 已提交
874 875 876 877 878
ExecExprVisitor::ExecBinaryRangeVisitorDispatcher(BinaryRangeExpr& expr_raw)
    -> BitsetType {
    typedef std::
        conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
            IndexInnerType;
Y
yah01 已提交
879 880 881
    using Index = index::ScalarIndex<IndexInnerType>;
    auto& expr = static_cast<BinaryRangeExprImpl<IndexInnerType>&>(expr_raw);

882 883
    bool lower_inclusive = expr.lower_inclusive_;
    bool upper_inclusive = expr.upper_inclusive_;
884 885
    IndexInnerType val1 = expr.lower_value_;
    IndexInnerType val2 = expr.upper_value_;
886

887
    auto index_func = [&](Index* index) {
Y
yah01 已提交
888 889
        return index->Range(val1, lower_inclusive, val2, upper_inclusive);
    };
890
    if (lower_inclusive && upper_inclusive) {
891 892 893
        auto elem_func = [val1, val2](MayConstRef<T> x) {
            return (val1 <= x && x <= val2);
        };
894 895
        return ExecRangeVisitorImpl<T>(
            expr.column_.field_id, index_func, elem_func);
896
    } else if (lower_inclusive && !upper_inclusive) {
897 898 899
        auto elem_func = [val1, val2](MayConstRef<T> x) {
            return (val1 <= x && x < val2);
        };
900 901
        return ExecRangeVisitorImpl<T>(
            expr.column_.field_id, index_func, elem_func);
902
    } else if (!lower_inclusive && upper_inclusive) {
903 904 905
        auto elem_func = [val1, val2](MayConstRef<T> x) {
            return (val1 < x && x <= val2);
        };
906 907
        return ExecRangeVisitorImpl<T>(
            expr.column_.field_id, index_func, elem_func);
G
GuoRentong 已提交
908
    } else {
909 910 911
        auto elem_func = [val1, val2](MayConstRef<T> x) {
            return (val1 < x && x < val2);
        };
912 913
        return ExecRangeVisitorImpl<T>(
            expr.column_.field_id, index_func, elem_func);
G
GuoRentong 已提交
914 915 916 917
    }
}
#pragma clang diagnostic pop

918 919 920 921 922 923 924 925 926 927 928 929 930 931 932
template <typename ExprValueType>
auto
ExecExprVisitor::ExecBinaryRangeVisitorDispatcherJson(BinaryRangeExpr& expr_raw)
    -> BitsetType {
    using Index = index::ScalarIndex<milvus::Json>;
    using GetType =
        std::conditional_t<std::is_same_v<ExprValueType, std::string>,
                           std::string_view,
                           ExprValueType>;

    auto& expr = static_cast<BinaryRangeExprImpl<ExprValueType>&>(expr_raw);
    bool lower_inclusive = expr.lower_inclusive_;
    bool upper_inclusive = expr.upper_inclusive_;
    ExprValueType val1 = expr.lower_value_;
    ExprValueType val2 = expr.upper_value_;
933
    auto pointer = milvus::Json::pointer(expr.column_.nested_path);
934 935

    // no json index now
936
    auto index_func = [=](Index* index) { return TargetBitmap{}; };
937

938 939 940 941 942 943 944 945 946 947 948 949 950 951 952
#define BinaryRangeJSONCompare(cmp)                           \
    do {                                                      \
        auto x = json.template at<GetType>(pointer);          \
        if (x.error()) {                                      \
            if constexpr (std::is_same_v<GetType, int64_t>) { \
                auto x = json.template at<double>(pointer);   \
                if (!x.error()) {                             \
                    auto value = x.value();                   \
                    return (cmp);                             \
                }                                             \
            }                                                 \
            return false;                                     \
        }                                                     \
        auto value = x.value();                               \
        return (cmp);                                         \
953 954
    } while (false)

955 956
    if (lower_inclusive && upper_inclusive) {
        auto elem_func = [&](const milvus::Json& json) {
957
            BinaryRangeJSONCompare(val1 <= value && value <= val2);
958 959 960 961 962
        };
        return ExecRangeVisitorImpl<milvus::Json>(
            expr.column_.field_id, index_func, elem_func);
    } else if (lower_inclusive && !upper_inclusive) {
        auto elem_func = [&](const milvus::Json& json) {
963
            BinaryRangeJSONCompare(val1 <= value && value < val2);
964 965 966 967 968
        };
        return ExecRangeVisitorImpl<milvus::Json>(
            expr.column_.field_id, index_func, elem_func);
    } else if (!lower_inclusive && upper_inclusive) {
        auto elem_func = [&](const milvus::Json& json) {
969
            BinaryRangeJSONCompare(val1 < value && value <= val2);
970 971 972 973 974
        };
        return ExecRangeVisitorImpl<milvus::Json>(
            expr.column_.field_id, index_func, elem_func);
    } else {
        auto elem_func = [&](const milvus::Json& json) {
975
            BinaryRangeJSONCompare(val1 < value && value < val2);
976 977 978 979 980 981
        };
        return ExecRangeVisitorImpl<milvus::Json>(
            expr.column_.field_id, index_func, elem_func);
    }
}

N
neza2017 已提交
982
void
983
ExecExprVisitor::visit(UnaryRangeExpr& expr) {
984 985
    auto& field_meta = segment_.get_schema()[expr.column_.field_id];
    AssertInfo(expr.column_.data_type == field_meta.get_data_type(),
986
               "[ExecExprVisitor]DataType of expr isn't field_meta data type");
987
    BitsetType res;
988
    switch (expr.column_.data_type) {
N
neza2017 已提交
989
        case DataType::BOOL: {
990
            res = ExecUnaryRangeVisitorDispatcher<bool>(expr);
N
neza2017 已提交
991 992
            break;
        }
G
GuoRentong 已提交
993
        case DataType::INT8: {
994
            res = ExecUnaryRangeVisitorDispatcher<int8_t>(expr);
G
GuoRentong 已提交
995 996 997
            break;
        }
        case DataType::INT16: {
998
            res = ExecUnaryRangeVisitorDispatcher<int16_t>(expr);
G
GuoRentong 已提交
999 1000 1001
            break;
        }
        case DataType::INT32: {
1002
            res = ExecUnaryRangeVisitorDispatcher<int32_t>(expr);
G
GuoRentong 已提交
1003 1004 1005
            break;
        }
        case DataType::INT64: {
1006
            res = ExecUnaryRangeVisitorDispatcher<int64_t>(expr);
G
GuoRentong 已提交
1007 1008 1009
            break;
        }
        case DataType::FLOAT: {
1010
            res = ExecUnaryRangeVisitorDispatcher<float>(expr);
G
GuoRentong 已提交
1011 1012 1013
            break;
        }
        case DataType::DOUBLE: {
1014 1015 1016
            res = ExecUnaryRangeVisitorDispatcher<double>(expr);
            break;
        }
1017
        case DataType::VARCHAR: {
Y
yah01 已提交
1018 1019 1020 1021 1022
            if (segment_.type() == SegmentType::Growing) {
                res = ExecUnaryRangeVisitorDispatcher<std::string>(expr);
            } else {
                res = ExecUnaryRangeVisitorDispatcher<std::string_view>(expr);
            }
1023 1024
            break;
        }
1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045
        case DataType::JSON: {
            switch (expr.val_case_) {
                case proto::plan::GenericValue::ValCase::kBoolVal:
                    res = ExecUnaryRangeVisitorDispatcherJson<bool>(expr);
                    break;
                case proto::plan::GenericValue::ValCase::kInt64Val:
                    res = ExecUnaryRangeVisitorDispatcherJson<int64_t>(expr);
                    break;
                case proto::plan::GenericValue::ValCase::kFloatVal:
                    res = ExecUnaryRangeVisitorDispatcherJson<double>(expr);
                    break;
                case proto::plan::GenericValue::ValCase::kStringVal:
                    res =
                        ExecUnaryRangeVisitorDispatcherJson<std::string>(expr);
                    break;
                default:
                    PanicInfo(
                        fmt::format("unknown data type: {}", expr.val_case_));
            }
            break;
        }
1046
        default:
1047 1048
            PanicInfo(fmt::format("unsupported data type: {}",
                                  expr.column_.data_type));
1049
    }
Y
yah01 已提交
1050 1051
    AssertInfo(res.size() == row_count_,
               "[ExecExprVisitor]Size of results not equal row count");
1052
    bitset_opt_ = std::move(res);
1053 1054
}

1055 1056
void
ExecExprVisitor::visit(BinaryArithOpEvalRangeExpr& expr) {
1057 1058
    auto& field_meta = segment_.get_schema()[expr.column_.field_id];
    AssertInfo(expr.column_.data_type == field_meta.get_data_type(),
1059 1060
               "[ExecExprVisitor]DataType of expr isn't field_meta data type");
    BitsetType res;
1061
    switch (expr.column_.data_type) {
1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085
        case DataType::INT8: {
            res = ExecBinaryArithOpEvalRangeVisitorDispatcher<int8_t>(expr);
            break;
        }
        case DataType::INT16: {
            res = ExecBinaryArithOpEvalRangeVisitorDispatcher<int16_t>(expr);
            break;
        }
        case DataType::INT32: {
            res = ExecBinaryArithOpEvalRangeVisitorDispatcher<int32_t>(expr);
            break;
        }
        case DataType::INT64: {
            res = ExecBinaryArithOpEvalRangeVisitorDispatcher<int64_t>(expr);
            break;
        }
        case DataType::FLOAT: {
            res = ExecBinaryArithOpEvalRangeVisitorDispatcher<float>(expr);
            break;
        }
        case DataType::DOUBLE: {
            res = ExecBinaryArithOpEvalRangeVisitorDispatcher<double>(expr);
            break;
        }
1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104
        case DataType::JSON: {
            switch (expr.val_case_) {
                case proto::plan::GenericValue::ValCase::kBoolVal: {
                    res = ExecBinaryArithOpEvalRangeVisitorDispatcherJson<bool>(
                        expr);
                    break;
                }
                case proto::plan::GenericValue::ValCase::kInt64Val: {
                    res = ExecBinaryArithOpEvalRangeVisitorDispatcherJson<
                        int64_t>(expr);
                    break;
                }
                case proto::plan::GenericValue::ValCase::kFloatVal: {
                    res =
                        ExecBinaryArithOpEvalRangeVisitorDispatcherJson<double>(
                            expr);
                    break;
                }
                default: {
1105 1106 1107
                    PanicInfo(
                        fmt::format("unsupported value type {} in expression",
                                    expr.val_case_));
1108 1109 1110 1111
                }
            }
            break;
        }
1112
        default:
1113 1114
            PanicInfo(fmt::format("unsupported data type: {}",
                                  expr.column_.data_type));
1115
    }
Y
yah01 已提交
1116 1117
    AssertInfo(res.size() == row_count_,
               "[ExecExprVisitor]Size of results not equal row count");
1118 1119 1120
    bitset_opt_ = std::move(res);
}

1121 1122
void
ExecExprVisitor::visit(BinaryRangeExpr& expr) {
1123 1124
    auto& field_meta = segment_.get_schema()[expr.column_.field_id];
    AssertInfo(expr.column_.data_type == field_meta.get_data_type(),
1125
               "[ExecExprVisitor]DataType of expr isn't field_meta data type");
1126
    BitsetType res;
1127
    switch (expr.column_.data_type) {
1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153
        case DataType::BOOL: {
            res = ExecBinaryRangeVisitorDispatcher<bool>(expr);
            break;
        }
        case DataType::INT8: {
            res = ExecBinaryRangeVisitorDispatcher<int8_t>(expr);
            break;
        }
        case DataType::INT16: {
            res = ExecBinaryRangeVisitorDispatcher<int16_t>(expr);
            break;
        }
        case DataType::INT32: {
            res = ExecBinaryRangeVisitorDispatcher<int32_t>(expr);
            break;
        }
        case DataType::INT64: {
            res = ExecBinaryRangeVisitorDispatcher<int64_t>(expr);
            break;
        }
        case DataType::FLOAT: {
            res = ExecBinaryRangeVisitorDispatcher<float>(expr);
            break;
        }
        case DataType::DOUBLE: {
            res = ExecBinaryRangeVisitorDispatcher<double>(expr);
G
GuoRentong 已提交
1154 1155
            break;
        }
1156
        case DataType::VARCHAR: {
Y
yah01 已提交
1157 1158 1159 1160 1161
            if (segment_.type() == SegmentType::Growing) {
                res = ExecBinaryRangeVisitorDispatcher<std::string>(expr);
            } else {
                res = ExecBinaryRangeVisitorDispatcher<std::string_view>(expr);
            }
1162 1163
            break;
        }
1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183
        case DataType::JSON: {
            switch (expr.val_case_) {
                case proto::plan::GenericValue::ValCase::kBoolVal: {
                    res = ExecBinaryRangeVisitorDispatcherJson<bool>(expr);
                    break;
                }
                case proto::plan::GenericValue::ValCase::kInt64Val: {
                    res = ExecBinaryRangeVisitorDispatcherJson<int64_t>(expr);
                    break;
                }
                case proto::plan::GenericValue::ValCase::kFloatVal: {
                    res = ExecBinaryRangeVisitorDispatcherJson<double>(expr);
                    break;
                }
                case proto::plan::GenericValue::ValCase::kStringVal: {
                    res =
                        ExecBinaryRangeVisitorDispatcherJson<std::string>(expr);
                    break;
                }
                default: {
1184 1185 1186
                    PanicInfo(
                        fmt::format("unsupported value type {} in expression",
                                    expr.val_case_));
1187 1188 1189 1190
                }
            }
            break;
        }
G
GuoRentong 已提交
1191
        default:
1192 1193
            PanicInfo(fmt::format("unsupported data type: {}",
                                  expr.column_.data_type));
G
GuoRentong 已提交
1194
    }
Y
yah01 已提交
1195 1196
    AssertInfo(res.size() == row_count_,
               "[ExecExprVisitor]Size of results not equal row count");
1197
    bitset_opt_ = std::move(res);
N
neza2017 已提交
1198 1199
}

1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213
template <typename Op>
struct relational {
    template <typename T, typename U>
    bool
    operator()(T const& a, U const& b) const {
        return Op{}(a, b);
    }
    template <typename... T>
    bool
    operator()(T const&...) const {
        PanicInfo("incompatible operands");
    }
};

1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244
template <typename T, typename U, typename CmpFunc>
TargetBitmap
ExecExprVisitor::ExecCompareRightType(const T* left_raw_data,
                                      const FieldId& right_field_id,
                                      const int64_t current_chunk_id,
                                      CmpFunc cmp_func) {
    auto size_per_chunk = segment_.size_per_chunk();
    auto num_chunks = upper_div(row_count_, size_per_chunk);
    auto size = current_chunk_id == num_chunks - 1
                    ? row_count_ - current_chunk_id * size_per_chunk
                    : size_per_chunk;

    TargetBitmap result(size);
    const U* right_raw_data =
        segment_.chunk_data<U>(right_field_id, current_chunk_id).data();

    for (int i = 0; i < size; ++i) {
        result[i] = cmp_func(left_raw_data[i], right_raw_data[i]);
    }

    return result;
}

template <typename T, typename CmpFunc>
BitsetType
ExecExprVisitor::ExecCompareLeftType(const FieldId& left_field_id,
                                     const FieldId& right_field_id,
                                     const DataType& right_field_type,
                                     CmpFunc cmp_func) {
    auto size_per_chunk = segment_.size_per_chunk();
    auto num_chunks = upper_div(row_count_, size_per_chunk);
1245 1246
    std::vector<FixedVector<bool>> results;
    results.reserve(num_chunks);
1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337

    for (int64_t chunk_id = 0; chunk_id < num_chunks; ++chunk_id) {
        FixedVector<bool> result;
        const T* left_raw_data =
            segment_.chunk_data<T>(left_field_id, chunk_id).data();

        switch (right_field_type) {
            case DataType::BOOL:
                result = ExecCompareRightType<T, bool, CmpFunc>(
                    left_raw_data, right_field_id, chunk_id, cmp_func);
                break;
            case DataType::INT8:
                result = ExecCompareRightType<T, int8_t, CmpFunc>(
                    left_raw_data, right_field_id, chunk_id, cmp_func);
                break;
            case DataType::INT16:
                result = ExecCompareRightType<T, int16_t, CmpFunc>(
                    left_raw_data, right_field_id, chunk_id, cmp_func);
                break;
            case DataType::INT32:
                result = ExecCompareRightType<T, int32_t, CmpFunc>(
                    left_raw_data, right_field_id, chunk_id, cmp_func);
                break;
            case DataType::INT64:
                result = ExecCompareRightType<T, int64_t, CmpFunc>(
                    left_raw_data, right_field_id, chunk_id, cmp_func);
                break;
            case DataType::FLOAT:
                result = ExecCompareRightType<T, float, CmpFunc>(
                    left_raw_data, right_field_id, chunk_id, cmp_func);
                break;
            case DataType::DOUBLE:
                result = ExecCompareRightType<T, double, CmpFunc>(
                    left_raw_data, right_field_id, chunk_id, cmp_func);
                break;
            default:
                PanicInfo("unsupported left datatype of compare expr");
        }
        results.push_back(result);
    }
    auto final_result = AssembleChunk(results);
    AssertInfo(final_result.size() == row_count_,
               "[ExecExprVisitor]Size of results not equal row count");
    return final_result;
}

template <typename CmpFunc>
BitsetType
ExecExprVisitor::ExecCompareExprDispatcherForNonIndexedSegment(
    CompareExpr& expr, CmpFunc cmp_func) {
    switch (expr.left_data_type_) {
        case DataType::BOOL:
            return ExecCompareLeftType<bool, CmpFunc>(expr.left_field_id_,
                                                      expr.right_field_id_,
                                                      expr.right_data_type_,
                                                      cmp_func);
        case DataType::INT8:
            return ExecCompareLeftType<int8_t, CmpFunc>(expr.left_field_id_,
                                                        expr.right_field_id_,
                                                        expr.right_data_type_,
                                                        cmp_func);
        case DataType::INT16:
            return ExecCompareLeftType<int16_t, CmpFunc>(expr.left_field_id_,
                                                         expr.right_field_id_,
                                                         expr.right_data_type_,
                                                         cmp_func);
        case DataType::INT32:
            return ExecCompareLeftType<int32_t, CmpFunc>(expr.left_field_id_,
                                                         expr.right_field_id_,
                                                         expr.right_data_type_,
                                                         cmp_func);
        case DataType::INT64:
            return ExecCompareLeftType<int64_t, CmpFunc>(expr.left_field_id_,
                                                         expr.right_field_id_,
                                                         expr.right_data_type_,
                                                         cmp_func);
        case DataType::FLOAT:
            return ExecCompareLeftType<float, CmpFunc>(expr.left_field_id_,
                                                       expr.right_field_id_,
                                                       expr.right_data_type_,
                                                       cmp_func);
        case DataType::DOUBLE:
            return ExecCompareLeftType<double, CmpFunc>(expr.left_field_id_,
                                                        expr.right_field_id_,
                                                        expr.right_data_type_,
                                                        cmp_func);
        default:
            PanicInfo("unsupported right datatype of compare expr");
    }
}

1338 1339
template <typename Op>
auto
Y
yah01 已提交
1340 1341 1342 1343 1344 1345 1346 1347 1348 1349
ExecExprVisitor::ExecCompareExprDispatcher(CompareExpr& expr, Op op)
    -> BitsetType {
    using number = boost::variant<bool,
                                  int8_t,
                                  int16_t,
                                  int32_t,
                                  int64_t,
                                  float,
                                  double,
                                  std::string>;
1350 1351 1352 1353 1354
    auto is_string_expr = [&expr]() -> bool {
        return expr.left_data_type_ == DataType::VARCHAR ||
               expr.right_data_type_ == DataType::VARCHAR;
    };

1355 1356
    auto size_per_chunk = segment_.size_per_chunk();
    auto num_chunk = upper_div(row_count_, size_per_chunk);
1357
    std::deque<BitsetType> bitsets;
1358 1359 1360 1361

    // check for sealed segment, load either raw field data or index
    auto left_indexing_barrier = segment_.num_chunk_index(expr.left_field_id_);
    auto left_data_barrier = segment_.num_chunk_data(expr.left_field_id_);
1362 1363 1364
    AssertInfo(std::max(left_data_barrier, left_indexing_barrier) == num_chunk,
               "max(left_data_barrier, left_indexing_barrier) not equal to "
               "num_chunk");
1365

Y
yah01 已提交
1366 1367
    auto right_indexing_barrier =
        segment_.num_chunk_index(expr.right_field_id_);
1368
    auto right_data_barrier = segment_.num_chunk_data(expr.right_field_id_);
Y
yah01 已提交
1369 1370 1371 1372
    AssertInfo(
        std::max(right_data_barrier, right_indexing_barrier) == num_chunk,
        "max(right_data_barrier, right_indexing_barrier) not equal to "
        "num_chunk");
1373

1374 1375 1376 1377 1378 1379 1380 1381
    // For segment both fields has no index, can use SIMD to speed up.
    // Avoiding too much call stack that blocks SIMD.
    if (left_indexing_barrier == 0 && right_indexing_barrier == 0 &&
        !is_string_expr()) {
        return ExecCompareExprDispatcherForNonIndexedSegment<Op>(expr, op);
    }

    // TODO: refactoring the code that contains too much call stack.
1382
    for (int64_t chunk_id = 0; chunk_id < num_chunk; ++chunk_id) {
Y
yah01 已提交
1383 1384 1385 1386 1387 1388
        auto size = chunk_id == num_chunk - 1
                        ? row_count_ - chunk_id * size_per_chunk
                        : size_per_chunk;
        auto getChunkData =
            [&, chunk_id](DataType type, FieldId field_id, int64_t data_barrier)
            -> std::function<const number(int)> {
1389 1390
            switch (type) {
                case DataType::BOOL: {
1391
                    if (chunk_id < data_barrier) {
Y
yah01 已提交
1392 1393 1394 1395 1396 1397
                        auto chunk_data =
                            segment_.chunk_data<bool>(field_id, chunk_id)
                                .data();
                        return [chunk_data](int i) -> const number {
                            return chunk_data[i];
                        };
1398 1399
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
Y
yah01 已提交
1400 1401 1402 1403 1404
                        auto& indexing = segment_.chunk_scalar_index<bool>(
                            field_id, chunk_id);
                        return [&indexing](int i) -> const number {
                            return indexing.Reverse_Lookup(i);
                        };
1405
                    }
1406 1407
                }
                case DataType::INT8: {
1408
                    if (chunk_id < data_barrier) {
Y
yah01 已提交
1409 1410 1411 1412 1413 1414
                        auto chunk_data =
                            segment_.chunk_data<int8_t>(field_id, chunk_id)
                                .data();
                        return [chunk_data](int i) -> const number {
                            return chunk_data[i];
                        };
1415 1416
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
Y
yah01 已提交
1417 1418 1419 1420 1421
                        auto& indexing = segment_.chunk_scalar_index<int8_t>(
                            field_id, chunk_id);
                        return [&indexing](int i) -> const number {
                            return indexing.Reverse_Lookup(i);
                        };
1422
                    }
1423 1424
                }
                case DataType::INT16: {
1425
                    if (chunk_id < data_barrier) {
Y
yah01 已提交
1426 1427 1428 1429 1430 1431
                        auto chunk_data =
                            segment_.chunk_data<int16_t>(field_id, chunk_id)
                                .data();
                        return [chunk_data](int i) -> const number {
                            return chunk_data[i];
                        };
1432 1433
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
Y
yah01 已提交
1434 1435 1436 1437 1438
                        auto& indexing = segment_.chunk_scalar_index<int16_t>(
                            field_id, chunk_id);
                        return [&indexing](int i) -> const number {
                            return indexing.Reverse_Lookup(i);
                        };
1439
                    }
1440 1441
                }
                case DataType::INT32: {
1442
                    if (chunk_id < data_barrier) {
Y
yah01 已提交
1443 1444 1445 1446 1447 1448
                        auto chunk_data =
                            segment_.chunk_data<int32_t>(field_id, chunk_id)
                                .data();
                        return [chunk_data](int i) -> const number {
                            return chunk_data[i];
                        };
1449 1450
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
Y
yah01 已提交
1451 1452 1453 1454 1455
                        auto& indexing = segment_.chunk_scalar_index<int32_t>(
                            field_id, chunk_id);
                        return [&indexing](int i) -> const number {
                            return indexing.Reverse_Lookup(i);
                        };
1456
                    }
1457 1458
                }
                case DataType::INT64: {
1459
                    if (chunk_id < data_barrier) {
Y
yah01 已提交
1460 1461 1462 1463 1464 1465
                        auto chunk_data =
                            segment_.chunk_data<int64_t>(field_id, chunk_id)
                                .data();
                        return [chunk_data](int i) -> const number {
                            return chunk_data[i];
                        };
1466 1467
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
Y
yah01 已提交
1468 1469 1470 1471 1472
                        auto& indexing = segment_.chunk_scalar_index<int64_t>(
                            field_id, chunk_id);
                        return [&indexing](int i) -> const number {
                            return indexing.Reverse_Lookup(i);
                        };
1473
                    }
1474 1475
                }
                case DataType::FLOAT: {
1476
                    if (chunk_id < data_barrier) {
Y
yah01 已提交
1477 1478 1479 1480 1481 1482
                        auto chunk_data =
                            segment_.chunk_data<float>(field_id, chunk_id)
                                .data();
                        return [chunk_data](int i) -> const number {
                            return chunk_data[i];
                        };
1483 1484
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
Y
yah01 已提交
1485 1486 1487 1488 1489
                        auto& indexing = segment_.chunk_scalar_index<float>(
                            field_id, chunk_id);
                        return [&indexing](int i) -> const number {
                            return indexing.Reverse_Lookup(i);
                        };
1490
                    }
1491 1492
                }
                case DataType::DOUBLE: {
1493
                    if (chunk_id < data_barrier) {
Y
yah01 已提交
1494 1495 1496 1497 1498 1499
                        auto chunk_data =
                            segment_.chunk_data<double>(field_id, chunk_id)
                                .data();
                        return [chunk_data](int i) -> const number {
                            return chunk_data[i];
                        };
1500 1501
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
Y
yah01 已提交
1502 1503 1504 1505 1506
                        auto& indexing = segment_.chunk_scalar_index<double>(
                            field_id, chunk_id);
                        return [&indexing](int i) -> const number {
                            return indexing.Reverse_Lookup(i);
                        };
1507
                    }
1508 1509
                }
                case DataType::VARCHAR: {
1510
                    if (chunk_id < data_barrier) {
Y
yah01 已提交
1511
                        if (segment_.type() == SegmentType::Growing) {
Y
yah01 已提交
1512 1513 1514 1515 1516 1517 1518
                            auto chunk_data =
                                segment_
                                    .chunk_data<std::string>(field_id, chunk_id)
                                    .data();
                            return [chunk_data](int i) -> const number {
                                return chunk_data[i];
                            };
Y
yah01 已提交
1519
                        } else {
Y
yah01 已提交
1520 1521 1522 1523 1524 1525 1526
                            auto chunk_data = segment_
                                                  .chunk_data<std::string_view>(
                                                      field_id, chunk_id)
                                                  .data();
                            return [chunk_data](int i) -> const number {
                                return std::string(chunk_data[i]);
                            };
Y
yah01 已提交
1527
                        }
1528 1529
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
Y
yah01 已提交
1530 1531 1532 1533 1534 1535
                        auto& indexing =
                            segment_.chunk_scalar_index<std::string>(field_id,
                                                                     chunk_id);
                        return [&indexing](int i) -> const number {
                            return indexing.Reverse_Lookup(i);
                        };
1536
                    }
1537 1538
                }
                default:
1539
                    PanicInfo(fmt::format("unsupported data type: {}", type));
1540 1541
            }
        };
Y
yah01 已提交
1542 1543 1544 1545
        auto left = getChunkData(
            expr.left_data_type_, expr.left_field_id_, left_data_barrier);
        auto right = getChunkData(
            expr.right_data_type_, expr.right_field_id_, right_data_barrier);
1546

1547
        BitsetType bitset(size);
1548
        for (int i = 0; i < size; ++i) {
Y
yah01 已提交
1549 1550
            bool is_in = boost::apply_visitor(
                Relational<decltype(op)>{}, left(i), right(i));
1551 1552 1553 1554
            bitset[i] = is_in;
        }
        bitsets.emplace_back(std::move(bitset));
    }
1555
    auto final_result = Assemble(bitsets);
Y
yah01 已提交
1556 1557
    AssertInfo(final_result.size() == row_count_,
               "[ExecExprVisitor]Size of results not equal row count");
1558
    return final_result;
1559 1560 1561 1562 1563
}

void
ExecExprVisitor::visit(CompareExpr& expr) {
    auto& schema = segment_.get_schema();
1564 1565
    auto& left_field_meta = schema[expr.left_field_id_];
    auto& right_field_meta = schema[expr.right_field_id_];
1566
    AssertInfo(expr.left_data_type_ == left_field_meta.get_data_type(),
1567 1568
               "[ExecExprVisitor]Left data type not equal to left field "
               "meta type");
1569 1570 1571
    AssertInfo(expr.right_data_type_ == right_field_meta.get_data_type(),
               "[ExecExprVisitor]right data type not equal to right field "
               "meta type");
1572

1573
    BitsetType res;
1574
    switch (expr.op_type_) {
1575
        case OpType::Equal: {
1576
            res = ExecCompareExprDispatcher(expr, std::equal_to<>{});
1577 1578 1579
            break;
        }
        case OpType::NotEqual: {
1580
            res = ExecCompareExprDispatcher(expr, std::not_equal_to<>{});
1581 1582 1583
            break;
        }
        case OpType::GreaterEqual: {
1584
            res = ExecCompareExprDispatcher(expr, std::greater_equal<>{});
1585 1586 1587
            break;
        }
        case OpType::GreaterThan: {
1588
            res = ExecCompareExprDispatcher(expr, std::greater<>{});
1589 1590 1591
            break;
        }
        case OpType::LessEqual: {
1592
            res = ExecCompareExprDispatcher(expr, std::less_equal<>{});
1593 1594 1595
            break;
        }
        case OpType::LessThan: {
1596
            res = ExecCompareExprDispatcher(expr, std::less<>{});
1597 1598
            break;
        }
1599
        case OpType::PrefixMatch: {
Y
yah01 已提交
1600 1601
            res =
                ExecCompareExprDispatcher(expr, MatchOp<OpType::PrefixMatch>{});
1602 1603 1604 1605
            break;
        }
            // case OpType::PostfixMatch: {
            // }
1606 1607 1608 1609
        default: {
            PanicInfo("unsupported optype");
        }
    }
Y
yah01 已提交
1610 1611
    AssertInfo(res.size() == row_count_,
               "[ExecExprVisitor]Size of results not equal row count");
1612
    bitset_opt_ = std::move(res);
1613 1614
}

S
sunby 已提交
1615 1616
template <typename T>
auto
1617
ExecExprVisitor::ExecTermVisitorImpl(TermExpr& expr_raw) -> BitsetType {
S
sunby 已提交
1618 1619
    auto& expr = static_cast<TermExprImpl<T>&>(expr_raw);
    auto& schema = segment_.get_schema();
1620
    auto primary_filed_id = schema.get_primary_field_id();
1621
    auto field_id = expr_raw.column_.field_id;
1622
    auto& field_meta = schema[field_id];
1623 1624

    bool use_pk_index = false;
1625
    if (primary_filed_id.has_value()) {
Y
yah01 已提交
1626 1627
        use_pk_index = primary_filed_id.value() == field_id &&
                       IsPrimaryKeyDataType(field_meta.get_data_type());
1628 1629 1630 1631
    }

    if (use_pk_index) {
        auto id_array = std::make_unique<IdArray>();
1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649
        switch (field_meta.get_data_type()) {
            case DataType::INT64: {
                auto dst_ids = id_array->mutable_int_id();
                for (const auto& id : expr.terms_) {
                    dst_ids->add_data((int64_t&)id);
                }
                break;
            }
            case DataType::VARCHAR: {
                auto dst_ids = id_array->mutable_str_id();
                for (const auto& id : expr.terms_) {
                    dst_ids->add_data((std::string&)id);
                }
                break;
            }
            default: {
                PanicInfo("unsupported type");
            }
1650
        }
1651

1652 1653 1654 1655 1656 1657
        auto [uids, seg_offsets] = segment_.search_ids(*id_array, timestamp_);
        BitsetType bitset(row_count_);
        for (const auto& offset : seg_offsets) {
            auto _offset = (int64_t)offset.get();
            bitset[_offset] = true;
        }
Y
yah01 已提交
1658 1659
        AssertInfo(bitset.size() == row_count_,
                   "[ExecExprVisitor]Size of results not equal row count");
1660 1661 1662
        return bitset;
    }

1663
    return ExecTermVisitorImplTemplate<T>(expr_raw);
S
sunby 已提交
1664 1665
}

1666 1667
template <>
auto
Y
yah01 已提交
1668 1669
ExecExprVisitor::ExecTermVisitorImpl<std::string>(TermExpr& expr_raw)
    -> BitsetType {
1670 1671 1672
    return ExecTermVisitorImplTemplate<std::string>(expr_raw);
}

Y
yah01 已提交
1673 1674
template <>
auto
Y
yah01 已提交
1675 1676
ExecExprVisitor::ExecTermVisitorImpl<std::string_view>(TermExpr& expr_raw)
    -> BitsetType {
Y
yah01 已提交
1677 1678 1679
    return ExecTermVisitorImplTemplate<std::string_view>(expr_raw);
}

1680 1681 1682
template <typename T>
auto
ExecExprVisitor::ExecTermVisitorImplTemplate(TermExpr& expr_raw) -> BitsetType {
Y
yah01 已提交
1683 1684 1685
    typedef std::
        conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
            IndexInnerType;
Y
yah01 已提交
1686 1687
    using Index = index::ScalarIndex<IndexInnerType>;
    auto& expr = static_cast<TermExprImpl<IndexInnerType>&>(expr_raw);
Y
yah01 已提交
1688 1689
    const std::vector<IndexInnerType> terms(expr.terms_.begin(),
                                            expr.terms_.end());
1690 1691 1692
    auto n = terms.size();
    std::unordered_set<T> term_set(expr.terms_.begin(), expr.terms_.end());

Y
yah01 已提交
1693 1694 1695
    auto index_func = [&terms, n](Index* index) {
        return index->In(n, terms.data());
    };
1696
    auto elem_func = [&terms, &term_set](MayConstRef<T> x) {
1697 1698 1699 1700 1701
        //// terms has already been sorted.
        // return std::binary_search(terms.begin(), terms.end(), x);
        return term_set.find(x) != term_set.end();
    };

1702 1703
    return ExecRangeVisitorImpl<T>(
        expr.column_.field_id, index_func, elem_func);
1704 1705
}

1706 1707 1708
// TODO: bool is so ugly here.
template <>
auto
Y
yah01 已提交
1709 1710
ExecExprVisitor::ExecTermVisitorImplTemplate<bool>(TermExpr& expr_raw)
    -> BitsetType {
1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728
    using T = bool;
    auto& expr = static_cast<TermExprImpl<T>&>(expr_raw);
    using Index = index::ScalarIndex<T>;
    const auto& terms = expr.terms_;
    auto n = terms.size();
    std::unordered_set<T> term_set(expr.terms_.begin(), expr.terms_.end());

    auto index_func = [&terms, n](Index* index) {
        auto bool_arr_copy = new bool[terms.size()];
        int it = 0;
        for (auto elem : terms) {
            bool_arr_copy[it++] = elem;
        }
        auto bitset = index->In(n, bool_arr_copy);
        delete[] bool_arr_copy;
        return bitset;
    };

1729
    auto elem_func = [&terms, &term_set](MayConstRef<T> x) {
1730 1731 1732 1733 1734
        //// terms has already been sorted.
        // return std::binary_search(terms.begin(), terms.end(), x);
        return term_set.find(x) != term_set.end();
    };

1735 1736 1737 1738 1739 1740 1741 1742 1743 1744
    return ExecRangeVisitorImpl<T>(
        expr.column_.field_id, index_func, elem_func);
}

template <typename ExprValueType>
auto
ExecExprVisitor::ExecTermVisitorImplTemplateJson(TermExpr& expr_raw)
    -> BitsetType {
    using Index = index::ScalarIndex<milvus::Json>;
    auto& expr = static_cast<TermExprImpl<ExprValueType>&>(expr_raw);
1745
    auto pointer = milvus::Json::pointer(expr.column_.nested_path);
1746
    auto index_func = [=](Index* index) { return TargetBitmap{}; };
1747 1748 1749 1750 1751 1752 1753 1754 1755 1756

    std::unordered_set<ExprValueType> term_set(expr.terms_.begin(),
                                               expr.terms_.end());

    if (term_set.empty()) {
        auto elem_func = [=](const milvus::Json& json) { return false; };
        return ExecRangeVisitorImpl<milvus::Json>(
            expr.column_.field_id, index_func, elem_func);
    }

1757
    auto elem_func = [&term_set, &pointer](const milvus::Json& json) {
1758 1759 1760 1761
        using GetType =
            std::conditional_t<std::is_same_v<ExprValueType, std::string>,
                               std::string_view,
                               ExprValueType>;
1762
        auto x = json.template at<GetType>(pointer);
1763 1764 1765 1766 1767 1768 1769 1770
        if (x.error()) {
            return false;
        }
        return term_set.find(ExprValueType(x.value())) != term_set.end();
    };

    return ExecRangeVisitorImpl<milvus::Json>(
        expr.column_.field_id, index_func, elem_func);
1771 1772
}

S
sunby 已提交
1773 1774
void
ExecExprVisitor::visit(TermExpr& expr) {
1775 1776
    auto& field_meta = segment_.get_schema()[expr.column_.field_id];
    AssertInfo(expr.column_.data_type == field_meta.get_data_type(),
1777 1778
               "[ExecExprVisitor]DataType of expr isn't field_meta "
               "data type ");
1779
    BitsetType res;
1780
    switch (expr.column_.data_type) {
S
sunby 已提交
1781
        case DataType::BOOL: {
1782
            res = ExecTermVisitorImpl<bool>(expr);
S
sunby 已提交
1783 1784 1785
            break;
        }
        case DataType::INT8: {
1786
            res = ExecTermVisitorImpl<int8_t>(expr);
S
sunby 已提交
1787 1788 1789
            break;
        }
        case DataType::INT16: {
1790
            res = ExecTermVisitorImpl<int16_t>(expr);
S
sunby 已提交
1791 1792 1793
            break;
        }
        case DataType::INT32: {
1794
            res = ExecTermVisitorImpl<int32_t>(expr);
S
sunby 已提交
1795 1796 1797
            break;
        }
        case DataType::INT64: {
1798
            res = ExecTermVisitorImpl<int64_t>(expr);
S
sunby 已提交
1799 1800 1801
            break;
        }
        case DataType::FLOAT: {
1802
            res = ExecTermVisitorImpl<float>(expr);
S
sunby 已提交
1803 1804 1805
            break;
        }
        case DataType::DOUBLE: {
1806
            res = ExecTermVisitorImpl<double>(expr);
S
sunby 已提交
1807 1808
            break;
        }
1809
        case DataType::VARCHAR: {
Y
yah01 已提交
1810 1811 1812 1813 1814
            if (segment_.type() == SegmentType::Growing) {
                res = ExecTermVisitorImpl<std::string>(expr);
            } else {
                res = ExecTermVisitorImpl<std::string_view>(expr);
            }
1815 1816
            break;
        }
1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839
        case DataType::JSON: {
            switch (expr.val_case_) {
                case proto::plan::GenericValue::ValCase::kBoolVal:
                    res = ExecTermVisitorImplTemplateJson<bool>(expr);
                    break;
                case proto::plan::GenericValue::ValCase::kInt64Val:
                    res = ExecTermVisitorImplTemplateJson<int64_t>(expr);
                    break;
                case proto::plan::GenericValue::ValCase::kFloatVal:
                    res = ExecTermVisitorImplTemplateJson<double>(expr);
                    break;
                case proto::plan::GenericValue::ValCase::kStringVal:
                    res = ExecTermVisitorImplTemplateJson<std::string>(expr);
                    break;
                case proto::plan::GenericValue::ValCase::VAL_NOT_SET:
                    res = ExecTermVisitorImplTemplateJson<bool>(expr);
                    break;
                default:
                    PanicInfo(
                        fmt::format("unknown data type: {}", expr.val_case_));
            }
            break;
        }
S
sunby 已提交
1840
        default:
1841 1842
            PanicInfo(fmt::format("unsupported data type: {}",
                                  expr.column_.data_type));
S
sunby 已提交
1843
    }
Y
yah01 已提交
1844 1845
    AssertInfo(res.size() == row_count_,
               "[ExecExprVisitor]Size of results not equal row count");
1846
    bitset_opt_ = std::move(res);
S
sunby 已提交
1847
}
1848 1849 1850 1851 1852 1853 1854

void
ExecExprVisitor::visit(ExistsExpr& expr) {
    auto& field_meta = segment_.get_schema()[expr.column_.field_id];
    AssertInfo(expr.column_.data_type == field_meta.get_data_type(),
               "[ExecExprVisitor]DataType of expr isn't field_meta data type");
    BitsetType res;
1855
    auto pointer = milvus::Json::pointer(expr.column_.nested_path);
1856 1857 1858
    switch (expr.column_.data_type) {
        case DataType::JSON: {
            using Index = index::ScalarIndex<milvus::Json>;
1859 1860 1861
            auto index_func = [&](Index* index) { return TargetBitmap{}; };
            auto elem_func = [&](const milvus::Json& json) {
                auto x = json.exist(pointer);
1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876
                return x;
            };
            res = ExecRangeVisitorImpl<milvus::Json>(
                expr.column_.field_id, index_func, elem_func);
            break;
        }
        default:
            PanicInfo(fmt::format("unsupported data type {}",
                                  expr.column_.data_type));
    }
    AssertInfo(res.size() == row_count_,
               "[ExecExprVisitor]Size of results not equal row count");
    bitset_opt_ = std::move(res);
}

N
neza2017 已提交
1877
}  // namespace milvus::query