提交 79571702 编写于 作者: D dragondriver 提交者: yefu.chen

Dock Proxy with other service component

Signed-off-by: Ndragondriver <jiquan.long@zilliz.com>
上级 3a866dab
package grpcproxynode
const (
StartParamsKey = "START_PARAMS"
MasterPort = "master.port"
MasterHost = "master.address"
PulsarPort = "pulsar.port"
PulsarHost = "pulsar.address"
IndexServerPort = "indexBuilder.port"
IndexServerHost = "indexBuilder.address"
QueryNodeIDList = "nodeID.queryNodeIDList"
TimeTickInterval = "proxyNode.timeTickInterval"
SubName = "msgChannel.subNamePrefix.proxySubNamePrefix"
TimeTickChannelNames = "msgChannel.chanNamePrefix.proxyTimeTick"
MsgStreamInsertBufSize = "proxyNode.msgStream.insert.bufSize"
MsgStreamSearchBufSize = "proxyNode.msgStream.search.bufSize"
MsgStreamSearchResultBufSize = "proxyNode.msgStream.searchResult.recvBufSize"
MsgStreamSearchResultPulsarBufSize = "proxyNode.msgStream.searchResult.pulsarBufSize"
MsgStreamTimeTickBufSize = "proxyNode.msgStream.timeTick.bufSize"
MaxNameLength = "proxyNode.maxNameLength"
MaxFieldNum = "proxyNode.maxFieldNum"
MaxDimension = "proxyNode.MaxDimension"
DefaultPartitionTag = "common.defaultPartitionTag"
)
package grpcproxynode
import (
"net"
"strconv"
"github.com/zilliztech/milvus-distributed/internal/util/paramtable"
)
type ParamTable struct {
paramtable.BaseTable
ProxyServiceAddress string
}
var Params ParamTable
func (pt *ParamTable) Init() {
pt.BaseTable.Init()
pt.initProxyServiceAddress()
}
func (pt *ParamTable) initProxyServiceAddress() {
addr, err := pt.Load("proxyService.address")
if err != nil {
panic(err)
}
hostName, _ := net.LookupHost(addr)
if len(hostName) <= 0 {
if ip := net.ParseIP(addr); ip == nil {
panic("invalid ip proxyService.address")
}
}
port, err := pt.Load("proxyService.port")
if err != nil {
panic(err)
}
_, err = strconv.Atoi(port)
if err != nil {
panic(err)
}
pt.ProxyServiceAddress = addr + ":" + port
}
package grpcproxynode
import (
"bytes"
"context"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
"github.com/spf13/viper"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb2"
"github.com/go-basic/ipv4"
......@@ -36,10 +45,125 @@ func CreateProxyNodeServer() (*Server, error) {
return &Server{}, nil
}
func (s *Server) loadConfigFromInitParams(initParams *internalpb2.InitParams) error {
proxynode.Params.ProxyID = initParams.NodeID
config := viper.New()
config.SetConfigType("yaml")
for _, pair := range initParams.StartParams {
if pair.Key == StartParamsKey {
err := config.ReadConfig(bytes.NewBuffer([]byte(pair.Value)))
if err != nil {
return err
}
break
}
}
masterPort := config.GetString(MasterPort)
masterHost := config.GetString(MasterHost)
proxynode.Params.MasterAddress = masterHost + ":" + masterPort
pulsarPort := config.GetString(PulsarPort)
pulsarHost := config.GetString(PulsarHost)
proxynode.Params.PulsarAddress = pulsarHost + ":" + pulsarPort
indexServerPort := config.GetString(IndexServerPort)
indexServerHost := config.GetString(IndexServerHost)
proxynode.Params.IndexServerAddress = indexServerHost + ":" + indexServerPort
queryNodeIDList := config.GetString(QueryNodeIDList)
proxynode.Params.QueryNodeIDList = nil
queryNodeIDs := strings.Split(queryNodeIDList, ",")
for _, queryNodeID := range queryNodeIDs {
v, err := strconv.Atoi(queryNodeID)
if err != nil {
return err
}
proxynode.Params.QueryNodeIDList = append(proxynode.Params.QueryNodeIDList, typeutil.UniqueID(v))
}
proxynode.Params.QueryNodeNum = len(proxynode.Params.QueryNodeIDList)
timeTickInterval := config.GetString(TimeTickInterval)
interval, err := strconv.Atoi(timeTickInterval)
if err != nil {
return err
}
proxynode.Params.TimeTickInterval = time.Duration(interval) * time.Millisecond
subName := config.GetString(SubName)
proxynode.Params.ProxySubName = subName
timeTickChannelNames := config.GetString(TimeTickChannelNames)
proxynode.Params.ProxyTimeTickChannelNames = []string{timeTickChannelNames}
msgStreamInsertBufSizeStr := config.GetString(MsgStreamInsertBufSize)
msgStreamInsertBufSize, err := strconv.Atoi(msgStreamInsertBufSizeStr)
if err != nil {
return err
}
proxynode.Params.MsgStreamInsertBufSize = int64(msgStreamInsertBufSize)
msgStreamSearchBufSizeStr := config.GetString(MsgStreamSearchBufSize)
msgStreamSearchBufSize, err := strconv.Atoi(msgStreamSearchBufSizeStr)
if err != nil {
return err
}
proxynode.Params.MsgStreamSearchBufSize = int64(msgStreamSearchBufSize)
msgStreamSearchResultBufSizeStr := config.GetString(MsgStreamSearchResultBufSize)
msgStreamSearchResultBufSize, err := strconv.Atoi(msgStreamSearchResultBufSizeStr)
if err != nil {
return err
}
proxynode.Params.MsgStreamSearchResultBufSize = int64(msgStreamSearchResultBufSize)
msgStreamSearchResultPulsarBufSizeStr := config.GetString(MsgStreamSearchResultPulsarBufSize)
msgStreamSearchResultPulsarBufSize, err := strconv.Atoi(msgStreamSearchResultPulsarBufSizeStr)
if err != nil {
return err
}
proxynode.Params.MsgStreamSearchResultPulsarBufSize = int64(msgStreamSearchResultPulsarBufSize)
msgStreamTimeTickBufSizeStr := config.GetString(MsgStreamTimeTickBufSize)
msgStreamTimeTickBufSize, err := strconv.Atoi(msgStreamTimeTickBufSizeStr)
if err != nil {
return err
}
proxynode.Params.MsgStreamTimeTickBufSize = int64(msgStreamTimeTickBufSize)
maxNameLengthStr := config.GetString(MaxNameLength)
maxNameLength, err := strconv.Atoi(maxNameLengthStr)
if err != nil {
return err
}
proxynode.Params.MaxNameLength = int64(maxNameLength)
maxFieldNumStr := config.GetString(MaxFieldNum)
maxFieldNum, err := strconv.Atoi(maxFieldNumStr)
if err != nil {
return err
}
proxynode.Params.MaxFieldNum = int64(maxFieldNum)
maxDimensionStr := config.GetString(MaxDimension)
maxDimension, err := strconv.Atoi(maxDimensionStr)
if err != nil {
return err
}
proxynode.Params.MaxDimension = int64(maxDimension)
defaultPartitionTag := config.GetString(DefaultPartitionTag)
proxynode.Params.DefaultPartitionTag = defaultPartitionTag
return nil
}
func (s *Server) connectProxyService() error {
Params.Init()
proxynode.Params.Init()
s.proxyServiceAddress = proxynode.Params.ProxyServiceAddress()
s.proxyServiceAddress = Params.ProxyServiceAddress
s.proxyServiceClient = grpcproxyservice.NewClient(s.proxyServiceAddress)
getAvailablePort := func() int {
......@@ -74,13 +198,7 @@ func (s *Server) connectProxyService() error {
panic(err)
}
proxynode.Params.Save("_proxyID", strconv.Itoa(int(response.InitParams.NodeID)))
for _, params := range response.InitParams.StartParams {
proxynode.Params.Save(params.Key, params.Value)
}
return err
return s.loadConfigFromInitParams(response.InitParams)
}
func (s *Server) Init() error {
......
package proxynode
import (
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/indexpb"
"github.com/zilliztech/milvus-distributed/internal/proto/milvuspb"
)
type MasterClientInterface interface {
Init() error
Start() error
Stop() error
CreateCollection(in *milvuspb.CreateCollectionRequest) (*commonpb.Status, error)
DropCollection(in *milvuspb.DropCollectionRequest) (*commonpb.Status, error)
HasCollection(in *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error)
DescribeCollection(in *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error)
ShowCollections(in *milvuspb.ShowCollectionRequest) (*milvuspb.ShowCollectionResponse, error)
CreatePartition(in *milvuspb.CreatePartitionRequest) (*commonpb.Status, error)
DropPartition(in *milvuspb.DropPartitionRequest) (*commonpb.Status, error)
HasPartition(in *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error)
ShowPartitions(in *milvuspb.ShowPartitionRequest) (*milvuspb.ShowPartitionResponse, error)
CreateIndex(in *milvuspb.CreateIndexRequest) (*commonpb.Status, error)
DescribeIndex(in *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error)
}
type IndexServiceClient interface {
Init() error
Start() error
Stop() error
GetIndexStates(req *indexpb.IndexStatesRequest) (*indexpb.IndexStatesResponse, error)
}
type QueryServiceClient interface {
Init() error
Start() error
Stop() error
GetSearchChannelNames() ([]string, error)
GetSearchResultChannelNames() ([]string, error)
}
type DataServiceClient interface {
Init() error
Start() error
Stop() error
GetInsertChannelNames() ([]string, error)
}
func (node *NodeImpl) SetMasterClient(cli MasterClientInterface) {
node.masterClient = cli
}
func (node *NodeImpl) SetIndexServiceClient(cli IndexServiceClient) {
node.indexServiceClient = cli
}
func (node *NodeImpl) SetQueryServiceClient(cli QueryServiceClient) {
node.queryServiceClient = cli
}
func (node *NodeImpl) SetDataServiceClient(cli DataServiceClient) {
node.dataServiceClient = cli
}
......@@ -503,7 +503,6 @@ func (node *NodeImpl) GetIndexState(ctx context.Context, request *milvuspb.Index
dipt := &GetIndexStateTask{
Condition: NewTaskCondition(ctx),
IndexStateRequest: request,
masterClient: node.masterClient,
}
var cancel func()
......@@ -568,7 +567,7 @@ func (node *NodeImpl) Insert(ctx context.Context, request *milvuspb.InsertReques
rowIDAllocator: node.idAllocator,
}
if len(it.PartitionName) <= 0 {
it.PartitionName = Params.defaultPartitionTag()
it.PartitionName = Params.DefaultPartitionTag
}
var cancel func()
......@@ -621,9 +620,9 @@ func (node *NodeImpl) Search(ctx context.Context, request *milvuspb.SearchReques
SearchRequest: internalpb2.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_kSearch,
SourceID: Params.ProxyID(),
SourceID: Params.ProxyID,
},
ResultChannelID: strconv.FormatInt(Params.ProxyID(), 10),
ResultChannelID: strconv.FormatInt(Params.ProxyID, 10),
},
queryMsgStream: node.queryMsgStream,
resultBuf: make(chan []*internalpb2.SearchResults),
......
......@@ -13,6 +13,34 @@ import (
type ParamTable struct {
paramtable.BaseTable
NetworkPort int
NetworkAddress string
ProxyServiceAddress string
MasterAddress string
PulsarAddress string
IndexServerAddress string
QueryNodeNum int
QueryNodeIDList []UniqueID
ProxyID UniqueID
TimeTickInterval time.Duration
InsertChannelNames []string
DeleteChannelNames []string
K2SChannelNames []string
SearchChannelNames []string
SearchResultChannelNames []string
ProxySubName string
ProxyTimeTickChannelNames []string
DataDefinitionChannelNames []string
MsgStreamInsertBufSize int64
MsgStreamSearchBufSize int64
MsgStreamSearchResultBufSize int64
MsgStreamSearchResultPulsarBufSize int64
MsgStreamTimeTickBufSize int64
MaxNameLength int64
MaxFieldNum int64
MaxDimension int64
DefaultPartitionTag string
}
var Params ParamTable
......@@ -20,28 +48,40 @@ var Params ParamTable
func (pt *ParamTable) Init() {
pt.BaseTable.Init()
err := pt.LoadYaml("advanced/proxy_node.yaml")
if err != nil {
panic(err)
}
proxyIDStr := os.Getenv("PROXY_ID")
if proxyIDStr == "" {
proxyIDList := pt.ProxyIDList()
if len(proxyIDList) <= 0 {
proxyIDStr = "0"
} else {
proxyIDStr = strconv.Itoa(int(proxyIDList[0]))
}
}
pt.Save("_proxyID", proxyIDStr)
pt.initNetworkPort()
pt.initNetworkAddress()
pt.initProxyServiceAddress()
pt.initMasterAddress()
pt.initPulsarAddress()
pt.initIndexServerAddress()
pt.initQueryNodeIDList()
pt.initQueryNodeNum()
pt.initProxyID()
pt.initTimeTickInterval()
pt.initInsertChannelNames()
pt.initDeleteChannelNames()
pt.initK2SChannelNames()
pt.initSearchChannelNames()
pt.initSearchResultChannelNames()
pt.initProxySubName()
pt.initProxyTimeTickChannelNames()
pt.initDataDefinitionChannelNames()
pt.initMsgStreamInsertBufSize()
pt.initMsgStreamSearchBufSize()
pt.initMsgStreamSearchResultBufSize()
pt.initMsgStreamSearchResultPulsarBufSize()
pt.initMsgStreamTimeTickBufSize()
pt.initMaxNameLength()
pt.initMaxFieldNum()
pt.initMaxDimension()
pt.initDefaultPartitionTag()
}
func (pt *ParamTable) NetworkPort() int {
return pt.ParseInt("proxyNode.port")
func (pt *ParamTable) initNetworkPort() {
pt.NetworkPort = pt.ParseInt("proxyNode.port")
}
func (pt *ParamTable) NetworkAddress() string {
func (pt *ParamTable) initNetworkAddress() {
addr, err := pt.Load("proxyNode.address")
if err != nil {
panic(err)
......@@ -62,14 +102,14 @@ func (pt *ParamTable) NetworkAddress() string {
if err != nil {
panic(err)
}
return addr + ":" + port
pt.NetworkAddress = addr + ":" + port
}
func (pt *ParamTable) ProxyServiceAddress() string {
func (pt *ParamTable) initProxyServiceAddress() {
addressFromEnv := os.Getenv("PROXY_SERVICE_ADDRESS")
if len(addressFromEnv) > 0 {
// TODO: or write to param table?
return addressFromEnv
pt.ProxyServiceAddress = addressFromEnv
}
addr, err := pt.Load("proxyService.address")
......@@ -92,35 +132,55 @@ func (pt *ParamTable) ProxyServiceAddress() string {
if err != nil {
panic(err)
}
return addr + ":" + port
pt.ProxyServiceAddress = addr + ":" + port
}
func (pt *ParamTable) MasterAddress() string {
func (pt *ParamTable) initMasterAddress() {
ret, err := pt.Load("_MasterAddress")
if err != nil {
panic(err)
}
return ret
pt.MasterAddress = ret
}
func (pt *ParamTable) PulsarAddress() string {
func (pt *ParamTable) initPulsarAddress() {
ret, err := pt.Load("_PulsarAddress")
if err != nil {
panic(err)
}
return ret
pt.PulsarAddress = ret
}
func (pt *ParamTable) ProxyNum() int {
ret := pt.ProxyIDList()
return len(ret)
func (pt *ParamTable) initIndexServerAddress() {
addr, err := pt.Load("indexServer.address")
if err != nil {
panic(err)
}
hostName, _ := net.LookupHost(addr)
if len(hostName) <= 0 {
if ip := net.ParseIP(addr); ip == nil {
panic("invalid ip indexServer.address")
}
}
port, err := pt.Load("indexServer.port")
if err != nil {
panic(err)
}
_, err = strconv.Atoi(port)
if err != nil {
panic(err)
}
pt.IndexServerAddress = addr + ":" + port
}
func (pt *ParamTable) queryNodeNum() int {
return len(pt.queryNodeIDList())
func (pt *ParamTable) initQueryNodeNum() {
pt.QueryNodeNum = len(pt.QueryNodeIDList)
}
func (pt *ParamTable) queryNodeIDList() []UniqueID {
func (pt *ParamTable) initQueryNodeIDList() []UniqueID {
queryNodeIDStr, err := pt.Load("nodeID.queryNodeIDList")
if err != nil {
panic(err)
......@@ -137,7 +197,7 @@ func (pt *ParamTable) queryNodeIDList() []UniqueID {
return ret
}
func (pt *ParamTable) ProxyID() UniqueID {
func (pt *ParamTable) initProxyID() {
proxyID, err := pt.Load("_proxyID")
if err != nil {
panic(err)
......@@ -146,10 +206,10 @@ func (pt *ParamTable) ProxyID() UniqueID {
if err != nil {
panic(err)
}
return UniqueID(ID)
pt.ProxyID = UniqueID(ID)
}
func (pt *ParamTable) TimeTickInterval() time.Duration {
func (pt *ParamTable) initTimeTickInterval() {
internalStr, err := pt.Load("proxyNode.timeTickInterval")
if err != nil {
panic(err)
......@@ -158,21 +218,10 @@ func (pt *ParamTable) TimeTickInterval() time.Duration {
if err != nil {
panic(err)
}
return time.Duration(interval) * time.Millisecond
pt.TimeTickInterval = time.Duration(interval) * time.Millisecond
}
func (pt *ParamTable) sliceIndex() int {
proxyID := pt.ProxyID()
proxyIDList := pt.ProxyIDList()
for i := 0; i < len(proxyIDList); i++ {
if proxyID == proxyIDList[i] {
return i
}
}
return -1
}
func (pt *ParamTable) InsertChannelNames() []string {
func (pt *ParamTable) initInsertChannelNames() {
prefix, err := pt.Load("msgChannel.chanNamePrefix.insert")
if err != nil {
panic(err)
......@@ -188,17 +237,10 @@ func (pt *ParamTable) InsertChannelNames() []string {
ret = append(ret, prefix+strconv.Itoa(ID))
}
proxyNum := pt.ProxyNum()
sep := len(channelIDs) / proxyNum
index := pt.sliceIndex()
if index == -1 {
panic("ProxyID not Match with Config")
}
start := index * sep
return ret[start : start+sep]
pt.InsertChannelNames = ret
}
func (pt *ParamTable) DeleteChannelNames() []string {
func (pt *ParamTable) initDeleteChannelNames() {
prefix, err := pt.Load("msgChannel.chanNamePrefix.delete")
if err != nil {
panic(err)
......@@ -213,10 +255,10 @@ func (pt *ParamTable) DeleteChannelNames() []string {
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
}
return ret
pt.DeleteChannelNames = ret
}
func (pt *ParamTable) K2SChannelNames() []string {
func (pt *ParamTable) initK2SChannelNames() {
prefix, err := pt.Load("msgChannel.chanNamePrefix.k2s")
if err != nil {
panic(err)
......@@ -231,10 +273,10 @@ func (pt *ParamTable) K2SChannelNames() []string {
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
}
return ret
pt.K2SChannelNames = ret
}
func (pt *ParamTable) SearchChannelNames() []string {
func (pt *ParamTable) initSearchChannelNames() {
prefix, err := pt.Load("msgChannel.chanNamePrefix.search")
if err != nil {
panic(err)
......@@ -249,10 +291,10 @@ func (pt *ParamTable) SearchChannelNames() []string {
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
}
return ret
pt.SearchChannelNames = ret
}
func (pt *ParamTable) SearchResultChannelNames() []string {
func (pt *ParamTable) initSearchResultChannelNames() {
prefix, err := pt.Load("msgChannel.chanNamePrefix.searchResult")
if err != nil {
panic(err)
......@@ -267,18 +309,10 @@ func (pt *ParamTable) SearchResultChannelNames() []string {
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
}
proxyNum := pt.ProxyNum()
sep := len(channelIDs) / proxyNum
index := pt.sliceIndex()
if index == -1 {
panic("ProxyID not Match with Config")
}
start := index * sep
return ret[start : start+sep]
pt.SearchResultChannelNames = ret
}
func (pt *ParamTable) ProxySubName() string {
func (pt *ParamTable) initProxySubName() {
prefix, err := pt.Load("msgChannel.subNamePrefix.proxySubNamePrefix")
if err != nil {
panic(err)
......@@ -287,48 +321,48 @@ func (pt *ParamTable) ProxySubName() string {
if err != nil {
panic(err)
}
return prefix + "-" + proxyIDStr
pt.ProxySubName = prefix + "-" + proxyIDStr
}
func (pt *ParamTable) ProxyTimeTickChannelNames() []string {
func (pt *ParamTable) initProxyTimeTickChannelNames() {
prefix, err := pt.Load("msgChannel.chanNamePrefix.proxyTimeTick")
if err != nil {
panic(err)
}
prefix += "-0"
return []string{prefix}
pt.ProxyTimeTickChannelNames = []string{prefix}
}
func (pt *ParamTable) DataDefinitionChannelNames() []string {
func (pt *ParamTable) initDataDefinitionChannelNames() {
prefix, err := pt.Load("msgChannel.chanNamePrefix.dataDefinition")
if err != nil {
panic(err)
}
prefix += "-0"
return []string{prefix}
pt.DataDefinitionChannelNames = []string{prefix}
}
func (pt *ParamTable) MsgStreamInsertBufSize() int64 {
return pt.ParseInt64("proxyNode.msgStream.insert.bufSize")
func (pt *ParamTable) initMsgStreamInsertBufSize() {
pt.MsgStreamInsertBufSize = pt.ParseInt64("proxyNode.msgStream.insert.bufSize")
}
func (pt *ParamTable) MsgStreamSearchBufSize() int64 {
return pt.ParseInt64("proxyNode.msgStream.search.bufSize")
func (pt *ParamTable) initMsgStreamSearchBufSize() {
pt.MsgStreamSearchBufSize = pt.ParseInt64("proxyNode.msgStream.search.bufSize")
}
func (pt *ParamTable) MsgStreamSearchResultBufSize() int64 {
return pt.ParseInt64("proxyNode.msgStream.searchResult.recvBufSize")
func (pt *ParamTable) initMsgStreamSearchResultBufSize() {
pt.MsgStreamSearchResultBufSize = pt.ParseInt64("proxyNode.msgStream.searchResult.recvBufSize")
}
func (pt *ParamTable) MsgStreamSearchResultPulsarBufSize() int64 {
return pt.ParseInt64("proxyNode.msgStream.searchResult.pulsarBufSize")
func (pt *ParamTable) initMsgStreamSearchResultPulsarBufSize() {
pt.MsgStreamSearchResultPulsarBufSize = pt.ParseInt64("proxyNode.msgStream.searchResult.pulsarBufSize")
}
func (pt *ParamTable) MsgStreamTimeTickBufSize() int64 {
return pt.ParseInt64("proxyNode.msgStream.timeTick.bufSize")
func (pt *ParamTable) initMsgStreamTimeTickBufSize() {
pt.MsgStreamTimeTickBufSize = pt.ParseInt64("proxyNode.msgStream.timeTick.bufSize")
}
func (pt *ParamTable) MaxNameLength() int64 {
func (pt *ParamTable) initMaxNameLength() {
str, err := pt.Load("proxyNode.maxNameLength")
if err != nil {
panic(err)
......@@ -337,10 +371,10 @@ func (pt *ParamTable) MaxNameLength() int64 {
if err != nil {
panic(err)
}
return maxNameLength
pt.MaxNameLength = maxNameLength
}
func (pt *ParamTable) MaxFieldNum() int64 {
func (pt *ParamTable) initMaxFieldNum() {
str, err := pt.Load("proxyNode.maxFieldNum")
if err != nil {
panic(err)
......@@ -349,10 +383,10 @@ func (pt *ParamTable) MaxFieldNum() int64 {
if err != nil {
panic(err)
}
return maxFieldNum
pt.MaxFieldNum = maxFieldNum
}
func (pt *ParamTable) MaxDimension() int64 {
func (pt *ParamTable) initMaxDimension() {
str, err := pt.Load("proxyNode.maxDimension")
if err != nil {
panic(err)
......@@ -361,13 +395,13 @@ func (pt *ParamTable) MaxDimension() int64 {
if err != nil {
panic(err)
}
return maxDimension
pt.MaxDimension = maxDimension
}
func (pt *ParamTable) defaultPartitionTag() string {
func (pt *ParamTable) initDefaultPartitionTag() {
tag, err := pt.Load("common.defaultPartitionTag")
if err != nil {
panic(err)
}
return tag
pt.DefaultPartitionTag = tag
}
......@@ -6,71 +6,71 @@ import (
)
func TestParamTable_InsertChannelRange(t *testing.T) {
ret := Params.InsertChannelNames()
ret := Params.InsertChannelNames
fmt.Println(ret)
}
func TestParamTable_DeleteChannelNames(t *testing.T) {
ret := Params.DeleteChannelNames()
ret := Params.DeleteChannelNames
fmt.Println(ret)
}
func TestParamTable_K2SChannelNames(t *testing.T) {
ret := Params.K2SChannelNames()
ret := Params.K2SChannelNames
fmt.Println(ret)
}
func TestParamTable_SearchChannelNames(t *testing.T) {
ret := Params.SearchChannelNames()
ret := Params.SearchChannelNames
fmt.Println(ret)
}
func TestParamTable_SearchResultChannelNames(t *testing.T) {
ret := Params.SearchResultChannelNames()
ret := Params.SearchResultChannelNames
fmt.Println(ret)
}
func TestParamTable_ProxySubName(t *testing.T) {
ret := Params.ProxySubName()
ret := Params.ProxySubName
fmt.Println(ret)
}
func TestParamTable_ProxyTimeTickChannelNames(t *testing.T) {
ret := Params.ProxyTimeTickChannelNames()
ret := Params.ProxyTimeTickChannelNames
fmt.Println(ret)
}
func TestParamTable_DataDefinitionChannelNames(t *testing.T) {
ret := Params.DataDefinitionChannelNames()
ret := Params.DataDefinitionChannelNames
fmt.Println(ret)
}
func TestParamTable_MsgStreamInsertBufSize(t *testing.T) {
ret := Params.MsgStreamInsertBufSize()
ret := Params.MsgStreamInsertBufSize
fmt.Println(ret)
}
func TestParamTable_MsgStreamSearchBufSize(t *testing.T) {
ret := Params.MsgStreamSearchBufSize()
ret := Params.MsgStreamSearchBufSize
fmt.Println(ret)
}
func TestParamTable_MsgStreamSearchResultBufSize(t *testing.T) {
ret := Params.MsgStreamSearchResultBufSize()
ret := Params.MsgStreamSearchResultBufSize
fmt.Println(ret)
}
func TestParamTable_MsgStreamSearchResultPulsarBufSize(t *testing.T) {
ret := Params.MsgStreamSearchResultPulsarBufSize()
ret := Params.MsgStreamSearchResultPulsarBufSize
fmt.Println(ret)
}
func TestParamTable_MsgStreamTimeTickBufSize(t *testing.T) {
ret := Params.MsgStreamTimeTickBufSize()
ret := Params.MsgStreamTimeTickBufSize
fmt.Println(ret)
}
func TestParamTable_defaultPartitionTag(t *testing.T) {
ret := Params.defaultPartitionTag()
ret := Params.DefaultPartitionTag
fmt.Println("default partition tag: ", ret)
}
......@@ -4,25 +4,19 @@ import (
"context"
"fmt"
"io"
"log"
"math/rand"
"sync"
"time"
"github.com/zilliztech/milvus-distributed/internal/msgstream/pulsarms"
grpcproxyservice "github.com/zilliztech/milvus-distributed/internal/distributed/proxyservice"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb2"
"github.com/opentracing/opentracing-go"
"github.com/uber/jaeger-client-go/config"
"google.golang.org/grpc"
"github.com/zilliztech/milvus-distributed/internal/allocator"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/masterpb"
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
)
......@@ -34,15 +28,17 @@ type NodeImpl struct {
cancel func()
wg sync.WaitGroup
proxyServiceClient *grpcproxyservice.Client
initParams *internalpb2.InitParams
ip string
port int
initParams *internalpb2.InitParams
ip string
port int
masterConn *grpc.ClientConn
masterClient masterpb.MasterServiceClient
sched *TaskScheduler
tick *timeTick
masterClient MasterClientInterface
indexServiceClient IndexServiceClient
queryServiceClient QueryServiceClient
dataServiceClient DataServiceClient
sched *TaskScheduler
tick *timeTick
idAllocator *allocator.IDAllocator
tsoAllocator *allocator.TimestampAllocator
......@@ -75,6 +71,36 @@ func (node *NodeImpl) Init() error {
var err error
err = node.masterClient.Init()
if err != nil {
return err
}
err = node.indexServiceClient.Init()
if err != nil {
return err
}
err = node.queryServiceClient.Init()
if err != nil {
return err
}
err = node.dataServiceClient.Init()
if err != nil {
return err
}
Params.SearchChannelNames, err = node.queryServiceClient.GetSearchChannelNames()
if err != nil {
return err
}
Params.SearchResultChannelNames, err = node.queryServiceClient.GetSearchResultChannelNames()
if err != nil {
return err
}
Params.InsertChannelNames, err = node.dataServiceClient.GetInsertChannelNames()
if err != nil {
return err
}
cfg := &config.Configuration{
ServiceName: "proxynode",
Sampler: &config.SamplerConfig{
......@@ -88,38 +114,38 @@ func (node *NodeImpl) Init() error {
}
opentracing.SetGlobalTracer(node.tracer)
pulsarAddress := Params.PulsarAddress()
pulsarAddress := Params.PulsarAddress
node.queryMsgStream = pulsarms.NewPulsarMsgStream(node.ctx, Params.MsgStreamSearchBufSize())
node.queryMsgStream = pulsarms.NewPulsarMsgStream(node.ctx, Params.MsgStreamSearchBufSize)
node.queryMsgStream.SetPulsarClient(pulsarAddress)
node.queryMsgStream.CreatePulsarProducers(Params.SearchChannelNames())
node.queryMsgStream.CreatePulsarProducers(Params.SearchChannelNames)
masterAddr := Params.MasterAddress()
masterAddr := Params.MasterAddress
idAllocator, err := allocator.NewIDAllocator(node.ctx, masterAddr)
if err != nil {
return err
}
node.idAllocator = idAllocator
node.idAllocator.PeerID = Params.ProxyID()
node.idAllocator.PeerID = Params.ProxyID
tsoAllocator, err := allocator.NewTimestampAllocator(node.ctx, masterAddr)
if err != nil {
return err
}
node.tsoAllocator = tsoAllocator
node.tsoAllocator.PeerID = Params.ProxyID()
node.tsoAllocator.PeerID = Params.ProxyID
segAssigner, err := allocator.NewSegIDAssigner(node.ctx, masterAddr, node.lastTick)
if err != nil {
panic(err)
}
node.segAssigner = segAssigner
node.segAssigner.PeerID = Params.ProxyID()
node.segAssigner.PeerID = Params.ProxyID
node.manipulationMsgStream = pulsarms.NewPulsarMsgStream(node.ctx, Params.MsgStreamInsertBufSize())
node.manipulationMsgStream = pulsarms.NewPulsarMsgStream(node.ctx, Params.MsgStreamInsertBufSize)
node.manipulationMsgStream.SetPulsarClient(pulsarAddress)
node.manipulationMsgStream.CreatePulsarProducers(Params.InsertChannelNames())
node.manipulationMsgStream.CreatePulsarProducers(Params.InsertChannelNames)
repackFuncImpl := func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
return insertRepackFunc(tsMsgs, hashKeys, node.segAssigner, true)
}
......@@ -136,7 +162,20 @@ func (node *NodeImpl) Init() error {
}
func (node *NodeImpl) Start() error {
err := node.connectMaster()
var err error
err = node.masterClient.Start()
if err != nil {
return err
}
err = node.indexServiceClient.Start()
if err != nil {
return err
}
err = node.queryServiceClient.Start()
if err != nil {
return err
}
err = node.dataServiceClient.Start()
if err != nil {
return err
}
......@@ -167,6 +206,23 @@ func (node *NodeImpl) Stop() error {
node.manipulationMsgStream.Close()
node.queryMsgStream.Close()
node.tick.Close()
var err error
err = node.dataServiceClient.Stop()
if err != nil {
return err
}
err = node.queryServiceClient.Stop()
if err != nil {
return err
}
err = node.indexServiceClient.Stop()
if err != nil {
return err
}
err = node.masterClient.Stop()
if err != nil {
return err
}
node.wg.Wait()
......@@ -197,19 +253,3 @@ func (node *NodeImpl) lastTick() Timestamp {
func (node *NodeImpl) AddCloseCallback(callbacks ...func()) {
node.closeCallbacks = append(node.closeCallbacks, callbacks...)
}
func (node *NodeImpl) connectMaster() error {
masterAddr := Params.MasterAddress()
log.Printf("NodeImpl connected to master, master_addr=%s", masterAddr)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
conn, err := grpc.DialContext(ctx, masterAddr, grpc.WithInsecure(), grpc.WithBlock())
if err != nil {
log.Printf("NodeImpl connect to master failed, error= %v", err)
return err
}
log.Printf("NodeImpl connected to master, master_addr=%s", masterAddr)
node.masterConn = conn
node.masterClient = masterpb.NewMasterServiceClient(conn)
return nil
}
......@@ -345,7 +345,7 @@ func TestProxy_Search(t *testing.T) {
queryResultChannels := []string{"QueryResult"}
bufSize := 1024
queryResultMsgStream := pulsarms.NewPulsarMsgStream(ctx, int64(bufSize))
pulsarAddress := Params.PulsarAddress()
pulsarAddress := Params.PulsarAddress
queryResultMsgStream.SetPulsarClient(pulsarAddress)
assert.NotEqual(t, queryResultMsgStream, nil, "query result message stream should not be nil!")
queryResultMsgStream.CreatePulsarProducers(queryResultChannels)
......@@ -417,7 +417,7 @@ func TestProxy_AssignSegID(t *testing.T) {
testNum := 1
futureTS := tsoutil.ComposeTS(time.Now().Add(time.Second*-1000).UnixNano()/int64(time.Millisecond), 0)
for i := 0; i < testNum; i++ {
segID, err := proxyServer.segAssigner.GetSegmentID(collectionName, Params.defaultPartitionTag(), int32(i), 200000, futureTS)
segID, err := proxyServer.segAssigner.GetSegmentID(collectionName, Params.DefaultPartitionTag, int32(i), 200000, futureTS)
assert.Nil(t, err)
fmt.Println("segID", segID)
}
......
......@@ -16,7 +16,6 @@ import (
"github.com/zilliztech/milvus-distributed/internal/msgstream/pulsarms"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb2"
"github.com/zilliztech/milvus-distributed/internal/proto/masterpb"
"github.com/zilliztech/milvus-distributed/internal/proto/milvuspb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
......@@ -84,7 +83,7 @@ func (it *InsertTask) Type() commonpb.MsgType {
func (it *InsertTask) PreExecute() error {
it.Base.MsgType = commonpb.MsgType_kInsert
it.Base.SourceID = Params.ProxyID()
it.Base.SourceID = Params.ProxyID
span, ctx := opentracing.StartSpanFromContext(it.ctx, "InsertTask preExecute")
defer span.Finish()
......@@ -189,7 +188,7 @@ func (it *InsertTask) PostExecute() error {
type CreateCollectionTask struct {
Condition
*milvuspb.CreateCollectionRequest
masterClient masterpb.MasterServiceClient
masterClient MasterClientInterface
result *commonpb.Status
ctx context.Context
schema *schemapb.CollectionSchema
......@@ -226,7 +225,7 @@ func (cct *CreateCollectionTask) SetTs(ts Timestamp) {
func (cct *CreateCollectionTask) PreExecute() error {
cct.Base.MsgType = commonpb.MsgType_kCreateCollection
cct.Base.SourceID = Params.ProxyID()
cct.Base.SourceID = Params.ProxyID
cct.schema = &schemapb.CollectionSchema{}
err := proto.Unmarshal(cct.Schema, cct.schema)
......@@ -234,8 +233,8 @@ func (cct *CreateCollectionTask) PreExecute() error {
return err
}
if int64(len(cct.schema.Fields)) > Params.MaxFieldNum() {
return errors.New("maximum field's number should be limited to " + strconv.FormatInt(Params.MaxFieldNum(), 10))
if int64(len(cct.schema.Fields)) > Params.MaxFieldNum {
return errors.New("maximum field's number should be limited to " + strconv.FormatInt(Params.MaxFieldNum, 10))
}
// validate collection name
......@@ -293,7 +292,7 @@ func (cct *CreateCollectionTask) PreExecute() error {
func (cct *CreateCollectionTask) Execute() error {
var err error
cct.result, err = cct.masterClient.CreateCollection(cct.ctx, cct.CreateCollectionRequest)
cct.result, err = cct.masterClient.CreateCollection(cct.CreateCollectionRequest)
return err
}
......@@ -304,7 +303,7 @@ func (cct *CreateCollectionTask) PostExecute() error {
type DropCollectionTask struct {
Condition
*milvuspb.DropCollectionRequest
masterClient masterpb.MasterServiceClient
masterClient MasterClientInterface
result *commonpb.Status
ctx context.Context
}
......@@ -340,7 +339,7 @@ func (dct *DropCollectionTask) SetTs(ts Timestamp) {
func (dct *DropCollectionTask) PreExecute() error {
dct.Base.MsgType = commonpb.MsgType_kDropCollection
dct.Base.SourceID = Params.ProxyID()
dct.Base.SourceID = Params.ProxyID
if err := ValidateCollectionName(dct.CollectionName); err != nil {
return err
......@@ -350,7 +349,7 @@ func (dct *DropCollectionTask) PreExecute() error {
func (dct *DropCollectionTask) Execute() error {
var err error
dct.result, err = dct.masterClient.DropCollection(dct.ctx, dct.DropCollectionRequest)
dct.result, err = dct.masterClient.DropCollection(dct.DropCollectionRequest)
return err
}
......@@ -401,7 +400,7 @@ func (st *SearchTask) SetTs(ts Timestamp) {
func (st *SearchTask) PreExecute() error {
st.Base.MsgType = commonpb.MsgType_kSearch
st.Base.SourceID = Params.ProxyID()
st.Base.SourceID = Params.ProxyID
span, ctx := opentracing.StartSpanFromContext(st.ctx, "SearchTask preExecute")
defer span.Finish()
......@@ -460,7 +459,7 @@ func (st *SearchTask) Execute() error {
var tsMsg msgstream.TsMsg = &msgstream.SearchMsg{
SearchRequest: st.SearchRequest,
BaseMsg: msgstream.BaseMsg{
HashValues: []uint32{uint32(Params.ProxyID())},
HashValues: []uint32{uint32(Params.ProxyID)},
BeginTimestamp: st.Base.Timestamp,
EndTimestamp: st.Base.Timestamp,
},
......@@ -646,7 +645,7 @@ func (st *SearchTask) PostExecute() error {
type HasCollectionTask struct {
Condition
*milvuspb.HasCollectionRequest
masterClient masterpb.MasterServiceClient
masterClient MasterClientInterface
result *milvuspb.BoolResponse
ctx context.Context
}
......@@ -682,7 +681,7 @@ func (hct *HasCollectionTask) SetTs(ts Timestamp) {
func (hct *HasCollectionTask) PreExecute() error {
hct.Base.MsgType = commonpb.MsgType_kHasCollection
hct.Base.SourceID = Params.ProxyID()
hct.Base.SourceID = Params.ProxyID
if err := ValidateCollectionName(hct.CollectionName); err != nil {
return err
......@@ -692,7 +691,7 @@ func (hct *HasCollectionTask) PreExecute() error {
func (hct *HasCollectionTask) Execute() error {
var err error
hct.result, err = hct.masterClient.HasCollection(hct.ctx, hct.HasCollectionRequest)
hct.result, err = hct.masterClient.HasCollection(hct.HasCollectionRequest)
return err
}
......@@ -703,7 +702,7 @@ func (hct *HasCollectionTask) PostExecute() error {
type DescribeCollectionTask struct {
Condition
*milvuspb.DescribeCollectionRequest
masterClient masterpb.MasterServiceClient
masterClient MasterClientInterface
result *milvuspb.DescribeCollectionResponse
ctx context.Context
}
......@@ -739,7 +738,7 @@ func (dct *DescribeCollectionTask) SetTs(ts Timestamp) {
func (dct *DescribeCollectionTask) PreExecute() error {
dct.Base.MsgType = commonpb.MsgType_kDescribeCollection
dct.Base.SourceID = Params.ProxyID()
dct.Base.SourceID = Params.ProxyID
if err := ValidateCollectionName(dct.CollectionName); err != nil {
return err
......@@ -749,7 +748,7 @@ func (dct *DescribeCollectionTask) PreExecute() error {
func (dct *DescribeCollectionTask) Execute() error {
var err error
dct.result, err = dct.masterClient.DescribeCollection(dct.ctx, dct.DescribeCollectionRequest)
dct.result, err = dct.masterClient.DescribeCollection(dct.DescribeCollectionRequest)
if err != nil {
return err
}
......@@ -764,7 +763,7 @@ func (dct *DescribeCollectionTask) PostExecute() error {
type ShowCollectionsTask struct {
Condition
*milvuspb.ShowCollectionRequest
masterClient masterpb.MasterServiceClient
masterClient MasterClientInterface
result *milvuspb.ShowCollectionResponse
ctx context.Context
}
......@@ -800,14 +799,14 @@ func (sct *ShowCollectionsTask) SetTs(ts Timestamp) {
func (sct *ShowCollectionsTask) PreExecute() error {
sct.Base.MsgType = commonpb.MsgType_kShowCollections
sct.Base.SourceID = Params.ProxyID()
sct.Base.SourceID = Params.ProxyID
return nil
}
func (sct *ShowCollectionsTask) Execute() error {
var err error
sct.result, err = sct.masterClient.ShowCollections(sct.ctx, sct.ShowCollectionRequest)
sct.result, err = sct.masterClient.ShowCollections(sct.ShowCollectionRequest)
return err
}
......@@ -818,7 +817,7 @@ func (sct *ShowCollectionsTask) PostExecute() error {
type CreatePartitionTask struct {
Condition
*milvuspb.CreatePartitionRequest
masterClient masterpb.MasterServiceClient
masterClient MasterClientInterface
result *commonpb.Status
ctx context.Context
}
......@@ -854,7 +853,7 @@ func (cpt *CreatePartitionTask) SetTs(ts Timestamp) {
func (cpt *CreatePartitionTask) PreExecute() error {
cpt.Base.MsgType = commonpb.MsgType_kCreatePartition
cpt.Base.SourceID = Params.ProxyID()
cpt.Base.SourceID = Params.ProxyID
collName, partitionTag := cpt.CollectionName, cpt.PartitionName
......@@ -870,7 +869,7 @@ func (cpt *CreatePartitionTask) PreExecute() error {
}
func (cpt *CreatePartitionTask) Execute() (err error) {
cpt.result, err = cpt.masterClient.CreatePartition(cpt.ctx, cpt.CreatePartitionRequest)
cpt.result, err = cpt.masterClient.CreatePartition(cpt.CreatePartitionRequest)
return err
}
......@@ -881,7 +880,7 @@ func (cpt *CreatePartitionTask) PostExecute() error {
type DropPartitionTask struct {
Condition
*milvuspb.DropPartitionRequest
masterClient masterpb.MasterServiceClient
masterClient MasterClientInterface
result *commonpb.Status
ctx context.Context
}
......@@ -917,7 +916,7 @@ func (dpt *DropPartitionTask) SetTs(ts Timestamp) {
func (dpt *DropPartitionTask) PreExecute() error {
dpt.Base.MsgType = commonpb.MsgType_kDropPartition
dpt.Base.SourceID = Params.ProxyID()
dpt.Base.SourceID = Params.ProxyID
collName, partitionTag := dpt.CollectionName, dpt.PartitionName
......@@ -933,7 +932,7 @@ func (dpt *DropPartitionTask) PreExecute() error {
}
func (dpt *DropPartitionTask) Execute() (err error) {
dpt.result, err = dpt.masterClient.DropPartition(dpt.ctx, dpt.DropPartitionRequest)
dpt.result, err = dpt.masterClient.DropPartition(dpt.DropPartitionRequest)
return err
}
......@@ -944,7 +943,7 @@ func (dpt *DropPartitionTask) PostExecute() error {
type HasPartitionTask struct {
Condition
*milvuspb.HasPartitionRequest
masterClient masterpb.MasterServiceClient
masterClient MasterClientInterface
result *milvuspb.BoolResponse
ctx context.Context
}
......@@ -980,7 +979,7 @@ func (hpt *HasPartitionTask) SetTs(ts Timestamp) {
func (hpt *HasPartitionTask) PreExecute() error {
hpt.Base.MsgType = commonpb.MsgType_kHasPartition
hpt.Base.SourceID = Params.ProxyID()
hpt.Base.SourceID = Params.ProxyID
collName, partitionTag := hpt.CollectionName, hpt.PartitionName
......@@ -995,7 +994,7 @@ func (hpt *HasPartitionTask) PreExecute() error {
}
func (hpt *HasPartitionTask) Execute() (err error) {
hpt.result, err = hpt.masterClient.HasPartition(hpt.ctx, hpt.HasPartitionRequest)
hpt.result, err = hpt.masterClient.HasPartition(hpt.HasPartitionRequest)
return err
}
......@@ -1060,7 +1059,7 @@ func (hpt *HasPartitionTask) PostExecute() error {
type ShowPartitionsTask struct {
Condition
*milvuspb.ShowPartitionRequest
masterClient masterpb.MasterServiceClient
masterClient MasterClientInterface
result *milvuspb.ShowPartitionResponse
ctx context.Context
}
......@@ -1096,7 +1095,7 @@ func (spt *ShowPartitionsTask) SetTs(ts Timestamp) {
func (spt *ShowPartitionsTask) PreExecute() error {
spt.Base.MsgType = commonpb.MsgType_kShowPartitions
spt.Base.SourceID = Params.ProxyID()
spt.Base.SourceID = Params.ProxyID
if err := ValidateCollectionName(spt.CollectionName); err != nil {
return err
......@@ -1106,7 +1105,7 @@ func (spt *ShowPartitionsTask) PreExecute() error {
func (spt *ShowPartitionsTask) Execute() error {
var err error
spt.result, err = spt.masterClient.ShowPartitions(spt.ctx, spt.ShowPartitionRequest)
spt.result, err = spt.masterClient.ShowPartitions(spt.ShowPartitionRequest)
return err
}
......@@ -1117,7 +1116,7 @@ func (spt *ShowPartitionsTask) PostExecute() error {
type CreateIndexTask struct {
Condition
*milvuspb.CreateIndexRequest
masterClient masterpb.MasterServiceClient
masterClient MasterClientInterface
result *commonpb.Status
ctx context.Context
}
......@@ -1153,7 +1152,7 @@ func (cit *CreateIndexTask) SetTs(ts Timestamp) {
func (cit *CreateIndexTask) PreExecute() error {
cit.Base.MsgType = commonpb.MsgType_kCreateIndex
cit.Base.SourceID = Params.ProxyID()
cit.Base.SourceID = Params.ProxyID
collName, fieldName := cit.CollectionName, cit.FieldName
......@@ -1169,7 +1168,7 @@ func (cit *CreateIndexTask) PreExecute() error {
}
func (cit *CreateIndexTask) Execute() (err error) {
cit.result, err = cit.masterClient.CreateIndex(cit.ctx, cit.CreateIndexRequest)
cit.result, err = cit.masterClient.CreateIndex(cit.CreateIndexRequest)
return err
}
......@@ -1180,7 +1179,7 @@ func (cit *CreateIndexTask) PostExecute() error {
type DescribeIndexTask struct {
Condition
*milvuspb.DescribeIndexRequest
masterClient masterpb.MasterServiceClient
masterClient MasterClientInterface
result *milvuspb.DescribeIndexResponse
ctx context.Context
}
......@@ -1216,7 +1215,7 @@ func (dit *DescribeIndexTask) SetTs(ts Timestamp) {
func (dit *DescribeIndexTask) PreExecute() error {
dit.Base.MsgType = commonpb.MsgType_kDescribeIndex
dit.Base.SourceID = Params.ProxyID()
dit.Base.SourceID = Params.ProxyID
collName, fieldName := dit.CollectionName, dit.FieldName
......@@ -1233,7 +1232,7 @@ func (dit *DescribeIndexTask) PreExecute() error {
func (dit *DescribeIndexTask) Execute() error {
var err error
dit.result, err = dit.masterClient.DescribeIndex(dit.ctx, dit.DescribeIndexRequest)
dit.result, err = dit.masterClient.DescribeIndex(dit.DescribeIndexRequest)
return err
}
......@@ -1244,9 +1243,9 @@ func (dit *DescribeIndexTask) PostExecute() error {
type GetIndexStateTask struct {
Condition
*milvuspb.IndexStateRequest
masterClient masterpb.MasterServiceClient
result *milvuspb.IndexStateResponse
ctx context.Context
indexServiceClient IndexServiceClient
result *milvuspb.IndexStateResponse
ctx context.Context
}
func (dipt *GetIndexStateTask) OnEnqueue() error {
......@@ -1280,7 +1279,7 @@ func (dipt *GetIndexStateTask) SetTs(ts Timestamp) {
func (dipt *GetIndexStateTask) PreExecute() error {
dipt.Base.MsgType = commonpb.MsgType_kGetIndexState
dipt.Base.SourceID = Params.ProxyID()
dipt.Base.SourceID = Params.ProxyID
collName, fieldName := dipt.CollectionName, dipt.FieldName
......@@ -1296,9 +1295,18 @@ func (dipt *GetIndexStateTask) PreExecute() error {
}
func (dipt *GetIndexStateTask) Execute() error {
var err error
dipt.result, err = dipt.masterClient.GetIndexState(dipt.ctx, dipt.IndexStateRequest)
return err
// TODO: use index service client
//var err error
//dipt.result, err = dipt.masterClient.GetIndexState(dipt.IndexStateRequest)
//return err
dipt.result = &milvuspb.IndexStateResponse{
Status: &commonpb.Status{
ErrorCode: 0,
Reason: "",
},
State: commonpb.IndexState_FINISHED,
}
return nil
}
func (dipt *GetIndexStateTask) PostExecute() error {
......
......@@ -374,13 +374,13 @@ func (sched *TaskScheduler) queryResultLoop() {
defer sched.wg.Done()
unmarshal := util.NewUnmarshalDispatcher()
queryResultMsgStream := pulsarms.NewPulsarMsgStream(sched.ctx, Params.MsgStreamSearchResultBufSize())
queryResultMsgStream.SetPulsarClient(Params.PulsarAddress())
queryResultMsgStream.CreatePulsarConsumers(Params.SearchResultChannelNames(),
Params.ProxySubName(),
queryResultMsgStream := pulsarms.NewPulsarMsgStream(sched.ctx, Params.MsgStreamSearchResultBufSize)
queryResultMsgStream.SetPulsarClient(Params.PulsarAddress)
queryResultMsgStream.CreatePulsarConsumers(Params.SearchResultChannelNames,
Params.ProxySubName,
unmarshal,
Params.MsgStreamSearchResultPulsarBufSize())
queryNodeNum := Params.queryNodeNum()
Params.MsgStreamSearchResultPulsarBufSize)
queryNodeNum := Params.QueryNodeNum
queryResultMsgStream.Start()
defer queryResultMsgStream.Close()
......
......@@ -47,15 +47,15 @@ func newTimeTick(ctx context.Context,
cancel: cancel,
tsoAllocator: tsoAllocator,
interval: interval,
peerID: Params.ProxyID(),
peerID: Params.ProxyID,
checkFunc: checkFunc,
}
t.tickMsgStream = pulsarms.NewPulsarMsgStream(t.ctx, Params.MsgStreamTimeTickBufSize())
pulsarAddress := Params.PulsarAddress()
t.tickMsgStream = pulsarms.NewPulsarMsgStream(t.ctx, Params.MsgStreamTimeTickBufSize)
pulsarAddress := Params.PulsarAddress
t.tickMsgStream.SetPulsarClient(pulsarAddress)
t.tickMsgStream.CreatePulsarProducers(Params.ProxyTimeTickChannelNames())
t.tickMsgStream.CreatePulsarProducers(Params.ProxyTimeTickChannelNames)
return t
}
......@@ -74,7 +74,7 @@ func (tt *timeTick) tick() error {
msgPack := msgstream.MsgPack{}
timeTickMsg := &msgstream.TimeTickMsg{
BaseMsg: msgstream.BaseMsg{
HashValues: []uint32{uint32(Params.ProxyID())},
HashValues: []uint32{uint32(Params.ProxyID)},
},
TimeTickMsg: internalpb2.TimeTickMsg{
Base: &commonpb.MsgBase{
......
......@@ -28,13 +28,13 @@ func TestTimeTick_Start(t *testing.T) {
func TestTimeTick_Start2(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
masterAddr := Params.MasterAddress()
masterAddr := Params.MasterAddress
tsoAllocator, err := allocator.NewTimestampAllocator(ctx, masterAddr)
assert.Nil(t, err)
err = tsoAllocator.Start()
assert.Nil(t, err)
tt := newTimeTick(ctx, tsoAllocator, Params.TimeTickInterval(), checkFunc)
tt := newTimeTick(ctx, tsoAllocator, Params.TimeTickInterval, checkFunc)
defer func() {
cancel()
......
......@@ -30,9 +30,9 @@ func ValidateCollectionName(collName string) error {
}
invalidMsg := "Invalid collection name: " + collName + ". "
if int64(len(collName)) > Params.MaxNameLength() {
if int64(len(collName)) > Params.MaxNameLength {
msg := invalidMsg + "The length of a collection name must be less than " +
strconv.FormatInt(Params.MaxNameLength(), 10) + " characters."
strconv.FormatInt(Params.MaxNameLength, 10) + " characters."
return errors.New(msg)
}
......@@ -61,9 +61,9 @@ func ValidatePartitionTag(partitionTag string, strictCheck bool) error {
return errors.New(msg)
}
if int64(len(partitionTag)) > Params.MaxNameLength() {
if int64(len(partitionTag)) > Params.MaxNameLength {
msg := invalidMsg + "The length of a partition tag must be less than " +
strconv.FormatInt(Params.MaxNameLength(), 10) + " characters."
strconv.FormatInt(Params.MaxNameLength, 10) + " characters."
return errors.New(msg)
}
......@@ -95,9 +95,9 @@ func ValidateFieldName(fieldName string) error {
}
invalidMsg := "Invalid field name: " + fieldName + ". "
if int64(len(fieldName)) > Params.MaxNameLength() {
if int64(len(fieldName)) > Params.MaxNameLength {
msg := invalidMsg + "The length of a field name must be less than " +
strconv.FormatInt(Params.MaxNameLength(), 10) + " characters."
strconv.FormatInt(Params.MaxNameLength, 10) + " characters."
return errors.New(msg)
}
......@@ -119,9 +119,9 @@ func ValidateFieldName(fieldName string) error {
}
func ValidateDimension(dim int64, isBinary bool) error {
if dim <= 0 || dim > Params.MaxDimension() {
if dim <= 0 || dim > Params.MaxDimension {
return errors.New("invalid dimension: " + strconv.FormatInt(dim, 10) + ". should be in range 1 ~ " +
strconv.FormatInt(Params.MaxDimension(), 10) + ".")
strconv.FormatInt(Params.MaxDimension, 10) + ".")
}
if isBinary && dim%8 != 0 {
return errors.New("invalid dimension: " + strconv.FormatInt(dim, 10) + ". should be multiple of 8.")
......
......@@ -84,13 +84,13 @@ func TestValidateFieldName(t *testing.T) {
func TestValidateDimension(t *testing.T) {
assert.Nil(t, ValidateDimension(1, false))
assert.Nil(t, ValidateDimension(Params.MaxDimension(), false))
assert.Nil(t, ValidateDimension(Params.MaxDimension, false))
assert.Nil(t, ValidateDimension(8, true))
assert.Nil(t, ValidateDimension(Params.MaxDimension(), true))
assert.Nil(t, ValidateDimension(Params.MaxDimension, true))
// invalid dim
assert.NotNil(t, ValidateDimension(-1, false))
assert.NotNil(t, ValidateDimension(Params.MaxDimension()+1, false))
assert.NotNil(t, ValidateDimension(Params.MaxDimension+1, false))
assert.NotNil(t, ValidateDimension(9, true))
}
......
......@@ -2,7 +2,11 @@ package proxyservice
import (
"context"
"fmt"
"io/ioutil"
"log"
"os"
"path"
"runtime"
"time"
"github.com/zilliztech/milvus-distributed/internal/msgstream/util"
......@@ -18,18 +22,63 @@ import (
)
const (
timeoutInterval = time.Second * 10
timeoutInterval = time.Second * 10
StartParamsKey = "START_PARAMS"
ChannelYamlContent = "advanced/channel.yaml"
CommonYamlContent = "advanced/common.yaml"
DataNodeYamlContent = "advanced/data_node.yaml"
MasterYamlContent = "advanced/master.yaml"
ProxyNodeYamlContent = "advanced/proxy_node.yaml"
QueryNodeYamlContent = "advanced/query_node.yaml"
WriteNodeYamlContent = "advanced/write_node.yaml"
MilvusYamlContent = "milvus.yaml"
)
func (s *ServiceImpl) fillNodeInitParams() error {
s.nodeStartParams = make([]*commonpb.KeyValuePair, 0)
nodeParams := &ParamTable{}
nodeParams.Init()
err := nodeParams.LoadYaml("advanced/proxy_node.yaml")
if err != nil {
return err
getConfigContentByName := func(fileName string) []byte {
_, fpath, _, _ := runtime.Caller(0)
configFile := path.Dir(fpath) + "/../../../configs/" + fileName
_, err := os.Stat(configFile)
if os.IsNotExist(err) {
runPath, err := os.Getwd()
if err != nil {
panic(err)
}
configFile = runPath + "/configs/" + fileName
}
data, err := ioutil.ReadFile(configFile)
if err != nil {
panic(err)
}
return data
}
channelYamlContent := getConfigContentByName(ChannelYamlContent)
commonYamlContent := getConfigContentByName(CommonYamlContent)
dataNodeYamlContent := getConfigContentByName(DataNodeYamlContent)
masterYamlContent := getConfigContentByName(MasterYamlContent)
proxyNodeYamlContent := getConfigContentByName(ProxyNodeYamlContent)
queryNodeYamlContent := getConfigContentByName(QueryNodeYamlContent)
writeNodeYamlContent := getConfigContentByName(WriteNodeYamlContent)
milvusYamlContent := getConfigContentByName(MilvusYamlContent)
var allContent []byte
allContent = append(allContent, channelYamlContent...)
allContent = append(allContent, commonYamlContent...)
allContent = append(allContent, dataNodeYamlContent...)
allContent = append(allContent, masterYamlContent...)
allContent = append(allContent, proxyNodeYamlContent...)
allContent = append(allContent, queryNodeYamlContent...)
allContent = append(allContent, writeNodeYamlContent...)
allContent = append(allContent, milvusYamlContent...)
s.nodeStartParams = append(s.nodeStartParams, &commonpb.KeyValuePair{
Key: StartParamsKey,
Value: string(allContent),
})
return nil
}
......@@ -40,12 +89,12 @@ func (s *ServiceImpl) Init() error {
}
serviceTimeTickMsgStream := pulsarms.NewPulsarTtMsgStream(s.ctx, 1024)
serviceTimeTickMsgStream.SetPulsarClient(Params.PulsarAddress())
serviceTimeTickMsgStream.CreatePulsarProducers([]string{Params.ServiceTimeTickChannel()})
serviceTimeTickMsgStream.SetPulsarClient(Params.PulsarAddress)
serviceTimeTickMsgStream.CreatePulsarProducers([]string{Params.ServiceTimeTickChannel})
nodeTimeTickMsgStream := pulsarms.NewPulsarMsgStream(s.ctx, 1024)
nodeTimeTickMsgStream.SetPulsarClient(Params.PulsarAddress())
nodeTimeTickMsgStream.CreatePulsarConsumers(Params.NodeTimeTickChannel(),
nodeTimeTickMsgStream.SetPulsarClient(Params.PulsarAddress)
nodeTimeTickMsgStream.CreatePulsarConsumers(Params.NodeTimeTickChannel,
"proxyservicesub", // TODO: add config
util.NewUnmarshalDispatcher(),
1024)
......@@ -53,20 +102,6 @@ func (s *ServiceImpl) Init() error {
ttBarrier := newSoftTimeTickBarrier(s.ctx, nodeTimeTickMsgStream, []UniqueID{0}, 10)
s.tick = newTimeTick(s.ctx, ttBarrier, serviceTimeTickMsgStream)
// dataServiceAddr := Params.DataServiceAddress()
// s.dataServiceClient = dataservice.NewClient(dataServiceAddr)
// insertChannelsRequest := &datapb.InsertChannelRequest{}
// insertChannelNames, err := s.dataServiceClient.GetInsertChannels(insertChannelsRequest)
// if err != nil {
// return err
// }
// if len(insertChannelNames.Values) > 0 {
// namesStr := strings.Join(insertChannelNames.Values, ",")
// s.nodeStartParams = append(s.nodeStartParams, &commonpb.KeyValuePair{Key: KInsertChannelNames, Value: namesStr})
// }
s.state.State.StateCode = internalpb2.StateCode_HEALTHY
return nil
......@@ -88,7 +123,7 @@ func (s *ServiceImpl) GetComponentStates() (*internalpb2.ComponentStates, error)
}
func (s *ServiceImpl) GetTimeTickChannel() (string, error) {
return Params.ServiceTimeTickChannel(), nil
return Params.ServiceTimeTickChannel, nil
}
func (s *ServiceImpl) GetStatisticsChannel() (string, error) {
......@@ -96,7 +131,7 @@ func (s *ServiceImpl) GetStatisticsChannel() (string, error) {
}
func (s *ServiceImpl) RegisterLink() (*milvuspb.RegisterLinkResponse, error) {
fmt.Println("register link")
log.Println("register link")
ctx, cancel := context.WithTimeout(s.ctx, timeoutInterval)
defer cancel()
......@@ -133,7 +168,7 @@ func (s *ServiceImpl) RegisterLink() (*milvuspb.RegisterLinkResponse, error) {
}
func (s *ServiceImpl) RegisterNode(request *proxypb.RegisterNodeRequest) (*proxypb.RegisterNodeResponse, error) {
fmt.Println("RegisterNode: ", request)
log.Println("RegisterNode: ", request)
ctx, cancel := context.WithTimeout(s.ctx, timeoutInterval)
defer cancel()
......@@ -173,7 +208,7 @@ func (s *ServiceImpl) RegisterNode(request *proxypb.RegisterNodeRequest) (*proxy
}
func (s *ServiceImpl) InvalidateCollectionMetaCache(request *proxypb.InvalidateCollMetaCacheRequest) error {
fmt.Println("InvalidateCollectionMetaCache")
log.Println("InvalidateCollectionMetaCache")
ctx, cancel := context.WithTimeout(s.ctx, timeoutInterval)
defer cancel()
......
......@@ -8,48 +8,60 @@ import (
type ParamTable struct {
paramtable.BaseTable
PulsarAddress string
MasterAddress string
NodeTimeTickChannel []string
ServiceTimeTickChannel string
DataServiceAddress string
}
var Params ParamTable
func (pt *ParamTable) Init() {
pt.BaseTable.Init()
pt.initPulsarAddress()
pt.initMasterAddress()
pt.initNodeTimeTickChannel()
pt.initServiceTimeTickChannel()
pt.initDataServiceAddress()
}
func (pt *ParamTable) PulsarAddress() string {
func (pt *ParamTable) initPulsarAddress() {
ret, err := pt.Load("_PulsarAddress")
if err != nil {
panic(err)
}
return ret
pt.PulsarAddress = ret
}
func (pt *ParamTable) MasterAddress() string {
func (pt *ParamTable) initMasterAddress() {
ret, err := pt.Load("_MasterAddress")
if err != nil {
panic(err)
}
return ret
pt.MasterAddress = ret
}
func (pt *ParamTable) NodeTimeTickChannel() []string {
func (pt *ParamTable) initNodeTimeTickChannel() {
prefix, err := pt.Load("msgChannel.chanNamePrefix.proxyTimeTick")
if err != nil {
log.Panic(err)
}
prefix += "-0"
return []string{prefix}
pt.NodeTimeTickChannel = []string{prefix}
}
func (pt *ParamTable) ServiceTimeTickChannel() string {
func (pt *ParamTable) initServiceTimeTickChannel() {
ch, err := pt.Load("msgChannel.chanNamePrefix.proxyServiceTimeTick")
if err != nil {
log.Panic(err)
}
return ch
pt.ServiceTimeTickChannel = ch
}
func (pt *ParamTable) DataServiceAddress() string {
func (pt *ParamTable) initDataServiceAddress() {
// NOT USED NOW
return "TODO: read from config"
pt.DataServiceAddress = "TODO: read from config"
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册