提交 bfd4c6c5 编写于 作者: X xiaojun.lin

upgrade SPTAG and support KDT and BKT

上级 0aa90ea9
......@@ -35,7 +35,9 @@ enum class EngineType {
NSG_MIX,
FAISS_IVFSQ8H,
FAISS_PQ,
MAX_VALUE = FAISS_PQ,
SPTAG_KDT,
SPTAG_BKT,
MAX_VALUE = SPTAG_BKT,
};
enum class MetricType {
......
......@@ -124,6 +124,14 @@ ExecutionEngineImpl::CreatetVecIndex(EngineType type) {
#endif
break;
}
case EngineType::SPTAG_KDT: {
index = GetVecIndexFactory(IndexType::SPTAG_KDT_RNT_CPU);
break;
}
case EngineType::SPTAG_BKT: {
index = GetVecIndexFactory(IndexType::SPTAG_BKT_RNT_CPU);
break;
}
default: {
ENGINE_LOG_ERROR << "Unsupported index type";
return nullptr;
......
......@@ -30,10 +30,10 @@ set(external_srcs
set(index_srcs
knowhere/index/preprocessor/Normalize.cpp
knowhere/index/vector_index/IndexKDT.cpp
knowhere/index/vector_index/IndexSPTAG.cpp
knowhere/index/vector_index/IndexIDMAP.cpp
knowhere/index/vector_index/IndexIVF.cpp
knowhere/index/vector_index/helpers/KDTParameterMgr.cpp
knowhere/index/vector_index/helpers/SPTAGParameterMgr.cpp
knowhere/index/vector_index/IndexNSG.cpp
knowhere/index/vector_index/nsg/NSG.cpp
knowhere/index/vector_index/nsg/NSGIO.cpp
......
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <SPTAG/AnnService/inc/Core/Common.h>
#include <SPTAG/AnnService/inc/Core/VectorSet.h>
#include <SPTAG/AnnService/inc/Server/QueryParser.h>
#include <sstream>
#include <vector>
#undef mkdir
#include "knowhere/index/vector_index/IndexKDT.h"
#include "knowhere/index/vector_index/helpers/Definitions.h"
//#include "knowhere/index/preprocessor/normalize.h"
#include "knowhere/adapter/SptagAdapter.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/helpers/KDTParameterMgr.h"
namespace knowhere {
BinarySet
CPUKDTRNG::Serialize() {
std::vector<void*> index_blobs;
std::vector<int64_t> index_len;
// TODO(zirui): dev
// index_ptr_->SaveIndexToMemory(index_blobs, index_len);
BinarySet binary_set;
//
// auto sample = std::make_shared<uint8_t>();
// sample.reset(static_cast<uint8_t*>(index_blobs[0]));
// auto tree = std::make_shared<uint8_t>();
// tree.reset(static_cast<uint8_t*>(index_blobs[1]));
// auto graph = std::make_shared<uint8_t>();
// graph.reset(static_cast<uint8_t*>(index_blobs[2]));
// auto metadata = std::make_shared<uint8_t>();
// metadata.reset(static_cast<uint8_t*>(index_blobs[3]));
//
// binary_set.Append("samples", sample, index_len[0]);
// binary_set.Append("tree", tree, index_len[1]);
// binary_set.Append("graph", graph, index_len[2]);
// binary_set.Append("metadata", metadata, index_len[3]);
return binary_set;
}
void
CPUKDTRNG::Load(const BinarySet& binary_set) {
// TODO(zirui): dev
// std::vector<void*> index_blobs;
//
// auto samples = binary_set.GetByName("samples");
// index_blobs.push_back(samples->data.get());
//
// auto tree = binary_set.GetByName("tree");
// index_blobs.push_back(tree->data.get());
//
// auto graph = binary_set.GetByName("graph");
// index_blobs.push_back(graph->data.get());
//
// auto metadata = binary_set.GetByName("metadata");
// index_blobs.push_back(metadata->data.get());
//
// index_ptr_->LoadIndexFromMemory(index_blobs);
}
// PreprocessorPtr
// CPUKDTRNG::BuildPreprocessor(const DatasetPtr &dataset, const Config &config) {
// return std::make_shared<NormalizePreprocessor>();
//}
IndexModelPtr
CPUKDTRNG::Train(const DatasetPtr& origin, const Config& train_config) {
SetParameters(train_config);
DatasetPtr dataset = origin->Clone();
// if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
// && preprocessor_) {
// preprocessor_->Preprocess(dataset);
//}
auto vectorset = ConvertToVectorSet(dataset);
auto metaset = ConvertToMetadataSet(dataset);
index_ptr_->BuildIndex(vectorset, metaset);
// TODO: return IndexModelPtr
return nullptr;
}
void
CPUKDTRNG::Add(const DatasetPtr& origin, const Config& add_config) {
SetParameters(add_config);
DatasetPtr dataset = origin->Clone();
// if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
// && preprocessor_) {
// preprocessor_->Preprocess(dataset);
//}
auto vectorset = ConvertToVectorSet(dataset);
auto metaset = ConvertToMetadataSet(dataset);
index_ptr_->AddIndex(vectorset, metaset);
}
void
CPUKDTRNG::SetParameters(const Config& config) {
for (auto& para : KDTParameterMgr::GetInstance().GetKDTParameters()) {
// auto value = config.get_with_default(para.first, para.second);
index_ptr_->SetParameter(para.first, para.second);
}
}
DatasetPtr
CPUKDTRNG::Search(const DatasetPtr& dataset, const Config& config) {
SetParameters(config);
auto tensor = dataset->tensor()[0];
auto p = (float*)tensor->raw_mutable_data();
for (auto i = 0; i < 10; ++i) {
for (auto j = 0; j < 10; ++j) {
std::cout << p[i * 10 + j] << " ";
}
std::cout << std::endl;
}
std::vector<SPTAG::QueryResult> query_results = ConvertToQueryResult(dataset, config);
#pragma omp parallel for
for (auto i = 0; i < query_results.size(); ++i) {
auto target = (float*)query_results[i].GetTarget();
std::cout << target[0] << ", " << target[1] << ", " << target[2] << std::endl;
index_ptr_->SearchIndex(query_results[i]);
}
return ConvertToDataset(query_results);
}
int64_t
CPUKDTRNG::Count() {
index_ptr_->GetNumSamples();
}
int64_t
CPUKDTRNG::Dimension() {
index_ptr_->GetFeatureDim();
}
VectorIndexPtr
CPUKDTRNG::Clone() {
KNOWHERE_THROW_MSG("not support");
}
void
CPUKDTRNG::Seal() {
// do nothing
}
// TODO(linxj):
BinarySet
CPUKDTRNGIndexModel::Serialize() {
}
void
CPUKDTRNGIndexModel::Load(const BinarySet& binary) {
}
} // namespace knowhere
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <SPTAG/AnnService/inc/Core/Common.h>
#include <SPTAG/AnnService/inc/Core/VectorSet.h>
#include <SPTAG/AnnService/inc/Server/QueryParser.h>
#include <sstream>
#include <vector>
#include <array>
#undef mkdir
#include "knowhere/index/vector_index/IndexSPTAG.h"
#include "knowhere/index/vector_index/helpers/Definitions.h"
#include "knowhere/adapter/SptagAdapter.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/helpers/SPTAGParameterMgr.h"
namespace knowhere {
CPUSPTAGRNG::CPUSPTAGRNG(const std::string& IndexType) {
if (IndexType == "KDT") {
index_ptr_ = SPTAG::VectorIndex::CreateInstance(SPTAG::IndexAlgoType::KDT, SPTAG::VectorValueType::Float);
index_ptr_->SetParameter("DistCalcMethod", "L2");
index_type_ = SPTAG::IndexAlgoType::KDT;
} else {
index_ptr_ = SPTAG::VectorIndex::CreateInstance(SPTAG::IndexAlgoType::BKT, SPTAG::VectorValueType::Float);
index_ptr_->SetParameter("DistCalcMethod", "L2");
index_type_ = SPTAG::IndexAlgoType::BKT;
}
}
BinarySet
CPUSPTAGRNG::Serialize() {
std::string index_config;
std::vector<SPTAG::ByteArray> index_blobs;
std::shared_ptr<std::vector<std::uint64_t>> buffersize = index_ptr_->CalculateBufferSize();
std::vector<char*> res(buffersize->size() + 1);
for (uint64_t i = 1; i < res.size(); i++) {
res[i] = new char[buffersize->at(i - 1)];
auto ptr = &res[i][0];
index_blobs.emplace_back(SPTAG::ByteArray((std::uint8_t*)ptr, buffersize->at(i - 1), false));
}
index_ptr_->SaveIndex(index_config, index_blobs);
size_t length = index_config.length();
char* cstr = new char[length];
snprintf(cstr, length, "%s", index_config.c_str());
BinarySet binary_set;
auto sample = std::make_shared<uint8_t>();
sample.reset(static_cast<uint8_t*>(index_blobs[0].Data()));
auto tree = std::make_shared<uint8_t>();
tree.reset(static_cast<uint8_t*>(index_blobs[1].Data()));
auto graph = std::make_shared<uint8_t>();
graph.reset(static_cast<uint8_t*>(index_blobs[2].Data()));
auto deleteid = std::make_shared<uint8_t>();
deleteid.reset(static_cast<uint8_t*>(index_blobs[3].Data()));
auto metadata1 = std::make_shared<uint8_t>();
metadata1.reset(static_cast<uint8_t*>(index_blobs[4].Data()));
auto metadata2 = std::make_shared<uint8_t>();
metadata2.reset(static_cast<uint8_t*>(index_blobs[5].Data()));
auto config = std::make_shared<uint8_t>();
config.reset(static_cast<uint8_t*>((void*)cstr));
binary_set.Append("samples", sample, index_blobs[0].Length());
binary_set.Append("tree", tree, index_blobs[1].Length());
binary_set.Append("deleteid", deleteid, index_blobs[3].Length());
binary_set.Append("metadata1", metadata1, index_blobs[4].Length());
binary_set.Append("metadata2", metadata2, index_blobs[5].Length());
binary_set.Append("config", config, length);
binary_set.Append("graph", graph, index_blobs[2].Length());
// MemoryIOWriter writer;
// size_t len = 0;
// for (int i = 0; i < 6; ++i) {
// len = index_blobs[i].Length();
// assert(len != 0);
// writer(&len, sizeof(size_t), 1);
// writer(index_blobs[i].Data(), len, 1);
// len = 0;
// }
// writer(&length, sizeof(size_t), 1);
// writer(cstr, length, 1);
// auto data = std::make_shared<uint8_t>();
// data.reset(writer.data_);
// BinarySet binary_set;
// binary_set.Append("sptag", data, writer.total);
// MemoryIOWriter writer;
// size_t len = 0;
// for (int i = 0; i < 6; ++i) {
// if (i == 2) continue;
// len = index_blobs[i].Length();
// assert(len != 0);
// writer(&len, sizeof(size_t), 1);
// writer(index_blobs[i].Data(), len, 1);
// len = 0;
// }
// writer(&length, sizeof(size_t), 1);
// writer(cstr, length, 1);
// auto data = std::make_shared<uint8_t>();
// data.reset(writer.data_);
// BinarySet binary_set;
// binary_set.Append("sptag", data, writer.total);
// auto graph = std::make_shared<uint8_t>();
// graph.reset(static_cast<uint8_t*>(index_blobs[2].Data()));
// binary_set.Append("graph", graph, index_blobs[2].Length());
return binary_set;
}
void
CPUSPTAGRNG::Load(const BinarySet& binary_set) {
std::string index_config;
std::vector<SPTAG::ByteArray> index_blobs;
auto samples = binary_set.GetByName("samples");
index_blobs.push_back(SPTAG::ByteArray(samples->data.get(), samples->size, false));
auto tree = binary_set.GetByName("tree");
index_blobs.push_back(SPTAG::ByteArray(tree->data.get(), tree->size, false));
auto graph = binary_set.GetByName("graph");
index_blobs.push_back(SPTAG::ByteArray(graph->data.get(), graph->size, false));
auto deleteid = binary_set.GetByName("deleteid");
index_blobs.push_back(SPTAG::ByteArray(deleteid->data.get(), deleteid->size, false));
auto metadata1 = binary_set.GetByName("metadata1");
index_blobs.push_back(SPTAG::ByteArray(metadata1->data.get(), metadata1->size, false));
auto metadata2 = binary_set.GetByName("metadata2");
index_blobs.push_back(SPTAG::ByteArray(metadata2->data.get(), metadata2->size, false));
auto config = binary_set.GetByName("config");
index_config = reinterpret_cast<char*>(config->data.get());
// std::vector<SPTAG::ByteArray> index_blobs;
// auto data = binary_set.GetByName("sptag");
// MemoryIOReader reader;
// reader.total = data->size;
// reader.data_ = data->data.get();
// size_t len = 0;
// for (int i = 0; i < 6; ++i) {
// reader(&len, sizeof(size_t), 1);
// assert(len != 0);
// auto binary = new uint8_t[len];
// reader(binary, len, 1);
// index_blobs.emplace_back(SPTAG::ByteArray(binary, len, true));
// len = 0;
// }
// reader(&len, sizeof(size_t), 1);
// assert(len != 0);
// auto config = new char[len];
// reader(config, len, 1);
// std::string index_config = config;
// delete[] config;
// std::vector<SPTAG::ByteArray> index_blobs;
// auto data = binary_set.GetByName("sptag");
// MemoryIOReader reader;
// reader.total = data->size;
// reader.data_ = data->data.get();
// size_t len = 0;
// for (int i = 0; i < 6; ++i) {
// if (i == 2) {
// auto graph = binary_set.GetByName("graph");
// index_blobs.emplace_back(SPTAG::ByteArray(graph->data.get(), graph->size, false));
// continue;
// }
// reader(&len, sizeof(size_t), 1);
// assert(len != 0);
// auto binary = new uint8_t[len];
// reader(binary, len, 1);
// index_blobs.emplace_back(SPTAG::ByteArray(binary, len, true));
// len = 0;
// }
// reader(&len, sizeof(size_t), 1);
// assert(len != 0);
// auto config = new char[len];
// reader(config, len, 1);
// std::string index_config = config;
// delete[] config;
index_ptr_->LoadIndex(index_config, index_blobs);
}
// PreprocessorPtr
// CPUKDTRNG::BuildPreprocessor(const DatasetPtr &dataset, const Config &config) {
// return std::make_shared<NormalizePreprocessor>();
//}
IndexModelPtr
CPUSPTAGRNG::Train(const DatasetPtr& origin, const Config& train_config) {
SetParameters(train_config);
DatasetPtr dataset = origin->Clone();
// if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
// && preprocessor_) {
// preprocessor_->Preprocess(dataset);
//}
auto vectorset = ConvertToVectorSet(dataset);
auto metaset = ConvertToMetadataSet(dataset);
index_ptr_->BuildIndex(vectorset, metaset);
// TODO: return IndexModelPtr
return nullptr;
}
void
CPUSPTAGRNG::Add(const DatasetPtr& origin, const Config& add_config) {
// SetParameters(add_config);
// DatasetPtr dataset = origin->Clone();
//
// // if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
// // && preprocessor_) {
// // preprocessor_->Preprocess(dataset);
// //}
//
// auto vectorset = ConvertToVectorSet(dataset);
// auto metaset = ConvertToMetadataSet(dataset);
// index_ptr_->AddIndex(vectorset, metaset);
}
void
CPUSPTAGRNG::SetParameters(const Config& config) {
#define Assign(param_name, str_name) \
conf->param_name == INVALID_VALUE ? index_ptr_->SetParameter(str_name, std::to_string(build_cfg->param_name)) \
: index_ptr_->SetParameter(str_name, std::to_string(conf->param_name))
if (index_type_ == SPTAG::IndexAlgoType::KDT) {
auto conf = std::dynamic_pointer_cast<KDTCfg>(config);
auto build_cfg = SPTAGParameterMgr::GetInstance().GetKDTParameters();
Assign(kdtnumber, "KDTNumber");
Assign(numtopdimensionkdtsplit, "NumTopDimensionKDTSplit");
Assign(samples, "Samples");
Assign(tptnumber, "TPTNumber");
Assign(tptleafsize, "TPTLeafSize");
Assign(numtopdimensiontptsplit, "NumTopDimensionTPTSplit");
Assign(neighborhoodsize, "NeighborhoodSize");
Assign(graphneighborhoodscale, "GraphNeighborhoodScale");
Assign(graphcefscale, "GraphCEFScale");
Assign(refineiterations, "RefineIterations");
Assign(cef, "CEF");
Assign(maxcheckforrefinegraph, "MaxCheckForRefineGraph");
Assign(numofthreads, "NumberOfThreads");
Assign(maxcheck, "MaxCheck");
Assign(thresholdofnumberofcontinuousnobetterpropagation, "ThresholdOfNumberOfContinuousNoBetterPropagation");
Assign(numberofinitialdynamicpivots, "NumberOfInitialDynamicPivots");
Assign(numberofotherdynamicpivots, "NumberOfOtherDynamicPivots");
} else {
auto conf = std::dynamic_pointer_cast<BKTCfg>(config);
auto build_cfg = SPTAGParameterMgr::GetInstance().GetBKTParameters();
Assign(bktnumber, "BKTNumber");
Assign(bktkmeansk, "BKTKMeansK");
Assign(bktleafsize, "BKTLeafSize");
Assign(samples, "Samples");
Assign(tptnumber, "TPTNumber");
Assign(tptleafsize, "TPTLeafSize");
Assign(numtopdimensiontptsplit, "NumTopDimensionTPTSplit");
Assign(neighborhoodsize, "NeighborhoodSize");
Assign(graphneighborhoodscale, "GraphNeighborhoodScale");
Assign(graphcefscale, "GraphCEFScale");
Assign(refineiterations, "RefineIterations");
Assign(cef, "CEF");
Assign(maxcheckforrefinegraph, "MaxCheckForRefineGraph");
Assign(numofthreads, "NumberOfThreads");
Assign(maxcheck, "MaxCheck");
Assign(thresholdofnumberofcontinuousnobetterpropagation, "ThresholdOfNumberOfContinuousNoBetterPropagation");
Assign(numberofinitialdynamicpivots, "NumberOfInitialDynamicPivots");
Assign(numberofotherdynamicpivots, "NumberOfOtherDynamicPivots");
}
}
DatasetPtr
CPUSPTAGRNG::Search(const DatasetPtr& dataset, const Config& config) {
SetParameters(config);
auto tensor = dataset->tensor()[0];
auto p = (float*)tensor->raw_mutable_data();
for (auto i = 0; i < 10; ++i) {
for (auto j = 0; j < 10; ++j) {
std::cout << p[i * 10 + j] << " ";
}
std::cout << std::endl;
}
std::vector<SPTAG::QueryResult> query_results = ConvertToQueryResult(dataset, config);
#pragma omp parallel for
for (auto i = 0; i < query_results.size(); ++i) {
auto target = (float*)query_results[i].GetTarget();
std::cout << target[0] << ", " << target[1] << ", " << target[2] << std::endl;
index_ptr_->SearchIndex(query_results[i]);
}
return ConvertToDataset(query_results);
}
int64_t
CPUSPTAGRNG::Count() {
return index_ptr_->GetNumSamples();
}
int64_t
CPUSPTAGRNG::Dimension() {
return index_ptr_->GetFeatureDim();
}
VectorIndexPtr
CPUSPTAGRNG::Clone() {
KNOWHERE_THROW_MSG("not support");
}
void
CPUSPTAGRNG::Seal() {
return; // do nothing
}
BinarySet
CPUSPTAGRNGIndexModel::Serialize() {
// KNOWHERE_THROW_MSG("not support"); // not support
}
void
CPUSPTAGRNGIndexModel::Load(const BinarySet& binary) {
// KNOWHERE_THROW_MSG("not support"); // not support
}
} // namespace knowhere
......@@ -18,70 +18,76 @@
#pragma once
#include <SPTAG/AnnService/inc/Core/VectorIndex.h>
#include <cstdint>
#include <memory>
#include <string>
#include "VectorIndex.h"
#include "knowhere/index/IndexModel.h"
namespace knowhere {
class CPUKDTRNG : public VectorIndex {
public:
CPUKDTRNG() {
index_ptr_ = SPTAG::VectorIndex::CreateInstance(SPTAG::IndexAlgoType::KDT, SPTAG::VectorValueType::Float);
index_ptr_->SetParameter("DistCalcMethod", "L2");
}
public:
BinarySet
Serialize() override;
VectorIndexPtr
Clone() override;
void
Load(const BinarySet& index_array) override;
public:
// PreprocessorPtr
// BuildPreprocessor(const DatasetPtr &dataset, const Config &config) override;
int64_t
Count() override;
int64_t
Dimension() override;
IndexModelPtr
Train(const DatasetPtr& dataset, const Config& config) override;
void
Add(const DatasetPtr& dataset, const Config& config) override;
DatasetPtr
Search(const DatasetPtr& dataset, const Config& config) override;
void
Seal() override;
private:
void
SetParameters(const Config& config);
private:
PreprocessorPtr preprocessor_;
std::shared_ptr<SPTAG::VectorIndex> index_ptr_;
};
using CPUKDTRNGPtr = std::shared_ptr<CPUKDTRNG>;
class CPUKDTRNGIndexModel : public IndexModel {
public:
BinarySet
Serialize() override;
void
Load(const BinarySet& binary) override;
private:
std::shared_ptr<SPTAG::VectorIndex> index_;
};
using CPUKDTRNGIndexModelPtr = std::shared_ptr<CPUKDTRNGIndexModel>;
class CPUSPTAGRNG : public VectorIndex {
public:
explicit CPUSPTAGRNG(const std::string& IndexType);
public:
BinarySet
Serialize() override;
VectorIndexPtr
Clone() override;
void
Load(const BinarySet& index_array) override;
public:
// PreprocessorPtr
// BuildPreprocessor(const DatasetPtr &dataset, const Config &config) override;
int64_t
Count() override;
int64_t
Dimension() override;
IndexModelPtr
Train(const DatasetPtr& dataset, const Config& config) override;
void
Add(const DatasetPtr& dataset, const Config& config) override;
DatasetPtr
Search(const DatasetPtr& dataset, const Config& config) override;
void
Seal() override;
private:
void
SetParameters(const Config& config);
private:
PreprocessorPtr preprocessor_;
std::shared_ptr<SPTAG::VectorIndex> index_ptr_;
SPTAG::IndexAlgoType index_type_;
};
using CPUSPTAGRNGPtr = std::shared_ptr<CPUSPTAGRNG>;
class CPUSPTAGRNGIndexModel : public IndexModel {
public:
BinarySet
Serialize() override;
void
Load(const BinarySet& binary) override;
private:
std::shared_ptr<SPTAG::VectorIndex> index_;
};
using CPUSPTAGRNGIndexModelPtr = std::shared_ptr<CPUSPTAGRNGIndexModel>;
} // namespace knowhere
......@@ -42,6 +42,32 @@ constexpr int64_t DEFAULT_OUT_DEGREE = INVALID_VALUE;
constexpr int64_t DEFAULT_CANDIDATE_SISE = INVALID_VALUE;
constexpr int64_t DEFAULT_NNG_K = INVALID_VALUE;
// SPTAG Config
constexpr int64_t DEFAULT_SAMPLES = INVALID_VALUE;
constexpr int64_t DEFAULT_TPTNUMBER = INVALID_VALUE;
constexpr int64_t DEFAULT_TPTLEAFSIZE = INVALID_VALUE;
constexpr int64_t DEFAULT_NUMTOPDIMENSIONTPTSPLIT = INVALID_VALUE;
constexpr int64_t DEFAULT_NEIGHBORHOODSIZE = INVALID_VALUE;
constexpr int64_t DEFAULT_GRAPHNEIGHBORHOODSCALE = INVALID_VALUE;
constexpr int64_t DEFAULT_GRAPHCEFSCALE = INVALID_VALUE;
constexpr int64_t DEFAULT_REFINEITERATIONS = INVALID_VALUE;
constexpr int64_t DEFAULT_CEF = INVALID_VALUE;
constexpr int64_t DEFAULT_MAXCHECKFORREFINEGRAPH = INVALID_VALUE;
constexpr int64_t DEFAULT_NUMOFTHREADS = INVALID_VALUE;
constexpr int64_t DEFAULT_MAXCHECK = INVALID_VALUE;
constexpr int64_t DEFAULT_THRESHOLDOFNUMBEROFCONTINUOUSNOBETTERPROPAGATION = INVALID_VALUE;
constexpr int64_t DEFAULT_NUMBEROFINITIALDYNAMICPIVOTS = INVALID_VALUE;
constexpr int64_t DEFAULT_NUMBEROFOTHERDYNAMICPIVOTS = INVALID_VALUE;
// KDT Config
constexpr int64_t DEFAULT_KDTNUMBER = INVALID_VALUE;
constexpr int64_t DEFAULT_NUMTOPDIMENSIONKDTSPLIT = INVALID_VALUE;
// BKT Config
constexpr int64_t DEFAULT_BKTNUMBER = INVALID_VALUE;
constexpr int64_t DEFAULT_BKTKMEANSK = INVALID_VALUE;
constexpr int64_t DEFAULT_BKTLEAFSIZE = INVALID_VALUE;
struct IVFCfg : public Cfg {
int64_t nlist = DEFAULT_NLIST;
int64_t nprobe = DEFAULT_NPROBE;
......@@ -126,8 +152,57 @@ struct NSGCfg : public IVFCfg {
};
using NSGConfig = std::shared_ptr<NSGCfg>;
struct KDTCfg : public Cfg {
int64_t tptnubmber = -1;
struct SPTAGCfg : public Cfg {
int64_t samples = DEFAULT_SAMPLES;
int64_t tptnumber = DEFAULT_TPTNUMBER;
int64_t tptleafsize = DEFAULT_TPTLEAFSIZE;
int64_t numtopdimensiontptsplit = DEFAULT_NUMTOPDIMENSIONTPTSPLIT;
int64_t neighborhoodsize = DEFAULT_NEIGHBORHOODSIZE;
int64_t graphneighborhoodscale = DEFAULT_GRAPHNEIGHBORHOODSCALE;
int64_t graphcefscale = DEFAULT_GRAPHCEFSCALE;
int64_t refineiterations = DEFAULT_REFINEITERATIONS;
int64_t cef = DEFAULT_CEF;
int64_t maxcheckforrefinegraph = DEFAULT_MAXCHECKFORREFINEGRAPH;
int64_t numofthreads = DEFAULT_NUMOFTHREADS;
int64_t maxcheck = DEFAULT_MAXCHECK;
int64_t thresholdofnumberofcontinuousnobetterpropagation = DEFAULT_THRESHOLDOFNUMBEROFCONTINUOUSNOBETTERPROPAGATION;
int64_t numberofinitialdynamicpivots = DEFAULT_NUMBEROFINITIALDYNAMICPIVOTS;
int64_t numberofotherdynamicpivots = DEFAULT_NUMBEROFOTHERDYNAMICPIVOTS;
SPTAGCfg() = default;
bool
CheckValid() override {
return true;
};
};
using SPTAGConfig = std::shared_ptr<SPTAGCfg>;
struct KDTCfg : public SPTAGCfg {
int64_t kdtnumber = DEFAULT_KDTNUMBER;
int64_t numtopdimensionkdtsplit = DEFAULT_NUMTOPDIMENSIONKDTSPLIT;
KDTCfg() = default;
bool
CheckValid() override {
return true;
};
};
using KDTConfig = std::shared_ptr<KDTCfg>;
struct BKTCfg : public SPTAGCfg {
int64_t bktnumber = DEFAULT_BKTNUMBER;
int64_t bktkmeansk = DEFAULT_BKTKMEANSK;
int64_t bktleafsize = DEFAULT_BKTLEAFSIZE;
BKTCfg() = default;
bool
CheckValid() override {
return true;
};
};
using BKTConfig = std::shared_ptr<BKTCfg>;
} // namespace knowhere
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <mutex>
#include "knowhere/index/vector_index/helpers/KDTParameterMgr.h"
namespace knowhere {
const std::vector<KDTParameter>&
KDTParameterMgr::GetKDTParameters() {
return kdt_parameters_;
}
KDTParameterMgr::KDTParameterMgr() {
kdt_parameters_ = std::vector<KDTParameter>{
{"KDTNumber", "1"},
{"NumTopDimensionKDTSplit", "5"},
{"NumSamplesKDTSplitConsideration", "100"},
{"TPTNumber", "1"},
{"TPTLeafSize", "2000"},
{"NumTopDimensionTPTSplit", "5"},
{"NeighborhoodSize", "32"},
{"GraphNeighborhoodScale", "2"},
{"GraphCEFScale", "2"},
{"RefineIterations", "0"},
{"CEF", "1000"},
{"MaxCheckForRefineGraph", "10000"},
{"NumberOfThreads", "1"},
{"MaxCheck", "8192"},
{"ThresholdOfNumberOfContinuousNoBetterPropagation", "3"},
{"NumberOfInitialDynamicPivots", "50"},
{"NumberOfOtherDynamicPivots", "4"},
};
}
} // namespace knowhere
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <mutex>
#include "knowhere/index/vector_index/helpers/SPTAGParameterMgr.h"
namespace knowhere {
const KDTConfig&
SPTAGParameterMgr::GetKDTParameters() {
return kdt_config_;
}
const BKTConfig&
SPTAGParameterMgr::GetBKTParameters() {
return bkt_config_;
}
SPTAGParameterMgr::SPTAGParameterMgr() {
kdt_config_ = std::make_shared<KDTCfg>();
kdt_config_->kdtnumber = 1;
kdt_config_->numtopdimensionkdtsplit = 5;
kdt_config_->samples = 100;
kdt_config_->tptnumber = 1;
kdt_config_->tptleafsize = 2000;
kdt_config_->numtopdimensiontptsplit = 5;
kdt_config_->neighborhoodsize = 32;
kdt_config_->graphneighborhoodscale = 2;
kdt_config_->graphcefscale = 2;
kdt_config_->refineiterations = 0;
kdt_config_->cef = 1000;
kdt_config_->maxcheckforrefinegraph = 10000;
kdt_config_->numofthreads = 1;
kdt_config_->maxcheck = 8192;
kdt_config_->thresholdofnumberofcontinuousnobetterpropagation = 3;
kdt_config_->numberofinitialdynamicpivots = 50;
kdt_config_->numberofotherdynamicpivots = 4;
bkt_config_ = std::make_shared<BKTCfg>();
bkt_config_->bktnumber = 1;
bkt_config_->bktkmeansk = 32;
bkt_config_->bktleafsize = 8;
bkt_config_->samples = 100;
bkt_config_->tptnumber = 1;
bkt_config_->tptleafsize = 2000;
bkt_config_->numtopdimensiontptsplit = 5;
bkt_config_->neighborhoodsize = 32;
bkt_config_->graphneighborhoodscale = 2;
bkt_config_->graphcefscale = 2;
bkt_config_->refineiterations = 0;
bkt_config_->cef = 1000;
bkt_config_->maxcheckforrefinegraph = 10000;
bkt_config_->numofthreads = 1;
bkt_config_->maxcheck = 8192;
bkt_config_->thresholdofnumberofcontinuousnobetterpropagation = 3;
bkt_config_->numberofinitialdynamicpivots = 50;
bkt_config_->numberofotherdynamicpivots = 4;
}
} // namespace knowhere
\ No newline at end of file
......@@ -22,31 +22,40 @@
#include <utility>
#include <vector>
#include <SPTAG/AnnService/inc/Core/Common.h>
#include "IndexParameter.h"
namespace knowhere {
using KDTParameter = std::pair<std::string, std::string>;
using KDTConfig = std::shared_ptr<KDTCfg>;
using BKTConfig = std::shared_ptr<BKTCfg>;
class SPTAGParameterMgr {
public:
const KDTConfig&
GetKDTParameters();
const BKTConfig&
GetBKTParameters();
class KDTParameterMgr {
public:
const std::vector<KDTParameter>&
GetKDTParameters();
public:
static SPTAGParameterMgr&
GetInstance() {
static SPTAGParameterMgr instance;
return instance;
}
public:
static KDTParameterMgr&
GetInstance() {
static KDTParameterMgr instance;
return instance;
}
SPTAGParameterMgr(const SPTAGParameterMgr&) = delete;
KDTParameterMgr(const KDTParameterMgr&) = delete;
KDTParameterMgr&
operator=(const KDTParameterMgr&) = delete;
SPTAGParameterMgr&
operator=(const SPTAGParameterMgr&) = delete;
private:
KDTParameterMgr();
private:
SPTAGParameterMgr();
private:
std::vector<KDTParameter> kdt_parameters_;
};
private:
KDTConfig kdt_config_;
BKTConfig bkt_config_;
};
} // namespace knowhere
......@@ -195,7 +195,7 @@ namespace SPTAG
C = *((DimensionType*)pDataPointsMemFile);
pDataPointsMemFile += sizeof(DimensionType);
Initialize(R, C, (T*)pDataPointsMemFile);
Initialize(R, C, (T*)pDataPointsMemFile, false);
std::cout << "Load " << name << " (" << R << ", " << C << ") Finish!" << std::endl;
return true;
}
......
......@@ -82,17 +82,17 @@ if (NOT TARGET test_idmap)
endif ()
target_link_libraries(test_idmap ${depend_libs} ${unittest_libs} ${basic_libs})
#<KDT-TEST>
set(kdt_srcs
#<SPTAG-TEST>
set(sptag_srcs
${INDEX_SOURCE_DIR}/knowhere/knowhere/adapter/SptagAdapter.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/preprocessor/Normalize.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/KDTParameterMgr.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexKDT.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/SPTAGParameterMgr.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexSPTAG.cpp
)
if (NOT TARGET test_kdt)
add_executable(test_kdt test_kdt.cpp ${kdt_srcs} ${util_srcs})
if (NOT TARGET test_sptag)
add_executable(test_sptag test_sptag.cpp ${sptag_srcs} ${util_srcs})
endif ()
target_link_libraries(test_kdt
target_link_libraries(test_sptag
SPTAGLibStatic
${depend_libs} ${unittest_libs} ${basic_libs})
......@@ -106,7 +106,7 @@ endif ()
install(TARGETS test_ivf DESTINATION unittest)
install(TARGETS test_idmap DESTINATION unittest)
install(TARGETS test_kdt DESTINATION unittest)
install(TARGETS test_sptag DESTINATION unittest)
if (KNOWHERE_GPU_VERSION)
install(TARGETS test_gpuresource DESTINATION unittest)
install(TARGETS test_customized_index DESTINATION unittest)
......
......@@ -23,7 +23,7 @@
#include "knowhere/adapter/SptagAdapter.h"
#include "knowhere/adapter/Structure.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/IndexKDT.h"
#include "knowhere/index/vector_index/IndexSPTAG.h"
#include "knowhere/index/vector_index/helpers/Definitions.h"
#include "unittest/utils.h"
......@@ -32,28 +32,38 @@ using ::testing::Combine;
using ::testing::TestWithParam;
using ::testing::Values;
class KDTTest : public DataGen, public ::testing::Test {
class SPTAGTest : public DataGen, public TestWithParam<std::string> {
protected:
void
SetUp() override {
Generate(96, 1000, 10);
index_ = std::make_shared<knowhere::CPUKDTRNG>();
auto tempconf = std::make_shared<knowhere::KDTCfg>();
tempconf->tptnubmber = 1;
tempconf->k = 10;
conf = tempconf;
IndexType = GetParam();
Generate(128, 100, 5);
index_ = std::make_shared<knowhere::CPUSPTAGRNG>(IndexType);
if (IndexType == "KDT") {
auto tempconf = std::make_shared<knowhere::KDTCfg>();
tempconf->tptnumber = 1;
tempconf->k = 10;
conf = tempconf;
} else {
auto tempconf = std::make_shared<knowhere::BKTCfg>();
tempconf->tptnumber = 1;
tempconf->k = 10;
conf = tempconf;
}
Init_with_default();
}
protected:
knowhere::Config conf;
std::shared_ptr<knowhere::CPUKDTRNG> index_ = nullptr;
std::shared_ptr<knowhere::CPUSPTAGRNG> index_ = nullptr;
std::string IndexType;
};
INSTANTIATE_TEST_CASE_P(SPTAGParameters, SPTAGTest, Values("KDT", "BKT"));
// TODO(lxj): add test about count() and dimension()
TEST_F(KDTTest, kdt_basic) {
TEST_P(SPTAGTest, sptag_basic) {
assert(!xb.empty());
auto preprocessor = index_->BuildPreprocessor(base_dataset, conf);
......@@ -66,8 +76,8 @@ TEST_F(KDTTest, kdt_basic) {
AssertAnns(result, nq, k);
{
// auto ids = result->array()[0];
// auto dists = result->array()[1];
//auto ids = result->array()[0];
//auto dists = result->array()[1];
auto ids = result->ids();
auto dists = result->dist();
......@@ -75,10 +85,10 @@ TEST_F(KDTTest, kdt_basic) {
std::stringstream ss_dist;
for (auto i = 0; i < nq; i++) {
for (auto j = 0; j < k; ++j) {
//ss_id << *ids->data()->GetValues<int64_t>(1, i * k + j) << " ";
//ss_dist << *dists->data()->GetValues<float>(1, i * k + j) << " ";
ss_id << *((int64_t*)(ids) + i * k + j) << " ";
ss_dist << *((float*)(dists) + i * k + j) << " ";
// ss_id << *ids->data()->GetValues<int64_t>(1, i * k + j) << " ";
// ss_dist << *dists->data()->GetValues<float>(1, i * k + j) << " ";
}
ss_id << std::endl;
ss_dist << std::endl;
......@@ -88,57 +98,57 @@ TEST_F(KDTTest, kdt_basic) {
}
}
// TODO(zirui): enable test
// TEST_F(KDTTest, kdt_serialize) {
// assert(!xb.empty());
//
// auto preprocessor = index_->BuildPreprocessor(base_dataset, conf);
// index_->set_preprocessor(preprocessor);
//
// auto model = index_->Train(base_dataset, conf);
// // index_->Add(base_dataset, conf);
// auto binaryset = index_->Serialize();
// auto new_index = std::make_shared<knowhere::CPUKDTRNG>();
// new_index->Load(binaryset);
// auto result = new_index->Search(query_dataset, conf);
// AssertAnns(result, nq, k);
// PrintResult(result, nq, k);
// ASSERT_EQ(new_index->Count(), nb);
// ASSERT_EQ(new_index->Dimension(), dim);
// ASSERT_THROW({ new_index->Clone(); }, knowhere::KnowhereException);
// ASSERT_NO_THROW({ new_index->Seal(); });
//
// {
// int fileno = 0;
// const std::string& base_name = "/tmp/kdt_serialize_test_bin_";
// std::vector<std::string> filename_list;
// std::vector<std::pair<std::string, size_t>> meta_list;
// for (auto& iter : binaryset.binary_map_) {
// const std::string& filename = base_name + std::to_string(fileno);
// FileIOWriter writer(filename);
// writer(iter.second->data.get(), iter.second->size);
//
// meta_list.emplace_back(std::make_pair(iter.first, iter.second->size));
// filename_list.push_back(filename);
// ++fileno;
// }
//
// knowhere::BinarySet load_data_list;
// for (int i = 0; i < filename_list.size() && i < meta_list.size(); ++i) {
// auto bin_size = meta_list[i].second;
// FileIOReader reader(filename_list[i]);
//
// auto load_data = new uint8_t[bin_size];
// reader(load_data, bin_size);
// auto data = std::make_shared<uint8_t>();
// data.reset(load_data);
// load_data_list.Append(meta_list[i].first, data, bin_size);
// }
//
// auto new_index = std::make_shared<knowhere::CPUKDTRNG>();
// new_index->Load(load_data_list);
// auto result = new_index->Search(query_dataset, conf);
// AssertAnns(result, nq, k);
// PrintResult(result, nq, k);
// }
//}
TEST_P(SPTAGTest, sptag_serialize) {
assert(!xb.empty());
auto preprocessor = index_->BuildPreprocessor(base_dataset, conf);
index_->set_preprocessor(preprocessor);
auto model = index_->Train(base_dataset, conf);
index_->Add(base_dataset, conf);
auto binaryset = index_->Serialize();
auto new_index = std::make_shared<knowhere::CPUSPTAGRNG>(IndexType);
new_index->Load(binaryset);
auto result = new_index->Search(query_dataset, conf);
AssertAnns(result, nq, k);
PrintResult(result, nq, k);
ASSERT_EQ(new_index->Count(), nb);
ASSERT_EQ(new_index->Dimension(), dim);
// ASSERT_THROW({ new_index->Clone(); }, knowhere::KnowhereException);
// ASSERT_NO_THROW({ new_index->Seal(); });
{
int fileno = 0;
const std::string& base_name = "/tmp/sptag_serialize_test_bin_";
std::vector<std::string> filename_list;
std::vector<std::pair<std::string, size_t>> meta_list;
for (auto& iter : binaryset.binary_map_) {
const std::string& filename = base_name + std::to_string(fileno);
FileIOWriter writer(filename);
writer(iter.second->data.get(), iter.second->size);
meta_list.emplace_back(std::make_pair(iter.first, iter.second->size));
filename_list.push_back(filename);
++fileno;
}
knowhere::BinarySet load_data_list;
for (int i = 0; i < filename_list.size() && i < meta_list.size(); ++i) {
auto bin_size = meta_list[i].second;
FileIOReader reader(filename_list[i]);
auto load_data = new uint8_t[bin_size];
reader(load_data, bin_size);
auto data = std::make_shared<uint8_t>();
data.reset(load_data);
load_data_list.Append(meta_list[i].first, data, bin_size);
}
auto new_index = std::make_shared<knowhere::CPUSPTAGRNG>(IndexType);
new_index->Load(load_data_list);
auto result = new_index->Search(query_dataset, conf);
AssertAnns(result, nq, k);
PrintResult(result, nq, k);
}
}
......@@ -153,22 +153,24 @@ void
AssertAnns(const knowhere::DatasetPtr& result, const int& nq, const int& k) {
auto ids = result->ids();
for (auto i = 0; i < nq; i++) {
EXPECT_EQ(i, *((int64_t*)(ids) + i * k));
EXPECT_EQ(i, *((int64_t*)(ids) + i * k));
// EXPECT_EQ(i, *(ids->data()->GetValues<int64_t>(1, i * k)));
}
}
void
PrintResult(const knowhere::DatasetPtr& result, const int& nq, const int& k) {
auto ids = result->array()[0];
auto dists = result->array()[1];
auto ids = result->ids();
auto dists = result->dist();
std::stringstream ss_id;
std::stringstream ss_dist;
for (auto i = 0; i < 10; i++) {
for (auto i = 0; i < nq; i++) {
for (auto j = 0; j < k; ++j) {
ss_id << *(ids->data()->GetValues<int64_t>(1, i * k + j)) << " ";
ss_dist << *(dists->data()->GetValues<float>(1, i * k + j)) << " ";
//ss_id << *(ids->data()->GetValues<int64_t>(1, i * k + j)) << " ";
//ss_dist << *(dists->data()->GetValues<float>(1, i * k + j)) << " ";
ss_id << *((int64_t*)(ids) + i * k + j) << " ";
ss_dist << *((float*)(dists) + i * k + j) << " ";
}
ss_id << std::endl;
ss_dist << std::endl;
......
......@@ -204,5 +204,35 @@ NSGConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type)
return conf;
}
knowhere::Config
SPTAGKDTConfAdapter::Match(const TempMetaConf& metaconf) {
auto conf = std::make_shared<knowhere::KDTCfg>();
conf->d = metaconf.dim;
conf->metric_type = metaconf.metric_type;
return conf;
}
knowhere::Config
SPTAGKDTConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) {
auto conf = std::make_shared<knowhere::KDTCfg>();
conf->k = metaconf.k;
return conf;
}
knowhere::Config
SPTAGBKTConfAdapter::Match(const TempMetaConf& metaconf) {
auto conf = std::make_shared<knowhere::BKTCfg>();
conf->d = metaconf.dim;
conf->metric_type = metaconf.metric_type;
return conf;
}
knowhere::Config
SPTAGBKTConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) {
auto conf = std::make_shared<knowhere::BKTCfg>();
conf->k = metaconf.k;
return conf;
}
} // namespace engine
} // namespace milvus
......@@ -97,5 +97,23 @@ class NSGConfAdapter : public IVFConfAdapter {
MatchSearch(const TempMetaConf& metaconf, const IndexType& type) final;
};
class SPTAGKDTConfAdapter : public ConfAdapter {
public:
knowhere::Config
Match(const TempMetaConf& metaconf) override;
knowhere::Config
MatchSearch(const TempMetaConf& metaconf, const IndexType& type) override;
};
class SPTAGBKTConfAdapter : public ConfAdapter {
public:
knowhere::Config
Match(const TempMetaConf& metaconf) override;
knowhere::Config
MatchSearch(const TempMetaConf& metaconf, const IndexType& type) override;
};
} // namespace engine
} // namespace milvus
......@@ -56,6 +56,9 @@ AdapterMgr::RegisterAdapter() {
REGISTER_CONF_ADAPTER(IVFPQConfAdapter, IndexType::FAISS_IVFPQ_MIX, ivfpq_mix);
REGISTER_CONF_ADAPTER(NSGConfAdapter, IndexType::NSG_MIX, nsg_mix);
REGISTER_CONF_ADAPTER(SPTAGKDTConfAdapter, IndexType::SPTAG_KDT_RNT_CPU, sptag_kdt);
REGISTER_CONF_ADAPTER(SPTAGBKTConfAdapter, IndexType::SPTAG_BKT_RNT_CPU, sptag_bkt);
}
} // namespace engine
......
......@@ -22,7 +22,7 @@
#include "knowhere/index/vector_index/IndexIVF.h"
#include "knowhere/index/vector_index/IndexIVFPQ.h"
#include "knowhere/index/vector_index/IndexIVFSQ.h"
#include "knowhere/index/vector_index/IndexKDT.h"
#include "knowhere/index/vector_index/IndexSPTAG.h"
#include "knowhere/index/vector_index/IndexNSG.h"
#include "utils/Log.h"
......@@ -128,7 +128,11 @@ GetVecIndexFactory(const IndexType& type, const Config& cfg) {
break;
}
case IndexType::SPTAG_KDT_RNT_CPU: {
index = std::make_shared<knowhere::CPUKDTRNG>();
index = std::make_shared<knowhere::CPUSPTAGRNG>("KDT");
break;
}
case IndexType::SPTAG_BKT_RNT_CPU: {
index = std::make_shared<knowhere::CPUSPTAGRNG>("BKT");
break;
}
case IndexType::FAISS_IVFSQ8_CPU: {
......
......@@ -49,6 +49,7 @@ enum class IndexType {
FAISS_IVFSQ8_HYBRID, // only support build on gpu.
NSG_MIX,
FAISS_IVFPQ_MIX,
SPTAG_BKT_RNT_CPU,
};
class VecIndex;
......@@ -139,6 +140,9 @@ write_index(VecIndexPtr index, const std::string& location);
extern VecIndexPtr
read_index(const std::string& location);
VecIndexPtr
read_index(const std::string& location, knowhere::BinarySet& index_binary);
extern VecIndexPtr
GetVecIndexFactory(const IndexType& type, const Config& cfg = Config());
......
......@@ -29,15 +29,16 @@
INITIALIZE_EASYLOGGINGPP
using ::testing::Combine;
using ::testing::TestWithParam;
using ::testing::Values;
using ::testing::Combine;
class KnowhereWrapperTest
: public DataGenBase,
public TestWithParam<::std::tuple<milvus::engine::IndexType, std::string, int, int, int, int>> {
: public DataGenBase,
public TestWithParam<::std::tuple<milvus::engine::IndexType, std::string, int, int, int, int>> {
protected:
void SetUp() override {
void
SetUp() override {
#ifdef MILVUS_GPU_VERSION
knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, PINMEM, TEMPMEM, RESNUM);
#endif
......@@ -57,12 +58,13 @@ class KnowhereWrapperTest
conf = ParamGenerator::GetInstance().GenBuild(index_type, tempconf);
searchconf = ParamGenerator::GetInstance().GenSearchConf(index_type, tempconf);
// conf->k = k;
// conf->d = dim;
// conf->gpu_id = DEVICEID;
// conf->k = k;
// conf->d = dim;
// conf->gpu_id = DEVICEID;
}
void TearDown() override {
void
TearDown() override {
#ifdef MILVUS_GPU_VERSION
knowhere::FaissGpuResourceMgr::GetInstance().Free();
#endif
......@@ -75,22 +77,21 @@ class KnowhereWrapperTest
knowhere::Config searchconf;
};
INSTANTIATE_TEST_CASE_P(WrapperParam, KnowhereWrapperTest,
Values(
//["Index type", "Generator type", "dim", "nb", "nq", "k", "build config", "search config"]
INSTANTIATE_TEST_CASE_P(
WrapperParam, KnowhereWrapperTest,
Values(
//["Index type", "Generator type", "dim", "nb", "nq", "k", "build config", "search config"]
#ifdef MILVUS_GPU_VERSION
std::make_tuple(milvus::engine::IndexType::FAISS_IVFFLAT_GPU, "Default", DIM, NB, 10, 10),
std::make_tuple(milvus::engine::IndexType::FAISS_IVFFLAT_MIX, "Default", 64, 1000, 10, 10),
// std::make_tuple(milvus::engine::IndexType::FAISS_IVFSQ8_GPU, "Default", DIM, NB,
// 10, 10),
std::make_tuple(milvus::engine::IndexType::FAISS_IVFSQ8_GPU, "Default", DIM, NB, 10, 10),
std::make_tuple(milvus::engine::IndexType::FAISS_IVFSQ8_MIX, "Default", DIM, NB, 10, 10),
std::make_tuple(milvus::engine::IndexType::FAISS_IVFPQ_MIX, "Default", 64, 1000, 10, 10),
// std::make_tuple(IndexType::NSG_MIX, "Default", 128, 250000, 10, 10),
// std::make_tuple(milvus::engine::IndexType::NSG_MIX, "Default", 128, 250000, 10, 10),
#endif
// std::make_tuple(IndexType::SPTAG_KDT_RNT_CPU, "Default", 128, 250000, 10, 10),
// std::make_tuple(milvus::engine::IndexType::SPTAG_KDT_RNT_CPU, "Default", 128, 100, 10, 10),
// std::make_tuple(milvus::engine::IndexType::SPTAG_BKT_RNT_CPU, "Default", 128, 100, 10, 10),
std::make_tuple(milvus::engine::IndexType::FAISS_IDMAP, "Default", 64, 1000, 10, 10),
std::make_tuple(milvus::engine::IndexType::FAISS_IVFFLAT_CPU, "Default", 64, 1000, 10, 10),
std::make_tuple(milvus::engine::IndexType::FAISS_IVFSQ8_CPU, "Default", DIM, NB, 10, 10)));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册