VecImpl.cpp 9.3 KB
Newer Older
J
jinhai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

X
MS-154  
xj.lin 已提交
18

S
starlord 已提交
19
#include "utils/Log.h"
X
xiaojun.lin 已提交
20 21 22
#include "knowhere/index/vector_index/IndexIDMAP.h"
#include "knowhere/index/vector_index/IndexGPUIVF.h"
#include "knowhere/common/Exception.h"
X
xiaojun.lin 已提交
23
#include "knowhere/index/vector_index/helpers/Cloner.h"
24 25
#include "VecImpl.h"
#include "DataTransfer.h"
X
MS-154  
xj.lin 已提交
26 27 28


namespace zilliz {
X
xj.lin 已提交
29
namespace milvus {
X
MS-154  
xj.lin 已提交
30 31 32 33
namespace engine {

using namespace zilliz::knowhere;

34 35 36 37 38 39 40
Status
VecIndexImpl::BuildAll(const long &nb,
                       const float *xb,
                       const long *ids,
                       const Config &cfg,
                       const long &nt,
                       const float *xt) {
X
xj.lin 已提交
41 42 43 44 45 46
    try {
        dim = cfg["dim"].as<int>();
        auto dataset = GenDatasetWithIds(nb, dim, xb, ids);

        auto preprocessor = index_->BuildPreprocessor(dataset, cfg);
        index_->set_preprocessor(preprocessor);
X
xj.lin 已提交
47
        auto model = index_->Train(dataset, cfg);
X
xj.lin 已提交
48 49 50 51
        index_->set_index_model(model);
        index_->Add(dataset, cfg);
    } catch (KnowhereException &e) {
        WRAPPER_LOG_ERROR << e.what();
52
        return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
X
xj.lin 已提交
53 54
    } catch (jsoncons::json_exception &e) {
        WRAPPER_LOG_ERROR << e.what();
55
        return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
X
xj.lin 已提交
56 57
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
58
        return Status(KNOWHERE_ERROR, e.what());
X
xj.lin 已提交
59
    }
60
    return Status::OK();
X
xj.lin 已提交
61 62
}

63 64
Status
VecIndexImpl::Add(const long &nb, const float *xb, const long *ids, const Config &cfg) {
X
xj.lin 已提交
65
    try {
X
xj.lin 已提交
66
        auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
X
xj.lin 已提交
67 68 69 70

        index_->Add(dataset, cfg);
    } catch (KnowhereException &e) {
        WRAPPER_LOG_ERROR << e.what();
71
        return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
X
xj.lin 已提交
72 73
    } catch (jsoncons::json_exception &e) {
        WRAPPER_LOG_ERROR << e.what();
74
        return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
X
xj.lin 已提交
75 76
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
77
        return Status(KNOWHERE_ERROR, e.what());
X
xj.lin 已提交
78
    }
79
    return Status::OK();
X
xj.lin 已提交
80 81
}

82 83
Status
VecIndexImpl::Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) {
X
xj.lin 已提交
84 85
    try {
        auto k = cfg["k"].as<int>();
X
xj.lin 已提交
86
        auto dataset = GenDataset(nq, dim, xq);
X
xj.lin 已提交
87

88 89
        Config search_cfg = cfg;

90
        ParameterValidation(type, search_cfg);
91 92

        auto res = index_->Search(dataset, search_cfg);
X
xj.lin 已提交
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
        auto ids_array = res->array()[0];
        auto dis_array = res->array()[1];

        //{
        //    auto& ids = ids_array;
        //    auto& dists = dis_array;
        //    std::stringstream ss_id;
        //    std::stringstream ss_dist;
        //    for (auto i = 0; i < 10; i++) {
        //        for (auto j = 0; j < k; ++j) {
        //            ss_id << *(ids->data()->GetValues<int64_t>(1, i * k + j)) << " ";
        //            ss_dist << *(dists->data()->GetValues<float>(1, i * k + j)) << " ";
        //        }
        //        ss_id << std::endl;
        //        ss_dist << std::endl;
        //    }
        //    std::cout << "id\n" << ss_id.str() << std::endl;
        //    std::cout << "dist\n" << ss_dist.str() << std::endl;
        //}

        auto p_ids = ids_array->data()->GetValues<int64_t>(1, 0);
        auto p_dist = dis_array->data()->GetValues<float>(1, 0);

        // TODO(linxj): avoid copy here.
        memcpy(ids, p_ids, sizeof(int64_t) * nq * k);
        memcpy(dist, p_dist, sizeof(float) * nq * k);
J
jinhai 已提交
119

X
xj.lin 已提交
120 121
    } catch (KnowhereException &e) {
        WRAPPER_LOG_ERROR << e.what();
122
        return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
X
xj.lin 已提交
123 124
    } catch (jsoncons::json_exception &e) {
        WRAPPER_LOG_ERROR << e.what();
125
        return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
X
xj.lin 已提交
126 127
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
128
        return Status(KNOWHERE_ERROR, e.what());
X
xj.lin 已提交
129
    }
130
    return Status::OK();
X
MS-154  
xj.lin 已提交
131 132
}

133 134
zilliz::knowhere::BinarySet
VecIndexImpl::Serialize() {
X
xj.lin 已提交
135
    type = ConvertToCpuIndexType(type);
X
MS-154  
xj.lin 已提交
136 137 138
    return index_->Serialize();
}

139 140
Status
VecIndexImpl::Load(const zilliz::knowhere::BinarySet &index_binary) {
X
MS-154  
xj.lin 已提交
141
    index_->Load(index_binary);
X
xj.lin 已提交
142
    dim = Dimension();
143
    return Status::OK();
X
MS-154  
xj.lin 已提交
144 145
}

146 147
int64_t
VecIndexImpl::Dimension() {
X
xj.lin 已提交
148 149 150
    return index_->Dimension();
}

151 152
int64_t
VecIndexImpl::Count() {
X
xj.lin 已提交
153 154 155
    return index_->Count();
}

156 157
IndexType
VecIndexImpl::GetType() {
X
xj.lin 已提交
158 159 160
    return type;
}

161 162
VecIndexPtr
VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) {
X
xj.lin 已提交
163
    // TODO(linxj): exception handle
X
xiaojun.lin 已提交
164
    auto gpu_index = zilliz::knowhere::cloner::CopyCpuToGpu(index_, device_id, cfg);
X
xj.lin 已提交
165
    auto new_index = std::make_shared<VecIndexImpl>(gpu_index, ConvertToGpuIndexType(type));
W
wxyu 已提交
166 167
    new_index->dim = dim;
    return new_index;
168 169
}

170 171
VecIndexPtr
VecIndexImpl::CopyToCpu(const Config &cfg) {
X
xj.lin 已提交
172
    // TODO(linxj): exception handle
X
xiaojun.lin 已提交
173
    auto cpu_index = zilliz::knowhere::cloner::CopyGpuToCpu(index_, cfg);
X
xj.lin 已提交
174 175 176
    auto new_index = std::make_shared<VecIndexImpl>(cpu_index, ConvertToCpuIndexType(type));
    new_index->dim = dim;
    return new_index;
177 178
}

179 180
VecIndexPtr
VecIndexImpl::Clone() {
X
xj.lin 已提交
181
    // TODO(linxj): exception handle
182 183 184 185 186
    auto clone_index = std::make_shared<VecIndexImpl>(index_->Clone(), type);
    clone_index->dim = dim;
    return clone_index;
}

187 188 189
int64_t
VecIndexImpl::GetDeviceId() {
    if (auto device_idx = std::dynamic_pointer_cast<GPUIndex>(index_)) {
190 191
        return device_idx->GetGpuDevice();
    }
X
xj.lin 已提交
192 193
    // else
    return -1; // -1 == cpu
194 195
}

196 197
float *
BFIndex::GetRawVectors() {
X
xj.lin 已提交
198 199 200
    auto raw_index = std::dynamic_pointer_cast<IDMAP>(index_);
    if (raw_index) { return raw_index->GetRawVectors(); }
    return nullptr;
X
xj.lin 已提交
201 202
}

203 204
int64_t *
BFIndex::GetRawIds() {
X
xj.lin 已提交
205 206 207
    return std::static_pointer_cast<IDMAP>(index_)->GetRawIds();
}

208 209
ErrorCode
BFIndex::Build(const Config &cfg) {
X
xj.lin 已提交
210
    try {
X
xj.lin 已提交
211 212
        dim = cfg["dim"].as<int>();
        std::static_pointer_cast<IDMAP>(index_)->Train(cfg);
X
xj.lin 已提交
213 214
    } catch (KnowhereException &e) {
        WRAPPER_LOG_ERROR << e.what();
S
starlord 已提交
215
        return KNOWHERE_UNEXPECTED_ERROR;
X
xj.lin 已提交
216 217
    } catch (jsoncons::json_exception &e) {
        WRAPPER_LOG_ERROR << e.what();
S
starlord 已提交
218
        return KNOWHERE_INVALID_ARGUMENT;
X
xj.lin 已提交
219 220
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
S
starlord 已提交
221
        return KNOWHERE_ERROR;
X
xj.lin 已提交
222
    }
S
starlord 已提交
223
    return KNOWHERE_SUCCESS;
X
xj.lin 已提交
224 225
}

226 227 228 229 230 231 232
Status
BFIndex::BuildAll(const long &nb,
                  const float *xb,
                  const long *ids,
                  const Config &cfg,
                  const long &nt,
                  const float *xt) {
X
xj.lin 已提交
233 234 235 236
    try {
        dim = cfg["dim"].as<int>();
        auto dataset = GenDatasetWithIds(nb, dim, xb, ids);

X
xj.lin 已提交
237
        std::static_pointer_cast<IDMAP>(index_)->Train(cfg);
X
xj.lin 已提交
238 239 240
        index_->Add(dataset, cfg);
    } catch (KnowhereException &e) {
        WRAPPER_LOG_ERROR << e.what();
241
        return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
X
xj.lin 已提交
242 243
    } catch (jsoncons::json_exception &e) {
        WRAPPER_LOG_ERROR << e.what();
244
        return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
X
xj.lin 已提交
245 246
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
247
        return Status(KNOWHERE_ERROR, e.what());
X
xj.lin 已提交
248
    }
249
    return Status::OK();
X
xj.lin 已提交
250 251
}

X
xj.lin 已提交
252
// TODO(linxj): add lock here.
253 254 255 256 257 258 259
Status
IVFMixIndex::BuildAll(const long &nb,
                      const float *xb,
                      const long *ids,
                      const Config &cfg,
                      const long &nt,
                      const float *xt) {
X
xj.lin 已提交
260 261 262 263 264 265
    try {
        dim = cfg["dim"].as<int>();
        auto dataset = GenDatasetWithIds(nb, dim, xb, ids);

        auto preprocessor = index_->BuildPreprocessor(dataset, cfg);
        index_->set_preprocessor(preprocessor);
X
xj.lin 已提交
266
        auto model = index_->Train(dataset, cfg);
X
xj.lin 已提交
267 268 269 270
        index_->set_index_model(model);
        index_->Add(dataset, cfg);

        if (auto device_index = std::dynamic_pointer_cast<GPUIVF>(index_)) {
W
wxyu 已提交
271
            auto host_index = device_index->CopyGpuToCpu(Config());
X
xj.lin 已提交
272
            index_ = host_index;
X
xj.lin 已提交
273
            type = ConvertToCpuIndexType(type);
X
xj.lin 已提交
274 275
        } else {
            WRAPPER_LOG_ERROR << "Build IVFMIXIndex Failed";
276
            return Status(KNOWHERE_ERROR, "Build IVFMIXIndex Failed");
X
xj.lin 已提交
277 278 279
        }
    } catch (KnowhereException &e) {
        WRAPPER_LOG_ERROR << e.what();
280
        return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
X
xj.lin 已提交
281 282
    } catch (jsoncons::json_exception &e) {
        WRAPPER_LOG_ERROR << e.what();
283
        return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
X
xj.lin 已提交
284 285
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
286
        return Status(KNOWHERE_ERROR, e.what());
X
xj.lin 已提交
287
    }
288
    return Status::OK();
X
xj.lin 已提交
289 290
}

291 292
Status
IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) {
X
xj.lin 已提交
293
    //index_ = std::make_shared<IVF>();
X
xj.lin 已提交
294 295
    index_->Load(index_binary);
    dim = Dimension();
296
    return Status::OK();
X
xj.lin 已提交
297 298
}

X
MS-154  
xj.lin 已提交
299 300 301
}
}
}