From 62bee091d6d5dd686a49c845e6ef8bd8b4c635e5 Mon Sep 17 00:00:00 2001 From: bigsheeper Date: Tue, 22 Sep 2020 11:21:19 +0800 Subject: [PATCH] Fix search failed about topK Signed-off-by: bigsheeper --- core/src/dog_segment/SegmentNaive.cpp | 1 + core/unittest/test_c_api.cpp | 2 +- reader/message_client/message_client.go | 8 +++ reader/read_node/meta.go | 7 ++- reader/read_node/query_node.go | 69 +++++++++++++------------ reader/read_node/segment.go | 29 +++++------ reader/read_node/segment_service.go | 66 ++++++++++++++--------- reader/read_node/segment_test.go | 3 +- 8 files changed, 107 insertions(+), 78 deletions(-) diff --git a/core/src/dog_segment/SegmentNaive.cpp b/core/src/dog_segment/SegmentNaive.cpp index 0500793f0..64ef13290 100644 --- a/core/src/dog_segment/SegmentNaive.cpp +++ b/core/src/dog_segment/SegmentNaive.cpp @@ -587,6 +587,7 @@ SegmentNaive::BuildIndex(IndexMetaPtr remote_index_meta) { if(record_.ack_responder_.GetAck() < 1024 * 4) { return Status(SERVER_BUILD_INDEX_ERROR, "too few elements"); } + index_meta_ = remote_index_meta; for (auto&[index_name, entry]: index_meta_->get_entries()) { assert(entry.index_name == index_name); const auto &field = (*schema_)[entry.field_name]; diff --git a/core/unittest/test_c_api.cpp b/core/unittest/test_c_api.cpp index d2144c667..d2618ac2f 100644 --- a/core/unittest/test_c_api.cpp +++ b/core/unittest/test_c_api.cpp @@ -238,7 +238,7 @@ TEST(CApiTest, BuildIndexTest) { CQueryInfo queryInfo{1, 10, "fakevec"}; auto sea_res = Search( - segment, queryInfo, 1, query_raw_data.data(), DIM, result_ids, result_distances); + segment, queryInfo, 20, query_raw_data.data(), DIM, result_ids, result_distances); assert(sea_res == 0); DeleteCollection(collection); diff --git a/reader/message_client/message_client.go b/reader/message_client/message_client.go index d64cc3cd3..64245f849 100644 --- a/reader/message_client/message_client.go +++ b/reader/message_client/message_client.go @@ -39,6 +39,14 @@ type MessageClient struct { MessageClientID int } +func (mc *MessageClient) GetTimeNow() uint64 { + msg, ok := <-mc.timeSyncCfg.TimeSync() + if !ok { + fmt.Println("cnn't get data from timesync chan") + } + return msg.Timestamp +} + func (mc *MessageClient) TimeSyncStart() uint64 { return mc.timestampBatchStart } diff --git a/reader/read_node/meta.go b/reader/read_node/meta.go index 1a65dd758..b2ff16488 100644 --- a/reader/read_node/meta.go +++ b/reader/read_node/meta.go @@ -96,7 +96,12 @@ func (node *QueryNode) processSegmentCreate(id string, value string) { if collection != nil { partition := collection.GetPartitionByName(segment.PartitionTag) if partition != nil { - partition.NewSegment(int64(segment.SegmentID)) // todo change all to uint64 + newSegmentID := int64(segment.SegmentID) // todo change all to uint64 + // start new segment and add it into partition.OpenedSegments + newSegment := partition.NewSegment(newSegmentID) + newSegment.SegmentStatus = SegmentOpened + partition.OpenedSegments = append(partition.OpenedSegments, newSegment) + node.SegmentsMap[newSegmentID] = newSegment } } // segment.CollectionName diff --git a/reader/read_node/query_node.go b/reader/read_node/query_node.go index 26bf0e1f3..663d44cdd 100644 --- a/reader/read_node/query_node.go +++ b/reader/read_node/query_node.go @@ -14,7 +14,9 @@ package reader import "C" import ( + "encoding/json" "fmt" + "log" "sort" "sync" "sync/atomic" @@ -56,6 +58,12 @@ type QueryNodeDataBuffer struct { validSearchBuffer []bool } +type QueryInfo struct { + NumQueries int64 `json:"num_queries"` + TopK int `json:"topK"` + FieldName string `json:"field_name"` +} + type QueryNode struct { QueryNodeId uint64 Collections []*Collection @@ -463,6 +471,19 @@ func (node *QueryNode) DoDelete(segmentID int64, deleteIDs *[]int64, deleteTimes return msgPb.Status{ErrorCode: msgPb.ErrorCode_SUCCESS} } +func (node *QueryNode) QueryJson2Info(queryJson *string) *QueryInfo { + var query QueryInfo + var err = json.Unmarshal([]byte(*queryJson), &query) + + if err != nil { + log.Printf("Unmarshal query json failed") + return nil + } + + fmt.Println(query) + return &query +} + func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status { // TODO: use client id to publish results to different clients // var clientId = (*(searchMessages[0])).ClientId @@ -475,16 +496,7 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status { // Traverse all messages in the current messageClient. // TODO: Do not receive batched search requests for _, msg := range searchMessages { - var collectionName = searchMessages[0].CollectionName - var targetCollection, err = node.GetCollectionByCollectionName(collectionName) - if err != nil { - fmt.Println(err.Error()) - return msgPb.Status{ErrorCode: 1} - } - var resultsTmp = make([]SearchResultTmp, 0) - // TODO: get top-k's k from queryString - const TopK = 1 var timestamp = msg.Timestamp var vector = msg.Records @@ -498,36 +510,27 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status { return msgPb.Status{ErrorCode: 1} } - // 2. Do search in all segments - for _, partition := range targetCollection.Partitions { - for _, openSegment := range partition.OpenedSegments { - var res, err = openSegment.SegmentSearch(queryJson, timestamp, vector) - if err != nil { - fmt.Println(err.Error()) - return msgPb.Status{ErrorCode: 1} - } - fmt.Println(res.ResultIds) - for i := 0; i < len(res.ResultIds); i++ { - resultsTmp = append(resultsTmp, SearchResultTmp{ResultId: res.ResultIds[i], ResultDistance: res.ResultDistances[i]}) - } + // 2. Get query information from query json + query := node.QueryJson2Info(&queryJson) + + // 3. Do search in all segments + for _, segment := range node.SegmentsMap { + var res, err = segment.SegmentSearch(query, timestamp, vector) + if err != nil { + fmt.Println(err.Error()) + return msgPb.Status{ErrorCode: 1} } - for _, closedSegment := range partition.ClosedSegments { - var res, err = closedSegment.SegmentSearch(queryJson, timestamp, vector) - if err != nil { - fmt.Println(err.Error()) - return msgPb.Status{ErrorCode: 1} - } - for i := 0; i <= len(res.ResultIds); i++ { - resultsTmp = append(resultsTmp, SearchResultTmp{ResultId: res.ResultIds[i], ResultDistance: res.ResultDistances[i]}) - } + fmt.Println(res.ResultIds) + for i := 0; i < len(res.ResultIds); i++ { + resultsTmp = append(resultsTmp, SearchResultTmp{ResultId: res.ResultIds[i], ResultDistance: res.ResultDistances[i]}) } } - // 2. Reduce results + // 4. Reduce results sort.Slice(resultsTmp, func(i, j int) bool { return resultsTmp[i].ResultDistance < resultsTmp[j].ResultDistance }) - resultsTmp = resultsTmp[:TopK] + resultsTmp = resultsTmp[:query.TopK] var entities = msgPb.Entities{ Ids: make([]int64, 0), } @@ -547,7 +550,7 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status { results.RowNum = int64(len(results.Distances)) - // 3. publish result to pulsar + // 5. publish result to pulsar node.PublishSearchResult(&results) } diff --git a/reader/read_node/segment.go b/reader/read_node/segment.go index 6889ccd74..b52e5cfc5 100644 --- a/reader/read_node/segment.go +++ b/reader/read_node/segment.go @@ -13,7 +13,6 @@ package reader */ import "C" import ( - "encoding/json" "fmt" "github.com/czs007/suvlim/errors" msgPb "github.com/czs007/suvlim/pkg/master/grpc/message" @@ -75,12 +74,21 @@ func (s *Segment) CloseSegment(collection* Collection) error { Close(CSegmentBase c_segment); */ var status = C.Close(s.SegmentPtr) + s.SegmentStatus = SegmentClosed + if status != 0 { return errors.New("Close segment failed, error code = " + strconv.Itoa(int(status))) } // Build index after closing segment - go s.buildIndex(collection) + s.SegmentStatus = SegmentIndexing + s.buildIndex(collection) + + // TODO: remove redundant segment indexed status + // Change segment status to indexed + s.SegmentStatus = SegmentIndexed + + s.SegmentStatus = SegmentClosed return nil } @@ -182,7 +190,7 @@ func (s *Segment) SegmentDelete(offset int64, entityIDs *[]int64, timestamps *[] return nil } -func (s *Segment) SegmentSearch(queryJson string, timestamp uint64, vectorRecord *msgPb.VectorRowRecord) (*SearchResult, error) { +func (s *Segment) SegmentSearch(query *QueryInfo, timestamp uint64, vectorRecord *msgPb.VectorRowRecord) (*SearchResult, error) { /* int Search(CSegmentBase c_segment, @@ -193,20 +201,7 @@ func (s *Segment) SegmentSearch(queryJson string, timestamp uint64, vectorRecord long int* result_ids, float* result_distances); */ - type QueryInfo struct { - NumQueries int64 `json:"num_queries"` - TopK int `json:"topK"` - FieldName string `json:"field_name"` - } - - type CQueryInfo C.CQueryInfo - - var query QueryInfo - var err = json.Unmarshal([]byte(queryJson), &query) - if err != nil { - return nil, err - } - fmt.Println(query) + //type CQueryInfo C.CQueryInfo cQuery := C.CQueryInfo{ num_queries: C.long(query.NumQueries), diff --git a/reader/read_node/segment_service.go b/reader/read_node/segment_service.go index 9f5ad5ae4..7bed55212 100644 --- a/reader/read_node/segment_service.go +++ b/reader/read_node/segment_service.go @@ -10,24 +10,22 @@ import ( ) func (node *QueryNode) SegmentsManagement() { - node.queryNodeTimeSync.UpdateTSOTimeSync() - var timeNow = node.queryNodeTimeSync.TSOTimeSync + //node.queryNodeTimeSync.UpdateTSOTimeSync() + //var timeNow = node.queryNodeTimeSync.TSOTimeSync + + timeNow := node.messageClient.GetTimeNow() + for _, collection := range node.Collections { for _, partition := range collection.Partitions { for _, oldSegment := range partition.OpenedSegments { // TODO: check segment status if timeNow >= oldSegment.SegmentCloseTime { - // start new segment and add it into partition.OpenedSegments - // TODO: get segmentID from master - var segmentID int64 = 0 - var newSegment = partition.NewSegment(segmentID) - newSegment.SegmentCloseTime = timeNow + SegmentLifetime - partition.OpenedSegments = append(partition.OpenedSegments, newSegment) - node.SegmentsMap[segmentID] = newSegment - // close old segment and move it into partition.ClosedSegments - // TODO: check status - var _ = oldSegment.CloseSegment(collection) + if oldSegment.SegmentStatus == SegmentClosed { + log.Println("Never reach here, Opened segment cannot be closed") + continue + } + go oldSegment.CloseSegment(collection) partition.ClosedSegments = append(partition.ClosedSegments, oldSegment) } } @@ -47,20 +45,38 @@ func (node *QueryNode) SegmentManagementService() { func (node *QueryNode) SegmentStatistic(sleepMillisecondTime int) { var statisticData = make([]masterPb.SegmentStat, 0) - for _, collection := range node.Collections { - for _, partition := range collection.Partitions { - for _, openedSegment := range partition.OpenedSegments { - currentMemSize := openedSegment.GetMemSize() - memIncreaseRate := float32((int64(currentMemSize))-(int64(openedSegment.LastMemSize))) / (float32(sleepMillisecondTime) / 1000) - stat := masterPb.SegmentStat{ - // TODO: set master pb's segment id type from uint64 to int64 - SegmentId: uint64(openedSegment.SegmentId), - MemorySize: currentMemSize, - MemoryRate: memIncreaseRate, - } - statisticData = append(statisticData, stat) - } + //for _, collection := range node.Collections { + // for _, partition := range collection.Partitions { + // for _, openedSegment := range partition.OpenedSegments { + // currentMemSize := openedSegment.GetMemSize() + // memIncreaseRate := float32((int64(currentMemSize))-(int64(openedSegment.LastMemSize))) / (float32(sleepMillisecondTime) / 1000) + // stat := masterPb.SegmentStat{ + // // TODO: set master pb's segment id type from uint64 to int64 + // SegmentId: uint64(openedSegment.SegmentId), + // MemorySize: currentMemSize, + // MemoryRate: memIncreaseRate, + // } + // statisticData = append(statisticData, stat) + // } + // } + //} + + for segmentID, segment := range node.SegmentsMap { + currentMemSize := segment.GetMemSize() + memIncreaseRate := float32((int64(currentMemSize))-(int64(segment.LastMemSize))) / (float32(sleepMillisecondTime) / 1000) + segment.LastMemSize = currentMemSize + + //segmentStatus := segment.SegmentStatus + //segmentNumOfRows := segment.GetRowCount() + + stat := masterPb.SegmentStat{ + // TODO: set master pb's segment id type from uint64 to int64 + SegmentId: uint64(segmentID), + MemorySize: currentMemSize, + MemoryRate: memIncreaseRate, } + + statisticData = append(statisticData, stat) } var status = node.PublicStatistic(&statisticData) diff --git a/reader/read_node/segment_test.go b/reader/read_node/segment_test.go index 8595e7a88..9c2fd653a 100644 --- a/reader/read_node/segment_test.go +++ b/reader/read_node/segment_test.go @@ -143,7 +143,8 @@ func TestSegment_SegmentSearch(t *testing.T) { var vectorRecord = msgPb.VectorRowRecord{ FloatData: queryRawData, } - var searchRes, searchErr = segment.SegmentSearch(queryJson, timestamps[N/2], &vectorRecord) + query := node.QueryJson2Info(&queryJson) + var searchRes, searchErr = segment.SegmentSearch(query, timestamps[N/2], &vectorRecord) assert.NoError(t, searchErr) fmt.Println(searchRes) -- GitLab