未验证 提交 db944cd0 编写于 作者: C congqixia 提交者: GitHub

Refactor flowgraph and related invocation (#8770)

Signed-off-by: NCongqi Xia <congqi.xia@zilliz.com>
上级 6837e8a0
......@@ -282,7 +282,7 @@ func (node *DataNode) NewDataSyncService(vchan *datapb.VchannelInfo) error {
zap.Int64("Collection ID", vchan.GetCollectionID()),
zap.String("Vchannel name", vchan.GetChannelName()),
)
go dataSyncService.start()
dataSyncService.start()
return nil
}
......
......@@ -191,7 +191,7 @@ func TestDataSyncService_Start(t *testing.T) {
assert.Nil(t, err)
// sync.replica.addCollection(collMeta.ID, collMeta.Schema)
go sync.start()
sync.start()
timeRange := TimeRange{
timestampMin: 0,
......
......@@ -45,6 +45,6 @@ func newDmInputNode(ctx context.Context, factory msgstream.Factory, collID Uniqu
}
var stream msgstream.MsgStream = insertStream
node := flowgraph.NewInputNode(&stream, "dmInputNode", maxQueueLength, maxParallelism)
node := flowgraph.NewInputNode(stream, "dmInputNode", maxQueueLength, maxParallelism)
return node, nil
}
......@@ -99,7 +99,7 @@ func (dsService *dataSyncService) startCollectionFlowGraph(collectionID UniqueID
if _, ok := dsService.collectionFlowGraphs[collectionID][channel]; ok {
// start flow graph
log.Debug("start collection flow graph", zap.Any("channel", channel))
go dsService.collectionFlowGraphs[collectionID][channel].flowGraph.Start()
dsService.collectionFlowGraphs[collectionID][channel].flowGraph.Start()
}
}
return nil
......@@ -169,7 +169,7 @@ func (dsService *dataSyncService) startPartitionFlowGraph(partitionID UniqueID,
if _, ok := dsService.partitionFlowGraphs[partitionID][channel]; ok {
// start flow graph
log.Debug("start partition flow graph", zap.Any("channel", channel))
go dsService.partitionFlowGraphs[partitionID][channel].flowGraph.Start()
dsService.partitionFlowGraphs[partitionID][channel].flowGraph.Start()
}
}
return nil
......
......@@ -114,7 +114,7 @@ func (q *queryNodeFlowGraph) newDmInputNode(ctx context.Context, factory msgstre
maxQueueLength := Params.FlowGraphMaxQueueLength
maxParallelism := Params.FlowGraphMaxParallelism
node := flowgraph.NewInputNode(&insertStream, "dmlInputNode", maxQueueLength, maxParallelism)
node := flowgraph.NewInputNode(insertStream, "dmlInputNode", maxQueueLength, maxParallelism)
return node
}
......
......@@ -20,8 +20,6 @@ import (
// TimeTickedFlowGraph flowgraph with input from tt msg stream
type TimeTickedFlowGraph struct {
ctx context.Context
cancel context.CancelFunc
nodeCtx map[NodeName]*nodeCtx
stopOnce sync.Once
startOnce sync.Once
......@@ -34,6 +32,7 @@ func (fg *TimeTickedFlowGraph) AddNode(node Node) {
node: node,
inputChannels: make([]chan Msg, 0),
downstreamInputChanIdx: make(map[string]int),
closeCh: make(chan struct{}),
}
fg.nodeCtx[nodeName] = &nodeCtx
}
......@@ -80,7 +79,7 @@ func (fg *TimeTickedFlowGraph) Start() {
wg := sync.WaitGroup{}
for _, v := range fg.nodeCtx {
wg.Add(1)
go v.Start(fg.ctx, &wg)
v.Start(&wg)
}
wg.Wait()
})
......@@ -93,16 +92,12 @@ func (fg *TimeTickedFlowGraph) Close() {
// maybe need to stop in order
v.Close()
}
fg.cancel()
})
}
// NewTimeTickedFlowGraph create timetick flowgraph
func NewTimeTickedFlowGraph(ctx context.Context) *TimeTickedFlowGraph {
ctx1, cancel := context.WithCancel(ctx)
flowGraph := TimeTickedFlowGraph{
ctx: ctx1,
cancel: cancel,
nodeCtx: make(map[string]*nodeCtx),
}
......
......@@ -282,10 +282,9 @@ func TestTimeTickedFlowGraph_SetEdges(t *testing.T) {
func TestTimeTickedFlowGraph_Start(t *testing.T) {
fg, inputChan, outputChan, cancel := createExampleFlowGraph()
defer cancel()
go fg.Start()
fg.Start()
// input
time.Sleep(10 * time.Millisecond)
go func() {
for i := 0; i < 10; i++ {
a := float64(rand.Int())
......
......@@ -12,39 +12,51 @@
package flowgraph
import (
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/msgstream"
"github.com/milvus-io/milvus/internal/util/trace"
"github.com/opentracing/opentracing-go"
oplog "github.com/opentracing/opentracing-go/log"
"go.uber.org/zap"
)
// InputNode is the entry point of flowgragh
type InputNode struct {
BaseNode
inStream *msgstream.MsgStream
inStream msgstream.MsgStream
name string
}
// IsInputNode returns whether Node is InputNode
func (inNode *InputNode) IsInputNode() bool {
return true
}
func (inNode *InputNode) Start() {
inNode.inStream.Start()
}
// Close implements node
func (inNode *InputNode) Close() {
// do nothing
inNode.inStream.Close()
log.Debug("message stream closed",
zap.String("node name", inNode.name),
)
}
// Name returns node name
func (inNode *InputNode) Name() string {
return inNode.name
}
func (inNode *InputNode) InStream() *msgstream.MsgStream {
// InStream returns the internal MsgStream
func (inNode *InputNode) InStream() msgstream.MsgStream {
return inNode.inStream
}
// empty input and return one *Msg
func (inNode *InputNode) Operate(in []Msg) []Msg {
//fmt.Println("Do InputNode operation")
msgPack := (*inNode.inStream).Consume()
msgPack := inNode.inStream.Consume()
// TODO: add status
if msgPack == nil {
......@@ -73,7 +85,8 @@ func (inNode *InputNode) Operate(in []Msg) []Msg {
return []Msg{msgStreamMsg}
}
func NewInputNode(inStream *msgstream.MsgStream, nodeName string, maxQueueLength int32, maxParallelism int32) *InputNode {
// NewInputNode composes an InputNode with provided MsgStream, name and parameters
func NewInputNode(inStream msgstream.MsgStream, nodeName string, maxQueueLength int32, maxParallelism int32) *InputNode {
baseNode := BaseNode{}
baseNode.SetMaxQueueLength(maxQueueLength)
baseNode.SetMaxParallelism(maxParallelism)
......
......@@ -14,7 +14,6 @@ package flowgraph
import (
"context"
"os"
"sync"
"testing"
"github.com/milvus-io/milvus/internal/msgstream"
......@@ -40,10 +39,10 @@ func TestInputNode(t *testing.T) {
nodeName := "input_node"
inputNode := &InputNode{
inStream: &msgStream,
inStream: msgStream,
name: nodeName,
}
inputNode.Close()
defer inputNode.Close()
isInputNode := inputNode.IsInputNode()
assert.True(t, isInputNode)
......@@ -54,18 +53,8 @@ func TestInputNode(t *testing.T) {
stream := inputNode.InStream()
assert.NotNil(t, stream)
var waitGroup sync.WaitGroup
OperateFunc := func() {
msgs := make([]Msg, 0)
output := inputNode.Operate(msgs)
assert.Greater(t, len(output), 0)
msgStream.Close()
waitGroup.Done()
}
waitGroup.Add(1)
go OperateFunc()
waitGroup.Wait()
output := inputNode.Operate([]Msg{})
assert.Greater(t, len(output), 0)
}
func Test_NewInputNode(t *testing.T) {
......
......@@ -12,7 +12,6 @@
package flowgraph
import (
"context"
"fmt"
"sync"
"time"
......@@ -22,20 +21,24 @@ import (
"github.com/milvus-io/milvus/internal/log"
)
// Node is the interface defines the behavior of flowgraph
type Node interface {
Name() string
MaxQueueLength() int32
MaxParallelism() int32
Operate(in []Msg) []Msg
IsInputNode() bool
Start()
Close()
}
// BaseNode defines some common node attributes and behavior
type BaseNode struct {
maxQueueLength int32
maxParallelism int32
}
// nodeCtx maintains the running context for a Node in flowgragh
type nodeCtx struct {
node Node
inputChannels []chan Msg
......@@ -43,33 +46,25 @@ type nodeCtx struct {
downstream []*nodeCtx
downstreamInputChanIdx map[string]int
NumActiveTasks int64
NumCompletedTasks int64
closeCh chan struct{}
}
func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) {
if nodeCtx.node.IsInputNode() {
inStream, ok := nodeCtx.node.(*InputNode)
if !ok {
log.Error("Invalid inputNode")
}
(*inStream.inStream).Start()
}
// Start invoke Node `Start` method and start a worker goroutine
func (nodeCtx *nodeCtx) Start(wg *sync.WaitGroup) {
nodeCtx.node.Start()
go nodeCtx.work()
wg.Done()
}
// work handles node work spinning
// 1. collectMessage from upstream or just produce Msg from InputNode
// 2. invoke node.Operate
// 3. deliver the Operate result to downstream nodes
func (nodeCtx *nodeCtx) work() {
for {
select {
case <-ctx.Done():
if nodeCtx.node.IsInputNode() {
inStream, ok := nodeCtx.node.(*InputNode)
if !ok {
log.Error("Invalid inputNode")
}
(*inStream.inStream).Close()
log.Debug("message stream closed",
zap.Any("node name", inStream.name),
)
}
wg.Done()
case <-nodeCtx.closeCh:
return
default:
// inputs from inputsMessages for Operate
......@@ -77,7 +72,7 @@ func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) {
var res []Msg
if !nodeCtx.node.IsInputNode() {
nodeCtx.collectInputMessages(ctx)
nodeCtx.collectInputMessages()
inputs = nodeCtx.inputMessages
}
n := nodeCtx.node
......@@ -95,23 +90,23 @@ func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) {
w := sync.WaitGroup{}
for i := 0; i < downstreamLength; i++ {
w.Add(1)
go nodeCtx.downstream[i].ReceiveMsg(&w, res[i], nodeCtx.downstreamInputChanIdx[nodeCtx.downstream[i].node.Name()])
go nodeCtx.downstream[i].deliverMsg(&w, res[i], nodeCtx.downstreamInputChanIdx[nodeCtx.downstream[i].node.Name()])
}
w.Wait()
}
}
}
// Close handles cleanup logic and notify worker to quit
func (nodeCtx *nodeCtx) Close() {
// data race with nodeCtx.ReceiveMsg { nodeCtx.inputChannels[inputChanIdx] <- msg }
//for _, channel := range nodeCtx.inputChannels {
// close(channel)
// log.Warn("close inputChannel")
//}
// close Node
nodeCtx.node.Close()
// notify worker
close(nodeCtx.closeCh)
}
func (nodeCtx *nodeCtx) ReceiveMsg(wg *sync.WaitGroup, msg Msg, inputChanIdx int) {
// deliverMsg tries to put the Msg to specified downstream channel
func (nodeCtx *nodeCtx) deliverMsg(wg *sync.WaitGroup, msg Msg, inputChanIdx int) {
defer wg.Done()
defer func() {
err := recover()
......@@ -119,11 +114,13 @@ func (nodeCtx *nodeCtx) ReceiveMsg(wg *sync.WaitGroup, msg Msg, inputChanIdx int
log.Warn(fmt.Sprintln(err))
}
}()
nodeCtx.inputChannels[inputChanIdx] <- msg
//fmt.Println((*nodeCtx.node).Name(), "receive to input channel ", inputChanIdx)
select {
case <-nodeCtx.closeCh:
case nodeCtx.inputChannels[inputChanIdx] <- msg:
}
}
func (nodeCtx *nodeCtx) collectInputMessages(exitCtx context.Context) {
func (nodeCtx *nodeCtx) collectInputMessages() {
inputsNum := len(nodeCtx.inputChannels)
nodeCtx.inputMessages = make([]Msg, inputsNum)
......@@ -133,7 +130,7 @@ func (nodeCtx *nodeCtx) collectInputMessages(exitCtx context.Context) {
for i := 0; i < inputsNum; i++ {
channel := nodeCtx.inputChannels[i]
select {
case <-exitCtx.Done():
case <-nodeCtx.closeCh:
return
case msg, ok := <-channel:
if !ok {
......@@ -163,7 +160,7 @@ func (nodeCtx *nodeCtx) collectInputMessages(exitCtx context.Context) {
log.Debug("try to align timestamp", zap.Uint64("t1", latestTime), zap.Uint64("t2", nodeCtx.inputMessages[i].TimeTick()))
channel := nodeCtx.inputChannels[i]
select {
case <-exitCtx.Done():
case <-nodeCtx.closeCh:
return
case msg, ok := <-channel:
if !ok {
......@@ -181,8 +178,8 @@ func (nodeCtx *nodeCtx) collectInputMessages(exitCtx context.Context) {
case <-time.After(10 * time.Second):
panic("Fatal, misaligned time tick, please restart pulsar")
case <-sign:
case <-nodeCtx.closeCh:
}
}
}
......@@ -202,10 +199,13 @@ func (node *BaseNode) SetMaxParallelism(n int32) {
node.maxParallelism = n
}
// IsInputNode returns whether Node is InputNode, BaseNode is not InputNode by default
func (node *BaseNode) IsInputNode() bool {
return false
}
func (node *BaseNode) Close() {
//TODO
}
// Start implementing Node, base node does nothing when starts
func (node *BaseNode) Start() {}
// Stop, implementing Node, base node does nothing when stops
func (node *BaseNode) Close() {}
......@@ -71,7 +71,7 @@ func TestNodeCtx_Start(t *testing.T) {
nodeName := "input_node"
inputNode := &InputNode{
inStream: &msgStream,
inStream: msgStream,
name: nodeName,
}
......@@ -79,18 +79,16 @@ func TestNodeCtx_Start(t *testing.T) {
node: inputNode,
inputChannels: make([]chan Msg, 2),
downstreamInputChanIdx: make(map[string]int),
closeCh: make(chan struct{}),
}
for i := 0; i < len(node.inputChannels); i++ {
node.inputChannels[i] = make(chan Msg)
}
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
var waitGroup sync.WaitGroup
waitGroup.Add(1)
go node.Start(ctx, &waitGroup)
node.Start(&waitGroup)
node.Close()
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册