Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
milvus
milvus
提交
af3c14a8
M
milvus
项目概览
milvus
/
milvus
11 个月 前同步成功
通知
260
Star
22476
Fork
2472
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
milvus
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
af3c14a8
编写于
10月 13, 2020
作者:
B
bigsheeper
提交者:
yefu.chen
10月 13, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add batched search support
Signed-off-by:
N
bigsheeper
<
yihao.dai@zilliz.com
>
上级
d7e6b993
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
94 addition
and
47 deletion
+94
-47
proxy/src/message_client/ClientV2.cpp
proxy/src/message_client/ClientV2.cpp
+51
-24
reader/read_node/query_node.go
reader/read_node/query_node.go
+38
-19
reader/read_node/segment.go
reader/read_node/segment.go
+2
-2
sdk/examples/simple/search.cpp
sdk/examples/simple/search.cpp
+3
-2
未找到文件。
proxy/src/message_client/ClientV2.cpp
浏览文件 @
af3c14a8
...
...
@@ -76,14 +76,21 @@ Aggregation(std::vector<std::shared_ptr<grpc::QueryResult>> results, milvus::grp
}
std
::
vector
<
float
>
all_scores
;
std
::
vector
<
float
>
all_distance
;
std
::
vector
<
int64_t
>
all_entities_ids
;
// Proxy get numQueries from row_num.
auto
numQueries
=
results
[
0
]
->
row_num
();
auto
topK
=
results
[
0
]
->
distances_size
()
/
numQueries
;
// 2d array for multiple queries
std
::
vector
<
std
::
vector
<
float
>>
all_distance
(
numQueries
);
std
::
vector
<
std
::
vector
<
int64_t
>>
all_entities_ids
(
numQueries
);
std
::
vector
<
bool
>
all_valid_row
;
std
::
vector
<
grpc
::
RowData
>
all_row_data
;
std
::
vector
<
grpc
::
KeyValuePair
>
all_kv_pairs
;
grpc
::
Status
status
;
int
row_num
=
0
;
//
int row_num = 0;
for
(
auto
&
result_per_node
:
results
)
{
if
(
result_per_node
->
status
().
error_code
()
!=
grpc
::
ErrorCode
::
SUCCESS
)
{
...
...
@@ -91,46 +98,66 @@ Aggregation(std::vector<std::shared_ptr<grpc::QueryResult>> results, milvus::grp
// one_node_res->entities().status().error_code() != grpc::ErrorCode::SUCCESS) {
return
Status
(
DB_ERROR
,
"QueryNode return wrong status!"
);
}
for
(
int
j
=
0
;
j
<
result_per_node
->
distances_size
();
j
++
)
{
all_scores
.
push_back
(
result_per_node
->
scores
()[
j
]);
all_distance
.
push_back
(
result_per_node
->
distances
()[
j
]);
// all_kv_pairs.push_back(result_per_node->extra_params()[j]);
}
for
(
int
k
=
0
;
k
<
result_per_node
->
entities
().
ids_size
();
++
k
)
{
all_entities_ids
.
push_back
(
result_per_node
->
entities
().
ids
(
k
));
// all_valid_row.push_back(result_per_node->entities().valid_row(k));
// all_row_data.push_back(result_per_node->entities().rows_data(k));
}
if
(
result_per_node
->
row_num
()
>
row_num
)
{
row_num
=
result_per_node
->
row_num
();
// assert(result_per_node->row_num() == numQueries);
for
(
int
i
=
0
;
i
<
numQueries
;
i
++
)
{
for
(
int
j
=
i
*
topK
;
j
<
(
i
+
1
)
*
topK
&&
j
<
result_per_node
->
distances_size
();
j
++
)
{
all_scores
.
push_back
(
result_per_node
->
scores
()[
j
]);
all_distance
[
i
].
push_back
(
result_per_node
->
distances
()[
j
]);
all_entities_ids
[
i
].
push_back
(
result_per_node
->
entities
().
ids
(
j
));
}
}
// for (int j = 0; j < result_per_node->distances_size(); j++) {
// all_scores.push_back(result_per_node->scores()[j]);
// all_distance.push_back(result_per_node->distances()[j]);
//// all_kv_pairs.push_back(result_per_node->extra_params()[j]);
// }
// for (int k = 0; k < result_per_node->entities().ids_size(); ++k) {
// all_entities_ids.push_back(result_per_node->entities().ids(k));
//// all_valid_row.push_back(result_per_node->entities().valid_row(k));
//// all_row_data.push_back(result_per_node->entities().rows_data(k));
// }
// if (result_per_node->row_num() > row_num) {
// row_num = result_per_node->row_num();
// }
status
=
result_per_node
->
status
();
}
std
::
vector
<
int
>
index
(
all_distance
.
size
());
std
::
vector
<
std
::
vector
<
int
>>
index_array
;
for
(
int
i
=
0
;
i
<
numQueries
;
i
++
)
{
auto
&
distance
=
all_distance
[
i
];
std
::
vector
<
int
>
index
(
distance
.
size
());
iota
(
index
.
begin
(),
index
.
end
(),
0
);
iota
(
index
.
begin
(),
index
.
end
(),
0
);
std
::
stable_sort
(
index
.
begin
(),
index
.
end
(),
[
&
distance
](
size_t
i1
,
size_t
i2
)
{
return
distance
[
i1
]
<
distance
[
i2
];
});
index_array
.
emplace_back
(
index
);
}
std
::
stable_sort
(
index
.
begin
(),
index
.
end
(),
[
&
all_distance
](
size_t
i1
,
size_t
i2
)
{
return
all_distance
[
i1
]
>
all_distance
[
i2
];
});
grpc
::
Entities
result_entities
;
for
(
int
m
=
0
;
m
<
result
->
row_num
();
++
m
)
{
result
->
add_scores
(
all_scores
[
index
[
m
]]);
result
->
add_distances
(
all_distance
[
index
[
m
]]);
for
(
int
i
=
0
;
i
<
numQueries
;
i
++
)
{
for
(
int
m
=
0
;
m
<
topK
;
++
m
)
{
result
->
add_scores
(
all_scores
[
index_array
[
i
][
m
]]);
result
->
add_distances
(
all_distance
[
i
][
index_array
[
i
][
m
]]);
// result->add_extra_params();
// result->mutable_extra_params(m)->CopyFrom(all_kv_pairs[index[m]]);
result_entities
.
add_ids
(
all_entities_ids
[
index
[
m
]]);
result_entities
.
add_ids
(
all_entities_ids
[
i
][
index_array
[
i
]
[
m
]]);
// result_entities.add_valid_row(all_valid_row[index[m]]);
// result_entities.add_rows_data();
// result_entities.mutable_rows_data(m)->CopyFrom(all_row_data[index[m]]);
}
}
result_entities
.
mutable_status
()
->
CopyFrom
(
status
);
result
->
set_row_num
(
row_num
);
result
->
set_row_num
(
numQueries
);
result
->
mutable_entities
()
->
CopyFrom
(
result_entities
);
result
->
set_query_id
(
results
[
0
]
->
query_id
());
// result->set_client_id(results[0]->client_id());
...
...
reader/read_node/query_node.go
浏览文件 @
af3c14a8
...
...
@@ -276,7 +276,7 @@ func (node *QueryNode) RunInsertDelete(wg *sync.WaitGroup) {
if
node
.
msgCounter
.
InsertCounter
/
CountInsertMsgBaseline
!=
BaselineCounter
{
node
.
WriteQueryLog
()
BaselineCounter
=
node
.
msgCounter
.
InsertCounter
/
CountInsertMsgBaseline
BaselineCounter
=
node
.
msgCounter
.
InsertCounter
/
CountInsertMsgBaseline
}
if
msgLen
[
0
]
==
0
&&
len
(
node
.
buffer
.
InsertDeleteBuffer
)
<=
0
{
...
...
@@ -339,10 +339,10 @@ func (node *QueryNode) RunSearch(wg *sync.WaitGroup) {
case
msg
:=
<-
node
.
messageClient
.
GetSearchChan
()
:
node
.
messageClient
.
SearchMsg
=
node
.
messageClient
.
SearchMsg
[
:
0
]
node
.
messageClient
.
SearchMsg
=
append
(
node
.
messageClient
.
SearchMsg
,
msg
)
fmt
.
Println
(
"Do Search..."
)
//for {
//if node.messageClient.SearchMsg[0].Timestamp < node.queryNodeTimeSync.ServiceTimeSync {
var
status
=
node
.
Search
(
node
.
messageClient
.
SearchMsg
)
fmt
.
Println
(
"Do Search done"
)
if
status
.
ErrorCode
!=
0
{
fmt
.
Println
(
"Search Failed"
)
node
.
PublishFailedSearchResult
()
...
...
@@ -504,8 +504,8 @@ func (node *QueryNode) DoInsertAndDelete() msgPb.Status {
}
wg
.
Add
(
1
)
var
deleteTimestamps
=
node
.
deleteData
.
deleteTimestamps
[
segmentID
]
fmt
.
Println
(
"Doing delete......"
)
go
node
.
DoDelete
(
segmentID
,
&
deleteIDs
,
&
deleteTimestamps
,
&
wg
)
fmt
.
Println
(
"Do delete done"
)
}
wg
.
Wait
()
...
...
@@ -513,7 +513,6 @@ func (node *QueryNode) DoInsertAndDelete() msgPb.Status {
}
func
(
node
*
QueryNode
)
DoInsert
(
segmentID
int64
,
wg
*
sync
.
WaitGroup
)
msgPb
.
Status
{
fmt
.
Println
(
"Doing insert..., len = "
,
len
(
node
.
insertData
.
insertIDs
[
segmentID
]))
var
targetSegment
,
err
=
node
.
GetSegmentBySegmentID
(
segmentID
)
if
err
!=
nil
{
fmt
.
Println
(
err
.
Error
())
...
...
@@ -526,6 +525,7 @@ func (node *QueryNode) DoInsert(segmentID int64, wg *sync.WaitGroup) msgPb.Statu
offsets
:=
node
.
insertData
.
insertOffset
[
segmentID
]
err
=
targetSegment
.
SegmentInsert
(
offsets
,
&
ids
,
&
timestamps
,
&
records
)
fmt
.
Println
(
"Do insert done, len = "
,
len
(
node
.
insertData
.
insertIDs
[
segmentID
]))
node
.
QueryLog
(
len
(
ids
))
...
...
@@ -584,8 +584,6 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
// TODO: Do not receive batched search requests
for
_
,
msg
:=
range
searchMessages
{
var
clientId
=
msg
.
ClientId
var
resultsTmp
=
make
([]
SearchResultTmp
,
0
)
var
searchTimestamp
=
msg
.
Timestamp
// ServiceTimeSync update by readerTimeSync, which is get from proxy.
...
...
@@ -610,6 +608,11 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
// 2. Get query information from query json
query
:=
node
.
QueryJson2Info
(
&
queryJson
)
// 2d slice for receiving multiple queries's results
var
resultsTmp
=
make
([][]
SearchResultTmp
,
query
.
NumQueries
)
for
i
:=
0
;
i
<
int
(
query
.
NumQueries
);
i
++
{
resultsTmp
[
i
]
=
make
([]
SearchResultTmp
,
0
)
}
// 3. Do search in all segments
for
_
,
segment
:=
range
node
.
SegmentsMap
{
...
...
@@ -625,18 +628,30 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
return
msgPb
.
Status
{
ErrorCode
:
1
}
}
for
i
:=
0
;
i
<
len
(
res
.
ResultIds
);
i
++
{
resultsTmp
=
append
(
resultsTmp
,
SearchResultTmp
{
ResultId
:
res
.
ResultIds
[
i
],
ResultDistance
:
res
.
ResultDistances
[
i
]})
for
i
:=
0
;
i
<
int
(
query
.
NumQueries
);
i
++
{
for
j
:=
i
*
query
.
TopK
;
j
<
(
i
+
1
)
*
query
.
TopK
;
j
++
{
resultsTmp
[
i
]
=
append
(
resultsTmp
[
i
],
SearchResultTmp
{
ResultId
:
res
.
ResultIds
[
j
],
ResultDistance
:
res
.
ResultDistances
[
j
],
})
}
}
}
// 4. Reduce results
sort
.
Slice
(
resultsTmp
,
func
(
i
,
j
int
)
bool
{
return
resultsTmp
[
i
]
.
ResultDistance
<
resultsTmp
[
j
]
.
ResultDistance
})
if
len
(
resultsTmp
)
>
query
.
TopK
{
resultsTmp
=
resultsTmp
[
:
query
.
TopK
]
for
_
,
rTmp
:=
range
resultsTmp
{
sort
.
Slice
(
rTmp
,
func
(
i
,
j
int
)
bool
{
return
rTmp
[
i
]
.
ResultDistance
<
rTmp
[
j
]
.
ResultDistance
})
}
for
_
,
rTmp
:=
range
resultsTmp
{
if
len
(
rTmp
)
>
query
.
TopK
{
rTmp
=
rTmp
[
:
query
.
TopK
]
}
}
var
entities
=
msgPb
.
Entities
{
Ids
:
make
([]
int64
,
0
),
}
...
...
@@ -649,15 +664,19 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
QueryId
:
msg
.
Uid
,
ClientId
:
clientId
,
}
for
_
,
res
:=
range
resultsTmp
{
results
.
Entities
.
Ids
=
append
(
results
.
Entities
.
Ids
,
res
.
ResultId
)
results
.
Distances
=
append
(
results
.
Distances
,
res
.
ResultDistance
)
results
.
Scores
=
append
(
results
.
Distances
,
float32
(
0
))
for
_
,
rTmp
:=
range
resultsTmp
{
for
_
,
res
:=
range
rTmp
{
results
.
Entities
.
Ids
=
append
(
results
.
Entities
.
Ids
,
res
.
ResultId
)
results
.
Distances
=
append
(
results
.
Distances
,
res
.
ResultDistance
)
results
.
Scores
=
append
(
results
.
Distances
,
float32
(
0
))
}
}
results
.
RowNum
=
int64
(
len
(
results
.
Distances
))
// Send numQueries to RowNum.
results
.
RowNum
=
query
.
NumQueries
// 5. publish result to pulsar
//fmt.Println(results.Entities.Ids)
//fmt.Println(results.Distances)
node
.
PublishSearchResult
(
&
results
)
}
...
...
reader/read_node/segment.go
浏览文件 @
af3c14a8
...
...
@@ -218,8 +218,8 @@ func (s *Segment) SegmentSearch(query *QueryInfo, timestamp uint64, vectorRecord
field_name
:
C
.
CString
(
query
.
FieldName
),
}
resultIds
:=
make
([]
int64
,
query
.
TopK
)
resultDistances
:=
make
([]
float32
,
query
.
TopK
)
resultIds
:=
make
([]
int64
,
int64
(
query
.
TopK
)
*
query
.
NumQueries
)
resultDistances
:=
make
([]
float32
,
int64
(
query
.
TopK
)
*
query
.
NumQueries
)
var
cTimestamp
=
C
.
ulong
(
timestamp
)
var
cResultIds
=
(
*
C
.
long
)(
&
resultIds
[
0
])
...
...
sdk/examples/simple/search.cpp
浏览文件 @
af3c14a8
...
...
@@ -17,6 +17,7 @@
#include "utils/Utils.h"
#include <random>
const
int
NUM_OF_VECTOR
=
1
;
const
int
TOP_K
=
10
;
const
int
LOOP
=
1000
;
...
...
@@ -32,7 +33,7 @@ get_vector_param() {
std
::
normal_distribution
<
float
>
dis
(
0
,
1
);
for
(
int
j
=
0
;
j
<
1
;
++
j
)
{
for
(
int
j
=
0
;
j
<
NUM_OF_VECTOR
;
++
j
)
{
milvus
::
VectorData
vectorData
;
std
::
vector
<
float
>
float_data
;
for
(
int
i
=
0
;
i
<
DIM
;
++
i
)
{
...
...
@@ -44,7 +45,7 @@ get_vector_param() {
}
nlohmann
::
json
vector_param_json
;
vector_param_json
[
"num_queries"
]
=
1
;
vector_param_json
[
"num_queries"
]
=
NUM_OF_VECTOR
;
vector_param_json
[
"topK"
]
=
TOP_K
;
vector_param_json
[
"field_name"
]
=
"field_vec"
;
std
::
string
vector_param_json_string
=
vector_param_json
.
dump
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录