From 0c9a10e8f8ca8b53423a8ce17453746b1281516a Mon Sep 17 00:00:00 2001 From: Zach Date: Tue, 23 Aug 2022 10:44:52 +0800 Subject: [PATCH] 1. refine logging interfaces (#18692) 2. adjust logs for query/search requests Signed-off-by: Zach41 Signed-off-by: Zach41 --- cmd/roles/roles.go | 5 +- internal/distributed/datacoord/service.go | 9 +- internal/distributed/datanode/service.go | 10 +- internal/distributed/indexcoord/service.go | 10 +- internal/distributed/indexnode/service.go | 10 +- internal/distributed/proxy/service.go | 7 +- internal/distributed/querycoord/service.go | 10 +- internal/distributed/querynode/service.go | 10 +- internal/distributed/rootcoord/service.go | 10 +- internal/log/global.go | 100 ++++++++++++++ internal/log/log.go | 85 +++++++++++- internal/log/log_test.go | 72 ++++++++++ internal/log/meta_logger.go | 129 ++++++++++++++++++ internal/log/meta_logger_test.go | 52 +++++++ internal/log/meta_ops.go | 17 +++ internal/log/mlogger.go | 31 +++++ internal/log/mlogger_test.go | 102 ++++++++++++++ internal/log/zap_log_test.go | 41 ++++++ internal/metastore/model/collection_test.go | 18 ++- internal/metastore/model/index_test.go | 7 +- internal/metastore/model/segment_test.go | 5 +- internal/proxy/impl.go | 40 ++---- internal/proxy/meta_cache_test.go | 2 +- internal/proxy/task_policies.go | 8 +- internal/proxy/task_query.go | 56 ++++---- internal/proxy/task_search.go | 68 ++++----- internal/proxy/task_search_test.go | 4 +- internal/querycoord/impl.go | 12 +- internal/querynode/benchmark_test.go | 8 +- internal/querynode/impl.go | 40 +++--- internal/querynode/query_shard_test.go | 9 +- internal/querynode/result.go | 25 ++-- internal/querynode/retrieve.go | 10 +- internal/querynode/retrieve_test.go | 5 +- internal/querynode/search.go | 15 +- internal/querynode/search_test.go | 22 +-- internal/querynode/statistic_test.go | 20 +-- internal/querynode/statistics.go | 9 +- internal/querynode/task_query.go | 16 ++- internal/querynode/task_search.go | 27 ++-- internal/querynode/task_statistics.go | 17 ++- internal/querynode/validate.go | 11 +- internal/querynode/validate_test.go | 18 +-- internal/rootcoord/root_coord.go | 6 +- internal/rootcoord/task.go | 37 ++++- internal/util/logutil/grpc_interceptor.go | 78 +++++++++++ .../util/logutil/grpc_interceptor_test.go | 64 +++++++++ internal/util/logutil/logutil.go | 52 +------ internal/util/timerecord/time_recorder.go | 21 ++- 49 files changed, 1123 insertions(+), 317 deletions(-) create mode 100644 internal/log/meta_logger.go create mode 100644 internal/log/meta_logger_test.go create mode 100644 internal/log/meta_ops.go create mode 100644 internal/log/mlogger.go create mode 100644 internal/log/mlogger_test.go create mode 100644 internal/util/logutil/grpc_interceptor.go create mode 100644 internal/util/logutil/grpc_interceptor_test.go diff --git a/cmd/roles/roles.go b/cmd/roles/roles.go index e4fca1216..12a8eb909 100644 --- a/cmd/roles/roles.go +++ b/cmd/roles/roles.go @@ -48,7 +48,6 @@ import ( "github.com/milvus-io/milvus/internal/rootcoord" "github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/healthz" - "github.com/milvus-io/milvus/internal/util/logutil" "github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/trace" @@ -238,7 +237,7 @@ func (mr *MilvusRoles) runDataCoord(ctx context.Context, localMsg bool) *compone factory := dependency.NewFactory(localMsg) - dctx := logutil.WithModule(ctx, "DataCoord") + dctx := log.WithModule(ctx, "DataCoord") var err error ds, err = components.NewDataCoord(dctx, factory) if err != nil { @@ -406,7 +405,7 @@ func (mr *MilvusRoles) Run(local bool, alias string) { var pn *components.Proxy if mr.EnableProxy { - pctx := logutil.WithModule(ctx, "Proxy") + pctx := log.WithModule(ctx, "Proxy") pn = mr.runProxy(pctx, local, alias) if pn != nil { defer pn.Stop() diff --git a/internal/distributed/datacoord/service.go b/internal/distributed/datacoord/service.go index b4bfb3901..bb5c762b9 100644 --- a/internal/distributed/datacoord/service.go +++ b/internal/distributed/datacoord/service.go @@ -25,6 +25,7 @@ import ( "sync" "time" + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" @@ -149,8 +150,12 @@ func (s *Server) startGrpcLoop(grpcPort int) { grpc.KeepaliveParams(kasp), grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize), grpc.MaxSendMsgSize(Params.ServerMaxSendSize), - grpc.UnaryInterceptor(ot.UnaryServerInterceptor(opts...)), - grpc.StreamInterceptor(ot.StreamServerInterceptor(opts...))) + grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( + ot.UnaryServerInterceptor(opts...), + logutil.UnaryTraceLoggerInterceptor)), + grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( + ot.StreamServerInterceptor(opts...), + logutil.StreamTraceLoggerInterceptor))) datapb.RegisterDataCoordServer(s.grpcServer, s) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) if err := s.grpcServer.Serve(lis); err != nil { diff --git a/internal/distributed/datanode/service.go b/internal/distributed/datanode/service.go index 2e82d030c..74ad31b99 100644 --- a/internal/distributed/datanode/service.go +++ b/internal/distributed/datanode/service.go @@ -26,6 +26,7 @@ import ( "sync" "time" + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" @@ -44,6 +45,7 @@ import ( "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/logutil" "github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/retry" "github.com/milvus-io/milvus/internal/util/trace" @@ -134,8 +136,12 @@ func (s *Server) startGrpcLoop(grpcPort int) { grpc.KeepaliveParams(kasp), grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize), grpc.MaxSendMsgSize(Params.ServerMaxSendSize), - grpc.UnaryInterceptor(ot.UnaryServerInterceptor(opts...)), - grpc.StreamInterceptor(ot.StreamServerInterceptor(opts...))) + grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( + ot.UnaryServerInterceptor(opts...), + logutil.UnaryTraceLoggerInterceptor)), + grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( + ot.StreamServerInterceptor(opts...), + logutil.StreamTraceLoggerInterceptor))) datapb.RegisterDataNodeServer(s.grpcServer, s) ctx, cancel := context.WithCancel(s.ctx) diff --git a/internal/distributed/indexcoord/service.go b/internal/distributed/indexcoord/service.go index a8b20e494..59123099f 100644 --- a/internal/distributed/indexcoord/service.go +++ b/internal/distributed/indexcoord/service.go @@ -29,6 +29,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/keepalive" + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" dcc "github.com/milvus-io/milvus/internal/distributed/datacoord/client" "github.com/milvus-io/milvus/internal/indexcoord" @@ -42,6 +43,7 @@ import ( "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/logutil" "github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/trace" "github.com/milvus-io/milvus/internal/util/typeutil" @@ -277,8 +279,12 @@ func (s *Server) startGrpcLoop(grpcPort int) { grpc.KeepaliveParams(kasp), grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize), grpc.MaxSendMsgSize(Params.ServerMaxSendSize), - grpc.UnaryInterceptor(ot.UnaryServerInterceptor(opts...)), - grpc.StreamInterceptor(ot.StreamServerInterceptor(opts...))) + grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( + ot.UnaryServerInterceptor(opts...), + logutil.UnaryTraceLoggerInterceptor)), + grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( + ot.StreamServerInterceptor(opts...), + logutil.StreamTraceLoggerInterceptor))) indexpb.RegisterIndexCoordServer(s.grpcServer, s) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) diff --git a/internal/distributed/indexnode/service.go b/internal/distributed/indexnode/service.go index 3a3544546..1b9a7e1f3 100644 --- a/internal/distributed/indexnode/service.go +++ b/internal/distributed/indexnode/service.go @@ -30,6 +30,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/keepalive" + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" "github.com/milvus-io/milvus/internal/indexnode" @@ -42,6 +43,7 @@ import ( "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/logutil" "github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/trace" "github.com/milvus-io/milvus/internal/util/typeutil" @@ -108,8 +110,12 @@ func (s *Server) startGrpcLoop(grpcPort int) { grpc.KeepaliveParams(kasp), grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize), grpc.MaxSendMsgSize(Params.ServerMaxSendSize), - grpc.UnaryInterceptor(ot.UnaryServerInterceptor(opts...)), - grpc.StreamInterceptor(ot.StreamServerInterceptor(opts...))) + grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( + ot.UnaryServerInterceptor(opts...), + logutil.UnaryTraceLoggerInterceptor)), + grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( + ot.StreamServerInterceptor(opts...), + logutil.StreamTraceLoggerInterceptor))) indexpb.RegisterIndexNodeServer(s.grpcServer, s) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) if err := s.grpcServer.Serve(lis); err != nil { diff --git a/internal/distributed/proxy/service.go b/internal/distributed/proxy/service.go index 54668fe99..a35ea9c60 100644 --- a/internal/distributed/proxy/service.go +++ b/internal/distributed/proxy/service.go @@ -62,6 +62,7 @@ import ( "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/logutil" "github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/trace" "github.com/milvus-io/milvus/internal/util/typeutil" @@ -171,10 +172,12 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) { ot.UnaryServerInterceptor(opts...), grpc_auth.UnaryServerInterceptor(proxy.AuthenticationInterceptor), proxy.UnaryServerInterceptor(proxy.PrivilegeInterceptor), + logutil.UnaryTraceLoggerInterceptor, )), grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( ot.StreamServerInterceptor(opts...), - grpc_auth.StreamServerInterceptor(proxy.AuthenticationInterceptor))), + grpc_auth.StreamServerInterceptor(proxy.AuthenticationInterceptor), + logutil.StreamTraceLoggerInterceptor)), } if Params.TLSMode == 1 { @@ -261,10 +264,12 @@ func (s *Server) startInternalGrpc(grpcPort int, errChan chan error) { grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( ot.UnaryServerInterceptor(opts...), grpc_auth.UnaryServerInterceptor(proxy.AuthenticationInterceptor), + logutil.UnaryTraceLoggerInterceptor, )), grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( ot.StreamServerInterceptor(opts...), grpc_auth.StreamServerInterceptor(proxy.AuthenticationInterceptor), + logutil.StreamTraceLoggerInterceptor, )), ) proxypb.RegisterProxyServer(s.grpcInternalServer, s) diff --git a/internal/distributed/querycoord/service.go b/internal/distributed/querycoord/service.go index 60a9b471a..2174a2b74 100644 --- a/internal/distributed/querycoord/service.go +++ b/internal/distributed/querycoord/service.go @@ -24,6 +24,7 @@ import ( "sync" "time" + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" @@ -43,6 +44,7 @@ import ( "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/logutil" "github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/trace" "github.com/milvus-io/milvus/internal/util/typeutil" @@ -259,8 +261,12 @@ func (s *Server) startGrpcLoop(grpcPort int) { grpc.KeepaliveParams(kasp), grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize), grpc.MaxSendMsgSize(Params.ServerMaxSendSize), - grpc.UnaryInterceptor(ot.UnaryServerInterceptor(opts...)), - grpc.StreamInterceptor(ot.StreamServerInterceptor(opts...))) + grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( + ot.UnaryServerInterceptor(opts...), + logutil.UnaryTraceLoggerInterceptor)), + grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( + ot.StreamServerInterceptor(opts...), + logutil.StreamTraceLoggerInterceptor))) querypb.RegisterQueryCoordServer(s.grpcServer, s) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) diff --git a/internal/distributed/querynode/service.go b/internal/distributed/querynode/service.go index 207a52ce3..696b4d767 100644 --- a/internal/distributed/querynode/service.go +++ b/internal/distributed/querynode/service.go @@ -25,6 +25,7 @@ import ( "sync" "time" + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" @@ -41,6 +42,7 @@ import ( "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/logutil" "github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/retry" "github.com/milvus-io/milvus/internal/util/trace" @@ -179,8 +181,12 @@ func (s *Server) startGrpcLoop(grpcPort int) { grpc.KeepaliveParams(kasp), grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize), grpc.MaxSendMsgSize(Params.ServerMaxSendSize), - grpc.UnaryInterceptor(ot.UnaryServerInterceptor(opts...)), - grpc.StreamInterceptor(ot.StreamServerInterceptor(opts...))) + grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( + ot.UnaryServerInterceptor(opts...), + logutil.UnaryTraceLoggerInterceptor)), + grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( + ot.StreamServerInterceptor(opts...), + logutil.StreamTraceLoggerInterceptor))) querypb.RegisterQueryNodeServer(s.grpcServer, s) ctx, cancel := context.WithCancel(s.ctx) diff --git a/internal/distributed/rootcoord/service.go b/internal/distributed/rootcoord/service.go index 1ff9199cb..7cc367de2 100644 --- a/internal/distributed/rootcoord/service.go +++ b/internal/distributed/rootcoord/service.go @@ -26,6 +26,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/indexpb" + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" @@ -45,6 +46,7 @@ import ( "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/logutil" "github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/trace" @@ -258,8 +260,12 @@ func (s *Server) startGrpcLoop(port int) { grpc.KeepaliveParams(kasp), grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize), grpc.MaxSendMsgSize(Params.ServerMaxSendSize), - grpc.UnaryInterceptor(ot.UnaryServerInterceptor(opts...)), - grpc.StreamInterceptor(ot.StreamServerInterceptor(opts...))) + grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( + ot.UnaryServerInterceptor(opts...), + logutil.UnaryTraceLoggerInterceptor)), + grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( + ot.StreamServerInterceptor(opts...), + logutil.StreamTraceLoggerInterceptor))) rootcoordpb.RegisterRootCoordServer(s.grpcServer, s) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) diff --git a/internal/log/global.go b/internal/log/global.go index a61ecbd59..576bb6043 100644 --- a/internal/log/global.go +++ b/internal/log/global.go @@ -14,10 +14,18 @@ package log import ( + "context" + "go.uber.org/zap" "go.uber.org/zap/zapcore" ) +type ctxLogKeyType struct{} + +var ( + CtxLogKey = ctxLogKeyType{} +) + // Debug logs a message at DebugLevel. The message includes any fields passed // at the log site, as well as any fields accumulated on the logger. func Debug(msg string, fields ...zap.Field) { @@ -107,3 +115,95 @@ func SetLevel(l zapcore.Level) { func GetLevel() zapcore.Level { return _globalP.Load().(*ZapProperties).Level.Level() } + +// WithTraceID returns a context with trace_id attached +func WithTraceID(ctx context.Context, traceID string) context.Context { + return WithFields(ctx, zap.String("traceID", traceID)) +} + +// WithReqID adds given reqID field to the logger in ctx +func WithReqID(ctx context.Context, reqID int64) context.Context { + fields := []zap.Field{zap.Int64("reqID", reqID)} + return WithFields(ctx, fields...) +} + +// WithModule adds given module field to the logger in ctx +func WithModule(ctx context.Context, module string) context.Context { + fields := []zap.Field{zap.String("module", module)} + return WithFields(ctx, fields...) +} + +// WithFields returns a context with fields attached +func WithFields(ctx context.Context, fields ...zap.Field) context.Context { + var zlogger *zap.Logger + if ctxLogger, ok := ctx.Value(CtxLogKey).(*MLogger); ok { + zlogger = ctxLogger.Logger + } else { + zlogger = ctxL() + } + mLogger := &MLogger{ + Logger: zlogger.With(fields...), + } + return context.WithValue(ctx, CtxLogKey, mLogger) +} + +// Ctx returns a logger which will log contextual messages attached in ctx +func Ctx(ctx context.Context) *MLogger { + if ctx == nil { + return &MLogger{Logger: ctxL()} + } + if ctxLogger, ok := ctx.Value(CtxLogKey).(*MLogger); ok { + return ctxLogger + } + return &MLogger{Logger: ctxL()} +} + +// withLogLevel returns ctx with a leveled logger, notes that it will overwrite logger previous attached! +func withLogLevel(ctx context.Context, level zapcore.Level) context.Context { + var zlogger *zap.Logger + switch level { + case zap.DebugLevel: + zlogger = debugL() + case zap.InfoLevel: + zlogger = infoL() + case zap.WarnLevel: + zlogger = warnL() + case zap.ErrorLevel: + zlogger = errorL() + case zap.FatalLevel: + zlogger = fatalL() + default: + zlogger = L() + } + return context.WithValue(ctx, CtxLogKey, &MLogger{Logger: zlogger}) +} + +// WithDebugLevel returns context with a debug level enabled logger. +// Notes that it will overwrite previous attached logger within context +func WithDebugLevel(ctx context.Context) context.Context { + return withLogLevel(ctx, zapcore.DebugLevel) +} + +// WithInfoLevel returns context with a info level enabled logger. +// Notes that it will overwrite previous attached logger within context +func WithInfoLevel(ctx context.Context) context.Context { + return withLogLevel(ctx, zapcore.InfoLevel) +} + +// WithWarnLevel returns context with a warning level enabled logger. +// Notes that it will overwrite previous attached logger within context +func WithWarnLevel(ctx context.Context) context.Context { + return withLogLevel(ctx, zapcore.WarnLevel) +} + +// WithErrorLevel returns context with a error level enabled logger. +// Notes that it will overwrite previous attached logger within context +func WithErrorLevel(ctx context.Context) context.Context { + return withLogLevel(ctx, zapcore.ErrorLevel) +} + +// WithFatalLevel returns context with a fatal level enabled logger. +// Notes that it will overwrite previous attached logger within context +func WithFatalLevel(ctx context.Context) context.Context { + return withLogLevel(ctx, zapcore.FatalLevel) +} diff --git a/internal/log/log.go b/internal/log/log.go index b3052bdfd..dcfa195c8 100644 --- a/internal/log/log.go +++ b/internal/log/log.go @@ -33,6 +33,7 @@ import ( "fmt" "net/http" "os" + "sync" "sync/atomic" "errors" @@ -45,12 +46,20 @@ import ( ) var _globalL, _globalP, _globalS, _globalR atomic.Value + +var ( + _globalLevelLogger sync.Map +) + var rateLimiter *utils.ReconfigurableRateLimiter func init() { l, p := newStdLogger() + + replaceLeveledLoggers(l) _globalL.Store(l) _globalP.Store(p) + s := _globalL.Load().(*zap.Logger).Sugar() _globalS.Store(s) @@ -80,7 +89,19 @@ func InitLogger(cfg *Config, opts ...zap.Option) (*zap.Logger, *ZapProperties, e } output = stdOut } - return InitLoggerWithWriteSyncer(cfg, output, opts...) + debugCfg := *cfg + debugCfg.Level = "debug" + debugL, r, err := InitLoggerWithWriteSyncer(&debugCfg, output, opts...) + if err != nil { + return nil, nil, err + } + replaceLeveledLoggers(debugL) + level := zapcore.DebugLevel + if err := level.UnmarshalText([]byte(cfg.Level)); err != nil { + return nil, nil, err + } + r.Level.SetLevel(level) + return debugL.WithOptions(zap.IncreaseLevel(level), zap.AddCallerSkip(1)), r, nil } // InitTestLogger initializes a logger for unit tests @@ -136,7 +157,7 @@ func initFileLog(cfg *FileLogConfig) (*lumberjack.Logger, error) { func newStdLogger() (*zap.Logger, *ZapProperties) { conf := &Config{Level: "debug", File: FileLogConfig{}} - lg, r, _ := InitLogger(conf, zap.AddCallerSkip(1)) + lg, r, _ := InitLogger(conf) return lg, r } @@ -157,6 +178,40 @@ func R() *utils.ReconfigurableRateLimiter { return _globalR.Load().(*utils.ReconfigurableRateLimiter) } +func ctxL() *zap.Logger { + level := _globalP.Load().(*ZapProperties).Level.Level() + l, ok := _globalLevelLogger.Load(level) + if !ok { + return L() + } + return l.(*zap.Logger) +} + +func debugL() *zap.Logger { + v, _ := _globalLevelLogger.Load(zapcore.DebugLevel) + return v.(*zap.Logger) +} + +func infoL() *zap.Logger { + v, _ := _globalLevelLogger.Load(zapcore.InfoLevel) + return v.(*zap.Logger) +} + +func warnL() *zap.Logger { + v, _ := _globalLevelLogger.Load(zapcore.WarnLevel) + return v.(*zap.Logger) +} + +func errorL() *zap.Logger { + v, _ := _globalLevelLogger.Load(zapcore.ErrorLevel) + return v.(*zap.Logger) +} + +func fatalL() *zap.Logger { + v, _ := _globalLevelLogger.Load(zapcore.FatalLevel) + return v.(*zap.Logger) +} + // ReplaceGlobals replaces the global Logger and SugaredLogger. // It's safe for concurrent use. func ReplaceGlobals(logger *zap.Logger, props *ZapProperties) { @@ -165,11 +220,31 @@ func ReplaceGlobals(logger *zap.Logger, props *ZapProperties) { _globalP.Store(props) } +func replaceLeveledLoggers(debugLogger *zap.Logger) { + levels := []zapcore.Level{zapcore.DebugLevel, zapcore.InfoLevel, zapcore.WarnLevel, zapcore.ErrorLevel, + zapcore.DPanicLevel, zapcore.PanicLevel, zapcore.FatalLevel} + for _, level := range levels { + levelL := debugLogger.WithOptions(zap.IncreaseLevel(level)) + _globalLevelLogger.Store(level, levelL) + } +} + // Sync flushes any buffered log entries. func Sync() error { - err := L().Sync() - if err != nil { + if err := L().Sync(); err != nil { return err } - return S().Sync() + if err := S().Sync(); err != nil { + return err + } + var reterr error + _globalLevelLogger.Range(func(key, val interface{}) bool { + l := val.(*zap.Logger) + if err := l.Sync(); err != nil { + reterr = err + return false + } + return true + }) + return reterr } diff --git a/internal/log/log_test.go b/internal/log/log_test.go index f975e883e..c63b13f09 100644 --- a/internal/log/log_test.go +++ b/internal/log/log_test.go @@ -104,6 +104,11 @@ func TestInvalidFileConfig(t *testing.T) { _, _, err := InitLogger(conf) assert.Equal(t, "can't use directory as log file name", err.Error()) + + // invalid level + conf = &Config{Level: "debuge", DisableTimestamp: true} + _, _, err = InitLogger(conf) + assert.Error(t, err) } func TestLevelGetterAndSetter(t *testing.T) { @@ -235,3 +240,70 @@ func TestRatedLog(t *testing.T) { assert.True(t, success) Sync() } + +func TestLeveledLogger(t *testing.T) { + ts := newTestLogSpy(t) + conf := &Config{Level: "debug", DisableTimestamp: true, DisableCaller: true} + logger, _, _ := InitTestLogger(ts, conf) + replaceLeveledLoggers(logger) + + debugL().Debug("DEBUG LOG") + debugL().Info("INFO LOG") + debugL().Warn("WARN LOG") + debugL().Error("ERROR LOG") + Sync() + + ts.assertMessageContainAny(`[DEBUG] ["DEBUG LOG"]`) + ts.assertMessageContainAny(`[INFO] ["INFO LOG"]`) + ts.assertMessageContainAny(`[WARN] ["WARN LOG"]`) + ts.assertMessageContainAny(`[ERROR] ["ERROR LOG"]`) + + ts.CleanBuffer() + + infoL().Debug("DEBUG LOG") + infoL().Info("INFO LOG") + infoL().Warn("WARN LOG") + infoL().Error("ERROR LOG") + Sync() + + ts.assertMessagesNotContains(`[DEBUG] ["DEBUG LOG"]`) + ts.assertMessageContainAny(`[INFO] ["INFO LOG"]`) + ts.assertMessageContainAny(`[WARN] ["WARN LOG"]`) + ts.assertMessageContainAny(`[ERROR] ["ERROR LOG"]`) + ts.CleanBuffer() + + warnL().Debug("DEBUG LOG") + warnL().Info("INFO LOG") + warnL().Warn("WARN LOG") + warnL().Error("ERROR LOG") + Sync() + + ts.assertMessagesNotContains(`[DEBUG] ["DEBUG LOG"]`) + ts.assertMessagesNotContains(`[INFO] ["INFO LOG"]`) + ts.assertMessageContainAny(`[WARN] ["WARN LOG"]`) + ts.assertMessageContainAny(`[ERROR] ["ERROR LOG"]`) + ts.CleanBuffer() + + errorL().Debug("DEBUG LOG") + errorL().Info("INFO LOG") + errorL().Warn("WARN LOG") + errorL().Error("ERROR LOG") + Sync() + + ts.assertMessagesNotContains(`[DEBUG] ["DEBUG LOG"]`) + ts.assertMessagesNotContains(`[INFO] ["INFO LOG"]`) + ts.assertMessagesNotContains(`[WARN] ["WARN LOG"]`) + ts.assertMessageContainAny(`[ERROR] ["ERROR LOG"]`) + + ts.CleanBuffer() + + ctx := withLogLevel(context.TODO(), zapcore.DPanicLevel) + assert.Equal(t, Ctx(ctx).Logger, L()) + + // set invalid level + orgLevel := GetLevel() + SetLevel(zapcore.FatalLevel + 1) + assert.Equal(t, ctxL(), L()) + SetLevel(orgLevel) + +} diff --git a/internal/log/meta_logger.go b/internal/log/meta_logger.go new file mode 100644 index 000000000..f3e775835 --- /dev/null +++ b/internal/log/meta_logger.go @@ -0,0 +1,129 @@ +package log + +import ( + "encoding/json" + + "github.com/milvus-io/milvus/internal/metastore/model" + "go.uber.org/zap" +) + +type Operator string + +const ( + // CreateCollection operator + Creator Operator = "create" + Update Operator = "update" + Delete Operator = "delete" + Insert Operator = "insert" + Sealed Operator = "sealed" +) + +type MetaLogger struct { + fields []zap.Field + logger *zap.Logger +} + +func NewMetaLogger() *MetaLogger { + l := infoL() + fields := []zap.Field{zap.Bool("MetaLogInfo", true)} + return &MetaLogger{ + fields: fields, + logger: l, + } +} + +func (m *MetaLogger) WithCollectionMeta(coll *model.Collection) *MetaLogger { + payload, _ := json.Marshal(coll) + m.fields = append(m.fields, zap.Binary("CollectionMeta", payload)) + return m +} + +func (m *MetaLogger) WithIndexMeta(idx *model.Index) *MetaLogger { + payload, _ := json.Marshal(idx) + m.fields = append(m.fields, zap.Binary("IndexMeta", payload)) + return m +} + +func (m *MetaLogger) WithCollectionID(collID int64) *MetaLogger { + m.fields = append(m.fields, zap.Int64("CollectionID", collID)) + return m +} + +func (m *MetaLogger) WithCollectionName(collname string) *MetaLogger { + m.fields = append(m.fields, zap.String("CollectionName", collname)) + return m +} + +func (m *MetaLogger) WithPartitionID(partID int64) *MetaLogger { + m.fields = append(m.fields, zap.Int64("PartitionID", partID)) + return m +} + +func (m *MetaLogger) WithPartitionName(partName string) *MetaLogger { + m.fields = append(m.fields, zap.String("PartitionName", partName)) + return m +} + +func (m *MetaLogger) WithFieldID(fieldID int64) *MetaLogger { + m.fields = append(m.fields, zap.Int64("FieldID", fieldID)) + return m +} + +func (m *MetaLogger) WithFieldName(fieldName string) *MetaLogger { + m.fields = append(m.fields, zap.String("FieldName", fieldName)) + return m +} + +func (m *MetaLogger) WithIndexID(idxID int64) *MetaLogger { + m.fields = append(m.fields, zap.Int64("IndexID", idxID)) + return m +} + +func (m *MetaLogger) WithIndexName(idxName string) *MetaLogger { + m.fields = append(m.fields, zap.String("IndexName", idxName)) + return m +} + +func (m *MetaLogger) WithBuildID(buildID int64) *MetaLogger { + m.fields = append(m.fields, zap.Int64("BuildID", buildID)) + return m +} + +func (m *MetaLogger) WithBuildIDS(buildIDs []int64) *MetaLogger { + m.fields = append(m.fields, zap.Int64s("BuildIDs", buildIDs)) + return m +} + +func (m *MetaLogger) WithSegmentID(segID int64) *MetaLogger { + m.fields = append(m.fields, zap.Int64("SegmentID", segID)) + return m +} + +func (m *MetaLogger) WithIndexFiles(files []string) *MetaLogger { + m.fields = append(m.fields, zap.Strings("IndexFiles", files)) + return m +} + +func (m *MetaLogger) WithIndexVersion(version int64) *MetaLogger { + m.fields = append(m.fields, zap.Int64("IndexVersion", version)) + return m +} + +func (m *MetaLogger) WithTSO(tso uint64) *MetaLogger { + m.fields = append(m.fields, zap.Uint64("TSO", tso)) + return m +} + +func (m *MetaLogger) WithAlias(alias string) *MetaLogger { + m.fields = append(m.fields, zap.String("Alias", alias)) + return m +} + +func (m *MetaLogger) WithOperation(op MetaOperation) *MetaLogger { + m.fields = append(m.fields, zap.Int("Operation", int(op))) + return m +} + +func (m *MetaLogger) Info() { + m.logger.Info("", m.fields...) +} diff --git a/internal/log/meta_logger_test.go b/internal/log/meta_logger_test.go new file mode 100644 index 000000000..acf4a329d --- /dev/null +++ b/internal/log/meta_logger_test.go @@ -0,0 +1,52 @@ +package log + +import ( + "testing" + + "github.com/milvus-io/milvus/internal/metastore/model" +) + +func TestMetaLogger(t *testing.T) { + ts := newTestLogSpy(t) + conf := &Config{Level: "debug", DisableTimestamp: true, DisableCaller: true} + logger, _, _ := InitTestLogger(ts, conf) + replaceLeveledLoggers(logger) + + NewMetaLogger().WithCollectionID(0). + WithIndexMeta(&model.Index{}). + WithCollectionMeta(&model.Collection{}). + WithCollectionName("coll"). + WithPartitionID(0). + WithPartitionName("part"). + WithFieldID(0). + WithFieldName("field"). + WithIndexID(0). + WithIndexName("idx"). + WithBuildID(0). + WithBuildIDS([]int64{0, 0}). + WithSegmentID(0). + WithIndexFiles([]string{"idx", "idx"}). + WithIndexVersion(0). + WithTSO(0). + WithAlias("alias"). + WithOperation(DropCollection).Info() + + ts.assertMessagesContains("CollectionID=0") + ts.assertMessagesContains("CollectionMeta=eyJUZW5hbnRJRCI6IiIsIkNvbGxlY3Rpb25JRCI6MCwiUGFydGl0aW9ucyI6bnVsbCwiTmFtZSI6IiIsIkRlc2NyaXB0aW9uIjoiIiwiQXV0b0lEIjpmYWxzZSwiRmllbGRzIjpudWxsLCJGaWVsZElEVG9JbmRleElEIjpudWxsLCJWaXJ0dWFsQ2hhbm5lbE5hbWVzIjpudWxsLCJQaHlzaWNhbENoYW5uZWxOYW1lcyI6bnVsbCwiU2hhcmRzTnVtIjowLCJTdGFydFBvc2l0aW9ucyI6bnVsbCwiQ3JlYXRlVGltZSI6MCwiQ29uc2lzdGVuY3lMZXZlbCI6MCwiQWxpYXNlcyI6bnVsbCwiRXh0cmEiOm51bGx9") + ts.assertMessagesContains("IndexMeta=eyJDb2xsZWN0aW9uSUQiOjAsIkZpZWxkSUQiOjAsIkluZGV4SUQiOjAsIkluZGV4TmFtZSI6IiIsIklzRGVsZXRlZCI6ZmFsc2UsIkNyZWF0ZVRpbWUiOjAsIkluZGV4UGFyYW1zIjpudWxsLCJTZWdtZW50SW5kZXhlcyI6bnVsbCwiRXh0cmEiOm51bGx9") + ts.assertMessagesContains("CollectionName=coll") + ts.assertMessagesContains("PartitionID=0") + ts.assertMessagesContains("PartitionName=part") + ts.assertMessagesContains("FieldID=0") + ts.assertMessagesContains("FieldName=field") + ts.assertMessagesContains("IndexID=0") + ts.assertMessagesContains("IndexName=idx") + ts.assertMessagesContains("BuildID=0") + ts.assertMessagesContains("\"[0,0]\"") + ts.assertMessagesContains("SegmentID=0") + ts.assertMessagesContains("IndexFiles=\"[idx,idx]\"") + ts.assertMessagesContains("IndexVersion=0") + ts.assertMessagesContains("TSO=0") + ts.assertMessagesContains("Alias=alias") + ts.assertMessagesContains("Operation=1") +} diff --git a/internal/log/meta_ops.go b/internal/log/meta_ops.go new file mode 100644 index 000000000..34e63b709 --- /dev/null +++ b/internal/log/meta_ops.go @@ -0,0 +1,17 @@ +package log + +type MetaOperation int + +const ( + InvalidMetaOperation MetaOperation = iota - 1 + CreateCollection + DropCollection + CreateCollectionAlias + AlterCollectionAlias + DropCollectionAlias + CreatePartition + DropPartition + CreateIndex + DropIndex + BuildSegmentIndex +) diff --git a/internal/log/mlogger.go b/internal/log/mlogger.go new file mode 100644 index 000000000..7166ee736 --- /dev/null +++ b/internal/log/mlogger.go @@ -0,0 +1,31 @@ +package log + +import "go.uber.org/zap" + +type MLogger struct { + *zap.Logger +} + +func (l *MLogger) RatedDebug(cost float64, msg string, fields ...zap.Field) bool { + if R().CheckCredit(cost) { + l.Debug(msg, fields...) + return true + } + return false +} + +func (l *MLogger) RatedInfo(cost float64, msg string, fields ...zap.Field) bool { + if R().CheckCredit(cost) { + l.Info(msg, fields...) + return true + } + return false +} + +func (l *MLogger) RatedWarn(cost float64, msg string, fields ...zap.Field) bool { + if R().CheckCredit(cost) { + l.Warn(msg, fields...) + return true + } + return false +} diff --git a/internal/log/mlogger_test.go b/internal/log/mlogger_test.go new file mode 100644 index 000000000..fcdf9339c --- /dev/null +++ b/internal/log/mlogger_test.go @@ -0,0 +1,102 @@ +package log + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.uber.org/zap" +) + +func TestExporterV2(t *testing.T) { + ts := newTestLogSpy(t) + conf := &Config{Level: "debug", DisableTimestamp: true} + logger, properties, _ := InitTestLogger(ts, conf) + ReplaceGlobals(logger, properties) + + replaceLeveledLoggers(logger) + ctx := WithTraceID(context.TODO(), "mock-trace") + + Ctx(ctx).Info("Info Test") + Ctx(ctx).Debug("Debug Test") + Ctx(ctx).Warn("Warn Test") + Ctx(ctx).Error("Error Test") + Ctx(ctx).Sync() + + ts.assertMessagesContains("log/mlogger_test.go") + ts.assertMessagesContains("traceID=mock-trace") + + ts.CleanBuffer() + Ctx(nil).Info("empty context") + ts.assertMessagesNotContains("traceID") + + fieldCtx := WithFields(ctx, zap.String("field", "test")) + reqCtx := WithReqID(ctx, 123456) + modCtx := WithModule(ctx, "test") + + Ctx(fieldCtx).Info("Info Test") + Ctx(fieldCtx).Sync() + ts.assertLastMessageContains("field=test") + ts.assertLastMessageContains("traceID=mock-trace") + + Ctx(reqCtx).Info("Info Test") + Ctx(reqCtx).Sync() + ts.assertLastMessageContains("reqID=123456") + ts.assertLastMessageContains("traceID=mock-trace") + ts.assertLastMessageNotContains("field=test") + + Ctx(modCtx).Info("Info Test") + Ctx(modCtx).Sync() + ts.assertLastMessageContains("module=test") + ts.assertLastMessageContains("traceID=mock-trace") + ts.assertLastMessageNotContains("reqID=123456") + ts.assertLastMessageNotContains("field=test") +} + +func TestMLoggerRatedLog(t *testing.T) { + ts := newTestLogSpy(t) + conf := &Config{Level: "debug", DisableTimestamp: true} + logger, p, _ := InitTestLogger(ts, conf) + ReplaceGlobals(logger, p) + + ctx := WithTraceID(context.TODO(), "test-trace") + time.Sleep(time.Duration(1) * time.Second) + success := Ctx(ctx).RatedDebug(1.0, "debug test") + assert.True(t, success) + + time.Sleep(time.Duration(1) * time.Second) + success = Ctx(ctx).RatedDebug(100.0, "debug test") + assert.False(t, success) + + time.Sleep(time.Duration(1) * time.Second) + success = Ctx(ctx).RatedInfo(1.0, "info test") + assert.True(t, success) + + time.Sleep(time.Duration(1) * time.Second) + success = Ctx(ctx).RatedWarn(1.0, "warn test") + assert.True(t, success) + + time.Sleep(time.Duration(1) * time.Second) + success = Ctx(ctx).RatedWarn(100.0, "warn test") + assert.False(t, success) + + time.Sleep(time.Duration(1) * time.Second) + success = Ctx(ctx).RatedInfo(100.0, "info test") + assert.False(t, success) + + successNum := 0 + for i := 0; i < 1000; i++ { + if Ctx(ctx).RatedInfo(1.0, "info test") { + successNum++ + } + time.Sleep(time.Duration(10) * time.Millisecond) + } + assert.True(t, successNum < 1000) + assert.True(t, successNum > 10) + + time.Sleep(time.Duration(3) * time.Second) + success = Ctx(ctx).RatedInfo(3.0, "info test") + assert.True(t, success) + Ctx(ctx).Sync() +} diff --git a/internal/log/zap_log_test.go b/internal/log/zap_log_test.go index 9b79b9510..c1715c190 100644 --- a/internal/log/zap_log_test.go +++ b/internal/log/zap_log_test.go @@ -292,6 +292,10 @@ func (t *testLogSpy) FailNow() { t.TB.FailNow() } +func (t *testLogSpy) CleanBuffer() { + t.Messages = []string{} +} + func (t *testLogSpy) Logf(format string, args ...interface{}) { // Log messages are in the format, // @@ -320,3 +324,40 @@ func (t *testLogSpy) assertMessagesNotContains(msg string) { assert.NotContains(t.TB, actualMsg, msg) } } + +func (t *testLogSpy) assertLastMessageContains(msg string) { + if len(t.Messages) == 0 { + assert.Error(t.TB, fmt.Errorf("empty message")) + } + assert.Contains(t.TB, t.Messages[len(t.Messages)-1], msg) +} + +func (t *testLogSpy) assertLastMessageNotContains(msg string) { + if len(t.Messages) == 0 { + assert.Error(t.TB, fmt.Errorf("empty message")) + } + assert.NotContains(t.TB, t.Messages[len(t.Messages)-1], msg) +} + +func (t *testLogSpy) assertMessageContainAny(msg string) { + found := false + for _, actualMsg := range t.Messages { + if strings.Contains(actualMsg, msg) { + found = true + } + } + assert.True(t, found, "can't found any message contain `%s`, all messages: %v", msg, fmtMsgs(t.Messages)) +} + +func fmtMsgs(messages []string) string { + builder := strings.Builder{} + builder.WriteString("[") + for i, msg := range messages { + if i == len(messages)-1 { + builder.WriteString(fmt.Sprintf("`%s]`", msg)) + } else { + builder.WriteString(fmt.Sprintf("`%s`, ", msg)) + } + } + return builder.String() +} diff --git a/internal/metastore/model/collection_test.go b/internal/metastore/model/collection_test.go index 5e7971cb1..56bea8565 100644 --- a/internal/metastore/model/collection_test.go +++ b/internal/metastore/model/collection_test.go @@ -9,20 +9,18 @@ import ( pb "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/proto/schemapb" - "github.com/milvus-io/milvus/internal/util/typeutil" - "github.com/milvus-io/milvus/internal/proto/commonpb" ) var ( - colID = typeutil.UniqueID(1) - colName = "c" - fieldID = typeutil.UniqueID(101) - fieldName = "field110" - partID = typeutil.UniqueID(20) - partName = "testPart" - tenantID = "tenant-1" - typeParams = []*commonpb.KeyValuePair{ + colID int64 = 1 + colName = "c" + fieldID int64 = 101 + fieldName = "field110" + partID int64 = 20 + partName = "testPart" + tenantID = "tenant-1" + typeParams = []*commonpb.KeyValuePair{ { Key: "field110-k1", Value: "field110-v1", diff --git a/internal/metastore/model/index_test.go b/internal/metastore/model/index_test.go index 2a122a6c8..a57c33e5b 100644 --- a/internal/metastore/model/index_test.go +++ b/internal/metastore/model/index_test.go @@ -7,13 +7,12 @@ import ( "github.com/milvus-io/milvus/internal/proto/commonpb" pb "github.com/milvus-io/milvus/internal/proto/etcdpb" - "github.com/milvus-io/milvus/internal/util/typeutil" ) var ( - indexID = typeutil.UniqueID(1) - indexName = "idx" - indexParams = []*commonpb.KeyValuePair{ + indexID int64 = 1 + indexName = "idx" + indexParams = []*commonpb.KeyValuePair{ { Key: "field110-i1", Value: "field110-v1", diff --git a/internal/metastore/model/segment_test.go b/internal/metastore/model/segment_test.go index 9c4b4c4d3..16444c96a 100644 --- a/internal/metastore/model/segment_test.go +++ b/internal/metastore/model/segment_test.go @@ -4,13 +4,12 @@ import ( "testing" pb "github.com/milvus-io/milvus/internal/proto/etcdpb" - "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/stretchr/testify/assert" ) var ( - segmentID = typeutil.UniqueID(1) - buildID = typeutil.UniqueID(1) + segmentID int64 = 1 + buildID int64 = 1 segmentIdxPb = &pb.SegmentIndexInfo{ CollectionID: colID, diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index 1e9bbb7fd..bb49cdac1 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -2520,7 +2520,6 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) sp, ctx := trace.StartSpanFromContextWithOperationName(ctx, "Proxy-Search") defer sp.Finish() - traceID, _, _ := trace.InfoFromSpan(sp) qt := &searchTask{ ctx: ctx, @@ -2541,9 +2540,8 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) travelTs := request.TravelTimestamp guaranteeTs := request.GuaranteeTimestamp - log.Debug( + log.Ctx(ctx).Info( rpcReceived(method), - zap.String("traceID", traceID), zap.String("role", typeutil.ProxyRole), zap.String("db", request.DbName), zap.String("collection", request.CollectionName), @@ -2556,10 +2554,9 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) zap.Uint64("guarantee_timestamp", guaranteeTs)) if err := node.sched.dqQueue.Enqueue(qt); err != nil { - log.Warn( + log.Ctx(ctx).Warn( rpcFailedToEnqueue(method), zap.Error(err), - zap.String("traceID", traceID), zap.String("role", typeutil.ProxyRole), zap.String("db", request.DbName), zap.String("collection", request.CollectionName), @@ -2581,11 +2578,10 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) }, }, nil } - tr.Record("search request enqueue") + tr.CtxRecord(ctx, "search request enqueue") - log.Debug( + log.Ctx(ctx).Debug( rpcEnqueued(method), - zap.String("traceID", traceID), zap.String("role", typeutil.ProxyRole), zap.Int64("msgID", qt.ID()), zap.Uint64("timestamp", qt.Base.Timestamp), @@ -2600,10 +2596,9 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) zap.Uint64("guarantee_timestamp", guaranteeTs)) if err := qt.WaitToFinish(); err != nil { - log.Warn( + log.Ctx(ctx).Warn( rpcFailedToWaitToFinish(method), zap.Error(err), - zap.String("traceID", traceID), zap.String("role", typeutil.ProxyRole), zap.Int64("msgID", qt.ID()), zap.String("db", request.DbName), @@ -2627,12 +2622,11 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) }, nil } - span := tr.Record("wait search result") + span := tr.CtxRecord(ctx, "wait search result") metrics.ProxyWaitForSearchResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), metrics.SearchLabel).Observe(float64(span.Milliseconds())) - log.Debug( + log.Ctx(ctx).Debug( rpcDone(method), - zap.String("traceID", traceID), zap.String("role", typeutil.ProxyRole), zap.Int64("msgID", qt.ID()), zap.String("db", request.DbName), @@ -2763,7 +2757,6 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (* sp, ctx := trace.StartSpanFromContextWithOperationName(ctx, "Proxy-Query") defer sp.Finish() - traceID, _, _ := trace.InfoFromSpan(sp) tr := timerecord.NewTimeRecorder("Query") qt := &queryTask{ @@ -2787,9 +2780,8 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (* metrics.ProxyDQLFunctionCall.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), method, metrics.TotalLabel).Inc() - log.Debug( + log.Ctx(ctx).Info( rpcReceived(method), - zap.String("traceID", traceID), zap.String("role", typeutil.ProxyRole), zap.String("db", request.DbName), zap.String("collection", request.CollectionName), @@ -2800,10 +2792,9 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (* zap.Uint64("guarantee_timestamp", request.GuaranteeTimestamp)) if err := node.sched.dqQueue.Enqueue(qt); err != nil { - log.Warn( + log.Ctx(ctx).Warn( rpcFailedToEnqueue(method), zap.Error(err), - zap.String("traceID", traceID), zap.String("role", typeutil.ProxyRole), zap.String("db", request.DbName), zap.String("collection", request.CollectionName), @@ -2819,11 +2810,10 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (* }, }, nil } - tr.Record("query request enqueue") + tr.CtxRecord(ctx, "query request enqueue") - log.Debug( + log.Ctx(ctx).Debug( rpcEnqueued(method), - zap.String("traceID", traceID), zap.String("role", typeutil.ProxyRole), zap.Int64("msgID", qt.ID()), zap.String("db", request.DbName), @@ -2831,10 +2821,9 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (* zap.Strings("partitions", request.PartitionNames)) if err := qt.WaitToFinish(); err != nil { - log.Warn( + log.Ctx(ctx).Warn( rpcFailedToWaitToFinish(method), zap.Error(err), - zap.String("traceID", traceID), zap.String("role", typeutil.ProxyRole), zap.Int64("msgID", qt.ID()), zap.String("db", request.DbName), @@ -2851,12 +2840,11 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (* }, }, nil } - span := tr.Record("wait query result") + span := tr.CtxRecord(ctx, "wait query result") metrics.ProxyWaitForSearchResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), metrics.QueryLabel).Observe(float64(span.Milliseconds())) - log.Debug( + log.Ctx(ctx).Debug( rpcDone(method), - zap.String("traceID", traceID), zap.String("role", typeutil.ProxyRole), zap.Int64("msgID", qt.ID()), zap.String("db", request.DbName), diff --git a/internal/proxy/meta_cache_test.go b/internal/proxy/meta_cache_test.go index 50a9c02d7..0466a12ed 100644 --- a/internal/proxy/meta_cache_test.go +++ b/internal/proxy/meta_cache_test.go @@ -394,12 +394,12 @@ func TestMetaCache_GetShards(t *testing.T) { assert.NoError(t, err) assert.NotEmpty(t, shards) assert.Equal(t, 1, len(shards)) - assert.Equal(t, 3, len(shards["channel-1"])) // get from cache qc.validShardLeaders = false shards, err = globalMetaCache.GetShards(ctx, true, collectionName) + assert.NoError(t, err) assert.NotEmpty(t, shards) assert.Equal(t, 1, len(shards)) diff --git a/internal/proxy/task_policies.go b/internal/proxy/task_policies.go index 829c40309..0c4bece23 100644 --- a/internal/proxy/task_policies.go +++ b/internal/proxy/task_policies.go @@ -40,7 +40,7 @@ func groupShardleadersWithSameQueryNode( // check if all leaders were checked for dml, idx := range nexts { if idx >= len(shard2leaders[dml]) { - log.Warn("no shard leaders were available", + log.Ctx(ctx).Warn("no shard leaders were available", zap.String("channel", dml), zap.String("leaders", fmt.Sprintf("%v", shard2leaders[dml]))) if e, ok := errSet[dml]; ok { @@ -59,7 +59,7 @@ func groupShardleadersWithSameQueryNode( if _, ok := qnSet[nodeInfo.nodeID]; !ok { qn, err := mgr.GetClient(ctx, nodeInfo.nodeID) if err != nil { - log.Warn("failed to get shard leader", zap.Int64("nodeID", nodeInfo.nodeID), zap.Error(err)) + log.Ctx(ctx).Warn("failed to get shard leader", zap.Int64("nodeID", nodeInfo.nodeID), zap.Error(err)) // if get client failed, just record error and wait for next round to get client and do query errSet[dml] = err continue @@ -111,7 +111,7 @@ func mergeRoundRobinPolicy( go func() { defer wg.Done() if err := query(ctx, nodeID, qn, channels); err != nil { - log.Warn("failed to do query with node", zap.Int64("nodeID", nodeID), + log.Ctx(ctx).Warn("failed to do query with node", zap.Int64("nodeID", nodeID), zap.Strings("dmlChannels", channels), zap.Error(err)) mu.Lock() defer mu.Unlock() @@ -138,7 +138,7 @@ func mergeRoundRobinPolicy( nextSet[dml] = dml2leaders[dml][idx].nodeID } } - log.Warn("retry another query node with round robin", zap.Any("Nexts", nextSet)) + log.Ctx(ctx).Warn("retry another query node with round robin", zap.Any("Nexts", nextSet)) } } return nil diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index 3804d2038..d2beb26b1 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -110,44 +110,44 @@ func (t *queryTask) PreExecute(ctx context.Context) error { collectionName := t.request.CollectionName t.collectionName = collectionName if err := validateCollectionName(collectionName); err != nil { - log.Warn("Invalid collection name.", zap.String("collectionName", collectionName), + log.Ctx(ctx).Warn("Invalid collection name.", zap.String("collectionName", collectionName), zap.Int64("msgID", t.ID()), zap.String("requestType", "query")) return err } - log.Info("Validate collection name.", zap.Any("collectionName", collectionName), + log.Ctx(ctx).Debug("Validate collection name.", zap.Any("collectionName", collectionName), zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) collID, err := globalMetaCache.GetCollectionID(ctx, collectionName) if err != nil { - log.Debug("Failed to get collection id.", zap.Any("collectionName", collectionName), + log.Ctx(ctx).Warn("Failed to get collection id.", zap.Any("collectionName", collectionName), zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) return err } t.CollectionID = collID - log.Info("Get collection ID by name", + log.Ctx(ctx).Debug("Get collection ID by name", zap.Int64("collectionID", t.CollectionID), zap.String("collection name", collectionName), zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) for _, tag := range t.request.PartitionNames { if err := validatePartitionTag(tag, false); err != nil { - log.Warn("invalid partition name", zap.String("partition name", tag), + log.Ctx(ctx).Warn("invalid partition name", zap.String("partition name", tag), zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) return err } } - log.Debug("Validate partition names.", + log.Ctx(ctx).Debug("Validate partition names.", zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) t.RetrieveRequest.PartitionIDs, err = getPartitionIDs(ctx, collectionName, t.request.GetPartitionNames()) if err != nil { - log.Warn("failed to get partitions in collection.", zap.String("collection name", collectionName), + log.Ctx(ctx).Warn("failed to get partitions in collection.", zap.String("collection name", collectionName), zap.Error(err), zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) return err } - log.Debug("Get partitions in collection.", zap.Any("collectionName", collectionName), + log.Ctx(ctx).Debug("Get partitions in collection.", zap.Any("collectionName", collectionName), zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) loaded, err := checkIfLoaded(ctx, t.qc, collectionName, t.RetrieveRequest.GetPartitionIDs()) @@ -182,7 +182,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error { if err != nil { return err } - log.Debug("translate output fields", zap.Any("OutputFields", t.request.OutputFields), + log.Ctx(ctx).Debug("translate output fields", zap.Any("OutputFields", t.request.OutputFields), zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) outputFieldIDs, err := translateToOutputFieldIDs(t.request.GetOutputFields(), schema) @@ -191,7 +191,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error { } t.RetrieveRequest.OutputFieldsId = outputFieldIDs plan.OutputFieldIds = outputFieldIDs - log.Debug("translate output fields to field ids", zap.Any("OutputFieldsID", t.OutputFieldsId), + log.Ctx(ctx).Debug("translate output fields to field ids", zap.Any("OutputFieldsID", t.OutputFieldsId), zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) t.RetrieveRequest.SerializedExprPlan, err = proto.Marshal(plan) @@ -219,7 +219,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error { } t.DbID = 0 // TODO - log.Info("Query PreExecute done.", + log.Ctx(ctx).Debug("Query PreExecute done.", zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"), zap.Uint64("guarantee_ts", guaranteeTs), zap.Uint64("travel_ts", t.GetTravelTimestamp()), zap.Uint64("timeout_ts", t.GetTimeoutTimestamp())) @@ -228,7 +228,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error { func (t *queryTask) Execute(ctx context.Context) error { tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute query %d", t.ID())) - defer tr.Elapse("done") + defer tr.CtxElapse(ctx, "done") executeQuery := func(withCache bool) error { shards, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName) @@ -246,7 +246,7 @@ func (t *queryTask) Execute(ctx context.Context) error { err := executeQuery(WithCache) if errors.Is(err, errInvalidShardLeaders) || funcutil.IsGrpcErr(err) || errors.Is(err, grpcclient.ErrConnect) { - log.Warn("invalid shard leaders cache, updating shardleader caches and retry search", + log.Ctx(ctx).Warn("invalid shard leaders cache, updating shardleader caches and retry search", zap.Int64("msgID", t.ID()), zap.Error(err)) return executeQuery(WithoutCache) } @@ -254,7 +254,7 @@ func (t *queryTask) Execute(ctx context.Context) error { return fmt.Errorf("fail to search on all shard leaders, err=%s", err.Error()) } - log.Info("Query Execute done.", + log.Ctx(ctx).Debug("Query Execute done.", zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) return nil } @@ -262,27 +262,27 @@ func (t *queryTask) Execute(ctx context.Context) error { func (t *queryTask) PostExecute(ctx context.Context) error { tr := timerecord.NewTimeRecorder("queryTask PostExecute") defer func() { - tr.Elapse("done") + tr.CtxElapse(ctx, "done") }() var err error select { case <-t.TraceCtx().Done(): - log.Warn("proxy", zap.Int64("Query: wait to finish failed, timeout!, msgID:", t.ID())) + log.Ctx(ctx).Warn("proxy", zap.Int64("Query: wait to finish failed, timeout!, msgID:", t.ID())) return nil default: - log.Debug("all queries are finished or canceled", zap.Int64("msgID", t.ID())) + log.Ctx(ctx).Debug("all queries are finished or canceled", zap.Int64("msgID", t.ID())) close(t.resultBuf) for res := range t.resultBuf { t.toReduceResults = append(t.toReduceResults, res) - log.Debug("proxy receives one query result", zap.Int64("sourceID", res.GetBase().GetSourceID()), zap.Any("msgID", t.ID())) + log.Ctx(ctx).Debug("proxy receives one query result", zap.Int64("sourceID", res.GetBase().GetSourceID()), zap.Any("msgID", t.ID())) } } metrics.ProxyDecodeResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), metrics.QueryLabel).Observe(0.0) - tr.Record("reduceResultStart") - t.result, err = mergeRetrieveResults(t.toReduceResults) + tr.CtxRecord(ctx, "reduceResultStart") + t.result, err = mergeRetrieveResults(ctx, t.toReduceResults) if err != nil { return err } @@ -294,7 +294,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error { ErrorCode: commonpb.ErrorCode_Success, } } else { - log.Info("Query result is nil", zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) + log.Ctx(ctx).Warn("Query result is nil", zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) t.result.Status = &commonpb.Status{ ErrorCode: commonpb.ErrorCode_EmptyCollection, Reason: "empty collection", // TODO @@ -315,7 +315,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error { } } } - log.Info("Query PostExecute done", zap.Int64("msgID", t.ID()), zap.String("requestType", "query")) + log.Ctx(ctx).Debug("Query PostExecute done", zap.Int64("msgID", t.ID()), zap.String("requestType", "query")) return nil } @@ -328,21 +328,21 @@ func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.Query result, err := qn.Query(ctx, req) if err != nil { - log.Warn("QueryNode query return error", zap.Int64("msgID", t.ID()), + log.Ctx(ctx).Warn("QueryNode query return error", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs), zap.Error(err)) return err } if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader { - log.Warn("QueryNode is not shardLeader", zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs)) + log.Ctx(ctx).Warn("QueryNode is not shardLeader", zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs)) return errInvalidShardLeaders } if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - log.Warn("QueryNode query result error", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID), + log.Ctx(ctx).Warn("QueryNode query result error", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID), zap.String("reason", result.GetStatus().GetReason())) return fmt.Errorf("fail to Query, QueryNode ID = %d, reason=%s", nodeID, result.GetStatus().GetReason()) } - log.Debug("get query result", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID), zap.Strings("channelIDs", channelIDs)) + log.Ctx(ctx).Debug("get query result", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID), zap.Strings("channelIDs", channelIDs)) t.resultBuf <- result return nil } @@ -360,7 +360,7 @@ func IDs2Expr(fieldName string, ids *schemapb.IDs) string { return fieldName + " in [ " + idsStr + " ]" } -func mergeRetrieveResults(retrieveResults []*internalpb.RetrieveResults) (*milvuspb.QueryResults, error) { +func mergeRetrieveResults(ctx context.Context, retrieveResults []*internalpb.RetrieveResults) (*milvuspb.QueryResults, error) { var ret *milvuspb.QueryResults var skipDupCnt int64 var idSet = make(map[interface{}]struct{}) @@ -394,7 +394,7 @@ func mergeRetrieveResults(retrieveResults []*internalpb.RetrieveResults) (*milvu } } } - log.Debug("skip duplicated query result", zap.Int64("count", skipDupCnt)) + log.Ctx(ctx).Debug("skip duplicated query result", zap.Int64("count", skipDupCnt)) if ret == nil { ret = &milvuspb.QueryResults{ diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index af2fa8617..da6fd1bab 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -210,7 +210,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error { if err != nil { return err } - log.Debug("translate output fields", zap.Int64("msgID", t.ID()), + log.Ctx(ctx).Debug("translate output fields", zap.Int64("msgID", t.ID()), zap.Strings("output fields", t.request.GetOutputFields())) if t.request.GetDslType() == commonpb.DslType_BoolExprV1 { @@ -226,12 +226,12 @@ func (t *searchTask) PreExecute(ctx context.Context) error { plan, err := planparserv2.CreateSearchPlan(t.schema, t.request.Dsl, annsField, queryInfo) if err != nil { - log.Debug("failed to create query plan", zap.Error(err), zap.Int64("msgID", t.ID()), + log.Ctx(ctx).Warn("failed to create query plan", zap.Error(err), zap.Int64("msgID", t.ID()), zap.String("dsl", t.request.Dsl), // may be very large if large term passed. zap.String("anns field", annsField), zap.Any("query info", queryInfo)) return fmt.Errorf("failed to create query plan: %v", err) } - log.Debug("create query plan", zap.Int64("msgID", t.ID()), + log.Ctx(ctx).Debug("create query plan", zap.Int64("msgID", t.ID()), zap.String("dsl", t.request.Dsl), // may be very large if large term passed. zap.String("anns field", annsField), zap.Any("query info", queryInfo)) @@ -253,7 +253,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error { if err := validateTopK(queryInfo.GetTopk()); err != nil { return err } - log.Debug("Proxy::searchTask::PreExecute", zap.Int64("msgID", t.ID()), + log.Ctx(ctx).Debug("Proxy::searchTask::PreExecute", zap.Int64("msgID", t.ID()), zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()), zap.String("plan", plan.String())) // may be very large if large term passed. } @@ -282,7 +282,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error { if t.SearchRequest.Nq, err = getNq(t.request); err != nil { return err } - log.Info("search PreExecute done.", zap.Int64("msgID", t.ID()), + log.Ctx(ctx).Debug("search PreExecute done.", zap.Int64("msgID", t.ID()), zap.Uint64("travel_ts", travelTimestamp), zap.Uint64("guarantee_ts", guaranteeTs), zap.Uint64("timeout_ts", t.SearchRequest.GetTimeoutTimestamp())) @@ -294,7 +294,7 @@ func (t *searchTask) Execute(ctx context.Context) error { defer sp.Finish() tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute search %d", t.ID())) - defer tr.Elapse("done") + defer tr.CtxElapse(ctx, "done") executeSearch := func(withCache bool) error { shard2Leaders, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName) @@ -304,7 +304,7 @@ func (t *searchTask) Execute(ctx context.Context) error { t.resultBuf = make(chan *internalpb.SearchResults, len(shard2Leaders)) t.toReduceResults = make([]*internalpb.SearchResults, 0, len(shard2Leaders)) if err := t.searchShardPolicy(ctx, t.shardMgr, t.searchShard, shard2Leaders); err != nil { - log.Warn("failed to do search", zap.Error(err), zap.String("Shards", fmt.Sprintf("%v", shard2Leaders))) + log.Ctx(ctx).Warn("failed to do search", zap.Error(err), zap.String("Shards", fmt.Sprintf("%v", shard2Leaders))) return err } return nil @@ -312,7 +312,7 @@ func (t *searchTask) Execute(ctx context.Context) error { err := executeSearch(WithCache) if errors.Is(err, errInvalidShardLeaders) || funcutil.IsGrpcErr(err) || errors.Is(err, grpcclient.ErrConnect) { - log.Warn("first search failed, updating shardleader caches and retry search", + log.Ctx(ctx).Warn("first search failed, updating shardleader caches and retry search", zap.Int64("msgID", t.ID()), zap.Error(err)) return executeSearch(WithoutCache) } @@ -320,7 +320,7 @@ func (t *searchTask) Execute(ctx context.Context) error { return fmt.Errorf("fail to search on all shard leaders, err=%v", err) } - log.Debug("Search Execute done.", zap.Int64("msgID", t.ID())) + log.Ctx(ctx).Debug("Search Execute done.", zap.Int64("msgID", t.ID())) return nil } @@ -329,34 +329,34 @@ func (t *searchTask) PostExecute(ctx context.Context) error { defer sp.Finish() tr := timerecord.NewTimeRecorder("searchTask PostExecute") defer func() { - tr.Elapse("done") + tr.CtxElapse(ctx, "done") }() select { // in case timeout happened case <-t.TraceCtx().Done(): - log.Debug("wait to finish timeout!", zap.Int64("msgID", t.ID())) + log.Ctx(ctx).Debug("wait to finish timeout!", zap.Int64("msgID", t.ID())) return nil default: - log.Debug("all searches are finished or canceled", zap.Int64("msgID", t.ID())) + log.Ctx(ctx).Debug("all searches are finished or canceled", zap.Int64("msgID", t.ID())) close(t.resultBuf) for res := range t.resultBuf { t.toReduceResults = append(t.toReduceResults, res) - log.Debug("proxy receives one query result", zap.Int64("sourceID", res.GetBase().GetSourceID()), zap.Int64("msgID", t.ID())) + log.Ctx(ctx).Debug("proxy receives one query result", zap.Int64("sourceID", res.GetBase().GetSourceID()), zap.Int64("msgID", t.ID())) } } - tr.Record("decodeResultStart") - validSearchResults, err := decodeSearchResults(t.toReduceResults) + tr.CtxRecord(ctx, "decodeResultStart") + validSearchResults, err := decodeSearchResults(ctx, t.toReduceResults) if err != nil { return err } metrics.ProxyDecodeResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), metrics.SearchLabel).Observe(float64(tr.RecordSpan().Milliseconds())) - log.Debug("proxy search post execute stage 2", zap.Int64("msgID", t.ID()), + log.Ctx(ctx).Debug("proxy search post execute stage 2", zap.Int64("msgID", t.ID()), zap.Int("len(validSearchResults)", len(validSearchResults))) if len(validSearchResults) <= 0 { - log.Warn("search result is empty", zap.Int64("msgID", t.ID())) + log.Ctx(ctx).Warn("search result is empty", zap.Int64("msgID", t.ID())) t.result = &milvuspb.SearchResults{ Status: &commonpb.Status{ @@ -375,12 +375,13 @@ func (t *searchTask) PostExecute(ctx context.Context) error { return nil } - tr.Record("reduceResultStart") + tr.CtxRecord(ctx, "reduceResultStart") primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(t.schema) if err != nil { return err } - t.result, err = reduceSearchResultData(validSearchResults, t.toReduceResults[0].NumQueries, t.toReduceResults[0].TopK, t.toReduceResults[0].MetricType, primaryFieldSchema.DataType) + t.result, err = reduceSearchResultData(ctx, validSearchResults, t.toReduceResults[0].NumQueries, + t.toReduceResults[0].TopK, t.toReduceResults[0].MetricType, primaryFieldSchema.DataType) if err != nil { return err } @@ -403,7 +404,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error { } } } - log.Info("Search post execute done", zap.Int64("msgID", t.ID())) + log.Ctx(ctx).Debug("Search post execute done", zap.Int64("msgID", t.ID())) return nil } @@ -415,17 +416,17 @@ func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.Que } result, err := qn.Search(ctx, req) if err != nil { - log.Warn("QueryNode search return error", zap.Int64("msgID", t.ID()), + log.Ctx(ctx).Warn("QueryNode search return error", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs), zap.Error(err)) return err } if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader { - log.Warn("QueryNode is not shardLeader", zap.Int64("msgID", t.ID()), + log.Ctx(ctx).Warn("QueryNode is not shardLeader", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs)) return errInvalidShardLeaders } if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - log.Warn("QueryNode search result error", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID), + log.Ctx(ctx).Warn("QueryNode search result error", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID), zap.String("reason", result.GetStatus().GetReason())) return fmt.Errorf("fail to Search, QueryNode ID=%d, reason=%s", nodeID, result.GetStatus().GetReason()) } @@ -482,7 +483,7 @@ func checkIfLoaded(ctx context.Context, qc types.QueryCoord, collectionName stri } if len(resp.GetPartitionIDs()) > 0 { - log.Warn("collection not fully loaded, search on these partitions", + log.Ctx(ctx).Warn("collection not fully loaded, search on these partitions", zap.String("collection", collectionName), zap.Int64("collectionID", info.collID), zap.Int64s("partitionIDs", resp.GetPartitionIDs())) return true, nil @@ -491,7 +492,7 @@ func checkIfLoaded(ctx context.Context, qc types.QueryCoord, collectionName stri return false, nil } -func decodeSearchResults(searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) { +func decodeSearchResults(ctx context.Context, searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) { tr := timerecord.NewTimeRecorder("decodeSearchResults") results := make([]*schemapb.SearchResultData, 0) for _, partialSearchResult := range searchResults { @@ -507,7 +508,7 @@ func decodeSearchResults(searchResults []*internalpb.SearchResults) ([]*schemapb results = append(results, &partialResultData) } - tr.Elapse("decodeSearchResults done") + tr.CtxElapse(ctx, "decodeSearchResults done") return results, nil } @@ -544,14 +545,13 @@ func selectSearchResultData(dataArray []*schemapb.SearchResultData, resultOffset return sel } -func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType) (*milvuspb.SearchResults, error) { - +func reduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType) (*milvuspb.SearchResults, error) { tr := timerecord.NewTimeRecorder("reduceSearchResultData") defer func() { - tr.Elapse("done") + tr.CtxElapse(ctx, "done") }() - log.Debug("reduceSearchResultData", zap.Int("len(searchResultData)", len(searchResultData)), + log.Ctx(ctx).Debug("reduceSearchResultData", zap.Int("len(searchResultData)", len(searchResultData)), zap.Int64("nq", nq), zap.Int64("topk", topk), zap.String("metricType", metricType)) ret := &milvuspb.SearchResults{ @@ -585,14 +585,14 @@ func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq in } for i, sData := range searchResultData { - log.Debug("reduceSearchResultData", + log.Ctx(ctx).Debug("reduceSearchResultData", zap.Int("result No.", i), zap.Int64("nq", sData.NumQueries), zap.Int64("topk", sData.TopK), zap.Any("len(topks)", len(sData.Topks)), zap.Any("len(FieldsData)", len(sData.FieldsData))) if err := checkSearchResultData(sData, nq, topk); err != nil { - log.Warn("invalid search results", zap.Error(err)) + log.Ctx(ctx).Warn("invalid search results", zap.Error(err)) return ret, err } //printSearchResultData(sData, strconv.FormatInt(int64(i), 10)) @@ -637,13 +637,13 @@ func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq in offsets[sel]++ } if realTopK != -1 && realTopK != j { - log.Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different"))) + log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different"))) // return nil, errors.New("the length (topk) between all result of query is different") } realTopK = j ret.Results.Topks = append(ret.Results.Topks, realTopK) } - log.Debug("skip duplicated search result", zap.Int64("count", skipDupCnt)) + log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt)) ret.Results.TopK = realTopK if !distance.PositivelyRelated(metricType) { diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 79a4de022..03ae3b6ab 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -1354,7 +1354,7 @@ func Test_reduceSearchResultData_int(t *testing.T) { }, } - reduced, err := reduceSearchResultData(results, int64(nq), int64(topk), distance.L2, schemapb.DataType_Int64) + reduced, err := reduceSearchResultData(context.TODO(), results, int64(nq), int64(topk), distance.L2, schemapb.DataType_Int64) assert.NoError(t, err) assert.ElementsMatch(t, []int64{3, 4, 7, 8, 11, 12}, reduced.GetResults().GetIds().GetIntId().GetData()) // hard to compare floating point value. @@ -1393,7 +1393,7 @@ func Test_reduceSearchResultData_str(t *testing.T) { }, } - reduced, err := reduceSearchResultData(results, int64(nq), int64(topk), distance.L2, schemapb.DataType_VarChar) + reduced, err := reduceSearchResultData(context.TODO(), results, int64(nq), int64(topk), distance.L2, schemapb.DataType_VarChar) assert.NoError(t, err) assert.ElementsMatch(t, []string{"3", "4", "7", "8", "11", "12"}, reduced.GetResults().GetIds().GetStrId().GetData()) // hard to compare floating point value. diff --git a/internal/querycoord/impl.go b/internal/querycoord/impl.go index 70c495bdc..f980da1ce 100644 --- a/internal/querycoord/impl.go +++ b/internal/querycoord/impl.go @@ -397,7 +397,7 @@ func (qc *QueryCoord) ReleaseCollection(ctx context.Context, req *querypb.Releas // ShowPartitions return all the partitions that have been loaded func (qc *QueryCoord) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { collectionID := req.CollectionID - log.Info("show partitions start", + log.Ctx(ctx).Debug("show partitions start", zap.String("role", typeutil.QueryCoordRole), zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", req.PartitionIDs), @@ -409,7 +409,7 @@ func (qc *QueryCoord) ShowPartitions(ctx context.Context, req *querypb.ShowParti status.ErrorCode = commonpb.ErrorCode_UnexpectedError err := errors.New("QueryCoord is not healthy") status.Reason = err.Error() - log.Error("show partition failed", zap.Int64("msgID", req.Base.MsgID), zap.Error(err)) + log.Ctx(ctx).Warn("show partition failed", zap.Int64("msgID", req.Base.MsgID), zap.Error(err)) return &querypb.ShowPartitionsResponse{ Status: status, }, nil @@ -420,7 +420,7 @@ func (qc *QueryCoord) ShowPartitions(ctx context.Context, req *querypb.ShowParti err = fmt.Errorf("collection %d has not been loaded into QueryNode", collectionID) status.ErrorCode = commonpb.ErrorCode_UnexpectedError status.Reason = err.Error() - log.Warn("show partitions failed", + log.Ctx(ctx).Warn("show partitions failed", zap.String("role", typeutil.QueryCoordRole), zap.Int64("collectionID", collectionID), zap.Int64("msgID", req.Base.MsgID), zap.Error(err)) @@ -439,7 +439,7 @@ func (qc *QueryCoord) ShowPartitions(ctx context.Context, req *querypb.ShowParti for _, id := range inMemoryPartitionIDs { inMemoryPercentages = append(inMemoryPercentages, ID2PartitionState[id].InMemoryPercentage) } - log.Info("show partitions end", + log.Ctx(ctx).Debug("show partitions end", zap.String("role", typeutil.QueryCoordRole), zap.Int64("collectionID", collectionID), zap.Int64("msgID", req.Base.MsgID), @@ -456,7 +456,7 @@ func (qc *QueryCoord) ShowPartitions(ctx context.Context, req *querypb.ShowParti err = fmt.Errorf("partition %d of collection %d has not been loaded into QueryNode", id, collectionID) status.ErrorCode = commonpb.ErrorCode_UnexpectedError status.Reason = err.Error() - log.Warn("show partitions failed", + log.Ctx(ctx).Warn("show partitions failed", zap.String("role", typeutil.QueryCoordRole), zap.Int64("collectionID", collectionID), zap.Int64("partitionID", id), @@ -469,7 +469,7 @@ func (qc *QueryCoord) ShowPartitions(ctx context.Context, req *querypb.ShowParti inMemoryPercentages = append(inMemoryPercentages, ID2PartitionState[id].InMemoryPercentage) } - log.Info("show partitions end", + log.Ctx(ctx).Debug("show partitions end", zap.String("role", typeutil.QueryCoordRole), zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", req.PartitionIDs), diff --git a/internal/querynode/benchmark_test.go b/internal/querynode/benchmark_test.go index 80d2c9775..29b561127 100644 --- a/internal/querynode/benchmark_test.go +++ b/internal/querynode/benchmark_test.go @@ -84,7 +84,7 @@ func benchmarkQueryCollectionSearch(nq int64, b *testing.B) { searchReq, err := newSearchRequest(collection, queryReq, queryReq.Req.GetPlaceholderGroup()) assert.NoError(b, err) for j := 0; j < 10000; j++ { - _, _, _, err := searchHistorical(queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, queryReq.GetSegmentIDs()) + _, _, _, err := searchHistorical(context.TODO(), queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, queryReq.GetSegmentIDs()) assert.NoError(b, err) } @@ -108,7 +108,7 @@ func benchmarkQueryCollectionSearch(nq int64, b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { for j := int64(0); j < benchmarkMaxNQ/nq; j++ { - _, _, _, err := searchHistorical(queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, queryReq.GetSegmentIDs()) + _, _, _, err := searchHistorical(context.TODO(), queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, queryReq.GetSegmentIDs()) assert.NoError(b, err) } } @@ -153,7 +153,7 @@ func benchmarkQueryCollectionSearchIndex(nq int64, indexType string, b *testing. searchReq, _ := genSearchPlanAndRequests(collection, indexType, nq) for j := 0; j < 10000; j++ { - _, _, _, err := searchHistorical(queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, []UniqueID{defaultSegmentID}) + _, _, _, err := searchHistorical(context.TODO(), queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, []UniqueID{defaultSegmentID}) assert.NoError(b, err) } @@ -178,7 +178,7 @@ func benchmarkQueryCollectionSearchIndex(nq int64, indexType string, b *testing. b.ResetTimer() for i := 0; i < b.N; i++ { for j := 0; j < benchmarkMaxNQ/int(nq); j++ { - _, _, _, err := searchHistorical(queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, []UniqueID{defaultSegmentID}) + _, _, _, err := searchHistorical(context.TODO(), queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, []UniqueID{defaultSegmentID}) assert.NoError(b, err) } } diff --git a/internal/querynode/impl.go b/internal/querynode/impl.go index 6e2f5b922..439cb6e84 100644 --- a/internal/querynode/impl.go +++ b/internal/querynode/impl.go @@ -568,7 +568,7 @@ func (node *QueryNode) isHealthy() bool { // Search performs replica search tasks. func (node *QueryNode) Search(ctx context.Context, req *queryPb.SearchRequest) (*internalpb.SearchResults, error) { - log.Debug("Received SearchRequest", + log.Ctx(ctx).Debug("Received SearchRequest", zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()), zap.Strings("vChannels", req.GetDmlChannels()), zap.Int64s("segmentIDs", req.GetSegmentIDs()), @@ -613,7 +613,7 @@ func (node *QueryNode) Search(ctx context.Context, req *queryPb.SearchRequest) ( if err := runningGp.Wait(); err != nil { return failRet, nil } - ret, err := reduceSearchResults(toReduceResults, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType()) + ret, err := reduceSearchResults(ctx, toReduceResults, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType()) if err != nil { failRet.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError failRet.Status.Reason = err.Error() @@ -641,7 +641,7 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *queryPb.Se } msgID := req.GetReq().GetBase().GetMsgID() - log.Debug("Received SearchRequest", + log.Ctx(ctx).Debug("Received SearchRequest", zap.Int64("msgID", msgID), zap.Bool("fromShardLeader", req.GetFromShardLeader()), zap.String("vChannel", dmlChannel), @@ -656,7 +656,7 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *queryPb.Se qs, err := node.queryShardService.getQueryShard(dmlChannel) if err != nil { - log.Warn("Search failed, failed to get query shard", + log.Ctx(ctx).Warn("Search failed, failed to get query shard", zap.Int64("msgID", msgID), zap.String("dml channel", dmlChannel), zap.Error(err)) @@ -665,7 +665,7 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *queryPb.Se return failRet, nil } - log.Debug("start do search", + log.Ctx(ctx).Debug("start do search", zap.Int64("msgID", msgID), zap.Bool("fromShardLeader", req.GetFromShardLeader()), zap.String("vChannel", dmlChannel), @@ -692,7 +692,7 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *queryPb.Se return failRet, nil } - tr.Elapse(fmt.Sprintf("do search done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v", + tr.CtxElapse(ctx, fmt.Sprintf("do search done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v", msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs())) failRet.Status.ErrorCode = commonpb.ErrorCode_Success @@ -747,22 +747,22 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *queryPb.Se // shard leader dispatches request to its shard cluster results, errCluster = cluster.Search(searchCtx, req, withStreaming) if errCluster != nil { - log.Warn("search cluster failed", zap.Int64("msgID", msgID), zap.Int64("collectionID", req.Req.GetCollectionID()), zap.Error(errCluster)) + log.Ctx(ctx).Warn("search cluster failed", zap.Int64("msgID", msgID), zap.Int64("collectionID", req.Req.GetCollectionID()), zap.Error(errCluster)) failRet.Status.Reason = errCluster.Error() return failRet, nil } - tr.Elapse(fmt.Sprintf("start reduce search result, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v", + tr.CtxElapse(ctx, fmt.Sprintf("start reduce search result, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v", msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs())) results = append(results, streamingResult) - ret, err2 := reduceSearchResults(results, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType()) + ret, err2 := reduceSearchResults(ctx, results, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType()) if err2 != nil { failRet.Status.Reason = err2.Error() return failRet, nil } - tr.Elapse(fmt.Sprintf("do search done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v", + tr.CtxElapse(ctx, fmt.Sprintf("do search done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v", msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs())) failRet.Status.ErrorCode = commonpb.ErrorCode_Success @@ -793,7 +793,7 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *queryPb.Que } msgID := req.GetReq().GetBase().GetMsgID() - log.Debug("Received QueryRequest", + log.Ctx(ctx).Debug("Received QueryRequest", zap.Int64("msgID", msgID), zap.Bool("fromShardLeader", req.GetFromShardLeader()), zap.String("vChannel", dmlChannel), @@ -808,12 +808,12 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *queryPb.Que qs, err := node.queryShardService.getQueryShard(dmlChannel) if err != nil { - log.Warn("Query failed, failed to get query shard", zap.Int64("msgID", msgID), zap.String("dml channel", dmlChannel), zap.Error(err)) + log.Ctx(ctx).Warn("Query failed, failed to get query shard", zap.Int64("msgID", msgID), zap.String("dml channel", dmlChannel), zap.Error(err)) failRet.Status.Reason = err.Error() return failRet, nil } - log.Debug("start do query", + log.Ctx(ctx).Debug("start do query", zap.Int64("msgID", msgID), zap.Bool("fromShardLeader", req.GetFromShardLeader()), zap.String("vChannel", dmlChannel), @@ -837,7 +837,7 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *queryPb.Que return failRet, nil } - tr.Elapse(fmt.Sprintf("do query done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v", + tr.CtxElapse(ctx, fmt.Sprintf("do query done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v", msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs())) failRet.Status.ErrorCode = commonpb.ErrorCode_Success @@ -890,22 +890,22 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *queryPb.Que // shard leader dispatches request to its shard cluster results, errCluster = cluster.Query(queryCtx, req, withStreaming) if errCluster != nil { - log.Warn("failed to query cluster", zap.Int64("msgID", msgID), zap.Int64("collectionID", req.Req.GetCollectionID()), zap.Error(errCluster)) + log.Ctx(ctx).Warn("failed to query cluster", zap.Int64("msgID", msgID), zap.Int64("collectionID", req.Req.GetCollectionID()), zap.Error(errCluster)) failRet.Status.Reason = errCluster.Error() return failRet, nil } - tr.Elapse(fmt.Sprintf("start reduce query result, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v", + tr.CtxElapse(ctx, fmt.Sprintf("start reduce query result, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v", msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs())) results = append(results, streamingResult) - ret, err2 := mergeInternalRetrieveResults(results) + ret, err2 := mergeInternalRetrieveResults(ctx, results) if err2 != nil { failRet.Status.Reason = err2.Error() return failRet, nil } - tr.Elapse(fmt.Sprintf("do query done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v", + tr.CtxElapse(ctx, fmt.Sprintf("do query done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v", msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs())) failRet.Status.ErrorCode = commonpb.ErrorCode_Success @@ -917,7 +917,7 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *queryPb.Que // Query performs replica query tasks. func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) { - log.Debug("Received QueryRequest", zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()), + log.Ctx(ctx).Debug("Received QueryRequest", zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()), zap.Strings("vChannels", req.GetDmlChannels()), zap.Int64s("segmentIDs", req.GetSegmentIDs()), zap.Uint64("guaranteeTimestamp", req.Req.GetGuaranteeTimestamp()), @@ -963,7 +963,7 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i if err := runningGp.Wait(); err != nil { return failRet, nil } - ret, err := mergeInternalRetrieveResults(toMergeResults) + ret, err := mergeInternalRetrieveResults(ctx, toMergeResults) if err != nil { failRet.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError failRet.Status.Reason = err.Error() diff --git a/internal/querynode/query_shard_test.go b/internal/querynode/query_shard_test.go index 02d5df73d..22ea0b82c 100644 --- a/internal/querynode/query_shard_test.go +++ b/internal/querynode/query_shard_test.go @@ -159,7 +159,7 @@ func TestReduceSearchResultData(t *testing.T) { dataArray := make([]*schemapb.SearchResultData, 0) dataArray = append(dataArray, data1) dataArray = append(dataArray, data2) - res, err := reduceSearchResultData(dataArray, nq, topk) + res, err := reduceSearchResultData(context.TODO(), dataArray, nq, topk) assert.Nil(t, err) assert.Equal(t, ids, res.Ids.GetIntId().Data) assert.Equal(t, scores, res.Scores) @@ -176,7 +176,7 @@ func TestReduceSearchResultData(t *testing.T) { dataArray := make([]*schemapb.SearchResultData, 0) dataArray = append(dataArray, data1) dataArray = append(dataArray, data2) - res, err := reduceSearchResultData(dataArray, nq, topk) + res, err := reduceSearchResultData(context.TODO(), dataArray, nq, topk) assert.Nil(t, err) assert.ElementsMatch(t, []int64{1, 5, 2, 3}, res.Ids.GetIntId().Data) }) @@ -223,12 +223,13 @@ func TestMergeInternalRetrieveResults(t *testing.T) { // Offset: []int64{0, 1}, FieldsData: fieldDataArray2, } + ctx := context.TODO() - result, err := mergeInternalRetrieveResults([]*internalpb.RetrieveResults{result1, result2}) + result, err := mergeInternalRetrieveResults(ctx, []*internalpb.RetrieveResults{result1, result2}) assert.NoError(t, err) assert.Equal(t, 2, len(result.FieldsData[0].GetScalars().GetLongData().Data)) assert.Equal(t, 2*Dim, len(result.FieldsData[1].GetVectors().GetFloatVector().Data)) - _, err = mergeInternalRetrieveResults(nil) + _, err = mergeInternalRetrieveResults(ctx, nil) assert.NoError(t, err) } diff --git a/internal/querynode/result.go b/internal/querynode/result.go index 98f63252e..c7020ab48 100644 --- a/internal/querynode/result.go +++ b/internal/querynode/result.go @@ -17,6 +17,7 @@ package querynode import ( + "context" "fmt" "math" "strconv" @@ -73,24 +74,24 @@ func reduceStatisticResponse(results []*internalpb.GetStatisticsResponse) (*inte return ret, nil } -func reduceSearchResults(results []*internalpb.SearchResults, nq int64, topk int64, metricType string) (*internalpb.SearchResults, error) { +func reduceSearchResults(ctx context.Context, results []*internalpb.SearchResults, nq int64, topk int64, metricType string) (*internalpb.SearchResults, error) { searchResultData, err := decodeSearchResults(results) if err != nil { - log.Warn("shard leader decode search results errors", zap.Error(err)) + log.Ctx(ctx).Warn("shard leader decode search results errors", zap.Error(err)) return nil, err } - log.Debug("shard leader get valid search results", zap.Int("numbers", len(searchResultData))) + log.Ctx(ctx).Debug("shard leader get valid search results", zap.Int("numbers", len(searchResultData))) for i, sData := range searchResultData { - log.Debug("reduceSearchResultData", + log.Ctx(ctx).Debug("reduceSearchResultData", zap.Int("result No.", i), zap.Int64("nq", sData.NumQueries), zap.Int64("topk", sData.TopK)) } - reducedResultData, err := reduceSearchResultData(searchResultData, nq, topk) + reducedResultData, err := reduceSearchResultData(ctx, searchResultData, nq, topk) if err != nil { - log.Warn("shard leader reduce errors", zap.Error(err)) + log.Ctx(ctx).Warn("shard leader reduce errors", zap.Error(err)) return nil, err } searchResults, err := encodeSearchResultData(reducedResultData, nq, topk, metricType) @@ -110,7 +111,7 @@ func reduceSearchResults(results []*internalpb.SearchResults, nq int64, topk int return searchResults, nil } -func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq int64, topk int64) (*schemapb.SearchResultData, error) { +func reduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, nq int64, topk int64) (*schemapb.SearchResultData, error) { if len(searchResultData) == 0 { return &schemapb.SearchResultData{ NumQueries: nq, @@ -174,7 +175,7 @@ func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq in // } ret.Topks = append(ret.Topks, j) } - log.Debug("skip duplicated search result", zap.Int64("count", skipDupCnt)) + log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt)) return ret, nil } @@ -234,7 +235,7 @@ func encodeSearchResultData(searchResultData *schemapb.SearchResultData, nq int6 } // TODO: largely based on function mergeSegcoreRetrieveResults, need rewriting -func mergeInternalRetrieveResults(retrieveResults []*internalpb.RetrieveResults) (*internalpb.RetrieveResults, error) { +func mergeInternalRetrieveResults(ctx context.Context, retrieveResults []*internalpb.RetrieveResults) (*internalpb.RetrieveResults, error) { var ret *internalpb.RetrieveResults var skipDupCnt int64 var idSet = make(map[interface{}]struct{}) @@ -254,7 +255,7 @@ func mergeInternalRetrieveResults(retrieveResults []*internalpb.RetrieveResults) } if len(ret.FieldsData) != len(rr.FieldsData) { - log.Warn("mismatch FieldData in RetrieveResults") + log.Ctx(ctx).Warn("mismatch FieldData in RetrieveResults") return nil, fmt.Errorf("mismatch FieldData in RetrieveResults") } @@ -283,7 +284,7 @@ func mergeInternalRetrieveResults(retrieveResults []*internalpb.RetrieveResults) return ret, nil } -func mergeSegcoreRetrieveResults(retrieveResults []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) { +func mergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) { var ret *segcorepb.RetrieveResults var skipDupCnt int64 var idSet = make(map[interface{}]struct{}) @@ -319,7 +320,7 @@ func mergeSegcoreRetrieveResults(retrieveResults []*segcorepb.RetrieveResults) ( } } } - log.Debug("skip duplicated query result", zap.Int64("count", skipDupCnt)) + log.Ctx(ctx).Debug("skip duplicated query result", zap.Int64("count", skipDupCnt)) // not found, return default values indicating not result found if ret == nil { diff --git a/internal/querynode/retrieve.go b/internal/querynode/retrieve.go index a3032a953..3410abcbf 100644 --- a/internal/querynode/retrieve.go +++ b/internal/querynode/retrieve.go @@ -17,6 +17,8 @@ package querynode import ( + "context" + "github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/storage" ) @@ -44,12 +46,12 @@ func retrieveOnSegments(replica ReplicaInterface, segType segmentType, collID Un } // retrieveHistorical will retrieve all the target segments in historical -func retrieveHistorical(replica ReplicaInterface, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) { +func retrieveHistorical(ctx context.Context, replica ReplicaInterface, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) { var err error var retrieveResults []*segcorepb.RetrieveResults var retrieveSegmentIDs []UniqueID var retrievePartIDs []UniqueID - retrievePartIDs, retrieveSegmentIDs, err = validateOnHistoricalReplica(replica, collID, partIDs, segIDs) + retrievePartIDs, retrieveSegmentIDs, err = validateOnHistoricalReplica(ctx, replica, collID, partIDs, segIDs) if err != nil { return retrieveResults, retrieveSegmentIDs, retrievePartIDs, err } @@ -59,13 +61,13 @@ func retrieveHistorical(replica ReplicaInterface, plan *RetrievePlan, collID Uni } // retrieveStreaming will retrieve all the target segments in streaming -func retrieveStreaming(replica ReplicaInterface, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, vChannel Channel, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) { +func retrieveStreaming(ctx context.Context, replica ReplicaInterface, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, vChannel Channel, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) { var err error var retrieveResults []*segcorepb.RetrieveResults var retrievePartIDs []UniqueID var retrieveSegmentIDs []UniqueID - retrievePartIDs, retrieveSegmentIDs, err = validateOnStreamReplica(replica, collID, partIDs, vChannel) + retrievePartIDs, retrieveSegmentIDs, err = validateOnStreamReplica(ctx, replica, collID, partIDs, vChannel) if err != nil { return retrieveResults, retrieveSegmentIDs, retrievePartIDs, err } diff --git a/internal/querynode/retrieve_test.go b/internal/querynode/retrieve_test.go index be7f60f43..a6dfe2a7d 100644 --- a/internal/querynode/retrieve_test.go +++ b/internal/querynode/retrieve_test.go @@ -17,6 +17,7 @@ package querynode import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -49,7 +50,7 @@ func TestStreaming_retrieve(t *testing.T) { assert.NoError(t, err) t.Run("test retrieve", func(t *testing.T) { - res, _, ids, err := retrieveStreaming(streaming, plan, + res, _, ids, err := retrieveStreaming(context.TODO(), streaming, plan, defaultCollectionID, []UniqueID{defaultPartitionID}, defaultDMLChannel, @@ -61,7 +62,7 @@ func TestStreaming_retrieve(t *testing.T) { t.Run("test empty partition", func(t *testing.T) { - res, _, ids, err := retrieveStreaming(streaming, plan, + res, _, ids, err := retrieveStreaming(context.TODO(), streaming, plan, defaultCollectionID, nil, defaultDMLChannel, diff --git a/internal/querynode/search.go b/internal/querynode/search.go index 8e5c8bf16..71c771ee8 100644 --- a/internal/querynode/search.go +++ b/internal/querynode/search.go @@ -17,6 +17,7 @@ package querynode import ( + "context" "fmt" "sync" @@ -28,7 +29,7 @@ import ( // searchOnSegments performs search on listed segments // all segment ids are validated before calling this function -func searchOnSegments(replica ReplicaInterface, segType segmentType, searchReq *searchRequest, segIDs []UniqueID) ([]*SearchResult, error) { +func searchOnSegments(ctx context.Context, replica ReplicaInterface, segType segmentType, searchReq *searchRequest, segIDs []UniqueID) ([]*SearchResult, error) { // results variables searchResults := make([]*SearchResult, len(segIDs)) errs := make([]error, len(segIDs)) @@ -72,31 +73,31 @@ func searchOnSegments(replica ReplicaInterface, segType segmentType, searchReq * // if segIDs is not specified, it will search on all the historical segments speficied by partIDs. // if segIDs is specified, it will only search on the segments specified by the segIDs. // if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded. -func searchHistorical(replica ReplicaInterface, searchReq *searchRequest, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]*SearchResult, []UniqueID, []UniqueID, error) { +func searchHistorical(ctx context.Context, replica ReplicaInterface, searchReq *searchRequest, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]*SearchResult, []UniqueID, []UniqueID, error) { var err error var searchResults []*SearchResult var searchSegmentIDs []UniqueID var searchPartIDs []UniqueID - searchPartIDs, searchSegmentIDs, err = validateOnHistoricalReplica(replica, collID, partIDs, segIDs) + searchPartIDs, searchSegmentIDs, err = validateOnHistoricalReplica(ctx, replica, collID, partIDs, segIDs) if err != nil { return searchResults, searchSegmentIDs, searchPartIDs, err } - searchResults, err = searchOnSegments(replica, segmentTypeSealed, searchReq, searchSegmentIDs) + searchResults, err = searchOnSegments(ctx, replica, segmentTypeSealed, searchReq, searchSegmentIDs) return searchResults, searchPartIDs, searchSegmentIDs, err } // searchStreaming will search all the target segments in streaming // if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded. -func searchStreaming(replica ReplicaInterface, searchReq *searchRequest, collID UniqueID, partIDs []UniqueID, vChannel Channel) ([]*SearchResult, []UniqueID, []UniqueID, error) { +func searchStreaming(ctx context.Context, replica ReplicaInterface, searchReq *searchRequest, collID UniqueID, partIDs []UniqueID, vChannel Channel) ([]*SearchResult, []UniqueID, []UniqueID, error) { var err error var searchResults []*SearchResult var searchPartIDs []UniqueID var searchSegmentIDs []UniqueID - searchPartIDs, searchSegmentIDs, err = validateOnStreamReplica(replica, collID, partIDs, vChannel) + searchPartIDs, searchSegmentIDs, err = validateOnStreamReplica(ctx, replica, collID, partIDs, vChannel) if err != nil { return searchResults, searchSegmentIDs, searchPartIDs, err } - searchResults, err = searchOnSegments(replica, segmentTypeGrowing, searchReq, searchSegmentIDs) + searchResults, err = searchOnSegments(ctx, replica, segmentTypeGrowing, searchReq, searchSegmentIDs) return searchResults, searchPartIDs, searchSegmentIDs, err } diff --git a/internal/querynode/search_test.go b/internal/querynode/search_test.go index 087777dfe..7a08993b5 100644 --- a/internal/querynode/search_test.go +++ b/internal/querynode/search_test.go @@ -36,7 +36,7 @@ func TestHistorical_Search(t *testing.T) { searchReq, err := genSearchPlanAndRequests(collection, IndexFaissIDMap, defaultNQ) assert.NoError(t, err) - _, _, _, err = searchHistorical(his, searchReq, defaultCollectionID, nil, []UniqueID{defaultSegmentID}) + _, _, _, err = searchHistorical(context.TODO(), his, searchReq, defaultCollectionID, nil, []UniqueID{defaultSegmentID}) assert.NoError(t, err) }) @@ -52,7 +52,7 @@ func TestHistorical_Search(t *testing.T) { err = his.removeCollection(defaultCollectionID) assert.NoError(t, err) - _, _, _, err = searchHistorical(his, searchReq, defaultCollectionID, nil, nil) + _, _, _, err = searchHistorical(context.TODO(), his, searchReq, defaultCollectionID, nil, nil) assert.Error(t, err) }) @@ -68,7 +68,7 @@ func TestHistorical_Search(t *testing.T) { err = his.removeCollection(defaultCollectionID) assert.NoError(t, err) - _, _, _, err = searchHistorical(his, searchReq, defaultCollectionID, []UniqueID{defaultPartitionID}, nil) + _, _, _, err = searchHistorical(context.TODO(), his, searchReq, defaultCollectionID, []UniqueID{defaultPartitionID}, nil) assert.Error(t, err) }) @@ -88,7 +88,7 @@ func TestHistorical_Search(t *testing.T) { err = his.removePartition(defaultPartitionID) assert.NoError(t, err) - _, _, _, err = searchHistorical(his, searchReq, defaultCollectionID, nil, nil) + _, _, _, err = searchHistorical(context.TODO(), his, searchReq, defaultCollectionID, nil, nil) assert.Error(t, err) }) @@ -104,7 +104,7 @@ func TestHistorical_Search(t *testing.T) { err = his.removePartition(defaultPartitionID) assert.NoError(t, err) - res, _, ids, err := searchHistorical(his, searchReq, defaultCollectionID, nil, nil) + res, _, ids, err := searchHistorical(context.TODO(), his, searchReq, defaultCollectionID, nil, nil) assert.Equal(t, 0, len(res)) assert.Equal(t, 0, len(ids)) assert.NoError(t, err) @@ -121,7 +121,7 @@ func TestStreaming_search(t *testing.T) { searchReq, err := genSearchPlanAndRequests(collection, IndexFaissIDMap, defaultNQ) assert.NoError(t, err) - res, _, _, err := searchStreaming(streaming, searchReq, + res, _, _, err := searchStreaming(context.TODO(), streaming, searchReq, defaultCollectionID, []UniqueID{defaultPartitionID}, defaultDMLChannel) @@ -138,7 +138,7 @@ func TestStreaming_search(t *testing.T) { searchReq, err := genSearchPlanAndRequests(collection, IndexFaissIDMap, defaultNQ) assert.NoError(t, err) - res, _, _, err := searchStreaming(streaming, searchReq, + res, _, _, err := searchStreaming(context.TODO(), streaming, searchReq, defaultCollectionID, []UniqueID{defaultPartitionID}, defaultDMLChannel) @@ -162,7 +162,7 @@ func TestStreaming_search(t *testing.T) { err = streaming.removePartition(defaultPartitionID) assert.NoError(t, err) - res, _, _, err := searchStreaming(streaming, searchReq, + res, _, _, err := searchStreaming(context.TODO(), streaming, searchReq, defaultCollectionID, []UniqueID{defaultPartitionID}, defaultDMLChannel) @@ -187,7 +187,7 @@ func TestStreaming_search(t *testing.T) { err = streaming.removePartition(defaultPartitionID) assert.NoError(t, err) - _, _, _, err = searchStreaming(streaming, searchReq, + _, _, _, err = searchStreaming(context.TODO(), streaming, searchReq, defaultCollectionID, []UniqueID{defaultPartitionID}, defaultDMLChannel) @@ -206,7 +206,7 @@ func TestStreaming_search(t *testing.T) { err = streaming.removePartition(defaultPartitionID) assert.NoError(t, err) - res, _, _, err := searchStreaming(streaming, searchReq, + res, _, _, err := searchStreaming(context.TODO(), streaming, searchReq, defaultCollectionID, []UniqueID{}, defaultDMLChannel) @@ -228,7 +228,7 @@ func TestStreaming_search(t *testing.T) { seg.segmentPtr = nil - _, _, _, err = searchStreaming(streaming, searchReq, + _, _, _, err = searchStreaming(context.TODO(), streaming, searchReq, defaultCollectionID, []UniqueID{}, defaultDMLChannel) diff --git a/internal/querynode/statistic_test.go b/internal/querynode/statistic_test.go index d4d4c3ecd..c142938de 100644 --- a/internal/querynode/statistic_test.go +++ b/internal/querynode/statistic_test.go @@ -31,7 +31,7 @@ func TestHistorical_statistic(t *testing.T) { his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) - _, _, _, err = statisticHistorical(his, defaultCollectionID, nil, []UniqueID{defaultSegmentID}) + _, _, _, err = statisticHistorical(context.TODO(), his, defaultCollectionID, nil, []UniqueID{defaultSegmentID}) assert.NoError(t, err) }) @@ -42,7 +42,7 @@ func TestHistorical_statistic(t *testing.T) { err = his.removeCollection(defaultCollectionID) assert.NoError(t, err) - _, _, _, err = statisticHistorical(his, defaultCollectionID, nil, nil) + _, _, _, err = statisticHistorical(context.TODO(), his, defaultCollectionID, nil, nil) assert.Error(t, err) }) @@ -53,7 +53,7 @@ func TestHistorical_statistic(t *testing.T) { err = his.removeCollection(defaultCollectionID) assert.NoError(t, err) - _, _, _, err = statisticHistorical(his, defaultCollectionID, []UniqueID{defaultPartitionID}, nil) + _, _, _, err = statisticHistorical(context.TODO(), his, defaultCollectionID, []UniqueID{defaultPartitionID}, nil) assert.Error(t, err) }) @@ -68,7 +68,7 @@ func TestHistorical_statistic(t *testing.T) { err = his.removePartition(defaultPartitionID) assert.NoError(t, err) - _, _, _, err = statisticHistorical(his, defaultCollectionID, nil, nil) + _, _, _, err = statisticHistorical(context.TODO(), his, defaultCollectionID, nil, nil) assert.Error(t, err) }) @@ -79,7 +79,7 @@ func TestHistorical_statistic(t *testing.T) { err = his.removePartition(defaultPartitionID) assert.NoError(t, err) - res, _, ids, err := statisticHistorical(his, defaultCollectionID, nil, nil) + res, _, ids, err := statisticHistorical(context.TODO(), his, defaultCollectionID, nil, nil) assert.Equal(t, 0, len(res)) assert.Equal(t, 0, len(ids)) assert.NoError(t, err) @@ -91,7 +91,7 @@ func TestStreaming_statistics(t *testing.T) { streaming, err := genSimpleReplicaWithGrowingSegment() assert.NoError(t, err) - res, _, _, err := statisticStreaming(streaming, + res, _, _, err := statisticStreaming(context.TODO(), streaming, defaultCollectionID, []UniqueID{defaultPartitionID}, defaultDMLChannel) @@ -103,7 +103,7 @@ func TestStreaming_statistics(t *testing.T) { streaming, err := genSimpleReplicaWithGrowingSegment() assert.NoError(t, err) - res, _, _, err := statisticStreaming(streaming, + res, _, _, err := statisticStreaming(context.TODO(), streaming, defaultCollectionID, []UniqueID{defaultPartitionID}, defaultDMLChannel) @@ -122,7 +122,7 @@ func TestStreaming_statistics(t *testing.T) { err = streaming.removePartition(defaultPartitionID) assert.NoError(t, err) - res, _, _, err := statisticStreaming(streaming, + res, _, _, err := statisticStreaming(context.TODO(), streaming, defaultCollectionID, []UniqueID{defaultPartitionID}, defaultDMLChannel) @@ -142,7 +142,7 @@ func TestStreaming_statistics(t *testing.T) { err = streaming.removePartition(defaultPartitionID) assert.NoError(t, err) - _, _, _, err = statisticStreaming(streaming, + _, _, _, err = statisticStreaming(context.TODO(), streaming, defaultCollectionID, []UniqueID{defaultPartitionID}, defaultDMLChannel) @@ -156,7 +156,7 @@ func TestStreaming_statistics(t *testing.T) { err = streaming.removePartition(defaultPartitionID) assert.NoError(t, err) - res, _, _, err := statisticStreaming(streaming, + res, _, _, err := statisticStreaming(context.TODO(), streaming, defaultCollectionID, []UniqueID{}, defaultDMLChannel) diff --git a/internal/querynode/statistics.go b/internal/querynode/statistics.go index 0342971b7..9c51a8538 100644 --- a/internal/querynode/statistics.go +++ b/internal/querynode/statistics.go @@ -1,6 +1,7 @@ package querynode import ( + "context" "fmt" "sync" @@ -55,8 +56,8 @@ func statisticOnSegments(replica ReplicaInterface, segType segmentType, segIDs [ // if segIDs is not specified, it will search on all the historical segments specified by partIDs. // if segIDs is specified, it will only search on the segments specified by the segIDs. // if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded. -func statisticHistorical(replica ReplicaInterface, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]map[string]interface{}, []UniqueID, []UniqueID, error) { - searchPartIDs, searchSegmentIDs, err := validateOnHistoricalReplica(replica, collID, partIDs, segIDs) +func statisticHistorical(ctx context.Context, replica ReplicaInterface, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]map[string]interface{}, []UniqueID, []UniqueID, error) { + searchPartIDs, searchSegmentIDs, err := validateOnHistoricalReplica(ctx, replica, collID, partIDs, segIDs) if err != nil { return nil, searchSegmentIDs, searchPartIDs, err } @@ -66,8 +67,8 @@ func statisticHistorical(replica ReplicaInterface, collID UniqueID, partIDs []Un // statisticStreaming will do statistics all the target segments in streaming // if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded. -func statisticStreaming(replica ReplicaInterface, collID UniqueID, partIDs []UniqueID, vChannel Channel) ([]map[string]interface{}, []UniqueID, []UniqueID, error) { - searchPartIDs, searchSegmentIDs, err := validateOnStreamReplica(replica, collID, partIDs, vChannel) +func statisticStreaming(ctx context.Context, replica ReplicaInterface, collID UniqueID, partIDs []UniqueID, vChannel Channel) ([]map[string]interface{}, []UniqueID, []UniqueID, error) { + searchPartIDs, searchSegmentIDs, err := validateOnStreamReplica(ctx, replica, collID, partIDs, vChannel) if err != nil { return nil, searchSegmentIDs, searchPartIDs, err } diff --git a/internal/querynode/task_query.go b/internal/querynode/task_query.go index 07ae85f73..3551214e5 100644 --- a/internal/querynode/task_query.go +++ b/internal/querynode/task_query.go @@ -53,6 +53,7 @@ func (q *queryTask) PreExecute(ctx context.Context) error { // TODO: merge queryOnStreaming and queryOnHistorical? func (q *queryTask) queryOnStreaming() error { // check ctx timeout + ctx := q.Ctx() if !funcutil.CheckCtxValid(q.Ctx()) { return errors.New("query context timeout") } @@ -66,7 +67,7 @@ func (q *queryTask) queryOnStreaming() error { q.QS.collection.RLock() // locks the collectionPtr defer q.QS.collection.RUnlock() if _, released := q.QS.collection.getReleaseTime(); released { - log.Debug("collection release before search", zap.Int64("msgID", q.ID()), + log.Ctx(ctx).Debug("collection release before search", zap.Int64("msgID", q.ID()), zap.Int64("collectionID", q.CollectionID)) return fmt.Errorf("retrieve failed, collection has been released, collectionID = %d", q.CollectionID) } @@ -78,13 +79,13 @@ func (q *queryTask) queryOnStreaming() error { } defer plan.delete() - sResults, _, _, sErr := retrieveStreaming(q.QS.metaReplica, plan, q.CollectionID, q.iReq.GetPartitionIDs(), q.QS.channel, q.QS.vectorChunkManager) + sResults, _, _, sErr := retrieveStreaming(ctx, q.QS.metaReplica, plan, q.CollectionID, q.iReq.GetPartitionIDs(), q.QS.channel, q.QS.vectorChunkManager) if sErr != nil { return sErr } q.tr.RecordSpan() - mergedResult, err := mergeSegcoreRetrieveResults(sResults) + mergedResult, err := mergeSegcoreRetrieveResults(ctx, sResults) if err != nil { return err } @@ -100,7 +101,8 @@ func (q *queryTask) queryOnStreaming() error { func (q *queryTask) queryOnHistorical() error { // check ctx timeout - if !funcutil.CheckCtxValid(q.Ctx()) { + ctx := q.Ctx() + if !funcutil.CheckCtxValid(ctx) { return errors.New("search context timeout3$") } @@ -114,7 +116,7 @@ func (q *queryTask) queryOnHistorical() error { defer q.QS.collection.RUnlock() if _, released := q.QS.collection.getReleaseTime(); released { - log.Debug("collection release before search", zap.Int64("msgID", q.ID()), + log.Ctx(ctx).Debug("collection release before search", zap.Int64("msgID", q.ID()), zap.Int64("collectionID", q.CollectionID)) return fmt.Errorf("retrieve failed, collection has been released, collectionID = %d", q.CollectionID) } @@ -125,11 +127,11 @@ func (q *queryTask) queryOnHistorical() error { return err } defer plan.delete() - retrieveResults, _, _, err := retrieveHistorical(q.QS.metaReplica, plan, q.CollectionID, nil, q.req.SegmentIDs, q.QS.vectorChunkManager) + retrieveResults, _, _, err := retrieveHistorical(ctx, q.QS.metaReplica, plan, q.CollectionID, nil, q.req.SegmentIDs, q.QS.vectorChunkManager) if err != nil { return err } - mergedResult, err := mergeSegcoreRetrieveResults(retrieveResults) + mergedResult, err := mergeSegcoreRetrieveResults(ctx, retrieveResults) if err != nil { return err } diff --git a/internal/querynode/task_search.go b/internal/querynode/task_search.go index eebf13c25..cc0277704 100644 --- a/internal/querynode/task_search.go +++ b/internal/querynode/task_search.go @@ -85,7 +85,8 @@ func (s *searchTask) init() error { // TODO: merge searchOnStreaming and searchOnHistorical? func (s *searchTask) searchOnStreaming() error { // check ctx timeout - if !funcutil.CheckCtxValid(s.Ctx()) { + ctx := s.Ctx() + if !funcutil.CheckCtxValid(ctx) { return errors.New("search context timeout") } @@ -102,7 +103,7 @@ func (s *searchTask) searchOnStreaming() error { s.QS.collection.RLock() // locks the collectionPtr defer s.QS.collection.RUnlock() if _, released := s.QS.collection.getReleaseTime(); released { - log.Debug("collection release before search", zap.Int64("msgID", s.ID()), + log.Ctx(ctx).Debug("collection release before search", zap.Int64("msgID", s.ID()), zap.Int64("collectionID", s.CollectionID)) return fmt.Errorf("retrieve failed, collection has been released, collectionID = %d", s.CollectionID) } @@ -113,20 +114,20 @@ func (s *searchTask) searchOnStreaming() error { } defer searchReq.delete() - // TODO add context - partResults, _, _, sErr := searchStreaming(s.QS.metaReplica, searchReq, s.CollectionID, s.iReq.GetPartitionIDs(), s.req.GetDmlChannels()[0]) + partResults, _, _, sErr := searchStreaming(ctx, s.QS.metaReplica, searchReq, s.CollectionID, s.iReq.GetPartitionIDs(), s.req.GetDmlChannels()[0]) if sErr != nil { - log.Debug("failed to search streaming data", zap.Int64("msgID", s.ID()), + log.Ctx(ctx).Warn("failed to search streaming data", zap.Int64("msgID", s.ID()), zap.Int64("collectionID", s.CollectionID), zap.Error(sErr)) return sErr } defer deleteSearchResults(partResults) - return s.reduceResults(searchReq, partResults) + return s.reduceResults(ctx, searchReq, partResults) } func (s *searchTask) searchOnHistorical() error { // check ctx timeout - if !funcutil.CheckCtxValid(s.Ctx()) { + ctx := s.Ctx() + if !funcutil.CheckCtxValid(ctx) { return errors.New("search context timeout") } @@ -139,7 +140,7 @@ func (s *searchTask) searchOnHistorical() error { s.QS.collection.RLock() // locks the collectionPtr defer s.QS.collection.RUnlock() if _, released := s.QS.collection.getReleaseTime(); released { - log.Debug("collection release before search", zap.Int64("msgID", s.ID()), + log.Ctx(ctx).Warn("collection release before search", zap.Int64("msgID", s.ID()), zap.Int64("collectionID", s.CollectionID)) return fmt.Errorf("retrieve failed, collection has been released, collectionID = %d", s.CollectionID) } @@ -151,12 +152,12 @@ func (s *searchTask) searchOnHistorical() error { } defer searchReq.delete() - partResults, _, _, err := searchHistorical(s.QS.metaReplica, searchReq, s.CollectionID, nil, segmentIDs) + partResults, _, _, err := searchHistorical(ctx, s.QS.metaReplica, searchReq, s.CollectionID, nil, segmentIDs) if err != nil { return err } defer deleteSearchResults(partResults) - return s.reduceResults(searchReq, partResults) + return s.reduceResults(ctx, searchReq, partResults) } func (s *searchTask) Execute(ctx context.Context) error { @@ -217,7 +218,7 @@ func (s *searchTask) CPUUsage() int32 { } // reduceResults reduce search results -func (s *searchTask) reduceResults(searchReq *searchRequest, results []*SearchResult) error { +func (s *searchTask) reduceResults(ctx context.Context, searchReq *searchRequest, results []*SearchResult) error { isEmpty := len(results) == 0 cnt := 1 + len(s.otherTasks) var t *searchTask @@ -227,7 +228,7 @@ func (s *searchTask) reduceResults(searchReq *searchRequest, results []*SearchRe numSegment := int64(len(results)) blobs, err := reduceSearchResultsAndFillData(searchReq.plan, results, numSegment, sInfo.sliceNQs, sInfo.sliceTopKs) if err != nil { - log.Debug("marshal for historical results error", zap.Int64("msgID", s.ID()), zap.Error(err)) + log.Ctx(ctx).Warn("marshal for historical results error", zap.Int64("msgID", s.ID()), zap.Error(err)) return err } defer deleteSearchResultDataBlobs(blobs) @@ -235,7 +236,7 @@ func (s *searchTask) reduceResults(searchReq *searchRequest, results []*SearchRe for i := 0; i < cnt; i++ { blob, err := getSearchResultDataBlob(blobs, i) if err != nil { - log.Debug("getSearchResultDataBlob for historical results error", zap.Int64("msgID", s.ID()), + log.Ctx(ctx).Warn("getSearchResultDataBlob for historical results error", zap.Int64("msgID", s.ID()), zap.Error(err)) return err } diff --git a/internal/querynode/task_statistics.go b/internal/querynode/task_statistics.go index 524b0cfe5..062c7f984 100644 --- a/internal/querynode/task_statistics.go +++ b/internal/querynode/task_statistics.go @@ -34,7 +34,8 @@ type statistics struct { func (s *statistics) statisticOnStreaming() error { // check ctx timeout - if !funcutil.CheckCtxValid(s.ctx) { + ctx := s.ctx + if !funcutil.CheckCtxValid(ctx) { return errors.New("get statistics context timeout") } @@ -47,14 +48,15 @@ func (s *statistics) statisticOnStreaming() error { s.qs.collection.RLock() // locks the collectionPtr defer s.qs.collection.RUnlock() if _, released := s.qs.collection.getReleaseTime(); released { - log.Debug("collection release before do statistics", zap.Int64("msgID", s.id), + log.Ctx(ctx).Warn("collection release before do statistics", zap.Int64("msgID", s.id), zap.Int64("collectionID", s.iReq.GetCollectionID())) return fmt.Errorf("statistic failed, collection has been released, collectionID = %d", s.iReq.GetCollectionID()) } - results, _, _, err := statisticStreaming(s.qs.metaReplica, s.iReq.GetCollectionID(), s.iReq.GetPartitionIDs(), s.req.GetDmlChannels()[0]) + results, _, _, err := statisticStreaming(ctx, s.qs.metaReplica, s.iReq.GetCollectionID(), + s.iReq.GetPartitionIDs(), s.req.GetDmlChannels()[0]) if err != nil { - log.Debug("failed to statistic on streaming data", zap.Int64("msgID", s.id), + log.Ctx(ctx).Warn("failed to statistic on streaming data", zap.Int64("msgID", s.id), zap.Int64("collectionID", s.iReq.GetCollectionID()), zap.Error(err)) return err } @@ -63,7 +65,8 @@ func (s *statistics) statisticOnStreaming() error { func (s *statistics) statisticOnHistorical() error { // check ctx timeout - if !funcutil.CheckCtxValid(s.ctx) { + ctx := s.ctx + if !funcutil.CheckCtxValid(ctx) { return errors.New("get statistics context timeout") } @@ -76,13 +79,13 @@ func (s *statistics) statisticOnHistorical() error { s.qs.collection.RLock() // locks the collectionPtr defer s.qs.collection.RUnlock() if _, released := s.qs.collection.getReleaseTime(); released { - log.Debug("collection release before do statistics", zap.Int64("msgID", s.id), + log.Ctx(ctx).Debug("collection release before do statistics", zap.Int64("msgID", s.id), zap.Int64("collectionID", s.iReq.GetCollectionID())) return fmt.Errorf("statistic failed, collection has been released, collectionID = %d", s.iReq.GetCollectionID()) } segmentIDs := s.req.GetSegmentIDs() - results, _, _, err := statisticHistorical(s.qs.metaReplica, s.iReq.GetCollectionID(), s.iReq.GetPartitionIDs(), segmentIDs) + results, _, _, err := statisticHistorical(ctx, s.qs.metaReplica, s.iReq.GetCollectionID(), s.iReq.GetPartitionIDs(), segmentIDs) if err != nil { return err } diff --git a/internal/querynode/validate.go b/internal/querynode/validate.go index e1bd60c0d..e1d5b13c9 100644 --- a/internal/querynode/validate.go +++ b/internal/querynode/validate.go @@ -17,6 +17,7 @@ package querynode import ( + "context" "errors" "fmt" @@ -26,7 +27,7 @@ import ( ) // TODO: merge validate? -func validateOnHistoricalReplica(replica ReplicaInterface, collectionID UniqueID, partitionIDs []UniqueID, segmentIDs []UniqueID) ([]UniqueID, []UniqueID, error) { +func validateOnHistoricalReplica(ctx context.Context, replica ReplicaInterface, collectionID UniqueID, partitionIDs []UniqueID, segmentIDs []UniqueID) ([]UniqueID, []UniqueID, error) { var err error var searchPartIDs []UniqueID @@ -46,7 +47,7 @@ func validateOnHistoricalReplica(replica ReplicaInterface, collectionID UniqueID } } - log.Debug("read target partitions", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", searchPartIDs)) + log.Ctx(ctx).Debug("read target partitions", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", searchPartIDs)) col, err2 := replica.getCollectionByID(collectionID) if err2 != nil { return searchPartIDs, segmentIDs, err2 @@ -86,7 +87,7 @@ func validateOnHistoricalReplica(replica ReplicaInterface, collectionID UniqueID return searchPartIDs, newSegmentIDs, nil } -func validateOnStreamReplica(replica ReplicaInterface, collectionID UniqueID, partitionIDs []UniqueID, vChannel Channel) ([]UniqueID, []UniqueID, error) { +func validateOnStreamReplica(ctx context.Context, replica ReplicaInterface, collectionID UniqueID, partitionIDs []UniqueID, vChannel Channel) ([]UniqueID, []UniqueID, error) { var err error var searchPartIDs []UniqueID var segmentIDs []UniqueID @@ -107,7 +108,7 @@ func validateOnStreamReplica(replica ReplicaInterface, collectionID UniqueID, pa } } - log.Debug("read target partitions", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", searchPartIDs)) + log.Ctx(ctx).Debug("read target partitions", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", searchPartIDs)) col, err2 := replica.getCollectionByID(collectionID) if err2 != nil { return searchPartIDs, segmentIDs, err2 @@ -123,7 +124,7 @@ func validateOnStreamReplica(replica ReplicaInterface, collectionID UniqueID, pa } segmentIDs, err = replica.getSegmentIDsByVChannel(searchPartIDs, vChannel, segmentTypeGrowing) - log.Debug("validateOnStreamReplica getSegmentIDsByVChannel", + log.Ctx(ctx).Debug("validateOnStreamReplica getSegmentIDsByVChannel", zap.Any("collectionID", collectionID), zap.Any("vChannel", vChannel), zap.Any("partitionIDs", searchPartIDs), diff --git a/internal/querynode/validate_test.go b/internal/querynode/validate_test.go index 2e5f4cfff..234f92d29 100644 --- a/internal/querynode/validate_test.go +++ b/internal/querynode/validate_test.go @@ -30,35 +30,35 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) { t.Run("test normal validate", func(t *testing.T) { his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) - _, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID}) + _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID}) assert.NoError(t, err) }) t.Run("test normal validate2", func(t *testing.T) { his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) - _, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID}) + _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID}) assert.NoError(t, err) }) t.Run("test validate non-existent collection", func(t *testing.T) { his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) - _, _, err = validateOnHistoricalReplica(his, defaultCollectionID+1, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID}) + _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID+1, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID}) assert.Error(t, err) }) t.Run("test validate non-existent partition", func(t *testing.T) { his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) - _, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{defaultPartitionID + 1}, []UniqueID{defaultSegmentID}) + _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{defaultPartitionID + 1}, []UniqueID{defaultSegmentID}) assert.Error(t, err) }) t.Run("test validate non-existent segment", func(t *testing.T) { his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) - _, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID + 1}) + _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID + 1}) assert.Error(t, err) }) @@ -79,7 +79,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) { assert.NoError(t, err) // Scenario: search for a segment (segmentID = defaultSegmentID + 1, partitionID = defaultPartitionID+1) // that does not belong to defaultPartition - _, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID + 1}) + _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID + 1}) assert.Error(t, err) }) @@ -88,7 +88,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) { assert.NoError(t, err) err = his.removePartition(defaultPartitionID) assert.NoError(t, err) - _, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID}) + _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID}) assert.Error(t, err) }) @@ -100,7 +100,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) { col.setLoadType(loadTypePartition) err = his.removePartition(defaultPartitionID) assert.NoError(t, err) - _, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID}) + _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID}) assert.Error(t, err) }) @@ -112,7 +112,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) { col.setLoadType(loadTypeCollection) err = his.removePartition(defaultPartitionID) assert.NoError(t, err) - _, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID}) + _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID}) assert.NoError(t, err) }) } diff --git a/internal/rootcoord/root_coord.go b/internal/rootcoord/root_coord.go index 0d8dbf7f0..216c364fc 100644 --- a/internal/rootcoord/root_coord.go +++ b/internal/rootcoord/root_coord.go @@ -1580,7 +1580,7 @@ func (c *Core) DescribeCollection(ctx context.Context, in *milvuspb.DescribeColl } tr := timerecord.NewTimeRecorder("DescribeCollection") - log.Debug("DescribeCollection", zap.String("role", typeutil.RootCoordRole), + log.Ctx(ctx).Debug("DescribeCollection", zap.String("role", typeutil.RootCoordRole), zap.String("collection name", in.CollectionName), zap.Int64("id", in.CollectionID), zap.Int64("msgID", in.Base.MsgID)) t := &DescribeCollectionReqTask{ baseReqTask: baseReqTask{ @@ -1592,14 +1592,14 @@ func (c *Core) DescribeCollection(ctx context.Context, in *milvuspb.DescribeColl } err := executeTask(t) if err != nil { - log.Error("DescribeCollection failed", zap.String("role", typeutil.RootCoordRole), + log.Ctx(ctx).Warn("DescribeCollection failed", zap.String("role", typeutil.RootCoordRole), zap.String("collection name", in.CollectionName), zap.Int64("id", in.CollectionID), zap.Int64("msgID", in.Base.MsgID), zap.Error(err)) metrics.RootCoordDDLReqCounter.WithLabelValues("DescribeCollection", metrics.FailLabel).Inc() return &milvuspb.DescribeCollectionResponse{ Status: failStatus(commonpb.ErrorCode_UnexpectedError, "DescribeCollection failed: "+err.Error()), }, nil } - log.Debug("DescribeCollection success", zap.String("role", typeutil.RootCoordRole), + log.Ctx(ctx).Debug("DescribeCollection success", zap.String("role", typeutil.RootCoordRole), zap.String("collection name", in.CollectionName), zap.Int64("id", in.CollectionID), zap.Int64("msgID", in.Base.MsgID)) metrics.RootCoordDDLReqCounter.WithLabelValues("DescribeCollection", metrics.SuccessLabel).Inc() diff --git a/internal/rootcoord/task.go b/internal/rootcoord/task.go index fa735a07b..d9181d481 100644 --- a/internal/rootcoord/task.go +++ b/internal/rootcoord/task.go @@ -274,6 +274,7 @@ func (t *CreateCollectionReqTask) Execute(ctx context.Context) error { return err } + log.NewMetaLogger().WithCollectionMeta(&collInfo).WithOperation(log.CreateCollection).WithTSO(ts).Info() return nil } @@ -391,6 +392,8 @@ func (t *DropCollectionReqTask) Execute(ctx context.Context) error { return err } + log.NewMetaLogger().WithCollectionID(collMeta.CollectionID). + WithCollectionName(collMeta.Name).WithTSO(ts).WithOperation(log.DropCollection).Info() return nil } @@ -594,6 +597,8 @@ func (t *CreatePartitionReqTask) Execute(ctx context.Context) error { return err } + log.NewMetaLogger().WithCollectionName(collMeta.Name).WithCollectionID(collMeta.CollectionID). + WithPartitionID(partID).WithPartitionName(t.Req.PartitionName).WithTSO(ts).WithOperation(log.CreatePartition).Info() return nil } @@ -691,6 +696,8 @@ func (t *DropPartitionReqTask) Execute(ctx context.Context) error { // return err //} + log.NewMetaLogger().WithCollectionID(collInfo.CollectionID).WithCollectionName(collInfo.Name). + WithPartitionName(t.Req.PartitionName).WithTSO(ts).WithOperation(log.DropCollection).Info() return nil } @@ -1038,6 +1045,11 @@ func (t *CreateIndexReqTask) Execute(ctx context.Context) error { } } + idxMeta, err := t.core.MetaTable.GetIndexByID(indexID) + if err == nil { + log.NewMetaLogger().WithIndexMeta(idxMeta).WithOperation(log.CreateIndex).WithTSO(createTS).Info() + } + return nil } @@ -1098,6 +1110,15 @@ func (t *DropIndexReqTask) Execute(ctx context.Context) error { if err := t.core.MetaTable.MarkIndexDeleted(t.Req.CollectionName, t.Req.FieldName, t.Req.IndexName); err != nil { return err } + deleteTS, err := t.core.TSOAllocator(1) + if err != nil { + return err + } + log.NewMetaLogger().WithCollectionName(t.Req.CollectionName). + WithFieldName(t.Req.FieldName). + WithIndexName(t.Req.IndexName). + WithOperation(log.DropIndex).WithTSO(deleteTS).Info() + return nil } @@ -1127,6 +1148,7 @@ func (t *CreateAliasReqTask) Execute(ctx context.Context) error { return fmt.Errorf("meta table add alias failed, error = %w", err) } + log.NewMetaLogger().WithCollectionName(t.Req.CollectionName).WithAlias(t.Req.Alias).WithTSO(ts).WithOperation(log.CreateCollectionAlias).Info() return nil } @@ -1156,7 +1178,12 @@ func (t *DropAliasReqTask) Execute(ctx context.Context) error { return fmt.Errorf("meta table drop alias failed, error = %w", err) } - return t.core.ExpireMetaCache(ctx, []string{t.Req.Alias}, InvalidCollectionID, ts) + if err := t.core.ExpireMetaCache(ctx, []string{t.Req.Alias}, InvalidCollectionID, ts); err != nil { + return err + } + + log.NewMetaLogger().WithAlias(t.Req.Alias).WithOperation(log.DropCollectionAlias).WithTSO(ts).Info() + return nil } // AlterAliasReqTask alter alias request task @@ -1185,5 +1212,11 @@ func (t *AlterAliasReqTask) Execute(ctx context.Context) error { return fmt.Errorf("meta table alter alias failed, error = %w", err) } - return t.core.ExpireMetaCache(ctx, []string{t.Req.Alias}, InvalidCollectionID, ts) + if err := t.core.ExpireMetaCache(ctx, []string{t.Req.Alias}, InvalidCollectionID, ts); err != nil { + return nil + } + + log.NewMetaLogger().WithCollectionName(t.Req.CollectionName). + WithAlias(t.Req.Alias).WithOperation(log.AlterCollectionAlias).WithTSO(ts).Info() + return nil } diff --git a/internal/util/logutil/grpc_interceptor.go b/internal/util/logutil/grpc_interceptor.go new file mode 100644 index 000000000..fd48fa9da --- /dev/null +++ b/internal/util/logutil/grpc_interceptor.go @@ -0,0 +1,78 @@ +package logutil + +import ( + "context" + + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" + "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/util/trace" + "go.uber.org/zap/zapcore" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +const ( + logLevelRPCMetaKey = "log_level" + clientRequestIDKey = "client_request_id" +) + +// UnaryTraceLoggerInterceptor adds a traced logger in unary rpc call ctx +func UnaryTraceLoggerInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + newctx := withLevelAndTrace(ctx) + return handler(newctx, req) +} + +// StreamTraceLoggerInterceptor add a traced logger in stream rpc call ctx +func StreamTraceLoggerInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + ctx := ss.Context() + newctx := withLevelAndTrace(ctx) + wrappedStream := grpc_middleware.WrapServerStream(ss) + wrappedStream.WrappedContext = newctx + return handler(srv, wrappedStream) +} + +func withLevelAndTrace(ctx context.Context) context.Context { + newctx := ctx + var traceID string + if md, ok := metadata.FromIncomingContext(ctx); ok { + levels := md.Get(logLevelRPCMetaKey) + // get log level + if len(levels) >= 1 { + level := zapcore.DebugLevel + if err := level.UnmarshalText([]byte(levels[0])); err != nil { + newctx = ctx + } else { + switch level { + case zapcore.DebugLevel: + newctx = log.WithDebugLevel(ctx) + case zapcore.InfoLevel: + newctx = log.WithInfoLevel(ctx) + case zapcore.WarnLevel: + newctx = log.WithWarnLevel(ctx) + case zapcore.ErrorLevel: + newctx = log.WithErrorLevel(ctx) + case zapcore.FatalLevel: + newctx = log.WithFatalLevel(ctx) + default: + newctx = ctx + } + } + // inject log level to outgoing meta + newctx = metadata.AppendToOutgoingContext(newctx, logLevelRPCMetaKey, level.String()) + } + // client request id + requestID := md.Get(clientRequestIDKey) + if len(requestID) >= 1 { + traceID = requestID[0] + // inject traceid in order to pass client request id + newctx = metadata.AppendToOutgoingContext(newctx, clientRequestIDKey, traceID) + } + } + if traceID == "" { + traceID, _, _ = trace.InfoFromContext(newctx) + } + if traceID != "" { + newctx = log.WithTraceID(newctx, traceID) + } + return newctx +} diff --git a/internal/util/logutil/grpc_interceptor_test.go b/internal/util/logutil/grpc_interceptor_test.go new file mode 100644 index 000000000..b49afb043 --- /dev/null +++ b/internal/util/logutil/grpc_interceptor_test.go @@ -0,0 +1,64 @@ +package logutil + +import ( + "context" + "testing" + + "github.com/milvus-io/milvus/internal/log" + "github.com/stretchr/testify/assert" + "go.uber.org/zap/zapcore" + "google.golang.org/grpc/metadata" +) + +func TestCtxWithLevelAndTrace(t *testing.T) { + t.Run("debug level", func(t *testing.T) { + ctx := withMetaData(context.TODO(), zapcore.DebugLevel) + newctx := withLevelAndTrace(ctx) + + assert.Equal(t, log.Ctx(log.WithDebugLevel(context.TODO())), log.Ctx(newctx)) + }) + + t.Run("info level", func(t *testing.T) { + ctx := context.TODO() + newctx := withLevelAndTrace(withMetaData(ctx, zapcore.InfoLevel)) + assert.Equal(t, log.Ctx(log.WithInfoLevel(ctx)), log.Ctx(newctx)) + }) + + t.Run("warn level", func(t *testing.T) { + ctx := context.TODO() + newctx := withLevelAndTrace(withMetaData(ctx, zapcore.WarnLevel)) + assert.Equal(t, log.Ctx(log.WithWarnLevel(ctx)), log.Ctx(newctx)) + }) + + t.Run("error level", func(t *testing.T) { + ctx := context.TODO() + newctx := withLevelAndTrace(withMetaData(ctx, zapcore.ErrorLevel)) + assert.Equal(t, log.Ctx(log.WithErrorLevel(ctx)), log.Ctx(newctx)) + }) + + t.Run("fatal level", func(t *testing.T) { + ctx := context.TODO() + newctx := withLevelAndTrace(withMetaData(ctx, zapcore.FatalLevel)) + assert.Equal(t, log.Ctx(log.WithFatalLevel(ctx)), log.Ctx(newctx)) + }) + + t.Run(("pass through variables"), func(t *testing.T) { + md := metadata.New(map[string]string{ + logLevelRPCMetaKey: zapcore.ErrorLevel.String(), + clientRequestIDKey: "client-req-id", + }) + ctx := metadata.NewIncomingContext(context.TODO(), md) + newctx := withLevelAndTrace(ctx) + md, ok := metadata.FromOutgoingContext(newctx) + assert.True(t, ok) + assert.Equal(t, "client-req-id", md.Get(clientRequestIDKey)[0]) + assert.Equal(t, zapcore.ErrorLevel.String(), md.Get(logLevelRPCMetaKey)[0]) + }) +} + +func withMetaData(ctx context.Context, level zapcore.Level) context.Context { + md := metadata.New(map[string]string{ + logLevelRPCMetaKey: level.String(), + }) + return metadata.NewIncomingContext(context.TODO(), md) +} diff --git a/internal/util/logutil/logutil.go b/internal/util/logutil/logutil.go index 9e64880a6..c3900657d 100644 --- a/internal/util/logutil/logutil.go +++ b/internal/util/logutil/logutil.go @@ -141,7 +141,7 @@ var once sync.Once func SetupLogger(cfg *log.Config) { once.Do(func() { // Initialize logger. - logger, p, err := log.InitLogger(cfg, zap.AddStacktrace(zap.ErrorLevel), zap.AddCallerSkip(1)) + logger, p, err := log.InitLogger(cfg, zap.AddStacktrace(zap.ErrorLevel)) if err == nil { log.ReplaceGlobals(logger, p) } else { @@ -167,54 +167,10 @@ func SetupLogger(cfg *log.Config) { }) } -type logKey int - -const logCtxKey logKey = iota - -// WithField adds given kv field to the logger in ctx -func WithField(ctx context.Context, key string, value string) context.Context { - logger := log.L() - if ctxLogger, ok := ctx.Value(logCtxKey).(*zap.Logger); ok { - logger = ctxLogger - } - - return context.WithValue(ctx, logCtxKey, logger.With(zap.String(key, value))) -} - -// WithReqID adds given reqID field to the logger in ctx -func WithReqID(ctx context.Context, reqID int64) context.Context { - logger := log.L() - if ctxLogger, ok := ctx.Value(logCtxKey).(*zap.Logger); ok { - logger = ctxLogger - } - - return context.WithValue(ctx, logCtxKey, logger.With(zap.Int64("reqID", reqID))) -} - -// WithModule adds given module field to the logger in ctx -func WithModule(ctx context.Context, module string) context.Context { - logger := log.L() - if ctxLogger, ok := ctx.Value(logCtxKey).(*zap.Logger); ok { - logger = ctxLogger - } - - return context.WithValue(ctx, logCtxKey, logger.With(zap.String("module", module))) -} - -func WithLogger(ctx context.Context, logger *zap.Logger) context.Context { - if logger == nil { - logger = log.L() - } - return context.WithValue(ctx, logCtxKey, logger) -} - func Logger(ctx context.Context) *zap.Logger { - if ctxLogger, ok := ctx.Value(logCtxKey).(*zap.Logger); ok { - return ctxLogger - } - return log.L() + return log.Ctx(ctx).Logger } -func BgLogger() *zap.Logger { - return log.L() +func WithModule(ctx context.Context, module string) context.Context { + return log.WithModule(ctx, module) } diff --git a/internal/util/timerecord/time_recorder.go b/internal/util/timerecord/time_recorder.go index fe405abf2..6061be3d5 100644 --- a/internal/util/timerecord/time_recorder.go +++ b/internal/util/timerecord/time_recorder.go @@ -25,6 +25,7 @@ type TimeRecorder struct { header string start time.Time last time.Time + ctx context.Context } // NewTimeRecorder creates a new TimeRecorder @@ -55,18 +56,30 @@ func (tr *TimeRecorder) ElapseSpan() time.Duration { // Record calculates the time span from previous Record call func (tr *TimeRecorder) Record(msg string) time.Duration { span := tr.RecordSpan() - tr.printTimeRecord(msg, span) + tr.printTimeRecord(context.TODO(), msg, span) + return span +} + +func (tr *TimeRecorder) CtxRecord(ctx context.Context, msg string) time.Duration { + span := tr.RecordSpan() + tr.printTimeRecord(ctx, msg, span) return span } // Elapse calculates the time span from the beginning of this TimeRecorder func (tr *TimeRecorder) Elapse(msg string) time.Duration { span := tr.ElapseSpan() - tr.printTimeRecord(msg, span) + tr.printTimeRecord(context.TODO(), msg, span) + return span +} + +func (tr *TimeRecorder) CtxElapse(ctx context.Context, msg string) time.Duration { + span := tr.ElapseSpan() + tr.printTimeRecord(ctx, msg, span) return span } -func (tr *TimeRecorder) printTimeRecord(msg string, span time.Duration) { +func (tr *TimeRecorder) printTimeRecord(ctx context.Context, msg string, span time.Duration) { str := "" if tr.header != "" { str += tr.header + ": " @@ -75,7 +88,7 @@ func (tr *TimeRecorder) printTimeRecord(msg string, span time.Duration) { str += " (" str += strconv.Itoa(int(span.Milliseconds())) str += "ms)" - log.Debug(str) + log.Ctx(ctx).Debug(str) } // LongTermChecker checks we receive at least one msg in d duration. If not, checker -- GitLab