test_nsg.cpp 4.1 KB
Newer Older
J
jinhai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

X
xj.lin 已提交
18 19 20
#include <gtest/gtest.h>
#include <memory>

21 22 23 24
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/FaissBaseIndex.h"
#include "knowhere/index/vector_index/IndexNSG.h"
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
S
starlord 已提交
25
#include "knowhere/index/vector_index/nsg/NSGIO.h"
X
xj.lin 已提交
26

27
#include "unittest/utils.h"
X
xj.lin 已提交
28

29
using ::testing::Combine;
X
xj.lin 已提交
30 31 32
using ::testing::TestWithParam;
using ::testing::Values;

33
constexpr int64_t DEVICE_ID = 1;
X
xj.lin 已提交
34

S
starlord 已提交
35
class NSGInterfaceTest : public DataGen, public ::testing::Test {
X
xj.lin 已提交
36
 protected:
S
starlord 已提交
37 38 39
    void
    SetUp() override {
        // Init_with_default();
S
starlord 已提交
40
        knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICE_ID, 1024 * 1024 * 200, 1024 * 1024 * 600, 2);
X
xiaojun.lin 已提交
41
        Generate(256, 1000000 / 100, 1);
S
starlord 已提交
42
        index_ = std::make_shared<knowhere::NSG>();
43

S
starlord 已提交
44
        auto tmp_conf = std::make_shared<knowhere::NSGCfg>();
45
        tmp_conf->gpu_id = DEVICE_ID;
X
xiaojun.lin 已提交
46 47 48 49 50 51
        tmp_conf->knng = 20;
        tmp_conf->nprobe = 8;
        tmp_conf->nlist = 163;
        tmp_conf->search_length = 40;
        tmp_conf->out_degree = 30;
        tmp_conf->candidate_pool_size = 100;
S
starlord 已提交
52
        tmp_conf->metric_type = knowhere::METRICTYPE::L2;
53 54
        train_conf = tmp_conf;

S
starlord 已提交
55
        auto tmp2_conf = std::make_shared<knowhere::NSGCfg>();
56 57 58
        tmp2_conf->k = k;
        tmp2_conf->search_length = 30;
        search_conf = tmp2_conf;
X
xj.lin 已提交
59 60
    }

S
starlord 已提交
61 62
    void
    TearDown() override {
S
starlord 已提交
63
        knowhere::FaissGpuResourceMgr::GetInstance().Free();
X
xj.lin 已提交
64 65
    }

X
xj.lin 已提交
66
 protected:
S
starlord 已提交
67 68 69
    std::shared_ptr<knowhere::NSG> index_;
    knowhere::Config train_conf;
    knowhere::Config search_conf;
X
xj.lin 已提交
70 71
};

S
starlord 已提交
72
void
S
starlord 已提交
73
AssertAnns(const knowhere::DatasetPtr& result, const int& nq, const int& k) {
X
xj.lin 已提交
74 75 76 77 78 79
    auto ids = result->array()[0];
    for (auto i = 0; i < nq; i++) {
        EXPECT_EQ(i, *(ids->data()->GetValues<int64_t>(1, i * k)));
    }
}

80
TEST_F(NSGInterfaceTest, basic_test) {
X
xj.lin 已提交
81 82
    assert(!xb.empty());

83 84
    auto model = index_->Train(base_dataset, train_conf);
    auto result = index_->Search(query_dataset, search_conf);
X
xj.lin 已提交
85 86 87
    AssertAnns(result, nq, k);

    auto binaryset = index_->Serialize();
S
starlord 已提交
88
    auto new_index = std::make_shared<knowhere::NSG>();
X
xj.lin 已提交
89
    new_index->Load(binaryset);
90
    auto new_result = new_index->Search(query_dataset, search_conf);
X
xj.lin 已提交
91 92
    AssertAnns(result, nq, k);

X
xj.lin 已提交
93 94
    ASSERT_EQ(index_->Count(), nb);
    ASSERT_EQ(index_->Dimension(), dim);
S
starlord 已提交
95
    ASSERT_THROW({ index_->Clone(); }, knowhere::KnowhereException);
X
xj.lin 已提交
96
    ASSERT_NO_THROW({
S
starlord 已提交
97
        index_->Add(base_dataset, knowhere::Config());
X
xj.lin 已提交
98 99 100
        index_->Seal();
    });

X
xj.lin 已提交
101
    {
S
starlord 已提交
102 103 104 105 106 107 108 109 110 111 112 113 114
        // std::cout << "k = 1" << std::endl;
        // new_index->Search(GenQuery(1), Config::object{{"k", 1}});
        // new_index->Search(GenQuery(10), Config::object{{"k", 1}});
        // new_index->Search(GenQuery(100), Config::object{{"k", 1}});
        // new_index->Search(GenQuery(1000), Config::object{{"k", 1}});
        // new_index->Search(GenQuery(10000), Config::object{{"k", 1}});

        // std::cout << "k = 5" << std::endl;
        // new_index->Search(GenQuery(1), Config::object{{"k", 5}});
        // new_index->Search(GenQuery(20), Config::object{{"k", 5}});
        // new_index->Search(GenQuery(100), Config::object{{"k", 5}});
        // new_index->Search(GenQuery(300), Config::object{{"k", 5}});
        // new_index->Search(GenQuery(500), Config::object{{"k", 5}});
X
xj.lin 已提交
115 116
    }
}