提交 7ad9b362 编写于 作者: Q quicksilver 提交者: yefu.chen

Update reviewdog/action-hadolint github action version to v1.16.1

Signed-off-by: Nquicksilver <zhifeng.zhang@zilliz.com>
上级 1b743e5c
......@@ -29,7 +29,7 @@ jobs:
- name: Checkout
uses: actions/checkout@v2
- name: Check Dockerfile
uses: reviewdog/action-hadolint@v1
uses: reviewdog/action-hadolint@v1.16.1
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
reporter: github-pr-check # Default is github-pr-check
......
......@@ -25,7 +25,7 @@ jobs:
- name: Checkout
uses: actions/checkout@v2
- name: Check Dockerfile
uses: reviewdog/action-hadolint@v1
uses: reviewdog/action-hadolint@v1.16.1
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
reporter: github-pr-check # Default is github-pr-check
......
......@@ -100,12 +100,12 @@ func TestDataSyncService_Start(t *testing.T) {
var ddMsgStream msgstream.MsgStream = ddStream
ddMsgStream.Start()
err = insertMsgStream.Produce(ctx, &msgPack)
err = insertMsgStream.Produce(&msgPack)
assert.NoError(t, err)
err = insertMsgStream.Broadcast(ctx, &timeTickMsgPack)
err = insertMsgStream.Broadcast(&timeTickMsgPack)
assert.NoError(t, err)
err = ddMsgStream.Broadcast(ctx, &timeTickMsgPack)
err = ddMsgStream.Broadcast(&timeTickMsgPack)
assert.NoError(t, err)
// dataSync
......
......@@ -11,6 +11,7 @@ import (
"github.com/golang/protobuf/proto"
"go.uber.org/zap"
"github.com/opentracing/opentracing-go"
"github.com/zilliztech/milvus-distributed/internal/kv"
miniokv "github.com/zilliztech/milvus-distributed/internal/kv/minio"
"github.com/zilliztech/milvus-distributed/internal/log"
......@@ -18,6 +19,8 @@ import (
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
"github.com/zilliztech/milvus-distributed/internal/storage"
"github.com/zilliztech/milvus-distributed/internal/util/flowgraph"
"github.com/zilliztech/milvus-distributed/internal/util/trace"
)
type ddNode struct {
......@@ -69,7 +72,7 @@ func (ddNode *ddNode) Name() string {
return "ddNode"
}
func (ddNode *ddNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
func (ddNode *ddNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
if len(in) != 1 {
log.Error("Invalid operate message input in ddNode", zap.Int("input length", len(in)))
......@@ -83,7 +86,13 @@ func (ddNode *ddNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Con
}
if msMsg == nil {
return []Msg{}, ctx
return []Msg{}
}
var spans []opentracing.Span
for _, msg := range msMsg.TsMessages() {
sp, ctx := trace.StartSpanFromContext(msg.TraceCtx())
spans = append(spans, sp)
msg.SetTraceCtx(ctx)
}
ddNode.ddMsg = &ddMsg{
......@@ -165,8 +174,12 @@ func (ddNode *ddNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Con
default:
}
for _, span := range spans {
span.Finish()
}
var res Msg = ddNode.ddMsg
return []Msg{res}, ctx
return []Msg{res}
}
/*
......@@ -245,6 +258,10 @@ func flushTxn(ddlData *sync.Map,
}
func (ddNode *ddNode) createCollection(msg *msgstream.CreateCollectionMsg) {
sp, ctx := trace.StartSpanFromContext(msg.TraceCtx())
msg.SetTraceCtx(ctx)
defer sp.Finish()
collectionID := msg.CollectionID
// add collection
......@@ -295,6 +312,10 @@ func (ddNode *ddNode) createCollection(msg *msgstream.CreateCollectionMsg) {
dropCollection will drop collection in ddRecords but won't drop collection in replica
*/
func (ddNode *ddNode) dropCollection(msg *msgstream.DropCollectionMsg) {
sp, ctx := trace.StartSpanFromContext(msg.TraceCtx())
msg.SetTraceCtx(ctx)
defer sp.Finish()
collectionID := msg.CollectionID
// remove collection
......@@ -327,6 +348,10 @@ func (ddNode *ddNode) dropCollection(msg *msgstream.DropCollectionMsg) {
}
func (ddNode *ddNode) createPartition(msg *msgstream.CreatePartitionMsg) {
sp, ctx := trace.StartSpanFromContext(msg.TraceCtx())
msg.SetTraceCtx(ctx)
defer sp.Finish()
partitionID := msg.PartitionID
collectionID := msg.CollectionID
......@@ -363,6 +388,9 @@ func (ddNode *ddNode) createPartition(msg *msgstream.CreatePartitionMsg) {
}
func (ddNode *ddNode) dropPartition(msg *msgstream.DropPartitionMsg) {
sp, ctx := trace.StartSpanFromContext(msg.TraceCtx())
msg.SetTraceCtx(ctx)
defer sp.Finish()
partitionID := msg.PartitionID
collectionID := msg.CollectionID
......
......@@ -160,5 +160,5 @@ func TestFlowGraphDDNode_Operate(t *testing.T) {
msgStream := flowgraph.GenerateMsgStreamMsg(tsMessages, Timestamp(0), Timestamp(3),
startPos, startPos)
var inMsg Msg = msgStream
ddNode.Operate(ctx, []Msg{inMsg})
ddNode.Operate([]Msg{inMsg})
}
package datanode
import (
"context"
"math"
"go.uber.org/zap"
"github.com/opentracing/opentracing-go"
"github.com/zilliztech/milvus-distributed/internal/log"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"github.com/zilliztech/milvus-distributed/internal/util/flowgraph"
"github.com/zilliztech/milvus-distributed/internal/util/trace"
"go.uber.org/zap"
)
type filterDmNode struct {
......@@ -21,7 +22,7 @@ func (fdmNode *filterDmNode) Name() string {
return "fdmNode"
}
func (fdmNode *filterDmNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
func (fdmNode *filterDmNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
if len(in) != 2 {
log.Error("Invalid operate message input in filterDmNode", zap.Int("input length", len(in)))
......@@ -41,7 +42,13 @@ func (fdmNode *filterDmNode) Operate(ctx context.Context, in []Msg) ([]Msg, cont
}
if msgStreamMsg == nil || ddMsg == nil {
return []Msg{}, ctx
return []Msg{}
}
var spans []opentracing.Span
for _, msg := range msgStreamMsg.TsMessages() {
sp, ctx := trace.StartSpanFromContext(msg.TraceCtx())
spans = append(spans, sp)
msg.SetTraceCtx(ctx)
}
fdmNode.ddMsg = ddMsg
......@@ -77,11 +84,18 @@ func (fdmNode *filterDmNode) Operate(ctx context.Context, in []Msg) ([]Msg, cont
iMsg.endPositions = append(iMsg.endPositions, msgStreamMsg.EndPositions()...)
iMsg.gcRecord = ddMsg.gcRecord
var res Msg = &iMsg
return []Msg{res}, ctx
for _, sp := range spans {
sp.Finish()
}
return []Msg{res}
}
func (fdmNode *filterDmNode) filterInvalidInsertMessage(msg *msgstream.InsertMsg) *msgstream.InsertMsg {
// No dd record, do all insert requests.
sp, ctx := trace.StartSpanFromContext(msg.TraceCtx())
msg.SetTraceCtx(ctx)
defer sp.Finish()
records, ok := fdmNode.ddMsg.collectionRecords[msg.CollectionID]
if !ok {
return msg
......
package datanode
import (
"context"
"github.com/zilliztech/milvus-distributed/internal/util/flowgraph"
"go.uber.org/zap"
......@@ -17,7 +17,7 @@ func (gcNode *gcNode) Name() string {
return "gcNode"
}
func (gcNode *gcNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
func (gcNode *gcNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
if len(in) != 1 {
log.Error("Invalid operate message input in gcNode", zap.Int("input length", len(in)))
......@@ -31,7 +31,7 @@ func (gcNode *gcNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Con
}
if gcMsg == nil {
return []Msg{}, ctx
return []Msg{}
}
// drop collections
......@@ -42,7 +42,7 @@ func (gcNode *gcNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Con
}
}
return nil, ctx
return nil
}
func newGCNode(replica Replica) *gcNode {
......
......@@ -11,11 +11,14 @@ import (
"go.uber.org/zap"
"github.com/opentracing/opentracing-go"
"github.com/zilliztech/milvus-distributed/internal/kv"
miniokv "github.com/zilliztech/milvus-distributed/internal/kv/minio"
"github.com/zilliztech/milvus-distributed/internal/log"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/storage"
"github.com/zilliztech/milvus-distributed/internal/util/flowgraph"
"github.com/zilliztech/milvus-distributed/internal/util/trace"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
......@@ -31,26 +34,25 @@ const (
type (
InsertData = storage.InsertData
Blob = storage.Blob
insertBufferNode struct {
BaseNode
insertBuffer *insertBuffer
replica Replica
flushMeta *binlogMeta
flushMap sync.Map
minIOKV kv.Base
timeTickStream msgstream.MsgStream
segmentStatisticsStream msgstream.MsgStream
completeFlushStream msgstream.MsgStream
}
insertBuffer struct {
insertData map[UniqueID]*InsertData // SegmentID to InsertData
maxSize int32
}
)
type insertBufferNode struct {
BaseNode
insertBuffer *insertBuffer
replica Replica
flushMeta *binlogMeta
flushMap sync.Map
minIOKV kv.Base
timeTickStream msgstream.MsgStream
segmentStatisticsStream msgstream.MsgStream
completeFlushStream msgstream.MsgStream
}
type insertBuffer struct {
insertData map[UniqueID]*InsertData // SegmentID to InsertData
maxSize int32
}
func (ib *insertBuffer) size(segmentID UniqueID) int32 {
if ib.insertData == nil || len(ib.insertData) <= 0 {
......@@ -85,7 +87,7 @@ func (ibNode *insertBufferNode) Name() string {
return "ibNode"
}
func (ibNode *insertBufferNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
func (ibNode *insertBufferNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
if len(in) != 1 {
log.Error("Invalid operate message input in insertBufferNode", zap.Int("input length", len(in)))
......@@ -99,12 +101,20 @@ func (ibNode *insertBufferNode) Operate(ctx context.Context, in []Msg) ([]Msg, c
}
if iMsg == nil {
return []Msg{}, ctx
return []Msg{}
}
var spans []opentracing.Span
for _, msg := range iMsg.insertMessages {
sp, ctx := trace.StartSpanFromContext(msg.TraceCtx())
spans = append(spans, sp)
msg.SetTraceCtx(ctx)
}
// Updating segment statistics
uniqueSeg := make(map[UniqueID]int64)
for _, msg := range iMsg.insertMessages {
currentSegID := msg.GetSegmentID()
collID := msg.GetCollectionID()
partitionID := msg.GetPartitionID()
......@@ -537,8 +547,11 @@ func (ibNode *insertBufferNode) Operate(ctx context.Context, in []Msg) ([]Msg, c
gcRecord: iMsg.gcRecord,
timeRange: iMsg.timeRange,
}
for _, sp := range spans {
sp.Finish()
}
return []Msg{res}, ctx
return []Msg{res}
}
func flushSegmentTxn(collMeta *etcdpb.CollectionMeta, segID UniqueID, partitionID UniqueID, collID UniqueID,
......@@ -639,7 +652,7 @@ func (ibNode *insertBufferNode) completeFlush(segID UniqueID, finishCh <-chan bo
}
msgPack.Msgs = append(msgPack.Msgs, msg)
err := ibNode.completeFlushStream.Produce(context.TODO(), &msgPack)
err := ibNode.completeFlushStream.Produce(&msgPack)
if err != nil {
log.Error(".. Produce complete flush msg failed ..", zap.Error(err))
}
......@@ -663,7 +676,7 @@ func (ibNode *insertBufferNode) writeHardTimeTick(ts Timestamp) error {
},
}
msgPack.Msgs = append(msgPack.Msgs, &timeTickMsg)
return ibNode.timeTickStream.Produce(context.TODO(), &msgPack)
return ibNode.timeTickStream.Produce(&msgPack)
}
func (ibNode *insertBufferNode) updateSegStatistics(segIDs []UniqueID) error {
......@@ -698,7 +711,7 @@ func (ibNode *insertBufferNode) updateSegStatistics(segIDs []UniqueID) error {
var msgPack = msgstream.MsgPack{
Msgs: []msgstream.TsMsg{msg},
}
return ibNode.segmentStatisticsStream.Produce(context.TODO(), &msgPack)
return ibNode.segmentStatisticsStream.Produce(&msgPack)
}
func (ibNode *insertBufferNode) getCollectionSchemaByID(collectionID UniqueID) (*schemapb.CollectionSchema, error) {
......
......@@ -52,7 +52,7 @@ func TestFlowGraphInsertBufferNode_Operate(t *testing.T) {
iBNode := newInsertBufferNode(ctx, newBinlogMeta(), replica, msFactory)
inMsg := genInsertMsg()
var iMsg flowgraph.Msg = &inMsg
iBNode.Operate(ctx, []flowgraph.Msg{iMsg})
iBNode.Operate([]flowgraph.Msg{iMsg})
}
func genInsertMsg() insertMsg {
......
......@@ -11,55 +11,53 @@ type (
MsgStreamMsg = flowgraph.MsgStreamMsg
)
type (
key2SegMsg struct {
tsMessages []msgstream.TsMsg
timeRange TimeRange
}
type key2SegMsg struct {
tsMessages []msgstream.TsMsg
timeRange TimeRange
}
ddMsg struct {
collectionRecords map[UniqueID][]*metaOperateRecord
partitionRecords map[UniqueID][]*metaOperateRecord
flushMessages []*flushMsg
gcRecord *gcRecord
timeRange TimeRange
}
type ddMsg struct {
collectionRecords map[UniqueID][]*metaOperateRecord
partitionRecords map[UniqueID][]*metaOperateRecord
flushMessages []*flushMsg
gcRecord *gcRecord
timeRange TimeRange
}
metaOperateRecord struct {
createOrDrop bool // create: true, drop: false
timestamp Timestamp
}
type metaOperateRecord struct {
createOrDrop bool // create: true, drop: false
timestamp Timestamp
}
insertMsg struct {
insertMessages []*msgstream.InsertMsg
flushMessages []*flushMsg
gcRecord *gcRecord
timeRange TimeRange
startPositions []*internalpb.MsgPosition
endPositions []*internalpb.MsgPosition
}
type insertMsg struct {
insertMessages []*msgstream.InsertMsg
flushMessages []*flushMsg
gcRecord *gcRecord
timeRange TimeRange
startPositions []*internalpb.MsgPosition
endPositions []*internalpb.MsgPosition
}
deleteMsg struct {
deleteMessages []*msgstream.DeleteMsg
timeRange TimeRange
}
type deleteMsg struct {
deleteMessages []*msgstream.DeleteMsg
timeRange TimeRange
}
gcMsg struct {
gcRecord *gcRecord
timeRange TimeRange
}
type gcMsg struct {
gcRecord *gcRecord
timeRange TimeRange
}
gcRecord struct {
collections []UniqueID
}
type gcRecord struct {
collections []UniqueID
}
flushMsg struct {
msgID UniqueID
timestamp Timestamp
segmentIDs []UniqueID
collectionID UniqueID
}
)
type flushMsg struct {
msgID UniqueID
timestamp Timestamp
segmentIDs []UniqueID
collectionID UniqueID
}
func (ksMsg *key2SegMsg) TimeTick() Timestamp {
return ksMsg.timeRange.timestampMax
......
......@@ -307,7 +307,7 @@ func (s *Server) startStatsChannel(ctx context.Context) {
return
default:
}
msgPack, _ := statsStream.Consume()
msgPack := statsStream.Consume()
for _, msg := range msgPack.Msgs {
statistics, ok := msg.(*msgstream.SegmentStatisticsMsg)
if !ok {
......@@ -338,7 +338,7 @@ func (s *Server) startSegmentFlushChannel(ctx context.Context) {
return
default:
}
msgPack, _ := flushStream.Consume()
msgPack := flushStream.Consume()
for _, msg := range msgPack.Msgs {
if msg.Type() != commonpb.MsgType_SegmentFlushDone {
continue
......@@ -368,7 +368,7 @@ func (s *Server) startDDChannel(ctx context.Context) {
return
default:
}
msgPack, ctx := ddStream.Consume()
msgPack := ddStream.Consume()
for _, msg := range msgPack.Msgs {
if err := s.ddHandler.HandleDDMsg(ctx, msg); err != nil {
log.Error("handle dd msg error", zap.Error(err))
......@@ -622,10 +622,10 @@ func (s *Server) openNewSegment(ctx context.Context, collectionID UniqueID, part
Segment: segmentInfo,
},
}
msgPack := &msgstream.MsgPack{
msgPack := msgstream.MsgPack{
Msgs: []msgstream.TsMsg{infoMsg},
}
if err = s.segmentInfoStream.Produce(ctx, msgPack); err != nil {
if err = s.segmentInfoStream.Produce(&msgPack); err != nil {
return err
}
return nil
......
......@@ -445,10 +445,10 @@ func (c *Core) setMsgStreams() error {
TimeTickMsg: timeTickResult,
}
msgPack.Msgs = append(msgPack.Msgs, timeTickMsg)
if err := timeTickStream.Broadcast(c.ctx, &msgPack); err != nil {
if err := timeTickStream.Broadcast(&msgPack); err != nil {
return err
}
if err := ddStream.Broadcast(c.ctx, &msgPack); err != nil {
if err := ddStream.Broadcast(&msgPack); err != nil {
return err
}
return nil
......@@ -457,6 +457,7 @@ func (c *Core) setMsgStreams() error {
c.DdCreateCollectionReq = func(ctx context.Context, req *internalpb.CreateCollectionRequest) error {
msgPack := ms.MsgPack{}
baseMsg := ms.BaseMsg{
Ctx: ctx,
BeginTimestamp: req.Base.Timestamp,
EndTimestamp: req.Base.Timestamp,
HashValues: []uint32{0},
......@@ -466,7 +467,7 @@ func (c *Core) setMsgStreams() error {
CreateCollectionRequest: *req,
}
msgPack.Msgs = append(msgPack.Msgs, collMsg)
if err := ddStream.Broadcast(ctx, &msgPack); err != nil {
if err := ddStream.Broadcast(&msgPack); err != nil {
return err
}
return nil
......@@ -475,6 +476,7 @@ func (c *Core) setMsgStreams() error {
c.DdDropCollectionReq = func(ctx context.Context, req *internalpb.DropCollectionRequest) error {
msgPack := ms.MsgPack{}
baseMsg := ms.BaseMsg{
Ctx: ctx,
BeginTimestamp: req.Base.Timestamp,
EndTimestamp: req.Base.Timestamp,
HashValues: []uint32{0},
......@@ -484,7 +486,7 @@ func (c *Core) setMsgStreams() error {
DropCollectionRequest: *req,
}
msgPack.Msgs = append(msgPack.Msgs, collMsg)
if err := ddStream.Broadcast(ctx, &msgPack); err != nil {
if err := ddStream.Broadcast(&msgPack); err != nil {
return err
}
return nil
......@@ -493,6 +495,7 @@ func (c *Core) setMsgStreams() error {
c.DdCreatePartitionReq = func(ctx context.Context, req *internalpb.CreatePartitionRequest) error {
msgPack := ms.MsgPack{}
baseMsg := ms.BaseMsg{
Ctx: ctx,
BeginTimestamp: req.Base.Timestamp,
EndTimestamp: req.Base.Timestamp,
HashValues: []uint32{0},
......@@ -502,7 +505,7 @@ func (c *Core) setMsgStreams() error {
CreatePartitionRequest: *req,
}
msgPack.Msgs = append(msgPack.Msgs, collMsg)
if err := ddStream.Broadcast(ctx, &msgPack); err != nil {
if err := ddStream.Broadcast(&msgPack); err != nil {
return err
}
return nil
......@@ -511,6 +514,7 @@ func (c *Core) setMsgStreams() error {
c.DdDropPartitionReq = func(ctx context.Context, req *internalpb.DropPartitionRequest) error {
msgPack := ms.MsgPack{}
baseMsg := ms.BaseMsg{
Ctx: ctx,
BeginTimestamp: req.Base.Timestamp,
EndTimestamp: req.Base.Timestamp,
HashValues: []uint32{0},
......@@ -520,7 +524,7 @@ func (c *Core) setMsgStreams() error {
DropPartitionRequest: *req,
}
msgPack.Msgs = append(msgPack.Msgs, collMsg)
if err := ddStream.Broadcast(ctx, &msgPack); err != nil {
if err := ddStream.Broadcast(&msgPack); err != nil {
return err
}
return nil
......
......@@ -274,7 +274,7 @@ func TestMasterService(t *testing.T) {
TimeTickMsg: timeTickResult,
}
msgPack.Msgs = append(msgPack.Msgs, timeTickMsg)
err := proxyTimeTickStream.Broadcast(ctx, &msgPack)
err := proxyTimeTickStream.Broadcast(&msgPack)
assert.Nil(t, err)
ttmsg, ok := <-timeTickStream.Chan()
......@@ -585,7 +585,7 @@ func TestMasterService(t *testing.T) {
},
}
msgPack.Msgs = append(msgPack.Msgs, segMsg)
err = dataServiceSegmentStream.Broadcast(ctx, &msgPack)
err = dataServiceSegmentStream.Broadcast(&msgPack)
assert.Nil(t, err)
time.Sleep(time.Second)
......@@ -744,7 +744,7 @@ func TestMasterService(t *testing.T) {
},
}
msgPack.Msgs = append(msgPack.Msgs, segMsg)
err = dataServiceSegmentStream.Broadcast(ctx, &msgPack)
err = dataServiceSegmentStream.Broadcast(&msgPack)
assert.Nil(t, err)
time.Sleep(time.Second)
......@@ -765,7 +765,7 @@ func TestMasterService(t *testing.T) {
},
}
msgPack.Msgs = []ms.TsMsg{flushMsg}
err = dataServiceSegmentStream.Broadcast(ctx, &msgPack)
err = dataServiceSegmentStream.Broadcast(&msgPack)
assert.Nil(t, err)
time.Sleep(time.Second)
......
......@@ -6,8 +6,6 @@ import (
"fmt"
"github.com/golang/protobuf/proto"
"go.uber.org/zap"
"github.com/zilliztech/milvus-distributed/internal/log"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
......@@ -15,6 +13,7 @@ import (
"github.com/zilliztech/milvus-distributed/internal/proto/milvuspb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
"go.uber.org/zap"
)
type reqTask interface {
......
......@@ -94,7 +94,7 @@ func (mms *MemMsgStream) AsConsumer(channels []string, groupName string) {
}
}
func (mms *MemMsgStream) Produce(ctx context.Context, pack *msgstream.MsgPack) error {
func (mms *MemMsgStream) Produce(pack *msgstream.MsgPack) error {
tsMsgs := pack.Msgs
if len(tsMsgs) <= 0 {
log.Printf("Warning: Receive empty msgPack")
......@@ -150,7 +150,7 @@ func (mms *MemMsgStream) Produce(ctx context.Context, pack *msgstream.MsgPack) e
return nil
}
func (mms *MemMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) error {
func (mms *MemMsgStream) Broadcast(msgPack *msgstream.MsgPack) error {
for _, channelName := range mms.producers {
err := Mmq.Produce(channelName, msgPack)
if err != nil {
......@@ -161,18 +161,18 @@ func (mms *MemMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) error
return nil
}
func (mms *MemMsgStream) Consume() (*msgstream.MsgPack, context.Context) {
func (mms *MemMsgStream) Consume() *msgstream.MsgPack {
for {
select {
case cm, ok := <-mms.receiveBuf:
if !ok {
log.Println("buf chan closed")
return nil, nil
return nil
}
return cm, nil
return cm
case <-mms.ctx.Done():
log.Printf("context closed")
return nil, nil
return nil
}
}
}
......
......@@ -101,7 +101,7 @@ func TestStream_GlobalMmq_Func(t *testing.T) {
if err != nil {
log.Fatalf("global mmq produce error = %v", err)
}
cm, _ := consumerStreams[0].Consume()
cm := consumerStreams[0].Consume()
assert.Equal(t, cm, &msg, "global mmq consume error")
err = Mmq.Broadcast(&msg)
......@@ -109,7 +109,7 @@ func TestStream_GlobalMmq_Func(t *testing.T) {
log.Fatalf("global mmq broadcast error = %v", err)
}
for _, cs := range consumerStreams {
cm, _ := cs.Consume()
cm := cs.Consume()
assert.Equal(t, cm, &msg, "global mmq consume error")
}
......@@ -142,12 +142,12 @@ func TestStream_MemMsgStream_Produce(t *testing.T) {
msgPack := msgstream.MsgPack{}
var hashValue uint32 = 2
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Search, 1, hashValue))
err := produceStream.Produce(context.Background(), &msgPack)
err := produceStream.Produce(&msgPack)
if err != nil {
log.Fatalf("new msgstream error = %v", err)
}
msg, _ := consumerStreams[hashValue].Consume()
msg := consumerStreams[hashValue].Consume()
if msg == nil {
log.Fatalf("msgstream consume error")
}
......@@ -167,13 +167,13 @@ func TestStream_MemMsgStream_BroadCast(t *testing.T) {
msgPack := msgstream.MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Search, 1, 100))
err := produceStream.Broadcast(context.Background(), &msgPack)
err := produceStream.Broadcast(&msgPack)
if err != nil {
log.Fatalf("new msgstream error = %v", err)
}
for _, consumer := range consumerStreams {
msg, _ := consumer.Consume()
msg := consumer.Consume()
if msg == nil {
log.Fatalf("msgstream consume error")
}
......
package msgstream
import (
"context"
"errors"
"github.com/golang/protobuf/proto"
......@@ -13,6 +14,8 @@ type MsgType = commonpb.MsgType
type MarshalType = interface{}
type TsMsg interface {
TraceCtx() context.Context
SetTraceCtx(ctx context.Context)
ID() UniqueID
BeginTs() Timestamp
EndTs() Timestamp
......@@ -25,6 +28,7 @@ type TsMsg interface {
}
type BaseMsg struct {
Ctx context.Context
BeginTimestamp Timestamp
EndTimestamp Timestamp
HashValues []uint32
......@@ -66,6 +70,13 @@ type InsertMsg struct {
internalpb.InsertRequest
}
func (it *InsertMsg) TraceCtx() context.Context {
return it.BaseMsg.Ctx
}
func (it *InsertMsg) SetTraceCtx(ctx context.Context) {
it.BaseMsg.Ctx = ctx
}
func (it *InsertMsg) ID() UniqueID {
return it.Base.MsgID
}
......@@ -118,6 +129,14 @@ type FlushCompletedMsg struct {
internalpb.SegmentFlushCompletedMsg
}
func (fl *FlushCompletedMsg) TraceCtx() context.Context {
return fl.BaseMsg.Ctx
}
func (fl *FlushCompletedMsg) SetTraceCtx(ctx context.Context) {
fl.BaseMsg.Ctx = ctx
}
func (fl *FlushCompletedMsg) ID() UniqueID {
return fl.Base.MsgID
}
......@@ -160,6 +179,14 @@ type FlushMsg struct {
internalpb.FlushMsg
}
func (fl *FlushMsg) TraceCtx() context.Context {
return fl.BaseMsg.Ctx
}
func (fl *FlushMsg) SetTraceCtx(ctx context.Context) {
fl.BaseMsg.Ctx = ctx
}
func (fl *FlushMsg) ID() UniqueID {
return fl.Base.MsgID
}
......@@ -201,6 +228,14 @@ type DeleteMsg struct {
internalpb.DeleteRequest
}
func (dt *DeleteMsg) TraceCtx() context.Context {
return dt.BaseMsg.Ctx
}
func (dt *DeleteMsg) SetTraceCtx(ctx context.Context) {
dt.BaseMsg.Ctx = ctx
}
func (dt *DeleteMsg) ID() UniqueID {
return dt.Base.MsgID
}
......@@ -254,6 +289,14 @@ type SearchMsg struct {
internalpb.SearchRequest
}
func (st *SearchMsg) TraceCtx() context.Context {
return st.BaseMsg.Ctx
}
func (st *SearchMsg) SetTraceCtx(ctx context.Context) {
st.BaseMsg.Ctx = ctx
}
func (st *SearchMsg) ID() UniqueID {
return st.Base.MsgID
}
......@@ -295,6 +338,14 @@ type SearchResultMsg struct {
internalpb.SearchResults
}
func (srt *SearchResultMsg) TraceCtx() context.Context {
return srt.BaseMsg.Ctx
}
func (srt *SearchResultMsg) SetTraceCtx(ctx context.Context) {
srt.BaseMsg.Ctx = ctx
}
func (srt *SearchResultMsg) ID() UniqueID {
return srt.Base.MsgID
}
......@@ -336,6 +387,14 @@ type TimeTickMsg struct {
internalpb.TimeTickMsg
}
func (tst *TimeTickMsg) TraceCtx() context.Context {
return tst.BaseMsg.Ctx
}
func (tst *TimeTickMsg) SetTraceCtx(ctx context.Context) {
tst.BaseMsg.Ctx = ctx
}
func (tst *TimeTickMsg) ID() UniqueID {
return tst.Base.MsgID
}
......@@ -378,6 +437,14 @@ type QueryNodeStatsMsg struct {
internalpb.QueryNodeStats
}
func (qs *QueryNodeStatsMsg) TraceCtx() context.Context {
return qs.BaseMsg.Ctx
}
func (qs *QueryNodeStatsMsg) SetTraceCtx(ctx context.Context) {
qs.BaseMsg.Ctx = ctx
}
func (qs *QueryNodeStatsMsg) ID() UniqueID {
return qs.Base.MsgID
}
......@@ -417,6 +484,14 @@ type SegmentStatisticsMsg struct {
internalpb.SegmentStatistics
}
func (ss *SegmentStatisticsMsg) TraceCtx() context.Context {
return ss.BaseMsg.Ctx
}
func (ss *SegmentStatisticsMsg) SetTraceCtx(ctx context.Context) {
ss.BaseMsg.Ctx = ctx
}
func (ss *SegmentStatisticsMsg) ID() UniqueID {
return ss.Base.MsgID
}
......@@ -466,6 +541,14 @@ type CreateCollectionMsg struct {
internalpb.CreateCollectionRequest
}
func (cc *CreateCollectionMsg) TraceCtx() context.Context {
return cc.BaseMsg.Ctx
}
func (cc *CreateCollectionMsg) SetTraceCtx(ctx context.Context) {
cc.BaseMsg.Ctx = ctx
}
func (cc *CreateCollectionMsg) ID() UniqueID {
return cc.Base.MsgID
}
......@@ -507,6 +590,14 @@ type DropCollectionMsg struct {
internalpb.DropCollectionRequest
}
func (dc *DropCollectionMsg) TraceCtx() context.Context {
return dc.BaseMsg.Ctx
}
func (dc *DropCollectionMsg) SetTraceCtx(ctx context.Context) {
dc.BaseMsg.Ctx = ctx
}
func (dc *DropCollectionMsg) ID() UniqueID {
return dc.Base.MsgID
}
......@@ -548,15 +639,23 @@ type CreatePartitionMsg struct {
internalpb.CreatePartitionRequest
}
func (cc *CreatePartitionMsg) ID() UniqueID {
return cc.Base.MsgID
func (cp *CreatePartitionMsg) TraceCtx() context.Context {
return cp.BaseMsg.Ctx
}
func (cc *CreatePartitionMsg) Type() MsgType {
return cc.Base.MsgType
func (cp *CreatePartitionMsg) SetTraceCtx(ctx context.Context) {
cp.BaseMsg.Ctx = ctx
}
func (cp *CreatePartitionMsg) ID() UniqueID {
return cp.Base.MsgID
}
func (cc *CreatePartitionMsg) Marshal(input TsMsg) (MarshalType, error) {
func (cp *CreatePartitionMsg) Type() MsgType {
return cp.Base.MsgType
}
func (cp *CreatePartitionMsg) Marshal(input TsMsg) (MarshalType, error) {
createPartitionMsg := input.(*CreatePartitionMsg)
createPartitionRequest := &createPartitionMsg.CreatePartitionRequest
mb, err := proto.Marshal(createPartitionRequest)
......@@ -566,7 +665,7 @@ func (cc *CreatePartitionMsg) Marshal(input TsMsg) (MarshalType, error) {
return mb, nil
}
func (cc *CreatePartitionMsg) Unmarshal(input MarshalType) (TsMsg, error) {
func (cp *CreatePartitionMsg) Unmarshal(input MarshalType) (TsMsg, error) {
createPartitionRequest := internalpb.CreatePartitionRequest{}
in, err := ConvertToByteArray(input)
if err != nil {
......@@ -589,15 +688,23 @@ type DropPartitionMsg struct {
internalpb.DropPartitionRequest
}
func (dc *DropPartitionMsg) ID() UniqueID {
return dc.Base.MsgID
func (dp *DropPartitionMsg) TraceCtx() context.Context {
return dp.BaseMsg.Ctx
}
func (dc *DropPartitionMsg) Type() MsgType {
return dc.Base.MsgType
func (dp *DropPartitionMsg) SetTraceCtx(ctx context.Context) {
dp.BaseMsg.Ctx = ctx
}
func (dp *DropPartitionMsg) ID() UniqueID {
return dp.Base.MsgID
}
func (dp *DropPartitionMsg) Type() MsgType {
return dp.Base.MsgType
}
func (dc *DropPartitionMsg) Marshal(input TsMsg) (MarshalType, error) {
func (dp *DropPartitionMsg) Marshal(input TsMsg) (MarshalType, error) {
dropPartitionMsg := input.(*DropPartitionMsg)
dropPartitionRequest := &dropPartitionMsg.DropPartitionRequest
mb, err := proto.Marshal(dropPartitionRequest)
......@@ -607,7 +714,7 @@ func (dc *DropPartitionMsg) Marshal(input TsMsg) (MarshalType, error) {
return mb, nil
}
func (dc *DropPartitionMsg) Unmarshal(input MarshalType) (TsMsg, error) {
func (dp *DropPartitionMsg) Unmarshal(input MarshalType) (TsMsg, error) {
dropPartitionRequest := internalpb.DropPartitionRequest{}
in, err := ConvertToByteArray(input)
if err != nil {
......@@ -630,6 +737,14 @@ type LoadIndexMsg struct {
internalpb.LoadIndex
}
func (lim *LoadIndexMsg) TraceCtx() context.Context {
return lim.BaseMsg.Ctx
}
func (lim *LoadIndexMsg) SetTraceCtx(ctx context.Context) {
lim.BaseMsg.Ctx = ctx
}
func (lim *LoadIndexMsg) ID() UniqueID {
return lim.Base.MsgID
}
......@@ -669,6 +784,14 @@ type SegmentInfoMsg struct {
datapb.SegmentMsg
}
func (sim *SegmentInfoMsg) TraceCtx() context.Context {
return sim.BaseMsg.Ctx
}
func (sim *SegmentInfoMsg) SetTraceCtx(ctx context.Context) {
sim.BaseMsg.Ctx = ctx
}
func (sim *SegmentInfoMsg) ID() UniqueID {
return sim.Base.MsgID
}
......
......@@ -30,9 +30,9 @@ type MsgStream interface {
AsConsumer(channels []string, subName string)
SetRepackFunc(repackFunc RepackFunc)
Produce(context.Context, *MsgPack) error
Broadcast(context.Context, *MsgPack) error
Consume() (*MsgPack, context.Context)
Produce(*MsgPack) error
Broadcast(*MsgPack) error
Consume() *MsgPack
Seek(offset *MsgPosition) error
}
......
......@@ -132,7 +132,6 @@ func getInsertTask(reqID msgstream.UniqueID, hashValue uint32) msgstream.TsMsg {
}
func TestStream_task_Insert(t *testing.T) {
ctx := context.Background()
pulsarAddress, _ := Params.Load("_PulsarAddress")
producerChannels := []string{"insert1", "insert2"}
consumerChannels := []string{"insert1", "insert2"}
......@@ -155,13 +154,13 @@ func TestStream_task_Insert(t *testing.T) {
outputStream.AsConsumer(consumerChannels, consumerSubName)
outputStream.Start()
err := inputStream.Produce(ctx, &msgPack)
err := inputStream.Produce(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
receiveCount := 0
for {
result, _ := outputStream.Consume()
result := outputStream.Consume()
if len(result.Msgs) > 0 {
msgs := result.Msgs
for _, v := range msgs {
......
......@@ -11,9 +11,9 @@ import (
"github.com/apache/pulsar-client-go/pulsar"
"github.com/golang/protobuf/proto"
"github.com/opentracing/opentracing-go"
"go.uber.org/zap"
"github.com/opentracing/opentracing-go"
"github.com/zilliztech/milvus-distributed/internal/log"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/msgstream/util"
......@@ -54,8 +54,6 @@ type PulsarMsgStream struct {
producerLock *sync.Mutex
consumerLock *sync.Mutex
consumerReflects []reflect.SelectCase
scMap *sync.Map
}
func newPulsarMsgStream(ctx context.Context,
......@@ -99,7 +97,6 @@ func newPulsarMsgStream(ctx context.Context,
producerLock: &sync.Mutex{},
consumerLock: &sync.Mutex{},
wait: &sync.WaitGroup{},
scMap: &sync.Map{},
}
return stream, nil
......@@ -195,7 +192,7 @@ func (ms *PulsarMsgStream) Close() {
}
}
func (ms *PulsarMsgStream) Produce(ctx context.Context, msgPack *MsgPack) error {
func (ms *PulsarMsgStream) Produce(msgPack *msgstream.MsgPack) error {
tsMsgs := msgPack.Msgs
if len(tsMsgs) <= 0 {
log.Debug("Warning: Receive empty msgPack")
......@@ -257,7 +254,7 @@ func (ms *PulsarMsgStream) Produce(ctx context.Context, msgPack *MsgPack) error
msg := &pulsar.ProducerMessage{Payload: m, Properties: map[string]string{}}
sp, spanCtx := trace.MsgSpanFromCtx(ctx, v.Msgs[i])
sp, spanCtx := trace.MsgSpanFromCtx(v.Msgs[i].TraceCtx(), v.Msgs[i])
trace.InjectContextToPulsarMsgProperties(sp.Context(), msg.Properties)
if _, err := ms.producers[channel].Send(
......@@ -274,7 +271,7 @@ func (ms *PulsarMsgStream) Produce(ctx context.Context, msgPack *MsgPack) error
return nil
}
func (ms *PulsarMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) error {
func (ms *PulsarMsgStream) Broadcast(msgPack *msgstream.MsgPack) error {
for _, v := range msgPack.Msgs {
mb, err := v.Marshal(v)
if err != nil {
......@@ -288,7 +285,7 @@ func (ms *PulsarMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) erro
msg := &pulsar.ProducerMessage{Payload: m, Properties: map[string]string{}}
sp, spanCtx := trace.MsgSpanFromCtx(ctx, v)
sp, spanCtx := trace.MsgSpanFromCtx(v.TraceCtx(), v)
trace.InjectContextToPulsarMsgProperties(sp.Context(), msg.Properties)
ms.producerLock.Lock()
......@@ -308,31 +305,18 @@ func (ms *PulsarMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) erro
return nil
}
func (ms *PulsarMsgStream) Consume() (*MsgPack, context.Context) {
func (ms *PulsarMsgStream) Consume() *msgstream.MsgPack {
for {
select {
case cm, ok := <-ms.receiveBuf:
if !ok {
log.Debug("buf chan closed")
return nil, nil
}
var ctx context.Context
var opts []opentracing.StartSpanOption
for _, msg := range cm.Msgs {
sc, loaded := ms.scMap.LoadAndDelete(msg.ID())
if loaded {
opts = append(opts, opentracing.ChildOf(sc.(opentracing.SpanContext)))
}
return nil
}
if len(opts) != 0 {
ctx = context.Background()
}
sp, ctx := trace.StartSpanFromContext(ctx, opts...)
sp.Finish()
return cm, ctx
return cm
case <-ms.ctx.Done():
//log.Debug("context closed")
return nil, nil
return nil
}
}
}
......@@ -368,7 +352,7 @@ func (ms *PulsarMsgStream) receiveMsg(consumer Consumer) {
sp, ok := trace.ExtractFromPulsarMsgProperties(tsMsg, pulsarMsg.Properties())
if ok {
ms.scMap.Store(tsMsg.ID(), sp.Context())
tsMsg.SetTraceCtx(opentracing.ContextWithSpan(context.Background(), sp))
}
msgPack := MsgPack{Msgs: []TsMsg{tsMsg}}
......@@ -460,6 +444,10 @@ func (ms *PulsarMsgStream) bufMsgPackToChannel() {
log.Error("Failed to unmarshal tsMsg", zap.Error(err))
continue
}
sp, ok := trace.ExtractFromPulsarMsgProperties(tsMsg, pulsarMsg.Properties())
if ok {
tsMsg.SetTraceCtx(opentracing.ContextWithSpan(context.Background(), sp))
}
tsMsg.SetPosition(&msgstream.MsgPosition{
ChannelName: filepath.Base(pulsarMsg.Topic()),
......@@ -736,7 +724,7 @@ func (ms *PulsarTtMsgStream) findTimeTick(consumer Consumer,
sp, ok := trace.ExtractFromPulsarMsgProperties(tsMsg, pulsarMsg.Properties())
if ok {
ms.scMap.Store(tsMsg.ID(), sp.Context())
tsMsg.SetTraceCtx(opentracing.ContextWithSpan(context.Background(), sp))
}
ms.unsolvedMutex.Lock()
......
......@@ -223,7 +223,7 @@ func initPulsarTtStream(pulsarAddress string,
func receiveMsg(outputStream msgstream.MsgStream, msgCount int) {
receiveCount := 0
for {
result, _ := outputStream.Consume()
result := outputStream.Consume()
if len(result.Msgs) > 0 {
msgs := result.Msgs
for _, v := range msgs {
......@@ -238,7 +238,6 @@ func receiveMsg(outputStream msgstream.MsgStream, msgCount int) {
}
func TestStream_PulsarMsgStream_Insert(t *testing.T) {
ctx := context.Background()
pulsarAddress, _ := Params.Load("_PulsarAddress")
c1, c2 := funcutil.RandomString(8), funcutil.RandomString(8)
producerChannels := []string{c1, c2}
......@@ -250,7 +249,7 @@ func TestStream_PulsarMsgStream_Insert(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Insert, 3, 3))
inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName)
err := inputStream.Produce(ctx, &msgPack)
err := inputStream.Produce(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
......@@ -262,7 +261,6 @@ func TestStream_PulsarMsgStream_Insert(t *testing.T) {
}
func TestStream_PulsarMsgStream_Delete(t *testing.T) {
ctx := context.Background()
pulsarAddress, _ := Params.Load("_PulsarAddress")
c := funcutil.RandomString(8)
producerChannels := []string{c}
......@@ -273,7 +271,7 @@ func TestStream_PulsarMsgStream_Delete(t *testing.T) {
//msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Delete, 3, 3))
inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName)
err := inputStream.Produce(ctx, &msgPack)
err := inputStream.Produce(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
......@@ -283,7 +281,6 @@ func TestStream_PulsarMsgStream_Delete(t *testing.T) {
}
func TestStream_PulsarMsgStream_Search(t *testing.T) {
ctx := context.Background()
pulsarAddress, _ := Params.Load("_PulsarAddress")
c := funcutil.RandomString(8)
producerChannels := []string{c}
......@@ -295,7 +292,7 @@ func TestStream_PulsarMsgStream_Search(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Search, 3, 3))
inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName)
err := inputStream.Produce(ctx, &msgPack)
err := inputStream.Produce(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
......@@ -305,7 +302,6 @@ func TestStream_PulsarMsgStream_Search(t *testing.T) {
}
func TestStream_PulsarMsgStream_SearchResult(t *testing.T) {
ctx := context.Background()
pulsarAddress, _ := Params.Load("_PulsarAddress")
c := funcutil.RandomString(8)
producerChannels := []string{c}
......@@ -316,7 +312,7 @@ func TestStream_PulsarMsgStream_SearchResult(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_SearchResult, 3, 3))
inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName)
err := inputStream.Produce(ctx, &msgPack)
err := inputStream.Produce(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
......@@ -326,7 +322,6 @@ func TestStream_PulsarMsgStream_SearchResult(t *testing.T) {
}
func TestStream_PulsarMsgStream_TimeTick(t *testing.T) {
ctx := context.Background()
pulsarAddress, _ := Params.Load("_PulsarAddress")
c := funcutil.RandomString(8)
producerChannels := []string{c}
......@@ -337,7 +332,7 @@ func TestStream_PulsarMsgStream_TimeTick(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_TimeTick, 3, 3))
inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName)
err := inputStream.Produce(ctx, &msgPack)
err := inputStream.Produce(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
......@@ -347,7 +342,6 @@ func TestStream_PulsarMsgStream_TimeTick(t *testing.T) {
}
func TestStream_PulsarMsgStream_BroadCast(t *testing.T) {
ctx := context.Background()
pulsarAddress, _ := Params.Load("_PulsarAddress")
c1, c2 := funcutil.RandomString(8), funcutil.RandomString(8)
producerChannels := []string{c1, c2}
......@@ -359,7 +353,7 @@ func TestStream_PulsarMsgStream_BroadCast(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_TimeTick, 3, 3))
inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName)
err := inputStream.Broadcast(ctx, &msgPack)
err := inputStream.Broadcast(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
......@@ -369,7 +363,6 @@ func TestStream_PulsarMsgStream_BroadCast(t *testing.T) {
}
func TestStream_PulsarMsgStream_RepackFunc(t *testing.T) {
ctx := context.Background()
pulsarAddress, _ := Params.Load("_PulsarAddress")
c1, c2 := funcutil.RandomString(8), funcutil.RandomString(8)
producerChannels := []string{c1, c2}
......@@ -381,7 +374,7 @@ func TestStream_PulsarMsgStream_RepackFunc(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Insert, 3, 3))
inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName, repackFunc)
err := inputStream.Produce(ctx, &msgPack)
err := inputStream.Produce(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
......@@ -391,7 +384,6 @@ func TestStream_PulsarMsgStream_RepackFunc(t *testing.T) {
}
func TestStream_PulsarMsgStream_InsertRepackFunc(t *testing.T) {
ctx := context.Background()
pulsarAddress, _ := Params.Load("_PulsarAddress")
c1, c2 := funcutil.RandomString(8), funcutil.RandomString(8)
producerChannels := []string{c1, c2}
......@@ -436,7 +428,7 @@ func TestStream_PulsarMsgStream_InsertRepackFunc(t *testing.T) {
outputStream.Start()
var output msgstream.MsgStream = outputStream
err := (*inputStream).Produce(ctx, &msgPack)
err := (*inputStream).Produce(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
......@@ -446,7 +438,6 @@ func TestStream_PulsarMsgStream_InsertRepackFunc(t *testing.T) {
}
func TestStream_PulsarMsgStream_DeleteRepackFunc(t *testing.T) {
ctx := context.Background()
pulsarAddress, _ := Params.Load("_PulsarAddress")
c1, c2 := funcutil.RandomString(8), funcutil.RandomString(8)
producerChannels := []string{c1, c2}
......@@ -489,7 +480,7 @@ func TestStream_PulsarMsgStream_DeleteRepackFunc(t *testing.T) {
outputStream.Start()
var output msgstream.MsgStream = outputStream
err := (*inputStream).Produce(ctx, &msgPack)
err := (*inputStream).Produce(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
......@@ -499,7 +490,6 @@ func TestStream_PulsarMsgStream_DeleteRepackFunc(t *testing.T) {
}
func TestStream_PulsarMsgStream_DefaultRepackFunc(t *testing.T) {
ctx := context.Background()
pulsarAddress, _ := Params.Load("_PulsarAddress")
c1, c2 := funcutil.RandomString(8), funcutil.RandomString(8)
producerChannels := []string{c1, c2}
......@@ -522,7 +512,7 @@ func TestStream_PulsarMsgStream_DefaultRepackFunc(t *testing.T) {
outputStream.Start()
var output msgstream.MsgStream = outputStream
err := (*inputStream).Produce(ctx, &msgPack)
err := (*inputStream).Produce(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
......@@ -532,7 +522,6 @@ func TestStream_PulsarMsgStream_DefaultRepackFunc(t *testing.T) {
}
func TestStream_PulsarTtMsgStream_Insert(t *testing.T) {
ctx := context.Background()
pulsarAddress, _ := Params.Load("_PulsarAddress")
c1, c2 := funcutil.RandomString(8), funcutil.RandomString(8)
producerChannels := []string{c1, c2}
......@@ -549,15 +538,15 @@ func TestStream_PulsarTtMsgStream_Insert(t *testing.T) {
msgPack2.Msgs = append(msgPack2.Msgs, getTimeTickMsg(5, 5, 5))
inputStream, outputStream := initPulsarTtStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName)
err := inputStream.Broadcast(ctx, &msgPack0)
err := inputStream.Broadcast(&msgPack0)
if err != nil {
log.Fatalf("broadcast error = %v", err)
}
err = inputStream.Produce(ctx, &msgPack1)
err = inputStream.Produce(&msgPack1)
if err != nil {
log.Fatalf("produce error = %v", err)
}
err = inputStream.Broadcast(ctx, &msgPack2)
err = inputStream.Broadcast(&msgPack2)
if err != nil {
log.Fatalf("broadcast error = %v", err)
}
......@@ -567,7 +556,6 @@ func TestStream_PulsarTtMsgStream_Insert(t *testing.T) {
}
func TestStream_PulsarTtMsgStream_Seek(t *testing.T) {
ctx := context.Background()
pulsarAddress, _ := Params.Load("_PulsarAddress")
c1, c2 := funcutil.RandomString(8), funcutil.RandomString(8)
producerChannels := []string{c1, c2}
......@@ -595,23 +583,23 @@ func TestStream_PulsarTtMsgStream_Seek(t *testing.T) {
msgPack5.Msgs = append(msgPack5.Msgs, getTimeTickMsg(15, 15, 15))
inputStream, outputStream := initPulsarTtStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName)
err := inputStream.Broadcast(ctx, &msgPack0)
err := inputStream.Broadcast(&msgPack0)
assert.Nil(t, err)
err = inputStream.Produce(ctx, &msgPack1)
err = inputStream.Produce(&msgPack1)
assert.Nil(t, err)
err = inputStream.Broadcast(ctx, &msgPack2)
err = inputStream.Broadcast(&msgPack2)
assert.Nil(t, err)
err = inputStream.Produce(ctx, &msgPack3)
err = inputStream.Produce(&msgPack3)
assert.Nil(t, err)
err = inputStream.Broadcast(ctx, &msgPack4)
err = inputStream.Broadcast(&msgPack4)
assert.Nil(t, err)
outputStream.Consume()
receivedMsg, _ := outputStream.Consume()
receivedMsg := outputStream.Consume()
for _, position := range receivedMsg.StartPositions {
outputStream.Seek(position)
}
err = inputStream.Broadcast(ctx, &msgPack5)
err = inputStream.Broadcast(&msgPack5)
assert.Nil(t, err)
//seekMsg, _ := outputStream.Consume()
//for _, msg := range seekMsg.Msgs {
......@@ -622,7 +610,6 @@ func TestStream_PulsarTtMsgStream_Seek(t *testing.T) {
}
func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) {
ctx := context.Background()
pulsarAddress, _ := Params.Load("_PulsarAddress")
c1, c2 := funcutil.RandomString(8), funcutil.RandomString(8)
producerChannels := []string{c1, c2}
......@@ -640,15 +627,15 @@ func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) {
msgPack2.Msgs = append(msgPack2.Msgs, getTimeTickMsg(5, 5, 5))
inputStream, outputStream := initPulsarTtStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName)
err := inputStream.Broadcast(ctx, &msgPack0)
err := inputStream.Broadcast(&msgPack0)
if err != nil {
log.Fatalf("broadcast error = %v", err)
}
err = inputStream.Produce(ctx, &msgPack1)
err = inputStream.Produce(&msgPack1)
if err != nil {
log.Fatalf("produce error = %v", err)
}
err = inputStream.Broadcast(ctx, &msgPack2)
err = inputStream.Broadcast(&msgPack2)
if err != nil {
log.Fatalf("broadcast error = %v", err)
}
......
......@@ -161,7 +161,7 @@ func (rms *RmqMsgStream) AsConsumer(channels []string, groupName string) {
}
}
func (rms *RmqMsgStream) Produce(ctx context.Context, pack *msgstream.MsgPack) error {
func (rms *RmqMsgStream) Produce(pack *msgstream.MsgPack) error {
tsMsgs := pack.Msgs
if len(tsMsgs) <= 0 {
log.Debug("Warning: Receive empty msgPack")
......@@ -228,7 +228,7 @@ func (rms *RmqMsgStream) Produce(ctx context.Context, pack *msgstream.MsgPack) e
return nil
}
func (rms *RmqMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) error {
func (rms *RmqMsgStream) Broadcast(msgPack *msgstream.MsgPack) error {
for _, v := range msgPack.Msgs {
mb, err := v.Marshal(v)
if err != nil {
......@@ -255,18 +255,18 @@ func (rms *RmqMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) error
return nil
}
func (rms *RmqMsgStream) Consume() (*msgstream.MsgPack, context.Context) {
func (rms *RmqMsgStream) Consume() *msgstream.MsgPack {
for {
select {
case cm, ok := <-rms.receiveBuf:
if !ok {
log.Debug("buf chan closed")
return nil, nil
return nil
}
return cm, nil
return cm
case <-rms.ctx.Done():
//log.Debug("context closed")
return nil, nil
return nil
}
}
}
......
......@@ -239,7 +239,7 @@ func initRmqTtStream(producerChannels []string,
func receiveMsg(outputStream msgstream.MsgStream, msgCount int) {
receiveCount := 0
for {
result, _ := outputStream.Consume()
result := outputStream.Consume()
if len(result.Msgs) > 0 {
msgs := result.Msgs
for _, v := range msgs {
......@@ -254,7 +254,6 @@ func receiveMsg(outputStream msgstream.MsgStream, msgCount int) {
}
func TestStream_RmqMsgStream_Insert(t *testing.T) {
ctx := context.Background()
producerChannels := []string{"insert1", "insert2"}
consumerChannels := []string{"insert1", "insert2"}
consumerGroupName := "InsertGroup"
......@@ -266,7 +265,7 @@ func TestStream_RmqMsgStream_Insert(t *testing.T) {
rocksdbName := "/tmp/rocksmq_insert"
etcdKV := initRmq(rocksdbName)
inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerGroupName)
err := inputStream.Produce(ctx, &msgPack)
err := inputStream.Produce(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
......@@ -276,7 +275,6 @@ func TestStream_RmqMsgStream_Insert(t *testing.T) {
}
func TestStream_RmqMsgStream_Delete(t *testing.T) {
ctx := context.Background()
producerChannels := []string{"delete"}
consumerChannels := []string{"delete"}
consumerSubName := "subDelete"
......@@ -287,7 +285,7 @@ func TestStream_RmqMsgStream_Delete(t *testing.T) {
rocksdbName := "/tmp/rocksmq_delete"
etcdKV := initRmq(rocksdbName)
inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName)
err := inputStream.Produce(ctx, &msgPack)
err := inputStream.Produce(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
......@@ -296,7 +294,6 @@ func TestStream_RmqMsgStream_Delete(t *testing.T) {
}
func TestStream_RmqMsgStream_Search(t *testing.T) {
ctx := context.Background()
producerChannels := []string{"search"}
consumerChannels := []string{"search"}
consumerSubName := "subSearch"
......@@ -308,7 +305,7 @@ func TestStream_RmqMsgStream_Search(t *testing.T) {
rocksdbName := "/tmp/rocksmq_search"
etcdKV := initRmq(rocksdbName)
inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName)
err := inputStream.Produce(ctx, &msgPack)
err := inputStream.Produce(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
......@@ -317,8 +314,6 @@ func TestStream_RmqMsgStream_Search(t *testing.T) {
}
func TestStream_RmqMsgStream_SearchResult(t *testing.T) {
ctx := context.Background()
producerChannels := []string{"searchResult"}
consumerChannels := []string{"searchResult"}
consumerSubName := "subSearchResult"
......@@ -330,7 +325,7 @@ func TestStream_RmqMsgStream_SearchResult(t *testing.T) {
rocksdbName := "/tmp/rocksmq_searchresult"
etcdKV := initRmq(rocksdbName)
inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName)
err := inputStream.Produce(ctx, &msgPack)
err := inputStream.Produce(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
......@@ -339,7 +334,6 @@ func TestStream_RmqMsgStream_SearchResult(t *testing.T) {
}
func TestStream_RmqMsgStream_TimeTick(t *testing.T) {
ctx := context.Background()
producerChannels := []string{"timeTick"}
consumerChannels := []string{"timeTick"}
consumerSubName := "subTimeTick"
......@@ -351,7 +345,7 @@ func TestStream_RmqMsgStream_TimeTick(t *testing.T) {
rocksdbName := "/tmp/rocksmq_timetick"
etcdKV := initRmq(rocksdbName)
inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName)
err := inputStream.Produce(ctx, &msgPack)
err := inputStream.Produce(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
......@@ -360,7 +354,6 @@ func TestStream_RmqMsgStream_TimeTick(t *testing.T) {
}
func TestStream_RmqMsgStream_BroadCast(t *testing.T) {
ctx := context.Background()
producerChannels := []string{"insert1", "insert2"}
consumerChannels := []string{"insert1", "insert2"}
consumerSubName := "subInsert"
......@@ -372,7 +365,7 @@ func TestStream_RmqMsgStream_BroadCast(t *testing.T) {
rocksdbName := "/tmp/rocksmq_broadcast"
etcdKV := initRmq(rocksdbName)
inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName)
err := inputStream.Broadcast(ctx, &msgPack)
err := inputStream.Broadcast(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
......@@ -381,8 +374,6 @@ func TestStream_RmqMsgStream_BroadCast(t *testing.T) {
}
func TestStream_RmqMsgStream_RepackFunc(t *testing.T) {
ctx := context.Background()
producerChannels := []string{"insert1", "insert2"}
consumerChannels := []string{"insert1", "insert2"}
consumerSubName := "subInsert"
......@@ -394,7 +385,7 @@ func TestStream_RmqMsgStream_RepackFunc(t *testing.T) {
rocksdbName := "/tmp/rocksmq_repackfunc"
etcdKV := initRmq(rocksdbName)
inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName, repackFunc)
err := inputStream.Produce(ctx, &msgPack)
err := inputStream.Produce(&msgPack)
if err != nil {
log.Fatalf("produce error = %v", err)
}
......@@ -403,8 +394,6 @@ func TestStream_RmqMsgStream_RepackFunc(t *testing.T) {
}
func TestStream_PulsarTtMsgStream_Insert(t *testing.T) {
ctx := context.Background()
producerChannels := []string{"insert1", "insert2"}
consumerChannels := []string{"insert1", "insert2"}
consumerSubName := "subInsert"
......@@ -423,15 +412,15 @@ func TestStream_PulsarTtMsgStream_Insert(t *testing.T) {
etcdKV := initRmq(rocksdbName)
inputStream, outputStream := initRmqTtStream(producerChannels, consumerChannels, consumerSubName)
err := inputStream.Broadcast(ctx, &msgPack0)
err := inputStream.Broadcast(&msgPack0)
if err != nil {
log.Fatalf("broadcast error = %v", err)
}
err = inputStream.Produce(ctx, &msgPack1)
err = inputStream.Produce(&msgPack1)
if err != nil {
log.Fatalf("produce error = %v", err)
}
err = inputStream.Broadcast(ctx, &msgPack2)
err = inputStream.Broadcast(&msgPack2)
if err != nil {
log.Fatalf("broadcast error = %v", err)
}
......
......@@ -57,6 +57,9 @@ func InsertRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, e
}
insertMsg := &msgstream.InsertMsg{
BaseMsg: BaseMsg{
Ctx: request.TraceCtx(),
},
InsertRequest: sliceRequest,
}
result[key].Msgs = append(result[key].Msgs, insertMsg)
......@@ -103,6 +106,9 @@ func DeleteRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, e
}
deleteMsg := &msgstream.DeleteMsg{
BaseMsg: BaseMsg{
Ctx: request.TraceCtx(),
},
DeleteRequest: sliceRequest,
}
result[key].Msgs = append(result[key].Msgs, deleteMsg)
......
......@@ -183,6 +183,7 @@ func insertRepackFunc(tsMsgs []msgstream.TsMsg,
// if slice, todo: a common function to calculate size of slice,
// if map, a little complicated
size := 0
size += int(unsafe.Sizeof(msg.Ctx))
size += int(unsafe.Sizeof(msg.BeginTimestamp))
size += int(unsafe.Sizeof(msg.EndTimestamp))
size += int(unsafe.Sizeof(msg.HashValues))
......@@ -262,6 +263,9 @@ func insertRepackFunc(tsMsgs []msgstream.TsMsg,
RowData: []*commonpb.Blob{row},
}
insertMsg := &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{
Ctx: request.TraceCtx(),
},
InsertRequest: sliceRequest,
}
if together { // all rows with same hash value are accumulated to only one message
......
......@@ -52,7 +52,7 @@ const (
)
type task interface {
Ctx() context.Context
TraceCtx() context.Context
ID() UniqueID // return ReqID
SetID(uid UniqueID) // set ReqID
Name() string
......@@ -79,7 +79,7 @@ type InsertTask struct {
rowIDAllocator *allocator.IDAllocator
}
func (it *InsertTask) Ctx() context.Context {
func (it *InsertTask) TraceCtx() context.Context {
return it.ctx
}
......@@ -185,7 +185,8 @@ func (it *InsertTask) Execute(ctx context.Context) error {
}
var tsMsg msgstream.TsMsg = &it.BaseInsertTask
msgPack := &msgstream.MsgPack{
it.BaseMsg.Ctx = ctx
msgPack := msgstream.MsgPack{
BeginTs: it.BeginTs(),
EndTs: it.EndTs(),
Msgs: make([]msgstream.TsMsg, 1),
......@@ -231,7 +232,7 @@ func (it *InsertTask) Execute(ctx context.Context) error {
return err
}
err = stream.Produce(ctx, msgPack)
err = stream.Produce(&msgPack)
if err != nil {
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
it.result.Status.Reason = err.Error()
......@@ -255,7 +256,7 @@ type CreateCollectionTask struct {
schema *schemapb.CollectionSchema
}
func (cct *CreateCollectionTask) Ctx() context.Context {
func (cct *CreateCollectionTask) TraceCtx() context.Context {
return cct.ctx
}
......@@ -403,7 +404,7 @@ type DropCollectionTask struct {
result *commonpb.Status
}
func (dct *DropCollectionTask) Ctx() context.Context {
func (dct *DropCollectionTask) TraceCtx() context.Context {
return dct.ctx
}
......@@ -484,7 +485,7 @@ type SearchTask struct {
query *milvuspb.SearchRequest
}
func (st *SearchTask) Ctx() context.Context {
func (st *SearchTask) TraceCtx() context.Context {
return st.ctx
}
......@@ -596,18 +597,19 @@ func (st *SearchTask) Execute(ctx context.Context) error {
var tsMsg msgstream.TsMsg = &msgstream.SearchMsg{
SearchRequest: *st.SearchRequest,
BaseMsg: msgstream.BaseMsg{
Ctx: ctx,
HashValues: []uint32{uint32(Params.ProxyID)},
BeginTimestamp: st.Base.Timestamp,
EndTimestamp: st.Base.Timestamp,
},
}
msgPack := &msgstream.MsgPack{
msgPack := msgstream.MsgPack{
BeginTs: st.Base.Timestamp,
EndTs: st.Base.Timestamp,
Msgs: make([]msgstream.TsMsg, 1),
}
msgPack.Msgs[0] = tsMsg
err := st.queryMsgStream.Produce(ctx, msgPack)
err := st.queryMsgStream.Produce(&msgPack)
log.Debug("proxynode", zap.Int("length of searchMsg", len(msgPack.Msgs)))
if err != nil {
log.Debug("proxynode", zap.String("send search request failed", err.Error()))
......@@ -990,7 +992,7 @@ func printSearchResult(partialSearchResult *internalpb.SearchResults) {
func (st *SearchTask) PostExecute(ctx context.Context) error {
for {
select {
case <-st.Ctx().Done():
case <-st.TraceCtx().Done():
log.Debug("proxynode", zap.Int64("SearchTask: wait to finish failed, timeout!, taskID:", st.ID()))
return fmt.Errorf("SearchTask:wait to finish failed, timeout: %d", st.ID())
case searchResults := <-st.resultBuf:
......@@ -1073,7 +1075,7 @@ type HasCollectionTask struct {
result *milvuspb.BoolResponse
}
func (hct *HasCollectionTask) Ctx() context.Context {
func (hct *HasCollectionTask) TraceCtx() context.Context {
return hct.ctx
}
......@@ -1144,7 +1146,7 @@ type DescribeCollectionTask struct {
result *milvuspb.DescribeCollectionResponse
}
func (dct *DescribeCollectionTask) Ctx() context.Context {
func (dct *DescribeCollectionTask) TraceCtx() context.Context {
return dct.ctx
}
......@@ -1215,7 +1217,7 @@ type GetCollectionsStatisticsTask struct {
result *milvuspb.GetCollectionStatisticsResponse
}
func (g *GetCollectionsStatisticsTask) Ctx() context.Context {
func (g *GetCollectionsStatisticsTask) TraceCtx() context.Context {
return g.ctx
}
......@@ -1302,7 +1304,7 @@ type ShowCollectionsTask struct {
result *milvuspb.ShowCollectionsResponse
}
func (sct *ShowCollectionsTask) Ctx() context.Context {
func (sct *ShowCollectionsTask) TraceCtx() context.Context {
return sct.ctx
}
......@@ -1370,7 +1372,7 @@ type CreatePartitionTask struct {
result *commonpb.Status
}
func (cpt *CreatePartitionTask) Ctx() context.Context {
func (cpt *CreatePartitionTask) TraceCtx() context.Context {
return cpt.ctx
}
......@@ -1447,7 +1449,7 @@ type DropPartitionTask struct {
result *commonpb.Status
}
func (dpt *DropPartitionTask) Ctx() context.Context {
func (dpt *DropPartitionTask) TraceCtx() context.Context {
return dpt.ctx
}
......@@ -1524,7 +1526,7 @@ type HasPartitionTask struct {
result *milvuspb.BoolResponse
}
func (hpt *HasPartitionTask) Ctx() context.Context {
func (hpt *HasPartitionTask) TraceCtx() context.Context {
return hpt.ctx
}
......@@ -1600,7 +1602,7 @@ type ShowPartitionsTask struct {
result *milvuspb.ShowPartitionsResponse
}
func (spt *ShowPartitionsTask) Ctx() context.Context {
func (spt *ShowPartitionsTask) TraceCtx() context.Context {
return spt.ctx
}
......@@ -1671,7 +1673,7 @@ type CreateIndexTask struct {
result *commonpb.Status
}
func (cit *CreateIndexTask) Ctx() context.Context {
func (cit *CreateIndexTask) TraceCtx() context.Context {
return cit.ctx
}
......@@ -1749,7 +1751,7 @@ type DescribeIndexTask struct {
result *milvuspb.DescribeIndexResponse
}
func (dit *DescribeIndexTask) Ctx() context.Context {
func (dit *DescribeIndexTask) TraceCtx() context.Context {
return dit.ctx
}
......@@ -1832,7 +1834,7 @@ type DropIndexTask struct {
result *commonpb.Status
}
func (dit *DropIndexTask) Ctx() context.Context {
func (dit *DropIndexTask) TraceCtx() context.Context {
return dit.ctx
}
......@@ -1911,7 +1913,7 @@ type GetIndexStateTask struct {
result *milvuspb.GetIndexStateResponse
}
func (gist *GetIndexStateTask) Ctx() context.Context {
func (gist *GetIndexStateTask) TraceCtx() context.Context {
return gist.ctx
}
......@@ -2142,7 +2144,7 @@ type FlushTask struct {
result *commonpb.Status
}
func (ft *FlushTask) Ctx() context.Context {
func (ft *FlushTask) TraceCtx() context.Context {
return ft.ctx
}
......@@ -2228,7 +2230,7 @@ type LoadCollectionTask struct {
result *commonpb.Status
}
func (lct *LoadCollectionTask) Ctx() context.Context {
func (lct *LoadCollectionTask) TraceCtx() context.Context {
return lct.ctx
}
......@@ -2323,7 +2325,7 @@ type ReleaseCollectionTask struct {
result *commonpb.Status
}
func (rct *ReleaseCollectionTask) Ctx() context.Context {
func (rct *ReleaseCollectionTask) TraceCtx() context.Context {
return rct.ctx
}
......@@ -2404,6 +2406,10 @@ type LoadPartitionTask struct {
result *commonpb.Status
}
func (lpt *LoadPartitionTask) TraceCtx() context.Context {
return lpt.ctx
}
func (lpt *LoadPartitionTask) ID() UniqueID {
return lpt.Base.MsgID
}
......@@ -2495,7 +2501,7 @@ type ReleasePartitionTask struct {
result *commonpb.Status
}
func (rpt *ReleasePartitionTask) Ctx() context.Context {
func (rpt *ReleasePartitionTask) TraceCtx() context.Context {
return rpt.ctx
}
......
......@@ -302,7 +302,7 @@ func (sched *TaskScheduler) getTaskByReqID(collMeta UniqueID) task {
}
func (sched *TaskScheduler) processTask(t task, q TaskQueue) {
span, ctx := trace.StartSpanFromContext(t.Ctx(),
span, ctx := trace.StartSpanFromContext(t.TraceCtx(),
opentracing.Tags{
"Type": t.Name(),
"ID": t.ID(),
......@@ -409,6 +409,8 @@ func (sched *TaskScheduler) queryResultLoop() {
continue
}
for _, tsMsg := range msgPack.Msgs {
sp, ctx := trace.StartSpanFromContext(tsMsg.TraceCtx())
tsMsg.SetTraceCtx(ctx)
searchResultMsg, _ := tsMsg.(*msgstream.SearchResultMsg)
reqID := searchResultMsg.Base.MsgID
reqIDStr := strconv.FormatInt(reqID, 10)
......@@ -443,6 +445,7 @@ func (sched *TaskScheduler) queryResultLoop() {
// log.Printf("task with reqID %v is nil", reqID)
}
}
sp.Finish()
}
case <-sched.ctx.Done():
log.Debug("proxynode server is closed ...")
......
......@@ -86,7 +86,7 @@ func (tt *timeTick) tick() error {
},
}
msgPack.Msgs = append(msgPack.Msgs, timeTickMsg)
err := tt.tickMsgStream.Produce(tt.ctx, &msgPack)
err := tt.tickMsgStream.Produce(&msgPack)
if err != nil {
log.Warn("proxynode", zap.String("error", err.Error()))
}
......
......@@ -55,7 +55,7 @@ func (tt *TimeTick) Start() error {
log.Debug("proxyservice", zap.Stringer("msg type", msg.Type()))
}
for _, channel := range tt.channels {
err = channel.Broadcast(tt.ctx, &msgPack)
err = channel.Broadcast(&msgPack)
if err != nil {
log.Error("proxyservice", zap.String("send time tick error", err.Error()))
}
......
package querynode
import (
"context"
"encoding/binary"
"math"
"testing"
......@@ -16,8 +15,6 @@ import (
// NOTE: start pulsar before test
func TestDataSyncService_Start(t *testing.T) {
ctx := context.Background()
collectionID := UniqueID(0)
node := newQueryNodeMock()
......@@ -127,10 +124,10 @@ func TestDataSyncService_Start(t *testing.T) {
var insertMsgStream msgstream.MsgStream = insertStream
insertMsgStream.Start()
err = insertMsgStream.Produce(ctx, &msgPack)
err = insertMsgStream.Produce(&msgPack)
assert.NoError(t, err)
err = insertMsgStream.Broadcast(ctx, &timeTickMsgPack)
err = insertMsgStream.Broadcast(&timeTickMsgPack)
assert.NoError(t, err)
// dataSync
......
package querynode
import (
"context"
"github.com/golang/protobuf/proto"
"github.com/opentracing/opentracing-go"
"github.com/zilliztech/milvus-distributed/internal/util/trace"
"go.uber.org/zap"
"github.com/zilliztech/milvus-distributed/internal/log"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
"github.com/zilliztech/milvus-distributed/internal/util/flowgraph"
)
type ddNode struct {
......@@ -21,7 +22,7 @@ func (ddNode *ddNode) Name() string {
return "ddNode"
}
func (ddNode *ddNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
func (ddNode *ddNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
//log.Debug("Do filterDmNode operation")
if len(in) != 1 {
......@@ -35,6 +36,13 @@ func (ddNode *ddNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Con
// TODO: add error handling
}
var spans []opentracing.Span
for _, msg := range msMsg.TsMessages() {
sp, ctx := trace.StartSpanFromContext(msg.TraceCtx())
spans = append(spans, sp)
msg.SetTraceCtx(ctx)
}
var ddMsg = ddMsg{
collectionRecords: make(map[UniqueID][]metaOperateRecord),
partitionRecords: make(map[UniqueID][]metaOperateRecord),
......@@ -74,7 +82,10 @@ func (ddNode *ddNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Con
//}
var res Msg = ddNode.ddMsg
return []Msg{res}, ctx
for _, span := range spans {
span.Finish()
}
return []Msg{res}
}
func (ddNode *ddNode) createCollection(msg *msgstream.CreateCollectionMsg) {
......
package querynode
import (
"context"
"fmt"
"go.uber.org/zap"
"github.com/opentracing/opentracing-go"
"github.com/zilliztech/milvus-distributed/internal/log"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/util/flowgraph"
"github.com/zilliztech/milvus-distributed/internal/util/trace"
"go.uber.org/zap"
)
type filterDmNode struct {
......@@ -21,7 +22,7 @@ func (fdmNode *filterDmNode) Name() string {
return "fdmNode"
}
func (fdmNode *filterDmNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
func (fdmNode *filterDmNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
//log.Debug("Do filterDmNode operation")
if len(in) != 1 {
......@@ -36,7 +37,14 @@ func (fdmNode *filterDmNode) Operate(ctx context.Context, in []Msg) ([]Msg, cont
}
if msgStreamMsg == nil {
return []Msg{}, ctx
return []Msg{}
}
var spans []opentracing.Span
for _, msg := range msgStreamMsg.TsMessages() {
sp, ctx := trace.StartSpanFromContext(msg.TraceCtx())
spans = append(spans, sp)
msg.SetTraceCtx(ctx)
}
var iMsg = insertMsg{
......@@ -61,11 +69,16 @@ func (fdmNode *filterDmNode) Operate(ctx context.Context, in []Msg) ([]Msg, cont
}
var res Msg = &iMsg
return []Msg{res}, ctx
for _, sp := range spans {
sp.Finish()
}
return []Msg{res}
}
func (fdmNode *filterDmNode) filterInvalidInsertMessage(msg *msgstream.InsertMsg) *msgstream.InsertMsg {
sp, ctx := trace.StartSpanFromContext(msg.TraceCtx())
msg.SetTraceCtx(ctx)
defer sp.Finish()
// check if collection and partition exist
collection := fdmNode.replica.hasCollection(msg.CollectionID)
partition := fdmNode.replica.hasPartition(msg.PartitionID)
......
package querynode
import (
"context"
"github.com/zilliztech/milvus-distributed/internal/util/flowgraph"
"go.uber.org/zap"
......@@ -17,7 +17,7 @@ func (gcNode *gcNode) Name() string {
return "gcNode"
}
func (gcNode *gcNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
func (gcNode *gcNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
//log.Debug("Do gcNode operation")
if len(in) != 1 {
......@@ -51,7 +51,7 @@ func (gcNode *gcNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Con
// }
//}
return nil, ctx
return nil
}
func newGCNode(replica ReplicaInterface) *gcNode {
......
......@@ -4,10 +4,12 @@ import (
"context"
"sync"
"go.uber.org/zap"
"github.com/opentracing/opentracing-go"
"github.com/zilliztech/milvus-distributed/internal/log"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/util/flowgraph"
"github.com/zilliztech/milvus-distributed/internal/util/trace"
"go.uber.org/zap"
)
type insertNode struct {
......@@ -28,7 +30,7 @@ func (iNode *insertNode) Name() string {
return "iNode"
}
func (iNode *insertNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
func (iNode *insertNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
//log.Debug("Do insertNode operation")
if len(in) != 1 {
......@@ -50,7 +52,14 @@ func (iNode *insertNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.
}
if iMsg == nil {
return []Msg{}, ctx
return []Msg{}
}
var spans []opentracing.Span
for _, msg := range iMsg.insertMessages {
sp, ctx := trace.StartSpanFromContext(msg.TraceCtx())
spans = append(spans, sp)
msg.SetTraceCtx(ctx)
}
// 1. hash insertMessages to insertData
......@@ -108,7 +117,10 @@ func (iNode *insertNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.
gcRecord: iMsg.gcRecord,
timeRange: iMsg.timeRange,
}
return []Msg{res}, ctx
for _, sp := range spans {
sp.Finish()
}
return []Msg{res}
}
func (iNode *insertNode) insert(insertData *InsertData, segmentID int64, wg *sync.WaitGroup) {
......
......@@ -3,12 +3,12 @@ package querynode
import (
"context"
"go.uber.org/zap"
"github.com/zilliztech/milvus-distributed/internal/log"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"github.com/zilliztech/milvus-distributed/internal/util/flowgraph"
"go.uber.org/zap"
)
type serviceTimeNode struct {
......@@ -22,7 +22,7 @@ func (stNode *serviceTimeNode) Name() string {
return "stNode"
}
func (stNode *serviceTimeNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
func (stNode *serviceTimeNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
//log.Debug("Do serviceTimeNode operation")
if len(in) != 1 {
......@@ -37,7 +37,7 @@ func (stNode *serviceTimeNode) Operate(ctx context.Context, in []Msg) ([]Msg, co
}
if serviceTimeMsg == nil {
return []Msg{}, ctx
return []Msg{}
}
// update service time
......@@ -57,7 +57,7 @@ func (stNode *serviceTimeNode) Operate(ctx context.Context, in []Msg) ([]Msg, co
gcRecord: serviceTimeMsg.gcRecord,
timeRange: serviceTimeMsg.timeRange,
}
return []Msg{res}, ctx
return []Msg{res}
}
func (stNode *serviceTimeNode) sendTimeTick(ts Timestamp) error {
......@@ -78,7 +78,7 @@ func (stNode *serviceTimeNode) sendTimeTick(ts Timestamp) error {
},
}
msgPack.Msgs = append(msgPack.Msgs, &timeTickMsg)
return stNode.timeTickMsgStream.Produce(context.TODO(), &msgPack)
return stNode.timeTickMsgStream.Produce(&msgPack)
}
func newServiceTimeNode(ctx context.Context, replica ReplicaInterface, factory msgstream.Factory, collectionID UniqueID) *serviceTimeNode {
......
......@@ -1038,16 +1038,16 @@ func doInsert(ctx context.Context, collectionID UniqueID, partitionID UniqueID,
var ddMsgStream msgstream.MsgStream = ddStream
ddMsgStream.Start()
err = insertMsgStream.Produce(ctx, &msgPack)
err = insertMsgStream.Produce(&msgPack)
if err != nil {
return err
}
err = insertMsgStream.Broadcast(ctx, &timeTickMsgPack)
err = insertMsgStream.Broadcast(&timeTickMsgPack)
if err != nil {
return err
}
err = ddMsgStream.Broadcast(ctx, &timeTickMsgPack)
err = ddMsgStream.Broadcast(&timeTickMsgPack)
if err != nil {
return err
}
......@@ -1104,11 +1104,11 @@ func sentTimeTick(ctx context.Context) error {
var ddMsgStream msgstream.MsgStream = ddStream
ddMsgStream.Start()
err = insertMsgStream.Broadcast(ctx, &timeTickMsgPack)
err = insertMsgStream.Broadcast(&timeTickMsgPack)
if err != nil {
return err
}
err = ddMsgStream.Broadcast(ctx, &timeTickMsgPack)
err = ddMsgStream.Broadcast(&timeTickMsgPack)
if err != nil {
return err
}
......
......@@ -15,6 +15,7 @@ import (
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"github.com/zilliztech/milvus-distributed/internal/proto/milvuspb"
"github.com/zilliztech/milvus-distributed/internal/util/trace"
)
type searchCollection struct {
......@@ -99,6 +100,9 @@ func (s *searchCollection) setServiceableTime(t Timestamp) {
}
func (s *searchCollection) emptySearch(searchMsg *msgstream.SearchMsg) {
sp, ctx := trace.StartSpanFromContext(searchMsg.TraceCtx())
defer sp.Finish()
searchMsg.SetTraceCtx(ctx)
err := s.search(searchMsg)
if err != nil {
log.Error(err.Error())
......@@ -164,6 +168,8 @@ func (s *searchCollection) doUnsolvedMsgSearch() {
continue
}
for _, sm := range searchMsg {
sp, ctx := trace.StartSpanFromContext(sm.TraceCtx())
sm.SetTraceCtx(ctx)
err := s.search(sm)
if err != nil {
log.Error(err.Error())
......@@ -172,6 +178,7 @@ func (s *searchCollection) doUnsolvedMsgSearch() {
log.Error("publish FailedSearchResult failed", zap.Error(err2))
}
}
sp.Finish()
}
log.Debug("doUnsolvedMsgSearch, do search done", zap.Int("num of searchMsg", len(searchMsg)))
}
......@@ -181,6 +188,9 @@ func (s *searchCollection) doUnsolvedMsgSearch() {
// TODO:: cache map[dsl]plan
// TODO: reBatched search requests
func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
sp, ctx := trace.StartSpanFromContext(searchMsg.TraceCtx())
defer sp.Finish()
searchMsg.SetTraceCtx(ctx)
searchTimestamp := searchMsg.Base.Timestamp
var queryBlob = searchMsg.Query.Value
query := milvuspb.SearchRequest{}
......@@ -266,7 +276,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
}
resultChannelInt, _ := strconv.ParseInt(searchMsg.ResultChannelID, 10, 64)
searchResultMsg := &msgstream.SearchResultMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(resultChannelInt)}},
BaseMsg: msgstream.BaseMsg{Ctx: searchMsg.Ctx, HashValues: []uint32{uint32(resultChannelInt)}},
SearchResults: internalpb.SearchResults{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SearchResult,
......@@ -328,7 +338,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
}
resultChannelInt, _ := strconv.ParseInt(searchMsg.ResultChannelID, 10, 64)
searchResultMsg := &msgstream.SearchResultMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(resultChannelInt)}},
BaseMsg: msgstream.BaseMsg{Ctx: searchMsg.Ctx, HashValues: []uint32{uint32(resultChannelInt)}},
SearchResults: internalpb.SearchResults{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SearchResult,
......@@ -368,19 +378,19 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
}
func (s *searchCollection) publishSearchResult(msg msgstream.TsMsg) error {
// span, ctx := opentracing.StartSpanFromContext(msg.GetMsgContext(), "publish search result")
// defer span.Finish()
// msg.SetMsgContext(ctx)
span, ctx := trace.StartSpanFromContext(msg.TraceCtx())
defer span.Finish()
msg.SetTraceCtx(ctx)
msgPack := msgstream.MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, msg)
err := s.searchResultMsgStream.Produce(context.TODO(), &msgPack)
err := s.searchResultMsgStream.Produce(&msgPack)
return err
}
func (s *searchCollection) publishFailedSearchResult(searchMsg *msgstream.SearchMsg, errMsg string) error {
// span, ctx := opentracing.StartSpanFromContext(msg.GetMsgContext(), "receive search msg")
// defer span.Finish()
// msg.SetMsgContext(ctx)
span, ctx := trace.StartSpanFromContext(searchMsg.TraceCtx())
defer span.Finish()
searchMsg.SetTraceCtx(ctx)
//log.Debug("Public fail SearchResult!")
msgPack := msgstream.MsgPack{}
......@@ -401,7 +411,7 @@ func (s *searchCollection) publishFailedSearchResult(searchMsg *msgstream.Search
}
msgPack.Msgs = append(msgPack.Msgs, searchResultMsg)
err := s.searchResultMsgStream.Produce(context.TODO(), &msgPack)
err := s.searchResultMsgStream.Produce(&msgPack)
if err != nil {
return err
}
......
......@@ -6,6 +6,7 @@ import (
"errors"
"github.com/zilliztech/milvus-distributed/internal/log"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/util/trace"
"go.uber.org/zap"
"strconv"
"strings"
......@@ -77,7 +78,7 @@ func (s *searchService) consumeSearch() {
case <-s.ctx.Done():
return
default:
msgPack, _ := s.searchMsgStream.Consume()
msgPack := s.searchMsgStream.Consume()
if msgPack == nil || len(msgPack.Msgs) <= 0 {
continue
}
......@@ -87,6 +88,8 @@ func (s *searchService) consumeSearch() {
if !ok {
continue
}
sp, ctx := trace.StartSpanFromContext(sm.BaseMsg.Ctx)
sm.BaseMsg.Ctx = ctx
err := s.collectionCheck(sm.CollectionID)
if err != nil {
s.emptySearchCollection.emptySearch(sm)
......@@ -98,6 +101,7 @@ func (s *searchService) consumeSearch() {
s.startSearchCollection(sm.CollectionID)
}
sc.msgBuffer <- sm
sp.Finish()
}
log.Debug("do empty search done", zap.Int("num of searchMsg", emptySearchNum))
}
......
......@@ -19,8 +19,6 @@ import (
)
func TestSearch_Search(t *testing.T) {
ctx := context.Background()
collectionID := UniqueID(0)
node := newQueryNodeMock()
......@@ -108,7 +106,7 @@ func TestSearch_Search(t *testing.T) {
searchStream, _ := msFactory.NewMsgStream(node.queryNodeLoopCtx)
searchStream.AsProducer(searchProducerChannels)
searchStream.Start()
err = searchStream.Produce(ctx, &msgPackSearch)
err = searchStream.Produce(&msgPackSearch)
assert.NoError(t, err)
node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica, msFactory)
......@@ -203,12 +201,12 @@ func TestSearch_Search(t *testing.T) {
var ddMsgStream msgstream.MsgStream = ddStream
ddMsgStream.Start()
err = insertMsgStream.Produce(ctx, &msgPack)
err = insertMsgStream.Produce(&msgPack)
assert.NoError(t, err)
err = insertMsgStream.Broadcast(ctx, &timeTickMsgPack)
err = insertMsgStream.Broadcast(&timeTickMsgPack)
assert.NoError(t, err)
err = ddMsgStream.Broadcast(ctx, &timeTickMsgPack)
err = ddMsgStream.Broadcast(&timeTickMsgPack)
assert.NoError(t, err)
// dataSync
......@@ -221,8 +219,6 @@ func TestSearch_Search(t *testing.T) {
}
func TestSearch_SearchMultiSegments(t *testing.T) {
ctx := context.Background()
collectionID := UniqueID(0)
pulsarURL := Params.PulsarAddress
......@@ -310,7 +306,7 @@ func TestSearch_SearchMultiSegments(t *testing.T) {
searchStream, _ := msFactory.NewMsgStream(node.queryNodeLoopCtx)
searchStream.AsProducer(searchProducerChannels)
searchStream.Start()
err = searchStream.Produce(ctx, &msgPackSearch)
err = searchStream.Produce(&msgPackSearch)
assert.NoError(t, err)
node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica, msFactory)
......@@ -409,12 +405,12 @@ func TestSearch_SearchMultiSegments(t *testing.T) {
var ddMsgStream msgstream.MsgStream = ddStream
ddMsgStream.Start()
err = insertMsgStream.Produce(ctx, &msgPack)
err = insertMsgStream.Produce(&msgPack)
assert.NoError(t, err)
err = insertMsgStream.Broadcast(ctx, &timeTickMsgPack)
err = insertMsgStream.Broadcast(&timeTickMsgPack)
assert.NoError(t, err)
err = ddMsgStream.Broadcast(ctx, &timeTickMsgPack)
err = ddMsgStream.Broadcast(&timeTickMsgPack)
assert.NoError(t, err)
// dataSync
......
......@@ -91,7 +91,7 @@ func (sService *statsService) publicStatistic(fieldStats []*internalpb.FieldStat
var msgPack = msgstream.MsgPack{
Msgs: []msgstream.TsMsg{msg},
}
err := sService.statsStream.Produce(context.TODO(), &msgPack)
err := sService.statsStream.Produce(&msgPack)
if err != nil {
log.Error(err.Error())
}
......
......@@ -97,7 +97,7 @@ func (ttBarrier *softTimeTickBarrier) Start() {
return
default:
}
ttmsgs, _ := ttBarrier.ttStream.Consume()
ttmsgs := ttBarrier.ttStream.Consume()
if len(ttmsgs.Msgs) > 0 {
for _, timetickmsg := range ttmsgs.Msgs {
ttmsg := timetickmsg.(*ms.TimeTickMsg)
......@@ -161,7 +161,7 @@ func (ttBarrier *hardTimeTickBarrier) Start() {
return
default:
}
ttmsgs, _ := ttBarrier.ttStream.Consume()
ttmsgs := ttBarrier.ttStream.Consume()
if len(ttmsgs.Msgs) > 0 {
log.Debug("receive tt msg")
for _, timetickmsg := range ttmsgs.Msgs {
......
......@@ -41,7 +41,7 @@ func (watcher *MsgTimeTickWatcher) StartBackgroundLoop(ctx context.Context) {
msgPack := &ms.MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, msg)
for _, stream := range watcher.streams {
if err := stream.Broadcast(ctx, msgPack); err != nil {
if err := stream.Broadcast(msgPack); err != nil {
log.Warn("stream broadcast failed", zap.Error(err))
}
}
......
......@@ -17,7 +17,7 @@ func (fg *TimeTickedFlowGraph) AddNode(node Node) {
nodeName := node.Name()
nodeCtx := nodeCtx{
node: node,
inputChannels: make([]chan *MsgWithCtx, 0),
inputChannels: make([]chan Msg, 0),
downstreamInputChanIdx: make(map[string]int),
}
fg.nodeCtx[nodeName] = &nodeCtx
......@@ -51,7 +51,7 @@ func (fg *TimeTickedFlowGraph) SetEdges(nodeName string, in []string, out []stri
return errors.New(errMsg)
}
maxQueueLength := outNode.node.MaxQueueLength()
outNode.inputChannels = append(outNode.inputChannels, make(chan *MsgWithCtx, maxQueueLength))
outNode.inputChannels = append(outNode.inputChannels, make(chan Msg, maxQueueLength))
currentNode.downstream[i] = outNode
}
......
......@@ -68,43 +68,43 @@ func (a *nodeA) Name() string {
return "NodeA"
}
func (a *nodeA) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
return append(in, in...), nil
func (a *nodeA) Operate(in []Msg) []Msg {
return append(in, in...)
}
func (b *nodeB) Name() string {
return "NodeB"
}
func (b *nodeB) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
func (b *nodeB) Operate(in []Msg) []Msg {
messages := make([]*intMsg, 0)
for _, msg := range msg2IntMsg(in) {
messages = append(messages, &intMsg{
num: math.Pow(msg.num, 2),
})
}
return intMsg2Msg(messages), nil
return intMsg2Msg(messages)
}
func (c *nodeC) Name() string {
return "NodeC"
}
func (c *nodeC) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
func (c *nodeC) Operate(in []Msg) []Msg {
messages := make([]*intMsg, 0)
for _, msg := range msg2IntMsg(in) {
messages = append(messages, &intMsg{
num: math.Sqrt(msg.num),
})
}
return intMsg2Msg(messages), nil
return intMsg2Msg(messages)
}
func (d *nodeD) Name() string {
return "NodeD"
}
func (d *nodeD) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
func (d *nodeD) Operate(in []Msg) []Msg {
messages := make([]*intMsg, 0)
outLength := len(in) / 2
inMessages := msg2IntMsg(in)
......@@ -117,7 +117,7 @@ func (d *nodeD) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context)
d.d = messages[0].num
d.resChan <- d.d
fmt.Println("flow graph result:", d.d)
return intMsg2Msg(messages), nil
return intMsg2Msg(messages)
}
func sendMsgFromCmd(ctx context.Context, fg *TimeTickedFlowGraph) {
......@@ -129,12 +129,8 @@ func sendMsgFromCmd(ctx context.Context, fg *TimeTickedFlowGraph) {
time.Sleep(time.Millisecond * time.Duration(500))
var num = float64(rand.Int() % 100)
var msg Msg = &intMsg{num: num}
var msgWithContext = &MsgWithCtx{
ctx: ctx,
msg: msg,
}
a := nodeA{}
fg.nodeCtx[a.Name()].inputChannels[0] <- msgWithContext
fg.nodeCtx[a.Name()].inputChannels[0] <- msg
fmt.Println("send number", num, "to node", a.Name())
res, ok := receiveResult(ctx, fg)
if !ok {
......@@ -254,7 +250,7 @@ func TestTimeTickedFlowGraph_Start(t *testing.T) {
// init node A
nodeCtxA := fg.nodeCtx[a.Name()]
nodeCtxA.inputChannels = []chan *MsgWithCtx{make(chan *MsgWithCtx, 10)}
nodeCtxA.inputChannels = []chan Msg{make(chan Msg, 10)}
go fg.Start()
......
package flowgraph
import (
"context"
"github.com/opentracing/opentracing-go"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/util/trace"
)
......@@ -28,17 +25,20 @@ func (inNode *InputNode) InStream() *msgstream.MsgStream {
}
// empty input and return one *Msg
func (inNode *InputNode) Operate(ctx context.Context, msgs []Msg) ([]Msg, context.Context) {
func (inNode *InputNode) Operate(in []Msg) []Msg {
//fmt.Println("Do InputNode operation")
msgPack, ctx := (*inNode.inStream).Consume()
sp, ctx := trace.StartSpanFromContext(ctx, opentracing.Tag{Key: "NodeName", Value: inNode.Name()})
defer sp.Finish()
msgPack := (*inNode.inStream).Consume()
// TODO: add status
if msgPack == nil {
return nil, ctx
return nil
}
var spans []opentracing.Span
for _, msg := range msgPack.Msgs {
sp, ctx := trace.StartSpanFromContext(msg.TraceCtx())
spans = append(spans, sp)
msg.SetTraceCtx(ctx)
}
var msgStreamMsg Msg = &MsgStreamMsg{
......@@ -49,7 +49,11 @@ func (inNode *InputNode) Operate(ctx context.Context, msgs []Msg) ([]Msg, contex
endPositions: msgPack.EndPositions,
}
return []Msg{msgStreamMsg}, ctx
for _, span := range spans {
span.Finish()
}
return []Msg{msgStreamMsg}
}
func NewInputNode(inStream *msgstream.MsgStream, nodeName string, maxQueueLength int32, maxParallelism int32) *InputNode {
......
package flowgraph
import "github.com/zilliztech/milvus-distributed/internal/msgstream"
import (
"github.com/zilliztech/milvus-distributed/internal/msgstream"
)
type Msg interface {
TimeTick() Timestamp
......
......@@ -6,16 +6,13 @@ import (
"log"
"sync"
"time"
"github.com/opentracing/opentracing-go"
"github.com/zilliztech/milvus-distributed/internal/util/trace"
)
type Node interface {
Name() string
MaxQueueLength() int32
MaxParallelism() int32
Operate(ctx context.Context, in []Msg) ([]Msg, context.Context)
Operate(in []Msg) []Msg
IsInputNode() bool
}
......@@ -26,7 +23,7 @@ type BaseNode struct {
type nodeCtx struct {
node Node
inputChannels []chan *MsgWithCtx
inputChannels []chan Msg
inputMessages []Msg
downstream []*nodeCtx
downstreamInputChanIdx map[string]int
......@@ -35,11 +32,6 @@ type nodeCtx struct {
NumCompletedTasks int64
}
type MsgWithCtx struct {
ctx context.Context
msg Msg
}
func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) {
if nodeCtx.node.IsInputNode() {
// fmt.Println("start InputNode.inStream")
......@@ -60,17 +52,13 @@ func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) {
// inputs from inputsMessages for Operate
inputs := make([]Msg, 0)
var msgCtx context.Context
var res []Msg
var sp opentracing.Span
if !nodeCtx.node.IsInputNode() {
msgCtx = nodeCtx.collectInputMessages(ctx)
nodeCtx.collectInputMessages(ctx)
inputs = nodeCtx.inputMessages
}
n := nodeCtx.node
res, msgCtx = n.Operate(msgCtx, inputs)
sp, msgCtx = trace.StartSpanFromContext(msgCtx)
sp.SetTag("node name", n.Name())
res = n.Operate(inputs)
downstreamLength := len(nodeCtx.downstreamInputChanIdx)
if len(nodeCtx.downstream) < downstreamLength {
......@@ -84,10 +72,9 @@ func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) {
w := sync.WaitGroup{}
for i := 0; i < downstreamLength; i++ {
w.Add(1)
go nodeCtx.downstream[i].ReceiveMsg(msgCtx, &w, res[i], nodeCtx.downstreamInputChanIdx[nodeCtx.downstream[i].node.Name()])
go nodeCtx.downstream[i].ReceiveMsg(&w, res[i], nodeCtx.downstreamInputChanIdx[nodeCtx.downstream[i].node.Name()])
}
w.Wait()
sp.Finish()
}
}
}
......@@ -99,18 +86,14 @@ func (nodeCtx *nodeCtx) Close() {
}
}
func (nodeCtx *nodeCtx) ReceiveMsg(ctx context.Context, wg *sync.WaitGroup, msg Msg, inputChanIdx int) {
sp, ctx := trace.StartSpanFromContext(ctx)
defer sp.Finish()
nodeCtx.inputChannels[inputChanIdx] <- &MsgWithCtx{ctx: ctx, msg: msg}
func (nodeCtx *nodeCtx) ReceiveMsg(wg *sync.WaitGroup, msg Msg, inputChanIdx int) {
nodeCtx.inputChannels[inputChanIdx] <- msg
//fmt.Println((*nodeCtx.node).Name(), "receive to input channel ", inputChanIdx)
wg.Done()
}
func (nodeCtx *nodeCtx) collectInputMessages(exitCtx context.Context) context.Context {
var opts []opentracing.StartSpanOption
func (nodeCtx *nodeCtx) collectInputMessages(exitCtx context.Context) {
inputsNum := len(nodeCtx.inputChannels)
nodeCtx.inputMessages = make([]Msg, inputsNum)
......@@ -121,29 +104,17 @@ func (nodeCtx *nodeCtx) collectInputMessages(exitCtx context.Context) context.Co
channel := nodeCtx.inputChannels[i]
select {
case <-exitCtx.Done():
return nil
case msgWithCtx, ok := <-channel:
return
case msg, ok := <-channel:
if !ok {
// TODO: add status
log.Println("input channel closed")
return nil
}
nodeCtx.inputMessages[i] = msgWithCtx.msg
if msgWithCtx.ctx != nil {
sp, _ := trace.StartSpanFromContext(msgWithCtx.ctx)
opts = append(opts, opentracing.ChildOf(sp.Context()))
sp.Finish()
return
}
nodeCtx.inputMessages[i] = msg
}
}
var ctx context.Context
var sp opentracing.Span
if len(opts) != 0 {
sp, ctx = trace.StartSpanFromContext(context.Background(), opts...)
defer sp.Finish()
}
// timeTick alignment check
if len(nodeCtx.inputMessages) > 1 {
t := nodeCtx.inputMessages[0].TimeTick()
......@@ -169,7 +140,7 @@ func (nodeCtx *nodeCtx) collectInputMessages(exitCtx context.Context) context.Co
log.Println("input channel closed")
return
}
nodeCtx.inputMessages[i] = msg.msg
nodeCtx.inputMessages[i] = msg
}
}
}
......@@ -183,7 +154,6 @@ func (nodeCtx *nodeCtx) collectInputMessages(exitCtx context.Context) context.Co
}
}
return ctx
}
func (node *BaseNode) MaxQueueLength() int32 {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册