未验证 提交 d94361fd 编写于 作者: C Cai Yudong 提交者: GitHub

Add parameter partitionID for API filterSegments (#8813)

Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>
上级 5cf46cdb
......@@ -68,13 +68,13 @@ func (dn *deleteNode) Operate(in []Msg) []Msg {
// filterSegmentByPK returns the bloom filter check result.
// If the key may exists in the segment, returns it in map.
// If the key not exists in the segment, the segment is filter out.
func (dn *deleteNode) filterSegmentByPK(pks []int64) (map[int64][]int64, error) {
func (dn *deleteNode) filterSegmentByPK(partID UniqueID, pks []int64) (map[int64][]int64, error) {
if pks == nil {
return nil, errors.New("pks is nil")
}
results := make(map[int64][]int64)
buf := make([]byte, 8)
segments := dn.replica.getSegments(dn.channelName)
segments := dn.replica.filterSegments(dn.channelName, partID)
for _, segment := range segments {
for _, pk := range pks {
binary.BigEndian.PutUint64(buf, uint64(pk))
......
......@@ -27,7 +27,7 @@ type mockReplica struct {
flushedSegments map[UniqueID]*Segment
}
func (replica *mockReplica) getSegments(channelName string) []*Segment {
func (replica *mockReplica) filterSegments(channelName string, partitionID UniqueID) []*Segment {
results := make([]*Segment, 0)
for _, value := range replica.newSegments {
results = append(results, value)
......@@ -148,7 +148,7 @@ func Test_GetSegmentsByPKs(t *testing.T) {
mockReplica.flushedSegments[segment5.segmentID] = segment5
mockReplica.flushedSegments[segment6.segmentID] = segment6
dn := newDeleteNode(mockReplica, "test", make(chan *flushMsg))
results, err := dn.filterSegmentByPK([]int64{0, 1, 2, 3, 4})
results, err := dn.filterSegmentByPK(0, []int64{0, 1, 2, 3, 4})
assert.Nil(t, err)
expected := map[int64][]int64{
0: {1, 2, 3},
......@@ -160,5 +160,4 @@ func Test_GetSegmentsByPKs(t *testing.T) {
for key, value := range expected {
assert.ElementsMatch(t, value, results[key])
}
}
......@@ -43,7 +43,7 @@ type Replica interface {
addNewSegment(segID, collID, partitionID UniqueID, channelName string, startPos, endPos *internalpb.MsgPosition) error
addNormalSegment(segID, collID, partitionID UniqueID, channelName string, numOfRows int64, cp *segmentCheckPoint) error
getSegments(channelName string) []*Segment
filterSegments(channelName string, partitionID UniqueID) []*Segment
listNewSegmentsStartPositions() []*datapb.SegmentStartPosition
listSegmentsCheckPoints() map[UniqueID]segmentCheckPoint
updateSegmentEndPosition(segID UniqueID, endPos *internalpb.MsgPosition)
......@@ -223,24 +223,28 @@ func (replica *SegmentReplica) addNewSegment(segID, collID, partitionID UniqueID
return nil
}
// getSegments return segments with same channelName
func (replica *SegmentReplica) getSegments(channelName string) []*Segment {
// filterSegments return segments with same channelName and partition ID
func (replica *SegmentReplica) filterSegments(channelName string, partitionID UniqueID) []*Segment {
replica.segMu.Lock()
defer replica.segMu.Unlock()
results := make([]*Segment, 0)
for _, value := range replica.newSegments {
if value.channelName == channelName {
results = append(results, value)
isMatched := func(segment *Segment, chanName string, partID UniqueID) bool {
return segment.channelName == chanName && (partID == 0 || segment.partitionID == partID)
}
for _, seg := range replica.newSegments {
if isMatched(seg, channelName, partitionID) {
results = append(results, seg)
}
}
for _, value := range replica.normalSegments {
if value.channelName == channelName {
results = append(results, value)
for _, seg := range replica.normalSegments {
if isMatched(seg, channelName, partitionID) {
results = append(results, seg)
}
}
for _, value := range replica.flushedSegments {
if value.channelName == channelName {
results = append(results, value)
for _, seg := range replica.flushedSegments {
if isMatched(seg, channelName, partitionID) {
results = append(results, seg)
}
}
return results
......
......@@ -559,7 +559,7 @@ func TestSegmentReplica_InterfaceMethod(te *testing.T) {
err = replica.addFlushedSegment(1, 1, 2, "insert-01", int64(0))
assert.Nil(t, err)
totalSegments := replica.getSegments("insert-01")
totalSegments := replica.filterSegments("insert-01", 0)
assert.Equal(t, len(totalSegments), 3)
})
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册