Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
milvus
提交
274d96f9
milvus
项目概览
BaiXuePrincess
/
milvus
与 Fork 源项目一致
从无法访问的项目Fork
通知
7
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
milvus
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
274d96f9
编写于
10月 12, 2019
作者:
Y
yudong.cai
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
MS-606 optimize reduce API, update unittest
Former-commit-id: 1e79b7dba4b7c90e5218fc75e5cc552152ec4bbe
上级
bf1f7ce6
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
213 addition
and
153 deletion
+213
-153
cpp/src/scheduler/task/SearchTask.cpp
cpp/src/scheduler/task/SearchTask.cpp
+20
-27
cpp/src/scheduler/task/SearchTask.h
cpp/src/scheduler/task/SearchTask.h
+5
-16
cpp/unittest/db/test_search.cpp
cpp/unittest/db/test_search.cpp
+188
-110
未找到文件。
cpp/src/scheduler/task/SearchTask.cpp
浏览文件 @
274d96f9
...
@@ -155,8 +155,8 @@ XSearchTask::Load(LoadType type, uint8_t device_id) {
...
@@ -155,8 +155,8 @@ XSearchTask::Load(LoadType type, uint8_t device_id) {
size_t
file_size
=
index_engine_
->
PhysicalSize
();
size_t
file_size
=
index_engine_
->
PhysicalSize
();
std
::
string
info
=
"Load file id:"
+
std
::
to_string
(
file_
->
id_
)
+
std
::
string
info
=
"Load file id:"
+
std
::
to_string
(
file_
->
id_
)
+
" file type:"
+
" file type:"
+
std
::
to_string
(
file_
->
file_type_
)
+
" size:"
+
std
::
to_string
(
file_size
)
+
std
::
to_string
(
file_
->
file_type_
)
+
" size:"
+
std
::
to_string
(
file_size
)
+
" bytes from location: "
+
file_
->
location_
+
" totally cost"
;
" bytes from location: "
+
file_
->
location_
+
" totally cost"
;
double
span
=
rc
.
ElapseFromBegin
(
info
);
double
span
=
rc
.
ElapseFromBegin
(
info
);
// for (auto &context : search_contexts_) {
// for (auto &context : search_contexts_) {
...
@@ -209,7 +209,8 @@ XSearchTask::Execute() {
...
@@ -209,7 +209,8 @@ XSearchTask::Execute() {
// step 3: pick up topk result
// step 3: pick up topk result
auto
spec_k
=
index_engine_
->
Count
()
<
topk
?
index_engine_
->
Count
()
:
topk
;
auto
spec_k
=
index_engine_
->
Count
()
<
topk
?
index_engine_
->
Count
()
:
topk
;
XSearchTask
::
MergeTopkToResultSet
(
output_ids
,
output_distance
,
spec_k
,
nq
,
topk
,
metric_l2
,
search_job
->
GetResult
());
XSearchTask
::
MergeTopkToResultSet
(
output_ids
,
output_distance
,
spec_k
,
nq
,
topk
,
metric_l2
,
search_job
->
GetResult
());
span
=
rc
.
RecordSection
(
hdr
+
", reduce topk"
);
span
=
rc
.
RecordSection
(
hdr
+
", reduce topk"
);
// search_job->AccumReduceCost(span);
// search_job->AccumReduceCost(span);
...
@@ -229,12 +230,8 @@ XSearchTask::Execute() {
...
@@ -229,12 +230,8 @@ XSearchTask::Execute() {
}
}
void
void
XSearchTask
::
MergeTopkToResultSet
(
const
std
::
vector
<
int64_t
>&
input_ids
,
XSearchTask
::
MergeTopkToResultSet
(
const
std
::
vector
<
int64_t
>&
input_ids
,
const
std
::
vector
<
float
>&
input_distance
,
const
std
::
vector
<
float
>&
input_distance
,
uint64_t
input_k
,
uint64_t
nq
,
uint64_t
topk
,
bool
ascending
,
uint64_t
input_k
,
uint64_t
nq
,
uint64_t
topk
,
bool
ascending
,
scheduler
::
ResultSet
&
result
)
{
scheduler
::
ResultSet
&
result
)
{
if
(
result
.
empty
())
{
if
(
result
.
empty
())
{
result
.
resize
(
nq
);
result
.
resize
(
nq
);
...
@@ -242,14 +239,14 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids,
...
@@ -242,14 +239,14 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids,
for
(
uint64_t
i
=
0
;
i
<
nq
;
i
++
)
{
for
(
uint64_t
i
=
0
;
i
<
nq
;
i
++
)
{
scheduler
::
Id2DistVec
result_buf
;
scheduler
::
Id2DistVec
result_buf
;
auto
&
result_i
=
result
[
i
];
auto
&
result_i
=
result
[
i
];
if
(
result
[
i
].
empty
())
{
if
(
result
[
i
].
empty
())
{
result_buf
.
resize
(
input_k
,
scheduler
::
IdDistPair
(
-
1
,
0.0
));
result_buf
.
resize
(
input_k
,
scheduler
::
IdDistPair
(
-
1
,
0.0
));
uint64_t
input_k_multi_i
=
input_k
*
i
;
uint64_t
input_k_multi_i
=
input_k
*
i
;
for
(
auto
k
=
0
;
k
<
input_k
;
++
k
)
{
for
(
auto
k
=
0
;
k
<
input_k
;
++
k
)
{
uint64_t
idx
=
input_k_multi_i
+
k
;
uint64_t
idx
=
input_k_multi_i
+
k
;
auto
&
result_buf_item
=
result_buf
[
k
];
auto
&
result_buf_item
=
result_buf
[
k
];
result_buf_item
.
first
=
input_ids
[
idx
];
result_buf_item
.
first
=
input_ids
[
idx
];
result_buf_item
.
second
=
input_distance
[
idx
];
result_buf_item
.
second
=
input_distance
[
idx
];
}
}
...
@@ -262,8 +259,8 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids,
...
@@ -262,8 +259,8 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids,
uint64_t
input_k_multi_i
=
input_k
*
i
;
uint64_t
input_k_multi_i
=
input_k
*
i
;
while
(
buf_k
<
output_k
&&
src_k
<
input_k
&&
tar_k
<
tar_size
)
{
while
(
buf_k
<
output_k
&&
src_k
<
input_k
&&
tar_k
<
tar_size
)
{
src_idx
=
input_k_multi_i
+
src_k
;
src_idx
=
input_k_multi_i
+
src_k
;
auto
&
result_buf_item
=
result_buf
[
buf_k
];
auto
&
result_buf_item
=
result_buf
[
buf_k
];
auto
&
result_item
=
result_i
[
tar_k
];
auto
&
result_item
=
result_i
[
tar_k
];
if
((
ascending
&&
input_distance
[
src_idx
]
<
result_item
.
second
)
||
if
((
ascending
&&
input_distance
[
src_idx
]
<
result_item
.
second
)
||
(
!
ascending
&&
input_distance
[
src_idx
]
>
result_item
.
second
))
{
(
!
ascending
&&
input_distance
[
src_idx
]
>
result_item
.
second
))
{
result_buf_item
.
first
=
input_ids
[
src_idx
];
result_buf_item
.
first
=
input_ids
[
src_idx
];
...
@@ -280,7 +277,7 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids,
...
@@ -280,7 +277,7 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids,
if
(
src_k
<
input_k
)
{
if
(
src_k
<
input_k
)
{
while
(
buf_k
<
output_k
&&
src_k
<
input_k
)
{
while
(
buf_k
<
output_k
&&
src_k
<
input_k
)
{
src_idx
=
input_k_multi_i
+
src_k
;
src_idx
=
input_k_multi_i
+
src_k
;
auto
&
result_buf_item
=
result_buf
[
buf_k
];
auto
&
result_buf_item
=
result_buf
[
buf_k
];
result_buf_item
.
first
=
input_ids
[
src_idx
];
result_buf_item
.
first
=
input_ids
[
src_idx
];
result_buf_item
.
second
=
input_distance
[
src_idx
];
result_buf_item
.
second
=
input_distance
[
src_idx
];
src_k
++
;
src_k
++
;
...
@@ -301,19 +298,15 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids,
...
@@ -301,19 +298,15 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids,
}
}
void
void
XSearchTask
::
MergeTopkArray
(
std
::
vector
<
int64_t
>&
tar_ids
,
XSearchTask
::
MergeTopkArray
(
std
::
vector
<
int64_t
>&
tar_ids
,
std
::
vector
<
float
>&
tar_distance
,
uint64_t
&
tar_input_k
,
std
::
vector
<
float
>&
tar_distance
,
const
std
::
vector
<
int64_t
>&
src_ids
,
const
std
::
vector
<
float
>&
src_distance
,
uint64_t
&
tar_input_k
,
uint64_t
src_input_k
,
uint64_t
nq
,
uint64_t
topk
,
bool
ascending
)
{
const
std
::
vector
<
int64_t
>&
src_ids
,
if
(
src_ids
.
empty
()
||
src_distance
.
empty
())
{
const
std
::
vector
<
float
>&
src_distance
,
return
;
uint64_t
src_input_k
,
}
uint64_t
nq
,
uint64_t
topk
,
std
::
vector
<
int64_t
>
id_buf
(
nq
*
topk
,
-
1
);
bool
ascending
)
{
std
::
vector
<
float
>
dist_buf
(
nq
*
topk
,
0.0
);
if
(
src_ids
.
empty
()
||
src_distance
.
empty
())
return
;
std
::
vector
<
int64_t
>
id_buf
(
nq
*
topk
,
-
1
);
std
::
vector
<
float
>
dist_buf
(
nq
*
topk
,
0.0
);
uint64_t
output_k
=
std
::
min
(
topk
,
tar_input_k
+
src_input_k
);
uint64_t
output_k
=
std
::
min
(
topk
,
tar_input_k
+
src_input_k
);
uint64_t
buf_k
,
src_k
,
tar_k
;
uint64_t
buf_k
,
src_k
,
tar_k
;
...
...
cpp/src/scheduler/task/SearchTask.h
浏览文件 @
274d96f9
...
@@ -39,24 +39,13 @@ class XSearchTask : public Task {
...
@@ -39,24 +39,13 @@ class XSearchTask : public Task {
public:
public:
static
void
static
void
MergeTopkToResultSet
(
const
std
::
vector
<
int64_t
>&
input_ids
,
MergeTopkToResultSet
(
const
std
::
vector
<
int64_t
>&
input_ids
,
const
std
::
vector
<
float
>&
input_distance
,
const
std
::
vector
<
float
>&
input_distance
,
uint64_t
input_k
,
uint64_t
nq
,
uint64_t
topk
,
bool
ascending
,
scheduler
::
ResultSet
&
result
);
uint64_t
input_k
,
uint64_t
nq
,
uint64_t
topk
,
bool
ascending
,
scheduler
::
ResultSet
&
result
);
static
void
static
void
MergeTopkArray
(
std
::
vector
<
int64_t
>&
tar_ids
,
MergeTopkArray
(
std
::
vector
<
int64_t
>&
tar_ids
,
std
::
vector
<
float
>&
tar_distance
,
uint64_t
&
tar_input_k
,
std
::
vector
<
float
>&
tar_distance
,
const
std
::
vector
<
int64_t
>&
src_ids
,
const
std
::
vector
<
float
>&
src_distance
,
uint64_t
src_input_k
,
uint64_t
&
tar_input_k
,
uint64_t
nq
,
uint64_t
topk
,
bool
ascending
);
const
std
::
vector
<
int64_t
>&
src_ids
,
const
std
::
vector
<
float
>&
src_distance
,
uint64_t
src_input_k
,
uint64_t
nq
,
uint64_t
topk
,
bool
ascending
);
public:
public:
TableFileSchemaPtr
file_
;
TableFileSchemaPtr
file_
;
...
...
cpp/unittest/db/test_search.cpp
浏览文件 @
274d96f9
...
@@ -28,20 +28,44 @@ namespace {
...
@@ -28,20 +28,44 @@ namespace {
namespace
ms
=
milvus
::
scheduler
;
namespace
ms
=
milvus
::
scheduler
;
void
void
BuildResult
(
uint64_t
nq
,
BuildResult
(
std
::
vector
<
int64_t
>&
output_ids
,
std
::
vector
<
float
>&
output_distance
,
uint64_t
topk
,
uint64_t
topk
,
bool
ascending
,
uint64_t
nq
,
std
::
vector
<
int64_t
>&
output_ids
,
bool
ascending
)
{
std
::
vector
<
float
>&
output_distence
)
{
output_ids
.
clear
();
output_ids
.
clear
();
output_ids
.
resize
(
nq
*
topk
);
output_ids
.
resize
(
nq
*
topk
);
output_dist
e
nce
.
clear
();
output_dist
a
nce
.
clear
();
output_dist
e
nce
.
resize
(
nq
*
topk
);
output_dist
a
nce
.
resize
(
nq
*
topk
);
for
(
uint64_t
i
=
0
;
i
<
nq
;
i
++
)
{
for
(
uint64_t
i
=
0
;
i
<
nq
;
i
++
)
{
for
(
uint64_t
j
=
0
;
j
<
topk
;
j
++
)
{
for
(
uint64_t
j
=
0
;
j
<
topk
;
j
++
)
{
output_ids
[
i
*
topk
+
j
]
=
(
int64_t
)(
drand48
()
*
100000
);
output_ids
[
i
*
topk
+
j
]
=
(
int64_t
)(
drand48
()
*
100000
);
output_distence
[
i
*
topk
+
j
]
=
ascending
?
(
j
+
drand48
())
:
((
topk
-
j
)
+
drand48
());
output_distance
[
i
*
topk
+
j
]
=
ascending
?
(
j
+
drand48
())
:
((
topk
-
j
)
+
drand48
());
}
}
}
void
CopyResult
(
std
::
vector
<
int64_t
>&
output_ids
,
std
::
vector
<
float
>&
output_distance
,
uint64_t
output_topk
,
std
::
vector
<
int64_t
>&
input_ids
,
std
::
vector
<
float
>&
input_distance
,
uint64_t
input_topk
,
uint64_t
nq
)
{
ASSERT_TRUE
(
input_ids
.
size
()
>=
nq
*
input_topk
);
ASSERT_TRUE
(
input_distance
.
size
()
>=
nq
*
input_topk
);
ASSERT_TRUE
(
output_topk
<=
input_topk
);
output_ids
.
clear
();
output_ids
.
resize
(
nq
*
output_topk
);
output_distance
.
clear
();
output_distance
.
resize
(
nq
*
output_topk
);
for
(
uint64_t
i
=
0
;
i
<
nq
;
i
++
)
{
for
(
uint64_t
j
=
0
;
j
<
output_topk
;
j
++
)
{
output_ids
[
i
*
output_topk
+
j
]
=
input_ids
[
i
*
input_topk
+
j
];
output_distance
[
i
*
output_topk
+
j
]
=
input_distance
[
i
*
input_topk
+
j
];
}
}
}
}
}
}
...
@@ -51,8 +75,8 @@ CheckTopkResult(const std::vector<int64_t>& input_ids_1,
...
@@ -51,8 +75,8 @@ CheckTopkResult(const std::vector<int64_t>& input_ids_1,
const
std
::
vector
<
float
>&
input_distance_1
,
const
std
::
vector
<
float
>&
input_distance_1
,
const
std
::
vector
<
int64_t
>&
input_ids_2
,
const
std
::
vector
<
int64_t
>&
input_ids_2
,
const
std
::
vector
<
float
>&
input_distance_2
,
const
std
::
vector
<
float
>&
input_distance_2
,
uint64_t
nq
,
uint64_t
topk
,
uint64_t
topk
,
uint64_t
nq
,
bool
ascending
,
bool
ascending
,
const
milvus
::
scheduler
::
ResultSet
&
result
)
{
const
milvus
::
scheduler
::
ResultSet
&
result
)
{
ASSERT_EQ
(
result
.
size
(),
nq
);
ASSERT_EQ
(
result
.
size
(),
nq
);
...
@@ -96,32 +120,32 @@ TEST(DBSearchTest, TOPK_TEST) {
...
@@ -96,32 +120,32 @@ TEST(DBSearchTest, TOPK_TEST) {
/* test1, id1/dist1 valid, id2/dist2 empty */
/* test1, id1/dist1 valid, id2/dist2 empty */
ascending
=
true
;
ascending
=
true
;
BuildResult
(
NQ
,
TOP_K
,
ascending
,
ids1
,
dist1
);
BuildResult
(
ids1
,
dist1
,
TOP_K
,
NQ
,
ascending
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids1
,
dist1
,
TOP_K
,
NQ
,
TOP_K
,
ascending
,
result
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids1
,
dist1
,
TOP_K
,
NQ
,
TOP_K
,
ascending
,
result
);
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
NQ
,
TOP_K
,
ascending
,
result
);
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
TOP_K
,
NQ
,
ascending
,
result
);
/* test2, id1/dist1 valid, id2/dist2 valid */
/* test2, id1/dist1 valid, id2/dist2 valid */
BuildResult
(
NQ
,
TOP_K
,
ascending
,
ids2
,
dist2
);
BuildResult
(
ids2
,
dist2
,
TOP_K
,
NQ
,
ascending
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids2
,
dist2
,
TOP_K
,
NQ
,
TOP_K
,
ascending
,
result
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids2
,
dist2
,
TOP_K
,
NQ
,
TOP_K
,
ascending
,
result
);
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
NQ
,
TOP_K
,
ascending
,
result
);
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
TOP_K
,
NQ
,
ascending
,
result
);
/* test3, id1/dist1 small topk */
/* test3, id1/dist1 small topk */
ids1
.
clear
();
ids1
.
clear
();
dist1
.
clear
();
dist1
.
clear
();
result
.
clear
();
result
.
clear
();
BuildResult
(
NQ
,
TOP_K
/
2
,
ascending
,
ids1
,
dist1
);
BuildResult
(
ids1
,
dist1
,
TOP_K
/
2
,
NQ
,
ascending
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids1
,
dist1
,
TOP_K
/
2
,
NQ
,
TOP_K
,
ascending
,
result
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids1
,
dist1
,
TOP_K
/
2
,
NQ
,
TOP_K
,
ascending
,
result
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids2
,
dist2
,
TOP_K
,
NQ
,
TOP_K
,
ascending
,
result
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids2
,
dist2
,
TOP_K
,
NQ
,
TOP_K
,
ascending
,
result
);
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
NQ
,
TOP_K
,
ascending
,
result
);
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
TOP_K
,
NQ
,
ascending
,
result
);
/* test4, id1/dist1 small topk, id2/dist2 small topk */
/* test4, id1/dist1 small topk, id2/dist2 small topk */
ids2
.
clear
();
ids2
.
clear
();
dist2
.
clear
();
dist2
.
clear
();
result
.
clear
();
result
.
clear
();
BuildResult
(
NQ
,
TOP_K
/
3
,
ascending
,
ids2
,
dist2
);
BuildResult
(
ids2
,
dist2
,
TOP_K
/
3
,
NQ
,
ascending
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids1
,
dist1
,
TOP_K
/
2
,
NQ
,
TOP_K
,
ascending
,
result
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids1
,
dist1
,
TOP_K
/
2
,
NQ
,
TOP_K
,
ascending
,
result
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids2
,
dist2
,
TOP_K
/
3
,
NQ
,
TOP_K
,
ascending
,
result
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids2
,
dist2
,
TOP_K
/
3
,
NQ
,
TOP_K
,
ascending
,
result
);
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
NQ
,
TOP_K
,
ascending
,
result
);
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
TOP_K
,
NQ
,
ascending
,
result
);
/////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////
ascending
=
false
;
ascending
=
false
;
...
@@ -132,145 +156,199 @@ TEST(DBSearchTest, TOPK_TEST) {
...
@@ -132,145 +156,199 @@ TEST(DBSearchTest, TOPK_TEST) {
result
.
clear
();
result
.
clear
();
/* test1, id1/dist1 valid, id2/dist2 empty */
/* test1, id1/dist1 valid, id2/dist2 empty */
BuildResult
(
NQ
,
TOP_K
,
ascending
,
ids1
,
dist1
);
BuildResult
(
ids1
,
dist1
,
TOP_K
,
NQ
,
ascending
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids1
,
dist1
,
TOP_K
,
NQ
,
TOP_K
,
ascending
,
result
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids1
,
dist1
,
TOP_K
,
NQ
,
TOP_K
,
ascending
,
result
);
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
NQ
,
TOP_K
,
ascending
,
result
);
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
TOP_K
,
NQ
,
ascending
,
result
);
/* test2, id1/dist1 valid, id2/dist2 valid */
/* test2, id1/dist1 valid, id2/dist2 valid */
BuildResult
(
NQ
,
TOP_K
,
ascending
,
ids2
,
dist2
);
BuildResult
(
ids2
,
dist2
,
TOP_K
,
NQ
,
ascending
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids2
,
dist2
,
TOP_K
,
NQ
,
TOP_K
,
ascending
,
result
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids2
,
dist2
,
TOP_K
,
NQ
,
TOP_K
,
ascending
,
result
);
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
NQ
,
TOP_K
,
ascending
,
result
);
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
TOP_K
,
NQ
,
ascending
,
result
);
/* test3, id1/dist1 small topk */
/* test3, id1/dist1 small topk */
ids1
.
clear
();
ids1
.
clear
();
dist1
.
clear
();
dist1
.
clear
();
result
.
clear
();
result
.
clear
();
BuildResult
(
NQ
,
TOP_K
/
2
,
ascending
,
ids1
,
dist1
);
BuildResult
(
ids1
,
dist1
,
TOP_K
/
2
,
NQ
,
ascending
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids1
,
dist1
,
TOP_K
/
2
,
NQ
,
TOP_K
,
ascending
,
result
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids1
,
dist1
,
TOP_K
/
2
,
NQ
,
TOP_K
,
ascending
,
result
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids2
,
dist2
,
TOP_K
,
NQ
,
TOP_K
,
ascending
,
result
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids2
,
dist2
,
TOP_K
,
NQ
,
TOP_K
,
ascending
,
result
);
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
NQ
,
TOP_K
,
ascending
,
result
);
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
TOP_K
,
NQ
,
ascending
,
result
);
/* test4, id1/dist1 small topk, id2/dist2 small topk */
/* test4, id1/dist1 small topk, id2/dist2 small topk */
ids2
.
clear
();
ids2
.
clear
();
dist2
.
clear
();
dist2
.
clear
();
result
.
clear
();
result
.
clear
();
BuildResult
(
NQ
,
TOP_K
/
3
,
ascending
,
ids2
,
dist2
);
BuildResult
(
ids2
,
dist2
,
TOP_K
/
3
,
NQ
,
ascending
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids1
,
dist1
,
TOP_K
/
2
,
NQ
,
TOP_K
,
ascending
,
result
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids1
,
dist1
,
TOP_K
/
2
,
NQ
,
TOP_K
,
ascending
,
result
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids2
,
dist2
,
TOP_K
/
3
,
NQ
,
TOP_K
,
ascending
,
result
);
ms
::
XSearchTask
::
MergeTopkToResultSet
(
ids2
,
dist2
,
TOP_K
/
3
,
NQ
,
TOP_K
,
ascending
,
result
);
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
NQ
,
TOP_K
,
ascending
,
result
);
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
TOP_K
,
NQ
,
ascending
,
result
);
}
}
TEST
(
DBSearchTest
,
REDUCE_PERF_TEST
)
{
TEST
(
DBSearchTest
,
REDUCE_PERF_TEST
)
{
int32_t
nq
=
100
;
int32_t
top_k
=
1000
;
int32_t
index_file_num
=
478
;
/* sift1B dataset, index files num */
int32_t
index_file_num
=
478
;
/* sift1B dataset, index files num */
bool
ascending
=
true
;
bool
ascending
=
true
;
std
::
vector
<
int32_t
>
thread_vec
=
{
4
,
8
,
11
};
std
::
vector
<
int32_t
>
nq_vec
=
{
1
,
10
,
100
,
1000
};
std
::
vector
<
int32_t
>
topk_vec
=
{
1
,
4
,
16
,
64
,
256
,
1024
};
int32_t
NQ
=
nq_vec
[
nq_vec
.
size
()
-
1
];
int32_t
TOPK
=
topk_vec
[
topk_vec
.
size
()
-
1
];
std
::
vector
<
std
::
vector
<
int64_t
>>
id_vec
;
std
::
vector
<
std
::
vector
<
int64_t
>>
id_vec
;
std
::
vector
<
std
::
vector
<
float
>>
dist_vec
;
std
::
vector
<
std
::
vector
<
float
>>
dist_vec
;
std
::
vector
<
uint64_t
>
k_vec
;
std
::
vector
<
int64_t
>
input_ids
;
std
::
vector
<
int64_t
>
input_ids
;
std
::
vector
<
float
>
input_distance
;
std
::
vector
<
float
>
input_distance
;
ms
::
ResultSet
final_result
,
final_result_2
,
final_result_3
;
int32_t
i
,
k
,
step
;
int32_t
i
,
k
,
step
;
double
reduce_cost
=
0.0
;
milvus
::
TimeRecorder
rc
(
""
);
/* generate testing data */
for
(
i
=
0
;
i
<
index_file_num
;
i
++
)
{
for
(
i
=
0
;
i
<
index_file_num
;
i
++
)
{
BuildResult
(
nq
,
top_k
,
ascending
,
input_ids
,
input_distance
);
BuildResult
(
input_ids
,
input_distance
,
TOPK
,
NQ
,
ascending
);
id_vec
.
push_back
(
input_ids
);
id_vec
.
push_back
(
input_ids
);
dist_vec
.
push_back
(
input_distance
);
dist_vec
.
push_back
(
input_distance
);
k_vec
.
push_back
(
top_k
);
}
}
rc
.
RecordSection
(
"Method-1 result reduce start"
);
for
(
int32_t
max_thread_num
:
thread_vec
)
{
milvus
::
ThreadPool
threadPool
(
max_thread_num
);
/* method-1 */
std
::
list
<
std
::
future
<
void
>>
threads_list
;
for
(
i
=
0
;
i
<
index_file_num
;
i
++
)
{
ms
::
XSearchTask
::
MergeTopkToResultSet
(
id_vec
[
i
],
dist_vec
[
i
],
k_vec
[
i
],
nq
,
top_k
,
ascending
,
final_result
);
ASSERT_EQ
(
final_result
.
size
(),
nq
);
}
reduce_cost
=
rc
.
RecordSection
(
"Method-1 result reduce done"
);
std
::
cout
<<
"Method-1: total reduce time "
<<
reduce_cost
/
1000
<<
" ms"
<<
std
::
endl
;
/* method-2 */
std
::
vector
<
std
::
vector
<
int64_t
>>
id_vec_2
(
id_vec
);
std
::
vector
<
std
::
vector
<
float
>>
dist_vec_2
(
dist_vec
);
std
::
vector
<
uint64_t
>
k_vec_2
(
k_vec
);
rc
.
RecordSection
(
"Method-2 result reduce start"
);
for
(
int32_t
nq
:
nq_vec
)
{
for
(
int32_t
top_k
:
topk_vec
)
{
ms
::
ResultSet
final_result
,
final_result_2
,
final_result_3
;
for
(
step
=
1
;
step
<
index_file_num
;
step
*=
2
)
{
std
::
vector
<
std
::
vector
<
int64_t
>>
id_vec_1
(
index_file_num
);
for
(
i
=
0
;
i
+
step
<
index_file_num
;
i
+=
step
*
2
)
{
std
::
vector
<
std
::
vector
<
float
>>
dist_vec_1
(
index_file_num
);
ms
::
XSearchTask
::
MergeTopkArray
(
id_vec_2
[
i
],
dist_vec_2
[
i
],
k_vec_2
[
i
],
for
(
i
=
0
;
i
<
index_file_num
;
i
++
)
{
id_vec_2
[
i
+
step
],
dist_vec_2
[
i
+
step
],
k_vec_2
[
i
+
step
],
CopyResult
(
id_vec_1
[
i
],
dist_vec_1
[
i
],
top_k
,
id_vec
[
i
],
dist_vec
[
i
],
TOPK
,
nq
);
nq
,
top_k
,
ascending
);
}
}
}
ms
::
XSearchTask
::
MergeTopkToResultSet
(
id_vec_2
[
0
],
dist_vec_2
[
0
],
k_vec_2
[
0
],
nq
,
top_k
,
ascending
,
final_result_2
);
ASSERT_EQ
(
final_result_2
.
size
(),
nq
);
reduce_cost
=
rc
.
RecordSection
(
"Method-2 result reduce done"
);
std
::
cout
<<
"Method-2: total reduce time "
<<
reduce_cost
/
1000
<<
" ms"
<<
std
::
endl
;
for
(
i
=
0
;
i
<
nq
;
i
++
)
{
std
::
string
str1
=
"Method-1 "
+
std
::
to_string
(
max_thread_num
)
+
" "
+
ASSERT_EQ
(
final_result
[
i
].
size
(),
final_result_2
[
i
].
size
());
std
::
to_string
(
nq
)
+
" "
+
std
::
to_string
(
top_k
);
for
(
k
=
0
;
k
<
final_result
.
size
();
k
++
)
{
milvus
::
TimeRecorder
rc1
(
str1
);
ASSERT_EQ
(
final_result
[
i
][
k
].
first
,
final_result_2
[
i
][
k
].
first
);
ASSERT_EQ
(
final_result
[
i
][
k
].
second
,
final_result_2
[
i
][
k
].
second
);
///////////////////////////////////////////////////////////////////////////////////////
}
/* method-1 */
}
for
(
i
=
0
;
i
<
index_file_num
;
i
++
)
{
ms
::
XSearchTask
::
MergeTopkToResultSet
(
id_vec_1
[
i
],
dist_vec_1
[
i
],
top_k
,
nq
,
top_k
,
ascending
,
final_result
);
ASSERT_EQ
(
final_result
.
size
(),
nq
);
}
/* method-3 parallel */
rc1
.
RecordSection
(
"reduce done"
);
std
::
vector
<
std
::
vector
<
int64_t
>>
id_vec_3
(
id_vec
);
std
::
vector
<
std
::
vector
<
float
>>
dist_vec_3
(
dist_vec
);
std
::
vector
<
uint64_t
>
k_vec_3
(
k_vec
);
uint32_t
max_thread_count
=
std
::
min
(
std
::
thread
::
hardware_concurrency
()
-
1
,
(
uint32_t
)
MAX_THREADS_NUM
);
///////////////////////////////////////////////////////////////////////////////////////
milvus
::
ThreadPool
threadPool
(
max_thread_count
);
/* method-2 */
std
::
list
<
std
::
future
<
void
>>
threads_list
;
std
::
vector
<
std
::
vector
<
int64_t
>>
id_vec_2
(
index_file_num
);
std
::
vector
<
std
::
vector
<
float
>>
dist_vec_2
(
index_file_num
);
std
::
vector
<
uint64_t
>
k_vec_2
(
index_file_num
);
for
(
i
=
0
;
i
<
index_file_num
;
i
++
)
{
CopyResult
(
id_vec_2
[
i
],
dist_vec_2
[
i
],
top_k
,
id_vec
[
i
],
dist_vec
[
i
],
TOPK
,
nq
);
k_vec_2
[
i
]
=
top_k
;
}
rc
.
RecordSection
(
"Method-3 parallel result reduce start"
);
std
::
string
str2
=
"Method-2 "
+
std
::
to_string
(
max_thread_num
)
+
" "
+
std
::
to_string
(
nq
)
+
" "
+
std
::
to_string
(
top_k
);
milvus
::
TimeRecorder
rc2
(
str2
);
for
(
step
=
1
;
step
<
index_file_num
;
step
*=
2
)
{
for
(
step
=
1
;
step
<
index_file_num
;
step
*=
2
)
{
for
(
i
=
0
;
i
+
step
<
index_file_num
;
i
+=
step
*
2
)
{
for
(
i
=
0
;
i
+
step
<
index_file_num
;
i
+=
step
*
2
)
{
threads_list
.
push_back
(
ms
::
XSearchTask
::
MergeTopkArray
(
id_vec_2
[
i
],
dist_vec_2
[
i
],
k_vec_2
[
i
],
threadPool
.
enqueue
(
ms
::
XSearchTask
::
MergeTopkArray
,
id_vec_2
[
i
+
step
],
dist_vec_2
[
i
+
step
],
k_vec_2
[
i
+
step
],
std
::
ref
(
id_vec_3
[
i
]),
std
::
ref
(
dist_vec_3
[
i
]),
std
::
ref
(
k_vec_3
[
i
]),
nq
,
top_k
,
ascending
);
std
::
ref
(
id_vec_3
[
i
+
step
]),
std
::
ref
(
dist_vec_3
[
i
+
step
]),
std
::
ref
(
k_vec_3
[
i
+
step
]),
}
nq
,
top_k
,
ascending
));
}
}
ms
::
XSearchTask
::
MergeTopkToResultSet
(
id_vec_2
[
0
],
dist_vec_2
[
0
],
k_vec_2
[
0
],
nq
,
top_k
,
ascending
,
final_result_2
);
ASSERT_EQ
(
final_result_2
.
size
(),
nq
);
rc2
.
RecordSection
(
"reduce done"
);
for
(
i
=
0
;
i
<
nq
;
i
++
)
{
ASSERT_EQ
(
final_result
[
i
].
size
(),
final_result_2
[
i
].
size
());
for
(
k
=
0
;
k
<
final_result
[
i
].
size
();
k
++
)
{
if
(
final_result
[
i
][
k
].
first
!=
final_result_2
[
i
][
k
].
first
)
{
std
::
cout
<<
i
<<
" "
<<
k
<<
std
::
endl
;
}
ASSERT_EQ
(
final_result
[
i
][
k
].
first
,
final_result_2
[
i
][
k
].
first
);
ASSERT_EQ
(
final_result
[
i
][
k
].
second
,
final_result_2
[
i
][
k
].
second
);
}
}
while
(
threads_list
.
size
()
>
0
)
{
///////////////////////////////////////////////////////////////////////////////////////
int
nready
=
0
;
/* method-3 parallel */
for
(
auto
it
=
threads_list
.
begin
();
it
!=
threads_list
.
end
();
it
=
it
)
{
std
::
vector
<
std
::
vector
<
int64_t
>>
id_vec_3
(
index_file_num
);
auto
&
p
=
*
it
;
std
::
vector
<
std
::
vector
<
float
>>
dist_vec_3
(
index_file_num
);
std
::
chrono
::
milliseconds
span
(
0
);
std
::
vector
<
uint64_t
>
k_vec_3
(
index_file_num
);
if
(
p
.
wait_for
(
span
)
==
std
::
future_status
::
ready
)
{
for
(
i
=
0
;
i
<
index_file_num
;
i
++
)
{
threads_list
.
erase
(
it
++
);
CopyResult
(
id_vec_3
[
i
],
dist_vec_3
[
i
],
top_k
,
id_vec
[
i
],
dist_vec
[
i
],
TOPK
,
nq
);
++
nready
;
k_vec_3
[
i
]
=
top_k
;
}
else
{
++
it
;
}
}
}
if
(
nready
==
0
)
{
std
::
string
str3
=
"Method-3 "
+
std
::
to_string
(
max_thread_num
)
+
" "
+
std
::
this_thread
::
yield
();
std
::
to_string
(
nq
)
+
" "
+
std
::
to_string
(
top_k
);
milvus
::
TimeRecorder
rc3
(
str3
);
for
(
step
=
1
;
step
<
index_file_num
;
step
*=
2
)
{
for
(
i
=
0
;
i
+
step
<
index_file_num
;
i
+=
step
*
2
)
{
threads_list
.
push_back
(
threadPool
.
enqueue
(
ms
::
XSearchTask
::
MergeTopkArray
,
std
::
ref
(
id_vec_3
[
i
]),
std
::
ref
(
dist_vec_3
[
i
]),
std
::
ref
(
k_vec_3
[
i
]),
std
::
ref
(
id_vec_3
[
i
+
step
]),
std
::
ref
(
dist_vec_3
[
i
+
step
]),
std
::
ref
(
k_vec_3
[
i
+
step
]),
nq
,
top_k
,
ascending
));
}
while
(
threads_list
.
size
()
>
0
)
{
int
nready
=
0
;
for
(
auto
it
=
threads_list
.
begin
();
it
!=
threads_list
.
end
();
it
=
it
)
{
auto
&
p
=
*
it
;
std
::
chrono
::
milliseconds
span
(
0
);
if
(
p
.
wait_for
(
span
)
==
std
::
future_status
::
ready
)
{
threads_list
.
erase
(
it
++
);
++
nready
;
}
else
{
++
it
;
}
}
if
(
nready
==
0
)
{
std
::
this_thread
::
yield
();
}
}
}
ms
::
XSearchTask
::
MergeTopkToResultSet
(
id_vec_3
[
0
],
dist_vec_3
[
0
],
k_vec_3
[
0
],
nq
,
top_k
,
ascending
,
final_result_3
);
ASSERT_EQ
(
final_result_3
.
size
(),
nq
);
rc3
.
RecordSection
(
"reduce done"
);
for
(
i
=
0
;
i
<
nq
;
i
++
)
{
ASSERT_EQ
(
final_result
[
i
].
size
(),
final_result_3
[
i
].
size
());
for
(
k
=
0
;
k
<
final_result
[
i
].
size
();
k
++
)
{
ASSERT_EQ
(
final_result
[
i
][
k
].
first
,
final_result_3
[
i
][
k
].
first
);
ASSERT_EQ
(
final_result
[
i
][
k
].
second
,
final_result_3
[
i
][
k
].
second
);
}
}
}
}
}
}
}
}
ms
::
XSearchTask
::
MergeTopkToResultSet
(
id_vec_3
[
0
],
dist_vec_3
[
0
],
k_vec_3
[
0
],
nq
,
top_k
,
ascending
,
final_result_3
);
ASSERT_EQ
(
final_result_3
.
size
(),
nq
);
reduce_cost
=
rc
.
RecordSection
(
"Method-3 parallel result reduce done"
);
std
::
cout
<<
"Method-3 parallel: total reduce time "
<<
reduce_cost
/
1000
<<
" ms"
<<
std
::
endl
;
for
(
i
=
0
;
i
<
nq
;
i
++
)
{
ASSERT_EQ
(
final_result
[
i
].
size
(),
final_result_3
[
i
].
size
());
for
(
k
=
0
;
k
<
final_result
.
size
();
k
++
)
{
ASSERT_EQ
(
final_result
[
i
][
k
].
first
,
final_result_3
[
i
][
k
].
first
);
ASSERT_EQ
(
final_result
[
i
][
k
].
second
,
final_result_3
[
i
][
k
].
second
);
}
}
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录