diff --git a/internal/util/indexparamcheck/conf_adapter_mgr.go b/internal/util/indexparamcheck/conf_adapter_mgr.go index 1635bad65a9ccca29d715e6ead139e9cb72cb9a9..a8ad6111a7d76906a215cbc292094e9001af66fa 100644 --- a/internal/util/indexparamcheck/conf_adapter_mgr.go +++ b/internal/util/indexparamcheck/conf_adapter_mgr.go @@ -24,15 +24,13 @@ type ConfAdapterMgr interface { // ConfAdapterMgrImpl implements ConfAdapter. type ConfAdapterMgrImpl struct { - init bool adapters map[IndexType]ConfAdapter + once sync.Once } // GetAdapter gets the conf adapter by the index type. func (mgr *ConfAdapterMgrImpl) GetAdapter(indexType string) (ConfAdapter, error) { - if !mgr.init { - mgr.registerConfAdapter() - } + mgr.once.Do(mgr.registerConfAdapter) adapter, ok := mgr.adapters[indexType] if ok { @@ -42,8 +40,6 @@ func (mgr *ConfAdapterMgrImpl) GetAdapter(indexType string) (ConfAdapter, error) } func (mgr *ConfAdapterMgrImpl) registerConfAdapter() { - mgr.init = true - mgr.adapters[IndexFaissIDMap] = newBaseConfAdapter() mgr.adapters[IndexFaissIvfFlat] = newIVFConfAdapter() mgr.adapters[IndexFaissIvfPQ] = newIVFPQConfAdapter() @@ -63,7 +59,6 @@ func (mgr *ConfAdapterMgrImpl) registerConfAdapter() { func newConfAdapterMgrImpl() *ConfAdapterMgrImpl { return &ConfAdapterMgrImpl{ - init: false, adapters: make(map[IndexType]ConfAdapter), } } diff --git a/internal/util/indexparamcheck/conf_adapter_mgr_test.go b/internal/util/indexparamcheck/conf_adapter_mgr_test.go index e90a5fa6bd8ca0fa87175f7cf1b1d31db031b411..075ae72bc2084793d1d8f9aaf90203f68b42ac25 100644 --- a/internal/util/indexparamcheck/conf_adapter_mgr_test.go +++ b/internal/util/indexparamcheck/conf_adapter_mgr_test.go @@ -12,6 +12,7 @@ package indexparamcheck import ( + "sync" "testing" "github.com/stretchr/testify/assert" @@ -220,3 +221,19 @@ func TestConfAdapterMgrImpl_GetAdapter(t *testing.T) { _, ok = adapter.(*NGTONNGConfAdapter) assert.Equal(t, true, ok) } + +func TestConfAdapterMgrImpl_GetAdapter_multiple_threads(t *testing.T) { + num := 4 + mgr := newConfAdapterMgrImpl() + var wg sync.WaitGroup + for i := 0; i < num; i++ { + wg.Add(1) + go func() { + defer wg.Done() + adapter, err := mgr.GetAdapter(IndexHNSW) + assert.NoError(t, err) + assert.NotNil(t, adapter) + }() + } + wg.Wait() +}