test_nsg.cpp 4.1 KB
Newer Older
J
jinhai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

X
xj.lin 已提交
18 19 20
#include <gtest/gtest.h>
#include <memory>

21 22 23 24
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/FaissBaseIndex.h"
#include "knowhere/index/vector_index/IndexNSG.h"
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
S
starlord 已提交
25
#include "knowhere/index/vector_index/nsg/NSGIO.h"
X
xj.lin 已提交
26

27
#include "unittest/utils.h"
X
xj.lin 已提交
28

S
starlord 已提交
29 30 31 32 33
namespace {

namespace kn = zilliz::knowhere;

}  // namespace
X
xj.lin 已提交
34 35 36 37 38

using ::testing::TestWithParam;
using ::testing::Values;
using ::testing::Combine;

39
constexpr int64_t DEVICE_ID = 1;
X
xj.lin 已提交
40

S
starlord 已提交
41
class NSGInterfaceTest : public DataGen, public ::testing::Test {
X
xj.lin 已提交
42
 protected:
S
starlord 已提交
43 44 45 46
    void
    SetUp() override {
        // Init_with_default();
        kn::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICE_ID, 1024 * 1024 * 200, 1024 * 1024 * 600, 2);
47
        Generate(256, 1000000, 1);
S
starlord 已提交
48
        index_ = std::make_shared<kn::NSG>();
49

S
starlord 已提交
50
        auto tmp_conf = std::make_shared<kn::NSGCfg>();
51 52 53 54 55 56 57
        tmp_conf->gpu_id = DEVICE_ID;
        tmp_conf->knng = 100;
        tmp_conf->nprobe = 32;
        tmp_conf->nlist = 16384;
        tmp_conf->search_length = 60;
        tmp_conf->out_degree = 70;
        tmp_conf->candidate_pool_size = 500;
S
starlord 已提交
58
        tmp_conf->metric_type = kn::METRICTYPE::L2;
59 60
        train_conf = tmp_conf;

S
starlord 已提交
61
        auto tmp2_conf = std::make_shared<kn::NSGCfg>();
62 63 64
        tmp2_conf->k = k;
        tmp2_conf->search_length = 30;
        search_conf = tmp2_conf;
X
xj.lin 已提交
65 66
    }

S
starlord 已提交
67 68 69
    void
    TearDown() override {
        kn::FaissGpuResourceMgr::GetInstance().Free();
X
xj.lin 已提交
70 71
    }

X
xj.lin 已提交
72
 protected:
S
starlord 已提交
73 74 75
    std::shared_ptr<kn::NSG> index_;
    kn::Config train_conf;
    kn::Config search_conf;
X
xj.lin 已提交
76 77
};

S
starlord 已提交
78 79
void
AssertAnns(const kn::DatasetPtr& result, const int& nq, const int& k) {
X
xj.lin 已提交
80 81 82 83 84 85
    auto ids = result->array()[0];
    for (auto i = 0; i < nq; i++) {
        EXPECT_EQ(i, *(ids->data()->GetValues<int64_t>(1, i * k)));
    }
}

86
TEST_F(NSGInterfaceTest, basic_test) {
X
xj.lin 已提交
87 88
    assert(!xb.empty());

89 90
    auto model = index_->Train(base_dataset, train_conf);
    auto result = index_->Search(query_dataset, search_conf);
X
xj.lin 已提交
91 92 93
    AssertAnns(result, nq, k);

    auto binaryset = index_->Serialize();
S
starlord 已提交
94
    auto new_index = std::make_shared<kn::NSG>();
X
xj.lin 已提交
95
    new_index->Load(binaryset);
96
    auto new_result = new_index->Search(query_dataset, search_conf);
X
xj.lin 已提交
97 98
    AssertAnns(result, nq, k);

X
xj.lin 已提交
99 100
    ASSERT_EQ(index_->Count(), nb);
    ASSERT_EQ(index_->Dimension(), dim);
S
starlord 已提交
101
    ASSERT_THROW({ index_->Clone(); }, zilliz::knowhere::KnowhereException);
X
xj.lin 已提交
102
    ASSERT_NO_THROW({
S
starlord 已提交
103
        index_->Add(base_dataset, kn::Config());
X
xj.lin 已提交
104 105 106
        index_->Seal();
    });

X
xj.lin 已提交
107
    {
S
starlord 已提交
108 109 110 111 112 113 114 115 116 117 118 119 120
        // std::cout << "k = 1" << std::endl;
        // new_index->Search(GenQuery(1), Config::object{{"k", 1}});
        // new_index->Search(GenQuery(10), Config::object{{"k", 1}});
        // new_index->Search(GenQuery(100), Config::object{{"k", 1}});
        // new_index->Search(GenQuery(1000), Config::object{{"k", 1}});
        // new_index->Search(GenQuery(10000), Config::object{{"k", 1}});

        // std::cout << "k = 5" << std::endl;
        // new_index->Search(GenQuery(1), Config::object{{"k", 5}});
        // new_index->Search(GenQuery(20), Config::object{{"k", 5}});
        // new_index->Search(GenQuery(100), Config::object{{"k", 5}});
        // new_index->Search(GenQuery(300), Config::object{{"k", 5}});
        // new_index->Search(GenQuery(500), Config::object{{"k", 5}});
X
xj.lin 已提交
121 122
    }
}