diff --git a/internal/master/timesync.go b/internal/master/timesync.go index 3a883b32a12f175976930788f92ba105990eac72..6e85ecbf607a8f3ee71062b90086344164b6360d 100644 --- a/internal/master/timesync.go +++ b/internal/master/timesync.go @@ -4,6 +4,7 @@ import ( "context" "log" "math" + "sync/atomic" "github.com/zilliztech/milvus-distributed/internal/errors" ms "github.com/zilliztech/milvus-distributed/internal/msgstream" @@ -19,12 +20,11 @@ type ( softTimeTickBarrier struct { peer2LastTt map[UniqueID]Timestamp minTtInterval Timestamp - lastTt Timestamp + lastTt int64 outTt chan Timestamp ttStream ms.MsgStream ctx context.Context - closeCh chan struct{} // close goroutinue in Start() - closed bool + cancel context.CancelFunc } hardTimeTickBarrier struct { @@ -32,46 +32,36 @@ type ( outTt chan Timestamp ttStream ms.MsgStream ctx context.Context - closeCh chan struct{} // close goroutinue in Start() - closed bool + cancel context.CancelFunc } ) func (ttBarrier *softTimeTickBarrier) GetTimeTick() (Timestamp, error) { - isEmpty := true - for { - - if ttBarrier.closed { + select { + case <-ttBarrier.ctx.Done(): + return 0, errors.Errorf("[GetTimeTick] closed.") + case ts, ok := <-ttBarrier.outTt: + if !ok { return 0, errors.Errorf("[GetTimeTick] closed.") } - - select { - case ts := <-ttBarrier.outTt: - isEmpty = false - ttBarrier.lastTt = ts - - default: - if isEmpty || ttBarrier.closed { - continue + num := len(ttBarrier.outTt) + for i := 0; i < num; i++ { + ts, ok = <-ttBarrier.outTt + if !ok { + return 0, errors.Errorf("[GetTimeTick] closed.") } - return ttBarrier.lastTt, nil } + atomic.StoreInt64(&(ttBarrier.lastTt), int64(ts)) + return ts, nil } } func (ttBarrier *softTimeTickBarrier) Start() error { - ttBarrier.closeCh = make(chan struct{}, 1) go func() { for { select { - - case <-ttBarrier.closeCh: - log.Printf("[TtBarrierStart] closed\n") - return - case <-ttBarrier.ctx.Done(): log.Printf("[TtBarrierStart] %s\n", ttBarrier.ctx.Err()) - ttBarrier.closed = true return case ttmsgs := <-ttBarrier.ttStream.Chan(): @@ -91,8 +81,8 @@ func (ttBarrier *softTimeTickBarrier) Start() error { // get a legal Timestamp ts := ttBarrier.minTimestamp() - - if ttBarrier.lastTt != 0 && ttBarrier.minTtInterval > ts-ttBarrier.lastTt { + lastTt := atomic.LoadInt64(&(ttBarrier.lastTt)) + if ttBarrier.lastTt != 0 && ttBarrier.minTtInterval > ts-Timestamp(lastTt) { continue } @@ -100,8 +90,6 @@ func (ttBarrier *softTimeTickBarrier) Start() error { } } } - - default: } } }() @@ -122,9 +110,7 @@ func newSoftTimeTickBarrier(ctx context.Context, sttbarrier.minTtInterval = minTtInterval sttbarrier.ttStream = *ttStream sttbarrier.outTt = make(chan Timestamp, 1024) - sttbarrier.ctx = ctx - sttbarrier.closed = false - + sttbarrier.ctx, sttbarrier.cancel = context.WithCancel(ctx) sttbarrier.peer2LastTt = make(map[UniqueID]Timestamp) for _, id := range peerIds { sttbarrier.peer2LastTt[id] = Timestamp(0) @@ -137,12 +123,7 @@ func newSoftTimeTickBarrier(ctx context.Context, } func (ttBarrier *softTimeTickBarrier) Close() { - - if ttBarrier.closeCh != nil { - ttBarrier.closeCh <- struct{}{} - } - - ttBarrier.closed = true + ttBarrier.cancel() } func (ttBarrier *softTimeTickBarrier) minTimestamp() Timestamp { @@ -156,36 +137,25 @@ func (ttBarrier *softTimeTickBarrier) minTimestamp() Timestamp { } func (ttBarrier *hardTimeTickBarrier) GetTimeTick() (Timestamp, error) { - for { - - if ttBarrier.closed { + select { + case <-ttBarrier.ctx.Done(): + return 0, errors.Errorf("[GetTimeTick] closed.") + case ts, ok := <-ttBarrier.outTt: + if !ok { return 0, errors.Errorf("[GetTimeTick] closed.") } - - select { - case ts := <-ttBarrier.outTt: - return ts, nil - default: - } + return ts, nil } } func (ttBarrier *hardTimeTickBarrier) Start() error { - ttBarrier.closeCh = make(chan struct{}, 1) - go func() { // Last timestamp synchronized state := Timestamp(0) for { select { - - case <-ttBarrier.closeCh: - log.Printf("[TtBarrierStart] closed\n") - return - case <-ttBarrier.ctx.Done(): log.Printf("[TtBarrierStart] %s\n", ttBarrier.ctx.Err()) - ttBarrier.closed = true return case ttmsgs := <-ttBarrier.ttStream.Chan(): @@ -217,7 +187,6 @@ func (ttBarrier *hardTimeTickBarrier) Start() error { } } } - default: } } }() @@ -246,8 +215,7 @@ func newHardTimeTickBarrier(ctx context.Context, sttbarrier := hardTimeTickBarrier{} sttbarrier.ttStream = *ttStream sttbarrier.outTt = make(chan Timestamp, 1024) - sttbarrier.ctx = ctx - sttbarrier.closed = false + sttbarrier.ctx, sttbarrier.cancel = context.WithCancel(ctx) sttbarrier.peer2Tt = make(map[UniqueID]Timestamp) for _, id := range peerIds { @@ -261,8 +229,5 @@ func newHardTimeTickBarrier(ctx context.Context, } func (ttBarrier *hardTimeTickBarrier) Close() { - if ttBarrier.closeCh != nil { - ttBarrier.closeCh <- struct{}{} - } - ttBarrier.closed = true + ttBarrier.cancel() }