提交 01781103 编写于 作者: B bigsheeper 提交者: yefu.chen

Fix incorrect usage of msgStream and illegal check in master

Signed-off-by: Nbigsheeper <yihao.dai@zilliz.com>
上级 a6690dbc
......@@ -63,6 +63,8 @@ SearchOnSealed(const Schema& schema,
Assert(record.test_readiness(field_offset));
auto indexing_entry = record.get_entry(field_offset);
std::cout << " SearchOnSealed, indexing_entry->metric:" << indexing_entry->metric_type_ << std::endl;
std::cout << " SearchOnSealed, query_info.metric_type_:" << query_info.metric_type_ << std::endl;
Assert(indexing_entry->metric_type_ == GetMetricType(query_info.metric_type_));
auto final = [&] {
......
......@@ -65,12 +65,8 @@ func refreshChannelNames() {
}
func receiveTimeTickMsg(stream *ms.MsgStream) bool {
for {
result := (*stream).Consume()
if len(result.Msgs) > 0 {
return true
}
}
result := (*stream).Consume()
return result != nil
}
func getTimeTickMsgPack(ttmsgs [][2]uint64) *ms.MsgPack {
......@@ -81,6 +77,14 @@ func getTimeTickMsgPack(ttmsgs [][2]uint64) *ms.MsgPack {
return &msgPack
}
func mockTimeTickBroadCast(msgStream ms.MsgStream, time Timestamp) error {
timeTick := [][2]uint64{
{0, time},
}
ttMsgPackForDD := getTimeTickMsgPack(timeTick)
return msgStream.Broadcast(ttMsgPackForDD)
}
func TestMaster(t *testing.T) {
Init()
refreshMasterAddress()
......@@ -533,10 +537,15 @@ func TestMaster(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS)
time.Sleep(1000 * time.Millisecond)
timestampNow := Timestamp(time.Now().Unix())
err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow)
assert.NoError(t, err)
//consume msg
ddMs := ms.NewPulsarMsgStream(ctx, 1024)
ddMs := ms.NewPulsarTtMsgStream(ctx, 1024)
ddMs.SetPulsarClient(pulsarAddr)
ddMs.CreatePulsarConsumers(Params.DDChannelNames, "DDStream", ms.NewUnmarshalDispatcher(), 1024)
ddMs.CreatePulsarConsumers(Params.DDChannelNames, Params.MsgChannelSubName, ms.NewUnmarshalDispatcher(), 1024)
ddMs.Start()
var consumeMsg ms.MsgStream = ddMs
......@@ -822,11 +831,16 @@ func TestMaster(t *testing.T) {
assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS)
//consume msg
ddMs := ms.NewPulsarMsgStream(ctx, 1024)
ddMs := ms.NewPulsarTtMsgStream(ctx, 1024)
ddMs.SetPulsarClient(pulsarAddr)
ddMs.CreatePulsarConsumers(Params.DDChannelNames, "DDStream", ms.NewUnmarshalDispatcher(), 1024)
ddMs.CreatePulsarConsumers(Params.DDChannelNames, Params.MsgChannelSubName, ms.NewUnmarshalDispatcher(), 1024)
ddMs.Start()
time.Sleep(1000 * time.Millisecond)
timestampNow := Timestamp(time.Now().Unix())
err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow)
assert.NoError(t, err)
var consumeMsg ms.MsgStream = ddMs
for {
result := consumeMsg.Consume()
......@@ -849,19 +863,19 @@ func TestMaster(t *testing.T) {
writeNodeStream.CreatePulsarProducers(Params.WriteNodeTimeTickChannelNames)
writeNodeStream.Start()
ddMs := ms.NewPulsarMsgStream(ctx, 1024)
ddMs := ms.NewPulsarTtMsgStream(ctx, 1024)
ddMs.SetPulsarClient(pulsarAddr)
ddMs.CreatePulsarConsumers(Params.DDChannelNames, "DDStream", ms.NewUnmarshalDispatcher(), 1024)
ddMs.CreatePulsarConsumers(Params.DDChannelNames, Params.MsgChannelSubName, ms.NewUnmarshalDispatcher(), 1024)
ddMs.Start()
dMMs := ms.NewPulsarMsgStream(ctx, 1024)
dMMs := ms.NewPulsarTtMsgStream(ctx, 1024)
dMMs.SetPulsarClient(pulsarAddr)
dMMs.CreatePulsarConsumers(Params.InsertChannelNames, "DMStream", ms.NewUnmarshalDispatcher(), 1024)
dMMs.CreatePulsarConsumers(Params.InsertChannelNames, Params.MsgChannelSubName, ms.NewUnmarshalDispatcher(), 1024)
dMMs.Start()
k2sMs := ms.NewPulsarMsgStream(ctx, 1024)
k2sMs.SetPulsarClient(pulsarAddr)
k2sMs.CreatePulsarConsumers(Params.K2SChannelNames, "K2SStream", ms.NewUnmarshalDispatcher(), 1024)
k2sMs.CreatePulsarConsumers(Params.K2SChannelNames, Params.MsgChannelSubName, ms.NewUnmarshalDispatcher(), 1024)
k2sMs.Start()
ttsoftmsgs := [][2]uint64{
......@@ -902,10 +916,11 @@ func TestMaster(t *testing.T) {
schemaBytes, err := proto.Marshal(&sch)
assert.Nil(t, err)
////////////////////////////CreateCollection////////////////////////
createCollectionReq := internalpb.CreateCollectionRequest{
MsgType: internalpb.MsgType_kCreateCollection,
ReqID: 1,
Timestamp: uint64(time.Now().Unix()),
Timestamp: Timestamp(time.Now().Unix()),
ProxyID: 1,
Schema: &commonpb.Blob{Value: schemaBytes},
}
......@@ -913,6 +928,11 @@ func TestMaster(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS)
time.Sleep(1000 * time.Millisecond)
timestampNow := Timestamp(time.Now().Unix())
err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow)
assert.NoError(t, err)
var consumeMsg ms.MsgStream = ddMs
var createCollectionMsg *ms.CreateCollectionMsg
for {
......@@ -947,6 +967,11 @@ func TestMaster(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS)
time.Sleep(1000 * time.Millisecond)
timestampNow = Timestamp(time.Now().Unix())
err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow)
assert.NoError(t, err)
var createPartitionMsg *ms.CreatePartitionMsg
for {
result := consumeMsg.Consume()
......@@ -981,6 +1006,11 @@ func TestMaster(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS)
time.Sleep(1000 * time.Millisecond)
timestampNow = Timestamp(time.Now().Unix())
err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow)
assert.NoError(t, err)
var dropPartitionMsg *ms.DropPartitionMsg
for {
result := consumeMsg.Consume()
......@@ -1011,6 +1041,11 @@ func TestMaster(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS)
time.Sleep(1000 * time.Millisecond)
timestampNow = Timestamp(time.Now().Unix())
err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow)
assert.NoError(t, err)
var dropCollectionMsg *ms.DropCollectionMsg
for {
result := consumeMsg.Consume()
......
......@@ -46,7 +46,7 @@ func TestMaster_Scheduler_Collection(t *testing.T) {
pulsarDDStream.Start()
defer pulsarDDStream.Close()
consumeMs := ms.NewPulsarMsgStream(ctx, 1024)
consumeMs := ms.NewPulsarTtMsgStream(ctx, 1024)
consumeMs.SetPulsarClient(pulsarAddr)
consumeMs.CreatePulsarConsumers(consumerChannels, consumerSubName, ms.NewUnmarshalDispatcher(), 1024)
consumeMs.Start()
......@@ -96,6 +96,9 @@ func TestMaster_Scheduler_Collection(t *testing.T) {
err = createCollectionTask.WaitToFinish(ctx)
assert.Nil(t, err)
err = mockTimeTickBroadCast(pulsarDDStream, Timestamp(12))
assert.NoError(t, err)
var consumeMsg ms.MsgStream = consumeMs
var createCollectionMsg *ms.CreateCollectionMsg
for {
......@@ -118,7 +121,7 @@ func TestMaster_Scheduler_Collection(t *testing.T) {
dropCollectionReq := internalpb.DropCollectionRequest{
MsgType: internalpb.MsgType_kDropCollection,
ReqID: 1,
Timestamp: 11,
Timestamp: 13,
ProxyID: 1,
CollectionName: &servicepb.CollectionName{CollectionName: sch.Name},
}
......@@ -138,6 +141,9 @@ func TestMaster_Scheduler_Collection(t *testing.T) {
err = dropCollectionTask.WaitToFinish(ctx)
assert.Nil(t, err)
err = mockTimeTickBroadCast(pulsarDDStream, Timestamp(14))
assert.NoError(t, err)
var dropCollectionMsg *ms.DropCollectionMsg
for {
result := consumeMsg.Consume()
......@@ -184,7 +190,7 @@ func TestMaster_Scheduler_Partition(t *testing.T) {
pulsarDDStream.Start()
defer pulsarDDStream.Close()
consumeMs := ms.NewPulsarMsgStream(ctx, 1024)
consumeMs := ms.NewPulsarTtMsgStream(ctx, 1024)
consumeMs.SetPulsarClient(pulsarAddr)
consumeMs.CreatePulsarConsumers(consumerChannels, consumerSubName, ms.NewUnmarshalDispatcher(), 1024)
consumeMs.Start()
......@@ -234,6 +240,9 @@ func TestMaster_Scheduler_Partition(t *testing.T) {
err = createCollectionTask.WaitToFinish(ctx)
assert.Nil(t, err)
err = mockTimeTickBroadCast(pulsarDDStream, Timestamp(12))
assert.NoError(t, err)
var consumeMsg ms.MsgStream = consumeMs
var createCollectionMsg *ms.CreateCollectionMsg
for {
......@@ -257,7 +266,7 @@ func TestMaster_Scheduler_Partition(t *testing.T) {
createPartitionReq := internalpb.CreatePartitionRequest{
MsgType: internalpb.MsgType_kCreatePartition,
ReqID: 1,
Timestamp: 11,
Timestamp: 13,
ProxyID: 1,
PartitionName: &servicepb.PartitionName{
CollectionName: sch.Name,
......@@ -279,6 +288,9 @@ func TestMaster_Scheduler_Partition(t *testing.T) {
err = createPartitionTask.WaitToFinish(ctx)
assert.Nil(t, err)
err = mockTimeTickBroadCast(pulsarDDStream, Timestamp(14))
assert.NoError(t, err)
var createPartitionMsg *ms.CreatePartitionMsg
for {
result := consumeMsg.Consume()
......@@ -301,7 +313,7 @@ func TestMaster_Scheduler_Partition(t *testing.T) {
dropPartitionReq := internalpb.DropPartitionRequest{
MsgType: internalpb.MsgType_kDropPartition,
ReqID: 1,
Timestamp: 11,
Timestamp: 15,
ProxyID: 1,
PartitionName: &servicepb.PartitionName{
CollectionName: sch.Name,
......@@ -323,6 +335,9 @@ func TestMaster_Scheduler_Partition(t *testing.T) {
err = dropPartitionTask.WaitToFinish(ctx)
assert.Nil(t, err)
err = mockTimeTickBroadCast(pulsarDDStream, Timestamp(16))
assert.NoError(t, err)
var dropPartitionMsg *ms.DropPartitionMsg
for {
result := consumeMsg.Consume()
......
......@@ -58,6 +58,7 @@ func initTestPulsarStream(ctx context.Context, pulsarAddress string,
return &input, &output
}
func receiveMsg(stream *ms.MsgStream) []uint64 {
receiveCount := 0
var results []uint64
......
......@@ -192,15 +192,15 @@ func TestTt_SoftTtBarrierStart(t *testing.T) {
func TestTt_SoftTtBarrierGetTimeTickClose(t *testing.T) {
channels := []string{"SoftTtBarrierGetTimeTickClose"}
ttmsgs := [][2]int{
{1, 10},
{2, 20},
{3, 30},
{4, 40},
{1, 30},
{2, 30},
}
inStream, ttStream := producer(channels, ttmsgs)
//ttmsgs := [][2]int{
// {1, 10},
// {2, 20},
// {3, 30},
// {4, 40},
// {1, 30},
// {2, 30},
//}
inStream, ttStream := producer(channels, nil)
defer func() {
(*inStream).Close()
(*ttStream).Close()
......@@ -259,15 +259,15 @@ func TestTt_SoftTtBarrierGetTimeTickClose(t *testing.T) {
func TestTt_SoftTtBarrierGetTimeTickCancel(t *testing.T) {
channels := []string{"SoftTtBarrierGetTimeTickCancel"}
ttmsgs := [][2]int{
{1, 10},
{2, 20},
{3, 30},
{4, 40},
{1, 30},
{2, 30},
}
inStream, ttStream := producer(channels, ttmsgs)
//ttmsgs := [][2]int{
// {1, 10},
// {2, 20},
// {3, 30},
// {4, 40},
// {1, 30},
// {2, 30},
//}
inStream, ttStream := producer(channels, nil)
defer func() {
(*inStream).Close()
(*ttStream).Close()
......
......@@ -157,6 +157,9 @@ func (ms *PulsarMsgStream) Produce(msgPack *MsgPack) error {
log.Printf("Warning: Receive empty msgPack")
return nil
}
if len(ms.producers) <= 0 {
return errors.New("nil producer in msg stream")
}
reBucketValues := make([][]int32, len(tsMsgs))
for channelID, tsMsg := range tsMsgs {
hashValues := tsMsg.HashKeys()
......
......@@ -24,12 +24,12 @@ class TestIndexBase:
params=gen_simple_index()
)
def get_simple_index(self, request, connect):
import copy
logging.getLogger().info(request.param)
# TODO: Determine the service mode
# if str(connect._cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("sq8h not support in CPU mode")
return request.param
if str(connect._cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("sq8h not support in CPU mode")
return copy.deepcopy(request.param)
@pytest.fixture(
scope="function",
......@@ -287,7 +287,6 @@ class TestIndexBase:
assert len(res) == nq
@pytest.mark.timeout(BUILD_TIMEOUT)
@pytest.mark.skip("test_create_index_multithread_ip")
@pytest.mark.level(2)
def test_create_index_multithread_ip(self, connect, collection, args):
'''
......
......@@ -89,10 +89,11 @@ class TestSearchBase:
params=gen_simple_index()
)
def get_simple_index(self, request, connect):
import copy
if str(connect._cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("sq8h not support in CPU mode")
return request.param
return copy.deepcopy(request.param)
@pytest.fixture(
scope="function",
......@@ -256,7 +257,6 @@ class TestSearchBase:
assert res2[0][0].entity.get("int64") == res[0][1].entity.get("int64")
# Pass
@pytest.mark.skip("search_after_index")
@pytest.mark.level(2)
def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
'''
......@@ -304,8 +304,6 @@ class TestSearchBase:
assert len(res[0]) == default_top_k
# pass
# should fix, 336 assert fail, insert data don't have partitionTag, But search data have
@pytest.mark.skip("search_index_partition")
@pytest.mark.level(2)
def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
'''
......@@ -337,7 +335,6 @@ class TestSearchBase:
assert len(res) == nq
# PASS
@pytest.mark.skip("search_index_partition_B")
@pytest.mark.level(2)
def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq):
'''
......@@ -388,7 +385,6 @@ class TestSearchBase:
assert len(res[0]) == 0
# PASS
@pytest.mark.skip("search_index_partitions")
@pytest.mark.level(2)
def test_search_index_partitions(self, connect, collection, get_simple_index, get_top_k):
'''
......@@ -423,7 +419,6 @@ class TestSearchBase:
assert res[1]._distances[0] > epsilon
# Pass
@pytest.mark.skip("search_index_partitions_B")
@pytest.mark.level(2)
def test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k):
'''
......@@ -484,7 +479,6 @@ class TestSearchBase:
res = connect.search(collection, query)
# PASS
@pytest.mark.skip("search_ip_after_index")
@pytest.mark.level(2)
def test_search_ip_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
'''
......@@ -513,8 +507,6 @@ class TestSearchBase:
assert check_id_result(res[0], ids[0])
assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
# should fix, nq not correct
@pytest.mark.skip("search_ip_index_partition")
@pytest.mark.level(2)
def test_search_ip_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
'''
......@@ -548,7 +540,6 @@ class TestSearchBase:
assert len(res) == nq
# PASS
@pytest.mark.skip("search_ip_index_partitions")
@pytest.mark.level(2)
def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k):
'''
......@@ -628,7 +619,6 @@ class TestSearchBase:
assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0])
# Pass
@pytest.mark.skip("test_search_distance_l2_after_index")
def test_search_distance_l2_after_index(self, connect, id_collection, get_simple_index):
'''
target: search collection, and check the result: distance
......@@ -683,7 +673,6 @@ class TestSearchBase:
assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= epsilon
# Pass
@pytest.mark.skip("search_distance_ip_after_index")
def test_search_distance_ip_after_index(self, connect, id_collection, get_simple_index):
'''
target: search collection, and check the result: distance
......@@ -953,8 +942,7 @@ class TestSearchBase:
assert res[i]._distances[0] < epsilon
assert res[i]._distances[1] > epsilon
# should fix
@pytest.mark.skip("query_entities_with_field_less_than_top_k")
@pytest.mark.skip("test_query_entities_with_field_less_than_top_k")
def test_query_entities_with_field_less_than_top_k(self, connect, id_collection):
"""
target: test search with field, and let return entities less than topk
......@@ -1754,7 +1742,6 @@ class TestSearchInvalid(object):
yield request.param
# Pass
@pytest.mark.skip("search_with_invalid_params")
@pytest.mark.level(2)
def test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params):
'''
......@@ -1776,7 +1763,6 @@ class TestSearchInvalid(object):
res = connect.search(collection, query)
# pass
@pytest.mark.skip("search_with_invalid_params_binary")
@pytest.mark.level(2)
def test_search_with_invalid_params_binary(self, connect, binary_collection):
'''
......@@ -1796,7 +1782,6 @@ class TestSearchInvalid(object):
res = connect.search(binary_collection, query)
# Pass
@pytest.mark.skip("search_with_empty_params")
@pytest.mark.level(2)
def test_search_with_empty_params(self, connect, collection, args, get_simple_index):
'''
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册