From bf391f2449a4ba7d5fb47f2f27f4286cee5e64ac Mon Sep 17 00:00:00 2001 From: dragondriver Date: Sun, 21 Nov 2021 00:39:13 +0800 Subject: [PATCH] Make indexparamcheck thread-safe (#11916) Signed-off-by: dragondriver --- .../util/indexparamcheck/conf_adapter_mgr.go | 9 ++------- .../indexparamcheck/conf_adapter_mgr_test.go | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/internal/util/indexparamcheck/conf_adapter_mgr.go b/internal/util/indexparamcheck/conf_adapter_mgr.go index 1635bad65..a8ad6111a 100644 --- a/internal/util/indexparamcheck/conf_adapter_mgr.go +++ b/internal/util/indexparamcheck/conf_adapter_mgr.go @@ -24,15 +24,13 @@ type ConfAdapterMgr interface { // ConfAdapterMgrImpl implements ConfAdapter. type ConfAdapterMgrImpl struct { - init bool adapters map[IndexType]ConfAdapter + once sync.Once } // GetAdapter gets the conf adapter by the index type. func (mgr *ConfAdapterMgrImpl) GetAdapter(indexType string) (ConfAdapter, error) { - if !mgr.init { - mgr.registerConfAdapter() - } + mgr.once.Do(mgr.registerConfAdapter) adapter, ok := mgr.adapters[indexType] if ok { @@ -42,8 +40,6 @@ func (mgr *ConfAdapterMgrImpl) GetAdapter(indexType string) (ConfAdapter, error) } func (mgr *ConfAdapterMgrImpl) registerConfAdapter() { - mgr.init = true - mgr.adapters[IndexFaissIDMap] = newBaseConfAdapter() mgr.adapters[IndexFaissIvfFlat] = newIVFConfAdapter() mgr.adapters[IndexFaissIvfPQ] = newIVFPQConfAdapter() @@ -63,7 +59,6 @@ func (mgr *ConfAdapterMgrImpl) registerConfAdapter() { func newConfAdapterMgrImpl() *ConfAdapterMgrImpl { return &ConfAdapterMgrImpl{ - init: false, adapters: make(map[IndexType]ConfAdapter), } } diff --git a/internal/util/indexparamcheck/conf_adapter_mgr_test.go b/internal/util/indexparamcheck/conf_adapter_mgr_test.go index e90a5fa6b..075ae72bc 100644 --- a/internal/util/indexparamcheck/conf_adapter_mgr_test.go +++ b/internal/util/indexparamcheck/conf_adapter_mgr_test.go @@ -12,6 +12,7 @@ package indexparamcheck import ( + "sync" "testing" "github.com/stretchr/testify/assert" @@ -220,3 +221,19 @@ func TestConfAdapterMgrImpl_GetAdapter(t *testing.T) { _, ok = adapter.(*NGTONNGConfAdapter) assert.Equal(t, true, ok) } + +func TestConfAdapterMgrImpl_GetAdapter_multiple_threads(t *testing.T) { + num := 4 + mgr := newConfAdapterMgrImpl() + var wg sync.WaitGroup + for i := 0; i < num; i++ { + wg.Add(1) + go func() { + defer wg.Done() + adapter, err := mgr.GetAdapter(IndexHNSW) + assert.NoError(t, err) + assert.NotNil(t, adapter) + }() + } + wg.Wait() +} -- GitLab