未验证 提交 2e7c7a1c 编写于 作者: T ThreadDao 提交者: GitHub

fix multi thread case by override join method (#4281)

Signed-off-by: NThreadDao <zongyufen@foxmail.com>
上级 7d393ce3
......@@ -2,7 +2,7 @@ import pdb
import copy
import logging
import itertools
from time import sleep
import time
import threading
from multiprocessing import Process
import sklearn.preprocessing
......@@ -172,7 +172,7 @@ class TestCreateCollection:
collection_names.append(collection_name)
connect.create_collection(collection_name, default_fields)
for i in range(threads_num):
t = threading.Thread(target=create, args=())
t = TestThread(target=create, args=())
threads.append(t)
t.start()
time.sleep(0.2)
......
......@@ -59,12 +59,12 @@ class TestDropCollection:
collection_names = []
def create():
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uniq_id)
collection_names.append(collection_name)
connect.create_collection(collection_name, default_fields)
connect.drop_collection(collection_name)
for i in range(threads_num):
t = threading.Thread(target=create, args=())
t = TestThread(target=create, args=())
threads.append(t)
t.start()
time.sleep(0.2)
......
import pdb
import pytest
import logging
import itertools
from time import sleep
import threading
from multiprocessing import Process
import time
from utils import *
from constants import *
uid = "collection_info"
class TestInfoBase:
@pytest.fixture(
......@@ -49,7 +46,7 @@ class TestInfoBase:
The following cases are used to test `get_collection_info` function, no data in collection
******************************************************************
"""
def test_info_collection_fields(self, connect, get_filter_field, get_vector_field):
'''
target: test create normal collection with different fields, check info returned
......@@ -60,8 +57,8 @@ class TestInfoBase:
vector_field = get_vector_field
collection_name = gen_unique_str(uid)
fields = {
"fields": [filter_field, vector_field],
"segment_row_limit": default_segment_row_limit
"fields": [filter_field, vector_field],
"segment_row_limit": default_segment_row_limit
}
connect.create_collection(collection_name, fields)
res = connect.get_collection_info(collection_name)
......@@ -123,20 +120,19 @@ class TestInfoBase:
def test_get_collection_info_multithread(self, connect):
'''
target: test create collection with multithread
method: create collection using multithread,
method: create collection using multithread,
expected: collections are created
'''
threads_num = 4
threads_num = 4
threads = []
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
def get_info():
res = connect.get_collection_info(connect, collection_name)
# assert
connect.get_collection_info(collection_name)
for i in range(threads_num):
t = threading.Thread(target=get_info, args=())
t = TestThread(target=get_info)
threads.append(t)
t.start()
time.sleep(0.2)
......@@ -159,8 +155,8 @@ class TestInfoBase:
vector_field = get_vector_field
collection_name = gen_unique_str(uid)
fields = {
"fields": [filter_field, vector_field],
"segment_row_limit": default_segment_row_limit
"fields": [filter_field, vector_field],
"segment_row_limit": default_segment_row_limit
}
connect.create_collection(collection_name, fields)
entities = gen_entities_by_fields(fields["fields"], default_nb, vector_field["params"]["dim"])
......@@ -199,6 +195,7 @@ class TestInfoInvalid(object):
"""
Test get collection info with invalid params
"""
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
......@@ -206,7 +203,6 @@ class TestInfoInvalid(object):
def get_collection_name(self, request):
yield request.param
@pytest.mark.level(2)
def test_get_collection_info_with_invalid_collectionname(self, connect, get_collection_name):
collection_name = get_collection_name
......
......@@ -3,7 +3,7 @@ import pytest
import logging
import itertools
import threading
from time import sleep
import time
from multiprocessing import Process
from utils import *
from constants import *
......@@ -57,9 +57,10 @@ class TestHasCollection:
connect.create_collection(collection_name, default_fields)
def has():
assert not assert_collection(connect, collection_name)
assert connect.has_collection(collection_name)
# assert not assert_collection(connect, collection_name)
for i in range(threads_num):
t = threading.Thread(target=has, args=())
t = TestThread(target=has, args=())
threads.append(t)
t.start()
time.sleep(0.2)
......
import pdb
import pytest
import logging
import itertools
import threading
from time import sleep
from multiprocessing import Process
import time
from utils import *
from constants import *
uid = "list_collections"
class TestListCollections:
"""
******************************************************************
The following cases are used to test `list_collections` function
******************************************************************
"""
def test_list_collections(self, connect, collection):
'''
target: test list collections
......@@ -71,20 +68,16 @@ class TestListCollections:
@pytest.mark.level(2)
def test_list_collections_multithread(self, connect):
'''
target: test create collection with multithread
method: create collection using multithread,
expected: collections are created
'''
threads_num = 4
threads_num = 10
threads = []
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
def _list():
def list():
assert collection_name in connect.list_collections()
for i in range(threads_num):
t = threading.Thread(target=_list, args=())
t = TestThread(target=list)
threads.append(t)
t.start()
time.sleep(0.2)
......
......@@ -546,15 +546,15 @@ class TestInsertBase:
def insert(thread_i):
logging.getLogger().info("In thread-%d" % thread_i)
res_ids = milvus.bulk_insert(collection, default_entities)
milvus.bulk_insert(collection, default_entities)
milvus.flush([collection])
for i in range(thread_num):
x = threading.Thread(target=insert, args=(i,))
threads.append(x)
x.start()
for th in threads:
th.join()
t = TestThread(target=insert, args=(i,))
threads.append(t)
t.start()
for t in threads:
t.join()
res_count = milvus.count_entities(collection)
assert res_count == thread_num * default_nb
......
import time
import pdb
import copy
import threading
import logging
from multiprocessing import Pool, Process
import pytest
......@@ -834,14 +833,14 @@ class TestSearchBase:
entities, ids = init_data(milvus, collection)
def search(milvus):
res = connect.search(collection, default_query)
res = milvus.search(collection, default_query)
assert len(res) == 1
assert res[0]._entities[0].id in ids
assert res[0]._distances[0] < epsilon
for i in range(threads_num):
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
t = threading.Thread(target=search, args=(milvus,))
t = TestThread(target=search, args=(milvus,))
threads.append(t)
t.start()
time.sleep(0.2)
......@@ -868,13 +867,13 @@ class TestSearchBase:
entities, ids = init_data(milvus, collection)
def search(milvus):
res = connect.search(collection, default_query)
res = milvus.search(collection, default_query)
assert len(res) == 1
assert res[0]._entities[0].id in ids
assert res[0]._distances[0] < epsilon
for i in range(threads_num):
t = threading.Thread(target=search, args=(milvus,))
t = TestThread(target=search, args=(milvus,))
threads.append(t)
t.start()
time.sleep(0.2)
......
......@@ -17,6 +17,7 @@ default_single_query = {
}
}
class TestFlushBase:
"""
******************************************************************
......@@ -240,17 +241,18 @@ class TestFlushBase:
ids.extend(tmp_ids)
disable_flush(connect)
status = connect.delete_entity_by_id(collection, ids)
def flush():
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
logging.error("start flush")
milvus.flush([collection])
logging.error("end flush")
p = threading.Thread(target=flush, args=())
p = TestThread(target=flush, args=())
p.start()
time.sleep(0.2)
logging.error("start count")
res = connect.count_entities(collection, timeout = 10)
res = connect.count_entities(collection, timeout=10)
p.join()
res = connect.count_entities(collection)
assert res == 0
......@@ -275,7 +277,7 @@ class TestFlushBase:
status = connect.delete_entity_by_id(collection, delete_ids)
connect.flush([collection])
res = future.result()
res_count = connect.count_entities(collection, timeout = 120)
res_count = connect.count_entities(collection, timeout=120)
assert res_count == loops * default_nb - len(delete_ids)
......
......@@ -146,7 +146,7 @@ class TestIndexBase:
method: create collection and add entities in it, create index
expected: return search success
'''
ids = connect.bulk_insert(collection, default_entities)
connect.bulk_insert(collection, default_entities)
def build(connect):
connect.create_index(collection, field_name, default_index)
......@@ -155,7 +155,7 @@ class TestIndexBase:
threads = []
for i in range(threads_num):
m = get_milvus(host=args["ip"], port=args["port"], handler=args["handler"])
t = threading.Thread(target=build, args=(m,))
t = TestThread(target=build, args=(m,))
threads.append(t)
t.start()
time.sleep(0.2)
......@@ -289,7 +289,7 @@ class TestIndexBase:
method: create collection and add entities in it, create index
expected: return search success
'''
ids = connect.bulk_insert(collection, default_entities)
connect.bulk_insert(collection, default_entities)
def build(connect):
default_index["metric_type"] = "IP"
......@@ -299,7 +299,7 @@ class TestIndexBase:
threads = []
for i in range(threads_num):
m = get_milvus(host=args["ip"], port=args["port"], handler=args["handler"])
t = threading.Thread(target=build, args=(m,))
t = TestThread(target=build, args=(m,))
threads.append(t)
t.start()
time.sleep(0.2)
......
......@@ -5,7 +5,8 @@ import pdb
import string
import struct
import logging
import time, datetime
import threading
import time
import copy
import numpy as np
from sklearn import preprocessing
......@@ -245,7 +246,7 @@ def gen_default_fields(auto_id=True):
{"name": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "params": {"dim": default_dim}},
],
"segment_row_limit": default_segment_row_limit,
"auto_id" : auto_id
"auto_id": auto_id
}
return default_fields
......@@ -258,7 +259,7 @@ def gen_binary_default_fields(auto_id=True):
{"name": default_binary_vec_field_name, "type": DataType.BINARY_VECTOR, "params": {"dim": default_dim}}
],
"segment_row_limit": default_segment_row_limit,
"auto_id" : auto_id
"auto_id": auto_id
}
return default_fields
......@@ -441,7 +442,7 @@ def gen_invalid_range():
def gen_valid_ranges():
ranges = [
{"GT": 0, "LT": default_nb//2},
{"GT": 0, "LT": default_nb // 2},
{"GT": default_nb // 2, "LT": default_nb * 2},
{"GT": 0},
{"LT": default_nb},
......@@ -969,3 +970,20 @@ def restart_server(helm_release_name):
# logging.error("Restart pod: %s timeout" % pod_name_tmp)
# res = False
return res
class TestThread(threading.Thread):
def __init__(self, target, args=()):
threading.Thread.__init__(self, target=target, args=args)
def run(self):
self.exc = None
try:
super(TestThread, self).run()
except BaseException as e:
self.exc = e
def join(self):
super(TestThread, self).join()
if self.exc:
raise self.exc
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册