From 42b687bf48f7e2cd8123b26c1f829aca44998d57 Mon Sep 17 00:00:00 2001 From: dragondriver Date: Mon, 6 Sep 2021 20:49:04 +0800 Subject: [PATCH] Add unittest for task scheduler (#7508) Signed-off-by: dragondriver --- internal/distributed/grpcconfigs/configs.go | 11 + internal/proxy/channels_mgr.go | 75 --- internal/proxy/channels_mgr_test.go | 14 - internal/proxy/channels_time_ticker.go | 8 - internal/proxy/channels_time_ticker_test.go | 11 - internal/proxy/impl.go | 52 +- internal/proxy/interface_def.go | 40 ++ internal/proxy/mock.go | 188 +++++++ internal/proxy/naive_unique_id_generator.go | 49 ++ .../proxy/naive_unique_id_generator_test.go | 32 ++ internal/proxy/proxy.go | 4 +- internal/proxy/task.go | 27 +- internal/proxy/task_scheduler.go | 243 ++++---- internal/proxy/task_scheduler_test.go | 528 ++++++++++++++++++ internal/proxy/test_utils.go | 21 + internal/proxy/timestamp.go | 5 - internal/util/funcutil/set.go | 28 + internal/util/funcutil/set_test.go | 40 ++ 18 files changed, 1096 insertions(+), 280 deletions(-) create mode 100644 internal/proxy/interface_def.go create mode 100644 internal/proxy/naive_unique_id_generator.go create mode 100644 internal/proxy/naive_unique_id_generator_test.go create mode 100644 internal/proxy/task_scheduler_test.go create mode 100644 internal/util/funcutil/set.go create mode 100644 internal/util/funcutil/set_test.go diff --git a/internal/distributed/grpcconfigs/configs.go b/internal/distributed/grpcconfigs/configs.go index 9d36e24f6..027876d04 100644 --- a/internal/distributed/grpcconfigs/configs.go +++ b/internal/distributed/grpcconfigs/configs.go @@ -1,3 +1,14 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + package grpcconfigs import "math" diff --git a/internal/proxy/channels_mgr.go b/internal/proxy/channels_mgr.go index ab4b7fefe..b734f4173 100644 --- a/internal/proxy/channels_mgr.go +++ b/internal/proxy/channels_mgr.go @@ -14,7 +14,6 @@ package proxy import ( "context" "fmt" - "math/rand" "runtime" "sort" "sync" @@ -37,83 +36,9 @@ type channelsMgr interface { removeAllDMLStream() error } -type ( - uniqueIntGenerator interface { - get() int - } - naiveUniqueIntGenerator struct { - now int - mtx sync.Mutex - } -) - -func (generator *naiveUniqueIntGenerator) get() int { - generator.mtx.Lock() - defer func() { - generator.now++ - generator.mtx.Unlock() - }() - return generator.now -} - -func newNaiveUniqueIntGenerator() *naiveUniqueIntGenerator { - return &naiveUniqueIntGenerator{ - now: 0, - } -} - -var uniqueIntGeneratorIns uniqueIntGenerator -var getUniqueIntGeneratorInsOnce sync.Once - -func getUniqueIntGeneratorIns() uniqueIntGenerator { - getUniqueIntGeneratorInsOnce.Do(func() { - uniqueIntGeneratorIns = newNaiveUniqueIntGenerator() - }) - return uniqueIntGeneratorIns -} - type getChannelsFuncType = func(collectionID UniqueID) (map[vChan]pChan, error) type repackFuncType = func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) -type getChannelsService interface { - GetChannels(collectionID UniqueID) (map[vChan]pChan, error) -} - -type mockGetChannelsService struct { - collectionID2Channels map[UniqueID]map[vChan]pChan -} - -func newMockGetChannelsService() *mockGetChannelsService { - return &mockGetChannelsService{ - collectionID2Channels: make(map[UniqueID]map[vChan]pChan), - } -} - -func genUniqueStr() string { - l := rand.Uint64()%100 + 1 - b := make([]byte, l) - if _, err := rand.Read(b); err != nil { - return "" - } - return fmt.Sprintf("%X", b) -} - -func (m *mockGetChannelsService) GetChannels(collectionID UniqueID) (map[vChan]pChan, error) { - channels, ok := m.collectionID2Channels[collectionID] - if ok { - return channels, nil - } - - channels = make(map[vChan]pChan) - l := rand.Uint64()%10 + 1 - for i := 0; uint64(i) < l; i++ { - channels[genUniqueStr()] = genUniqueStr() - } - - m.collectionID2Channels[collectionID] = channels - return channels, nil -} - type streamType int const ( diff --git a/internal/proxy/channels_mgr_test.go b/internal/proxy/channels_mgr_test.go index 2a82207ad..03c9bd334 100644 --- a/internal/proxy/channels_mgr_test.go +++ b/internal/proxy/channels_mgr_test.go @@ -19,20 +19,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestNaiveUniqueIntGenerator_get(t *testing.T) { - exists := make(map[int]bool) - num := 10 - - generator := newNaiveUniqueIntGenerator() - - for i := 0; i < num; i++ { - g := generator.get() - _, ok := exists[g] - assert.False(t, ok) - exists[g] = true - } -} - func TestChannelsMgrImpl_getChannels(t *testing.T) { master := newMockGetChannelsService() query := newMockGetChannelsService() diff --git a/internal/proxy/channels_time_ticker.go b/internal/proxy/channels_time_ticker.go index 458713864..15fffaad3 100644 --- a/internal/proxy/channels_time_ticker.go +++ b/internal/proxy/channels_time_ticker.go @@ -24,14 +24,6 @@ import ( // ticker can update ts only when the minTs greater than the ts of ticker, we can use maxTs to update current later type getPChanStatisticsFuncType func() (map[pChan]*pChanStatistics, error) -// use interface tsoAllocator to keep channelsTimeTickerImpl testable -type tsoAllocator interface { - //Start() error - AllocOne() (Timestamp, error) - //Alloc(count uint32) ([]Timestamp, error) - //ClearCache() -} - type channelsTimeTicker interface { start() error close() error diff --git a/internal/proxy/channels_time_ticker_test.go b/internal/proxy/channels_time_ticker_test.go index 7c3df47f5..56af187b3 100644 --- a/internal/proxy/channels_time_ticker_test.go +++ b/internal/proxy/channels_time_ticker_test.go @@ -25,17 +25,6 @@ import ( "github.com/stretchr/testify/assert" ) -type mockTsoAllocator struct { -} - -func (tso *mockTsoAllocator) AllocOne() (Timestamp, error) { - return Timestamp(time.Now().UnixNano()), nil -} - -func newMockTsoAllocator() *mockTsoAllocator { - return &mockTsoAllocator{} -} - func newGetStatisticsFunc(pchans []pChan) getPChanStatisticsFuncType { totalPchan := len(pchans) pchanNum := rand.Uint64()%(uint64(totalPchan)) + 1 diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index 7f5ca2390..7b2132acd 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -134,7 +134,7 @@ func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.Creat zap.String("db", request.DbName), zap.String("collection", request.CollectionName), zap.Any("schema", request.Schema)) - err := node.sched.DdQueue.Enqueue(cct) + err := node.sched.ddQueue.Enqueue(cct) if err != nil { return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -188,7 +188,7 @@ func (node *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCol zap.String("role", Params.RoleName), zap.String("db", request.DbName), zap.String("collection", request.CollectionName)) - err := node.sched.DdQueue.Enqueue(dct) + err := node.sched.ddQueue.Enqueue(dct) if err != nil { return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -240,7 +240,7 @@ func (node *Proxy) HasCollection(ctx context.Context, request *milvuspb.HasColle zap.String("role", Params.RoleName), zap.String("db", request.DbName), zap.String("collection", request.CollectionName)) - err := node.sched.DdQueue.Enqueue(hct) + err := node.sched.ddQueue.Enqueue(hct) if err != nil { return &milvuspb.BoolResponse{ Status: &commonpb.Status{ @@ -294,7 +294,7 @@ func (node *Proxy) LoadCollection(ctx context.Context, request *milvuspb.LoadCol zap.String("role", Params.RoleName), zap.String("db", request.DbName), zap.String("collection", request.CollectionName)) - err := node.sched.DdQueue.Enqueue(lct) + err := node.sched.ddQueue.Enqueue(lct) if err != nil { return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -345,7 +345,7 @@ func (node *Proxy) ReleaseCollection(ctx context.Context, request *milvuspb.Rele zap.String("role", Params.RoleName), zap.String("db", request.DbName), zap.String("collection", request.CollectionName)) - err := node.sched.DdQueue.Enqueue(rct) + err := node.sched.ddQueue.Enqueue(rct) if err != nil { return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -397,7 +397,7 @@ func (node *Proxy) DescribeCollection(ctx context.Context, request *milvuspb.Des zap.String("role", Params.RoleName), zap.String("db", request.DbName), zap.String("collection", request.CollectionName)) - err := node.sched.DdQueue.Enqueue(dct) + err := node.sched.ddQueue.Enqueue(dct) if err != nil { return &milvuspb.DescribeCollectionResponse{ Status: &commonpb.Status{ @@ -453,7 +453,7 @@ func (node *Proxy) GetCollectionStatistics(ctx context.Context, request *milvusp zap.String("role", Params.RoleName), zap.String("db", request.DbName), zap.String("collection", request.CollectionName)) - err := node.sched.DdQueue.Enqueue(g) + err := node.sched.ddQueue.Enqueue(g) if err != nil { return &milvuspb.GetCollectionStatisticsResponse{ Status: &commonpb.Status{ @@ -509,7 +509,7 @@ func (node *Proxy) ShowCollections(ctx context.Context, request *milvuspb.ShowCo log.Debug("ShowCollections enqueue", zap.String("role", Params.RoleName), zap.Any("request", request)) - err := node.sched.DdQueue.Enqueue(sct) + err := node.sched.ddQueue.Enqueue(sct) if err != nil { return &milvuspb.ShowCollectionsResponse{ Status: &commonpb.Status{ @@ -560,7 +560,7 @@ func (node *Proxy) CreatePartition(ctx context.Context, request *milvuspb.Create zap.String("db", request.DbName), zap.String("collection", request.CollectionName), zap.String("partition", request.PartitionName)) - err := node.sched.DdQueue.Enqueue(cpt) + err := node.sched.ddQueue.Enqueue(cpt) if err != nil { return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -613,7 +613,7 @@ func (node *Proxy) DropPartition(ctx context.Context, request *milvuspb.DropPart zap.String("db", request.DbName), zap.String("collection", request.CollectionName), zap.String("partition", request.PartitionName)) - err := node.sched.DdQueue.Enqueue(dpt) + err := node.sched.ddQueue.Enqueue(dpt) if err != nil { return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -668,7 +668,7 @@ func (node *Proxy) HasPartition(ctx context.Context, request *milvuspb.HasPartit zap.String("db", request.DbName), zap.String("collection", request.CollectionName), zap.String("partition", request.PartitionName)) - err := node.sched.DdQueue.Enqueue(hpt) + err := node.sched.ddQueue.Enqueue(hpt) if err != nil { return &milvuspb.BoolResponse{ Status: &commonpb.Status{ @@ -726,7 +726,7 @@ func (node *Proxy) LoadPartitions(ctx context.Context, request *milvuspb.LoadPar zap.String("db", request.DbName), zap.String("collection", request.CollectionName), zap.Any("partitions", request.PartitionNames)) - err := node.sched.DdQueue.Enqueue(lpt) + err := node.sched.ddQueue.Enqueue(lpt) if err != nil { return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -779,7 +779,7 @@ func (node *Proxy) ReleasePartitions(ctx context.Context, request *milvuspb.Rele zap.String("db", request.DbName), zap.String("collection", request.CollectionName), zap.Any("partitions", request.PartitionNames)) - err := node.sched.DdQueue.Enqueue(rpt) + err := node.sched.ddQueue.Enqueue(rpt) if err != nil { return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -834,7 +834,7 @@ func (node *Proxy) GetPartitionStatistics(ctx context.Context, request *milvuspb zap.String("db", request.DbName), zap.String("collection", request.CollectionName), zap.String("partition", request.PartitionName)) - err := node.sched.DdQueue.Enqueue(g) + err := node.sched.ddQueue.Enqueue(g) if err != nil { return &milvuspb.GetPartitionStatisticsResponse{ Status: &commonpb.Status{ @@ -893,7 +893,7 @@ func (node *Proxy) ShowPartitions(ctx context.Context, request *milvuspb.ShowPar log.Debug("ShowPartitions enqueue", zap.String("role", Params.RoleName), zap.Any("request", request)) - err := node.sched.DdQueue.Enqueue(spt) + err := node.sched.ddQueue.Enqueue(spt) if err != nil { return &milvuspb.ShowPartitionsResponse{ Status: &commonpb.Status{ @@ -943,7 +943,7 @@ func (node *Proxy) CreateIndex(ctx context.Context, request *milvuspb.CreateInde zap.String("collection", request.CollectionName), zap.String("field", request.FieldName), zap.Any("extra_params", request.ExtraParams)) - err := node.sched.DdQueue.Enqueue(cit) + err := node.sched.ddQueue.Enqueue(cit) if err != nil { return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -1001,7 +1001,7 @@ func (node *Proxy) DescribeIndex(ctx context.Context, request *milvuspb.Describe zap.String("collection", request.CollectionName), zap.String("field", request.FieldName), zap.String("index name", request.IndexName)) - err := node.sched.DdQueue.Enqueue(dit) + err := node.sched.ddQueue.Enqueue(dit) if err != nil { return &milvuspb.DescribeIndexResponse{ Status: &commonpb.Status{ @@ -1065,7 +1065,7 @@ func (node *Proxy) DropIndex(ctx context.Context, request *milvuspb.DropIndexReq zap.String("collection", request.CollectionName), zap.String("field", request.FieldName), zap.String("index name", request.IndexName)) - err := node.sched.DdQueue.Enqueue(dit) + err := node.sched.ddQueue.Enqueue(dit) if err != nil { return &commonpb.Status{ @@ -1127,7 +1127,7 @@ func (node *Proxy) GetIndexBuildProgress(ctx context.Context, request *milvuspb. zap.String("collection", request.CollectionName), zap.String("field", request.FieldName), zap.String("index name", request.IndexName)) - err := node.sched.DdQueue.Enqueue(gibpt) + err := node.sched.ddQueue.Enqueue(gibpt) if err != nil { return &milvuspb.GetIndexBuildProgressResponse{ Status: &commonpb.Status{ @@ -1192,7 +1192,7 @@ func (node *Proxy) GetIndexState(ctx context.Context, request *milvuspb.GetIndex zap.String("collection", request.CollectionName), zap.String("field", request.FieldName), zap.String("index name", request.IndexName)) - err := node.sched.DdQueue.Enqueue(dipt) + err := node.sched.ddQueue.Enqueue(dipt) if err != nil { return &milvuspb.GetIndexStateResponse{ Status: &commonpb.Status{ @@ -1299,7 +1299,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest) ErrorCode: commonpb.ErrorCode_Success, }, } - err = node.sched.DmQueue.Enqueue(it) + err = node.sched.dmQueue.Enqueue(it) log.Debug("Insert Task Enqueue", zap.Int64("msgID", it.BaseInsertTask.InsertRequest.Base.MsgID), @@ -1366,7 +1366,7 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest) zap.String("collection", request.CollectionName), zap.String("partition", request.PartitionName), zap.String("expr", request.Expr)) - err := node.sched.DmQueue.Enqueue(dt) + err := node.sched.dmQueue.Enqueue(dt) if err != nil { return &milvuspb.MutationResult{ Status: &commonpb.Status{ @@ -1439,7 +1439,7 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) zap.Any("dsl", request.Dsl), zap.Any("len(PlaceholderGroup)", len(request.PlaceholderGroup)), zap.Any("OutputFields", request.OutputFields)) - err := node.sched.DqQueue.Enqueue(qt) + err := node.sched.dqQueue.Enqueue(qt) if err != nil { return &milvuspb.SearchResults{ Status: &commonpb.Status{ @@ -1519,7 +1519,7 @@ func (node *Proxy) Flush(ctx context.Context, request *milvuspb.FlushRequest) (* zap.String("role", Params.RoleName), zap.String("db", request.DbName), zap.Any("collections", request.CollectionNames)) - err := node.sched.DdQueue.Enqueue(ft) + err := node.sched.ddQueue.Enqueue(ft) if err != nil { resp.Status.Reason = err.Error() return resp, nil @@ -1587,7 +1587,7 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (* zap.String("collection", queryRequest.CollectionName), zap.Any("partitions", queryRequest.PartitionNames)) - err := node.sched.DqQueue.Enqueue(qt) + err := node.sched.dqQueue.Enqueue(qt) if err != nil { return &milvuspb.QueryResults{ Status: &commonpb.Status{ @@ -1669,7 +1669,7 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista ids: ids.IdArray, } - err := node.sched.DqQueue.Enqueue(qt) + err := node.sched.dqQueue.Enqueue(qt) if err != nil { return &milvuspb.QueryResults{ Status: &commonpb.Status{ diff --git a/internal/proxy/interface_def.go b/internal/proxy/interface_def.go new file mode 100644 index 000000000..bb35249c9 --- /dev/null +++ b/internal/proxy/interface_def.go @@ -0,0 +1,40 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +package proxy + +import ( + "context" + + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" +) + +// use interface tsoAllocator to keep other components testable +// include: channelsTimeTickerImpl, baseTaskQueue, taskScheduler +type tsoAllocator interface { + AllocOne() (Timestamp, error) +} + +// use interface idAllocatorInterface to keep other components testable +// include: baseTaskQueue, taskScheduler +type idAllocatorInterface interface { + AllocOne() (UniqueID, error) +} + +// use timestampAllocatorInterface to keep other components testable +// include: TimestampAllocator +type timestampAllocatorInterface interface { + AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error) +} + +type getChannelsService interface { + GetChannels(collectionID UniqueID) (map[vChan]pChan, error) +} diff --git a/internal/proxy/mock.go b/internal/proxy/mock.go index 0778491ee..11272d15f 100644 --- a/internal/proxy/mock.go +++ b/internal/proxy/mock.go @@ -13,6 +13,7 @@ package proxy import ( "context" + "math/rand" "time" "github.com/milvus-io/milvus/internal/proto/commonpb" @@ -37,3 +38,190 @@ func (tso *mockTimestampAllocatorInterface) AllocTimestamp(ctx context.Context, func newMockTimestampAllocatorInterface() timestampAllocatorInterface { return &mockTimestampAllocatorInterface{} } + +type mockTsoAllocator struct { +} + +func (tso *mockTsoAllocator) AllocOne() (Timestamp, error) { + return Timestamp(time.Now().UnixNano()), nil +} + +func newMockTsoAllocator() tsoAllocator { + return &mockTsoAllocator{} +} + +type mockIDAllocatorInterface struct { +} + +func (m *mockIDAllocatorInterface) AllocOne() (UniqueID, error) { + return UniqueID(getUniqueIntGeneratorIns().get()), nil +} + +func newMockIDAllocatorInterface() idAllocatorInterface { + return &mockIDAllocatorInterface{} +} + +type mockGetChannelsService struct { + collectionID2Channels map[UniqueID]map[vChan]pChan +} + +func newMockGetChannelsService() *mockGetChannelsService { + return &mockGetChannelsService{ + collectionID2Channels: make(map[UniqueID]map[vChan]pChan), + } +} + +func (m *mockGetChannelsService) GetChannels(collectionID UniqueID) (map[vChan]pChan, error) { + channels, ok := m.collectionID2Channels[collectionID] + if ok { + return channels, nil + } + + channels = make(map[vChan]pChan) + l := rand.Uint64()%10 + 1 + for i := 0; uint64(i) < l; i++ { + channels[genUniqueStr()] = genUniqueStr() + } + + m.collectionID2Channels[collectionID] = channels + return channels, nil +} + +type mockTask struct { + *TaskCondition + id UniqueID + name string + tType commonpb.MsgType + ts Timestamp +} + +func (m *mockTask) TraceCtx() context.Context { + return m.TaskCondition.ctx +} + +func (m *mockTask) ID() UniqueID { + return m.id +} + +func (m *mockTask) SetID(uid UniqueID) { + m.id = uid +} + +func (m *mockTask) Name() string { + return m.name +} + +func (m *mockTask) Type() commonpb.MsgType { + return m.tType +} + +func (m *mockTask) BeginTs() Timestamp { + return m.ts +} + +func (m *mockTask) EndTs() Timestamp { + return m.ts +} + +func (m *mockTask) SetTs(ts Timestamp) { + m.ts = ts +} + +func (m *mockTask) OnEnqueue() error { + return nil +} + +func (m *mockTask) PreExecute(ctx context.Context) error { + return nil +} + +func (m *mockTask) Execute(ctx context.Context) error { + return nil +} + +func (m *mockTask) PostExecute(ctx context.Context) error { + return nil +} + +func newMockTask(ctx context.Context) *mockTask { + return &mockTask{ + TaskCondition: NewTaskCondition(ctx), + id: UniqueID(getUniqueIntGeneratorIns().get()), + name: genUniqueStr(), + tType: commonpb.MsgType_Undefined, + ts: Timestamp(time.Now().Nanosecond()), + } +} + +func newDefaultMockTask() *mockTask { + return newMockTask(context.Background()) +} + +type mockDdlTask struct { + *mockTask +} + +func newMockDdlTask(ctx context.Context) *mockDdlTask { + return &mockDdlTask{ + mockTask: newMockTask(ctx), + } +} + +func newDefaultMockDdlTask() *mockDdlTask { + return newMockDdlTask(context.Background()) +} + +type mockDmlTask struct { + *mockTask + vchans []vChan + pchans []pChan +} + +func (m *mockDmlTask) getChannels() ([]vChan, error) { + return m.vchans, nil +} + +func (m *mockDmlTask) getPChanStats() (map[pChan]pChanStatistics, error) { + ret := make(map[pChan]pChanStatistics) + for _, pchan := range m.pchans { + ret[pchan] = pChanStatistics{ + minTs: m.ts, + maxTs: m.ts, + } + } + return ret, nil +} + +func newMockDmlTask(ctx context.Context) *mockDmlTask { + shardNum := 2 + + vchans := make([]vChan, 0, shardNum) + pchans := make([]pChan, 0, shardNum) + + for i := 0; i < shardNum; i++ { + vchans = append(vchans, genUniqueStr()) + pchans = append(pchans, genUniqueStr()) + } + + return &mockDmlTask{ + mockTask: newMockTask(ctx), + } +} + +func newDefaultMockDmlTask() *mockDmlTask { + return newMockDmlTask(context.Background()) +} + +type mockDqlTask struct { + *mockTask +} + +func newMockDqlTask(ctx context.Context) *mockDqlTask { + return &mockDqlTask{ + mockTask: newMockTask(ctx), + } +} + +func newDefaultMockDqlTask() *mockDqlTask { + return newMockDqlTask(context.Background()) +} diff --git a/internal/proxy/naive_unique_id_generator.go b/internal/proxy/naive_unique_id_generator.go new file mode 100644 index 000000000..a50f7f0b2 --- /dev/null +++ b/internal/proxy/naive_unique_id_generator.go @@ -0,0 +1,49 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +package proxy + +import "sync" + +type ( + uniqueIntGenerator interface { + get() int + } + naiveUniqueIntGenerator struct { + now int + mtx sync.Mutex + } +) + +func (generator *naiveUniqueIntGenerator) get() int { + generator.mtx.Lock() + defer func() { + generator.now++ + generator.mtx.Unlock() + }() + return generator.now +} + +func newNaiveUniqueIntGenerator() *naiveUniqueIntGenerator { + return &naiveUniqueIntGenerator{ + now: 0, + } +} + +var uniqueIntGeneratorIns uniqueIntGenerator +var getUniqueIntGeneratorInsOnce sync.Once + +func getUniqueIntGeneratorIns() uniqueIntGenerator { + getUniqueIntGeneratorInsOnce.Do(func() { + uniqueIntGeneratorIns = newNaiveUniqueIntGenerator() + }) + return uniqueIntGeneratorIns +} diff --git a/internal/proxy/naive_unique_id_generator_test.go b/internal/proxy/naive_unique_id_generator_test.go new file mode 100644 index 000000000..ea3d27c12 --- /dev/null +++ b/internal/proxy/naive_unique_id_generator_test.go @@ -0,0 +1,32 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +package proxy + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNaiveUniqueIntGenerator_get(t *testing.T) { + exists := make(map[int]bool) + num := 10 + + generator := newNaiveUniqueIntGenerator() + + for i := 0; i < num; i++ { + g := generator.get() + _, ok := exists[g] + assert.False(t, ok) + exists[g] = true + } +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 6461b3a81..03aaeef85 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -63,7 +63,7 @@ type Proxy struct { chMgr channelsMgr - sched *TaskScheduler + sched *taskScheduler tick *timeTick chTicker channelsTimeTicker @@ -256,7 +256,7 @@ func (node *Proxy) Init() error { chMgr := newChannelsMgrImpl(getDmlChannelsFunc, defaultInsertRepackFunc, getDqlChannelsFunc, nil, node.msFactory) node.chMgr = chMgr - node.sched, err = NewTaskScheduler(node.ctx, node.idAllocator, node.tsoAllocator, node.msFactory) + node.sched, err = newTaskScheduler(node.ctx, node.idAllocator, node.tsoAllocator, node.msFactory) if err != nil { return err } diff --git a/internal/proxy/task.go b/internal/proxy/task.go index f327cda5b..61b9ca05b 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -100,7 +100,6 @@ type dmlTask interface { task getChannels() ([]vChan, error) getPChanStats() (map[pChan]pChanStatistics, error) - getChannelsTimerTicker() channelsTimeTicker } type BaseInsertTask = msgstream.InsertMsg @@ -155,10 +154,6 @@ func (it *InsertTask) EndTs() Timestamp { return it.EndTimestamp } -func (it *InsertTask) getChannelsTimerTicker() channelsTimeTicker { - return it.chTicker -} - func (it *InsertTask) getPChanStats() (map[pChan]pChanStatistics, error) { ret := make(map[pChan]pChanStatistics) @@ -192,6 +187,17 @@ func (it *InsertTask) getChannels() ([]pChan, error) { return nil, err } channels, err = it.chMgr.getChannels(collID) + if err == nil { + for _, pchan := range channels { + err := it.chTicker.addPChan(pchan) + if err != nil { + log.Warn("failed to add pchan to channels time ticker", + zap.Error(err), + zap.Int64("collection id", collID), + zap.String("pchan", pchan)) + } + } + } } return channels, err } @@ -1023,6 +1029,17 @@ func (it *InsertTask) Execute(ctx context.Context) error { it.result.Status.Reason = err.Error() return err } + channels, err := it.chMgr.getChannels(collID) + if err == nil { + for _, pchan := range channels { + err := it.chTicker.addPChan(pchan) + if err != nil { + log.Warn("failed to add pchan to channels time ticker", + zap.Error(err), + zap.String("pchan", pchan)) + } + } + } stream, err = it.chMgr.getDMLStream(collID) if err != nil { it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError diff --git a/internal/proxy/task_scheduler.go b/internal/proxy/task_scheduler.go index ea9fdec13..b1299b345 100644 --- a/internal/proxy/task_scheduler.go +++ b/internal/proxy/task_scheduler.go @@ -19,9 +19,10 @@ import ( "strconv" "sync" + "github.com/milvus-io/milvus/internal/util/funcutil" + "go.uber.org/zap" - "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/msgstream" "github.com/milvus-io/milvus/internal/proto/commonpb" @@ -31,7 +32,7 @@ import ( oplog "github.com/opentracing/opentracing-go/log" ) -type TaskQueue interface { +type taskQueue interface { utChan() <-chan int utEmpty() bool utFull() bool @@ -45,7 +46,10 @@ type TaskQueue interface { Enqueue(t task) error } -type BaseTaskQueue struct { +// TODO(dragondriver): load from config +const maxTaskNum = 1024 + +type baseTaskQueue struct { unissuedTasks *list.List activeTasks map[UniqueID]task utLock sync.RWMutex @@ -56,24 +60,25 @@ type BaseTaskQueue struct { utBufChan chan int // to block scheduler - sched *TaskScheduler + tsoAllocatorIns tsoAllocator + idAllocatorIns idAllocatorInterface } -func (queue *BaseTaskQueue) utChan() <-chan int { +func (queue *baseTaskQueue) utChan() <-chan int { return queue.utBufChan } -func (queue *BaseTaskQueue) utEmpty() bool { +func (queue *baseTaskQueue) utEmpty() bool { queue.utLock.RLock() defer queue.utLock.RUnlock() return queue.unissuedTasks.Len() == 0 } -func (queue *BaseTaskQueue) utFull() bool { +func (queue *baseTaskQueue) utFull() bool { return int64(queue.unissuedTasks.Len()) >= queue.maxTaskNum } -func (queue *BaseTaskQueue) addUnissuedTask(t task) error { +func (queue *baseTaskQueue) addUnissuedTask(t task) error { queue.utLock.Lock() defer queue.utLock.Unlock() @@ -85,7 +90,7 @@ func (queue *BaseTaskQueue) addUnissuedTask(t task) error { return nil } -func (queue *BaseTaskQueue) FrontUnissuedTask() task { +func (queue *baseTaskQueue) FrontUnissuedTask() task { queue.utLock.RLock() defer queue.utLock.RUnlock() @@ -97,7 +102,7 @@ func (queue *BaseTaskQueue) FrontUnissuedTask() task { return queue.unissuedTasks.Front().Value.(task) } -func (queue *BaseTaskQueue) PopUnissuedTask() task { +func (queue *baseTaskQueue) PopUnissuedTask() task { queue.utLock.Lock() defer queue.utLock.Unlock() @@ -112,7 +117,7 @@ func (queue *BaseTaskQueue) PopUnissuedTask() task { return ft.Value.(task) } -func (queue *BaseTaskQueue) AddActiveTask(t task) { +func (queue *baseTaskQueue) AddActiveTask(t task) { queue.atLock.Lock() defer queue.atLock.Unlock() tID := t.ID() @@ -124,7 +129,7 @@ func (queue *BaseTaskQueue) AddActiveTask(t task) { queue.activeTasks[tID] = t } -func (queue *BaseTaskQueue) PopActiveTask(tID UniqueID) task { +func (queue *baseTaskQueue) PopActiveTask(tID UniqueID) task { queue.atLock.Lock() defer queue.atLock.Unlock() t, ok := queue.activeTasks[tID] @@ -137,7 +142,7 @@ func (queue *BaseTaskQueue) PopActiveTask(tID UniqueID) task { return t } -func (queue *BaseTaskQueue) getTaskByReqID(reqID UniqueID) task { +func (queue *baseTaskQueue) getTaskByReqID(reqID UniqueID) task { queue.utLock.RLock() defer queue.utLock.RUnlock() for e := queue.unissuedTasks.Front(); e != nil; e = e.Next() { @@ -157,7 +162,7 @@ func (queue *BaseTaskQueue) getTaskByReqID(reqID UniqueID) task { return nil } -func (queue *BaseTaskQueue) TaskDoneTest(ts Timestamp) bool { +func (queue *baseTaskQueue) TaskDoneTest(ts Timestamp) bool { queue.utLock.RLock() defer queue.utLock.RUnlock() for e := queue.unissuedTasks.Front(); e != nil; e = e.Next() { @@ -177,19 +182,19 @@ func (queue *BaseTaskQueue) TaskDoneTest(ts Timestamp) bool { return true } -func (queue *BaseTaskQueue) Enqueue(t task) error { +func (queue *baseTaskQueue) Enqueue(t task) error { err := t.OnEnqueue() if err != nil { return err } - ts, err := queue.sched.tsoAllocator.AllocOne() + ts, err := queue.tsoAllocatorIns.AllocOne() if err != nil { return err } t.SetTs(ts) - reqID, err := queue.sched.idAllocator.AllocOne() + reqID, err := queue.idAllocatorIns.AllocOne() if err != nil { return err } @@ -198,8 +203,21 @@ func (queue *BaseTaskQueue) Enqueue(t task) error { return queue.addUnissuedTask(t) } -type DdTaskQueue struct { - BaseTaskQueue +func newBaseTaskQueue(tsoAllocatorIns tsoAllocator, idAllocatorIns idAllocatorInterface) *baseTaskQueue { + return &baseTaskQueue{ + unissuedTasks: list.New(), + activeTasks: make(map[UniqueID]task), + utLock: sync.RWMutex{}, + atLock: sync.RWMutex{}, + maxTaskNum: maxTaskNum, + utBufChan: make(chan int, maxTaskNum), + tsoAllocatorIns: tsoAllocatorIns, + idAllocatorIns: idAllocatorIns, + } +} + +type ddTaskQueue struct { + *baseTaskQueue lock sync.Mutex } @@ -208,19 +226,19 @@ type pChanStatInfo struct { tsSet map[Timestamp]struct{} } -type DmTaskQueue struct { - BaseTaskQueue +type dmTaskQueue struct { + *baseTaskQueue lock sync.Mutex statsLock sync.RWMutex pChanStatisticsInfos map[pChan]*pChanStatInfo } -func (queue *DmTaskQueue) Enqueue(t task) error { +func (queue *dmTaskQueue) Enqueue(t task) error { queue.lock.Lock() defer queue.lock.Unlock() - err := queue.BaseTaskQueue.Enqueue(t) + err := queue.baseTaskQueue.Enqueue(t) if err != nil { return err } @@ -229,13 +247,13 @@ func (queue *DmTaskQueue) Enqueue(t task) error { return nil } -func (queue *DmTaskQueue) PopActiveTask(tID UniqueID) task { +func (queue *dmTaskQueue) PopActiveTask(tID UniqueID) task { queue.atLock.Lock() defer queue.atLock.Unlock() t, ok := queue.activeTasks[tID] if ok { delete(queue.activeTasks, tID) - log.Debug("Proxy DmTaskQueue popPChanStats", zap.Any("tID", t.ID())) + log.Debug("Proxy dmTaskQueue popPChanStats", zap.Any("tID", t.ID())) queue.popPChanStats(t) } else { log.Debug("Proxy task not in active task list!", zap.Any("tID", tID)) @@ -243,11 +261,11 @@ func (queue *DmTaskQueue) PopActiveTask(tID UniqueID) task { return t } -func (queue *DmTaskQueue) addPChanStats(t task) error { +func (queue *dmTaskQueue) addPChanStats(t task) error { if dmT, ok := t.(dmlTask); ok { stats, err := dmT.getPChanStats() if err != nil { - log.Debug("Proxy DmTaskQueue addPChanStats", zap.Any("tID", t.ID()), + log.Debug("Proxy dmTaskQueue addPChanStats", zap.Any("tID", t.ID()), zap.Any("stats", stats), zap.Error(err)) return err } @@ -262,7 +280,6 @@ func (queue *DmTaskQueue) addPChanStats(t task) error { }, } queue.pChanStatisticsInfos[cName] = info - dmT.getChannelsTimerTicker().addPChan(cName) } else { if info.minTs > stat.minTs { queue.pChanStatisticsInfos[cName].minTs = stat.minTs @@ -280,7 +297,7 @@ func (queue *DmTaskQueue) addPChanStats(t task) error { return nil } -func (queue *DmTaskQueue) popPChanStats(t task) error { +func (queue *dmTaskQueue) popPChanStats(t task) error { if dmT, ok := t.(dmlTask); ok { channels, err := dmT.getChannels() if err != nil { @@ -306,12 +323,12 @@ func (queue *DmTaskQueue) popPChanStats(t task) error { } queue.statsLock.Unlock() } else { - return fmt.Errorf("Proxy DmTaskQueue popPChanStats reflect to dmlTask failed, tID:%v", t.ID()) + return fmt.Errorf("Proxy dmTaskQueue popPChanStats reflect to dmlTask failed, tID:%v", t.ID()) } return nil } -func (queue *DmTaskQueue) getPChanStatsInfo() (map[pChan]*pChanStatistics, error) { +func (queue *dmTaskQueue) getPChanStatsInfo() (map[pChan]*pChanStatistics, error) { ret := make(map[pChan]*pChanStatistics) queue.statsLock.RLock() @@ -325,60 +342,39 @@ func (queue *DmTaskQueue) getPChanStatsInfo() (map[pChan]*pChanStatistics, error return ret, nil } -type DqTaskQueue struct { - BaseTaskQueue +type dqTaskQueue struct { + *baseTaskQueue } -func (queue *DdTaskQueue) Enqueue(t task) error { +func (queue *ddTaskQueue) Enqueue(t task) error { queue.lock.Lock() defer queue.lock.Unlock() - return queue.BaseTaskQueue.Enqueue(t) + return queue.baseTaskQueue.Enqueue(t) } -func NewDdTaskQueue(sched *TaskScheduler) *DdTaskQueue { - return &DdTaskQueue{ - BaseTaskQueue: BaseTaskQueue{ - unissuedTasks: list.New(), - activeTasks: make(map[UniqueID]task), - maxTaskNum: 1024, - utBufChan: make(chan int, 1024), - sched: sched, - }, +func newDdTaskQueue(tsoAllocatorIns tsoAllocator, idAllocatorIns idAllocatorInterface) *ddTaskQueue { + return &ddTaskQueue{ + baseTaskQueue: newBaseTaskQueue(tsoAllocatorIns, idAllocatorIns), } } -func NewDmTaskQueue(sched *TaskScheduler) *DmTaskQueue { - return &DmTaskQueue{ - BaseTaskQueue: BaseTaskQueue{ - unissuedTasks: list.New(), - activeTasks: make(map[UniqueID]task), - maxTaskNum: 1024, - utBufChan: make(chan int, 1024), - sched: sched, - }, +func newDmTaskQueue(tsoAllocatorIns tsoAllocator, idAllocatorIns idAllocatorInterface) *dmTaskQueue { + return &dmTaskQueue{ + baseTaskQueue: newBaseTaskQueue(tsoAllocatorIns, idAllocatorIns), pChanStatisticsInfos: make(map[pChan]*pChanStatInfo), } } -func NewDqTaskQueue(sched *TaskScheduler) *DqTaskQueue { - return &DqTaskQueue{ - BaseTaskQueue: BaseTaskQueue{ - unissuedTasks: list.New(), - activeTasks: make(map[UniqueID]task), - maxTaskNum: 1024, - utBufChan: make(chan int, 1024), - sched: sched, - }, +func newDqTaskQueue(tsoAllocatorIns tsoAllocator, idAllocatorIns idAllocatorInterface) *dqTaskQueue { + return &dqTaskQueue{ + baseTaskQueue: newBaseTaskQueue(tsoAllocatorIns, idAllocatorIns), } } -type TaskScheduler struct { - DdQueue TaskQueue - DmQueue *DmTaskQueue - DqQueue TaskQueue - - idAllocator *allocator.IDAllocator - tsoAllocator *TimestampAllocator +type taskScheduler struct { + ddQueue taskQueue + dmQueue *dmTaskQueue + dqQueue taskQueue wg sync.WaitGroup ctx context.Context @@ -387,51 +383,49 @@ type TaskScheduler struct { msFactory msgstream.Factory } -func NewTaskScheduler(ctx context.Context, - idAllocator *allocator.IDAllocator, - tsoAllocator *TimestampAllocator, - factory msgstream.Factory) (*TaskScheduler, error) { +func newTaskScheduler(ctx context.Context, + idAllocatorIns idAllocatorInterface, + tsoAllocatorIns tsoAllocator, + factory msgstream.Factory) (*taskScheduler, error) { ctx1, cancel := context.WithCancel(ctx) - s := &TaskScheduler{ - idAllocator: idAllocator, - tsoAllocator: tsoAllocator, - ctx: ctx1, - cancel: cancel, - msFactory: factory, + s := &taskScheduler{ + ctx: ctx1, + cancel: cancel, + msFactory: factory, } - s.DdQueue = NewDdTaskQueue(s) - s.DmQueue = NewDmTaskQueue(s) - s.DqQueue = NewDqTaskQueue(s) + s.ddQueue = newDdTaskQueue(tsoAllocatorIns, idAllocatorIns) + s.dmQueue = newDmTaskQueue(tsoAllocatorIns, idAllocatorIns) + s.dqQueue = newDqTaskQueue(tsoAllocatorIns, idAllocatorIns) return s, nil } -func (sched *TaskScheduler) scheduleDdTask() task { - return sched.DdQueue.PopUnissuedTask() +func (sched *taskScheduler) scheduleDdTask() task { + return sched.ddQueue.PopUnissuedTask() } -func (sched *TaskScheduler) scheduleDmTask() task { - return sched.DmQueue.PopUnissuedTask() +func (sched *taskScheduler) scheduleDmTask() task { + return sched.dmQueue.PopUnissuedTask() } -func (sched *TaskScheduler) scheduleDqTask() task { - return sched.DqQueue.PopUnissuedTask() +func (sched *taskScheduler) scheduleDqTask() task { + return sched.dqQueue.PopUnissuedTask() } -func (sched *TaskScheduler) getTaskByReqID(collMeta UniqueID) task { - if t := sched.DdQueue.getTaskByReqID(collMeta); t != nil { +func (sched *taskScheduler) getTaskByReqID(collMeta UniqueID) task { + if t := sched.ddQueue.getTaskByReqID(collMeta); t != nil { return t } - if t := sched.DmQueue.getTaskByReqID(collMeta); t != nil { + if t := sched.dmQueue.getTaskByReqID(collMeta); t != nil { return t } - if t := sched.DqQueue.getTaskByReqID(collMeta); t != nil { + if t := sched.dqQueue.getTaskByReqID(collMeta); t != nil { return t } return nil } -func (sched *TaskScheduler) processTask(t task, q TaskQueue) { +func (sched *taskScheduler) processTask(t task, q taskQueue) { span, ctx := trace.StartSpanFromContext(t.TraceCtx(), opentracing.Tags{ "Type": t.Name(), @@ -469,47 +463,47 @@ func (sched *TaskScheduler) processTask(t task, q TaskQueue) { err = t.PostExecute(ctx) } -func (sched *TaskScheduler) definitionLoop() { +func (sched *taskScheduler) definitionLoop() { defer sched.wg.Done() for { select { case <-sched.ctx.Done(): return - case <-sched.DdQueue.utChan(): - if !sched.DdQueue.utEmpty() { + case <-sched.ddQueue.utChan(): + if !sched.ddQueue.utEmpty() { t := sched.scheduleDdTask() - sched.processTask(t, sched.DdQueue) + sched.processTask(t, sched.ddQueue) } } } } -func (sched *TaskScheduler) manipulationLoop() { +func (sched *taskScheduler) manipulationLoop() { defer sched.wg.Done() for { select { case <-sched.ctx.Done(): return - case <-sched.DmQueue.utChan(): - if !sched.DmQueue.utEmpty() { + case <-sched.dmQueue.utChan(): + if !sched.dmQueue.utEmpty() { t := sched.scheduleDmTask() - go sched.processTask(t, sched.DmQueue) + go sched.processTask(t, sched.dmQueue) } } } } -func (sched *TaskScheduler) queryLoop() { +func (sched *taskScheduler) queryLoop() { defer sched.wg.Done() for { select { case <-sched.ctx.Done(): return - case <-sched.DqQueue.utChan(): - if !sched.DqQueue.utEmpty() { + case <-sched.dqQueue.utChan(): + if !sched.dqQueue.utEmpty() { t := sched.scheduleDqTask() - go sched.processTask(t, sched.DqQueue) + go sched.processTask(t, sched.dqQueue) } else { log.Debug("query queue is empty ...") } @@ -561,25 +555,6 @@ func newQueryResultBuf() *queryResultBuf { } } -func setContain(m1, m2 map[interface{}]struct{}) bool { - log.Debug("Proxy task_scheduler setContain", zap.Any("len(m1)", len(m1)), - zap.Any("len(m2)", len(m2))) - if len(m1) < len(m2) { - return false - } - - for k2 := range m2 { - _, ok := m1[k2] - log.Debug("Proxy task_scheduler setContain", zap.Any("k2", fmt.Sprintf("%v", k2)), - zap.Any("ok", ok)) - if !ok { - return false - } - } - - return true -} - func (sr *resultBufHeader) readyToReduce() bool { if sr.haveError { log.Debug("Proxy searchResultBuf readyToReduce", zap.Any("haveError", true)) @@ -608,7 +583,7 @@ func (sr *resultBufHeader) readyToReduce() bool { sealedGlobalSegmentIDsStrMap[x.(int64)] = 1 } - ret1 := setContain(sr.receivedVChansSet, sr.usedVChans) + ret1 := funcutil.SetContain(sr.receivedVChansSet, sr.usedVChans) log.Debug("Proxy searchResultBuf readyToReduce", zap.Any("receivedVChansSet", receivedVChansSetStrMap), zap.Any("usedVChans", usedVChansSetStrMap), zap.Any("receivedSealedSegmentIDsSet", sealedSegmentIDsStrMap), @@ -618,7 +593,7 @@ func (sr *resultBufHeader) readyToReduce() bool { if !ret1 { return false } - ret := setContain(sr.receivedSealedSegmentIDsSet, sr.receivedGlobalSegmentIDsSet) + ret := funcutil.SetContain(sr.receivedSealedSegmentIDsSet, sr.receivedGlobalSegmentIDsSet) log.Debug("Proxy searchResultBuf readyToReduce", zap.Any("ret", ret)) return ret } @@ -658,7 +633,7 @@ func (qr *queryResultBuf) addPartialResult(result *internalpb.RetrieveResults) { result.GlobalSealedSegmentIDs) } -func (sched *TaskScheduler) collectResultLoop() { +func (sched *taskScheduler) collectResultLoop() { defer sched.wg.Done() queryResultMsgStream, _ := sched.msFactory.NewQueryMsgStream(sched.ctx) @@ -862,7 +837,7 @@ func (sched *TaskScheduler) collectResultLoop() { } } -func (sched *TaskScheduler) Start() error { +func (sched *taskScheduler) Start() error { sched.wg.Add(1) go sched.definitionLoop() @@ -878,17 +853,17 @@ func (sched *TaskScheduler) Start() error { return nil } -func (sched *TaskScheduler) Close() { +func (sched *taskScheduler) Close() { sched.cancel() sched.wg.Wait() } -func (sched *TaskScheduler) TaskDoneTest(ts Timestamp) bool { - ddTaskDone := sched.DdQueue.TaskDoneTest(ts) - dmTaskDone := sched.DmQueue.TaskDoneTest(ts) +func (sched *taskScheduler) TaskDoneTest(ts Timestamp) bool { + ddTaskDone := sched.ddQueue.TaskDoneTest(ts) + dmTaskDone := sched.dmQueue.TaskDoneTest(ts) return ddTaskDone && dmTaskDone } -func (sched *TaskScheduler) getPChanStatistics() (map[pChan]*pChanStatistics, error) { - return sched.DmQueue.getPChanStatsInfo() +func (sched *taskScheduler) getPChanStatistics() (map[pChan]*pChanStatistics, error) { + return sched.dmQueue.getPChanStatsInfo() } diff --git a/internal/proxy/task_scheduler_test.go b/internal/proxy/task_scheduler_test.go new file mode 100644 index 000000000..cc92d325f --- /dev/null +++ b/internal/proxy/task_scheduler_test.go @@ -0,0 +1,528 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +package proxy + +import ( + "context" + "math/rand" + "sync" + "testing" + "time" + + "github.com/milvus-io/milvus/internal/msgstream" + + "github.com/stretchr/testify/assert" +) + +func TestBaseTaskQueue(t *testing.T) { + var err error + var unissuedTask task + var activeTask task + var done bool + + tsoAllocatorIns := newMockTsoAllocator() + idAllocatorIns := newMockIDAllocatorInterface() + queue := newBaseTaskQueue(tsoAllocatorIns, idAllocatorIns) + assert.NotNil(t, queue) + + assert.True(t, queue.utEmpty()) + assert.False(t, queue.utFull()) + + st := newDefaultMockTask() + stID := st.ID() + stTs := st.BeginTs() + + // no task in queue + + unissuedTask = queue.FrontUnissuedTask() + assert.Nil(t, unissuedTask) + + unissuedTask = queue.getTaskByReqID(stID) + assert.Nil(t, unissuedTask) + + unissuedTask = queue.PopUnissuedTask() + assert.Nil(t, unissuedTask) + + done = queue.TaskDoneTest(stTs) + assert.True(t, done) + + // task enqueue, only one task in queue + + err = queue.Enqueue(st) + assert.NoError(t, err) + + assert.False(t, queue.utEmpty()) + assert.False(t, queue.utFull()) + assert.Equal(t, 1, queue.unissuedTasks.Len()) + assert.Equal(t, 1, len(queue.utChan())) + + unissuedTask = queue.FrontUnissuedTask() + assert.NotNil(t, unissuedTask) + + unissuedTask = queue.getTaskByReqID(unissuedTask.ID()) + assert.NotNil(t, unissuedTask) + + done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1) + assert.False(t, done) + + unissuedTask = queue.PopUnissuedTask() + assert.NotNil(t, unissuedTask) + assert.True(t, queue.utEmpty()) + assert.False(t, queue.utFull()) + + done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1) + assert.True(t, done) + + // test active list, no task in queue + + activeTask = queue.getTaskByReqID(unissuedTask.ID()) + assert.Nil(t, activeTask) + + done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1) + assert.True(t, done) + + activeTask = queue.PopActiveTask(unissuedTask.ID()) + assert.Nil(t, activeTask) + + // test active list, no task in unissued list, only one task in active list + + queue.AddActiveTask(unissuedTask) + + activeTask = queue.getTaskByReqID(unissuedTask.ID()) + assert.NotNil(t, activeTask) + + done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1) + assert.False(t, done) + + activeTask = queue.PopActiveTask(unissuedTask.ID()) + assert.NotNil(t, activeTask) + + done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1) + assert.True(t, done) + + // test utFull + queue.maxTaskNum = 10 // not accurate, full also means utBufChan block + for i := 0; i < int(queue.maxTaskNum); i++ { + err = queue.Enqueue(newDefaultMockTask()) + assert.Nil(t, err) + } + assert.True(t, queue.utFull()) + err = queue.Enqueue(newDefaultMockTask()) + assert.NotNil(t, err) +} + +func TestDdTaskQueue(t *testing.T) { + var err error + var unissuedTask task + var activeTask task + var done bool + + tsoAllocatorIns := newMockTsoAllocator() + idAllocatorIns := newMockIDAllocatorInterface() + queue := newDdTaskQueue(tsoAllocatorIns, idAllocatorIns) + assert.NotNil(t, queue) + + assert.True(t, queue.utEmpty()) + assert.False(t, queue.utFull()) + + st := newDefaultMockDdlTask() + stID := st.ID() + stTs := st.BeginTs() + + // no task in queue + + unissuedTask = queue.FrontUnissuedTask() + assert.Nil(t, unissuedTask) + + unissuedTask = queue.getTaskByReqID(stID) + assert.Nil(t, unissuedTask) + + unissuedTask = queue.PopUnissuedTask() + assert.Nil(t, unissuedTask) + + done = queue.TaskDoneTest(stTs) + assert.True(t, done) + + // task enqueue, only one task in queue + + err = queue.Enqueue(st) + assert.NoError(t, err) + + assert.False(t, queue.utEmpty()) + assert.False(t, queue.utFull()) + assert.Equal(t, 1, queue.unissuedTasks.Len()) + assert.Equal(t, 1, len(queue.utChan())) + + unissuedTask = queue.FrontUnissuedTask() + assert.NotNil(t, unissuedTask) + + unissuedTask = queue.getTaskByReqID(unissuedTask.ID()) + assert.NotNil(t, unissuedTask) + + done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1) + assert.False(t, done) + + unissuedTask = queue.PopUnissuedTask() + assert.NotNil(t, unissuedTask) + assert.True(t, queue.utEmpty()) + assert.False(t, queue.utFull()) + + done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1) + assert.True(t, done) + + // test active list, no task in queue + + activeTask = queue.getTaskByReqID(unissuedTask.ID()) + assert.Nil(t, activeTask) + + done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1) + assert.True(t, done) + + activeTask = queue.PopActiveTask(unissuedTask.ID()) + assert.Nil(t, activeTask) + + // test active list, no task in unissued list, only one task in active list + + queue.AddActiveTask(unissuedTask) + + activeTask = queue.getTaskByReqID(unissuedTask.ID()) + assert.NotNil(t, activeTask) + + done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1) + assert.False(t, done) + + activeTask = queue.PopActiveTask(unissuedTask.ID()) + assert.NotNil(t, activeTask) + + done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1) + assert.True(t, done) + + // test utFull + queue.maxTaskNum = 10 // not accurate, full also means utBufChan block + for i := 0; i < int(queue.maxTaskNum); i++ { + err = queue.Enqueue(newDefaultMockDdlTask()) + assert.Nil(t, err) + } + assert.True(t, queue.utFull()) + err = queue.Enqueue(newDefaultMockDdlTask()) + assert.NotNil(t, err) +} + +// test the logic of queue +func TestDmTaskQueue_Basic(t *testing.T) { + var err error + var unissuedTask task + var activeTask task + var done bool + + tsoAllocatorIns := newMockTsoAllocator() + idAllocatorIns := newMockIDAllocatorInterface() + queue := newDmTaskQueue(tsoAllocatorIns, idAllocatorIns) + assert.NotNil(t, queue) + + assert.True(t, queue.utEmpty()) + assert.False(t, queue.utFull()) + + st := newDefaultMockDmlTask() + stID := st.ID() + stTs := st.BeginTs() + + // no task in queue + + unissuedTask = queue.FrontUnissuedTask() + assert.Nil(t, unissuedTask) + + unissuedTask = queue.getTaskByReqID(stID) + assert.Nil(t, unissuedTask) + + unissuedTask = queue.PopUnissuedTask() + assert.Nil(t, unissuedTask) + + done = queue.TaskDoneTest(stTs) + assert.True(t, done) + + // task enqueue, only one task in queue + + err = queue.Enqueue(st) + assert.NoError(t, err) + + assert.False(t, queue.utEmpty()) + assert.False(t, queue.utFull()) + assert.Equal(t, 1, queue.unissuedTasks.Len()) + assert.Equal(t, 1, len(queue.utChan())) + + unissuedTask = queue.FrontUnissuedTask() + assert.NotNil(t, unissuedTask) + + unissuedTask = queue.getTaskByReqID(unissuedTask.ID()) + assert.NotNil(t, unissuedTask) + + done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1) + assert.False(t, done) + + unissuedTask = queue.PopUnissuedTask() + assert.NotNil(t, unissuedTask) + assert.True(t, queue.utEmpty()) + assert.False(t, queue.utFull()) + + done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1) + assert.True(t, done) + + // test active list, no task in queue + + activeTask = queue.getTaskByReqID(unissuedTask.ID()) + assert.Nil(t, activeTask) + + done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1) + assert.True(t, done) + + activeTask = queue.PopActiveTask(unissuedTask.ID()) + assert.Nil(t, activeTask) + + // test active list, no task in unissued list, only one task in active list + + queue.AddActiveTask(unissuedTask) + + activeTask = queue.getTaskByReqID(unissuedTask.ID()) + assert.NotNil(t, activeTask) + + done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1) + assert.False(t, done) + + activeTask = queue.PopActiveTask(unissuedTask.ID()) + assert.NotNil(t, activeTask) + + done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1) + assert.True(t, done) + + // test utFull + queue.maxTaskNum = 10 // not accurate, full also means utBufChan block + for i := 0; i < int(queue.maxTaskNum); i++ { + err = queue.Enqueue(newDefaultMockDmlTask()) + assert.Nil(t, err) + } + assert.True(t, queue.utFull()) + err = queue.Enqueue(newDefaultMockDmlTask()) + assert.NotNil(t, err) +} + +// test the timestamp statistics +func TestDmTaskQueue_TimestampStatistics(t *testing.T) { + var err error + var unissuedTask task + + tsoAllocatorIns := newMockTsoAllocator() + idAllocatorIns := newMockIDAllocatorInterface() + queue := newDmTaskQueue(tsoAllocatorIns, idAllocatorIns) + assert.NotNil(t, queue) + + st := newDefaultMockDmlTask() + stPChans := st.pchans + + err = queue.Enqueue(st) + assert.NoError(t, err) + + stats, err := queue.getPChanStatsInfo() + assert.NoError(t, err) + assert.Equal(t, len(stPChans), len(stats)) + unissuedTask = queue.FrontUnissuedTask() + assert.NotNil(t, unissuedTask) + for _, stat := range stats { + assert.Equal(t, unissuedTask.BeginTs(), stat.minTs) + assert.Equal(t, unissuedTask.EndTs(), stat.maxTs) + } + + unissuedTask = queue.PopUnissuedTask() + assert.NotNil(t, unissuedTask) + assert.True(t, queue.utEmpty()) + + queue.AddActiveTask(unissuedTask) + + queue.PopActiveTask(unissuedTask.ID()) + + stats, err = queue.getPChanStatsInfo() + assert.NoError(t, err) + assert.Zero(t, len(stats)) +} + +func TestDqTaskQueue(t *testing.T) { + var err error + var unissuedTask task + var activeTask task + var done bool + + tsoAllocatorIns := newMockTsoAllocator() + idAllocatorIns := newMockIDAllocatorInterface() + queue := newDqTaskQueue(tsoAllocatorIns, idAllocatorIns) + assert.NotNil(t, queue) + + assert.True(t, queue.utEmpty()) + assert.False(t, queue.utFull()) + + st := newDefaultMockDqlTask() + stID := st.ID() + stTs := st.BeginTs() + + // no task in queue + + unissuedTask = queue.FrontUnissuedTask() + assert.Nil(t, unissuedTask) + + unissuedTask = queue.getTaskByReqID(stID) + assert.Nil(t, unissuedTask) + + unissuedTask = queue.PopUnissuedTask() + assert.Nil(t, unissuedTask) + + done = queue.TaskDoneTest(stTs) + assert.True(t, done) + + // task enqueue, only one task in queue + + err = queue.Enqueue(st) + assert.NoError(t, err) + + assert.False(t, queue.utEmpty()) + assert.False(t, queue.utFull()) + assert.Equal(t, 1, queue.unissuedTasks.Len()) + assert.Equal(t, 1, len(queue.utChan())) + + unissuedTask = queue.FrontUnissuedTask() + assert.NotNil(t, unissuedTask) + + unissuedTask = queue.getTaskByReqID(unissuedTask.ID()) + assert.NotNil(t, unissuedTask) + + done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1) + assert.False(t, done) + + unissuedTask = queue.PopUnissuedTask() + assert.NotNil(t, unissuedTask) + assert.True(t, queue.utEmpty()) + assert.False(t, queue.utFull()) + + done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1) + assert.True(t, done) + + // test active list, no task in queue + + activeTask = queue.getTaskByReqID(unissuedTask.ID()) + assert.Nil(t, activeTask) + + done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1) + assert.True(t, done) + + activeTask = queue.PopActiveTask(unissuedTask.ID()) + assert.Nil(t, activeTask) + + // test active list, no task in unissued list, only one task in active list + + queue.AddActiveTask(unissuedTask) + + activeTask = queue.getTaskByReqID(unissuedTask.ID()) + assert.NotNil(t, activeTask) + + done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1) + assert.False(t, done) + + activeTask = queue.PopActiveTask(unissuedTask.ID()) + assert.NotNil(t, activeTask) + + done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1) + assert.True(t, done) + + // test utFull + queue.maxTaskNum = 10 // not accurate, full also means utBufChan block + for i := 0; i < int(queue.maxTaskNum); i++ { + err = queue.Enqueue(newDefaultMockDqlTask()) + assert.Nil(t, err) + } + assert.True(t, queue.utFull()) + err = queue.Enqueue(newDefaultMockDqlTask()) + assert.NotNil(t, err) +} + +func TestTaskScheduler(t *testing.T) { + var err error + + ctx := context.Background() + tsoAllocatorIns := newMockTsoAllocator() + idAllocatorIns := newMockIDAllocatorInterface() + factory := msgstream.NewSimpleMsgStreamFactory() + + sched, err := newTaskScheduler(ctx, idAllocatorIns, tsoAllocatorIns, factory) + assert.NoError(t, err) + assert.NotNil(t, sched) + + err = sched.Start() + assert.NoError(t, err) + defer sched.Close() + + assert.True(t, sched.TaskDoneTest(Timestamp(time.Now().Nanosecond()))) + + stats, err := sched.getPChanStatistics() + assert.NoError(t, err) + assert.Equal(t, 0, len(stats)) + + ddNum := rand.Int() % 10 + dmNum := rand.Int() % 10 + dqNum := rand.Int() % 10 + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + + for i := 0; i < ddNum; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + err := sched.ddQueue.Enqueue(newDefaultMockDdlTask()) + assert.NoError(t, err) + }() + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + + for i := 0; i < dmNum; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + err := sched.dmQueue.Enqueue(newDefaultMockDmlTask()) + assert.NoError(t, err) + }() + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + + for i := 0; i < dqNum; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + err := sched.dqQueue.Enqueue(newDefaultMockDqlTask()) + assert.NoError(t, err) + }() + } + }() + + wg.Wait() +} diff --git a/internal/proxy/test_utils.go b/internal/proxy/test_utils.go index bd360cd21..24f21aef0 100644 --- a/internal/proxy/test_utils.go +++ b/internal/proxy/test_utils.go @@ -1,11 +1,32 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + package proxy import ( + "fmt" "math/rand" "github.com/milvus-io/milvus/internal/proto/schemapb" ) +func genUniqueStr() string { + l := rand.Uint64()%100 + 1 + b := make([]byte, l) + if _, err := rand.Read(b); err != nil { + return "" + } + return fmt.Sprintf("%X", b) +} + func generateBoolArray(numRows int) []bool { ret := make([]bool, 0, numRows) for i := 0; i < numRows; i++ { diff --git a/internal/proxy/timestamp.go b/internal/proxy/timestamp.go index 529592556..18f12838e 100644 --- a/internal/proxy/timestamp.go +++ b/internal/proxy/timestamp.go @@ -20,11 +20,6 @@ import ( "github.com/milvus-io/milvus/internal/proto/rootcoordpb" ) -// use timestampAllocatorInterface to keep TimestampAllocator testable -type timestampAllocatorInterface interface { - AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error) -} - type TimestampAllocator struct { ctx context.Context tso timestampAllocatorInterface diff --git a/internal/util/funcutil/set.go b/internal/util/funcutil/set.go new file mode 100644 index 000000000..344194d18 --- /dev/null +++ b/internal/util/funcutil/set.go @@ -0,0 +1,28 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +package funcutil + +// SetContain returns true if set m1 contains set m2 +func SetContain(m1, m2 map[interface{}]struct{}) bool { + if len(m1) < len(m2) { + return false + } + + for k2 := range m2 { + _, ok := m1[k2] + if !ok { + return false + } + } + + return true +} diff --git a/internal/util/funcutil/set_test.go b/internal/util/funcutil/set_test.go new file mode 100644 index 000000000..2f67100cd --- /dev/null +++ b/internal/util/funcutil/set_test.go @@ -0,0 +1,40 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +package funcutil + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSetContain(t *testing.T) { + key1 := "key1" + key2 := "key2" + key3 := "key3" + + // len(m1) < len(m2) + m1 := make(map[interface{}]struct{}) + m2 := make(map[interface{}]struct{}) + m1[key1] = struct{}{} + m2[key1] = struct{}{} + m2[key2] = struct{}{} + assert.False(t, SetContain(m1, m2)) + + // len(m1) >= len(m2), but m2 contains other key not in m1 + m1[key3] = struct{}{} + assert.False(t, SetContain(m1, m2)) + + // m1 contains m2 + m1[key2] = struct{}{} + assert.True(t, SetContain(m1, m2)) +} -- GitLab