Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
milvus
milvus
提交
f4566731
M
milvus
项目概览
milvus
/
milvus
11 个月 前同步成功
通知
260
Star
22476
Fork
2472
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
milvus
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
f4566731
编写于
1月 09, 2021
作者:
X
xige-16
提交者:
yefu.chen
1月 09, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix search error in regression test
Signed-off-by:
N
xige-16
<
xi.ge@zilliz.com
>
上级
5dfe9448
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
51 addition
and
20 deletion
+51
-20
internal/core/src/segcore/load_index_c.cpp
internal/core/src/segcore/load_index_c.cpp
+1
-1
internal/msgstream/msgstream.go
internal/msgstream/msgstream.go
+12
-5
internal/proxy/task.go
internal/proxy/task.go
+0
-3
internal/querynode/search_service.go
internal/querynode/search_service.go
+25
-5
tests/python/test_search.py
tests/python/test_search.py
+13
-6
未找到文件。
internal/core/src/segcore/load_index_c.cpp
浏览文件 @
f4566731
...
...
@@ -133,7 +133,7 @@ AppendBinaryIndex(CBinarySet c_binary_set, void* index_binary, int64_t index_siz
auto
binary_set
=
(
milvus
::
knowhere
::
BinarySet
*
)
c_binary_set
;
std
::
string
index_key
(
c_index_key
);
uint8_t
*
index
=
(
uint8_t
*
)
index_binary
;
std
::
shared_ptr
<
uint8_t
[]
>
data
(
index
);
std
::
shared_ptr
<
uint8_t
[]
>
data
(
index
,
[](
void
*
)
{}
);
binary_set
->
Append
(
index_key
,
data
,
index_size
);
auto
status
=
CStatus
();
...
...
internal/msgstream/msgstream.go
浏览文件 @
f4566731
...
...
@@ -353,15 +353,16 @@ func (ms *PulsarTtMsgStream) bufMsgPackToChannel() {
default
:
wg
:=
sync
.
WaitGroup
{}
mu
:=
sync
.
Mutex
{}
findMapMutex
:=
sync
.
RWMutex
{}
for
i
:=
0
;
i
<
len
(
ms
.
consumers
);
i
++
{
if
isChannelReady
[
i
]
{
continue
}
wg
.
Add
(
1
)
go
ms
.
findTimeTick
(
i
,
eofMsgTimeStamp
,
&
wg
,
&
mu
)
go
ms
.
findTimeTick
(
i
,
eofMsgTimeStamp
,
&
wg
,
&
mu
,
&
findMapMutex
)
}
wg
.
Wait
()
timeStamp
,
ok
:=
checkTimeTickMsg
(
eofMsgTimeStamp
,
isChannelReady
)
timeStamp
,
ok
:=
checkTimeTickMsg
(
eofMsgTimeStamp
,
isChannelReady
,
&
findMapMutex
)
if
!
ok
||
timeStamp
<=
ms
.
lastTimeStamp
{
log
.
Printf
(
"All timeTick's timestamps are inconsistent"
)
continue
...
...
@@ -394,7 +395,8 @@ func (ms *PulsarTtMsgStream) bufMsgPackToChannel() {
func
(
ms
*
PulsarTtMsgStream
)
findTimeTick
(
channelIndex
int
,
eofMsgMap
map
[
int
]
Timestamp
,
wg
*
sync
.
WaitGroup
,
mu
*
sync
.
Mutex
)
{
mu
*
sync
.
Mutex
,
findMapMutex
*
sync
.
RWMutex
)
{
defer
wg
.
Done
()
for
{
select
{
...
...
@@ -421,7 +423,9 @@ func (ms *PulsarTtMsgStream) findTimeTick(channelIndex int,
log
.
Printf
(
"Failed to unmarshal, error = %v"
,
err
)
}
if
headerMsg
.
MsgType
==
internalPb
.
MsgType_kTimeTick
{
findMapMutex
.
Lock
()
eofMsgMap
[
channelIndex
]
=
tsMsg
.
(
*
TimeTickMsg
)
.
Timestamp
findMapMutex
.
Unlock
()
return
}
mu
.
Lock
()
...
...
@@ -470,7 +474,7 @@ func (ms *InMemMsgStream) Chan() <- chan *MsgPack {
}
*/
func
checkTimeTickMsg
(
msg
map
[
int
]
Timestamp
,
isChannelReady
[]
bool
)
(
Timestamp
,
bool
)
{
func
checkTimeTickMsg
(
msg
map
[
int
]
Timestamp
,
isChannelReady
[]
bool
,
mu
*
sync
.
RWMutex
)
(
Timestamp
,
bool
)
{
checkMap
:=
make
(
map
[
Timestamp
]
int
)
var
maxTime
Timestamp
=
0
for
_
,
v
:=
range
msg
{
...
...
@@ -485,7 +489,10 @@ func checkTimeTickMsg(msg map[int]Timestamp, isChannelReady []bool) (Timestamp,
}
return
maxTime
,
true
}
for
i
,
v
:=
range
msg
{
for
i
:=
range
msg
{
mu
.
RLock
()
v
:=
msg
[
i
]
mu
.
Unlock
()
if
v
!=
maxTime
{
isChannelReady
[
i
]
=
false
}
else
{
...
...
internal/proxy/task.go
浏览文件 @
f4566731
...
...
@@ -374,9 +374,6 @@ func (qt *QueryTask) PreExecute() error {
}
}
qt
.
MsgType
=
internalpb
.
MsgType_kSearch
if
qt
.
query
.
PartitionTags
==
nil
||
len
(
qt
.
query
.
PartitionTags
)
<=
0
{
qt
.
query
.
PartitionTags
=
[]
string
{
Params
.
defaultPartitionTag
()}
}
queryBytes
,
err
:=
proto
.
Marshal
(
qt
.
query
)
if
err
!=
nil
{
return
err
...
...
internal/querynode/search_service.go
浏览文件 @
f4566731
...
...
@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"log"
"regexp"
"sync"
"github.com/golang/protobuf/proto"
...
...
@@ -223,7 +224,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
return
errors
.
New
(
"unmarshal query failed"
)
}
collectionName
:=
query
.
CollectionName
partitionTags
:=
query
.
PartitionTags
partitionTags
InQuery
:=
query
.
PartitionTags
collection
,
err
:=
ss
.
replica
.
getCollectionByName
(
collectionName
)
if
err
!=
nil
{
return
err
...
...
@@ -245,11 +246,29 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
searchResults
:=
make
([]
*
SearchResult
,
0
)
matchedSegments
:=
make
([]
*
Segment
,
0
)
for
_
,
partitionTag
:=
range
partitionTags
{
partition
,
err
:=
ss
.
replica
.
getPartitionByTag
(
collectionID
,
partitionTag
)
if
err
!=
nil
{
continue
fmt
.
Println
(
"search msg's partitionTag = "
,
partitionTagsInQuery
)
var
partitionTagsInCol
[]
string
for
_
,
partition
:=
range
collection
.
partitions
{
partitionTag
:=
partition
.
partitionTag
partitionTagsInCol
=
append
(
partitionTagsInCol
,
partitionTag
)
}
var
searchPartitionTag
[]
string
if
len
(
partitionTagsInQuery
)
==
0
{
searchPartitionTag
=
partitionTagsInCol
}
else
{
for
_
,
tag
:=
range
partitionTagsInCol
{
for
_
,
toMatchTag
:=
range
partitionTagsInQuery
{
re
:=
regexp
.
MustCompile
(
"^"
+
toMatchTag
+
"$"
)
if
re
.
MatchString
(
tag
)
{
searchPartitionTag
=
append
(
searchPartitionTag
,
tag
)
}
}
}
}
for
_
,
partitionTag
:=
range
searchPartitionTag
{
partition
,
_
:=
ss
.
replica
.
getPartitionByTag
(
collectionID
,
partitionTag
)
for
_
,
segment
:=
range
partition
.
segments
{
//fmt.Println("dsl = ", dsl)
...
...
@@ -360,6 +379,7 @@ func (ss *searchService) publishSearchResult(msg msgstream.TsMsg) error {
}
func
(
ss
*
searchService
)
publishFailedSearchResult
(
msg
msgstream
.
TsMsg
,
errMsg
string
)
error
{
fmt
.
Println
(
"Public fail SearchResult!"
)
msgPack
:=
msgstream
.
MsgPack
{}
searchMsg
,
ok
:=
msg
.
(
*
msgstream
.
SearchMsg
)
if
!
ok
{
...
...
tests/python/test_search.py
浏览文件 @
f4566731
...
...
@@ -255,7 +255,7 @@ class TestSearchBase:
assert
res2
[
0
][
0
].
id
==
res
[
0
][
1
].
id
assert
res2
[
0
][
0
].
entity
.
get
(
"int64"
)
==
res
[
0
][
1
].
entity
.
get
(
"int64"
)
#
p
ass
#
P
ass
@
pytest
.
mark
.
skip
(
"search_after_index"
)
@
pytest
.
mark
.
level
(
2
)
def
test_search_after_index
(
self
,
connect
,
collection
,
get_simple_index
,
get_top_k
,
get_nq
):
...
...
@@ -303,6 +303,7 @@ class TestSearchBase:
assert
len
(
res
)
==
nq
assert
len
(
res
[
0
])
==
default_top_k
# should fix, 336 assert fail, insert data don't have partitionTag, But search data have
@
pytest
.
mark
.
skip
(
"search_index_partition"
)
@
pytest
.
mark
.
level
(
2
)
def
test_search_index_partition
(
self
,
connect
,
collection
,
get_simple_index
,
get_top_k
,
get_nq
):
...
...
@@ -334,7 +335,7 @@ class TestSearchBase:
res
=
connect
.
search
(
collection
,
query
,
partition_tags
=
[
default_tag
])
assert
len
(
res
)
==
nq
#
pass
#
PASS
@
pytest
.
mark
.
skip
(
"search_index_partition_B"
)
@
pytest
.
mark
.
level
(
2
)
def
test_search_index_partition_B
(
self
,
connect
,
collection
,
get_simple_index
,
get_top_k
,
get_nq
):
...
...
@@ -385,6 +386,7 @@ class TestSearchBase:
assert
len
(
res
)
==
nq
assert
len
(
res
[
0
])
==
0
# PASS
@
pytest
.
mark
.
skip
(
"search_index_partitions"
)
@
pytest
.
mark
.
level
(
2
)
def
test_search_index_partitions
(
self
,
connect
,
collection
,
get_simple_index
,
get_top_k
):
...
...
@@ -419,6 +421,7 @@ class TestSearchBase:
assert
res
[
0
].
_distances
[
0
]
>
epsilon
assert
res
[
1
].
_distances
[
0
]
>
epsilon
# Pass
@
pytest
.
mark
.
skip
(
"search_index_partitions_B"
)
@
pytest
.
mark
.
level
(
2
)
def
test_search_index_partitions_B
(
self
,
connect
,
collection
,
get_simple_index
,
get_top_k
):
...
...
@@ -479,7 +482,7 @@ class TestSearchBase:
with
pytest
.
raises
(
Exception
)
as
e
:
res
=
connect
.
search
(
collection
,
query
)
#
pass
#
PASS
@
pytest
.
mark
.
skip
(
"search_ip_after_index"
)
@
pytest
.
mark
.
level
(
2
)
def
test_search_ip_after_index
(
self
,
connect
,
collection
,
get_simple_index
,
get_top_k
,
get_nq
):
...
...
@@ -509,6 +512,7 @@ class TestSearchBase:
assert
check_id_result
(
res
[
0
],
ids
[
0
])
assert
res
[
0
].
_distances
[
0
]
>=
1
-
gen_inaccuracy
(
res
[
0
].
_distances
[
0
])
# should fix, nq not correct
@
pytest
.
mark
.
skip
(
"search_ip_index_partition"
)
@
pytest
.
mark
.
level
(
2
)
def
test_search_ip_index_partition
(
self
,
connect
,
collection
,
get_simple_index
,
get_top_k
,
get_nq
):
...
...
@@ -542,6 +546,7 @@ class TestSearchBase:
res
=
connect
.
search
(
collection
,
query
,
partition_tags
=
[
default_tag
])
assert
len
(
res
)
==
nq
# PASS
@
pytest
.
mark
.
skip
(
"search_ip_index_partitions"
)
@
pytest
.
mark
.
level
(
2
)
def
test_search_ip_index_partitions
(
self
,
connect
,
collection
,
get_simple_index
,
get_top_k
):
...
...
@@ -621,6 +626,7 @@ class TestSearchBase:
res
=
connect
.
search
(
collection
,
query
)
assert
abs
(
np
.
sqrt
(
res
[
0
].
_distances
[
0
])
-
min
(
distance_0
,
distance_1
))
<=
gen_inaccuracy
(
res
[
0
].
_distances
[
0
])
# Pass
@
pytest
.
mark
.
skip
(
"test_search_distance_l2_after_index"
)
def
test_search_distance_l2_after_index
(
self
,
connect
,
id_collection
,
get_simple_index
):
'''
...
...
@@ -675,7 +681,7 @@ class TestSearchBase:
res
=
connect
.
search
(
collection
,
query
)
assert
abs
(
res
[
0
].
_distances
[
0
]
-
max
(
distance_0
,
distance_1
))
<=
epsilon
#
p
ass
#
P
ass
@
pytest
.
mark
.
skip
(
"search_distance_ip_after_index"
)
def
test_search_distance_ip_after_index
(
self
,
connect
,
id_collection
,
get_simple_index
):
'''
...
...
@@ -946,6 +952,7 @@ class TestSearchBase:
assert
res
[
i
].
_distances
[
0
]
<
epsilon
assert
res
[
i
].
_distances
[
1
]
>
epsilon
# should fix
@
pytest
.
mark
.
skip
(
"query_entities_with_field_less_than_top_k"
)
def
test_query_entities_with_field_less_than_top_k
(
self
,
connect
,
id_collection
):
"""
...
...
@@ -1745,7 +1752,7 @@ class TestSearchInvalid(object):
def
get_search_params
(
self
,
request
):
yield
request
.
param
#
p
ass
#
P
ass
@
pytest
.
mark
.
skip
(
"search_with_invalid_params"
)
@
pytest
.
mark
.
level
(
2
)
def
test_search_with_invalid_params
(
self
,
connect
,
collection
,
get_simple_index
,
get_search_params
):
...
...
@@ -1787,7 +1794,7 @@ class TestSearchInvalid(object):
with
pytest
.
raises
(
Exception
)
as
e
:
res
=
connect
.
search
(
binary_collection
,
query
)
#
p
ass
#
P
ass
@
pytest
.
mark
.
skip
(
"search_with_empty_params"
)
@
pytest
.
mark
.
level
(
2
)
def
test_search_with_empty_params
(
self
,
connect
,
collection
,
args
,
get_simple_index
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录