vec_impl.cpp 9.4 KB
Newer Older
J
jinhai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

X
MS-154  
xj.lin 已提交
18

S
starlord 已提交
19
#include "utils/Log.h"
X
xiaojun.lin 已提交
20 21 22
#include "knowhere/index/vector_index/IndexIDMAP.h"
#include "knowhere/index/vector_index/IndexGPUIVF.h"
#include "knowhere/common/Exception.h"
X
xiaojun.lin 已提交
23
#include "knowhere/index/vector_index/helpers/Cloner.h"
X
MS-154  
xj.lin 已提交
24 25 26 27 28 29

#include "vec_impl.h"
#include "data_transfer.h"


namespace zilliz {
X
xj.lin 已提交
30
namespace milvus {
X
MS-154  
xj.lin 已提交
31 32 33 34
namespace engine {

using namespace zilliz::knowhere;

S
starlord 已提交
35
ErrorCode VecIndexImpl::BuildAll(const long &nb,
X
xj.lin 已提交
36 37 38 39 40 41 42 43 44 45 46
                                             const float *xb,
                                             const long *ids,
                                             const Config &cfg,
                                             const long &nt,
                                             const float *xt) {
    try {
        dim = cfg["dim"].as<int>();
        auto dataset = GenDatasetWithIds(nb, dim, xb, ids);

        auto preprocessor = index_->BuildPreprocessor(dataset, cfg);
        index_->set_preprocessor(preprocessor);
X
xj.lin 已提交
47
        auto model = index_->Train(dataset, cfg);
X
xj.lin 已提交
48 49 50 51
        index_->set_index_model(model);
        index_->Add(dataset, cfg);
    } catch (KnowhereException &e) {
        WRAPPER_LOG_ERROR << e.what();
S
starlord 已提交
52
        return KNOWHERE_UNEXPECTED_ERROR;
X
xj.lin 已提交
53 54
    } catch (jsoncons::json_exception &e) {
        WRAPPER_LOG_ERROR << e.what();
S
starlord 已提交
55
        return KNOWHERE_INVALID_ARGUMENT;
X
xj.lin 已提交
56 57
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
S
starlord 已提交
58
        return KNOWHERE_ERROR;
X
xj.lin 已提交
59
    }
S
starlord 已提交
60
    return KNOWHERE_SUCCESS;
X
xj.lin 已提交
61 62
}

S
starlord 已提交
63
ErrorCode VecIndexImpl::Add(const long &nb, const float *xb, const long *ids, const Config &cfg) {
X
xj.lin 已提交
64
    try {
X
xj.lin 已提交
65
        auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
X
xj.lin 已提交
66 67 68 69

        index_->Add(dataset, cfg);
    } catch (KnowhereException &e) {
        WRAPPER_LOG_ERROR << e.what();
S
starlord 已提交
70
        return KNOWHERE_UNEXPECTED_ERROR;
X
xj.lin 已提交
71 72
    } catch (jsoncons::json_exception &e) {
        WRAPPER_LOG_ERROR << e.what();
S
starlord 已提交
73
        return KNOWHERE_INVALID_ARGUMENT;
X
xj.lin 已提交
74 75
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
S
starlord 已提交
76
        return KNOWHERE_ERROR;
X
xj.lin 已提交
77
    }
S
starlord 已提交
78
    return KNOWHERE_SUCCESS;
X
xj.lin 已提交
79 80
}

S
starlord 已提交
81
ErrorCode VecIndexImpl::Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) {
X
xj.lin 已提交
82 83
    try {
        auto k = cfg["k"].as<int>();
X
xj.lin 已提交
84
        auto dataset = GenDataset(nq, dim, xq);
X
xj.lin 已提交
85

86 87
        Config search_cfg = cfg;

88
        ParameterValidation(type, search_cfg);
89 90

        auto res = index_->Search(dataset, search_cfg);
X
xj.lin 已提交
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
        auto ids_array = res->array()[0];
        auto dis_array = res->array()[1];

        //{
        //    auto& ids = ids_array;
        //    auto& dists = dis_array;
        //    std::stringstream ss_id;
        //    std::stringstream ss_dist;
        //    for (auto i = 0; i < 10; i++) {
        //        for (auto j = 0; j < k; ++j) {
        //            ss_id << *(ids->data()->GetValues<int64_t>(1, i * k + j)) << " ";
        //            ss_dist << *(dists->data()->GetValues<float>(1, i * k + j)) << " ";
        //        }
        //        ss_id << std::endl;
        //        ss_dist << std::endl;
        //    }
        //    std::cout << "id\n" << ss_id.str() << std::endl;
        //    std::cout << "dist\n" << ss_dist.str() << std::endl;
        //}

        auto p_ids = ids_array->data()->GetValues<int64_t>(1, 0);
        auto p_dist = dis_array->data()->GetValues<float>(1, 0);

        // TODO(linxj): avoid copy here.
        memcpy(ids, p_ids, sizeof(int64_t) * nq * k);
        memcpy(dist, p_dist, sizeof(float) * nq * k);
J
jinhai 已提交
117

X
xj.lin 已提交
118 119
    } catch (KnowhereException &e) {
        WRAPPER_LOG_ERROR << e.what();
S
starlord 已提交
120
        return KNOWHERE_UNEXPECTED_ERROR;
X
xj.lin 已提交
121 122
    } catch (jsoncons::json_exception &e) {
        WRAPPER_LOG_ERROR << e.what();
S
starlord 已提交
123
        return KNOWHERE_INVALID_ARGUMENT;
X
xj.lin 已提交
124 125
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
S
starlord 已提交
126
        return KNOWHERE_ERROR;
X
xj.lin 已提交
127
    }
S
starlord 已提交
128
    return KNOWHERE_SUCCESS;
X
MS-154  
xj.lin 已提交
129 130 131
}

zilliz::knowhere::BinarySet VecIndexImpl::Serialize() {
X
xj.lin 已提交
132
    type = ConvertToCpuIndexType(type);
X
MS-154  
xj.lin 已提交
133 134 135
    return index_->Serialize();
}

S
starlord 已提交
136
ErrorCode VecIndexImpl::Load(const zilliz::knowhere::BinarySet &index_binary) {
X
MS-154  
xj.lin 已提交
137
    index_->Load(index_binary);
X
xj.lin 已提交
138
    dim = Dimension();
S
starlord 已提交
139
    return KNOWHERE_SUCCESS;
X
MS-154  
xj.lin 已提交
140 141
}

X
xj.lin 已提交
142 143 144 145 146 147 148 149
int64_t VecIndexImpl::Dimension() {
    return index_->Dimension();
}

int64_t VecIndexImpl::Count() {
    return index_->Count();
}

X
xj.lin 已提交
150 151 152 153
IndexType VecIndexImpl::GetType() {
    return type;
}

154
VecIndexPtr VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) {
X
xj.lin 已提交
155
    // TODO(linxj): exception handle
X
xiaojun.lin 已提交
156
    auto gpu_index = zilliz::knowhere::cloner::CopyCpuToGpu(index_, device_id, cfg);
X
xj.lin 已提交
157
    auto new_index = std::make_shared<VecIndexImpl>(gpu_index, ConvertToGpuIndexType(type));
W
wxyu 已提交
158 159
    new_index->dim = dim;
    return new_index;
160 161 162
}

VecIndexPtr VecIndexImpl::CopyToCpu(const Config &cfg) {
X
xj.lin 已提交
163
    // TODO(linxj): exception handle
X
xiaojun.lin 已提交
164
    auto cpu_index = zilliz::knowhere::cloner::CopyGpuToCpu(index_, cfg);
X
xj.lin 已提交
165 166 167
    auto new_index = std::make_shared<VecIndexImpl>(cpu_index, ConvertToCpuIndexType(type));
    new_index->dim = dim;
    return new_index;
168 169
}

170
VecIndexPtr VecIndexImpl::Clone() {
X
xj.lin 已提交
171
    // TODO(linxj): exception handle
172 173 174 175 176 177 178 179 180
    auto clone_index = std::make_shared<VecIndexImpl>(index_->Clone(), type);
    clone_index->dim = dim;
    return clone_index;
}

int64_t VecIndexImpl::GetDeviceId() {
    if (auto device_idx = std::dynamic_pointer_cast<GPUIndex>(index_)){
        return device_idx->GetGpuDevice();
    }
X
xj.lin 已提交
181 182
    // else
    return -1; // -1 == cpu
183 184
}

X
xj.lin 已提交
185
float *BFIndex::GetRawVectors() {
X
xj.lin 已提交
186 187 188
    auto raw_index = std::dynamic_pointer_cast<IDMAP>(index_);
    if (raw_index) { return raw_index->GetRawVectors(); }
    return nullptr;
X
xj.lin 已提交
189 190 191 192 193 194
}

int64_t *BFIndex::GetRawIds() {
    return std::static_pointer_cast<IDMAP>(index_)->GetRawIds();
}

S
starlord 已提交
195
ErrorCode BFIndex::Build(const Config &cfg) {
X
xj.lin 已提交
196
    try {
X
xj.lin 已提交
197 198
        dim = cfg["dim"].as<int>();
        std::static_pointer_cast<IDMAP>(index_)->Train(cfg);
X
xj.lin 已提交
199 200
    } catch (KnowhereException &e) {
        WRAPPER_LOG_ERROR << e.what();
S
starlord 已提交
201
        return KNOWHERE_UNEXPECTED_ERROR;
X
xj.lin 已提交
202 203
    } catch (jsoncons::json_exception &e) {
        WRAPPER_LOG_ERROR << e.what();
S
starlord 已提交
204
        return KNOWHERE_INVALID_ARGUMENT;
X
xj.lin 已提交
205 206
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
S
starlord 已提交
207
        return KNOWHERE_ERROR;
X
xj.lin 已提交
208
    }
S
starlord 已提交
209
    return KNOWHERE_SUCCESS;
X
xj.lin 已提交
210 211
}

S
starlord 已提交
212
ErrorCode BFIndex::BuildAll(const long &nb,
X
xj.lin 已提交
213 214 215 216 217 218 219 220 221
                                        const float *xb,
                                        const long *ids,
                                        const Config &cfg,
                                        const long &nt,
                                        const float *xt) {
    try {
        dim = cfg["dim"].as<int>();
        auto dataset = GenDatasetWithIds(nb, dim, xb, ids);

X
xj.lin 已提交
222
        std::static_pointer_cast<IDMAP>(index_)->Train(cfg);
X
xj.lin 已提交
223 224 225
        index_->Add(dataset, cfg);
    } catch (KnowhereException &e) {
        WRAPPER_LOG_ERROR << e.what();
S
starlord 已提交
226
        return KNOWHERE_UNEXPECTED_ERROR;
X
xj.lin 已提交
227 228
    } catch (jsoncons::json_exception &e) {
        WRAPPER_LOG_ERROR << e.what();
S
starlord 已提交
229
        return KNOWHERE_INVALID_ARGUMENT;
X
xj.lin 已提交
230 231
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
S
starlord 已提交
232
        return KNOWHERE_ERROR;
X
xj.lin 已提交
233
    }
S
starlord 已提交
234
    return KNOWHERE_SUCCESS;
X
xj.lin 已提交
235 236
}

X
xj.lin 已提交
237
// TODO(linxj): add lock here.
S
starlord 已提交
238
ErrorCode IVFMixIndex::BuildAll(const long &nb,
X
xj.lin 已提交
239 240 241 242 243 244 245 246 247 248 249
                                            const float *xb,
                                            const long *ids,
                                            const Config &cfg,
                                            const long &nt,
                                            const float *xt) {
    try {
        dim = cfg["dim"].as<int>();
        auto dataset = GenDatasetWithIds(nb, dim, xb, ids);

        auto preprocessor = index_->BuildPreprocessor(dataset, cfg);
        index_->set_preprocessor(preprocessor);
X
xj.lin 已提交
250
        auto model = index_->Train(dataset, cfg);
X
xj.lin 已提交
251 252 253 254
        index_->set_index_model(model);
        index_->Add(dataset, cfg);

        if (auto device_index = std::dynamic_pointer_cast<GPUIVF>(index_)) {
W
wxyu 已提交
255
            auto host_index = device_index->CopyGpuToCpu(Config());
X
xj.lin 已提交
256
            index_ = host_index;
X
xj.lin 已提交
257
            type = ConvertToCpuIndexType(type);
X
xj.lin 已提交
258 259
        } else {
            WRAPPER_LOG_ERROR << "Build IVFMIXIndex Failed";
S
starlord 已提交
260
            return KNOWHERE_ERROR;
X
xj.lin 已提交
261 262 263
        }
    } catch (KnowhereException &e) {
        WRAPPER_LOG_ERROR << e.what();
S
starlord 已提交
264
        return KNOWHERE_UNEXPECTED_ERROR;
X
xj.lin 已提交
265 266
    } catch (jsoncons::json_exception &e) {
        WRAPPER_LOG_ERROR << e.what();
S
starlord 已提交
267
        return KNOWHERE_INVALID_ARGUMENT;
X
xj.lin 已提交
268 269
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
S
starlord 已提交
270
        return KNOWHERE_ERROR;
X
xj.lin 已提交
271
    }
S
starlord 已提交
272
    return KNOWHERE_SUCCESS;
X
xj.lin 已提交
273 274
}

S
starlord 已提交
275
ErrorCode IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) {
X
xj.lin 已提交
276
    //index_ = std::make_shared<IVF>();
X
xj.lin 已提交
277 278
    index_->Load(index_binary);
    dim = Dimension();
S
starlord 已提交
279
    return KNOWHERE_SUCCESS;
X
xj.lin 已提交
280 281
}

X
MS-154  
xj.lin 已提交
282 283 284
}
}
}