提交 62bee091 编写于 作者: B bigsheeper 提交者: yefu.chen

Fix search failed about topK

Signed-off-by: Nbigsheeper <yihao.dai@zilliz.com>
上级 93b9a06b
......@@ -587,6 +587,7 @@ SegmentNaive::BuildIndex(IndexMetaPtr remote_index_meta) {
if(record_.ack_responder_.GetAck() < 1024 * 4) {
return Status(SERVER_BUILD_INDEX_ERROR, "too few elements");
}
index_meta_ = remote_index_meta;
for (auto&[index_name, entry]: index_meta_->get_entries()) {
assert(entry.index_name == index_name);
const auto &field = (*schema_)[entry.field_name];
......
......@@ -238,7 +238,7 @@ TEST(CApiTest, BuildIndexTest) {
CQueryInfo queryInfo{1, 10, "fakevec"};
auto sea_res = Search(
segment, queryInfo, 1, query_raw_data.data(), DIM, result_ids, result_distances);
segment, queryInfo, 20, query_raw_data.data(), DIM, result_ids, result_distances);
assert(sea_res == 0);
DeleteCollection(collection);
......
......@@ -39,6 +39,14 @@ type MessageClient struct {
MessageClientID int
}
func (mc *MessageClient) GetTimeNow() uint64 {
msg, ok := <-mc.timeSyncCfg.TimeSync()
if !ok {
fmt.Println("cnn't get data from timesync chan")
}
return msg.Timestamp
}
func (mc *MessageClient) TimeSyncStart() uint64 {
return mc.timestampBatchStart
}
......
......@@ -96,7 +96,12 @@ func (node *QueryNode) processSegmentCreate(id string, value string) {
if collection != nil {
partition := collection.GetPartitionByName(segment.PartitionTag)
if partition != nil {
partition.NewSegment(int64(segment.SegmentID)) // todo change all to uint64
newSegmentID := int64(segment.SegmentID) // todo change all to uint64
// start new segment and add it into partition.OpenedSegments
newSegment := partition.NewSegment(newSegmentID)
newSegment.SegmentStatus = SegmentOpened
partition.OpenedSegments = append(partition.OpenedSegments, newSegment)
node.SegmentsMap[newSegmentID] = newSegment
}
}
// segment.CollectionName
......
......@@ -14,7 +14,9 @@ package reader
import "C"
import (
"encoding/json"
"fmt"
"log"
"sort"
"sync"
"sync/atomic"
......@@ -56,6 +58,12 @@ type QueryNodeDataBuffer struct {
validSearchBuffer []bool
}
type QueryInfo struct {
NumQueries int64 `json:"num_queries"`
TopK int `json:"topK"`
FieldName string `json:"field_name"`
}
type QueryNode struct {
QueryNodeId uint64
Collections []*Collection
......@@ -463,6 +471,19 @@ func (node *QueryNode) DoDelete(segmentID int64, deleteIDs *[]int64, deleteTimes
return msgPb.Status{ErrorCode: msgPb.ErrorCode_SUCCESS}
}
func (node *QueryNode) QueryJson2Info(queryJson *string) *QueryInfo {
var query QueryInfo
var err = json.Unmarshal([]byte(*queryJson), &query)
if err != nil {
log.Printf("Unmarshal query json failed")
return nil
}
fmt.Println(query)
return &query
}
func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
// TODO: use client id to publish results to different clients
// var clientId = (*(searchMessages[0])).ClientId
......@@ -475,16 +496,7 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
// Traverse all messages in the current messageClient.
// TODO: Do not receive batched search requests
for _, msg := range searchMessages {
var collectionName = searchMessages[0].CollectionName
var targetCollection, err = node.GetCollectionByCollectionName(collectionName)
if err != nil {
fmt.Println(err.Error())
return msgPb.Status{ErrorCode: 1}
}
var resultsTmp = make([]SearchResultTmp, 0)
// TODO: get top-k's k from queryString
const TopK = 1
var timestamp = msg.Timestamp
var vector = msg.Records
......@@ -498,36 +510,27 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
return msgPb.Status{ErrorCode: 1}
}
// 2. Do search in all segments
for _, partition := range targetCollection.Partitions {
for _, openSegment := range partition.OpenedSegments {
var res, err = openSegment.SegmentSearch(queryJson, timestamp, vector)
if err != nil {
fmt.Println(err.Error())
return msgPb.Status{ErrorCode: 1}
}
fmt.Println(res.ResultIds)
for i := 0; i < len(res.ResultIds); i++ {
resultsTmp = append(resultsTmp, SearchResultTmp{ResultId: res.ResultIds[i], ResultDistance: res.ResultDistances[i]})
}
// 2. Get query information from query json
query := node.QueryJson2Info(&queryJson)
// 3. Do search in all segments
for _, segment := range node.SegmentsMap {
var res, err = segment.SegmentSearch(query, timestamp, vector)
if err != nil {
fmt.Println(err.Error())
return msgPb.Status{ErrorCode: 1}
}
for _, closedSegment := range partition.ClosedSegments {
var res, err = closedSegment.SegmentSearch(queryJson, timestamp, vector)
if err != nil {
fmt.Println(err.Error())
return msgPb.Status{ErrorCode: 1}
}
for i := 0; i <= len(res.ResultIds); i++ {
resultsTmp = append(resultsTmp, SearchResultTmp{ResultId: res.ResultIds[i], ResultDistance: res.ResultDistances[i]})
}
fmt.Println(res.ResultIds)
for i := 0; i < len(res.ResultIds); i++ {
resultsTmp = append(resultsTmp, SearchResultTmp{ResultId: res.ResultIds[i], ResultDistance: res.ResultDistances[i]})
}
}
// 2. Reduce results
// 4. Reduce results
sort.Slice(resultsTmp, func(i, j int) bool {
return resultsTmp[i].ResultDistance < resultsTmp[j].ResultDistance
})
resultsTmp = resultsTmp[:TopK]
resultsTmp = resultsTmp[:query.TopK]
var entities = msgPb.Entities{
Ids: make([]int64, 0),
}
......@@ -547,7 +550,7 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
results.RowNum = int64(len(results.Distances))
// 3. publish result to pulsar
// 5. publish result to pulsar
node.PublishSearchResult(&results)
}
......
......@@ -13,7 +13,6 @@ package reader
*/
import "C"
import (
"encoding/json"
"fmt"
"github.com/czs007/suvlim/errors"
msgPb "github.com/czs007/suvlim/pkg/master/grpc/message"
......@@ -75,12 +74,21 @@ func (s *Segment) CloseSegment(collection* Collection) error {
Close(CSegmentBase c_segment);
*/
var status = C.Close(s.SegmentPtr)
s.SegmentStatus = SegmentClosed
if status != 0 {
return errors.New("Close segment failed, error code = " + strconv.Itoa(int(status)))
}
// Build index after closing segment
go s.buildIndex(collection)
s.SegmentStatus = SegmentIndexing
s.buildIndex(collection)
// TODO: remove redundant segment indexed status
// Change segment status to indexed
s.SegmentStatus = SegmentIndexed
s.SegmentStatus = SegmentClosed
return nil
}
......@@ -182,7 +190,7 @@ func (s *Segment) SegmentDelete(offset int64, entityIDs *[]int64, timestamps *[]
return nil
}
func (s *Segment) SegmentSearch(queryJson string, timestamp uint64, vectorRecord *msgPb.VectorRowRecord) (*SearchResult, error) {
func (s *Segment) SegmentSearch(query *QueryInfo, timestamp uint64, vectorRecord *msgPb.VectorRowRecord) (*SearchResult, error) {
/*
int
Search(CSegmentBase c_segment,
......@@ -193,20 +201,7 @@ func (s *Segment) SegmentSearch(queryJson string, timestamp uint64, vectorRecord
long int* result_ids,
float* result_distances);
*/
type QueryInfo struct {
NumQueries int64 `json:"num_queries"`
TopK int `json:"topK"`
FieldName string `json:"field_name"`
}
type CQueryInfo C.CQueryInfo
var query QueryInfo
var err = json.Unmarshal([]byte(queryJson), &query)
if err != nil {
return nil, err
}
fmt.Println(query)
//type CQueryInfo C.CQueryInfo
cQuery := C.CQueryInfo{
num_queries: C.long(query.NumQueries),
......
......@@ -10,24 +10,22 @@ import (
)
func (node *QueryNode) SegmentsManagement() {
node.queryNodeTimeSync.UpdateTSOTimeSync()
var timeNow = node.queryNodeTimeSync.TSOTimeSync
//node.queryNodeTimeSync.UpdateTSOTimeSync()
//var timeNow = node.queryNodeTimeSync.TSOTimeSync
timeNow := node.messageClient.GetTimeNow()
for _, collection := range node.Collections {
for _, partition := range collection.Partitions {
for _, oldSegment := range partition.OpenedSegments {
// TODO: check segment status
if timeNow >= oldSegment.SegmentCloseTime {
// start new segment and add it into partition.OpenedSegments
// TODO: get segmentID from master
var segmentID int64 = 0
var newSegment = partition.NewSegment(segmentID)
newSegment.SegmentCloseTime = timeNow + SegmentLifetime
partition.OpenedSegments = append(partition.OpenedSegments, newSegment)
node.SegmentsMap[segmentID] = newSegment
// close old segment and move it into partition.ClosedSegments
// TODO: check status
var _ = oldSegment.CloseSegment(collection)
if oldSegment.SegmentStatus == SegmentClosed {
log.Println("Never reach here, Opened segment cannot be closed")
continue
}
go oldSegment.CloseSegment(collection)
partition.ClosedSegments = append(partition.ClosedSegments, oldSegment)
}
}
......@@ -47,20 +45,38 @@ func (node *QueryNode) SegmentManagementService() {
func (node *QueryNode) SegmentStatistic(sleepMillisecondTime int) {
var statisticData = make([]masterPb.SegmentStat, 0)
for _, collection := range node.Collections {
for _, partition := range collection.Partitions {
for _, openedSegment := range partition.OpenedSegments {
currentMemSize := openedSegment.GetMemSize()
memIncreaseRate := float32((int64(currentMemSize))-(int64(openedSegment.LastMemSize))) / (float32(sleepMillisecondTime) / 1000)
stat := masterPb.SegmentStat{
// TODO: set master pb's segment id type from uint64 to int64
SegmentId: uint64(openedSegment.SegmentId),
MemorySize: currentMemSize,
MemoryRate: memIncreaseRate,
}
statisticData = append(statisticData, stat)
}
//for _, collection := range node.Collections {
// for _, partition := range collection.Partitions {
// for _, openedSegment := range partition.OpenedSegments {
// currentMemSize := openedSegment.GetMemSize()
// memIncreaseRate := float32((int64(currentMemSize))-(int64(openedSegment.LastMemSize))) / (float32(sleepMillisecondTime) / 1000)
// stat := masterPb.SegmentStat{
// // TODO: set master pb's segment id type from uint64 to int64
// SegmentId: uint64(openedSegment.SegmentId),
// MemorySize: currentMemSize,
// MemoryRate: memIncreaseRate,
// }
// statisticData = append(statisticData, stat)
// }
// }
//}
for segmentID, segment := range node.SegmentsMap {
currentMemSize := segment.GetMemSize()
memIncreaseRate := float32((int64(currentMemSize))-(int64(segment.LastMemSize))) / (float32(sleepMillisecondTime) / 1000)
segment.LastMemSize = currentMemSize
//segmentStatus := segment.SegmentStatus
//segmentNumOfRows := segment.GetRowCount()
stat := masterPb.SegmentStat{
// TODO: set master pb's segment id type from uint64 to int64
SegmentId: uint64(segmentID),
MemorySize: currentMemSize,
MemoryRate: memIncreaseRate,
}
statisticData = append(statisticData, stat)
}
var status = node.PublicStatistic(&statisticData)
......
......@@ -143,7 +143,8 @@ func TestSegment_SegmentSearch(t *testing.T) {
var vectorRecord = msgPb.VectorRowRecord{
FloatData: queryRawData,
}
var searchRes, searchErr = segment.SegmentSearch(queryJson, timestamps[N/2], &vectorRecord)
query := node.QueryJson2Info(&queryJson)
var searchRes, searchErr = segment.SegmentSearch(query, timestamps[N/2], &vectorRecord)
assert.NoError(t, searchErr)
fmt.Println(searchRes)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册