flow_graph_filter_dm_node.go 4.5 KB
Newer Older
N
neza2017 已提交
1
package querynode
2 3

import (
G
godchen 已提交
4
	"context"
C
cai.zhang 已提交
5
	"log"
B
bigsheeper 已提交
6
	"math"
C
cai.zhang 已提交
7

G
godchen 已提交
8
	"github.com/opentracing/opentracing-go"
9
	"github.com/zilliztech/milvus-distributed/internal/msgstream"
10
	"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
11 12 13 14 15
	internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
)

type filterDmNode struct {
	BaseNode
B
bigsheeper 已提交
16
	ddMsg *ddMsg
17 18 19 20 21 22 23
}

func (fdmNode *filterDmNode) Name() string {
	return "fdmNode"
}

func (fdmNode *filterDmNode) Operate(in []*Msg) []*Msg {
C
cai.zhang 已提交
24
	//fmt.Println("Do filterDmNode operation")
25

26
	if len(in) != 2 {
C
cai.zhang 已提交
27
		log.Println("Invalid operate message input in filterDmNode, input length = ", len(in))
28 29 30
		// TODO: add error handling
	}

31
	msgStreamMsg, ok := (*in[0]).(*MsgStreamMsg)
32 33 34 35 36
	if !ok {
		log.Println("type assertion failed for MsgStreamMsg")
		// TODO: add error handling
	}

G
godchen 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
	var childs []opentracing.Span
	tracer := opentracing.GlobalTracer()
	if tracer != nil && msgStreamMsg != nil {
		for _, msg := range msgStreamMsg.TsMessages() {
			if msg.Type() == internalPb.MsgType_kInsert || msg.Type() == internalPb.MsgType_kSearch {
				var child opentracing.Span
				ctx := msg.GetMsgContext()
				if parent := opentracing.SpanFromContext(ctx); parent != nil {
					child = tracer.StartSpan("pass filter node",
						opentracing.FollowsFrom(parent.Context()))
				} else {
					child = tracer.StartSpan("pass filter node")
				}
				child.SetTag("hash keys", msg.HashKeys())
				child.SetTag("start time", msg.BeginTs())
				child.SetTag("end time", msg.EndTs())
				msg.SetMsgContext(opentracing.ContextWithSpan(ctx, child))
				childs = append(childs, child)
			}
		}
	}

59 60 61 62 63 64 65
	ddMsg, ok := (*in[1]).(*ddMsg)
	if !ok {
		log.Println("type assertion failed for ddMsg")
		// TODO: add error handling
	}
	fdmNode.ddMsg = ddMsg

66 67 68
	var iMsg = insertMsg{
		insertMessages: make([]*msgstream.InsertMsg, 0),
		timeRange: TimeRange{
69 70
			timestampMin: msgStreamMsg.TimestampMin(),
			timestampMax: msgStreamMsg.TimestampMax(),
71 72
		},
	}
G
godchen 已提交
73
	for key, msg := range msgStreamMsg.TsMessages() {
X
xige-16 已提交
74
		switch msg.Type() {
75
		case internalPb.MsgType_kInsert:
G
godchen 已提交
76 77 78 79 80 81 82 83
			var ctx2 context.Context
			if childs != nil {
				if childs[key] != nil {
					ctx2 = opentracing.ContextWithSpan(msg.GetMsgContext(), childs[key])
				} else {
					ctx2 = context.Background()
				}
			}
84 85
			resMsg := fdmNode.filterInvalidInsertMessage(msg.(*msgstream.InsertMsg))
			if resMsg != nil {
G
godchen 已提交
86
				resMsg.SetMsgContext(ctx2)
87 88
				iMsg.insertMessages = append(iMsg.insertMessages, resMsg)
			}
89 90 91
		// case internalPb.MsgType_kDelete:
		// dmMsg.deleteMessages = append(dmMsg.deleteMessages, (*msg).(*msgstream.DeleteTask))
		default:
X
xige-16 已提交
92
			log.Println("Non supporting message type:", msg.Type())
93 94 95
		}
	}

B
bigsheeper 已提交
96
	iMsg.gcRecord = ddMsg.gcRecord
97
	var res Msg = &iMsg
G
godchen 已提交
98 99 100 101

	for _, child := range childs {
		child.Finish()
	}
102 103 104
	return []*Msg{&res}
}

105 106 107 108 109 110 111
func (fdmNode *filterDmNode) filterInvalidInsertMessage(msg *msgstream.InsertMsg) *msgstream.InsertMsg {
	// No dd record, do all insert requests.
	records, ok := fdmNode.ddMsg.collectionRecords[msg.CollectionName]
	if !ok {
		return msg
	}

112 113 114 115
	// TODO: If the last record is drop type, all insert requests are invalid.
	//if !records[len(records)-1].createOrDrop {
	//	return nil
	//}
116 117 118 119 120 121 122

	// Filter insert requests before last record.
	if len(msg.RowIDs) != len(msg.Timestamps) || len(msg.RowIDs) != len(msg.RowData) {
		// TODO: what if the messages are misaligned? Here, we ignore those messages and print error
		log.Println("Error, misaligned messages detected")
		return nil
	}
B
bigsheeper 已提交
123

124 125 126
	tmpTimestamps := make([]Timestamp, 0)
	tmpRowIDs := make([]int64, 0)
	tmpRowData := make([]*commonpb.Blob, 0)
B
bigsheeper 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139

	// calculate valid time range
	timeBegin := Timestamp(0)
	timeEnd := Timestamp(math.MaxUint64)
	for _, record := range records {
		if record.createOrDrop && timeBegin < record.timestamp {
			timeBegin = record.timestamp
		}
		if !record.createOrDrop && timeEnd > record.timestamp {
			timeEnd = record.timestamp
		}
	}

140
	for i, t := range msg.Timestamps {
B
bigsheeper 已提交
141
		if t >= timeBegin && t <= timeEnd {
142 143 144 145 146
			tmpTimestamps = append(tmpTimestamps, t)
			tmpRowIDs = append(tmpRowIDs, msg.RowIDs[i])
			tmpRowData = append(tmpRowData, msg.RowData[i])
		}
	}
B
bigsheeper 已提交
147 148 149 150 151

	if len(tmpRowIDs) <= 0 {
		return nil
	}

152 153 154 155 156 157
	msg.Timestamps = tmpTimestamps
	msg.RowIDs = tmpRowIDs
	msg.RowData = tmpRowData
	return msg
}

158
func newFilteredDmNode() *filterDmNode {
159 160
	maxQueueLength := Params.FlowGraphMaxQueueLength
	maxParallelism := Params.FlowGraphMaxParallelism
F
FluorineDog 已提交
161

162 163 164 165 166 167 168 169
	baseNode := BaseNode{}
	baseNode.SetMaxQueueLength(maxQueueLength)
	baseNode.SetMaxParallelism(maxParallelism)

	return &filterDmNode{
		BaseNode: baseNode,
	}
}