未验证 提交 a08b51c2 编写于 作者: G groot 提交者: GitHub

merge json to master to get docker image (#1500)

* General proto api for NNS libraries
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* refactor confadapter
Signed-off-by: NNicky <nicky.xj.lin@gmail.com>

* fix unittest failures
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* update test_add
Signed-off-by: Nzhenwu <zw@zilliz.com>

* update knowhere
Signed-off-by: NNicky <nicky.xj.lin@gmail.com>

* update test cases
Signed-off-by: NXiaohai Xu <xiaohaix@student.unimelb.edu.au>

* Update cases

* C++ sdk for json parameters
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* update unittest
Signed-off-by: NNicky <nicky.xj.lin@gmail.com>

* fix unittest failures
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix case
Signed-off-by: Ndel-zhenwu <zw@zilliz.com>

* modify test_index.py
Signed-off-by: Nshengjh <jianghong.sheng@zilliz.com>

* update
Signed-off-by: NNicky <nicky.xj.lin@gmail.com>

* update sptag
Signed-off-by: NNicky <nicky.xj.lin@gmail.com>

* update...
Signed-off-by: NNicky <nicky.xj.lin@gmail.com>

* Build Pass
Signed-off-by: Nxiaojun.lin <xiaojun.lin@zilliz.com>

* knowhere/wrapper ut pass
Signed-off-by: Nxiaojun.lin <xiaojun.lin@zilliz.com>

* update util
Signed-off-by: NXiaohai Xu <xiaohaix@student.unimelb.edu.au>

* fix wal case
Signed-off-by: Ndel-zhenwu <zw@zilliz.com>

* modify test_search_vectors
Signed-off-by: Nshengjh <jianghong.sheng@zilliz.com>

* update ci
Signed-off-by: Ndel-zhenwu <zw@zilliz.com>

* update util
Signed-off-by: NXiaohai Xu <xiaohaix@student.unimelb.edu.au>

* modify test_search_vectoes
Signed-off-by: Nshengjh <jianghong.sheng@zilliz.com>

* add hnsw in http module & modify index apis
Signed-off-by: NYhz <yinghao.zou@zilliz.com>

* modify search in http module
Signed-off-by: NYhz <yinghao.zou@zilliz.com>

* fix build error
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix typo in test_index and test_search
Signed-off-by: Nshengjh <jianghong.sheng@zilliz.com>

* update...
Signed-off-by: NNicky <nicky.xj.lin@gmail.com>

* index apis in http module done
Signed-off-by: NYhz <yinghao.zou@zilliz.com>

* fix build index bug
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* search apis unittest pass
Signed-off-by: NYhz <yinghao.zou@zilliz.com>

* web test pass
Signed-off-by: NYhz <yinghao.zou@zilliz.com>

* update confadapter
Signed-off-by: NNicky <nicky.xj.lin@gmail.com>

* update util
Signed-off-by: NXiaohai Xu <xiaohaix@student.unimelb.edu.au>

* code format
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* code format
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix vectors results bug (fix #1476)
Signed-off-by: NYhz <yinghao.zou@zilliz.com>

* clang format
Signed-off-by: NYhz <yinghao.zou@zilliz.com>

* update test
Signed-off-by: Nshengjh <jianghong.sheng@zilliz.com>

* fix unittest
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* add test_config
Signed-off-by: NXiaohai Xu <xiaohaix@student.unimelb.edu.au>

* add log
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix a build error
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* add invalid param search test
Signed-off-by: Nshengjh <jianghong.sheng@zilliz.com>

* fix range check
Signed-off-by: NNicky <nicky.xj.lin@gmail.com>

* cpmpact/flush case passed
Signed-off-by: Ndel-zhenwu <zhenxiang.li@zilliz.com>

* fix unittest failures
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix unittest failures
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix unittest failures
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* validate json parameters in request
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* add unittest cases
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* update test index/search
Signed-off-by: Nshengjh <jianghong.sheng@zilliz.com>

* update test_config
Signed-off-by: Nsahuang <xiaohaix@student.unimelb.edu.au>

* fix
Signed-off-by: Nshengjh <jianghong.sheng@zilliz.com>

* support nsg and ivf-nlist
Signed-off-by: NNicky <nicky.xj.lin@gmail.com>

* update
Signed-off-by: Nxiaojun.lin <xiaojun.lin@zilliz.com>

* fix validation bug
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix python test bug
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix python test bug
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix python test bug
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix python test bug
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* code format
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix python test failure
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* remove rnsg cases
Signed-off-by: Nzhenwu <zw@zilliz.com>

* fix python test failure
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* Update changelog
Signed-off-by: NJinHai-CN <hai.jin@zilliz.com>

* Fix typo
Signed-off-by: NJinHai-CN <hai.jin@zilliz.com>

* add pq to test_index && multithread test
Signed-off-by: Nshengjh <jianghong.sheng@zilliz.com>

* add pq to test_search
Signed-off-by: Nshengjh <jianghong.sheng@zilliz.com>

* Fix format
Signed-off-by: NJinHai-CN <hai.jin@zilliz.com>

* Update CHANGELOG
Signed-off-by: NJinHai-CN <hai.jin@zilliz.com>

* Fix compiling error
Signed-off-by: NJinHai-CN <hai.jin@zilliz.com>

* Fix compiling error
Signed-off-by: NJinHai-CN <hai.jin@zilliz.com>

* fix config bug
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* code format
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix config test
Signed-off-by: Nxiaojun.lin <xiaojun.lin@zilliz.com>

* Update CHANGELOG.md
Signed-off-by: NJinHai-CN <hai.jin@zilliz.com>

* Update CHANGELOG.md
Signed-off-by: NJinHai-CN <hai.jin@zilliz.com>

* disable config test case
Signed-off-by: Nzhenwu <zw@zilliz.com>
Co-authored-by: NNicky <nicky.xj.lin@gmail.com>
Co-authored-by: Nzhenwu <zw@zilliz.com>
Co-authored-by: NXiaohai Xu <xiaohaix@student.unimelb.edu.au>
Co-authored-by: Nshengjh <jianghong.sheng@zilliz.com>
Co-authored-by: Nxiaojun.lin <xiaojun.lin@zilliz.com>
Co-authored-by: NYhz <yinghao.zou@zilliz.com>
Co-authored-by: Ndel-zhenwu <zhenxiang.li@zilliz.com>
Co-authored-by: NJinHai-CN <hai.jin@zilliz.com>
上级 2a101eaa
......@@ -11,7 +11,7 @@ Please mark all change in change log and use the issue from GitHub
- \#805 IVFTest.gpu_seal_test unittest failed
- \#831 Judge branch error in CommonUtil.cpp
- \#977 Server crash when create tables concurrently
- \#990 check gpu resources setting when assign repeated value
- \#990 Check gpu resources setting when assign repeated value
- \#995 table count set to 0 if no tables found
- \#1010 improve error message when offset or page_size is equal 0
- \#1022 check if partition name is legal
......@@ -19,8 +19,8 @@ Please mark all change in change log and use the issue from GitHub
- \#1029 check if table exists when try to delete partition
- \#1066 optimize http insert and search speed
- \#1067 Add binary vectors support in http server
- \#1075 improve error message when page size or offset is illegal
- \#1082 check page_size or offset value to avoid float
- \#1075 Improve error message when page size or offset is illegal
- \#1082 Check page_size or offset value to avoid float
- \#1115 http server support load table into memory
- \#1152 Error log output continuously after server start
- \#1211 Server down caused by searching with index_type: HNSW
......@@ -86,6 +86,7 @@ Please mark all change in change log and use the issue from GitHub
- \#1320 Remove debug logging from faiss
- \#1426 Support to configure whether to enabled autoflush and the autoflush interval
- \#1444 Improve delete
- \#1448 General proto api for NNS libraries
- \#1480 Add return code for AVX512 selection
- \#1524 Update config "preload_table" description
......
......@@ -2,6 +2,7 @@ timeout(time: 60, unit: 'MINUTES') {
dir ("tests/milvus_python_test") {
// sh 'python3 -m pip install -r requirements.txt -i http://pypi.douban.com/simple --trusted-host pypi.douban.com'
sh 'python3 -m pip install -r requirements.txt'
sh 'python3 -m pip install git+https://github.com/BossZou/pymilvus.git@nns'
sh "pytest . --alluredir=\"test_out/dev/single/sqlite\" --level=1 --ip ${env.HELM_RELEASE_NAME}.milvus.svc.cluster.local"
}
......
......@@ -112,18 +112,18 @@ class DB {
virtual Status
QueryByID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, IDNumber vector_id,
ResultIds& result_ids, ResultDistances& result_distances) = 0;
const std::vector<std::string>& partition_tags, uint64_t k, const milvus::json& extra_params,
IDNumber vector_id, ResultIds& result_ids, ResultDistances& result_distances) = 0;
virtual Status
Query(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
ResultIds& result_ids, ResultDistances& result_distances) = 0;
const std::vector<std::string>& partition_tags, uint64_t k, const milvus::json& extra_params,
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances) = 0;
virtual Status
QueryByFileID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& file_ids, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
ResultIds& result_ids, ResultDistances& result_distances) = 0;
const std::vector<std::string>& file_ids, uint64_t k, const milvus::json& extra_params,
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances) = 0;
virtual Status
Size(uint64_t& result) = 0;
......
......@@ -371,8 +371,10 @@ DBImpl::PreloadTable(const std::string& table_id) {
} else {
engine_type = (EngineType)file.engine_type_;
}
ExecutionEnginePtr engine = EngineFactory::Build(file.dimension_, file.location_, engine_type,
(MetricType)file.metric_type_, file.nlist_);
auto json = milvus::json::parse(file.index_params_);
ExecutionEnginePtr engine =
EngineFactory::Build(file.dimension_, file.location_, engine_type, (MetricType)file.metric_type_, json);
fiu_do_on("DBImpl.PreloadTable.null_engine", engine = nullptr);
if (engine == nullptr) {
ENGINE_LOG_ERROR << "Invalid engine type";
......@@ -382,7 +384,7 @@ DBImpl::PreloadTable(const std::string& table_id) {
size += engine->PhysicalSize();
fiu_do_on("DBImpl.PreloadTable.exceed_cache", size = available_size + 1);
if (size > available_size) {
ENGINE_LOG_DEBUG << "Pre-load canceled since cache almost full";
ENGINE_LOG_DEBUG << "Pre-load cancelled since cache is almost full";
return Status(SERVER_CACHE_FULL, "Cache is full");
} else {
try {
......@@ -1110,8 +1112,8 @@ DBImpl::DropIndex(const std::string& table_id) {
Status
DBImpl::QueryByID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, IDNumber vector_id,
ResultIds& result_ids, ResultDistances& result_distances) {
const std::vector<std::string>& partition_tags, uint64_t k, const milvus::json& extra_params,
IDNumber vector_id, ResultIds& result_ids, ResultDistances& result_distances) {
if (!initialized_.load(std::memory_order_acquire)) {
return SHUTDOWN_ERROR;
}
......@@ -1119,14 +1121,15 @@ DBImpl::QueryByID(const std::shared_ptr<server::Context>& context, const std::st
VectorsData vectors_data = VectorsData();
vectors_data.id_array_.emplace_back(vector_id);
vectors_data.vector_count_ = 1;
Status result = Query(context, table_id, partition_tags, k, nprobe, vectors_data, result_ids, result_distances);
Status result =
Query(context, table_id, partition_tags, k, extra_params, vectors_data, result_ids, result_distances);
return result;
}
Status
DBImpl::Query(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
ResultIds& result_ids, ResultDistances& result_distances) {
const std::vector<std::string>& partition_tags, uint64_t k, const milvus::json& extra_params,
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances) {
auto query_ctx = context->Child("Query");
if (!initialized_.load(std::memory_order_acquire)) {
......@@ -1169,7 +1172,7 @@ DBImpl::Query(const std::shared_ptr<server::Context>& context, const std::string
}
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query
status = QueryAsync(query_ctx, table_id, files_array, k, nprobe, vectors, result_ids, result_distances);
status = QueryAsync(query_ctx, table_id, files_array, k, extra_params, vectors, result_ids, result_distances);
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query
query_ctx->GetTraceContext()->GetSpan()->Finish();
......@@ -1179,8 +1182,8 @@ DBImpl::Query(const std::shared_ptr<server::Context>& context, const std::string
Status
DBImpl::QueryByFileID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& file_ids, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
ResultIds& result_ids, ResultDistances& result_distances) {
const std::vector<std::string>& file_ids, uint64_t k, const milvus::json& extra_params,
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances) {
auto query_ctx = context->Child("Query by file id");
if (!initialized_.load(std::memory_order_acquire)) {
......@@ -1208,7 +1211,7 @@ DBImpl::QueryByFileID(const std::shared_ptr<server::Context>& context, const std
}
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query
status = QueryAsync(query_ctx, table_id, files_array, k, nprobe, vectors, result_ids, result_distances);
status = QueryAsync(query_ctx, table_id, files_array, k, extra_params, vectors, result_ids, result_distances);
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query
query_ctx->GetTraceContext()->GetSpan()->Finish();
......@@ -1230,8 +1233,8 @@ DBImpl::Size(uint64_t& result) {
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Status
DBImpl::QueryAsync(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const meta::TableFilesSchema& files, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
ResultIds& result_ids, ResultDistances& result_distances) {
const meta::TableFilesSchema& files, uint64_t k, const milvus::json& extra_params,
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances) {
auto query_async_ctx = context->Child("Query Async");
server::CollectQueryMetrics metrics(vectors.vector_count_);
......@@ -1242,7 +1245,7 @@ DBImpl::QueryAsync(const std::shared_ptr<server::Context>& context, const std::s
auto status = OngoingFileChecker::GetInstance().MarkOngoingFiles(files);
ENGINE_LOG_DEBUG << "Engine query begin, index file count: " << files.size();
scheduler::SearchJobPtr job = std::make_shared<scheduler::SearchJob>(query_async_ctx, k, nprobe, vectors);
scheduler::SearchJobPtr job = std::make_shared<scheduler::SearchJob>(query_async_ctx, k, extra_params, vectors);
for (auto& file : files) {
scheduler::TableFileSchemaPtr file_ptr = std::make_shared<meta::TableFileSchema>(file);
job->AddIndexFile(file_ptr);
......
......@@ -131,18 +131,18 @@ class DBImpl : public DB, public server::CacheConfigHandler {
Status
QueryByID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, IDNumber vector_id,
ResultIds& result_ids, ResultDistances& result_distances) override;
const std::vector<std::string>& partition_tags, uint64_t k, const milvus::json& extra_params,
IDNumber vector_id, ResultIds& result_ids, ResultDistances& result_distances) override;
Status
Query(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
ResultIds& result_ids, ResultDistances& result_distances) override;
const std::vector<std::string>& partition_tags, uint64_t k, const milvus::json& extra_params,
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances) override;
Status
QueryByFileID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& file_ids, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
ResultIds& result_ids, ResultDistances& result_distances) override;
const std::vector<std::string>& file_ids, uint64_t k, const milvus::json& extra_params,
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances) override;
Status
Size(uint64_t& result) override;
......@@ -154,8 +154,8 @@ class DBImpl : public DB, public server::CacheConfigHandler {
private:
Status
QueryAsync(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const meta::TableFilesSchema& files, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
ResultIds& result_ids, ResultDistances& result_distances);
const meta::TableFilesSchema& files, uint64_t k, const milvus::json& extra_params,
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances);
Status
GetVectorByIdHelper(const std::string& table_id, IDNumber vector_id, VectorsData& vector,
......
......@@ -22,6 +22,7 @@
#include "db/engine/ExecutionEngine.h"
#include "segment/Types.h"
#include "utils/Json.h"
namespace milvus {
namespace engine {
......@@ -35,8 +36,8 @@ typedef std::vector<faiss::Index::distance_t> ResultDistances;
struct TableIndex {
int32_t engine_type_ = (int)EngineType::FAISS_IDMAP;
int32_t nlist_ = 16384;
int32_t metric_type_ = (int)MetricType::L2;
milvus::json extra_params_ = {{"nlist", 16384}};
};
struct VectorsData {
......
......@@ -211,7 +211,7 @@ GetParentPath(const std::string& path, std::string& parent_path) {
bool
IsSameIndex(const TableIndex& index1, const TableIndex& index2) {
return index1.engine_type_ == index2.engine_type_ && index1.nlist_ == index2.nlist_ &&
return index1.engine_type_ == index2.engine_type_ && index1.extra_params_ == index2.extra_params_ &&
index1.metric_type_ == index2.metric_type_;
}
......
......@@ -20,7 +20,7 @@ namespace engine {
ExecutionEnginePtr
EngineFactory::Build(uint16_t dimension, const std::string& location, EngineType index_type, MetricType metric_type,
int32_t nlist) {
const milvus::json& index_params) {
if (index_type == EngineType::INVALID) {
ENGINE_LOG_ERROR << "Unsupported engine type";
return nullptr;
......@@ -28,7 +28,7 @@ EngineFactory::Build(uint16_t dimension, const std::string& location, EngineType
ENGINE_LOG_DEBUG << "EngineFactory index type: " << (int)index_type;
ExecutionEnginePtr execution_engine_ptr =
std::make_shared<ExecutionEngineImpl>(dimension, location, index_type, metric_type, nlist);
std::make_shared<ExecutionEngineImpl>(dimension, location, index_type, metric_type, index_params);
execution_engine_ptr->Init();
return execution_engine_ptr;
......
......@@ -12,6 +12,7 @@
#pragma once
#include "ExecutionEngine.h"
#include "utils/Json.h"
#include "utils/Status.h"
#include <string>
......@@ -23,7 +24,7 @@ class EngineFactory {
public:
static ExecutionEnginePtr
Build(uint16_t dimension, const std::string& location, EngineType index_type, MetricType metric_type,
int32_t nlist);
const milvus::json& index_params);
};
} // namespace engine
......
......@@ -15,6 +15,7 @@
#include <string>
#include <vector>
#include "utils/Json.h"
#include "utils/Status.h"
namespace milvus {
......@@ -94,15 +95,16 @@ class ExecutionEngine {
GetVectorByID(const int64_t& id, uint8_t* vector, bool hybrid) = 0;
virtual Status
Search(int64_t n, const float* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels, bool hybrid) = 0;
Search(int64_t n, const float* data, int64_t k, const milvus::json& extra_params, float* distances, int64_t* labels,
bool hybrid) = 0;
virtual Status
Search(int64_t n, const uint8_t* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
bool hybrid) = 0;
Search(int64_t n, const uint8_t* data, int64_t k, const milvus::json& extra_params, float* distances,
int64_t* labels, bool hybrid) = 0;
virtual Status
Search(int64_t n, const std::vector<int64_t>& ids, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
bool hybrid) = 0;
Search(int64_t n, const std::vector<int64_t>& ids, int64_t k, const milvus::json& extra_params, float* distances,
int64_t* labels, bool hybrid) = 0;
virtual std::shared_ptr<ExecutionEngine>
BuildIndex(const std::string& location, EngineType engine_type) = 0;
......
......@@ -43,22 +43,22 @@ namespace engine {
namespace {
Status
MappingMetricType(MetricType metric_type, knowhere::METRICTYPE& kw_type) {
MappingMetricType(MetricType metric_type, milvus::json& conf) {
switch (metric_type) {
case MetricType::IP:
kw_type = knowhere::METRICTYPE::IP;
conf[knowhere::Metric::TYPE] = knowhere::Metric::IP;
break;
case MetricType::L2:
kw_type = knowhere::METRICTYPE::L2;
conf[knowhere::Metric::TYPE] = knowhere::Metric::L2;
break;
case MetricType::HAMMING:
kw_type = knowhere::METRICTYPE::HAMMING;
conf[knowhere::Metric::TYPE] = knowhere::Metric::HAMMING;
break;
case MetricType::JACCARD:
kw_type = knowhere::METRICTYPE::JACCARD;
conf[knowhere::Metric::TYPE] = knowhere::Metric::JACCARD;
break;
case MetricType::TANIMOTO:
kw_type = knowhere::METRICTYPE::TANIMOTO;
conf[knowhere::Metric::TYPE] = knowhere::Metric::TANIMOTO;
break;
default:
return Status(DB_ERROR, "Unsupported metric type");
......@@ -94,8 +94,12 @@ class CachedQuantizer : public cache::DataObj {
};
ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, const std::string& location, EngineType index_type,
MetricType metric_type, int32_t nlist)
: location_(location), dim_(dimension), index_type_(index_type), metric_type_(metric_type), nlist_(nlist) {
MetricType metric_type, const milvus::json& index_params)
: location_(location),
dim_(dimension),
index_type_(index_type),
metric_type_(metric_type),
index_params_(index_params) {
EngineType tmp_index_type = server::ValidationUtil::IsBinaryMetricType((int32_t)metric_type)
? EngineType::FAISS_BIN_IDMAP
: EngineType::FAISS_IDMAP;
......@@ -104,16 +108,15 @@ ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, const std::string&
throw Exception(DB_ERROR, "Unsupported index type");
}
TempMetaConf temp_conf;
temp_conf.gpu_id = gpu_num_;
temp_conf.dim = dimension;
auto status = MappingMetricType(metric_type, temp_conf.metric_type);
if (!status.ok()) {
throw Exception(DB_ERROR, status.message());
}
milvus::json conf = index_params;
conf[knowhere::meta::DEVICEID] = gpu_num_;
conf[knowhere::meta::DIM] = dimension;
MappingMetricType(metric_type, conf);
ENGINE_LOG_DEBUG << "Index params: " << conf.dump();
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
auto conf = adapter->Match(temp_conf);
if (!adapter->CheckTrain(conf)) {
throw Exception(DB_ERROR, "Illegal index params");
}
ErrorCode ec = KNOWHERE_UNEXPECTED_ERROR;
if (auto bf_index = std::dynamic_pointer_cast<BFIndex>(index_)) {
......@@ -127,8 +130,12 @@ ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, const std::string&
}
ExecutionEngineImpl::ExecutionEngineImpl(VecIndexPtr index, const std::string& location, EngineType index_type,
MetricType metric_type, int32_t nlist)
: index_(std::move(index)), location_(location), index_type_(index_type), metric_type_(metric_type), nlist_(nlist) {
MetricType metric_type, const milvus::json& index_params)
: index_(std::move(index)),
location_(location),
index_type_(index_type),
metric_type_(metric_type),
index_params_(index_params) {
}
VecIndexPtr
......@@ -273,10 +280,9 @@ ExecutionEngineImpl::HybridLoad() const {
auto best_index = std::distance(all_free_mem.begin(), max_e);
auto best_device_id = gpus[best_index];
auto quantizer_conf = std::make_shared<knowhere::QuantizerCfg>();
quantizer_conf->mode = 1;
quantizer_conf->gpu_id = best_device_id;
milvus::json quantizer_conf{{knowhere::meta::DEVICEID, best_device_id}, {"mode", 1}};
auto quantizer = index_->LoadQuantizer(quantizer_conf);
ENGINE_LOG_DEBUG << "Quantizer params: " << quantizer_conf.dump();
if (quantizer == nullptr) {
ENGINE_LOG_ERROR << "quantizer is nullptr";
}
......@@ -403,19 +409,15 @@ ExecutionEngineImpl::Load(bool to_cache) {
if (index_type_ == EngineType::FAISS_IDMAP || index_type_ == EngineType::FAISS_BIN_IDMAP) {
index_ = index_type_ == EngineType::FAISS_IDMAP ? GetVecIndexFactory(IndexType::FAISS_IDMAP)
: GetVecIndexFactory(IndexType::FAISS_BIN_IDMAP);
TempMetaConf temp_conf;
temp_conf.gpu_id = gpu_num_;
temp_conf.dim = dim_;
auto status = MappingMetricType(metric_type_, temp_conf.metric_type);
if (!status.ok()) {
return status;
}
milvus::json conf{{knowhere::meta::DEVICEID, gpu_num_}, {knowhere::meta::DIM, dim_}};
MappingMetricType(metric_type_, conf);
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
auto conf = adapter->Match(temp_conf);
ENGINE_LOG_DEBUG << "Index params: " << conf.dump();
if (!adapter->CheckTrain(conf)) {
throw Exception(DB_ERROR, "Illegal index params");
}
status = segment_reader_ptr->Load();
auto status = segment_reader_ptr->Load();
if (!status.ok()) {
std::string msg = "Failed to load segment from " + location_;
ENGINE_LOG_ERROR << msg;
......@@ -453,7 +455,7 @@ ExecutionEngineImpl::Load(bool to_cache) {
float_vectors.data(), Config());
status = std::static_pointer_cast<BFIndex>(index_)->SetBlacklist(concurrent_bitset_ptr);
int64_t index_size = vectors->GetCount() * conf->d * sizeof(float);
int64_t index_size = vectors->GetCount() * dim_ * sizeof(float);
int64_t bitset_size = vectors->GetCount() / 8;
index_->set_size(index_size + bitset_size);
} else if (index_type_ == EngineType::FAISS_BIN_IDMAP) {
......@@ -465,7 +467,7 @@ ExecutionEngineImpl::Load(bool to_cache) {
vectors_data.data(), Config());
status = std::static_pointer_cast<BinBFIndex>(index_)->SetBlacklist(concurrent_bitset_ptr);
int64_t index_size = vectors->GetCount() * conf->d * sizeof(uint8_t);
int64_t index_size = vectors->GetCount() * dim_ * sizeof(uint8_t);
int64_t bitset_size = vectors->GetCount() / 8;
index_->set_size(index_size + bitset_size);
}
......@@ -548,9 +550,7 @@ ExecutionEngineImpl::CopyToGpu(uint64_t device_id, bool hybrid) {
if (device_id != NOT_FOUND) {
// cache hit
auto config = std::make_shared<knowhere::QuantizerCfg>();
config->gpu_id = device_id;
config->mode = 2;
milvus::json quantizer_conf{{knowhere::meta::DEVICEID : device_id}, {"mode" : 2}};
auto new_index = index_->LoadData(quantizer, config);
index_ = new_index;
}
......@@ -723,19 +723,18 @@ ExecutionEngineImpl::BuildIndex(const std::string& location, EngineType engine_t
throw Exception(DB_ERROR, "Unsupported index type");
}
TempMetaConf temp_conf;
temp_conf.gpu_id = gpu_num_;
temp_conf.dim = Dimension();
temp_conf.nlist = nlist_;
temp_conf.size = Count();
auto status = MappingMetricType(metric_type_, temp_conf.metric_type);
if (!status.ok()) {
throw Exception(DB_ERROR, status.message());
}
milvus::json conf = index_params_;
conf[knowhere::meta::DIM] = Dimension();
conf[knowhere::meta::ROWS] = Count();
conf[knowhere::meta::DEVICEID] = gpu_num_;
MappingMetricType(metric_type_, conf);
ENGINE_LOG_DEBUG << "Index params: " << conf.dump();
auto adapter = AdapterMgr::GetInstance().GetAdapter(to_index->GetType());
auto conf = adapter->Match(temp_conf);
if (!adapter->CheckTrain(conf)) {
throw Exception(DB_ERROR, "Illegal index params");
}
ENGINE_LOG_DEBUG << "Index config: " << conf.dump();
auto status = Status::OK();
if (from_index) {
status = to_index->BuildAll(Count(), from_index->GetRawVectors(), from_index->GetRawIds(), conf);
} else if (bin_from_index) {
......@@ -746,7 +745,7 @@ ExecutionEngineImpl::BuildIndex(const std::string& location, EngineType engine_t
}
ENGINE_LOG_DEBUG << "Finish build index file: " << location << " size: " << to_index->Size();
return std::make_shared<ExecutionEngineImpl>(to_index, location, engine_type, metric_type_, nlist_);
return std::make_shared<ExecutionEngineImpl>(to_index, location, engine_type, metric_type_, index_params_);
}
// map offsets to ids
......@@ -761,8 +760,8 @@ MapUids(const std::vector<segment::doc_id_t>& uids, int64_t* labels, size_t num)
}
Status
ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
bool hybrid) {
ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, const milvus::json& extra_params, float* distances,
int64_t* labels, bool hybrid) {
#if 0
if (index_type_ == EngineType::FAISS_IVFSQ8H) {
if (!hybrid) {
......@@ -786,9 +785,7 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr
if (device_id != NOT_FOUND) {
// cache hit
auto config = std::make_shared<knowhere::QuantizerCfg>();
config->gpu_id = device_id;
config->mode = 2;
milvus::json quantizer_conf{{knowhere::meta::DEVICEID : device_id}, {"mode" : 2}};
auto new_index = index_->LoadData(quantizer, config);
index_ = new_index;
}
......@@ -824,15 +821,13 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr
return Status(DB_ERROR, "index is null");
}
ENGINE_LOG_DEBUG << "Search Params: [k] " << k << " [nprobe] " << nprobe;
// TODO(linxj): remove here. Get conf from function
TempMetaConf temp_conf;
temp_conf.k = k;
temp_conf.nprobe = nprobe;
milvus::json conf = extra_params;
conf[knowhere::meta::TOPK] = k;
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
auto conf = adapter->MatchSearch(temp_conf, index_->GetType());
ENGINE_LOG_DEBUG << "Search params: " << conf.dump();
if (!adapter->CheckSearch(conf, index_->GetType())) {
throw Exception(DB_ERROR, "Illegal search params");
}
if (hybrid) {
HybridLoad();
......@@ -858,8 +853,8 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr
}
Status
ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, int64_t nprobe, float* distances,
int64_t* labels, bool hybrid) {
ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, const milvus::json& extra_params,
float* distances, int64_t* labels, bool hybrid) {
TimeRecorder rc("ExecutionEngineImpl::Search uint8");
if (index_ == nullptr) {
......@@ -867,15 +862,13 @@ ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, int64_t n
return Status(DB_ERROR, "index is null");
}
ENGINE_LOG_DEBUG << "Search Params: [k] " << k << " [nprobe] " << nprobe;
// TODO(linxj): remove here. Get conf from function
TempMetaConf temp_conf;
temp_conf.k = k;
temp_conf.nprobe = nprobe;
milvus::json conf = extra_params;
conf[knowhere::meta::TOPK] = k;
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
auto conf = adapter->MatchSearch(temp_conf, index_->GetType());
ENGINE_LOG_DEBUG << "Search params: " << conf.dump();
if (!adapter->CheckSearch(conf, index_->GetType())) {
throw Exception(DB_ERROR, "Illegal search params");
}
if (hybrid) {
HybridLoad();
......@@ -901,8 +894,8 @@ ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, int64_t n
}
Status
ExecutionEngineImpl::Search(int64_t n, const std::vector<int64_t>& ids, int64_t k, int64_t nprobe, float* distances,
int64_t* labels, bool hybrid) {
ExecutionEngineImpl::Search(int64_t n, const std::vector<int64_t>& ids, int64_t k, const milvus::json& extra_params,
float* distances, int64_t* labels, bool hybrid) {
TimeRecorder rc("ExecutionEngineImpl::Search vector of ids");
if (index_ == nullptr) {
......@@ -910,15 +903,13 @@ ExecutionEngineImpl::Search(int64_t n, const std::vector<int64_t>& ids, int64_t
return Status(DB_ERROR, "index is null");
}
ENGINE_LOG_DEBUG << "Search by ids Params: [k] " << k << " [nprobe] " << nprobe;
// TODO(linxj): remove here. Get conf from function
TempMetaConf temp_conf;
temp_conf.k = k;
temp_conf.nprobe = nprobe;
milvus::json conf = extra_params;
conf[knowhere::meta::TOPK] = k;
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
auto conf = adapter->MatchSearch(temp_conf, index_->GetType());
ENGINE_LOG_DEBUG << "Search params: " << conf.dump();
if (!adapter->CheckSearch(conf, index_->GetType())) {
throw Exception(DB_ERROR, "Illegal search params");
}
if (hybrid) {
HybridLoad();
......@@ -993,19 +984,13 @@ ExecutionEngineImpl::GetVectorByID(const int64_t& id, float* vector, bool hybrid
return Status(DB_ERROR, "index is null");
}
// TODO(linxj): remove here. Get conf from function
TempMetaConf temp_conf;
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
auto conf = adapter->MatchSearch(temp_conf, index_->GetType());
if (hybrid) {
HybridLoad();
}
// Only one id for now
std::vector<int64_t> ids{id};
auto status = index_->GetVectorById(1, ids.data(), vector, conf);
auto status = index_->GetVectorById(1, ids.data(), vector, milvus::json());
if (hybrid) {
HybridUnset();
......@@ -1026,19 +1011,13 @@ ExecutionEngineImpl::GetVectorByID(const int64_t& id, uint8_t* vector, bool hybr
ENGINE_LOG_DEBUG << "Get binary vector by id: " << id;
// TODO(linxj): remove here. Get conf from function
TempMetaConf temp_conf;
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
auto conf = adapter->MatchSearch(temp_conf, index_->GetType());
if (hybrid) {
HybridLoad();
}
// Only one id for now
std::vector<int64_t> ids{id};
auto status = index_->GetVectorById(1, ids.data(), vector, conf);
auto status = index_->GetVectorById(1, ids.data(), vector, milvus::json());
if (hybrid) {
HybridUnset();
......@@ -1075,7 +1054,7 @@ ExecutionEngineImpl::Init() {
std::vector<int64_t> gpu_ids;
Status s = config.GetGpuResourceConfigBuildIndexResources(gpu_ids);
if (!s.ok()) {
gpu_num_ = knowhere::INVALID_VALUE;
gpu_num_ = -1;
return s;
}
for (auto id : gpu_ids) {
......
......@@ -11,7 +11,8 @@
#pragma once
#include <src/segment/SegmentReader.h>
#include "segment/SegmentReader.h"
#include "utils/Json.h"
#include <memory>
#include <string>
......@@ -26,10 +27,10 @@ namespace engine {
class ExecutionEngineImpl : public ExecutionEngine {
public:
ExecutionEngineImpl(uint16_t dimension, const std::string& location, EngineType index_type, MetricType metric_type,
int32_t nlist);
const milvus::json& index_params);
ExecutionEngineImpl(VecIndexPtr index, const std::string& location, EngineType index_type, MetricType metric_type,
int32_t nlist);
const milvus::json& index_params);
Status
AddWithIds(int64_t n, const float* xdata, const int64_t* xids) override;
......@@ -77,16 +78,16 @@ class ExecutionEngineImpl : public ExecutionEngine {
GetVectorByID(const int64_t& id, uint8_t* vector, bool hybrid) override;
Status
Search(int64_t n, const float* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
Search(int64_t n, const float* data, int64_t k, const milvus::json& extra_params, float* distances, int64_t* labels,
bool hybrid = false) override;
Status
Search(int64_t n, const uint8_t* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
bool hybrid = false) override;
Search(int64_t n, const uint8_t* data, int64_t k, const milvus::json& extra_params, float* distances,
int64_t* labels, bool hybrid = false) override;
Status
Search(int64_t n, const std::vector<int64_t>& ids, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
bool hybrid) override;
Search(int64_t n, const std::vector<int64_t>& ids, int64_t k, const milvus::json& extra_params, float* distances,
int64_t* labels, bool hybrid) override;
ExecutionEnginePtr
BuildIndex(const std::string& location, EngineType engine_type) override;
......@@ -136,7 +137,7 @@ class ExecutionEngineImpl : public ExecutionEngine {
int64_t dim_;
std::string location_;
int64_t nlist_ = 0;
milvus::json index_params_;
int64_t gpu_num_ = 0;
};
......
......@@ -54,7 +54,7 @@ struct TableSchema {
int64_t flag_ = 0;
int64_t index_file_size_ = DEFAULT_INDEX_FILE_SIZE;
int32_t engine_type_ = DEFAULT_ENGINE_TYPE;
int32_t nlist_ = DEFAULT_NLIST;
std::string index_params_ = "{ \"nlist\": 16384 }";
int32_t metric_type_ = DEFAULT_METRIC_TYPE;
std::string owner_table_;
std::string partition_tag_;
......@@ -89,7 +89,7 @@ struct TableFileSchema {
int64_t created_on_ = 0;
int64_t index_file_size_ = DEFAULT_INDEX_FILE_SIZE; // not persist to meta
int32_t engine_type_ = DEFAULT_ENGINE_TYPE;
int32_t nlist_ = DEFAULT_NLIST; // not persist to meta
std::string index_params_; // not persist to meta
int32_t metric_type_ = DEFAULT_METRIC_TYPE; // not persist to meta
uint64_t flush_lsn_ = 0;
}; // TableFileSchema
......
......@@ -144,7 +144,7 @@ static const MetaSchema TABLES_SCHEMA(META_TABLES, {
MetaField("flag", "BIGINT", "DEFAULT 0 NOT NULL"),
MetaField("index_file_size", "BIGINT", "DEFAULT 1024 NOT NULL"),
MetaField("engine_type", "INT", "DEFAULT 1 NOT NULL"),
MetaField("nlist", "INT", "DEFAULT 16384 NOT NULL"),
MetaField("index_params", "VARCHAR(512)", "NOT NULL"),
MetaField("metric_type", "INT", "DEFAULT 1 NOT NULL"),
MetaField("owner_table", "VARCHAR(255)", "NOT NULL"),
MetaField("partition_tag", "VARCHAR(255)", "NOT NULL"),
......@@ -398,7 +398,7 @@ MySQLMetaImpl::CreateTable(TableSchema& table_schema) {
std::string flag = std::to_string(table_schema.flag_);
std::string index_file_size = std::to_string(table_schema.index_file_size_);
std::string engine_type = std::to_string(table_schema.engine_type_);
std::string nlist = std::to_string(table_schema.nlist_);
std::string& index_params = table_schema.index_params_;
std::string metric_type = std::to_string(table_schema.metric_type_);
std::string& owner_table = table_schema.owner_table_;
std::string& partition_tag = table_schema.partition_tag_;
......@@ -407,9 +407,9 @@ MySQLMetaImpl::CreateTable(TableSchema& table_schema) {
createTableQuery << "INSERT INTO " << META_TABLES << " VALUES(" << id << ", " << mysqlpp::quote << table_id
<< ", " << state << ", " << dimension << ", " << created_on << ", " << flag << ", "
<< index_file_size << ", " << engine_type << ", " << nlist << ", " << metric_type << ", "
<< mysqlpp::quote << owner_table << ", " << mysqlpp::quote << partition_tag << ", "
<< mysqlpp::quote << version << ", " << flush_lsn << ");";
<< index_file_size << ", " << engine_type << ", " << mysqlpp::quote << index_params << ", "
<< metric_type << ", " << mysqlpp::quote << owner_table << ", " << mysqlpp::quote
<< partition_tag << ", " << mysqlpp::quote << version << ", " << flush_lsn << ");";
ENGINE_LOG_DEBUG << "MySQLMetaImpl::CreateTable: " << createTableQuery.str();
......@@ -446,8 +446,8 @@ MySQLMetaImpl::DescribeTable(TableSchema& table_schema) {
mysqlpp::Query describeTableQuery = connectionPtr->query();
describeTableQuery
<< "SELECT id, state, dimension, created_on, flag, index_file_size, engine_type, nlist, metric_type"
<< " ,owner_table, partition_tag, version, flush_lsn"
<< "SELECT id, state, dimension, created_on, flag, index_file_size, engine_type, index_params"
<< " , metric_type ,owner_table, partition_tag, version, flush_lsn"
<< " FROM " << META_TABLES << " WHERE table_id = " << mysqlpp::quote << table_schema.table_id_
<< " AND state <> " << std::to_string(TableSchema::TO_DELETE) << ";";
......@@ -465,7 +465,7 @@ MySQLMetaImpl::DescribeTable(TableSchema& table_schema) {
table_schema.flag_ = resRow["flag"];
table_schema.index_file_size_ = resRow["index_file_size"];
table_schema.engine_type_ = resRow["engine_type"];
table_schema.nlist_ = resRow["nlist"];
resRow["index_params"].to_string(table_schema.index_params_);
table_schema.metric_type_ = resRow["metric_type"];
resRow["owner_table"].to_string(table_schema.owner_table_);
resRow["partition_tag"].to_string(table_schema.partition_tag_);
......@@ -534,7 +534,7 @@ MySQLMetaImpl::AllTables(std::vector<TableSchema>& table_schema_array) {
}
mysqlpp::Query allTablesQuery = connectionPtr->query();
allTablesQuery << "SELECT id, table_id, dimension, engine_type, nlist, index_file_size, metric_type"
allTablesQuery << "SELECT id, table_id, dimension, engine_type, index_params, index_file_size, metric_type"
<< " ,owner_table, partition_tag, version, flush_lsn"
<< " FROM " << META_TABLES << " WHERE state <> " << std::to_string(TableSchema::TO_DELETE)
<< " AND owner_table = \"\";";
......@@ -551,7 +551,7 @@ MySQLMetaImpl::AllTables(std::vector<TableSchema>& table_schema_array) {
table_schema.dimension_ = resRow["dimension"];
table_schema.index_file_size_ = resRow["index_file_size"];
table_schema.engine_type_ = resRow["engine_type"];
table_schema.nlist_ = resRow["nlist"];
resRow["index_params"].to_string(table_schema.index_params_);
table_schema.metric_type_ = resRow["metric_type"];
resRow["owner_table"].to_string(table_schema.owner_table_);
resRow["partition_tag"].to_string(table_schema.partition_tag_);
......@@ -673,6 +673,7 @@ MySQLMetaImpl::CreateTableFile(TableFileSchema& file_schema) {
file_schema.created_on_ = utils::GetMicroSecTimeStamp();
file_schema.updated_time_ = file_schema.created_on_;
file_schema.index_file_size_ = table_schema.index_file_size_;
file_schema.index_params_ = table_schema.index_params_;
if (file_schema.file_type_ == TableFileSchema::FILE_TYPE::NEW ||
file_schema.file_type_ == TableFileSchema::FILE_TYPE::NEW_MERGE) {
......@@ -683,7 +684,6 @@ MySQLMetaImpl::CreateTableFile(TableFileSchema& file_schema) {
file_schema.engine_type_ = table_schema.engine_type_;
}
file_schema.nlist_ = table_schema.nlist_;
file_schema.metric_type_ = table_schema.metric_type_;
std::string id = "NULL"; // auto-increment
......@@ -785,7 +785,7 @@ MySQLMetaImpl::GetTableFiles(const std::string& table_id, const std::vector<size
resRow["segment_id"].to_string(file_schema.segment_id_);
file_schema.index_file_size_ = table_schema.index_file_size_;
file_schema.engine_type_ = resRow["engine_type"];
file_schema.nlist_ = table_schema.nlist_;
file_schema.index_params_ = table_schema.index_params_;
file_schema.metric_type_ = table_schema.metric_type_;
resRow["file_id"].to_string(file_schema.file_id_);
file_schema.file_type_ = resRow["file_type"];
......@@ -844,7 +844,7 @@ MySQLMetaImpl::GetTableFilesBySegmentId(const std::string& segment_id,
resRow["segment_id"].to_string(file_schema.segment_id_);
file_schema.index_file_size_ = table_schema.index_file_size_;
file_schema.engine_type_ = resRow["engine_type"];
file_schema.nlist_ = table_schema.nlist_;
file_schema.index_params_ = table_schema.index_params_;
file_schema.metric_type_ = table_schema.metric_type_;
resRow["file_id"].to_string(file_schema.file_id_);
file_schema.file_type_ = resRow["file_type"];
......@@ -900,7 +900,8 @@ MySQLMetaImpl::UpdateTableIndex(const std::string& table_id, const TableIndex& i
updateTableIndexParamQuery << "UPDATE " << META_TABLES << " SET id = " << id << " ,state = " << state
<< " ,dimension = " << dimension << " ,created_on = " << created_on
<< " ,engine_type = " << index.engine_type_ << " ,nlist = " << index.nlist_
<< " ,engine_type = " << index.engine_type_
<< " ,index_params = " << mysqlpp::quote << index.extra_params_.dump()
<< " ,metric_type = " << index.metric_type_
<< " WHERE table_id = " << mysqlpp::quote << table_id << ";";
......@@ -1044,7 +1045,7 @@ MySQLMetaImpl::GetTableFilesByFlushLSN(uint64_t flush_lsn, TableFilesSchema& tab
}
table_file.dimension_ = groups[table_file.table_id_].dimension_;
table_file.index_file_size_ = groups[table_file.table_id_].index_file_size_;
table_file.nlist_ = groups[table_file.table_id_].nlist_;
table_file.index_params_ = groups[table_file.table_id_].index_params_;
table_file.metric_type_ = groups[table_file.table_id_].metric_type_;
auto status = utils::GetTableFilePath(options_, table_file);
......@@ -1263,7 +1264,7 @@ MySQLMetaImpl::DescribeTableIndex(const std::string& table_id, TableIndex& index
}
mysqlpp::Query describeTableIndexQuery = connectionPtr->query();
describeTableIndexQuery << "SELECT engine_type, nlist, index_file_size, metric_type"
describeTableIndexQuery << "SELECT engine_type, index_params, index_file_size, metric_type"
<< " FROM " << META_TABLES << " WHERE table_id = " << mysqlpp::quote << table_id
<< " AND state <> " << std::to_string(TableSchema::TO_DELETE) << ";";
......@@ -1275,7 +1276,9 @@ MySQLMetaImpl::DescribeTableIndex(const std::string& table_id, TableIndex& index
const mysqlpp::Row& resRow = res[0];
index.engine_type_ = resRow["engine_type"];
index.nlist_ = resRow["nlist"];
std::string str_index_params;
resRow["index_params"].to_string(str_index_params);
index.extra_params_ = milvus::json::parse(str_index_params);
index.metric_type_ = resRow["metric_type"];
} else {
return Status(DB_NOT_FOUND, "Table " + table_id + " not found");
......@@ -1334,7 +1337,7 @@ MySQLMetaImpl::DropTableIndex(const std::string& table_id) {
// set table index type to raw
dropTableIndexQuery << "UPDATE " << META_TABLES
<< " SET engine_type = " << std::to_string(DEFAULT_ENGINE_TYPE)
<< " ,nlist = " << std::to_string(DEFAULT_NLIST)
<< " , index_params = '{}'"
<< " WHERE table_id = " << mysqlpp::quote << table_id << ";";
ENGINE_LOG_DEBUG << "MySQLMetaImpl::DropTableIndex: " << dropTableIndexQuery.str();
......@@ -1426,7 +1429,7 @@ MySQLMetaImpl::ShowPartitions(const std::string& table_id, std::vector<meta::Tab
mysqlpp::Query allPartitionsQuery = connectionPtr->query();
allPartitionsQuery << "SELECT table_id, id, state, dimension, created_on, flag, index_file_size,"
<< " engine_type, nlist, metric_type, partition_tag, version FROM " << META_TABLES
<< " engine_type, index_params, metric_type, partition_tag, version FROM " << META_TABLES
<< " WHERE owner_table = " << mysqlpp::quote << table_id << " AND state <> "
<< std::to_string(TableSchema::TO_DELETE) << ";";
......@@ -1445,7 +1448,7 @@ MySQLMetaImpl::ShowPartitions(const std::string& table_id, std::vector<meta::Tab
partition_schema.flag_ = resRow["flag"];
partition_schema.index_file_size_ = resRow["index_file_size"];
partition_schema.engine_type_ = resRow["engine_type"];
partition_schema.nlist_ = resRow["nlist"];
resRow["index_params"].to_string(partition_schema.index_params_);
partition_schema.metric_type_ = resRow["metric_type"];
partition_schema.owner_table_ = table_id;
resRow["partition_tag"].to_string(partition_schema.partition_tag_);
......@@ -1562,7 +1565,7 @@ MySQLMetaImpl::FilesToSearch(const std::string& table_id, const std::vector<size
resRow["segment_id"].to_string(table_file.segment_id_);
table_file.index_file_size_ = table_schema.index_file_size_;
table_file.engine_type_ = resRow["engine_type"];
table_file.nlist_ = table_schema.nlist_;
table_file.index_params_ = table_schema.index_params_;
table_file.metric_type_ = table_schema.metric_type_;
resRow["file_id"].to_string(table_file.file_id_);
table_file.file_type_ = resRow["file_type"];
......@@ -1644,7 +1647,7 @@ MySQLMetaImpl::FilesToMerge(const std::string& table_id, TableFilesSchema& files
table_file.date_ = resRow["date"];
table_file.index_file_size_ = table_schema.index_file_size_;
table_file.engine_type_ = resRow["engine_type"];
table_file.nlist_ = table_schema.nlist_;
table_file.index_params_ = table_schema.index_params_;
table_file.metric_type_ = table_schema.metric_type_;
table_file.created_on_ = resRow["created_on"];
table_file.dimension_ = table_schema.dimension_;
......@@ -1722,7 +1725,7 @@ MySQLMetaImpl::FilesToIndex(TableFilesSchema& files) {
}
table_file.dimension_ = groups[table_file.table_id_].dimension_;
table_file.index_file_size_ = groups[table_file.table_id_].index_file_size_;
table_file.nlist_ = groups[table_file.table_id_].nlist_;
table_file.index_params_ = groups[table_file.table_id_].index_params_;
table_file.metric_type_ = groups[table_file.table_id_].metric_type_;
auto status = utils::GetTableFilePath(options_, table_file);
......@@ -1809,7 +1812,7 @@ MySQLMetaImpl::FilesByType(const std::string& table_id, const std::vector<int>&
file_schema.created_on_ = resRow["created_on"];
file_schema.index_file_size_ = table_schema.index_file_size_;
file_schema.nlist_ = table_schema.nlist_;
file_schema.index_params_ = table_schema.index_params_;
file_schema.metric_type_ = table_schema.metric_type_;
file_schema.dimension_ = table_schema.dimension_;
......
......@@ -68,7 +68,8 @@ StoragePrototype(const std::string& path) {
make_column("created_on", &TableSchema::created_on_),
make_column("flag", &TableSchema::flag_, default_value(0)),
make_column("index_file_size", &TableSchema::index_file_size_),
make_column("engine_type", &TableSchema::engine_type_), make_column("nlist", &TableSchema::nlist_),
make_column("engine_type", &TableSchema::engine_type_),
make_column("index_params", &TableSchema::index_params_),
make_column("metric_type", &TableSchema::metric_type_),
make_column("owner_table", &TableSchema::owner_table_, default_value("")),
make_column("partition_tag", &TableSchema::partition_tag_, default_value("")),
......@@ -213,7 +214,7 @@ SqliteMetaImpl::DescribeTable(TableSchema& table_schema) {
auto groups = ConnectorPtr->select(
columns(&TableSchema::id_, &TableSchema::state_, &TableSchema::dimension_, &TableSchema::created_on_,
&TableSchema::flag_, &TableSchema::index_file_size_, &TableSchema::engine_type_,
&TableSchema::nlist_, &TableSchema::metric_type_, &TableSchema::owner_table_,
&TableSchema::index_params_, &TableSchema::metric_type_, &TableSchema::owner_table_,
&TableSchema::partition_tag_, &TableSchema::version_, &TableSchema::flush_lsn_),
where(c(&TableSchema::table_id_) == table_schema.table_id_ and
c(&TableSchema::state_) != (int)TableSchema::TO_DELETE));
......@@ -226,7 +227,7 @@ SqliteMetaImpl::DescribeTable(TableSchema& table_schema) {
table_schema.flag_ = std::get<4>(groups[0]);
table_schema.index_file_size_ = std::get<5>(groups[0]);
table_schema.engine_type_ = std::get<6>(groups[0]);
table_schema.nlist_ = std::get<7>(groups[0]);
table_schema.index_params_ = std::get<7>(groups[0]);
table_schema.metric_type_ = std::get<8>(groups[0]);
table_schema.owner_table_ = std::get<9>(groups[0]);
table_schema.partition_tag_ = std::get<10>(groups[0]);
......@@ -272,7 +273,7 @@ SqliteMetaImpl::AllTables(std::vector<TableSchema>& table_schema_array) {
auto selected = ConnectorPtr->select(
columns(&TableSchema::id_, &TableSchema::table_id_, &TableSchema::dimension_, &TableSchema::created_on_,
&TableSchema::flag_, &TableSchema::index_file_size_, &TableSchema::engine_type_,
&TableSchema::nlist_, &TableSchema::metric_type_, &TableSchema::owner_table_,
&TableSchema::index_params_, &TableSchema::metric_type_, &TableSchema::owner_table_,
&TableSchema::partition_tag_, &TableSchema::version_, &TableSchema::flush_lsn_),
where(c(&TableSchema::state_) != (int)TableSchema::TO_DELETE and c(&TableSchema::owner_table_) == ""));
for (auto& table : selected) {
......@@ -284,7 +285,7 @@ SqliteMetaImpl::AllTables(std::vector<TableSchema>& table_schema_array) {
schema.flag_ = std::get<4>(table);
schema.index_file_size_ = std::get<5>(table);
schema.engine_type_ = std::get<6>(table);
schema.nlist_ = std::get<7>(table);
schema.index_params_ = std::get<7>(table);
schema.metric_type_ = std::get<8>(table);
schema.owner_table_ = std::get<9>(table);
schema.partition_tag_ = std::get<10>(table);
......@@ -373,6 +374,7 @@ SqliteMetaImpl::CreateTableFile(TableFileSchema& file_schema) {
file_schema.created_on_ = utils::GetMicroSecTimeStamp();
file_schema.updated_time_ = file_schema.created_on_;
file_schema.index_file_size_ = table_schema.index_file_size_;
file_schema.index_params_ = table_schema.index_params_;
if (file_schema.file_type_ == TableFileSchema::FILE_TYPE::NEW ||
file_schema.file_type_ == TableFileSchema::FILE_TYPE::NEW_MERGE) {
......@@ -383,7 +385,6 @@ SqliteMetaImpl::CreateTableFile(TableFileSchema& file_schema) {
file_schema.engine_type_ = table_schema.engine_type_;
}
file_schema.nlist_ = table_schema.nlist_;
file_schema.metric_type_ = table_schema.metric_type_;
// multi-threads call sqlite update may get exception('bad logic', etc), so we add a lock here
......@@ -436,7 +437,7 @@ SqliteMetaImpl::GetTableFiles(const std::string& table_id, const std::vector<siz
file_schema.created_on_ = std::get<8>(file);
file_schema.dimension_ = table_schema.dimension_;
file_schema.index_file_size_ = table_schema.index_file_size_;
file_schema.nlist_ = table_schema.nlist_;
file_schema.index_params_ = table_schema.index_params_;
file_schema.metric_type_ = table_schema.metric_type_;
utils::GetTableFilePath(options_, file_schema);
......@@ -486,7 +487,7 @@ SqliteMetaImpl::GetTableFilesBySegmentId(const std::string& segment_id,
file_schema.created_on_ = std::get<9>(file);
file_schema.dimension_ = table_schema.dimension_;
file_schema.index_file_size_ = table_schema.index_file_size_;
file_schema.nlist_ = table_schema.nlist_;
file_schema.index_params_ = table_schema.index_params_;
file_schema.metric_type_ = table_schema.metric_type_;
utils::GetTableFilePath(options_, file_schema);
......@@ -601,7 +602,7 @@ SqliteMetaImpl::GetTableFilesByFlushLSN(uint64_t flush_lsn, TableFilesSchema& ta
}
table_file.dimension_ = groups[table_file.table_id_].dimension_;
table_file.index_file_size_ = groups[table_file.table_id_].index_file_size_;
table_file.nlist_ = groups[table_file.table_id_].nlist_;
table_file.index_params_ = groups[table_file.table_id_].index_params_;
table_file.metric_type_ = groups[table_file.table_id_].metric_type_;
table_files.push_back(table_file);
}
......@@ -721,7 +722,7 @@ SqliteMetaImpl::UpdateTableIndex(const std::string& table_id, const TableIndex&
table_schema.partition_tag_ = std::get<7>(tables[0]);
table_schema.version_ = std::get<8>(tables[0]);
table_schema.engine_type_ = index.engine_type_;
table_schema.nlist_ = index.nlist_;
table_schema.index_params_ = index.extra_params_.dump();
table_schema.metric_type_ = index.metric_type_;
ConnectorPtr->update(table_schema);
......@@ -773,12 +774,12 @@ SqliteMetaImpl::DescribeTableIndex(const std::string& table_id, TableIndex& inde
fiu_do_on("SqliteMetaImpl.DescribeTableIndex.throw_exception", throw std::exception());
auto groups = ConnectorPtr->select(
columns(&TableSchema::engine_type_, &TableSchema::nlist_, &TableSchema::metric_type_),
columns(&TableSchema::engine_type_, &TableSchema::index_params_, &TableSchema::metric_type_),
where(c(&TableSchema::table_id_) == table_id and c(&TableSchema::state_) != (int)TableSchema::TO_DELETE));
if (groups.size() == 1) {
index.engine_type_ = std::get<0>(groups[0]);
index.nlist_ = std::get<1>(groups[0]);
index.extra_params_ = milvus::json::parse(std::get<1>(groups[0]));
index.metric_type_ = std::get<2>(groups[0]);
} else {
return Status(DB_NOT_FOUND, "Table " + table_id + " not found");
......@@ -813,7 +814,7 @@ SqliteMetaImpl::DropTableIndex(const std::string& table_id) {
// set table index type to raw
ConnectorPtr->update_all(
set(c(&TableSchema::engine_type_) = DEFAULT_ENGINE_TYPE, c(&TableSchema::nlist_) = DEFAULT_NLIST),
set(c(&TableSchema::engine_type_) = DEFAULT_ENGINE_TYPE, c(&TableSchema::index_params_) = "{}"),
where(c(&TableSchema::table_id_) == table_id));
ENGINE_LOG_DEBUG << "Successfully drop table index, table id = " << table_id;
......@@ -886,13 +887,14 @@ SqliteMetaImpl::ShowPartitions(const std::string& table_id, std::vector<meta::Ta
server::MetricCollector metric;
fiu_do_on("SqliteMetaImpl.ShowPartitions.throw_exception", throw std::exception());
auto partitions =
ConnectorPtr->select(columns(&TableSchema::id_, &TableSchema::state_, &TableSchema::dimension_,
&TableSchema::created_on_, &TableSchema::flag_, &TableSchema::index_file_size_,
&TableSchema::engine_type_, &TableSchema::nlist_, &TableSchema::metric_type_,
&TableSchema::partition_tag_, &TableSchema::version_, &TableSchema::table_id_),
where(c(&TableSchema::owner_table_) == table_id and
c(&TableSchema::state_) != (int)TableSchema::TO_DELETE));
auto partitions = ConnectorPtr->select(
columns(&TableSchema::id_, &TableSchema::state_, &TableSchema::dimension_, &TableSchema::created_on_,
&TableSchema::flag_, &TableSchema::index_file_size_, &TableSchema::engine_type_,
&TableSchema::index_params_, &TableSchema::metric_type_, &TableSchema::partition_tag_,
&TableSchema::version_, &TableSchema::table_id_),
where(c(&TableSchema::owner_table_) == table_id and
c(&TableSchema::state_) != (int)TableSchema::TO_DELETE));
for (size_t i = 0; i < partitions.size(); i++) {
meta::TableSchema partition_schema;
partition_schema.id_ = std::get<0>(partitions[i]);
......@@ -902,7 +904,7 @@ SqliteMetaImpl::ShowPartitions(const std::string& table_id, std::vector<meta::Ta
partition_schema.flag_ = std::get<4>(partitions[i]);
partition_schema.index_file_size_ = std::get<5>(partitions[i]);
partition_schema.engine_type_ = std::get<6>(partitions[i]);
partition_schema.nlist_ = std::get<7>(partitions[i]);
partition_schema.index_params_ = std::get<7>(partitions[i]);
partition_schema.metric_type_ = std::get<8>(partitions[i]);
partition_schema.owner_table_ = table_id;
partition_schema.partition_tag_ = std::get<9>(partitions[i]);
......@@ -995,7 +997,7 @@ SqliteMetaImpl::FilesToSearch(const std::string& table_id, const std::vector<siz
table_file.engine_type_ = std::get<8>(file);
table_file.dimension_ = table_schema.dimension_;
table_file.index_file_size_ = table_schema.index_file_size_;
table_file.nlist_ = table_schema.nlist_;
table_file.index_params_ = table_schema.index_params_;
table_file.metric_type_ = table_schema.metric_type_;
auto status = utils::GetTableFilePath(options_, table_file);
......@@ -1063,7 +1065,7 @@ SqliteMetaImpl::FilesToMerge(const std::string& table_id, TableFilesSchema& file
table_file.created_on_ = std::get<8>(file);
table_file.dimension_ = table_schema.dimension_;
table_file.index_file_size_ = table_schema.index_file_size_;
table_file.nlist_ = table_schema.nlist_;
table_file.index_params_ = table_schema.index_params_;
table_file.metric_type_ = table_schema.metric_type_;
auto status = utils::GetTableFilePath(options_, table_file);
......@@ -1134,7 +1136,7 @@ SqliteMetaImpl::FilesToIndex(TableFilesSchema& files) {
}
table_file.dimension_ = groups[table_file.table_id_].dimension_;
table_file.index_file_size_ = groups[table_file.table_id_].index_file_size_;
table_file.nlist_ = groups[table_file.table_id_].nlist_;
table_file.index_params_ = groups[table_file.table_id_].index_params_;
table_file.metric_type_ = groups[table_file.table_id_].metric_type_;
files.push_back(table_file);
}
......@@ -1192,7 +1194,7 @@ SqliteMetaImpl::FilesByType(const std::string& table_id, const std::vector<int>&
file_schema.dimension_ = table_schema.dimension_;
file_schema.index_file_size_ = table_schema.index_file_size_;
file_schema.nlist_ = table_schema.nlist_;
file_schema.index_params_ = table_schema.index_params_;
file_schema.metric_type_ = table_schema.metric_type_;
switch (file_schema.file_type_) {
......
......@@ -4,6 +4,14 @@ import "status.proto";
package milvus.grpc;
/**
* @brief general usage
*/
message KeyValuePair {
string key = 1;
string value = 2;
}
/**
* @brief Table name
*/
......@@ -21,6 +29,7 @@ message TableNameList {
/**
* @brief Table schema
* metric_type: 1-L2, 2-IP
*/
message TableSchema {
Status status = 1;
......@@ -28,6 +37,7 @@ message TableSchema {
int64 dimension = 3;
int64 index_file_size = 4;
int32 metric_type = 5;
repeated KeyValuePair extra_params = 6;
}
/**
......@@ -62,6 +72,7 @@ message InsertParam {
repeated RowRecord row_record_array = 2;
repeated int64 row_id_array = 3; //optional
string partition_tag = 4;
repeated KeyValuePair extra_params = 5;
}
/**
......@@ -77,10 +88,10 @@ message VectorIds {
*/
message SearchParam {
string table_name = 1;
repeated RowRecord query_record_array = 2;
int64 topk = 3;
int64 nprobe = 4;
repeated string partition_tag_array = 5;
repeated string partition_tag_array = 2;
repeated RowRecord query_record_array = 3;
int64 topk = 4;
repeated KeyValuePair extra_params = 5;
}
/**
......@@ -96,10 +107,10 @@ message SearchInFilesParam {
*/
message SearchByIDParam {
string table_name = 1;
int64 id = 2;
int64 topk = 3;
int64 nprobe = 4;
repeated string partition_tag_array = 5;
repeated string partition_tag_array = 2;
int64 id = 3;
int64 topk = 4;
repeated KeyValuePair extra_params = 5;
}
/**
......@@ -143,23 +154,15 @@ message Command {
string cmd = 1;
}
/**
* @brief Index
* @index_type: 0-invalid, 1-idmap, 2-ivflat, 3-ivfsq8, 4-nsgmix
* @metric_type: 1-L2, 2-IP
*/
message Index {
int32 index_type = 1;
int32 nlist = 2;
}
/**
* @brief Index params
* @index_type: 0-invalid, 1-idmap, 2-ivflat, 3-ivfsq8, 4-nsgmix
*/
message IndexParam {
Status status = 1;
string table_name = 2;
Index index = 3;
int32 index_type = 3;
repeated KeyValuePair extra_params = 4;
}
/**
......
......@@ -22,7 +22,6 @@ endif ()
set(external_srcs
knowhere/adapter/SptagAdapter.cpp
knowhere/adapter/VectorAdapter.cpp
knowhere/common/Exception.cpp
knowhere/common/Timer.cpp
)
......@@ -117,4 +116,4 @@ set(INDEX_INCLUDE_DIRS
${LAPACK_INCLUDE_DIR}
)
set(INDEX_INCLUDE_DIRS ${INDEX_INCLUDE_DIRS} PARENT_SCOPE)
\ No newline at end of file
set(INDEX_INCLUDE_DIRS ${INDEX_INCLUDE_DIRS} PARENT_SCOPE)
......@@ -43,7 +43,8 @@ std::vector<SPTAG::QueryResult>
ConvertToQueryResult(const DatasetPtr& dataset, const Config& config) {
GETTENSOR(dataset);
std::vector<SPTAG::QueryResult> query_results(rows, SPTAG::QueryResult(nullptr, config->k, true));
std::vector<SPTAG::QueryResult> query_results(rows,
SPTAG::QueryResult(nullptr, config[meta::TOPK].get<int64_t>(), true));
for (auto i = 0; i < rows; ++i) {
query_results[i].SetTarget(&p_data[i * dim]);
}
......
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/adapter/VectorAdapter.h"
namespace knowhere {
namespace meta {
const char* DIM = "dim";
const char* TENSOR = "tensor";
const char* ROWS = "rows";
const char* IDS = "ids";
const char* DISTANCE = "distance";
}; // namespace meta
} // namespace knowhere
......@@ -13,17 +13,10 @@
#include <string>
#include "knowhere/common/Dataset.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
namespace knowhere {
namespace meta {
extern const char* DIM;
extern const char* TENSOR;
extern const char* ROWS;
extern const char* IDS;
extern const char* DISTANCE;
}; // namespace meta
#define GETTENSOR(dataset) \
auto dim = dataset->Get<int64_t>(meta::DIM); \
auto rows = dataset->Get<int64_t>(meta::ROWS); \
......
......@@ -11,64 +11,10 @@
#pragma once
#include <memory>
#include <sstream>
#include "Log.h"
#include "knowhere/common/Exception.h"
#include "src/utils/Json.h"
namespace knowhere {
enum class METRICTYPE {
INVALID = 0,
L2 = 1,
IP = 2,
HAMMING = 20,
JACCARD = 21,
TANIMOTO = 22,
};
// General Config
constexpr int64_t INVALID_VALUE = -1;
constexpr int64_t DEFAULT_K = INVALID_VALUE;
constexpr int64_t DEFAULT_DIM = INVALID_VALUE;
constexpr int64_t DEFAULT_GPUID = INVALID_VALUE;
constexpr METRICTYPE DEFAULT_TYPE = METRICTYPE::INVALID;
struct Cfg {
METRICTYPE metric_type = DEFAULT_TYPE;
int64_t k = DEFAULT_K;
int64_t gpu_id = DEFAULT_GPUID;
int64_t d = DEFAULT_DIM;
Cfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, METRICTYPE type)
: metric_type(type), k(k), gpu_id(gpu_id), d(dim) {
}
Cfg() = default;
virtual bool
CheckValid() {
if (metric_type == METRICTYPE::IP || metric_type == METRICTYPE::L2) {
return true;
}
std::stringstream ss;
ss << "MetricType: " << int(metric_type) << " not support!";
KNOWHERE_THROW_MSG(ss.str());
return false;
}
void
Dump() {
KNOWHERE_LOG_DEBUG << DumpImpl().str();
}
virtual std::stringstream
DumpImpl() {
std::stringstream ss;
ss << "dim: " << d << ", metric: " << int(metric_type) << ", gpuid: " << gpu_id << ", k: " << k;
return ss;
}
};
using Config = std::shared_ptr<Cfg>;
using Config = milvus::json;
} // namespace knowhere
......@@ -16,6 +16,7 @@
#include "IndexModel.h"
#include "IndexType.h"
#include "knowhere/common/BinarySet.h"
#include "knowhere/common/Config.h"
#include "knowhere/common/Dataset.h"
#include "knowhere/index/preprocessor/Preprocessor.h"
......
......@@ -14,6 +14,7 @@
#include <utility>
#include "knowhere/common/Exception.h"
#include "knowhere/common/Log.h"
#include "knowhere/index/vector_index/FaissBaseIndex.h"
#include "knowhere/index/vector_index/IndexIVF.h"
#include "knowhere/index/vector_index/helpers/FaissIO.h"
......
......@@ -15,6 +15,8 @@
#include <faiss/MetaIndexes.h>
#include <faiss/index_factory.h>
#include <string>
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
......@@ -43,13 +45,13 @@ BinaryIDMAP::Search(const DatasetPtr& dataset, const Config& config) {
}
GETBINARYTENSOR(dataset)
auto elems = rows * config->k;
auto elems = rows * config[meta::TOPK].get<int64_t>();
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
search_impl(rows, (uint8_t*)p_data, config->k, p_dist, p_id, Config());
search_impl(rows, (uint8_t*)p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, Config());
auto ret_ds = std::make_shared<Dataset>();
if (index_->metric_type == faiss::METRIC_Hamming) {
......@@ -90,14 +92,9 @@ BinaryIDMAP::Add(const DatasetPtr& dataset, const Config& config) {
void
BinaryIDMAP::Train(const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<BinIDMAPCfg>(config);
if (build_cfg == nullptr) {
KNOWHERE_THROW_MSG("not support this kind of config");
}
config->CheckValid();
const char* type = "BFlat";
auto index = faiss::index_binary_factory(config->d, type, GetMetricType(config->metric_type));
auto index = faiss::index_binary_factory(config[meta::DIM].get<int64_t>(), type,
GetMetricType(config[Metric::TYPE].get<std::string>()));
index_.reset(index);
}
......@@ -181,26 +178,18 @@ BinaryIDMAP::SearchById(const DatasetPtr& dataset, const Config& config) {
KNOWHERE_THROW_MSG("index not initialize");
}
// auto search_cfg = std::dynamic_pointer_cast<BinIDMAPCfg>(config);
// if (search_cfg == nullptr) {
// KNOWHERE_THROW_MSG("not support this kind of config");
// }
// GETBINARYTENSOR(dataset)
auto dim = dataset->Get<int64_t>(meta::DIM);
auto rows = dataset->Get<int64_t>(meta::ROWS);
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
auto elems = rows * config->k;
auto elems = rows * config[meta::TOPK].get<int64_t>();
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
auto* pdistances = (int32_t*)p_dist;
// index_->searchById(rows, (uint8_t*)p_data, config->k, pdistances, p_id, bitset_);
// auto blacklist = dataset->Get<faiss::ConcurrentBitsetPtr>("bitset");
index_->search_by_id(rows, p_data, config->k, pdistances, p_id, bitset_);
index_->search_by_id(rows, p_data, config[meta::TOPK].get<int64_t>(), pdistances, p_id, bitset_);
auto ret_ds = std::make_shared<Dataset>();
if (index_->metric_type == faiss::METRIC_Hamming) {
......
......@@ -15,9 +15,11 @@
#include <faiss/IndexBinaryIVF.h>
#include <chrono>
#include <string>
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
#include "knowhere/common/Log.h"
namespace knowhere {
......@@ -45,22 +47,17 @@ BinaryIVF::Search(const DatasetPtr& dataset, const Config& config) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
// auto search_cfg = std::dynamic_pointer_cast<IVFBinCfg>(config);
// if (search_cfg == nullptr) {
// KNOWHERE_THROW_MSG("not support this kind of config");
// }
GETBINARYTENSOR(dataset)
try {
auto elems = rows * config->k;
auto elems = rows * config[meta::TOPK].get<int64_t>();
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
search_impl(rows, (uint8_t*)p_data, config->k, p_dist, p_id, config);
search_impl(rows, (uint8_t*)p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, config);
auto ret_ds = std::make_shared<Dataset>();
if (index_->metric_type == faiss::METRIC_Hamming) {
......@@ -108,29 +105,20 @@ BinaryIVF::search_impl(int64_t n, const uint8_t* data, int64_t k, float* distanc
std::shared_ptr<faiss::IVFSearchParameters>
BinaryIVF::GenParams(const Config& config) {
auto params = std::make_shared<faiss::IVFSearchParameters>();
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
params->nprobe = search_cfg->nprobe;
// params->max_codes = config.get_with_default("max_codes", size_t(0));
params->nprobe = config[IndexParams::nprobe];
// params->max_codes = config["max_code"];
return params;
}
IndexModelPtr
BinaryIVF::Train(const DatasetPtr& dataset, const Config& config) {
std::lock_guard<std::mutex> lk(mutex_);
auto build_cfg = std::dynamic_pointer_cast<IVFBinCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
}
GETBINARYTENSOR(dataset)
auto p_ids = dataset->Get<const int64_t*>(meta::IDS);
faiss::IndexBinary* coarse_quantizer = new faiss::IndexBinaryFlat(dim, GetMetricType(build_cfg->metric_type));
auto index = std::make_shared<faiss::IndexBinaryIVF>(coarse_quantizer, dim, build_cfg->nlist,
GetMetricType(build_cfg->metric_type));
faiss::IndexBinary* coarse_quantizer =
new faiss::IndexBinaryFlat(dim, GetMetricType(config[Metric::TYPE].get<std::string>()));
auto index = std::make_shared<faiss::IndexBinaryIVF>(coarse_quantizer, dim, config[IndexParams::nlist],
GetMetricType(config[Metric::TYPE].get<std::string>()));
index->train(rows, (uint8_t*)p_data);
index->add_with_ids(rows, (uint8_t*)p_data, p_ids);
index_ = index;
......@@ -190,17 +178,11 @@ BinaryIVF::SearchById(const DatasetPtr& dataset, const Config& config) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
// auto search_cfg = std::dynamic_pointer_cast<IVFBinCfg>(config);
// if (search_cfg == nullptr) {
// KNOWHERE_THROW_MSG("not support this kind of config");
// }
// GETBINARYTENSOR(dataset)
auto rows = dataset->Get<int64_t>(meta::ROWS);
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
try {
auto elems = rows * config->k;
auto elems = rows * config[meta::TOPK].get<int64_t>();
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
......@@ -208,9 +190,7 @@ BinaryIVF::SearchById(const DatasetPtr& dataset, const Config& config) {
auto p_dist = (float*)malloc(p_dist_size);
int32_t* pdistances = (int32_t*)p_dist;
// auto blacklist = dataset->Get<faiss::ConcurrentBitsetPtr>("bitset");
// index_->searchById(rows, (uint8_t*)p_data, config->k, pdistances, p_id, blacklist);
index_->search_by_id(rows, p_data, config->k, pdistances, p_id, bitset_);
index_->search_by_id(rows, p_data, config[meta::TOPK].get<int64_t>(), pdistances, p_id, bitset_);
auto ret_ds = std::make_shared<Dataset>();
if (index_->metric_type == faiss::METRIC_Hamming) {
......
......@@ -16,6 +16,7 @@
#include <faiss/MetaIndexes.h>
#include <faiss/index_io.h>
#include <fiu-local.h>
#include <string>
#ifdef MILVUS_GPU_VERSION
......@@ -127,7 +128,7 @@ GPUIDMAP::GenGraph(const float* data, const int64_t& k, Graph& graph, const Conf
int64_t K = k + 1;
auto ntotal = Count();
size_t dim = config->d;
size_t dim = config[meta::DIM];
auto batch_size = 1000;
auto tail_batch_size = ntotal % batch_size;
auto batch_search_count = ntotal / batch_size;
......
......@@ -10,6 +10,7 @@
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <memory>
#include <string>
#include <faiss/gpu/GpuCloner.h>
#include <faiss/gpu/GpuIndexIVF.h>
......@@ -28,21 +29,16 @@ namespace knowhere {
IndexModelPtr
GPUIVF::Train(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
}
gpu_id_ = build_cfg->gpu_id;
GETTENSOR(dataset)
gpu_id_ = config[knowhere::meta::DEVICEID];
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
if (temp_resource != nullptr) {
ResScope rs(temp_resource, gpu_id_, true);
faiss::gpu::GpuIndexIVFFlatConfig idx_config;
idx_config.device = gpu_id_;
faiss::gpu::GpuIndexIVFFlat device_index(temp_resource->faiss_res.get(), dim, build_cfg->nlist,
GetMetricType(build_cfg->metric_type), idx_config);
faiss::gpu::GpuIndexIVFFlat device_index(temp_resource->faiss_res.get(), dim, config[IndexParams::nlist],
GetMetricType(config[Metric::TYPE].get<std::string>()), idx_config);
device_index.train(rows, (float*)p_data);
std::shared_ptr<faiss::Index> host_index = nullptr;
......@@ -121,15 +117,13 @@ GPUIVF::LoadImpl(const BinarySet& index_binary) {
}
void
GPUIVF::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) {
GPUIVF::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
std::lock_guard<std::mutex> lk(mutex_);
auto device_index = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_);
fiu_do_on("GPUIVF.search_impl.invald_index", device_index = nullptr);
if (device_index) {
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(cfg);
device_index->nprobe = search_cfg->nprobe;
// assert(device_index->getNumProbes() == search_cfg->nprobe);
device_index->nprobe = config[IndexParams::nprobe];
ResScope rs(res_, gpu_id_);
device_index->search(n, (float*)data, k, distances, labels);
} else {
......
......@@ -15,6 +15,7 @@
#include <faiss/index_factory.h>
#include <memory>
#include <string>
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
......@@ -25,20 +26,16 @@ namespace knowhere {
IndexModelPtr
GPUIVFPQ::Train(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<IVFPQCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
}
gpu_id_ = build_cfg->gpu_id;
GETTENSOR(dataset)
gpu_id_ = config[knowhere::meta::DEVICEID];
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
if (temp_resource != nullptr) {
ResScope rs(temp_resource, gpu_id_, true);
auto device_index = new faiss::gpu::GpuIndexIVFPQ(temp_resource->faiss_res.get(), dim, build_cfg->nlist,
build_cfg->m, build_cfg->nbits,
GetMetricType(build_cfg->metric_type)); // IP not support
auto device_index = new faiss::gpu::GpuIndexIVFPQ(
temp_resource->faiss_res.get(), dim, config[IndexParams::nlist].get<int64_t>(), config[IndexParams::m],
config[IndexParams::nbits],
GetMetricType(config[Metric::TYPE].get<std::string>())); // IP not support
device_index->train(rows, (float*)p_data);
std::shared_ptr<faiss::Index> host_index = nullptr;
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
......@@ -51,11 +48,10 @@ GPUIVFPQ::Train(const DatasetPtr& dataset, const Config& config) {
std::shared_ptr<faiss::IVFSearchParameters>
GPUIVFPQ::GenParams(const Config& config) {
auto params = std::make_shared<faiss::IVFPQSearchParameters>();
auto search_cfg = std::dynamic_pointer_cast<IVFPQCfg>(config);
params->nprobe = search_cfg->nprobe;
// params->scan_table_threshold = conf->scan_table_threhold;
// params->polysemous_ht = conf->polysemous_ht;
// params->max_codes = conf->max_codes;
params->nprobe = config[IndexParams::nprobe];
// params->scan_table_threshold = config["scan_table_threhold"]
// params->polysemous_ht = config["polysemous_ht"]
// params->max_codes = config["max_codes"]
return params;
}
......
......@@ -13,6 +13,7 @@
#include <faiss/index_factory.h>
#include <memory>
#include <string>
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
......@@ -23,18 +24,14 @@ namespace knowhere {
IndexModelPtr
GPUIVFSQ::Train(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<IVFSQCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
}
gpu_id_ = build_cfg->gpu_id;
GETTENSOR(dataset)
gpu_id_ = config[knowhere::meta::DEVICEID];
std::stringstream index_type;
index_type << "IVF" << build_cfg->nlist << ","
<< "SQ" << build_cfg->nbits;
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(build_cfg->metric_type));
index_type << "IVF" << config[IndexParams::nlist] << ","
<< "SQ" << config[IndexParams::nbits];
auto build_index =
faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(config[Metric::TYPE].get<std::string>()));
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
if (temp_resource != nullptr) {
......
......@@ -79,20 +79,15 @@ IndexHNSW::Search(const DatasetPtr& dataset, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
auto search_cfg = std::dynamic_pointer_cast<HNSWCfg>(config);
if (search_cfg == nullptr) {
KNOWHERE_THROW_MSG("search conf is null");
}
index_->setEf(search_cfg->ef);
GETTENSOR(dataset)
size_t id_size = sizeof(int64_t) * config->k;
size_t dist_size = sizeof(float) * config->k;
size_t id_size = sizeof(int64_t) * config[meta::TOPK].get<int64_t>();
size_t dist_size = sizeof(float) * config[meta::TOPK].get<int64_t>();
auto p_id = (int64_t*)malloc(id_size * rows);
auto p_dist = (float*)malloc(dist_size * rows);
index_->setEf(config[IndexParams::ef]);
using P = std::pair<float, int64_t>;
auto compare = [](const P& v1, const P& v2) { return v1.first < v2.first; };
#pragma omp parallel for
......@@ -103,13 +98,13 @@ IndexHNSW::Search(const DatasetPtr& dataset, const Config& config) {
// if (normalize) {
// std::vector<float> norm_vector(Dimension());
// normalize_vector((float*)(single_query), norm_vector.data(), Dimension());
// ret = index_->searchKnn((float*)(norm_vector.data()), config->k, compare);
// ret = index_->searchKnn((float*)(norm_vector.data()), config[meta::TOPK].get<int64_t>(), compare);
// } else {
// ret = index_->searchKnn((float*)single_query, config->k, compare);
// ret = index_->searchKnn((float*)single_query, config[meta::TOPK].get<int64_t>(), compare);
// }
ret = index_->searchKnn((float*)single_query, config->k, compare);
ret = index_->searchKnn((float*)single_query, config[meta::TOPK].get<int64_t>(), compare);
while (ret.size() < config->k) {
while (ret.size() < config[meta::TOPK]) {
ret.push_back(std::make_pair(-1, -1));
}
std::vector<float> dist;
......@@ -125,8 +120,8 @@ IndexHNSW::Search(const DatasetPtr& dataset, const Config& config) {
std::transform(ret.begin(), ret.end(), std::back_inserter(ids),
[](const std::pair<float, int64_t>& e) { return e.second; });
memcpy(p_dist + i * config->k, dist.data(), dist_size);
memcpy(p_id + i * config->k, ids.data(), id_size);
memcpy(p_dist + i * config[meta::TOPK].get<int64_t>(), dist.data(), dist_size);
memcpy(p_id + i * config[meta::TOPK].get<int64_t>(), ids.data(), id_size);
}
auto ret_ds = std::make_shared<Dataset>();
......@@ -137,21 +132,17 @@ IndexHNSW::Search(const DatasetPtr& dataset, const Config& config) {
IndexModelPtr
IndexHNSW::Train(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<HNSWCfg>(config);
if (build_cfg == nullptr) {
KNOWHERE_THROW_MSG("build conf is null");
}
GETTENSOR(dataset)
hnswlib::SpaceInterface<float>* space;
if (config->metric_type == METRICTYPE::L2) {
if (config[Metric::TYPE] == Metric::L2) {
space = new hnswlib::L2Space(dim);
} else if (config->metric_type == METRICTYPE::IP) {
} else if (config[Metric::TYPE] == Metric::IP) {
space = new hnswlib::InnerProductSpace(dim);
normalize = true;
}
index_ = std::make_shared<hnswlib::HierarchicalNSW<float>>(space, rows, build_cfg->M, build_cfg->ef);
index_ = std::make_shared<hnswlib::HierarchicalNSW<float>>(space, rows, config[IndexParams::M].get<int64_t>(),
config[IndexParams::efConstruction].get<int64_t>());
return nullptr;
}
......
......@@ -22,6 +22,7 @@
#endif
#include <string>
#include <vector>
#include "knowhere/adapter/VectorAdapter.h"
......@@ -61,13 +62,13 @@ IDMAP::Search(const DatasetPtr& dataset, const Config& config) {
}
GETTENSOR(dataset)
auto elems = rows * config->k;
auto elems = rows * config[meta::TOPK].get<int64_t>();
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
search_impl(rows, (float*)p_data, config->k, p_dist, p_id, Config());
search_impl(rows, (float*)p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, Config());
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
......@@ -144,10 +145,9 @@ IDMAP::GetRawIds() {
void
IDMAP::Train(const Config& config) {
config->CheckValid();
const char* type = "IDMap,Flat";
auto index = faiss::index_factory(config->d, type, GetMetricType(config->metric_type));
auto index = faiss::index_factory(config[meta::DIM].get<int64_t>(), type,
GetMetricType(config[Metric::TYPE].get<std::string>()));
index_.reset(index);
}
......@@ -214,7 +214,7 @@ IDMAP::SearchById(const DatasetPtr& dataset, const Config& config) {
auto rows = dataset->Get<int64_t>(meta::ROWS);
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
auto elems = rows * config->k;
auto elems = rows * config[meta::TOPK].get<int64_t>();
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = (int64_t*)malloc(p_id_size);
......@@ -222,8 +222,8 @@ IDMAP::SearchById(const DatasetPtr& dataset, const Config& config) {
// todo: enable search by id (zhiru)
// auto blacklist = dataset->Get<faiss::ConcurrentBitsetPtr>("bitset");
// index_->searchById(rows, (float*)p_data, config->k, p_dist, p_id, blacklist);
index_->search_by_id(rows, p_data, config->k, p_dist, p_id, bitset_);
// index_->searchById(rows, (float*)p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, blacklist);
index_->search_by_id(rows, p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, bitset_);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
......
......@@ -26,6 +26,7 @@
#include <fiu-local.h>
#include <chrono>
#include <memory>
#include <string>
#include <utility>
#include <vector>
......@@ -43,16 +44,11 @@ using stdclock = std::chrono::high_resolution_clock;
IndexModelPtr
IVF::Train(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
}
GETTENSOR(dataset)
faiss::Index* coarse_quantizer = new faiss::IndexFlatL2(dim);
auto index = std::make_shared<faiss::IndexIVFFlat>(coarse_quantizer, dim, build_cfg->nlist,
GetMetricType(build_cfg->metric_type));
auto index = std::make_shared<faiss::IndexIVFFlat>(coarse_quantizer, dim, config[IndexParams::nlist].get<int64_t>(),
GetMetricType(config[Metric::TYPE].get<std::string>()));
index->train(rows, (float*)p_data);
// TODO(linxj): override here. train return model or not.
......@@ -106,24 +102,19 @@ IVF::Search(const DatasetPtr& dataset, const Config& config) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
if (search_cfg == nullptr) {
KNOWHERE_THROW_MSG("not support this kind of config");
}
GETTENSOR(dataset)
try {
fiu_do_on("IVF.Search.throw_std_exception", throw std::exception());
fiu_do_on("IVF.Search.throw_faiss_exception", throw faiss::FaissException(""));
auto elems = rows * search_cfg->k;
auto elems = rows * config[meta::TOPK].get<int64_t>();
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
search_impl(rows, (float*)p_data, search_cfg->k, p_dist, p_id, config);
search_impl(rows, (float*)p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, config);
// std::stringstream ss_res_id, ss_res_dist;
// for (int i = 0; i < 10; ++i) {
......@@ -163,9 +154,8 @@ std::shared_ptr<faiss::IVFSearchParameters>
IVF::GenParams(const Config& config) {
auto params = std::make_shared<faiss::IVFSearchParameters>();
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
params->nprobe = search_cfg->nprobe;
// params->max_codes = config.get_with_default("max_codes", size_t(0));
params->nprobe = config[IndexParams::nprobe];
// params->max_codes = config["max_codes"];
return params;
}
......@@ -185,7 +175,7 @@ IVF::GenGraph(const float* data, const int64_t& k, Graph& graph, const Config& c
int64_t K = k + 1;
auto ntotal = Count();
size_t dim = config->d;
size_t dim = config[meta::DIM];
auto batch_size = 1000;
auto tail_batch_size = ntotal % batch_size;
auto batch_search_count = ntotal / batch_size;
......@@ -279,12 +269,6 @@ IVF::GetVectorById(const DatasetPtr& dataset, const Config& config) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
if (search_cfg == nullptr) {
KNOWHERE_THROW_MSG("not support this kind of config");
}
// auto rows = dataset->Get<int64_t>(meta::ROWS);
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
auto elems = dataset->Get<int64_t>(meta::DIM);
......@@ -311,16 +295,11 @@ IVF::SearchById(const DatasetPtr& dataset, const Config& config) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
if (search_cfg == nullptr) {
KNOWHERE_THROW_MSG("not support this kind of config");
}
auto rows = dataset->Get<int64_t>(meta::ROWS);
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
try {
auto elems = rows * search_cfg->k;
auto elems = rows * config[meta::TOPK].get<int64_t>();
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
......@@ -330,7 +309,7 @@ IVF::SearchById(const DatasetPtr& dataset, const Config& config) {
// todo: enable search by id (zhiru)
// auto blacklist = dataset->Get<faiss::ConcurrentBitsetPtr>("bitset");
auto index_ivf = std::static_pointer_cast<faiss::IndexIVF>(index_);
index_ivf->search_by_id(rows, p_data, search_cfg->k, p_dist, p_id, bitset_);
index_ivf->search_by_id(rows, p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, bitset_);
// std::stringstream ss_res_id, ss_res_dist;
// for (int i = 0; i < 10; ++i) {
......
......@@ -16,6 +16,7 @@
#endif
#include <memory>
#include <string>
#include <utility>
#include "knowhere/adapter/VectorAdapter.h"
......@@ -30,16 +31,12 @@ namespace knowhere {
IndexModelPtr
IVFPQ::Train(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<IVFPQCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
}
GETTENSOR(dataset)
faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, GetMetricType(build_cfg->metric_type));
auto index =
std::make_shared<faiss::IndexIVFPQ>(coarse_quantizer, dim, build_cfg->nlist, build_cfg->m, build_cfg->nbits);
faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, GetMetricType(config[Metric::TYPE].get<std::string>()));
auto index = std::make_shared<faiss::IndexIVFPQ>(coarse_quantizer, dim, config[IndexParams::nlist].get<int64_t>(),
config[IndexParams::m].get<int64_t>(),
config[IndexParams::nbits].get<int64_t>());
index->train(rows, (float*)p_data);
return std::make_shared<IVFIndexModel>(index);
......@@ -48,11 +45,10 @@ IVFPQ::Train(const DatasetPtr& dataset, const Config& config) {
std::shared_ptr<faiss::IVFSearchParameters>
IVFPQ::GenParams(const Config& config) {
auto params = std::make_shared<faiss::IVFPQSearchParameters>();
auto search_cfg = std::dynamic_pointer_cast<IVFPQCfg>(config);
params->nprobe = search_cfg->nprobe;
// params->scan_table_threshold = conf->scan_table_threhold;
// params->polysemous_ht = conf->polysemous_ht;
// params->max_codes = conf->max_codes;
params->nprobe = config[IndexParams::nprobe];
// params->scan_table_threshold = config["scan_table_threhold"]
// params->polysemous_ht = config["polysemous_ht"]
// params->max_codes = config["max_codes"]
return params;
}
......
......@@ -16,6 +16,7 @@
#include <faiss/index_factory.h>
#include <memory>
#include <string>
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
......@@ -30,17 +31,13 @@ namespace knowhere {
IndexModelPtr
IVFSQ::Train(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<IVFSQCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
}
GETTENSOR(dataset)
std::stringstream index_type;
index_type << "IVF" << build_cfg->nlist << ","
<< "SQ" << build_cfg->nbits;
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(build_cfg->metric_type));
index_type << "IVF" << config[IndexParams::nlist] << ","
<< "SQ" << config[IndexParams::nbits];
auto build_index =
faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(config[Metric::TYPE].get<std::string>()));
build_index->train(rows, (float*)p_data);
std::shared_ptr<faiss::Index> ret_index;
......
......@@ -19,6 +19,7 @@
#include <faiss/gpu/GpuIndexIVF.h>
#include <faiss/index_factory.h>
#include <fiu-local.h>
#include <string>
#include <utility>
namespace knowhere {
......@@ -30,19 +31,14 @@ namespace knowhere {
IndexModelPtr
IVFSQHybrid::Train(const DatasetPtr& dataset, const Config& config) {
// std::lock_guard<std::mutex> lk(g_mutex);
auto build_cfg = std::dynamic_pointer_cast<IVFSQCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
}
gpu_id_ = build_cfg->gpu_id;
GETTENSOR(dataset)
gpu_id_ = config[knowhere::meta::DEVICEID];
std::stringstream index_type;
index_type << "IVF" << build_cfg->nlist << ","
index_type << "IVF" << config[IndexParams::nlist] << ","
<< "SQ8Hybrid";
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(build_cfg->metric_type));
auto build_index =
faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(config[Metric::TYPE].get<std::string>()));
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
if (temp_resource != nullptr) {
......@@ -133,17 +129,10 @@ IVFSQHybrid::search_impl(int64_t n, const float* data, int64_t k, float* distanc
}
QuantizerPtr
IVFSQHybrid::LoadQuantizer(const Config& conf) {
IVFSQHybrid::LoadQuantizer(const Config& config) {
// std::lock_guard<std::mutex> lk(g_mutex);
auto quantizer_conf = std::dynamic_pointer_cast<QuantizerCfg>(conf);
if (quantizer_conf != nullptr) {
if (quantizer_conf->mode != 1) {
KNOWHERE_THROW_MSG("mode only support 1 in this func");
}
}
auto gpu_id = quantizer_conf->gpu_id;
auto gpu_id = config[knowhere::meta::DEVICEID].get<int64_t>();
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id)) {
ResScope rs(res, gpu_id, false);
faiss::gpu::GpuClonerOptions option;
......@@ -152,7 +141,7 @@ IVFSQHybrid::LoadQuantizer(const Config& conf) {
auto index_composition = new faiss::IndexComposition;
index_composition->index = index_.get();
index_composition->quantizer = nullptr;
index_composition->mode = quantizer_conf->mode; // only 1
index_composition->mode = 1; // only 1
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id, index_composition, &option);
delete gpu_index;
......@@ -205,19 +194,10 @@ IVFSQHybrid::UnsetQuantizer() {
}
VectorIndexPtr
IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) {
IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& config) {
// std::lock_guard<std::mutex> lk(g_mutex);
auto quantizer_conf = std::dynamic_pointer_cast<QuantizerCfg>(conf);
if (quantizer_conf != nullptr) {
if (quantizer_conf->mode != 2) {
KNOWHERE_THROW_MSG("mode only support 2 in this func");
}
} else {
KNOWHERE_THROW_MSG("conf error");
}
auto gpu_id = quantizer_conf->gpu_id;
int64_t gpu_id = config[knowhere::meta::DEVICEID];
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id)) {
ResScope rs(res, gpu_id, false);
......@@ -231,7 +211,7 @@ IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) {
auto index_composition = new faiss::IndexComposition;
index_composition->index = index_.get();
index_composition->quantizer = ivf_quantizer->quantizer;
index_composition->mode = quantizer_conf->mode; // only 2
index_composition->mode = 2; // only 2
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id, index_composition, &option);
std::shared_ptr<faiss::Index> new_idx;
......
......@@ -10,6 +10,7 @@
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/index/vector_index/IndexNSG.h"
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
#include "knowhere/common/Timer.h"
......@@ -23,6 +24,7 @@
#endif
#include <fiu-local.h>
#include "knowhere/index/vector_index/IndexIDMAP.h"
#include "knowhere/index/vector_index/IndexIVF.h"
#include "knowhere/index/vector_index/nsg/NSG.h"
......@@ -72,23 +74,21 @@ NSG::Load(const BinarySet& index_binary) {
DatasetPtr
NSG::Search(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<NSGCfg>(config);
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
GETTENSOR(dataset)
auto elems = rows * build_cfg->k;
auto elems = rows * config[meta::TOPK].get<int64_t>();
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
algo::SearchParams s_params;
s_params.search_length = build_cfg->search_length;
index_->Search((float*)p_data, rows, dim, build_cfg->k, p_dist, p_id, s_params);
s_params.search_length = config[IndexParams::search_length];
index_->Search((float*)p_data, rows, dim, config[meta::TOPK].get<int64_t>(), p_dist, p_id, s_params);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
......@@ -98,41 +98,35 @@ NSG::Search(const DatasetPtr& dataset, const Config& config) {
IndexModelPtr
NSG::Train(const DatasetPtr& dataset, const Config& config) {
config->Dump();
auto build_cfg = std::dynamic_pointer_cast<NSGCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
}
auto idmap = std::make_shared<IDMAP>();
idmap->Train(config);
idmap->AddWithoutId(dataset, config);
Graph knng;
const float* raw_data = idmap->GetRawVectors();
#ifdef MILVUS_GPU_VERSION
if (build_cfg->gpu_id == knowhere::INVALID_VALUE) {
if (config[knowhere::meta::DEVICEID].get<int64_t>() == -1) {
auto preprocess_index = std::make_shared<IVF>();
auto model = preprocess_index->Train(dataset, config);
preprocess_index->set_index_model(model);
preprocess_index->Add(dataset, config);
preprocess_index->GenGraph(raw_data, build_cfg->knng, knng, config);
preprocess_index->AddWithoutIds(dataset, config);
preprocess_index->GenGraph(raw_data, config[IndexParams::knng].get<int64_t>(), knng, config);
} else {
auto gpu_idx = cloner::CopyCpuToGpu(idmap, build_cfg->gpu_id, config);
auto gpu_idx = cloner::CopyCpuToGpu(idmap, config[knowhere::meta::DEVICEID].get<int64_t>(), config);
auto gpu_idmap = std::dynamic_pointer_cast<GPUIDMAP>(gpu_idx);
gpu_idmap->GenGraph(raw_data, build_cfg->knng, knng, config);
gpu_idmap->GenGraph(raw_data, config[IndexParams::knng].get<int64_t>(), knng, config);
}
#else
auto preprocess_index = std::make_shared<IVF>();
auto model = preprocess_index->Train(dataset, config);
preprocess_index->set_index_model(model);
preprocess_index->AddWithoutIds(dataset, config);
preprocess_index->GenGraph(raw_data, build_cfg->knng, knng, config);
preprocess_index->GenGraph(raw_data, config[IndexParams::knng].get<int64_t>(), knng, config);
#endif
algo::BuildParams b_params;
b_params.candidate_pool_size = build_cfg->candidate_pool_size;
b_params.out_degree = build_cfg->out_degree;
b_params.search_length = build_cfg->search_length;
b_params.candidate_pool_size = config[IndexParams::candidate];
b_params.out_degree = config[IndexParams::out_degree];
b_params.search_length = config[IndexParams::search_length];
auto p_ids = dataset->Get<const int64_t*>(meta::IDS);
......
......@@ -123,9 +123,6 @@ CPUSPTAGRNG::Load(const BinarySet& binary_set) {
IndexModelPtr
CPUSPTAGRNG::Train(const DatasetPtr& origin, const Config& train_config) {
SetParameters(train_config);
if (train_config != nullptr) {
train_config->CheckValid(); // throw exception
}
DatasetPtr dataset = origin; // TODO(linxj): copy or reference?
......@@ -159,62 +156,56 @@ CPUSPTAGRNG::Add(const DatasetPtr& origin, const Config& add_config) {
void
CPUSPTAGRNG::SetParameters(const Config& config) {
#define Assign(param_name, str_name) \
conf->param_name == INVALID_VALUE ? index_ptr_->SetParameter(str_name, std::to_string(build_cfg->param_name)) \
: index_ptr_->SetParameter(str_name, std::to_string(conf->param_name))
#define Assign(param_name, str_name) \
index_ptr_->SetParameter(str_name, std::to_string(build_cfg[param_name].get<int64_t>()))
if (index_type_ == SPTAG::IndexAlgoType::KDT) {
auto conf = std::dynamic_pointer_cast<KDTCfg>(config);
auto build_cfg = SPTAGParameterMgr::GetInstance().GetKDTParameters();
Assign(kdtnumber, "KDTNumber");
Assign(numtopdimensionkdtsplit, "NumTopDimensionKDTSplit");
Assign(samples, "Samples");
Assign(tptnumber, "TPTNumber");
Assign(tptleafsize, "TPTLeafSize");
Assign(numtopdimensiontptsplit, "NumTopDimensionTPTSplit");
Assign(neighborhoodsize, "NeighborhoodSize");
Assign(graphneighborhoodscale, "GraphNeighborhoodScale");
Assign(graphcefscale, "GraphCEFScale");
Assign(refineiterations, "RefineIterations");
Assign(cef, "CEF");
Assign(maxcheckforrefinegraph, "MaxCheckForRefineGraph");
Assign(numofthreads, "NumberOfThreads");
Assign(maxcheck, "MaxCheck");
Assign(thresholdofnumberofcontinuousnobetterpropagation, "ThresholdOfNumberOfContinuousNoBetterPropagation");
Assign(numberofinitialdynamicpivots, "NumberOfInitialDynamicPivots");
Assign(numberofotherdynamicpivots, "NumberOfOtherDynamicPivots");
Assign("kdtnumber", "KDTNumber");
Assign("numtopdimensionkdtsplit", "NumTopDimensionKDTSplit");
Assign("samples", "Samples");
Assign("tptnumber", "TPTNumber");
Assign("tptleafsize", "TPTLeafSize");
Assign("numtopdimensiontptsplit", "NumTopDimensionTPTSplit");
Assign("neighborhoodsize", "NeighborhoodSize");
Assign("graphneighborhoodscale", "GraphNeighborhoodScale");
Assign("graphcefscale", "GraphCEFScale");
Assign("refineiterations", "RefineIterations");
Assign("cef", "CEF");
Assign("maxcheckforrefinegraph", "MaxCheckForRefineGraph");
Assign("numofthreads", "NumberOfThreads");
Assign("maxcheck", "MaxCheck");
Assign("thresholdofnumberofcontinuousnobetterpropagation", "ThresholdOfNumberOfContinuousNoBetterPropagation");
Assign("numberofinitialdynamicpivots", "NumberOfInitialDynamicPivots");
Assign("numberofotherdynamicpivots", "NumberOfOtherDynamicPivots");
} else {
auto conf = std::dynamic_pointer_cast<BKTCfg>(config);
auto build_cfg = SPTAGParameterMgr::GetInstance().GetBKTParameters();
Assign(bktnumber, "BKTNumber");
Assign(bktkmeansk, "BKTKMeansK");
Assign(bktleafsize, "BKTLeafSize");
Assign(samples, "Samples");
Assign(tptnumber, "TPTNumber");
Assign(tptleafsize, "TPTLeafSize");
Assign(numtopdimensiontptsplit, "NumTopDimensionTPTSplit");
Assign(neighborhoodsize, "NeighborhoodSize");
Assign(graphneighborhoodscale, "GraphNeighborhoodScale");
Assign(graphcefscale, "GraphCEFScale");
Assign(refineiterations, "RefineIterations");
Assign(cef, "CEF");
Assign(maxcheckforrefinegraph, "MaxCheckForRefineGraph");
Assign(numofthreads, "NumberOfThreads");
Assign(maxcheck, "MaxCheck");
Assign(thresholdofnumberofcontinuousnobetterpropagation, "ThresholdOfNumberOfContinuousNoBetterPropagation");
Assign(numberofinitialdynamicpivots, "NumberOfInitialDynamicPivots");
Assign(numberofotherdynamicpivots, "NumberOfOtherDynamicPivots");
Assign("bktnumber", "BKTNumber");
Assign("bktkmeansk", "BKTKMeansK");
Assign("bktleafsize", "BKTLeafSize");
Assign("samples", "Samples");
Assign("tptnumber", "TPTNumber");
Assign("tptleafsize", "TPTLeafSize");
Assign("numtopdimensiontptsplit", "NumTopDimensionTPTSplit");
Assign("neighborhoodsize", "NeighborhoodSize");
Assign("graphneighborhoodscale", "GraphNeighborhoodScale");
Assign("graphcefscale", "GraphCEFScale");
Assign("refineiterations", "RefineIterations");
Assign("cef", "CEF");
Assign("maxcheckforrefinegraph", "MaxCheckForRefineGraph");
Assign("numofthreads", "NumberOfThreads");
Assign("maxcheck", "MaxCheck");
Assign("thresholdofnumberofcontinuousnobetterpropagation", "ThresholdOfNumberOfContinuousNoBetterPropagation");
Assign("numberofinitialdynamicpivots", "NumberOfInitialDynamicPivots");
Assign("numberofotherdynamicpivots", "NumberOfOtherDynamicPivots");
}
}
DatasetPtr
CPUSPTAGRNG::Search(const DatasetPtr& dataset, const Config& config) {
SetParameters(config);
// if (config != nullptr) {
// config->CheckValid(); // throw exception
// }
auto p_data = dataset->Get<const float*>(meta::TENSOR);
for (auto i = 0; i < 10; ++i) {
......
......@@ -23,9 +23,9 @@ struct Quantizer {
};
using QuantizerPtr = std::shared_ptr<Quantizer>;
struct QuantizerCfg : Cfg {
int64_t mode = -1; // 0: all data, 1: copy quantizer, 2: copy data
};
using QuantizerConfig = std::shared_ptr<QuantizerCfg>;
// struct QuantizerCfg : Cfg {
// int64_t mode = -1; // 0: all data, 1: copy quantizer, 2: copy data
// };
// using QuantizerConfig = std::shared_ptr<QuantizerCfg>;
} // namespace knowhere
......@@ -17,47 +17,23 @@
namespace knowhere {
faiss::MetricType
GetMetricType(METRICTYPE& type) {
if (type == METRICTYPE::L2) {
GetMetricType(const std::string& type) {
if (type == Metric::L2) {
return faiss::METRIC_L2;
}
if (type == METRICTYPE::IP) {
if (type == Metric::IP) {
return faiss::METRIC_INNER_PRODUCT;
}
// binary only
if (type == METRICTYPE::JACCARD) {
if (type == Metric::JACCARD) {
return faiss::METRIC_Jaccard;
}
if (type == METRICTYPE::TANIMOTO) {
if (type == Metric::TANIMOTO) {
return faiss::METRIC_Tanimoto;
}
if (type == METRICTYPE::HAMMING) {
if (type == Metric::HAMMING) {
return faiss::METRIC_Hamming;
}
KNOWHERE_THROW_MSG("Metric type is invalid");
}
std::stringstream
IVFCfg::DumpImpl() {
auto ss = Cfg::DumpImpl();
ss << ", nlist: " << nlist << ", nprobe: " << nprobe;
return ss;
}
std::stringstream
IVFSQCfg::DumpImpl() {
auto ss = IVFCfg::DumpImpl();
ss << ", nbits: " << nbits;
return ss;
}
std::stringstream
NSGCfg::DumpImpl() {
auto ss = IVFCfg::DumpImpl();
ss << ", knng: " << knng << ", search_length: " << search_length << ", out_degree: " << out_degree
<< ", candidate: " << candidate_pool_size;
return ss;
}
} // namespace knowhere
......@@ -12,240 +12,49 @@
#pragma once
#include <faiss/Index.h>
#include <memory>
#include "knowhere/common/Config.h"
#include <string>
namespace knowhere {
extern faiss::MetricType
GetMetricType(METRICTYPE& type);
// IVF Config
constexpr int64_t DEFAULT_NLIST = INVALID_VALUE;
constexpr int64_t DEFAULT_NPROBE = INVALID_VALUE;
constexpr int64_t DEFAULT_NSUBVECTORS = INVALID_VALUE;
constexpr int64_t DEFAULT_NBITS = INVALID_VALUE;
constexpr int64_t DEFAULT_SCAN_TABLE_THREHOLD = INVALID_VALUE;
constexpr int64_t DEFAULT_POLYSEMOUS_HT = INVALID_VALUE;
constexpr int64_t DEFAULT_MAX_CODES = INVALID_VALUE;
// NSG Config
constexpr int64_t DEFAULT_SEARCH_LENGTH = INVALID_VALUE;
constexpr int64_t DEFAULT_OUT_DEGREE = INVALID_VALUE;
constexpr int64_t DEFAULT_CANDIDATE_SISE = INVALID_VALUE;
constexpr int64_t DEFAULT_NNG_K = INVALID_VALUE;
// SPTAG Config
constexpr int64_t DEFAULT_SAMPLES = INVALID_VALUE;
constexpr int64_t DEFAULT_TPTNUMBER = INVALID_VALUE;
constexpr int64_t DEFAULT_TPTLEAFSIZE = INVALID_VALUE;
constexpr int64_t DEFAULT_NUMTOPDIMENSIONTPTSPLIT = INVALID_VALUE;
constexpr int64_t DEFAULT_NEIGHBORHOODSIZE = INVALID_VALUE;
constexpr int64_t DEFAULT_GRAPHNEIGHBORHOODSCALE = INVALID_VALUE;
constexpr int64_t DEFAULT_GRAPHCEFSCALE = INVALID_VALUE;
constexpr int64_t DEFAULT_REFINEITERATIONS = INVALID_VALUE;
constexpr int64_t DEFAULT_CEF = INVALID_VALUE;
constexpr int64_t DEFAULT_MAXCHECKFORREFINEGRAPH = INVALID_VALUE;
constexpr int64_t DEFAULT_NUMOFTHREADS = INVALID_VALUE;
constexpr int64_t DEFAULT_MAXCHECK = INVALID_VALUE;
constexpr int64_t DEFAULT_THRESHOLDOFNUMBEROFCONTINUOUSNOBETTERPROPAGATION = INVALID_VALUE;
constexpr int64_t DEFAULT_NUMBEROFINITIALDYNAMICPIVOTS = INVALID_VALUE;
constexpr int64_t DEFAULT_NUMBEROFOTHERDYNAMICPIVOTS = INVALID_VALUE;
// KDT Config
constexpr int64_t DEFAULT_KDTNUMBER = INVALID_VALUE;
constexpr int64_t DEFAULT_NUMTOPDIMENSIONKDTSPLIT = INVALID_VALUE;
// BKT Config
constexpr int64_t DEFAULT_BKTNUMBER = INVALID_VALUE;
constexpr int64_t DEFAULT_BKTKMEANSK = INVALID_VALUE;
constexpr int64_t DEFAULT_BKTLEAFSIZE = INVALID_VALUE;
// HNSW Config
constexpr int64_t DEFAULT_M = INVALID_VALUE;
constexpr int64_t DEFAULT_EF = INVALID_VALUE;
struct IVFCfg : public Cfg {
int64_t nlist = DEFAULT_NLIST;
int64_t nprobe = DEFAULT_NPROBE;
IVFCfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, const int64_t& nlist, const int64_t& nprobe,
METRICTYPE type)
: Cfg(dim, k, gpu_id, type), nlist(nlist), nprobe(nprobe) {
}
IVFCfg() = default;
std::stringstream
DumpImpl() override;
// bool
// CheckValid() override {
// return true;
// };
};
using IVFConfig = std::shared_ptr<IVFCfg>;
struct IVFBinCfg : public IVFCfg {
bool
CheckValid() override {
if (metric_type == METRICTYPE::HAMMING || metric_type == METRICTYPE::TANIMOTO ||
metric_type == METRICTYPE::JACCARD) {
return true;
}
std::stringstream ss;
ss << "MetricType: " << int(metric_type) << " not support!";
KNOWHERE_THROW_MSG(ss.str());
return false;
}
};
struct IVFSQCfg : public IVFCfg {
// TODO(linxj): cpu only support SQ4 SQ6 SQ8 SQ16, gpu only support SQ4, SQ8, SQ16
int64_t nbits = DEFAULT_NBITS;
IVFSQCfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, const int64_t& nlist, const int64_t& nprobe,
const int64_t& nbits, METRICTYPE type)
: IVFCfg(dim, k, gpu_id, nlist, nprobe, type), nbits(nbits) {
}
std::stringstream
DumpImpl() override;
IVFSQCfg() = default;
// bool
// CheckValid() override {
// return true;
// };
};
using IVFSQConfig = std::shared_ptr<IVFSQCfg>;
struct IVFPQCfg : public IVFCfg {
int64_t m = DEFAULT_NSUBVECTORS; // number of subquantizers(subvector)
int64_t nbits = DEFAULT_NBITS; // number of bit per subvector index
// TODO(linxj): not use yet
int64_t scan_table_threhold = DEFAULT_SCAN_TABLE_THREHOLD;
int64_t polysemous_ht = DEFAULT_POLYSEMOUS_HT;
int64_t max_codes = DEFAULT_MAX_CODES;
namespace meta {
constexpr const char* DIM = "dim";
constexpr const char* TENSOR = "tensor";
constexpr const char* ROWS = "rows";
constexpr const char* IDS = "ids";
constexpr const char* DISTANCE = "distance";
constexpr const char* TOPK = "k";
constexpr const char* DEVICEID = "gpu_id";
}; // namespace meta
namespace IndexParams {
// IVF Params
constexpr const char* nprobe = "nprobe";
constexpr const char* nlist = "nlist";
constexpr const char* m = "m"; // PQ
constexpr const char* nbits = "nbits"; // PQ/SQ
// NSG Params
constexpr const char* knng = "knng";
constexpr const char* search_length = "search_length";
constexpr const char* out_degree = "out_degree";
constexpr const char* candidate = "candidate_pool_size";
// HNSW Params
constexpr const char* efConstruction = "efConstruction";
constexpr const char* M = "M";
constexpr const char* ef = "ef";
} // namespace IndexParams
namespace Metric {
constexpr const char* TYPE = "metric_type";
constexpr const char* IP = "IP";
constexpr const char* L2 = "L2";
constexpr const char* HAMMING = "HAMMING";
constexpr const char* JACCARD = "JACCARD";
constexpr const char* TANIMOTO = "TANIMOTO";
} // namespace Metric
IVFPQCfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, const int64_t& nlist, const int64_t& nprobe,
const int64_t& nbits, const int64_t& m, METRICTYPE type)
: IVFCfg(dim, k, gpu_id, nlist, nprobe, type), m(m), nbits(nbits) {
}
IVFPQCfg() = default;
// bool
// CheckValid() override {
// return true;
// };
};
using IVFPQConfig = std::shared_ptr<IVFPQCfg>;
struct NSGCfg : public IVFCfg {
int64_t knng = DEFAULT_NNG_K;
int64_t search_length = DEFAULT_SEARCH_LENGTH;
int64_t out_degree = DEFAULT_OUT_DEGREE;
int64_t candidate_pool_size = DEFAULT_CANDIDATE_SISE;
NSGCfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, const int64_t& nlist, const int64_t& nprobe,
const int64_t& knng, const int64_t& search_length, const int64_t& out_degree, const int64_t& candidate_size,
METRICTYPE type)
: IVFCfg(dim, k, gpu_id, nlist, nprobe, type),
knng(knng),
search_length(search_length),
out_degree(out_degree),
candidate_pool_size(candidate_size) {
}
NSGCfg() = default;
std::stringstream
DumpImpl() override;
// bool
// CheckValid() override {
// return true;
// };
};
using NSGConfig = std::shared_ptr<NSGCfg>;
struct SPTAGCfg : public Cfg {
int64_t samples = DEFAULT_SAMPLES;
int64_t tptnumber = DEFAULT_TPTNUMBER;
int64_t tptleafsize = DEFAULT_TPTLEAFSIZE;
int64_t numtopdimensiontptsplit = DEFAULT_NUMTOPDIMENSIONTPTSPLIT;
int64_t neighborhoodsize = DEFAULT_NEIGHBORHOODSIZE;
int64_t graphneighborhoodscale = DEFAULT_GRAPHNEIGHBORHOODSCALE;
int64_t graphcefscale = DEFAULT_GRAPHCEFSCALE;
int64_t refineiterations = DEFAULT_REFINEITERATIONS;
int64_t cef = DEFAULT_CEF;
int64_t maxcheckforrefinegraph = DEFAULT_MAXCHECKFORREFINEGRAPH;
int64_t numofthreads = DEFAULT_NUMOFTHREADS;
int64_t maxcheck = DEFAULT_MAXCHECK;
int64_t thresholdofnumberofcontinuousnobetterpropagation = DEFAULT_THRESHOLDOFNUMBEROFCONTINUOUSNOBETTERPROPAGATION;
int64_t numberofinitialdynamicpivots = DEFAULT_NUMBEROFINITIALDYNAMICPIVOTS;
int64_t numberofotherdynamicpivots = DEFAULT_NUMBEROFOTHERDYNAMICPIVOTS;
SPTAGCfg() = default;
// bool
// CheckValid() override {
// return true;
// };
};
using SPTAGConfig = std::shared_ptr<SPTAGCfg>;
struct KDTCfg : public SPTAGCfg {
int64_t kdtnumber = DEFAULT_KDTNUMBER;
int64_t numtopdimensionkdtsplit = DEFAULT_NUMTOPDIMENSIONKDTSPLIT;
KDTCfg() = default;
// bool
// CheckValid() override {
// return true;
// };
};
using KDTConfig = std::shared_ptr<KDTCfg>;
struct BKTCfg : public SPTAGCfg {
int64_t bktnumber = DEFAULT_BKTNUMBER;
int64_t bktkmeansk = DEFAULT_BKTKMEANSK;
int64_t bktleafsize = DEFAULT_BKTLEAFSIZE;
BKTCfg() = default;
// bool
// CheckValid() override {
// return true;
// };
};
using BKTConfig = std::shared_ptr<BKTCfg>;
struct BinIDMAPCfg : public Cfg {
bool
CheckValid() override {
if (metric_type == METRICTYPE::HAMMING || metric_type == METRICTYPE::TANIMOTO ||
metric_type == METRICTYPE::JACCARD) {
return true;
}
std::stringstream ss;
ss << "MetricType: " << int(metric_type) << " not support!";
KNOWHERE_THROW_MSG(ss.str());
return false;
}
};
struct HNSWCfg : public Cfg {
int64_t M = DEFAULT_M;
int64_t ef = DEFAULT_EF;
HNSWCfg() = default;
};
using HNSWConfig = std::shared_ptr<HNSWCfg>;
extern faiss::MetricType
GetMetricType(const std::string& type);
} // namespace knowhere
......@@ -15,55 +15,53 @@
namespace knowhere {
const KDTConfig&
const Config&
SPTAGParameterMgr::GetKDTParameters() {
return kdt_config_;
}
const BKTConfig&
const Config&
SPTAGParameterMgr::GetBKTParameters() {
return bkt_config_;
}
SPTAGParameterMgr::SPTAGParameterMgr() {
kdt_config_ = std::make_shared<KDTCfg>();
kdt_config_->kdtnumber = 1;
kdt_config_->numtopdimensionkdtsplit = 5;
kdt_config_->samples = 100;
kdt_config_->tptnumber = 1;
kdt_config_->tptleafsize = 2000;
kdt_config_->numtopdimensiontptsplit = 5;
kdt_config_->neighborhoodsize = 32;
kdt_config_->graphneighborhoodscale = 2;
kdt_config_->graphcefscale = 2;
kdt_config_->refineiterations = 0;
kdt_config_->cef = 1000;
kdt_config_->maxcheckforrefinegraph = 10000;
kdt_config_->numofthreads = 1;
kdt_config_->maxcheck = 8192;
kdt_config_->thresholdofnumberofcontinuousnobetterpropagation = 3;
kdt_config_->numberofinitialdynamicpivots = 50;
kdt_config_->numberofotherdynamicpivots = 4;
kdt_config_["kdtnumber"] = 1;
kdt_config_["numtopdimensionkdtsplit"] = 5;
kdt_config_["samples"] = 100;
kdt_config_["tptnumber"] = 1;
kdt_config_["tptleafsize"] = 2000;
kdt_config_["numtopdimensiontptsplit"] = 5;
kdt_config_["neighborhoodsize"] = 32;
kdt_config_["graphneighborhoodscale"] = 2;
kdt_config_["graphcefscale"] = 2;
kdt_config_["refineiterations"] = 0;
kdt_config_["cef"] = 1000;
kdt_config_["maxcheckforrefinegraph"] = 10000;
kdt_config_["numofthreads"] = 1;
kdt_config_["maxcheck"] = 8192;
kdt_config_["thresholdofnumberofcontinuousnobetterpropagation"] = 3;
kdt_config_["numberofinitialdynamicpivots"] = 50;
kdt_config_["numberofotherdynamicpivots"] = 4;
bkt_config_ = std::make_shared<BKTCfg>();
bkt_config_->bktnumber = 1;
bkt_config_->bktkmeansk = 32;
bkt_config_->bktleafsize = 8;
bkt_config_->samples = 100;
bkt_config_->tptnumber = 1;
bkt_config_->tptleafsize = 2000;
bkt_config_->numtopdimensiontptsplit = 5;
bkt_config_->neighborhoodsize = 32;
bkt_config_->graphneighborhoodscale = 2;
bkt_config_->graphcefscale = 2;
bkt_config_->refineiterations = 0;
bkt_config_->cef = 1000;
bkt_config_->maxcheckforrefinegraph = 10000;
bkt_config_->numofthreads = 1;
bkt_config_->maxcheck = 8192;
bkt_config_->thresholdofnumberofcontinuousnobetterpropagation = 3;
bkt_config_->numberofinitialdynamicpivots = 50;
bkt_config_->numberofotherdynamicpivots = 4;
bkt_config_["bktnumber"] = 1;
bkt_config_["bktkmeansk"] = 32;
bkt_config_["bktleafsize"] = 8;
bkt_config_["samples"] = 100;
bkt_config_["tptnumber"] = 1;
bkt_config_["tptleafsize"] = 2000;
bkt_config_["numtopdimensiontptsplit"] = 5;
bkt_config_["neighborhoodsize"] = 32;
bkt_config_["graphneighborhoodscale"] = 2;
bkt_config_["graphcefscale"] = 2;
bkt_config_["refineiterations"] = 0;
bkt_config_["cef"] = 1000;
bkt_config_["maxcheckforrefinegraph"] = 10000;
bkt_config_["numofthreads"] = 1;
bkt_config_["maxcheck"] = 8192;
bkt_config_["thresholdofnumberofcontinuousnobetterpropagation"] = 3;
bkt_config_["numberofinitialdynamicpivots"] = 50;
bkt_config_["numberofotherdynamicpivots"] = 4;
}
} // namespace knowhere
......@@ -18,18 +18,15 @@
#include <SPTAG/AnnService/inc/Core/Common.h>
#include "IndexParameter.h"
#include "knowhere/common/Config.h"
namespace knowhere {
using KDTConfig = std::shared_ptr<KDTCfg>;
using BKTConfig = std::shared_ptr<BKTCfg>;
class SPTAGParameterMgr {
public:
const KDTConfig&
const Config&
GetKDTParameters();
const BKTConfig&
const Config&
GetBKTParameters();
public:
......@@ -48,8 +45,8 @@ class SPTAGParameterMgr {
SPTAGParameterMgr();
private:
KDTConfig kdt_config_;
BKTConfig bkt_config_;
Config kdt_config_;
Config bkt_config_;
};
} // namespace knowhere
......@@ -29,16 +29,9 @@ namespace algo {
unsigned int seed = 100;
NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, METRICTYPE metric)
NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, const std::string& metric)
: dimension(dimension), ntotal(n), metric_type(metric) {
switch (metric) {
case METRICTYPE::L2:
distance_ = new DistanceL2;
break;
case METRICTYPE::IP:
distance_ = new DistanceIP;
break;
}
distance_ = new DistanceL2; // hardcode here
}
NsgIndex::~NsgIndex() {
......
......@@ -13,6 +13,7 @@
#include <cstddef>
#include <mutex>
#include <string>
#include <vector>
#include <boost/dynamic_bitset.hpp>
......@@ -41,8 +42,8 @@ using Graph = std::vector<std::vector<node_t>>;
class NsgIndex {
public:
size_t dimension;
size_t ntotal; // totabl nb of indexed vectors
METRICTYPE metric_type; // L2 | IP
size_t ntotal; // totabl nb of indexed vectors
std::string metric_type; // todo(linxj) IP
Distance* distance_;
float* ori_data_;
......@@ -62,7 +63,7 @@ class NsgIndex {
size_t out_degree;
public:
explicit NsgIndex(const size_t& dimension, const size_t& n, METRICTYPE metric = METRICTYPE::L2);
explicit NsgIndex(const size_t& dimension, const size_t& n, const std::string& metric = "L2");
NsgIndex() = default;
......
......@@ -14,10 +14,6 @@
#include <omp.h>
#ifdef __SSE__
#include <immintrin.h>
#endif
#include <faiss/utils/utils.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/AuxIndexStructures.h>
......
......@@ -30,7 +30,6 @@ set(util_srcs
${MILVUS_THIRDPARTY_SRC}/easyloggingpp/easylogging++.cc
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/FaissIO.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/IndexParameter.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/adapter/VectorAdapter.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/common/Exception.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/common/Timer.cpp
${INDEX_SOURCE_DIR}/unittest/utils.cpp
......
......@@ -72,35 +72,32 @@ class ParamGenerator {
knowhere::Config
Gen(const ParameterType& type) {
if (type == ParameterType::ivf) {
auto tempconf = std::make_shared<knowhere::IVFCfg>();
tempconf->d = DIM;
tempconf->gpu_id = DEVICEID;
tempconf->nlist = 100;
tempconf->nprobe = 4;
tempconf->k = K;
tempconf->metric_type = knowhere::METRICTYPE::L2;
return tempconf;
return knowhere::Config{
{knowhere::meta::DIM, DIM},
{knowhere::meta::TOPK, K},
{knowhere::IndexParams::nlist, 100},
{knowhere::IndexParams::nprobe, 4},
{knowhere::Metric::TYPE, knowhere::Metric::L2},
{knowhere::meta::DEVICEID, DEVICEID},
};
} else if (type == ParameterType::ivfpq) {
auto tempconf = std::make_shared<knowhere::IVFPQCfg>();
tempconf->d = DIM;
tempconf->gpu_id = DEVICEID;
tempconf->nlist = 100;
tempconf->nprobe = 4;
tempconf->k = K;
tempconf->m = 4;
tempconf->nbits = 8;
tempconf->metric_type = knowhere::METRICTYPE::L2;
return tempconf;
return knowhere::Config{
{knowhere::meta::DIM, DIM},
{knowhere::meta::TOPK, K},
{knowhere::IndexParams::nlist, 100},
{knowhere::IndexParams::nprobe, 4},
{knowhere::IndexParams::m, 4},
{knowhere::IndexParams::nbits, 8},
{knowhere::Metric::TYPE, knowhere::Metric::L2},
{knowhere::meta::DEVICEID, DEVICEID},
};
} else if (type == ParameterType::ivfsq) {
auto tempconf = std::make_shared<knowhere::IVFSQCfg>();
tempconf->d = DIM;
tempconf->gpu_id = DEVICEID;
tempconf->nlist = 100;
tempconf->nprobe = 4;
tempconf->k = K;
tempconf->nbits = 8;
tempconf->metric_type = knowhere::METRICTYPE::L2;
return tempconf;
return knowhere::Config{
{knowhere::meta::DIM, DIM}, {knowhere::meta::TOPK, K},
{knowhere::IndexParams::nlist, 100}, {knowhere::IndexParams::nprobe, 4},
{knowhere::IndexParams::nbits, 8}, {knowhere::Metric::TYPE, knowhere::Metric::L2},
{knowhere::meta::DEVICEID, DEVICEID},
};
}
}
};
......
......@@ -21,7 +21,7 @@ using ::testing::Combine;
using ::testing::TestWithParam;
using ::testing::Values;
class BinaryIDMAPTest : public BinaryDataGen, public TestWithParam<knowhere::METRICTYPE> {
class BinaryIDMAPTest : public BinaryDataGen, public TestWithParam<std::string> {
protected:
void
SetUp() override {
......@@ -37,17 +37,17 @@ class BinaryIDMAPTest : public BinaryDataGen, public TestWithParam<knowhere::MET
};
INSTANTIATE_TEST_CASE_P(METRICParameters, BinaryIDMAPTest,
Values(knowhere::METRICTYPE::JACCARD, knowhere::METRICTYPE::TANIMOTO,
knowhere::METRICTYPE::HAMMING));
Values(std::string("JACCARD"), std::string("TANIMOTO"), std::string("HAMMING")));
TEST_P(BinaryIDMAPTest, binaryidmap_basic) {
ASSERT_TRUE(!xb.empty());
knowhere::METRICTYPE MetricType = GetParam();
auto conf = std::make_shared<knowhere::BinIDMAPCfg>();
conf->d = dim;
conf->k = k;
conf->metric_type = MetricType;
std::string MetricType = GetParam();
knowhere::Config conf{
{knowhere::meta::DIM, dim},
{knowhere::meta::TOPK, k},
{knowhere::Metric::TYPE, MetricType},
};
index_->Train(conf);
index_->Add(base_dataset, conf);
......@@ -88,11 +88,12 @@ TEST_P(BinaryIDMAPTest, binaryidmap_serialize) {
reader(ret, bin->size);
};
knowhere::METRICTYPE MetricType = GetParam();
auto conf = std::make_shared<knowhere::BinIDMAPCfg>();
conf->d = dim;
conf->k = k;
conf->metric_type = MetricType;
std::string MetricType = GetParam();
knowhere::Config conf{
{knowhere::meta::DIM, dim},
{knowhere::meta::TOPK, k},
{knowhere::Metric::TYPE, MetricType},
};
{
// serialize index
......
......@@ -27,25 +27,24 @@ using ::testing::Combine;
using ::testing::TestWithParam;
using ::testing::Values;
class BinaryIVFTest : public BinaryDataGen, public TestWithParam<knowhere::METRICTYPE> {
class BinaryIVFTest : public BinaryDataGen, public TestWithParam<std::string> {
protected:
void
SetUp() override {
knowhere::METRICTYPE MetricType = GetParam();
std::string MetricType = GetParam();
Init_with_binary_default();
// nb = 1000000;
// nq = 1000;
// k = 1000;
// Generate(DIM, NB, NQ);
index_ = std::make_shared<knowhere::BinaryIVF>();
auto x_conf = std::make_shared<knowhere::IVFBinCfg>();
x_conf->d = dim;
x_conf->k = k;
x_conf->metric_type = MetricType;
x_conf->nlist = 100;
x_conf->nprobe = 10;
conf = x_conf;
conf->Dump();
knowhere::Config temp_conf{
{knowhere::meta::DIM, dim}, {knowhere::meta::TOPK, k},
{knowhere::IndexParams::nlist, 100}, {knowhere::IndexParams::nprobe, 10},
{knowhere::Metric::TYPE, MetricType},
};
conf = temp_conf;
}
void
......@@ -59,8 +58,7 @@ class BinaryIVFTest : public BinaryDataGen, public TestWithParam<knowhere::METRI
};
INSTANTIATE_TEST_CASE_P(METRICParameters, BinaryIVFTest,
Values(knowhere::METRICTYPE::JACCARD, knowhere::METRICTYPE::TANIMOTO,
knowhere::METRICTYPE::HAMMING));
Values(std::string("JACCARD"), std::string("TANIMOTO"), std::string("HAMMING")));
TEST_P(BinaryIVFTest, binaryivf_basic) {
assert(!xb.empty());
......@@ -75,7 +73,7 @@ TEST_P(BinaryIVFTest, binaryivf_basic) {
EXPECT_EQ(index_->Dimension(), dim);
auto result = index_->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
// PrintResult(result, nq, k);
faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = std::make_shared<faiss::ConcurrentBitset>(nb);
......@@ -123,7 +121,7 @@ TEST_P(BinaryIVFTest, binaryivf_serialize) {
// index_->set_index_model(model);
// index_->Add(base_dataset, conf);
// auto result = index_->Search(query_dataset, conf);
// AssertAnns(result, nq, conf->k);
// AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
// }
{
......@@ -147,7 +145,7 @@ TEST_P(BinaryIVFTest, binaryivf_serialize) {
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dimension(), dim);
auto result = index_->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
// PrintResult(result, nq, k);
}
}
......@@ -69,7 +69,7 @@ TEST_F(SingleIndexTest, IVFSQHybrid) {
for (int i = 0; i < 3; ++i) {
auto gpu_idx = cpu_idx->CopyCpuToGpu(DEVICEID, conf);
auto result = gpu_idx->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
// PrintResult(result, nq, k);
}
}
......@@ -86,30 +86,18 @@ TEST_F(SingleIndexTest, IVFSQHybrid) {
auto quantization = pair.second;
auto result = gpu_idx->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
// PrintResult(result, nq, k);
auto quantizer_conf = std::make_shared<knowhere::QuantizerCfg>();
quantizer_conf->mode = 2; // only copy data
quantizer_conf->gpu_id = DEVICEID;
milvus::json quantizer_conf{{knowhere::meta::DEVICEID, DEVICEID}, {"mode", 2}};
for (int i = 0; i < 2; ++i) {
auto hybrid_idx = std::make_shared<knowhere::IVFSQHybrid>(DEVICEID);
hybrid_idx->Load(binaryset);
auto new_idx = hybrid_idx->LoadData(quantization, quantizer_conf);
auto result = new_idx->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
// PrintResult(result, nq, k);
}
{
// invalid quantizer config
quantizer_conf = std::make_shared<knowhere::QuantizerCfg>();
auto hybrid_idx = std::make_shared<knowhere::IVFSQHybrid>(DEVICEID);
ASSERT_ANY_THROW(hybrid_idx->LoadData(quantization, nullptr));
ASSERT_ANY_THROW(hybrid_idx->LoadData(quantization, quantizer_conf));
quantizer_conf->mode = 2; // only copy data
ASSERT_ANY_THROW(hybrid_idx->LoadData(quantization, quantizer_conf));
}
}
{
......@@ -126,7 +114,7 @@ TEST_F(SingleIndexTest, IVFSQHybrid) {
hybrid_idx->SetQuantizer(quantization);
auto result = hybrid_idx->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
// PrintResult(result, nq, k);
hybrid_idx->UnsetQuantizer();
}
......
......@@ -45,10 +45,8 @@ class IDMAPTest : public DataGen, public TestGpuIndexBase {
TEST_F(IDMAPTest, idmap_basic) {
ASSERT_TRUE(!xb.empty());
auto conf = std::make_shared<knowhere::Cfg>();
conf->d = dim;
conf->k = k;
conf->metric_type = knowhere::METRICTYPE::L2;
knowhere::Config conf{
{knowhere::meta::DIM, dim}, {knowhere::meta::TOPK, k}, {knowhere::Metric::TYPE, knowhere::Metric::L2}};
// null faiss index
{
......@@ -107,10 +105,8 @@ TEST_F(IDMAPTest, idmap_serialize) {
reader(ret, bin->size);
};
auto conf = std::make_shared<knowhere::Cfg>();
conf->d = dim;
conf->k = k;
conf->metric_type = knowhere::METRICTYPE::L2;
knowhere::Config conf{
{knowhere::meta::DIM, dim}, {knowhere::meta::TOPK, k}, {knowhere::Metric::TYPE, knowhere::Metric::L2}};
{
// serialize index
......@@ -146,10 +142,8 @@ TEST_F(IDMAPTest, idmap_serialize) {
TEST_F(IDMAPTest, copy_test) {
ASSERT_TRUE(!xb.empty());
auto conf = std::make_shared<knowhere::Cfg>();
conf->d = dim;
conf->k = k;
conf->metric_type = knowhere::METRICTYPE::L2;
knowhere::Config conf{
{knowhere::meta::DIM, dim}, {knowhere::meta::TOPK, k}, {knowhere::Metric::TYPE, knowhere::Metric::L2}};
index_->Train(conf);
index_->Add(base_dataset, conf);
......
......@@ -62,7 +62,7 @@ class IVFTest : public DataGen, public TestWithParam<::std::tuple<std::string, P
Generate(DIM, NB, NQ);
index_ = IndexFactory(index_type);
conf = ParamGenerator::GetInstance().Gen(parameter_type_);
conf->Dump();
// KNOWHERE_LOG_DEBUG << "conf: " << conf->dump();
}
void
......@@ -109,7 +109,7 @@ TEST_P(IVFTest, ivf_basic) {
EXPECT_EQ(index_->Dimension(), dim);
auto result = index_->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
// PrintResult(result, nq, k);
if (index_type.find("GPU") == std::string::npos && index_type.find("Hybrid") == std::string::npos &&
......@@ -190,7 +190,7 @@ TEST_P(IVFTest, ivf_serialize) {
index_->set_index_model(model);
index_->Add(base_dataset, conf);
auto result = index_->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
}
{
......@@ -214,7 +214,7 @@ TEST_P(IVFTest, ivf_serialize) {
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dimension(), dim);
auto result = index_->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
}
}
......@@ -232,7 +232,7 @@ TEST_P(IVFTest, clone_test) {
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dimension(), dim);
auto result = index_->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
// PrintResult(result, nq, k);
auto AssertEqual = [&](knowhere::DatasetPtr p1, knowhere::DatasetPtr p2) {
......@@ -254,7 +254,7 @@ TEST_P(IVFTest, clone_test) {
// EXPECT_NO_THROW({
// auto clone_index = index_->Clone();
// auto clone_result = clone_index->Search(query_dataset, conf);
// //AssertAnns(result, nq, conf->k);
// //AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
// AssertEqual(result, clone_result);
// std::cout << "inplace clone [" << index_type << "] success" << std::endl;
// });
......@@ -339,7 +339,7 @@ TEST_P(IVFTest, gpu_seal_test) {
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dimension(), dim);
auto result = index_->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
fiu_init(0);
fiu_enable("IVF.Search.throw_std_exception", 1, nullptr, 0);
......@@ -374,7 +374,7 @@ TEST_P(IVFTest, invalid_gpu_source) {
}
auto invalid_conf = ParamGenerator::GetInstance().Gen(parameter_type_);
invalid_conf->gpu_id = -1;
invalid_conf[knowhere::meta::DEVICEID] = -1;
if (index_type == "GPUIVF") {
// null faiss index
......@@ -430,15 +430,6 @@ TEST_P(IVFTest, IVFSQHybrid_test) {
ASSERT_TRUE(index != nullptr);
ASSERT_ANY_THROW(index->UnsetQuantizer());
knowhere::QuantizerConfig config = std::make_shared<knowhere::QuantizerCfg>();
config->gpu_id = knowhere::INVALID_VALUE;
// mode = -1
ASSERT_ANY_THROW(index->LoadQuantizer(config));
config->mode = 1;
ASSERT_ANY_THROW(index->LoadQuantizer(config));
config->gpu_id = DEVICEID;
// index->LoadQuantizer(config);
ASSERT_ANY_THROW(index->SetQuantizer(nullptr));
}
......
......@@ -46,24 +46,19 @@ class NSGInterfaceTest : public DataGen, public ::testing::Test {
Generate(256, 1000000 / 100, 1);
index_ = std::make_shared<knowhere::NSG>();
auto tmp_conf = std::make_shared<knowhere::NSGCfg>();
tmp_conf->gpu_id = DEVICEID;
tmp_conf->d = 256;
tmp_conf->knng = 20;
tmp_conf->nprobe = 8;
tmp_conf->nlist = 163;
tmp_conf->search_length = 40;
tmp_conf->out_degree = 30;
tmp_conf->candidate_pool_size = 100;
tmp_conf->metric_type = knowhere::METRICTYPE::L2;
train_conf = tmp_conf;
train_conf->Dump();
train_conf = knowhere::Config{{knowhere::meta::DIM, 256},
{knowhere::IndexParams::nlist, 163},
{knowhere::IndexParams::nprobe, 8},
{knowhere::IndexParams::knng, 20},
{knowhere::IndexParams::search_length, 40},
{knowhere::IndexParams::out_degree, 30},
{knowhere::IndexParams::candidate, 100},
{knowhere::Metric::TYPE, knowhere::Metric::L2}};
auto tmp2_conf = std::make_shared<knowhere::NSGCfg>();
tmp2_conf->k = k;
tmp2_conf->search_length = 30;
search_conf = tmp2_conf;
search_conf->Dump();
search_conf = knowhere::Config{
{knowhere::meta::TOPK, k},
{knowhere::IndexParams::search_length, 30},
};
}
void
......@@ -87,9 +82,9 @@ TEST_F(NSGInterfaceTest, basic_test) {
ASSERT_ANY_THROW(index_->Search(query_dataset, search_conf));
ASSERT_ANY_THROW(index_->Serialize());
}
train_conf->gpu_id = knowhere::INVALID_VALUE;
auto model_invalid_gpu = index_->Train(base_dataset, train_conf);
train_conf->gpu_id = DEVICEID;
// train_conf->gpu_id = knowhere::INVALID_VALUE;
// auto model_invalid_gpu = index_->Train(base_dataset, train_conf);
train_conf[knowhere::meta::DEVICEID] = DEVICEID;
auto model = index_->Train(base_dataset, train_conf);
auto result = index_->Search(query_dataset, search_conf);
AssertAnns(result, nq, k);
......
......@@ -33,17 +33,17 @@ class SPTAGTest : public DataGen, public TestWithParam<std::string> {
Generate(128, 100, 5);
index_ = std::make_shared<knowhere::CPUSPTAGRNG>(IndexType);
if (IndexType == "KDT") {
auto tempconf = std::make_shared<knowhere::KDTCfg>();
tempconf->tptnumber = 1;
tempconf->k = 10;
tempconf->metric_type = knowhere::METRICTYPE::L2;
conf = tempconf;
conf = knowhere::Config{
{knowhere::meta::DIM, dim},
{knowhere::meta::TOPK, 10},
{knowhere::Metric::TYPE, knowhere::Metric::L2},
};
} else {
auto tempconf = std::make_shared<knowhere::BKTCfg>();
tempconf->tptnumber = 1;
tempconf->k = 10;
tempconf->metric_type = knowhere::METRICTYPE::L2;
conf = tempconf;
conf = knowhere::Config{
{knowhere::meta::DIM, dim},
{knowhere::meta::TOPK, 10},
{knowhere::Metric::TYPE, knowhere::Metric::L2},
};
}
Init_with_default();
......
......@@ -16,9 +16,9 @@
namespace milvus {
namespace scheduler {
SearchJob::SearchJob(const std::shared_ptr<server::Context>& context, uint64_t topk, uint64_t nprobe,
SearchJob::SearchJob(const std::shared_ptr<server::Context>& context, uint64_t topk, const milvus::json& extra_params,
const engine::VectorsData& vectors)
: Job(JobType::SEARCH), context_(context), topk_(topk), nprobe_(nprobe), vectors_(vectors) {
: Job(JobType::SEARCH), context_(context), topk_(topk), extra_params_(extra_params), vectors_(vectors) {
}
bool
......@@ -72,7 +72,7 @@ SearchJob::Dump() const {
json ret{
{"topk", topk_},
{"nq", vectors_.vector_count_},
{"nprobe", nprobe_},
{"extra_params", extra_params_.dump()},
};
auto base = Job::Dump();
ret.insert(base.begin(), base.end());
......
......@@ -40,7 +40,7 @@ using ResultDistances = engine::ResultDistances;
class SearchJob : public Job {
public:
SearchJob(const std::shared_ptr<server::Context>& context, uint64_t topk, uint64_t nprobe,
SearchJob(const std::shared_ptr<server::Context>& context, uint64_t topk, const milvus::json& extra_params,
const engine::VectorsData& vectors);
public:
......@@ -79,9 +79,9 @@ class SearchJob : public Job {
return vectors_.vector_count_;
}
uint64_t
nprobe() const {
return nprobe_;
const milvus::json&
extra_params() const {
return extra_params_;
}
const engine::VectorsData&
......@@ -103,7 +103,7 @@ class SearchJob : public Job {
const std::shared_ptr<server::Context> context_;
uint64_t topk_ = 0;
uint64_t nprobe_ = 0;
milvus::json extra_params_;
// TODO: smart pointer
const engine::VectorsData& vectors_;
......
......@@ -41,8 +41,9 @@ XBuildIndexTask::XBuildIndexTask(TableFileSchemaPtr file, TaskLabelPtr label)
engine_type = (EngineType)file->engine_type_;
}
auto json = milvus::json::parse(file_->index_params_);
to_index_engine_ = EngineFactory::Build(file_->dimension_, file_->location_, engine_type,
(MetricType)file_->metric_type_, file_->nlist_);
(MetricType)file_->metric_type_, json);
}
}
......
......@@ -115,8 +115,12 @@ XSearchTask::XSearchTask(const std::shared_ptr<server::Context>& context, TableF
engine_type = (EngineType)file->engine_type_;
}
milvus::json json_params;
if (!file_->index_params_.empty()) {
json_params = milvus::json::parse(file_->index_params_);
}
index_engine_ = EngineFactory::Build(file_->dimension_, file_->location_, engine_type,
(MetricType)file_->metric_type_, file_->nlist_);
(MetricType)file_->metric_type_, json_params);
}
}
......@@ -217,7 +221,8 @@ XSearchTask::Execute() {
// step 1: allocate memory
uint64_t nq = search_job->nq();
uint64_t topk = search_job->topk();
uint64_t nprobe = search_job->nprobe();
const milvus::json& extra_params = search_job->extra_params();
ENGINE_LOG_DEBUG << "Search job extra params: " << extra_params.dump();
const engine::VectorsData& vectors = search_job->vectors();
output_ids.resize(topk * nq);
......@@ -235,13 +240,13 @@ XSearchTask::Execute() {
}
Status s;
if (!vectors.float_data_.empty()) {
s = index_engine_->Search(nq, vectors.float_data_.data(), topk, nprobe, output_distance.data(),
s = index_engine_->Search(nq, vectors.float_data_.data(), topk, extra_params, output_distance.data(),
output_ids.data(), hybrid);
} else if (!vectors.binary_data_.empty()) {
s = index_engine_->Search(nq, vectors.binary_data_.data(), topk, nprobe, output_distance.data(),
s = index_engine_->Search(nq, vectors.binary_data_.data(), topk, extra_params, output_distance.data(),
output_ids.data(), hybrid);
} else if (!vectors.id_array_.empty()) {
s = index_engine_->Search(nq, vectors.id_array_, topk, nprobe, output_distance.data(),
s = index_engine_->Search(nq, vectors.id_array_, topk, extra_params, output_distance.data(),
output_ids.data(), hybrid);
}
......
......@@ -1009,8 +1009,7 @@ Config::CheckCacheConfigCpuCacheCapacity(const std::string& value) {
std::cerr << "WARNING: cpu cache capacity value is too big" << std::endl;
}
std::string str =
GetConfigStr(CONFIG_CACHE, CONFIG_CACHE_INSERT_BUFFER_SIZE, CONFIG_CACHE_INSERT_BUFFER_SIZE_DEFAULT);
std::string str = GetConfigStr(CONFIG_CACHE, CONFIG_CACHE_INSERT_BUFFER_SIZE, "0");
int64_t buffer_value = std::stoll(str);
int64_t insert_buffer_size = buffer_value * GB;
......@@ -1059,9 +1058,8 @@ Config::CheckCacheConfigInsertBufferSize(const std::string& value) {
return Status(SERVER_INVALID_ARGUMENT, msg);
}
std::string str =
GetConfigStr(CONFIG_CACHE, CONFIG_CACHE_CPU_CACHE_CAPACITY, CONFIG_CACHE_CPU_CACHE_CAPACITY_DEFAULT);
int64_t cache_size = std::stoll(str);
std::string str = GetConfigStr(CONFIG_CACHE, CONFIG_CACHE_CPU_CACHE_CAPACITY, "0");
int64_t cache_size = std::stoll(str) * GB;
uint64_t total_mem = 0, free_mem = 0;
CommonUtil::GetSystemMemInfo(total_mem, free_mem);
......
......@@ -70,8 +70,8 @@ RequestHandler::DropTable(const std::shared_ptr<Context>& context, const std::st
Status
RequestHandler::CreateIndex(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t index_type,
int64_t nlist) {
BaseRequestPtr request_ptr = CreateIndexRequest::Create(context, table_name, index_type, nlist);
const milvus::json& json_params) {
BaseRequestPtr request_ptr = CreateIndexRequest::Create(context, table_name, index_type, json_params);
RequestScheduler::ExecRequest(request_ptr);
return request_ptr->status();
......@@ -123,11 +123,11 @@ RequestHandler::ShowTableInfo(const std::shared_ptr<Context>& context, const std
Status
RequestHandler::Search(const std::shared_ptr<Context>& context, const std::string& table_name,
const engine::VectorsData& vectors, int64_t topk, int64_t nprobe,
const engine::VectorsData& vectors, int64_t topk, const milvus::json& extra_params,
const std::vector<std::string>& partition_list, const std::vector<std::string>& file_id_list,
TopKQueryResult& result) {
BaseRequestPtr request_ptr =
SearchRequest::Create(context, table_name, vectors, topk, nprobe, partition_list, file_id_list, result);
SearchRequest::Create(context, table_name, vectors, topk, extra_params, partition_list, file_id_list, result);
RequestScheduler::ExecRequest(request_ptr);
return request_ptr->status();
......@@ -135,10 +135,10 @@ RequestHandler::Search(const std::shared_ptr<Context>& context, const std::strin
Status
RequestHandler::SearchByID(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t vector_id,
int64_t topk, int64_t nprobe, const std::vector<std::string>& partition_list,
TopKQueryResult& result) {
int64_t topk, const milvus::json& extra_params,
const std::vector<std::string>& partition_list, TopKQueryResult& result) {
BaseRequestPtr request_ptr =
SearchByIDRequest::Create(context, table_name, vector_id, topk, nprobe, partition_list, result);
SearchByIDRequest::Create(context, table_name, vector_id, topk, extra_params, partition_list, result);
RequestScheduler::ExecRequest(request_ptr);
return request_ptr->status();
......
......@@ -38,7 +38,7 @@ class RequestHandler {
Status
CreateIndex(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t index_type,
int64_t nlist);
const milvus::json& json_params);
Status
Insert(const std::shared_ptr<Context>& context, const std::string& table_name, engine::VectorsData& vectors,
......@@ -60,12 +60,13 @@ class RequestHandler {
Status
Search(const std::shared_ptr<Context>& context, const std::string& table_name, const engine::VectorsData& vectors,
int64_t topk, int64_t nprobe, const std::vector<std::string>& partition_list,
int64_t topk, const milvus::json& extra_params, const std::vector<std::string>& partition_list,
const std::vector<std::string>& file_id_list, TopKQueryResult& result);
Status
SearchByID(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t vector_id, int64_t topk,
int64_t nprobe, const std::vector<std::string>& partition_list, TopKQueryResult& result);
const milvus::json& extra_params, const std::vector<std::string>& partition_list,
TopKQueryResult& result);
Status
DescribeTable(const std::shared_ptr<Context>& context, const std::string& table_name, TableSchema& table_schema);
......
......@@ -17,6 +17,7 @@
#include "grpc/gen-status/status.grpc.pb.h"
#include "grpc/gen-status/status.pb.h"
#include "server/context/Context.h"
#include "utils/Json.h"
#include "utils/Status.h"
#include <condition_variable>
......@@ -73,17 +74,15 @@ struct TopKQueryResult {
struct IndexParam {
std::string table_name_;
int64_t index_type_;
int64_t nlist_;
std::string extra_params_;
IndexParam() {
index_type_ = 0;
nlist_ = 0;
}
IndexParam(const std::string& table_name, int64_t index_type, int64_t nlist) {
IndexParam(const std::string& table_name, int64_t index_type) {
table_name_ = table_name;
index_type_ = index_type;
nlist_ = nlist;
}
};
......
......@@ -24,14 +24,17 @@ namespace milvus {
namespace server {
CreateIndexRequest::CreateIndexRequest(const std::shared_ptr<Context>& context, const std::string& table_name,
int64_t index_type, int64_t nlist)
: BaseRequest(context, DDL_DML_REQUEST_GROUP), table_name_(table_name), index_type_(index_type), nlist_(nlist) {
int64_t index_type, const milvus::json& json_params)
: BaseRequest(context, DDL_DML_REQUEST_GROUP),
table_name_(table_name),
index_type_(index_type),
json_params_(json_params) {
}
BaseRequestPtr
CreateIndexRequest::Create(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t index_type,
int64_t nlist) {
return std::shared_ptr<BaseRequest>(new CreateIndexRequest(context, table_name, index_type, nlist));
const milvus::json& json_params) {
return std::shared_ptr<BaseRequest>(new CreateIndexRequest(context, table_name, index_type, json_params));
}
Status
......@@ -69,7 +72,7 @@ CreateIndexRequest::OnExecute() {
return status;
}
status = ValidationUtil::ValidateTableIndexNlist(nlist_);
status = ValidationUtil::ValidateIndexParams(json_params_, table_schema, index_type_);
if (!status.ok()) {
return status;
}
......@@ -109,7 +112,7 @@ CreateIndexRequest::OnExecute() {
// step 3: create index
engine::TableIndex index;
index.engine_type_ = adapter_index_type;
index.nlist_ = nlist_;
index.extra_params_ = json_params_;
status = DBWrapper::DB()->CreateIndex(table_name_, index);
fiu_do_on("CreateIndexRequest.OnExecute.create_index_fail",
status = Status(milvus::SERVER_UNEXPECTED_ERROR, ""));
......
......@@ -21,11 +21,12 @@ namespace server {
class CreateIndexRequest : public BaseRequest {
public:
static BaseRequestPtr
Create(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t index_type, int64_t nlist);
Create(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t index_type,
const milvus::json& json_params);
protected:
CreateIndexRequest(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t index_type,
int64_t nlist);
const milvus::json& json_params);
Status
OnExecute() override;
......@@ -33,7 +34,7 @@ class CreateIndexRequest : public BaseRequest {
private:
const std::string table_name_;
const int64_t index_type_;
const int64_t nlist_;
milvus::json json_params_;
};
} // namespace server
......
......@@ -78,7 +78,7 @@ DescribeIndexRequest::OnExecute() {
index_param_.table_name_ = table_name_;
index_param_.index_type_ = index.engine_type_;
index_param_.nlist_ = index.nlist_;
index_param_.extra_params_ = index.extra_params_.dump();
} catch (std::exception& ex) {
return Status(SERVER_UNEXPECTED_ERROR, ex.what());
}
......
......@@ -34,23 +34,23 @@ namespace milvus {
namespace server {
SearchByIDRequest::SearchByIDRequest(const std::shared_ptr<Context>& context, const std::string& table_name,
int64_t vector_id, int64_t topk, int64_t nprobe,
int64_t vector_id, int64_t topk, const milvus::json& extra_params,
const std::vector<std::string>& partition_list, TopKQueryResult& result)
: BaseRequest(context, DQL_REQUEST_GROUP),
table_name_(table_name),
vector_id_(vector_id),
topk_(topk),
nprobe_(nprobe),
extra_params_(extra_params),
partition_list_(partition_list),
result_(result) {
}
BaseRequestPtr
SearchByIDRequest::Create(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t vector_id,
int64_t topk, int64_t nprobe, const std::vector<std::string>& partition_list,
TopKQueryResult& result) {
int64_t topk, const milvus::json& extra_params,
const std::vector<std::string>& partition_list, TopKQueryResult& result) {
return std::shared_ptr<BaseRequest>(
new SearchByIDRequest(context, table_name, vector_id, topk, nprobe, partition_list, result));
new SearchByIDRequest(context, table_name, vector_id, topk, extra_params, partition_list, result));
}
Status
......@@ -59,7 +59,7 @@ SearchByIDRequest::OnExecute() {
auto pre_query_ctx = context_->Child("Pre query");
std::string hdr = "SearchByIDRequest(table=" + table_name_ + ", id=" + std::to_string(vector_id_) +
", k=" + std::to_string(topk_) + ", nprob=" + std::to_string(nprobe_) + ")";
", k=" + std::to_string(topk_) + ", extra_params=" + extra_params_.dump() + ")";
TimeRecorder rc(hdr);
......@@ -88,6 +88,11 @@ SearchByIDRequest::OnExecute() {
}
}
status = ValidationUtil::ValidateSearchParams(extra_params_, table_schema, topk_);
if (!status.ok()) {
return status;
}
// Check whether GPU search resource is enabled
#ifdef MILVUS_GPU_VERSION
Config& config = Config::GetInstance();
......@@ -122,11 +127,6 @@ SearchByIDRequest::OnExecute() {
return status;
}
status = ValidationUtil::ValidateSearchNprobe(nprobe_, table_schema);
if (!status.ok()) {
return status;
}
rc.RecordSection("check validation");
// step 5: search vectors
......@@ -140,8 +140,8 @@ SearchByIDRequest::OnExecute() {
pre_query_ctx->GetTraceContext()->GetSpan()->Finish();
status = DBWrapper::DB()->QueryByID(context_, table_name_, partition_list_, (size_t)topk_, nprobe_, vector_id_,
result_ids, result_distances);
status = DBWrapper::DB()->QueryByID(context_, table_name_, partition_list_, (size_t)topk_, extra_params_,
vector_id_, result_ids, result_distances);
#ifdef MILVUS_ENABLE_PROFILING
ProfilerStop();
......
......@@ -30,11 +30,11 @@ class SearchByIDRequest : public BaseRequest {
public:
static BaseRequestPtr
Create(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t vector_id, int64_t topk,
int64_t nprobe, const std::vector<std::string>& partition_list, TopKQueryResult& result);
const milvus::json& extra_params, const std::vector<std::string>& partition_list, TopKQueryResult& result);
protected:
SearchByIDRequest(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t vector_id,
int64_t topk, int64_t nprobe, const std::vector<std::string>& partition_list,
int64_t topk, const milvus::json& extra_params, const std::vector<std::string>& partition_list,
TopKQueryResult& result);
Status
......@@ -44,7 +44,7 @@ class SearchByIDRequest : public BaseRequest {
const std::string table_name_;
const int64_t vector_id_;
int64_t topk_;
int64_t nprobe_;
milvus::json extra_params_;
const std::vector<std::string> partition_list_;
TopKQueryResult& result_;
......
......@@ -26,14 +26,14 @@ namespace milvus {
namespace server {
SearchRequest::SearchRequest(const std::shared_ptr<Context>& context, const std::string& table_name,
const engine::VectorsData& vectors, int64_t topk, int64_t nprobe,
const engine::VectorsData& vectors, int64_t topk, const milvus::json& extra_params,
const std::vector<std::string>& partition_list,
const std::vector<std::string>& file_id_list, TopKQueryResult& result)
: BaseRequest(context, DQL_REQUEST_GROUP),
table_name_(table_name),
vectors_data_(vectors),
topk_(topk),
nprobe_(nprobe),
extra_params_(extra_params),
partition_list_(partition_list),
file_id_list_(file_id_list),
result_(result) {
......@@ -41,11 +41,11 @@ SearchRequest::SearchRequest(const std::shared_ptr<Context>& context, const std:
BaseRequestPtr
SearchRequest::Create(const std::shared_ptr<Context>& context, const std::string& table_name,
const engine::VectorsData& vectors, int64_t topk, int64_t nprobe,
const engine::VectorsData& vectors, int64_t topk, const milvus::json& extra_params,
const std::vector<std::string>& partition_list, const std::vector<std::string>& file_id_list,
TopKQueryResult& result) {
return std::shared_ptr<BaseRequest>(
new SearchRequest(context, table_name, vectors, topk, nprobe, partition_list, file_id_list, result));
new SearchRequest(context, table_name, vectors, topk, extra_params, partition_list, file_id_list, result));
}
Status
......@@ -56,7 +56,7 @@ SearchRequest::OnExecute() {
auto pre_query_ctx = context_->Child("Pre query");
std::string hdr = "SearchRequest(table=" + table_name_ + ", nq=" + std::to_string(vector_count) +
", k=" + std::to_string(topk_) + ", nprob=" + std::to_string(nprobe_) + ")";
", k=" + std::to_string(topk_) + ", extra_params=" + extra_params_.dump() + ")";
TimeRecorder rc(hdr);
......@@ -84,13 +84,13 @@ SearchRequest::OnExecute() {
}
}
// step 3: check search parameter
status = ValidationUtil::ValidateSearchTopk(topk_, table_schema);
status = ValidationUtil::ValidateSearchParams(extra_params_, table_schema, topk_);
if (!status.ok()) {
return status;
}
status = ValidationUtil::ValidateSearchNprobe(nprobe_, table_schema);
// step 3: check search parameter
status = ValidationUtil::ValidateSearchTopk(topk_, table_schema);
if (!status.ok()) {
return status;
}
......@@ -150,10 +150,10 @@ SearchRequest::OnExecute() {
return status;
}
status = DBWrapper::DB()->Query(context_, table_name_, partition_list_, (size_t)topk_, nprobe_,
status = DBWrapper::DB()->Query(context_, table_name_, partition_list_, (size_t)topk_, extra_params_,
vectors_data_, result_ids, result_distances);
} else {
status = DBWrapper::DB()->QueryByFileID(context_, table_name_, file_id_list_, (size_t)topk_, nprobe_,
status = DBWrapper::DB()->QueryByFileID(context_, table_name_, file_id_list_, (size_t)topk_, extra_params_,
vectors_data_, result_ids, result_distances);
}
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册