diff --git a/internal/dataservice/allocator.go b/internal/dataservice/allocator.go index 08a7106fbd645dbf384b24b2e4e99b39aa8a3701..b3b4b3730868a9ce7dede4f12cf07e36bb723192 100644 --- a/internal/dataservice/allocator.go +++ b/internal/dataservice/allocator.go @@ -1,23 +1,54 @@ package dataservice +import ( + "github.com/zilliztech/milvus-distributed/internal/distributed/masterservice" + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/proto/masterpb" +) + type allocator interface { allocTimestamp() (Timestamp, error) allocID() (UniqueID, error) } type allocatorImpl struct { - // TODO call allocate functions in client.go in master service + masterClient *masterservice.GrpcClient } -// TODO implements -func newAllocatorImpl() *allocatorImpl { - return nil +func newAllocatorImpl(masterClient *masterservice.GrpcClient) *allocatorImpl { + return &allocatorImpl{ + masterClient: masterClient, + } } func (allocator *allocatorImpl) allocTimestamp() (Timestamp, error) { - return 0, nil + resp, err := allocator.masterClient.AllocTimestamp(&masterpb.TsoRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kShowCollections, + MsgID: -1, // todo add msg id + Timestamp: 0, // todo + SourceID: -1, // todo + }, + Count: 1, + }) + if err != nil { + return 0, err + } + return resp.Timestamp, nil } func (allocator *allocatorImpl) allocID() (UniqueID, error) { - return 0, nil + resp, err := allocator.masterClient.AllocID(&masterpb.IDRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kShowCollections, + MsgID: -1, // todo add msg id + Timestamp: 0, // todo + SourceID: -1, // todo + }, + Count: 1, + }) + if err != nil { + return 0, err + } + return resp.ID, nil } diff --git a/internal/dataservice/cluster.go b/internal/dataservice/cluster.go new file mode 100644 index 0000000000000000000000000000000000000000..4ce264253a31c249768321c7c0fd4219c1162ef5 --- /dev/null +++ b/internal/dataservice/cluster.go @@ -0,0 +1,127 @@ +package dataservice + +import ( + "log" + "sort" + "sync" + + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + + "github.com/zilliztech/milvus-distributed/internal/proto/datapb" + + "github.com/zilliztech/milvus-distributed/internal/proto/internalpb2" + + "github.com/zilliztech/milvus-distributed/internal/distributed/datanode" +) + +type ( + dataNode struct { + id int64 + address struct { + ip string + port int64 + } + client *datanode.Client + channelNum int + } + dataNodeCluster struct { + mu sync.RWMutex + finishCh chan struct{} + nodes []*dataNode + } +) + +func newDataNodeCluster(finishCh chan struct{}) *dataNodeCluster { + return &dataNodeCluster{ + finishCh: finishCh, + nodes: make([]*dataNode, 0), + } +} + +func (c *dataNodeCluster) Register(ip string, port int64, id int64) { + c.mu.Lock() + defer c.mu.Unlock() + if !c.checkDataNodeNotExist(ip, port) { + c.nodes = append(c.nodes, &dataNode{ + id: id, + address: struct { + ip string + port int64 + }{ip: ip, port: port}, + channelNum: 0, + }) + } + if len(c.nodes) == Params.DataNodeNum { + close(c.finishCh) + } +} + +func (c *dataNodeCluster) checkDataNodeNotExist(ip string, port int64) bool { + for _, node := range c.nodes { + if node.address.ip == ip || node.address.port == port { + return false + } + } + return true +} + +func (c *dataNodeCluster) GetNumOfNodes() int { + return len(c.nodes) +} + +func (c *dataNodeCluster) GetNodeIDs() []int64 { + c.mu.RLock() + defer c.mu.RUnlock() + ret := make([]int64, len(c.nodes)) + for _, node := range c.nodes { + ret = append(ret, node.id) + } + return ret +} + +func (c *dataNodeCluster) WatchInsertChannels(groups []channelGroup) { + c.mu.Lock() + defer c.mu.Unlock() + sort.Slice(c.nodes, func(i, j int) bool { return c.nodes[i].channelNum < c.nodes[j].channelNum }) + for i, group := range groups { + err := c.nodes[i%len(c.nodes)].client.WatchDmChannels(&datapb.WatchDmChannelRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kDescribeCollection, + MsgID: -1, // todo + Timestamp: 0, // todo + SourceID: -1, // todo + }, + ChannelNames: group, + }) + if err != nil { + log.Println(err.Error()) + continue + } + } +} + +func (c *dataNodeCluster) GetDataNodeStates() ([]*internalpb2.ComponentInfo, error) { + c.mu.RLock() + defer c.mu.RUnlock() + ret := make([]*internalpb2.ComponentInfo, 0) + for _, node := range c.nodes { + states, err := node.client.GetComponentStates(nil) + if err != nil { + log.Println(err.Error()) + continue + } + ret = append(ret, states.State) + } + return ret, nil +} + +func (c *dataNodeCluster) FlushSegment(request *datapb.FlushSegRequest) { + c.mu.RLock() + defer c.mu.RUnlock() + for _, node := range c.nodes { + if err := node.client.FlushSegments(request); err != nil { + log.Println(err.Error()) + continue + } + } +} diff --git a/internal/dataservice/meta.go b/internal/dataservice/meta.go index f3475182a82854f4ac0113fcb748318d275038ac..84846acfcba465e5e2565d88e852ff9e4879b0df 100644 --- a/internal/dataservice/meta.go +++ b/internal/dataservice/meta.go @@ -8,16 +8,12 @@ import ( "github.com/zilliztech/milvus-distributed/internal/proto/datapb" "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" - "github.com/zilliztech/milvus-distributed/internal/util/typeutil" - "github.com/golang/protobuf/proto" "github.com/zilliztech/milvus-distributed/internal/errors" "github.com/zilliztech/milvus-distributed/internal/kv" ) type ( - UniqueID = typeutil.UniqueID - Timestamp = typeutil.Timestamp errSegmentNotFound struct { segmentID UniqueID } @@ -33,9 +29,8 @@ type ( client kv.TxnBase // client of a reliable kv service, i.e. etcd client collID2Info map[UniqueID]*collectionInfo // collection id to collection info segID2Info map[UniqueID]*datapb.SegmentInfo // segment id to segment info - - allocator allocator - ddLock sync.RWMutex + allocator allocator + ddLock sync.RWMutex } ) diff --git a/internal/dataservice/mock.go b/internal/dataservice/mock.go new file mode 100644 index 0000000000000000000000000000000000000000..b3115a7b89d8ebac93c4eee4433ac8352973ad98 --- /dev/null +++ b/internal/dataservice/mock.go @@ -0,0 +1,48 @@ +package dataservice + +import ( + "sync/atomic" + "time" + + "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" + + memkv "github.com/zilliztech/milvus-distributed/internal/kv/mem" + "github.com/zilliztech/milvus-distributed/internal/util/tsoutil" +) + +func newMemoryMeta(allocator allocator) (*meta, error) { + memoryKV := memkv.NewMemoryKV() + return newMeta(memoryKV, allocator) +} + +type MockAllocator struct { + cnt int64 +} + +func (m *MockAllocator) allocTimestamp() (Timestamp, error) { + val := atomic.AddInt64(&m.cnt, 1) + phy := time.Now().UnixNano() / int64(time.Millisecond) + ts := tsoutil.ComposeTS(phy, val) + return ts, nil +} + +func (m *MockAllocator) allocID() (UniqueID, error) { + val := atomic.AddInt64(&m.cnt, 1) + return val, nil +} + +func newMockAllocator() *MockAllocator { + return &MockAllocator{} +} + +func NewTestSchema() *schemapb.CollectionSchema { + return &schemapb.CollectionSchema{ + Name: "test", + Description: "schema for test used", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + {FieldID: 1, Name: "field1", IsPrimaryKey: false, Description: "field no.1", DataType: schemapb.DataType_STRING}, + {FieldID: 2, Name: "field2", IsPrimaryKey: false, Description: "field no.2", DataType: schemapb.DataType_VECTOR_FLOAT}, + }, + } +} diff --git a/internal/dataservice/param.go b/internal/dataservice/param.go index 034db2adfd152aa2057ad7dfe4cd903bd0710bfe..b697406c704cb3ca2e059f0b427705e55cbf036b 100644 --- a/internal/dataservice/param.go +++ b/internal/dataservice/param.go @@ -9,6 +9,9 @@ type ParamTable struct { Address string Port int + NodeID int64 + + MasterAddress string EtcdAddress string MetaRootPath string @@ -25,7 +28,7 @@ type ParamTable struct { InsertChannelNumPerCollection int64 StatisticsChannelName string TimeTickChannelName string - DataNodeNum int64 + DataNodeNum int } var Params ParamTable @@ -42,6 +45,7 @@ func (p *ParamTable) Init() { // set members p.initAddress() p.initPort() + p.NodeID = 1 // todo p.initEtcdAddress() p.initMetaRootPath() @@ -51,6 +55,12 @@ func (p *ParamTable) Init() { p.initSegmentSize() p.initSegmentSizeFactor() p.initDefaultRecordSize() + p.initSegIDAssignExpiration() + p.initInsertChannelPrefixName() + p.initInsertChannelNumPerCollection() + p.initStatisticsChannelName() + p.initTimeTickChannelName() + p.initDataNodeNum() } func (p *ParamTable) initAddress() { @@ -115,3 +125,28 @@ func (p *ParamTable) initSegmentSizeFactor() { func (p *ParamTable) initDefaultRecordSize() { p.DefaultRecordSize = p.ParseInt64("master.segment.defaultSizePerRecord") } + +// TODO read from config/env +func (p *ParamTable) initSegIDAssignExpiration() { + p.SegIDAssignExpiration = 3000 //ms +} + +func (p *ParamTable) initInsertChannelPrefixName() { + p.InsertChannelPrefixName = "insert-channel-" +} + +func (p *ParamTable) initInsertChannelNumPerCollection() { + p.InsertChannelNumPerCollection = 4 +} + +func (p *ParamTable) initStatisticsChannelName() { + p.StatisticsChannelName = "dataservice-statistics-channel" +} + +func (p *ParamTable) initTimeTickChannelName() { + p.TimeTickChannelName = "dataservice-timetick-channel" +} + +func (p *ParamTable) initDataNodeNum() { + p.DataNodeNum = 2 +} diff --git a/internal/dataservice/segment_allocator.go b/internal/dataservice/segment_allocator.go index abd92d14dc9e4f4ec3afa8aa1812da86a878008c..b92e88e8413ed448e03ea178d00adf6cc1f49e03 100644 --- a/internal/dataservice/segment_allocator.go +++ b/internal/dataservice/segment_allocator.go @@ -2,10 +2,13 @@ package dataservice import ( "fmt" + "log" "strconv" "sync" "time" + "github.com/zilliztech/milvus-distributed/internal/proto/datapb" + "github.com/zilliztech/milvus-distributed/internal/util/typeutil" "github.com/zilliztech/milvus-distributed/internal/util/tsoutil" @@ -26,7 +29,7 @@ func (err errRemainInSufficient) Error() string { // segmentAllocator is used to allocate rows for segments and record the allocations. type segmentAllocator interface { // OpenSegment add the segment to allocator and set it allocatable - OpenSegment(collectionID UniqueID, partitionID UniqueID, segmentID UniqueID, cRange channelGroup) error + OpenSegment(segmentInfo *datapb.SegmentInfo) error // AllocSegment allocate rows and record the allocation. AllocSegment(collectionID UniqueID, partitionID UniqueID, channelName string, requestRows int) (UniqueID, int, Timestamp, error) // GetSealedSegments get all sealed segment. @@ -37,6 +40,8 @@ type segmentAllocator interface { DropSegment(segmentID UniqueID) // ExpireAllocations check all allocations' expire time and remove the expired allocation. ExpireAllocations(timeTick Timestamp) error + // SealAllSegments get all opened segment ids of collection. return success and failed segment ids + SealAllSegments(collectionID UniqueID) (bool, []UniqueID) // IsAllocationsExpired check all allocations of segment expired. IsAllocationsExpired(segmentID UniqueID, ts Timestamp) (bool, error) } @@ -50,7 +55,7 @@ type ( sealed bool lastExpireTime Timestamp allocations []*allocation - cRange channelGroup + channelGroup channelGroup } allocation struct { rowNums int @@ -67,9 +72,9 @@ type ( } ) -func newSegmentAssigner(metaTable *meta, allocator allocator) (*segmentAllocatorImpl, error) { +func newSegmentAllocator(meta *meta, allocator allocator) (*segmentAllocatorImpl, error) { segmentAllocator := &segmentAllocatorImpl{ - mt: metaTable, + mt: meta, segments: make(map[UniqueID]*segmentStatus), segmentExpireDuration: Params.SegIDAssignExpiration, segmentThreshold: Params.SegmentSize * 1024 * 1024, @@ -79,22 +84,22 @@ func newSegmentAssigner(metaTable *meta, allocator allocator) (*segmentAllocator return segmentAllocator, nil } -func (allocator *segmentAllocatorImpl) OpenSegment(collectionID UniqueID, partitionID UniqueID, segmentID UniqueID, cRange channelGroup) error { - if _, ok := allocator.segments[segmentID]; ok { - return fmt.Errorf("segment %d already exist", segmentID) +func (allocator *segmentAllocatorImpl) OpenSegment(segmentInfo *datapb.SegmentInfo) error { + if _, ok := allocator.segments[segmentInfo.SegmentID]; ok { + return fmt.Errorf("segment %d already exist", segmentInfo.SegmentID) } - totalRows, err := allocator.estimateTotalRows(collectionID) + totalRows, err := allocator.estimateTotalRows(segmentInfo.CollectionID) if err != nil { return err } - allocator.segments[segmentID] = &segmentStatus{ - id: segmentID, - collectionID: collectionID, - partitionID: partitionID, + allocator.segments[segmentInfo.SegmentID] = &segmentStatus{ + id: segmentInfo.SegmentID, + collectionID: segmentInfo.CollectionID, + partitionID: segmentInfo.PartitionID, total: totalRows, sealed: false, lastExpireTime: 0, - cRange: cRange, + channelGroup: segmentInfo.InsertChannels, } return nil } @@ -106,7 +111,7 @@ func (allocator *segmentAllocatorImpl) AllocSegment(collectionID UniqueID, for _, segStatus := range allocator.segments { if segStatus.sealed || segStatus.collectionID != collectionID || segStatus.partitionID != partitionID || - !segStatus.cRange.Contains(channelName) { + !segStatus.channelGroup.Contains(channelName) { continue } var success bool @@ -240,3 +245,24 @@ func (allocator *segmentAllocatorImpl) IsAllocationsExpired(segmentID UniqueID, } return status.lastExpireTime <= ts, nil } + +func (allocator *segmentAllocatorImpl) SealAllSegments(collectionID UniqueID) (bool, []UniqueID) { + allocator.mu.Lock() + defer allocator.mu.Unlock() + failed := make([]UniqueID, 0) + success := true + for _, status := range allocator.segments { + if status.collectionID == collectionID { + if status.sealed { + continue + } + if err := allocator.mt.SealSegment(status.id); err != nil { + log.Printf("seal segment error: %s", err.Error()) + failed = append(failed, status.id) + success = false + } + status.sealed = true + } + } + return success, failed +} diff --git a/internal/dataservice/segment_allocator_test.go b/internal/dataservice/segment_allocator_test.go new file mode 100644 index 0000000000000000000000000000000000000000..33d0bb4c66ef378eeccb7a2495979a06b5c295a1 --- /dev/null +++ b/internal/dataservice/segment_allocator_test.go @@ -0,0 +1,129 @@ +package dataservice + +import ( + "math" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestAllocSegment(t *testing.T) { + Params.Init() + mockAllocator := newMockAllocator() + meta, err := newMemoryMeta(mockAllocator) + assert.Nil(t, err) + segAllocator, err := newSegmentAllocator(meta, mockAllocator) + assert.Nil(t, err) + + schema := NewTestSchema() + collID, err := mockAllocator.allocID() + err = meta.AddCollection(&collectionInfo{ + ID: collID, + Schema: schema, + }) + assert.Nil(t, err) + segmentInfo, err := meta.BuildSegment(collID, 100, []string{"c1", "c2"}) + assert.Nil(t, err) + err = meta.AddSegment(segmentInfo) + assert.Nil(t, err) + err = segAllocator.OpenSegment(segmentInfo) + assert.Nil(t, err) + + cases := []struct { + collectionID UniqueID + partitionID UniqueID + channelName string + requestRows int + expectResult bool + }{ + {collID, 100, "c1", 100, true}, + {collID + 1, 100, "c1", 100, false}, + {collID, 101, "c1", 100, false}, + {collID, 100, "c3", 100, false}, + {collID, 100, "c1", math.MaxInt64, false}, + } + for _, c := range cases { + id, count, expireTime, err := segAllocator.AllocSegment(c.collectionID, c.partitionID, c.channelName, c.requestRows) + if c.expectResult { + assert.Nil(t, err) + assert.EqualValues(t, c.requestRows, count) + assert.NotEqualValues(t, 0, id) + assert.NotEqualValues(t, 0, expireTime) + } else { + assert.NotNil(t, err) + } + } +} + +func TestSealSegment(t *testing.T) { + Params.Init() + mockAllocator := newMockAllocator() + meta, err := newMemoryMeta(mockAllocator) + assert.Nil(t, err) + segAllocator, err := newSegmentAllocator(meta, mockAllocator) + assert.Nil(t, err) + + schema := NewTestSchema() + collID, err := mockAllocator.allocID() + err = meta.AddCollection(&collectionInfo{ + ID: collID, + Schema: schema, + }) + assert.Nil(t, err) + var lastSegID UniqueID + for i := 0; i < 10; i++ { + segmentInfo, err := meta.BuildSegment(collID, 100, []string{"c" + strconv.Itoa(i)}) + assert.Nil(t, err) + err = meta.AddSegment(segmentInfo) + assert.Nil(t, err) + err = segAllocator.OpenSegment(segmentInfo) + assert.Nil(t, err) + lastSegID = segmentInfo.SegmentID + } + + err = segAllocator.SealSegment(lastSegID) + assert.Nil(t, err) + success, ids := segAllocator.SealAllSegments(collID) + assert.True(t, success) + assert.EqualValues(t, 0, len(ids)) + sealedSegments, err := segAllocator.GetSealedSegments() + assert.Nil(t, err) + assert.EqualValues(t, 10, sealedSegments) +} + +func TestExpireSegment(t *testing.T) { + Params.Init() + mockAllocator := newMockAllocator() + meta, err := newMemoryMeta(mockAllocator) + assert.Nil(t, err) + segAllocator, err := newSegmentAllocator(meta, mockAllocator) + assert.Nil(t, err) + + schema := NewTestSchema() + collID, err := mockAllocator.allocID() + err = meta.AddCollection(&collectionInfo{ + ID: collID, + Schema: schema, + }) + assert.Nil(t, err) + segmentInfo, err := meta.BuildSegment(collID, 100, []string{"c1", "c2"}) + assert.Nil(t, err) + err = meta.AddSegment(segmentInfo) + assert.Nil(t, err) + err = segAllocator.OpenSegment(segmentInfo) + assert.Nil(t, err) + + id1, _, _, err := segAllocator.AllocSegment(collID, 100, "c1", 10) + assert.Nil(t, err) + time.Sleep(time.Duration(Params.SegIDAssignExpiration) * time.Millisecond) + ts, err := mockAllocator.allocTimestamp() + assert.Nil(t, err) + err = segAllocator.ExpireAllocations(ts) + assert.Nil(t, err) + expired, err := segAllocator.IsAllocationsExpired(id1, ts) + assert.Nil(t, err) + assert.True(t, expired) + assert.EqualValues(t, 0, len(segAllocator.segments[id1].allocations)) +} diff --git a/internal/dataservice/server.go b/internal/dataservice/server.go index 36e0a193fb7b44e68675ca44bbc6ee83b8e46307..848e4db98e682ee2c5aa41a9288bb2cd025f7360 100644 --- a/internal/dataservice/server.go +++ b/internal/dataservice/server.go @@ -4,7 +4,12 @@ import ( "context" "fmt" "log" - "sync" + "time" + + "github.com/zilliztech/milvus-distributed/internal/msgstream" + "github.com/zilliztech/milvus-distributed/internal/msgstream/pulsarms" + + "github.com/zilliztech/milvus-distributed/internal/distributed/masterservice" "github.com/zilliztech/milvus-distributed/internal/proto/milvuspb" @@ -19,6 +24,8 @@ import ( "github.com/zilliztech/milvus-distributed/internal/util/typeutil" ) +const role = "dataservice" + type DataService interface { typeutil.Service RegisterNode(req *datapb.RegisterNodeRequest) (*datapb.RegisterNodeResponse, error) @@ -38,16 +45,9 @@ type DataService interface { } type ( - datanode struct { - nodeID int64 - address struct { - ip string - port int64 - } - // todo add client - } - - Server struct { + UniqueID = typeutil.UniqueID + Timestamp = typeutil.Timestamp + Server struct { ctx context.Context state internalpb2.StateCode client *etcdkv.EtcdKV @@ -56,40 +56,70 @@ type ( statsHandler *statsHandler insertChannelMgr *insertChannelManager allocator allocator + cluster *dataNodeCluster msgProducer *timesync.MsgProducer - nodeIDCounter int64 - nodes []*datanode registerFinishCh chan struct{} - registerMu sync.RWMutex + masterClient *masterservice.GrpcClient + ttMsgStream msgstream.MsgStream } ) func CreateServer(ctx context.Context) (*Server, error) { + ch := make(chan struct{}) return &Server{ ctx: ctx, state: internalpb2.StateCode_INITIALIZING, insertChannelMgr: newInsertChannelManager(), - nodeIDCounter: 0, - nodes: make([]*datanode, 0), - registerFinishCh: make(chan struct{}), + registerFinishCh: ch, + cluster: newDataNodeCluster(ch), }, nil } func (s *Server) Init() error { Params.Init() - s.allocator = newAllocatorImpl() + return nil +} + +func (s *Server) Start() error { + if err := s.connectMaster(); err != nil { + return err + } + s.allocator = newAllocatorImpl(s.masterClient) if err := s.initMeta(); err != nil { return err } s.statsHandler = newStatsHandler(s.meta) - segAllocator, err := newSegmentAssigner(s.meta, s.allocator) + segAllocator, err := newSegmentAllocator(s.meta, s.allocator) if err != nil { return err } s.segAllocator = segAllocator + s.waitDataNodeRegister() + if err = s.loadMetaFromMaster(); err != nil { + return err + } if err = s.initMsgProducer(); err != nil { return err } + s.state = internalpb2.StateCode_HEALTHY + log.Println("start success") + return nil +} + +func (s *Server) connectMaster() error { + log.Println("connecting to master") + master, err := masterservice.NewGrpcClient(Params.MasterAddress, 30*time.Second) + if err != nil { + return err + } + if err = master.Init(nil); err != nil { + return err + } + if err = master.Start(); err != nil { + return err + } + s.masterClient = master + log.Println("connect to master success") return nil } @@ -107,37 +137,109 @@ func (s *Server) initMeta() error { return nil } +func (s *Server) waitDataNodeRegister() { + log.Println("waiting data node to register") + <-s.registerFinishCh + log.Println("all data nodes register") +} + func (s *Server) initMsgProducer() error { // todo ttstream and peerids - timeTickBarrier := timesync.NewHardTimeTickBarrier(nil, nil) - // todo add watchers - producer, err := timesync.NewTimeSyncMsgProducer(timeTickBarrier) + s.ttMsgStream = pulsarms.NewPulsarTtMsgStream(s.ctx, 1024) + s.ttMsgStream.Start() + timeTickBarrier := timesync.NewHardTimeTickBarrier(s.ttMsgStream, s.cluster.GetNodeIDs()) + dataNodeTTWatcher := newDataNodeTimeTickWatcher(s.meta, s.segAllocator, s.cluster) + producer, err := timesync.NewTimeSyncMsgProducer(timeTickBarrier, dataNodeTTWatcher) if err != nil { return err } s.msgProducer = producer - return nil -} - -func (s *Server) Start() error { - s.waitDataNodeRegister() - // todo add load meta from master s.msgProducer.Start(s.ctx) return nil } - -func (s *Server) waitDataNodeRegister() { - <-s.registerFinishCh +func (s *Server) loadMetaFromMaster() error { + log.Println("loading collection meta from master") + collections, err := s.masterClient.ShowCollections(&milvuspb.ShowCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kShowCollections, + MsgID: -1, // todo add msg id + Timestamp: 0, // todo + SourceID: -1, // todo + }, + DbName: "", + }) + if err != nil { + return err + } + for _, collectionName := range collections.CollectionNames { + collection, err := s.masterClient.DescribeCollection(&milvuspb.DescribeCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kDescribeCollection, + MsgID: -1, // todo + Timestamp: 0, // todo + SourceID: -1, // todo + }, + DbName: "", + CollectionName: collectionName, + }) + if err != nil { + log.Println(err.Error()) + continue + } + partitions, err := s.masterClient.ShowPartitions(&milvuspb.ShowPartitionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kShowPartitions, + MsgID: -1, // todo + Timestamp: 0, // todo + SourceID: -1, // todo + }, + DbName: "", + CollectionName: collectionName, + CollectionID: collection.CollectionID, + }) + if err != nil { + log.Println(err.Error()) + continue + } + err = s.meta.AddCollection(&collectionInfo{ + ID: collection.CollectionID, + Schema: collection.Schema, + partitions: partitions.PartitionIDs, + }) + if err != nil { + log.Println(err.Error()) + continue + } + } + log.Println("load collection meta from master complete") + return nil } func (s *Server) Stop() error { + s.ttMsgStream.Close() s.msgProducer.Close() return nil } func (s *Server) GetComponentStates() (*internalpb2.ComponentStates, error) { - // todo foreach datanode, call GetServiceStates - return nil, nil + resp := &internalpb2.ComponentStates{ + State: &internalpb2.ComponentInfo{ + NodeID: Params.NodeID, + Role: role, + StateCode: s.state, + }, + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, + }, + } + dataNodeStates, err := s.cluster.GetDataNodeStates() + if err != nil { + resp.Status.Reason = err.Error() + return resp, nil + } + resp.SubcomponentStates = dataNodeStates + resp.Status.ErrorCode = commonpb.ErrorCode_SUCCESS + return resp, nil } func (s *Server) GetTimeTickChannel() (*milvuspb.StringResponse, error) { @@ -159,45 +261,27 @@ func (s *Server) GetStatisticsChannel() (*milvuspb.StringResponse, error) { } func (s *Server) RegisterNode(req *datapb.RegisterNodeRequest) (*datapb.RegisterNodeResponse, error) { - s.registerMu.Lock() - defer s.registerMu.Unlock() - resp := &datapb.RegisterNodeResponse{ + s.cluster.Register(req.Address.Ip, req.Address.Port, req.Base.SourceID) + // add init params + return &datapb.RegisterNodeResponse{ Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, + ErrorCode: commonpb.ErrorCode_SUCCESS, }, - } - if !s.checkDataNodeNotExist(req.Address.Ip, req.Address.Port) { - resp.Status.Reason = fmt.Sprintf("data node with address %s exist", req.Address.String()) - return resp, nil - } - s.nodeIDCounter++ - s.nodes = append(s.nodes, &datanode{ - nodeID: s.nodeIDCounter, - address: struct { - ip string - port int64 - }{ip: req.Address.Ip, port: req.Address.Port}, - }) - if s.nodeIDCounter == Params.DataNodeNum { - close(s.registerFinishCh) - } - resp.Status.ErrorCode = commonpb.ErrorCode_SUCCESS - // add init params - return resp, nil -} - -func (s *Server) checkDataNodeNotExist(ip string, port int64) bool { - for _, node := range s.nodes { - if node.address.ip == ip || node.address.port == port { - return false - } - } - return true + }, nil } func (s *Server) Flush(req *datapb.FlushRequest) (*commonpb.Status, error) { - // todo call datanode flush - return nil, nil + success, fails := s.segAllocator.SealAllSegments(req.CollectionID) + log.Printf("sealing failed segments: %v", fails) + if !success { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, + Reason: fmt.Sprintf("flush failed, %d segment can not be sealed", len(fails)), + }, nil + } + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_SUCCESS, + }, nil } func (s *Server) AssignSegmentID(req *datapb.AssignSegIDRequest) (*datapb.AssignSegIDResponse, error) { @@ -264,7 +348,7 @@ func (s *Server) openNewSegment(collectionID UniqueID, partitionID UniqueID, cha if err = s.meta.AddSegment(segmentInfo); err != nil { return err } - if err = s.segAllocator.OpenSegment(collectionID, partitionID, segmentInfo.SegmentID, segmentInfo.InsertChannels); err != nil { + if err = s.segAllocator.OpenSegment(segmentInfo); err != nil { return err } return nil @@ -310,19 +394,19 @@ func (s *Server) GetInsertChannels(req *datapb.InsertChannelRequest) (*internalp resp.Values = ret return resp, nil } - channelGroups, err := s.insertChannelMgr.AllocChannels(req.CollectionID, len(s.nodes)) + channelGroups, err := s.insertChannelMgr.AllocChannels(req.CollectionID, s.cluster.GetNumOfNodes()) if err != nil { resp.Status.ErrorCode = commonpb.ErrorCode_UNEXPECTED_ERROR resp.Status.Reason = err.Error() return resp, nil } + channels := make([]string, Params.InsertChannelNumPerCollection) for _, group := range channelGroups { - for _, c := range group { - channels = append(channels, c) - } + channels = append(channels, group...) } - // todo datanode watch dm channels + s.cluster.WatchInsertChannels(channelGroups) + resp.Values = channels return resp, nil } diff --git a/internal/dataservice/watcher.go b/internal/dataservice/watcher.go index d3c13271c0585c69e8fb1c9da9178b07ac553ede..581e1485574084f54c7e6dcc02226ed12ca9e226 100644 --- a/internal/dataservice/watcher.go +++ b/internal/dataservice/watcher.go @@ -3,6 +3,9 @@ package dataservice import ( "log" + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/proto/datapb" + "golang.org/x/net/context" "github.com/zilliztech/milvus-distributed/internal/msgstream" @@ -14,6 +17,8 @@ type ( msgQueue chan *msgstream.TimeTickMsg } dataNodeTimeTickWatcher struct { + meta *meta + cluster *dataNodeCluster allocator segmentAllocator msgQueue chan *msgstream.TimeTickMsg } @@ -30,7 +35,7 @@ func (watcher *proxyTimeTickWatcher) StartBackgroundLoop(ctx context.Context) { for { select { case <-ctx.Done(): - log.Println("proxy time tick watcher clsoed") + log.Println("proxy time tick watcher closed") return case msg := <-watcher.msgQueue: if err := watcher.allocator.ExpireAllocations(msg.Base.Timestamp); err != nil { @@ -44,9 +49,11 @@ func (watcher *proxyTimeTickWatcher) Watch(msg *msgstream.TimeTickMsg) { watcher.msgQueue <- msg } -func newDataNodeTimeTickWatcher(allocator segmentAllocator) *dataNodeTimeTickWatcher { +func newDataNodeTimeTickWatcher(meta *meta, allocator segmentAllocator, cluster *dataNodeCluster) *dataNodeTimeTickWatcher { return &dataNodeTimeTickWatcher{ + meta: meta, allocator: allocator, + cluster: cluster, msgQueue: make(chan *msgstream.TimeTickMsg, 1), } } @@ -74,7 +81,21 @@ func (watcher *dataNodeTimeTickWatcher) StartBackgroundLoop(ctx context.Context) continue } if expired { - // TODO: flush segment + segmentInfo, err := watcher.meta.GetSegment(id) + if err != nil { + log.Println(err.Error()) + continue + } + watcher.cluster.FlushSegment(&datapb.FlushSegRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kShowCollections, + MsgID: -1, // todo add msg id + Timestamp: 0, // todo + SourceID: -1, // todo + }, + CollectionID: segmentInfo.CollectionID, + SegmentIDs: []int64{segmentInfo.SegmentID}, + }) watcher.allocator.DropSegment(id) } } diff --git a/internal/distributed/dataservice/grpc_service.go b/internal/distributed/dataservice/grpc_service.go index 38dd333ccfc31ecb9e3b4e1e4616b5584623f01a..10e287cb651b5b7646ce1f4ea816983221654507 100644 --- a/internal/distributed/dataservice/grpc_service.go +++ b/internal/distributed/dataservice/grpc_service.go @@ -2,6 +2,10 @@ package dataservice import ( "context" + "log" + "net" + + "google.golang.org/grpc" "github.com/zilliztech/milvus-distributed/internal/dataservice" @@ -13,7 +17,47 @@ import ( ) type Service struct { - server *dataservice.Server + server *dataservice.Server + ctx context.Context + cancel context.CancelFunc + grpcServer *grpc.Server +} + +func NewGrpcService() { + s := &Service{} + var err error + s.ctx, s.cancel = context.WithCancel(context.Background()) + s.server, err = dataservice.CreateServer(s.ctx) + if err != nil { + log.Fatalf("create server error: %s", err.Error()) + return + } + s.grpcServer = grpc.NewServer() + datapb.RegisterDataServiceServer(s.grpcServer, s) + lis, err := net.Listen("tcp", "localhost:11111") // todo address + if err != nil { + log.Fatal(err.Error()) + return + } + if err = s.grpcServer.Serve(lis); err != nil { + log.Fatal(err.Error()) + return + } +} + +func (s *Service) Init() error { + return s.server.Init() +} + +func (s *Service) Start() error { + return s.server.Start() +} + +func (s *Service) Stop() error { + err := s.server.Stop() + s.grpcServer.GracefulStop() + s.cancel() + return err } func (s *Service) RegisterNode(ctx context.Context, request *datapb.RegisterNodeRequest) (*datapb.RegisterNodeResponse, error) { diff --git a/internal/kv/mem/mem_kv.go b/internal/kv/mem/mem_kv.go index 7f2281ff95f8a6713f385935cd73172e8bff5d27..0310ac90e6f77e3b373e1dc460efff8cee9a7d86 100644 --- a/internal/kv/mem/mem_kv.go +++ b/internal/kv/mem/mem_kv.go @@ -1,6 +1,7 @@ package memkv import ( + "strings" "sync" "github.com/google/btree" @@ -110,7 +111,19 @@ func (kv *MemoryKV) MultiSaveAndRemove(saves map[string]string, removals []strin // todo func (kv *MemoryKV) LoadWithPrefix(key string) ([]string, []string, error) { - panic("implement me") + kv.Lock() + defer kv.Unlock() + + keys := make([]string, 0) + values := make([]string, 0) + kv.tree.Ascend(func(i btree.Item) bool { + if strings.HasPrefix(i.(memoryKVItem).key, key) { + keys = append(keys, i.(memoryKVItem).key) + values = append(values, i.(memoryKVItem).value) + } + return true + }) + return keys, values, nil } func (kv *MemoryKV) Close() { diff --git a/internal/timesync/timesync.go b/internal/timesync/timesync.go index 4cd071d3f601f038f6fff8a8ca6def48c5c56151..4f74aca5a798180c8b69868d2b029fa80cde61dd 100644 --- a/internal/timesync/timesync.go +++ b/internal/timesync/timesync.go @@ -38,7 +38,7 @@ type ( } ) -func NewSoftTimeTickBarrier(ttStream *ms.MsgStream, peerIds []UniqueID, minTtInterval Timestamp) *softTimeTickBarrier { +func NewSoftTimeTickBarrier(ttStream ms.MsgStream, peerIds []UniqueID, minTtInterval Timestamp) *softTimeTickBarrier { if len(peerIds) <= 0 { log.Printf("[newSoftTimeTickBarrier] Error: peerIds is empty!\n") return nil @@ -46,7 +46,7 @@ func NewSoftTimeTickBarrier(ttStream *ms.MsgStream, peerIds []UniqueID, minTtInt sttbarrier := softTimeTickBarrier{} sttbarrier.minTtInterval = minTtInterval - sttbarrier.ttStream = *ttStream + sttbarrier.ttStream = ttStream sttbarrier.outTt = make(chan Timestamp, 1024) sttbarrier.peer2LastTt = make(map[UniqueID]Timestamp) for _, id := range peerIds { @@ -86,28 +86,29 @@ func (ttBarrier *softTimeTickBarrier) StartBackgroundLoop(ctx context.Context) { case <-ctx.Done(): log.Printf("[TtBarrierStart] %s\n", ctx.Err()) return - case ttmsgs := <-ttBarrier.ttStream.Chan(): - if len(ttmsgs.Msgs) > 0 { - for _, timetickmsg := range ttmsgs.Msgs { - ttmsg := timetickmsg.(*ms.TimeTickMsg) - oldT, ok := ttBarrier.peer2LastTt[ttmsg.Base.SourceID] - // log.Printf("[softTimeTickBarrier] peer(%d)=%d\n", ttmsg.PeerID, ttmsg.Timestamp) - - if !ok { - log.Printf("[softTimeTickBarrier] Warning: peerID %d not exist\n", ttmsg.Base.SourceID) + default: + } + ttmsgs := ttBarrier.ttStream.Consume() + if len(ttmsgs.Msgs) > 0 { + for _, timetickmsg := range ttmsgs.Msgs { + ttmsg := timetickmsg.(*ms.TimeTickMsg) + oldT, ok := ttBarrier.peer2LastTt[ttmsg.Base.SourceID] + // log.Printf("[softTimeTickBarrier] peer(%d)=%d\n", ttmsg.PeerID, ttmsg.Timestamp) + + if !ok { + log.Printf("[softTimeTickBarrier] Warning: peerID %d not exist\n", ttmsg.Base.SourceID) + continue + } + if ttmsg.Base.Timestamp > oldT { + ttBarrier.peer2LastTt[ttmsg.Base.SourceID] = ttmsg.Base.Timestamp + + // get a legal Timestamp + ts := ttBarrier.minTimestamp() + lastTt := atomic.LoadInt64(&(ttBarrier.lastTt)) + if lastTt != 0 && ttBarrier.minTtInterval > ts-Timestamp(lastTt) { continue } - if ttmsg.Base.Timestamp > oldT { - ttBarrier.peer2LastTt[ttmsg.Base.SourceID] = ttmsg.Base.Timestamp - - // get a legal Timestamp - ts := ttBarrier.minTimestamp() - lastTt := atomic.LoadInt64(&(ttBarrier.lastTt)) - if lastTt != 0 && ttBarrier.minTtInterval > ts-Timestamp(lastTt) { - continue - } - ttBarrier.outTt <- ts - } + ttBarrier.outTt <- ts } } } @@ -145,32 +146,32 @@ func (ttBarrier *hardTimeTickBarrier) StartBackgroundLoop(ctx context.Context) { case <-ctx.Done(): log.Printf("[TtBarrierStart] %s\n", ctx.Err()) return - case ttmsgs := <-ttBarrier.ttStream.Chan(): - if len(ttmsgs.Msgs) > 0 { - for _, timetickmsg := range ttmsgs.Msgs { - - // Suppose ttmsg.Timestamp from stream is always larger than the previous one, - // that `ttmsg.Timestamp > oldT` - ttmsg := timetickmsg.(*ms.TimeTickMsg) - - oldT, ok := ttBarrier.peer2Tt[ttmsg.Base.SourceID] - if !ok { - log.Printf("[hardTimeTickBarrier] Warning: peerID %d not exist\n", ttmsg.Base.SourceID) - continue - } + default: + } + ttmsgs := ttBarrier.ttStream.Consume() + if len(ttmsgs.Msgs) > 0 { + for _, timetickmsg := range ttmsgs.Msgs { + // Suppose ttmsg.Timestamp from stream is always larger than the previous one, + // that `ttmsg.Timestamp > oldT` + ttmsg := timetickmsg.(*ms.TimeTickMsg) + + oldT, ok := ttBarrier.peer2Tt[ttmsg.Base.SourceID] + if !ok { + log.Printf("[hardTimeTickBarrier] Warning: peerID %d not exist\n", ttmsg.Base.SourceID) + continue + } - if oldT > state { - log.Printf("[hardTimeTickBarrier] Warning: peer(%d) timestamp(%d) ahead\n", - ttmsg.Base.SourceID, ttmsg.Base.Timestamp) - } + if oldT > state { + log.Printf("[hardTimeTickBarrier] Warning: peer(%d) timestamp(%d) ahead\n", + ttmsg.Base.SourceID, ttmsg.Base.Timestamp) + } - ttBarrier.peer2Tt[ttmsg.Base.SourceID] = ttmsg.Base.Timestamp + ttBarrier.peer2Tt[ttmsg.Base.SourceID] = ttmsg.Base.Timestamp - newState := ttBarrier.minTimestamp() - if newState > state { - ttBarrier.outTt <- newState - state = newState - } + newState := ttBarrier.minTimestamp() + if newState > state { + ttBarrier.outTt <- newState + state = newState } } } @@ -187,14 +188,14 @@ func (ttBarrier *hardTimeTickBarrier) minTimestamp() Timestamp { return tempMin } -func NewHardTimeTickBarrier(ttStream *ms.MsgStream, peerIds []UniqueID) *hardTimeTickBarrier { +func NewHardTimeTickBarrier(ttStream ms.MsgStream, peerIds []UniqueID) *hardTimeTickBarrier { if len(peerIds) <= 0 { log.Printf("[newSoftTimeTickBarrier] Error: peerIds is empty!") return nil } sttbarrier := hardTimeTickBarrier{} - sttbarrier.ttStream = *ttStream + sttbarrier.ttStream = ttStream sttbarrier.outTt = make(chan Timestamp, 1024) sttbarrier.peer2Tt = make(map[UniqueID]Timestamp)