search_test.cpp 10.1 KB
Newer Older
J
jinhai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

J
jinhai 已提交
18 19
#include <gtest/gtest.h>
#include <cmath>
G
groot 已提交
20
#include <vector>
21 22 23

#include "scheduler/task/SearchTask.h"
#include "utils/TimeRecorder.h"
W
wxyu 已提交
24

G
groot 已提交
25 26 27 28 29 30 31 32
using namespace zilliz::milvus;

namespace {

static constexpr uint64_t NQ = 15;
static constexpr uint64_t TOP_K = 64;

void BuildResult(uint64_t nq,
33 34
                 uint64_t topk,
                 bool ascending,
G
groot 已提交
35 36 37
                 std::vector<long> &output_ids,
                 std::vector<float> &output_distence) {
    output_ids.clear();
38
    output_ids.resize(nq*topk);
G
groot 已提交
39
    output_distence.clear();
40
    output_distence.resize(nq*topk);
G
groot 已提交
41 42

    for(uint64_t i = 0; i < nq; i++) {
43 44 45
        for(uint64_t j = 0; j < topk; j++) {
            output_ids[i * topk + j] = (long)(drand48()*100000);
            output_distence[i * topk + j] = ascending ? (j + drand48()) : ((topk - j) + drand48());
G
groot 已提交
46 47 48 49
        }
    }
}

W
wxyu 已提交
50 51 52
void CheckResult(const scheduler::Id2DistanceMap& src_1,
        const scheduler::Id2DistanceMap& src_2,
        const scheduler::Id2DistanceMap& target,
53
        bool ascending) {
G
groot 已提交
54
    for(uint64_t i = 0; i < target.size() - 1; i++) {
55 56 57 58 59
        if(ascending) {
            ASSERT_LE(target[i].second, target[i + 1].second);
        } else {
            ASSERT_GE(target[i].second, target[i + 1].second);
        }
G
groot 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
    }

    using ID2DistMap = std::map<long, float>;
    ID2DistMap src_map_1, src_map_2;
    for(const auto& pair : src_1) {
        src_map_1.insert(pair);
    }
    for(const auto& pair : src_2) {
        src_map_2.insert(pair);
    }

    for(const auto& pair : target) {
        ASSERT_TRUE(src_map_1.find(pair.first) != src_map_1.end() || src_map_2.find(pair.first) != src_map_2.end());

        float dist = src_map_1.find(pair.first) != src_map_1.end() ? src_map_1[pair.first] : src_map_2[pair.first];
        ASSERT_LT(fabs(pair.second - dist), std::numeric_limits<float>::epsilon());
    }
}

79 80
void CheckCluster(const std::vector<long>& target_ids,
        const std::vector<float>& target_distence,
W
wxyu 已提交
81
        const scheduler::ResultSet& src_result,
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
        int64_t nq,
        int64_t topk) {
    ASSERT_EQ(src_result.size(), nq);
    for(int64_t i = 0; i < nq; i++) {
        auto& res = src_result[i];
        ASSERT_EQ(res.size(), topk);

        if(res.empty()) {
            continue;
        }

        ASSERT_EQ(res[0].first, target_ids[i*topk]);
        ASSERT_EQ(res[topk - 1].first, target_ids[i*topk + topk - 1]);
    }
}

W
wxyu 已提交
98
void CheckTopkResult(const scheduler::ResultSet& src_result,
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
                     bool ascending,
                     int64_t nq,
                     int64_t topk) {
    ASSERT_EQ(src_result.size(), nq);
    for(int64_t i = 0; i < nq; i++) {
        auto& res = src_result[i];
        ASSERT_EQ(res.size(), topk);

        if(res.empty()) {
            continue;
        }

        for(int64_t k = 0; k < topk - 1; k++) {
            if(ascending) {
                ASSERT_LE(res[k].second, res[k + 1].second);
            } else {
                ASSERT_GE(res[k].second, res[k + 1].second);
            }
        }
    }
}

G
groot 已提交
121 122 123
}

TEST(DBSearchTest, TOPK_TEST) {
124
    bool ascending = true;
G
groot 已提交
125 126
    std::vector<long> target_ids;
    std::vector<float> target_distence;
W
wxyu 已提交
127
    scheduler::ResultSet src_result;
W
wxyu 已提交
128
    auto status = scheduler::XSearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result);
G
groot 已提交
129 130 131
    ASSERT_FALSE(status.ok());
    ASSERT_TRUE(src_result.empty());

132
    BuildResult(NQ, TOP_K, ascending, target_ids, target_distence);
W
wxyu 已提交
133
    status = scheduler::XSearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result);
G
groot 已提交
134 135 136
    ASSERT_TRUE(status.ok());
    ASSERT_EQ(src_result.size(), NQ);

W
wxyu 已提交
137
    scheduler::ResultSet target_result;
W
wxyu 已提交
138
    status = scheduler::XSearchTask::TopkResult(target_result, TOP_K, ascending, target_result);
G
groot 已提交
139 140
    ASSERT_TRUE(status.ok());

W
wxyu 已提交
141
    status = scheduler::XSearchTask::TopkResult(target_result, TOP_K, ascending, src_result);
G
groot 已提交
142 143
    ASSERT_FALSE(status.ok());

W
wxyu 已提交
144
    status = scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
G
groot 已提交
145 146 147 148 149 150 151
    ASSERT_TRUE(status.ok());
    ASSERT_TRUE(src_result.empty());
    ASSERT_EQ(target_result.size(), NQ);

    std::vector<long> src_ids;
    std::vector<float> src_distence;
    uint64_t wrong_topk = TOP_K - 10;
152
    BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence);
G
groot 已提交
153

W
wxyu 已提交
154
    status = scheduler::XSearchTask::ClusterResult(src_ids, src_distence, NQ, wrong_topk, src_result);
G
groot 已提交
155 156
    ASSERT_TRUE(status.ok());

W
wxyu 已提交
157
    status = scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
G
groot 已提交
158 159 160 161 162 163
    ASSERT_TRUE(status.ok());
    for(uint64_t i = 0; i < NQ; i++) {
        ASSERT_EQ(target_result[i].size(), TOP_K);
    }

    wrong_topk = TOP_K + 10;
164
    BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence);
G
groot 已提交
165

W
wxyu 已提交
166
    status = scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
G
groot 已提交
167 168 169 170 171 172 173
    ASSERT_TRUE(status.ok());
    for(uint64_t i = 0; i < NQ; i++) {
        ASSERT_EQ(target_result[i].size(), TOP_K);
    }
}

TEST(DBSearchTest, MERGE_TEST) {
174
    bool ascending = true;
G
groot 已提交
175 176 177 178
    std::vector<long> target_ids;
    std::vector<float> target_distence;
    std::vector<long> src_ids;
    std::vector<float> src_distence;
W
wxyu 已提交
179
    scheduler::ResultSet src_result, target_result;
G
groot 已提交
180 181

    uint64_t src_count = 5, target_count = 8;
182 183
    BuildResult(1, src_count, ascending, src_ids, src_distence);
    BuildResult(1, target_count, ascending, target_ids, target_distence);
W
wxyu 已提交
184
    auto status = scheduler::XSearchTask::ClusterResult(src_ids, src_distence, 1, src_count, src_result);
G
groot 已提交
185
    ASSERT_TRUE(status.ok());
W
wxyu 已提交
186
    status = scheduler::XSearchTask::ClusterResult(target_ids, target_distence, 1, target_count, target_result);
G
groot 已提交
187 188 189
    ASSERT_TRUE(status.ok());

    {
W
wxyu 已提交
190 191
        scheduler::Id2DistanceMap src = src_result[0];
        scheduler::Id2DistanceMap target = target_result[0];
W
wxyu 已提交
192
        status = scheduler::XSearchTask::MergeResult(src, target, 10, ascending);
G
groot 已提交
193 194
        ASSERT_TRUE(status.ok());
        ASSERT_EQ(target.size(), 10);
195
        CheckResult(src_result[0], target_result[0], target, ascending);
G
groot 已提交
196 197 198
    }

    {
W
wxyu 已提交
199 200
        scheduler::Id2DistanceMap src = src_result[0];
        scheduler::Id2DistanceMap target;
W
wxyu 已提交
201
        status = scheduler::XSearchTask::MergeResult(src, target, 10, ascending);
G
groot 已提交
202 203 204
        ASSERT_TRUE(status.ok());
        ASSERT_EQ(target.size(), src_count);
        ASSERT_TRUE(src.empty());
205
        CheckResult(src_result[0], target_result[0], target, ascending);
G
groot 已提交
206 207 208
    }

    {
W
wxyu 已提交
209 210
        scheduler::Id2DistanceMap src = src_result[0];
        scheduler::Id2DistanceMap target = target_result[0];
W
wxyu 已提交
211
        status = scheduler::XSearchTask::MergeResult(src, target, 30, ascending);
G
groot 已提交
212 213
        ASSERT_TRUE(status.ok());
        ASSERT_EQ(target.size(), src_count + target_count);
214
        CheckResult(src_result[0], target_result[0], target, ascending);
G
groot 已提交
215 216 217
    }

    {
W
wxyu 已提交
218 219
        scheduler::Id2DistanceMap target = src_result[0];
        scheduler::Id2DistanceMap src = target_result[0];
W
wxyu 已提交
220
        status = scheduler::XSearchTask::MergeResult(src, target, 30, ascending);
G
groot 已提交
221 222
        ASSERT_TRUE(status.ok());
        ASSERT_EQ(target.size(), src_count + target_count);
223
        CheckResult(src_result[0], target_result[0], target, ascending);
G
groot 已提交
224
    }
J
jinhai 已提交
225
}
226 227 228 229 230

TEST(DBSearchTest, PARALLEL_CLUSTER_TEST) {
    bool ascending = true;
    std::vector<long> target_ids;
    std::vector<float> target_distence;
W
wxyu 已提交
231
    scheduler::ResultSet src_result;
232 233

    auto DoCluster = [&](int64_t nq, int64_t topk) {
G
groot 已提交
234
        TimeRecorder rc("DoCluster");
235 236 237 238
        src_result.clear();
        BuildResult(nq, topk, ascending, target_ids, target_distence);
        rc.RecordSection("build id/dietance map");

W
wxyu 已提交
239
        auto status = scheduler::XSearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result);
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
        ASSERT_TRUE(status.ok());
        ASSERT_EQ(src_result.size(), nq);

        rc.RecordSection("cluster result");

        CheckCluster(target_ids, target_distence, src_result, nq, topk);
        rc.RecordSection("check result");
    };

    DoCluster(10000, 1000);
    DoCluster(333, 999);
    DoCluster(1, 1000);
    DoCluster(1, 1);
    DoCluster(7, 0);
    DoCluster(9999, 1);
    DoCluster(10001, 1);
    DoCluster(58273, 1234);
}

TEST(DBSearchTest, PARALLEL_TOPK_TEST) {
    std::vector<long> target_ids;
    std::vector<float> target_distence;
W
wxyu 已提交
262
    scheduler::ResultSet src_result;
263 264 265

    std::vector<long> insufficient_ids;
    std::vector<float> insufficient_distence;
W
wxyu 已提交
266
    scheduler::ResultSet insufficient_result;
267 268 269 270 271

    auto DoTopk = [&](int64_t nq, int64_t topk,int64_t insufficient_topk, bool ascending) {
        src_result.clear();
        insufficient_result.clear();

G
groot 已提交
272
        TimeRecorder rc("DoCluster");
273 274

        BuildResult(nq, topk, ascending, target_ids, target_distence);
W
wxyu 已提交
275
        auto status = scheduler::XSearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result);
276 277 278
        rc.RecordSection("cluster result");

        BuildResult(nq, insufficient_topk, ascending, insufficient_ids, insufficient_distence);
W
wxyu 已提交
279
        status = scheduler::XSearchTask::ClusterResult(target_ids, target_distence, nq, insufficient_topk, insufficient_result);
280 281
        rc.RecordSection("cluster result");

W
wxyu 已提交
282
        scheduler::XSearchTask::TopkResult(insufficient_result, topk, ascending, src_result);
283 284 285 286 287 288 289 290 291
        ASSERT_TRUE(status.ok());
        rc.RecordSection("topk");

        CheckTopkResult(src_result, ascending, nq, topk);
        rc.RecordSection("check result");
    };

    DoTopk(5, 10, 4, false);
    DoTopk(20005, 998, 123, true);
G
groot 已提交
292 293 294
//    DoTopk(9987, 12, 10, false);
//    DoTopk(77777, 1000, 1, false);
//    DoTopk(5432, 8899, 8899, true);
G
groot 已提交
295
}