Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
milvus
milvus
提交
b2e8ba7b
M
milvus
项目概览
milvus
/
milvus
11 个月 前同步成功
通知
261
Star
22476
Fork
2472
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
milvus
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
b2e8ba7b
编写于
9月 22, 2021
作者:
D
dragondriver
提交者:
GitHub
9月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix reduce algorithm in proxy search task (#8206)
Signed-off-by:
N
dragondriver
<
jiquan.long@zilliz.com
>
上级
c1e229cb
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
250 addition
and
9 deletion
+250
-9
internal/proxy/task.go
internal/proxy/task.go
+5
-8
internal/proxy/task_test.go
internal/proxy/task_test.go
+245
-1
未找到文件。
internal/proxy/task.go
浏览文件 @
b2e8ba7b
...
...
@@ -83,6 +83,8 @@ const (
CreateAliasTaskName
=
"CreateAliasTask"
DropAliasTaskName
=
"DropAliasTask"
AlterAliasTaskName
=
"AlterAliasTask"
minFloat32
=
-
1
*
float32
(
math
.
MaxFloat32
)
)
type
task
interface
{
...
...
@@ -1755,8 +1757,6 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
}
}
const
minFloat32
=
-
1
*
float32
(
math
.
MaxFloat32
)
// TODO(yukun): Use parallel function
var
realTopK
int64
=
-
1
var
idx
int64
...
...
@@ -1766,17 +1766,14 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
j
=
0
for
;
j
<
topk
;
j
++
{
valid
:=
true
choice
,
maxDistance
:=
0
,
minFloat32
choice
,
maxDistance
:=
-
1
,
minFloat32
for
q
,
loc
:=
range
locs
{
// query num, the number of ways to merge
if
loc
>=
topk
{
continue
}
curIdx
:=
idx
*
topk
+
loc
id
:=
searchResultData
[
q
]
.
Ids
.
GetIntId
()
.
Data
[
curIdx
]
if
id
==
-
1
{
valid
=
false
}
else
{
if
id
!=
-
1
{
distance
:=
searchResultData
[
q
]
.
Scores
[
curIdx
]
if
distance
>
maxDistance
{
choice
=
q
...
...
@@ -1784,7 +1781,7 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
}
}
}
if
!
valid
{
if
choice
==
-
1
{
break
}
choiceOffset
:=
locs
[
choice
]
...
...
internal/proxy/task_test.go
浏览文件 @
b2e8ba7b
...
...
@@ -1903,7 +1903,7 @@ func TestSearchTask_all(t *testing.T) {
for
i
:=
0
;
i
<
nq
;
i
++
{
for
j
:=
0
;
j
<
topk
;
j
++
{
offset
:=
i
*
topk
+
j
score
:=
rand
.
Float32
()
score
:=
float32
(
uniquegenerator
.
GetUniqueIntGeneratorIns
()
.
GetInt
())
// increasingly
id
:=
int64
(
uniquegenerator
.
GetUniqueIntGeneratorIns
()
.
GetInt
())
resultData
.
Scores
[
offset
]
=
score
resultData
.
Ids
.
IdField
.
(
*
schemapb
.
IDs_IntId
)
.
IntId
.
Data
[
offset
]
=
id
...
...
@@ -1981,6 +1981,250 @@ func TestSearchTask_all(t *testing.T) {
wg
.
Wait
()
}
func
TestSearchTask_7803_reduce
(
t
*
testing
.
T
)
{
var
err
error
Params
.
Init
()
Params
.
SearchResultChannelNames
=
[]
string
{
funcutil
.
GenRandomStr
()}
rc
:=
NewRootCoordMock
()
rc
.
Start
()
defer
rc
.
Stop
()
ctx
:=
context
.
Background
()
err
=
InitMetaCache
(
rc
)
assert
.
NoError
(
t
,
err
)
shardsNum
:=
int32
(
2
)
prefix
:=
"TestSearchTask_7803_reduce"
dbName
:=
""
collectionName
:=
prefix
+
funcutil
.
GenRandomStr
()
int64Field
:=
"int64"
floatVecField
:=
"fvec"
dim
:=
128
expr
:=
fmt
.
Sprintf
(
"%s > 0"
,
int64Field
)
nq
:=
10
topk
:=
10
nprobe
:=
10
schema
:=
constructCollectionSchema
(
int64Field
,
floatVecField
,
dim
,
collectionName
)
marshaledSchema
,
err
:=
proto
.
Marshal
(
schema
)
assert
.
NoError
(
t
,
err
)
createColT
:=
&
createCollectionTask
{
Condition
:
NewTaskCondition
(
ctx
),
CreateCollectionRequest
:
&
milvuspb
.
CreateCollectionRequest
{
Base
:
nil
,
DbName
:
dbName
,
CollectionName
:
collectionName
,
Schema
:
marshaledSchema
,
ShardsNum
:
shardsNum
,
},
ctx
:
ctx
,
rootCoord
:
rc
,
result
:
nil
,
schema
:
nil
,
}
assert
.
NoError
(
t
,
createColT
.
OnEnqueue
())
assert
.
NoError
(
t
,
createColT
.
PreExecute
(
ctx
))
assert
.
NoError
(
t
,
createColT
.
Execute
(
ctx
))
assert
.
NoError
(
t
,
createColT
.
PostExecute
(
ctx
))
dmlChannelsFunc
:=
getDmlChannelsFunc
(
ctx
,
rc
)
query
:=
newMockGetChannelsService
()
factory
:=
newSimpleMockMsgStreamFactory
()
chMgr
:=
newChannelsMgrImpl
(
dmlChannelsFunc
,
nil
,
query
.
GetChannels
,
nil
,
factory
)
defer
chMgr
.
removeAllDMLStream
()
defer
chMgr
.
removeAllDQLStream
()
collectionID
,
err
:=
globalMetaCache
.
GetCollectionID
(
ctx
,
collectionName
)
assert
.
NoError
(
t
,
err
)
qc
:=
NewQueryCoordMock
()
qc
.
Start
()
defer
qc
.
Stop
()
status
,
err
:=
qc
.
LoadCollection
(
ctx
,
&
querypb
.
LoadCollectionRequest
{
Base
:
&
commonpb
.
MsgBase
{
MsgType
:
commonpb
.
MsgType_LoadCollection
,
MsgID
:
0
,
Timestamp
:
0
,
SourceID
:
Params
.
ProxyID
,
},
DbID
:
0
,
CollectionID
:
collectionID
,
Schema
:
nil
,
})
assert
.
NoError
(
t
,
err
)
assert
.
Equal
(
t
,
commonpb
.
ErrorCode_Success
,
status
.
ErrorCode
)
req
:=
constructSearchRequest
(
dbName
,
collectionName
,
expr
,
floatVecField
,
nq
,
dim
,
nprobe
,
topk
)
task
:=
&
searchTask
{
Condition
:
NewTaskCondition
(
ctx
),
SearchRequest
:
&
internalpb
.
SearchRequest
{
Base
:
&
commonpb
.
MsgBase
{
MsgType
:
commonpb
.
MsgType_Search
,
MsgID
:
0
,
Timestamp
:
0
,
SourceID
:
Params
.
ProxyID
,
},
ResultChannelID
:
strconv
.
FormatInt
(
Params
.
ProxyID
,
10
),
DbID
:
0
,
CollectionID
:
0
,
PartitionIDs
:
nil
,
Dsl
:
""
,
PlaceholderGroup
:
nil
,
DslType
:
0
,
SerializedExprPlan
:
nil
,
OutputFieldsId
:
nil
,
TravelTimestamp
:
0
,
GuaranteeTimestamp
:
0
,
},
ctx
:
ctx
,
resultBuf
:
make
(
chan
[]
*
internalpb
.
SearchResults
),
result
:
nil
,
query
:
req
,
chMgr
:
chMgr
,
qc
:
qc
,
}
// simple mock for query node
// TODO(dragondriver): should we replace this mock using RocksMq or MemMsgStream?
err
=
chMgr
.
createDQLStream
(
collectionID
)
assert
.
NoError
(
t
,
err
)
stream
,
err
:=
chMgr
.
getDQLStream
(
collectionID
)
assert
.
NoError
(
t
,
err
)
var
wg
sync
.
WaitGroup
wg
.
Add
(
1
)
consumeCtx
,
cancel
:=
context
.
WithCancel
(
ctx
)
go
func
()
{
defer
wg
.
Done
()
for
{
select
{
case
<-
consumeCtx
.
Done
()
:
return
case
pack
:=
<-
stream
.
Chan
()
:
for
_
,
msg
:=
range
pack
.
Msgs
{
_
,
ok
:=
msg
.
(
*
msgstream
.
SearchMsg
)
assert
.
True
(
t
,
ok
)
// TODO(dragondriver): construct result according to the request
constructSearchResulstData
:=
func
(
invalidNum
int
)
*
schemapb
.
SearchResultData
{
resultData
:=
&
schemapb
.
SearchResultData
{
NumQueries
:
int64
(
nq
),
TopK
:
int64
(
topk
),
FieldsData
:
nil
,
Scores
:
make
([]
float32
,
nq
*
topk
),
Ids
:
&
schemapb
.
IDs
{
IdField
:
&
schemapb
.
IDs_IntId
{
IntId
:
&
schemapb
.
LongArray
{
Data
:
make
([]
int64
,
nq
*
topk
),
},
},
},
Topks
:
make
([]
int64
,
nq
),
}
for
i
:=
0
;
i
<
nq
;
i
++
{
for
j
:=
0
;
j
<
topk
;
j
++
{
offset
:=
i
*
topk
+
j
if
j
>=
invalidNum
{
resultData
.
Scores
[
offset
]
=
minFloat32
resultData
.
Ids
.
IdField
.
(
*
schemapb
.
IDs_IntId
)
.
IntId
.
Data
[
offset
]
=
-
1
}
else
{
score
:=
float32
(
uniquegenerator
.
GetUniqueIntGeneratorIns
()
.
GetInt
())
// increasingly
id
:=
int64
(
uniquegenerator
.
GetUniqueIntGeneratorIns
()
.
GetInt
())
resultData
.
Scores
[
offset
]
=
score
resultData
.
Ids
.
IdField
.
(
*
schemapb
.
IDs_IntId
)
.
IntId
.
Data
[
offset
]
=
id
}
}
resultData
.
Topks
[
i
]
=
int64
(
topk
)
}
return
resultData
}
result1
:=
&
internalpb
.
SearchResults
{
Base
:
&
commonpb
.
MsgBase
{
MsgType
:
commonpb
.
MsgType_SearchResult
,
MsgID
:
0
,
Timestamp
:
0
,
SourceID
:
0
,
},
Status
:
&
commonpb
.
Status
{
ErrorCode
:
commonpb
.
ErrorCode_Success
,
Reason
:
""
,
},
ResultChannelID
:
""
,
MetricType
:
distance
.
L2
,
NumQueries
:
int64
(
nq
),
TopK
:
int64
(
topk
),
SealedSegmentIDsSearched
:
nil
,
ChannelIDsSearched
:
nil
,
GlobalSealedSegmentIDs
:
nil
,
SlicedBlob
:
nil
,
SlicedNumCount
:
1
,
SlicedOffset
:
0
,
}
resultData
:=
constructSearchResulstData
(
topk
/
2
)
sliceBlob
,
err
:=
proto
.
Marshal
(
resultData
)
assert
.
NoError
(
t
,
err
)
result1
.
SlicedBlob
=
sliceBlob
result2
:=
&
internalpb
.
SearchResults
{
Base
:
&
commonpb
.
MsgBase
{
MsgType
:
commonpb
.
MsgType_SearchResult
,
MsgID
:
0
,
Timestamp
:
0
,
SourceID
:
0
,
},
Status
:
&
commonpb
.
Status
{
ErrorCode
:
commonpb
.
ErrorCode_Success
,
Reason
:
""
,
},
ResultChannelID
:
""
,
MetricType
:
distance
.
L2
,
NumQueries
:
int64
(
nq
),
TopK
:
int64
(
topk
),
SealedSegmentIDsSearched
:
nil
,
ChannelIDsSearched
:
nil
,
GlobalSealedSegmentIDs
:
nil
,
SlicedBlob
:
nil
,
SlicedNumCount
:
1
,
SlicedOffset
:
0
,
}
resultData2
:=
constructSearchResulstData
(
topk
-
topk
/
2
)
sliceBlob2
,
err
:=
proto
.
Marshal
(
resultData2
)
assert
.
NoError
(
t
,
err
)
result2
.
SlicedBlob
=
sliceBlob2
// send search result
task
.
resultBuf
<-
[]
*
internalpb
.
SearchResults
{
result1
,
result2
}
}
}
}
}()
assert
.
NoError
(
t
,
task
.
OnEnqueue
())
assert
.
NoError
(
t
,
task
.
PreExecute
(
ctx
))
assert
.
NoError
(
t
,
task
.
Execute
(
ctx
))
assert
.
NoError
(
t
,
task
.
PostExecute
(
ctx
))
cancel
()
wg
.
Wait
()
}
func
TestSearchTask_Type
(
t
*
testing
.
T
)
{
Params
.
Init
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录