提交 1578c132 编写于 作者: B bigsheeper 提交者: yefu.chen

Refactor collection replica

Signed-off-by: Nbigsheeper <yihao.dai@zilliz.com>
上级 a48590a2
......@@ -134,12 +134,10 @@ func (m *InsertChannelsMap) closeInsertMsgStream(collID UniqueID) error {
m.usageHistogram[loc]--
if m.usageHistogram[loc] <= 0 {
m.insertMsgStreams[loc].Close()
m.droppedBitMap[loc] = 1
delete(m.collectionID2InsertChannels, collID)
log.Print("close insert message stream ...")
}
log.Print("close insert message stream ...")
m.droppedBitMap[loc] = 1
delete(m.collectionID2InsertChannels, collID)
return nil
}
......
......@@ -19,22 +19,32 @@ import (
type Collection struct {
collectionPtr C.CCollection
id UniqueID
partitionIDs []UniqueID
schema *schemapb.CollectionSchema
partitions []*Partition
}
func (c *Collection) ID() UniqueID {
return c.id
}
func (c *Collection) Partitions() *[]*Partition {
return &c.partitions
}
func (c *Collection) Schema() *schemapb.CollectionSchema {
return c.schema
}
func (c *Collection) addPartitionID(partitionID UniqueID) {
c.partitionIDs = append(c.partitionIDs, partitionID)
}
func (c *Collection) removePartitionID(partitionID UniqueID) {
tmpIDs := make([]UniqueID, 0)
for _, id := range c.partitionIDs {
if id == partitionID {
tmpIDs = append(tmpIDs, id)
}
}
c.partitionIDs = tmpIDs
}
func newCollection(collectionID UniqueID, schema *schemapb.CollectionSchema) *Collection {
/*
CCollection
......
......@@ -11,13 +11,15 @@ func TestCollectionReplica_getCollectionNum(t *testing.T) {
node := newQueryNodeMock()
initTestMeta(t, node, 0, 0)
assert.Equal(t, node.replica.getCollectionNum(), 1)
node.Stop()
err := node.Stop()
assert.NoError(t, err)
}
func TestCollectionReplica_addCollection(t *testing.T) {
node := newQueryNodeMock()
initTestMeta(t, node, 0, 0)
node.Stop()
err := node.Stop()
assert.NoError(t, err)
}
func TestCollectionReplica_removeCollection(t *testing.T) {
......@@ -28,7 +30,8 @@ func TestCollectionReplica_removeCollection(t *testing.T) {
err := node.replica.removeCollection(0)
assert.NoError(t, err)
assert.Equal(t, node.replica.getCollectionNum(), 0)
node.Stop()
err = node.Stop()
assert.NoError(t, err)
}
func TestCollectionReplica_getCollectionByID(t *testing.T) {
......@@ -39,7 +42,8 @@ func TestCollectionReplica_getCollectionByID(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, targetCollection)
assert.Equal(t, targetCollection.ID(), collectionID)
node.Stop()
err = node.Stop()
assert.NoError(t, err)
}
func TestCollectionReplica_hasCollection(t *testing.T) {
......@@ -52,7 +56,8 @@ func TestCollectionReplica_hasCollection(t *testing.T) {
hasCollection = node.replica.hasCollection(UniqueID(1))
assert.Equal(t, hasCollection, false)
node.Stop()
err := node.Stop()
assert.NoError(t, err)
}
//----------------------------------------------------------------------------------------------------- partition
......@@ -65,15 +70,15 @@ func TestCollectionReplica_getPartitionNum(t *testing.T) {
for _, id := range partitionIDs {
err := node.replica.addPartition(collectionID, id)
assert.NoError(t, err)
partition, err := node.replica.getPartitionByID(collectionID, id)
partition, err := node.replica.getPartitionByID(id)
assert.NoError(t, err)
assert.Equal(t, partition.ID(), id)
}
partitionNum, err := node.replica.getPartitionNum(collectionID)
partitionNum := node.replica.getPartitionNum()
assert.Equal(t, partitionNum, len(partitionIDs)+1)
err := node.Stop()
assert.NoError(t, err)
assert.Equal(t, partitionNum, len(partitionIDs)+1) // _default
node.Stop()
}
func TestCollectionReplica_addPartition(t *testing.T) {
......@@ -85,11 +90,12 @@ func TestCollectionReplica_addPartition(t *testing.T) {
for _, id := range partitionIDs {
err := node.replica.addPartition(collectionID, id)
assert.NoError(t, err)
partition, err := node.replica.getPartitionByID(collectionID, id)
partition, err := node.replica.getPartitionByID(id)
assert.NoError(t, err)
assert.Equal(t, partition.ID(), id)
}
node.Stop()
err := node.Stop()
assert.NoError(t, err)
}
func TestCollectionReplica_removePartition(t *testing.T) {
......@@ -102,60 +108,14 @@ func TestCollectionReplica_removePartition(t *testing.T) {
for _, id := range partitionIDs {
err := node.replica.addPartition(collectionID, id)
assert.NoError(t, err)
partition, err := node.replica.getPartitionByID(collectionID, id)
partition, err := node.replica.getPartitionByID(id)
assert.NoError(t, err)
assert.Equal(t, partition.ID(), id)
err = node.replica.removePartition(collectionID, id)
err = node.replica.removePartition(id)
assert.NoError(t, err)
}
node.Stop()
}
func TestCollectionReplica_addPartitionsByCollectionMeta(t *testing.T) {
node := newQueryNodeMock()
collectionID := UniqueID(0)
initTestMeta(t, node, collectionID, 0)
collectionMeta := genTestCollectionMeta(collectionID, false)
collectionMeta.PartitionIDs = []UniqueID{0, 1, 2}
err := node.replica.addPartitionsByCollectionMeta(collectionMeta)
err := node.Stop()
assert.NoError(t, err)
partitionNum, err := node.replica.getPartitionNum(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, partitionNum, len(collectionMeta.PartitionIDs)+1)
hasPartition := node.replica.hasPartition(UniqueID(0), UniqueID(0))
assert.Equal(t, hasPartition, true)
hasPartition = node.replica.hasPartition(UniqueID(0), UniqueID(1))
assert.Equal(t, hasPartition, true)
hasPartition = node.replica.hasPartition(UniqueID(0), UniqueID(2))
assert.Equal(t, hasPartition, true)
node.Stop()
}
func TestCollectionReplica_removePartitionsByCollectionMeta(t *testing.T) {
node := newQueryNodeMock()
collectionID := UniqueID(0)
initTestMeta(t, node, collectionID, 0)
collectionMeta := genTestCollectionMeta(collectionID, false)
collectionMeta.PartitionIDs = []UniqueID{0}
err := node.replica.addPartitionsByCollectionMeta(collectionMeta)
assert.NoError(t, err)
partitionNum, err := node.replica.getPartitionNum(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, partitionNum, len(collectionMeta.PartitionIDs)+1)
hasPartition := node.replica.hasPartition(UniqueID(0), UniqueID(0))
assert.Equal(t, hasPartition, true)
hasPartition = node.replica.hasPartition(UniqueID(0), UniqueID(1))
assert.Equal(t, hasPartition, false)
hasPartition = node.replica.hasPartition(UniqueID(0), UniqueID(2))
assert.Equal(t, hasPartition, false)
node.Stop()
}
func TestCollectionReplica_getPartitionByTag(t *testing.T) {
......@@ -168,12 +128,13 @@ func TestCollectionReplica_getPartitionByTag(t *testing.T) {
for _, id := range collectionMeta.PartitionIDs {
err := node.replica.addPartition(collectionID, id)
assert.NoError(t, err)
partition, err := node.replica.getPartitionByID(collectionID, id)
partition, err := node.replica.getPartitionByID(id)
assert.NoError(t, err)
assert.Equal(t, partition.ID(), id)
assert.NotNil(t, partition)
}
node.Stop()
err := node.Stop()
assert.NoError(t, err)
}
func TestCollectionReplica_hasPartition(t *testing.T) {
......@@ -184,11 +145,12 @@ func TestCollectionReplica_hasPartition(t *testing.T) {
collectionMeta := genTestCollectionMeta(collectionID, false)
err := node.replica.addPartition(collectionID, collectionMeta.PartitionIDs[0])
assert.NoError(t, err)
hasPartition := node.replica.hasPartition(collectionID, defaultPartitionID)
hasPartition := node.replica.hasPartition(defaultPartitionID)
assert.Equal(t, hasPartition, true)
hasPartition = node.replica.hasPartition(collectionID, defaultPartitionID+1)
hasPartition = node.replica.hasPartition(defaultPartitionID + 1)
assert.Equal(t, hasPartition, false)
node.Stop()
err = node.Stop()
assert.NoError(t, err)
}
//----------------------------------------------------------------------------------------------------- segment
......@@ -206,7 +168,8 @@ func TestCollectionReplica_addSegment(t *testing.T) {
assert.Equal(t, targetSeg.segmentID, UniqueID(i))
}
node.Stop()
err := node.Stop()
assert.NoError(t, err)
}
func TestCollectionReplica_removeSegment(t *testing.T) {
......@@ -226,7 +189,8 @@ func TestCollectionReplica_removeSegment(t *testing.T) {
assert.NoError(t, err)
}
node.Stop()
err := node.Stop()
assert.NoError(t, err)
}
func TestCollectionReplica_getSegmentByID(t *testing.T) {
......@@ -244,7 +208,8 @@ func TestCollectionReplica_getSegmentByID(t *testing.T) {
assert.Equal(t, targetSeg.segmentID, UniqueID(i))
}
node.Stop()
err := node.Stop()
assert.NoError(t, err)
}
func TestCollectionReplica_hasSegment(t *testing.T) {
......@@ -266,7 +231,8 @@ func TestCollectionReplica_hasSegment(t *testing.T) {
assert.Equal(t, hasSeg, false)
}
node.Stop()
err := node.Stop()
assert.NoError(t, err)
}
func TestCollectionReplica_freeAll(t *testing.T) {
......@@ -274,6 +240,7 @@ func TestCollectionReplica_freeAll(t *testing.T) {
collectionID := UniqueID(0)
initTestMeta(t, node, collectionID, 0)
node.Stop()
err := node.Stop()
assert.NoError(t, err)
}
......@@ -6,18 +6,6 @@ import (
"github.com/stretchr/testify/assert"
)
func TestCollection_Partitions(t *testing.T) {
node := newQueryNodeMock()
collectionID := UniqueID(0)
initTestMeta(t, node, collectionID, 0)
collection, err := node.replica.getCollectionByID(collectionID)
assert.NoError(t, err)
partitions := collection.Partitions()
assert.Equal(t, 1, len(*partitions))
}
func TestCollection_newCollection(t *testing.T) {
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionID, false)
......
......@@ -41,7 +41,7 @@ func (gcNode *gcNode) Operate(in []*Msg) []*Msg {
//
//// drop partitions
//for _, partition := range gcMsg.gcRecord.partitions {
// err := gcNode.replica.removePartition(partition.collectionID, partition.partitionID)
// err := gcNode.replica.removePartition(partition.partitionID)
// if err != nil {
// log.Println(err)
// }
......
......@@ -13,23 +13,35 @@ package querynode
import "C"
type Partition struct {
id UniqueID
segments []*Segment
enableDM bool
collectionID UniqueID
partitionID UniqueID
segmentIDs []UniqueID
enableDM bool
}
func (p *Partition) ID() UniqueID {
return p.id
return p.partitionID
}
func (p *Partition) Segments() *[]*Segment {
return &(*p).segments
func (p *Partition) addSegmentID(segmentID UniqueID) {
p.segmentIDs = append(p.segmentIDs, segmentID)
}
func newPartition(partitionID UniqueID) *Partition {
func (p *Partition) removeSegmentID(segmentID UniqueID) {
tmpIDs := make([]UniqueID, 0)
for _, id := range p.segmentIDs {
if id == segmentID {
tmpIDs = append(tmpIDs, id)
}
}
p.segmentIDs = tmpIDs
}
func newPartition(collectionID UniqueID, partitionID UniqueID) *Partition {
var newPartition = &Partition{
id: partitionID,
enableDM: false,
collectionID: collectionID,
partitionID: partitionID,
enableDM: false,
}
return newPartition
......
......@@ -6,29 +6,8 @@ import (
"github.com/stretchr/testify/assert"
)
func TestPartition_Segments(t *testing.T) {
node := newQueryNodeMock()
collectionID := UniqueID(0)
initTestMeta(t, node, collectionID, 0)
collection, err := node.replica.getCollectionByID(collectionID)
assert.NoError(t, err)
partitions := collection.Partitions()
targetPartition := (*partitions)[0]
const segmentNum = 3
for i := 0; i < segmentNum; i++ {
err := node.replica.addSegment(UniqueID(i), targetPartition.ID(), collection.ID(), segTypeGrowing)
assert.NoError(t, err)
}
segments := targetPartition.Segments()
assert.Equal(t, segmentNum+1, len(*segments))
}
func TestPartition_newPartition(t *testing.T) {
partitionID := defaultPartitionID
partition := newPartition(partitionID)
partition := newPartition(UniqueID(0), partitionID)
assert.Equal(t, partition.ID(), defaultPartitionID)
}
......@@ -81,17 +81,7 @@ func NewQueryNode(ctx context.Context, queryNodeID uint64) *QueryNode {
statsService: nil,
}
segmentsMap := make(map[int64]*Segment)
collections := make([]*Collection, 0)
tSafe := newTSafe()
node.replica = &collectionReplicaImpl{
collections: collections,
segments: segmentsMap,
tSafe: tSafe,
}
node.replica = newCollectionReplicaImpl()
node.stateCode.Store(internalpb2.StateCode_INITIALIZING)
return node
}
......@@ -108,17 +98,7 @@ func NewQueryNodeWithoutID(ctx context.Context) *QueryNode {
statsService: nil,
}
segmentsMap := make(map[int64]*Segment)
collections := make([]*Collection, 0)
tSafe := newTSafe()
node.replica = &collectionReplicaImpl{
collections: collections,
segments: segmentsMap,
tSafe: tSafe,
}
node.replica = newCollectionReplicaImpl()
node.stateCode.Store(internalpb2.StateCode_INITIALIZING)
return node
}
......@@ -403,7 +383,7 @@ func (node *QueryNode) LoadSegments(in *queryPb.LoadSegmentRequest) (*commonpb.S
segmentIDs := in.SegmentIDs
fieldIDs := in.FieldIDs
err := node.replica.enablePartitionDM(collectionID, partitionID)
err := node.replica.enablePartitionDM(partitionID)
if err != nil {
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
......@@ -444,7 +424,7 @@ func (node *QueryNode) LoadSegments(in *queryPb.LoadSegmentRequest) (*commonpb.S
func (node *QueryNode) ReleaseSegments(in *queryPb.ReleaseSegmentRequest) (*commonpb.Status, error) {
for _, id := range in.PartitionIDs {
err := node.replica.enablePartitionDM(in.CollectionID, id)
err := node.replica.enablePartitionDM(id)
if err != nil {
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
......
......@@ -18,7 +18,7 @@ import (
)
const ctxTimeInMillisecond = 5000
const closeWithDeadline = true
const debug = false
const defaultPartitionID = UniqueID(2021)
......@@ -121,7 +121,9 @@ func newQueryNodeMock() *QueryNode {
var ctx context.Context
if closeWithDeadline {
if debug {
ctx = context.Background()
} else {
var cancel context.CancelFunc
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel = context.WithDeadline(context.Background(), d)
......@@ -129,8 +131,6 @@ func newQueryNodeMock() *QueryNode {
<-ctx.Done()
cancel()
}()
} else {
ctx = context.Background()
}
svr := NewQueryNode(ctx, 0)
......
......@@ -245,11 +245,9 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
matchedSegments := make([]*Segment, 0)
//fmt.Println("search msg's partitionID = ", partitionIDsInQuery)
var partitionIDsInCol []UniqueID
for _, partition := range collection.partitions {
partitionID := partition.ID()
partitionIDsInCol = append(partitionIDsInCol, partitionID)
partitionIDsInCol, err := ss.replica.getPartitionIDs(collectionID)
if err != nil {
return err
}
var searchPartitionIDs []UniqueID
partitionIDsInQuery := searchMsg.PartitionIDs
......@@ -267,10 +265,16 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
}
for _, partitionID := range searchPartitionIDs {
partition, _ := ss.replica.getPartitionByID(collectionID, partitionID)
for _, segment := range partition.segments {
segmentIDs, err := ss.replica.getSegmentIDs(partitionID)
if err != nil {
return err
}
for _, segmentID := range segmentIDs {
//fmt.Println("dsl = ", dsl)
segment, err := ss.replica.getSegmentByID(segmentID)
if err != nil {
return err
}
searchResult, err := segment.segmentSearch(plan, placeholderGroups, []Timestamp{searchTimestamp})
if err != nil {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册