test_kdt.cpp 5.5 KB
Newer Older
J
jinhai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

X
xj.lin 已提交
18 19 20 21 22
#include <gtest/gtest.h>

#include <iostream>
#include <sstream>

S
starlord 已提交
23 24
#include "knowhere/adapter/SptagAdapter.h"
#include "knowhere/adapter/Structure.h"
X
xiaojun.lin 已提交
25
#include "knowhere/common/Exception.h"
X
xiaojun.lin 已提交
26
#include "knowhere/index/vector_index/IndexKDT.h"
X
xiaojun.lin 已提交
27
#include "knowhere/index/vector_index/helpers/Definitions.h"
X
xj.lin 已提交
28

29
#include "unittest/utils.h"
S
starlord 已提交
30 31

namespace {
X
xj.lin 已提交
32

S
starlord 已提交
33
namespace kn = knowhere;
X
xj.lin 已提交
34

S
starlord 已提交
35
}  // namespace
X
xj.lin 已提交
36

37
using ::testing::Combine;
X
xj.lin 已提交
38 39 40
using ::testing::TestWithParam;
using ::testing::Values;

S
starlord 已提交
41
class KDTTest : public DataGen, public ::testing::Test {
X
xj.lin 已提交
42
 protected:
S
starlord 已提交
43 44 45
    void
    SetUp() override {
        index_ = std::make_shared<kn::CPUKDTRNG>();
X
xiaojun.lin 已提交
46

S
starlord 已提交
47
        auto tempconf = std::make_shared<kn::KDTCfg>();
X
xiaojun.lin 已提交
48 49 50 51
        tempconf->tptnubmber = 1;
        tempconf->k = 10;
        conf = tempconf;

X
xj.lin 已提交
52 53 54 55
        Init_with_default();
    }

 protected:
S
starlord 已提交
56 57
    kn::Config conf;
    std::shared_ptr<kn::CPUKDTRNG> index_ = nullptr;
X
xj.lin 已提交
58 59
};

S
starlord 已提交
60 61
void
AssertAnns(const kn::DatasetPtr& result, const int& nq, const int& k) {
X
xj.lin 已提交
62 63 64 65 66 67
    auto ids = result->array()[0];
    for (auto i = 0; i < nq; i++) {
        EXPECT_EQ(i, *(ids->data()->GetValues<int64_t>(1, i * k)));
    }
}

S
starlord 已提交
68 69
void
PrintResult(const kn::DatasetPtr& result, const int& nq, const int& k) {
X
xj.lin 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
    auto ids = result->array()[0];
    auto dists = result->array()[1];

    std::stringstream ss_id;
    std::stringstream ss_dist;
    for (auto i = 0; i < 10; i++) {
        for (auto j = 0; j < k; ++j) {
            ss_id << *(ids->data()->GetValues<int64_t>(1, i * k + j)) << " ";
            ss_dist << *(dists->data()->GetValues<float>(1, i * k + j)) << " ";
        }
        ss_id << std::endl;
        ss_dist << std::endl;
    }
    std::cout << "id\n" << ss_id.str() << std::endl;
    std::cout << "dist\n" << ss_dist.str() << std::endl;
}

S
starlord 已提交
87
// TODO(lxj): add test about count() and dimension()
X
xiaojun.lin 已提交
88
TEST_F(KDTTest, kdt_basic) {
X
xj.lin 已提交
89 90
    assert(!xb.empty());

X
xiaojun.lin 已提交
91
    auto preprocessor = index_->BuildPreprocessor(base_dataset, conf);
X
xj.lin 已提交
92 93
    index_->set_preprocessor(preprocessor);

X
xiaojun.lin 已提交
94
    auto model = index_->Train(base_dataset, conf);
X
xj.lin 已提交
95
    index_->set_index_model(model);
X
xiaojun.lin 已提交
96 97
    index_->Add(base_dataset, conf);
    auto result = index_->Search(query_dataset, conf);
X
xj.lin 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
    AssertAnns(result, nq, k);

    {
        auto ids = result->array()[0];
        auto dists = result->array()[1];

        std::stringstream ss_id;
        std::stringstream ss_dist;
        for (auto i = 0; i < nq; i++) {
            for (auto j = 0; j < k; ++j) {
                ss_id << *ids->data()->GetValues<int64_t>(1, i * k + j) << " ";
                ss_dist << *dists->data()->GetValues<float>(1, i * k + j) << " ";
            }
            ss_id << std::endl;
            ss_dist << std::endl;
        }
        std::cout << "id\n" << ss_id.str() << std::endl;
        std::cout << "dist\n" << ss_dist.str() << std::endl;
    }
}

X
xiaojun.lin 已提交
119
TEST_F(KDTTest, kdt_serialize) {
X
xj.lin 已提交
120 121
    assert(!xb.empty());

X
xiaojun.lin 已提交
122
    auto preprocessor = index_->BuildPreprocessor(base_dataset, conf);
X
xj.lin 已提交
123 124
    index_->set_preprocessor(preprocessor);

X
xiaojun.lin 已提交
125
    auto model = index_->Train(base_dataset, conf);
S
starlord 已提交
126
    // index_->Add(base_dataset, conf);
X
xj.lin 已提交
127
    auto binaryset = index_->Serialize();
S
starlord 已提交
128
    auto new_index = std::make_shared<kn::CPUKDTRNG>();
X
xj.lin 已提交
129
    new_index->Load(binaryset);
X
xiaojun.lin 已提交
130
    auto result = new_index->Search(query_dataset, conf);
X
xj.lin 已提交
131 132
    AssertAnns(result, nq, k);
    PrintResult(result, nq, k);
X
xj.lin 已提交
133 134
    ASSERT_EQ(new_index->Count(), nb);
    ASSERT_EQ(new_index->Dimension(), dim);
S
starlord 已提交
135
    ASSERT_THROW({ new_index->Clone(); }, knowhere::KnowhereException);
S
starlord 已提交
136
    ASSERT_NO_THROW({ new_index->Seal(); });
X
xj.lin 已提交
137 138 139

    {
        int fileno = 0;
S
starlord 已提交
140
        const std::string& base_name = "/tmp/kdt_serialize_test_bin_";
X
xj.lin 已提交
141
        std::vector<std::string> filename_list;
S
starlord 已提交
142 143 144
        std::vector<std::pair<std::string, size_t>> meta_list;
        for (auto& iter : binaryset.binary_map_) {
            const std::string& filename = base_name + std::to_string(fileno);
X
xj.lin 已提交
145 146 147 148 149 150 151 152
            FileIOWriter writer(filename);
            writer(iter.second->data.get(), iter.second->size);

            meta_list.emplace_back(std::make_pair(iter.first, iter.second->size));
            filename_list.push_back(filename);
            ++fileno;
        }

S
starlord 已提交
153
        kn::BinarySet load_data_list;
X
xj.lin 已提交
154 155 156 157 158 159 160 161 162 163 164
        for (int i = 0; i < filename_list.size() && i < meta_list.size(); ++i) {
            auto bin_size = meta_list[i].second;
            FileIOReader reader(filename_list[i]);

            auto load_data = new uint8_t[bin_size];
            reader(load_data, bin_size);
            auto data = std::make_shared<uint8_t>();
            data.reset(load_data);
            load_data_list.Append(meta_list[i].first, data, bin_size);
        }

S
starlord 已提交
165
        auto new_index = std::make_shared<kn::CPUKDTRNG>();
X
xj.lin 已提交
166
        new_index->Load(load_data_list);
X
xiaojun.lin 已提交
167
        auto result = new_index->Search(query_dataset, conf);
X
xj.lin 已提交
168 169 170 171
        AssertAnns(result, nq, k);
        PrintResult(result, nq, k);
    }
}