test_search.cpp 10.6 KB
Newer Older
J
jinhai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

J
jinhai 已提交
18 19
#include <gtest/gtest.h>
#include <cmath>
20
#include <vector>
21 22 23

#include "scheduler/task/SearchTask.h"
#include "utils/TimeRecorder.h"
W
wxyu 已提交
24

25 26
namespace {

S
starlord 已提交
27
namespace ms = milvus;
S
starlord 已提交
28

29 30 31
static constexpr uint64_t NQ = 15;
static constexpr uint64_t TOP_K = 64;

S
starlord 已提交
32 33 34 35 36 37
void
BuildResult(uint64_t nq,
            uint64_t topk,
            bool ascending,
            std::vector<int64_t> &output_ids,
            std::vector<float> &output_distence) {
38
    output_ids.clear();
S
starlord 已提交
39
    output_ids.resize(nq * topk);
40
    output_distence.clear();
S
starlord 已提交
41
    output_distence.resize(nq * topk);
42

S
starlord 已提交
43 44 45
    for (uint64_t i = 0; i < nq; i++) {
        for (uint64_t j = 0; j < topk; j++) {
            output_ids[i * topk + j] = (int64_t) (drand48() * 100000);
46
            output_distence[i * topk + j] = ascending ? (j + drand48()) : ((topk - j) + drand48());
47 48 49 50
        }
    }
}

S
starlord 已提交
51 52 53 54 55 56 57
void
CheckResult(const ms::scheduler::Id2DistanceMap &src_1,
            const ms::scheduler::Id2DistanceMap &src_2,
            const ms::scheduler::Id2DistanceMap &target,
            bool ascending) {
    for (uint64_t i = 0; i < target.size() - 1; i++) {
        if (ascending) {
58 59 60 61
            ASSERT_LE(target[i].second, target[i + 1].second);
        } else {
            ASSERT_GE(target[i].second, target[i + 1].second);
        }
62 63
    }

S
starlord 已提交
64
    using ID2DistMap = std::map<int64_t, float>;
65
    ID2DistMap src_map_1, src_map_2;
S
starlord 已提交
66
    for (const auto &pair : src_1) {
67 68
        src_map_1.insert(pair);
    }
S
starlord 已提交
69
    for (const auto &pair : src_2) {
70 71 72
        src_map_2.insert(pair);
    }

S
starlord 已提交
73
    for (const auto &pair : target) {
74 75 76 77 78 79 80
        ASSERT_TRUE(src_map_1.find(pair.first) != src_map_1.end() || src_map_2.find(pair.first) != src_map_2.end());

        float dist = src_map_1.find(pair.first) != src_map_1.end() ? src_map_1[pair.first] : src_map_2[pair.first];
        ASSERT_LT(fabs(pair.second - dist), std::numeric_limits<float>::epsilon());
    }
}

S
starlord 已提交
81 82 83 84 85 86
void
CheckCluster(const std::vector<int64_t> &target_ids,
             const std::vector<float> &target_distence,
             const ms::scheduler::ResultSet &src_result,
             int64_t nq,
             int64_t topk) {
87
    ASSERT_EQ(src_result.size(), nq);
S
starlord 已提交
88 89
    for (int64_t i = 0; i < nq; i++) {
        auto &res = src_result[i];
90 91
        ASSERT_EQ(res.size(), topk);

S
starlord 已提交
92
        if (res.empty()) {
93 94 95
            continue;
        }

S
starlord 已提交
96 97
        ASSERT_EQ(res[0].first, target_ids[i * topk]);
        ASSERT_EQ(res[topk - 1].first, target_ids[i * topk + topk - 1]);
98 99 100
    }
}

S
starlord 已提交
101 102 103 104 105
void
CheckTopkResult(const ms::scheduler::ResultSet &src_result,
                bool ascending,
                int64_t nq,
                int64_t topk) {
106
    ASSERT_EQ(src_result.size(), nq);
S
starlord 已提交
107 108
    for (int64_t i = 0; i < nq; i++) {
        auto &res = src_result[i];
109 110
        ASSERT_EQ(res.size(), topk);

S
starlord 已提交
111
        if (res.empty()) {
112 113 114
            continue;
        }

S
starlord 已提交
115 116
        for (int64_t k = 0; k < topk - 1; k++) {
            if (ascending) {
117 118 119 120 121 122 123 124
                ASSERT_LE(res[k].second, res[k + 1].second);
            } else {
                ASSERT_GE(res[k].second, res[k + 1].second);
            }
        }
    }
}

S
starlord 已提交
125
} // namespace
126 127

TEST(DBSearchTest, TOPK_TEST) {
128
    bool ascending = true;
S
starlord 已提交
129
    std::vector<int64_t> target_ids;
130
    std::vector<float> target_distence;
S
starlord 已提交
131 132
    ms::scheduler::ResultSet src_result;
    auto status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result);
133 134 135
    ASSERT_FALSE(status.ok());
    ASSERT_TRUE(src_result.empty());

136
    BuildResult(NQ, TOP_K, ascending, target_ids, target_distence);
S
starlord 已提交
137
    status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result);
138 139 140
    ASSERT_TRUE(status.ok());
    ASSERT_EQ(src_result.size(), NQ);

S
starlord 已提交
141 142
    ms::scheduler::ResultSet target_result;
    status = ms::scheduler::XSearchTask::TopkResult(target_result, TOP_K, ascending, target_result);
143 144
    ASSERT_TRUE(status.ok());

S
starlord 已提交
145
    status = ms::scheduler::XSearchTask::TopkResult(target_result, TOP_K, ascending, src_result);
146 147
    ASSERT_FALSE(status.ok());

S
starlord 已提交
148
    status = ms::scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
149 150 151 152
    ASSERT_TRUE(status.ok());
    ASSERT_TRUE(src_result.empty());
    ASSERT_EQ(target_result.size(), NQ);

S
starlord 已提交
153
    std::vector<int64_t> src_ids;
154 155
    std::vector<float> src_distence;
    uint64_t wrong_topk = TOP_K - 10;
156
    BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence);
157

S
starlord 已提交
158
    status = ms::scheduler::XSearchTask::ClusterResult(src_ids, src_distence, NQ, wrong_topk, src_result);
159 160
    ASSERT_TRUE(status.ok());

S
starlord 已提交
161
    status = ms::scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
162
    ASSERT_TRUE(status.ok());
S
starlord 已提交
163
    for (uint64_t i = 0; i < NQ; i++) {
164 165 166 167
        ASSERT_EQ(target_result[i].size(), TOP_K);
    }

    wrong_topk = TOP_K + 10;
168
    BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence);
169

S
starlord 已提交
170
    status = ms::scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
171
    ASSERT_TRUE(status.ok());
S
starlord 已提交
172
    for (uint64_t i = 0; i < NQ; i++) {
173 174 175 176 177
        ASSERT_EQ(target_result[i].size(), TOP_K);
    }
}

TEST(DBSearchTest, MERGE_TEST) {
178
    bool ascending = true;
S
starlord 已提交
179
    std::vector<int64_t> target_ids;
180
    std::vector<float> target_distence;
S
starlord 已提交
181
    std::vector<int64_t> src_ids;
182
    std::vector<float> src_distence;
S
starlord 已提交
183
    ms::scheduler::ResultSet src_result, target_result;
184 185

    uint64_t src_count = 5, target_count = 8;
186 187
    BuildResult(1, src_count, ascending, src_ids, src_distence);
    BuildResult(1, target_count, ascending, target_ids, target_distence);
S
starlord 已提交
188
    auto status = ms::scheduler::XSearchTask::ClusterResult(src_ids, src_distence, 1, src_count, src_result);
189
    ASSERT_TRUE(status.ok());
S
starlord 已提交
190
    status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, 1, target_count, target_result);
191 192 193
    ASSERT_TRUE(status.ok());

    {
S
starlord 已提交
194 195 196
        ms::scheduler::Id2DistanceMap src = src_result[0];
        ms::scheduler::Id2DistanceMap target = target_result[0];
        status = ms::scheduler::XSearchTask::MergeResult(src, target, 10, ascending);
197 198
        ASSERT_TRUE(status.ok());
        ASSERT_EQ(target.size(), 10);
199
        CheckResult(src_result[0], target_result[0], target, ascending);
200 201 202
    }

    {
S
starlord 已提交
203 204 205
        ms::scheduler::Id2DistanceMap src = src_result[0];
        ms::scheduler::Id2DistanceMap target;
        status = ms::scheduler::XSearchTask::MergeResult(src, target, 10, ascending);
206 207 208
        ASSERT_TRUE(status.ok());
        ASSERT_EQ(target.size(), src_count);
        ASSERT_TRUE(src.empty());
209
        CheckResult(src_result[0], target_result[0], target, ascending);
210 211 212
    }

    {
S
starlord 已提交
213 214 215
        ms::scheduler::Id2DistanceMap src = src_result[0];
        ms::scheduler::Id2DistanceMap target = target_result[0];
        status = ms::scheduler::XSearchTask::MergeResult(src, target, 30, ascending);
216 217
        ASSERT_TRUE(status.ok());
        ASSERT_EQ(target.size(), src_count + target_count);
218
        CheckResult(src_result[0], target_result[0], target, ascending);
219 220 221
    }

    {
S
starlord 已提交
222 223 224
        ms::scheduler::Id2DistanceMap target = src_result[0];
        ms::scheduler::Id2DistanceMap src = target_result[0];
        status = ms::scheduler::XSearchTask::MergeResult(src, target, 30, ascending);
225 226
        ASSERT_TRUE(status.ok());
        ASSERT_EQ(target.size(), src_count + target_count);
227
        CheckResult(src_result[0], target_result[0], target, ascending);
228
    }
J
jinhai 已提交
229
}
230 231 232

TEST(DBSearchTest, PARALLEL_CLUSTER_TEST) {
    bool ascending = true;
S
starlord 已提交
233
    std::vector<int64_t> target_ids;
234
    std::vector<float> target_distence;
S
starlord 已提交
235
    ms::scheduler::ResultSet src_result;
236 237

    auto DoCluster = [&](int64_t nq, int64_t topk) {
S
starlord 已提交
238
        ms::TimeRecorder rc("DoCluster");
239 240 241 242
        src_result.clear();
        BuildResult(nq, topk, ascending, target_ids, target_distence);
        rc.RecordSection("build id/dietance map");

S
starlord 已提交
243
        auto status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result);
244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
        ASSERT_TRUE(status.ok());
        ASSERT_EQ(src_result.size(), nq);

        rc.RecordSection("cluster result");

        CheckCluster(target_ids, target_distence, src_result, nq, topk);
        rc.RecordSection("check result");
    };

    DoCluster(10000, 1000);
    DoCluster(333, 999);
    DoCluster(1, 1000);
    DoCluster(1, 1);
    DoCluster(7, 0);
    DoCluster(9999, 1);
    DoCluster(10001, 1);
    DoCluster(58273, 1234);
}

TEST(DBSearchTest, PARALLEL_TOPK_TEST) {
S
starlord 已提交
264
    std::vector<int64_t> target_ids;
265
    std::vector<float> target_distence;
S
starlord 已提交
266
    ms::scheduler::ResultSet src_result;
267

S
starlord 已提交
268
    std::vector<int64_t> insufficient_ids;
269
    std::vector<float> insufficient_distence;
S
starlord 已提交
270
    ms::scheduler::ResultSet insufficient_result;
271

S
starlord 已提交
272
    auto DoTopk = [&](int64_t nq, int64_t topk, int64_t insufficient_topk, bool ascending) {
273 274 275
        src_result.clear();
        insufficient_result.clear();

S
starlord 已提交
276
        ms::TimeRecorder rc("DoCluster");
277 278

        BuildResult(nq, topk, ascending, target_ids, target_distence);
S
starlord 已提交
279
        auto status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result);
280 281 282
        rc.RecordSection("cluster result");

        BuildResult(nq, insufficient_topk, ascending, insufficient_ids, insufficient_distence);
S
starlord 已提交
283 284 285 286 287
        status = ms::scheduler::XSearchTask::ClusterResult(target_ids,
                                                           target_distence,
                                                           nq,
                                                           insufficient_topk,
                                                           insufficient_result);
288 289
        rc.RecordSection("cluster result");

S
starlord 已提交
290
        ms::scheduler::XSearchTask::TopkResult(insufficient_result, topk, ascending, src_result);
291 292 293 294 295 296 297 298 299
        ASSERT_TRUE(status.ok());
        rc.RecordSection("topk");

        CheckTopkResult(src_result, ascending, nq, topk);
        rc.RecordSection("check result");
    };

    DoTopk(5, 10, 4, false);
    DoTopk(20005, 998, 123, true);
S
starlord 已提交
300 301 302
//    DoTopk(9987, 12, 10, false);
//    DoTopk(77777, 1000, 1, false);
//    DoTopk(5432, 8899, 8899, true);
S
starlord 已提交
303
}