IndexBuilder.cpp 2.1 KB
Newer Older
X
xj.lin 已提交
1 2 3 4 5 6 7 8
////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////

#include "mutex"

X
xj.lin 已提交
9 10 11 12
#include <faiss/gpu/StandardGpuResources.h>
#include "faiss/gpu/GpuIndexIVFFlat.h"
#include "faiss/gpu/GpuAutoTune.h"

X
xj.lin 已提交
13 14
#include "IndexBuilder.h"

X
xj.lin 已提交
15

X
xj.lin 已提交
16 17 18 19 20 21
namespace zilliz {
namespace vecwise {
namespace engine {

using std::vector;

X
xj.lin 已提交
22
static std::mutex gpu_resource;
X
xj.lin 已提交
23 24 25 26 27

IndexBuilder::IndexBuilder(const Operand_ptr &opd) {
    opd_ = opd;
}

X
xj.lin 已提交
28
// Default: build use gpu
29 30 31 32 33
Index_ptr IndexBuilder::build_all(const long &nb,
                                  const float* xb,
                                  const long* ids,
                                  const long &nt,
                                  const float* xt) {
X
xj.lin 已提交
34
    std::shared_ptr<faiss::Index> host_index = nullptr;
X
xj.lin 已提交
35
    {
X
xj.lin 已提交
36 37 38 39 40 41 42 43 44
        // TODO: list support index-type.
        faiss::Index *ori_index = faiss::index_factory(opd_->d, opd_->index_type.c_str());

        std::lock_guard<std::mutex> lk(gpu_resource);
        faiss::gpu::StandardGpuResources res;
        auto device_index = faiss::gpu::index_cpu_to_gpu(&res, 0, ori_index);
        if (!device_index->is_trained) {
            nt == 0 || xt == nullptr ? device_index->train(nb, xb)
                                     : device_index->train(nt, xt);
X
xj.lin 已提交
45
        }
X
xj.lin 已提交
46 47 48
        device_index->add_with_ids(nb, xb, ids);

        host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
X
xj.lin 已提交
49

X
xj.lin 已提交
50 51 52
        delete device_index;
        delete ori_index;
    }
53

X
xj.lin 已提交
54
    return std::make_shared<Index>(host_index);
55 56 57 58 59 60
}

Index_ptr IndexBuilder::build_all(const long &nb, const vector<float> &xb,
                                  const vector<long> &ids,
                                  const long &nt, const vector<float> &xt) {
    return build_all(nb, xb.data(), ids.data(), nt, xt.data());
X
xj.lin 已提交
61 62 63 64 65 66 67 68 69 70
}

// Be Factory pattern later
IndexBuilderPtr GetIndexBuilder(const Operand_ptr &opd) {
    return std::make_shared<IndexBuilder>(opd);
}

}
}
}