ExecExprVisitor.cpp 17.3 KB
Newer Older
F
FluorineDog 已提交
1 2 3 4 5 6 7 8 9 10 11
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License

N
neza2017 已提交
12
#include <optional>
F
FluorineDog 已提交
13
#include <boost/dynamic_bitset.hpp>
14
#include <boost/variant.hpp>
F
FluorineDog 已提交
15 16
#include <utility>
#include <deque>
17
#include "segcore/SegmentGrowingImpl.h"
G
GuoRentong 已提交
18
#include "query/ExprImpl.h"
N
neza2017 已提交
19 20 21 22 23 24 25 26 27
#include "query/generated/ExecExprVisitor.h"

namespace milvus::query {
#if 1
// THIS CONTAINS EXTRA BODY FOR VISITOR
// WILL BE USED BY GENERATOR
namespace impl {
class ExecExprVisitor : ExprVisitor {
 public:
28
    using RetType = boost::dynamic_bitset<>;
29 30
    ExecExprVisitor(const segcore::SegmentInternalInterface& segment, int64_t row_count, Timestamp timestamp)
        : segment_(segment), row_count_(row_count), timestamp_(timestamp) {
N
neza2017 已提交
31 32 33 34 35 36
    }
    RetType
    call_child(Expr& expr) {
        Assert(!ret_.has_value());
        expr.accept(*this);
        Assert(ret_.has_value());
37
        auto res = std::move(ret_);
N
neza2017 已提交
38
        ret_ = std::nullopt;
39
        return std::move(res.value());
N
neza2017 已提交
40 41
    }

G
GuoRentong 已提交
42
 public:
F
FluorineDog 已提交
43
    template <typename T, typename IndexFunc, typename ElementFunc>
G
GuoRentong 已提交
44
    auto
F
FluorineDog 已提交
45
    ExecRangeVisitorImpl(RangeExprImpl<T>& expr, IndexFunc func, ElementFunc element_func) -> RetType;
G
GuoRentong 已提交
46 47 48 49 50

    template <typename T>
    auto
    ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType;

S
sunby 已提交
51 52 53 54
    template <typename T>
    auto
    ExecTermVisitorImpl(TermExpr& expr_raw) -> RetType;

55 56 57 58
    template <typename CmpFunc>
    auto
    ExecCompareExprDispatcher(CompareExpr& expr, CmpFunc cmp_func) -> RetType;

N
neza2017 已提交
59
 private:
60 61
    const segcore::SegmentInternalInterface& segment_;
    int64_t row_count_;
N
neza2017 已提交
62
    std::optional<RetType> ret_;
63
    Timestamp timestamp_;
N
neza2017 已提交
64 65 66 67 68
};
}  // namespace impl
#endif

void
F
FluorineDog 已提交
69 70
ExecExprVisitor::visit(LogicalUnaryExpr& expr) {
    using OpType = LogicalUnaryExpr::OpType;
71 72 73 74 75 76 77 78 79
    auto child_res = call_child(*expr.child_);
    RetType res = std::move(child_res);
    switch (expr.op_type_) {
        case OpType::LogicalNot: {
            res.flip();
            break;
        }
        default: {
            PanicInfo("Invalid Unary Op");
F
FluorineDog 已提交
80 81
        }
    }
82 83
    Assert(res.size() == row_count_);
    ret_ = std::move(res);
N
neza2017 已提交
84 85 86
}

void
F
FluorineDog 已提交
87 88
ExecExprVisitor::visit(LogicalBinaryExpr& expr) {
    using OpType = LogicalBinaryExpr::OpType;
F
FluorineDog 已提交
89 90 91
    auto left = call_child(*expr.left_);
    auto right = call_child(*expr.right_);
    Assert(left.size() == right.size());
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
    auto res = std::move(left);
    switch (expr.op_type_) {
        case OpType::LogicalAnd: {
            res &= right;
            break;
        }
        case OpType::LogicalOr: {
            res |= right;
            break;
        }
        case OpType::LogicalXor: {
            res ^= right;
            break;
        }
        case OpType::LogicalMinus: {
            res -= right;
            break;
        }
        default: {
            PanicInfo("Invalid Binary Op");
        }
    }
    Assert(res.size() == row_count_);
    ret_ = std::move(res);
}
F
FluorineDog 已提交
117

118 119 120 121 122 123 124 125 126 127 128 129 130 131
static auto
Assemble(const std::deque<boost::dynamic_bitset<>>& srcs) -> boost::dynamic_bitset<> {
    boost::dynamic_bitset<> res;

    int64_t total_size = 0;
    for (auto& chunk : srcs) {
        total_size += chunk.size();
    }
    res.resize(total_size);

    int64_t counter = 0;
    for (auto& chunk : srcs) {
        for (int64_t i = 0; i < chunk.size(); ++i) {
            res[counter + i] = chunk[i];
F
FluorineDog 已提交
132
        }
133
        counter += chunk.size();
F
FluorineDog 已提交
134
    }
135
    return res;
N
neza2017 已提交
136 137
}

F
FluorineDog 已提交
138
template <typename T, typename IndexFunc, typename ElementFunc>
G
GuoRentong 已提交
139
auto
F
FluorineDog 已提交
140 141
ExecExprVisitor::ExecRangeVisitorImpl(RangeExprImpl<T>& expr, IndexFunc index_func, ElementFunc element_func)
    -> RetType {
G
GuoRentong 已提交
142
    auto& schema = segment_.get_schema();
G
GuoRentong 已提交
143
    auto field_offset = expr.field_offset_;
G
GuoRentong 已提交
144
    auto& field_meta = schema[field_offset];
B
BossZou 已提交
145 146 147
    auto indexing_barrier = segment_.num_chunk_index(field_offset);
    auto size_per_chunk = segment_.size_per_chunk();
    auto num_chunk = upper_div(row_count_, size_per_chunk);
148
    std::deque<boost::dynamic_bitset<>> results;
149 150

    using Index = knowhere::scalar::StructuredIndex<T>;
F
FluorineDog 已提交
151
    for (auto chunk_id = 0; chunk_id < indexing_barrier; ++chunk_id) {
152 153 154 155
        const Index& indexing = segment_.chunk_scalar_index<T>(field_offset, chunk_id);
        // NOTE: knowhere is not const-ready
        // This is a dirty workaround
        auto data = index_func(const_cast<Index*>(&indexing));
B
BossZou 已提交
156
        Assert(data->size() == size_per_chunk);
157
        results.emplace_back(std::move(*data));
F
FluorineDog 已提交
158
    }
159
    for (auto chunk_id = indexing_barrier; chunk_id < num_chunk; ++chunk_id) {
160 161
        auto this_size = chunk_id == num_chunk - 1 ? row_count_ - chunk_id * size_per_chunk : size_per_chunk;
        boost::dynamic_bitset<> result(this_size);
162
        auto chunk = segment_.chunk_data<T>(field_offset, chunk_id);
G
GuoRentong 已提交
163
        const T* data = chunk.data();
164
        for (int index = 0; index < this_size; ++index) {
F
FluorineDog 已提交
165
            result[index] = element_func(data[index]);
G
GuoRentong 已提交
166
        }
167
        Assert(result.size() == this_size);
168
        results.emplace_back(std::move(result));
G
GuoRentong 已提交
169
    }
170 171 172
    auto final_result = Assemble(results);
    Assert(final_result.size() == row_count_);
    return final_result;
G
GuoRentong 已提交
173
}
174

G
GuoRentong 已提交
175 176 177 178 179 180 181 182
#pragma clang diagnostic push
#pragma ide diagnostic ignored "Simplify"
template <typename T>
auto
ExecExprVisitor::ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType {
    auto& expr = static_cast<RangeExprImpl<T>&>(expr_raw);
    auto conditions = expr.conditions_;
    std::sort(conditions.begin(), conditions.end());
F
FluorineDog 已提交
183 184
    using Index = knowhere::scalar::StructuredIndex<T>;
    using Operator = knowhere::scalar::OperatorType;
G
GuoRentong 已提交
185 186 187 188 189 190 191
    if (conditions.size() == 1) {
        auto cond = conditions[0];
        // auto [op, val] = cond; // strange bug on capture
        auto op = std::get<0>(cond);
        auto val = std::get<1>(cond);
        switch (op) {
            case OpType::Equal: {
F
FluorineDog 已提交
192
                auto index_func = [val](Index* index) { return index->In(1, &val); };
S
sunby 已提交
193
                return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x == val); });
G
GuoRentong 已提交
194 195 196
            }

            case OpType::NotEqual: {
N
neza2017 已提交
197
                auto index_func = [val](Index* index) { return index->NotIn(1, &val); };
S
sunby 已提交
198
                return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x != val); });
G
GuoRentong 已提交
199 200 201
            }

            case OpType::GreaterEqual: {
F
FluorineDog 已提交
202
                auto index_func = [val](Index* index) { return index->Range(val, Operator::GE); };
S
sunby 已提交
203
                return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x >= val); });
G
GuoRentong 已提交
204 205 206
            }

            case OpType::GreaterThan: {
F
FluorineDog 已提交
207
                auto index_func = [val](Index* index) { return index->Range(val, Operator::GT); };
S
sunby 已提交
208
                return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x > val); });
G
GuoRentong 已提交
209 210 211
            }

            case OpType::LessEqual: {
F
FluorineDog 已提交
212
                auto index_func = [val](Index* index) { return index->Range(val, Operator::LE); };
S
sunby 已提交
213
                return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x <= val); });
G
GuoRentong 已提交
214 215 216
            }

            case OpType::LessThan: {
F
FluorineDog 已提交
217
                auto index_func = [val](Index* index) { return index->Range(val, Operator::LT); };
S
sunby 已提交
218
                return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x < val); });
G
GuoRentong 已提交
219 220 221 222 223 224 225 226 227 228
            }
            default: {
                PanicInfo("unsupported range node");
            }
        }
    } else if (conditions.size() == 2) {
        OpType op1, op2;
        T val1, val2;
        std::tie(op1, val1) = conditions[0];
        std::tie(op2, val2) = conditions[1];
Z
zhenshan.cao 已提交
229 230
        // TODO: disable check?
        if (val1 > val2) {
231 232
            RetType res(row_count_, false);
            return res;
Z
zhenshan.cao 已提交
233
        }
G
GuoRentong 已提交
234 235 236
        auto ops = std::make_tuple(op1, op2);
        if (false) {
        } else if (ops == std::make_tuple(OpType::GreaterThan, OpType::LessThan)) {
F
FluorineDog 已提交
237
            auto index_func = [val1, val2](Index* index) { return index->Range(val1, false, val2, false); };
S
sunby 已提交
238
            return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 < x && x < val2); });
G
GuoRentong 已提交
239
        } else if (ops == std::make_tuple(OpType::GreaterThan, OpType::LessEqual)) {
F
FluorineDog 已提交
240
            auto index_func = [val1, val2](Index* index) { return index->Range(val1, false, val2, true); };
S
sunby 已提交
241
            return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 < x && x <= val2); });
G
GuoRentong 已提交
242
        } else if (ops == std::make_tuple(OpType::GreaterEqual, OpType::LessThan)) {
F
FluorineDog 已提交
243
            auto index_func = [val1, val2](Index* index) { return index->Range(val1, true, val2, false); };
S
sunby 已提交
244
            return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 <= x && x < val2); });
G
GuoRentong 已提交
245
        } else if (ops == std::make_tuple(OpType::GreaterEqual, OpType::LessEqual)) {
F
FluorineDog 已提交
246
            auto index_func = [val1, val2](Index* index) { return index->Range(val1, true, val2, true); };
S
sunby 已提交
247
            return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 <= x && x <= val2); });
G
GuoRentong 已提交
248 249 250 251 252 253 254 255 256
        } else {
            PanicInfo("unsupported range node");
        }
    } else {
        PanicInfo("unsupported range node");
    }
}
#pragma clang diagnostic pop

N
neza2017 已提交
257 258
void
ExecExprVisitor::visit(RangeExpr& expr) {
G
GuoRentong 已提交
259
    auto& field_meta = segment_.get_schema()[expr.field_offset_];
G
GuoRentong 已提交
260
    Assert(expr.data_type_ == field_meta.get_data_type());
261
    RetType res;
G
GuoRentong 已提交
262
    switch (expr.data_type_) {
N
neza2017 已提交
263
        case DataType::BOOL: {
264
            res = ExecRangeVisitorDispatcher<bool>(expr);
N
neza2017 已提交
265 266
            break;
        }
G
GuoRentong 已提交
267
        case DataType::INT8: {
268
            res = ExecRangeVisitorDispatcher<int8_t>(expr);
G
GuoRentong 已提交
269 270 271
            break;
        }
        case DataType::INT16: {
272
            res = ExecRangeVisitorDispatcher<int16_t>(expr);
G
GuoRentong 已提交
273 274 275
            break;
        }
        case DataType::INT32: {
276
            res = ExecRangeVisitorDispatcher<int32_t>(expr);
G
GuoRentong 已提交
277 278 279
            break;
        }
        case DataType::INT64: {
280
            res = ExecRangeVisitorDispatcher<int64_t>(expr);
G
GuoRentong 已提交
281 282 283
            break;
        }
        case DataType::FLOAT: {
284
            res = ExecRangeVisitorDispatcher<float>(expr);
G
GuoRentong 已提交
285 286 287
            break;
        }
        case DataType::DOUBLE: {
288
            res = ExecRangeVisitorDispatcher<double>(expr);
G
GuoRentong 已提交
289 290 291 292 293
            break;
        }
        default:
            PanicInfo("unsupported");
    }
294 295
    Assert(res.size() == row_count_);
    ret_ = std::move(res);
N
neza2017 已提交
296 297
}

298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
template <typename Op>
struct relational {
    template <typename T, typename U>
    bool
    operator()(T const& a, U const& b) const {
        return Op{}(a, b);
    }
    template <typename... T>
    bool
    operator()(T const&...) const {
        PanicInfo("incompatible operands");
    }
};

using number = boost::variant<bool, int8_t, int16_t, int32_t, int64_t, float, double>;

template <typename Op>
auto
ExecExprVisitor::ExecCompareExprDispatcher(CompareExpr& expr, Op op) -> RetType {
    auto size_per_chunk = segment_.size_per_chunk();
    auto num_chunk = upper_div(row_count_, size_per_chunk);
319
    std::deque<RetType> bitsets;
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358
    for (int64_t chunk_id = 0; chunk_id < num_chunk; ++chunk_id) {
        auto size = chunk_id == num_chunk - 1 ? row_count_ - chunk_id * size_per_chunk : size_per_chunk;
        auto getChunkData = [&, chunk_id](DataType type, FieldOffset offset) -> std::function<const number(int)> {
            switch (type) {
                case DataType::BOOL: {
                    auto chunk = segment_.chunk_data<bool>(offset, chunk_id);
                    return [chunk](int i) -> const number { return chunk.data()[i]; };
                }
                case DataType::INT8: {
                    auto chunk = segment_.chunk_data<int8_t>(offset, chunk_id);
                    return [chunk](int i) -> const number { return chunk.data()[i]; };
                }
                case DataType::INT16: {
                    auto chunk = segment_.chunk_data<int16_t>(offset, chunk_id);
                    return [chunk](int i) -> const number { return chunk.data()[i]; };
                }
                case DataType::INT32: {
                    auto chunk = segment_.chunk_data<int32_t>(offset, chunk_id);
                    return [chunk](int i) -> const number { return chunk.data()[i]; };
                }
                case DataType::INT64: {
                    auto chunk = segment_.chunk_data<int64_t>(offset, chunk_id);
                    return [chunk](int i) -> const number { return chunk.data()[i]; };
                }
                case DataType::FLOAT: {
                    auto chunk = segment_.chunk_data<float>(offset, chunk_id);
                    return [chunk](int i) -> const number { return chunk.data()[i]; };
                }
                case DataType::DOUBLE: {
                    auto chunk = segment_.chunk_data<double>(offset, chunk_id);
                    return [chunk](int i) -> const number { return chunk.data()[i]; };
                }
                default:
                    PanicInfo("unsupported datatype");
            }
        };
        auto left = getChunkData(expr.data_types_[0], expr.field_offsets_[0]);
        auto right = getChunkData(expr.data_types_[1], expr.field_offsets_[1]);

359
        boost::dynamic_bitset<> bitset(size);
360 361 362 363 364 365
        for (int i = 0; i < size; ++i) {
            bool is_in = boost::apply_visitor(relational<decltype(op)>{}, left(i), right(i));
            bitset[i] = is_in;
        }
        bitsets.emplace_back(std::move(bitset));
    }
366 367 368
    auto final_result = Assemble(bitsets);
    Assert(final_result.size() == row_count_);
    return final_result;
369 370 371 372 373 374 375 376 377 378 379 380 381
}

void
ExecExprVisitor::visit(CompareExpr& expr) {
    Assert(expr.data_types_.size() == expr.field_offsets_.size());
    Assert(expr.data_types_.size() == 2);
    auto& schema = segment_.get_schema();

    for (auto i = 0; i < expr.field_offsets_.size(); i++) {
        auto& field_meta = schema[expr.field_offsets_[i]];
        Assert(expr.data_types_[i] == field_meta.get_data_type());
    }

382
    RetType res;
383 384
    switch (expr.op) {
        case OpType::Equal: {
385
            res = ExecCompareExprDispatcher(expr, std::equal_to<>{});
386 387 388
            break;
        }
        case OpType::NotEqual: {
389
            res = ExecCompareExprDispatcher(expr, std::not_equal_to<>{});
390 391 392
            break;
        }
        case OpType::GreaterEqual: {
393
            res = ExecCompareExprDispatcher(expr, std::greater_equal<>{});
394 395 396
            break;
        }
        case OpType::GreaterThan: {
397
            res = ExecCompareExprDispatcher(expr, std::greater<>{});
398 399 400
            break;
        }
        case OpType::LessEqual: {
401
            res = ExecCompareExprDispatcher(expr, std::less_equal<>{});
402 403 404
            break;
        }
        case OpType::LessThan: {
405
            res = ExecCompareExprDispatcher(expr, std::less<>{});
406 407 408 409 410 411
            break;
        }
        default: {
            PanicInfo("unsupported optype");
        }
    }
412 413
    Assert(res.size() == row_count_);
    ret_ = std::move(res);
414 415
}

S
sunby 已提交
416 417 418 419 420
template <typename T>
auto
ExecExprVisitor::ExecTermVisitorImpl(TermExpr& expr_raw) -> RetType {
    auto& expr = static_cast<TermExprImpl<T>&>(expr_raw);
    auto& schema = segment_.get_schema();
G
GuoRentong 已提交
421 422

    auto field_offset = expr_raw.field_offset_;
S
sunby 已提交
423
    auto& field_meta = schema[field_offset];
B
BossZou 已提交
424 425
    auto size_per_chunk = segment_.size_per_chunk();
    auto num_chunk = upper_div(row_count_, size_per_chunk);
426
    std::deque<RetType> bitsets;
S
sunby 已提交
427
    for (int64_t chunk_id = 0; chunk_id < num_chunk; ++chunk_id) {
428
        Span<T> chunk = segment_.chunk_data<T>(field_offset, chunk_id);
S
sunby 已提交
429

B
BossZou 已提交
430
        auto size = chunk_id == num_chunk - 1 ? row_count_ - chunk_id * size_per_chunk : size_per_chunk;
S
sunby 已提交
431

432
        boost::dynamic_bitset<> bitset(size);
S
sunby 已提交
433
        for (int i = 0; i < size; ++i) {
434
            auto value = chunk.data()[i];
S
sunby 已提交
435 436 437 438 439
            bool is_in = std::binary_search(expr.terms_.begin(), expr.terms_.end(), value);
            bitset[i] = is_in;
        }
        bitsets.emplace_back(std::move(bitset));
    }
440 441 442
    auto final_result = Assemble(bitsets);
    Assert(final_result.size() == row_count_);
    return final_result;
S
sunby 已提交
443 444 445 446
}

void
ExecExprVisitor::visit(TermExpr& expr) {
G
GuoRentong 已提交
447
    auto& field_meta = segment_.get_schema()[expr.field_offset_];
S
sunby 已提交
448
    Assert(expr.data_type_ == field_meta.get_data_type());
449
    RetType res;
S
sunby 已提交
450 451
    switch (expr.data_type_) {
        case DataType::BOOL: {
452
            res = ExecTermVisitorImpl<bool>(expr);
S
sunby 已提交
453 454 455
            break;
        }
        case DataType::INT8: {
456
            res = ExecTermVisitorImpl<int8_t>(expr);
S
sunby 已提交
457 458 459
            break;
        }
        case DataType::INT16: {
460
            res = ExecTermVisitorImpl<int16_t>(expr);
S
sunby 已提交
461 462 463
            break;
        }
        case DataType::INT32: {
464
            res = ExecTermVisitorImpl<int32_t>(expr);
S
sunby 已提交
465 466 467
            break;
        }
        case DataType::INT64: {
468
            res = ExecTermVisitorImpl<int64_t>(expr);
S
sunby 已提交
469 470 471
            break;
        }
        case DataType::FLOAT: {
472
            res = ExecTermVisitorImpl<float>(expr);
S
sunby 已提交
473 474 475
            break;
        }
        case DataType::DOUBLE: {
476
            res = ExecTermVisitorImpl<double>(expr);
S
sunby 已提交
477 478 479 480 481
            break;
        }
        default:
            PanicInfo("unsupported");
    }
482 483
    Assert(res.size() == row_count_);
    ret_ = std::move(res);
S
sunby 已提交
484
}
N
neza2017 已提交
485
}  // namespace milvus::query