diff --git a/internal/proxynode/impl.go b/internal/proxynode/impl.go index 4891fed1ff2214a5d05ed00acea6fe62c49483ec..7c8b51d86527189cf355591bfd955a7d5e0cca14 100644 --- a/internal/proxynode/impl.go +++ b/internal/proxynode/impl.go @@ -28,6 +28,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/util/typeutil" ) @@ -1170,7 +1171,62 @@ func (node *ProxyNode) Search(ctx context.Context, request *milvuspb.SearchReque } func (node *ProxyNode) Retrieve(ctx context.Context, request *milvuspb.RetrieveRequest) (*milvuspb.RetrieveResults, error) { - return nil, nil + rt := &RetrieveTask{ + ctx: ctx, + Condition: NewTaskCondition(ctx), + RetrieveRequest: &internalpb.RetrieveRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Retrieve, + SourceID: Params.ProxyID, + }, + ResultChannelID: strconv.FormatInt(Params.ProxyID, 10), + }, + queryMsgStream: node.queryMsgStream, + resultBuf: make(chan []*internalpb.RetrieveResults), + retrieve: request, + } + + err := node.sched.DqQueue.Enqueue(rt) + if err != nil { + return &milvuspb.RetrieveResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + }, + }, nil + } + + log.Debug("Retrieve", + zap.String("role", Params.RoleName), + zap.Int64("msgID", rt.Base.MsgID), + zap.Uint64("timestamp", rt.Base.Timestamp), + zap.String("db", request.DbName), + zap.String("collection", request.CollectionName), + zap.Any("partitions", request.PartitionNames), + zap.Any("len(Ids)", len(request.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data))) + defer func() { + log.Debug("Retrieve Done", + zap.Error(err), + zap.String("role", Params.RoleName), + zap.Int64("msgID", rt.Base.MsgID), + zap.Uint64("timestamp", rt.Base.Timestamp), + zap.String("db", request.DbName), + zap.String("collection", request.CollectionName), + zap.Any("partitions", request.PartitionNames), + zap.Any("len(Ids)", len(request.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data))) + }() + + err = rt.WaitToFinish() + if err != nil { + return &milvuspb.RetrieveResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + }, + }, nil + } + + return rt.result, nil } func (node *ProxyNode) Flush(ctx context.Context, request *milvuspb.FlushRequest) (*commonpb.Status, error) { diff --git a/internal/proxynode/task_scheduler.go b/internal/proxynode/task_scheduler.go index a3d097162181f1987326f649cfeda4d59a35c77c..0453d89270af251d0e89c9ddfeffbd32dfb46559 100644 --- a/internal/proxynode/task_scheduler.go +++ b/internal/proxynode/task_scheduler.go @@ -401,13 +401,17 @@ func (sched *TaskScheduler) queryResultLoop() { queryResultMsgStream, _ := sched.msFactory.NewQueryMsgStream(sched.ctx) queryResultMsgStream.AsConsumer(Params.SearchResultChannelNames, Params.ProxySubName) log.Debug("proxynode", zap.Strings("search result channel names", Params.SearchResultChannelNames)) + queryResultMsgStream.AsConsumer(Params.RetrieveResultChannelNames, Params.ProxySubName) + log.Debug("proxynode", zap.Strings("Retrieve result channel names", Params.RetrieveResultChannelNames)) log.Debug("proxynode", zap.String("proxySubName", Params.ProxySubName)) + queryNodeNum := Params.QueryNodeNum queryResultMsgStream.Start() defer queryResultMsgStream.Close() queryResultBuf := make(map[UniqueID][]*internalpb.SearchResults) + retrieveResultBuf := make(map[UniqueID][]*internalpb.RetrieveResults) for { select { @@ -422,41 +426,75 @@ func (sched *TaskScheduler) queryResultLoop() { for _, tsMsg := range msgPack.Msgs { sp, ctx := trace.StartSpanFromContext(tsMsg.TraceCtx()) tsMsg.SetTraceCtx(ctx) - searchResultMsg, _ := tsMsg.(*msgstream.SearchResultMsg) - reqID := searchResultMsg.Base.MsgID - reqIDStr := strconv.FormatInt(reqID, 10) - t := sched.getTaskByReqID(reqID) - if t == nil { - log.Debug("proxynode", zap.String("QueryResult GetTaskByReqID failed, reqID = ", reqIDStr)) - delete(queryResultBuf, reqID) - continue - } + if searchResultMsg, srOk := tsMsg.(*msgstream.SearchResultMsg); srOk { + reqID := searchResultMsg.Base.MsgID + reqIDStr := strconv.FormatInt(reqID, 10) + t := sched.getTaskByReqID(reqID) + if t == nil { + log.Debug("proxynode", zap.String("QueryResult GetTaskByReqID failed, reqID = ", reqIDStr)) + delete(queryResultBuf, reqID) + continue + } - _, ok = queryResultBuf[reqID] - if !ok { - queryResultBuf[reqID] = make([]*internalpb.SearchResults, 0) - } - queryResultBuf[reqID] = append(queryResultBuf[reqID], &searchResultMsg.SearchResults) + _, ok = queryResultBuf[reqID] + if !ok { + queryResultBuf[reqID] = make([]*internalpb.SearchResults, 0) + } + queryResultBuf[reqID] = append(queryResultBuf[reqID], &searchResultMsg.SearchResults) - //t := sched.getTaskByReqID(reqID) - { - colName := t.(*SearchTask).query.CollectionName - log.Debug("Getcollection", zap.String("collection name", colName), zap.String("reqID", reqIDStr), zap.Int("answer cnt", len(queryResultBuf[reqID]))) + //t := sched.getTaskByReqID(reqID) + { + colName := t.(*SearchTask).query.CollectionName + log.Debug("Getcollection", zap.String("collection name", colName), zap.String("reqID", reqIDStr), zap.Int("answer cnt", len(queryResultBuf[reqID]))) + } + if len(queryResultBuf[reqID]) == queryNodeNum { + t := sched.getTaskByReqID(reqID) + if t != nil { + qt, ok := t.(*SearchTask) + if ok { + qt.resultBuf <- queryResultBuf[reqID] + delete(queryResultBuf, reqID) + } + } else { + + // log.Printf("task with reqID %v is nil", reqID) + } + } + sp.Finish() } - if len(queryResultBuf[reqID]) == queryNodeNum { + if retrieveResultMsg, rtOk := tsMsg.(*msgstream.RetrieveResultMsg); rtOk { + reqID := retrieveResultMsg.Base.MsgID + reqIDStr := strconv.FormatInt(reqID, 10) t := sched.getTaskByReqID(reqID) - if t != nil { - qt, ok := t.(*SearchTask) - if ok { - qt.resultBuf <- queryResultBuf[reqID] - delete(queryResultBuf, reqID) - } - } else { + if t == nil { + log.Debug("proxynode", zap.String("RetrieveResult GetTaskByReqID failed, reqID = ", reqIDStr)) + delete(retrieveResultBuf, reqID) + continue + } - // log.Printf("task with reqID %v is nil", reqID) + _, ok = retrieveResultBuf[reqID] + if !ok { + retrieveResultBuf[reqID] = make([]*internalpb.RetrieveResults, 0) + } + retrieveResultBuf[reqID] = append(retrieveResultBuf[reqID], &retrieveResultMsg.RetrieveResults) + + { + colName := t.(*RetrieveTask).retrieve.CollectionName + log.Debug("Getcollection", zap.String("collection name", colName), zap.String("reqID", reqIDStr), zap.Int("answer cnt", len(retrieveResultBuf[reqID]))) + } + if len(retrieveResultBuf[reqID]) == queryNodeNum { + t := sched.getTaskByReqID(reqID) + if t != nil { + rt, ok := t.(*RetrieveTask) + if ok { + rt.resultBuf <- retrieveResultBuf[reqID] + delete(retrieveResultBuf, reqID) + } + } else { + } } + sp.Finish() } - sp.Finish() } case <-sched.ctx.Done(): log.Debug("proxynode server is closed ...")