test_idmap.cpp 6.1 KB
Newer Older
J
jinhai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

X
xj.lin 已提交
18 19 20
#include <gtest/gtest.h>
#include <iostream>

X
xiaojun.lin 已提交
21 22
#include "knowhere/adapter/Structure.h"
#include "knowhere/common/Exception.h"
S
starlord 已提交
23 24 25
#include "knowhere/index/vector_index/IndexIDMAP.h"
#include "knowhere/index/vector_index/helpers/Cloner.h"

X
xiaojun.lin 已提交
26
#include "Helper.h"
X
xiaojun.lin 已提交
27
#include "unittest/utils.h"
X
xj.lin 已提交
28

X
xiaojun.lin 已提交
29
class IDMAPTest : public DataGen, public TestGpuIndexBase {
X
xj.lin 已提交
30
 protected:
S
starlord 已提交
31 32
    void
    SetUp() override {
X
xiaojun.lin 已提交
33 34
        TestGpuIndexBase::SetUp();

X
xj.lin 已提交
35
        Init_with_default();
S
starlord 已提交
36
        index_ = std::make_shared<knowhere::IDMAP>();
X
xj.lin 已提交
37
    }
38

S
starlord 已提交
39 40
    void
    TearDown() override {
X
xiaojun.lin 已提交
41
        TestGpuIndexBase::TearDown();
42 43
    }

X
xj.lin 已提交
44
 protected:
S
starlord 已提交
45
    knowhere::IDMAPPtr index_ = nullptr;
X
xj.lin 已提交
46 47 48
};

TEST_F(IDMAPTest, idmap_basic) {
X
xj.lin 已提交
49
    ASSERT_TRUE(!xb.empty());
X
xj.lin 已提交
50

S
starlord 已提交
51
    auto conf = std::make_shared<knowhere::Cfg>();
X
xiaojun.lin 已提交
52 53
    conf->d = dim;
    conf->k = k;
S
starlord 已提交
54
    conf->metric_type = knowhere::METRICTYPE::L2;
X
xiaojun.lin 已提交
55 56 57

    index_->Train(conf);
    index_->Add(base_dataset, conf);
X
xj.lin 已提交
58 59
    EXPECT_EQ(index_->Count(), nb);
    EXPECT_EQ(index_->Dimension(), dim);
X
xj.lin 已提交
60 61
    ASSERT_TRUE(index_->GetRawVectors() != nullptr);
    ASSERT_TRUE(index_->GetRawIds() != nullptr);
X
xiaojun.lin 已提交
62
    auto result = index_->Search(query_dataset, conf);
X
xj.lin 已提交
63
    AssertAnns(result, nq, k);
X
xiaojun.lin 已提交
64
    //    PrintResult(result, nq, k);
X
xj.lin 已提交
65

X
xj.lin 已提交
66
    index_->Seal();
X
xj.lin 已提交
67
    auto binaryset = index_->Serialize();
S
starlord 已提交
68
    auto new_index = std::make_shared<knowhere::IDMAP>();
X
xj.lin 已提交
69
    new_index->Load(binaryset);
X
xiaojun.lin 已提交
70
    auto re_result = index_->Search(query_dataset, conf);
X
xj.lin 已提交
71
    AssertAnns(re_result, nq, k);
X
xiaojun.lin 已提交
72
    //    PrintResult(re_result, nq, k);
X
xj.lin 已提交
73 74 75
}

TEST_F(IDMAPTest, idmap_serialize) {
S
starlord 已提交
76
    auto serialize = [](const std::string& filename, knowhere::BinaryPtr& bin, uint8_t* ret) {
X
xj.lin 已提交
77
        FileIOWriter writer(filename);
S
starlord 已提交
78
        writer(static_cast<void*>(bin->data.get()), bin->size);
X
xj.lin 已提交
79 80 81 82 83

        FileIOReader reader(filename);
        reader(ret, bin->size);
    };

S
starlord 已提交
84
    auto conf = std::make_shared<knowhere::Cfg>();
X
xiaojun.lin 已提交
85 86
    conf->d = dim;
    conf->k = k;
S
starlord 已提交
87
    conf->metric_type = knowhere::METRICTYPE::L2;
X
xiaojun.lin 已提交
88

X
xj.lin 已提交
89 90
    {
        // serialize index
X
xiaojun.lin 已提交
91
        index_->Train(conf);
S
starlord 已提交
92
        index_->Add(base_dataset, knowhere::Config());
X
xiaojun.lin 已提交
93
        auto re_result = index_->Search(query_dataset, conf);
X
xj.lin 已提交
94
        AssertAnns(re_result, nq, k);
X
xiaojun.lin 已提交
95
        //        PrintResult(re_result, nq, k);
X
xj.lin 已提交
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
        EXPECT_EQ(index_->Count(), nb);
        EXPECT_EQ(index_->Dimension(), dim);
        auto binaryset = index_->Serialize();
        auto bin = binaryset.GetByName("IVF");

        std::string filename = "/tmp/idmap_test_serialize.bin";
        auto load_data = new uint8_t[bin->size];
        serialize(filename, bin, load_data);

        binaryset.clear();
        auto data = std::make_shared<uint8_t>();
        data.reset(load_data);
        binaryset.Append("IVF", data, bin->size);

        index_->Load(binaryset);
        EXPECT_EQ(index_->Count(), nb);
        EXPECT_EQ(index_->Dimension(), dim);
X
xiaojun.lin 已提交
113
        auto result = index_->Search(query_dataset, conf);
X
xj.lin 已提交
114
        AssertAnns(result, nq, k);
X
xiaojun.lin 已提交
115
        //        PrintResult(result, nq, k);
X
xj.lin 已提交
116 117 118 119
    }
}

TEST_F(IDMAPTest, copy_test) {
X
xj.lin 已提交
120
    ASSERT_TRUE(!xb.empty());
X
xj.lin 已提交
121

S
starlord 已提交
122
    auto conf = std::make_shared<knowhere::Cfg>();
X
xiaojun.lin 已提交
123 124
    conf->d = dim;
    conf->k = k;
S
starlord 已提交
125
    conf->metric_type = knowhere::METRICTYPE::L2;
X
xiaojun.lin 已提交
126 127 128

    index_->Train(conf);
    index_->Add(base_dataset, conf);
X
xj.lin 已提交
129 130
    EXPECT_EQ(index_->Count(), nb);
    EXPECT_EQ(index_->Dimension(), dim);
X
xj.lin 已提交
131 132
    ASSERT_TRUE(index_->GetRawVectors() != nullptr);
    ASSERT_TRUE(index_->GetRawIds() != nullptr);
X
xiaojun.lin 已提交
133
    auto result = index_->Search(query_dataset, conf);
X
xj.lin 已提交
134
    AssertAnns(result, nq, k);
S
starlord 已提交
135
    // PrintResult(result, nq, k);
X
xj.lin 已提交
136 137 138 139

    {
        // clone
        auto clone_index = index_->Clone();
X
xiaojun.lin 已提交
140
        auto clone_result = clone_index->Search(query_dataset, conf);
X
xj.lin 已提交
141 142 143 144 145
        AssertAnns(clone_result, nq, k);
    }

    {
        // cpu to gpu
X
xiaojun.lin 已提交
146
        auto clone_index = knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, conf);
X
xiaojun.lin 已提交
147
        auto clone_result = clone_index->Search(query_dataset, conf);
X
xj.lin 已提交
148
        AssertAnns(clone_result, nq, k);
S
starlord 已提交
149
        ASSERT_THROW({ std::static_pointer_cast<knowhere::GPUIDMAP>(clone_index)->GetRawVectors(); },
S
starlord 已提交
150
                     knowhere::KnowhereException);
S
starlord 已提交
151
        ASSERT_THROW({ std::static_pointer_cast<knowhere::GPUIDMAP>(clone_index)->GetRawIds(); },
S
starlord 已提交
152
                     knowhere::KnowhereException);
X
xj.lin 已提交
153 154 155

        auto binary = clone_index->Serialize();
        clone_index->Load(binary);
X
xiaojun.lin 已提交
156
        auto new_result = clone_index->Search(query_dataset, conf);
X
xj.lin 已提交
157 158
        AssertAnns(new_result, nq, k);

X
xj.lin 已提交
159
        auto clone_gpu_idx = clone_index->Clone();
X
xiaojun.lin 已提交
160
        auto clone_gpu_res = clone_gpu_idx->Search(query_dataset, conf);
X
xj.lin 已提交
161 162 163
        AssertAnns(clone_gpu_res, nq, k);

        // gpu to cpu
S
starlord 已提交
164
        auto host_index = knowhere::cloner::CopyGpuToCpu(clone_index, conf);
X
xiaojun.lin 已提交
165
        auto host_result = host_index->Search(query_dataset, conf);
X
xj.lin 已提交
166
        AssertAnns(host_result, nq, k);
S
starlord 已提交
167 168
        ASSERT_TRUE(std::static_pointer_cast<knowhere::IDMAP>(host_index)->GetRawVectors() != nullptr);
        ASSERT_TRUE(std::static_pointer_cast<knowhere::IDMAP>(host_index)->GetRawIds() != nullptr);
X
xj.lin 已提交
169 170

        // gpu to gpu
X
xiaojun.lin 已提交
171
        auto device_index = knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, conf);
S
starlord 已提交
172
        auto new_device_index =
X
xiaojun.lin 已提交
173
            std::static_pointer_cast<knowhere::GPUIDMAP>(device_index)->CopyGpuToGpu(DEVICEID, conf);
X
xiaojun.lin 已提交
174
        auto device_result = new_device_index->Search(query_dataset, conf);
X
xj.lin 已提交
175 176 177
        AssertAnns(device_result, nq, k);
    }
}