vec_impl.cpp 3.9 KB
Newer Older
X
MS-154  
xj.lin 已提交
1 2 3 4 5 6
////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////

X
xj.lin 已提交
7
#include <src/utils/Log.h>
X
xj.lin 已提交
8
#include "knowhere/index/vector_index/idmap.h"
X
MS-154  
xj.lin 已提交
9 10 11 12 13 14

#include "vec_impl.h"
#include "data_transfer.h"


namespace zilliz {
X
xj.lin 已提交
15
namespace milvus {
X
MS-154  
xj.lin 已提交
16 17 18 19 20 21 22 23 24 25
namespace engine {

using namespace zilliz::knowhere;

void VecIndexImpl::BuildAll(const long &nb,
                            const float *xb,
                            const long *ids,
                            const Config &cfg,
                            const long &nt,
                            const float *xt) {
X
xj.lin 已提交
26 27
    dim = cfg["dim"].as<int>();
    auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
X
MS-154  
xj.lin 已提交
28 29 30

    auto preprocessor = index_->BuildPreprocessor(dataset, cfg);
    index_->set_preprocessor(preprocessor);
X
xj.lin 已提交
31 32 33
    auto nlist = int(nb / 1000000.0 * 16384);
    auto cfg_t = Config::object{{"nlist", nlist}, {"dim", dim}};
    auto model = index_->Train(dataset, cfg_t);
X
MS-154  
xj.lin 已提交
34 35 36 37 38 39 40
    index_->set_index_model(model);
    index_->Add(dataset, cfg);
}

void VecIndexImpl::Add(const long &nb, const float *xb, const long *ids, const Config &cfg) {
    // TODO(linxj): Assert index is trained;

X
xj.lin 已提交
41
    auto d = cfg.get_with_default("dim", dim);
X
MS-154  
xj.lin 已提交
42 43 44 45 46 47 48 49 50
    auto dataset = GenDatasetWithIds(nb, d, xb, ids);

    index_->Add(dataset, cfg);
}

void VecIndexImpl::Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) {
    // TODO: Assert index is trained;

    auto k = cfg["k"].as<int>();
X
xj.lin 已提交
51
    auto d = cfg.get_with_default("dim", dim);
X
MS-154  
xj.lin 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
    auto dataset = GenDataset(nq, d, xq);

    Config search_cfg;
    auto res = index_->Search(dataset, cfg);
    auto ids_array = res->array()[0];
    auto dis_array = res->array()[1];

    //{
    //    auto& ids = ids_array;
    //    auto& dists = dis_array;
    //    std::stringstream ss_id;
    //    std::stringstream ss_dist;
    //    for (auto i = 0; i < 10; i++) {
    //        for (auto j = 0; j < k; ++j) {
    //            ss_id << *(ids->data()->GetValues<int64_t>(1, i * k + j)) << " ";
    //            ss_dist << *(dists->data()->GetValues<float>(1, i * k + j)) << " ";
    //        }
    //        ss_id << std::endl;
    //        ss_dist << std::endl;
    //    }
    //    std::cout << "id\n" << ss_id.str() << std::endl;
    //    std::cout << "dist\n" << ss_dist.str() << std::endl;
    //}

    auto p_ids = ids_array->data()->GetValues<int64_t>(1, 0);
X
xj.lin 已提交
77
    auto p_dist = dis_array->data()->GetValues<float>(1, 0);
X
MS-154  
xj.lin 已提交
78 79 80 81 82 83 84 85 86 87 88 89

    // TODO(linxj): avoid copy here.
    memcpy(ids, p_ids, sizeof(int64_t) * nq * k);
    memcpy(dist, p_dist, sizeof(float) * nq * k);
}

zilliz::knowhere::BinarySet VecIndexImpl::Serialize() {
    return index_->Serialize();
}

void VecIndexImpl::Load(const zilliz::knowhere::BinarySet &index_binary) {
    index_->Load(index_binary);
X
xj.lin 已提交
90
    dim = Dimension();
X
MS-154  
xj.lin 已提交
91 92
}

X
xj.lin 已提交
93 94 95 96 97 98 99 100 101
int64_t VecIndexImpl::Dimension() {
    return index_->Dimension();
}

int64_t VecIndexImpl::Count() {
    return index_->Count();
}

float *BFIndex::GetRawVectors() {
X
xj.lin 已提交
102 103 104
    auto raw_index = std::dynamic_pointer_cast<IDMAP>(index_);
    if (raw_index) { return raw_index->GetRawVectors(); }
    return nullptr;
X
xj.lin 已提交
105 106 107 108 109 110 111 112 113 114 115
}

int64_t *BFIndex::GetRawIds() {
    return std::static_pointer_cast<IDMAP>(index_)->GetRawIds();
}

void BFIndex::Build(const int64_t &d) {
    dim = d;
    std::static_pointer_cast<IDMAP>(index_)->Train(dim);
}

X
xj.lin 已提交
116 117 118 119 120 121 122 123 124 125 126 127 128
void BFIndex::BuildAll(const long &nb,
                       const float *xb,
                       const long *ids,
                       const Config &cfg,
                       const long &nt,
                       const float *xt) {
    dim = cfg["dim"].as<int>();
    auto dataset = GenDatasetWithIds(nb, dim, xb, ids);

    std::static_pointer_cast<IDMAP>(index_)->Train(dim);
    index_->Add(dataset, cfg);
}

X
MS-154  
xj.lin 已提交
129 130 131
}
}
}