VecImpl.cpp 11.1 KB
Newer Older
J
jinhai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

X
MS-154  
xj.lin 已提交
18

S
starlord 已提交
19
#include "wrapper/VecImpl.h"
S
starlord 已提交
20
#include "utils/Log.h"
X
xiaojun.lin 已提交
21
#include "knowhere/index/vector_index/IndexIDMAP.h"
X
xiaojun.lin 已提交
22
#include "knowhere/index/vector_index/IndexIVFSQHybrid.h"
X
xiaojun.lin 已提交
23 24
#include "knowhere/index/vector_index/IndexGPUIVF.h"
#include "knowhere/common/Exception.h"
X
xiaojun.lin 已提交
25
#include "knowhere/index/vector_index/helpers/Cloner.h"
26
#include "DataTransfer.h"
X
MS-154  
xj.lin 已提交
27

X
xiaojun.lin 已提交
28 29 30 31
/*
 * no parameter check in this layer.
 * only responible for index combination
 */
X
MS-154  
xj.lin 已提交
32 33

namespace zilliz {
X
xj.lin 已提交
34
namespace milvus {
X
MS-154  
xj.lin 已提交
35 36
namespace engine {

37
Status
S
starlord 已提交
38
VecIndexImpl::BuildAll(const int64_t &nb,
39
                       const float *xb,
S
starlord 已提交
40
                       const int64_t *ids,
41
                       const Config &cfg,
S
starlord 已提交
42
                       const int64_t &nt,
43
                       const float *xt) {
X
xj.lin 已提交
44
    try {
X
xiaojun.lin 已提交
45
        dim = cfg->d;
X
xj.lin 已提交
46 47 48 49
        auto dataset = GenDatasetWithIds(nb, dim, xb, ids);

        auto preprocessor = index_->BuildPreprocessor(dataset, cfg);
        index_->set_preprocessor(preprocessor);
X
xj.lin 已提交
50
        auto model = index_->Train(dataset, cfg);
X
xj.lin 已提交
51 52
        index_->set_index_model(model);
        index_->Add(dataset, cfg);
S
starlord 已提交
53
    } catch (knowhere::KnowhereException &e) {
X
xj.lin 已提交
54
        WRAPPER_LOG_ERROR << e.what();
55
        return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
X
xj.lin 已提交
56 57
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
58
        return Status(KNOWHERE_ERROR, e.what());
X
xj.lin 已提交
59
    }
60
    return Status::OK();
X
xj.lin 已提交
61 62
}

63
Status
S
starlord 已提交
64
VecIndexImpl::Add(const int64_t &nb, const float *xb, const int64_t *ids, const Config &cfg) {
X
xj.lin 已提交
65
    try {
X
xj.lin 已提交
66
        auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
X
xj.lin 已提交
67 68

        index_->Add(dataset, cfg);
S
starlord 已提交
69
    } catch (knowhere::KnowhereException &e) {
X
xj.lin 已提交
70
        WRAPPER_LOG_ERROR << e.what();
71
        return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
X
xj.lin 已提交
72 73
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
74
        return Status(KNOWHERE_ERROR, e.what());
X
xj.lin 已提交
75
    }
76
    return Status::OK();
X
xj.lin 已提交
77 78
}

79
Status
S
starlord 已提交
80
VecIndexImpl::Search(const int64_t &nq, const float *xq, float *dist, int64_t *ids, const Config &cfg) {
X
xj.lin 已提交
81
    try {
X
xiaojun.lin 已提交
82
        auto k = cfg->k;
X
xj.lin 已提交
83
        auto dataset = GenDataset(nq, dim, xq);
X
xj.lin 已提交
84

85 86 87
        Config search_cfg = cfg;

        auto res = index_->Search(dataset, search_cfg);
X
xj.lin 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
        auto ids_array = res->array()[0];
        auto dis_array = res->array()[1];

        //{
        //    auto& ids = ids_array;
        //    auto& dists = dis_array;
        //    std::stringstream ss_id;
        //    std::stringstream ss_dist;
        //    for (auto i = 0; i < 10; i++) {
        //        for (auto j = 0; j < k; ++j) {
        //            ss_id << *(ids->data()->GetValues<int64_t>(1, i * k + j)) << " ";
        //            ss_dist << *(dists->data()->GetValues<float>(1, i * k + j)) << " ";
        //        }
        //        ss_id << std::endl;
        //        ss_dist << std::endl;
        //    }
        //    std::cout << "id\n" << ss_id.str() << std::endl;
        //    std::cout << "dist\n" << ss_dist.str() << std::endl;
        //}

        auto p_ids = ids_array->data()->GetValues<int64_t>(1, 0);
        auto p_dist = dis_array->data()->GetValues<float>(1, 0);

        // TODO(linxj): avoid copy here.
        memcpy(ids, p_ids, sizeof(int64_t) * nq * k);
        memcpy(dist, p_dist, sizeof(float) * nq * k);
S
starlord 已提交
114
    } catch (knowhere::KnowhereException &e) {
X
xj.lin 已提交
115
        WRAPPER_LOG_ERROR << e.what();
116
        return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
X
xj.lin 已提交
117 118
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
119
        return Status(KNOWHERE_ERROR, e.what());
X
xj.lin 已提交
120
    }
121
    return Status::OK();
X
MS-154  
xj.lin 已提交
122 123
}

124 125
zilliz::knowhere::BinarySet
VecIndexImpl::Serialize() {
X
xj.lin 已提交
126
    type = ConvertToCpuIndexType(type);
X
MS-154  
xj.lin 已提交
127 128 129
    return index_->Serialize();
}

130 131
Status
VecIndexImpl::Load(const zilliz::knowhere::BinarySet &index_binary) {
X
MS-154  
xj.lin 已提交
132
    index_->Load(index_binary);
X
xj.lin 已提交
133
    dim = Dimension();
134
    return Status::OK();
X
MS-154  
xj.lin 已提交
135 136
}

137 138
int64_t
VecIndexImpl::Dimension() {
X
xj.lin 已提交
139 140 141
    return index_->Dimension();
}

142 143
int64_t
VecIndexImpl::Count() {
X
xj.lin 已提交
144 145 146
    return index_->Count();
}

147 148
IndexType
VecIndexImpl::GetType() {
X
xj.lin 已提交
149 150 151
    return type;
}

152 153
VecIndexPtr
VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) {
X
xj.lin 已提交
154
    // TODO(linxj): exception handle
X
xiaojun.lin 已提交
155
    auto gpu_index = zilliz::knowhere::cloner::CopyCpuToGpu(index_, device_id, cfg);
X
xj.lin 已提交
156
    auto new_index = std::make_shared<VecIndexImpl>(gpu_index, ConvertToGpuIndexType(type));
W
wxyu 已提交
157 158
    new_index->dim = dim;
    return new_index;
159 160
}

161 162
VecIndexPtr
VecIndexImpl::CopyToCpu(const Config &cfg) {
X
xj.lin 已提交
163
    // TODO(linxj): exception handle
X
xiaojun.lin 已提交
164
    auto cpu_index = zilliz::knowhere::cloner::CopyGpuToCpu(index_, cfg);
X
xj.lin 已提交
165 166 167
    auto new_index = std::make_shared<VecIndexImpl>(cpu_index, ConvertToCpuIndexType(type));
    new_index->dim = dim;
    return new_index;
168 169
}

170 171
VecIndexPtr
VecIndexImpl::Clone() {
X
xj.lin 已提交
172
    // TODO(linxj): exception handle
173 174 175 176 177
    auto clone_index = std::make_shared<VecIndexImpl>(index_->Clone(), type);
    clone_index->dim = dim;
    return clone_index;
}

178 179
int64_t
VecIndexImpl::GetDeviceId() {
S
starlord 已提交
180
    if (auto device_idx = std::dynamic_pointer_cast<knowhere::GPUIndex>(index_)) {
181 182
        return device_idx->GetGpuDevice();
    }
X
xj.lin 已提交
183 184
    // else
    return -1; // -1 == cpu
185 186
}

187 188
float *
BFIndex::GetRawVectors() {
S
starlord 已提交
189
    auto raw_index = std::dynamic_pointer_cast<knowhere::IDMAP>(index_);
X
xj.lin 已提交
190 191
    if (raw_index) { return raw_index->GetRawVectors(); }
    return nullptr;
X
xj.lin 已提交
192 193
}

194 195
int64_t *
BFIndex::GetRawIds() {
S
starlord 已提交
196
    return std::static_pointer_cast<knowhere::IDMAP>(index_)->GetRawIds();
X
xj.lin 已提交
197 198
}

199 200
ErrorCode
BFIndex::Build(const Config &cfg) {
X
xj.lin 已提交
201
    try {
X
xiaojun.lin 已提交
202
        dim = cfg->d;
S
starlord 已提交
203 204
        std::static_pointer_cast<knowhere::IDMAP>(index_)->Train(cfg);
    } catch (knowhere::KnowhereException &e) {
X
xj.lin 已提交
205
        WRAPPER_LOG_ERROR << e.what();
S
starlord 已提交
206
        return KNOWHERE_UNEXPECTED_ERROR;
X
xj.lin 已提交
207 208
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
S
starlord 已提交
209
        return KNOWHERE_ERROR;
X
xj.lin 已提交
210
    }
S
starlord 已提交
211
    return KNOWHERE_SUCCESS;
X
xj.lin 已提交
212 213
}

214
Status
S
starlord 已提交
215
BFIndex::BuildAll(const int64_t &nb,
216
                  const float *xb,
S
starlord 已提交
217
                  const int64_t *ids,
218
                  const Config &cfg,
S
starlord 已提交
219
                  const int64_t &nt,
220
                  const float *xt) {
X
xj.lin 已提交
221
    try {
X
xiaojun.lin 已提交
222
        dim = cfg->d;
X
xj.lin 已提交
223 224
        auto dataset = GenDatasetWithIds(nb, dim, xb, ids);

S
starlord 已提交
225
        std::static_pointer_cast<knowhere::IDMAP>(index_)->Train(cfg);
X
xj.lin 已提交
226
        index_->Add(dataset, cfg);
S
starlord 已提交
227
    } catch (knowhere::KnowhereException &e) {
X
xj.lin 已提交
228
        WRAPPER_LOG_ERROR << e.what();
229
        return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
X
xj.lin 已提交
230 231
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
232
        return Status(KNOWHERE_ERROR, e.what());
X
xj.lin 已提交
233
    }
234
    return Status::OK();
X
xj.lin 已提交
235 236
}

X
xj.lin 已提交
237
// TODO(linxj): add lock here.
238
Status
S
starlord 已提交
239
IVFMixIndex::BuildAll(const int64_t &nb,
240
                      const float *xb,
S
starlord 已提交
241
                      const int64_t *ids,
242
                      const Config &cfg,
S
starlord 已提交
243
                      const int64_t &nt,
244
                      const float *xt) {
X
xj.lin 已提交
245
    try {
X
xiaojun.lin 已提交
246
        dim = cfg->d;
X
xj.lin 已提交
247 248 249 250
        auto dataset = GenDatasetWithIds(nb, dim, xb, ids);

        auto preprocessor = index_->BuildPreprocessor(dataset, cfg);
        index_->set_preprocessor(preprocessor);
X
xj.lin 已提交
251
        auto model = index_->Train(dataset, cfg);
X
xj.lin 已提交
252 253 254
        index_->set_index_model(model);
        index_->Add(dataset, cfg);

X
xiaojun.lin 已提交
255
        if (auto device_index = std::dynamic_pointer_cast<knowhere::GPUIndex>(index_)) {
W
wxyu 已提交
256
            auto host_index = device_index->CopyGpuToCpu(Config());
X
xj.lin 已提交
257
            index_ = host_index;
X
xj.lin 已提交
258
            type = ConvertToCpuIndexType(type);
X
xj.lin 已提交
259 260
        } else {
            WRAPPER_LOG_ERROR << "Build IVFMIXIndex Failed";
261
            return Status(KNOWHERE_ERROR, "Build IVFMIXIndex Failed");
X
xj.lin 已提交
262
        }
S
starlord 已提交
263
    } catch (knowhere::KnowhereException &e) {
X
xj.lin 已提交
264
        WRAPPER_LOG_ERROR << e.what();
265
        return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
X
xj.lin 已提交
266 267
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
268
        return Status(KNOWHERE_ERROR, e.what());
X
xj.lin 已提交
269
    }
270
    return Status::OK();
X
xj.lin 已提交
271 272
}

273 274
Status
IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) {
X
xj.lin 已提交
275 276
    index_->Load(index_binary);
    dim = Dimension();
277
    return Status::OK();
X
xj.lin 已提交
278 279
}

J
JinHai-CN 已提交
280 281
knowhere::QuantizerPtr
IVFHybridIndex::LoadQuantizer(const Config& conf) {
X
xiaojun.lin 已提交
282 283 284 285 286 287 288 289
    // TODO(linxj): Hardcode here
    if (auto new_idx = std::dynamic_pointer_cast<knowhere::IVFSQHybrid>(index_)){
        return new_idx->LoadQuantizer(conf);
    } else {
        WRAPPER_LOG_ERROR << "Hybrid mode not support for index type: " << int(type);
    }
}

J
JinHai-CN 已提交
290 291
Status
IVFHybridIndex::SetQuantizer(const knowhere::QuantizerPtr& q) {
X
xiaojun.lin 已提交
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
    try {
        // TODO(linxj): Hardcode here
        if (auto new_idx = std::dynamic_pointer_cast<knowhere::IVFSQHybrid>(index_)) {
            new_idx->SetQuantizer(q);
        } else {
            WRAPPER_LOG_ERROR << "Hybrid mode not support for index type: " << int(type);
            return Status(KNOWHERE_ERROR, "not support");
        }
    } catch (knowhere::KnowhereException &e) {
        WRAPPER_LOG_ERROR << e.what();
        return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
        return Status(KNOWHERE_ERROR, e.what());
    }
X
xiaojun.lin 已提交
307
    return Status::OK();
X
xiaojun.lin 已提交
308 309
}

J
JinHai-CN 已提交
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
Status
IVFHybridIndex::UnsetQuantizer() {
    try {
        // TODO(linxj): Hardcode here
        if (auto new_idx = std::dynamic_pointer_cast<knowhere::IVFSQHybrid>(index_)) {
            new_idx->UnsetQuantizer();
        } else {
            WRAPPER_LOG_ERROR << "Hybrid mode not support for index type: " << int(type);
            return Status(KNOWHERE_ERROR, "not support");
        }
    } catch (knowhere::KnowhereException &e) {
        WRAPPER_LOG_ERROR << e.what();
        return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
        return Status(KNOWHERE_ERROR, e.what());
    }
X
xiaojun.lin 已提交
327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
    return Status::OK();
}

Status IVFHybridIndex::LoadData(const knowhere::QuantizerPtr &q, const Config &conf) {
    try {
        // TODO(linxj): Hardcode here
        if (auto new_idx = std::dynamic_pointer_cast<knowhere::IVFSQHybrid>(index_)) {
            new_idx->LoadData(q, conf);
        } else {
            WRAPPER_LOG_ERROR << "Hybrid mode not support for index type: " << int(type);
            return Status(KNOWHERE_ERROR, "not support");
        }
    } catch (knowhere::KnowhereException &e) {
        WRAPPER_LOG_ERROR << e.what();
        return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
        return Status(KNOWHERE_ERROR, e.what());
    }
    return Status::OK();
J
JinHai-CN 已提交
347 348
}

S
starlord 已提交
349 350 351
} // namespace engine
} // namespace milvus
} // namespace zilliz