Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
milvus
提交
6d1e1578
milvus
项目概览
BaiXuePrincess
/
milvus
与 Fork 源项目一致
从无法访问的项目Fork
通知
7
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
milvus
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
6d1e1578
编写于
10月 08, 2019
作者:
Y
yudong.cai
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
MS-606 speed up result reduce
Former-commit-id: 3414caf6afa687d79637890dd0f34c4d6c6dcd03
上级
e884fdb0
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
204 addition
and
382 deletion
+204
-382
cpp/CHANGELOG.md
cpp/CHANGELOG.md
+3
-2
cpp/src/scheduler/job/SearchJob.h
cpp/src/scheduler/job/SearchJob.h
+3
-2
cpp/src/scheduler/task/SearchTask.cpp
cpp/src/scheduler/task/SearchTask.cpp
+70
-143
cpp/src/scheduler/task/SearchTask.h
cpp/src/scheduler/task/SearchTask.h
+7
-9
cpp/unittest/db/test_search.cpp
cpp/unittest/db/test_search.cpp
+121
-226
未找到文件。
cpp/CHANGELOG.md
浏览文件 @
6d1e1578
...
...
@@ -14,15 +14,16 @@ Please mark all change in change log and use the ticket from JIRA.
## Improvement
-
MS-552 - Add and change the easylogging library
-
MS-553 - Refine cache code
-
MS-55
7 - Merge Log.h
-
MS-55
5 - Remove old scheduler
-
MS-556 - Add Job Definition in Scheduler
-
MS-557 - Merge Log.h
-
MS-558 - Refine status code
-
MS-562 - Add JobMgr and TaskCreator in Scheduler
-
MS-566 - Refactor cmake
-
MS-555 - Remove old scheduler
-
MS-574 - Milvus configuration refactor
-
MS-578 - Make sure milvus5.0 don't crack 0.3.1 data
-
MS-585 - Update namespace in scheduler
-
MS-606 - Speed up result reduce
-
MS-608 - Update TODO names
-
MS-609 - Update task construct function
...
...
cpp/src/scheduler/job/SearchJob.h
浏览文件 @
6d1e1578
...
...
@@ -37,8 +37,9 @@ namespace scheduler {
using
engine
::
meta
::
TableFileSchemaPtr
;
using
Id2IndexMap
=
std
::
unordered_map
<
size_t
,
TableFileSchemaPtr
>
;
using
Id2DistanceMap
=
std
::
vector
<
std
::
pair
<
int64_t
,
double
>>
;
using
ResultSet
=
std
::
vector
<
Id2DistanceMap
>
;
using
IdDistPair
=
std
::
pair
<
int64_t
,
double
>
;
using
Id2DistVec
=
std
::
vector
<
IdDistPair
>
;
using
ResultSet
=
std
::
vector
<
Id2DistVec
>
;
class
SearchJob
:
public
Job
{
public:
...
...
cpp/src/scheduler/task/SearchTask.cpp
浏览文件 @
6d1e1578
...
...
@@ -78,18 +78,19 @@ std::mutex XSearchTask::merge_mutex_;
void
CollectFileMetrics
(
int
file_type
,
size_t
file_size
)
{
server
::
MetricsBase
&
inst
=
server
::
Metrics
::
GetInstance
();
switch
(
file_type
)
{
case
TableFileSchema
::
RAW
:
case
TableFileSchema
::
TO_INDEX
:
{
server
::
Metrics
::
GetInstance
()
.
RawFileSizeHistogramObserve
(
file_size
);
server
::
Metrics
::
GetInstance
()
.
RawFileSizeTotalIncrement
(
file_size
);
server
::
Metrics
::
GetInstance
()
.
RawFileSizeGaugeSet
(
file_size
);
inst
.
RawFileSizeHistogramObserve
(
file_size
);
inst
.
RawFileSizeTotalIncrement
(
file_size
);
inst
.
RawFileSizeGaugeSet
(
file_size
);
break
;
}
default:
{
server
::
Metrics
::
GetInstance
()
.
IndexFileSizeHistogramObserve
(
file_size
);
server
::
Metrics
::
GetInstance
()
.
IndexFileSizeTotalIncrement
(
file_size
);
server
::
Metrics
::
GetInstance
()
.
IndexFileSizeGaugeSet
(
file_size
);
inst
.
IndexFileSizeHistogramObserve
(
file_size
);
inst
.
IndexFileSizeTotalIncrement
(
file_size
);
inst
.
IndexFileSizeGaugeSet
(
file_size
);
break
;
}
}
...
...
@@ -206,16 +207,9 @@ XSearchTask::Execute() {
double
span
=
rc
.
RecordSection
(
hdr
+
", do search"
);
// search_job->AccumSearchCost(span);
// step 3: cluster result
scheduler
::
ResultSet
result_set
;
// step 3: pick up topk result
auto
spec_k
=
index_engine_
->
Count
()
<
topk
?
index_engine_
->
Count
()
:
topk
;
XSearchTask
::
ClusterResult
(
output_ids
,
output_distance
,
nq
,
spec_k
,
result_set
);
span
=
rc
.
RecordSection
(
hdr
+
", cluster result"
);
// search_job->AccumReduceCost(span);
// step 4: pick up topk result
XSearchTask
::
TopkResult
(
result_set
,
topk
,
metric_l2
,
search_job
->
GetResult
());
XSearchTask
::
TopkResult
(
output_ids
,
output_distance
,
spec_k
,
nq
,
topk
,
metric_l2
,
search_job
->
GetResult
());
span
=
rc
.
RecordSection
(
hdr
+
", reduce topk"
);
// search_job->AccumReduceCost(span);
...
...
@@ -235,142 +229,75 @@ XSearchTask::Execute() {
}
Status
XSearchTask
::
ClusterResult
(
const
std
::
vector
<
int64_t
>&
output_ids
,
const
std
::
vector
<
float
>&
output_distance
,
uint64_t
nq
,
uint64_t
topk
,
scheduler
::
ResultSet
&
result_set
)
{
if
(
output_ids
.
size
()
<
nq
*
topk
||
output_distance
.
size
()
<
nq
*
topk
)
{
std
::
string
msg
=
"Invalid id array size: "
+
std
::
to_string
(
output_ids
.
size
())
+
" distance array size: "
+
std
::
to_string
(
output_distance
.
size
());
ENGINE_LOG_ERROR
<<
msg
;
return
Status
(
DB_ERROR
,
msg
);
}
result_set
.
clear
();
result_set
.
resize
(
nq
);
std
::
function
<
void
(
size_t
,
size_t
)
>
reduce_worker
=
[
&
](
size_t
from_index
,
size_t
to_index
)
{
for
(
auto
i
=
from_index
;
i
<
to_index
;
i
++
)
{
scheduler
::
Id2DistanceMap
id_distance
;
id_distance
.
reserve
(
topk
);
for
(
auto
k
=
0
;
k
<
topk
;
k
++
)
{
uint64_t
index
=
i
*
topk
+
k
;
if
(
output_ids
[
index
]
<
0
)
{
continue
;
}
id_distance
.
push_back
(
std
::
make_pair
(
output_ids
[
index
],
output_distance
[
index
]));
XSearchTask
::
TopkResult
(
const
std
::
vector
<
long
>
&
input_ids
,
const
std
::
vector
<
float
>
&
input_distance
,
uint64_t
input_k
,
uint64_t
nq
,
uint64_t
topk
,
bool
ascending
,
scheduler
::
ResultSet
&
result
)
{
scheduler
::
ResultSet
result_buf
;
if
(
result
.
empty
())
{
result_buf
.
resize
(
nq
,
scheduler
::
Id2DistVec
(
input_k
,
scheduler
::
IdDistPair
(
-
1
,
0.0
)));
for
(
auto
i
=
0
;
i
<
nq
;
++
i
)
{
auto
&
result_buf_i
=
result_buf
[
i
];
uint64_t
input_k_multi_i
=
input_k
*
i
;
for
(
auto
k
=
0
;
k
<
input_k
;
++
k
)
{
uint64_t
idx
=
input_k_multi_i
+
k
;
auto
&
result_buf_item
=
result_buf_i
[
k
];
result_buf_item
.
first
=
input_ids
[
idx
];
result_buf_item
.
second
=
input_distance
[
idx
];
}
result_set
[
i
]
=
id_distance
;
}
};
// if (NeedParallelReduce(nq, topk)) {
// ParallelReduce(reduce_worker, nq);
// } else {
reduce_worker
(
0
,
nq
);
// }
return
Status
::
OK
();
}
Status
XSearchTask
::
MergeResult
(
scheduler
::
Id2DistanceMap
&
distance_src
,
scheduler
::
Id2DistanceMap
&
distance_target
,
uint64_t
topk
,
bool
ascending
)
{
// Note: the score_src and score_target are already arranged by score in ascending order
if
(
distance_src
.
empty
())
{
ENGINE_LOG_WARNING
<<
"Empty distance source array"
;
return
Status
::
OK
();
}
std
::
unique_lock
<
std
::
mutex
>
lock
(
merge_mutex_
);
if
(
distance_target
.
empty
())
{
distance_target
.
swap
(
distance_src
);
return
Status
::
OK
();
}
size_t
src_count
=
distance_src
.
size
();
size_t
target_count
=
distance_target
.
size
();
scheduler
::
Id2DistanceMap
distance_merged
;
distance_merged
.
reserve
(
topk
);
size_t
src_index
=
0
,
target_index
=
0
;
while
(
true
)
{
// all score_src items are merged, if score_merged.size() still less than topk
// move items from score_target to score_merged until score_merged.size() equal topk
if
(
src_index
>=
src_count
)
{
for
(
size_t
i
=
target_index
;
i
<
target_count
&&
distance_merged
.
size
()
<
topk
;
++
i
)
{
distance_merged
.
push_back
(
distance_target
[
i
]);
}
break
;
}
// all score_target items are merged, if score_merged.size() still less than topk
// move items from score_src to score_merged until score_merged.size() equal topk
if
(
target_index
>=
target_count
)
{
for
(
size_t
i
=
src_index
;
i
<
src_count
&&
distance_merged
.
size
()
<
topk
;
++
i
)
{
distance_merged
.
push_back
(
distance_src
[
i
]);
}
else
{
size_t
tar_size
=
result
[
0
].
size
();
uint64_t
output_k
=
std
::
min
(
topk
,
input_k
+
tar_size
);
result_buf
.
resize
(
nq
,
scheduler
::
Id2DistVec
(
output_k
,
scheduler
::
IdDistPair
(
-
1
,
0.0
)));
for
(
auto
i
=
0
;
i
<
nq
;
++
i
)
{
size_t
buf_k
=
0
,
src_k
=
0
,
tar_k
=
0
;
uint64_t
src_idx
;
auto
&
result_i
=
result
[
i
];
auto
&
result_buf_i
=
result_buf
[
i
];
uint64_t
input_k_multi_i
=
input_k
*
i
;
while
(
buf_k
<
output_k
&&
src_k
<
input_k
&&
tar_k
<
tar_size
)
{
src_idx
=
input_k_multi_i
+
src_k
;
auto
&
result_buf_item
=
result_buf_i
[
buf_k
];
auto
&
result_item
=
result_i
[
tar_k
];
if
((
ascending
&&
input_distance
[
src_idx
]
<
result_item
.
second
)
||
(
!
ascending
&&
input_distance
[
src_idx
]
>
result_item
.
second
))
{
result_buf_item
.
first
=
input_ids
[
src_idx
];
result_buf_item
.
second
=
input_distance
[
src_idx
];
src_k
++
;
}
else
{
result_buf_item
=
result_item
;
tar_k
++
;
}
buf_k
++
;
}
break
;
}
// compare score,
// if ascending = true, put smallest score to score_merged one by one
// else, put largest score to score_merged one by one
auto
&
src_pair
=
distance_src
[
src_index
];
auto
&
target_pair
=
distance_target
[
target_index
];
if
(
ascending
)
{
if
(
src_pair
.
second
>
target_pair
.
second
)
{
distance_merged
.
push_back
(
target_pair
);
target_index
++
;
}
else
{
distance_merged
.
push_back
(
src_pair
);
src_index
++
;
}
}
else
{
if
(
src_pair
.
second
<
target_pair
.
second
)
{
distance_merged
.
push_back
(
target_pair
);
target_index
++
;
}
else
{
distance_merged
.
push_back
(
src_pair
);
src_index
++
;
if
(
buf_k
<
topk
)
{
if
(
src_k
<
input_k
)
{
while
(
buf_k
<
output_k
&&
src_k
<
input_k
)
{
src_idx
=
input_k_multi_i
+
src_k
;
auto
&
result_buf_item
=
result_buf_i
[
buf_k
];
result_buf_item
.
first
=
input_ids
[
src_idx
];
result_buf_item
.
second
=
input_distance
[
src_idx
];
src_k
++
;
buf_k
++
;
}
}
else
{
while
(
buf_k
<
output_k
&&
tar_k
<
tar_size
)
{
result_buf_i
[
buf_k
]
=
result_i
[
tar_k
];
tar_k
++
;
buf_k
++
;
}
}
}
}
// score_merged.size() already equal topk
if
(
distance_merged
.
size
()
>=
topk
)
{
break
;
}
}
distance_target
.
swap
(
distance_merged
);
return
Status
::
OK
();
}
Status
XSearchTask
::
TopkResult
(
scheduler
::
ResultSet
&
result_src
,
uint64_t
topk
,
bool
ascending
,
scheduler
::
ResultSet
&
result_target
)
{
if
(
result_target
.
empty
())
{
result_target
.
swap
(
result_src
);
return
Status
::
OK
();
}
if
(
result_src
.
size
()
!=
result_target
.
size
())
{
std
::
string
msg
=
"Invalid result set size"
;
ENGINE_LOG_ERROR
<<
msg
;
return
Status
(
DB_ERROR
,
msg
);
}
std
::
function
<
void
(
size_t
,
size_t
)
>
ReduceWorker
=
[
&
](
size_t
from_index
,
size_t
to_index
)
{
for
(
size_t
i
=
from_index
;
i
<
to_index
;
i
++
)
{
scheduler
::
Id2DistanceMap
&
score_src
=
result_src
[
i
];
scheduler
::
Id2DistanceMap
&
score_target
=
result_target
[
i
];
XSearchTask
::
MergeResult
(
score_src
,
score_target
,
topk
,
ascending
);
}
};
// if (NeedParallelReduce(result_src.size(), topk)) {
// ParallelReduce(ReduceWorker, result_src.size());
// } else {
ReduceWorker
(
0
,
result_src
.
size
());
// }
result
.
swap
(
result_buf
);
return
Status
::
OK
();
}
...
...
cpp/src/scheduler/task/SearchTask.h
浏览文件 @
6d1e1578
...
...
@@ -39,15 +39,13 @@ class XSearchTask : public Task {
public:
static
Status
ClusterResult
(
const
std
::
vector
<
int64_t
>&
output_ids
,
const
std
::
vector
<
float
>&
output_distence
,
uint64_t
nq
,
uint64_t
topk
,
scheduler
::
ResultSet
&
result_set
);
static
Status
MergeResult
(
scheduler
::
Id2DistanceMap
&
distance_src
,
scheduler
::
Id2DistanceMap
&
distance_target
,
uint64_t
topk
,
bool
ascending
);
static
Status
TopkResult
(
scheduler
::
ResultSet
&
result_src
,
uint64_t
topk
,
bool
ascending
,
scheduler
::
ResultSet
&
result_target
);
TopkResult
(
const
std
::
vector
<
long
>
&
input_ids
,
const
std
::
vector
<
float
>
&
input_distance
,
uint64_t
input_k
,
uint64_t
nq
,
uint64_t
topk
,
bool
ascending
,
scheduler
::
ResultSet
&
result
);
public:
TableFileSchemaPtr
file_
;
...
...
cpp/unittest/db/test_search.cpp
浏览文件 @
6d1e1578
...
...
@@ -22,12 +22,9 @@
#include "scheduler/task/SearchTask.h"
#include "utils/TimeRecorder.h"
namespace
{
namespace
ms
=
milvus
;
using
namespace
milvus
::
scheduler
;
static
constexpr
uint64_t
NQ
=
15
;
static
constexpr
uint64_t
TOP_K
=
64
;
namespace
{
void
BuildResult
(
uint64_t
nq
,
...
...
@@ -48,76 +45,36 @@ BuildResult(uint64_t nq,
}
}
void
CheckResult
(
const
ms
::
scheduler
::
Id2DistanceMap
&
src_1
,
const
ms
::
scheduler
::
Id2DistanceMap
&
src_2
,
const
ms
::
scheduler
::
Id2DistanceMap
&
target
,
bool
ascending
)
{
for
(
uint64_t
i
=
0
;
i
<
target
.
size
()
-
1
;
i
++
)
{
if
(
ascending
)
{
ASSERT_LE
(
target
[
i
].
second
,
target
[
i
+
1
].
second
);
}
else
{
ASSERT_GE
(
target
[
i
].
second
,
target
[
i
+
1
].
second
);
}
}
using
ID2DistMap
=
std
::
map
<
int64_t
,
float
>
;
ID2DistMap
src_map_1
,
src_map_2
;
for
(
const
auto
&
pair
:
src_1
)
{
src_map_1
.
insert
(
pair
);
}
for
(
const
auto
&
pair
:
src_2
)
{
src_map_2
.
insert
(
pair
);
}
for
(
const
auto
&
pair
:
target
)
{
ASSERT_TRUE
(
src_map_1
.
find
(
pair
.
first
)
!=
src_map_1
.
end
()
||
src_map_2
.
find
(
pair
.
first
)
!=
src_map_2
.
end
());
void
CheckTopkResult
(
const
std
::
vector
<
long
>
&
input_ids_1
,
const
std
::
vector
<
float
>
&
input_distance_1
,
const
std
::
vector
<
long
>
&
input_ids_2
,
const
std
::
vector
<
float
>
&
input_distance_2
,
uint64_t
nq
,
uint64_t
topk
,
bool
ascending
,
const
ResultSet
&
result
)
{
ASSERT_EQ
(
result
.
size
(),
nq
);
ASSERT_EQ
(
input_ids_1
.
size
(),
input_distance_1
.
size
());
ASSERT_EQ
(
input_ids_2
.
size
(),
input_distance_2
.
size
());
uint64_t
input_k1
=
input_ids_1
.
size
()
/
nq
;
uint64_t
input_k2
=
input_ids_2
.
size
()
/
nq
;
float
dist
=
src_map_1
.
find
(
pair
.
first
)
!=
src_map_1
.
end
()
?
src_map_1
[
pair
.
first
]
:
src_map_2
[
pair
.
first
];
ASSERT_LT
(
fabs
(
pair
.
second
-
dist
),
std
::
numeric_limits
<
float
>::
epsilon
());
}
}
void
CheckCluster
(
const
std
::
vector
<
int64_t
>
&
target_ids
,
const
std
::
vector
<
float
>
&
target_distence
,
const
ms
::
scheduler
::
ResultSet
&
src_result
,
int64_t
nq
,
int64_t
topk
)
{
ASSERT_EQ
(
src_result
.
size
(),
nq
);
for
(
int64_t
i
=
0
;
i
<
nq
;
i
++
)
{
auto
&
res
=
src_result
[
i
];
ASSERT_EQ
(
res
.
size
(),
topk
);
if
(
res
.
empty
())
{
continue
;
}
ASSERT_EQ
(
res
[
0
].
first
,
target_ids
[
i
*
topk
]);
ASSERT_EQ
(
res
[
topk
-
1
].
first
,
target_ids
[
i
*
topk
+
topk
-
1
]);
}
}
void
CheckTopkResult
(
const
ms
::
scheduler
::
ResultSet
&
src_result
,
bool
ascending
,
int64_t
nq
,
int64_t
topk
)
{
ASSERT_EQ
(
src_result
.
size
(),
nq
);
for
(
int64_t
i
=
0
;
i
<
nq
;
i
++
)
{
auto
&
res
=
src_result
[
i
];
ASSERT_EQ
(
res
.
size
(),
topk
);
if
(
res
.
empty
())
{
continue
;
std
::
vector
<
float
>
src_vec
(
input_distance_1
.
begin
()
+
i
*
input_k1
,
input_distance_1
.
begin
()
+
(
i
+
1
)
*
input_k1
);
src_vec
.
insert
(
src_vec
.
end
(),
input_distance_2
.
begin
()
+
i
*
input_k2
,
input_distance_2
.
begin
()
+
(
i
+
1
)
*
input_k2
);
if
(
ascending
)
{
std
::
sort
(
src_vec
.
begin
(),
src_vec
.
end
());
}
else
{
std
::
sort
(
src_vec
.
begin
(),
src_vec
.
end
(),
std
::
greater
<
float
>
());
}
for
(
int64_t
k
=
0
;
k
<
topk
-
1
;
k
++
)
{
if
(
ascending
)
{
ASSERT_LE
(
res
[
k
].
second
,
res
[
k
+
1
].
second
);
}
else
{
ASSERT_GE
(
res
[
k
].
second
,
res
[
k
+
1
].
second
);
uint64_t
n
=
std
::
min
(
topk
,
input_k1
+
input_k2
);
for
(
uint64_t
j
=
0
;
j
<
n
;
j
++
)
{
if
(
src_vec
[
j
]
!=
result
[
i
][
j
].
second
)
{
std
::
cout
<<
src_vec
[
j
]
<<
" "
<<
result
[
i
][
j
].
second
<<
std
::
endl
;
}
ASSERT_TRUE
(
src_vec
[
j
]
==
result
[
i
][
j
].
second
);
}
}
}
...
...
@@ -125,179 +82,117 @@ CheckTopkResult(const ms::scheduler::ResultSet &src_result,
}
// namespace
TEST
(
DBSearchTest
,
TOPK_TEST
)
{
bool
ascending
=
true
;
std
::
vector
<
int64_t
>
target_ids
;
std
::
vector
<
float
>
target_distence
;
ms
::
scheduler
::
ResultSet
src_result
;
auto
status
=
ms
::
scheduler
::
XSearchTask
::
ClusterResult
(
target_ids
,
target_distence
,
NQ
,
TOP_K
,
src_result
);
ASSERT_FALSE
(
status
.
ok
());
ASSERT_TRUE
(
src_result
.
empty
());
BuildResult
(
NQ
,
TOP_K
,
ascending
,
target_ids
,
target_distence
);
status
=
ms
::
scheduler
::
XSearchTask
::
ClusterResult
(
target_ids
,
target_distence
,
NQ
,
TOP_K
,
src_result
);
uint64_t
NQ
=
15
;
uint64_t
TOP_K
=
64
;
bool
ascending
;
std
::
vector
<
long
>
ids1
,
ids2
;
std
::
vector
<
float
>
dist1
,
dist2
;
ResultSet
result
;
milvus
::
Status
status
;
/* test1, id1/dist1 valid, id2/dist2 empty */
ascending
=
true
;
BuildResult
(
NQ
,
TOP_K
,
ascending
,
ids1
,
dist1
);
status
=
XSearchTask
::
TopkResult
(
ids1
,
dist1
,
TOP_K
,
NQ
,
TOP_K
,
ascending
,
result
);
ASSERT_TRUE
(
status
.
ok
());
ASSERT_EQ
(
src_result
.
size
(),
NQ
);
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
NQ
,
TOP_K
,
ascending
,
result
);
ms
::
scheduler
::
ResultSet
target_result
;
status
=
ms
::
scheduler
::
XSearchTask
::
TopkResult
(
target_result
,
TOP_K
,
ascending
,
target_result
);
/* test2, id1/dist1 valid, id2/dist2 valid */
BuildResult
(
NQ
,
TOP_K
,
ascending
,
ids2
,
dist2
);
status
=
XSearchTask
::
TopkResult
(
ids2
,
dist2
,
TOP_K
,
NQ
,
TOP_K
,
ascending
,
result
);
ASSERT_TRUE
(
status
.
ok
());
status
=
ms
::
scheduler
::
XSearchTask
::
TopkResult
(
target_result
,
TOP_K
,
ascending
,
src_result
);
ASSERT_FALSE
(
status
.
ok
());
status
=
ms
::
scheduler
::
XSearchTask
::
TopkResult
(
src_result
,
TOP_K
,
ascending
,
target_result
);
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
NQ
,
TOP_K
,
ascending
,
result
);
/* test3, id1/dist1 small topk */
ids1
.
clear
();
dist1
.
clear
();
result
.
clear
();
BuildResult
(
NQ
,
TOP_K
/
2
,
ascending
,
ids1
,
dist1
);
status
=
XSearchTask
::
TopkResult
(
ids1
,
dist1
,
TOP_K
/
2
,
NQ
,
TOP_K
,
ascending
,
result
);
ASSERT_TRUE
(
status
.
ok
());
ASSERT_TRUE
(
src_result
.
empty
());
ASSERT_EQ
(
target_result
.
size
(),
NQ
);
std
::
vector
<
int64_t
>
src_ids
;
std
::
vector
<
float
>
src_distence
;
uint64_t
wrong_topk
=
TOP_K
-
10
;
BuildResult
(
NQ
,
wrong_topk
,
ascending
,
src_ids
,
src_distence
);
status
=
ms
::
scheduler
::
XSearchTask
::
ClusterResult
(
src_ids
,
src_distence
,
NQ
,
wrong_topk
,
src_result
);
status
=
XSearchTask
::
TopkResult
(
ids2
,
dist2
,
TOP_K
,
NQ
,
TOP_K
,
ascending
,
result
);
ASSERT_TRUE
(
status
.
ok
());
status
=
ms
::
scheduler
::
XSearchTask
::
TopkResult
(
src_result
,
TOP_K
,
ascending
,
target_result
);
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
NQ
,
TOP_K
,
ascending
,
result
);
/* test4, id1/dist1 small topk, id2/dist2 small topk */
ids2
.
clear
();
dist2
.
clear
();
result
.
clear
();
BuildResult
(
NQ
,
TOP_K
/
3
,
ascending
,
ids2
,
dist2
);
status
=
XSearchTask
::
TopkResult
(
ids1
,
dist1
,
TOP_K
/
2
,
NQ
,
TOP_K
,
ascending
,
result
);
ASSERT_TRUE
(
status
.
ok
());
for
(
uint64_t
i
=
0
;
i
<
NQ
;
i
++
)
{
ASSERT_EQ
(
target_result
[
i
].
size
(),
TOP_K
);
}
wrong_topk
=
TOP_K
+
10
;
BuildResult
(
NQ
,
wrong_topk
,
ascending
,
src_ids
,
src_distence
);
status
=
ms
::
scheduler
::
XSearchTask
::
TopkResult
(
src_result
,
TOP_K
,
ascending
,
target_result
);
status
=
XSearchTask
::
TopkResult
(
ids2
,
dist2
,
TOP_K
/
3
,
NQ
,
TOP_K
,
ascending
,
result
);
ASSERT_TRUE
(
status
.
ok
());
for
(
uint64_t
i
=
0
;
i
<
NQ
;
i
++
)
{
ASSERT_EQ
(
target_result
[
i
].
size
(),
TOP_K
);
}
}
TEST
(
DBSearchTest
,
MERGE_TEST
)
{
bool
ascending
=
true
;
std
::
vector
<
int64_t
>
target_ids
;
std
::
vector
<
float
>
target_distence
;
std
::
vector
<
int64_t
>
src_ids
;
std
::
vector
<
float
>
src_distence
;
ms
::
scheduler
::
ResultSet
src_result
,
target_result
;
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
NQ
,
TOP_K
,
ascending
,
result
);
/////////////////////////////////////////////////////////////////////////////////////////
ascending
=
false
;
ids1
.
clear
();
dist1
.
clear
();
ids2
.
clear
();
dist2
.
clear
();
result
.
clear
();
/* test1, id1/dist1 valid, id2/dist2 empty */
BuildResult
(
NQ
,
TOP_K
,
ascending
,
ids1
,
dist1
);
status
=
XSearchTask
::
TopkResult
(
ids1
,
dist1
,
TOP_K
,
NQ
,
TOP_K
,
ascending
,
result
);
ASSERT_TRUE
(
status
.
ok
());
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
NQ
,
TOP_K
,
ascending
,
result
);
uint64_t
src_count
=
5
,
target_count
=
8
;
BuildResult
(
1
,
src_count
,
ascending
,
src_ids
,
src_distence
);
BuildResult
(
1
,
target_count
,
ascending
,
target_ids
,
target_distence
);
auto
status
=
ms
::
scheduler
::
XSearchTask
::
ClusterResult
(
src_ids
,
src_distence
,
1
,
src_count
,
src_result
);
/* test2, id1/dist1 valid, id2/dist2 valid */
BuildResult
(
NQ
,
TOP_K
,
ascending
,
ids2
,
dist2
);
status
=
XSearchTask
::
TopkResult
(
ids2
,
dist2
,
TOP_K
,
NQ
,
TOP_K
,
ascending
,
result
);
ASSERT_TRUE
(
status
.
ok
());
status
=
ms
::
scheduler
::
XSearchTask
::
ClusterResult
(
target_ids
,
target_distence
,
1
,
target_count
,
target_result
);
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
NQ
,
TOP_K
,
ascending
,
result
);
/* test3, id1/dist1 small topk */
ids1
.
clear
();
dist1
.
clear
();
result
.
clear
();
BuildResult
(
NQ
,
TOP_K
/
2
,
ascending
,
ids1
,
dist1
);
status
=
XSearchTask
::
TopkResult
(
ids1
,
dist1
,
TOP_K
/
2
,
NQ
,
TOP_K
,
ascending
,
result
);
ASSERT_TRUE
(
status
.
ok
());
{
ms
::
scheduler
::
Id2DistanceMap
src
=
src_result
[
0
];
ms
::
scheduler
::
Id2DistanceMap
target
=
target_result
[
0
];
status
=
ms
::
scheduler
::
XSearchTask
::
MergeResult
(
src
,
target
,
10
,
ascending
);
ASSERT_TRUE
(
status
.
ok
());
ASSERT_EQ
(
target
.
size
(),
10
);
CheckResult
(
src_result
[
0
],
target_result
[
0
],
target
,
ascending
);
}
{
ms
::
scheduler
::
Id2DistanceMap
src
=
src_result
[
0
];
ms
::
scheduler
::
Id2DistanceMap
target
;
status
=
ms
::
scheduler
::
XSearchTask
::
MergeResult
(
src
,
target
,
10
,
ascending
);
ASSERT_TRUE
(
status
.
ok
());
ASSERT_EQ
(
target
.
size
(),
src_count
);
ASSERT_TRUE
(
src
.
empty
());
CheckResult
(
src_result
[
0
],
target_result
[
0
],
target
,
ascending
);
}
{
ms
::
scheduler
::
Id2DistanceMap
src
=
src_result
[
0
];
ms
::
scheduler
::
Id2DistanceMap
target
=
target_result
[
0
];
status
=
ms
::
scheduler
::
XSearchTask
::
MergeResult
(
src
,
target
,
30
,
ascending
);
ASSERT_TRUE
(
status
.
ok
());
ASSERT_EQ
(
target
.
size
(),
src_count
+
target_count
);
CheckResult
(
src_result
[
0
],
target_result
[
0
],
target
,
ascending
);
}
{
ms
::
scheduler
::
Id2DistanceMap
target
=
src_result
[
0
];
ms
::
scheduler
::
Id2DistanceMap
src
=
target_result
[
0
];
status
=
ms
::
scheduler
::
XSearchTask
::
MergeResult
(
src
,
target
,
30
,
ascending
);
ASSERT_TRUE
(
status
.
ok
());
ASSERT_EQ
(
target
.
size
(),
src_count
+
target_count
);
CheckResult
(
src_result
[
0
],
target_result
[
0
],
target
,
ascending
);
}
status
=
XSearchTask
::
TopkResult
(
ids2
,
dist2
,
TOP_K
,
NQ
,
TOP_K
,
ascending
,
result
);
ASSERT_TRUE
(
status
.
ok
());
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
NQ
,
TOP_K
,
ascending
,
result
);
/* test4, id1/dist1 small topk, id2/dist2 small topk */
ids2
.
clear
();
dist2
.
clear
();
result
.
clear
();
BuildResult
(
NQ
,
TOP_K
/
3
,
ascending
,
ids2
,
dist2
);
status
=
XSearchTask
::
TopkResult
(
ids1
,
dist1
,
TOP_K
/
2
,
NQ
,
TOP_K
,
ascending
,
result
);
ASSERT_TRUE
(
status
.
ok
());
status
=
XSearchTask
::
TopkResult
(
ids2
,
dist2
,
TOP_K
/
3
,
NQ
,
TOP_K
,
ascending
,
result
);
ASSERT_TRUE
(
status
.
ok
());
CheckTopkResult
(
ids1
,
dist1
,
ids2
,
dist2
,
NQ
,
TOP_K
,
ascending
,
result
);
}
TEST
(
DBSearchTest
,
PARALLEL_CLUSTER_TEST
)
{
TEST
(
DBSearchTest
,
REDUCE_PERF_TEST
)
{
int32_t
nq
=
100
;
int32_t
top_k
=
1000
;
int32_t
index_file_num
=
478
;
/* sift1B dataset, index files num */
bool
ascending
=
true
;
std
::
vector
<
int64_t
>
target_ids
;
std
::
vector
<
float
>
target_distence
;
ms
::
scheduler
::
ResultSet
src_result
;
std
::
vector
<
long
>
input_ids
;
std
::
vector
<
float
>
input_distance
;
ResultSet
final_result
;
milvus
::
Status
status
;
auto
DoCluster
=
[
&
](
int64_t
nq
,
int64_t
topk
)
{
ms
::
TimeRecorder
rc
(
"DoCluster"
);
src_result
.
clear
();
BuildResult
(
nq
,
topk
,
ascending
,
target_ids
,
target_distence
);
rc
.
RecordSection
(
"build id/dietance map"
);
double
span
,
reduce_cost
=
0.0
;
milvus
::
TimeRecorder
rc
(
""
);
auto
status
=
ms
::
scheduler
::
XSearchTask
::
ClusterResult
(
target_ids
,
target_distence
,
nq
,
topk
,
src_result
);
ASSERT_TRUE
(
status
.
ok
());
ASSERT_EQ
(
src_result
.
size
(),
nq
);
rc
.
RecordSection
(
"cluster result"
);
for
(
int32_t
i
=
0
;
i
<
index_file_num
;
i
++
)
{
BuildResult
(
nq
,
top_k
,
ascending
,
input_ids
,
input_distance
);
CheckCluster
(
target_ids
,
target_distence
,
src_result
,
nq
,
topk
);
rc
.
RecordSection
(
"check result"
);
};
DoCluster
(
10000
,
1000
);
DoCluster
(
333
,
999
);
DoCluster
(
1
,
1000
);
DoCluster
(
1
,
1
);
DoCluster
(
7
,
0
);
DoCluster
(
9999
,
1
);
DoCluster
(
10001
,
1
);
DoCluster
(
58273
,
1234
);
}
rc
.
RecordSection
(
"do search for context: "
+
std
::
to_string
(
i
));
TEST
(
DBSearchTest
,
PARALLEL_TOPK_TEST
)
{
std
::
vector
<
int64_t
>
target_ids
;
std
::
vector
<
float
>
target_distence
;
ms
::
scheduler
::
ResultSet
src_result
;
std
::
vector
<
int64_t
>
insufficient_ids
;
std
::
vector
<
float
>
insufficient_distence
;
ms
::
scheduler
::
ResultSet
insufficient_result
;
auto
DoTopk
=
[
&
](
int64_t
nq
,
int64_t
topk
,
int64_t
insufficient_topk
,
bool
ascending
)
{
src_result
.
clear
();
insufficient_result
.
clear
();
ms
::
TimeRecorder
rc
(
"DoCluster"
);
BuildResult
(
nq
,
topk
,
ascending
,
target_ids
,
target_distence
);
auto
status
=
ms
::
scheduler
::
XSearchTask
::
ClusterResult
(
target_ids
,
target_distence
,
nq
,
topk
,
src_result
);
rc
.
RecordSection
(
"cluster result"
);
BuildResult
(
nq
,
insufficient_topk
,
ascending
,
insufficient_ids
,
insufficient_distence
);
status
=
ms
::
scheduler
::
XSearchTask
::
ClusterResult
(
target_ids
,
target_distence
,
nq
,
insufficient_topk
,
insufficient_result
);
rc
.
RecordSection
(
"cluster result"
);
ms
::
scheduler
::
XSearchTask
::
TopkResult
(
insufficient_result
,
topk
,
ascending
,
src_result
);
// pick up topk result
status
=
XSearchTask
::
TopkResult
(
input_ids
,
input_distance
,
top_k
,
nq
,
top_k
,
ascending
,
final_result
);
ASSERT_TRUE
(
status
.
ok
());
rc
.
RecordSection
(
"topk"
);
ASSERT_EQ
(
final_result
.
size
(),
nq
);
CheckTopkResult
(
src_result
,
ascending
,
nq
,
topk
);
rc
.
RecordSection
(
"check result"
);
};
DoTopk
(
5
,
10
,
4
,
false
);
DoTopk
(
20005
,
998
,
123
,
true
);
// DoTopk(9987, 12, 10, false);
// DoTopk(77777, 1000, 1, false);
// DoTopk(5432, 8899, 8899, true);
span
=
rc
.
RecordSection
(
"reduce topk for context: "
+
std
::
to_string
(
i
));
reduce_cost
+=
span
;
}
std
::
cout
<<
"total reduce time: "
<<
reduce_cost
/
1000
<<
" ms"
<<
std
::
endl
;
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录