kdtree.cpp 4.0 KB
Newer Older
X
xj.lin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46

#include <iostream>
#include <sstream>
#include "knowhere/index/vector_index/cpu_kdt_rng.h"
#include "knowhere/index/vector_index/definitions.h"
#include "knowhere/adapter/sptag.h"
#include "knowhere/adapter/structure.h"


using namespace zilliz::knowhere;

DatasetPtr
generate_dataset(int64_t n, int64_t d, int64_t base) {
    auto elems = n * d;
    auto p_data = (float *) malloc(elems * sizeof(float));
    auto p_id = (int64_t *) malloc(elems * sizeof(int64_t));
    assert(p_data != nullptr && p_id != nullptr);

    for (auto i = 0; i < n; ++i) {
        for (auto j = 0; j < d; ++j) {
            p_data[i * d + j] = float(base + i);
        }
        p_id[i] = i;
    }

    std::vector<int64_t> shape{n, d};
    auto tensor = ConstructFloatTensorSmart((uint8_t *) p_data, elems * sizeof(float), shape);
    std::vector<TensorPtr> tensors{tensor};
    std::vector<FieldPtr> tensor_fields{ConstructFloatField("data")};
    auto tensor_schema = std::make_shared<Schema>(tensor_fields);

    auto id_array = ConstructInt64ArraySmart((uint8_t *) p_id, n * sizeof(int64_t));
    std::vector<ArrayPtr> arrays{id_array};
    std::vector<FieldPtr> array_fields{ConstructInt64Field("id")};
    auto array_schema = std::make_shared<Schema>(tensor_fields);

    auto dataset = std::make_shared<Dataset>(std::move(arrays), array_schema,
                                             std::move(tensors), tensor_schema);

    return dataset;
}

DatasetPtr
generate_queries(int64_t n, int64_t d, int64_t k, int64_t base) {
    size_t size = sizeof(float) * n * d;
    auto v = (float *) malloc(size);
S
starlord 已提交
47
    // TODO(lxj): check malloc
X
xj.lin 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
    for (auto i = 0; i < n; ++i) {
        for (auto j = 0; j < d; ++j) {
            v[i * d + j] = float(base + i);
        }
    }

    std::vector<TensorPtr> data;
    auto buffer = MakeMutableBufferSmart((uint8_t *) v, size);
    std::vector<int64_t> shape{n, d};
    auto float_type = std::make_shared<arrow::FloatType>();
    auto tensor = std::make_shared<Tensor>(float_type, buffer, shape);
    data.push_back(tensor);

    Config meta;
    meta[META_ROWS] = int64_t (n);
    meta[META_DIM] = int64_t (d);
    meta[META_K] = int64_t (k);

    auto type = std::make_shared<arrow::FloatType>();
    auto field = std::make_shared<Field>("data", type);
    std::vector<FieldPtr> fields{field};
    auto schema = std::make_shared<Schema>(fields);

    return std::make_shared<Dataset>(data, schema);
}


int
main(int argc, char *argv[]) {
    auto kdt_index = std::make_shared<CPUKDTRNG>();

    const auto d = 10;
    const auto k = 3;
    const auto nquery = 10;

    // ID [0, 99]
    auto train = generate_dataset(100, d, 0);
    // ID [100]
    auto base = generate_dataset(1, d, 0);
    auto queries = generate_queries(nquery, d, k, 0);

    // Build Preprocessor
    auto preprocessor = kdt_index->BuildPreprocessor(train, Config());

    // Set Preprocessor
    kdt_index->set_preprocessor(preprocessor);

    Config train_config;
    train_config["TPTNumber"] = "64";
    // Train
    kdt_index->Train(train, train_config);

    // Add
    kdt_index->Add(base, Config());

    auto binary = kdt_index->Serialize();
    auto new_index = std::make_shared<CPUKDTRNG>();
    new_index->Load(binary);
//    auto new_index = kdt_index;

    Config search_config;
    search_config[META_K] = int64_t (k);

    // Search
    auto result = new_index->Search(queries, search_config);

    // Print Result
    {
        auto ids = result->array()[0];
        auto dists = result->array()[1];

        std::stringstream ss_id;
        std::stringstream ss_dist;
        for (auto i = 0; i < nquery; i++) {
            for (auto j = 0; j < k; ++j) {
                ss_id << *ids->data()->GetValues<int64_t>(1, i * k + j) << " ";
                ss_dist << *dists->data()->GetValues<float>(1, i * k + j) << " ";
            }
            ss_id << std::endl;
            ss_dist << std::endl;
        }
        std::cout << "id\n" << ss_id.str() << std::endl;
        std::cout << "dist\n" << ss_dist.str() << std::endl;
    }
}