Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
milvus
提交
19dca4c9
milvus
项目概览
BaiXuePrincess
/
milvus
与 Fork 源项目一致
从无法访问的项目Fork
通知
7
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
milvus
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
19dca4c9
编写于
10月 21, 2019
作者:
S
starlord
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
#59 Topk result is incorrect for small dataset
Former-commit-id: f8c4a38365881252e66f280647aedf1e4c6395e5
上级
18108119
变更
3
展开全部
隐藏空白更改
内联
并排
Showing
3 changed file
with
263 addition
and
245 deletion
+263
-245
core/src/scheduler/task/SearchTask.cpp
core/src/scheduler/task/SearchTask.cpp
+65
-65
core/src/scheduler/task/SearchTask.h
core/src/scheduler/task/SearchTask.h
+4
-4
core/unittest/db/test_search.cpp
core/unittest/db/test_search.cpp
+194
-176
未找到文件。
core/src/scheduler/task/SearchTask.cpp
浏览文件 @
19dca4c9
...
...
@@ -307,71 +307,71 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const s
}
}
void
XSearchTask
::
MergeTopkArray
(
std
::
vector
<
int64_t
>&
tar_ids
,
std
::
vector
<
float
>&
tar_distance
,
uint64_t
&
tar_input_k
,
const
std
::
vector
<
int64_t
>&
src_ids
,
const
std
::
vector
<
float
>&
src_distance
,
uint64_t
src_input_k
,
uint64_t
nq
,
uint64_t
topk
,
bool
ascending
)
{
if
(
src_ids
.
empty
()
||
src_distance
.
empty
())
{
return
;
}
uint64_t
output_k
=
std
::
min
(
topk
,
tar_input_k
+
src_input_k
);
std
::
vector
<
int64_t
>
id_buf
(
nq
*
output_k
,
-
1
);
std
::
vector
<
float
>
dist_buf
(
nq
*
output_k
,
0.0
);
uint64_t
buf_k
,
src_k
,
tar_k
;
uint64_t
src_idx
,
tar_idx
,
buf_idx
;
uint64_t
src_input_k_multi_i
,
tar_input_k_multi_i
,
buf_k_multi_i
;
for
(
uint64_t
i
=
0
;
i
<
nq
;
i
++
)
{
src_input_k_multi_i
=
src_input_k
*
i
;
tar_input_k_multi_i
=
tar_input_k
*
i
;
buf_k_multi_i
=
output_k
*
i
;
buf_k
=
src_k
=
tar_k
=
0
;
while
(
buf_k
<
output_k
&&
src_k
<
src_input_k
&&
tar_k
<
tar_input_k
)
{
src_idx
=
src_input_k_multi_i
+
src_k
;
tar_idx
=
tar_input_k_multi_i
+
tar_k
;
buf_idx
=
buf_k_multi_i
+
buf_k
;
if
((
ascending
&&
src_distance
[
src_idx
]
<
tar_distance
[
tar_idx
])
||
(
!
ascending
&&
src_distance
[
src_idx
]
>
tar_distance
[
tar_idx
]))
{
id_buf
[
buf_idx
]
=
src_ids
[
src_idx
];
dist_buf
[
buf_idx
]
=
src_distance
[
src_idx
];
src_k
++
;
}
else
{
id_buf
[
buf_idx
]
=
tar_ids
[
tar_idx
];
dist_buf
[
buf_idx
]
=
tar_distance
[
tar_idx
];
tar_k
++
;
}
buf_k
++
;
}
if
(
buf_k
<
output_k
)
{
if
(
src_k
<
src_input_k
)
{
while
(
buf_k
<
output_k
&&
src_k
<
src_input_k
)
{
src_idx
=
src_input_k_multi_i
+
src_k
;
buf_idx
=
buf_k_multi_i
+
buf_k
;
id_buf
[
buf_idx
]
=
src_ids
[
src_idx
];
dist_buf
[
buf_idx
]
=
src_distance
[
src_idx
];
src_k
++
;
buf_k
++
;
}
}
else
{
while
(
buf_k
<
output_k
&&
tar_k
<
tar_input_k
)
{
tar_idx
=
tar_input_k_multi_i
+
tar_k
;
buf_idx
=
buf_k_multi_i
+
buf_k
;
id_buf
[
buf_idx
]
=
tar_ids
[
tar_idx
];
dist_buf
[
buf_idx
]
=
tar_distance
[
tar_idx
];
tar_k
++
;
buf_k
++
;
}
}
}
}
tar_ids
.
swap
(
id_buf
);
tar_distance
.
swap
(
dist_buf
);
tar_input_k
=
output_k
;
}
//
void
//
XSearchTask::MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k,
//
const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance,
//
uint64_t src_input_k, uint64_t nq, uint64_t topk, bool ascending) {
//
if (src_ids.empty() || src_distance.empty()) {
//
return;
//
}
//
//
uint64_t output_k = std::min(topk, tar_input_k + src_input_k);
//
std::vector<int64_t> id_buf(nq * output_k, -1);
//
std::vector<float> dist_buf(nq * output_k, 0.0);
//
//
uint64_t buf_k, src_k, tar_k;
//
uint64_t src_idx, tar_idx, buf_idx;
//
uint64_t src_input_k_multi_i, tar_input_k_multi_i, buf_k_multi_i;
//
//
for (uint64_t i = 0; i < nq; i++) {
//
src_input_k_multi_i = src_input_k * i;
//
tar_input_k_multi_i = tar_input_k * i;
//
buf_k_multi_i = output_k * i;
//
buf_k = src_k = tar_k = 0;
//
while (buf_k < output_k && src_k < src_input_k && tar_k < tar_input_k) {
//
src_idx = src_input_k_multi_i + src_k;
//
tar_idx = tar_input_k_multi_i + tar_k;
//
buf_idx = buf_k_multi_i + buf_k;
//
if ((ascending && src_distance[src_idx] < tar_distance[tar_idx]) ||
//
(!ascending && src_distance[src_idx] > tar_distance[tar_idx])) {
//
id_buf[buf_idx] = src_ids[src_idx];
//
dist_buf[buf_idx] = src_distance[src_idx];
//
src_k++;
//
} else {
//
id_buf[buf_idx] = tar_ids[tar_idx];
//
dist_buf[buf_idx] = tar_distance[tar_idx];
//
tar_k++;
//
}
//
buf_k++;
//
}
//
//
if (buf_k < output_k) {
//
if (src_k < src_input_k) {
//
while (buf_k < output_k && src_k < src_input_k) {
//
src_idx = src_input_k_multi_i + src_k;
//
buf_idx = buf_k_multi_i + buf_k;
//
id_buf[buf_idx] = src_ids[src_idx];
//
dist_buf[buf_idx] = src_distance[src_idx];
//
src_k++;
//
buf_k++;
//
}
//
} else {
//
while (buf_k < output_k && tar_k < tar_input_k) {
//
tar_idx = tar_input_k_multi_i + tar_k;
//
buf_idx = buf_k_multi_i + buf_k;
//
id_buf[buf_idx] = tar_ids[tar_idx];
//
dist_buf[buf_idx] = tar_distance[tar_idx];
//
tar_k++;
//
buf_k++;
//
}
//
}
//
}
//
}
//
//
tar_ids.swap(id_buf);
//
tar_distance.swap(dist_buf);
//
tar_input_k = output_k;
//
}
}
// namespace scheduler
}
// namespace milvus
core/src/scheduler/task/SearchTask.h
浏览文件 @
19dca4c9
...
...
@@ -42,10 +42,10 @@ class XSearchTask : public Task {
MergeTopkToResultSet
(
const
std
::
vector
<
int64_t
>&
input_ids
,
const
std
::
vector
<
float
>&
input_distance
,
uint64_t
input_k
,
uint64_t
nq
,
uint64_t
topk
,
bool
ascending
,
scheduler
::
ResultSet
&
result
);
static
void
MergeTopkArray
(
std
::
vector
<
int64_t
>&
tar_ids
,
std
::
vector
<
float
>&
tar_distance
,
uint64_t
&
tar_input_k
,
const
std
::
vector
<
int64_t
>&
src_ids
,
const
std
::
vector
<
float
>&
src_distance
,
uint64_t
src_input_k
,
uint64_t
nq
,
uint64_t
topk
,
bool
ascending
);
//
static void
//
MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k,
//
const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance, uint64_t src_input_k,
//
uint64_t nq, uint64_t topk, bool ascending);
public:
TableFileSchemaPtr
file_
;
...
...
core/unittest/db/test_search.cpp
浏览文件 @
19dca4c9
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录