提交 679118d3 编写于 作者: S starlord

MS-212 Support Inner product metric type


Former-commit-id: 068ed6d011b45f46abc485036ca8e3cf397dfcda
上级 cd941d55
......@@ -18,6 +18,7 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-204 - Support multi db_path
- MS-206 - Support SQ8 index type
- MS-208 - Add buildinde interface for C++ SDK
- MS-212 - Support Inner product metric type
## New Feature
- MS-195 - Add nlist and use_blas_threshold conf
......
......@@ -36,4 +36,5 @@ cache_config: # cache configure
engine_config:
nprobe: 10
nlist: 16384
use_blas_threshold: 20
\ No newline at end of file
use_blas_threshold: 20
metric_type: L2 #L2 or Inner Product
\ No newline at end of file
......@@ -22,15 +22,25 @@ namespace zilliz {
namespace milvus {
namespace engine {
namespace {
std::string GetMetricType() {
server::ServerConfig &config = server::ServerConfig::GetInstance();
server::ConfigNode engine_config = config.GetConfig(server::CONFIG_ENGINE);
return engine_config.GetValue(server::CONFIG_METRICTYPE, "L2");
}
}
FaissExecutionEngine::FaissExecutionEngine(uint16_t dimension,
const std::string& location,
const std::string& build_index_type,
const std::string& raw_index_type)
: pIndex_(faiss::index_factory(dimension, raw_index_type.c_str())),
location_(location),
: location_(location),
build_index_type_(build_index_type),
raw_index_type_(raw_index_type) {
std::string metric_type = GetMetricType();
faiss::MetricType faiss_metric_type = (metric_type == "L2") ? faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
pIndex_.reset(faiss::index_factory(dimension, raw_index_type.c_str(), faiss_metric_type));
}
FaissExecutionEngine::FaissExecutionEngine(std::shared_ptr<faiss::Index> index,
......@@ -119,6 +129,7 @@ FaissExecutionEngine::BuildIndex(const std::string& location) {
auto opd = std::make_shared<Operand>();
opd->d = pIndex_->d;
opd->index_type = build_index_type_;
opd->metric_type = GetMetricType();
IndexBuilderPtr pBuilder = GetIndexBuilder(opd);
auto from_index = dynamic_cast<faiss::IndexIDMap*>(pIndex_.get());
......
......@@ -30,11 +30,20 @@ void CollectDurationMetrics(int index_type, double total_time) {
}
}
std::string GetMetricType() {
server::ServerConfig &config = server::ServerConfig::GetInstance();
server::ConfigNode engine_config = config.GetConfig(server::CONFIG_ENGINE);
return engine_config.GetValue(server::CONFIG_METRICTYPE, "L2");
}
}
SearchTask::SearchTask()
: IScheduleTask(ScheduleTaskType::kSearch) {
std::string metric_type = GetMetricType();
if(metric_type != "L2") {
metric_l2 = false;
}
}
std::shared_ptr<IScheduleTask> SearchTask::Execute() {
......@@ -71,7 +80,7 @@ std::shared_ptr<IScheduleTask> SearchTask::Execute() {
rc.Record("cluster result");
//step 4: pick up topk result
SearchTask::TopkResult(result_set, inner_k, context->GetResult());
SearchTask::TopkResult(result_set, inner_k, metric_l2, context->GetResult());
rc.Record("reduce topk");
} catch (std::exception& ex) {
......@@ -125,7 +134,8 @@ Status SearchTask::ClusterResult(const std::vector<long> &output_ids,
Status SearchTask::MergeResult(SearchContext::Id2DistanceMap &distance_src,
SearchContext::Id2DistanceMap &distance_target,
uint64_t topk) {
uint64_t topk,
bool ascending) {
//Note: the score_src and score_target are already arranged by score in ascending order
if(distance_src.empty()) {
SERVER_LOG_WARNING << "Empty distance source array";
......@@ -161,15 +171,27 @@ Status SearchTask::MergeResult(SearchContext::Id2DistanceMap &distance_src,
break;
}
//compare score, put smallest score to score_merged one by one
//compare score,
// if ascending = true, put smallest score to score_merged one by one
// else, put largest score to score_merged one by one
auto& src_pair = distance_src[src_index];
auto& target_pair = distance_target[target_index];
if(src_pair.second > target_pair.second) {
distance_merged.push_back(target_pair);
target_index++;
if(ascending){
if(src_pair.second > target_pair.second) {
distance_merged.push_back(target_pair);
target_index++;
} else {
distance_merged.push_back(src_pair);
src_index++;
}
} else {
distance_merged.push_back(src_pair);
src_index++;
if(src_pair.second < target_pair.second) {
distance_merged.push_back(target_pair);
target_index++;
} else {
distance_merged.push_back(src_pair);
src_index++;
}
}
//score_merged.size() already equal topk
......@@ -185,6 +207,7 @@ Status SearchTask::MergeResult(SearchContext::Id2DistanceMap &distance_src,
Status SearchTask::TopkResult(SearchContext::ResultSet &result_src,
uint64_t topk,
bool ascending,
SearchContext::ResultSet &result_target) {
if (result_target.empty()) {
result_target.swap(result_src);
......@@ -200,7 +223,7 @@ Status SearchTask::TopkResult(SearchContext::ResultSet &result_src,
for (size_t i = 0; i < result_src.size(); i++) {
SearchContext::Id2DistanceMap &score_src = result_src[i];
SearchContext::Id2DistanceMap &score_target = result_target[i];
SearchTask::MergeResult(score_src, score_target, topk);
SearchTask::MergeResult(score_src, score_target, topk, ascending);
}
return Status::OK();
......
......@@ -27,10 +27,12 @@ public:
static Status MergeResult(SearchContext::Id2DistanceMap &distance_src,
SearchContext::Id2DistanceMap &distance_target,
uint64_t topk);
uint64_t topk,
bool ascending);
static Status TopkResult(SearchContext::ResultSet &result_src,
uint64_t topk,
bool ascending,
SearchContext::ResultSet &result_target);
public:
......@@ -38,6 +40,7 @@ public:
int index_type_ = 0; //for metrics
ExecutionEnginePtr index_engine_;
std::vector<SearchContextPtr> search_contexts_;
bool metric_l2 = true;
};
using SearchTaskPtr = std::shared_ptr<SearchTask>;
......
......@@ -98,7 +98,7 @@ namespace {
TableSchema BuildTableSchema() {
TableSchema tb_schema;
tb_schema.table_name = TABLE_NAME;
tb_schema.index_type = IndexType::gpu_ivfsq8;
tb_schema.index_type = IndexType::gpu_ivfflat;
tb_schema.dimension = TABLE_DIMENSION;
tb_schema.store_raw_vector = true;
......
......@@ -47,6 +47,7 @@ static const std::string CONFIG_ENGINE = "engine_config";
static const std::string CONFIG_NPROBE = "nprobe";
static const std::string CONFIG_NLIST = "nlist";
static const std::string CONFIG_DCBT = "use_blas_threshold";
static const std::string CONFIG_METRICTYPE = "metric_type";
class ServerConfig {
public:
......
......@@ -71,7 +71,8 @@ Index_ptr IndexBuilder::build_all(const long &nb,
{
LOG(DEBUG) << "Build index by GPU";
// TODO: list support index-type.
faiss::Index *ori_index = faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str());
faiss::MetricType metric_type = opd_->metric_type == "L2" ? faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
faiss::Index *ori_index = faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str(), metric_type);
std::lock_guard<std::mutex> lk(gpu_resource);
faiss::gpu::StandardGpuResources res;
......@@ -90,7 +91,8 @@ Index_ptr IndexBuilder::build_all(const long &nb,
#else
{
LOG(DEBUG) << "Build index by CPU";
faiss::Index *index = faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str());
faiss::MetricType metric_type = opd_->metric_type == "L2" ? faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
faiss::Index *index = faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str(), metric_type);
if (!index->is_trained) {
nt == 0 || xt == nullptr ? index->train(nb, xb)
: index->train(nt, xt);
......@@ -113,7 +115,8 @@ BgCpuBuilder::BgCpuBuilder(const zilliz::milvus::engine::Operand_ptr &opd) : Ind
Index_ptr BgCpuBuilder::build_all(const long &nb, const float *xb, const long *ids, const long &nt, const float *xt) {
std::shared_ptr<faiss::Index> index = nullptr;
index.reset(faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str()));
faiss::MetricType metric_type = opd_->metric_type == "L2" ? faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
index.reset(faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str(), metric_type));
LOG(DEBUG) << "Build index by CPU";
{
......
......@@ -73,13 +73,13 @@ TEST(DBSearchTest, TOPK_TEST) {
ASSERT_EQ(src_result.size(), NQ);
engine::SearchContext::ResultSet target_result;
status = engine::SearchTask::TopkResult(target_result, TOP_K, target_result);
status = engine::SearchTask::TopkResult(target_result, TOP_K, true, target_result);
ASSERT_TRUE(status.ok());
status = engine::SearchTask::TopkResult(target_result, TOP_K, src_result);
status = engine::SearchTask::TopkResult(target_result, TOP_K, true, src_result);
ASSERT_FALSE(status.ok());
status = engine::SearchTask::TopkResult(src_result, TOP_K, target_result);
status = engine::SearchTask::TopkResult(src_result, TOP_K, true, target_result);
ASSERT_TRUE(status.ok());
ASSERT_TRUE(src_result.empty());
ASSERT_EQ(target_result.size(), NQ);
......@@ -92,7 +92,7 @@ TEST(DBSearchTest, TOPK_TEST) {
status = engine::SearchTask::ClusterResult(src_ids, src_distence, NQ, wrong_topk, src_result);
ASSERT_TRUE(status.ok());
status = engine::SearchTask::TopkResult(src_result, TOP_K, target_result);
status = engine::SearchTask::TopkResult(src_result, TOP_K, true, target_result);
ASSERT_TRUE(status.ok());
for(uint64_t i = 0; i < NQ; i++) {
ASSERT_EQ(target_result[i].size(), TOP_K);
......@@ -101,7 +101,7 @@ TEST(DBSearchTest, TOPK_TEST) {
wrong_topk = TOP_K + 10;
BuildResult(NQ, wrong_topk, src_ids, src_distence);
status = engine::SearchTask::TopkResult(src_result, TOP_K, target_result);
status = engine::SearchTask::TopkResult(src_result, TOP_K, true, target_result);
ASSERT_TRUE(status.ok());
for(uint64_t i = 0; i < NQ; i++) {
ASSERT_EQ(target_result[i].size(), TOP_K);
......@@ -126,7 +126,7 @@ TEST(DBSearchTest, MERGE_TEST) {
{
engine::SearchContext::Id2DistanceMap src = src_result[0];
engine::SearchContext::Id2DistanceMap target = target_result[0];
status = engine::SearchTask::MergeResult(src, target, 10);
status = engine::SearchTask::MergeResult(src, target, 10, true);
ASSERT_TRUE(status.ok());
ASSERT_EQ(target.size(), 10);
CheckResult(src_result[0], target_result[0], target);
......@@ -135,7 +135,7 @@ TEST(DBSearchTest, MERGE_TEST) {
{
engine::SearchContext::Id2DistanceMap src = src_result[0];
engine::SearchContext::Id2DistanceMap target;
status = engine::SearchTask::MergeResult(src, target, 10);
status = engine::SearchTask::MergeResult(src, target, 10, true);
ASSERT_TRUE(status.ok());
ASSERT_EQ(target.size(), src_count);
ASSERT_TRUE(src.empty());
......@@ -145,7 +145,7 @@ TEST(DBSearchTest, MERGE_TEST) {
{
engine::SearchContext::Id2DistanceMap src = src_result[0];
engine::SearchContext::Id2DistanceMap target = target_result[0];
status = engine::SearchTask::MergeResult(src, target, 30);
status = engine::SearchTask::MergeResult(src, target, 30, true);
ASSERT_TRUE(status.ok());
ASSERT_EQ(target.size(), src_count + target_count);
CheckResult(src_result[0], target_result[0], target);
......@@ -154,7 +154,7 @@ TEST(DBSearchTest, MERGE_TEST) {
{
engine::SearchContext::Id2DistanceMap target = src_result[0];
engine::SearchContext::Id2DistanceMap src = target_result[0];
status = engine::SearchTask::MergeResult(src, target, 30);
status = engine::SearchTask::MergeResult(src, target, 30, true);
ASSERT_TRUE(status.ok());
ASSERT_EQ(target.size(), src_count + target_count);
CheckResult(src_result[0], target_result[0], target);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册