diff --git a/internal/distributed/proxynode/config_key.go b/internal/distributed/proxynode/config_key.go new file mode 100644 index 0000000000000000000000000000000000000000..b22ba11e15896860911321cf7e963417525645c6 --- /dev/null +++ b/internal/distributed/proxynode/config_key.go @@ -0,0 +1,24 @@ +package grpcproxynode + +const ( + StartParamsKey = "START_PARAMS" + MasterPort = "master.port" + MasterHost = "master.address" + PulsarPort = "pulsar.port" + PulsarHost = "pulsar.address" + IndexServerPort = "indexBuilder.port" + IndexServerHost = "indexBuilder.address" + QueryNodeIDList = "nodeID.queryNodeIDList" + TimeTickInterval = "proxyNode.timeTickInterval" + SubName = "msgChannel.subNamePrefix.proxySubNamePrefix" + TimeTickChannelNames = "msgChannel.chanNamePrefix.proxyTimeTick" + MsgStreamInsertBufSize = "proxyNode.msgStream.insert.bufSize" + MsgStreamSearchBufSize = "proxyNode.msgStream.search.bufSize" + MsgStreamSearchResultBufSize = "proxyNode.msgStream.searchResult.recvBufSize" + MsgStreamSearchResultPulsarBufSize = "proxyNode.msgStream.searchResult.pulsarBufSize" + MsgStreamTimeTickBufSize = "proxyNode.msgStream.timeTick.bufSize" + MaxNameLength = "proxyNode.maxNameLength" + MaxFieldNum = "proxyNode.maxFieldNum" + MaxDimension = "proxyNode.MaxDimension" + DefaultPartitionTag = "common.defaultPartitionTag" +) diff --git a/internal/distributed/proxynode/paramtable.go b/internal/distributed/proxynode/paramtable.go new file mode 100644 index 0000000000000000000000000000000000000000..d2fddef81c2d866e632b4bed0ceb6ed87d03548e --- /dev/null +++ b/internal/distributed/proxynode/paramtable.go @@ -0,0 +1,46 @@ +package grpcproxynode + +import ( + "net" + "strconv" + + "github.com/zilliztech/milvus-distributed/internal/util/paramtable" +) + +type ParamTable struct { + paramtable.BaseTable + + ProxyServiceAddress string +} + +var Params ParamTable + +func (pt *ParamTable) Init() { + pt.BaseTable.Init() + + pt.initProxyServiceAddress() +} + +func (pt *ParamTable) initProxyServiceAddress() { + addr, err := pt.Load("proxyService.address") + if err != nil { + panic(err) + } + + hostName, _ := net.LookupHost(addr) + if len(hostName) <= 0 { + if ip := net.ParseIP(addr); ip == nil { + panic("invalid ip proxyService.address") + } + } + + port, err := pt.Load("proxyService.port") + if err != nil { + panic(err) + } + _, err = strconv.Atoi(port) + if err != nil { + panic(err) + } + pt.ProxyServiceAddress = addr + ":" + port +} diff --git a/internal/distributed/proxynode/service.go b/internal/distributed/proxynode/service.go index 825aae916914a6ab5a088dbe095a6ca86d546a5e..f98ffe20b27b6967db04cf4426cfaa8962307c11 100644 --- a/internal/distributed/proxynode/service.go +++ b/internal/distributed/proxynode/service.go @@ -1,11 +1,20 @@ package grpcproxynode import ( + "bytes" "context" "net" "os" "strconv" + "strings" "sync" + "time" + + "github.com/zilliztech/milvus-distributed/internal/util/typeutil" + + "github.com/spf13/viper" + + "github.com/zilliztech/milvus-distributed/internal/proto/internalpb2" "github.com/go-basic/ipv4" @@ -36,10 +45,125 @@ func CreateProxyNodeServer() (*Server, error) { return &Server{}, nil } +func (s *Server) loadConfigFromInitParams(initParams *internalpb2.InitParams) error { + proxynode.Params.ProxyID = initParams.NodeID + + config := viper.New() + config.SetConfigType("yaml") + for _, pair := range initParams.StartParams { + if pair.Key == StartParamsKey { + err := config.ReadConfig(bytes.NewBuffer([]byte(pair.Value))) + if err != nil { + return err + } + break + } + } + + masterPort := config.GetString(MasterPort) + masterHost := config.GetString(MasterHost) + proxynode.Params.MasterAddress = masterHost + ":" + masterPort + + pulsarPort := config.GetString(PulsarPort) + pulsarHost := config.GetString(PulsarHost) + proxynode.Params.PulsarAddress = pulsarHost + ":" + pulsarPort + + indexServerPort := config.GetString(IndexServerPort) + indexServerHost := config.GetString(IndexServerHost) + proxynode.Params.IndexServerAddress = indexServerHost + ":" + indexServerPort + + queryNodeIDList := config.GetString(QueryNodeIDList) + proxynode.Params.QueryNodeIDList = nil + queryNodeIDs := strings.Split(queryNodeIDList, ",") + for _, queryNodeID := range queryNodeIDs { + v, err := strconv.Atoi(queryNodeID) + if err != nil { + return err + } + proxynode.Params.QueryNodeIDList = append(proxynode.Params.QueryNodeIDList, typeutil.UniqueID(v)) + } + proxynode.Params.QueryNodeNum = len(proxynode.Params.QueryNodeIDList) + + timeTickInterval := config.GetString(TimeTickInterval) + interval, err := strconv.Atoi(timeTickInterval) + if err != nil { + return err + } + proxynode.Params.TimeTickInterval = time.Duration(interval) * time.Millisecond + + subName := config.GetString(SubName) + proxynode.Params.ProxySubName = subName + + timeTickChannelNames := config.GetString(TimeTickChannelNames) + proxynode.Params.ProxyTimeTickChannelNames = []string{timeTickChannelNames} + + msgStreamInsertBufSizeStr := config.GetString(MsgStreamInsertBufSize) + msgStreamInsertBufSize, err := strconv.Atoi(msgStreamInsertBufSizeStr) + if err != nil { + return err + } + proxynode.Params.MsgStreamInsertBufSize = int64(msgStreamInsertBufSize) + + msgStreamSearchBufSizeStr := config.GetString(MsgStreamSearchBufSize) + msgStreamSearchBufSize, err := strconv.Atoi(msgStreamSearchBufSizeStr) + if err != nil { + return err + } + proxynode.Params.MsgStreamSearchBufSize = int64(msgStreamSearchBufSize) + + msgStreamSearchResultBufSizeStr := config.GetString(MsgStreamSearchResultBufSize) + msgStreamSearchResultBufSize, err := strconv.Atoi(msgStreamSearchResultBufSizeStr) + if err != nil { + return err + } + proxynode.Params.MsgStreamSearchResultBufSize = int64(msgStreamSearchResultBufSize) + + msgStreamSearchResultPulsarBufSizeStr := config.GetString(MsgStreamSearchResultPulsarBufSize) + msgStreamSearchResultPulsarBufSize, err := strconv.Atoi(msgStreamSearchResultPulsarBufSizeStr) + if err != nil { + return err + } + proxynode.Params.MsgStreamSearchResultPulsarBufSize = int64(msgStreamSearchResultPulsarBufSize) + + msgStreamTimeTickBufSizeStr := config.GetString(MsgStreamTimeTickBufSize) + msgStreamTimeTickBufSize, err := strconv.Atoi(msgStreamTimeTickBufSizeStr) + if err != nil { + return err + } + proxynode.Params.MsgStreamTimeTickBufSize = int64(msgStreamTimeTickBufSize) + + maxNameLengthStr := config.GetString(MaxNameLength) + maxNameLength, err := strconv.Atoi(maxNameLengthStr) + if err != nil { + return err + } + proxynode.Params.MaxNameLength = int64(maxNameLength) + + maxFieldNumStr := config.GetString(MaxFieldNum) + maxFieldNum, err := strconv.Atoi(maxFieldNumStr) + if err != nil { + return err + } + proxynode.Params.MaxFieldNum = int64(maxFieldNum) + + maxDimensionStr := config.GetString(MaxDimension) + maxDimension, err := strconv.Atoi(maxDimensionStr) + if err != nil { + return err + } + proxynode.Params.MaxDimension = int64(maxDimension) + + defaultPartitionTag := config.GetString(DefaultPartitionTag) + proxynode.Params.DefaultPartitionTag = defaultPartitionTag + + return nil +} + func (s *Server) connectProxyService() error { + Params.Init() proxynode.Params.Init() - s.proxyServiceAddress = proxynode.Params.ProxyServiceAddress() + s.proxyServiceAddress = Params.ProxyServiceAddress s.proxyServiceClient = grpcproxyservice.NewClient(s.proxyServiceAddress) getAvailablePort := func() int { @@ -74,13 +198,7 @@ func (s *Server) connectProxyService() error { panic(err) } - proxynode.Params.Save("_proxyID", strconv.Itoa(int(response.InitParams.NodeID))) - - for _, params := range response.InitParams.StartParams { - proxynode.Params.Save(params.Key, params.Value) - } - - return err + return s.loadConfigFromInitParams(response.InitParams) } func (s *Server) Init() error { diff --git a/internal/proxynode/component_clients.go b/internal/proxynode/component_clients.go new file mode 100644 index 0000000000000000000000000000000000000000..aa948e38dc84b9d81593531e40a9e0875e2ed9b1 --- /dev/null +++ b/internal/proxynode/component_clients.go @@ -0,0 +1,66 @@ +package proxynode + +import ( + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/proto/indexpb" + "github.com/zilliztech/milvus-distributed/internal/proto/milvuspb" +) + +type MasterClientInterface interface { + Init() error + Start() error + Stop() error + + CreateCollection(in *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) + DropCollection(in *milvuspb.DropCollectionRequest) (*commonpb.Status, error) + HasCollection(in *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) + DescribeCollection(in *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) + ShowCollections(in *milvuspb.ShowCollectionRequest) (*milvuspb.ShowCollectionResponse, error) + CreatePartition(in *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) + DropPartition(in *milvuspb.DropPartitionRequest) (*commonpb.Status, error) + HasPartition(in *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) + ShowPartitions(in *milvuspb.ShowPartitionRequest) (*milvuspb.ShowPartitionResponse, error) + CreateIndex(in *milvuspb.CreateIndexRequest) (*commonpb.Status, error) + DescribeIndex(in *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) +} + +type IndexServiceClient interface { + Init() error + Start() error + Stop() error + + GetIndexStates(req *indexpb.IndexStatesRequest) (*indexpb.IndexStatesResponse, error) +} + +type QueryServiceClient interface { + Init() error + Start() error + Stop() error + + GetSearchChannelNames() ([]string, error) + GetSearchResultChannelNames() ([]string, error) +} + +type DataServiceClient interface { + Init() error + Start() error + Stop() error + + GetInsertChannelNames() ([]string, error) +} + +func (node *NodeImpl) SetMasterClient(cli MasterClientInterface) { + node.masterClient = cli +} + +func (node *NodeImpl) SetIndexServiceClient(cli IndexServiceClient) { + node.indexServiceClient = cli +} + +func (node *NodeImpl) SetQueryServiceClient(cli QueryServiceClient) { + node.queryServiceClient = cli +} + +func (node *NodeImpl) SetDataServiceClient(cli DataServiceClient) { + node.dataServiceClient = cli +} diff --git a/internal/proxynode/impl.go b/internal/proxynode/impl.go index 37dc99332cad3539547cab3f7b8001c8449cb2a2..748fcefd1141cdb4e6fa1940453aa310e08fedfd 100644 --- a/internal/proxynode/impl.go +++ b/internal/proxynode/impl.go @@ -503,7 +503,6 @@ func (node *NodeImpl) GetIndexState(ctx context.Context, request *milvuspb.Index dipt := &GetIndexStateTask{ Condition: NewTaskCondition(ctx), IndexStateRequest: request, - masterClient: node.masterClient, } var cancel func() @@ -568,7 +567,7 @@ func (node *NodeImpl) Insert(ctx context.Context, request *milvuspb.InsertReques rowIDAllocator: node.idAllocator, } if len(it.PartitionName) <= 0 { - it.PartitionName = Params.defaultPartitionTag() + it.PartitionName = Params.DefaultPartitionTag } var cancel func() @@ -621,9 +620,9 @@ func (node *NodeImpl) Search(ctx context.Context, request *milvuspb.SearchReques SearchRequest: internalpb2.SearchRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_kSearch, - SourceID: Params.ProxyID(), + SourceID: Params.ProxyID, }, - ResultChannelID: strconv.FormatInt(Params.ProxyID(), 10), + ResultChannelID: strconv.FormatInt(Params.ProxyID, 10), }, queryMsgStream: node.queryMsgStream, resultBuf: make(chan []*internalpb2.SearchResults), diff --git a/internal/proxynode/paramtable.go b/internal/proxynode/paramtable.go index 8dcce983df82872b3513c648ce0ad2550cbfb5e5..40e5a2bdcd10b9a90b55046f7b20d450048376f5 100644 --- a/internal/proxynode/paramtable.go +++ b/internal/proxynode/paramtable.go @@ -13,6 +13,34 @@ import ( type ParamTable struct { paramtable.BaseTable + + NetworkPort int + NetworkAddress string + ProxyServiceAddress string + MasterAddress string + PulsarAddress string + IndexServerAddress string + QueryNodeNum int + QueryNodeIDList []UniqueID + ProxyID UniqueID + TimeTickInterval time.Duration + InsertChannelNames []string + DeleteChannelNames []string + K2SChannelNames []string + SearchChannelNames []string + SearchResultChannelNames []string + ProxySubName string + ProxyTimeTickChannelNames []string + DataDefinitionChannelNames []string + MsgStreamInsertBufSize int64 + MsgStreamSearchBufSize int64 + MsgStreamSearchResultBufSize int64 + MsgStreamSearchResultPulsarBufSize int64 + MsgStreamTimeTickBufSize int64 + MaxNameLength int64 + MaxFieldNum int64 + MaxDimension int64 + DefaultPartitionTag string } var Params ParamTable @@ -20,28 +48,40 @@ var Params ParamTable func (pt *ParamTable) Init() { pt.BaseTable.Init() - err := pt.LoadYaml("advanced/proxy_node.yaml") - if err != nil { - panic(err) - } - - proxyIDStr := os.Getenv("PROXY_ID") - if proxyIDStr == "" { - proxyIDList := pt.ProxyIDList() - if len(proxyIDList) <= 0 { - proxyIDStr = "0" - } else { - proxyIDStr = strconv.Itoa(int(proxyIDList[0])) - } - } - pt.Save("_proxyID", proxyIDStr) + pt.initNetworkPort() + pt.initNetworkAddress() + pt.initProxyServiceAddress() + pt.initMasterAddress() + pt.initPulsarAddress() + pt.initIndexServerAddress() + pt.initQueryNodeIDList() + pt.initQueryNodeNum() + pt.initProxyID() + pt.initTimeTickInterval() + pt.initInsertChannelNames() + pt.initDeleteChannelNames() + pt.initK2SChannelNames() + pt.initSearchChannelNames() + pt.initSearchResultChannelNames() + pt.initProxySubName() + pt.initProxyTimeTickChannelNames() + pt.initDataDefinitionChannelNames() + pt.initMsgStreamInsertBufSize() + pt.initMsgStreamSearchBufSize() + pt.initMsgStreamSearchResultBufSize() + pt.initMsgStreamSearchResultPulsarBufSize() + pt.initMsgStreamTimeTickBufSize() + pt.initMaxNameLength() + pt.initMaxFieldNum() + pt.initMaxDimension() + pt.initDefaultPartitionTag() } -func (pt *ParamTable) NetworkPort() int { - return pt.ParseInt("proxyNode.port") +func (pt *ParamTable) initNetworkPort() { + pt.NetworkPort = pt.ParseInt("proxyNode.port") } -func (pt *ParamTable) NetworkAddress() string { +func (pt *ParamTable) initNetworkAddress() { addr, err := pt.Load("proxyNode.address") if err != nil { panic(err) @@ -62,14 +102,14 @@ func (pt *ParamTable) NetworkAddress() string { if err != nil { panic(err) } - return addr + ":" + port + + pt.NetworkAddress = addr + ":" + port } -func (pt *ParamTable) ProxyServiceAddress() string { +func (pt *ParamTable) initProxyServiceAddress() { addressFromEnv := os.Getenv("PROXY_SERVICE_ADDRESS") if len(addressFromEnv) > 0 { - // TODO: or write to param table? - return addressFromEnv + pt.ProxyServiceAddress = addressFromEnv } addr, err := pt.Load("proxyService.address") @@ -92,35 +132,55 @@ func (pt *ParamTable) ProxyServiceAddress() string { if err != nil { panic(err) } - return addr + ":" + port + pt.ProxyServiceAddress = addr + ":" + port } -func (pt *ParamTable) MasterAddress() string { +func (pt *ParamTable) initMasterAddress() { ret, err := pt.Load("_MasterAddress") if err != nil { panic(err) } - return ret + pt.MasterAddress = ret } -func (pt *ParamTable) PulsarAddress() string { +func (pt *ParamTable) initPulsarAddress() { ret, err := pt.Load("_PulsarAddress") if err != nil { panic(err) } - return ret + pt.PulsarAddress = ret } -func (pt *ParamTable) ProxyNum() int { - ret := pt.ProxyIDList() - return len(ret) +func (pt *ParamTable) initIndexServerAddress() { + addr, err := pt.Load("indexServer.address") + if err != nil { + panic(err) + } + + hostName, _ := net.LookupHost(addr) + if len(hostName) <= 0 { + if ip := net.ParseIP(addr); ip == nil { + panic("invalid ip indexServer.address") + } + } + + port, err := pt.Load("indexServer.port") + if err != nil { + panic(err) + } + _, err = strconv.Atoi(port) + if err != nil { + panic(err) + } + + pt.IndexServerAddress = addr + ":" + port } -func (pt *ParamTable) queryNodeNum() int { - return len(pt.queryNodeIDList()) +func (pt *ParamTable) initQueryNodeNum() { + pt.QueryNodeNum = len(pt.QueryNodeIDList) } -func (pt *ParamTable) queryNodeIDList() []UniqueID { +func (pt *ParamTable) initQueryNodeIDList() []UniqueID { queryNodeIDStr, err := pt.Load("nodeID.queryNodeIDList") if err != nil { panic(err) @@ -137,7 +197,7 @@ func (pt *ParamTable) queryNodeIDList() []UniqueID { return ret } -func (pt *ParamTable) ProxyID() UniqueID { +func (pt *ParamTable) initProxyID() { proxyID, err := pt.Load("_proxyID") if err != nil { panic(err) @@ -146,10 +206,10 @@ func (pt *ParamTable) ProxyID() UniqueID { if err != nil { panic(err) } - return UniqueID(ID) + pt.ProxyID = UniqueID(ID) } -func (pt *ParamTable) TimeTickInterval() time.Duration { +func (pt *ParamTable) initTimeTickInterval() { internalStr, err := pt.Load("proxyNode.timeTickInterval") if err != nil { panic(err) @@ -158,21 +218,10 @@ func (pt *ParamTable) TimeTickInterval() time.Duration { if err != nil { panic(err) } - return time.Duration(interval) * time.Millisecond + pt.TimeTickInterval = time.Duration(interval) * time.Millisecond } -func (pt *ParamTable) sliceIndex() int { - proxyID := pt.ProxyID() - proxyIDList := pt.ProxyIDList() - for i := 0; i < len(proxyIDList); i++ { - if proxyID == proxyIDList[i] { - return i - } - } - return -1 -} - -func (pt *ParamTable) InsertChannelNames() []string { +func (pt *ParamTable) initInsertChannelNames() { prefix, err := pt.Load("msgChannel.chanNamePrefix.insert") if err != nil { panic(err) @@ -188,17 +237,10 @@ func (pt *ParamTable) InsertChannelNames() []string { ret = append(ret, prefix+strconv.Itoa(ID)) } - proxyNum := pt.ProxyNum() - sep := len(channelIDs) / proxyNum - index := pt.sliceIndex() - if index == -1 { - panic("ProxyID not Match with Config") - } - start := index * sep - return ret[start : start+sep] + pt.InsertChannelNames = ret } -func (pt *ParamTable) DeleteChannelNames() []string { +func (pt *ParamTable) initDeleteChannelNames() { prefix, err := pt.Load("msgChannel.chanNamePrefix.delete") if err != nil { panic(err) @@ -213,10 +255,10 @@ func (pt *ParamTable) DeleteChannelNames() []string { for _, ID := range channelIDs { ret = append(ret, prefix+strconv.Itoa(ID)) } - return ret + pt.DeleteChannelNames = ret } -func (pt *ParamTable) K2SChannelNames() []string { +func (pt *ParamTable) initK2SChannelNames() { prefix, err := pt.Load("msgChannel.chanNamePrefix.k2s") if err != nil { panic(err) @@ -231,10 +273,10 @@ func (pt *ParamTable) K2SChannelNames() []string { for _, ID := range channelIDs { ret = append(ret, prefix+strconv.Itoa(ID)) } - return ret + pt.K2SChannelNames = ret } -func (pt *ParamTable) SearchChannelNames() []string { +func (pt *ParamTable) initSearchChannelNames() { prefix, err := pt.Load("msgChannel.chanNamePrefix.search") if err != nil { panic(err) @@ -249,10 +291,10 @@ func (pt *ParamTable) SearchChannelNames() []string { for _, ID := range channelIDs { ret = append(ret, prefix+strconv.Itoa(ID)) } - return ret + pt.SearchChannelNames = ret } -func (pt *ParamTable) SearchResultChannelNames() []string { +func (pt *ParamTable) initSearchResultChannelNames() { prefix, err := pt.Load("msgChannel.chanNamePrefix.searchResult") if err != nil { panic(err) @@ -267,18 +309,10 @@ func (pt *ParamTable) SearchResultChannelNames() []string { for _, ID := range channelIDs { ret = append(ret, prefix+strconv.Itoa(ID)) } - proxyNum := pt.ProxyNum() - sep := len(channelIDs) / proxyNum - index := pt.sliceIndex() - if index == -1 { - panic("ProxyID not Match with Config") - } - start := index * sep - - return ret[start : start+sep] + pt.SearchResultChannelNames = ret } -func (pt *ParamTable) ProxySubName() string { +func (pt *ParamTable) initProxySubName() { prefix, err := pt.Load("msgChannel.subNamePrefix.proxySubNamePrefix") if err != nil { panic(err) @@ -287,48 +321,48 @@ func (pt *ParamTable) ProxySubName() string { if err != nil { panic(err) } - return prefix + "-" + proxyIDStr + pt.ProxySubName = prefix + "-" + proxyIDStr } -func (pt *ParamTable) ProxyTimeTickChannelNames() []string { +func (pt *ParamTable) initProxyTimeTickChannelNames() { prefix, err := pt.Load("msgChannel.chanNamePrefix.proxyTimeTick") if err != nil { panic(err) } prefix += "-0" - return []string{prefix} + pt.ProxyTimeTickChannelNames = []string{prefix} } -func (pt *ParamTable) DataDefinitionChannelNames() []string { +func (pt *ParamTable) initDataDefinitionChannelNames() { prefix, err := pt.Load("msgChannel.chanNamePrefix.dataDefinition") if err != nil { panic(err) } prefix += "-0" - return []string{prefix} + pt.DataDefinitionChannelNames = []string{prefix} } -func (pt *ParamTable) MsgStreamInsertBufSize() int64 { - return pt.ParseInt64("proxyNode.msgStream.insert.bufSize") +func (pt *ParamTable) initMsgStreamInsertBufSize() { + pt.MsgStreamInsertBufSize = pt.ParseInt64("proxyNode.msgStream.insert.bufSize") } -func (pt *ParamTable) MsgStreamSearchBufSize() int64 { - return pt.ParseInt64("proxyNode.msgStream.search.bufSize") +func (pt *ParamTable) initMsgStreamSearchBufSize() { + pt.MsgStreamSearchBufSize = pt.ParseInt64("proxyNode.msgStream.search.bufSize") } -func (pt *ParamTable) MsgStreamSearchResultBufSize() int64 { - return pt.ParseInt64("proxyNode.msgStream.searchResult.recvBufSize") +func (pt *ParamTable) initMsgStreamSearchResultBufSize() { + pt.MsgStreamSearchResultBufSize = pt.ParseInt64("proxyNode.msgStream.searchResult.recvBufSize") } -func (pt *ParamTable) MsgStreamSearchResultPulsarBufSize() int64 { - return pt.ParseInt64("proxyNode.msgStream.searchResult.pulsarBufSize") +func (pt *ParamTable) initMsgStreamSearchResultPulsarBufSize() { + pt.MsgStreamSearchResultPulsarBufSize = pt.ParseInt64("proxyNode.msgStream.searchResult.pulsarBufSize") } -func (pt *ParamTable) MsgStreamTimeTickBufSize() int64 { - return pt.ParseInt64("proxyNode.msgStream.timeTick.bufSize") +func (pt *ParamTable) initMsgStreamTimeTickBufSize() { + pt.MsgStreamTimeTickBufSize = pt.ParseInt64("proxyNode.msgStream.timeTick.bufSize") } -func (pt *ParamTable) MaxNameLength() int64 { +func (pt *ParamTable) initMaxNameLength() { str, err := pt.Load("proxyNode.maxNameLength") if err != nil { panic(err) @@ -337,10 +371,10 @@ func (pt *ParamTable) MaxNameLength() int64 { if err != nil { panic(err) } - return maxNameLength + pt.MaxNameLength = maxNameLength } -func (pt *ParamTable) MaxFieldNum() int64 { +func (pt *ParamTable) initMaxFieldNum() { str, err := pt.Load("proxyNode.maxFieldNum") if err != nil { panic(err) @@ -349,10 +383,10 @@ func (pt *ParamTable) MaxFieldNum() int64 { if err != nil { panic(err) } - return maxFieldNum + pt.MaxFieldNum = maxFieldNum } -func (pt *ParamTable) MaxDimension() int64 { +func (pt *ParamTable) initMaxDimension() { str, err := pt.Load("proxyNode.maxDimension") if err != nil { panic(err) @@ -361,13 +395,13 @@ func (pt *ParamTable) MaxDimension() int64 { if err != nil { panic(err) } - return maxDimension + pt.MaxDimension = maxDimension } -func (pt *ParamTable) defaultPartitionTag() string { +func (pt *ParamTable) initDefaultPartitionTag() { tag, err := pt.Load("common.defaultPartitionTag") if err != nil { panic(err) } - return tag + pt.DefaultPartitionTag = tag } diff --git a/internal/proxynode/paramtable_test.go b/internal/proxynode/paramtable_test.go index 72a3e0ab51f13eadf9f5c9ee84d619d69567a17e..c1da4f6583025ffbfd9c734f7b518f6d12f41b18 100644 --- a/internal/proxynode/paramtable_test.go +++ b/internal/proxynode/paramtable_test.go @@ -6,71 +6,71 @@ import ( ) func TestParamTable_InsertChannelRange(t *testing.T) { - ret := Params.InsertChannelNames() + ret := Params.InsertChannelNames fmt.Println(ret) } func TestParamTable_DeleteChannelNames(t *testing.T) { - ret := Params.DeleteChannelNames() + ret := Params.DeleteChannelNames fmt.Println(ret) } func TestParamTable_K2SChannelNames(t *testing.T) { - ret := Params.K2SChannelNames() + ret := Params.K2SChannelNames fmt.Println(ret) } func TestParamTable_SearchChannelNames(t *testing.T) { - ret := Params.SearchChannelNames() + ret := Params.SearchChannelNames fmt.Println(ret) } func TestParamTable_SearchResultChannelNames(t *testing.T) { - ret := Params.SearchResultChannelNames() + ret := Params.SearchResultChannelNames fmt.Println(ret) } func TestParamTable_ProxySubName(t *testing.T) { - ret := Params.ProxySubName() + ret := Params.ProxySubName fmt.Println(ret) } func TestParamTable_ProxyTimeTickChannelNames(t *testing.T) { - ret := Params.ProxyTimeTickChannelNames() + ret := Params.ProxyTimeTickChannelNames fmt.Println(ret) } func TestParamTable_DataDefinitionChannelNames(t *testing.T) { - ret := Params.DataDefinitionChannelNames() + ret := Params.DataDefinitionChannelNames fmt.Println(ret) } func TestParamTable_MsgStreamInsertBufSize(t *testing.T) { - ret := Params.MsgStreamInsertBufSize() + ret := Params.MsgStreamInsertBufSize fmt.Println(ret) } func TestParamTable_MsgStreamSearchBufSize(t *testing.T) { - ret := Params.MsgStreamSearchBufSize() + ret := Params.MsgStreamSearchBufSize fmt.Println(ret) } func TestParamTable_MsgStreamSearchResultBufSize(t *testing.T) { - ret := Params.MsgStreamSearchResultBufSize() + ret := Params.MsgStreamSearchResultBufSize fmt.Println(ret) } func TestParamTable_MsgStreamSearchResultPulsarBufSize(t *testing.T) { - ret := Params.MsgStreamSearchResultPulsarBufSize() + ret := Params.MsgStreamSearchResultPulsarBufSize fmt.Println(ret) } func TestParamTable_MsgStreamTimeTickBufSize(t *testing.T) { - ret := Params.MsgStreamTimeTickBufSize() + ret := Params.MsgStreamTimeTickBufSize fmt.Println(ret) } func TestParamTable_defaultPartitionTag(t *testing.T) { - ret := Params.defaultPartitionTag() + ret := Params.DefaultPartitionTag fmt.Println("default partition tag: ", ret) } diff --git a/internal/proxynode/proxy_node.go b/internal/proxynode/proxy_node.go index ac52869fcc19eb128c1a63db00c47ff60d46df58..11c52320af98865dff46496c8b22ebf1e5970d94 100644 --- a/internal/proxynode/proxy_node.go +++ b/internal/proxynode/proxy_node.go @@ -4,25 +4,19 @@ import ( "context" "fmt" "io" - "log" "math/rand" "sync" "time" "github.com/zilliztech/milvus-distributed/internal/msgstream/pulsarms" - grpcproxyservice "github.com/zilliztech/milvus-distributed/internal/distributed/proxyservice" - "github.com/zilliztech/milvus-distributed/internal/proto/internalpb2" "github.com/opentracing/opentracing-go" "github.com/uber/jaeger-client-go/config" - "google.golang.org/grpc" - "github.com/zilliztech/milvus-distributed/internal/allocator" "github.com/zilliztech/milvus-distributed/internal/msgstream" - "github.com/zilliztech/milvus-distributed/internal/proto/masterpb" "github.com/zilliztech/milvus-distributed/internal/util/typeutil" ) @@ -34,15 +28,17 @@ type NodeImpl struct { cancel func() wg sync.WaitGroup - proxyServiceClient *grpcproxyservice.Client - initParams *internalpb2.InitParams - ip string - port int + initParams *internalpb2.InitParams + ip string + port int - masterConn *grpc.ClientConn - masterClient masterpb.MasterServiceClient - sched *TaskScheduler - tick *timeTick + masterClient MasterClientInterface + indexServiceClient IndexServiceClient + queryServiceClient QueryServiceClient + dataServiceClient DataServiceClient + + sched *TaskScheduler + tick *timeTick idAllocator *allocator.IDAllocator tsoAllocator *allocator.TimestampAllocator @@ -75,6 +71,36 @@ func (node *NodeImpl) Init() error { var err error + err = node.masterClient.Init() + if err != nil { + return err + } + err = node.indexServiceClient.Init() + if err != nil { + return err + } + err = node.queryServiceClient.Init() + if err != nil { + return err + } + err = node.dataServiceClient.Init() + if err != nil { + return err + } + + Params.SearchChannelNames, err = node.queryServiceClient.GetSearchChannelNames() + if err != nil { + return err + } + Params.SearchResultChannelNames, err = node.queryServiceClient.GetSearchResultChannelNames() + if err != nil { + return err + } + Params.InsertChannelNames, err = node.dataServiceClient.GetInsertChannelNames() + if err != nil { + return err + } + cfg := &config.Configuration{ ServiceName: "proxynode", Sampler: &config.SamplerConfig{ @@ -88,38 +114,38 @@ func (node *NodeImpl) Init() error { } opentracing.SetGlobalTracer(node.tracer) - pulsarAddress := Params.PulsarAddress() + pulsarAddress := Params.PulsarAddress - node.queryMsgStream = pulsarms.NewPulsarMsgStream(node.ctx, Params.MsgStreamSearchBufSize()) + node.queryMsgStream = pulsarms.NewPulsarMsgStream(node.ctx, Params.MsgStreamSearchBufSize) node.queryMsgStream.SetPulsarClient(pulsarAddress) - node.queryMsgStream.CreatePulsarProducers(Params.SearchChannelNames()) + node.queryMsgStream.CreatePulsarProducers(Params.SearchChannelNames) - masterAddr := Params.MasterAddress() + masterAddr := Params.MasterAddress idAllocator, err := allocator.NewIDAllocator(node.ctx, masterAddr) if err != nil { return err } node.idAllocator = idAllocator - node.idAllocator.PeerID = Params.ProxyID() + node.idAllocator.PeerID = Params.ProxyID tsoAllocator, err := allocator.NewTimestampAllocator(node.ctx, masterAddr) if err != nil { return err } node.tsoAllocator = tsoAllocator - node.tsoAllocator.PeerID = Params.ProxyID() + node.tsoAllocator.PeerID = Params.ProxyID segAssigner, err := allocator.NewSegIDAssigner(node.ctx, masterAddr, node.lastTick) if err != nil { panic(err) } node.segAssigner = segAssigner - node.segAssigner.PeerID = Params.ProxyID() + node.segAssigner.PeerID = Params.ProxyID - node.manipulationMsgStream = pulsarms.NewPulsarMsgStream(node.ctx, Params.MsgStreamInsertBufSize()) + node.manipulationMsgStream = pulsarms.NewPulsarMsgStream(node.ctx, Params.MsgStreamInsertBufSize) node.manipulationMsgStream.SetPulsarClient(pulsarAddress) - node.manipulationMsgStream.CreatePulsarProducers(Params.InsertChannelNames()) + node.manipulationMsgStream.CreatePulsarProducers(Params.InsertChannelNames) repackFuncImpl := func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) { return insertRepackFunc(tsMsgs, hashKeys, node.segAssigner, true) } @@ -136,7 +162,20 @@ func (node *NodeImpl) Init() error { } func (node *NodeImpl) Start() error { - err := node.connectMaster() + var err error + err = node.masterClient.Start() + if err != nil { + return err + } + err = node.indexServiceClient.Start() + if err != nil { + return err + } + err = node.queryServiceClient.Start() + if err != nil { + return err + } + err = node.dataServiceClient.Start() if err != nil { return err } @@ -167,6 +206,23 @@ func (node *NodeImpl) Stop() error { node.manipulationMsgStream.Close() node.queryMsgStream.Close() node.tick.Close() + var err error + err = node.dataServiceClient.Stop() + if err != nil { + return err + } + err = node.queryServiceClient.Stop() + if err != nil { + return err + } + err = node.indexServiceClient.Stop() + if err != nil { + return err + } + err = node.masterClient.Stop() + if err != nil { + return err + } node.wg.Wait() @@ -197,19 +253,3 @@ func (node *NodeImpl) lastTick() Timestamp { func (node *NodeImpl) AddCloseCallback(callbacks ...func()) { node.closeCallbacks = append(node.closeCallbacks, callbacks...) } - -func (node *NodeImpl) connectMaster() error { - masterAddr := Params.MasterAddress() - log.Printf("NodeImpl connected to master, master_addr=%s", masterAddr) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - conn, err := grpc.DialContext(ctx, masterAddr, grpc.WithInsecure(), grpc.WithBlock()) - if err != nil { - log.Printf("NodeImpl connect to master failed, error= %v", err) - return err - } - log.Printf("NodeImpl connected to master, master_addr=%s", masterAddr) - node.masterConn = conn - node.masterClient = masterpb.NewMasterServiceClient(conn) - return nil -} diff --git a/internal/proxynode/proxy_node_test.go b/internal/proxynode/proxy_node_test.go index 1592d8c7a403a80aa8c78af7d361f3bd89cadb02..8589307fe5e3e6d86277aaad225d56210c6e91da 100644 --- a/internal/proxynode/proxy_node_test.go +++ b/internal/proxynode/proxy_node_test.go @@ -345,7 +345,7 @@ func TestProxy_Search(t *testing.T) { queryResultChannels := []string{"QueryResult"} bufSize := 1024 queryResultMsgStream := pulsarms.NewPulsarMsgStream(ctx, int64(bufSize)) - pulsarAddress := Params.PulsarAddress() + pulsarAddress := Params.PulsarAddress queryResultMsgStream.SetPulsarClient(pulsarAddress) assert.NotEqual(t, queryResultMsgStream, nil, "query result message stream should not be nil!") queryResultMsgStream.CreatePulsarProducers(queryResultChannels) @@ -417,7 +417,7 @@ func TestProxy_AssignSegID(t *testing.T) { testNum := 1 futureTS := tsoutil.ComposeTS(time.Now().Add(time.Second*-1000).UnixNano()/int64(time.Millisecond), 0) for i := 0; i < testNum; i++ { - segID, err := proxyServer.segAssigner.GetSegmentID(collectionName, Params.defaultPartitionTag(), int32(i), 200000, futureTS) + segID, err := proxyServer.segAssigner.GetSegmentID(collectionName, Params.DefaultPartitionTag, int32(i), 200000, futureTS) assert.Nil(t, err) fmt.Println("segID", segID) } diff --git a/internal/proxynode/task.go b/internal/proxynode/task.go index 2d91b0dadffe30d0c490652bf82bcfb0e00a1f42..a8adbb4dd46cf60d29c409cc7edb6c413ed271e8 100644 --- a/internal/proxynode/task.go +++ b/internal/proxynode/task.go @@ -16,7 +16,6 @@ import ( "github.com/zilliztech/milvus-distributed/internal/msgstream/pulsarms" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" "github.com/zilliztech/milvus-distributed/internal/proto/internalpb2" - "github.com/zilliztech/milvus-distributed/internal/proto/masterpb" "github.com/zilliztech/milvus-distributed/internal/proto/milvuspb" "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" "github.com/zilliztech/milvus-distributed/internal/util/typeutil" @@ -84,7 +83,7 @@ func (it *InsertTask) Type() commonpb.MsgType { func (it *InsertTask) PreExecute() error { it.Base.MsgType = commonpb.MsgType_kInsert - it.Base.SourceID = Params.ProxyID() + it.Base.SourceID = Params.ProxyID span, ctx := opentracing.StartSpanFromContext(it.ctx, "InsertTask preExecute") defer span.Finish() @@ -189,7 +188,7 @@ func (it *InsertTask) PostExecute() error { type CreateCollectionTask struct { Condition *milvuspb.CreateCollectionRequest - masterClient masterpb.MasterServiceClient + masterClient MasterClientInterface result *commonpb.Status ctx context.Context schema *schemapb.CollectionSchema @@ -226,7 +225,7 @@ func (cct *CreateCollectionTask) SetTs(ts Timestamp) { func (cct *CreateCollectionTask) PreExecute() error { cct.Base.MsgType = commonpb.MsgType_kCreateCollection - cct.Base.SourceID = Params.ProxyID() + cct.Base.SourceID = Params.ProxyID cct.schema = &schemapb.CollectionSchema{} err := proto.Unmarshal(cct.Schema, cct.schema) @@ -234,8 +233,8 @@ func (cct *CreateCollectionTask) PreExecute() error { return err } - if int64(len(cct.schema.Fields)) > Params.MaxFieldNum() { - return errors.New("maximum field's number should be limited to " + strconv.FormatInt(Params.MaxFieldNum(), 10)) + if int64(len(cct.schema.Fields)) > Params.MaxFieldNum { + return errors.New("maximum field's number should be limited to " + strconv.FormatInt(Params.MaxFieldNum, 10)) } // validate collection name @@ -293,7 +292,7 @@ func (cct *CreateCollectionTask) PreExecute() error { func (cct *CreateCollectionTask) Execute() error { var err error - cct.result, err = cct.masterClient.CreateCollection(cct.ctx, cct.CreateCollectionRequest) + cct.result, err = cct.masterClient.CreateCollection(cct.CreateCollectionRequest) return err } @@ -304,7 +303,7 @@ func (cct *CreateCollectionTask) PostExecute() error { type DropCollectionTask struct { Condition *milvuspb.DropCollectionRequest - masterClient masterpb.MasterServiceClient + masterClient MasterClientInterface result *commonpb.Status ctx context.Context } @@ -340,7 +339,7 @@ func (dct *DropCollectionTask) SetTs(ts Timestamp) { func (dct *DropCollectionTask) PreExecute() error { dct.Base.MsgType = commonpb.MsgType_kDropCollection - dct.Base.SourceID = Params.ProxyID() + dct.Base.SourceID = Params.ProxyID if err := ValidateCollectionName(dct.CollectionName); err != nil { return err @@ -350,7 +349,7 @@ func (dct *DropCollectionTask) PreExecute() error { func (dct *DropCollectionTask) Execute() error { var err error - dct.result, err = dct.masterClient.DropCollection(dct.ctx, dct.DropCollectionRequest) + dct.result, err = dct.masterClient.DropCollection(dct.DropCollectionRequest) return err } @@ -401,7 +400,7 @@ func (st *SearchTask) SetTs(ts Timestamp) { func (st *SearchTask) PreExecute() error { st.Base.MsgType = commonpb.MsgType_kSearch - st.Base.SourceID = Params.ProxyID() + st.Base.SourceID = Params.ProxyID span, ctx := opentracing.StartSpanFromContext(st.ctx, "SearchTask preExecute") defer span.Finish() @@ -460,7 +459,7 @@ func (st *SearchTask) Execute() error { var tsMsg msgstream.TsMsg = &msgstream.SearchMsg{ SearchRequest: st.SearchRequest, BaseMsg: msgstream.BaseMsg{ - HashValues: []uint32{uint32(Params.ProxyID())}, + HashValues: []uint32{uint32(Params.ProxyID)}, BeginTimestamp: st.Base.Timestamp, EndTimestamp: st.Base.Timestamp, }, @@ -646,7 +645,7 @@ func (st *SearchTask) PostExecute() error { type HasCollectionTask struct { Condition *milvuspb.HasCollectionRequest - masterClient masterpb.MasterServiceClient + masterClient MasterClientInterface result *milvuspb.BoolResponse ctx context.Context } @@ -682,7 +681,7 @@ func (hct *HasCollectionTask) SetTs(ts Timestamp) { func (hct *HasCollectionTask) PreExecute() error { hct.Base.MsgType = commonpb.MsgType_kHasCollection - hct.Base.SourceID = Params.ProxyID() + hct.Base.SourceID = Params.ProxyID if err := ValidateCollectionName(hct.CollectionName); err != nil { return err @@ -692,7 +691,7 @@ func (hct *HasCollectionTask) PreExecute() error { func (hct *HasCollectionTask) Execute() error { var err error - hct.result, err = hct.masterClient.HasCollection(hct.ctx, hct.HasCollectionRequest) + hct.result, err = hct.masterClient.HasCollection(hct.HasCollectionRequest) return err } @@ -703,7 +702,7 @@ func (hct *HasCollectionTask) PostExecute() error { type DescribeCollectionTask struct { Condition *milvuspb.DescribeCollectionRequest - masterClient masterpb.MasterServiceClient + masterClient MasterClientInterface result *milvuspb.DescribeCollectionResponse ctx context.Context } @@ -739,7 +738,7 @@ func (dct *DescribeCollectionTask) SetTs(ts Timestamp) { func (dct *DescribeCollectionTask) PreExecute() error { dct.Base.MsgType = commonpb.MsgType_kDescribeCollection - dct.Base.SourceID = Params.ProxyID() + dct.Base.SourceID = Params.ProxyID if err := ValidateCollectionName(dct.CollectionName); err != nil { return err @@ -749,7 +748,7 @@ func (dct *DescribeCollectionTask) PreExecute() error { func (dct *DescribeCollectionTask) Execute() error { var err error - dct.result, err = dct.masterClient.DescribeCollection(dct.ctx, dct.DescribeCollectionRequest) + dct.result, err = dct.masterClient.DescribeCollection(dct.DescribeCollectionRequest) if err != nil { return err } @@ -764,7 +763,7 @@ func (dct *DescribeCollectionTask) PostExecute() error { type ShowCollectionsTask struct { Condition *milvuspb.ShowCollectionRequest - masterClient masterpb.MasterServiceClient + masterClient MasterClientInterface result *milvuspb.ShowCollectionResponse ctx context.Context } @@ -800,14 +799,14 @@ func (sct *ShowCollectionsTask) SetTs(ts Timestamp) { func (sct *ShowCollectionsTask) PreExecute() error { sct.Base.MsgType = commonpb.MsgType_kShowCollections - sct.Base.SourceID = Params.ProxyID() + sct.Base.SourceID = Params.ProxyID return nil } func (sct *ShowCollectionsTask) Execute() error { var err error - sct.result, err = sct.masterClient.ShowCollections(sct.ctx, sct.ShowCollectionRequest) + sct.result, err = sct.masterClient.ShowCollections(sct.ShowCollectionRequest) return err } @@ -818,7 +817,7 @@ func (sct *ShowCollectionsTask) PostExecute() error { type CreatePartitionTask struct { Condition *milvuspb.CreatePartitionRequest - masterClient masterpb.MasterServiceClient + masterClient MasterClientInterface result *commonpb.Status ctx context.Context } @@ -854,7 +853,7 @@ func (cpt *CreatePartitionTask) SetTs(ts Timestamp) { func (cpt *CreatePartitionTask) PreExecute() error { cpt.Base.MsgType = commonpb.MsgType_kCreatePartition - cpt.Base.SourceID = Params.ProxyID() + cpt.Base.SourceID = Params.ProxyID collName, partitionTag := cpt.CollectionName, cpt.PartitionName @@ -870,7 +869,7 @@ func (cpt *CreatePartitionTask) PreExecute() error { } func (cpt *CreatePartitionTask) Execute() (err error) { - cpt.result, err = cpt.masterClient.CreatePartition(cpt.ctx, cpt.CreatePartitionRequest) + cpt.result, err = cpt.masterClient.CreatePartition(cpt.CreatePartitionRequest) return err } @@ -881,7 +880,7 @@ func (cpt *CreatePartitionTask) PostExecute() error { type DropPartitionTask struct { Condition *milvuspb.DropPartitionRequest - masterClient masterpb.MasterServiceClient + masterClient MasterClientInterface result *commonpb.Status ctx context.Context } @@ -917,7 +916,7 @@ func (dpt *DropPartitionTask) SetTs(ts Timestamp) { func (dpt *DropPartitionTask) PreExecute() error { dpt.Base.MsgType = commonpb.MsgType_kDropPartition - dpt.Base.SourceID = Params.ProxyID() + dpt.Base.SourceID = Params.ProxyID collName, partitionTag := dpt.CollectionName, dpt.PartitionName @@ -933,7 +932,7 @@ func (dpt *DropPartitionTask) PreExecute() error { } func (dpt *DropPartitionTask) Execute() (err error) { - dpt.result, err = dpt.masterClient.DropPartition(dpt.ctx, dpt.DropPartitionRequest) + dpt.result, err = dpt.masterClient.DropPartition(dpt.DropPartitionRequest) return err } @@ -944,7 +943,7 @@ func (dpt *DropPartitionTask) PostExecute() error { type HasPartitionTask struct { Condition *milvuspb.HasPartitionRequest - masterClient masterpb.MasterServiceClient + masterClient MasterClientInterface result *milvuspb.BoolResponse ctx context.Context } @@ -980,7 +979,7 @@ func (hpt *HasPartitionTask) SetTs(ts Timestamp) { func (hpt *HasPartitionTask) PreExecute() error { hpt.Base.MsgType = commonpb.MsgType_kHasPartition - hpt.Base.SourceID = Params.ProxyID() + hpt.Base.SourceID = Params.ProxyID collName, partitionTag := hpt.CollectionName, hpt.PartitionName @@ -995,7 +994,7 @@ func (hpt *HasPartitionTask) PreExecute() error { } func (hpt *HasPartitionTask) Execute() (err error) { - hpt.result, err = hpt.masterClient.HasPartition(hpt.ctx, hpt.HasPartitionRequest) + hpt.result, err = hpt.masterClient.HasPartition(hpt.HasPartitionRequest) return err } @@ -1060,7 +1059,7 @@ func (hpt *HasPartitionTask) PostExecute() error { type ShowPartitionsTask struct { Condition *milvuspb.ShowPartitionRequest - masterClient masterpb.MasterServiceClient + masterClient MasterClientInterface result *milvuspb.ShowPartitionResponse ctx context.Context } @@ -1096,7 +1095,7 @@ func (spt *ShowPartitionsTask) SetTs(ts Timestamp) { func (spt *ShowPartitionsTask) PreExecute() error { spt.Base.MsgType = commonpb.MsgType_kShowPartitions - spt.Base.SourceID = Params.ProxyID() + spt.Base.SourceID = Params.ProxyID if err := ValidateCollectionName(spt.CollectionName); err != nil { return err @@ -1106,7 +1105,7 @@ func (spt *ShowPartitionsTask) PreExecute() error { func (spt *ShowPartitionsTask) Execute() error { var err error - spt.result, err = spt.masterClient.ShowPartitions(spt.ctx, spt.ShowPartitionRequest) + spt.result, err = spt.masterClient.ShowPartitions(spt.ShowPartitionRequest) return err } @@ -1117,7 +1116,7 @@ func (spt *ShowPartitionsTask) PostExecute() error { type CreateIndexTask struct { Condition *milvuspb.CreateIndexRequest - masterClient masterpb.MasterServiceClient + masterClient MasterClientInterface result *commonpb.Status ctx context.Context } @@ -1153,7 +1152,7 @@ func (cit *CreateIndexTask) SetTs(ts Timestamp) { func (cit *CreateIndexTask) PreExecute() error { cit.Base.MsgType = commonpb.MsgType_kCreateIndex - cit.Base.SourceID = Params.ProxyID() + cit.Base.SourceID = Params.ProxyID collName, fieldName := cit.CollectionName, cit.FieldName @@ -1169,7 +1168,7 @@ func (cit *CreateIndexTask) PreExecute() error { } func (cit *CreateIndexTask) Execute() (err error) { - cit.result, err = cit.masterClient.CreateIndex(cit.ctx, cit.CreateIndexRequest) + cit.result, err = cit.masterClient.CreateIndex(cit.CreateIndexRequest) return err } @@ -1180,7 +1179,7 @@ func (cit *CreateIndexTask) PostExecute() error { type DescribeIndexTask struct { Condition *milvuspb.DescribeIndexRequest - masterClient masterpb.MasterServiceClient + masterClient MasterClientInterface result *milvuspb.DescribeIndexResponse ctx context.Context } @@ -1216,7 +1215,7 @@ func (dit *DescribeIndexTask) SetTs(ts Timestamp) { func (dit *DescribeIndexTask) PreExecute() error { dit.Base.MsgType = commonpb.MsgType_kDescribeIndex - dit.Base.SourceID = Params.ProxyID() + dit.Base.SourceID = Params.ProxyID collName, fieldName := dit.CollectionName, dit.FieldName @@ -1233,7 +1232,7 @@ func (dit *DescribeIndexTask) PreExecute() error { func (dit *DescribeIndexTask) Execute() error { var err error - dit.result, err = dit.masterClient.DescribeIndex(dit.ctx, dit.DescribeIndexRequest) + dit.result, err = dit.masterClient.DescribeIndex(dit.DescribeIndexRequest) return err } @@ -1244,9 +1243,9 @@ func (dit *DescribeIndexTask) PostExecute() error { type GetIndexStateTask struct { Condition *milvuspb.IndexStateRequest - masterClient masterpb.MasterServiceClient - result *milvuspb.IndexStateResponse - ctx context.Context + indexServiceClient IndexServiceClient + result *milvuspb.IndexStateResponse + ctx context.Context } func (dipt *GetIndexStateTask) OnEnqueue() error { @@ -1280,7 +1279,7 @@ func (dipt *GetIndexStateTask) SetTs(ts Timestamp) { func (dipt *GetIndexStateTask) PreExecute() error { dipt.Base.MsgType = commonpb.MsgType_kGetIndexState - dipt.Base.SourceID = Params.ProxyID() + dipt.Base.SourceID = Params.ProxyID collName, fieldName := dipt.CollectionName, dipt.FieldName @@ -1296,9 +1295,18 @@ func (dipt *GetIndexStateTask) PreExecute() error { } func (dipt *GetIndexStateTask) Execute() error { - var err error - dipt.result, err = dipt.masterClient.GetIndexState(dipt.ctx, dipt.IndexStateRequest) - return err + // TODO: use index service client + //var err error + //dipt.result, err = dipt.masterClient.GetIndexState(dipt.IndexStateRequest) + //return err + dipt.result = &milvuspb.IndexStateResponse{ + Status: &commonpb.Status{ + ErrorCode: 0, + Reason: "", + }, + State: commonpb.IndexState_FINISHED, + } + return nil } func (dipt *GetIndexStateTask) PostExecute() error { diff --git a/internal/proxynode/task_scheduler.go b/internal/proxynode/task_scheduler.go index 4de0d5986aa1392dab03a927e86d535333dc08f4..7716613c25bd60c64c4ea444db0d868c7dfe8d71 100644 --- a/internal/proxynode/task_scheduler.go +++ b/internal/proxynode/task_scheduler.go @@ -374,13 +374,13 @@ func (sched *TaskScheduler) queryResultLoop() { defer sched.wg.Done() unmarshal := util.NewUnmarshalDispatcher() - queryResultMsgStream := pulsarms.NewPulsarMsgStream(sched.ctx, Params.MsgStreamSearchResultBufSize()) - queryResultMsgStream.SetPulsarClient(Params.PulsarAddress()) - queryResultMsgStream.CreatePulsarConsumers(Params.SearchResultChannelNames(), - Params.ProxySubName(), + queryResultMsgStream := pulsarms.NewPulsarMsgStream(sched.ctx, Params.MsgStreamSearchResultBufSize) + queryResultMsgStream.SetPulsarClient(Params.PulsarAddress) + queryResultMsgStream.CreatePulsarConsumers(Params.SearchResultChannelNames, + Params.ProxySubName, unmarshal, - Params.MsgStreamSearchResultPulsarBufSize()) - queryNodeNum := Params.queryNodeNum() + Params.MsgStreamSearchResultPulsarBufSize) + queryNodeNum := Params.QueryNodeNum queryResultMsgStream.Start() defer queryResultMsgStream.Close() diff --git a/internal/proxynode/timetick.go b/internal/proxynode/timetick.go index 1bca3f59ad012adbdcc685e0ef0a37d33cc0dc9b..b6ef1c7d6dde419d187b52032cf34e2a4fe25236 100644 --- a/internal/proxynode/timetick.go +++ b/internal/proxynode/timetick.go @@ -47,15 +47,15 @@ func newTimeTick(ctx context.Context, cancel: cancel, tsoAllocator: tsoAllocator, interval: interval, - peerID: Params.ProxyID(), + peerID: Params.ProxyID, checkFunc: checkFunc, } - t.tickMsgStream = pulsarms.NewPulsarMsgStream(t.ctx, Params.MsgStreamTimeTickBufSize()) - pulsarAddress := Params.PulsarAddress() + t.tickMsgStream = pulsarms.NewPulsarMsgStream(t.ctx, Params.MsgStreamTimeTickBufSize) + pulsarAddress := Params.PulsarAddress t.tickMsgStream.SetPulsarClient(pulsarAddress) - t.tickMsgStream.CreatePulsarProducers(Params.ProxyTimeTickChannelNames()) + t.tickMsgStream.CreatePulsarProducers(Params.ProxyTimeTickChannelNames) return t } @@ -74,7 +74,7 @@ func (tt *timeTick) tick() error { msgPack := msgstream.MsgPack{} timeTickMsg := &msgstream.TimeTickMsg{ BaseMsg: msgstream.BaseMsg{ - HashValues: []uint32{uint32(Params.ProxyID())}, + HashValues: []uint32{uint32(Params.ProxyID)}, }, TimeTickMsg: internalpb2.TimeTickMsg{ Base: &commonpb.MsgBase{ diff --git a/internal/proxynode/timetick_test.go b/internal/proxynode/timetick_test.go index 3f0e63700ed1d39b609e76b76f41cf0a23413d15..6520a99fcf15911f378fcb0882a7c886debdb3cc 100644 --- a/internal/proxynode/timetick_test.go +++ b/internal/proxynode/timetick_test.go @@ -28,13 +28,13 @@ func TestTimeTick_Start(t *testing.T) { func TestTimeTick_Start2(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - masterAddr := Params.MasterAddress() + masterAddr := Params.MasterAddress tsoAllocator, err := allocator.NewTimestampAllocator(ctx, masterAddr) assert.Nil(t, err) err = tsoAllocator.Start() assert.Nil(t, err) - tt := newTimeTick(ctx, tsoAllocator, Params.TimeTickInterval(), checkFunc) + tt := newTimeTick(ctx, tsoAllocator, Params.TimeTickInterval, checkFunc) defer func() { cancel() diff --git a/internal/proxynode/validate_util.go b/internal/proxynode/validate_util.go index b07d6532edfcaf1e2ca3c5a1c420b943e8e906d7..839fedd9f0eefea1685e788a7b4c88234483a03f 100644 --- a/internal/proxynode/validate_util.go +++ b/internal/proxynode/validate_util.go @@ -30,9 +30,9 @@ func ValidateCollectionName(collName string) error { } invalidMsg := "Invalid collection name: " + collName + ". " - if int64(len(collName)) > Params.MaxNameLength() { + if int64(len(collName)) > Params.MaxNameLength { msg := invalidMsg + "The length of a collection name must be less than " + - strconv.FormatInt(Params.MaxNameLength(), 10) + " characters." + strconv.FormatInt(Params.MaxNameLength, 10) + " characters." return errors.New(msg) } @@ -61,9 +61,9 @@ func ValidatePartitionTag(partitionTag string, strictCheck bool) error { return errors.New(msg) } - if int64(len(partitionTag)) > Params.MaxNameLength() { + if int64(len(partitionTag)) > Params.MaxNameLength { msg := invalidMsg + "The length of a partition tag must be less than " + - strconv.FormatInt(Params.MaxNameLength(), 10) + " characters." + strconv.FormatInt(Params.MaxNameLength, 10) + " characters." return errors.New(msg) } @@ -95,9 +95,9 @@ func ValidateFieldName(fieldName string) error { } invalidMsg := "Invalid field name: " + fieldName + ". " - if int64(len(fieldName)) > Params.MaxNameLength() { + if int64(len(fieldName)) > Params.MaxNameLength { msg := invalidMsg + "The length of a field name must be less than " + - strconv.FormatInt(Params.MaxNameLength(), 10) + " characters." + strconv.FormatInt(Params.MaxNameLength, 10) + " characters." return errors.New(msg) } @@ -119,9 +119,9 @@ func ValidateFieldName(fieldName string) error { } func ValidateDimension(dim int64, isBinary bool) error { - if dim <= 0 || dim > Params.MaxDimension() { + if dim <= 0 || dim > Params.MaxDimension { return errors.New("invalid dimension: " + strconv.FormatInt(dim, 10) + ". should be in range 1 ~ " + - strconv.FormatInt(Params.MaxDimension(), 10) + ".") + strconv.FormatInt(Params.MaxDimension, 10) + ".") } if isBinary && dim%8 != 0 { return errors.New("invalid dimension: " + strconv.FormatInt(dim, 10) + ". should be multiple of 8.") diff --git a/internal/proxynode/validate_util_test.go b/internal/proxynode/validate_util_test.go index 54f3ffb19b9da403b14de08a7fe92b00d7e2a1d1..32c722c0373901bb8ac16f02c26bede9f0548e64 100644 --- a/internal/proxynode/validate_util_test.go +++ b/internal/proxynode/validate_util_test.go @@ -84,13 +84,13 @@ func TestValidateFieldName(t *testing.T) { func TestValidateDimension(t *testing.T) { assert.Nil(t, ValidateDimension(1, false)) - assert.Nil(t, ValidateDimension(Params.MaxDimension(), false)) + assert.Nil(t, ValidateDimension(Params.MaxDimension, false)) assert.Nil(t, ValidateDimension(8, true)) - assert.Nil(t, ValidateDimension(Params.MaxDimension(), true)) + assert.Nil(t, ValidateDimension(Params.MaxDimension, true)) // invalid dim assert.NotNil(t, ValidateDimension(-1, false)) - assert.NotNil(t, ValidateDimension(Params.MaxDimension()+1, false)) + assert.NotNil(t, ValidateDimension(Params.MaxDimension+1, false)) assert.NotNil(t, ValidateDimension(9, true)) } diff --git a/internal/proxyservice/impl.go b/internal/proxyservice/impl.go index 193534f15bd62e0fd2849da6eec4e6b4033ce214..74ac4c0684ff457f7b71abf0c56f431994212afe 100644 --- a/internal/proxyservice/impl.go +++ b/internal/proxyservice/impl.go @@ -2,7 +2,11 @@ package proxyservice import ( "context" - "fmt" + "io/ioutil" + "log" + "os" + "path" + "runtime" "time" "github.com/zilliztech/milvus-distributed/internal/msgstream/util" @@ -18,18 +22,63 @@ import ( ) const ( - timeoutInterval = time.Second * 10 + timeoutInterval = time.Second * 10 + StartParamsKey = "START_PARAMS" + ChannelYamlContent = "advanced/channel.yaml" + CommonYamlContent = "advanced/common.yaml" + DataNodeYamlContent = "advanced/data_node.yaml" + MasterYamlContent = "advanced/master.yaml" + ProxyNodeYamlContent = "advanced/proxy_node.yaml" + QueryNodeYamlContent = "advanced/query_node.yaml" + WriteNodeYamlContent = "advanced/write_node.yaml" + MilvusYamlContent = "milvus.yaml" ) func (s *ServiceImpl) fillNodeInitParams() error { s.nodeStartParams = make([]*commonpb.KeyValuePair, 0) - nodeParams := &ParamTable{} - nodeParams.Init() - err := nodeParams.LoadYaml("advanced/proxy_node.yaml") - if err != nil { - return err + + getConfigContentByName := func(fileName string) []byte { + _, fpath, _, _ := runtime.Caller(0) + configFile := path.Dir(fpath) + "/../../../configs/" + fileName + _, err := os.Stat(configFile) + if os.IsNotExist(err) { + runPath, err := os.Getwd() + if err != nil { + panic(err) + } + configFile = runPath + "/configs/" + fileName + } + data, err := ioutil.ReadFile(configFile) + if err != nil { + panic(err) + } + return data } + channelYamlContent := getConfigContentByName(ChannelYamlContent) + commonYamlContent := getConfigContentByName(CommonYamlContent) + dataNodeYamlContent := getConfigContentByName(DataNodeYamlContent) + masterYamlContent := getConfigContentByName(MasterYamlContent) + proxyNodeYamlContent := getConfigContentByName(ProxyNodeYamlContent) + queryNodeYamlContent := getConfigContentByName(QueryNodeYamlContent) + writeNodeYamlContent := getConfigContentByName(WriteNodeYamlContent) + milvusYamlContent := getConfigContentByName(MilvusYamlContent) + + var allContent []byte + allContent = append(allContent, channelYamlContent...) + allContent = append(allContent, commonYamlContent...) + allContent = append(allContent, dataNodeYamlContent...) + allContent = append(allContent, masterYamlContent...) + allContent = append(allContent, proxyNodeYamlContent...) + allContent = append(allContent, queryNodeYamlContent...) + allContent = append(allContent, writeNodeYamlContent...) + allContent = append(allContent, milvusYamlContent...) + + s.nodeStartParams = append(s.nodeStartParams, &commonpb.KeyValuePair{ + Key: StartParamsKey, + Value: string(allContent), + }) + return nil } @@ -40,12 +89,12 @@ func (s *ServiceImpl) Init() error { } serviceTimeTickMsgStream := pulsarms.NewPulsarTtMsgStream(s.ctx, 1024) - serviceTimeTickMsgStream.SetPulsarClient(Params.PulsarAddress()) - serviceTimeTickMsgStream.CreatePulsarProducers([]string{Params.ServiceTimeTickChannel()}) + serviceTimeTickMsgStream.SetPulsarClient(Params.PulsarAddress) + serviceTimeTickMsgStream.CreatePulsarProducers([]string{Params.ServiceTimeTickChannel}) nodeTimeTickMsgStream := pulsarms.NewPulsarMsgStream(s.ctx, 1024) - nodeTimeTickMsgStream.SetPulsarClient(Params.PulsarAddress()) - nodeTimeTickMsgStream.CreatePulsarConsumers(Params.NodeTimeTickChannel(), + nodeTimeTickMsgStream.SetPulsarClient(Params.PulsarAddress) + nodeTimeTickMsgStream.CreatePulsarConsumers(Params.NodeTimeTickChannel, "proxyservicesub", // TODO: add config util.NewUnmarshalDispatcher(), 1024) @@ -53,20 +102,6 @@ func (s *ServiceImpl) Init() error { ttBarrier := newSoftTimeTickBarrier(s.ctx, nodeTimeTickMsgStream, []UniqueID{0}, 10) s.tick = newTimeTick(s.ctx, ttBarrier, serviceTimeTickMsgStream) - // dataServiceAddr := Params.DataServiceAddress() - // s.dataServiceClient = dataservice.NewClient(dataServiceAddr) - - // insertChannelsRequest := &datapb.InsertChannelRequest{} - // insertChannelNames, err := s.dataServiceClient.GetInsertChannels(insertChannelsRequest) - // if err != nil { - // return err - // } - - // if len(insertChannelNames.Values) > 0 { - // namesStr := strings.Join(insertChannelNames.Values, ",") - // s.nodeStartParams = append(s.nodeStartParams, &commonpb.KeyValuePair{Key: KInsertChannelNames, Value: namesStr}) - // } - s.state.State.StateCode = internalpb2.StateCode_HEALTHY return nil @@ -88,7 +123,7 @@ func (s *ServiceImpl) GetComponentStates() (*internalpb2.ComponentStates, error) } func (s *ServiceImpl) GetTimeTickChannel() (string, error) { - return Params.ServiceTimeTickChannel(), nil + return Params.ServiceTimeTickChannel, nil } func (s *ServiceImpl) GetStatisticsChannel() (string, error) { @@ -96,7 +131,7 @@ func (s *ServiceImpl) GetStatisticsChannel() (string, error) { } func (s *ServiceImpl) RegisterLink() (*milvuspb.RegisterLinkResponse, error) { - fmt.Println("register link") + log.Println("register link") ctx, cancel := context.WithTimeout(s.ctx, timeoutInterval) defer cancel() @@ -133,7 +168,7 @@ func (s *ServiceImpl) RegisterLink() (*milvuspb.RegisterLinkResponse, error) { } func (s *ServiceImpl) RegisterNode(request *proxypb.RegisterNodeRequest) (*proxypb.RegisterNodeResponse, error) { - fmt.Println("RegisterNode: ", request) + log.Println("RegisterNode: ", request) ctx, cancel := context.WithTimeout(s.ctx, timeoutInterval) defer cancel() @@ -173,7 +208,7 @@ func (s *ServiceImpl) RegisterNode(request *proxypb.RegisterNodeRequest) (*proxy } func (s *ServiceImpl) InvalidateCollectionMetaCache(request *proxypb.InvalidateCollMetaCacheRequest) error { - fmt.Println("InvalidateCollectionMetaCache") + log.Println("InvalidateCollectionMetaCache") ctx, cancel := context.WithTimeout(s.ctx, timeoutInterval) defer cancel() diff --git a/internal/proxyservice/paramtable.go b/internal/proxyservice/paramtable.go index bdd7bae37fc006888c422eeab866f80c03d21f6a..27d6b51098dd560ac407d66a74f568487c93a2d5 100644 --- a/internal/proxyservice/paramtable.go +++ b/internal/proxyservice/paramtable.go @@ -8,48 +8,60 @@ import ( type ParamTable struct { paramtable.BaseTable + + PulsarAddress string + MasterAddress string + NodeTimeTickChannel []string + ServiceTimeTickChannel string + DataServiceAddress string } var Params ParamTable func (pt *ParamTable) Init() { pt.BaseTable.Init() + + pt.initPulsarAddress() + pt.initMasterAddress() + pt.initNodeTimeTickChannel() + pt.initServiceTimeTickChannel() + pt.initDataServiceAddress() } -func (pt *ParamTable) PulsarAddress() string { +func (pt *ParamTable) initPulsarAddress() { ret, err := pt.Load("_PulsarAddress") if err != nil { panic(err) } - return ret + pt.PulsarAddress = ret } -func (pt *ParamTable) MasterAddress() string { +func (pt *ParamTable) initMasterAddress() { ret, err := pt.Load("_MasterAddress") if err != nil { panic(err) } - return ret + pt.MasterAddress = ret } -func (pt *ParamTable) NodeTimeTickChannel() []string { +func (pt *ParamTable) initNodeTimeTickChannel() { prefix, err := pt.Load("msgChannel.chanNamePrefix.proxyTimeTick") if err != nil { log.Panic(err) } prefix += "-0" - return []string{prefix} + pt.NodeTimeTickChannel = []string{prefix} } -func (pt *ParamTable) ServiceTimeTickChannel() string { +func (pt *ParamTable) initServiceTimeTickChannel() { ch, err := pt.Load("msgChannel.chanNamePrefix.proxyServiceTimeTick") if err != nil { log.Panic(err) } - return ch + pt.ServiceTimeTickChannel = ch } -func (pt *ParamTable) DataServiceAddress() string { +func (pt *ParamTable) initDataServiceAddress() { // NOT USED NOW - return "TODO: read from config" + pt.DataServiceAddress = "TODO: read from config" }