Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
milvus
提交
679118d3
milvus
项目概览
BaiXuePrincess
/
milvus
与 Fork 源项目一致
从无法访问的项目Fork
通知
7
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
milvus
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
679118d3
编写于
7月 12, 2019
作者:
S
starlord
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
MS-212 Support Inner product metric type
Former-commit-id: 068ed6d011b45f46abc485036ca8e3cf397dfcda
上级
cd941d55
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
70 addition
and
27 deletion
+70
-27
cpp/CHANGELOG.md
cpp/CHANGELOG.md
+1
-0
cpp/conf/server_config.template
cpp/conf/server_config.template
+2
-1
cpp/src/db/FaissExecutionEngine.cpp
cpp/src/db/FaissExecutionEngine.cpp
+13
-2
cpp/src/db/scheduler/task/SearchTask.cpp
cpp/src/db/scheduler/task/SearchTask.cpp
+33
-10
cpp/src/db/scheduler/task/SearchTask.h
cpp/src/db/scheduler/task/SearchTask.h
+4
-1
cpp/src/sdk/examples/simple/src/ClientTest.cpp
cpp/src/sdk/examples/simple/src/ClientTest.cpp
+1
-1
cpp/src/server/ServerConfig.h
cpp/src/server/ServerConfig.h
+1
-0
cpp/src/wrapper/IndexBuilder.cpp
cpp/src/wrapper/IndexBuilder.cpp
+6
-3
cpp/unittest/db/search_test.cpp
cpp/unittest/db/search_test.cpp
+9
-9
未找到文件。
cpp/CHANGELOG.md
浏览文件 @
679118d3
...
...
@@ -18,6 +18,7 @@ Please mark all change in change log and use the ticket from JIRA.
-
MS-204 - Support multi db_path
-
MS-206 - Support SQ8 index type
-
MS-208 - Add buildinde interface for C++ SDK
-
MS-212 - Support Inner product metric type
## New Feature
-
MS-195 - Add nlist and use_blas_threshold conf
...
...
cpp/conf/server_config.template
浏览文件 @
679118d3
...
...
@@ -36,4 +36,5 @@ cache_config: # cache configure
engine_config:
nprobe: 10
nlist: 16384
use_blas_threshold: 20
\ No newline at end of file
use_blas_threshold: 20
metric_type: L2 #L2 or Inner Product
\ No newline at end of file
cpp/src/db/FaissExecutionEngine.cpp
浏览文件 @
679118d3
...
...
@@ -22,15 +22,25 @@ namespace zilliz {
namespace
milvus
{
namespace
engine
{
namespace
{
std
::
string
GetMetricType
()
{
server
::
ServerConfig
&
config
=
server
::
ServerConfig
::
GetInstance
();
server
::
ConfigNode
engine_config
=
config
.
GetConfig
(
server
::
CONFIG_ENGINE
);
return
engine_config
.
GetValue
(
server
::
CONFIG_METRICTYPE
,
"L2"
);
}
}
FaissExecutionEngine
::
FaissExecutionEngine
(
uint16_t
dimension
,
const
std
::
string
&
location
,
const
std
::
string
&
build_index_type
,
const
std
::
string
&
raw_index_type
)
:
pIndex_
(
faiss
::
index_factory
(
dimension
,
raw_index_type
.
c_str
())),
location_
(
location
),
:
location_
(
location
),
build_index_type_
(
build_index_type
),
raw_index_type_
(
raw_index_type
)
{
std
::
string
metric_type
=
GetMetricType
();
faiss
::
MetricType
faiss_metric_type
=
(
metric_type
==
"L2"
)
?
faiss
::
METRIC_L2
:
faiss
::
METRIC_INNER_PRODUCT
;
pIndex_
.
reset
(
faiss
::
index_factory
(
dimension
,
raw_index_type
.
c_str
(),
faiss_metric_type
));
}
FaissExecutionEngine
::
FaissExecutionEngine
(
std
::
shared_ptr
<
faiss
::
Index
>
index
,
...
...
@@ -119,6 +129,7 @@ FaissExecutionEngine::BuildIndex(const std::string& location) {
auto
opd
=
std
::
make_shared
<
Operand
>
();
opd
->
d
=
pIndex_
->
d
;
opd
->
index_type
=
build_index_type_
;
opd
->
metric_type
=
GetMetricType
();
IndexBuilderPtr
pBuilder
=
GetIndexBuilder
(
opd
);
auto
from_index
=
dynamic_cast
<
faiss
::
IndexIDMap
*>
(
pIndex_
.
get
());
...
...
cpp/src/db/scheduler/task/SearchTask.cpp
浏览文件 @
679118d3
...
...
@@ -30,11 +30,20 @@ void CollectDurationMetrics(int index_type, double total_time) {
}
}
std
::
string
GetMetricType
()
{
server
::
ServerConfig
&
config
=
server
::
ServerConfig
::
GetInstance
();
server
::
ConfigNode
engine_config
=
config
.
GetConfig
(
server
::
CONFIG_ENGINE
);
return
engine_config
.
GetValue
(
server
::
CONFIG_METRICTYPE
,
"L2"
);
}
}
SearchTask
::
SearchTask
()
:
IScheduleTask
(
ScheduleTaskType
::
kSearch
)
{
std
::
string
metric_type
=
GetMetricType
();
if
(
metric_type
!=
"L2"
)
{
metric_l2
=
false
;
}
}
std
::
shared_ptr
<
IScheduleTask
>
SearchTask
::
Execute
()
{
...
...
@@ -71,7 +80,7 @@ std::shared_ptr<IScheduleTask> SearchTask::Execute() {
rc
.
Record
(
"cluster result"
);
//step 4: pick up topk result
SearchTask
::
TopkResult
(
result_set
,
inner_k
,
context
->
GetResult
());
SearchTask
::
TopkResult
(
result_set
,
inner_k
,
metric_l2
,
context
->
GetResult
());
rc
.
Record
(
"reduce topk"
);
}
catch
(
std
::
exception
&
ex
)
{
...
...
@@ -125,7 +134,8 @@ Status SearchTask::ClusterResult(const std::vector<long> &output_ids,
Status
SearchTask
::
MergeResult
(
SearchContext
::
Id2DistanceMap
&
distance_src
,
SearchContext
::
Id2DistanceMap
&
distance_target
,
uint64_t
topk
)
{
uint64_t
topk
,
bool
ascending
)
{
//Note: the score_src and score_target are already arranged by score in ascending order
if
(
distance_src
.
empty
())
{
SERVER_LOG_WARNING
<<
"Empty distance source array"
;
...
...
@@ -161,15 +171,27 @@ Status SearchTask::MergeResult(SearchContext::Id2DistanceMap &distance_src,
break
;
}
//compare score, put smallest score to score_merged one by one
//compare score,
// if ascending = true, put smallest score to score_merged one by one
// else, put largest score to score_merged one by one
auto
&
src_pair
=
distance_src
[
src_index
];
auto
&
target_pair
=
distance_target
[
target_index
];
if
(
src_pair
.
second
>
target_pair
.
second
)
{
distance_merged
.
push_back
(
target_pair
);
target_index
++
;
if
(
ascending
){
if
(
src_pair
.
second
>
target_pair
.
second
)
{
distance_merged
.
push_back
(
target_pair
);
target_index
++
;
}
else
{
distance_merged
.
push_back
(
src_pair
);
src_index
++
;
}
}
else
{
distance_merged
.
push_back
(
src_pair
);
src_index
++
;
if
(
src_pair
.
second
<
target_pair
.
second
)
{
distance_merged
.
push_back
(
target_pair
);
target_index
++
;
}
else
{
distance_merged
.
push_back
(
src_pair
);
src_index
++
;
}
}
//score_merged.size() already equal topk
...
...
@@ -185,6 +207,7 @@ Status SearchTask::MergeResult(SearchContext::Id2DistanceMap &distance_src,
Status
SearchTask
::
TopkResult
(
SearchContext
::
ResultSet
&
result_src
,
uint64_t
topk
,
bool
ascending
,
SearchContext
::
ResultSet
&
result_target
)
{
if
(
result_target
.
empty
())
{
result_target
.
swap
(
result_src
);
...
...
@@ -200,7 +223,7 @@ Status SearchTask::TopkResult(SearchContext::ResultSet &result_src,
for
(
size_t
i
=
0
;
i
<
result_src
.
size
();
i
++
)
{
SearchContext
::
Id2DistanceMap
&
score_src
=
result_src
[
i
];
SearchContext
::
Id2DistanceMap
&
score_target
=
result_target
[
i
];
SearchTask
::
MergeResult
(
score_src
,
score_target
,
topk
);
SearchTask
::
MergeResult
(
score_src
,
score_target
,
topk
,
ascending
);
}
return
Status
::
OK
();
...
...
cpp/src/db/scheduler/task/SearchTask.h
浏览文件 @
679118d3
...
...
@@ -27,10 +27,12 @@ public:
static
Status
MergeResult
(
SearchContext
::
Id2DistanceMap
&
distance_src
,
SearchContext
::
Id2DistanceMap
&
distance_target
,
uint64_t
topk
);
uint64_t
topk
,
bool
ascending
);
static
Status
TopkResult
(
SearchContext
::
ResultSet
&
result_src
,
uint64_t
topk
,
bool
ascending
,
SearchContext
::
ResultSet
&
result_target
);
public:
...
...
@@ -38,6 +40,7 @@ public:
int
index_type_
=
0
;
//for metrics
ExecutionEnginePtr
index_engine_
;
std
::
vector
<
SearchContextPtr
>
search_contexts_
;
bool
metric_l2
=
true
;
};
using
SearchTaskPtr
=
std
::
shared_ptr
<
SearchTask
>
;
...
...
cpp/src/sdk/examples/simple/src/ClientTest.cpp
浏览文件 @
679118d3
...
...
@@ -98,7 +98,7 @@ namespace {
TableSchema
BuildTableSchema
()
{
TableSchema
tb_schema
;
tb_schema
.
table_name
=
TABLE_NAME
;
tb_schema
.
index_type
=
IndexType
::
gpu_ivf
sq8
;
tb_schema
.
index_type
=
IndexType
::
gpu_ivf
flat
;
tb_schema
.
dimension
=
TABLE_DIMENSION
;
tb_schema
.
store_raw_vector
=
true
;
...
...
cpp/src/server/ServerConfig.h
浏览文件 @
679118d3
...
...
@@ -47,6 +47,7 @@ static const std::string CONFIG_ENGINE = "engine_config";
static
const
std
::
string
CONFIG_NPROBE
=
"nprobe"
;
static
const
std
::
string
CONFIG_NLIST
=
"nlist"
;
static
const
std
::
string
CONFIG_DCBT
=
"use_blas_threshold"
;
static
const
std
::
string
CONFIG_METRICTYPE
=
"metric_type"
;
class
ServerConfig
{
public:
...
...
cpp/src/wrapper/IndexBuilder.cpp
浏览文件 @
679118d3
...
...
@@ -71,7 +71,8 @@ Index_ptr IndexBuilder::build_all(const long &nb,
{
LOG
(
DEBUG
)
<<
"Build index by GPU"
;
// TODO: list support index-type.
faiss
::
Index
*
ori_index
=
faiss
::
index_factory
(
opd_
->
d
,
opd_
->
get_index_type
(
nb
).
c_str
());
faiss
::
MetricType
metric_type
=
opd_
->
metric_type
==
"L2"
?
faiss
::
METRIC_L2
:
faiss
::
METRIC_INNER_PRODUCT
;
faiss
::
Index
*
ori_index
=
faiss
::
index_factory
(
opd_
->
d
,
opd_
->
get_index_type
(
nb
).
c_str
(),
metric_type
);
std
::
lock_guard
<
std
::
mutex
>
lk
(
gpu_resource
);
faiss
::
gpu
::
StandardGpuResources
res
;
...
...
@@ -90,7 +91,8 @@ Index_ptr IndexBuilder::build_all(const long &nb,
#else
{
LOG
(
DEBUG
)
<<
"Build index by CPU"
;
faiss
::
Index
*
index
=
faiss
::
index_factory
(
opd_
->
d
,
opd_
->
get_index_type
(
nb
).
c_str
());
faiss
::
MetricType
metric_type
=
opd_
->
metric_type
==
"L2"
?
faiss
::
METRIC_L2
:
faiss
::
METRIC_INNER_PRODUCT
;
faiss
::
Index
*
index
=
faiss
::
index_factory
(
opd_
->
d
,
opd_
->
get_index_type
(
nb
).
c_str
(),
metric_type
);
if
(
!
index
->
is_trained
)
{
nt
==
0
||
xt
==
nullptr
?
index
->
train
(
nb
,
xb
)
:
index
->
train
(
nt
,
xt
);
...
...
@@ -113,7 +115,8 @@ BgCpuBuilder::BgCpuBuilder(const zilliz::milvus::engine::Operand_ptr &opd) : Ind
Index_ptr
BgCpuBuilder
::
build_all
(
const
long
&
nb
,
const
float
*
xb
,
const
long
*
ids
,
const
long
&
nt
,
const
float
*
xt
)
{
std
::
shared_ptr
<
faiss
::
Index
>
index
=
nullptr
;
index
.
reset
(
faiss
::
index_factory
(
opd_
->
d
,
opd_
->
get_index_type
(
nb
).
c_str
()));
faiss
::
MetricType
metric_type
=
opd_
->
metric_type
==
"L2"
?
faiss
::
METRIC_L2
:
faiss
::
METRIC_INNER_PRODUCT
;
index
.
reset
(
faiss
::
index_factory
(
opd_
->
d
,
opd_
->
get_index_type
(
nb
).
c_str
(),
metric_type
));
LOG
(
DEBUG
)
<<
"Build index by CPU"
;
{
...
...
cpp/unittest/db/search_test.cpp
浏览文件 @
679118d3
...
...
@@ -73,13 +73,13 @@ TEST(DBSearchTest, TOPK_TEST) {
ASSERT_EQ
(
src_result
.
size
(),
NQ
);
engine
::
SearchContext
::
ResultSet
target_result
;
status
=
engine
::
SearchTask
::
TopkResult
(
target_result
,
TOP_K
,
target_result
);
status
=
engine
::
SearchTask
::
TopkResult
(
target_result
,
TOP_K
,
t
rue
,
t
arget_result
);
ASSERT_TRUE
(
status
.
ok
());
status
=
engine
::
SearchTask
::
TopkResult
(
target_result
,
TOP_K
,
src_result
);
status
=
engine
::
SearchTask
::
TopkResult
(
target_result
,
TOP_K
,
true
,
src_result
);
ASSERT_FALSE
(
status
.
ok
());
status
=
engine
::
SearchTask
::
TopkResult
(
src_result
,
TOP_K
,
target_result
);
status
=
engine
::
SearchTask
::
TopkResult
(
src_result
,
TOP_K
,
t
rue
,
t
arget_result
);
ASSERT_TRUE
(
status
.
ok
());
ASSERT_TRUE
(
src_result
.
empty
());
ASSERT_EQ
(
target_result
.
size
(),
NQ
);
...
...
@@ -92,7 +92,7 @@ TEST(DBSearchTest, TOPK_TEST) {
status
=
engine
::
SearchTask
::
ClusterResult
(
src_ids
,
src_distence
,
NQ
,
wrong_topk
,
src_result
);
ASSERT_TRUE
(
status
.
ok
());
status
=
engine
::
SearchTask
::
TopkResult
(
src_result
,
TOP_K
,
target_result
);
status
=
engine
::
SearchTask
::
TopkResult
(
src_result
,
TOP_K
,
t
rue
,
t
arget_result
);
ASSERT_TRUE
(
status
.
ok
());
for
(
uint64_t
i
=
0
;
i
<
NQ
;
i
++
)
{
ASSERT_EQ
(
target_result
[
i
].
size
(),
TOP_K
);
...
...
@@ -101,7 +101,7 @@ TEST(DBSearchTest, TOPK_TEST) {
wrong_topk
=
TOP_K
+
10
;
BuildResult
(
NQ
,
wrong_topk
,
src_ids
,
src_distence
);
status
=
engine
::
SearchTask
::
TopkResult
(
src_result
,
TOP_K
,
target_result
);
status
=
engine
::
SearchTask
::
TopkResult
(
src_result
,
TOP_K
,
t
rue
,
t
arget_result
);
ASSERT_TRUE
(
status
.
ok
());
for
(
uint64_t
i
=
0
;
i
<
NQ
;
i
++
)
{
ASSERT_EQ
(
target_result
[
i
].
size
(),
TOP_K
);
...
...
@@ -126,7 +126,7 @@ TEST(DBSearchTest, MERGE_TEST) {
{
engine
::
SearchContext
::
Id2DistanceMap
src
=
src_result
[
0
];
engine
::
SearchContext
::
Id2DistanceMap
target
=
target_result
[
0
];
status
=
engine
::
SearchTask
::
MergeResult
(
src
,
target
,
10
);
status
=
engine
::
SearchTask
::
MergeResult
(
src
,
target
,
10
,
true
);
ASSERT_TRUE
(
status
.
ok
());
ASSERT_EQ
(
target
.
size
(),
10
);
CheckResult
(
src_result
[
0
],
target_result
[
0
],
target
);
...
...
@@ -135,7 +135,7 @@ TEST(DBSearchTest, MERGE_TEST) {
{
engine
::
SearchContext
::
Id2DistanceMap
src
=
src_result
[
0
];
engine
::
SearchContext
::
Id2DistanceMap
target
;
status
=
engine
::
SearchTask
::
MergeResult
(
src
,
target
,
10
);
status
=
engine
::
SearchTask
::
MergeResult
(
src
,
target
,
10
,
true
);
ASSERT_TRUE
(
status
.
ok
());
ASSERT_EQ
(
target
.
size
(),
src_count
);
ASSERT_TRUE
(
src
.
empty
());
...
...
@@ -145,7 +145,7 @@ TEST(DBSearchTest, MERGE_TEST) {
{
engine
::
SearchContext
::
Id2DistanceMap
src
=
src_result
[
0
];
engine
::
SearchContext
::
Id2DistanceMap
target
=
target_result
[
0
];
status
=
engine
::
SearchTask
::
MergeResult
(
src
,
target
,
30
);
status
=
engine
::
SearchTask
::
MergeResult
(
src
,
target
,
30
,
true
);
ASSERT_TRUE
(
status
.
ok
());
ASSERT_EQ
(
target
.
size
(),
src_count
+
target_count
);
CheckResult
(
src_result
[
0
],
target_result
[
0
],
target
);
...
...
@@ -154,7 +154,7 @@ TEST(DBSearchTest, MERGE_TEST) {
{
engine
::
SearchContext
::
Id2DistanceMap
target
=
src_result
[
0
];
engine
::
SearchContext
::
Id2DistanceMap
src
=
target_result
[
0
];
status
=
engine
::
SearchTask
::
MergeResult
(
src
,
target
,
30
);
status
=
engine
::
SearchTask
::
MergeResult
(
src
,
target
,
30
,
true
);
ASSERT_TRUE
(
status
.
ok
());
ASSERT_EQ
(
target
.
size
(),
src_count
+
target_count
);
CheckResult
(
src_result
[
0
],
target_result
[
0
],
target
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录