未验证 提交 5dae6a65 编写于 作者: Y yah01 提交者: GitHub

Protect segment from being released while query/search (#26322)

Signed-off-by: Nyah01 <yah2er0ne@outlook.com>
上级 97237871
......@@ -388,20 +388,24 @@ func (node *QueryNode) getChannelStatistics(ctx context.Context, req *querypb.Ge
}
if req.GetFromShardLeader() {
var results []segments.SegmentStats
var err error
var (
results []segments.SegmentStats
readSegments []segments.Segment
err error
)
switch req.GetScope() {
case querypb.DataScope_Historical:
results, _, _, err = segments.StatisticsHistorical(ctx, node.manager, req.Req.GetCollectionID(), req.Req.GetPartitionIDs(), req.GetSegmentIDs())
results, readSegments, err = segments.StatisticsHistorical(ctx, node.manager, req.Req.GetCollectionID(), req.Req.GetPartitionIDs(), req.GetSegmentIDs())
case querypb.DataScope_Streaming:
results, _, _, err = segments.StatisticStreaming(ctx, node.manager, req.Req.GetCollectionID(), req.Req.GetPartitionIDs(), req.GetSegmentIDs())
results, readSegments, err = segments.StatisticStreaming(ctx, node.manager, req.Req.GetCollectionID(), req.Req.GetPartitionIDs(), req.GetSegmentIDs())
}
if err != nil {
log.Warn("get segments statistics failed", zap.Error(err))
return nil, err
}
defer node.manager.Segment.Unpin(readSegments)
return segmentStatsResponse(results), nil
}
......
......@@ -32,6 +32,7 @@ import (
"github.com/milvus-io/milvus/pkg/eventlog"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
. "github.com/milvus-io/milvus/pkg/util/typeutil"
"go.uber.org/zap"
......@@ -90,6 +91,10 @@ type SegmentManager interface {
Get(segmentID UniqueID) Segment
GetWithType(segmentID UniqueID, typ SegmentType) Segment
GetBy(filters ...SegmentFilter) []Segment
// Get segments and acquire the read locks
GetAndPinBy(filters ...SegmentFilter) ([]Segment, error)
GetAndPin(segments []int64, filters ...SegmentFilter) ([]Segment, error)
Unpin(segments []Segment)
GetSealed(segmentID UniqueID) Segment
GetGrowing(segmentID UniqueID) Segment
Empty() bool
......@@ -252,6 +257,95 @@ func (mgr *segmentManager) GetBy(filters ...SegmentFilter) []Segment {
return ret
}
func (mgr *segmentManager) GetAndPinBy(filters ...SegmentFilter) ([]Segment, error) {
mgr.mu.RLock()
defer mgr.mu.RUnlock()
ret := make([]Segment, 0)
var err error
defer func() {
if err != nil {
for _, segment := range ret {
segment.RUnlock()
}
}
}()
for _, segment := range mgr.growingSegments {
if filter(segment, filters...) {
err = segment.RLock()
if err != nil {
return nil, err
}
ret = append(ret, segment)
}
}
for _, segment := range mgr.sealedSegments {
if filter(segment, filters...) {
err = segment.RLock()
if err != nil {
return nil, err
}
ret = append(ret, segment)
}
}
return ret, nil
}
func (mgr *segmentManager) GetAndPin(segments []int64, filters ...SegmentFilter) ([]Segment, error) {
mgr.mu.RLock()
defer mgr.mu.RUnlock()
lockedSegments := make([]Segment, 0, len(segments))
var err error
defer func() {
if err != nil {
for _, segment := range lockedSegments {
segment.RUnlock()
}
}
}()
for _, id := range segments {
growing, growingExist := mgr.growingSegments[id]
sealed, sealedExist := mgr.sealedSegments[id]
growingExist = growingExist && filter(growing, filters...)
sealedExist = sealedExist && filter(sealed, filters...)
if growingExist {
err = growing.RLock()
if err != nil {
return nil, err
}
lockedSegments = append(lockedSegments, growing)
}
if sealedExist {
err = sealed.RLock()
if err != nil {
return nil, err
}
lockedSegments = append(lockedSegments, sealed)
}
if !growingExist && !sealedExist {
err = merr.WrapErrSegmentNotLoaded(id, "segment not found")
return nil, err
}
}
return lockedSegments, nil
}
func (mgr *segmentManager) Unpin(segments []Segment) {
mgr.mu.RLock()
defer mgr.mu.RUnlock()
for _, segment := range segments {
segment.RUnlock()
}
}
func filter(segment Segment, filters ...SegmentFilter) bool {
for _, filter := range filters {
if !filter(segment) {
......
// Code generated by mockery v2.21.1. DO NOT EDIT.
// Code generated by mockery v2.32.4. DO NOT EDIT.
package segments
......@@ -564,6 +564,79 @@ func (_c *MockSegment_Partition_Call) RunAndReturn(run func() int64) *MockSegmen
return _c
}
// RLock provides a mock function with given fields:
func (_m *MockSegment) RLock() error {
ret := _m.Called()
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
} else {
r0 = ret.Error(0)
}
return r0
}
// MockSegment_RLock_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RLock'
type MockSegment_RLock_Call struct {
*mock.Call
}
// RLock is a helper method to define mock.On call
func (_e *MockSegment_Expecter) RLock() *MockSegment_RLock_Call {
return &MockSegment_RLock_Call{Call: _e.mock.On("RLock")}
}
func (_c *MockSegment_RLock_Call) Run(run func()) *MockSegment_RLock_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockSegment_RLock_Call) Return(_a0 error) *MockSegment_RLock_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockSegment_RLock_Call) RunAndReturn(run func() error) *MockSegment_RLock_Call {
_c.Call.Return(run)
return _c
}
// RUnlock provides a mock function with given fields:
func (_m *MockSegment) RUnlock() {
_m.Called()
}
// MockSegment_RUnlock_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RUnlock'
type MockSegment_RUnlock_Call struct {
*mock.Call
}
// RUnlock is a helper method to define mock.On call
func (_e *MockSegment_Expecter) RUnlock() *MockSegment_RUnlock_Call {
return &MockSegment_RUnlock_Call{Call: _e.mock.On("RUnlock")}
}
func (_c *MockSegment_RUnlock_Call) Run(run func()) *MockSegment_RUnlock_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockSegment_RUnlock_Call) Return() *MockSegment_RUnlock_Call {
_c.Call.Return()
return _c
}
func (_c *MockSegment_RUnlock_Call) RunAndReturn(run func()) *MockSegment_RUnlock_Call {
_c.Call.Return(run)
return _c
}
// RowNum provides a mock function with given fields:
func (_m *MockSegment) RowNum() int64 {
ret := _m.Called()
......@@ -837,13 +910,12 @@ func (_c *MockSegment_Version_Call) RunAndReturn(run func() int64) *MockSegment_
return _c
}
type mockConstructorTestingTNewMockSegment interface {
// NewMockSegment creates a new instance of MockSegment. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockSegment(t interface {
mock.TestingT
Cleanup(func())
}
// NewMockSegment creates a new instance of MockSegment. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func NewMockSegment(t mockConstructorTestingTNewMockSegment) *MockSegment {
}) *MockSegment {
mock := &MockSegment{}
mock.Mock.Test(t)
......
// Code generated by mockery v2.21.1. DO NOT EDIT.
// Code generated by mockery v2.32.4. DO NOT EDIT.
package segments
......@@ -139,6 +139,142 @@ func (_c *MockSegmentManager_Get_Call) RunAndReturn(run func(int64) Segment) *Mo
return _c
}
// GetAndPin provides a mock function with given fields: segments, filters
func (_m *MockSegmentManager) GetAndPin(segments []int64, filters ...SegmentFilter) ([]Segment, error) {
_va := make([]interface{}, len(filters))
for _i := range filters {
_va[_i] = filters[_i]
}
var _ca []interface{}
_ca = append(_ca, segments)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 []Segment
var r1 error
if rf, ok := ret.Get(0).(func([]int64, ...SegmentFilter) ([]Segment, error)); ok {
return rf(segments, filters...)
}
if rf, ok := ret.Get(0).(func([]int64, ...SegmentFilter) []Segment); ok {
r0 = rf(segments, filters...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]Segment)
}
}
if rf, ok := ret.Get(1).(func([]int64, ...SegmentFilter) error); ok {
r1 = rf(segments, filters...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockSegmentManager_GetAndPin_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAndPin'
type MockSegmentManager_GetAndPin_Call struct {
*mock.Call
}
// GetAndPin is a helper method to define mock.On call
// - segments []int64
// - filters ...SegmentFilter
func (_e *MockSegmentManager_Expecter) GetAndPin(segments interface{}, filters ...interface{}) *MockSegmentManager_GetAndPin_Call {
return &MockSegmentManager_GetAndPin_Call{Call: _e.mock.On("GetAndPin",
append([]interface{}{segments}, filters...)...)}
}
func (_c *MockSegmentManager_GetAndPin_Call) Run(run func(segments []int64, filters ...SegmentFilter)) *MockSegmentManager_GetAndPin_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]SegmentFilter, len(args)-1)
for i, a := range args[1:] {
if a != nil {
variadicArgs[i] = a.(SegmentFilter)
}
}
run(args[0].([]int64), variadicArgs...)
})
return _c
}
func (_c *MockSegmentManager_GetAndPin_Call) Return(_a0 []Segment, _a1 error) *MockSegmentManager_GetAndPin_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockSegmentManager_GetAndPin_Call) RunAndReturn(run func([]int64, ...SegmentFilter) ([]Segment, error)) *MockSegmentManager_GetAndPin_Call {
_c.Call.Return(run)
return _c
}
// GetAndPinBy provides a mock function with given fields: filters
func (_m *MockSegmentManager) GetAndPinBy(filters ...SegmentFilter) ([]Segment, error) {
_va := make([]interface{}, len(filters))
for _i := range filters {
_va[_i] = filters[_i]
}
var _ca []interface{}
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 []Segment
var r1 error
if rf, ok := ret.Get(0).(func(...SegmentFilter) ([]Segment, error)); ok {
return rf(filters...)
}
if rf, ok := ret.Get(0).(func(...SegmentFilter) []Segment); ok {
r0 = rf(filters...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]Segment)
}
}
if rf, ok := ret.Get(1).(func(...SegmentFilter) error); ok {
r1 = rf(filters...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockSegmentManager_GetAndPinBy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAndPinBy'
type MockSegmentManager_GetAndPinBy_Call struct {
*mock.Call
}
// GetAndPinBy is a helper method to define mock.On call
// - filters ...SegmentFilter
func (_e *MockSegmentManager_Expecter) GetAndPinBy(filters ...interface{}) *MockSegmentManager_GetAndPinBy_Call {
return &MockSegmentManager_GetAndPinBy_Call{Call: _e.mock.On("GetAndPinBy",
append([]interface{}{}, filters...)...)}
}
func (_c *MockSegmentManager_GetAndPinBy_Call) Run(run func(filters ...SegmentFilter)) *MockSegmentManager_GetAndPinBy_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]SegmentFilter, len(args)-0)
for i, a := range args[0:] {
if a != nil {
variadicArgs[i] = a.(SegmentFilter)
}
}
run(variadicArgs...)
})
return _c
}
func (_c *MockSegmentManager_GetAndPinBy_Call) Return(_a0 []Segment, _a1 error) *MockSegmentManager_GetAndPinBy_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockSegmentManager_GetAndPinBy_Call) RunAndReturn(run func(...SegmentFilter) ([]Segment, error)) *MockSegmentManager_GetAndPinBy_Call {
_c.Call.Return(run)
return _c
}
// GetBy provides a mock function with given fields: filters
func (_m *MockSegmentManager) GetBy(filters ...SegmentFilter) []Segment {
_va := make([]interface{}, len(filters))
......@@ -495,6 +631,39 @@ func (_c *MockSegmentManager_RemoveBy_Call) RunAndReturn(run func(...SegmentFilt
return _c
}
// Unpin provides a mock function with given fields: segments
func (_m *MockSegmentManager) Unpin(segments []Segment) {
_m.Called(segments)
}
// MockSegmentManager_Unpin_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Unpin'
type MockSegmentManager_Unpin_Call struct {
*mock.Call
}
// Unpin is a helper method to define mock.On call
// - segments []Segment
func (_e *MockSegmentManager_Expecter) Unpin(segments interface{}) *MockSegmentManager_Unpin_Call {
return &MockSegmentManager_Unpin_Call{Call: _e.mock.On("Unpin", segments)}
}
func (_c *MockSegmentManager_Unpin_Call) Run(run func(segments []Segment)) *MockSegmentManager_Unpin_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]Segment))
})
return _c
}
func (_c *MockSegmentManager_Unpin_Call) Return() *MockSegmentManager_Unpin_Call {
_c.Call.Return()
return _c
}
func (_c *MockSegmentManager_Unpin_Call) RunAndReturn(run func([]Segment)) *MockSegmentManager_Unpin_Call {
_c.Call.Return(run)
return _c
}
// UpdateSegmentVersion provides a mock function with given fields: segmentType, segmentID, newVersion
func (_m *MockSegmentManager) UpdateSegmentVersion(segmentType commonpb.SegmentState, segmentID int64, newVersion int64) {
_m.Called(segmentType, segmentID, newVersion)
......@@ -530,13 +699,12 @@ func (_c *MockSegmentManager_UpdateSegmentVersion_Call) RunAndReturn(run func(co
return _c
}
type mockConstructorTestingTNewMockSegmentManager interface {
// NewMockSegmentManager creates a new instance of MockSegmentManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockSegmentManager(t interface {
mock.TestingT
Cleanup(func())
}
// NewMockSegmentManager creates a new instance of MockSegmentManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func NewMockSegmentManager(t mockConstructorTestingTNewMockSegmentManager) *MockSegmentManager {
}) *MockSegmentManager {
mock := &MockSegmentManager{}
mock.Mock.Test(t)
......
......@@ -31,10 +31,10 @@ import (
// retrieveOnSegments performs retrieve on listed segments
// all segment ids are validated before calling this function
func retrieveOnSegments(ctx context.Context, manager *Manager, segType SegmentType, plan *RetrievePlan, segIDs []UniqueID) ([]*segcorepb.RetrieveResults, error) {
func retrieveOnSegments(ctx context.Context, segments []Segment, segType SegmentType, plan *RetrievePlan) ([]*segcorepb.RetrieveResults, error) {
var (
resultCh = make(chan *segcorepb.RetrieveResults, len(segIDs))
errs = make([]error, len(segIDs))
resultCh = make(chan *segcorepb.RetrieveResults, len(segments))
errs = make([]error, len(segments))
wg sync.WaitGroup
)
......@@ -43,22 +43,18 @@ func retrieveOnSegments(ctx context.Context, manager *Manager, segType SegmentTy
label = metrics.GrowingSegmentLabel
}
for i, segID := range segIDs {
for i, segment := range segments {
wg.Add(1)
go func(segID int64, i int) {
go func(segment Segment, i int) {
defer wg.Done()
segment, _ := manager.Segment.GetWithType(segID, segType).(*LocalSegment)
if segment == nil {
errs[i] = nil
return
}
seg := segment.(*LocalSegment)
tr := timerecord.NewTimeRecorder("retrieveOnSegments")
result, err := segment.Retrieve(ctx, plan)
result, err := seg.Retrieve(ctx, plan)
if err != nil {
errs[i] = err
return
}
if err = segment.ValidateIndexedFieldsData(ctx, result); err != nil {
if err = seg.ValidateIndexedFieldsData(ctx, result); err != nil {
errs[i] = err
return
}
......@@ -66,7 +62,7 @@ func retrieveOnSegments(ctx context.Context, manager *Manager, segType SegmentTy
resultCh <- result
metrics.QueryNodeSQSegmentLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryLabel, label).Observe(float64(tr.ElapseSpan().Milliseconds()))
}(segID, i)
}(segment, i)
}
wg.Wait()
close(resultCh)
......@@ -86,31 +82,22 @@ func retrieveOnSegments(ctx context.Context, manager *Manager, segType SegmentTy
}
// retrieveHistorical will retrieve all the target segments in historical
func RetrieveHistorical(ctx context.Context, manager *Manager, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) {
var err error
var retrieveResults []*segcorepb.RetrieveResults
var retrieveSegmentIDs []UniqueID
var retrievePartIDs []UniqueID
retrievePartIDs, retrieveSegmentIDs, err = validateOnHistorical(ctx, manager, collID, partIDs, segIDs)
func RetrieveHistorical(ctx context.Context, manager *Manager, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]*segcorepb.RetrieveResults, []Segment, error) {
segments, err := validateOnHistorical(ctx, manager, collID, partIDs, segIDs)
if err != nil {
return retrieveResults, retrieveSegmentIDs, retrievePartIDs, err
return nil, nil, err
}
retrieveResults, err = retrieveOnSegments(ctx, manager, SegmentTypeSealed, plan, retrieveSegmentIDs)
return retrieveResults, retrievePartIDs, retrieveSegmentIDs, err
retrieveResults, err := retrieveOnSegments(ctx, segments, SegmentTypeSealed, plan)
return retrieveResults, segments, err
}
// retrieveStreaming will retrieve all the target segments in streaming
func RetrieveStreaming(ctx context.Context, manager *Manager, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) {
var err error
var retrieveResults []*segcorepb.RetrieveResults
var retrievePartIDs []UniqueID
var retrieveSegmentIDs []UniqueID
retrievePartIDs, retrieveSegmentIDs, err = validateOnStream(ctx, manager, collID, partIDs, segIDs)
func RetrieveStreaming(ctx context.Context, manager *Manager, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]*segcorepb.RetrieveResults, []Segment, error) {
segments, err := validateOnStream(ctx, manager, collID, partIDs, segIDs)
if err != nil {
return retrieveResults, retrieveSegmentIDs, retrievePartIDs, err
return nil, nil, err
}
retrieveResults, err = retrieveOnSegments(ctx, manager, SegmentTypeGrowing, plan, retrieveSegmentIDs)
return retrieveResults, retrievePartIDs, retrieveSegmentIDs, err
retrieveResults, err := retrieveOnSegments(ctx, segments, SegmentTypeGrowing, plan)
return retrieveResults, segments, err
}
......@@ -140,36 +140,39 @@ func (suite *RetrieveSuite) TestRetrieveSealed() {
plan, err := genSimpleRetrievePlan(suite.collection)
suite.NoError(err)
res, _, _, err := RetrieveHistorical(context.TODO(), suite.manager, plan,
res, segments, err := RetrieveHistorical(context.TODO(), suite.manager, plan,
suite.collectionID,
[]int64{suite.partitionID},
[]int64{suite.sealed.ID()})
suite.NoError(err)
suite.Len(res[0].Offset, 3)
suite.manager.Segment.Unpin(segments)
}
func (suite *RetrieveSuite) TestRetrieveGrowing() {
plan, err := genSimpleRetrievePlan(suite.collection)
suite.NoError(err)
res, _, _, err := RetrieveStreaming(context.TODO(), suite.manager, plan,
res, segments, err := RetrieveStreaming(context.TODO(), suite.manager, plan,
suite.collectionID,
[]int64{suite.partitionID},
[]int64{suite.growing.ID()})
suite.NoError(err)
suite.Len(res[0].Offset, 3)
suite.manager.Segment.Unpin(segments)
}
func (suite *RetrieveSuite) TestRetrieveNonExistSegment() {
plan, err := genSimpleRetrievePlan(suite.collection)
suite.NoError(err)
res, _, _, err := RetrieveHistorical(context.TODO(), suite.manager, plan,
res, segments, err := RetrieveHistorical(context.TODO(), suite.manager, plan,
suite.collectionID,
[]int64{suite.partitionID},
[]int64{999})
suite.NoError(err)
suite.ErrorIs(err, merr.ErrSegmentNotLoaded)
suite.Len(res, 0)
suite.manager.Segment.Unpin(segments)
}
func (suite *RetrieveSuite) TestRetrieveNilSegment() {
......@@ -177,12 +180,13 @@ func (suite *RetrieveSuite) TestRetrieveNilSegment() {
suite.NoError(err)
DeleteSegment(suite.sealed)
res, _, _, err := RetrieveHistorical(context.TODO(), suite.manager, plan,
res, segments, err := RetrieveHistorical(context.TODO(), suite.manager, plan,
suite.collectionID,
[]int64{suite.partitionID},
[]int64{suite.sealed.ID()})
suite.ErrorIs(err, merr.ErrSegmentNotLoaded)
suite.Len(res, 0)
suite.manager.Segment.Unpin(segments)
}
func TestRetrieve(t *testing.T) {
......
......@@ -32,11 +32,11 @@ import (
// searchOnSegments performs search on listed segments
// all segment ids are validated before calling this function
func searchSegments(ctx context.Context, manager *Manager, segType SegmentType, searchReq *SearchRequest, segIDs []int64) ([]*SearchResult, error) {
func searchSegments(ctx context.Context, segments []Segment, segType SegmentType, searchReq *SearchRequest) ([]*SearchResult, error) {
var (
// results variables
resultCh = make(chan *SearchResult, len(segIDs))
errs = make([]error, len(segIDs))
resultCh = make(chan *SearchResult, len(segments))
errs = make([]error, len(segments))
wg sync.WaitGroup
// For log only
......@@ -50,19 +50,14 @@ func searchSegments(ctx context.Context, manager *Manager, segType SegmentType,
}
// calling segment search in goroutines
for i, segID := range segIDs {
for i, segment := range segments {
wg.Add(1)
go func(segID int64, i int) {
go func(segment Segment, i int) {
defer wg.Done()
seg, _ := manager.Segment.GetWithType(segID, segType).(*LocalSegment)
if seg == nil {
log.Warn("segment released while searching", zap.Int64("segmentID", segID))
return
}
seg := segment.(*LocalSegment)
if !seg.ExistIndex(searchReq.searchFieldID) {
mu.Lock()
segmentsWithoutIndex = append(segmentsWithoutIndex, segID)
segmentsWithoutIndex = append(segmentsWithoutIndex, seg.ID())
mu.Unlock()
}
// record search time
......@@ -76,12 +71,12 @@ func searchSegments(ctx context.Context, manager *Manager, segType SegmentType,
metrics.SearchLabel, searchLabel).Observe(float64(elapsed))
metrics.QueryNodeSegmentSearchLatencyPerVector.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.SearchLabel, searchLabel).Observe(float64(elapsed) / float64(searchReq.getNumOfQuery()))
}(segID, i)
}(segment, i)
}
wg.Wait()
close(resultCh)
searchResults := make([]*SearchResult, 0, len(segIDs))
searchResults := make([]*SearchResult, 0, len(segments))
for result := range resultCh {
searchResults = append(searchResults, result)
}
......@@ -104,31 +99,22 @@ func searchSegments(ctx context.Context, manager *Manager, segType SegmentType,
// if segIDs is not specified, it will search on all the historical segments speficied by partIDs.
// if segIDs is specified, it will only search on the segments specified by the segIDs.
// if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded.
func SearchHistorical(ctx context.Context, manager *Manager, searchReq *SearchRequest, collID int64, partIDs []int64, segIDs []int64) ([]*SearchResult, []int64, []int64, error) {
var err error
var searchResults []*SearchResult
var searchSegmentIDs []int64
var searchPartIDs []int64
searchPartIDs, searchSegmentIDs, err = validateOnHistorical(ctx, manager, collID, partIDs, segIDs)
func SearchHistorical(ctx context.Context, manager *Manager, searchReq *SearchRequest, collID int64, partIDs []int64, segIDs []int64) ([]*SearchResult, []Segment, error) {
segments, err := validateOnHistorical(ctx, manager, collID, partIDs, segIDs)
if err != nil {
return searchResults, searchSegmentIDs, searchPartIDs, err
return nil, nil, err
}
searchResults, err = searchSegments(ctx, manager, SegmentTypeSealed, searchReq, searchSegmentIDs)
return searchResults, searchPartIDs, searchSegmentIDs, err
searchResults, err := searchSegments(ctx, segments, SegmentTypeSealed, searchReq)
return searchResults, segments, err
}
// searchStreaming will search all the target segments in streaming
// if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded.
func SearchStreaming(ctx context.Context, manager *Manager, searchReq *SearchRequest, collID int64, partIDs []int64, segIDs []int64) ([]*SearchResult, []int64, []int64, error) {
var err error
var searchResults []*SearchResult
var searchPartIDs []int64
var searchSegmentIDs []int64
searchPartIDs, searchSegmentIDs, err = validateOnStream(ctx, manager, collID, partIDs, segIDs)
func SearchStreaming(ctx context.Context, manager *Manager, searchReq *SearchRequest, collID int64, partIDs []int64, segIDs []int64) ([]*SearchResult, []Segment, error) {
segments, err := validateOnStream(ctx, manager, collID, partIDs, segIDs)
if err != nil {
return searchResults, searchSegmentIDs, searchPartIDs, err
return nil, nil, err
}
searchResults, err = searchSegments(ctx, manager, SegmentTypeGrowing, searchReq, searchSegmentIDs)
return searchResults, searchPartIDs, searchSegmentIDs, err
searchResults, err := searchSegments(ctx, segments, SegmentTypeGrowing, searchReq)
return searchResults, segments, err
}
......@@ -137,21 +137,23 @@ func (suite *SearchSuite) TestSearchSealed() {
searchReq, err := genSearchPlanAndRequests(suite.collection, []int64{suite.sealed.ID()}, IndexFaissIDMap, nq)
suite.NoError(err)
_, _, _, err = SearchHistorical(ctx, suite.manager, searchReq, suite.collectionID, nil, []int64{suite.sealed.ID()})
_, segments, err := SearchHistorical(ctx, suite.manager, searchReq, suite.collectionID, nil, []int64{suite.sealed.ID()})
suite.NoError(err)
suite.manager.Segment.Unpin(segments)
}
func (suite *SearchSuite) TestSearchGrowing() {
searchReq, err := genSearchPlanAndRequests(suite.collection, []int64{suite.growing.ID()}, IndexFaissIDMap, 1)
suite.NoError(err)
res, _, _, err := SearchStreaming(context.TODO(), suite.manager, searchReq,
res, segments, err := SearchStreaming(context.TODO(), suite.manager, searchReq,
suite.collectionID,
[]int64{suite.partitionID},
[]int64{suite.growing.ID()},
)
suite.NoError(err)
suite.Len(res, 1)
suite.manager.Segment.Unpin(segments)
}
func TestSearch(t *testing.T) {
......
......@@ -146,8 +146,8 @@ var _ Segment = (*LocalSegment)(nil)
// Segment is a wrapper of the underlying C-structure segment.
type LocalSegment struct {
baseSegment
mut sync.RWMutex // protects segmentPtr
ptr C.CSegmentInterface
ptrLock sync.RWMutex // protects segmentPtr
ptr C.CSegmentInterface
size int64
row int64
......@@ -199,9 +199,25 @@ func (s *LocalSegment) isValid() bool {
return s.ptr != nil
}
// RLock acquires the `ptrLock` and returns true if the pointer is valid
// Provide ONLY the read lock operations,
// don't make `ptrLock` public to avoid abusing of the mutex.
func (s *LocalSegment) RLock() error {
s.ptrLock.RLock()
if !s.isValid() {
s.ptrLock.RUnlock()
return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
}
return nil
}
func (s *LocalSegment) RUnlock() {
s.ptrLock.RUnlock()
}
func (s *LocalSegment) InsertCount() int64 {
s.mut.RLock()
defer s.mut.RUnlock()
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if !s.isValid() {
return 0
......@@ -216,8 +232,8 @@ func (s *LocalSegment) InsertCount() int64 {
}
func (s *LocalSegment) RowNum() int64 {
s.mut.RLock()
defer s.mut.RUnlock()
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if !s.isValid() {
return 0
......@@ -232,8 +248,8 @@ func (s *LocalSegment) RowNum() int64 {
}
func (s *LocalSegment) MemSize() int64 {
s.mut.RLock()
defer s.mut.RUnlock()
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if !s.isValid() {
return 0
......@@ -269,8 +285,8 @@ func (s *LocalSegment) ExistIndex(fieldID int64) bool {
}
func (s *LocalSegment) HasRawData(fieldID int64) bool {
s.mut.RLock()
defer s.mut.RUnlock()
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if !s.isValid() {
return false
}
......@@ -299,10 +315,10 @@ func DeleteSegment(segment *LocalSegment) {
// wait all read ops finished
var ptr C.CSegmentInterface
segment.mut.Lock()
segment.ptrLock.Lock()
ptr = segment.ptr
segment.ptr = nil
segment.mut.Unlock()
segment.ptrLock.Unlock()
if ptr == nil {
return
......@@ -331,8 +347,8 @@ func (s *LocalSegment) Search(ctx context.Context, searchReq *SearchRequest) (*S
zap.Int64("segmentID", s.ID()),
zap.String("segmentType", s.typ.String()),
)
s.mut.RLock()
defer s.mut.RUnlock()
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if s.ptr == nil {
return nil, merr.WrapErrSegmentNotLoaded(s.segmentID, "segment released")
......@@ -373,8 +389,8 @@ func (s *LocalSegment) Search(ctx context.Context, searchReq *SearchRequest) (*S
}
func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) {
s.mut.RLock()
defer s.mut.RUnlock()
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if s.ptr == nil {
return nil, merr.WrapErrSegmentNotLoaded(s.segmentID, "segment released")
......@@ -502,8 +518,8 @@ func (s *LocalSegment) Insert(rowIDs []int64, timestamps []typeutil.Timestamp, r
return fmt.Errorf("unexpected segmentType when segmentInsert, segmentType = %s", s.typ.String())
}
s.mut.RLock()
defer s.mut.RUnlock()
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if s.ptr == nil {
return merr.WrapErrSegmentNotLoaded(s.segmentID, "segment released")
......@@ -561,8 +577,8 @@ func (s *LocalSegment) Delete(primaryKeys []storage.PrimaryKey, timestamps []typ
const unsigned long* timestamps);
*/
s.mut.RLock()
defer s.mut.RUnlock()
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if s.ptr == nil {
return merr.WrapErrSegmentNotLoaded(s.segmentID, "segment released")
......@@ -626,8 +642,8 @@ func (s *LocalSegment) Delete(primaryKeys []storage.PrimaryKey, timestamps []typ
// -------------------------------------------------------------------------------------- interfaces for sealed segment
func (s *LocalSegment) LoadMultiFieldData(rowCount int64, fields []*datapb.FieldBinlog) error {
s.mut.RLock()
defer s.mut.RUnlock()
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if s.ptr == nil {
return merr.WrapErrSegmentNotLoaded(s.segmentID, "segment released")
......@@ -679,8 +695,8 @@ func (s *LocalSegment) LoadMultiFieldData(rowCount int64, fields []*datapb.Field
}
func (s *LocalSegment) LoadFieldData(fieldID int64, rowCount int64, field *datapb.FieldBinlog) error {
s.mut.RLock()
defer s.mut.RUnlock()
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if s.ptr == nil {
return merr.WrapErrSegmentNotLoaded(s.segmentID, "segment released")
......@@ -732,8 +748,8 @@ func (s *LocalSegment) LoadDeltaData(deltaData *storage.DeleteData) error {
pks, tss := deltaData.Pks, deltaData.Tss
rowNum := deltaData.RowCount
s.mut.RLock()
defer s.mut.RUnlock()
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if s.ptr == nil {
return merr.WrapErrSegmentNotLoaded(s.segmentID, "segment released")
......@@ -834,8 +850,8 @@ func (s *LocalSegment) LoadIndexInfo(indexInfo *querypb.FieldIndexInfo, info *Lo
zap.Int64("segmentID", s.ID()),
zap.Int64("fieldID", indexInfo.FieldID),
)
s.mut.RLock()
defer s.mut.RUnlock()
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if s.ptr == nil {
return merr.WrapErrSegmentNotLoaded(s.segmentID, "segment released")
......
......@@ -32,6 +32,8 @@ type Segment interface {
Version() int64
StartPosition() *msgpb.MsgPosition
Type() SegmentType
RLock() error
RUnlock()
// Stats related
// InsertCount returns the number of inserted rows, not effected by deletion
......
......@@ -201,9 +201,9 @@ func (suite *SegmentSuite) TestValidateIndexedFieldsData() {
func (suite *SegmentSuite) TestSegmentReleased() {
DeleteSegment(suite.sealed)
suite.sealed.mut.RLock()
suite.sealed.ptrLock.RLock()
suite.False(suite.sealed.isValid())
suite.sealed.mut.RUnlock()
suite.sealed.ptrLock.RUnlock()
suite.EqualValues(0, suite.sealed.InsertCount())
suite.EqualValues(0, suite.sealed.RowNum())
suite.EqualValues(0, suite.sealed.MemSize())
......
......@@ -19,8 +19,6 @@ package segments
import (
"context"
"sync"
"github.com/milvus-io/milvus/pkg/log"
)
// SegmentStats struct for segment statistics.
......@@ -31,28 +29,22 @@ type SegmentStats struct {
// statisticOnSegments performs statistic on listed segments
// all segment ids are validated before calling this function
func statisticOnSegments(ctx context.Context, manager *Manager, segType SegmentType, segIDs []int64) ([]SegmentStats, error) {
func statisticOnSegments(ctx context.Context, segments []Segment, segType SegmentType) ([]SegmentStats, error) {
// results variables
results := make([]SegmentStats, 0, len(segIDs))
resultCh := make(chan SegmentStats, len(segIDs))
results := make([]SegmentStats, 0, len(segments))
resultCh := make(chan SegmentStats, len(segments))
// fetch seg statistics in goroutines
var wg sync.WaitGroup
for i, segID := range segIDs {
for i, segment := range segments {
wg.Add(1)
go func(segID int64, i int) {
go func(segment Segment, i int) {
defer wg.Done()
seg := manager.Segment.GetWithType(segID, segType)
if seg == nil {
log.Warn("segment released while get statistics")
return
}
resultCh <- SegmentStats{
SegmentID: segID,
RowCount: seg.RowNum(),
SegmentID: segment.ID(),
RowCount: segment.RowNum(),
}
}(segID, i)
}(segment, i)
}
wg.Wait()
close(resultCh)
......@@ -67,30 +59,22 @@ func statisticOnSegments(ctx context.Context, manager *Manager, segType SegmentT
// if segIDs is not specified, it will search on all the historical segments specified by partIDs.
// if segIDs is specified, it will only search on the segments specified by the segIDs.
// if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded.
func StatisticsHistorical(ctx context.Context, manager *Manager, collID int64, partIDs []int64, segIDs []int64) ([]SegmentStats, []int64, []int64, error) {
var err error
var result []SegmentStats
var searchSegmentIDs []int64
var searchPartIDs []int64
searchPartIDs, searchSegmentIDs, err = validateOnHistorical(ctx, manager, collID, partIDs, segIDs)
func StatisticsHistorical(ctx context.Context, manager *Manager, collID int64, partIDs []int64, segIDs []int64) ([]SegmentStats, []Segment, error) {
segments, err := validateOnHistorical(ctx, manager, collID, partIDs, segIDs)
if err != nil {
return result, searchSegmentIDs, searchPartIDs, err
return nil, nil, err
}
result, err = statisticOnSegments(ctx, manager, SegmentTypeSealed, searchSegmentIDs)
return result, searchPartIDs, searchSegmentIDs, err
result, err := statisticOnSegments(ctx, segments, SegmentTypeSealed)
return result, segments, err
}
// StatisticStreaming will do statistics all the target segments in streaming
// if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded.
func StatisticStreaming(ctx context.Context, manager *Manager, collID int64, partIDs []int64, segIDs []int64) ([]SegmentStats, []int64, []int64, error) {
var err error
var result []SegmentStats
var searchSegmentIDs []int64
var searchPartIDs []int64
searchPartIDs, searchSegmentIDs, err = validateOnStream(ctx, manager, collID, partIDs, segIDs)
func StatisticStreaming(ctx context.Context, manager *Manager, collID int64, partIDs []int64, segIDs []int64) ([]SegmentStats, []Segment, error) {
segments, err := validateOnStream(ctx, manager, collID, partIDs, segIDs)
if err != nil {
return result, searchSegmentIDs, searchPartIDs, err
return nil, nil, err
}
result, err = statisticOnSegments(ctx, manager, SegmentTypeGrowing, searchSegmentIDs)
return result, searchPartIDs, searchSegmentIDs, err
result, err := statisticOnSegments(ctx, segments, SegmentTypeGrowing)
return result, segments, err
}
......@@ -29,13 +29,12 @@ import (
"github.com/milvus-io/milvus/pkg/util/merr"
)
func validate(ctx context.Context, manager *Manager, collectionID int64, partitionIDs []int64, segmentIDs []int64, segmentFilter SegmentFilter) ([]int64, []int64, error) {
func validate(ctx context.Context, manager *Manager, collectionID int64, partitionIDs []int64, segmentIDs []int64, segmentFilter SegmentFilter) ([]Segment, error) {
var searchPartIDs []int64
var newSegmentIDs []int64
collection := manager.Collection.Get(collectionID)
if collection == nil {
return nil, nil, merr.WrapErrCollectionNotFound(collectionID)
return nil, merr.WrapErrCollectionNotFound(collectionID)
}
//validate partition
......@@ -52,43 +51,43 @@ func validate(ctx context.Context, manager *Manager, collectionID int64, partiti
// all partitions have been released
if len(searchPartIDs) == 0 && collection.GetLoadType() == querypb.LoadType_LoadPartition {
return searchPartIDs, newSegmentIDs, errors.New("partitions have been released , collectionID = " +
return nil, errors.New("partitions have been released , collectionID = " +
fmt.Sprintln(collectionID) + "target partitionIDs = " + fmt.Sprintln(searchPartIDs))
}
if len(searchPartIDs) == 0 && collection.GetLoadType() == querypb.LoadType_LoadCollection {
return searchPartIDs, newSegmentIDs, nil
return []Segment{}, nil
}
//validate segment
segments := make([]Segment, 0, len(segmentIDs))
var err error
if len(segmentIDs) == 0 {
for _, partID := range searchPartIDs {
segments := manager.Segment.GetBy(WithPartition(partID), segmentFilter)
for _, seg := range segments {
newSegmentIDs = append(segmentIDs, seg.ID())
segments, err = manager.Segment.GetAndPinBy(WithPartition(partID), segmentFilter)
if err != nil {
return nil, err
}
}
} else {
newSegmentIDs = segmentIDs
for _, segmentID := range newSegmentIDs {
segments := manager.Segment.GetBy(WithID(segmentID), segmentFilter)
if len(segments) != 1 {
continue
}
segment := segments[0]
segments, err = manager.Segment.GetAndPin(segmentIDs, segmentFilter)
if err != nil {
return nil, err
}
for _, segment := range segments {
if !funcutil.SliceContain(searchPartIDs, segment.Partition()) {
err := fmt.Errorf("segment %d belongs to partition %d, which is not in %v", segmentID, segment.Partition(), searchPartIDs)
return searchPartIDs, newSegmentIDs, err
err := fmt.Errorf("segment %d belongs to partition %d, which is not in %v", segment.ID(), segment.Partition(), searchPartIDs)
return nil, err
}
}
}
return searchPartIDs, newSegmentIDs, nil
return segments, nil
}
func validateOnHistorical(ctx context.Context, manager *Manager, collectionID int64, partitionIDs []int64, segmentIDs []int64) ([]int64, []int64, error) {
func validateOnHistorical(ctx context.Context, manager *Manager, collectionID int64, partitionIDs []int64, segmentIDs []int64) ([]Segment, error) {
return validate(ctx, manager, collectionID, partitionIDs, segmentIDs, WithType(SegmentTypeSealed))
}
func validateOnStream(ctx context.Context, manager *Manager, collectionID int64, partitionIDs []int64, segmentIDs []int64) ([]int64, []int64, error) {
func validateOnStream(ctx context.Context, manager *Manager, collectionID int64, partitionIDs []int64, segmentIDs []int64) ([]Segment, error) {
return validate(ctx, manager, collectionID, partitionIDs, segmentIDs, WithType(SegmentTypeGrowing))
}
......@@ -88,9 +88,12 @@ func (t *QueryTask) Execute() error {
}
defer retrievePlan.Delete()
var results []*segcorepb.RetrieveResults
var (
results []*segcorepb.RetrieveResults
searchedSegments []segments.Segment
)
if t.req.GetScope() == querypb.DataScope_Historical {
results, _, _, err = segments.RetrieveHistorical(
results, searchedSegments, err = segments.RetrieveHistorical(
t.ctx,
t.segmentManager,
retrievePlan,
......@@ -99,7 +102,7 @@ func (t *QueryTask) Execute() error {
t.req.GetSegmentIDs(),
)
} else {
results, _, _, err = segments.RetrieveStreaming(
results, searchedSegments, err = segments.RetrieveStreaming(
t.ctx,
t.segmentManager,
retrievePlan,
......@@ -108,6 +111,7 @@ func (t *QueryTask) Execute() error {
t.req.GetSegmentIDs(),
)
}
defer t.segmentManager.Segment.Unpin(searchedSegments)
if err != nil {
return err
}
......
......@@ -124,9 +124,12 @@ func (t *SearchTask) Execute() error {
}
defer searchReq.Delete()
var results []*segments.SearchResult
var (
results []*segments.SearchResult
searchedSegments []segments.Segment
)
if req.GetScope() == querypb.DataScope_Historical {
results, _, _, err = segments.SearchHistorical(
results, searchedSegments, err = segments.SearchHistorical(
t.ctx,
t.segmentManager,
searchReq,
......@@ -135,7 +138,7 @@ func (t *SearchTask) Execute() error {
req.GetSegmentIDs(),
)
} else if req.GetScope() == querypb.DataScope_Streaming {
results, _, _, err = segments.SearchStreaming(
results, searchedSegments, err = segments.SearchStreaming(
t.ctx,
t.segmentManager,
searchReq,
......@@ -144,6 +147,7 @@ func (t *SearchTask) Execute() error {
req.GetSegmentIDs(),
)
}
defer t.segmentManager.Segment.Unpin(searchedSegments)
if err != nil {
return err
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册