#include #include #include "knowhere/index/vector_index/cpu_kdt_rng.h" #include "knowhere/index/vector_index/definitions.h" #include "knowhere/adapter/sptag.h" #include "knowhere/adapter/structure.h" using namespace zilliz::knowhere; DatasetPtr generate_dataset(int64_t n, int64_t d, int64_t base) { auto elems = n * d; auto p_data = (float *) malloc(elems * sizeof(float)); auto p_id = (int64_t *) malloc(elems * sizeof(int64_t)); assert(p_data != nullptr && p_id != nullptr); for (auto i = 0; i < n; ++i) { for (auto j = 0; j < d; ++j) { p_data[i * d + j] = float(base + i); } p_id[i] = i; } std::vector shape{n, d}; auto tensor = ConstructFloatTensorSmart((uint8_t *) p_data, elems * sizeof(float), shape); std::vector tensors{tensor}; std::vector tensor_fields{ConstructFloatField("data")}; auto tensor_schema = std::make_shared(tensor_fields); auto id_array = ConstructInt64ArraySmart((uint8_t *) p_id, n * sizeof(int64_t)); std::vector arrays{id_array}; std::vector array_fields{ConstructInt64Field("id")}; auto array_schema = std::make_shared(tensor_fields); auto dataset = std::make_shared(std::move(arrays), array_schema, std::move(tensors), tensor_schema); return dataset; } DatasetPtr generate_queries(int64_t n, int64_t d, int64_t k, int64_t base) { size_t size = sizeof(float) * n * d; auto v = (float *) malloc(size); // TODO(lxj): check malloc for (auto i = 0; i < n; ++i) { for (auto j = 0; j < d; ++j) { v[i * d + j] = float(base + i); } } std::vector data; auto buffer = MakeMutableBufferSmart((uint8_t *) v, size); std::vector shape{n, d}; auto float_type = std::make_shared(); auto tensor = std::make_shared(float_type, buffer, shape); data.push_back(tensor); Config meta; meta[META_ROWS] = int64_t (n); meta[META_DIM] = int64_t (d); meta[META_K] = int64_t (k); auto type = std::make_shared(); auto field = std::make_shared("data", type); std::vector fields{field}; auto schema = std::make_shared(fields); return std::make_shared(data, schema); } int main(int argc, char *argv[]) { auto kdt_index = std::make_shared(); const auto d = 10; const auto k = 3; const auto nquery = 10; // ID [0, 99] auto train = generate_dataset(100, d, 0); // ID [100] auto base = generate_dataset(1, d, 0); auto queries = generate_queries(nquery, d, k, 0); // Build Preprocessor auto preprocessor = kdt_index->BuildPreprocessor(train, Config()); // Set Preprocessor kdt_index->set_preprocessor(preprocessor); Config train_config; train_config["TPTNumber"] = "64"; // Train kdt_index->Train(train, train_config); // Add kdt_index->Add(base, Config()); auto binary = kdt_index->Serialize(); auto new_index = std::make_shared(); new_index->Load(binary); // auto new_index = kdt_index; Config search_config; search_config[META_K] = int64_t (k); // Search auto result = new_index->Search(queries, search_config); // Print Result { auto ids = result->array()[0]; auto dists = result->array()[1]; std::stringstream ss_id; std::stringstream ss_dist; for (auto i = 0; i < nquery; i++) { for (auto j = 0; j < k; ++j) { ss_id << *ids->data()->GetValues(1, i * k + j) << " "; ss_dist << *dists->data()->GetValues(1, i * k + j) << " "; } ss_id << std::endl; ss_dist << std::endl; } std::cout << "id\n" << ss_id.str() << std::endl; std::cout << "dist\n" << ss_dist.str() << std::endl; } }