vec_impl.cpp 5.2 KB
Newer Older
X
MS-154  
xj.lin 已提交
1 2 3 4 5 6
////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////

X
xj.lin 已提交
7
#include <src/utils/Log.h>
X
xj.lin 已提交
8
#include "knowhere/index/vector_index/idmap.h"
X
xj.lin 已提交
9
#include "knowhere/index/vector_index/gpu_ivf.h"
X
MS-154  
xj.lin 已提交
10 11 12

#include "vec_impl.h"
#include "data_transfer.h"
X
xj.lin 已提交
13
#include "wrapper_log.h"
X
MS-154  
xj.lin 已提交
14 15 16


namespace zilliz {
X
xj.lin 已提交
17
namespace milvus {
X
MS-154  
xj.lin 已提交
18 19 20 21 22 23 24 25 26 27
namespace engine {

using namespace zilliz::knowhere;

void VecIndexImpl::BuildAll(const long &nb,
                            const float *xb,
                            const long *ids,
                            const Config &cfg,
                            const long &nt,
                            const float *xt) {
X
xj.lin 已提交
28 29
    dim = cfg["dim"].as<int>();
    auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
X
MS-154  
xj.lin 已提交
30 31 32

    auto preprocessor = index_->BuildPreprocessor(dataset, cfg);
    index_->set_preprocessor(preprocessor);
X
xj.lin 已提交
33 34 35
    auto nlist = int(nb / 1000000.0 * 16384);
    auto cfg_t = Config::object{{"nlist", nlist}, {"dim", dim}};
    auto model = index_->Train(dataset, cfg_t);
X
MS-154  
xj.lin 已提交
36 37 38 39 40 41 42
    index_->set_index_model(model);
    index_->Add(dataset, cfg);
}

void VecIndexImpl::Add(const long &nb, const float *xb, const long *ids, const Config &cfg) {
    // TODO(linxj): Assert index is trained;

X
xj.lin 已提交
43
    auto d = cfg.get_with_default("dim", dim);
X
MS-154  
xj.lin 已提交
44 45 46 47 48 49 50 51 52
    auto dataset = GenDatasetWithIds(nb, d, xb, ids);

    index_->Add(dataset, cfg);
}

void VecIndexImpl::Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) {
    // TODO: Assert index is trained;

    auto k = cfg["k"].as<int>();
X
xj.lin 已提交
53
    auto d = cfg.get_with_default("dim", dim);
X
MS-154  
xj.lin 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
    auto dataset = GenDataset(nq, d, xq);

    Config search_cfg;
    auto res = index_->Search(dataset, cfg);
    auto ids_array = res->array()[0];
    auto dis_array = res->array()[1];

    //{
    //    auto& ids = ids_array;
    //    auto& dists = dis_array;
    //    std::stringstream ss_id;
    //    std::stringstream ss_dist;
    //    for (auto i = 0; i < 10; i++) {
    //        for (auto j = 0; j < k; ++j) {
    //            ss_id << *(ids->data()->GetValues<int64_t>(1, i * k + j)) << " ";
    //            ss_dist << *(dists->data()->GetValues<float>(1, i * k + j)) << " ";
    //        }
    //        ss_id << std::endl;
    //        ss_dist << std::endl;
    //    }
    //    std::cout << "id\n" << ss_id.str() << std::endl;
    //    std::cout << "dist\n" << ss_dist.str() << std::endl;
    //}

    auto p_ids = ids_array->data()->GetValues<int64_t>(1, 0);
X
xj.lin 已提交
79
    auto p_dist = dis_array->data()->GetValues<float>(1, 0);
X
MS-154  
xj.lin 已提交
80 81 82 83 84 85 86 87 88 89 90 91

    // TODO(linxj): avoid copy here.
    memcpy(ids, p_ids, sizeof(int64_t) * nq * k);
    memcpy(dist, p_dist, sizeof(float) * nq * k);
}

zilliz::knowhere::BinarySet VecIndexImpl::Serialize() {
    return index_->Serialize();
}

void VecIndexImpl::Load(const zilliz::knowhere::BinarySet &index_binary) {
    index_->Load(index_binary);
X
xj.lin 已提交
92
    dim = Dimension();
X
MS-154  
xj.lin 已提交
93 94
}

X
xj.lin 已提交
95 96 97 98 99 100 101 102
int64_t VecIndexImpl::Dimension() {
    return index_->Dimension();
}

int64_t VecIndexImpl::Count() {
    return index_->Count();
}

X
xj.lin 已提交
103 104 105 106
IndexType VecIndexImpl::GetType() {
    return type;
}

X
xj.lin 已提交
107
float *BFIndex::GetRawVectors() {
X
xj.lin 已提交
108 109 110
    auto raw_index = std::dynamic_pointer_cast<IDMAP>(index_);
    if (raw_index) { return raw_index->GetRawVectors(); }
    return nullptr;
X
xj.lin 已提交
111 112 113 114 115 116 117 118 119 120 121
}

int64_t *BFIndex::GetRawIds() {
    return std::static_pointer_cast<IDMAP>(index_)->GetRawIds();
}

void BFIndex::Build(const int64_t &d) {
    dim = d;
    std::static_pointer_cast<IDMAP>(index_)->Train(dim);
}

X
xj.lin 已提交
122 123 124 125 126 127 128 129 130 131 132 133 134
void BFIndex::BuildAll(const long &nb,
                       const float *xb,
                       const long *ids,
                       const Config &cfg,
                       const long &nt,
                       const float *xt) {
    dim = cfg["dim"].as<int>();
    auto dataset = GenDatasetWithIds(nb, dim, xb, ids);

    std::static_pointer_cast<IDMAP>(index_)->Train(dim);
    index_->Add(dataset, cfg);
}

X
xj.lin 已提交
135 136 137 138 139 140 141
// TODO(linxj): add lock here.
void IVFMixIndex::BuildAll(const long &nb,
                           const float *xb,
                           const long *ids,
                           const Config &cfg,
                           const long &nt,
                           const float *xt) {
X
xj.lin 已提交
142 143
    WRAPPER_LOG_DEBUG << "Get Into Build IVFMIX";

X
xj.lin 已提交
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
    dim = cfg["dim"].as<int>();
    auto dataset = GenDatasetWithIds(nb, dim, xb, ids);

    auto preprocessor = index_->BuildPreprocessor(dataset, cfg);
    index_->set_preprocessor(preprocessor);
    auto nlist = int(nb / 1000000.0 * 16384);
    auto cfg_t = Config::object{{"nlist", nlist}, {"dim", dim}};
    auto model = index_->Train(dataset, cfg_t);
    index_->set_index_model(model);
    index_->Add(dataset, cfg);

    if (auto device_index = std::dynamic_pointer_cast<GPUIVF>(index_)) {
        auto host_index = device_index->Copy_index_gpu_to_cpu();
        index_ = host_index;
    } else {
X
xj.lin 已提交
159
        WRAPPER_LOG_ERROR << "Build IVFMIXIndex Failed";
X
xj.lin 已提交
160 161 162 163 164 165 166 167 168
    }
}

void IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) {
    index_ = std::make_shared<IVF>();
    index_->Load(index_binary);
    dim = Dimension();
}

X
MS-154  
xj.lin 已提交
169 170 171
}
}
}