ExecExprVisitor.cpp 80.7 KB
Newer Older
F
FluorineDog 已提交
1 2 3 4 5 6 7 8 9 10 11
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License

Y
yah01 已提交
12 13 14
#include "query/generated/ExecExprVisitor.h"

#include <boost/variant.hpp>
15
#include <boost/utility/binary.hpp>
16 17
#include <cmath>
#include <cstdint>
18
#include <ctime>
19
#include <deque>
N
neza2017 已提交
20
#include <optional>
21 22 23
#include <string>
#include <string_view>
#include <type_traits>
24
#include <unordered_set>
25 26
#include <utility>

27
#include "arrow/type_fwd.h"
28 29 30 31
#include "common/Json.h"
#include "common/Types.h"
#include "exceptions/EasyAssert.h"
#include "pb/plan.pb.h"
G
GuoRentong 已提交
32
#include "query/ExprImpl.h"
33
#include "query/Relational.h"
Y
yah01 已提交
34 35
#include "query/Utils.h"
#include "segcore/SegmentGrowingImpl.h"
36
#include "simdjson/error.h"
37
#include "query/PlanProto.h"
N
neza2017 已提交
38 39 40 41 42 43
namespace milvus::query {
// THIS CONTAINS EXTRA BODY FOR VISITOR
// WILL BE USED BY GENERATOR
namespace impl {
class ExecExprVisitor : ExprVisitor {
 public:
Y
yah01 已提交
44 45 46
    ExecExprVisitor(const segcore::SegmentInternalInterface& segment,
                    int64_t row_count,
                    Timestamp timestamp)
47
        : segment_(segment), row_count_(row_count), timestamp_(timestamp) {
N
neza2017 已提交
48
    }
49 50

    BitsetType
N
neza2017 已提交
51
    call_child(Expr& expr) {
Y
yah01 已提交
52 53
        AssertInfo(!bitset_opt_.has_value(),
                   "[ExecExprVisitor]Bitset already has value before accept");
N
neza2017 已提交
54
        expr.accept(*this);
Y
yah01 已提交
55 56
        AssertInfo(bitset_opt_.has_value(),
                   "[ExecExprVisitor]Bitset doesn't have value after accept");
57 58
        auto res = std::move(bitset_opt_);
        bitset_opt_ = std::nullopt;
59
        return std::move(res.value());
N
neza2017 已提交
60 61
    }

G
GuoRentong 已提交
62
 public:
F
FluorineDog 已提交
63
    template <typename T, typename IndexFunc, typename ElementFunc>
G
GuoRentong 已提交
64
    auto
Y
yah01 已提交
65 66 67
    ExecRangeVisitorImpl(FieldId field_id,
                         IndexFunc func,
                         ElementFunc element_func) -> BitsetType;
G
GuoRentong 已提交
68

69 70 71 72
    template <typename T>
    auto
    ExecUnaryRangeVisitorDispatcherImpl(UnaryRangeExpr& expr_raw) -> BitsetType;

G
GuoRentong 已提交
73 74
    template <typename T>
    auto
75
    ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw) -> BitsetType;
76

77 78
    template <typename T>
    auto
Y
yah01 已提交
79 80
    ExecBinaryArithOpEvalRangeVisitorDispatcher(
        BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType;
81

82 83
    template <typename T>
    auto
84
    ExecBinaryRangeVisitorDispatcher(BinaryRangeExpr& expr_raw) -> BitsetType;
G
GuoRentong 已提交
85

S
sunby 已提交
86 87
    template <typename T>
    auto
88
    ExecTermVisitorImpl(TermExpr& expr_raw) -> BitsetType;
S
sunby 已提交
89

90 91 92 93
    template <typename T>
    auto
    ExecTermVisitorImplTemplate(TermExpr& expr_raw) -> BitsetType;

94 95
    template <typename CmpFunc>
    auto
Y
yah01 已提交
96 97
    ExecCompareExprDispatcher(CompareExpr& expr, CmpFunc cmp_func)
        -> BitsetType;
98

N
neza2017 已提交
99
 private:
100 101
    const segcore::SegmentInternalInterface& segment_;
    int64_t row_count_;
102
    Timestamp timestamp_;
103
    BitsetTypeOpt bitset_opt_;
N
neza2017 已提交
104 105 106 107
};
}  // namespace impl

void
F
FluorineDog 已提交
108 109
ExecExprVisitor::visit(LogicalUnaryExpr& expr) {
    using OpType = LogicalUnaryExpr::OpType;
110
    auto child_res = call_child(*expr.child_);
111
    BitsetType res = std::move(child_res);
112 113 114 115 116 117 118
    switch (expr.op_type_) {
        case OpType::LogicalNot: {
            res.flip();
            break;
        }
        default: {
            PanicInfo("Invalid Unary Op");
F
FluorineDog 已提交
119 120
        }
    }
Y
yah01 已提交
121 122
    AssertInfo(res.size() == row_count_,
               "[ExecExprVisitor]Size of results not equal row count");
123
    bitset_opt_ = std::move(res);
N
neza2017 已提交
124 125 126
}

void
F
FluorineDog 已提交
127 128
ExecExprVisitor::visit(LogicalBinaryExpr& expr) {
    using OpType = LogicalBinaryExpr::OpType;
F
FluorineDog 已提交
129 130
    auto left = call_child(*expr.left_);
    auto right = call_child(*expr.right_);
Y
yah01 已提交
131 132
    AssertInfo(left.size() == right.size(),
               "[ExecExprVisitor]Left size not equal to right size");
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
    auto res = std::move(left);
    switch (expr.op_type_) {
        case OpType::LogicalAnd: {
            res &= right;
            break;
        }
        case OpType::LogicalOr: {
            res |= right;
            break;
        }
        case OpType::LogicalXor: {
            res ^= right;
            break;
        }
        case OpType::LogicalMinus: {
            res -= right;
            break;
        }
        default: {
            PanicInfo("Invalid Binary Op");
        }
    }
Y
yah01 已提交
155 156
    AssertInfo(res.size() == row_count_,
               "[ExecExprVisitor]Size of results not equal row count");
157
    bitset_opt_ = std::move(res);
158
}
F
FluorineDog 已提交
159

160
static auto
161 162
Assemble(const std::deque<BitsetType>& srcs) -> BitsetType {
    BitsetType res;
163

164 165 166 167
    if (srcs.size() == 1) {
        return srcs[0];
    }

168 169 170 171 172 173 174 175 176 177
    int64_t total_size = 0;
    for (auto& chunk : srcs) {
        total_size += chunk.size();
    }
    res.resize(total_size);

    int64_t counter = 0;
    for (auto& chunk : srcs) {
        for (int64_t i = 0; i < chunk.size(); ++i) {
            res[counter + i] = chunk[i];
F
FluorineDog 已提交
178
        }
179
        counter += chunk.size();
F
FluorineDog 已提交
180
    }
181
    return res;
N
neza2017 已提交
182 183
}

184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
void
AppendOneChunk(BitsetType& result, const FixedVector<bool>& chunk_res) {
    // Append a value once instead of BITSET_BLOCK_BIT_SIZE times.
    auto AppendBlock = [&result](const bool* ptr, int n) {
        for (int i = 0; i < n; ++i) {
            BitSetBlockType val = 0;
            // This can use CPU SIMD optimzation
            uint8_t vals[BITSET_BLOCK_SIZE] = {0};
            for (size_t j = 0; j < 8; ++j) {
                for (size_t k = 0; k < BITSET_BLOCK_SIZE; ++k) {
                    vals[k] |= uint8_t(*(ptr + k * 8 + j)) << j;
                }
            }
            for (size_t j = 0; j < BITSET_BLOCK_SIZE; ++j) {
                val |= BitSetBlockType(vals[j]) << (8 * j);
            }
            result.append(val);
            ptr += BITSET_BLOCK_SIZE * 8;
        }
    };
    // Append bit for these bits that can not be union as a block
    // Usually n less than BITSET_BLOCK_BIT_SIZE.
    auto AppendBit = [&result](const bool* ptr, int n) {
        for (int i = 0; i < n; ++i) {
            bool bit = *ptr++;
            result.push_back(bit);
        }
    };

    size_t res_len = result.size();
    size_t chunk_len = chunk_res.size();
    const bool* chunk_ptr = chunk_res.data();

    int n_prefix =
        res_len % BITSET_BLOCK_BIT_SIZE == 0
            ? 0
            : std::min(BITSET_BLOCK_BIT_SIZE - res_len % BITSET_BLOCK_BIT_SIZE,
                       chunk_len);

    AppendBit(chunk_ptr, n_prefix);

    if (n_prefix == chunk_len)
        return;

    size_t n_block = (chunk_len - n_prefix) / BITSET_BLOCK_BIT_SIZE;
    size_t n_suffix = (chunk_len - n_prefix) % BITSET_BLOCK_BIT_SIZE;

    AppendBlock(chunk_ptr + n_prefix, n_block);

    AppendBit(chunk_ptr + n_prefix + n_block * BITSET_BLOCK_BIT_SIZE, n_suffix);

    return;
}

BitsetType
AssembleChunk(const std::vector<FixedVector<bool>>& results) {
    BitsetType assemble_result;
    for (auto& result : results) {
        AppendOneChunk(assemble_result, result);
    }
    return assemble_result;
}

F
FluorineDog 已提交
247
template <typename T, typename IndexFunc, typename ElementFunc>
G
GuoRentong 已提交
248
auto
Y
yah01 已提交
249 250 251
ExecExprVisitor::ExecRangeVisitorImpl(FieldId field_id,
                                      IndexFunc index_func,
                                      ElementFunc element_func) -> BitsetType {
G
GuoRentong 已提交
252
    auto& schema = segment_.get_schema();
253 254
    auto& field_meta = schema[field_id];
    auto indexing_barrier = segment_.num_chunk_index(field_id);
B
BossZou 已提交
255 256
    auto size_per_chunk = segment_.size_per_chunk();
    auto num_chunk = upper_div(row_count_, size_per_chunk);
257
    std::vector<FixedVector<bool>> results;
258 259
    results.reserve(num_chunk);

Y
yah01 已提交
260 261 262
    typedef std::
        conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
            IndexInnerType;
Y
yah01 已提交
263
    using Index = index::ScalarIndex<IndexInnerType>;
F
FluorineDog 已提交
264
    for (auto chunk_id = 0; chunk_id < indexing_barrier; ++chunk_id) {
Y
yah01 已提交
265 266
        const Index& indexing =
            segment_.chunk_scalar_index<IndexInnerType>(field_id, chunk_id);
267 268 269
        // NOTE: knowhere is not const-ready
        // This is a dirty workaround
        auto data = index_func(const_cast<Index*>(&indexing));
270
        AssertInfo(data.size() == size_per_chunk,
Y
yah01 已提交
271
                   "[ExecExprVisitor]Data size not equal to size_per_chunk");
272
        results.emplace_back(std::move(data));
F
FluorineDog 已提交
273
    }
274
    for (auto chunk_id = indexing_barrier; chunk_id < num_chunk; ++chunk_id) {
Y
yah01 已提交
275 276 277
        auto this_size = chunk_id == num_chunk - 1
                             ? row_count_ - chunk_id * size_per_chunk
                             : size_per_chunk;
278
        FixedVector<bool> chunk_res(this_size);
279
        auto chunk = segment_.chunk_data<T>(field_id, chunk_id);
G
GuoRentong 已提交
280
        const T* data = chunk.data();
281
        // Can use CPU SIMD optimazation to speed up
282
        for (int index = 0; index < this_size; ++index) {
283
            chunk_res[index] = element_func(data[index]);
G
GuoRentong 已提交
284
        }
285
        results.emplace_back(std::move(chunk_res));
G
GuoRentong 已提交
286
    }
287
    auto final_result = AssembleChunk(results);
Y
yah01 已提交
288 289
    AssertInfo(final_result.size() == row_count_,
               "[ExecExprVisitor]Final result size not equal to row count");
290
    return final_result;
G
GuoRentong 已提交
291
}
292

293
template <typename T, typename IndexFunc, typename ElementFunc>
294
auto
Y
yah01 已提交
295 296 297
ExecExprVisitor::ExecDataRangeVisitorImpl(FieldId field_id,
                                          IndexFunc index_func,
                                          ElementFunc element_func)
298
    -> BitsetType {
299
    auto& schema = segment_.get_schema();
300
    auto& field_meta = schema[field_id];
301 302
    auto size_per_chunk = segment_.size_per_chunk();
    auto num_chunk = upper_div(row_count_, size_per_chunk);
303 304 305 306
    auto indexing_barrier = segment_.num_chunk_index(field_id);
    auto data_barrier = segment_.num_chunk_data(field_id);
    AssertInfo(std::max(data_barrier, indexing_barrier) == num_chunk,
               "max(data_barrier, index_barrier) not equal to num_chunk");
307
    std::vector<FixedVector<bool>> results;
308
    results.reserve(num_chunk);
309

310 311 312 313 314
    // for growing segment, indexing_barrier will always less than data_barrier
    // so growing segment will always execute expr plan using raw data
    // if sealed segment has loaded raw data on this field, then index_barrier = 0 and data_barrier = 1
    // in this case, sealed segment execute expr plan using raw data
    for (auto chunk_id = 0; chunk_id < data_barrier; ++chunk_id) {
Y
yah01 已提交
315 316 317
        auto this_size = chunk_id == num_chunk - 1
                             ? row_count_ - chunk_id * size_per_chunk
                             : size_per_chunk;
318
        FixedVector<bool> result(this_size);
319
        auto chunk = segment_.chunk_data<T>(field_id, chunk_id);
320 321 322 323
        const T* data = chunk.data();
        for (int index = 0; index < this_size; ++index) {
            result[index] = element_func(data[index]);
        }
324 325 326
        AssertInfo(result.size() == this_size,
                   "[ExecExprVisitor]Chunk result size not equal to "
                   "expected size");
327 328
        results.emplace_back(std::move(result));
    }
329 330 331

    // if sealed segment has loaded scalar index for this field, then index_barrier = 1 and data_barrier = 0
    // in this case, sealed segment execute expr plan using scalar index
Y
yah01 已提交
332 333 334
    typedef std::
        conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
            IndexInnerType;
Y
yah01 已提交
335
    using Index = index::ScalarIndex<IndexInnerType>;
Y
yah01 已提交
336 337 338 339
    for (auto chunk_id = data_barrier; chunk_id < indexing_barrier;
         ++chunk_id) {
        auto& indexing =
            segment_.chunk_scalar_index<IndexInnerType>(field_id, chunk_id);
340
        auto this_size = const_cast<Index*>(&indexing)->Count();
341
        FixedVector<bool> result(this_size);
342 343 344 345 346 347
        for (int offset = 0; offset < this_size; ++offset) {
            result[offset] = index_func(const_cast<Index*>(&indexing), offset);
        }
        results.emplace_back(std::move(result));
    }

348
    auto final_result = AssembleChunk(results);
Y
yah01 已提交
349 350
    AssertInfo(final_result.size() == row_count_,
               "[ExecExprVisitor]Final result size not equal to row count");
351 352 353
    return final_result;
}

G
GuoRentong 已提交
354 355 356 357
#pragma clang diagnostic push
#pragma ide diagnostic ignored "Simplify"
template <typename T>
auto
358
ExecExprVisitor::ExecUnaryRangeVisitorDispatcherImpl(UnaryRangeExpr& expr_raw)
Y
yah01 已提交
359 360 361 362
    -> BitsetType {
    typedef std::
        conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
            IndexInnerType;
Y
yah01 已提交
363 364 365
    using Index = index::ScalarIndex<IndexInnerType>;
    auto& expr = static_cast<UnaryRangeExprImpl<IndexInnerType>&>(expr_raw);

366
    auto op = expr.op_type_;
Y
yah01 已提交
367
    auto val = IndexInnerType(expr.value_);
368
    auto field_id = expr.column_.field_id;
369 370
    switch (op) {
        case OpType::Equal: {
371
            auto index_func = [&](Index* index) { return index->In(1, &val); };
372
            auto elem_func = [&](MayConstRef<T> x) { return (x == val); };
373
            return ExecRangeVisitorImpl<T>(field_id, index_func, elem_func);
374 375
        }
        case OpType::NotEqual: {
376
            auto index_func = [&](Index* index) {
Y
yah01 已提交
377 378
                return index->NotIn(1, &val);
            };
379
            auto elem_func = [&](MayConstRef<T> x) { return (x != val); };
380
            return ExecRangeVisitorImpl<T>(field_id, index_func, elem_func);
381 382
        }
        case OpType::GreaterEqual: {
383
            auto index_func = [&](Index* index) {
Y
yah01 已提交
384 385
                return index->Range(val, OpType::GreaterEqual);
            };
386
            auto elem_func = [&](MayConstRef<T> x) { return (x >= val); };
387
            return ExecRangeVisitorImpl<T>(field_id, index_func, elem_func);
G
GuoRentong 已提交
388
        }
389
        case OpType::GreaterThan: {
390
            auto index_func = [&](Index* index) {
Y
yah01 已提交
391 392
                return index->Range(val, OpType::GreaterThan);
            };
393
            auto elem_func = [&](MayConstRef<T> x) { return (x > val); };
394
            return ExecRangeVisitorImpl<T>(field_id, index_func, elem_func);
395 396
        }
        case OpType::LessEqual: {
397
            auto index_func = [&](Index* index) {
Y
yah01 已提交
398 399
                return index->Range(val, OpType::LessEqual);
            };
400
            auto elem_func = [&](MayConstRef<T> x) { return (x <= val); };
401
            return ExecRangeVisitorImpl<T>(field_id, index_func, elem_func);
402 403
        }
        case OpType::LessThan: {
404
            auto index_func = [&](Index* index) {
Y
yah01 已提交
405 406
                return index->Range(val, OpType::LessThan);
            };
407
            auto elem_func = [&](MayConstRef<T> x) { return (x < val); };
408
            return ExecRangeVisitorImpl<T>(field_id, index_func, elem_func);
409 410
        }
        case OpType::PrefixMatch: {
411
            auto index_func = [&](Index* index) {
P
presburger 已提交
412
                auto dataset = std::make_unique<Dataset>();
413 414
                dataset->Set(milvus::index::OPERATOR_TYPE, OpType::PrefixMatch);
                dataset->Set(milvus::index::PREFIX_VALUE, val);
415 416
                return index->Query(std::move(dataset));
            };
417 418 419
            auto elem_func = [&](MayConstRef<T> x) {
                return Match(x, val, op);
            };
420
            return ExecRangeVisitorImpl<T>(field_id, index_func, elem_func);
421 422
        }
        // TODO: PostfixMatch
423
        default: {
G
GuoRentong 已提交
424 425
            PanicInfo("unsupported range node");
        }
426 427 428 429
    }
}
#pragma clang diagnostic pop

430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481
template <typename T>
auto
ExecExprVisitor::ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw)
    -> BitsetType {
    if constexpr (std::is_integral_v<T>) {
        auto& expr = static_cast<UnaryRangeExprImpl<int64_t>&>(expr_raw);
        auto val = expr.value_;

        if (!out_of_range<T>(val)) {
            return ExecUnaryRangeVisitorDispatcherImpl<T>(expr_raw);
        }

        // see also: https://github.com/milvus-io/milvus/issues/23646.
        switch (expr.op_type_) {
            case proto::plan::GreaterThan:
            case proto::plan::GreaterEqual: {
                BitsetType r(row_count_);
                if (lt_lb<T>(val)) {
                    r.set();
                }
                return r;
            }

            case proto::plan::LessThan:
            case proto::plan::LessEqual: {
                BitsetType r(row_count_);
                if (gt_ub<T>(val)) {
                    r.set();
                }
                return r;
            }

            case proto::plan::Equal: {
                BitsetType r(row_count_);
                r.reset();
                return r;
            }

            case proto::plan::NotEqual: {
                BitsetType r(row_count_);
                r.set();
                return r;
            }

            default: {
                PanicInfo("unsupported range node");
            }
        }
    }
    return ExecUnaryRangeVisitorDispatcherImpl<T>(expr_raw);
}

482 483 484 485 486 487 488 489 490
template <typename ExprValueType>
auto
ExecExprVisitor::ExecUnaryRangeVisitorDispatcherJson(UnaryRangeExpr& expr_raw)
    -> BitsetType {
    using Index = index::ScalarIndex<milvus::Json>;
    auto& expr = static_cast<UnaryRangeExprImpl<ExprValueType>&>(expr_raw);

    auto op = expr.op_type_;
    auto val = expr.value_;
491
    auto pointer = milvus::Json::pointer(expr.column_.nested_path);
492
    auto field_id = expr.column_.field_id;
493
    auto index_func = [=](Index* index) { return TargetBitmap{}; };
494 495 496 497 498
    using GetType =
        std::conditional_t<std::is_same_v<ExprValueType, std::string>,
                           std::string_view,
                           ExprValueType>;

499 500 501 502 503 504 505 506 507 508 509
#define UnaryRangeJSONCompare(cmp)                            \
    do {                                                      \
        auto x = json.template at<GetType>(pointer);          \
        if (x.error()) {                                      \
            if constexpr (std::is_same_v<GetType, int64_t>) { \
                auto x = json.template at<double>(pointer);   \
                return !x.error() && (cmp);                   \
            }                                                 \
            return false;                                     \
        }                                                     \
        return (cmp);                                         \
510 511
    } while (false)

512 513 514 515 516 517 518 519 520 521 522
#define UnaryRangeJSONCompareNotEqual(cmp)                    \
    do {                                                      \
        auto x = json.template at<GetType>(pointer);          \
        if (x.error()) {                                      \
            if constexpr (std::is_same_v<GetType, int64_t>) { \
                auto x = json.template at<double>(pointer);   \
                return x.error() || (cmp);                    \
            }                                                 \
            return true;                                      \
        }                                                     \
        return (cmp);                                         \
523 524
    } while (false)

525 526
    switch (op) {
        case OpType::Equal: {
527
            auto elem_func = [&](const milvus::Json& json) {
528
                UnaryRangeJSONCompare(x.value() == val);
529 530 531 532 533
            };
            return ExecRangeVisitorImpl<milvus::Json>(
                field_id, index_func, elem_func);
        }
        case OpType::NotEqual: {
534
            auto elem_func = [&](const milvus::Json& json) {
535
                UnaryRangeJSONCompareNotEqual(x.value() != val);
536 537 538 539 540
            };
            return ExecRangeVisitorImpl<milvus::Json>(
                field_id, index_func, elem_func);
        }
        case OpType::GreaterEqual: {
541
            auto elem_func = [&](const milvus::Json& json) {
542
                UnaryRangeJSONCompare(x.value() >= val);
543 544 545 546 547
            };
            return ExecRangeVisitorImpl<milvus::Json>(
                field_id, index_func, elem_func);
        }
        case OpType::GreaterThan: {
548
            auto elem_func = [&](const milvus::Json& json) {
549
                UnaryRangeJSONCompare(x.value() > val);
550 551 552 553 554
            };
            return ExecRangeVisitorImpl<milvus::Json>(
                field_id, index_func, elem_func);
        }
        case OpType::LessEqual: {
555
            auto elem_func = [&](const milvus::Json& json) {
556
                UnaryRangeJSONCompare(x.value() <= val);
557 558 559 560 561
            };
            return ExecRangeVisitorImpl<milvus::Json>(
                field_id, index_func, elem_func);
        }
        case OpType::LessThan: {
562
            auto elem_func = [&](const milvus::Json& json) {
563
                UnaryRangeJSONCompare(x.value() < val);
564 565 566 567 568
            };
            return ExecRangeVisitorImpl<milvus::Json>(
                field_id, index_func, elem_func);
        }
        case OpType::PrefixMatch: {
569
            auto elem_func = [&](const milvus::Json& json) {
570
                UnaryRangeJSONCompare(Match(ExprValueType(x.value()), val, op));
571 572 573 574 575 576 577 578 579 580 581
            };
            return ExecRangeVisitorImpl<milvus::Json>(
                field_id, index_func, elem_func);
        }
        // TODO: PostfixMatch
        default: {
            PanicInfo("unsupported range node");
        }
    }
}

582 583 584 585
#pragma clang diagnostic push
#pragma ide diagnostic ignored "Simplify"
template <typename T>
auto
Y
yah01 已提交
586 587
ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcher(
    BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType {
588 589 590 591 592 593 594
    // see also: https://github.com/milvus-io/milvus/issues/23646.
    typedef std::conditional_t<std::is_integral_v<T>, int64_t, T>
        HighPrecisionType;

    auto& expr =
        static_cast<BinaryArithOpEvalRangeExprImpl<HighPrecisionType>&>(
            expr_raw);
595
    using Index = index::ScalarIndex<T>;
596 597 598 599 600 601 602 603 604
    auto arith_op = expr.arith_op_;
    auto right_operand = expr.right_operand_;
    auto op = expr.op_type_;
    auto val = expr.value_;

    switch (op) {
        case OpType::Equal: {
            switch (arith_op) {
                case ArithOpType::Add: {
Y
yah01 已提交
605 606
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
607 608 609
                        auto x = index->Reverse_Lookup(offset);
                        return (x + right_operand) == val;
                    };
610 611 612
                    auto elem_func = [val, right_operand](MayConstRef<T> x) {
                        return ((x + right_operand) == val);
                    };
Y
yah01 已提交
613
                    return ExecDataRangeVisitorImpl<T>(
614
                        expr.column_.field_id, index_func, elem_func);
615 616
                }
                case ArithOpType::Sub: {
Y
yah01 已提交
617 618
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
619 620 621
                        auto x = index->Reverse_Lookup(offset);
                        return (x - right_operand) == val;
                    };
622
                    auto elem_func = [val, right_operand](MayConstRef<T> x) {
Y
yah01 已提交
623 624 625
                        return ((x - right_operand) == val);
                    };
                    return ExecDataRangeVisitorImpl<T>(
626
                        expr.column_.field_id, index_func, elem_func);
627 628
                }
                case ArithOpType::Mul: {
Y
yah01 已提交
629 630
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
631 632 633
                        auto x = index->Reverse_Lookup(offset);
                        return (x * right_operand) == val;
                    };
634
                    auto elem_func = [val, right_operand](MayConstRef<T> x) {
Y
yah01 已提交
635 636 637
                        return ((x * right_operand) == val);
                    };
                    return ExecDataRangeVisitorImpl<T>(
638
                        expr.column_.field_id, index_func, elem_func);
639 640
                }
                case ArithOpType::Div: {
Y
yah01 已提交
641 642
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
643 644 645
                        auto x = index->Reverse_Lookup(offset);
                        return (x / right_operand) == val;
                    };
646
                    auto elem_func = [val, right_operand](MayConstRef<T> x) {
Y
yah01 已提交
647 648 649
                        return ((x / right_operand) == val);
                    };
                    return ExecDataRangeVisitorImpl<T>(
650
                        expr.column_.field_id, index_func, elem_func);
651 652
                }
                case ArithOpType::Mod: {
Y
yah01 已提交
653 654
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
655 656 657
                        auto x = index->Reverse_Lookup(offset);
                        return static_cast<T>(fmod(x, right_operand)) == val;
                    };
658
                    auto elem_func = [val, right_operand](MayConstRef<T> x) {
659 660
                        return (static_cast<T>(fmod(x, right_operand)) == val);
                    };
Y
yah01 已提交
661
                    return ExecDataRangeVisitorImpl<T>(
662
                        expr.column_.field_id, index_func, elem_func);
663 664 665 666 667 668 669 670 671
                }
                default: {
                    PanicInfo("unsupported arithmetic operation");
                }
            }
        }
        case OpType::NotEqual: {
            switch (arith_op) {
                case ArithOpType::Add: {
Y
yah01 已提交
672 673
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
674 675 676
                        auto x = index->Reverse_Lookup(offset);
                        return (x + right_operand) != val;
                    };
677
                    auto elem_func = [val, right_operand](MayConstRef<T> x) {
Y
yah01 已提交
678 679 680
                        return ((x + right_operand) != val);
                    };
                    return ExecDataRangeVisitorImpl<T>(
681
                        expr.column_.field_id, index_func, elem_func);
682 683
                }
                case ArithOpType::Sub: {
Y
yah01 已提交
684 685
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
686 687 688
                        auto x = index->Reverse_Lookup(offset);
                        return (x - right_operand) != val;
                    };
689
                    auto elem_func = [val, right_operand](MayConstRef<T> x) {
Y
yah01 已提交
690 691 692
                        return ((x - right_operand) != val);
                    };
                    return ExecDataRangeVisitorImpl<T>(
693
                        expr.column_.field_id, index_func, elem_func);
694 695
                }
                case ArithOpType::Mul: {
Y
yah01 已提交
696 697
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
698 699 700
                        auto x = index->Reverse_Lookup(offset);
                        return (x * right_operand) != val;
                    };
701
                    auto elem_func = [val, right_operand](MayConstRef<T> x) {
Y
yah01 已提交
702 703 704
                        return ((x * right_operand) != val);
                    };
                    return ExecDataRangeVisitorImpl<T>(
705
                        expr.column_.field_id, index_func, elem_func);
706 707
                }
                case ArithOpType::Div: {
Y
yah01 已提交
708 709
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
710 711 712
                        auto x = index->Reverse_Lookup(offset);
                        return (x / right_operand) != val;
                    };
713
                    auto elem_func = [val, right_operand](MayConstRef<T> x) {
Y
yah01 已提交
714 715 716
                        return ((x / right_operand) != val);
                    };
                    return ExecDataRangeVisitorImpl<T>(
717
                        expr.column_.field_id, index_func, elem_func);
718 719
                }
                case ArithOpType::Mod: {
Y
yah01 已提交
720 721
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
722 723 724
                        auto x = index->Reverse_Lookup(offset);
                        return static_cast<T>(fmod(x, right_operand)) != val;
                    };
725
                    auto elem_func = [val, right_operand](MayConstRef<T> x) {
726 727
                        return (static_cast<T>(fmod(x, right_operand)) != val);
                    };
Y
yah01 已提交
728
                    return ExecDataRangeVisitorImpl<T>(
729
                        expr.column_.field_id, index_func, elem_func);
730 731 732 733 734 735 736 737 738 739 740 741 742
                }
                default: {
                    PanicInfo("unsupported arithmetic operation");
                }
            }
        }
        default: {
            PanicInfo("unsupported range node with arithmetic operation");
        }
    }
}
#pragma clang diagnostic pop

743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758
template <typename ExprValueType>
auto
ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson(
    BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType {
    auto& expr =
        static_cast<BinaryArithOpEvalRangeExprImpl<ExprValueType>&>(expr_raw);
    using Index = index::ScalarIndex<milvus::Json>;
    using GetType =
        std::conditional_t<std::is_same_v<ExprValueType, std::string>,
                           std::string_view,
                           ExprValueType>;

    auto arith_op = expr.arith_op_;
    auto right_operand = expr.right_operand_;
    auto op = expr.op_type_;
    auto val = expr.value_;
759
    auto pointer = milvus::Json::pointer(expr.column_.nested_path);
760 761 762 763 764 765 766 767 768 769 770 771

#define BinaryArithRangeJSONCompare(cmp)                      \
    do {                                                      \
        auto x = json.template at<GetType>(pointer);          \
        if (x.error()) {                                      \
            if constexpr (std::is_same_v<GetType, int64_t>) { \
                auto x = json.template at<double>(pointer);   \
                return !x.error() && (cmp);                   \
            }                                                 \
            return false;                                     \
        }                                                     \
        return (cmp);                                         \
772 773
    } while (false)

774 775 776 777 778 779 780 781 782 783 784
#define BinaryArithRangeJSONCompareNotEqual(cmp)              \
    do {                                                      \
        auto x = json.template at<GetType>(pointer);          \
        if (x.error()) {                                      \
            if constexpr (std::is_same_v<GetType, int64_t>) { \
                auto x = json.template at<double>(pointer);   \
                return x.error() || (cmp);                    \
            }                                                 \
            return true;                                      \
        }                                                     \
        return (cmp);                                         \
785 786
    } while (false)

787 788 789 790 791 792 793 794 795
    switch (op) {
        case OpType::Equal: {
            switch (arith_op) {
                case ArithOpType::Add: {
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
                        return false;
                    };
                    auto elem_func = [&](const milvus::Json& json) {
796 797
                        BinaryArithRangeJSONCompare(x.value() + right_operand ==
                                                    val);
798 799 800 801 802 803 804 805 806 807
                    };
                    return ExecDataRangeVisitorImpl<milvus::Json>(
                        expr.column_.field_id, index_func, elem_func);
                }
                case ArithOpType::Sub: {
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
                        return false;
                    };
                    auto elem_func = [&](const milvus::Json& json) {
808 809
                        BinaryArithRangeJSONCompare(x.value() - right_operand ==
                                                    val);
810 811 812 813 814 815 816 817 818 819
                    };
                    return ExecDataRangeVisitorImpl<milvus::Json>(
                        expr.column_.field_id, index_func, elem_func);
                }
                case ArithOpType::Mul: {
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
                        return false;
                    };
                    auto elem_func = [&](const milvus::Json& json) {
820 821
                        BinaryArithRangeJSONCompare(x.value() * right_operand ==
                                                    val);
822 823 824 825 826 827 828 829 830 831
                    };
                    return ExecDataRangeVisitorImpl<milvus::Json>(
                        expr.column_.field_id, index_func, elem_func);
                }
                case ArithOpType::Div: {
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
                        return false;
                    };
                    auto elem_func = [&](const milvus::Json& json) {
832 833
                        BinaryArithRangeJSONCompare(x.value() / right_operand ==
                                                    val);
834 835 836 837 838 839 840 841 842 843
                    };
                    return ExecDataRangeVisitorImpl<milvus::Json>(
                        expr.column_.field_id, index_func, elem_func);
                }
                case ArithOpType::Mod: {
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
                        return false;
                    };
                    auto elem_func = [&](const milvus::Json& json) {
844 845 846
                        BinaryArithRangeJSONCompare(
                            static_cast<ExprValueType>(
                                fmod(x.value(), right_operand)) == val);
847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863
                    };
                    return ExecDataRangeVisitorImpl<milvus::Json>(
                        expr.column_.field_id, index_func, elem_func);
                }
                default: {
                    PanicInfo("unsupported arithmetic operation");
                }
            }
        }
        case OpType::NotEqual: {
            switch (arith_op) {
                case ArithOpType::Add: {
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
                        return false;
                    };
                    auto elem_func = [&](const milvus::Json& json) {
864 865
                        BinaryArithRangeJSONCompareNotEqual(
                            x.value() + right_operand != val);
866 867 868 869 870 871 872 873 874 875
                    };
                    return ExecDataRangeVisitorImpl<milvus::Json>(
                        expr.column_.field_id, index_func, elem_func);
                }
                case ArithOpType::Sub: {
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
                        return false;
                    };
                    auto elem_func = [&](const milvus::Json& json) {
876 877
                        BinaryArithRangeJSONCompareNotEqual(
                            x.value() - right_operand != val);
878 879 880 881 882 883 884 885 886 887
                    };
                    return ExecDataRangeVisitorImpl<milvus::Json>(
                        expr.column_.field_id, index_func, elem_func);
                }
                case ArithOpType::Mul: {
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
                        return false;
                    };
                    auto elem_func = [&](const milvus::Json& json) {
888 889
                        BinaryArithRangeJSONCompareNotEqual(
                            x.value() * right_operand != val);
890 891 892 893 894 895 896 897 898 899
                    };
                    return ExecDataRangeVisitorImpl<milvus::Json>(
                        expr.column_.field_id, index_func, elem_func);
                }
                case ArithOpType::Div: {
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
                        return false;
                    };
                    auto elem_func = [&](const milvus::Json& json) {
900 901
                        BinaryArithRangeJSONCompareNotEqual(
                            x.value() / right_operand != val);
902 903 904 905 906 907 908 909 910 911
                    };
                    return ExecDataRangeVisitorImpl<milvus::Json>(
                        expr.column_.field_id, index_func, elem_func);
                }
                case ArithOpType::Mod: {
                    auto index_func = [val, right_operand](Index* index,
                                                           size_t offset) {
                        return false;
                    };
                    auto elem_func = [&](const milvus::Json& json) {
912 913 914
                        BinaryArithRangeJSONCompareNotEqual(
                            static_cast<ExprValueType>(
                                fmod(x.value(), right_operand)) != val);
915 916 917 918 919 920 921 922 923 924 925 926 927
                    };
                    return ExecDataRangeVisitorImpl<milvus::Json>(
                        expr.column_.field_id, index_func, elem_func);
                }
                default: {
                    PanicInfo("unsupported arithmetic operation");
                }
            }
        }
        default: {
            PanicInfo("unsupported range node with arithmetic operation");
        }
    }
928
}  // namespace milvus::query
929

930 931 932 933
#pragma clang diagnostic push
#pragma ide diagnostic ignored "Simplify"
template <typename T>
auto
Y
yah01 已提交
934 935 936 937 938
ExecExprVisitor::ExecBinaryRangeVisitorDispatcher(BinaryRangeExpr& expr_raw)
    -> BitsetType {
    typedef std::
        conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
            IndexInnerType;
Y
yah01 已提交
939 940 941
    using Index = index::ScalarIndex<IndexInnerType>;
    auto& expr = static_cast<BinaryRangeExprImpl<IndexInnerType>&>(expr_raw);

942 943
    bool lower_inclusive = expr.lower_inclusive_;
    bool upper_inclusive = expr.upper_inclusive_;
944 945 946 947 948 949 950 951 952

    // see also: https://github.com/milvus-io/milvus/issues/23646.
    typedef std::conditional_t<std::is_integral_v<IndexInnerType>,
                               int64_t,
                               IndexInnerType>
        HighPrecisionType;

    auto val1 = static_cast<HighPrecisionType>(expr.lower_value_);
    auto val2 = static_cast<HighPrecisionType>(expr.upper_value_);
953

954
    auto index_func = [&](Index* index) {
955 956 957 958 959 960 961 962 963 964 965 966 967
        if constexpr (std::is_integral_v<T>) {
            if (gt_ub<T>(val1)) {
                return TargetBitmap(index->Size(), false);
            } else if (lt_lb<T>(val1)) {
                val1 = std::numeric_limits<T>::min();
            }

            if (gt_ub<T>(val2)) {
                val2 = std::numeric_limits<T>::max();
            } else if (lt_lb<T>(val2)) {
                return TargetBitmap(index->Size(), false);
            }
        }
Y
yah01 已提交
968 969
        return index->Range(val1, lower_inclusive, val2, upper_inclusive);
    };
970

971
    if (lower_inclusive && upper_inclusive) {
972 973 974
        auto elem_func = [val1, val2](MayConstRef<T> x) {
            return (val1 <= x && x <= val2);
        };
975 976
        return ExecRangeVisitorImpl<T>(
            expr.column_.field_id, index_func, elem_func);
977
    } else if (lower_inclusive && !upper_inclusive) {
978 979 980
        auto elem_func = [val1, val2](MayConstRef<T> x) {
            return (val1 <= x && x < val2);
        };
981 982
        return ExecRangeVisitorImpl<T>(
            expr.column_.field_id, index_func, elem_func);
983
    } else if (!lower_inclusive && upper_inclusive) {
984 985 986
        auto elem_func = [val1, val2](MayConstRef<T> x) {
            return (val1 < x && x <= val2);
        };
987 988
        return ExecRangeVisitorImpl<T>(
            expr.column_.field_id, index_func, elem_func);
G
GuoRentong 已提交
989
    } else {
990 991 992
        auto elem_func = [val1, val2](MayConstRef<T> x) {
            return (val1 < x && x < val2);
        };
993 994
        return ExecRangeVisitorImpl<T>(
            expr.column_.field_id, index_func, elem_func);
G
GuoRentong 已提交
995 996 997 998
    }
}
#pragma clang diagnostic pop

999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013
template <typename ExprValueType>
auto
ExecExprVisitor::ExecBinaryRangeVisitorDispatcherJson(BinaryRangeExpr& expr_raw)
    -> BitsetType {
    using Index = index::ScalarIndex<milvus::Json>;
    using GetType =
        std::conditional_t<std::is_same_v<ExprValueType, std::string>,
                           std::string_view,
                           ExprValueType>;

    auto& expr = static_cast<BinaryRangeExprImpl<ExprValueType>&>(expr_raw);
    bool lower_inclusive = expr.lower_inclusive_;
    bool upper_inclusive = expr.upper_inclusive_;
    ExprValueType val1 = expr.lower_value_;
    ExprValueType val2 = expr.upper_value_;
1014
    auto pointer = milvus::Json::pointer(expr.column_.nested_path);
1015 1016

    // no json index now
1017
    auto index_func = [=](Index* index) { return TargetBitmap{}; };
1018

1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033
#define BinaryRangeJSONCompare(cmp)                           \
    do {                                                      \
        auto x = json.template at<GetType>(pointer);          \
        if (x.error()) {                                      \
            if constexpr (std::is_same_v<GetType, int64_t>) { \
                auto x = json.template at<double>(pointer);   \
                if (!x.error()) {                             \
                    auto value = x.value();                   \
                    return (cmp);                             \
                }                                             \
            }                                                 \
            return false;                                     \
        }                                                     \
        auto value = x.value();                               \
        return (cmp);                                         \
1034 1035
    } while (false)

1036 1037
    if (lower_inclusive && upper_inclusive) {
        auto elem_func = [&](const milvus::Json& json) {
1038
            BinaryRangeJSONCompare(val1 <= value && value <= val2);
1039 1040 1041 1042 1043
        };
        return ExecRangeVisitorImpl<milvus::Json>(
            expr.column_.field_id, index_func, elem_func);
    } else if (lower_inclusive && !upper_inclusive) {
        auto elem_func = [&](const milvus::Json& json) {
1044
            BinaryRangeJSONCompare(val1 <= value && value < val2);
1045 1046 1047 1048 1049
        };
        return ExecRangeVisitorImpl<milvus::Json>(
            expr.column_.field_id, index_func, elem_func);
    } else if (!lower_inclusive && upper_inclusive) {
        auto elem_func = [&](const milvus::Json& json) {
1050
            BinaryRangeJSONCompare(val1 < value && value <= val2);
1051 1052 1053 1054 1055
        };
        return ExecRangeVisitorImpl<milvus::Json>(
            expr.column_.field_id, index_func, elem_func);
    } else {
        auto elem_func = [&](const milvus::Json& json) {
1056
            BinaryRangeJSONCompare(val1 < value && value < val2);
1057 1058 1059 1060 1061 1062
        };
        return ExecRangeVisitorImpl<milvus::Json>(
            expr.column_.field_id, index_func, elem_func);
    }
}

N
neza2017 已提交
1063
void
1064
ExecExprVisitor::visit(UnaryRangeExpr& expr) {
1065 1066
    auto& field_meta = segment_.get_schema()[expr.column_.field_id];
    AssertInfo(expr.column_.data_type == field_meta.get_data_type(),
1067
               "[ExecExprVisitor]DataType of expr isn't field_meta data type");
1068
    BitsetType res;
1069
    switch (expr.column_.data_type) {
N
neza2017 已提交
1070
        case DataType::BOOL: {
1071
            res = ExecUnaryRangeVisitorDispatcher<bool>(expr);
N
neza2017 已提交
1072 1073
            break;
        }
G
GuoRentong 已提交
1074
        case DataType::INT8: {
1075
            res = ExecUnaryRangeVisitorDispatcher<int8_t>(expr);
G
GuoRentong 已提交
1076 1077 1078
            break;
        }
        case DataType::INT16: {
1079
            res = ExecUnaryRangeVisitorDispatcher<int16_t>(expr);
G
GuoRentong 已提交
1080 1081 1082
            break;
        }
        case DataType::INT32: {
1083
            res = ExecUnaryRangeVisitorDispatcher<int32_t>(expr);
G
GuoRentong 已提交
1084 1085 1086
            break;
        }
        case DataType::INT64: {
1087
            res = ExecUnaryRangeVisitorDispatcher<int64_t>(expr);
G
GuoRentong 已提交
1088 1089 1090
            break;
        }
        case DataType::FLOAT: {
1091
            res = ExecUnaryRangeVisitorDispatcher<float>(expr);
G
GuoRentong 已提交
1092 1093 1094
            break;
        }
        case DataType::DOUBLE: {
1095 1096 1097
            res = ExecUnaryRangeVisitorDispatcher<double>(expr);
            break;
        }
1098
        case DataType::VARCHAR: {
Y
yah01 已提交
1099 1100 1101 1102 1103
            if (segment_.type() == SegmentType::Growing) {
                res = ExecUnaryRangeVisitorDispatcher<std::string>(expr);
            } else {
                res = ExecUnaryRangeVisitorDispatcher<std::string_view>(expr);
            }
1104 1105
            break;
        }
1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126
        case DataType::JSON: {
            switch (expr.val_case_) {
                case proto::plan::GenericValue::ValCase::kBoolVal:
                    res = ExecUnaryRangeVisitorDispatcherJson<bool>(expr);
                    break;
                case proto::plan::GenericValue::ValCase::kInt64Val:
                    res = ExecUnaryRangeVisitorDispatcherJson<int64_t>(expr);
                    break;
                case proto::plan::GenericValue::ValCase::kFloatVal:
                    res = ExecUnaryRangeVisitorDispatcherJson<double>(expr);
                    break;
                case proto::plan::GenericValue::ValCase::kStringVal:
                    res =
                        ExecUnaryRangeVisitorDispatcherJson<std::string>(expr);
                    break;
                default:
                    PanicInfo(
                        fmt::format("unknown data type: {}", expr.val_case_));
            }
            break;
        }
1127
        default:
1128 1129
            PanicInfo(fmt::format("unsupported data type: {}",
                                  expr.column_.data_type));
1130
    }
Y
yah01 已提交
1131 1132
    AssertInfo(res.size() == row_count_,
               "[ExecExprVisitor]Size of results not equal row count");
1133
    bitset_opt_ = std::move(res);
1134 1135
}

1136 1137
void
ExecExprVisitor::visit(BinaryArithOpEvalRangeExpr& expr) {
1138 1139
    auto& field_meta = segment_.get_schema()[expr.column_.field_id];
    AssertInfo(expr.column_.data_type == field_meta.get_data_type(),
1140 1141
               "[ExecExprVisitor]DataType of expr isn't field_meta data type");
    BitsetType res;
1142
    switch (expr.column_.data_type) {
1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166
        case DataType::INT8: {
            res = ExecBinaryArithOpEvalRangeVisitorDispatcher<int8_t>(expr);
            break;
        }
        case DataType::INT16: {
            res = ExecBinaryArithOpEvalRangeVisitorDispatcher<int16_t>(expr);
            break;
        }
        case DataType::INT32: {
            res = ExecBinaryArithOpEvalRangeVisitorDispatcher<int32_t>(expr);
            break;
        }
        case DataType::INT64: {
            res = ExecBinaryArithOpEvalRangeVisitorDispatcher<int64_t>(expr);
            break;
        }
        case DataType::FLOAT: {
            res = ExecBinaryArithOpEvalRangeVisitorDispatcher<float>(expr);
            break;
        }
        case DataType::DOUBLE: {
            res = ExecBinaryArithOpEvalRangeVisitorDispatcher<double>(expr);
            break;
        }
1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185
        case DataType::JSON: {
            switch (expr.val_case_) {
                case proto::plan::GenericValue::ValCase::kBoolVal: {
                    res = ExecBinaryArithOpEvalRangeVisitorDispatcherJson<bool>(
                        expr);
                    break;
                }
                case proto::plan::GenericValue::ValCase::kInt64Val: {
                    res = ExecBinaryArithOpEvalRangeVisitorDispatcherJson<
                        int64_t>(expr);
                    break;
                }
                case proto::plan::GenericValue::ValCase::kFloatVal: {
                    res =
                        ExecBinaryArithOpEvalRangeVisitorDispatcherJson<double>(
                            expr);
                    break;
                }
                default: {
1186 1187 1188
                    PanicInfo(
                        fmt::format("unsupported value type {} in expression",
                                    expr.val_case_));
1189 1190 1191 1192
                }
            }
            break;
        }
1193
        default:
1194 1195
            PanicInfo(fmt::format("unsupported data type: {}",
                                  expr.column_.data_type));
1196
    }
Y
yah01 已提交
1197 1198
    AssertInfo(res.size() == row_count_,
               "[ExecExprVisitor]Size of results not equal row count");
1199 1200 1201
    bitset_opt_ = std::move(res);
}

1202 1203
void
ExecExprVisitor::visit(BinaryRangeExpr& expr) {
1204 1205
    auto& field_meta = segment_.get_schema()[expr.column_.field_id];
    AssertInfo(expr.column_.data_type == field_meta.get_data_type(),
1206
               "[ExecExprVisitor]DataType of expr isn't field_meta data type");
1207
    BitsetType res;
1208
    switch (expr.column_.data_type) {
1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234
        case DataType::BOOL: {
            res = ExecBinaryRangeVisitorDispatcher<bool>(expr);
            break;
        }
        case DataType::INT8: {
            res = ExecBinaryRangeVisitorDispatcher<int8_t>(expr);
            break;
        }
        case DataType::INT16: {
            res = ExecBinaryRangeVisitorDispatcher<int16_t>(expr);
            break;
        }
        case DataType::INT32: {
            res = ExecBinaryRangeVisitorDispatcher<int32_t>(expr);
            break;
        }
        case DataType::INT64: {
            res = ExecBinaryRangeVisitorDispatcher<int64_t>(expr);
            break;
        }
        case DataType::FLOAT: {
            res = ExecBinaryRangeVisitorDispatcher<float>(expr);
            break;
        }
        case DataType::DOUBLE: {
            res = ExecBinaryRangeVisitorDispatcher<double>(expr);
G
GuoRentong 已提交
1235 1236
            break;
        }
1237
        case DataType::VARCHAR: {
Y
yah01 已提交
1238 1239 1240 1241 1242
            if (segment_.type() == SegmentType::Growing) {
                res = ExecBinaryRangeVisitorDispatcher<std::string>(expr);
            } else {
                res = ExecBinaryRangeVisitorDispatcher<std::string_view>(expr);
            }
1243 1244
            break;
        }
1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264
        case DataType::JSON: {
            switch (expr.val_case_) {
                case proto::plan::GenericValue::ValCase::kBoolVal: {
                    res = ExecBinaryRangeVisitorDispatcherJson<bool>(expr);
                    break;
                }
                case proto::plan::GenericValue::ValCase::kInt64Val: {
                    res = ExecBinaryRangeVisitorDispatcherJson<int64_t>(expr);
                    break;
                }
                case proto::plan::GenericValue::ValCase::kFloatVal: {
                    res = ExecBinaryRangeVisitorDispatcherJson<double>(expr);
                    break;
                }
                case proto::plan::GenericValue::ValCase::kStringVal: {
                    res =
                        ExecBinaryRangeVisitorDispatcherJson<std::string>(expr);
                    break;
                }
                default: {
1265 1266 1267
                    PanicInfo(
                        fmt::format("unsupported value type {} in expression",
                                    expr.val_case_));
1268 1269 1270 1271
                }
            }
            break;
        }
G
GuoRentong 已提交
1272
        default:
1273 1274
            PanicInfo(fmt::format("unsupported data type: {}",
                                  expr.column_.data_type));
G
GuoRentong 已提交
1275
    }
Y
yah01 已提交
1276 1277
    AssertInfo(res.size() == row_count_,
               "[ExecExprVisitor]Size of results not equal row count");
1278
    bitset_opt_ = std::move(res);
N
neza2017 已提交
1279 1280
}

1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294
template <typename Op>
struct relational {
    template <typename T, typename U>
    bool
    operator()(T const& a, U const& b) const {
        return Op{}(a, b);
    }
    template <typename... T>
    bool
    operator()(T const&...) const {
        PanicInfo("incompatible operands");
    }
};

1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325
template <typename T, typename U, typename CmpFunc>
TargetBitmap
ExecExprVisitor::ExecCompareRightType(const T* left_raw_data,
                                      const FieldId& right_field_id,
                                      const int64_t current_chunk_id,
                                      CmpFunc cmp_func) {
    auto size_per_chunk = segment_.size_per_chunk();
    auto num_chunks = upper_div(row_count_, size_per_chunk);
    auto size = current_chunk_id == num_chunks - 1
                    ? row_count_ - current_chunk_id * size_per_chunk
                    : size_per_chunk;

    TargetBitmap result(size);
    const U* right_raw_data =
        segment_.chunk_data<U>(right_field_id, current_chunk_id).data();

    for (int i = 0; i < size; ++i) {
        result[i] = cmp_func(left_raw_data[i], right_raw_data[i]);
    }

    return result;
}

template <typename T, typename CmpFunc>
BitsetType
ExecExprVisitor::ExecCompareLeftType(const FieldId& left_field_id,
                                     const FieldId& right_field_id,
                                     const DataType& right_field_type,
                                     CmpFunc cmp_func) {
    auto size_per_chunk = segment_.size_per_chunk();
    auto num_chunks = upper_div(row_count_, size_per_chunk);
1326 1327
    std::vector<FixedVector<bool>> results;
    results.reserve(num_chunks);
1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418

    for (int64_t chunk_id = 0; chunk_id < num_chunks; ++chunk_id) {
        FixedVector<bool> result;
        const T* left_raw_data =
            segment_.chunk_data<T>(left_field_id, chunk_id).data();

        switch (right_field_type) {
            case DataType::BOOL:
                result = ExecCompareRightType<T, bool, CmpFunc>(
                    left_raw_data, right_field_id, chunk_id, cmp_func);
                break;
            case DataType::INT8:
                result = ExecCompareRightType<T, int8_t, CmpFunc>(
                    left_raw_data, right_field_id, chunk_id, cmp_func);
                break;
            case DataType::INT16:
                result = ExecCompareRightType<T, int16_t, CmpFunc>(
                    left_raw_data, right_field_id, chunk_id, cmp_func);
                break;
            case DataType::INT32:
                result = ExecCompareRightType<T, int32_t, CmpFunc>(
                    left_raw_data, right_field_id, chunk_id, cmp_func);
                break;
            case DataType::INT64:
                result = ExecCompareRightType<T, int64_t, CmpFunc>(
                    left_raw_data, right_field_id, chunk_id, cmp_func);
                break;
            case DataType::FLOAT:
                result = ExecCompareRightType<T, float, CmpFunc>(
                    left_raw_data, right_field_id, chunk_id, cmp_func);
                break;
            case DataType::DOUBLE:
                result = ExecCompareRightType<T, double, CmpFunc>(
                    left_raw_data, right_field_id, chunk_id, cmp_func);
                break;
            default:
                PanicInfo("unsupported left datatype of compare expr");
        }
        results.push_back(result);
    }
    auto final_result = AssembleChunk(results);
    AssertInfo(final_result.size() == row_count_,
               "[ExecExprVisitor]Size of results not equal row count");
    return final_result;
}

template <typename CmpFunc>
BitsetType
ExecExprVisitor::ExecCompareExprDispatcherForNonIndexedSegment(
    CompareExpr& expr, CmpFunc cmp_func) {
    switch (expr.left_data_type_) {
        case DataType::BOOL:
            return ExecCompareLeftType<bool, CmpFunc>(expr.left_field_id_,
                                                      expr.right_field_id_,
                                                      expr.right_data_type_,
                                                      cmp_func);
        case DataType::INT8:
            return ExecCompareLeftType<int8_t, CmpFunc>(expr.left_field_id_,
                                                        expr.right_field_id_,
                                                        expr.right_data_type_,
                                                        cmp_func);
        case DataType::INT16:
            return ExecCompareLeftType<int16_t, CmpFunc>(expr.left_field_id_,
                                                         expr.right_field_id_,
                                                         expr.right_data_type_,
                                                         cmp_func);
        case DataType::INT32:
            return ExecCompareLeftType<int32_t, CmpFunc>(expr.left_field_id_,
                                                         expr.right_field_id_,
                                                         expr.right_data_type_,
                                                         cmp_func);
        case DataType::INT64:
            return ExecCompareLeftType<int64_t, CmpFunc>(expr.left_field_id_,
                                                         expr.right_field_id_,
                                                         expr.right_data_type_,
                                                         cmp_func);
        case DataType::FLOAT:
            return ExecCompareLeftType<float, CmpFunc>(expr.left_field_id_,
                                                       expr.right_field_id_,
                                                       expr.right_data_type_,
                                                       cmp_func);
        case DataType::DOUBLE:
            return ExecCompareLeftType<double, CmpFunc>(expr.left_field_id_,
                                                        expr.right_field_id_,
                                                        expr.right_data_type_,
                                                        cmp_func);
        default:
            PanicInfo("unsupported right datatype of compare expr");
    }
}

1419 1420
template <typename Op>
auto
Y
yah01 已提交
1421 1422 1423 1424 1425 1426 1427 1428 1429 1430
ExecExprVisitor::ExecCompareExprDispatcher(CompareExpr& expr, Op op)
    -> BitsetType {
    using number = boost::variant<bool,
                                  int8_t,
                                  int16_t,
                                  int32_t,
                                  int64_t,
                                  float,
                                  double,
                                  std::string>;
1431 1432 1433 1434 1435
    auto is_string_expr = [&expr]() -> bool {
        return expr.left_data_type_ == DataType::VARCHAR ||
               expr.right_data_type_ == DataType::VARCHAR;
    };

1436 1437
    auto size_per_chunk = segment_.size_per_chunk();
    auto num_chunk = upper_div(row_count_, size_per_chunk);
1438
    std::deque<BitsetType> bitsets;
1439 1440 1441 1442

    // check for sealed segment, load either raw field data or index
    auto left_indexing_barrier = segment_.num_chunk_index(expr.left_field_id_);
    auto left_data_barrier = segment_.num_chunk_data(expr.left_field_id_);
1443 1444 1445
    AssertInfo(std::max(left_data_barrier, left_indexing_barrier) == num_chunk,
               "max(left_data_barrier, left_indexing_barrier) not equal to "
               "num_chunk");
1446

Y
yah01 已提交
1447 1448
    auto right_indexing_barrier =
        segment_.num_chunk_index(expr.right_field_id_);
1449
    auto right_data_barrier = segment_.num_chunk_data(expr.right_field_id_);
Y
yah01 已提交
1450 1451 1452 1453
    AssertInfo(
        std::max(right_data_barrier, right_indexing_barrier) == num_chunk,
        "max(right_data_barrier, right_indexing_barrier) not equal to "
        "num_chunk");
1454

1455 1456 1457 1458 1459 1460 1461 1462
    // For segment both fields has no index, can use SIMD to speed up.
    // Avoiding too much call stack that blocks SIMD.
    if (left_indexing_barrier == 0 && right_indexing_barrier == 0 &&
        !is_string_expr()) {
        return ExecCompareExprDispatcherForNonIndexedSegment<Op>(expr, op);
    }

    // TODO: refactoring the code that contains too much call stack.
1463
    for (int64_t chunk_id = 0; chunk_id < num_chunk; ++chunk_id) {
Y
yah01 已提交
1464 1465 1466 1467 1468 1469
        auto size = chunk_id == num_chunk - 1
                        ? row_count_ - chunk_id * size_per_chunk
                        : size_per_chunk;
        auto getChunkData =
            [&, chunk_id](DataType type, FieldId field_id, int64_t data_barrier)
            -> std::function<const number(int)> {
1470 1471
            switch (type) {
                case DataType::BOOL: {
1472
                    if (chunk_id < data_barrier) {
Y
yah01 已提交
1473 1474 1475 1476 1477 1478
                        auto chunk_data =
                            segment_.chunk_data<bool>(field_id, chunk_id)
                                .data();
                        return [chunk_data](int i) -> const number {
                            return chunk_data[i];
                        };
1479 1480
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
Y
yah01 已提交
1481 1482 1483 1484 1485
                        auto& indexing = segment_.chunk_scalar_index<bool>(
                            field_id, chunk_id);
                        return [&indexing](int i) -> const number {
                            return indexing.Reverse_Lookup(i);
                        };
1486
                    }
1487 1488
                }
                case DataType::INT8: {
1489
                    if (chunk_id < data_barrier) {
Y
yah01 已提交
1490 1491 1492 1493 1494 1495
                        auto chunk_data =
                            segment_.chunk_data<int8_t>(field_id, chunk_id)
                                .data();
                        return [chunk_data](int i) -> const number {
                            return chunk_data[i];
                        };
1496 1497
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
Y
yah01 已提交
1498 1499 1500 1501 1502
                        auto& indexing = segment_.chunk_scalar_index<int8_t>(
                            field_id, chunk_id);
                        return [&indexing](int i) -> const number {
                            return indexing.Reverse_Lookup(i);
                        };
1503
                    }
1504 1505
                }
                case DataType::INT16: {
1506
                    if (chunk_id < data_barrier) {
Y
yah01 已提交
1507 1508 1509 1510 1511 1512
                        auto chunk_data =
                            segment_.chunk_data<int16_t>(field_id, chunk_id)
                                .data();
                        return [chunk_data](int i) -> const number {
                            return chunk_data[i];
                        };
1513 1514
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
Y
yah01 已提交
1515 1516 1517 1518 1519
                        auto& indexing = segment_.chunk_scalar_index<int16_t>(
                            field_id, chunk_id);
                        return [&indexing](int i) -> const number {
                            return indexing.Reverse_Lookup(i);
                        };
1520
                    }
1521 1522
                }
                case DataType::INT32: {
1523
                    if (chunk_id < data_barrier) {
Y
yah01 已提交
1524 1525 1526 1527 1528 1529
                        auto chunk_data =
                            segment_.chunk_data<int32_t>(field_id, chunk_id)
                                .data();
                        return [chunk_data](int i) -> const number {
                            return chunk_data[i];
                        };
1530 1531
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
Y
yah01 已提交
1532 1533 1534 1535 1536
                        auto& indexing = segment_.chunk_scalar_index<int32_t>(
                            field_id, chunk_id);
                        return [&indexing](int i) -> const number {
                            return indexing.Reverse_Lookup(i);
                        };
1537
                    }
1538 1539
                }
                case DataType::INT64: {
1540
                    if (chunk_id < data_barrier) {
Y
yah01 已提交
1541 1542 1543 1544 1545 1546
                        auto chunk_data =
                            segment_.chunk_data<int64_t>(field_id, chunk_id)
                                .data();
                        return [chunk_data](int i) -> const number {
                            return chunk_data[i];
                        };
1547 1548
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
Y
yah01 已提交
1549 1550 1551 1552 1553
                        auto& indexing = segment_.chunk_scalar_index<int64_t>(
                            field_id, chunk_id);
                        return [&indexing](int i) -> const number {
                            return indexing.Reverse_Lookup(i);
                        };
1554
                    }
1555 1556
                }
                case DataType::FLOAT: {
1557
                    if (chunk_id < data_barrier) {
Y
yah01 已提交
1558 1559 1560 1561 1562 1563
                        auto chunk_data =
                            segment_.chunk_data<float>(field_id, chunk_id)
                                .data();
                        return [chunk_data](int i) -> const number {
                            return chunk_data[i];
                        };
1564 1565
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
Y
yah01 已提交
1566 1567 1568 1569 1570
                        auto& indexing = segment_.chunk_scalar_index<float>(
                            field_id, chunk_id);
                        return [&indexing](int i) -> const number {
                            return indexing.Reverse_Lookup(i);
                        };
1571
                    }
1572 1573
                }
                case DataType::DOUBLE: {
1574
                    if (chunk_id < data_barrier) {
Y
yah01 已提交
1575 1576 1577 1578 1579 1580
                        auto chunk_data =
                            segment_.chunk_data<double>(field_id, chunk_id)
                                .data();
                        return [chunk_data](int i) -> const number {
                            return chunk_data[i];
                        };
1581 1582
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
Y
yah01 已提交
1583 1584 1585 1586 1587
                        auto& indexing = segment_.chunk_scalar_index<double>(
                            field_id, chunk_id);
                        return [&indexing](int i) -> const number {
                            return indexing.Reverse_Lookup(i);
                        };
1588
                    }
1589 1590
                }
                case DataType::VARCHAR: {
1591
                    if (chunk_id < data_barrier) {
Y
yah01 已提交
1592
                        if (segment_.type() == SegmentType::Growing) {
Y
yah01 已提交
1593 1594 1595 1596 1597 1598 1599
                            auto chunk_data =
                                segment_
                                    .chunk_data<std::string>(field_id, chunk_id)
                                    .data();
                            return [chunk_data](int i) -> const number {
                                return chunk_data[i];
                            };
Y
yah01 已提交
1600
                        } else {
Y
yah01 已提交
1601 1602 1603 1604 1605 1606 1607
                            auto chunk_data = segment_
                                                  .chunk_data<std::string_view>(
                                                      field_id, chunk_id)
                                                  .data();
                            return [chunk_data](int i) -> const number {
                                return std::string(chunk_data[i]);
                            };
Y
yah01 已提交
1608
                        }
1609 1610
                    } else {
                        // for case, sealed segment has loaded index for scalar field instead of raw data
Y
yah01 已提交
1611 1612 1613 1614 1615 1616
                        auto& indexing =
                            segment_.chunk_scalar_index<std::string>(field_id,
                                                                     chunk_id);
                        return [&indexing](int i) -> const number {
                            return indexing.Reverse_Lookup(i);
                        };
1617
                    }
1618 1619
                }
                default:
1620
                    PanicInfo(fmt::format("unsupported data type: {}", type));
1621 1622
            }
        };
Y
yah01 已提交
1623 1624 1625 1626
        auto left = getChunkData(
            expr.left_data_type_, expr.left_field_id_, left_data_barrier);
        auto right = getChunkData(
            expr.right_data_type_, expr.right_field_id_, right_data_barrier);
1627

1628
        BitsetType bitset(size);
1629
        for (int i = 0; i < size; ++i) {
Y
yah01 已提交
1630 1631
            bool is_in = boost::apply_visitor(
                Relational<decltype(op)>{}, left(i), right(i));
1632 1633 1634 1635
            bitset[i] = is_in;
        }
        bitsets.emplace_back(std::move(bitset));
    }
1636
    auto final_result = Assemble(bitsets);
Y
yah01 已提交
1637 1638
    AssertInfo(final_result.size() == row_count_,
               "[ExecExprVisitor]Size of results not equal row count");
1639
    return final_result;
1640 1641 1642 1643 1644
}

void
ExecExprVisitor::visit(CompareExpr& expr) {
    auto& schema = segment_.get_schema();
1645 1646
    auto& left_field_meta = schema[expr.left_field_id_];
    auto& right_field_meta = schema[expr.right_field_id_];
1647
    AssertInfo(expr.left_data_type_ == left_field_meta.get_data_type(),
1648 1649
               "[ExecExprVisitor]Left data type not equal to left field "
               "meta type");
1650 1651 1652
    AssertInfo(expr.right_data_type_ == right_field_meta.get_data_type(),
               "[ExecExprVisitor]right data type not equal to right field "
               "meta type");
1653

1654
    BitsetType res;
1655
    switch (expr.op_type_) {
1656
        case OpType::Equal: {
1657
            res = ExecCompareExprDispatcher(expr, std::equal_to<>{});
1658 1659 1660
            break;
        }
        case OpType::NotEqual: {
1661
            res = ExecCompareExprDispatcher(expr, std::not_equal_to<>{});
1662 1663 1664
            break;
        }
        case OpType::GreaterEqual: {
1665
            res = ExecCompareExprDispatcher(expr, std::greater_equal<>{});
1666 1667 1668
            break;
        }
        case OpType::GreaterThan: {
1669
            res = ExecCompareExprDispatcher(expr, std::greater<>{});
1670 1671 1672
            break;
        }
        case OpType::LessEqual: {
1673
            res = ExecCompareExprDispatcher(expr, std::less_equal<>{});
1674 1675 1676
            break;
        }
        case OpType::LessThan: {
1677
            res = ExecCompareExprDispatcher(expr, std::less<>{});
1678 1679
            break;
        }
1680
        case OpType::PrefixMatch: {
Y
yah01 已提交
1681 1682
            res =
                ExecCompareExprDispatcher(expr, MatchOp<OpType::PrefixMatch>{});
1683 1684 1685 1686
            break;
        }
            // case OpType::PostfixMatch: {
            // }
1687 1688 1689 1690
        default: {
            PanicInfo("unsupported optype");
        }
    }
Y
yah01 已提交
1691 1692
    AssertInfo(res.size() == row_count_,
               "[ExecExprVisitor]Size of results not equal row count");
1693
    bitset_opt_ = std::move(res);
1694 1695
}

S
sunby 已提交
1696 1697
template <typename T>
auto
1698
ExecExprVisitor::ExecTermVisitorImpl(TermExpr& expr_raw) -> BitsetType {
S
sunby 已提交
1699 1700
    auto& expr = static_cast<TermExprImpl<T>&>(expr_raw);
    auto& schema = segment_.get_schema();
1701
    auto primary_filed_id = schema.get_primary_field_id();
1702
    auto field_id = expr_raw.column_.field_id;
1703
    auto& field_meta = schema[field_id];
1704 1705

    bool use_pk_index = false;
1706
    if (primary_filed_id.has_value()) {
Y
yah01 已提交
1707 1708
        use_pk_index = primary_filed_id.value() == field_id &&
                       IsPrimaryKeyDataType(field_meta.get_data_type());
1709 1710 1711 1712
    }

    if (use_pk_index) {
        auto id_array = std::make_unique<IdArray>();
1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730
        switch (field_meta.get_data_type()) {
            case DataType::INT64: {
                auto dst_ids = id_array->mutable_int_id();
                for (const auto& id : expr.terms_) {
                    dst_ids->add_data((int64_t&)id);
                }
                break;
            }
            case DataType::VARCHAR: {
                auto dst_ids = id_array->mutable_str_id();
                for (const auto& id : expr.terms_) {
                    dst_ids->add_data((std::string&)id);
                }
                break;
            }
            default: {
                PanicInfo("unsupported type");
            }
1731
        }
1732

1733 1734
        auto [uids, seg_offsets] = segment_.search_ids(*id_array, timestamp_);
        BitsetType bitset(row_count_);
1735
        std::vector<int64_t> cached_offsets;
1736 1737 1738
        for (const auto& offset : seg_offsets) {
            auto _offset = (int64_t)offset.get();
            bitset[_offset] = true;
1739 1740 1741 1742 1743 1744
            cached_offsets.push_back(_offset);
        }
        // If enable plan_visitor pk index cache, pass offsets to it
        if (plan_visitor_ != nullptr) {
            plan_visitor_->SetExprUsePkIndex(true);
            plan_visitor_->SetExprCacheOffsets(std::move(cached_offsets));
1745
        }
Y
yah01 已提交
1746 1747
        AssertInfo(bitset.size() == row_count_,
                   "[ExecExprVisitor]Size of results not equal row count");
1748 1749 1750
        return bitset;
    }

1751
    return ExecTermVisitorImplTemplate<T>(expr_raw);
S
sunby 已提交
1752 1753
}

1754 1755
template <>
auto
Y
yah01 已提交
1756 1757
ExecExprVisitor::ExecTermVisitorImpl<std::string>(TermExpr& expr_raw)
    -> BitsetType {
1758 1759 1760
    return ExecTermVisitorImplTemplate<std::string>(expr_raw);
}

Y
yah01 已提交
1761 1762
template <>
auto
Y
yah01 已提交
1763 1764
ExecExprVisitor::ExecTermVisitorImpl<std::string_view>(TermExpr& expr_raw)
    -> BitsetType {
Y
yah01 已提交
1765 1766 1767
    return ExecTermVisitorImplTemplate<std::string_view>(expr_raw);
}

1768 1769 1770
template <typename T>
auto
ExecExprVisitor::ExecTermVisitorImplTemplate(TermExpr& expr_raw) -> BitsetType {
Y
yah01 已提交
1771 1772 1773
    typedef std::
        conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
            IndexInnerType;
Y
yah01 已提交
1774 1775
    using Index = index::ScalarIndex<IndexInnerType>;
    auto& expr = static_cast<TermExprImpl<IndexInnerType>&>(expr_raw);
Y
yah01 已提交
1776 1777
    const std::vector<IndexInnerType> terms(expr.terms_.begin(),
                                            expr.terms_.end());
1778 1779 1780
    auto n = terms.size();
    std::unordered_set<T> term_set(expr.terms_.begin(), expr.terms_.end());

Y
yah01 已提交
1781 1782 1783
    auto index_func = [&terms, n](Index* index) {
        return index->In(n, terms.data());
    };
1784
    auto elem_func = [&terms, &term_set](MayConstRef<T> x) {
1785 1786 1787 1788 1789
        //// terms has already been sorted.
        // return std::binary_search(terms.begin(), terms.end(), x);
        return term_set.find(x) != term_set.end();
    };

1790 1791
    return ExecRangeVisitorImpl<T>(
        expr.column_.field_id, index_func, elem_func);
1792 1793
}

1794 1795 1796
// TODO: bool is so ugly here.
template <>
auto
Y
yah01 已提交
1797 1798
ExecExprVisitor::ExecTermVisitorImplTemplate<bool>(TermExpr& expr_raw)
    -> BitsetType {
1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816
    using T = bool;
    auto& expr = static_cast<TermExprImpl<T>&>(expr_raw);
    using Index = index::ScalarIndex<T>;
    const auto& terms = expr.terms_;
    auto n = terms.size();
    std::unordered_set<T> term_set(expr.terms_.begin(), expr.terms_.end());

    auto index_func = [&terms, n](Index* index) {
        auto bool_arr_copy = new bool[terms.size()];
        int it = 0;
        for (auto elem : terms) {
            bool_arr_copy[it++] = elem;
        }
        auto bitset = index->In(n, bool_arr_copy);
        delete[] bool_arr_copy;
        return bitset;
    };

1817
    auto elem_func = [&terms, &term_set](MayConstRef<T> x) {
1818 1819 1820 1821 1822
        //// terms has already been sorted.
        // return std::binary_search(terms.begin(), terms.end(), x);
        return term_set.find(x) != term_set.end();
    };

1823 1824 1825 1826 1827 1828
    return ExecRangeVisitorImpl<T>(
        expr.column_.field_id, index_func, elem_func);
}

template <typename ExprValueType>
auto
1829
ExecExprVisitor::ExecTermJsonFieldInVariable(TermExpr& expr_raw) -> BitsetType {
1830 1831
    using Index = index::ScalarIndex<milvus::Json>;
    auto& expr = static_cast<TermExprImpl<ExprValueType>&>(expr_raw);
1832
    auto pointer = milvus::Json::pointer(expr.column_.nested_path);
1833
    auto index_func = [](Index* index) { return TargetBitmap{}; };
1834 1835 1836 1837 1838 1839 1840 1841 1842 1843

    std::unordered_set<ExprValueType> term_set(expr.terms_.begin(),
                                               expr.terms_.end());

    if (term_set.empty()) {
        auto elem_func = [=](const milvus::Json& json) { return false; };
        return ExecRangeVisitorImpl<milvus::Json>(
            expr.column_.field_id, index_func, elem_func);
    }

1844
    auto elem_func = [&term_set, &pointer](const milvus::Json& json) {
1845 1846 1847 1848
        using GetType =
            std::conditional_t<std::is_same_v<ExprValueType, std::string>,
                               std::string_view,
                               ExprValueType>;
1849
        auto x = json.template at<GetType>(pointer);
1850
        if (x.error()) {
1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861
            if constexpr (std::is_same_v<GetType, std::int64_t>) {
                auto x = json.template at<double>(pointer);
                if (x.error()) {
                    return false;
                }

                auto value = x.value();
                // if the term set is {1}, and the value is 1.1, we should not return true.
                return std::floor(value) == value &&
                       term_set.find(ExprValueType(value)) != term_set.end();
            }
1862 1863 1864 1865 1866 1867 1868
            return false;
        }
        return term_set.find(ExprValueType(x.value())) != term_set.end();
    };

    return ExecRangeVisitorImpl<milvus::Json>(
        expr.column_.field_id, index_func, elem_func);
1869 1870
}

1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917
template <typename ExprValueType>
auto
ExecExprVisitor::ExecTermJsonVariableInField(TermExpr& expr_raw) -> BitsetType {
    using Index = index::ScalarIndex<milvus::Json>;
    auto& expr = static_cast<TermExprImpl<ExprValueType>&>(expr_raw);
    auto pointer = milvus::Json::pointer(expr.column_.nested_path);
    auto index_func = [](Index* index) { return TargetBitmap{}; };

    AssertInfo(expr.terms_.size() == 1,
               "element length in json array must be one");
    ExprValueType target_val = expr.terms_[0];

    auto elem_func = [&target_val, &pointer](const milvus::Json& json) {
        using GetType =
            std::conditional_t<std::is_same_v<ExprValueType, std::string>,
                               std::string_view,
                               ExprValueType>;
        auto doc = json.doc();
        auto array = doc.at_pointer(pointer).get_array();
        if (array.error())
            return false;
        for (auto it = array.begin(); it != array.end(); ++it) {
            auto val = (*it).template get<GetType>();
            if (val.error()) {
                return false;
            }
            if (val.value() == target_val)
                return true;
        }
        return false;
    };

    return ExecRangeVisitorImpl<milvus::Json>(
        expr.column_.field_id, index_func, elem_func);
}

template <typename ExprValueType>
auto
ExecExprVisitor::ExecTermVisitorImplTemplateJson(TermExpr& expr_raw)
    -> BitsetType {
    if (expr_raw.is_in_field_) {
        return ExecTermJsonVariableInField<ExprValueType>(expr_raw);
    } else {
        return ExecTermJsonFieldInVariable<ExprValueType>(expr_raw);
    }
}

S
sunby 已提交
1918 1919
void
ExecExprVisitor::visit(TermExpr& expr) {
1920 1921
    auto& field_meta = segment_.get_schema()[expr.column_.field_id];
    AssertInfo(expr.column_.data_type == field_meta.get_data_type(),
1922 1923
               "[ExecExprVisitor]DataType of expr isn't field_meta "
               "data type ");
1924
    BitsetType res;
1925
    switch (expr.column_.data_type) {
S
sunby 已提交
1926
        case DataType::BOOL: {
1927
            res = ExecTermVisitorImpl<bool>(expr);
S
sunby 已提交
1928 1929 1930
            break;
        }
        case DataType::INT8: {
1931
            res = ExecTermVisitorImpl<int8_t>(expr);
S
sunby 已提交
1932 1933 1934
            break;
        }
        case DataType::INT16: {
1935
            res = ExecTermVisitorImpl<int16_t>(expr);
S
sunby 已提交
1936 1937 1938
            break;
        }
        case DataType::INT32: {
1939
            res = ExecTermVisitorImpl<int32_t>(expr);
S
sunby 已提交
1940 1941 1942
            break;
        }
        case DataType::INT64: {
1943
            res = ExecTermVisitorImpl<int64_t>(expr);
S
sunby 已提交
1944 1945 1946
            break;
        }
        case DataType::FLOAT: {
1947
            res = ExecTermVisitorImpl<float>(expr);
S
sunby 已提交
1948 1949 1950
            break;
        }
        case DataType::DOUBLE: {
1951
            res = ExecTermVisitorImpl<double>(expr);
S
sunby 已提交
1952 1953
            break;
        }
1954
        case DataType::VARCHAR: {
Y
yah01 已提交
1955 1956 1957 1958 1959
            if (segment_.type() == SegmentType::Growing) {
                res = ExecTermVisitorImpl<std::string>(expr);
            } else {
                res = ExecTermVisitorImpl<std::string_view>(expr);
            }
1960 1961
            break;
        }
1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984
        case DataType::JSON: {
            switch (expr.val_case_) {
                case proto::plan::GenericValue::ValCase::kBoolVal:
                    res = ExecTermVisitorImplTemplateJson<bool>(expr);
                    break;
                case proto::plan::GenericValue::ValCase::kInt64Val:
                    res = ExecTermVisitorImplTemplateJson<int64_t>(expr);
                    break;
                case proto::plan::GenericValue::ValCase::kFloatVal:
                    res = ExecTermVisitorImplTemplateJson<double>(expr);
                    break;
                case proto::plan::GenericValue::ValCase::kStringVal:
                    res = ExecTermVisitorImplTemplateJson<std::string>(expr);
                    break;
                case proto::plan::GenericValue::ValCase::VAL_NOT_SET:
                    res = ExecTermVisitorImplTemplateJson<bool>(expr);
                    break;
                default:
                    PanicInfo(
                        fmt::format("unknown data type: {}", expr.val_case_));
            }
            break;
        }
S
sunby 已提交
1985
        default:
1986 1987
            PanicInfo(fmt::format("unsupported data type: {}",
                                  expr.column_.data_type));
S
sunby 已提交
1988
    }
Y
yah01 已提交
1989 1990
    AssertInfo(res.size() == row_count_,
               "[ExecExprVisitor]Size of results not equal row count");
1991
    bitset_opt_ = std::move(res);
S
sunby 已提交
1992
}
1993 1994 1995 1996 1997 1998 1999

void
ExecExprVisitor::visit(ExistsExpr& expr) {
    auto& field_meta = segment_.get_schema()[expr.column_.field_id];
    AssertInfo(expr.column_.data_type == field_meta.get_data_type(),
               "[ExecExprVisitor]DataType of expr isn't field_meta data type");
    BitsetType res;
2000
    auto pointer = milvus::Json::pointer(expr.column_.nested_path);
2001 2002 2003
    switch (expr.column_.data_type) {
        case DataType::JSON: {
            using Index = index::ScalarIndex<milvus::Json>;
2004 2005 2006
            auto index_func = [&](Index* index) { return TargetBitmap{}; };
            auto elem_func = [&](const milvus::Json& json) {
                auto x = json.exist(pointer);
2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021
                return x;
            };
            res = ExecRangeVisitorImpl<milvus::Json>(
                expr.column_.field_id, index_func, elem_func);
            break;
        }
        default:
            PanicInfo(fmt::format("unsupported data type {}",
                                  expr.column_.data_type));
    }
    AssertInfo(res.size() == row_count_,
               "[ExecExprVisitor]Size of results not equal row count");
    bitset_opt_ = std::move(res);
}

N
neza2017 已提交
2022
}  // namespace milvus::query