未验证 提交 f915f173 编写于 作者: C chyezh 提交者: GitHub

[Fixup] Short-term fix metacache data race (#25802)

Signed-off-by: Nchyezh <ye.zhen@zilliz.com>
上级 d71efd60
......@@ -55,10 +55,10 @@ import (
type Cache interface {
// GetCollectionID get collection's id by name.
GetCollectionID(ctx context.Context, database, collectionName string) (typeutil.UniqueID, error)
// GetDatabaseAndCollectionName get collection's name and database by id
GetDatabaseAndCollectionName(ctx context.Context, collectionID int64) (string, string, error)
// GetCollectionName get collection's name and database by id
GetCollectionName(ctx context.Context, collectionID int64) (string, error)
// GetCollectionInfo get collection's information by name or collection id, such as schema, and etc.
GetCollectionInfo(ctx context.Context, database, collectionName string, collectionID int64) (*collectionInfo, error)
GetCollectionInfo(ctx context.Context, database, collectionName string, collectionID int64) (*collectionBasicInfo, error)
// GetPartitionID get partition's identifier of specific collection.
GetPartitionID(ctx context.Context, database, collectionName string, partitionName string) (typeutil.UniqueID, error)
// GetPartitions get all partitions' id of specific collection.
......@@ -87,6 +87,14 @@ type Cache interface {
RemoveDatabase(ctx context.Context, database string)
}
type collectionBasicInfo struct {
collID typeutil.UniqueID
createdTimestamp uint64
createdUtcTimestamp uint64
consistencyLevel commonpb.ConsistencyLevel
partInfo map[string]*partitionInfo
}
type collectionInfo struct {
collID typeutil.UniqueID
schema *schemapb.CollectionSchema
......@@ -96,7 +104,23 @@ type collectionInfo struct {
createdTimestamp uint64
createdUtcTimestamp uint64
consistencyLevel commonpb.ConsistencyLevel
database string
}
// getBasicInfo get a basic info by deep copy.
func (info *collectionInfo) getBasicInfo() *collectionBasicInfo {
// Do a deep copy for all fields.
basicInfo := &collectionBasicInfo{
collID: info.collID,
createdTimestamp: info.createdTimestamp,
createdUtcTimestamp: info.createdUtcTimestamp,
consistencyLevel: info.consistencyLevel,
partInfo: make(map[string]*partitionInfo, len(info.partInfo)),
}
for s, info := range info.partInfo {
info2 := *info
basicInfo.partInfo[s] = &info2
}
return basicInfo
}
func (info *collectionInfo) isCollectionCached() bool {
......@@ -252,8 +276,8 @@ func (m *MetaCache) GetCollectionID(ctx context.Context, database, collectionNam
return collInfo.collID, nil
}
// GetDatabaseAndCollectionName returns the corresponding collection name for provided collection id
func (m *MetaCache) GetDatabaseAndCollectionName(ctx context.Context, collectionID int64) (string, string, error) {
// GetCollectionName returns the corresponding collection name for provided collection id
func (m *MetaCache) GetCollectionName(ctx context.Context, collectionID int64) (string, error) {
m.mu.RLock()
var collInfo *collectionInfo
for _, db := range m.collInfo {
......@@ -272,24 +296,59 @@ func (m *MetaCache) GetDatabaseAndCollectionName(ctx context.Context, collection
m.mu.RUnlock()
coll, err := m.describeCollection(ctx, "", "", collectionID)
if err != nil {
return "", "", err
return "", err
}
m.mu.Lock()
defer m.mu.Unlock()
m.updateCollection(coll, coll.GetDbName(), coll.Schema.Name)
metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
return coll.GetDbName(), coll.Schema.Name, nil
return coll.Schema.Name, nil
}
defer m.mu.RUnlock()
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc()
return collInfo.database, collInfo.schema.Name, nil
return collInfo.schema.Name, nil
}
func (m *MetaCache) GetCollectionInfo(ctx context.Context, database string, collectionName string, collectionID int64) (*collectionBasicInfo, error) {
m.mu.RLock()
var collInfo *collectionInfo
var ok bool
db, dbOk := m.collInfo[database]
if dbOk {
collInfo, ok = db[collectionName]
}
method := "GetCollectionInfo"
// if collInfo.collID != collectionID, means that the cache is not trustable
// try to get collection according to collectionID
if !ok || !collInfo.isCollectionCached() || collInfo.collID != collectionID {
m.mu.RUnlock()
tr := timerecord.NewTimeRecorder("UpdateCache")
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
coll, err := m.describeCollection(ctx, database, "", collectionID)
if err != nil {
return nil, err
}
m.mu.Lock()
defer m.mu.Unlock()
m.updateCollection(coll, database, collectionName)
collInfo = m.collInfo[database][collectionName]
metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
return collInfo.getBasicInfo(), nil
}
defer m.mu.RUnlock()
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc()
return collInfo.getBasicInfo(), nil
}
// GetCollectionInfo returns the collection information related to provided collection name
// If the information is not found, proxy will try to fetch information for other source (RootCoord for now)
func (m *MetaCache) GetCollectionInfo(ctx context.Context, database, collectionName string, collectionID int64) (*collectionInfo, error) {
// TODO: may cause data race of this implementation, should be refactored in future.
func (m *MetaCache) getFullCollectionInfo(ctx context.Context, database, collectionName string, collectionID int64) (*collectionInfo, error) {
m.mu.RLock()
var collInfo *collectionInfo
var ok bool
......@@ -298,12 +357,12 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, database, collectionN
if dbOk {
collInfo, ok = db[collectionName]
}
m.mu.RUnlock()
method := "GetCollectionInfo"
// if collInfo.collID != collectionID, means that the cache is not trustable
// try to get collection according to collectionID
if !ok || !collInfo.isCollectionCached() || collInfo.collID != collectionID {
m.mu.RUnlock()
tr := timerecord.NewTimeRecorder("UpdateCache")
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
var coll *milvuspb.DescribeCollectionResponse
......@@ -320,8 +379,10 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, database, collectionN
collInfo = m.collInfo[database][collectionName]
m.mu.Unlock()
metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
return collInfo, nil
}
m.mu.RUnlock()
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc()
return collInfo, nil
}
......@@ -707,7 +768,7 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
zap.String("collectionName", collectionName),
zap.Int64("collectionID", collectionID))
info, err := m.GetCollectionInfo(ctx, database, collectionName, collectionID)
info, err := m.getFullCollectionInfo(ctx, database, collectionName, collectionID)
if err != nil {
return nil, err
}
......@@ -764,7 +825,7 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
shards := parseShardLeaderList2QueryNode(resp.GetShards())
info, err = m.GetCollectionInfo(ctx, database, collectionName, collectionID)
info, err = m.getFullCollectionInfo(ctx, database, collectionName, collectionID)
if err != nil {
return nil, fmt.Errorf("failed to get shards, collectionName %s, colectionID %d not found", collectionName, collectionID)
}
......@@ -825,7 +886,6 @@ func (m *MetaCache) DeprecateShardCache(database, collectionName string) {
if ok {
info.deprecateLeaderCache()
}
}
func (m *MetaCache) expireShardLeaderCache(ctx context.Context) {
......
......@@ -259,7 +259,40 @@ func TestMetaCache_GetCollection(t *testing.T) {
Fields: []*schemapb.FieldSchema{},
Name: "collection1",
})
}
func TestMetaCache_GetBasicCollectionInfo(t *testing.T) {
ctx := context.Background()
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &mocks.MockQueryCoord{}
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, queryCoord, mgr)
assert.NoError(t, err)
// should be no data race.
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
info, err := globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1", 1)
assert.NoError(t, err)
assert.Equal(t, info.collID, int64(1))
_ = info.consistencyLevel
_ = info.createdTimestamp
_ = info.createdUtcTimestamp
_ = info.partInfo
}()
go func() {
defer wg.Done()
info, err := globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1", 1)
assert.NoError(t, err)
assert.Equal(t, info.collID, int64(1))
_ = info.consistencyLevel
_ = info.createdTimestamp
_ = info.createdUtcTimestamp
_ = info.partInfo
}()
wg.Wait()
}
func TestMetaCache_GetCollectionName(t *testing.T) {
......@@ -270,9 +303,8 @@ func TestMetaCache_GetCollectionName(t *testing.T) {
err := InitMetaCache(ctx, rootCoord, queryCoord, mgr)
assert.NoError(t, err)
db, collection, err := globalMetaCache.GetDatabaseAndCollectionName(ctx, 1)
collection, err := globalMetaCache.GetCollectionName(ctx, 1)
assert.NoError(t, err)
assert.Equal(t, db, dbName)
assert.Equal(t, collection, "collection1")
assert.Equal(t, rootCoord.GetAccessCount(), 1)
......@@ -285,7 +317,7 @@ func TestMetaCache_GetCollectionName(t *testing.T) {
Fields: []*schemapb.FieldSchema{},
Name: "collection1",
})
_, collection, err = globalMetaCache.GetDatabaseAndCollectionName(ctx, 1)
collection, err = globalMetaCache.GetCollectionName(ctx, 1)
assert.Equal(t, rootCoord.GetAccessCount(), 1)
assert.NoError(t, err)
assert.Equal(t, collection, "collection1")
......@@ -299,7 +331,7 @@ func TestMetaCache_GetCollectionName(t *testing.T) {
})
// test to get from cache, this should trigger root request
_, collection, err = globalMetaCache.GetDatabaseAndCollectionName(ctx, 1)
collection, err = globalMetaCache.GetCollectionName(ctx, 1)
assert.Equal(t, rootCoord.GetAccessCount(), 2)
assert.NoError(t, err)
assert.Equal(t, collection, "collection1")
......@@ -397,7 +429,7 @@ func TestMetaCache_ConcurrentTest1(t *testing.T) {
getCollectionCacheFunc := func(wg *sync.WaitGroup) {
defer wg.Done()
for i := 0; i < cnt; i++ {
//GetCollectionSchema will never fail
// GetCollectionSchema will never fail
schema, err := globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1")
assert.NoError(t, err)
assert.Equal(t, schema, &schemapb.CollectionSchema{
......@@ -412,7 +444,7 @@ func TestMetaCache_ConcurrentTest1(t *testing.T) {
getPartitionCacheFunc := func(wg *sync.WaitGroup) {
defer wg.Done()
for i := 0; i < cnt; i++ {
//GetPartitions may fail
// GetPartitions may fail
globalMetaCache.GetPartitions(ctx, dbName, "collection1")
time.Sleep(10 * time.Millisecond)
}
......@@ -421,7 +453,7 @@ func TestMetaCache_ConcurrentTest1(t *testing.T) {
invalidCacheFunc := func(wg *sync.WaitGroup) {
defer wg.Done()
for i := 0; i < cnt; i++ {
//periodically invalid collection cache
// periodically invalid collection cache
globalMetaCache.RemoveCollection(ctx, dbName, "collection1")
time.Sleep(10 * time.Millisecond)
}
......@@ -574,7 +606,6 @@ func TestMetaCache_ClearShards(t *testing.T) {
})
t.Run("Clear valid collection valid cache", func(t *testing.T) {
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
......@@ -731,6 +762,13 @@ func TestMetaCache_RemoveCollection(t *testing.T) {
assert.NoError(t, err)
// shouldn't access RootCoord again
assert.Equal(t, rootCoord.GetAccessCount(), 3)
globalMetaCache.RemoveCollectionsByID(ctx, UniqueID(1))
// no collectionInfo of collection2, should access RootCoord
_, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1", 1)
assert.NoError(t, err)
// no collectionInfo of collection1, should access RootCoord
assert.Equal(t, rootCoord.GetAccessCount(), 4)
}
func TestMetaCache_ExpireShardLeaderCache(t *testing.T) {
......
......@@ -115,19 +115,19 @@ func (_c *MockCache_GetCollectionID_Call) RunAndReturn(run func(context.Context,
}
// GetCollectionInfo provides a mock function with given fields: ctx, database, collectionName, collectionID
func (_m *MockCache) GetCollectionInfo(ctx context.Context, database string, collectionName string, collectionID int64) (*collectionInfo, error) {
func (_m *MockCache) GetCollectionInfo(ctx context.Context, database string, collectionName string, collectionID int64) (*collectionBasicInfo, error) {
ret := _m.Called(ctx, database, collectionName, collectionID)
var r0 *collectionInfo
var r0 *collectionBasicInfo
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) (*collectionInfo, error)); ok {
if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) (*collectionBasicInfo, error)); ok {
return rf(ctx, database, collectionName, collectionID)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) *collectionInfo); ok {
if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) *collectionBasicInfo); ok {
r0 = rf(ctx, database, collectionName, collectionID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*collectionInfo)
r0 = ret.Get(0).(*collectionBasicInfo)
}
}
......@@ -161,12 +161,65 @@ func (_c *MockCache_GetCollectionInfo_Call) Run(run func(ctx context.Context, da
return _c
}
func (_c *MockCache_GetCollectionInfo_Call) Return(_a0 *collectionInfo, _a1 error) *MockCache_GetCollectionInfo_Call {
func (_c *MockCache_GetCollectionInfo_Call) Return(_a0 *collectionBasicInfo, _a1 error) *MockCache_GetCollectionInfo_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockCache_GetCollectionInfo_Call) RunAndReturn(run func(context.Context, string, string, int64) (*collectionInfo, error)) *MockCache_GetCollectionInfo_Call {
func (_c *MockCache_GetCollectionInfo_Call) RunAndReturn(run func(context.Context, string, string, int64) (*collectionBasicInfo, error)) *MockCache_GetCollectionInfo_Call {
_c.Call.Return(run)
return _c
}
// GetCollectionName provides a mock function with given fields: ctx, collectionID
func (_m *MockCache) GetCollectionName(ctx context.Context, collectionID int64) (string, error) {
ret := _m.Called(ctx, collectionID)
var r0 string
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, int64) (string, error)); ok {
return rf(ctx, collectionID)
}
if rf, ok := ret.Get(0).(func(context.Context, int64) string); ok {
r0 = rf(ctx, collectionID)
} else {
r0 = ret.Get(0).(string)
}
if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok {
r1 = rf(ctx, collectionID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockCache_GetCollectionName_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionName'
type MockCache_GetCollectionName_Call struct {
*mock.Call
}
// GetCollectionName is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
func (_e *MockCache_Expecter) GetCollectionName(ctx interface{}, collectionID interface{}) *MockCache_GetCollectionName_Call {
return &MockCache_GetCollectionName_Call{Call: _e.mock.On("GetCollectionName", ctx, collectionID)}
}
func (_c *MockCache_GetCollectionName_Call) Run(run func(ctx context.Context, collectionID int64)) *MockCache_GetCollectionName_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64))
})
return _c
}
func (_c *MockCache_GetCollectionName_Call) Return(_a0 string, _a1 error) *MockCache_GetCollectionName_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockCache_GetCollectionName_Call) RunAndReturn(run func(context.Context, int64) (string, error)) *MockCache_GetCollectionName_Call {
_c.Call.Return(run)
return _c
}
......@@ -282,66 +335,6 @@ func (_c *MockCache_GetCredentialInfo_Call) RunAndReturn(run func(context.Contex
return _c
}
// GetDatabaseAndCollectionName provides a mock function with given fields: ctx, collectionID
func (_m *MockCache) GetDatabaseAndCollectionName(ctx context.Context, collectionID int64) (string, string, error) {
ret := _m.Called(ctx, collectionID)
var r0 string
var r1 string
var r2 error
if rf, ok := ret.Get(0).(func(context.Context, int64) (string, string, error)); ok {
return rf(ctx, collectionID)
}
if rf, ok := ret.Get(0).(func(context.Context, int64) string); ok {
r0 = rf(ctx, collectionID)
} else {
r0 = ret.Get(0).(string)
}
if rf, ok := ret.Get(1).(func(context.Context, int64) string); ok {
r1 = rf(ctx, collectionID)
} else {
r1 = ret.Get(1).(string)
}
if rf, ok := ret.Get(2).(func(context.Context, int64) error); ok {
r2 = rf(ctx, collectionID)
} else {
r2 = ret.Error(2)
}
return r0, r1, r2
}
// MockCache_GetDatabaseAndCollectionName_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDatabaseAndCollectionName'
type MockCache_GetDatabaseAndCollectionName_Call struct {
*mock.Call
}
// GetDatabaseAndCollectionName is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
func (_e *MockCache_Expecter) GetDatabaseAndCollectionName(ctx interface{}, collectionID interface{}) *MockCache_GetDatabaseAndCollectionName_Call {
return &MockCache_GetDatabaseAndCollectionName_Call{Call: _e.mock.On("GetDatabaseAndCollectionName", ctx, collectionID)}
}
func (_c *MockCache_GetDatabaseAndCollectionName_Call) Run(run func(ctx context.Context, collectionID int64)) *MockCache_GetDatabaseAndCollectionName_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64))
})
return _c
}
func (_c *MockCache_GetDatabaseAndCollectionName_Call) Return(_a0 string, _a1 string, _a2 error) *MockCache_GetDatabaseAndCollectionName_Call {
_c.Call.Return(_a0, _a1, _a2)
return _c
}
func (_c *MockCache_GetDatabaseAndCollectionName_Call) RunAndReturn(run func(context.Context, int64) (string, string, error)) *MockCache_GetDatabaseAndCollectionName_Call {
_c.Call.Return(run)
return _c
}
// GetPartitionID provides a mock function with given fields: ctx, database, collectionName, partitionName
func (_m *MockCache) GetPartitionID(ctx context.Context, database string, collectionName string, partitionName string) (int64, error) {
ret := _m.Called(ctx, database, collectionName, partitionName)
......
......@@ -521,7 +521,6 @@ func (dct *describeCollectionTask) Execute(ctx context.Context) error {
}
result, err := dct.rootCoord.DescribeCollection(ctx, dct.DescribeCollectionRequest)
if err != nil {
return err
}
......@@ -634,7 +633,6 @@ func (sct *showCollectionsTask) PreExecute(ctx context.Context) error {
func (sct *showCollectionsTask) Execute(ctx context.Context) error {
respFromRootCoord, err := sct.rootCoord.ShowCollections(ctx, sct.ShowCollectionsRequest)
if err != nil {
return err
}
......@@ -670,10 +668,9 @@ func (sct *showCollectionsTask) Execute(ctx context.Context) error {
sct.Base,
commonpbutil.WithMsgType(commonpb.MsgType_ShowCollections),
),
//DbID: sct.ShowCollectionsRequest.DbName,
// DbID: sct.ShowCollectionsRequest.DbName,
CollectionIDs: collectionIDs,
})
if err != nil {
return err
}
......@@ -1179,7 +1176,6 @@ func (spt *showPartitionsTask) Execute(ctx context.Context) error {
CollectionID: collectionID,
PartitionIDs: partitionIDs,
})
if err != nil {
return err
}
......@@ -2209,13 +2205,12 @@ func (t *DescribeResourceGroupTask) Execute(ctx context.Context) error {
resp, err := t.queryCoord.DescribeResourceGroup(ctx, &querypb.DescribeResourceGroupRequest{
ResourceGroup: t.ResourceGroup,
})
if err != nil {
return err
}
getCollectionNameFunc := func(value int32, key int64) string {
_, name, err := globalMetaCache.GetDatabaseAndCollectionName(ctx, key)
name, err := globalMetaCache.GetCollectionName(ctx, key)
if err != nil {
// unreachable logic path
return "unavailable_collection"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册