未验证 提交 49443e8a 编写于 作者: Y yukun 提交者: GitHub

Add retrieve taskscheduler implementation (#5353)

See also: #5257 
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>
上级 bfc057d5
...@@ -28,6 +28,7 @@ import ( ...@@ -28,6 +28,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/internal/util/typeutil"
) )
...@@ -1170,7 +1171,62 @@ func (node *ProxyNode) Search(ctx context.Context, request *milvuspb.SearchReque ...@@ -1170,7 +1171,62 @@ func (node *ProxyNode) Search(ctx context.Context, request *milvuspb.SearchReque
} }
func (node *ProxyNode) Retrieve(ctx context.Context, request *milvuspb.RetrieveRequest) (*milvuspb.RetrieveResults, error) { func (node *ProxyNode) Retrieve(ctx context.Context, request *milvuspb.RetrieveRequest) (*milvuspb.RetrieveResults, error) {
return nil, nil rt := &RetrieveTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
RetrieveRequest: &internalpb.RetrieveRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
SourceID: Params.ProxyID,
},
ResultChannelID: strconv.FormatInt(Params.ProxyID, 10),
},
queryMsgStream: node.queryMsgStream,
resultBuf: make(chan []*internalpb.RetrieveResults),
retrieve: request,
}
err := node.sched.DqQueue.Enqueue(rt)
if err != nil {
return &milvuspb.RetrieveResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, nil
}
log.Debug("Retrieve",
zap.String("role", Params.RoleName),
zap.Int64("msgID", rt.Base.MsgID),
zap.Uint64("timestamp", rt.Base.Timestamp),
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName),
zap.Any("partitions", request.PartitionNames),
zap.Any("len(Ids)", len(request.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data)))
defer func() {
log.Debug("Retrieve Done",
zap.Error(err),
zap.String("role", Params.RoleName),
zap.Int64("msgID", rt.Base.MsgID),
zap.Uint64("timestamp", rt.Base.Timestamp),
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName),
zap.Any("partitions", request.PartitionNames),
zap.Any("len(Ids)", len(request.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data)))
}()
err = rt.WaitToFinish()
if err != nil {
return &milvuspb.RetrieveResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, nil
}
return rt.result, nil
} }
func (node *ProxyNode) Flush(ctx context.Context, request *milvuspb.FlushRequest) (*commonpb.Status, error) { func (node *ProxyNode) Flush(ctx context.Context, request *milvuspb.FlushRequest) (*commonpb.Status, error) {
......
...@@ -401,13 +401,17 @@ func (sched *TaskScheduler) queryResultLoop() { ...@@ -401,13 +401,17 @@ func (sched *TaskScheduler) queryResultLoop() {
queryResultMsgStream, _ := sched.msFactory.NewQueryMsgStream(sched.ctx) queryResultMsgStream, _ := sched.msFactory.NewQueryMsgStream(sched.ctx)
queryResultMsgStream.AsConsumer(Params.SearchResultChannelNames, Params.ProxySubName) queryResultMsgStream.AsConsumer(Params.SearchResultChannelNames, Params.ProxySubName)
log.Debug("proxynode", zap.Strings("search result channel names", Params.SearchResultChannelNames)) log.Debug("proxynode", zap.Strings("search result channel names", Params.SearchResultChannelNames))
queryResultMsgStream.AsConsumer(Params.RetrieveResultChannelNames, Params.ProxySubName)
log.Debug("proxynode", zap.Strings("Retrieve result channel names", Params.RetrieveResultChannelNames))
log.Debug("proxynode", zap.String("proxySubName", Params.ProxySubName)) log.Debug("proxynode", zap.String("proxySubName", Params.ProxySubName))
queryNodeNum := Params.QueryNodeNum queryNodeNum := Params.QueryNodeNum
queryResultMsgStream.Start() queryResultMsgStream.Start()
defer queryResultMsgStream.Close() defer queryResultMsgStream.Close()
queryResultBuf := make(map[UniqueID][]*internalpb.SearchResults) queryResultBuf := make(map[UniqueID][]*internalpb.SearchResults)
retrieveResultBuf := make(map[UniqueID][]*internalpb.RetrieveResults)
for { for {
select { select {
...@@ -422,41 +426,75 @@ func (sched *TaskScheduler) queryResultLoop() { ...@@ -422,41 +426,75 @@ func (sched *TaskScheduler) queryResultLoop() {
for _, tsMsg := range msgPack.Msgs { for _, tsMsg := range msgPack.Msgs {
sp, ctx := trace.StartSpanFromContext(tsMsg.TraceCtx()) sp, ctx := trace.StartSpanFromContext(tsMsg.TraceCtx())
tsMsg.SetTraceCtx(ctx) tsMsg.SetTraceCtx(ctx)
searchResultMsg, _ := tsMsg.(*msgstream.SearchResultMsg) if searchResultMsg, srOk := tsMsg.(*msgstream.SearchResultMsg); srOk {
reqID := searchResultMsg.Base.MsgID reqID := searchResultMsg.Base.MsgID
reqIDStr := strconv.FormatInt(reqID, 10) reqIDStr := strconv.FormatInt(reqID, 10)
t := sched.getTaskByReqID(reqID) t := sched.getTaskByReqID(reqID)
if t == nil { if t == nil {
log.Debug("proxynode", zap.String("QueryResult GetTaskByReqID failed, reqID = ", reqIDStr)) log.Debug("proxynode", zap.String("QueryResult GetTaskByReqID failed, reqID = ", reqIDStr))
delete(queryResultBuf, reqID) delete(queryResultBuf, reqID)
continue continue
} }
_, ok = queryResultBuf[reqID] _, ok = queryResultBuf[reqID]
if !ok { if !ok {
queryResultBuf[reqID] = make([]*internalpb.SearchResults, 0) queryResultBuf[reqID] = make([]*internalpb.SearchResults, 0)
} }
queryResultBuf[reqID] = append(queryResultBuf[reqID], &searchResultMsg.SearchResults) queryResultBuf[reqID] = append(queryResultBuf[reqID], &searchResultMsg.SearchResults)
//t := sched.getTaskByReqID(reqID) //t := sched.getTaskByReqID(reqID)
{ {
colName := t.(*SearchTask).query.CollectionName colName := t.(*SearchTask).query.CollectionName
log.Debug("Getcollection", zap.String("collection name", colName), zap.String("reqID", reqIDStr), zap.Int("answer cnt", len(queryResultBuf[reqID]))) log.Debug("Getcollection", zap.String("collection name", colName), zap.String("reqID", reqIDStr), zap.Int("answer cnt", len(queryResultBuf[reqID])))
}
if len(queryResultBuf[reqID]) == queryNodeNum {
t := sched.getTaskByReqID(reqID)
if t != nil {
qt, ok := t.(*SearchTask)
if ok {
qt.resultBuf <- queryResultBuf[reqID]
delete(queryResultBuf, reqID)
}
} else {
// log.Printf("task with reqID %v is nil", reqID)
}
}
sp.Finish()
} }
if len(queryResultBuf[reqID]) == queryNodeNum { if retrieveResultMsg, rtOk := tsMsg.(*msgstream.RetrieveResultMsg); rtOk {
reqID := retrieveResultMsg.Base.MsgID
reqIDStr := strconv.FormatInt(reqID, 10)
t := sched.getTaskByReqID(reqID) t := sched.getTaskByReqID(reqID)
if t != nil { if t == nil {
qt, ok := t.(*SearchTask) log.Debug("proxynode", zap.String("RetrieveResult GetTaskByReqID failed, reqID = ", reqIDStr))
if ok { delete(retrieveResultBuf, reqID)
qt.resultBuf <- queryResultBuf[reqID] continue
delete(queryResultBuf, reqID) }
}
} else {
// log.Printf("task with reqID %v is nil", reqID) _, ok = retrieveResultBuf[reqID]
if !ok {
retrieveResultBuf[reqID] = make([]*internalpb.RetrieveResults, 0)
}
retrieveResultBuf[reqID] = append(retrieveResultBuf[reqID], &retrieveResultMsg.RetrieveResults)
{
colName := t.(*RetrieveTask).retrieve.CollectionName
log.Debug("Getcollection", zap.String("collection name", colName), zap.String("reqID", reqIDStr), zap.Int("answer cnt", len(retrieveResultBuf[reqID])))
}
if len(retrieveResultBuf[reqID]) == queryNodeNum {
t := sched.getTaskByReqID(reqID)
if t != nil {
rt, ok := t.(*RetrieveTask)
if ok {
rt.resultBuf <- retrieveResultBuf[reqID]
delete(retrieveResultBuf, reqID)
}
} else {
}
} }
sp.Finish()
} }
sp.Finish()
} }
case <-sched.ctx.Done(): case <-sched.ctx.Done():
log.Debug("proxynode server is closed ...") log.Debug("proxynode server is closed ...")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册