未验证 提交 42b687bf 编写于 作者: D dragondriver 提交者: GitHub

Add unittest for task scheduler (#7508)

Signed-off-by: Ndragondriver <jiquan.long@zilliz.com>
上级 7025a6e9
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package grpcconfigs
import "math"
......
......@@ -14,7 +14,6 @@ package proxy
import (
"context"
"fmt"
"math/rand"
"runtime"
"sort"
"sync"
......@@ -37,83 +36,9 @@ type channelsMgr interface {
removeAllDMLStream() error
}
type (
uniqueIntGenerator interface {
get() int
}
naiveUniqueIntGenerator struct {
now int
mtx sync.Mutex
}
)
func (generator *naiveUniqueIntGenerator) get() int {
generator.mtx.Lock()
defer func() {
generator.now++
generator.mtx.Unlock()
}()
return generator.now
}
func newNaiveUniqueIntGenerator() *naiveUniqueIntGenerator {
return &naiveUniqueIntGenerator{
now: 0,
}
}
var uniqueIntGeneratorIns uniqueIntGenerator
var getUniqueIntGeneratorInsOnce sync.Once
func getUniqueIntGeneratorIns() uniqueIntGenerator {
getUniqueIntGeneratorInsOnce.Do(func() {
uniqueIntGeneratorIns = newNaiveUniqueIntGenerator()
})
return uniqueIntGeneratorIns
}
type getChannelsFuncType = func(collectionID UniqueID) (map[vChan]pChan, error)
type repackFuncType = func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error)
type getChannelsService interface {
GetChannels(collectionID UniqueID) (map[vChan]pChan, error)
}
type mockGetChannelsService struct {
collectionID2Channels map[UniqueID]map[vChan]pChan
}
func newMockGetChannelsService() *mockGetChannelsService {
return &mockGetChannelsService{
collectionID2Channels: make(map[UniqueID]map[vChan]pChan),
}
}
func genUniqueStr() string {
l := rand.Uint64()%100 + 1
b := make([]byte, l)
if _, err := rand.Read(b); err != nil {
return ""
}
return fmt.Sprintf("%X", b)
}
func (m *mockGetChannelsService) GetChannels(collectionID UniqueID) (map[vChan]pChan, error) {
channels, ok := m.collectionID2Channels[collectionID]
if ok {
return channels, nil
}
channels = make(map[vChan]pChan)
l := rand.Uint64()%10 + 1
for i := 0; uint64(i) < l; i++ {
channels[genUniqueStr()] = genUniqueStr()
}
m.collectionID2Channels[collectionID] = channels
return channels, nil
}
type streamType int
const (
......
......@@ -19,20 +19,6 @@ import (
"github.com/stretchr/testify/assert"
)
func TestNaiveUniqueIntGenerator_get(t *testing.T) {
exists := make(map[int]bool)
num := 10
generator := newNaiveUniqueIntGenerator()
for i := 0; i < num; i++ {
g := generator.get()
_, ok := exists[g]
assert.False(t, ok)
exists[g] = true
}
}
func TestChannelsMgrImpl_getChannels(t *testing.T) {
master := newMockGetChannelsService()
query := newMockGetChannelsService()
......
......@@ -24,14 +24,6 @@ import (
// ticker can update ts only when the minTs greater than the ts of ticker, we can use maxTs to update current later
type getPChanStatisticsFuncType func() (map[pChan]*pChanStatistics, error)
// use interface tsoAllocator to keep channelsTimeTickerImpl testable
type tsoAllocator interface {
//Start() error
AllocOne() (Timestamp, error)
//Alloc(count uint32) ([]Timestamp, error)
//ClearCache()
}
type channelsTimeTicker interface {
start() error
close() error
......
......@@ -25,17 +25,6 @@ import (
"github.com/stretchr/testify/assert"
)
type mockTsoAllocator struct {
}
func (tso *mockTsoAllocator) AllocOne() (Timestamp, error) {
return Timestamp(time.Now().UnixNano()), nil
}
func newMockTsoAllocator() *mockTsoAllocator {
return &mockTsoAllocator{}
}
func newGetStatisticsFunc(pchans []pChan) getPChanStatisticsFuncType {
totalPchan := len(pchans)
pchanNum := rand.Uint64()%(uint64(totalPchan)) + 1
......
......@@ -134,7 +134,7 @@ func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.Creat
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName),
zap.Any("schema", request.Schema))
err := node.sched.DdQueue.Enqueue(cct)
err := node.sched.ddQueue.Enqueue(cct)
if err != nil {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
......@@ -188,7 +188,7 @@ func (node *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCol
zap.String("role", Params.RoleName),
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName))
err := node.sched.DdQueue.Enqueue(dct)
err := node.sched.ddQueue.Enqueue(dct)
if err != nil {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
......@@ -240,7 +240,7 @@ func (node *Proxy) HasCollection(ctx context.Context, request *milvuspb.HasColle
zap.String("role", Params.RoleName),
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName))
err := node.sched.DdQueue.Enqueue(hct)
err := node.sched.ddQueue.Enqueue(hct)
if err != nil {
return &milvuspb.BoolResponse{
Status: &commonpb.Status{
......@@ -294,7 +294,7 @@ func (node *Proxy) LoadCollection(ctx context.Context, request *milvuspb.LoadCol
zap.String("role", Params.RoleName),
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName))
err := node.sched.DdQueue.Enqueue(lct)
err := node.sched.ddQueue.Enqueue(lct)
if err != nil {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
......@@ -345,7 +345,7 @@ func (node *Proxy) ReleaseCollection(ctx context.Context, request *milvuspb.Rele
zap.String("role", Params.RoleName),
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName))
err := node.sched.DdQueue.Enqueue(rct)
err := node.sched.ddQueue.Enqueue(rct)
if err != nil {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
......@@ -397,7 +397,7 @@ func (node *Proxy) DescribeCollection(ctx context.Context, request *milvuspb.Des
zap.String("role", Params.RoleName),
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName))
err := node.sched.DdQueue.Enqueue(dct)
err := node.sched.ddQueue.Enqueue(dct)
if err != nil {
return &milvuspb.DescribeCollectionResponse{
Status: &commonpb.Status{
......@@ -453,7 +453,7 @@ func (node *Proxy) GetCollectionStatistics(ctx context.Context, request *milvusp
zap.String("role", Params.RoleName),
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName))
err := node.sched.DdQueue.Enqueue(g)
err := node.sched.ddQueue.Enqueue(g)
if err != nil {
return &milvuspb.GetCollectionStatisticsResponse{
Status: &commonpb.Status{
......@@ -509,7 +509,7 @@ func (node *Proxy) ShowCollections(ctx context.Context, request *milvuspb.ShowCo
log.Debug("ShowCollections enqueue",
zap.String("role", Params.RoleName),
zap.Any("request", request))
err := node.sched.DdQueue.Enqueue(sct)
err := node.sched.ddQueue.Enqueue(sct)
if err != nil {
return &milvuspb.ShowCollectionsResponse{
Status: &commonpb.Status{
......@@ -560,7 +560,7 @@ func (node *Proxy) CreatePartition(ctx context.Context, request *milvuspb.Create
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName),
zap.String("partition", request.PartitionName))
err := node.sched.DdQueue.Enqueue(cpt)
err := node.sched.ddQueue.Enqueue(cpt)
if err != nil {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
......@@ -613,7 +613,7 @@ func (node *Proxy) DropPartition(ctx context.Context, request *milvuspb.DropPart
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName),
zap.String("partition", request.PartitionName))
err := node.sched.DdQueue.Enqueue(dpt)
err := node.sched.ddQueue.Enqueue(dpt)
if err != nil {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
......@@ -668,7 +668,7 @@ func (node *Proxy) HasPartition(ctx context.Context, request *milvuspb.HasPartit
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName),
zap.String("partition", request.PartitionName))
err := node.sched.DdQueue.Enqueue(hpt)
err := node.sched.ddQueue.Enqueue(hpt)
if err != nil {
return &milvuspb.BoolResponse{
Status: &commonpb.Status{
......@@ -726,7 +726,7 @@ func (node *Proxy) LoadPartitions(ctx context.Context, request *milvuspb.LoadPar
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName),
zap.Any("partitions", request.PartitionNames))
err := node.sched.DdQueue.Enqueue(lpt)
err := node.sched.ddQueue.Enqueue(lpt)
if err != nil {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
......@@ -779,7 +779,7 @@ func (node *Proxy) ReleasePartitions(ctx context.Context, request *milvuspb.Rele
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName),
zap.Any("partitions", request.PartitionNames))
err := node.sched.DdQueue.Enqueue(rpt)
err := node.sched.ddQueue.Enqueue(rpt)
if err != nil {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
......@@ -834,7 +834,7 @@ func (node *Proxy) GetPartitionStatistics(ctx context.Context, request *milvuspb
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName),
zap.String("partition", request.PartitionName))
err := node.sched.DdQueue.Enqueue(g)
err := node.sched.ddQueue.Enqueue(g)
if err != nil {
return &milvuspb.GetPartitionStatisticsResponse{
Status: &commonpb.Status{
......@@ -893,7 +893,7 @@ func (node *Proxy) ShowPartitions(ctx context.Context, request *milvuspb.ShowPar
log.Debug("ShowPartitions enqueue",
zap.String("role", Params.RoleName),
zap.Any("request", request))
err := node.sched.DdQueue.Enqueue(spt)
err := node.sched.ddQueue.Enqueue(spt)
if err != nil {
return &milvuspb.ShowPartitionsResponse{
Status: &commonpb.Status{
......@@ -943,7 +943,7 @@ func (node *Proxy) CreateIndex(ctx context.Context, request *milvuspb.CreateInde
zap.String("collection", request.CollectionName),
zap.String("field", request.FieldName),
zap.Any("extra_params", request.ExtraParams))
err := node.sched.DdQueue.Enqueue(cit)
err := node.sched.ddQueue.Enqueue(cit)
if err != nil {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
......@@ -1001,7 +1001,7 @@ func (node *Proxy) DescribeIndex(ctx context.Context, request *milvuspb.Describe
zap.String("collection", request.CollectionName),
zap.String("field", request.FieldName),
zap.String("index name", request.IndexName))
err := node.sched.DdQueue.Enqueue(dit)
err := node.sched.ddQueue.Enqueue(dit)
if err != nil {
return &milvuspb.DescribeIndexResponse{
Status: &commonpb.Status{
......@@ -1065,7 +1065,7 @@ func (node *Proxy) DropIndex(ctx context.Context, request *milvuspb.DropIndexReq
zap.String("collection", request.CollectionName),
zap.String("field", request.FieldName),
zap.String("index name", request.IndexName))
err := node.sched.DdQueue.Enqueue(dit)
err := node.sched.ddQueue.Enqueue(dit)
if err != nil {
return &commonpb.Status{
......@@ -1127,7 +1127,7 @@ func (node *Proxy) GetIndexBuildProgress(ctx context.Context, request *milvuspb.
zap.String("collection", request.CollectionName),
zap.String("field", request.FieldName),
zap.String("index name", request.IndexName))
err := node.sched.DdQueue.Enqueue(gibpt)
err := node.sched.ddQueue.Enqueue(gibpt)
if err != nil {
return &milvuspb.GetIndexBuildProgressResponse{
Status: &commonpb.Status{
......@@ -1192,7 +1192,7 @@ func (node *Proxy) GetIndexState(ctx context.Context, request *milvuspb.GetIndex
zap.String("collection", request.CollectionName),
zap.String("field", request.FieldName),
zap.String("index name", request.IndexName))
err := node.sched.DdQueue.Enqueue(dipt)
err := node.sched.ddQueue.Enqueue(dipt)
if err != nil {
return &milvuspb.GetIndexStateResponse{
Status: &commonpb.Status{
......@@ -1299,7 +1299,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
ErrorCode: commonpb.ErrorCode_Success,
},
}
err = node.sched.DmQueue.Enqueue(it)
err = node.sched.dmQueue.Enqueue(it)
log.Debug("Insert Task Enqueue",
zap.Int64("msgID", it.BaseInsertTask.InsertRequest.Base.MsgID),
......@@ -1366,7 +1366,7 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)
zap.String("collection", request.CollectionName),
zap.String("partition", request.PartitionName),
zap.String("expr", request.Expr))
err := node.sched.DmQueue.Enqueue(dt)
err := node.sched.dmQueue.Enqueue(dt)
if err != nil {
return &milvuspb.MutationResult{
Status: &commonpb.Status{
......@@ -1439,7 +1439,7 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
zap.Any("dsl", request.Dsl),
zap.Any("len(PlaceholderGroup)", len(request.PlaceholderGroup)),
zap.Any("OutputFields", request.OutputFields))
err := node.sched.DqQueue.Enqueue(qt)
err := node.sched.dqQueue.Enqueue(qt)
if err != nil {
return &milvuspb.SearchResults{
Status: &commonpb.Status{
......@@ -1519,7 +1519,7 @@ func (node *Proxy) Flush(ctx context.Context, request *milvuspb.FlushRequest) (*
zap.String("role", Params.RoleName),
zap.String("db", request.DbName),
zap.Any("collections", request.CollectionNames))
err := node.sched.DdQueue.Enqueue(ft)
err := node.sched.ddQueue.Enqueue(ft)
if err != nil {
resp.Status.Reason = err.Error()
return resp, nil
......@@ -1587,7 +1587,7 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
zap.String("collection", queryRequest.CollectionName),
zap.Any("partitions", queryRequest.PartitionNames))
err := node.sched.DqQueue.Enqueue(qt)
err := node.sched.dqQueue.Enqueue(qt)
if err != nil {
return &milvuspb.QueryResults{
Status: &commonpb.Status{
......@@ -1669,7 +1669,7 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
ids: ids.IdArray,
}
err := node.sched.DqQueue.Enqueue(qt)
err := node.sched.dqQueue.Enqueue(qt)
if err != nil {
return &milvuspb.QueryResults{
Status: &commonpb.Status{
......
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package proxy
import (
"context"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
)
// use interface tsoAllocator to keep other components testable
// include: channelsTimeTickerImpl, baseTaskQueue, taskScheduler
type tsoAllocator interface {
AllocOne() (Timestamp, error)
}
// use interface idAllocatorInterface to keep other components testable
// include: baseTaskQueue, taskScheduler
type idAllocatorInterface interface {
AllocOne() (UniqueID, error)
}
// use timestampAllocatorInterface to keep other components testable
// include: TimestampAllocator
type timestampAllocatorInterface interface {
AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error)
}
type getChannelsService interface {
GetChannels(collectionID UniqueID) (map[vChan]pChan, error)
}
......@@ -13,6 +13,7 @@ package proxy
import (
"context"
"math/rand"
"time"
"github.com/milvus-io/milvus/internal/proto/commonpb"
......@@ -37,3 +38,190 @@ func (tso *mockTimestampAllocatorInterface) AllocTimestamp(ctx context.Context,
func newMockTimestampAllocatorInterface() timestampAllocatorInterface {
return &mockTimestampAllocatorInterface{}
}
type mockTsoAllocator struct {
}
func (tso *mockTsoAllocator) AllocOne() (Timestamp, error) {
return Timestamp(time.Now().UnixNano()), nil
}
func newMockTsoAllocator() tsoAllocator {
return &mockTsoAllocator{}
}
type mockIDAllocatorInterface struct {
}
func (m *mockIDAllocatorInterface) AllocOne() (UniqueID, error) {
return UniqueID(getUniqueIntGeneratorIns().get()), nil
}
func newMockIDAllocatorInterface() idAllocatorInterface {
return &mockIDAllocatorInterface{}
}
type mockGetChannelsService struct {
collectionID2Channels map[UniqueID]map[vChan]pChan
}
func newMockGetChannelsService() *mockGetChannelsService {
return &mockGetChannelsService{
collectionID2Channels: make(map[UniqueID]map[vChan]pChan),
}
}
func (m *mockGetChannelsService) GetChannels(collectionID UniqueID) (map[vChan]pChan, error) {
channels, ok := m.collectionID2Channels[collectionID]
if ok {
return channels, nil
}
channels = make(map[vChan]pChan)
l := rand.Uint64()%10 + 1
for i := 0; uint64(i) < l; i++ {
channels[genUniqueStr()] = genUniqueStr()
}
m.collectionID2Channels[collectionID] = channels
return channels, nil
}
type mockTask struct {
*TaskCondition
id UniqueID
name string
tType commonpb.MsgType
ts Timestamp
}
func (m *mockTask) TraceCtx() context.Context {
return m.TaskCondition.ctx
}
func (m *mockTask) ID() UniqueID {
return m.id
}
func (m *mockTask) SetID(uid UniqueID) {
m.id = uid
}
func (m *mockTask) Name() string {
return m.name
}
func (m *mockTask) Type() commonpb.MsgType {
return m.tType
}
func (m *mockTask) BeginTs() Timestamp {
return m.ts
}
func (m *mockTask) EndTs() Timestamp {
return m.ts
}
func (m *mockTask) SetTs(ts Timestamp) {
m.ts = ts
}
func (m *mockTask) OnEnqueue() error {
return nil
}
func (m *mockTask) PreExecute(ctx context.Context) error {
return nil
}
func (m *mockTask) Execute(ctx context.Context) error {
return nil
}
func (m *mockTask) PostExecute(ctx context.Context) error {
return nil
}
func newMockTask(ctx context.Context) *mockTask {
return &mockTask{
TaskCondition: NewTaskCondition(ctx),
id: UniqueID(getUniqueIntGeneratorIns().get()),
name: genUniqueStr(),
tType: commonpb.MsgType_Undefined,
ts: Timestamp(time.Now().Nanosecond()),
}
}
func newDefaultMockTask() *mockTask {
return newMockTask(context.Background())
}
type mockDdlTask struct {
*mockTask
}
func newMockDdlTask(ctx context.Context) *mockDdlTask {
return &mockDdlTask{
mockTask: newMockTask(ctx),
}
}
func newDefaultMockDdlTask() *mockDdlTask {
return newMockDdlTask(context.Background())
}
type mockDmlTask struct {
*mockTask
vchans []vChan
pchans []pChan
}
func (m *mockDmlTask) getChannels() ([]vChan, error) {
return m.vchans, nil
}
func (m *mockDmlTask) getPChanStats() (map[pChan]pChanStatistics, error) {
ret := make(map[pChan]pChanStatistics)
for _, pchan := range m.pchans {
ret[pchan] = pChanStatistics{
minTs: m.ts,
maxTs: m.ts,
}
}
return ret, nil
}
func newMockDmlTask(ctx context.Context) *mockDmlTask {
shardNum := 2
vchans := make([]vChan, 0, shardNum)
pchans := make([]pChan, 0, shardNum)
for i := 0; i < shardNum; i++ {
vchans = append(vchans, genUniqueStr())
pchans = append(pchans, genUniqueStr())
}
return &mockDmlTask{
mockTask: newMockTask(ctx),
}
}
func newDefaultMockDmlTask() *mockDmlTask {
return newMockDmlTask(context.Background())
}
type mockDqlTask struct {
*mockTask
}
func newMockDqlTask(ctx context.Context) *mockDqlTask {
return &mockDqlTask{
mockTask: newMockTask(ctx),
}
}
func newDefaultMockDqlTask() *mockDqlTask {
return newMockDqlTask(context.Background())
}
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package proxy
import "sync"
type (
uniqueIntGenerator interface {
get() int
}
naiveUniqueIntGenerator struct {
now int
mtx sync.Mutex
}
)
func (generator *naiveUniqueIntGenerator) get() int {
generator.mtx.Lock()
defer func() {
generator.now++
generator.mtx.Unlock()
}()
return generator.now
}
func newNaiveUniqueIntGenerator() *naiveUniqueIntGenerator {
return &naiveUniqueIntGenerator{
now: 0,
}
}
var uniqueIntGeneratorIns uniqueIntGenerator
var getUniqueIntGeneratorInsOnce sync.Once
func getUniqueIntGeneratorIns() uniqueIntGenerator {
getUniqueIntGeneratorInsOnce.Do(func() {
uniqueIntGeneratorIns = newNaiveUniqueIntGenerator()
})
return uniqueIntGeneratorIns
}
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package proxy
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestNaiveUniqueIntGenerator_get(t *testing.T) {
exists := make(map[int]bool)
num := 10
generator := newNaiveUniqueIntGenerator()
for i := 0; i < num; i++ {
g := generator.get()
_, ok := exists[g]
assert.False(t, ok)
exists[g] = true
}
}
......@@ -63,7 +63,7 @@ type Proxy struct {
chMgr channelsMgr
sched *TaskScheduler
sched *taskScheduler
tick *timeTick
chTicker channelsTimeTicker
......@@ -256,7 +256,7 @@ func (node *Proxy) Init() error {
chMgr := newChannelsMgrImpl(getDmlChannelsFunc, defaultInsertRepackFunc, getDqlChannelsFunc, nil, node.msFactory)
node.chMgr = chMgr
node.sched, err = NewTaskScheduler(node.ctx, node.idAllocator, node.tsoAllocator, node.msFactory)
node.sched, err = newTaskScheduler(node.ctx, node.idAllocator, node.tsoAllocator, node.msFactory)
if err != nil {
return err
}
......
......@@ -100,7 +100,6 @@ type dmlTask interface {
task
getChannels() ([]vChan, error)
getPChanStats() (map[pChan]pChanStatistics, error)
getChannelsTimerTicker() channelsTimeTicker
}
type BaseInsertTask = msgstream.InsertMsg
......@@ -155,10 +154,6 @@ func (it *InsertTask) EndTs() Timestamp {
return it.EndTimestamp
}
func (it *InsertTask) getChannelsTimerTicker() channelsTimeTicker {
return it.chTicker
}
func (it *InsertTask) getPChanStats() (map[pChan]pChanStatistics, error) {
ret := make(map[pChan]pChanStatistics)
......@@ -192,6 +187,17 @@ func (it *InsertTask) getChannels() ([]pChan, error) {
return nil, err
}
channels, err = it.chMgr.getChannels(collID)
if err == nil {
for _, pchan := range channels {
err := it.chTicker.addPChan(pchan)
if err != nil {
log.Warn("failed to add pchan to channels time ticker",
zap.Error(err),
zap.Int64("collection id", collID),
zap.String("pchan", pchan))
}
}
}
}
return channels, err
}
......@@ -1023,6 +1029,17 @@ func (it *InsertTask) Execute(ctx context.Context) error {
it.result.Status.Reason = err.Error()
return err
}
channels, err := it.chMgr.getChannels(collID)
if err == nil {
for _, pchan := range channels {
err := it.chTicker.addPChan(pchan)
if err != nil {
log.Warn("failed to add pchan to channels time ticker",
zap.Error(err),
zap.String("pchan", pchan))
}
}
}
stream, err = it.chMgr.getDMLStream(collID)
if err != nil {
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
......
......@@ -19,9 +19,10 @@ import (
"strconv"
"sync"
"github.com/milvus-io/milvus/internal/util/funcutil"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/msgstream"
"github.com/milvus-io/milvus/internal/proto/commonpb"
......@@ -31,7 +32,7 @@ import (
oplog "github.com/opentracing/opentracing-go/log"
)
type TaskQueue interface {
type taskQueue interface {
utChan() <-chan int
utEmpty() bool
utFull() bool
......@@ -45,7 +46,10 @@ type TaskQueue interface {
Enqueue(t task) error
}
type BaseTaskQueue struct {
// TODO(dragondriver): load from config
const maxTaskNum = 1024
type baseTaskQueue struct {
unissuedTasks *list.List
activeTasks map[UniqueID]task
utLock sync.RWMutex
......@@ -56,24 +60,25 @@ type BaseTaskQueue struct {
utBufChan chan int // to block scheduler
sched *TaskScheduler
tsoAllocatorIns tsoAllocator
idAllocatorIns idAllocatorInterface
}
func (queue *BaseTaskQueue) utChan() <-chan int {
func (queue *baseTaskQueue) utChan() <-chan int {
return queue.utBufChan
}
func (queue *BaseTaskQueue) utEmpty() bool {
func (queue *baseTaskQueue) utEmpty() bool {
queue.utLock.RLock()
defer queue.utLock.RUnlock()
return queue.unissuedTasks.Len() == 0
}
func (queue *BaseTaskQueue) utFull() bool {
func (queue *baseTaskQueue) utFull() bool {
return int64(queue.unissuedTasks.Len()) >= queue.maxTaskNum
}
func (queue *BaseTaskQueue) addUnissuedTask(t task) error {
func (queue *baseTaskQueue) addUnissuedTask(t task) error {
queue.utLock.Lock()
defer queue.utLock.Unlock()
......@@ -85,7 +90,7 @@ func (queue *BaseTaskQueue) addUnissuedTask(t task) error {
return nil
}
func (queue *BaseTaskQueue) FrontUnissuedTask() task {
func (queue *baseTaskQueue) FrontUnissuedTask() task {
queue.utLock.RLock()
defer queue.utLock.RUnlock()
......@@ -97,7 +102,7 @@ func (queue *BaseTaskQueue) FrontUnissuedTask() task {
return queue.unissuedTasks.Front().Value.(task)
}
func (queue *BaseTaskQueue) PopUnissuedTask() task {
func (queue *baseTaskQueue) PopUnissuedTask() task {
queue.utLock.Lock()
defer queue.utLock.Unlock()
......@@ -112,7 +117,7 @@ func (queue *BaseTaskQueue) PopUnissuedTask() task {
return ft.Value.(task)
}
func (queue *BaseTaskQueue) AddActiveTask(t task) {
func (queue *baseTaskQueue) AddActiveTask(t task) {
queue.atLock.Lock()
defer queue.atLock.Unlock()
tID := t.ID()
......@@ -124,7 +129,7 @@ func (queue *BaseTaskQueue) AddActiveTask(t task) {
queue.activeTasks[tID] = t
}
func (queue *BaseTaskQueue) PopActiveTask(tID UniqueID) task {
func (queue *baseTaskQueue) PopActiveTask(tID UniqueID) task {
queue.atLock.Lock()
defer queue.atLock.Unlock()
t, ok := queue.activeTasks[tID]
......@@ -137,7 +142,7 @@ func (queue *BaseTaskQueue) PopActiveTask(tID UniqueID) task {
return t
}
func (queue *BaseTaskQueue) getTaskByReqID(reqID UniqueID) task {
func (queue *baseTaskQueue) getTaskByReqID(reqID UniqueID) task {
queue.utLock.RLock()
defer queue.utLock.RUnlock()
for e := queue.unissuedTasks.Front(); e != nil; e = e.Next() {
......@@ -157,7 +162,7 @@ func (queue *BaseTaskQueue) getTaskByReqID(reqID UniqueID) task {
return nil
}
func (queue *BaseTaskQueue) TaskDoneTest(ts Timestamp) bool {
func (queue *baseTaskQueue) TaskDoneTest(ts Timestamp) bool {
queue.utLock.RLock()
defer queue.utLock.RUnlock()
for e := queue.unissuedTasks.Front(); e != nil; e = e.Next() {
......@@ -177,19 +182,19 @@ func (queue *BaseTaskQueue) TaskDoneTest(ts Timestamp) bool {
return true
}
func (queue *BaseTaskQueue) Enqueue(t task) error {
func (queue *baseTaskQueue) Enqueue(t task) error {
err := t.OnEnqueue()
if err != nil {
return err
}
ts, err := queue.sched.tsoAllocator.AllocOne()
ts, err := queue.tsoAllocatorIns.AllocOne()
if err != nil {
return err
}
t.SetTs(ts)
reqID, err := queue.sched.idAllocator.AllocOne()
reqID, err := queue.idAllocatorIns.AllocOne()
if err != nil {
return err
}
......@@ -198,8 +203,21 @@ func (queue *BaseTaskQueue) Enqueue(t task) error {
return queue.addUnissuedTask(t)
}
type DdTaskQueue struct {
BaseTaskQueue
func newBaseTaskQueue(tsoAllocatorIns tsoAllocator, idAllocatorIns idAllocatorInterface) *baseTaskQueue {
return &baseTaskQueue{
unissuedTasks: list.New(),
activeTasks: make(map[UniqueID]task),
utLock: sync.RWMutex{},
atLock: sync.RWMutex{},
maxTaskNum: maxTaskNum,
utBufChan: make(chan int, maxTaskNum),
tsoAllocatorIns: tsoAllocatorIns,
idAllocatorIns: idAllocatorIns,
}
}
type ddTaskQueue struct {
*baseTaskQueue
lock sync.Mutex
}
......@@ -208,19 +226,19 @@ type pChanStatInfo struct {
tsSet map[Timestamp]struct{}
}
type DmTaskQueue struct {
BaseTaskQueue
type dmTaskQueue struct {
*baseTaskQueue
lock sync.Mutex
statsLock sync.RWMutex
pChanStatisticsInfos map[pChan]*pChanStatInfo
}
func (queue *DmTaskQueue) Enqueue(t task) error {
func (queue *dmTaskQueue) Enqueue(t task) error {
queue.lock.Lock()
defer queue.lock.Unlock()
err := queue.BaseTaskQueue.Enqueue(t)
err := queue.baseTaskQueue.Enqueue(t)
if err != nil {
return err
}
......@@ -229,13 +247,13 @@ func (queue *DmTaskQueue) Enqueue(t task) error {
return nil
}
func (queue *DmTaskQueue) PopActiveTask(tID UniqueID) task {
func (queue *dmTaskQueue) PopActiveTask(tID UniqueID) task {
queue.atLock.Lock()
defer queue.atLock.Unlock()
t, ok := queue.activeTasks[tID]
if ok {
delete(queue.activeTasks, tID)
log.Debug("Proxy DmTaskQueue popPChanStats", zap.Any("tID", t.ID()))
log.Debug("Proxy dmTaskQueue popPChanStats", zap.Any("tID", t.ID()))
queue.popPChanStats(t)
} else {
log.Debug("Proxy task not in active task list!", zap.Any("tID", tID))
......@@ -243,11 +261,11 @@ func (queue *DmTaskQueue) PopActiveTask(tID UniqueID) task {
return t
}
func (queue *DmTaskQueue) addPChanStats(t task) error {
func (queue *dmTaskQueue) addPChanStats(t task) error {
if dmT, ok := t.(dmlTask); ok {
stats, err := dmT.getPChanStats()
if err != nil {
log.Debug("Proxy DmTaskQueue addPChanStats", zap.Any("tID", t.ID()),
log.Debug("Proxy dmTaskQueue addPChanStats", zap.Any("tID", t.ID()),
zap.Any("stats", stats), zap.Error(err))
return err
}
......@@ -262,7 +280,6 @@ func (queue *DmTaskQueue) addPChanStats(t task) error {
},
}
queue.pChanStatisticsInfos[cName] = info
dmT.getChannelsTimerTicker().addPChan(cName)
} else {
if info.minTs > stat.minTs {
queue.pChanStatisticsInfos[cName].minTs = stat.minTs
......@@ -280,7 +297,7 @@ func (queue *DmTaskQueue) addPChanStats(t task) error {
return nil
}
func (queue *DmTaskQueue) popPChanStats(t task) error {
func (queue *dmTaskQueue) popPChanStats(t task) error {
if dmT, ok := t.(dmlTask); ok {
channels, err := dmT.getChannels()
if err != nil {
......@@ -306,12 +323,12 @@ func (queue *DmTaskQueue) popPChanStats(t task) error {
}
queue.statsLock.Unlock()
} else {
return fmt.Errorf("Proxy DmTaskQueue popPChanStats reflect to dmlTask failed, tID:%v", t.ID())
return fmt.Errorf("Proxy dmTaskQueue popPChanStats reflect to dmlTask failed, tID:%v", t.ID())
}
return nil
}
func (queue *DmTaskQueue) getPChanStatsInfo() (map[pChan]*pChanStatistics, error) {
func (queue *dmTaskQueue) getPChanStatsInfo() (map[pChan]*pChanStatistics, error) {
ret := make(map[pChan]*pChanStatistics)
queue.statsLock.RLock()
......@@ -325,60 +342,39 @@ func (queue *DmTaskQueue) getPChanStatsInfo() (map[pChan]*pChanStatistics, error
return ret, nil
}
type DqTaskQueue struct {
BaseTaskQueue
type dqTaskQueue struct {
*baseTaskQueue
}
func (queue *DdTaskQueue) Enqueue(t task) error {
func (queue *ddTaskQueue) Enqueue(t task) error {
queue.lock.Lock()
defer queue.lock.Unlock()
return queue.BaseTaskQueue.Enqueue(t)
return queue.baseTaskQueue.Enqueue(t)
}
func NewDdTaskQueue(sched *TaskScheduler) *DdTaskQueue {
return &DdTaskQueue{
BaseTaskQueue: BaseTaskQueue{
unissuedTasks: list.New(),
activeTasks: make(map[UniqueID]task),
maxTaskNum: 1024,
utBufChan: make(chan int, 1024),
sched: sched,
},
func newDdTaskQueue(tsoAllocatorIns tsoAllocator, idAllocatorIns idAllocatorInterface) *ddTaskQueue {
return &ddTaskQueue{
baseTaskQueue: newBaseTaskQueue(tsoAllocatorIns, idAllocatorIns),
}
}
func NewDmTaskQueue(sched *TaskScheduler) *DmTaskQueue {
return &DmTaskQueue{
BaseTaskQueue: BaseTaskQueue{
unissuedTasks: list.New(),
activeTasks: make(map[UniqueID]task),
maxTaskNum: 1024,
utBufChan: make(chan int, 1024),
sched: sched,
},
func newDmTaskQueue(tsoAllocatorIns tsoAllocator, idAllocatorIns idAllocatorInterface) *dmTaskQueue {
return &dmTaskQueue{
baseTaskQueue: newBaseTaskQueue(tsoAllocatorIns, idAllocatorIns),
pChanStatisticsInfos: make(map[pChan]*pChanStatInfo),
}
}
func NewDqTaskQueue(sched *TaskScheduler) *DqTaskQueue {
return &DqTaskQueue{
BaseTaskQueue: BaseTaskQueue{
unissuedTasks: list.New(),
activeTasks: make(map[UniqueID]task),
maxTaskNum: 1024,
utBufChan: make(chan int, 1024),
sched: sched,
},
func newDqTaskQueue(tsoAllocatorIns tsoAllocator, idAllocatorIns idAllocatorInterface) *dqTaskQueue {
return &dqTaskQueue{
baseTaskQueue: newBaseTaskQueue(tsoAllocatorIns, idAllocatorIns),
}
}
type TaskScheduler struct {
DdQueue TaskQueue
DmQueue *DmTaskQueue
DqQueue TaskQueue
idAllocator *allocator.IDAllocator
tsoAllocator *TimestampAllocator
type taskScheduler struct {
ddQueue taskQueue
dmQueue *dmTaskQueue
dqQueue taskQueue
wg sync.WaitGroup
ctx context.Context
......@@ -387,51 +383,49 @@ type TaskScheduler struct {
msFactory msgstream.Factory
}
func NewTaskScheduler(ctx context.Context,
idAllocator *allocator.IDAllocator,
tsoAllocator *TimestampAllocator,
factory msgstream.Factory) (*TaskScheduler, error) {
func newTaskScheduler(ctx context.Context,
idAllocatorIns idAllocatorInterface,
tsoAllocatorIns tsoAllocator,
factory msgstream.Factory) (*taskScheduler, error) {
ctx1, cancel := context.WithCancel(ctx)
s := &TaskScheduler{
idAllocator: idAllocator,
tsoAllocator: tsoAllocator,
ctx: ctx1,
cancel: cancel,
msFactory: factory,
s := &taskScheduler{
ctx: ctx1,
cancel: cancel,
msFactory: factory,
}
s.DdQueue = NewDdTaskQueue(s)
s.DmQueue = NewDmTaskQueue(s)
s.DqQueue = NewDqTaskQueue(s)
s.ddQueue = newDdTaskQueue(tsoAllocatorIns, idAllocatorIns)
s.dmQueue = newDmTaskQueue(tsoAllocatorIns, idAllocatorIns)
s.dqQueue = newDqTaskQueue(tsoAllocatorIns, idAllocatorIns)
return s, nil
}
func (sched *TaskScheduler) scheduleDdTask() task {
return sched.DdQueue.PopUnissuedTask()
func (sched *taskScheduler) scheduleDdTask() task {
return sched.ddQueue.PopUnissuedTask()
}
func (sched *TaskScheduler) scheduleDmTask() task {
return sched.DmQueue.PopUnissuedTask()
func (sched *taskScheduler) scheduleDmTask() task {
return sched.dmQueue.PopUnissuedTask()
}
func (sched *TaskScheduler) scheduleDqTask() task {
return sched.DqQueue.PopUnissuedTask()
func (sched *taskScheduler) scheduleDqTask() task {
return sched.dqQueue.PopUnissuedTask()
}
func (sched *TaskScheduler) getTaskByReqID(collMeta UniqueID) task {
if t := sched.DdQueue.getTaskByReqID(collMeta); t != nil {
func (sched *taskScheduler) getTaskByReqID(collMeta UniqueID) task {
if t := sched.ddQueue.getTaskByReqID(collMeta); t != nil {
return t
}
if t := sched.DmQueue.getTaskByReqID(collMeta); t != nil {
if t := sched.dmQueue.getTaskByReqID(collMeta); t != nil {
return t
}
if t := sched.DqQueue.getTaskByReqID(collMeta); t != nil {
if t := sched.dqQueue.getTaskByReqID(collMeta); t != nil {
return t
}
return nil
}
func (sched *TaskScheduler) processTask(t task, q TaskQueue) {
func (sched *taskScheduler) processTask(t task, q taskQueue) {
span, ctx := trace.StartSpanFromContext(t.TraceCtx(),
opentracing.Tags{
"Type": t.Name(),
......@@ -469,47 +463,47 @@ func (sched *TaskScheduler) processTask(t task, q TaskQueue) {
err = t.PostExecute(ctx)
}
func (sched *TaskScheduler) definitionLoop() {
func (sched *taskScheduler) definitionLoop() {
defer sched.wg.Done()
for {
select {
case <-sched.ctx.Done():
return
case <-sched.DdQueue.utChan():
if !sched.DdQueue.utEmpty() {
case <-sched.ddQueue.utChan():
if !sched.ddQueue.utEmpty() {
t := sched.scheduleDdTask()
sched.processTask(t, sched.DdQueue)
sched.processTask(t, sched.ddQueue)
}
}
}
}
func (sched *TaskScheduler) manipulationLoop() {
func (sched *taskScheduler) manipulationLoop() {
defer sched.wg.Done()
for {
select {
case <-sched.ctx.Done():
return
case <-sched.DmQueue.utChan():
if !sched.DmQueue.utEmpty() {
case <-sched.dmQueue.utChan():
if !sched.dmQueue.utEmpty() {
t := sched.scheduleDmTask()
go sched.processTask(t, sched.DmQueue)
go sched.processTask(t, sched.dmQueue)
}
}
}
}
func (sched *TaskScheduler) queryLoop() {
func (sched *taskScheduler) queryLoop() {
defer sched.wg.Done()
for {
select {
case <-sched.ctx.Done():
return
case <-sched.DqQueue.utChan():
if !sched.DqQueue.utEmpty() {
case <-sched.dqQueue.utChan():
if !sched.dqQueue.utEmpty() {
t := sched.scheduleDqTask()
go sched.processTask(t, sched.DqQueue)
go sched.processTask(t, sched.dqQueue)
} else {
log.Debug("query queue is empty ...")
}
......@@ -561,25 +555,6 @@ func newQueryResultBuf() *queryResultBuf {
}
}
func setContain(m1, m2 map[interface{}]struct{}) bool {
log.Debug("Proxy task_scheduler setContain", zap.Any("len(m1)", len(m1)),
zap.Any("len(m2)", len(m2)))
if len(m1) < len(m2) {
return false
}
for k2 := range m2 {
_, ok := m1[k2]
log.Debug("Proxy task_scheduler setContain", zap.Any("k2", fmt.Sprintf("%v", k2)),
zap.Any("ok", ok))
if !ok {
return false
}
}
return true
}
func (sr *resultBufHeader) readyToReduce() bool {
if sr.haveError {
log.Debug("Proxy searchResultBuf readyToReduce", zap.Any("haveError", true))
......@@ -608,7 +583,7 @@ func (sr *resultBufHeader) readyToReduce() bool {
sealedGlobalSegmentIDsStrMap[x.(int64)] = 1
}
ret1 := setContain(sr.receivedVChansSet, sr.usedVChans)
ret1 := funcutil.SetContain(sr.receivedVChansSet, sr.usedVChans)
log.Debug("Proxy searchResultBuf readyToReduce", zap.Any("receivedVChansSet", receivedVChansSetStrMap),
zap.Any("usedVChans", usedVChansSetStrMap),
zap.Any("receivedSealedSegmentIDsSet", sealedSegmentIDsStrMap),
......@@ -618,7 +593,7 @@ func (sr *resultBufHeader) readyToReduce() bool {
if !ret1 {
return false
}
ret := setContain(sr.receivedSealedSegmentIDsSet, sr.receivedGlobalSegmentIDsSet)
ret := funcutil.SetContain(sr.receivedSealedSegmentIDsSet, sr.receivedGlobalSegmentIDsSet)
log.Debug("Proxy searchResultBuf readyToReduce", zap.Any("ret", ret))
return ret
}
......@@ -658,7 +633,7 @@ func (qr *queryResultBuf) addPartialResult(result *internalpb.RetrieveResults) {
result.GlobalSealedSegmentIDs)
}
func (sched *TaskScheduler) collectResultLoop() {
func (sched *taskScheduler) collectResultLoop() {
defer sched.wg.Done()
queryResultMsgStream, _ := sched.msFactory.NewQueryMsgStream(sched.ctx)
......@@ -862,7 +837,7 @@ func (sched *TaskScheduler) collectResultLoop() {
}
}
func (sched *TaskScheduler) Start() error {
func (sched *taskScheduler) Start() error {
sched.wg.Add(1)
go sched.definitionLoop()
......@@ -878,17 +853,17 @@ func (sched *TaskScheduler) Start() error {
return nil
}
func (sched *TaskScheduler) Close() {
func (sched *taskScheduler) Close() {
sched.cancel()
sched.wg.Wait()
}
func (sched *TaskScheduler) TaskDoneTest(ts Timestamp) bool {
ddTaskDone := sched.DdQueue.TaskDoneTest(ts)
dmTaskDone := sched.DmQueue.TaskDoneTest(ts)
func (sched *taskScheduler) TaskDoneTest(ts Timestamp) bool {
ddTaskDone := sched.ddQueue.TaskDoneTest(ts)
dmTaskDone := sched.dmQueue.TaskDoneTest(ts)
return ddTaskDone && dmTaskDone
}
func (sched *TaskScheduler) getPChanStatistics() (map[pChan]*pChanStatistics, error) {
return sched.DmQueue.getPChanStatsInfo()
func (sched *taskScheduler) getPChanStatistics() (map[pChan]*pChanStatistics, error) {
return sched.dmQueue.getPChanStatsInfo()
}
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package proxy
import (
"context"
"math/rand"
"sync"
"testing"
"time"
"github.com/milvus-io/milvus/internal/msgstream"
"github.com/stretchr/testify/assert"
)
func TestBaseTaskQueue(t *testing.T) {
var err error
var unissuedTask task
var activeTask task
var done bool
tsoAllocatorIns := newMockTsoAllocator()
idAllocatorIns := newMockIDAllocatorInterface()
queue := newBaseTaskQueue(tsoAllocatorIns, idAllocatorIns)
assert.NotNil(t, queue)
assert.True(t, queue.utEmpty())
assert.False(t, queue.utFull())
st := newDefaultMockTask()
stID := st.ID()
stTs := st.BeginTs()
// no task in queue
unissuedTask = queue.FrontUnissuedTask()
assert.Nil(t, unissuedTask)
unissuedTask = queue.getTaskByReqID(stID)
assert.Nil(t, unissuedTask)
unissuedTask = queue.PopUnissuedTask()
assert.Nil(t, unissuedTask)
done = queue.TaskDoneTest(stTs)
assert.True(t, done)
// task enqueue, only one task in queue
err = queue.Enqueue(st)
assert.NoError(t, err)
assert.False(t, queue.utEmpty())
assert.False(t, queue.utFull())
assert.Equal(t, 1, queue.unissuedTasks.Len())
assert.Equal(t, 1, len(queue.utChan()))
unissuedTask = queue.FrontUnissuedTask()
assert.NotNil(t, unissuedTask)
unissuedTask = queue.getTaskByReqID(unissuedTask.ID())
assert.NotNil(t, unissuedTask)
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
assert.False(t, done)
unissuedTask = queue.PopUnissuedTask()
assert.NotNil(t, unissuedTask)
assert.True(t, queue.utEmpty())
assert.False(t, queue.utFull())
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
assert.True(t, done)
// test active list, no task in queue
activeTask = queue.getTaskByReqID(unissuedTask.ID())
assert.Nil(t, activeTask)
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
assert.True(t, done)
activeTask = queue.PopActiveTask(unissuedTask.ID())
assert.Nil(t, activeTask)
// test active list, no task in unissued list, only one task in active list
queue.AddActiveTask(unissuedTask)
activeTask = queue.getTaskByReqID(unissuedTask.ID())
assert.NotNil(t, activeTask)
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
assert.False(t, done)
activeTask = queue.PopActiveTask(unissuedTask.ID())
assert.NotNil(t, activeTask)
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
assert.True(t, done)
// test utFull
queue.maxTaskNum = 10 // not accurate, full also means utBufChan block
for i := 0; i < int(queue.maxTaskNum); i++ {
err = queue.Enqueue(newDefaultMockTask())
assert.Nil(t, err)
}
assert.True(t, queue.utFull())
err = queue.Enqueue(newDefaultMockTask())
assert.NotNil(t, err)
}
func TestDdTaskQueue(t *testing.T) {
var err error
var unissuedTask task
var activeTask task
var done bool
tsoAllocatorIns := newMockTsoAllocator()
idAllocatorIns := newMockIDAllocatorInterface()
queue := newDdTaskQueue(tsoAllocatorIns, idAllocatorIns)
assert.NotNil(t, queue)
assert.True(t, queue.utEmpty())
assert.False(t, queue.utFull())
st := newDefaultMockDdlTask()
stID := st.ID()
stTs := st.BeginTs()
// no task in queue
unissuedTask = queue.FrontUnissuedTask()
assert.Nil(t, unissuedTask)
unissuedTask = queue.getTaskByReqID(stID)
assert.Nil(t, unissuedTask)
unissuedTask = queue.PopUnissuedTask()
assert.Nil(t, unissuedTask)
done = queue.TaskDoneTest(stTs)
assert.True(t, done)
// task enqueue, only one task in queue
err = queue.Enqueue(st)
assert.NoError(t, err)
assert.False(t, queue.utEmpty())
assert.False(t, queue.utFull())
assert.Equal(t, 1, queue.unissuedTasks.Len())
assert.Equal(t, 1, len(queue.utChan()))
unissuedTask = queue.FrontUnissuedTask()
assert.NotNil(t, unissuedTask)
unissuedTask = queue.getTaskByReqID(unissuedTask.ID())
assert.NotNil(t, unissuedTask)
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
assert.False(t, done)
unissuedTask = queue.PopUnissuedTask()
assert.NotNil(t, unissuedTask)
assert.True(t, queue.utEmpty())
assert.False(t, queue.utFull())
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
assert.True(t, done)
// test active list, no task in queue
activeTask = queue.getTaskByReqID(unissuedTask.ID())
assert.Nil(t, activeTask)
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
assert.True(t, done)
activeTask = queue.PopActiveTask(unissuedTask.ID())
assert.Nil(t, activeTask)
// test active list, no task in unissued list, only one task in active list
queue.AddActiveTask(unissuedTask)
activeTask = queue.getTaskByReqID(unissuedTask.ID())
assert.NotNil(t, activeTask)
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
assert.False(t, done)
activeTask = queue.PopActiveTask(unissuedTask.ID())
assert.NotNil(t, activeTask)
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
assert.True(t, done)
// test utFull
queue.maxTaskNum = 10 // not accurate, full also means utBufChan block
for i := 0; i < int(queue.maxTaskNum); i++ {
err = queue.Enqueue(newDefaultMockDdlTask())
assert.Nil(t, err)
}
assert.True(t, queue.utFull())
err = queue.Enqueue(newDefaultMockDdlTask())
assert.NotNil(t, err)
}
// test the logic of queue
func TestDmTaskQueue_Basic(t *testing.T) {
var err error
var unissuedTask task
var activeTask task
var done bool
tsoAllocatorIns := newMockTsoAllocator()
idAllocatorIns := newMockIDAllocatorInterface()
queue := newDmTaskQueue(tsoAllocatorIns, idAllocatorIns)
assert.NotNil(t, queue)
assert.True(t, queue.utEmpty())
assert.False(t, queue.utFull())
st := newDefaultMockDmlTask()
stID := st.ID()
stTs := st.BeginTs()
// no task in queue
unissuedTask = queue.FrontUnissuedTask()
assert.Nil(t, unissuedTask)
unissuedTask = queue.getTaskByReqID(stID)
assert.Nil(t, unissuedTask)
unissuedTask = queue.PopUnissuedTask()
assert.Nil(t, unissuedTask)
done = queue.TaskDoneTest(stTs)
assert.True(t, done)
// task enqueue, only one task in queue
err = queue.Enqueue(st)
assert.NoError(t, err)
assert.False(t, queue.utEmpty())
assert.False(t, queue.utFull())
assert.Equal(t, 1, queue.unissuedTasks.Len())
assert.Equal(t, 1, len(queue.utChan()))
unissuedTask = queue.FrontUnissuedTask()
assert.NotNil(t, unissuedTask)
unissuedTask = queue.getTaskByReqID(unissuedTask.ID())
assert.NotNil(t, unissuedTask)
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
assert.False(t, done)
unissuedTask = queue.PopUnissuedTask()
assert.NotNil(t, unissuedTask)
assert.True(t, queue.utEmpty())
assert.False(t, queue.utFull())
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
assert.True(t, done)
// test active list, no task in queue
activeTask = queue.getTaskByReqID(unissuedTask.ID())
assert.Nil(t, activeTask)
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
assert.True(t, done)
activeTask = queue.PopActiveTask(unissuedTask.ID())
assert.Nil(t, activeTask)
// test active list, no task in unissued list, only one task in active list
queue.AddActiveTask(unissuedTask)
activeTask = queue.getTaskByReqID(unissuedTask.ID())
assert.NotNil(t, activeTask)
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
assert.False(t, done)
activeTask = queue.PopActiveTask(unissuedTask.ID())
assert.NotNil(t, activeTask)
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
assert.True(t, done)
// test utFull
queue.maxTaskNum = 10 // not accurate, full also means utBufChan block
for i := 0; i < int(queue.maxTaskNum); i++ {
err = queue.Enqueue(newDefaultMockDmlTask())
assert.Nil(t, err)
}
assert.True(t, queue.utFull())
err = queue.Enqueue(newDefaultMockDmlTask())
assert.NotNil(t, err)
}
// test the timestamp statistics
func TestDmTaskQueue_TimestampStatistics(t *testing.T) {
var err error
var unissuedTask task
tsoAllocatorIns := newMockTsoAllocator()
idAllocatorIns := newMockIDAllocatorInterface()
queue := newDmTaskQueue(tsoAllocatorIns, idAllocatorIns)
assert.NotNil(t, queue)
st := newDefaultMockDmlTask()
stPChans := st.pchans
err = queue.Enqueue(st)
assert.NoError(t, err)
stats, err := queue.getPChanStatsInfo()
assert.NoError(t, err)
assert.Equal(t, len(stPChans), len(stats))
unissuedTask = queue.FrontUnissuedTask()
assert.NotNil(t, unissuedTask)
for _, stat := range stats {
assert.Equal(t, unissuedTask.BeginTs(), stat.minTs)
assert.Equal(t, unissuedTask.EndTs(), stat.maxTs)
}
unissuedTask = queue.PopUnissuedTask()
assert.NotNil(t, unissuedTask)
assert.True(t, queue.utEmpty())
queue.AddActiveTask(unissuedTask)
queue.PopActiveTask(unissuedTask.ID())
stats, err = queue.getPChanStatsInfo()
assert.NoError(t, err)
assert.Zero(t, len(stats))
}
func TestDqTaskQueue(t *testing.T) {
var err error
var unissuedTask task
var activeTask task
var done bool
tsoAllocatorIns := newMockTsoAllocator()
idAllocatorIns := newMockIDAllocatorInterface()
queue := newDqTaskQueue(tsoAllocatorIns, idAllocatorIns)
assert.NotNil(t, queue)
assert.True(t, queue.utEmpty())
assert.False(t, queue.utFull())
st := newDefaultMockDqlTask()
stID := st.ID()
stTs := st.BeginTs()
// no task in queue
unissuedTask = queue.FrontUnissuedTask()
assert.Nil(t, unissuedTask)
unissuedTask = queue.getTaskByReqID(stID)
assert.Nil(t, unissuedTask)
unissuedTask = queue.PopUnissuedTask()
assert.Nil(t, unissuedTask)
done = queue.TaskDoneTest(stTs)
assert.True(t, done)
// task enqueue, only one task in queue
err = queue.Enqueue(st)
assert.NoError(t, err)
assert.False(t, queue.utEmpty())
assert.False(t, queue.utFull())
assert.Equal(t, 1, queue.unissuedTasks.Len())
assert.Equal(t, 1, len(queue.utChan()))
unissuedTask = queue.FrontUnissuedTask()
assert.NotNil(t, unissuedTask)
unissuedTask = queue.getTaskByReqID(unissuedTask.ID())
assert.NotNil(t, unissuedTask)
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
assert.False(t, done)
unissuedTask = queue.PopUnissuedTask()
assert.NotNil(t, unissuedTask)
assert.True(t, queue.utEmpty())
assert.False(t, queue.utFull())
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
assert.True(t, done)
// test active list, no task in queue
activeTask = queue.getTaskByReqID(unissuedTask.ID())
assert.Nil(t, activeTask)
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
assert.True(t, done)
activeTask = queue.PopActiveTask(unissuedTask.ID())
assert.Nil(t, activeTask)
// test active list, no task in unissued list, only one task in active list
queue.AddActiveTask(unissuedTask)
activeTask = queue.getTaskByReqID(unissuedTask.ID())
assert.NotNil(t, activeTask)
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
assert.False(t, done)
activeTask = queue.PopActiveTask(unissuedTask.ID())
assert.NotNil(t, activeTask)
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
assert.True(t, done)
// test utFull
queue.maxTaskNum = 10 // not accurate, full also means utBufChan block
for i := 0; i < int(queue.maxTaskNum); i++ {
err = queue.Enqueue(newDefaultMockDqlTask())
assert.Nil(t, err)
}
assert.True(t, queue.utFull())
err = queue.Enqueue(newDefaultMockDqlTask())
assert.NotNil(t, err)
}
func TestTaskScheduler(t *testing.T) {
var err error
ctx := context.Background()
tsoAllocatorIns := newMockTsoAllocator()
idAllocatorIns := newMockIDAllocatorInterface()
factory := msgstream.NewSimpleMsgStreamFactory()
sched, err := newTaskScheduler(ctx, idAllocatorIns, tsoAllocatorIns, factory)
assert.NoError(t, err)
assert.NotNil(t, sched)
err = sched.Start()
assert.NoError(t, err)
defer sched.Close()
assert.True(t, sched.TaskDoneTest(Timestamp(time.Now().Nanosecond())))
stats, err := sched.getPChanStatistics()
assert.NoError(t, err)
assert.Equal(t, 0, len(stats))
ddNum := rand.Int() % 10
dmNum := rand.Int() % 10
dqNum := rand.Int() % 10
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < ddNum; i++ {
wg.Add(1)
go func() {
defer wg.Done()
err := sched.ddQueue.Enqueue(newDefaultMockDdlTask())
assert.NoError(t, err)
}()
}
}()
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < dmNum; i++ {
wg.Add(1)
go func() {
defer wg.Done()
err := sched.dmQueue.Enqueue(newDefaultMockDmlTask())
assert.NoError(t, err)
}()
}
}()
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < dqNum; i++ {
wg.Add(1)
go func() {
defer wg.Done()
err := sched.dqQueue.Enqueue(newDefaultMockDqlTask())
assert.NoError(t, err)
}()
}
}()
wg.Wait()
}
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package proxy
import (
"fmt"
"math/rand"
"github.com/milvus-io/milvus/internal/proto/schemapb"
)
func genUniqueStr() string {
l := rand.Uint64()%100 + 1
b := make([]byte, l)
if _, err := rand.Read(b); err != nil {
return ""
}
return fmt.Sprintf("%X", b)
}
func generateBoolArray(numRows int) []bool {
ret := make([]bool, 0, numRows)
for i := 0; i < numRows; i++ {
......
......@@ -20,11 +20,6 @@ import (
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
)
// use timestampAllocatorInterface to keep TimestampAllocator testable
type timestampAllocatorInterface interface {
AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error)
}
type TimestampAllocator struct {
ctx context.Context
tso timestampAllocatorInterface
......
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package funcutil
// SetContain returns true if set m1 contains set m2
func SetContain(m1, m2 map[interface{}]struct{}) bool {
if len(m1) < len(m2) {
return false
}
for k2 := range m2 {
_, ok := m1[k2]
if !ok {
return false
}
}
return true
}
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package funcutil
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestSetContain(t *testing.T) {
key1 := "key1"
key2 := "key2"
key3 := "key3"
// len(m1) < len(m2)
m1 := make(map[interface{}]struct{})
m2 := make(map[interface{}]struct{})
m1[key1] = struct{}{}
m2[key1] = struct{}{}
m2[key2] = struct{}{}
assert.False(t, SetContain(m1, m2))
// len(m1) >= len(m2), but m2 contains other key not in m1
m1[key3] = struct{}{}
assert.False(t, SetContain(m1, m2))
// m1 contains m2
m1[key2] = struct{}{}
assert.True(t, SetContain(m1, m2))
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册