search_test.cpp 10.9 KB
Newer Older
J
jinhai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

18
#include "db/scheduler/task/SearchTask.h"
S
starlord 已提交
19
#include "server/ServerConfig.h"
20
#include "utils/TimeRecorder.h"
21

J
jinhai 已提交
22 23
#include <gtest/gtest.h>
#include <cmath>
24
#include <vector>
W
wxyu 已提交
25 26
#include <src/scheduler/task/SearchTask.h>

27 28 29 30 31 32 33 34 35

using namespace zilliz::milvus;

namespace {

static constexpr uint64_t NQ = 15;
static constexpr uint64_t TOP_K = 64;

void BuildResult(uint64_t nq,
36 37
                 uint64_t topk,
                 bool ascending,
38 39 40
                 std::vector<long> &output_ids,
                 std::vector<float> &output_distence) {
    output_ids.clear();
41
    output_ids.resize(nq*topk);
42
    output_distence.clear();
43
    output_distence.resize(nq*topk);
44 45

    for(uint64_t i = 0; i < nq; i++) {
46 47 48
        for(uint64_t j = 0; j < topk; j++) {
            output_ids[i * topk + j] = (long)(drand48()*100000);
            output_distence[i * topk + j] = ascending ? (j + drand48()) : ((topk - j) + drand48());
49 50 51 52 53 54
        }
    }
}

void CheckResult(const engine::SearchContext::Id2DistanceMap& src_1,
        const engine::SearchContext::Id2DistanceMap& src_2,
55 56
        const engine::SearchContext::Id2DistanceMap& target,
        bool ascending) {
57
    for(uint64_t i = 0; i < target.size() - 1; i++) {
58 59 60 61 62
        if(ascending) {
            ASSERT_LE(target[i].second, target[i + 1].second);
        } else {
            ASSERT_GE(target[i].second, target[i + 1].second);
        }
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
    }

    using ID2DistMap = std::map<long, float>;
    ID2DistMap src_map_1, src_map_2;
    for(const auto& pair : src_1) {
        src_map_1.insert(pair);
    }
    for(const auto& pair : src_2) {
        src_map_2.insert(pair);
    }

    for(const auto& pair : target) {
        ASSERT_TRUE(src_map_1.find(pair.first) != src_map_1.end() || src_map_2.find(pair.first) != src_map_2.end());

        float dist = src_map_1.find(pair.first) != src_map_1.end() ? src_map_1[pair.first] : src_map_2[pair.first];
        ASSERT_LT(fabs(pair.second - dist), std::numeric_limits<float>::epsilon());
    }
}

82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
void CheckCluster(const std::vector<long>& target_ids,
        const std::vector<float>& target_distence,
        const engine::SearchContext::ResultSet& src_result,
        int64_t nq,
        int64_t topk) {
    ASSERT_EQ(src_result.size(), nq);
    for(int64_t i = 0; i < nq; i++) {
        auto& res = src_result[i];
        ASSERT_EQ(res.size(), topk);

        if(res.empty()) {
            continue;
        }

        ASSERT_EQ(res[0].first, target_ids[i*topk]);
        ASSERT_EQ(res[topk - 1].first, target_ids[i*topk + topk - 1]);
    }
}

void CheckTopkResult(const engine::SearchContext::ResultSet& src_result,
                     bool ascending,
                     int64_t nq,
                     int64_t topk) {
    ASSERT_EQ(src_result.size(), nq);
    for(int64_t i = 0; i < nq; i++) {
        auto& res = src_result[i];
        ASSERT_EQ(res.size(), topk);

        if(res.empty()) {
            continue;
        }

        for(int64_t k = 0; k < topk - 1; k++) {
            if(ascending) {
                ASSERT_LE(res[k].second, res[k + 1].second);
            } else {
                ASSERT_GE(res[k].second, res[k + 1].second);
            }
        }
    }
}

124 125 126
}

TEST(DBSearchTest, TOPK_TEST) {
127
    bool ascending = true;
128 129 130
    std::vector<long> target_ids;
    std::vector<float> target_distence;
    engine::SearchContext::ResultSet src_result;
W
wxyu 已提交
131
    auto status = engine::XSearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result);
132 133 134
    ASSERT_FALSE(status.ok());
    ASSERT_TRUE(src_result.empty());

135
    BuildResult(NQ, TOP_K, ascending, target_ids, target_distence);
W
wxyu 已提交
136
    status = engine::XSearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result);
137 138 139 140
    ASSERT_TRUE(status.ok());
    ASSERT_EQ(src_result.size(), NQ);

    engine::SearchContext::ResultSet target_result;
W
wxyu 已提交
141
    status = engine::XSearchTask::TopkResult(target_result, TOP_K, ascending, target_result);
142 143
    ASSERT_TRUE(status.ok());

W
wxyu 已提交
144
    status = engine::XSearchTask::TopkResult(target_result, TOP_K, ascending, src_result);
145 146
    ASSERT_FALSE(status.ok());

W
wxyu 已提交
147
    status = engine::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
148 149 150 151 152 153 154
    ASSERT_TRUE(status.ok());
    ASSERT_TRUE(src_result.empty());
    ASSERT_EQ(target_result.size(), NQ);

    std::vector<long> src_ids;
    std::vector<float> src_distence;
    uint64_t wrong_topk = TOP_K - 10;
155
    BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence);
156

W
wxyu 已提交
157
    status = engine::XSearchTask::ClusterResult(src_ids, src_distence, NQ, wrong_topk, src_result);
158 159
    ASSERT_TRUE(status.ok());

W
wxyu 已提交
160
    status = engine::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
161 162 163 164 165 166
    ASSERT_TRUE(status.ok());
    for(uint64_t i = 0; i < NQ; i++) {
        ASSERT_EQ(target_result[i].size(), TOP_K);
    }

    wrong_topk = TOP_K + 10;
167
    BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence);
168

W
wxyu 已提交
169
    status = engine::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
170 171 172 173 174 175 176
    ASSERT_TRUE(status.ok());
    for(uint64_t i = 0; i < NQ; i++) {
        ASSERT_EQ(target_result[i].size(), TOP_K);
    }
}

TEST(DBSearchTest, MERGE_TEST) {
177
    bool ascending = true;
178 179 180 181 182 183 184
    std::vector<long> target_ids;
    std::vector<float> target_distence;
    std::vector<long> src_ids;
    std::vector<float> src_distence;
    engine::SearchContext::ResultSet src_result, target_result;

    uint64_t src_count = 5, target_count = 8;
185 186
    BuildResult(1, src_count, ascending, src_ids, src_distence);
    BuildResult(1, target_count, ascending, target_ids, target_distence);
W
wxyu 已提交
187
    auto status = engine::XSearchTask::ClusterResult(src_ids, src_distence, 1, src_count, src_result);
188
    ASSERT_TRUE(status.ok());
W
wxyu 已提交
189
    status = engine::XSearchTask::ClusterResult(target_ids, target_distence, 1, target_count, target_result);
190 191 192 193 194
    ASSERT_TRUE(status.ok());

    {
        engine::SearchContext::Id2DistanceMap src = src_result[0];
        engine::SearchContext::Id2DistanceMap target = target_result[0];
W
wxyu 已提交
195
        status = engine::XSearchTask::MergeResult(src, target, 10, ascending);
196 197
        ASSERT_TRUE(status.ok());
        ASSERT_EQ(target.size(), 10);
198
        CheckResult(src_result[0], target_result[0], target, ascending);
199 200 201 202 203
    }

    {
        engine::SearchContext::Id2DistanceMap src = src_result[0];
        engine::SearchContext::Id2DistanceMap target;
W
wxyu 已提交
204
        status = engine::XSearchTask::MergeResult(src, target, 10, ascending);
205 206 207
        ASSERT_TRUE(status.ok());
        ASSERT_EQ(target.size(), src_count);
        ASSERT_TRUE(src.empty());
208
        CheckResult(src_result[0], target_result[0], target, ascending);
209 210 211 212 213
    }

    {
        engine::SearchContext::Id2DistanceMap src = src_result[0];
        engine::SearchContext::Id2DistanceMap target = target_result[0];
W
wxyu 已提交
214
        status = engine::XSearchTask::MergeResult(src, target, 30, ascending);
215 216
        ASSERT_TRUE(status.ok());
        ASSERT_EQ(target.size(), src_count + target_count);
217
        CheckResult(src_result[0], target_result[0], target, ascending);
218 219 220 221 222
    }

    {
        engine::SearchContext::Id2DistanceMap target = src_result[0];
        engine::SearchContext::Id2DistanceMap src = target_result[0];
W
wxyu 已提交
223
        status = engine::XSearchTask::MergeResult(src, target, 30, ascending);
224 225
        ASSERT_TRUE(status.ok());
        ASSERT_EQ(target.size(), src_count + target_count);
226
        CheckResult(src_result[0], target_result[0], target, ascending);
227
    }
J
jinhai 已提交
228
}
229 230

TEST(DBSearchTest, PARALLEL_CLUSTER_TEST) {
S
starlord 已提交
231 232
    server::ServerConfig &config = server::ServerConfig::GetInstance();
    server::ConfigNode& db_config = config.GetConfig(server::CONFIG_DB);
S
starlord 已提交
233
    db_config.SetValue(server::CONFIG_DB_PARALLEL_REDUCE, "false");//lvoc cannot work for std::function, set to false
S
starlord 已提交
234

235 236 237 238 239 240
    bool ascending = true;
    std::vector<long> target_ids;
    std::vector<float> target_distence;
    engine::SearchContext::ResultSet src_result;

    auto DoCluster = [&](int64_t nq, int64_t topk) {
S
starlord 已提交
241
        TimeRecorder rc("DoCluster");
242 243 244 245
        src_result.clear();
        BuildResult(nq, topk, ascending, target_ids, target_distence);
        rc.RecordSection("build id/dietance map");

W
wxyu 已提交
246
        auto status = engine::XSearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result);
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
        ASSERT_TRUE(status.ok());
        ASSERT_EQ(src_result.size(), nq);

        rc.RecordSection("cluster result");

        CheckCluster(target_ids, target_distence, src_result, nq, topk);
        rc.RecordSection("check result");
    };

    DoCluster(10000, 1000);
    DoCluster(333, 999);
    DoCluster(1, 1000);
    DoCluster(1, 1);
    DoCluster(7, 0);
    DoCluster(9999, 1);
    DoCluster(10001, 1);
    DoCluster(58273, 1234);
}

TEST(DBSearchTest, PARALLEL_TOPK_TEST) {
S
starlord 已提交
267 268
    server::ServerConfig &config = server::ServerConfig::GetInstance();
    server::ConfigNode& db_config = config.GetConfig(server::CONFIG_DB);
S
starlord 已提交
269
    db_config.SetValue(server::CONFIG_DB_PARALLEL_REDUCE, "false");//lvoc cannot work for std::function, set to false
S
starlord 已提交
270

271 272 273 274 275 276 277 278 279 280 281 282
    std::vector<long> target_ids;
    std::vector<float> target_distence;
    engine::SearchContext::ResultSet src_result;

    std::vector<long> insufficient_ids;
    std::vector<float> insufficient_distence;
    engine::SearchContext::ResultSet insufficient_result;

    auto DoTopk = [&](int64_t nq, int64_t topk,int64_t insufficient_topk, bool ascending) {
        src_result.clear();
        insufficient_result.clear();

S
starlord 已提交
283
        TimeRecorder rc("DoCluster");
284 285

        BuildResult(nq, topk, ascending, target_ids, target_distence);
W
wxyu 已提交
286
        auto status = engine::XSearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result);
287 288 289
        rc.RecordSection("cluster result");

        BuildResult(nq, insufficient_topk, ascending, insufficient_ids, insufficient_distence);
W
wxyu 已提交
290
        status = engine::XSearchTask::ClusterResult(target_ids, target_distence, nq, insufficient_topk, insufficient_result);
291 292
        rc.RecordSection("cluster result");

W
wxyu 已提交
293
        engine::XSearchTask::TopkResult(insufficient_result, topk, ascending, src_result);
294 295 296 297 298 299 300 301 302
        ASSERT_TRUE(status.ok());
        rc.RecordSection("topk");

        CheckTopkResult(src_result, ascending, nq, topk);
        rc.RecordSection("check result");
    };

    DoTopk(5, 10, 4, false);
    DoTopk(20005, 998, 123, true);
S
starlord 已提交
303 304 305
//    DoTopk(9987, 12, 10, false);
//    DoTopk(77777, 1000, 1, false);
//    DoTopk(5432, 8899, 8899, true);
306
}