search_service.go 12.4 KB
Newer Older
N
neza2017 已提交
1 2 3 4 5 6
package querynode

import "C"
import (
	"context"
	"errors"
X
XuanYang-cn 已提交
7
	"fmt"
G
godchen 已提交
8 9
	"github.com/opentracing/opentracing-go"
	oplog "github.com/opentracing/opentracing-go/log"
N
neza2017 已提交
10 11 12
	"log"
	"sync"

X
XuanYang-cn 已提交
13 14
	"github.com/golang/protobuf/proto"

N
neza2017 已提交
15 16 17 18 19 20 21 22 23 24 25
	"github.com/zilliztech/milvus-distributed/internal/msgstream"
	"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
	"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
	"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
)

type searchService struct {
	ctx    context.Context
	wait   sync.WaitGroup
	cancel context.CancelFunc

X
XuanYang-cn 已提交
26
	replica      collectionReplica
N
neza2017 已提交
27 28 29 30 31 32 33
	tSafeWatcher *tSafeWatcher

	serviceableTime      Timestamp
	serviceableTimeMutex sync.Mutex

	msgBuffer             chan msgstream.TsMsg
	unsolvedMsg           []msgstream.TsMsg
X
XuanYang-cn 已提交
34 35 36
	searchMsgStream       msgstream.MsgStream
	searchResultMsgStream msgstream.MsgStream
	queryNodeID           UniqueID
N
neza2017 已提交
37 38 39 40
}

type ResultEntityIds []UniqueID

X
XuanYang-cn 已提交
41
func newSearchService(ctx context.Context, replica collectionReplica) *searchService {
42 43
	receiveBufSize := Params.SearchReceiveBufSize
	pulsarBufSize := Params.SearchPulsarBufSize
N
neza2017 已提交
44

45
	msgStreamURL := Params.PulsarAddress
N
neza2017 已提交
46

47 48
	consumeChannels := Params.SearchChannelNames
	consumeSubName := Params.MsgChannelSubName
N
neza2017 已提交
49 50 51 52 53 54
	searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
	searchStream.SetPulsarClient(msgStreamURL)
	unmarshalDispatcher := msgstream.NewUnmarshalDispatcher()
	searchStream.CreatePulsarConsumers(consumeChannels, consumeSubName, unmarshalDispatcher, pulsarBufSize)
	var inputStream msgstream.MsgStream = searchStream

55
	producerChannels := Params.SearchResultChannelNames
N
neza2017 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
	searchResultStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
	searchResultStream.SetPulsarClient(msgStreamURL)
	searchResultStream.CreatePulsarProducers(producerChannels)
	var outputStream msgstream.MsgStream = searchResultStream

	searchServiceCtx, searchServiceCancel := context.WithCancel(ctx)
	msgBuffer := make(chan msgstream.TsMsg, receiveBufSize)
	unsolvedMsg := make([]msgstream.TsMsg, 0)
	return &searchService{
		ctx:             searchServiceCtx,
		cancel:          searchServiceCancel,
		serviceableTime: Timestamp(0),
		msgBuffer:       msgBuffer,
		unsolvedMsg:     unsolvedMsg,

		replica:      replica,
		tSafeWatcher: newTSafeWatcher(),

X
XuanYang-cn 已提交
74 75
		searchMsgStream:       inputStream,
		searchResultMsgStream: outputStream,
76
		queryNodeID:           Params.QueryNodeID,
N
neza2017 已提交
77 78 79 80
	}
}

func (ss *searchService) start() {
X
XuanYang-cn 已提交
81 82
	ss.searchMsgStream.Start()
	ss.searchResultMsgStream.Start()
N
neza2017 已提交
83 84 85 86 87 88 89 90
	ss.register()
	ss.wait.Add(2)
	go ss.receiveSearchMsg()
	go ss.doUnsolvedMsgSearch()
	ss.wait.Wait()
}

func (ss *searchService) close() {
X
XuanYang-cn 已提交
91 92 93 94 95 96
	if ss.searchMsgStream != nil {
		ss.searchMsgStream.Close()
	}
	if ss.searchResultMsgStream != nil {
		ss.searchResultMsgStream.Close()
	}
N
neza2017 已提交
97 98 99 100
	ss.cancel()
}

func (ss *searchService) register() {
X
XuanYang-cn 已提交
101 102
	tSafe := ss.replica.getTSafe()
	tSafe.registerTSafeWatcher(ss.tSafeWatcher)
N
neza2017 已提交
103 104 105 106 107
}

func (ss *searchService) waitNewTSafe() Timestamp {
	// block until dataSyncService updating tSafe
	ss.tSafeWatcher.hasUpdate()
X
XuanYang-cn 已提交
108
	timestamp := ss.replica.getTSafe().get()
N
neza2017 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
	return timestamp
}

func (ss *searchService) getServiceableTime() Timestamp {
	ss.serviceableTimeMutex.Lock()
	defer ss.serviceableTimeMutex.Unlock()
	return ss.serviceableTime
}

func (ss *searchService) setServiceableTime(t Timestamp) {
	ss.serviceableTimeMutex.Lock()
	// TODO:: add gracefulTime
	ss.serviceableTime = t
	ss.serviceableTimeMutex.Unlock()
}

func (ss *searchService) receiveSearchMsg() {
	defer ss.wait.Done()
	for {
		select {
		case <-ss.ctx.Done():
			return
		default:
X
XuanYang-cn 已提交
132
			msgPack := ss.searchMsgStream.Consume()
N
neza2017 已提交
133 134 135 136 137
			if msgPack == nil || len(msgPack.Msgs) <= 0 {
				continue
			}
			searchMsg := make([]msgstream.TsMsg, 0)
			serverTime := ss.getServiceableTime()
G
godchen 已提交
138 139 140
			for i, msg := range msgPack.Msgs {
				if msg.BeginTs() > serverTime {
					ss.msgBuffer <- msg
N
neza2017 已提交
141 142 143 144 145
					continue
				}
				searchMsg = append(searchMsg, msgPack.Msgs[i])
			}
			for _, msg := range searchMsg {
G
godchen 已提交
146 147
				span, ctx := opentracing.StartSpanFromContext(msg.GetMsgContext(), "receive search msg")
				msg.SetMsgContext(ctx)
N
neza2017 已提交
148 149
				err := ss.search(msg)
				if err != nil {
Z
zhenshan.cao 已提交
150
					log.Println(err)
G
godchen 已提交
151
					span.LogFields(oplog.Error(err))
Z
zhenshan.cao 已提交
152 153
					err2 := ss.publishFailedSearchResult(msg, err.Error())
					if err2 != nil {
G
godchen 已提交
154
						span.LogFields(oplog.Error(err2))
Z
zhenshan.cao 已提交
155
						log.Println("publish FailedSearchResult failed, error message: ", err2)
B
bigsheeper 已提交
156
					}
N
neza2017 已提交
157
				}
G
godchen 已提交
158
				span.Finish()
N
neza2017 已提交
159
			}
C
cai.zhang 已提交
160
			log.Println("ReceiveSearchMsg, do search done, num of searchMsg = ", len(searchMsg))
N
neza2017 已提交
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
		}
	}
}

func (ss *searchService) doUnsolvedMsgSearch() {
	defer ss.wait.Done()
	for {
		select {
		case <-ss.ctx.Done():
			return
		default:
			serviceTime := ss.waitNewTSafe()
			ss.setServiceableTime(serviceTime)
			searchMsg := make([]msgstream.TsMsg, 0)
			tempMsg := make([]msgstream.TsMsg, 0)
			tempMsg = append(tempMsg, ss.unsolvedMsg...)
			ss.unsolvedMsg = ss.unsolvedMsg[:0]
			for _, msg := range tempMsg {
				if msg.EndTs() <= serviceTime {
					searchMsg = append(searchMsg, msg)
					continue
				}
				ss.unsolvedMsg = append(ss.unsolvedMsg, msg)
			}

B
bigsheeper 已提交
186
			for {
C
cai.zhang 已提交
187 188 189 190
				msgBufferLength := len(ss.msgBuffer)
				if msgBufferLength <= 0 {
					break
				}
N
neza2017 已提交
191 192 193 194 195 196 197
				msg := <-ss.msgBuffer
				if msg.EndTs() <= serviceTime {
					searchMsg = append(searchMsg, msg)
					continue
				}
				ss.unsolvedMsg = append(ss.unsolvedMsg, msg)
			}
B
bigsheeper 已提交
198

N
neza2017 已提交
199 200 201 202 203 204
			if len(searchMsg) <= 0 {
				continue
			}
			for _, msg := range searchMsg {
				err := ss.search(msg)
				if err != nil {
Z
zhenshan.cao 已提交
205
					log.Println(err)
Z
zhenshan.cao 已提交
206 207 208
					err2 := ss.publishFailedSearchResult(msg, err.Error())
					if err2 != nil {
						log.Println("publish FailedSearchResult failed, error message: ", err2)
B
bigsheeper 已提交
209
					}
N
neza2017 已提交
210 211
				}
			}
C
cai.zhang 已提交
212
			log.Println("doUnsolvedMsgSearch, do search done, num of searchMsg = ", len(searchMsg))
N
neza2017 已提交
213 214 215 216 217 218 219
		}
	}
}

// TODO:: cache map[dsl]plan
// TODO: reBatched search requests
func (ss *searchService) search(msg msgstream.TsMsg) error {
G
godchen 已提交
220 221 222
	span, ctx := opentracing.StartSpanFromContext(msg.GetMsgContext(), "do search")
	defer span.Finish()
	msg.SetMsgContext(ctx)
N
neza2017 已提交
223 224
	searchMsg, ok := msg.(*msgstream.SearchMsg)
	if !ok {
G
godchen 已提交
225
		span.LogFields(oplog.Error(errors.New("invalid request type = " + string(msg.Type()))))
N
neza2017 已提交
226 227 228 229 230 231 232 233
		return errors.New("invalid request type = " + string(msg.Type()))
	}

	searchTimestamp := searchMsg.Timestamp
	var queryBlob = searchMsg.Query.Value
	query := servicepb.Query{}
	err := proto.Unmarshal(queryBlob, &query)
	if err != nil {
G
godchen 已提交
234
		span.LogFields(oplog.Error(err))
N
neza2017 已提交
235 236 237 238
		return errors.New("unmarshal query failed")
	}
	collectionName := query.CollectionName
	partitionTags := query.PartitionTags
X
XuanYang-cn 已提交
239
	collection, err := ss.replica.getCollectionByName(collectionName)
N
neza2017 已提交
240
	if err != nil {
G
godchen 已提交
241
		span.LogFields(oplog.Error(err))
N
neza2017 已提交
242 243 244 245
		return err
	}
	collectionID := collection.ID()
	dsl := query.Dsl
246 247
	plan, err := createPlan(*collection, dsl)
	if err != nil {
G
godchen 已提交
248
		span.LogFields(oplog.Error(err))
249 250
		return err
	}
N
neza2017 已提交
251
	placeHolderGroupBlob := query.PlaceholderGroup
252 253
	placeholderGroup, err := parserPlaceholderGroup(plan, placeHolderGroupBlob)
	if err != nil {
G
godchen 已提交
254
		span.LogFields(oplog.Error(err))
255 256
		return err
	}
N
neza2017 已提交
257 258 259 260
	placeholderGroups := make([]*PlaceholderGroup, 0)
	placeholderGroups = append(placeholderGroups, placeholderGroup)

	searchResults := make([]*SearchResult, 0)
G
godchen 已提交
261
	matchedSegments := make([]*Segment, 0)
N
neza2017 已提交
262 263

	for _, partitionTag := range partitionTags {
X
XuanYang-cn 已提交
264
		hasPartition := ss.replica.hasPartition(collectionID, partitionTag)
265
		if !hasPartition {
G
godchen 已提交
266
			span.LogFields(oplog.Error(errors.New("search Failed, invalid partitionTag")))
267
			return errors.New("search Failed, invalid partitionTag")
N
neza2017 已提交
268
		}
269 270 271
	}

	for _, partitionTag := range partitionTags {
X
XuanYang-cn 已提交
272
		partition, _ := ss.replica.getPartitionByTag(collectionID, partitionTag)
N
neza2017 已提交
273
		for _, segment := range partition.segments {
C
cai.zhang 已提交
274 275
			//fmt.Println("dsl = ", dsl)

N
neza2017 已提交
276
			searchResult, err := segment.segmentSearch(plan, placeholderGroups, []Timestamp{searchTimestamp})
C
cai.zhang 已提交
277

N
neza2017 已提交
278
			if err != nil {
G
godchen 已提交
279
				span.LogFields(oplog.Error(err))
N
neza2017 已提交
280 281 282
				return err
			}
			searchResults = append(searchResults, searchResult)
G
godchen 已提交
283
			matchedSegments = append(matchedSegments, segment)
N
neza2017 已提交
284 285 286
		}
	}

C
cai.zhang 已提交
287
	if len(searchResults) <= 0 {
288 289 290 291 292
		var results = internalpb.SearchResult{
			MsgType:         internalpb.MsgType_kSearchResult,
			Status:          &commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS},
			ReqID:           searchMsg.ReqID,
			ProxyID:         searchMsg.ProxyID,
X
XuanYang-cn 已提交
293
			QueryNodeID:     ss.queryNodeID,
294 295 296 297 298
			Timestamp:       searchTimestamp,
			ResultChannelID: searchMsg.ResultChannelID,
			Hits:            nil,
		}
		searchResultMsg := &msgstream.SearchResultMsg{
G
godchen 已提交
299 300 301 302
			BaseMsg: msgstream.BaseMsg{
				MsgCtx:     searchMsg.MsgCtx,
				HashValues: []uint32{uint32(searchMsg.ResultChannelID)},
			},
303 304 305 306
			SearchResult: results,
		}
		err = ss.publishSearchResult(searchResultMsg)
		if err != nil {
G
godchen 已提交
307
			span.LogFields(oplog.Error(err))
308 309
			return err
		}
G
godchen 已提交
310
		span.LogFields(oplog.String("publish search research success", "publish search research success"))
311
		return nil
C
cai.zhang 已提交
312 313
	}

G
godchen 已提交
314 315
	inReduced := make([]bool, len(searchResults))
	numSegment := int64(len(searchResults))
B
bigsheeper 已提交
316 317
	err2 := reduceSearchResults(searchResults, numSegment, inReduced)
	if err2 != nil {
G
godchen 已提交
318
		span.LogFields(oplog.Error(err2))
B
bigsheeper 已提交
319
		return err2
G
godchen 已提交
320 321 322
	}
	err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced)
	if err != nil {
G
godchen 已提交
323
		span.LogFields(oplog.Error(err))
G
godchen 已提交
324 325 326 327
		return err
	}
	marshaledHits, err := reorganizeQueryResults(plan, placeholderGroups, searchResults, numSegment, inReduced)
	if err != nil {
G
godchen 已提交
328
		span.LogFields(oplog.Error(err))
G
godchen 已提交
329 330
		return err
	}
N
neza2017 已提交
331 332
	hitsBlob, err := marshaledHits.getHitsBlob()
	if err != nil {
G
godchen 已提交
333
		span.LogFields(oplog.Error(err))
N
neza2017 已提交
334 335 336 337 338
		return err
	}

	var offset int64 = 0
	for index := range placeholderGroups {
G
godchen 已提交
339
		hitBlobSizePeerQuery, err := marshaledHits.hitBlobSizeInGroup(int64(index))
N
neza2017 已提交
340 341 342 343
		if err != nil {
			return err
		}
		hits := make([][]byte, 0)
G
godchen 已提交
344
		for _, len := range hitBlobSizePeerQuery {
N
neza2017 已提交
345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
			hits = append(hits, hitsBlob[offset:offset+len])
			//test code to checkout marshaled hits
			//marshaledHit := hitsBlob[offset:offset+len]
			//unMarshaledHit := servicepb.Hits{}
			//err = proto.Unmarshal(marshaledHit, &unMarshaledHit)
			//if err != nil {
			//	return err
			//}
			//fmt.Println("hits msg  = ", unMarshaledHit)
			offset += len
		}
		var results = internalpb.SearchResult{
			MsgType:         internalpb.MsgType_kSearchResult,
			Status:          &commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS},
			ReqID:           searchMsg.ReqID,
			ProxyID:         searchMsg.ProxyID,
			QueryNodeID:     searchMsg.ProxyID,
			Timestamp:       searchTimestamp,
			ResultChannelID: searchMsg.ResultChannelID,
			Hits:            hits,
G
GuoRentong 已提交
365
			MetricType:      plan.getMetricType(),
N
neza2017 已提交
366 367
		}
		searchResultMsg := &msgstream.SearchResultMsg{
G
godchen 已提交
368 369 370
			BaseMsg: msgstream.BaseMsg{
				MsgCtx:     searchMsg.MsgCtx,
				HashValues: []uint32{uint32(searchMsg.ResultChannelID)}},
N
neza2017 已提交
371 372 373 374
			SearchResult: results,
		}
		err = ss.publishSearchResult(searchResultMsg)
		if err != nil {
G
godchen 已提交
375
			span.LogFields(oplog.Error(err))
N
neza2017 已提交
376 377 378 379 380 381 382 383 384 385 386 387
			return err
		}
	}

	deleteSearchResults(searchResults)
	deleteMarshaledHits(marshaledHits)
	plan.delete()
	placeholderGroup.delete()
	return nil
}

func (ss *searchService) publishSearchResult(msg msgstream.TsMsg) error {
G
godchen 已提交
388 389 390
	span, ctx := opentracing.StartSpanFromContext(msg.GetMsgContext(), "publish search result")
	defer span.Finish()
	msg.SetMsgContext(ctx)
X
XuanYang-cn 已提交
391
	fmt.Println("Public SearchResult", msg.HashKeys())
N
neza2017 已提交
392 393
	msgPack := msgstream.MsgPack{}
	msgPack.Msgs = append(msgPack.Msgs, msg)
X
XuanYang-cn 已提交
394
	err := ss.searchResultMsgStream.Produce(&msgPack)
X
XuanYang-cn 已提交
395
	return err
N
neza2017 已提交
396 397
}

G
godchen 已提交
398
func (ss *searchService) publishFailedSearchResult(msg msgstream.TsMsg, errMsg string) error {
G
godchen 已提交
399 400 401
	span, ctx := opentracing.StartSpanFromContext(msg.GetMsgContext(), "receive search msg")
	defer span.Finish()
	msg.SetMsgContext(ctx)
N
neza2017 已提交
402 403 404 405 406 407 408
	msgPack := msgstream.MsgPack{}
	searchMsg, ok := msg.(*msgstream.SearchMsg)
	if !ok {
		return errors.New("invalid request type = " + string(msg.Type()))
	}
	var results = internalpb.SearchResult{
		MsgType:         internalpb.MsgType_kSearchResult,
G
godchen 已提交
409
		Status:          &commonpb.Status{ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, Reason: errMsg},
N
neza2017 已提交
410 411 412 413 414 415 416 417 418
		ReqID:           searchMsg.ReqID,
		ProxyID:         searchMsg.ProxyID,
		QueryNodeID:     searchMsg.ProxyID,
		Timestamp:       searchMsg.Timestamp,
		ResultChannelID: searchMsg.ResultChannelID,
		Hits:            [][]byte{},
	}

	tsMsg := &msgstream.SearchResultMsg{
X
XuanYang-cn 已提交
419
		BaseMsg:      msgstream.BaseMsg{HashValues: []uint32{uint32(searchMsg.ResultChannelID)}},
N
neza2017 已提交
420 421 422
		SearchResult: results,
	}
	msgPack.Msgs = append(msgPack.Msgs, tsMsg)
X
XuanYang-cn 已提交
423
	err := ss.searchResultMsgStream.Produce(&msgPack)
N
neza2017 已提交
424 425 426 427 428 429
	if err != nil {
		return err
	}

	return nil
}