vec_impl.cpp 3.0 KB
Newer Older
X
MS-154  
xj.lin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////

#include "knowhere/index/index.h"
#include "knowhere/index/index_model.h"
#include "knowhere/index/index_type.h"
#include "knowhere/adapter/sptag.h"
#include "knowhere/common/tensor.h"

#include "vec_impl.h"
#include "data_transfer.h"


namespace zilliz {
namespace vecwise {
namespace engine {

using namespace zilliz::knowhere;

void VecIndexImpl::BuildAll(const long &nb,
                            const float *xb,
                            const long *ids,
                            const Config &cfg,
                            const long &nt,
                            const float *xt) {
    auto d = cfg["dim"].as<int>();
    auto dataset = GenDatasetWithIds(nb, d, xb, ids);

    auto preprocessor = index_->BuildPreprocessor(dataset, cfg);
    index_->set_preprocessor(preprocessor);
    auto model = index_->Train(dataset, cfg);
    index_->set_index_model(model);
    index_->Add(dataset, cfg);
}

void VecIndexImpl::Add(const long &nb, const float *xb, const long *ids, const Config &cfg) {
    // TODO(linxj): Assert index is trained;

    auto d = cfg["dim"].as<int>();
    auto dataset = GenDatasetWithIds(nb, d, xb, ids);

    index_->Add(dataset, cfg);
}

void VecIndexImpl::Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) {
    // TODO: Assert index is trained;

    auto d = cfg["dim"].as<int>();
    auto k = cfg["k"].as<int>();
    auto dataset = GenDataset(nq, d, xq);

    Config search_cfg;
    auto res = index_->Search(dataset, cfg);
    auto ids_array = res->array()[0];
    auto dis_array = res->array()[1];

    //{
    //    auto& ids = ids_array;
    //    auto& dists = dis_array;
    //    std::stringstream ss_id;
    //    std::stringstream ss_dist;
    //    for (auto i = 0; i < 10; i++) {
    //        for (auto j = 0; j < k; ++j) {
    //            ss_id << *(ids->data()->GetValues<int64_t>(1, i * k + j)) << " ";
    //            ss_dist << *(dists->data()->GetValues<float>(1, i * k + j)) << " ";
    //        }
    //        ss_id << std::endl;
    //        ss_dist << std::endl;
    //    }
    //    std::cout << "id\n" << ss_id.str() << std::endl;
    //    std::cout << "dist\n" << ss_dist.str() << std::endl;
    //}

    auto p_ids = ids_array->data()->GetValues<int64_t>(1, 0);
    auto p_dist = ids_array->data()->GetValues<float>(1, 0);

    // TODO(linxj): avoid copy here.
    memcpy(ids, p_ids, sizeof(int64_t) * nq * k);
    memcpy(dist, p_dist, sizeof(float) * nq * k);
}

zilliz::knowhere::BinarySet VecIndexImpl::Serialize() {
    return index_->Serialize();
}

void VecIndexImpl::Load(const zilliz::knowhere::BinarySet &index_binary) {
    index_->Load(index_binary);
}

}
}
}