search_service.go 7.0 KB
Newer Older
N
neza2017 已提交
1 2
package reader

X
xige-16 已提交
3
import "C"
N
neza2017 已提交
4 5
import (
	"context"
X
xige-16 已提交
6
	"errors"
B
bigsheeper 已提交
7
	"fmt"
D
dragondriver 已提交
8
	"log"
X
xige-16 已提交
9 10 11
	"sort"

	"github.com/golang/protobuf/proto"
C
cai.zhang 已提交
12

N
neza2017 已提交
13
	"github.com/zilliztech/milvus-distributed/internal/msgstream"
X
xige-16 已提交
14 15 16
	"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
	"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
	"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
N
neza2017 已提交
17 18 19
)

type searchService struct {
X
xige-16 已提交
20 21
	ctx    context.Context
	cancel context.CancelFunc
22

23 24 25
	replica      *collectionReplica
	tSafeWatcher *tSafeWatcher

26 27
	searchMsgStream       *msgstream.MsgStream
	searchResultMsgStream *msgstream.MsgStream
B
bigsheeper 已提交
28 29
}

30 31 32 33 34 35 36
type ResultEntityIds []UniqueID

type SearchResult struct {
	ResultIds       []UniqueID
	ResultDistances []float32
}

D
dragondriver 已提交
37
func newSearchService(ctx context.Context, replica *collectionReplica) *searchService {
N
neza2017 已提交
38
	receiveBufSize := Params.searchMsgStreamReceiveBufSize()
F
FluorineDog 已提交
39
	pulsarBufSize := Params.searchPulsarBufSize()
B
bigsheeper 已提交
40

N
neza2017 已提交
41
	msgStreamURL, err := Params.PulsarAddress()
D
dragondriver 已提交
42 43 44 45
	if err != nil {
		log.Fatal(err)
	}

46
	consumeChannels := []string{"search"}
X
xige-16 已提交
47 48
	consumeSubName := "subSearch"
	searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
N
neza2017 已提交
49
	searchStream.SetPulsarClient(msgStreamURL)
50 51
	unmarshalDispatcher := msgstream.NewUnmarshalDispatcher()
	searchStream.CreatePulsarConsumers(consumeChannels, consumeSubName, unmarshalDispatcher, pulsarBufSize)
X
xige-16 已提交
52
	var inputStream msgstream.MsgStream = searchStream
B
bigsheeper 已提交
53

54
	producerChannels := []string{"searchResult"}
X
xige-16 已提交
55
	searchResultStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
N
neza2017 已提交
56
	searchResultStream.SetPulsarClient(msgStreamURL)
57
	searchResultStream.CreatePulsarProducers(producerChannels)
X
xige-16 已提交
58
	var outputStream msgstream.MsgStream = searchResultStream
B
bigsheeper 已提交
59

X
xige-16 已提交
60 61 62 63
	searchServiceCtx, searchServiceCancel := context.WithCancel(ctx)
	return &searchService{
		ctx:    searchServiceCtx,
		cancel: searchServiceCancel,
B
bigsheeper 已提交
64

65 66 67
		replica:      replica,
		tSafeWatcher: newTSafeWatcher(),

X
xige-16 已提交
68 69 70 71
		searchMsgStream:       &inputStream,
		searchResultMsgStream: &outputStream,
	}
}
B
bigsheeper 已提交
72

X
xige-16 已提交
73
func (ss *searchService) start() {
74
	(*ss.searchMsgStream).Start()
X
xige-16 已提交
75 76 77
	(*ss.searchResultMsgStream).Start()

	go func() {
X
xige-16 已提交
78
		for {
X
xige-16 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
			select {
			case <-ss.ctx.Done():
				return
			default:
				msgPack := (*ss.searchMsgStream).Consume()
				if msgPack == nil || len(msgPack.Msgs) <= 0 {
					continue
				}
				// TODO: add serviceTime check
				err := ss.search(msgPack.Msgs)
				if err != nil {
					fmt.Println("search Failed")
					ss.publishFailedSearchResult()
				}
				fmt.Println("Do search done")
B
bigsheeper 已提交
94 95
			}
		}
X
xige-16 已提交
96 97 98 99 100 101 102
	}()
}

func (ss *searchService) close() {
	(*ss.searchMsgStream).Close()
	(*ss.searchResultMsgStream).Close()
	ss.cancel()
103
}
B
bigsheeper 已提交
104

105 106 107 108 109 110 111 112 113 114 115 116
func (ss *searchService) register() {
	tSafe := (*(ss.replica)).getTSafe()
	(*tSafe).registerTSafeWatcher(ss.tSafeWatcher)
}

func (ss *searchService) waitNewTSafe() Timestamp {
	// block until dataSyncService updating tSafe
	ss.tSafeWatcher.hasUpdate()
	timestamp := (*(*ss.replica).getTSafe()).get()
	return timestamp
}

X
xige-16 已提交
117
func (ss *searchService) search(searchMessages []msgstream.TsMsg) error {
118

X
xige-16 已提交
119 120 121 122 123 124 125
	type SearchResult struct {
		ResultID       int64
		ResultDistance float32
	}
	// TODO:: cache map[dsl]plan
	// TODO: reBatched search requests
	for _, msg := range searchMessages {
X
xige-16 已提交
126
		searchMsg, ok := msg.(*msgstream.SearchMsg)
X
xige-16 已提交
127
		if !ok {
X
xige-16 已提交
128
			return errors.New("invalid request type = " + string(msg.Type()))
X
xige-16 已提交
129
		}
130

X
xige-16 已提交
131
		searchTimestamp := searchMsg.Timestamp
B
bigsheeper 已提交
132

X
xige-16 已提交
133 134 135 136 137 138 139 140 141
		// TODO:: add serviceable time
		var queryBlob = searchMsg.Query.Value
		query := servicepb.Query{}
		err := proto.Unmarshal(queryBlob, &query)
		if err != nil {
			return errors.New("unmarshal query failed")
		}
		collectionName := query.CollectionName
		partitionTags := query.PartitionTags
D
dragondriver 已提交
142
		collection, err := (*ss.replica).getCollectionByName(collectionName)
X
xige-16 已提交
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
		if err != nil {
			return err
		}
		collectionID := collection.ID()
		dsl := query.Dsl
		plan := CreatePlan(*collection, dsl)
		topK := plan.GetTopK()
		placeHolderGroupBlob := query.PlaceholderGroup
		group := servicepb.PlaceholderGroup{}
		err = proto.Unmarshal(placeHolderGroupBlob, &group)
		if err != nil {
			return err
		}
		placeholderGroup := ParserPlaceholderGroup(plan, placeHolderGroupBlob)
		placeholderGroups := make([]*PlaceholderGroup, 0)
		placeholderGroups = append(placeholderGroups, placeholderGroup)

		// 2d slice for receiving multiple queries's results
		var numQueries int64 = 0
		for _, pg := range placeholderGroups {
			numQueries += pg.GetNumOfQuery()
		}
		var searchResults = make([][]SearchResult, numQueries)
		for i := 0; i < int(numQueries); i++ {
			searchResults[i] = make([]SearchResult, 0)
		}
169

X
xige-16 已提交
170 171
		// 3. Do search in all segments
		for _, partitionTag := range partitionTags {
D
dragondriver 已提交
172
			partition, err := (*ss.replica).getPartitionByTag(collectionID, partitionTag)
X
xige-16 已提交
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
			if err != nil {
				return err
			}
			for _, segment := range partition.segments {
				res, err := segment.segmentSearch(plan, placeholderGroups, []Timestamp{searchTimestamp}, numQueries, topK)
				if err != nil {
					return err
				}
				for i := 0; int64(i) < numQueries; i++ {
					for j := int64(i) * topK; j < int64(i+1)*topK; j++ {
						searchResults[i] = append(searchResults[i], SearchResult{
							ResultID:       res.ResultIds[j],
							ResultDistance: res.ResultDistances[j],
						})
					}
				}
			}
		}

		// 4. Reduce results
		// TODO::reduce in c++ merge_into func
		for _, temp := range searchResults {
			sort.Slice(temp, func(i, j int) bool {
				return temp[i].ResultDistance < temp[j].ResultDistance
			})
		}
199

X
xige-16 已提交
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
		for i, tmp := range searchResults {
			if int64(len(tmp)) > topK {
				searchResults[i] = searchResults[i][:topK]
			}
		}

		hits := make([]*servicepb.Hits, 0)
		for _, value := range searchResults {
			hit := servicepb.Hits{}
			score := servicepb.Score{}
			for j := 0; int64(j) < topK; j++ {
				hit.IDs = append(hit.IDs, value[j].ResultID)
				score.Values = append(score.Values, value[j].ResultDistance)
			}
			hit.Scores = append(hit.Scores, &score)
			hits = append(hits, &hit)
		}

		var results = internalpb.SearchResult{
			MsgType:         internalpb.MsgType_kSearchResult,
			Status:          &commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS},
			ReqID:           searchMsg.ReqID,
			ProxyID:         searchMsg.ProxyID,
			QueryNodeID:     searchMsg.ProxyID,
			Timestamp:       searchTimestamp,
			ResultChannelID: searchMsg.ResultChannelID,
			Hits:            hits,
		}

		var tsMsg msgstream.TsMsg = &msgstream.SearchResultMsg{SearchResult: results}
X
xige-16 已提交
230 231 232
		ss.publishSearchResult(tsMsg)
		plan.Delete()
		placeholderGroup.Delete()
X
xige-16 已提交
233 234 235
	}

	return nil
N
neza2017 已提交
236 237
}

X
xige-16 已提交
238
func (ss *searchService) publishSearchResult(res msgstream.TsMsg) {
X
xige-16 已提交
239 240 241 242
	msgPack := msgstream.MsgPack{}
	msgPack.Msgs = append(msgPack.Msgs, res)
	(*ss.searchResultMsgStream).Produce(&msgPack)
}
B
bigsheeper 已提交
243

X
xige-16 已提交
244 245 246 247
func (ss *searchService) publishFailedSearchResult() {
	var errorResults = internalpb.SearchResult{
		MsgType: internalpb.MsgType_kSearchResult,
		Status:  &commonpb.Status{ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR},
B
bigsheeper 已提交
248 249
	}

X
xige-16 已提交
250 251
	var tsMsg msgstream.TsMsg = &msgstream.SearchResultMsg{SearchResult: errorResults}
	msgPack := msgstream.MsgPack{}
X
xige-16 已提交
252
	msgPack.Msgs = append(msgPack.Msgs, tsMsg)
X
xige-16 已提交
253
	(*ss.searchResultMsgStream).Produce(&msgPack)
B
bigsheeper 已提交
254
}