IndexBuilder.cpp 4.1 KB
Newer Older
X
xj.lin 已提交
1 2 3 4 5 6 7 8
////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////

#include "mutex"

X
xj.lin 已提交
9

X
xj.lin 已提交
10
#ifdef GPU_VERSION
X
xj.lin 已提交
11
#include <faiss/gpu/StandardGpuResources.h>
X
xj.lin 已提交
12 13
#include <faiss/gpu/GpuIndexIVFFlat.h>
#include <faiss/gpu/GpuAutoTune.h>
X
xj.lin 已提交
14 15
#endif

X
xj.lin 已提交
16

X
xj.lin 已提交
17 18 19 20 21
#include <faiss/IndexFlat.h>
#include <easylogging++.h>


#include "server/ServerConfig.h"
X
xj.lin 已提交
22 23
#include "IndexBuilder.h"

X
xj.lin 已提交
24

X
xj.lin 已提交
25 26 27 28
namespace zilliz {
namespace vecwise {
namespace engine {

X
xj.lin 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
class GpuResources {
 public:
    static GpuResources &GetInstance() {
        static GpuResources instance;
        return instance;
    }

    void SelectGpu() {
        using namespace zilliz::vecwise::server;
        ServerConfig &config = ServerConfig::GetInstance();
        ConfigNode server_config = config.GetConfig(CONFIG_SERVER);
        gpu_num = server_config.GetInt32Value("gpu_index", 0);
    }

    int32_t GetGpu() {
        return gpu_num;
    }

 private:
    GpuResources() : gpu_num(0) { SelectGpu(); }

 private:
    int32_t gpu_num;
};

X
xj.lin 已提交
54 55
using std::vector;

X
xj.lin 已提交
56
static std::mutex gpu_resource;
X
xj.lin 已提交
57
static std::mutex cpu_resource;
X
xj.lin 已提交
58 59 60 61 62

IndexBuilder::IndexBuilder(const Operand_ptr &opd) {
    opd_ = opd;
}

X
xj.lin 已提交
63
// Default: build use gpu
64
Index_ptr IndexBuilder::build_all(const long &nb,
X
xj.lin 已提交
65 66
                                  const float *xb,
                                  const long *ids,
67
                                  const long &nt,
X
xj.lin 已提交
68
                                  const float *xt) {
X
xj.lin 已提交
69
    std::shared_ptr<faiss::Index> host_index = nullptr;
X
xj.lin 已提交
70
#ifdef GPU_VERSION
X
xj.lin 已提交
71
    {
X
xj.lin 已提交
72
        // TODO: list support index-type.
X
xj.lin 已提交
73
        faiss::Index *ori_index = faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str());
X
xj.lin 已提交
74 75 76

        std::lock_guard<std::mutex> lk(gpu_resource);
        faiss::gpu::StandardGpuResources res;
X
xj.lin 已提交
77
        auto device_index = faiss::gpu::index_cpu_to_gpu(&res, GpuResources::GetInstance().GetGpu(), ori_index);
X
xj.lin 已提交
78 79 80
        if (!device_index->is_trained) {
            nt == 0 || xt == nullptr ? device_index->train(nb, xb)
                                     : device_index->train(nt, xt);
X
xj.lin 已提交
81
        }
X
xj.lin 已提交
82
        device_index->add_with_ids(nb, xb, ids); // TODO: support with add_with_IDMAP
X
xj.lin 已提交
83 84

        host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
X
xj.lin 已提交
85

X
xj.lin 已提交
86 87 88
        delete device_index;
        delete ori_index;
    }
X
xj.lin 已提交
89 90 91 92 93 94 95 96 97 98 99
#else
    {
        faiss::Index *index = faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str());
        if (!index->is_trained) {
            nt == 0 || xt == nullptr ? index->train(nb, xb)
                                     : index->train(nt, xt);
        }
        index->add_with_ids(nb, xb, ids);
        host_index.reset(index);
    }
#endif
100

X
xj.lin 已提交
101
    return std::make_shared<Index>(host_index);
102 103 104 105 106 107
}

Index_ptr IndexBuilder::build_all(const long &nb, const vector<float> &xb,
                                  const vector<long> &ids,
                                  const long &nt, const vector<float> &xt) {
    return build_all(nb, xb.data(), ids.data(), nt, xt.data());
X
xj.lin 已提交
108 109
}

X
xj.lin 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
BgCpuBuilder::BgCpuBuilder(const zilliz::vecwise::engine::Operand_ptr &opd) : IndexBuilder(opd) {};

Index_ptr BgCpuBuilder::build_all(const long &nb, const float *xb, const long *ids, const long &nt, const float *xt) {
    std::shared_ptr<faiss::Index> index = nullptr;
    index.reset(faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str()));

    {
        std::lock_guard<std::mutex> lk(cpu_resource);
        if (!index->is_trained) {
            nt == 0 || xt == nullptr ? index->train(nb, xb)
                                     : index->train(nt, xt);
        }
        index->add_with_ids(nb, xb, ids);
    }

    return std::make_shared<Index>(index);
}

// TODO: Be Factory pattern later
X
xj.lin 已提交
129
IndexBuilderPtr GetIndexBuilder(const Operand_ptr &opd) {
X
xj.lin 已提交
130 131 132 133 134 135
    if (opd->index_type == "IDMap") {
        // TODO: fix hardcode
        IndexBuilderPtr index = nullptr;
        return std::make_shared<BgCpuBuilder>(opd);
    }

X
xj.lin 已提交
136 137 138 139 140 141
    return std::make_shared<IndexBuilder>(opd);
}

}
}
}