Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
milvus
milvus
提交
2e7c7a1c
M
milvus
项目概览
milvus
/
milvus
9 个月 前同步成功
通知
260
Star
22476
Fork
2472
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
milvus
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
2e7c7a1c
编写于
11月 25, 2020
作者:
T
ThreadDao
提交者:
GitHub
11月 25, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix multi thread case by override join method (#4281)
Signed-off-by:
N
ThreadDao
<
zongyufen@foxmail.com
>
上级
7d393ce3
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
69 addition
and
60 deletion
+69
-60
tests/milvus_python_test/collection/test_create_collection.py
...s/milvus_python_test/collection/test_create_collection.py
+2
-2
tests/milvus_python_test/collection/test_drop_collection.py
tests/milvus_python_test/collection/test_drop_collection.py
+2
-2
tests/milvus_python_test/collection/test_get_collection_info.py
...milvus_python_test/collection/test_get_collection_info.py
+12
-16
tests/milvus_python_test/collection/test_has_collection.py
tests/milvus_python_test/collection/test_has_collection.py
+4
-3
tests/milvus_python_test/collection/test_list_collections.py
tests/milvus_python_test/collection/test_list_collections.py
+7
-14
tests/milvus_python_test/entity/test_bulk_insert.py
tests/milvus_python_test/entity/test_bulk_insert.py
+6
-6
tests/milvus_python_test/entity/test_search.py
tests/milvus_python_test/entity/test_search.py
+4
-5
tests/milvus_python_test/test_flush.py
tests/milvus_python_test/test_flush.py
+6
-4
tests/milvus_python_test/test_index.py
tests/milvus_python_test/test_index.py
+4
-4
tests/milvus_python_test/utils.py
tests/milvus_python_test/utils.py
+22
-4
未找到文件。
tests/milvus_python_test/collection/test_create_collection.py
浏览文件 @
2e7c7a1c
...
...
@@ -2,7 +2,7 @@ import pdb
import
copy
import
logging
import
itertools
from
time
import
sleep
import
time
import
threading
from
multiprocessing
import
Process
import
sklearn.preprocessing
...
...
@@ -172,7 +172,7 @@ class TestCreateCollection:
collection_names
.
append
(
collection_name
)
connect
.
create_collection
(
collection_name
,
default_fields
)
for
i
in
range
(
threads_num
):
t
=
threading
.
Thread
(
target
=
create
,
args
=
())
t
=
Test
Thread
(
target
=
create
,
args
=
())
threads
.
append
(
t
)
t
.
start
()
time
.
sleep
(
0.2
)
...
...
tests/milvus_python_test/collection/test_drop_collection.py
浏览文件 @
2e7c7a1c
...
...
@@ -59,12 +59,12 @@ class TestDropCollection:
collection_names
=
[]
def
create
():
collection_name
=
gen_unique_str
(
collection
_id
)
collection_name
=
gen_unique_str
(
uniq
_id
)
collection_names
.
append
(
collection_name
)
connect
.
create_collection
(
collection_name
,
default_fields
)
connect
.
drop_collection
(
collection_name
)
for
i
in
range
(
threads_num
):
t
=
threading
.
Thread
(
target
=
create
,
args
=
())
t
=
Test
Thread
(
target
=
create
,
args
=
())
threads
.
append
(
t
)
t
.
start
()
time
.
sleep
(
0.2
)
...
...
tests/milvus_python_test/collection/test_get_collection_info.py
浏览文件 @
2e7c7a1c
import
pdb
import
pytest
import
logging
import
itertools
from
time
import
sleep
import
threading
from
multiprocessing
import
Process
import
time
from
utils
import
*
from
constants
import
*
uid
=
"collection_info"
class
TestInfoBase
:
@
pytest
.
fixture
(
...
...
@@ -49,7 +46,7 @@ class TestInfoBase:
The following cases are used to test `get_collection_info` function, no data in collection
******************************************************************
"""
def
test_info_collection_fields
(
self
,
connect
,
get_filter_field
,
get_vector_field
):
'''
target: test create normal collection with different fields, check info returned
...
...
@@ -60,8 +57,8 @@ class TestInfoBase:
vector_field
=
get_vector_field
collection_name
=
gen_unique_str
(
uid
)
fields
=
{
"fields"
:
[
filter_field
,
vector_field
],
"segment_row_limit"
:
default_segment_row_limit
"fields"
:
[
filter_field
,
vector_field
],
"segment_row_limit"
:
default_segment_row_limit
}
connect
.
create_collection
(
collection_name
,
fields
)
res
=
connect
.
get_collection_info
(
collection_name
)
...
...
@@ -123,20 +120,19 @@ class TestInfoBase:
def
test_get_collection_info_multithread
(
self
,
connect
):
'''
target: test create collection with multithread
method: create collection using multithread,
method: create collection using multithread,
expected: collections are created
'''
threads_num
=
4
threads_num
=
4
threads
=
[]
collection_name
=
gen_unique_str
(
uid
)
connect
.
create_collection
(
collection_name
,
default_fields
)
def
get_info
():
res
=
connect
.
get_collection_info
(
connect
,
collection_name
)
# assert
connect
.
get_collection_info
(
collection_name
)
for
i
in
range
(
threads_num
):
t
=
threading
.
Thread
(
target
=
get_info
,
args
=
()
)
t
=
TestThread
(
target
=
get_info
)
threads
.
append
(
t
)
t
.
start
()
time
.
sleep
(
0.2
)
...
...
@@ -159,8 +155,8 @@ class TestInfoBase:
vector_field
=
get_vector_field
collection_name
=
gen_unique_str
(
uid
)
fields
=
{
"fields"
:
[
filter_field
,
vector_field
],
"segment_row_limit"
:
default_segment_row_limit
"fields"
:
[
filter_field
,
vector_field
],
"segment_row_limit"
:
default_segment_row_limit
}
connect
.
create_collection
(
collection_name
,
fields
)
entities
=
gen_entities_by_fields
(
fields
[
"fields"
],
default_nb
,
vector_field
[
"params"
][
"dim"
])
...
...
@@ -199,6 +195,7 @@ class TestInfoInvalid(object):
"""
Test get collection info with invalid params
"""
@
pytest
.
fixture
(
scope
=
"function"
,
params
=
gen_invalid_strs
()
...
...
@@ -206,7 +203,6 @@ class TestInfoInvalid(object):
def
get_collection_name
(
self
,
request
):
yield
request
.
param
@
pytest
.
mark
.
level
(
2
)
def
test_get_collection_info_with_invalid_collectionname
(
self
,
connect
,
get_collection_name
):
collection_name
=
get_collection_name
...
...
tests/milvus_python_test/collection/test_has_collection.py
浏览文件 @
2e7c7a1c
...
...
@@ -3,7 +3,7 @@ import pytest
import
logging
import
itertools
import
threading
from
time
import
sleep
import
time
from
multiprocessing
import
Process
from
utils
import
*
from
constants
import
*
...
...
@@ -57,9 +57,10 @@ class TestHasCollection:
connect
.
create_collection
(
collection_name
,
default_fields
)
def
has
():
assert
not
assert_collection
(
connect
,
collection_name
)
assert
connect
.
has_collection
(
collection_name
)
# assert not assert_collection(connect, collection_name)
for
i
in
range
(
threads_num
):
t
=
threading
.
Thread
(
target
=
has
,
args
=
())
t
=
Test
Thread
(
target
=
has
,
args
=
())
threads
.
append
(
t
)
t
.
start
()
time
.
sleep
(
0.2
)
...
...
tests/milvus_python_test/collection/test_list_collections.py
浏览文件 @
2e7c7a1c
import
pdb
import
pytest
import
logging
import
itertools
import
threading
from
time
import
sleep
from
multiprocessing
import
Process
import
time
from
utils
import
*
from
constants
import
*
uid
=
"list_collections"
class
TestListCollections
:
"""
******************************************************************
The following cases are used to test `list_collections` function
******************************************************************
"""
def
test_list_collections
(
self
,
connect
,
collection
):
'''
target: test list collections
...
...
@@ -71,20 +68,16 @@ class TestListCollections:
@
pytest
.
mark
.
level
(
2
)
def
test_list_collections_multithread
(
self
,
connect
):
'''
target: test create collection with multithread
method: create collection using multithread,
expected: collections are created
'''
threads_num
=
4
threads_num
=
10
threads
=
[]
collection_name
=
gen_unique_str
(
uid
)
connect
.
create_collection
(
collection_name
,
default_fields
)
def
_
list
():
def
list
():
assert
collection_name
in
connect
.
list_collections
()
for
i
in
range
(
threads_num
):
t
=
threading
.
Thread
(
target
=
_list
,
args
=
()
)
t
=
TestThread
(
target
=
list
)
threads
.
append
(
t
)
t
.
start
()
time
.
sleep
(
0.2
)
...
...
tests/milvus_python_test/entity/test_bulk_insert.py
浏览文件 @
2e7c7a1c
...
...
@@ -546,15 +546,15 @@ class TestInsertBase:
def
insert
(
thread_i
):
logging
.
getLogger
().
info
(
"In thread-%d"
%
thread_i
)
res_ids
=
milvus
.
bulk_insert
(
collection
,
default_entities
)
milvus
.
bulk_insert
(
collection
,
default_entities
)
milvus
.
flush
([
collection
])
for
i
in
range
(
thread_num
):
x
=
threading
.
Thread
(
target
=
insert
,
args
=
(
i
,))
threads
.
append
(
x
)
x
.
start
()
for
t
h
in
threads
:
t
h
.
join
()
t
=
Test
Thread
(
target
=
insert
,
args
=
(
i
,))
threads
.
append
(
t
)
t
.
start
()
for
t
in
threads
:
t
.
join
()
res_count
=
milvus
.
count_entities
(
collection
)
assert
res_count
==
thread_num
*
default_nb
...
...
tests/milvus_python_test/entity/test_search.py
浏览文件 @
2e7c7a1c
import
time
import
pdb
import
copy
import
threading
import
logging
from
multiprocessing
import
Pool
,
Process
import
pytest
...
...
@@ -834,14 +833,14 @@ class TestSearchBase:
entities
,
ids
=
init_data
(
milvus
,
collection
)
def
search
(
milvus
):
res
=
connect
.
search
(
collection
,
default_query
)
res
=
milvus
.
search
(
collection
,
default_query
)
assert
len
(
res
)
==
1
assert
res
[
0
].
_entities
[
0
].
id
in
ids
assert
res
[
0
].
_distances
[
0
]
<
epsilon
for
i
in
range
(
threads_num
):
milvus
=
get_milvus
(
args
[
"ip"
],
args
[
"port"
],
handler
=
args
[
"handler"
])
t
=
threading
.
Thread
(
target
=
search
,
args
=
(
milvus
,))
t
=
Test
Thread
(
target
=
search
,
args
=
(
milvus
,))
threads
.
append
(
t
)
t
.
start
()
time
.
sleep
(
0.2
)
...
...
@@ -868,13 +867,13 @@ class TestSearchBase:
entities
,
ids
=
init_data
(
milvus
,
collection
)
def
search
(
milvus
):
res
=
connect
.
search
(
collection
,
default_query
)
res
=
milvus
.
search
(
collection
,
default_query
)
assert
len
(
res
)
==
1
assert
res
[
0
].
_entities
[
0
].
id
in
ids
assert
res
[
0
].
_distances
[
0
]
<
epsilon
for
i
in
range
(
threads_num
):
t
=
threading
.
Thread
(
target
=
search
,
args
=
(
milvus
,))
t
=
Test
Thread
(
target
=
search
,
args
=
(
milvus
,))
threads
.
append
(
t
)
t
.
start
()
time
.
sleep
(
0.2
)
...
...
tests/milvus_python_test/test_flush.py
浏览文件 @
2e7c7a1c
...
...
@@ -17,6 +17,7 @@ default_single_query = {
}
}
class
TestFlushBase
:
"""
******************************************************************
...
...
@@ -240,17 +241,18 @@ class TestFlushBase:
ids
.
extend
(
tmp_ids
)
disable_flush
(
connect
)
status
=
connect
.
delete_entity_by_id
(
collection
,
ids
)
def
flush
():
milvus
=
get_milvus
(
args
[
"ip"
],
args
[
"port"
],
handler
=
args
[
"handler"
])
logging
.
error
(
"start flush"
)
milvus
.
flush
([
collection
])
logging
.
error
(
"end flush"
)
p
=
threading
.
Thread
(
target
=
flush
,
args
=
())
p
=
Test
Thread
(
target
=
flush
,
args
=
())
p
.
start
()
time
.
sleep
(
0.2
)
logging
.
error
(
"start count"
)
res
=
connect
.
count_entities
(
collection
,
timeout
=
10
)
res
=
connect
.
count_entities
(
collection
,
timeout
=
10
)
p
.
join
()
res
=
connect
.
count_entities
(
collection
)
assert
res
==
0
...
...
@@ -275,7 +277,7 @@ class TestFlushBase:
status
=
connect
.
delete_entity_by_id
(
collection
,
delete_ids
)
connect
.
flush
([
collection
])
res
=
future
.
result
()
res_count
=
connect
.
count_entities
(
collection
,
timeout
=
120
)
res_count
=
connect
.
count_entities
(
collection
,
timeout
=
120
)
assert
res_count
==
loops
*
default_nb
-
len
(
delete_ids
)
...
...
tests/milvus_python_test/test_index.py
浏览文件 @
2e7c7a1c
...
...
@@ -146,7 +146,7 @@ class TestIndexBase:
method: create collection and add entities in it, create index
expected: return search success
'''
ids
=
connect
.
bulk_insert
(
collection
,
default_entities
)
connect
.
bulk_insert
(
collection
,
default_entities
)
def
build
(
connect
):
connect
.
create_index
(
collection
,
field_name
,
default_index
)
...
...
@@ -155,7 +155,7 @@ class TestIndexBase:
threads
=
[]
for
i
in
range
(
threads_num
):
m
=
get_milvus
(
host
=
args
[
"ip"
],
port
=
args
[
"port"
],
handler
=
args
[
"handler"
])
t
=
threading
.
Thread
(
target
=
build
,
args
=
(
m
,))
t
=
Test
Thread
(
target
=
build
,
args
=
(
m
,))
threads
.
append
(
t
)
t
.
start
()
time
.
sleep
(
0.2
)
...
...
@@ -289,7 +289,7 @@ class TestIndexBase:
method: create collection and add entities in it, create index
expected: return search success
'''
ids
=
connect
.
bulk_insert
(
collection
,
default_entities
)
connect
.
bulk_insert
(
collection
,
default_entities
)
def
build
(
connect
):
default_index
[
"metric_type"
]
=
"IP"
...
...
@@ -299,7 +299,7 @@ class TestIndexBase:
threads
=
[]
for
i
in
range
(
threads_num
):
m
=
get_milvus
(
host
=
args
[
"ip"
],
port
=
args
[
"port"
],
handler
=
args
[
"handler"
])
t
=
threading
.
Thread
(
target
=
build
,
args
=
(
m
,))
t
=
Test
Thread
(
target
=
build
,
args
=
(
m
,))
threads
.
append
(
t
)
t
.
start
()
time
.
sleep
(
0.2
)
...
...
tests/milvus_python_test/utils.py
浏览文件 @
2e7c7a1c
...
...
@@ -5,7 +5,8 @@ import pdb
import
string
import
struct
import
logging
import
time
,
datetime
import
threading
import
time
import
copy
import
numpy
as
np
from
sklearn
import
preprocessing
...
...
@@ -245,7 +246,7 @@ def gen_default_fields(auto_id=True):
{
"name"
:
default_float_vec_field_name
,
"type"
:
DataType
.
FLOAT_VECTOR
,
"params"
:
{
"dim"
:
default_dim
}},
],
"segment_row_limit"
:
default_segment_row_limit
,
"auto_id"
:
auto_id
"auto_id"
:
auto_id
}
return
default_fields
...
...
@@ -258,7 +259,7 @@ def gen_binary_default_fields(auto_id=True):
{
"name"
:
default_binary_vec_field_name
,
"type"
:
DataType
.
BINARY_VECTOR
,
"params"
:
{
"dim"
:
default_dim
}}
],
"segment_row_limit"
:
default_segment_row_limit
,
"auto_id"
:
auto_id
"auto_id"
:
auto_id
}
return
default_fields
...
...
@@ -441,7 +442,7 @@ def gen_invalid_range():
def
gen_valid_ranges
():
ranges
=
[
{
"GT"
:
0
,
"LT"
:
default_nb
//
2
},
{
"GT"
:
0
,
"LT"
:
default_nb
//
2
},
{
"GT"
:
default_nb
//
2
,
"LT"
:
default_nb
*
2
},
{
"GT"
:
0
},
{
"LT"
:
default_nb
},
...
...
@@ -969,3 +970,20 @@ def restart_server(helm_release_name):
# logging.error("Restart pod: %s timeout" % pod_name_tmp)
# res = False
return
res
class
TestThread
(
threading
.
Thread
):
def
__init__
(
self
,
target
,
args
=
()):
threading
.
Thread
.
__init__
(
self
,
target
=
target
,
args
=
args
)
def
run
(
self
):
self
.
exc
=
None
try
:
super
(
TestThread
,
self
).
run
()
except
BaseException
as
e
:
self
.
exc
=
e
def
join
(
self
):
super
(
TestThread
,
self
).
join
()
if
self
.
exc
:
raise
self
.
exc
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录