diff --git a/cmd/roles/roles.go b/cmd/roles/roles.go index 157bb1f9abc88c8c97e9383570b05890e381f0bb..6353c482fd2418ab4983c25096f8a18e8f340c63 100644 --- a/cmd/roles/roles.go +++ b/cmd/roles/roles.go @@ -28,6 +28,7 @@ import ( "syscall" "github.com/milvus-io/milvus/internal/util/healthz" + "github.com/milvus-io/milvus/internal/util/rocksmq/server/rocksmq" "github.com/milvus-io/milvus/internal/util/metricsinfo" "go.uber.org/zap" @@ -56,6 +57,15 @@ func newMsgFactory(localMsg bool) msgstream.Factory { return msgstream.NewPmsFactory() } +func initRocksmq() error { + err := rocksmq.InitRocksMQ() + return err +} + +func stopRocksmq() { + rocksmq.CloseRocksMQ() +} + // MilvusRoles determines to run which components. type MilvusRoles struct { EnableRootCoord bool `env:"ENABLE_ROOT_COORD"` @@ -349,6 +359,12 @@ func (mr *MilvusRoles) Run(localMsg bool, alias string) { cfg := mr.setLogConfigFilename("standalone.log") logutil.SetupLogger(cfg) defer log.Sync() + + err := initRocksmq() + if err != nil { + panic(err) + } + defer stopRocksmq() } else { err := os.Setenv(metricsinfo.DeployModeEnvKey, metricsinfo.ClusterDeployMode) if err != nil { diff --git a/internal/util/rocksmq/server/rocksmq/global_rmq.go b/internal/util/rocksmq/server/rocksmq/global_rmq.go index 86912a811b5edc94a86d0de5c858e47b2591d017..0533d4e43ebd64bdf1a41c6def4ac16efe289f29 100644 --- a/internal/util/rocksmq/server/rocksmq/global_rmq.go +++ b/internal/util/rocksmq/server/rocksmq/global_rmq.go @@ -12,6 +12,7 @@ package rocksmq import ( + "errors" "os" "strconv" "sync" @@ -43,26 +44,33 @@ func InitRmq(rocksdbName string, idAllocator allocator.GIDAllocator) error { // InitRocksMQ init global rocksmq single instance func InitRocksMQ() error { - var err error + var finalErr error once.Do(func() { params.Init() rocksdbName, _ := params.Load("_RocksmqPath") log.Debug("RocksmqPath=" + rocksdbName) - _, err = os.Stat(rocksdbName) - if os.IsNotExist(err) { - err = os.MkdirAll(rocksdbName, os.ModePerm) - if err != nil { - errMsg := "Create dir " + rocksdbName + " failed" - panic(errMsg) + var fi os.FileInfo + fi, finalErr = os.Stat(rocksdbName) + if os.IsNotExist(finalErr) { + finalErr = os.MkdirAll(rocksdbName, os.ModePerm) + if finalErr != nil { + return + } + } else { + if !fi.IsDir() { + errMsg := "can't create a directory because there exists a file with the same name" + finalErr = errors.New(errMsg) + return } } kvname := rocksdbName + "_kv" - rocksdbKV, err := rocksdbkv.NewRocksdbKV(kvname) - if err != nil { - panic(err) + var rkv *rocksdbkv.RocksdbKV + rkv, finalErr = rocksdbkv.NewRocksdbKV(kvname) + if finalErr != nil { + return } - idAllocator := allocator.NewGlobalIDAllocator("rmq_id", rocksdbKV) + idAllocator := allocator.NewGlobalIDAllocator("rmq_id", rkv) _ = idAllocator.Initialize() rawRmqPageSize, err := params.Load("rocksmq.rocksmqPageSize") @@ -94,12 +102,9 @@ func InitRocksMQ() error { } log.Debug("", zap.Any("RocksmqRetentionTimeInMinutes", RocksmqRetentionTimeInMinutes), zap.Any("RocksmqRetentionSizeInMB", RocksmqRetentionSizeInMB), zap.Any("RocksmqPageSize", RocksmqPageSize)) - Rmq, err = NewRocksMQ(rocksdbName, idAllocator) - if err != nil { - panic(err) - } + Rmq, finalErr = NewRocksMQ(rocksdbName, idAllocator) }) - return err + return finalErr } // CloseRocksMQ is used to close global rocksmq diff --git a/internal/util/rocksmq/server/rocksmq/global_rmq_test.go b/internal/util/rocksmq/server/rocksmq/global_rmq_test.go index 257b2d49e1f9563ec7f0fad375be1c0baddf8627..c6f452eb6e16509eede7036a643e2bc29fa4d82c 100644 --- a/internal/util/rocksmq/server/rocksmq/global_rmq_test.go +++ b/internal/util/rocksmq/server/rocksmq/global_rmq_test.go @@ -15,6 +15,7 @@ import ( "log" "os" "strings" + "sync" "testing" "github.com/milvus-io/milvus/internal/allocator" @@ -69,3 +70,13 @@ func Test_InitRocksMQ(t *testing.T) { } Rmq.RegisterConsumer(consumer) } + +func Test_InitRocksMQError(t *testing.T) { + once = sync.Once{} + dummyPath := "/tmp/milvus/dummy" + os.Create(dummyPath) + os.Setenv("ROCKSMQ_PATH", dummyPath) + defer os.RemoveAll(dummyPath) + err := InitRocksMQ() + assert.Error(t, err) +}