Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
milvus
milvus
提交
ed029178
M
milvus
项目概览
milvus
/
milvus
12 个月 前同步成功
通知
261
Star
22476
Fork
2472
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
milvus
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ed029178
编写于
7月 24, 2019
作者:
S
starlord
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
MS-266 Improve topk reduce time by using multi-threads
Former-commit-id: 79e4cfe6ade7b0fc059cd246c4b670a4b5343ca3
上级
3fdebc91
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
247 addition
and
56 deletion
+247
-56
cpp/conf/server_config.template
cpp/conf/server_config.template
+2
-0
cpp/src/db/DBImpl.cpp
cpp/src/db/DBImpl.cpp
+9
-9
cpp/src/db/scheduler/task/SearchTask.cpp
cpp/src/db/scheduler/task/SearchTask.cpp
+87
-21
cpp/src/server/ServerConfig.h
cpp/src/server/ServerConfig.h
+1
-0
cpp/unittest/db/search_test.cpp
cpp/unittest/db/search_test.cpp
+148
-26
未找到文件。
cpp/conf/server_config.template
浏览文件 @
ed029178
...
...
@@ -8,6 +8,8 @@ db_config:
db_path: @MILVUS_DB_PATH@ # milvus data storage path
db_slave_path: # secondry data storage path, split by semicolon
parallel_reduce: true # use multi-threads to reduce topk result
# URI format: dialect://username:password@host:port/database
# All parts except dialect are optional, but you MUST include the delimiters
# Currently dialect supports mysql or sqlite
...
...
cpp/src/db/DBImpl.cpp
浏览文件 @
ed029178
...
...
@@ -297,7 +297,7 @@ void DBImpl::StartMetricTask() {
return
;
}
ENGINE_LOG_
DEBUG
<<
"Start metric task"
;
ENGINE_LOG_
INFO
<<
"Start metric task"
;
server
::
Metrics
::
GetInstance
().
KeepingAliveCounterIncrement
(
METRIC_ACTION_INTERVAL
);
int64_t
cache_usage
=
cache
::
CpuCacheMgr
::
GetInstance
()
->
CacheUsage
();
...
...
@@ -312,7 +312,7 @@ void DBImpl::StartMetricTask() {
server
::
Metrics
::
GetInstance
().
GPUMemoryUsageGaugeSet
();
server
::
Metrics
::
GetInstance
().
OctetsSet
();
ENGINE_LOG_
DEBUG
<<
"Metric task finished"
;
ENGINE_LOG_
INFO
<<
"Metric task finished"
;
}
void
DBImpl
::
StartCompactionTask
()
{
...
...
@@ -322,8 +322,6 @@ void DBImpl::StartCompactionTask() {
return
;
}
ENGINE_LOG_DEBUG
<<
"Serialize insert cache"
;
//serialize memory data
std
::
set
<
std
::
string
>
temp_table_ids
;
mem_mgr_
->
Serialize
(
temp_table_ids
);
...
...
@@ -331,7 +329,9 @@ void DBImpl::StartCompactionTask() {
compact_table_ids_
.
insert
(
id
);
}
ENGINE_LOG_DEBUG
<<
"Insert cache serialized"
;
if
(
!
temp_table_ids
.
empty
())
{
SERVER_LOG_DEBUG
<<
"Insert cache serialized"
;
}
//compactiong has been finished?
if
(
!
compact_thread_results_
.
empty
())
{
...
...
@@ -433,7 +433,7 @@ Status DBImpl::BackgroundMergeFiles(const std::string& table_id) {
}
void
DBImpl
::
BackgroundCompaction
(
std
::
set
<
std
::
string
>
table_ids
)
{
ENGINE_LOG_
DEBUG
<<
" Background compaction thread start"
;
ENGINE_LOG_
INFO
<<
" Background compaction thread start"
;
Status
status
;
for
(
auto
&
table_id
:
table_ids
)
{
...
...
@@ -452,7 +452,7 @@ void DBImpl::BackgroundCompaction(std::set<std::string> table_ids) {
}
meta_ptr_
->
CleanUpFilesWithTTL
(
ttl
);
ENGINE_LOG_
DEBUG
<<
" Background compaction thread exit"
;
ENGINE_LOG_
INFO
<<
" Background compaction thread exit"
;
}
void
DBImpl
::
StartBuildIndexTask
(
bool
force
)
{
...
...
@@ -581,7 +581,7 @@ Status DBImpl::BuildIndexByTable(const std::string& table_id) {
}
void
DBImpl
::
BackgroundBuildIndex
()
{
ENGINE_LOG_
DEBUG
<<
" Background build index thread start"
;
ENGINE_LOG_
INFO
<<
" Background build index thread start"
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
build_index_mutex_
);
meta
::
TableFilesSchema
to_index_files
;
...
...
@@ -599,7 +599,7 @@ void DBImpl::BackgroundBuildIndex() {
}
}
ENGINE_LOG_
DEBUG
<<
" Background build index thread exit"
;
ENGINE_LOG_
INFO
<<
" Background build index thread exit"
;
}
Status
DBImpl
::
DropAll
()
{
...
...
cpp/src/db/scheduler/task/SearchTask.cpp
浏览文件 @
ed029178
...
...
@@ -5,14 +5,60 @@
******************************************************************************/
#include "SearchTask.h"
#include "metrics/Metrics.h"
#include "
utils
/Log.h"
#include "
db
/Log.h"
#include "utils/TimeRecorder.h"
#include <thread>
namespace
zilliz
{
namespace
milvus
{
namespace
engine
{
namespace
{
static
constexpr
size_t
PARALLEL_REDUCE_THRESHOLD
=
10000
;
static
constexpr
size_t
PARALLEL_REDUCE_BATCH
=
1000
;
bool
NeedParallelReduce
(
uint64_t
nq
,
uint64_t
topk
)
{
server
::
ServerConfig
&
config
=
server
::
ServerConfig
::
GetInstance
();
server
::
ConfigNode
&
db_config
=
config
.
GetConfig
(
server
::
CONFIG_DB
);
bool
need_parallel
=
db_config
.
GetBoolValue
(
server
::
CONFIG_DB_PARALLEL_REDUCE
,
true
);
if
(
!
need_parallel
)
{
return
false
;
}
return
nq
*
topk
>=
PARALLEL_REDUCE_THRESHOLD
;
}
void
ParallelReduce
(
std
::
function
<
void
(
size_t
,
size_t
)
>&
reduce_function
,
size_t
max_index
)
{
size_t
reduce_batch
=
PARALLEL_REDUCE_BATCH
;
auto
thread_count
=
std
::
thread
::
hardware_concurrency
()
-
1
;
//not all core do this work
if
(
thread_count
>
0
)
{
reduce_batch
=
max_index
/
thread_count
+
1
;
}
ENGINE_LOG_DEBUG
<<
"use "
<<
thread_count
<<
" thread parallelly do reduce, each thread process "
<<
reduce_batch
<<
" vectors"
;
std
::
vector
<
std
::
shared_ptr
<
std
::
thread
>
>
thread_array
;
size_t
from_index
=
0
;
while
(
from_index
<
max_index
)
{
size_t
to_index
=
from_index
+
reduce_batch
;
if
(
to_index
>
max_index
)
{
to_index
=
max_index
;
}
auto
reduce_thread
=
std
::
make_shared
<
std
::
thread
>
(
reduce_function
,
from_index
,
to_index
);
thread_array
.
push_back
(
reduce_thread
);
from_index
=
to_index
;
}
for
(
auto
&
thread_ptr
:
thread_array
)
{
thread_ptr
->
join
();
}
}
void
CollectDurationMetrics
(
int
index_type
,
double
total_time
)
{
switch
(
index_type
)
{
case
meta
::
TableFileSchema
::
RAW
:
{
...
...
@@ -32,7 +78,7 @@ void CollectDurationMetrics(int index_type, double total_time) {
std
::
string
GetMetricType
()
{
server
::
ServerConfig
&
config
=
server
::
ServerConfig
::
GetInstance
();
server
::
ConfigNode
engine_config
=
config
.
GetConfig
(
server
::
CONFIG_ENGINE
);
server
::
ConfigNode
&
engine_config
=
config
.
GetConfig
(
server
::
CONFIG_ENGINE
);
return
engine_config
.
GetValue
(
server
::
CONFIG_METRICTYPE
,
"L2"
);
}
...
...
@@ -51,7 +97,7 @@ std::shared_ptr<IScheduleTask> SearchTask::Execute() {
return
nullptr
;
}
SERVER
_LOG_DEBUG
<<
"Searching in file id:"
<<
index_id_
<<
" with "
ENGINE
_LOG_DEBUG
<<
"Searching in file id:"
<<
index_id_
<<
" with "
<<
search_contexts_
.
size
()
<<
" tasks"
;
server
::
TimeRecorder
rc
(
"DoSearch file id:"
+
std
::
to_string
(
index_id_
));
...
...
@@ -79,6 +125,9 @@ std::shared_ptr<IScheduleTask> SearchTask::Execute() {
auto
spec_k
=
index_engine_
->
Count
()
<
context
->
topk
()
?
index_engine_
->
Count
()
:
context
->
topk
();
SearchTask
::
ClusterResult
(
output_ids
,
output_distence
,
context
->
nq
(),
spec_k
,
result_set
);
span
=
rc
.
RecordSection
(
"cluster result for context:"
+
context
->
Identity
());
context
->
AccumReduceCost
(
span
);
//step 4: pick up topk result
SearchTask
::
TopkResult
(
result_set
,
inner_k
,
metric_l2
,
context
->
GetResult
());
...
...
@@ -86,7 +135,7 @@ std::shared_ptr<IScheduleTask> SearchTask::Execute() {
context
->
AccumReduceCost
(
span
);
}
catch
(
std
::
exception
&
ex
)
{
SERVER
_LOG_ERROR
<<
"SearchTask encounter exception: "
<<
ex
.
what
();
ENGINE
_LOG_ERROR
<<
"SearchTask encounter exception: "
<<
ex
.
what
();
context
->
IndexSearchDone
(
index_id_
);
//mark as done avoid dead lock, even search failed
continue
;
}
...
...
@@ -112,23 +161,32 @@ Status SearchTask::ClusterResult(const std::vector<long> &output_ids,
if
(
output_ids
.
size
()
<
nq
*
topk
||
output_distence
.
size
()
<
nq
*
topk
)
{
std
::
string
msg
=
"Invalid id array size: "
+
std
::
to_string
(
output_ids
.
size
())
+
" distance array size: "
+
std
::
to_string
(
output_distence
.
size
());
SERVER
_LOG_ERROR
<<
msg
;
ENGINE
_LOG_ERROR
<<
msg
;
return
Status
::
Error
(
msg
);
}
result_set
.
clear
();
result_set
.
reserve
(
nq
);
for
(
auto
i
=
0
;
i
<
nq
;
i
++
)
{
SearchContext
::
Id2DistanceMap
id_distance
;
id_distance
.
reserve
(
topk
);
for
(
auto
k
=
0
;
k
<
topk
;
k
++
)
{
uint64_t
index
=
i
*
topk
+
k
;
if
(
output_ids
[
index
]
<
0
)
{
continue
;
result_set
.
resize
(
nq
);
std
::
function
<
void
(
size_t
,
size_t
)
>
reduce_worker
=
[
&
](
size_t
from_index
,
size_t
to_index
)
{
for
(
auto
i
=
from_index
;
i
<
to_index
;
i
++
)
{
SearchContext
::
Id2DistanceMap
id_distance
;
id_distance
.
reserve
(
topk
);
for
(
auto
k
=
0
;
k
<
topk
;
k
++
)
{
uint64_t
index
=
i
*
topk
+
k
;
if
(
output_ids
[
index
]
<
0
)
{
continue
;
}
id_distance
.
push_back
(
std
::
make_pair
(
output_ids
[
index
],
output_distence
[
index
]));
}
id_distance
.
push_back
(
std
::
make_pair
(
output_ids
[
index
],
output_distence
[
index
]))
;
result_set
[
i
]
=
id_distance
;
}
result_set
.
emplace_back
(
id_distance
);
};
if
(
NeedParallelReduce
(
nq
,
topk
))
{
ParallelReduce
(
reduce_worker
,
nq
);
}
else
{
reduce_worker
(
0
,
nq
);
}
return
Status
::
OK
();
...
...
@@ -140,7 +198,7 @@ Status SearchTask::MergeResult(SearchContext::Id2DistanceMap &distance_src,
bool
ascending
)
{
//Note: the score_src and score_target are already arranged by score in ascending order
if
(
distance_src
.
empty
())
{
SERVER
_LOG_WARNING
<<
"Empty distance source array"
;
ENGINE
_LOG_WARNING
<<
"Empty distance source array"
;
return
Status
::
OK
();
}
...
...
@@ -218,14 +276,22 @@ Status SearchTask::TopkResult(SearchContext::ResultSet &result_src,
if
(
result_src
.
size
()
!=
result_target
.
size
())
{
std
::
string
msg
=
"Invalid result set size"
;
SERVER
_LOG_ERROR
<<
msg
;
ENGINE
_LOG_ERROR
<<
msg
;
return
Status
::
Error
(
msg
);
}
for
(
size_t
i
=
0
;
i
<
result_src
.
size
();
i
++
)
{
SearchContext
::
Id2DistanceMap
&
score_src
=
result_src
[
i
];
SearchContext
::
Id2DistanceMap
&
score_target
=
result_target
[
i
];
SearchTask
::
MergeResult
(
score_src
,
score_target
,
topk
,
ascending
);
std
::
function
<
void
(
size_t
,
size_t
)
>
ReduceWorker
=
[
&
](
size_t
from_index
,
size_t
to_index
)
{
for
(
size_t
i
=
from_index
;
i
<
to_index
;
i
++
)
{
SearchContext
::
Id2DistanceMap
&
score_src
=
result_src
[
i
];
SearchContext
::
Id2DistanceMap
&
score_target
=
result_target
[
i
];
SearchTask
::
MergeResult
(
score_src
,
score_target
,
topk
,
ascending
);
}
};
if
(
NeedParallelReduce
(
result_src
.
size
(),
topk
))
{
ParallelReduce
(
ReduceWorker
,
result_src
.
size
());
}
else
{
ReduceWorker
(
0
,
result_src
.
size
());
}
return
Status
::
OK
();
...
...
cpp/src/server/ServerConfig.h
浏览文件 @
ed029178
...
...
@@ -29,6 +29,7 @@ static const std::string CONFIG_DB_INDEX_TRIGGER_SIZE = "index_building_threshol
static
const
std
::
string
CONFIG_DB_ARCHIVE_DISK
=
"archive_disk_threshold"
;
static
const
std
::
string
CONFIG_DB_ARCHIVE_DAYS
=
"archive_days_threshold"
;
static
const
std
::
string
CONFIG_DB_INSERT_BUFFER_SIZE
=
"insert_buffer_size"
;
static
const
std
::
string
CONFIG_DB_PARALLEL_REDUCE
=
"parallel_reduce"
;
static
const
std
::
string
CONFIG_LOG
=
"log_config"
;
...
...
cpp/unittest/db/search_test.cpp
浏览文件 @
ed029178
...
...
@@ -6,6 +6,8 @@
#include <gtest/gtest.h>
#include "db/scheduler/task/SearchTask.h"
#include "utils/TimeRecorder.h"
#include <cmath>
#include <vector>
...
...
@@ -17,27 +19,33 @@ static constexpr uint64_t NQ = 15;
static
constexpr
uint64_t
TOP_K
=
64
;
void
BuildResult
(
uint64_t
nq
,
uint64_t
top_k
,
uint64_t
topk
,
bool
ascending
,
std
::
vector
<
long
>
&
output_ids
,
std
::
vector
<
float
>
&
output_distence
)
{
output_ids
.
clear
();
output_ids
.
resize
(
nq
*
top
_
k
);
output_ids
.
resize
(
nq
*
topk
);
output_distence
.
clear
();
output_distence
.
resize
(
nq
*
top
_
k
);
output_distence
.
resize
(
nq
*
topk
);
for
(
uint64_t
i
=
0
;
i
<
nq
;
i
++
)
{
for
(
uint64_t
j
=
0
;
j
<
top
_
k
;
j
++
)
{
output_ids
[
i
*
top
_
k
+
j
]
=
(
long
)(
drand48
()
*
100000
);
output_distence
[
i
*
top
_k
+
j
]
=
j
+
drand48
(
);
for
(
uint64_t
j
=
0
;
j
<
topk
;
j
++
)
{
output_ids
[
i
*
topk
+
j
]
=
(
long
)(
drand48
()
*
100000
);
output_distence
[
i
*
top
k
+
j
]
=
ascending
?
(
j
+
drand48
())
:
((
topk
-
j
)
+
drand48
()
);
}
}
}
void
CheckResult
(
const
engine
::
SearchContext
::
Id2DistanceMap
&
src_1
,
const
engine
::
SearchContext
::
Id2DistanceMap
&
src_2
,
const
engine
::
SearchContext
::
Id2DistanceMap
&
target
)
{
const
engine
::
SearchContext
::
Id2DistanceMap
&
target
,
bool
ascending
)
{
for
(
uint64_t
i
=
0
;
i
<
target
.
size
()
-
1
;
i
++
)
{
ASSERT_LE
(
target
[
i
].
second
,
target
[
i
+
1
].
second
);
if
(
ascending
)
{
ASSERT_LE
(
target
[
i
].
second
,
target
[
i
+
1
].
second
);
}
else
{
ASSERT_GE
(
target
[
i
].
second
,
target
[
i
+
1
].
second
);
}
}
using
ID2DistMap
=
std
::
map
<
long
,
float
>
;
...
...
@@ -57,9 +65,52 @@ void CheckResult(const engine::SearchContext::Id2DistanceMap& src_1,
}
}
void
CheckCluster
(
const
std
::
vector
<
long
>&
target_ids
,
const
std
::
vector
<
float
>&
target_distence
,
const
engine
::
SearchContext
::
ResultSet
&
src_result
,
int64_t
nq
,
int64_t
topk
)
{
ASSERT_EQ
(
src_result
.
size
(),
nq
);
for
(
int64_t
i
=
0
;
i
<
nq
;
i
++
)
{
auto
&
res
=
src_result
[
i
];
ASSERT_EQ
(
res
.
size
(),
topk
);
if
(
res
.
empty
())
{
continue
;
}
ASSERT_EQ
(
res
[
0
].
first
,
target_ids
[
i
*
topk
]);
ASSERT_EQ
(
res
[
topk
-
1
].
first
,
target_ids
[
i
*
topk
+
topk
-
1
]);
}
}
void
CheckTopkResult
(
const
engine
::
SearchContext
::
ResultSet
&
src_result
,
bool
ascending
,
int64_t
nq
,
int64_t
topk
)
{
ASSERT_EQ
(
src_result
.
size
(),
nq
);
for
(
int64_t
i
=
0
;
i
<
nq
;
i
++
)
{
auto
&
res
=
src_result
[
i
];
ASSERT_EQ
(
res
.
size
(),
topk
);
if
(
res
.
empty
())
{
continue
;
}
for
(
int64_t
k
=
0
;
k
<
topk
-
1
;
k
++
)
{
if
(
ascending
)
{
ASSERT_LE
(
res
[
k
].
second
,
res
[
k
+
1
].
second
);
}
else
{
ASSERT_GE
(
res
[
k
].
second
,
res
[
k
+
1
].
second
);
}
}
}
}
}
TEST
(
DBSearchTest
,
TOPK_TEST
)
{
bool
ascending
=
true
;
std
::
vector
<
long
>
target_ids
;
std
::
vector
<
float
>
target_distence
;
engine
::
SearchContext
::
ResultSet
src_result
;
...
...
@@ -67,19 +118,19 @@ TEST(DBSearchTest, TOPK_TEST) {
ASSERT_FALSE
(
status
.
ok
());
ASSERT_TRUE
(
src_result
.
empty
());
BuildResult
(
NQ
,
TOP_K
,
target_ids
,
target_distence
);
BuildResult
(
NQ
,
TOP_K
,
ascending
,
target_ids
,
target_distence
);
status
=
engine
::
SearchTask
::
ClusterResult
(
target_ids
,
target_distence
,
NQ
,
TOP_K
,
src_result
);
ASSERT_TRUE
(
status
.
ok
());
ASSERT_EQ
(
src_result
.
size
(),
NQ
);
engine
::
SearchContext
::
ResultSet
target_result
;
status
=
engine
::
SearchTask
::
TopkResult
(
target_result
,
TOP_K
,
true
,
target_result
);
status
=
engine
::
SearchTask
::
TopkResult
(
target_result
,
TOP_K
,
ascending
,
target_result
);
ASSERT_TRUE
(
status
.
ok
());
status
=
engine
::
SearchTask
::
TopkResult
(
target_result
,
TOP_K
,
true
,
src_result
);
status
=
engine
::
SearchTask
::
TopkResult
(
target_result
,
TOP_K
,
ascending
,
src_result
);
ASSERT_FALSE
(
status
.
ok
());
status
=
engine
::
SearchTask
::
TopkResult
(
src_result
,
TOP_K
,
true
,
target_result
);
status
=
engine
::
SearchTask
::
TopkResult
(
src_result
,
TOP_K
,
ascending
,
target_result
);
ASSERT_TRUE
(
status
.
ok
());
ASSERT_TRUE
(
src_result
.
empty
());
ASSERT_EQ
(
target_result
.
size
(),
NQ
);
...
...
@@ -87,21 +138,21 @@ TEST(DBSearchTest, TOPK_TEST) {
std
::
vector
<
long
>
src_ids
;
std
::
vector
<
float
>
src_distence
;
uint64_t
wrong_topk
=
TOP_K
-
10
;
BuildResult
(
NQ
,
wrong_topk
,
src_ids
,
src_distence
);
BuildResult
(
NQ
,
wrong_topk
,
ascending
,
src_ids
,
src_distence
);
status
=
engine
::
SearchTask
::
ClusterResult
(
src_ids
,
src_distence
,
NQ
,
wrong_topk
,
src_result
);
ASSERT_TRUE
(
status
.
ok
());
status
=
engine
::
SearchTask
::
TopkResult
(
src_result
,
TOP_K
,
true
,
target_result
);
status
=
engine
::
SearchTask
::
TopkResult
(
src_result
,
TOP_K
,
ascending
,
target_result
);
ASSERT_TRUE
(
status
.
ok
());
for
(
uint64_t
i
=
0
;
i
<
NQ
;
i
++
)
{
ASSERT_EQ
(
target_result
[
i
].
size
(),
TOP_K
);
}
wrong_topk
=
TOP_K
+
10
;
BuildResult
(
NQ
,
wrong_topk
,
src_ids
,
src_distence
);
BuildResult
(
NQ
,
wrong_topk
,
ascending
,
src_ids
,
src_distence
);
status
=
engine
::
SearchTask
::
TopkResult
(
src_result
,
TOP_K
,
true
,
target_result
);
status
=
engine
::
SearchTask
::
TopkResult
(
src_result
,
TOP_K
,
ascending
,
target_result
);
ASSERT_TRUE
(
status
.
ok
());
for
(
uint64_t
i
=
0
;
i
<
NQ
;
i
++
)
{
ASSERT_EQ
(
target_result
[
i
].
size
(),
TOP_K
);
...
...
@@ -109,6 +160,7 @@ TEST(DBSearchTest, TOPK_TEST) {
}
TEST
(
DBSearchTest
,
MERGE_TEST
)
{
bool
ascending
=
true
;
std
::
vector
<
long
>
target_ids
;
std
::
vector
<
float
>
target_distence
;
std
::
vector
<
long
>
src_ids
;
...
...
@@ -116,8 +168,8 @@ TEST(DBSearchTest, MERGE_TEST) {
engine
::
SearchContext
::
ResultSet
src_result
,
target_result
;
uint64_t
src_count
=
5
,
target_count
=
8
;
BuildResult
(
1
,
src_count
,
src_ids
,
src_distence
);
BuildResult
(
1
,
target_count
,
target_ids
,
target_distence
);
BuildResult
(
1
,
src_count
,
ascending
,
src_ids
,
src_distence
);
BuildResult
(
1
,
target_count
,
ascending
,
target_ids
,
target_distence
);
auto
status
=
engine
::
SearchTask
::
ClusterResult
(
src_ids
,
src_distence
,
1
,
src_count
,
src_result
);
ASSERT_TRUE
(
status
.
ok
());
status
=
engine
::
SearchTask
::
ClusterResult
(
target_ids
,
target_distence
,
1
,
target_count
,
target_result
);
...
...
@@ -126,37 +178,107 @@ TEST(DBSearchTest, MERGE_TEST) {
{
engine
::
SearchContext
::
Id2DistanceMap
src
=
src_result
[
0
];
engine
::
SearchContext
::
Id2DistanceMap
target
=
target_result
[
0
];
status
=
engine
::
SearchTask
::
MergeResult
(
src
,
target
,
10
,
true
);
status
=
engine
::
SearchTask
::
MergeResult
(
src
,
target
,
10
,
ascending
);
ASSERT_TRUE
(
status
.
ok
());
ASSERT_EQ
(
target
.
size
(),
10
);
CheckResult
(
src_result
[
0
],
target_result
[
0
],
target
);
CheckResult
(
src_result
[
0
],
target_result
[
0
],
target
,
ascending
);
}
{
engine
::
SearchContext
::
Id2DistanceMap
src
=
src_result
[
0
];
engine
::
SearchContext
::
Id2DistanceMap
target
;
status
=
engine
::
SearchTask
::
MergeResult
(
src
,
target
,
10
,
true
);
status
=
engine
::
SearchTask
::
MergeResult
(
src
,
target
,
10
,
ascending
);
ASSERT_TRUE
(
status
.
ok
());
ASSERT_EQ
(
target
.
size
(),
src_count
);
ASSERT_TRUE
(
src
.
empty
());
CheckResult
(
src_result
[
0
],
target_result
[
0
],
target
);
CheckResult
(
src_result
[
0
],
target_result
[
0
],
target
,
ascending
);
}
{
engine
::
SearchContext
::
Id2DistanceMap
src
=
src_result
[
0
];
engine
::
SearchContext
::
Id2DistanceMap
target
=
target_result
[
0
];
status
=
engine
::
SearchTask
::
MergeResult
(
src
,
target
,
30
,
true
);
status
=
engine
::
SearchTask
::
MergeResult
(
src
,
target
,
30
,
ascending
);
ASSERT_TRUE
(
status
.
ok
());
ASSERT_EQ
(
target
.
size
(),
src_count
+
target_count
);
CheckResult
(
src_result
[
0
],
target_result
[
0
],
target
);
CheckResult
(
src_result
[
0
],
target_result
[
0
],
target
,
ascending
);
}
{
engine
::
SearchContext
::
Id2DistanceMap
target
=
src_result
[
0
];
engine
::
SearchContext
::
Id2DistanceMap
src
=
target_result
[
0
];
status
=
engine
::
SearchTask
::
MergeResult
(
src
,
target
,
30
,
true
);
status
=
engine
::
SearchTask
::
MergeResult
(
src
,
target
,
30
,
ascending
);
ASSERT_TRUE
(
status
.
ok
());
ASSERT_EQ
(
target
.
size
(),
src_count
+
target_count
);
CheckResult
(
src_result
[
0
],
target_result
[
0
],
target
);
CheckResult
(
src_result
[
0
],
target_result
[
0
],
target
,
ascending
);
}
}
TEST
(
DBSearchTest
,
PARALLEL_CLUSTER_TEST
)
{
bool
ascending
=
true
;
std
::
vector
<
long
>
target_ids
;
std
::
vector
<
float
>
target_distence
;
engine
::
SearchContext
::
ResultSet
src_result
;
auto
DoCluster
=
[
&
](
int64_t
nq
,
int64_t
topk
)
{
server
::
TimeRecorder
rc
(
"DoCluster"
);
src_result
.
clear
();
BuildResult
(
nq
,
topk
,
ascending
,
target_ids
,
target_distence
);
rc
.
RecordSection
(
"build id/dietance map"
);
auto
status
=
engine
::
SearchTask
::
ClusterResult
(
target_ids
,
target_distence
,
nq
,
topk
,
src_result
);
ASSERT_TRUE
(
status
.
ok
());
ASSERT_EQ
(
src_result
.
size
(),
nq
);
rc
.
RecordSection
(
"cluster result"
);
CheckCluster
(
target_ids
,
target_distence
,
src_result
,
nq
,
topk
);
rc
.
RecordSection
(
"check result"
);
};
DoCluster
(
10000
,
1000
);
DoCluster
(
333
,
999
);
DoCluster
(
1
,
1000
);
DoCluster
(
1
,
1
);
DoCluster
(
7
,
0
);
DoCluster
(
9999
,
1
);
DoCluster
(
10001
,
1
);
DoCluster
(
58273
,
1234
);
}
TEST
(
DBSearchTest
,
PARALLEL_TOPK_TEST
)
{
std
::
vector
<
long
>
target_ids
;
std
::
vector
<
float
>
target_distence
;
engine
::
SearchContext
::
ResultSet
src_result
;
std
::
vector
<
long
>
insufficient_ids
;
std
::
vector
<
float
>
insufficient_distence
;
engine
::
SearchContext
::
ResultSet
insufficient_result
;
auto
DoTopk
=
[
&
](
int64_t
nq
,
int64_t
topk
,
int64_t
insufficient_topk
,
bool
ascending
)
{
src_result
.
clear
();
insufficient_result
.
clear
();
server
::
TimeRecorder
rc
(
"DoCluster"
);
BuildResult
(
nq
,
topk
,
ascending
,
target_ids
,
target_distence
);
auto
status
=
engine
::
SearchTask
::
ClusterResult
(
target_ids
,
target_distence
,
nq
,
topk
,
src_result
);
rc
.
RecordSection
(
"cluster result"
);
BuildResult
(
nq
,
insufficient_topk
,
ascending
,
insufficient_ids
,
insufficient_distence
);
status
=
engine
::
SearchTask
::
ClusterResult
(
target_ids
,
target_distence
,
nq
,
insufficient_topk
,
insufficient_result
);
rc
.
RecordSection
(
"cluster result"
);
engine
::
SearchTask
::
TopkResult
(
insufficient_result
,
topk
,
ascending
,
src_result
);
ASSERT_TRUE
(
status
.
ok
());
rc
.
RecordSection
(
"topk"
);
CheckTopkResult
(
src_result
,
ascending
,
nq
,
topk
);
rc
.
RecordSection
(
"check result"
);
};
DoTopk
(
5
,
10
,
4
,
false
);
DoTopk
(
20005
,
998
,
123
,
true
);
DoTopk
(
9987
,
12
,
10
,
false
);
DoTopk
(
77777
,
1000
,
1
,
false
);
DoTopk
(
5432
,
8899
,
8899
,
true
);
}
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录