提交 450ea631 编写于 作者: X xige-16 提交者: yefu.chen

Fix queryService assign queryChannel failure

Signed-off-by: Nxige-16 <xi.ge@zilliz.com>
上级 549bda93
......@@ -51,7 +51,6 @@ func (s *Server) Init() error {
if err := s.queryService.Init(); err != nil {
return err
}
s.queryService.SetEnableGrpc(true)
return nil
}
......
......@@ -650,9 +650,8 @@ func TestMasterService(t *testing.T) {
rsp, err := core.DescribeIndex(req)
assert.Nil(t, err)
assert.Equal(t, rsp.Status.ErrorCode, commonpb.ErrorCode_SUCCESS)
assert.Equal(t, len(rsp.IndexDescriptions), 2)
idxNames := []string{rsp.IndexDescriptions[0].IndexName, rsp.IndexDescriptions[1].IndexName}
assert.ElementsMatch(t, idxNames, []string{"testColl_index_100", Params.DefaultIndexName})
assert.Equal(t, len(rsp.IndexDescriptions), 1)
assert.Equal(t, rsp.IndexDescriptions[0].IndexName, Params.DefaultIndexName)
})
t.Run("flush segment", func(t *testing.T) {
......
......@@ -701,7 +701,7 @@ func (mt *metaTable) GetNotIndexedSegments(collName string, fieldName string, id
mt.indexID2Meta[idx.IndexID] = *idxInfo
k2 := path.Join(IndexMetaPrefix, strconv.FormatInt(idx.IndexID, 10))
v2 := proto.MarshalTextString(idxInfo)
v2 := proto.MarshalTextString(idx)
meta := map[string]string{k1: v1, k2: v2}
err = mt.client.MultiSave(meta)
......@@ -751,20 +751,40 @@ func (mt *metaTable) GetIndexByName(collName string, fieldName string, indexName
if !ok {
return nil, errors.Errorf("collection %s not found", collName)
}
fieldSchema, err := mt.GetFieldSchema(collName, fieldName)
fileSchema, err := mt.GetFieldSchema(collName, fieldName)
if err != nil {
return nil, err
}
rstIndex := make([]pb.IndexInfo, 0, len(collMeta.FieldIndexes))
for _, idx := range collMeta.FieldIndexes {
if idx.FiledID == fieldSchema.FieldID {
idxInfo, ok := mt.indexID2Meta[idx.IndexID]
if !ok {
return nil, errors.Errorf("index id = %d not found", idx.IndexID)
}
if indexName == "" || idxInfo.IndexName == indexName {
rstIndex = append(rstIndex, idxInfo)
existMap := map[typeutil.UniqueID]bool{}
for _, partID := range collMeta.PartitionIDs {
partMeta, ok := mt.partitionID2Meta[partID]
if ok {
for _, segID := range partMeta.SegmentIDs {
idxMeta, ok := mt.segID2IndexMeta[segID]
if !ok {
continue
}
for idxID, segMeta := range *idxMeta {
if segMeta.FieldID != fileSchema.FieldID {
continue
}
idxMeta, ok := mt.indexID2Meta[idxID]
if !ok {
continue
}
if _, ok = existMap[idxID]; ok {
continue
}
if indexName == "" {
rstIndex = append(rstIndex, idxMeta)
} else if idxMeta.IndexName == indexName {
rstIndex = append(rstIndex, idxMeta)
}
existMap[idxID] = true
}
}
}
}
......
......@@ -109,7 +109,7 @@ func (p *ParamTable) Init() {
p.initQueryNodeID()
p.initQueryNodeNum()
p.initQueryTimeTickChannelName()
//p.initQueryTimeTickChannelName()
p.initQueryTimeTickReceiveBufSize()
p.initMinioEndPoint()
......@@ -140,14 +140,14 @@ func (p *ParamTable) Init() {
p.initDDReceiveBufSize()
p.initDDPulsarBufSize()
p.initSearchChannelNames()
p.initSearchResultChannelNames()
//p.initSearchChannelNames()
//p.initSearchResultChannelNames()
p.initSearchReceiveBufSize()
p.initSearchPulsarBufSize()
p.initSearchResultReceiveBufSize()
p.initStatsPublishInterval()
p.initStatsChannelName()
//p.initStatsChannelName()
p.initStatsReceiveBufSize()
})
}
......
......@@ -14,12 +14,12 @@ import "C"
import (
"context"
"errors"
"fmt"
"io"
"log"
"sync/atomic"
"github.com/zilliztech/milvus-distributed/internal/errors"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/msgstream/pulsarms"
"github.com/zilliztech/milvus-distributed/internal/msgstream/rmqms"
......@@ -118,21 +118,39 @@ func Init() {
func (node *QueryNode) Init() error {
registerReq := &queryPb.RegisterNodeRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_kNone,
SourceID: Params.QueryNodeID,
},
Address: &commonpb.Address{
Ip: Params.QueryNodeIP,
Port: Params.QueryNodePort,
},
}
response, err := node.queryClient.RegisterNode(registerReq)
resp, err := node.queryClient.RegisterNode(registerReq)
if err != nil {
panic(err)
}
if response.Status.ErrorCode != commonpb.ErrorCode_SUCCESS {
panic(response.Status.Reason)
if resp.Status.ErrorCode != commonpb.ErrorCode_SUCCESS {
panic(resp.Status.Reason)
}
for _, kv := range resp.InitParams.StartParams {
switch kv.Key {
case "StatsChannelName":
Params.StatsChannelName = kv.Value
case "TimeTickChannelName":
Params.QueryTimeTickChannelName = kv.Value
case "QueryChannelName":
Params.SearchChannelNames = append(Params.SearchChannelNames, kv.Value)
case "QueryResultChannelName":
Params.SearchResultChannelNames = append(Params.SearchResultChannelNames, kv.Value)
default:
return errors.Errorf("Invalid key: %v", kv.Key)
}
}
Params.QueryNodeID = response.InitParams.NodeID
fmt.Println("QueryNodeID is", Params.QueryNodeID)
if node.masterClient == nil {
......
......@@ -27,6 +27,10 @@ type queryServiceMock struct{}
func setup() {
Params.Init()
Params.initQueryTimeTickChannelName()
Params.initSearchResultChannelNames()
Params.initStatsChannelName()
Params.initSearchChannelNames()
Params.MetaRootPath = "/etcd/test/root/querynode"
}
......
......@@ -44,11 +44,11 @@ func newSearchService(ctx context.Context, replica collectionReplica, factory ms
searchResultStream, _ := factory.NewMsgStream(ctx)
// query node doesn't need to consumer any search or search result channel actively.
//consumeChannels := Params.SearchChannelNames
//consumeSubName := Params.MsgChannelSubName
//searchStream.AsConsumer(consumeChannels, consumeSubName)
//producerChannels := Params.SearchResultChannelNames
//searchResultStream.AsProducer(producerChannels)
consumeChannels := Params.SearchChannelNames
consumeSubName := Params.MsgChannelSubName
searchStream.AsConsumer(consumeChannels, consumeSubName)
producerChannels := Params.SearchResultChannelNames
searchResultStream.AsProducer(producerChannels)
searchServiceCtx, searchServiceCancel := context.WithCancel(ctx)
msgBuffer := make(chan msgstream.TsMsg, receiveBufSize)
......
......@@ -18,8 +18,8 @@ type metaReplica interface {
getPartitionStates(dbID UniqueID, collectionID UniqueID, partitionIDs []UniqueID) ([]*querypb.PartitionStates, error)
releaseCollection(dbID UniqueID, collectionID UniqueID) error
releasePartition(dbID UniqueID, collectionID UniqueID, partitionID UniqueID) error
addDmChannels(dbID UniqueID, collectionID UniqueID, channels2NodeID map[string]UniqueID) error
getAssignedNodeIDByChannelName(dbID UniqueID, collectionID UniqueID, channel string) (UniqueID, error)
addDmChannels(dbID UniqueID, collectionID UniqueID, channels2NodeID map[string]int64) error
getAssignedNodeIDByChannelName(dbID UniqueID, collectionID UniqueID, channel string) (int64, error)
}
type segment struct {
......@@ -35,7 +35,7 @@ type partition struct {
type collection struct {
id UniqueID
partitions map[UniqueID]*partition
dmChannels2Node map[string]UniqueID
dmChannels2Node map[string]int64
schema *schemapb.CollectionSchema
}
......@@ -59,7 +59,7 @@ func (mp *metaReplicaImpl) addCollection(dbID UniqueID, collectionID UniqueID, s
//TODO:: assert dbID = 0 exist
if _, ok := mp.db2collections[dbID]; ok {
partitions := make(map[UniqueID]*partition)
channels := make(map[string]UniqueID)
channels := make(map[string]int64)
newCollection := &collection{
id: collectionID,
partitions: partitions,
......@@ -229,7 +229,7 @@ func (mp *metaReplicaImpl) releasePartition(dbID UniqueID, collectionID UniqueID
return errors.New("releasePartition: can't find dbID or collectionID or partitionID")
}
func (mp *metaReplicaImpl) addDmChannels(dbID UniqueID, collectionID UniqueID, channels2NodeID map[string]UniqueID) error {
func (mp *metaReplicaImpl) addDmChannels(dbID UniqueID, collectionID UniqueID, channels2NodeID map[string]int64) error {
if collections, ok := mp.db2collections[dbID]; ok {
for _, collection := range collections {
if collectionID == collection.id {
......@@ -243,7 +243,7 @@ func (mp *metaReplicaImpl) addDmChannels(dbID UniqueID, collectionID UniqueID, c
return errors.New("addDmChannels: can't find dbID or collectionID")
}
func (mp *metaReplicaImpl) getAssignedNodeIDByChannelName(dbID UniqueID, collectionID UniqueID, channel string) (UniqueID, error) {
func (mp *metaReplicaImpl) getAssignedNodeIDByChannelName(dbID UniqueID, collectionID UniqueID, channel string) (int64, error) {
if collections, ok := mp.db2collections[dbID]; ok {
for _, collection := range collections {
if collectionID == collection.id {
......
......@@ -5,6 +5,7 @@ import (
"fmt"
"sort"
"strconv"
"sync"
"sync/atomic"
nodeclient "github.com/zilliztech/milvus-distributed/internal/distributed/querynode/client"
......@@ -15,7 +16,6 @@ import (
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb2"
"github.com/zilliztech/milvus-distributed/internal/proto/milvuspb"
"github.com/zilliztech/milvus-distributed/internal/proto/querypb"
"github.com/zilliztech/milvus-distributed/internal/querynode"
)
type MasterServiceInterface interface {
......@@ -39,6 +39,11 @@ type QueryNodeInterface interface {
GetSegmentInfo(req *querypb.SegmentInfoRequest) (*querypb.SegmentInfoResponse, error)
}
type queryChannelInfo struct {
requestChannel string
responseChannel string
}
type QueryService struct {
loopCtx context.Context
loopCancel context.CancelFunc
......@@ -48,9 +53,9 @@ type QueryService struct {
dataServiceClient DataServiceInterface
masterServiceClient MasterServiceInterface
queryNodes map[UniqueID]*queryNodeInfo
numRegisterNode uint64
numQueryChannel uint64
queryNodes map[int64]*queryNodeInfo
queryChannels []*queryChannelInfo
qcMutex *sync.Mutex
stateCode atomic.Value
isInit atomic.Value
......@@ -124,37 +129,57 @@ func (qs *QueryService) GetStatisticsChannel() (string, error) {
func (qs *QueryService) RegisterNode(req *querypb.RegisterNodeRequest) (*querypb.RegisterNodeResponse, error) {
fmt.Println("register query node =", req.Address)
// TODO:: add mutex
allocatedID := len(qs.queryNodes)
nodeID := req.Base.SourceID
if _, ok := qs.queryNodes[nodeID]; ok {
err := errors.New("nodeID already exists")
return &querypb.RegisterNodeResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
Reason: err.Error(),
},
}, err
}
registerNodeAddress := req.Address.Ip + ":" + strconv.FormatInt(req.Address.Port, 10)
var node *queryNodeInfo
if qs.enableGrpc {
client := nodeclient.NewClient(registerNodeAddress)
if err := client.Init(); err != nil {
return &querypb.RegisterNodeResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
},
InitParams: new(internalpb2.InitParams),
}, err
}
if err := client.Start(); err != nil {
return nil, err
}
node = newQueryNodeInfo(client)
} else {
client := querynode.NewQueryNode(qs.loopCtx, uint64(allocatedID), qs.msFactory)
node = newQueryNodeInfo(client)
client := nodeclient.NewClient(registerNodeAddress)
if err := client.Init(); err != nil {
return &querypb.RegisterNodeResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
},
InitParams: new(internalpb2.InitParams),
}, err
}
qs.queryNodes[UniqueID(allocatedID)] = node
if err := client.Start(); err != nil {
return nil, err
}
qs.queryNodes[nodeID] = newQueryNodeInfo(client)
//TODO::return init params to queryNode
startParams := []*commonpb.KeyValuePair{
{Key: "StatsChannelName", Value: Params.StatsChannelName},
{Key: "TimeTickChannelName", Value: Params.TimeTickChannelName},
}
qs.qcMutex.Lock()
for _, queryChannel := range qs.queryChannels {
startParams = append(startParams, &commonpb.KeyValuePair{
Key: "QueryChannelName",
Value: queryChannel.requestChannel,
})
startParams = append(startParams, &commonpb.KeyValuePair{
Key: "QueryResultChannelName",
Value: queryChannel.responseChannel,
})
}
qs.qcMutex.Unlock()
return &querypb.RegisterNodeResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_SUCCESS,
},
InitParams: &internalpb2.InitParams{
NodeID: int64(allocatedID),
NodeID: nodeID,
StartParams: startParams,
},
}, nil
}
......@@ -499,18 +524,26 @@ func (qs *QueryService) ReleasePartitions(req *querypb.ReleasePartitionRequest)
}
func (qs *QueryService) CreateQueryChannel() (*querypb.CreateQueryChannelResponse, error) {
channelID := qs.numQueryChannel
qs.numQueryChannel++
channelID := len(qs.queryChannels)
allocatedQueryChannel := "query-" + strconv.FormatInt(int64(channelID), 10)
allocatedQueryResultChannel := "queryResult-" + strconv.FormatInt(int64(channelID), 10)
qs.qcMutex.Lock()
qs.queryChannels = append(qs.queryChannels, &queryChannelInfo{
requestChannel: allocatedQueryChannel,
responseChannel: allocatedQueryResultChannel,
})
addQueryChannelsRequest := &querypb.AddQueryChannelsRequest{
RequestChannelID: allocatedQueryChannel,
ResultChannelID: allocatedQueryResultChannel,
}
for _, node := range qs.queryNodes {
fmt.Println("query service create query channel, queryChannelName = ", allocatedQueryChannel)
for nodeID, node := range qs.queryNodes {
fmt.Println("node ", nodeID, " watch query channel")
_, err := node.AddQueryChannel(addQueryChannelsRequest)
if err != nil {
qs.qcMutex.Unlock()
return &querypb.CreateQueryChannelResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
......@@ -519,6 +552,7 @@ func (qs *QueryService) CreateQueryChannel() (*querypb.CreateQueryChannelRespons
}, err
}
}
qs.qcMutex.Unlock()
return &querypb.CreateQueryChannelResponse{
Status: &commonpb.Status{
......@@ -571,18 +605,18 @@ func (qs *QueryService) GetSegmentInfo(req *querypb.SegmentInfoRequest) (*queryp
}
func NewQueryService(ctx context.Context, factory msgstream.Factory) (*QueryService, error) {
nodes := make(map[UniqueID]*queryNodeInfo)
nodes := make(map[int64]*queryNodeInfo)
queryChannels := make([]*queryChannelInfo, 0)
ctx1, cancel := context.WithCancel(ctx)
replica := newMetaReplica()
service := &QueryService{
loopCtx: ctx1,
loopCancel: cancel,
replica: replica,
queryNodes: nodes,
numRegisterNode: 0,
numQueryChannel: 0,
enableGrpc: false,
msFactory: factory,
loopCtx: ctx1,
loopCancel: cancel,
replica: replica,
queryNodes: nodes,
queryChannels: queryChannels,
qcMutex: &sync.Mutex{},
msFactory: factory,
}
service.stateCode.Store(internalpb2.StateCode_INITIALIZING)
service.isInit.Store(false)
......@@ -597,10 +631,6 @@ func (qs *QueryService) SetDataService(dataService DataServiceInterface) {
qs.dataServiceClient = dataService
}
func (qs *QueryService) SetEnableGrpc(en bool) {
qs.enableGrpc = en
}
func (qs *QueryService) watchDmChannels(dbID UniqueID, collectionID UniqueID) error {
collection, err := qs.replica.getCollectionByID(0, collectionID)
if err != nil {
......@@ -620,7 +650,7 @@ func (qs *QueryService) watchDmChannels(dbID UniqueID, collectionID UniqueID) er
}
dmChannels := resp.Values
watchedChannels2NodeID := make(map[string]UniqueID)
watchedChannels2NodeID := make(map[string]int64)
unWatchedChannels := make([]string, 0)
for _, channel := range dmChannels {
findChannel := false
......@@ -647,7 +677,7 @@ func (qs *QueryService) watchDmChannels(dbID UniqueID, collectionID UniqueID) er
if err != nil {
return err
}
node2channels := make(map[UniqueID][]string)
node2channels := make(map[int64][]string)
for channel, nodeID := range channels2NodeID {
if _, ok := node2channels[nodeID]; ok {
node2channels[nodeID] = append(node2channels[nodeID], channel)
......@@ -674,13 +704,13 @@ func (qs *QueryService) watchDmChannels(dbID UniqueID, collectionID UniqueID) er
return nil
}
func (qs *QueryService) shuffleChannelsToQueryNode(dmChannels []string) map[string]UniqueID {
func (qs *QueryService) shuffleChannelsToQueryNode(dmChannels []string) map[string]int64 {
maxNumDMChannel := 0
res := make(map[string]UniqueID)
res := make(map[string]int64)
if len(dmChannels) == 0 {
return res
}
node2lens := make(map[UniqueID]int)
node2lens := make(map[int64]int)
for id, node := range qs.queryNodes {
node2lens[id] = len(node.dmChannelNames)
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册