未验证 提交 0c9a10e8 编写于 作者: Z Zach 提交者: GitHub

1. refine logging interfaces (#18692)

2. adjust logs for query/search requests
Signed-off-by: NZach41 <zongmei.zhang@zilliz.com>
Signed-off-by: NZach41 <zongmei.zhang@zilliz.com>
上级 ce434b49
......@@ -48,7 +48,6 @@ import (
"github.com/milvus-io/milvus/internal/rootcoord"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/healthz"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/metricsinfo"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/trace"
......@@ -238,7 +237,7 @@ func (mr *MilvusRoles) runDataCoord(ctx context.Context, localMsg bool) *compone
factory := dependency.NewFactory(localMsg)
dctx := logutil.WithModule(ctx, "DataCoord")
dctx := log.WithModule(ctx, "DataCoord")
var err error
ds, err = components.NewDataCoord(dctx, factory)
if err != nil {
......@@ -406,7 +405,7 @@ func (mr *MilvusRoles) Run(local bool, alias string) {
var pn *components.Proxy
if mr.EnableProxy {
pctx := logutil.WithModule(ctx, "Proxy")
pctx := log.WithModule(ctx, "Proxy")
pn = mr.runProxy(pctx, local, alias)
if pn != nil {
defer pn.Stop()
......
......@@ -25,6 +25,7 @@ import (
"sync"
"time"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/zap"
......@@ -149,8 +150,12 @@ func (s *Server) startGrpcLoop(grpcPort int) {
grpc.KeepaliveParams(kasp),
grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize),
grpc.MaxSendMsgSize(Params.ServerMaxSendSize),
grpc.UnaryInterceptor(ot.UnaryServerInterceptor(opts...)),
grpc.StreamInterceptor(ot.StreamServerInterceptor(opts...)))
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
ot.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor)))
datapb.RegisterDataCoordServer(s.grpcServer, s)
go funcutil.CheckGrpcReady(ctx, s.grpcErrChan)
if err := s.grpcServer.Serve(lis); err != nil {
......
......@@ -26,6 +26,7 @@ import (
"sync"
"time"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/zap"
......@@ -44,6 +45,7 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/trace"
......@@ -134,8 +136,12 @@ func (s *Server) startGrpcLoop(grpcPort int) {
grpc.KeepaliveParams(kasp),
grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize),
grpc.MaxSendMsgSize(Params.ServerMaxSendSize),
grpc.UnaryInterceptor(ot.UnaryServerInterceptor(opts...)),
grpc.StreamInterceptor(ot.StreamServerInterceptor(opts...)))
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
ot.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor)))
datapb.RegisterDataNodeServer(s.grpcServer, s)
ctx, cancel := context.WithCancel(s.ctx)
......
......@@ -29,6 +29,7 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
dcc "github.com/milvus-io/milvus/internal/distributed/datacoord/client"
"github.com/milvus-io/milvus/internal/indexcoord"
......@@ -42,6 +43,7 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/trace"
"github.com/milvus-io/milvus/internal/util/typeutil"
......@@ -277,8 +279,12 @@ func (s *Server) startGrpcLoop(grpcPort int) {
grpc.KeepaliveParams(kasp),
grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize),
grpc.MaxSendMsgSize(Params.ServerMaxSendSize),
grpc.UnaryInterceptor(ot.UnaryServerInterceptor(opts...)),
grpc.StreamInterceptor(ot.StreamServerInterceptor(opts...)))
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
ot.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor)))
indexpb.RegisterIndexCoordServer(s.grpcServer, s)
go funcutil.CheckGrpcReady(ctx, s.grpcErrChan)
......
......@@ -30,6 +30,7 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
"github.com/milvus-io/milvus/internal/indexnode"
......@@ -42,6 +43,7 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/trace"
"github.com/milvus-io/milvus/internal/util/typeutil"
......@@ -108,8 +110,12 @@ func (s *Server) startGrpcLoop(grpcPort int) {
grpc.KeepaliveParams(kasp),
grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize),
grpc.MaxSendMsgSize(Params.ServerMaxSendSize),
grpc.UnaryInterceptor(ot.UnaryServerInterceptor(opts...)),
grpc.StreamInterceptor(ot.StreamServerInterceptor(opts...)))
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
ot.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor)))
indexpb.RegisterIndexNodeServer(s.grpcServer, s)
go funcutil.CheckGrpcReady(ctx, s.grpcErrChan)
if err := s.grpcServer.Serve(lis); err != nil {
......
......@@ -62,6 +62,7 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/trace"
"github.com/milvus-io/milvus/internal/util/typeutil"
......@@ -171,10 +172,12 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) {
ot.UnaryServerInterceptor(opts...),
grpc_auth.UnaryServerInterceptor(proxy.AuthenticationInterceptor),
proxy.UnaryServerInterceptor(proxy.PrivilegeInterceptor),
logutil.UnaryTraceLoggerInterceptor,
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...),
grpc_auth.StreamServerInterceptor(proxy.AuthenticationInterceptor))),
grpc_auth.StreamServerInterceptor(proxy.AuthenticationInterceptor),
logutil.StreamTraceLoggerInterceptor)),
}
if Params.TLSMode == 1 {
......@@ -261,10 +264,12 @@ func (s *Server) startInternalGrpc(grpcPort int, errChan chan error) {
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
ot.UnaryServerInterceptor(opts...),
grpc_auth.UnaryServerInterceptor(proxy.AuthenticationInterceptor),
logutil.UnaryTraceLoggerInterceptor,
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...),
grpc_auth.StreamServerInterceptor(proxy.AuthenticationInterceptor),
logutil.StreamTraceLoggerInterceptor,
)),
)
proxypb.RegisterProxyServer(s.grpcInternalServer, s)
......
......@@ -24,6 +24,7 @@ import (
"sync"
"time"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/zap"
......@@ -43,6 +44,7 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/trace"
"github.com/milvus-io/milvus/internal/util/typeutil"
......@@ -259,8 +261,12 @@ func (s *Server) startGrpcLoop(grpcPort int) {
grpc.KeepaliveParams(kasp),
grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize),
grpc.MaxSendMsgSize(Params.ServerMaxSendSize),
grpc.UnaryInterceptor(ot.UnaryServerInterceptor(opts...)),
grpc.StreamInterceptor(ot.StreamServerInterceptor(opts...)))
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
ot.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor)))
querypb.RegisterQueryCoordServer(s.grpcServer, s)
go funcutil.CheckGrpcReady(ctx, s.grpcErrChan)
......
......@@ -25,6 +25,7 @@ import (
"sync"
"time"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/zap"
......@@ -41,6 +42,7 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/trace"
......@@ -179,8 +181,12 @@ func (s *Server) startGrpcLoop(grpcPort int) {
grpc.KeepaliveParams(kasp),
grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize),
grpc.MaxSendMsgSize(Params.ServerMaxSendSize),
grpc.UnaryInterceptor(ot.UnaryServerInterceptor(opts...)),
grpc.StreamInterceptor(ot.StreamServerInterceptor(opts...)))
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
ot.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor)))
querypb.RegisterQueryNodeServer(s.grpcServer, s)
ctx, cancel := context.WithCancel(s.ctx)
......
......@@ -26,6 +26,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/indexpb"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/zap"
......@@ -45,6 +46,7 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/trace"
......@@ -258,8 +260,12 @@ func (s *Server) startGrpcLoop(port int) {
grpc.KeepaliveParams(kasp),
grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize),
grpc.MaxSendMsgSize(Params.ServerMaxSendSize),
grpc.UnaryInterceptor(ot.UnaryServerInterceptor(opts...)),
grpc.StreamInterceptor(ot.StreamServerInterceptor(opts...)))
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
ot.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor)))
rootcoordpb.RegisterRootCoordServer(s.grpcServer, s)
go funcutil.CheckGrpcReady(ctx, s.grpcErrChan)
......
......@@ -14,10 +14,18 @@
package log
import (
"context"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
type ctxLogKeyType struct{}
var (
CtxLogKey = ctxLogKeyType{}
)
// Debug logs a message at DebugLevel. The message includes any fields passed
// at the log site, as well as any fields accumulated on the logger.
func Debug(msg string, fields ...zap.Field) {
......@@ -107,3 +115,95 @@ func SetLevel(l zapcore.Level) {
func GetLevel() zapcore.Level {
return _globalP.Load().(*ZapProperties).Level.Level()
}
// WithTraceID returns a context with trace_id attached
func WithTraceID(ctx context.Context, traceID string) context.Context {
return WithFields(ctx, zap.String("traceID", traceID))
}
// WithReqID adds given reqID field to the logger in ctx
func WithReqID(ctx context.Context, reqID int64) context.Context {
fields := []zap.Field{zap.Int64("reqID", reqID)}
return WithFields(ctx, fields...)
}
// WithModule adds given module field to the logger in ctx
func WithModule(ctx context.Context, module string) context.Context {
fields := []zap.Field{zap.String("module", module)}
return WithFields(ctx, fields...)
}
// WithFields returns a context with fields attached
func WithFields(ctx context.Context, fields ...zap.Field) context.Context {
var zlogger *zap.Logger
if ctxLogger, ok := ctx.Value(CtxLogKey).(*MLogger); ok {
zlogger = ctxLogger.Logger
} else {
zlogger = ctxL()
}
mLogger := &MLogger{
Logger: zlogger.With(fields...),
}
return context.WithValue(ctx, CtxLogKey, mLogger)
}
// Ctx returns a logger which will log contextual messages attached in ctx
func Ctx(ctx context.Context) *MLogger {
if ctx == nil {
return &MLogger{Logger: ctxL()}
}
if ctxLogger, ok := ctx.Value(CtxLogKey).(*MLogger); ok {
return ctxLogger
}
return &MLogger{Logger: ctxL()}
}
// withLogLevel returns ctx with a leveled logger, notes that it will overwrite logger previous attached!
func withLogLevel(ctx context.Context, level zapcore.Level) context.Context {
var zlogger *zap.Logger
switch level {
case zap.DebugLevel:
zlogger = debugL()
case zap.InfoLevel:
zlogger = infoL()
case zap.WarnLevel:
zlogger = warnL()
case zap.ErrorLevel:
zlogger = errorL()
case zap.FatalLevel:
zlogger = fatalL()
default:
zlogger = L()
}
return context.WithValue(ctx, CtxLogKey, &MLogger{Logger: zlogger})
}
// WithDebugLevel returns context with a debug level enabled logger.
// Notes that it will overwrite previous attached logger within context
func WithDebugLevel(ctx context.Context) context.Context {
return withLogLevel(ctx, zapcore.DebugLevel)
}
// WithInfoLevel returns context with a info level enabled logger.
// Notes that it will overwrite previous attached logger within context
func WithInfoLevel(ctx context.Context) context.Context {
return withLogLevel(ctx, zapcore.InfoLevel)
}
// WithWarnLevel returns context with a warning level enabled logger.
// Notes that it will overwrite previous attached logger within context
func WithWarnLevel(ctx context.Context) context.Context {
return withLogLevel(ctx, zapcore.WarnLevel)
}
// WithErrorLevel returns context with a error level enabled logger.
// Notes that it will overwrite previous attached logger within context
func WithErrorLevel(ctx context.Context) context.Context {
return withLogLevel(ctx, zapcore.ErrorLevel)
}
// WithFatalLevel returns context with a fatal level enabled logger.
// Notes that it will overwrite previous attached logger within context
func WithFatalLevel(ctx context.Context) context.Context {
return withLogLevel(ctx, zapcore.FatalLevel)
}
......@@ -33,6 +33,7 @@ import (
"fmt"
"net/http"
"os"
"sync"
"sync/atomic"
"errors"
......@@ -45,12 +46,20 @@ import (
)
var _globalL, _globalP, _globalS, _globalR atomic.Value
var (
_globalLevelLogger sync.Map
)
var rateLimiter *utils.ReconfigurableRateLimiter
func init() {
l, p := newStdLogger()
replaceLeveledLoggers(l)
_globalL.Store(l)
_globalP.Store(p)
s := _globalL.Load().(*zap.Logger).Sugar()
_globalS.Store(s)
......@@ -80,7 +89,19 @@ func InitLogger(cfg *Config, opts ...zap.Option) (*zap.Logger, *ZapProperties, e
}
output = stdOut
}
return InitLoggerWithWriteSyncer(cfg, output, opts...)
debugCfg := *cfg
debugCfg.Level = "debug"
debugL, r, err := InitLoggerWithWriteSyncer(&debugCfg, output, opts...)
if err != nil {
return nil, nil, err
}
replaceLeveledLoggers(debugL)
level := zapcore.DebugLevel
if err := level.UnmarshalText([]byte(cfg.Level)); err != nil {
return nil, nil, err
}
r.Level.SetLevel(level)
return debugL.WithOptions(zap.IncreaseLevel(level), zap.AddCallerSkip(1)), r, nil
}
// InitTestLogger initializes a logger for unit tests
......@@ -136,7 +157,7 @@ func initFileLog(cfg *FileLogConfig) (*lumberjack.Logger, error) {
func newStdLogger() (*zap.Logger, *ZapProperties) {
conf := &Config{Level: "debug", File: FileLogConfig{}}
lg, r, _ := InitLogger(conf, zap.AddCallerSkip(1))
lg, r, _ := InitLogger(conf)
return lg, r
}
......@@ -157,6 +178,40 @@ func R() *utils.ReconfigurableRateLimiter {
return _globalR.Load().(*utils.ReconfigurableRateLimiter)
}
func ctxL() *zap.Logger {
level := _globalP.Load().(*ZapProperties).Level.Level()
l, ok := _globalLevelLogger.Load(level)
if !ok {
return L()
}
return l.(*zap.Logger)
}
func debugL() *zap.Logger {
v, _ := _globalLevelLogger.Load(zapcore.DebugLevel)
return v.(*zap.Logger)
}
func infoL() *zap.Logger {
v, _ := _globalLevelLogger.Load(zapcore.InfoLevel)
return v.(*zap.Logger)
}
func warnL() *zap.Logger {
v, _ := _globalLevelLogger.Load(zapcore.WarnLevel)
return v.(*zap.Logger)
}
func errorL() *zap.Logger {
v, _ := _globalLevelLogger.Load(zapcore.ErrorLevel)
return v.(*zap.Logger)
}
func fatalL() *zap.Logger {
v, _ := _globalLevelLogger.Load(zapcore.FatalLevel)
return v.(*zap.Logger)
}
// ReplaceGlobals replaces the global Logger and SugaredLogger.
// It's safe for concurrent use.
func ReplaceGlobals(logger *zap.Logger, props *ZapProperties) {
......@@ -165,11 +220,31 @@ func ReplaceGlobals(logger *zap.Logger, props *ZapProperties) {
_globalP.Store(props)
}
func replaceLeveledLoggers(debugLogger *zap.Logger) {
levels := []zapcore.Level{zapcore.DebugLevel, zapcore.InfoLevel, zapcore.WarnLevel, zapcore.ErrorLevel,
zapcore.DPanicLevel, zapcore.PanicLevel, zapcore.FatalLevel}
for _, level := range levels {
levelL := debugLogger.WithOptions(zap.IncreaseLevel(level))
_globalLevelLogger.Store(level, levelL)
}
}
// Sync flushes any buffered log entries.
func Sync() error {
err := L().Sync()
if err != nil {
if err := L().Sync(); err != nil {
return err
}
return S().Sync()
if err := S().Sync(); err != nil {
return err
}
var reterr error
_globalLevelLogger.Range(func(key, val interface{}) bool {
l := val.(*zap.Logger)
if err := l.Sync(); err != nil {
reterr = err
return false
}
return true
})
return reterr
}
......@@ -104,6 +104,11 @@ func TestInvalidFileConfig(t *testing.T) {
_, _, err := InitLogger(conf)
assert.Equal(t, "can't use directory as log file name", err.Error())
// invalid level
conf = &Config{Level: "debuge", DisableTimestamp: true}
_, _, err = InitLogger(conf)
assert.Error(t, err)
}
func TestLevelGetterAndSetter(t *testing.T) {
......@@ -235,3 +240,70 @@ func TestRatedLog(t *testing.T) {
assert.True(t, success)
Sync()
}
func TestLeveledLogger(t *testing.T) {
ts := newTestLogSpy(t)
conf := &Config{Level: "debug", DisableTimestamp: true, DisableCaller: true}
logger, _, _ := InitTestLogger(ts, conf)
replaceLeveledLoggers(logger)
debugL().Debug("DEBUG LOG")
debugL().Info("INFO LOG")
debugL().Warn("WARN LOG")
debugL().Error("ERROR LOG")
Sync()
ts.assertMessageContainAny(`[DEBUG] ["DEBUG LOG"]`)
ts.assertMessageContainAny(`[INFO] ["INFO LOG"]`)
ts.assertMessageContainAny(`[WARN] ["WARN LOG"]`)
ts.assertMessageContainAny(`[ERROR] ["ERROR LOG"]`)
ts.CleanBuffer()
infoL().Debug("DEBUG LOG")
infoL().Info("INFO LOG")
infoL().Warn("WARN LOG")
infoL().Error("ERROR LOG")
Sync()
ts.assertMessagesNotContains(`[DEBUG] ["DEBUG LOG"]`)
ts.assertMessageContainAny(`[INFO] ["INFO LOG"]`)
ts.assertMessageContainAny(`[WARN] ["WARN LOG"]`)
ts.assertMessageContainAny(`[ERROR] ["ERROR LOG"]`)
ts.CleanBuffer()
warnL().Debug("DEBUG LOG")
warnL().Info("INFO LOG")
warnL().Warn("WARN LOG")
warnL().Error("ERROR LOG")
Sync()
ts.assertMessagesNotContains(`[DEBUG] ["DEBUG LOG"]`)
ts.assertMessagesNotContains(`[INFO] ["INFO LOG"]`)
ts.assertMessageContainAny(`[WARN] ["WARN LOG"]`)
ts.assertMessageContainAny(`[ERROR] ["ERROR LOG"]`)
ts.CleanBuffer()
errorL().Debug("DEBUG LOG")
errorL().Info("INFO LOG")
errorL().Warn("WARN LOG")
errorL().Error("ERROR LOG")
Sync()
ts.assertMessagesNotContains(`[DEBUG] ["DEBUG LOG"]`)
ts.assertMessagesNotContains(`[INFO] ["INFO LOG"]`)
ts.assertMessagesNotContains(`[WARN] ["WARN LOG"]`)
ts.assertMessageContainAny(`[ERROR] ["ERROR LOG"]`)
ts.CleanBuffer()
ctx := withLogLevel(context.TODO(), zapcore.DPanicLevel)
assert.Equal(t, Ctx(ctx).Logger, L())
// set invalid level
orgLevel := GetLevel()
SetLevel(zapcore.FatalLevel + 1)
assert.Equal(t, ctxL(), L())
SetLevel(orgLevel)
}
package log
import (
"encoding/json"
"github.com/milvus-io/milvus/internal/metastore/model"
"go.uber.org/zap"
)
type Operator string
const (
// CreateCollection operator
Creator Operator = "create"
Update Operator = "update"
Delete Operator = "delete"
Insert Operator = "insert"
Sealed Operator = "sealed"
)
type MetaLogger struct {
fields []zap.Field
logger *zap.Logger
}
func NewMetaLogger() *MetaLogger {
l := infoL()
fields := []zap.Field{zap.Bool("MetaLogInfo", true)}
return &MetaLogger{
fields: fields,
logger: l,
}
}
func (m *MetaLogger) WithCollectionMeta(coll *model.Collection) *MetaLogger {
payload, _ := json.Marshal(coll)
m.fields = append(m.fields, zap.Binary("CollectionMeta", payload))
return m
}
func (m *MetaLogger) WithIndexMeta(idx *model.Index) *MetaLogger {
payload, _ := json.Marshal(idx)
m.fields = append(m.fields, zap.Binary("IndexMeta", payload))
return m
}
func (m *MetaLogger) WithCollectionID(collID int64) *MetaLogger {
m.fields = append(m.fields, zap.Int64("CollectionID", collID))
return m
}
func (m *MetaLogger) WithCollectionName(collname string) *MetaLogger {
m.fields = append(m.fields, zap.String("CollectionName", collname))
return m
}
func (m *MetaLogger) WithPartitionID(partID int64) *MetaLogger {
m.fields = append(m.fields, zap.Int64("PartitionID", partID))
return m
}
func (m *MetaLogger) WithPartitionName(partName string) *MetaLogger {
m.fields = append(m.fields, zap.String("PartitionName", partName))
return m
}
func (m *MetaLogger) WithFieldID(fieldID int64) *MetaLogger {
m.fields = append(m.fields, zap.Int64("FieldID", fieldID))
return m
}
func (m *MetaLogger) WithFieldName(fieldName string) *MetaLogger {
m.fields = append(m.fields, zap.String("FieldName", fieldName))
return m
}
func (m *MetaLogger) WithIndexID(idxID int64) *MetaLogger {
m.fields = append(m.fields, zap.Int64("IndexID", idxID))
return m
}
func (m *MetaLogger) WithIndexName(idxName string) *MetaLogger {
m.fields = append(m.fields, zap.String("IndexName", idxName))
return m
}
func (m *MetaLogger) WithBuildID(buildID int64) *MetaLogger {
m.fields = append(m.fields, zap.Int64("BuildID", buildID))
return m
}
func (m *MetaLogger) WithBuildIDS(buildIDs []int64) *MetaLogger {
m.fields = append(m.fields, zap.Int64s("BuildIDs", buildIDs))
return m
}
func (m *MetaLogger) WithSegmentID(segID int64) *MetaLogger {
m.fields = append(m.fields, zap.Int64("SegmentID", segID))
return m
}
func (m *MetaLogger) WithIndexFiles(files []string) *MetaLogger {
m.fields = append(m.fields, zap.Strings("IndexFiles", files))
return m
}
func (m *MetaLogger) WithIndexVersion(version int64) *MetaLogger {
m.fields = append(m.fields, zap.Int64("IndexVersion", version))
return m
}
func (m *MetaLogger) WithTSO(tso uint64) *MetaLogger {
m.fields = append(m.fields, zap.Uint64("TSO", tso))
return m
}
func (m *MetaLogger) WithAlias(alias string) *MetaLogger {
m.fields = append(m.fields, zap.String("Alias", alias))
return m
}
func (m *MetaLogger) WithOperation(op MetaOperation) *MetaLogger {
m.fields = append(m.fields, zap.Int("Operation", int(op)))
return m
}
func (m *MetaLogger) Info() {
m.logger.Info("", m.fields...)
}
package log
import (
"testing"
"github.com/milvus-io/milvus/internal/metastore/model"
)
func TestMetaLogger(t *testing.T) {
ts := newTestLogSpy(t)
conf := &Config{Level: "debug", DisableTimestamp: true, DisableCaller: true}
logger, _, _ := InitTestLogger(ts, conf)
replaceLeveledLoggers(logger)
NewMetaLogger().WithCollectionID(0).
WithIndexMeta(&model.Index{}).
WithCollectionMeta(&model.Collection{}).
WithCollectionName("coll").
WithPartitionID(0).
WithPartitionName("part").
WithFieldID(0).
WithFieldName("field").
WithIndexID(0).
WithIndexName("idx").
WithBuildID(0).
WithBuildIDS([]int64{0, 0}).
WithSegmentID(0).
WithIndexFiles([]string{"idx", "idx"}).
WithIndexVersion(0).
WithTSO(0).
WithAlias("alias").
WithOperation(DropCollection).Info()
ts.assertMessagesContains("CollectionID=0")
ts.assertMessagesContains("CollectionMeta=eyJUZW5hbnRJRCI6IiIsIkNvbGxlY3Rpb25JRCI6MCwiUGFydGl0aW9ucyI6bnVsbCwiTmFtZSI6IiIsIkRlc2NyaXB0aW9uIjoiIiwiQXV0b0lEIjpmYWxzZSwiRmllbGRzIjpudWxsLCJGaWVsZElEVG9JbmRleElEIjpudWxsLCJWaXJ0dWFsQ2hhbm5lbE5hbWVzIjpudWxsLCJQaHlzaWNhbENoYW5uZWxOYW1lcyI6bnVsbCwiU2hhcmRzTnVtIjowLCJTdGFydFBvc2l0aW9ucyI6bnVsbCwiQ3JlYXRlVGltZSI6MCwiQ29uc2lzdGVuY3lMZXZlbCI6MCwiQWxpYXNlcyI6bnVsbCwiRXh0cmEiOm51bGx9")
ts.assertMessagesContains("IndexMeta=eyJDb2xsZWN0aW9uSUQiOjAsIkZpZWxkSUQiOjAsIkluZGV4SUQiOjAsIkluZGV4TmFtZSI6IiIsIklzRGVsZXRlZCI6ZmFsc2UsIkNyZWF0ZVRpbWUiOjAsIkluZGV4UGFyYW1zIjpudWxsLCJTZWdtZW50SW5kZXhlcyI6bnVsbCwiRXh0cmEiOm51bGx9")
ts.assertMessagesContains("CollectionName=coll")
ts.assertMessagesContains("PartitionID=0")
ts.assertMessagesContains("PartitionName=part")
ts.assertMessagesContains("FieldID=0")
ts.assertMessagesContains("FieldName=field")
ts.assertMessagesContains("IndexID=0")
ts.assertMessagesContains("IndexName=idx")
ts.assertMessagesContains("BuildID=0")
ts.assertMessagesContains("\"[0,0]\"")
ts.assertMessagesContains("SegmentID=0")
ts.assertMessagesContains("IndexFiles=\"[idx,idx]\"")
ts.assertMessagesContains("IndexVersion=0")
ts.assertMessagesContains("TSO=0")
ts.assertMessagesContains("Alias=alias")
ts.assertMessagesContains("Operation=1")
}
package log
type MetaOperation int
const (
InvalidMetaOperation MetaOperation = iota - 1
CreateCollection
DropCollection
CreateCollectionAlias
AlterCollectionAlias
DropCollectionAlias
CreatePartition
DropPartition
CreateIndex
DropIndex
BuildSegmentIndex
)
package log
import "go.uber.org/zap"
type MLogger struct {
*zap.Logger
}
func (l *MLogger) RatedDebug(cost float64, msg string, fields ...zap.Field) bool {
if R().CheckCredit(cost) {
l.Debug(msg, fields...)
return true
}
return false
}
func (l *MLogger) RatedInfo(cost float64, msg string, fields ...zap.Field) bool {
if R().CheckCredit(cost) {
l.Info(msg, fields...)
return true
}
return false
}
func (l *MLogger) RatedWarn(cost float64, msg string, fields ...zap.Field) bool {
if R().CheckCredit(cost) {
l.Warn(msg, fields...)
return true
}
return false
}
package log
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
)
func TestExporterV2(t *testing.T) {
ts := newTestLogSpy(t)
conf := &Config{Level: "debug", DisableTimestamp: true}
logger, properties, _ := InitTestLogger(ts, conf)
ReplaceGlobals(logger, properties)
replaceLeveledLoggers(logger)
ctx := WithTraceID(context.TODO(), "mock-trace")
Ctx(ctx).Info("Info Test")
Ctx(ctx).Debug("Debug Test")
Ctx(ctx).Warn("Warn Test")
Ctx(ctx).Error("Error Test")
Ctx(ctx).Sync()
ts.assertMessagesContains("log/mlogger_test.go")
ts.assertMessagesContains("traceID=mock-trace")
ts.CleanBuffer()
Ctx(nil).Info("empty context")
ts.assertMessagesNotContains("traceID")
fieldCtx := WithFields(ctx, zap.String("field", "test"))
reqCtx := WithReqID(ctx, 123456)
modCtx := WithModule(ctx, "test")
Ctx(fieldCtx).Info("Info Test")
Ctx(fieldCtx).Sync()
ts.assertLastMessageContains("field=test")
ts.assertLastMessageContains("traceID=mock-trace")
Ctx(reqCtx).Info("Info Test")
Ctx(reqCtx).Sync()
ts.assertLastMessageContains("reqID=123456")
ts.assertLastMessageContains("traceID=mock-trace")
ts.assertLastMessageNotContains("field=test")
Ctx(modCtx).Info("Info Test")
Ctx(modCtx).Sync()
ts.assertLastMessageContains("module=test")
ts.assertLastMessageContains("traceID=mock-trace")
ts.assertLastMessageNotContains("reqID=123456")
ts.assertLastMessageNotContains("field=test")
}
func TestMLoggerRatedLog(t *testing.T) {
ts := newTestLogSpy(t)
conf := &Config{Level: "debug", DisableTimestamp: true}
logger, p, _ := InitTestLogger(ts, conf)
ReplaceGlobals(logger, p)
ctx := WithTraceID(context.TODO(), "test-trace")
time.Sleep(time.Duration(1) * time.Second)
success := Ctx(ctx).RatedDebug(1.0, "debug test")
assert.True(t, success)
time.Sleep(time.Duration(1) * time.Second)
success = Ctx(ctx).RatedDebug(100.0, "debug test")
assert.False(t, success)
time.Sleep(time.Duration(1) * time.Second)
success = Ctx(ctx).RatedInfo(1.0, "info test")
assert.True(t, success)
time.Sleep(time.Duration(1) * time.Second)
success = Ctx(ctx).RatedWarn(1.0, "warn test")
assert.True(t, success)
time.Sleep(time.Duration(1) * time.Second)
success = Ctx(ctx).RatedWarn(100.0, "warn test")
assert.False(t, success)
time.Sleep(time.Duration(1) * time.Second)
success = Ctx(ctx).RatedInfo(100.0, "info test")
assert.False(t, success)
successNum := 0
for i := 0; i < 1000; i++ {
if Ctx(ctx).RatedInfo(1.0, "info test") {
successNum++
}
time.Sleep(time.Duration(10) * time.Millisecond)
}
assert.True(t, successNum < 1000)
assert.True(t, successNum > 10)
time.Sleep(time.Duration(3) * time.Second)
success = Ctx(ctx).RatedInfo(3.0, "info test")
assert.True(t, success)
Ctx(ctx).Sync()
}
......@@ -292,6 +292,10 @@ func (t *testLogSpy) FailNow() {
t.TB.FailNow()
}
func (t *testLogSpy) CleanBuffer() {
t.Messages = []string{}
}
func (t *testLogSpy) Logf(format string, args ...interface{}) {
// Log messages are in the format,
//
......@@ -320,3 +324,40 @@ func (t *testLogSpy) assertMessagesNotContains(msg string) {
assert.NotContains(t.TB, actualMsg, msg)
}
}
func (t *testLogSpy) assertLastMessageContains(msg string) {
if len(t.Messages) == 0 {
assert.Error(t.TB, fmt.Errorf("empty message"))
}
assert.Contains(t.TB, t.Messages[len(t.Messages)-1], msg)
}
func (t *testLogSpy) assertLastMessageNotContains(msg string) {
if len(t.Messages) == 0 {
assert.Error(t.TB, fmt.Errorf("empty message"))
}
assert.NotContains(t.TB, t.Messages[len(t.Messages)-1], msg)
}
func (t *testLogSpy) assertMessageContainAny(msg string) {
found := false
for _, actualMsg := range t.Messages {
if strings.Contains(actualMsg, msg) {
found = true
}
}
assert.True(t, found, "can't found any message contain `%s`, all messages: %v", msg, fmtMsgs(t.Messages))
}
func fmtMsgs(messages []string) string {
builder := strings.Builder{}
builder.WriteString("[")
for i, msg := range messages {
if i == len(messages)-1 {
builder.WriteString(fmt.Sprintf("`%s]`", msg))
} else {
builder.WriteString(fmt.Sprintf("`%s`, ", msg))
}
}
return builder.String()
}
......@@ -9,20 +9,18 @@ import (
pb "github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/internal/proto/commonpb"
)
var (
colID = typeutil.UniqueID(1)
colName = "c"
fieldID = typeutil.UniqueID(101)
fieldName = "field110"
partID = typeutil.UniqueID(20)
partName = "testPart"
tenantID = "tenant-1"
typeParams = []*commonpb.KeyValuePair{
colID int64 = 1
colName = "c"
fieldID int64 = 101
fieldName = "field110"
partID int64 = 20
partName = "testPart"
tenantID = "tenant-1"
typeParams = []*commonpb.KeyValuePair{
{
Key: "field110-k1",
Value: "field110-v1",
......
......@@ -7,13 +7,12 @@ import (
"github.com/milvus-io/milvus/internal/proto/commonpb"
pb "github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
var (
indexID = typeutil.UniqueID(1)
indexName = "idx"
indexParams = []*commonpb.KeyValuePair{
indexID int64 = 1
indexName = "idx"
indexParams = []*commonpb.KeyValuePair{
{
Key: "field110-i1",
Value: "field110-v1",
......
......@@ -4,13 +4,12 @@ import (
"testing"
pb "github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/stretchr/testify/assert"
)
var (
segmentID = typeutil.UniqueID(1)
buildID = typeutil.UniqueID(1)
segmentID int64 = 1
buildID int64 = 1
segmentIdxPb = &pb.SegmentIndexInfo{
CollectionID: colID,
......
......@@ -2520,7 +2520,6 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
sp, ctx := trace.StartSpanFromContextWithOperationName(ctx, "Proxy-Search")
defer sp.Finish()
traceID, _, _ := trace.InfoFromSpan(sp)
qt := &searchTask{
ctx: ctx,
......@@ -2541,9 +2540,8 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
travelTs := request.TravelTimestamp
guaranteeTs := request.GuaranteeTimestamp
log.Debug(
log.Ctx(ctx).Info(
rpcReceived(method),
zap.String("traceID", traceID),
zap.String("role", typeutil.ProxyRole),
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName),
......@@ -2556,10 +2554,9 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
zap.Uint64("guarantee_timestamp", guaranteeTs))
if err := node.sched.dqQueue.Enqueue(qt); err != nil {
log.Warn(
log.Ctx(ctx).Warn(
rpcFailedToEnqueue(method),
zap.Error(err),
zap.String("traceID", traceID),
zap.String("role", typeutil.ProxyRole),
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName),
......@@ -2581,11 +2578,10 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
},
}, nil
}
tr.Record("search request enqueue")
tr.CtxRecord(ctx, "search request enqueue")
log.Debug(
log.Ctx(ctx).Debug(
rpcEnqueued(method),
zap.String("traceID", traceID),
zap.String("role", typeutil.ProxyRole),
zap.Int64("msgID", qt.ID()),
zap.Uint64("timestamp", qt.Base.Timestamp),
......@@ -2600,10 +2596,9 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
zap.Uint64("guarantee_timestamp", guaranteeTs))
if err := qt.WaitToFinish(); err != nil {
log.Warn(
log.Ctx(ctx).Warn(
rpcFailedToWaitToFinish(method),
zap.Error(err),
zap.String("traceID", traceID),
zap.String("role", typeutil.ProxyRole),
zap.Int64("msgID", qt.ID()),
zap.String("db", request.DbName),
......@@ -2627,12 +2622,11 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
}, nil
}
span := tr.Record("wait search result")
span := tr.CtxRecord(ctx, "wait search result")
metrics.ProxyWaitForSearchResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10),
metrics.SearchLabel).Observe(float64(span.Milliseconds()))
log.Debug(
log.Ctx(ctx).Debug(
rpcDone(method),
zap.String("traceID", traceID),
zap.String("role", typeutil.ProxyRole),
zap.Int64("msgID", qt.ID()),
zap.String("db", request.DbName),
......@@ -2763,7 +2757,6 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
sp, ctx := trace.StartSpanFromContextWithOperationName(ctx, "Proxy-Query")
defer sp.Finish()
traceID, _, _ := trace.InfoFromSpan(sp)
tr := timerecord.NewTimeRecorder("Query")
qt := &queryTask{
......@@ -2787,9 +2780,8 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
metrics.ProxyDQLFunctionCall.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), method,
metrics.TotalLabel).Inc()
log.Debug(
log.Ctx(ctx).Info(
rpcReceived(method),
zap.String("traceID", traceID),
zap.String("role", typeutil.ProxyRole),
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName),
......@@ -2800,10 +2792,9 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
zap.Uint64("guarantee_timestamp", request.GuaranteeTimestamp))
if err := node.sched.dqQueue.Enqueue(qt); err != nil {
log.Warn(
log.Ctx(ctx).Warn(
rpcFailedToEnqueue(method),
zap.Error(err),
zap.String("traceID", traceID),
zap.String("role", typeutil.ProxyRole),
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName),
......@@ -2819,11 +2810,10 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
},
}, nil
}
tr.Record("query request enqueue")
tr.CtxRecord(ctx, "query request enqueue")
log.Debug(
log.Ctx(ctx).Debug(
rpcEnqueued(method),
zap.String("traceID", traceID),
zap.String("role", typeutil.ProxyRole),
zap.Int64("msgID", qt.ID()),
zap.String("db", request.DbName),
......@@ -2831,10 +2821,9 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
zap.Strings("partitions", request.PartitionNames))
if err := qt.WaitToFinish(); err != nil {
log.Warn(
log.Ctx(ctx).Warn(
rpcFailedToWaitToFinish(method),
zap.Error(err),
zap.String("traceID", traceID),
zap.String("role", typeutil.ProxyRole),
zap.Int64("msgID", qt.ID()),
zap.String("db", request.DbName),
......@@ -2851,12 +2840,11 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
},
}, nil
}
span := tr.Record("wait query result")
span := tr.CtxRecord(ctx, "wait query result")
metrics.ProxyWaitForSearchResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10),
metrics.QueryLabel).Observe(float64(span.Milliseconds()))
log.Debug(
log.Ctx(ctx).Debug(
rpcDone(method),
zap.String("traceID", traceID),
zap.String("role", typeutil.ProxyRole),
zap.Int64("msgID", qt.ID()),
zap.String("db", request.DbName),
......
......@@ -394,12 +394,12 @@ func TestMetaCache_GetShards(t *testing.T) {
assert.NoError(t, err)
assert.NotEmpty(t, shards)
assert.Equal(t, 1, len(shards))
assert.Equal(t, 3, len(shards["channel-1"]))
// get from cache
qc.validShardLeaders = false
shards, err = globalMetaCache.GetShards(ctx, true, collectionName)
assert.NoError(t, err)
assert.NotEmpty(t, shards)
assert.Equal(t, 1, len(shards))
......
......@@ -40,7 +40,7 @@ func groupShardleadersWithSameQueryNode(
// check if all leaders were checked
for dml, idx := range nexts {
if idx >= len(shard2leaders[dml]) {
log.Warn("no shard leaders were available",
log.Ctx(ctx).Warn("no shard leaders were available",
zap.String("channel", dml),
zap.String("leaders", fmt.Sprintf("%v", shard2leaders[dml])))
if e, ok := errSet[dml]; ok {
......@@ -59,7 +59,7 @@ func groupShardleadersWithSameQueryNode(
if _, ok := qnSet[nodeInfo.nodeID]; !ok {
qn, err := mgr.GetClient(ctx, nodeInfo.nodeID)
if err != nil {
log.Warn("failed to get shard leader", zap.Int64("nodeID", nodeInfo.nodeID), zap.Error(err))
log.Ctx(ctx).Warn("failed to get shard leader", zap.Int64("nodeID", nodeInfo.nodeID), zap.Error(err))
// if get client failed, just record error and wait for next round to get client and do query
errSet[dml] = err
continue
......@@ -111,7 +111,7 @@ func mergeRoundRobinPolicy(
go func() {
defer wg.Done()
if err := query(ctx, nodeID, qn, channels); err != nil {
log.Warn("failed to do query with node", zap.Int64("nodeID", nodeID),
log.Ctx(ctx).Warn("failed to do query with node", zap.Int64("nodeID", nodeID),
zap.Strings("dmlChannels", channels), zap.Error(err))
mu.Lock()
defer mu.Unlock()
......@@ -138,7 +138,7 @@ func mergeRoundRobinPolicy(
nextSet[dml] = dml2leaders[dml][idx].nodeID
}
}
log.Warn("retry another query node with round robin", zap.Any("Nexts", nextSet))
log.Ctx(ctx).Warn("retry another query node with round robin", zap.Any("Nexts", nextSet))
}
}
return nil
......
......@@ -110,44 +110,44 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
collectionName := t.request.CollectionName
t.collectionName = collectionName
if err := validateCollectionName(collectionName); err != nil {
log.Warn("Invalid collection name.", zap.String("collectionName", collectionName),
log.Ctx(ctx).Warn("Invalid collection name.", zap.String("collectionName", collectionName),
zap.Int64("msgID", t.ID()), zap.String("requestType", "query"))
return err
}
log.Info("Validate collection name.", zap.Any("collectionName", collectionName),
log.Ctx(ctx).Debug("Validate collection name.", zap.Any("collectionName", collectionName),
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
collID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
if err != nil {
log.Debug("Failed to get collection id.", zap.Any("collectionName", collectionName),
log.Ctx(ctx).Warn("Failed to get collection id.", zap.Any("collectionName", collectionName),
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
return err
}
t.CollectionID = collID
log.Info("Get collection ID by name",
log.Ctx(ctx).Debug("Get collection ID by name",
zap.Int64("collectionID", t.CollectionID), zap.String("collection name", collectionName),
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
for _, tag := range t.request.PartitionNames {
if err := validatePartitionTag(tag, false); err != nil {
log.Warn("invalid partition name", zap.String("partition name", tag),
log.Ctx(ctx).Warn("invalid partition name", zap.String("partition name", tag),
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
return err
}
}
log.Debug("Validate partition names.",
log.Ctx(ctx).Debug("Validate partition names.",
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
t.RetrieveRequest.PartitionIDs, err = getPartitionIDs(ctx, collectionName, t.request.GetPartitionNames())
if err != nil {
log.Warn("failed to get partitions in collection.", zap.String("collection name", collectionName),
log.Ctx(ctx).Warn("failed to get partitions in collection.", zap.String("collection name", collectionName),
zap.Error(err),
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
return err
}
log.Debug("Get partitions in collection.", zap.Any("collectionName", collectionName),
log.Ctx(ctx).Debug("Get partitions in collection.", zap.Any("collectionName", collectionName),
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
loaded, err := checkIfLoaded(ctx, t.qc, collectionName, t.RetrieveRequest.GetPartitionIDs())
......@@ -182,7 +182,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
if err != nil {
return err
}
log.Debug("translate output fields", zap.Any("OutputFields", t.request.OutputFields),
log.Ctx(ctx).Debug("translate output fields", zap.Any("OutputFields", t.request.OutputFields),
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
outputFieldIDs, err := translateToOutputFieldIDs(t.request.GetOutputFields(), schema)
......@@ -191,7 +191,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
}
t.RetrieveRequest.OutputFieldsId = outputFieldIDs
plan.OutputFieldIds = outputFieldIDs
log.Debug("translate output fields to field ids", zap.Any("OutputFieldsID", t.OutputFieldsId),
log.Ctx(ctx).Debug("translate output fields to field ids", zap.Any("OutputFieldsID", t.OutputFieldsId),
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
t.RetrieveRequest.SerializedExprPlan, err = proto.Marshal(plan)
......@@ -219,7 +219,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
}
t.DbID = 0 // TODO
log.Info("Query PreExecute done.",
log.Ctx(ctx).Debug("Query PreExecute done.",
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"),
zap.Uint64("guarantee_ts", guaranteeTs), zap.Uint64("travel_ts", t.GetTravelTimestamp()),
zap.Uint64("timeout_ts", t.GetTimeoutTimestamp()))
......@@ -228,7 +228,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
func (t *queryTask) Execute(ctx context.Context) error {
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute query %d", t.ID()))
defer tr.Elapse("done")
defer tr.CtxElapse(ctx, "done")
executeQuery := func(withCache bool) error {
shards, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName)
......@@ -246,7 +246,7 @@ func (t *queryTask) Execute(ctx context.Context) error {
err := executeQuery(WithCache)
if errors.Is(err, errInvalidShardLeaders) || funcutil.IsGrpcErr(err) || errors.Is(err, grpcclient.ErrConnect) {
log.Warn("invalid shard leaders cache, updating shardleader caches and retry search",
log.Ctx(ctx).Warn("invalid shard leaders cache, updating shardleader caches and retry search",
zap.Int64("msgID", t.ID()), zap.Error(err))
return executeQuery(WithoutCache)
}
......@@ -254,7 +254,7 @@ func (t *queryTask) Execute(ctx context.Context) error {
return fmt.Errorf("fail to search on all shard leaders, err=%s", err.Error())
}
log.Info("Query Execute done.",
log.Ctx(ctx).Debug("Query Execute done.",
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
return nil
}
......@@ -262,27 +262,27 @@ func (t *queryTask) Execute(ctx context.Context) error {
func (t *queryTask) PostExecute(ctx context.Context) error {
tr := timerecord.NewTimeRecorder("queryTask PostExecute")
defer func() {
tr.Elapse("done")
tr.CtxElapse(ctx, "done")
}()
var err error
select {
case <-t.TraceCtx().Done():
log.Warn("proxy", zap.Int64("Query: wait to finish failed, timeout!, msgID:", t.ID()))
log.Ctx(ctx).Warn("proxy", zap.Int64("Query: wait to finish failed, timeout!, msgID:", t.ID()))
return nil
default:
log.Debug("all queries are finished or canceled", zap.Int64("msgID", t.ID()))
log.Ctx(ctx).Debug("all queries are finished or canceled", zap.Int64("msgID", t.ID()))
close(t.resultBuf)
for res := range t.resultBuf {
t.toReduceResults = append(t.toReduceResults, res)
log.Debug("proxy receives one query result", zap.Int64("sourceID", res.GetBase().GetSourceID()), zap.Any("msgID", t.ID()))
log.Ctx(ctx).Debug("proxy receives one query result", zap.Int64("sourceID", res.GetBase().GetSourceID()), zap.Any("msgID", t.ID()))
}
}
metrics.ProxyDecodeResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), metrics.QueryLabel).Observe(0.0)
tr.Record("reduceResultStart")
t.result, err = mergeRetrieveResults(t.toReduceResults)
tr.CtxRecord(ctx, "reduceResultStart")
t.result, err = mergeRetrieveResults(ctx, t.toReduceResults)
if err != nil {
return err
}
......@@ -294,7 +294,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
ErrorCode: commonpb.ErrorCode_Success,
}
} else {
log.Info("Query result is nil", zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
log.Ctx(ctx).Warn("Query result is nil", zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
t.result.Status = &commonpb.Status{
ErrorCode: commonpb.ErrorCode_EmptyCollection,
Reason: "empty collection", // TODO
......@@ -315,7 +315,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
}
}
}
log.Info("Query PostExecute done", zap.Int64("msgID", t.ID()), zap.String("requestType", "query"))
log.Ctx(ctx).Debug("Query PostExecute done", zap.Int64("msgID", t.ID()), zap.String("requestType", "query"))
return nil
}
......@@ -328,21 +328,21 @@ func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.Query
result, err := qn.Query(ctx, req)
if err != nil {
log.Warn("QueryNode query return error", zap.Int64("msgID", t.ID()),
log.Ctx(ctx).Warn("QueryNode query return error", zap.Int64("msgID", t.ID()),
zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs), zap.Error(err))
return err
}
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
log.Warn("QueryNode is not shardLeader", zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs))
log.Ctx(ctx).Warn("QueryNode is not shardLeader", zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs))
return errInvalidShardLeaders
}
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("QueryNode query result error", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID),
log.Ctx(ctx).Warn("QueryNode query result error", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID),
zap.String("reason", result.GetStatus().GetReason()))
return fmt.Errorf("fail to Query, QueryNode ID = %d, reason=%s", nodeID, result.GetStatus().GetReason())
}
log.Debug("get query result", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID), zap.Strings("channelIDs", channelIDs))
log.Ctx(ctx).Debug("get query result", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID), zap.Strings("channelIDs", channelIDs))
t.resultBuf <- result
return nil
}
......@@ -360,7 +360,7 @@ func IDs2Expr(fieldName string, ids *schemapb.IDs) string {
return fieldName + " in [ " + idsStr + " ]"
}
func mergeRetrieveResults(retrieveResults []*internalpb.RetrieveResults) (*milvuspb.QueryResults, error) {
func mergeRetrieveResults(ctx context.Context, retrieveResults []*internalpb.RetrieveResults) (*milvuspb.QueryResults, error) {
var ret *milvuspb.QueryResults
var skipDupCnt int64
var idSet = make(map[interface{}]struct{})
......@@ -394,7 +394,7 @@ func mergeRetrieveResults(retrieveResults []*internalpb.RetrieveResults) (*milvu
}
}
}
log.Debug("skip duplicated query result", zap.Int64("count", skipDupCnt))
log.Ctx(ctx).Debug("skip duplicated query result", zap.Int64("count", skipDupCnt))
if ret == nil {
ret = &milvuspb.QueryResults{
......
......@@ -210,7 +210,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
if err != nil {
return err
}
log.Debug("translate output fields", zap.Int64("msgID", t.ID()),
log.Ctx(ctx).Debug("translate output fields", zap.Int64("msgID", t.ID()),
zap.Strings("output fields", t.request.GetOutputFields()))
if t.request.GetDslType() == commonpb.DslType_BoolExprV1 {
......@@ -226,12 +226,12 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
plan, err := planparserv2.CreateSearchPlan(t.schema, t.request.Dsl, annsField, queryInfo)
if err != nil {
log.Debug("failed to create query plan", zap.Error(err), zap.Int64("msgID", t.ID()),
log.Ctx(ctx).Warn("failed to create query plan", zap.Error(err), zap.Int64("msgID", t.ID()),
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
zap.String("anns field", annsField), zap.Any("query info", queryInfo))
return fmt.Errorf("failed to create query plan: %v", err)
}
log.Debug("create query plan", zap.Int64("msgID", t.ID()),
log.Ctx(ctx).Debug("create query plan", zap.Int64("msgID", t.ID()),
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
zap.String("anns field", annsField), zap.Any("query info", queryInfo))
......@@ -253,7 +253,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
if err := validateTopK(queryInfo.GetTopk()); err != nil {
return err
}
log.Debug("Proxy::searchTask::PreExecute", zap.Int64("msgID", t.ID()),
log.Ctx(ctx).Debug("Proxy::searchTask::PreExecute", zap.Int64("msgID", t.ID()),
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
zap.String("plan", plan.String())) // may be very large if large term passed.
}
......@@ -282,7 +282,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
if t.SearchRequest.Nq, err = getNq(t.request); err != nil {
return err
}
log.Info("search PreExecute done.", zap.Int64("msgID", t.ID()),
log.Ctx(ctx).Debug("search PreExecute done.", zap.Int64("msgID", t.ID()),
zap.Uint64("travel_ts", travelTimestamp), zap.Uint64("guarantee_ts", guaranteeTs),
zap.Uint64("timeout_ts", t.SearchRequest.GetTimeoutTimestamp()))
......@@ -294,7 +294,7 @@ func (t *searchTask) Execute(ctx context.Context) error {
defer sp.Finish()
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute search %d", t.ID()))
defer tr.Elapse("done")
defer tr.CtxElapse(ctx, "done")
executeSearch := func(withCache bool) error {
shard2Leaders, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName)
......@@ -304,7 +304,7 @@ func (t *searchTask) Execute(ctx context.Context) error {
t.resultBuf = make(chan *internalpb.SearchResults, len(shard2Leaders))
t.toReduceResults = make([]*internalpb.SearchResults, 0, len(shard2Leaders))
if err := t.searchShardPolicy(ctx, t.shardMgr, t.searchShard, shard2Leaders); err != nil {
log.Warn("failed to do search", zap.Error(err), zap.String("Shards", fmt.Sprintf("%v", shard2Leaders)))
log.Ctx(ctx).Warn("failed to do search", zap.Error(err), zap.String("Shards", fmt.Sprintf("%v", shard2Leaders)))
return err
}
return nil
......@@ -312,7 +312,7 @@ func (t *searchTask) Execute(ctx context.Context) error {
err := executeSearch(WithCache)
if errors.Is(err, errInvalidShardLeaders) || funcutil.IsGrpcErr(err) || errors.Is(err, grpcclient.ErrConnect) {
log.Warn("first search failed, updating shardleader caches and retry search",
log.Ctx(ctx).Warn("first search failed, updating shardleader caches and retry search",
zap.Int64("msgID", t.ID()), zap.Error(err))
return executeSearch(WithoutCache)
}
......@@ -320,7 +320,7 @@ func (t *searchTask) Execute(ctx context.Context) error {
return fmt.Errorf("fail to search on all shard leaders, err=%v", err)
}
log.Debug("Search Execute done.", zap.Int64("msgID", t.ID()))
log.Ctx(ctx).Debug("Search Execute done.", zap.Int64("msgID", t.ID()))
return nil
}
......@@ -329,34 +329,34 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
defer sp.Finish()
tr := timerecord.NewTimeRecorder("searchTask PostExecute")
defer func() {
tr.Elapse("done")
tr.CtxElapse(ctx, "done")
}()
select {
// in case timeout happened
case <-t.TraceCtx().Done():
log.Debug("wait to finish timeout!", zap.Int64("msgID", t.ID()))
log.Ctx(ctx).Debug("wait to finish timeout!", zap.Int64("msgID", t.ID()))
return nil
default:
log.Debug("all searches are finished or canceled", zap.Int64("msgID", t.ID()))
log.Ctx(ctx).Debug("all searches are finished or canceled", zap.Int64("msgID", t.ID()))
close(t.resultBuf)
for res := range t.resultBuf {
t.toReduceResults = append(t.toReduceResults, res)
log.Debug("proxy receives one query result", zap.Int64("sourceID", res.GetBase().GetSourceID()), zap.Int64("msgID", t.ID()))
log.Ctx(ctx).Debug("proxy receives one query result", zap.Int64("sourceID", res.GetBase().GetSourceID()), zap.Int64("msgID", t.ID()))
}
}
tr.Record("decodeResultStart")
validSearchResults, err := decodeSearchResults(t.toReduceResults)
tr.CtxRecord(ctx, "decodeResultStart")
validSearchResults, err := decodeSearchResults(ctx, t.toReduceResults)
if err != nil {
return err
}
metrics.ProxyDecodeResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10),
metrics.SearchLabel).Observe(float64(tr.RecordSpan().Milliseconds()))
log.Debug("proxy search post execute stage 2", zap.Int64("msgID", t.ID()),
log.Ctx(ctx).Debug("proxy search post execute stage 2", zap.Int64("msgID", t.ID()),
zap.Int("len(validSearchResults)", len(validSearchResults)))
if len(validSearchResults) <= 0 {
log.Warn("search result is empty", zap.Int64("msgID", t.ID()))
log.Ctx(ctx).Warn("search result is empty", zap.Int64("msgID", t.ID()))
t.result = &milvuspb.SearchResults{
Status: &commonpb.Status{
......@@ -375,12 +375,13 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
return nil
}
tr.Record("reduceResultStart")
tr.CtxRecord(ctx, "reduceResultStart")
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(t.schema)
if err != nil {
return err
}
t.result, err = reduceSearchResultData(validSearchResults, t.toReduceResults[0].NumQueries, t.toReduceResults[0].TopK, t.toReduceResults[0].MetricType, primaryFieldSchema.DataType)
t.result, err = reduceSearchResultData(ctx, validSearchResults, t.toReduceResults[0].NumQueries,
t.toReduceResults[0].TopK, t.toReduceResults[0].MetricType, primaryFieldSchema.DataType)
if err != nil {
return err
}
......@@ -403,7 +404,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
}
}
}
log.Info("Search post execute done", zap.Int64("msgID", t.ID()))
log.Ctx(ctx).Debug("Search post execute done", zap.Int64("msgID", t.ID()))
return nil
}
......@@ -415,17 +416,17 @@ func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.Que
}
result, err := qn.Search(ctx, req)
if err != nil {
log.Warn("QueryNode search return error", zap.Int64("msgID", t.ID()),
log.Ctx(ctx).Warn("QueryNode search return error", zap.Int64("msgID", t.ID()),
zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs), zap.Error(err))
return err
}
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
log.Warn("QueryNode is not shardLeader", zap.Int64("msgID", t.ID()),
log.Ctx(ctx).Warn("QueryNode is not shardLeader", zap.Int64("msgID", t.ID()),
zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs))
return errInvalidShardLeaders
}
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("QueryNode search result error", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID),
log.Ctx(ctx).Warn("QueryNode search result error", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID),
zap.String("reason", result.GetStatus().GetReason()))
return fmt.Errorf("fail to Search, QueryNode ID=%d, reason=%s", nodeID, result.GetStatus().GetReason())
}
......@@ -482,7 +483,7 @@ func checkIfLoaded(ctx context.Context, qc types.QueryCoord, collectionName stri
}
if len(resp.GetPartitionIDs()) > 0 {
log.Warn("collection not fully loaded, search on these partitions",
log.Ctx(ctx).Warn("collection not fully loaded, search on these partitions",
zap.String("collection", collectionName),
zap.Int64("collectionID", info.collID), zap.Int64s("partitionIDs", resp.GetPartitionIDs()))
return true, nil
......@@ -491,7 +492,7 @@ func checkIfLoaded(ctx context.Context, qc types.QueryCoord, collectionName stri
return false, nil
}
func decodeSearchResults(searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) {
func decodeSearchResults(ctx context.Context, searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) {
tr := timerecord.NewTimeRecorder("decodeSearchResults")
results := make([]*schemapb.SearchResultData, 0)
for _, partialSearchResult := range searchResults {
......@@ -507,7 +508,7 @@ func decodeSearchResults(searchResults []*internalpb.SearchResults) ([]*schemapb
results = append(results, &partialResultData)
}
tr.Elapse("decodeSearchResults done")
tr.CtxElapse(ctx, "decodeSearchResults done")
return results, nil
}
......@@ -544,14 +545,13 @@ func selectSearchResultData(dataArray []*schemapb.SearchResultData, resultOffset
return sel
}
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType) (*milvuspb.SearchResults, error) {
func reduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType) (*milvuspb.SearchResults, error) {
tr := timerecord.NewTimeRecorder("reduceSearchResultData")
defer func() {
tr.Elapse("done")
tr.CtxElapse(ctx, "done")
}()
log.Debug("reduceSearchResultData", zap.Int("len(searchResultData)", len(searchResultData)),
log.Ctx(ctx).Debug("reduceSearchResultData", zap.Int("len(searchResultData)", len(searchResultData)),
zap.Int64("nq", nq), zap.Int64("topk", topk), zap.String("metricType", metricType))
ret := &milvuspb.SearchResults{
......@@ -585,14 +585,14 @@ func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq in
}
for i, sData := range searchResultData {
log.Debug("reduceSearchResultData",
log.Ctx(ctx).Debug("reduceSearchResultData",
zap.Int("result No.", i),
zap.Int64("nq", sData.NumQueries),
zap.Int64("topk", sData.TopK),
zap.Any("len(topks)", len(sData.Topks)),
zap.Any("len(FieldsData)", len(sData.FieldsData)))
if err := checkSearchResultData(sData, nq, topk); err != nil {
log.Warn("invalid search results", zap.Error(err))
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
return ret, err
}
//printSearchResultData(sData, strconv.FormatInt(int64(i), 10))
......@@ -637,13 +637,13 @@ func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq in
offsets[sel]++
}
if realTopK != -1 && realTopK != j {
log.Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
// return nil, errors.New("the length (topk) between all result of query is different")
}
realTopK = j
ret.Results.Topks = append(ret.Results.Topks, realTopK)
}
log.Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
ret.Results.TopK = realTopK
if !distance.PositivelyRelated(metricType) {
......
......@@ -1354,7 +1354,7 @@ func Test_reduceSearchResultData_int(t *testing.T) {
},
}
reduced, err := reduceSearchResultData(results, int64(nq), int64(topk), distance.L2, schemapb.DataType_Int64)
reduced, err := reduceSearchResultData(context.TODO(), results, int64(nq), int64(topk), distance.L2, schemapb.DataType_Int64)
assert.NoError(t, err)
assert.ElementsMatch(t, []int64{3, 4, 7, 8, 11, 12}, reduced.GetResults().GetIds().GetIntId().GetData())
// hard to compare floating point value.
......@@ -1393,7 +1393,7 @@ func Test_reduceSearchResultData_str(t *testing.T) {
},
}
reduced, err := reduceSearchResultData(results, int64(nq), int64(topk), distance.L2, schemapb.DataType_VarChar)
reduced, err := reduceSearchResultData(context.TODO(), results, int64(nq), int64(topk), distance.L2, schemapb.DataType_VarChar)
assert.NoError(t, err)
assert.ElementsMatch(t, []string{"3", "4", "7", "8", "11", "12"}, reduced.GetResults().GetIds().GetStrId().GetData())
// hard to compare floating point value.
......
......@@ -397,7 +397,7 @@ func (qc *QueryCoord) ReleaseCollection(ctx context.Context, req *querypb.Releas
// ShowPartitions return all the partitions that have been loaded
func (qc *QueryCoord) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
collectionID := req.CollectionID
log.Info("show partitions start",
log.Ctx(ctx).Debug("show partitions start",
zap.String("role", typeutil.QueryCoordRole),
zap.Int64("collectionID", collectionID),
zap.Int64s("partitionIDs", req.PartitionIDs),
......@@ -409,7 +409,7 @@ func (qc *QueryCoord) ShowPartitions(ctx context.Context, req *querypb.ShowParti
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
err := errors.New("QueryCoord is not healthy")
status.Reason = err.Error()
log.Error("show partition failed", zap.Int64("msgID", req.Base.MsgID), zap.Error(err))
log.Ctx(ctx).Warn("show partition failed", zap.Int64("msgID", req.Base.MsgID), zap.Error(err))
return &querypb.ShowPartitionsResponse{
Status: status,
}, nil
......@@ -420,7 +420,7 @@ func (qc *QueryCoord) ShowPartitions(ctx context.Context, req *querypb.ShowParti
err = fmt.Errorf("collection %d has not been loaded into QueryNode", collectionID)
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
log.Warn("show partitions failed",
log.Ctx(ctx).Warn("show partitions failed",
zap.String("role", typeutil.QueryCoordRole),
zap.Int64("collectionID", collectionID),
zap.Int64("msgID", req.Base.MsgID), zap.Error(err))
......@@ -439,7 +439,7 @@ func (qc *QueryCoord) ShowPartitions(ctx context.Context, req *querypb.ShowParti
for _, id := range inMemoryPartitionIDs {
inMemoryPercentages = append(inMemoryPercentages, ID2PartitionState[id].InMemoryPercentage)
}
log.Info("show partitions end",
log.Ctx(ctx).Debug("show partitions end",
zap.String("role", typeutil.QueryCoordRole),
zap.Int64("collectionID", collectionID),
zap.Int64("msgID", req.Base.MsgID),
......@@ -456,7 +456,7 @@ func (qc *QueryCoord) ShowPartitions(ctx context.Context, req *querypb.ShowParti
err = fmt.Errorf("partition %d of collection %d has not been loaded into QueryNode", id, collectionID)
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
log.Warn("show partitions failed",
log.Ctx(ctx).Warn("show partitions failed",
zap.String("role", typeutil.QueryCoordRole),
zap.Int64("collectionID", collectionID),
zap.Int64("partitionID", id),
......@@ -469,7 +469,7 @@ func (qc *QueryCoord) ShowPartitions(ctx context.Context, req *querypb.ShowParti
inMemoryPercentages = append(inMemoryPercentages, ID2PartitionState[id].InMemoryPercentage)
}
log.Info("show partitions end",
log.Ctx(ctx).Debug("show partitions end",
zap.String("role", typeutil.QueryCoordRole),
zap.Int64("collectionID", collectionID),
zap.Int64s("partitionIDs", req.PartitionIDs),
......
......@@ -84,7 +84,7 @@ func benchmarkQueryCollectionSearch(nq int64, b *testing.B) {
searchReq, err := newSearchRequest(collection, queryReq, queryReq.Req.GetPlaceholderGroup())
assert.NoError(b, err)
for j := 0; j < 10000; j++ {
_, _, _, err := searchHistorical(queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, queryReq.GetSegmentIDs())
_, _, _, err := searchHistorical(context.TODO(), queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, queryReq.GetSegmentIDs())
assert.NoError(b, err)
}
......@@ -108,7 +108,7 @@ func benchmarkQueryCollectionSearch(nq int64, b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := int64(0); j < benchmarkMaxNQ/nq; j++ {
_, _, _, err := searchHistorical(queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, queryReq.GetSegmentIDs())
_, _, _, err := searchHistorical(context.TODO(), queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, queryReq.GetSegmentIDs())
assert.NoError(b, err)
}
}
......@@ -153,7 +153,7 @@ func benchmarkQueryCollectionSearchIndex(nq int64, indexType string, b *testing.
searchReq, _ := genSearchPlanAndRequests(collection, indexType, nq)
for j := 0; j < 10000; j++ {
_, _, _, err := searchHistorical(queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, []UniqueID{defaultSegmentID})
_, _, _, err := searchHistorical(context.TODO(), queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, []UniqueID{defaultSegmentID})
assert.NoError(b, err)
}
......@@ -178,7 +178,7 @@ func benchmarkQueryCollectionSearchIndex(nq int64, indexType string, b *testing.
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < benchmarkMaxNQ/int(nq); j++ {
_, _, _, err := searchHistorical(queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, []UniqueID{defaultSegmentID})
_, _, _, err := searchHistorical(context.TODO(), queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, []UniqueID{defaultSegmentID})
assert.NoError(b, err)
}
}
......
......@@ -568,7 +568,7 @@ func (node *QueryNode) isHealthy() bool {
// Search performs replica search tasks.
func (node *QueryNode) Search(ctx context.Context, req *queryPb.SearchRequest) (*internalpb.SearchResults, error) {
log.Debug("Received SearchRequest",
log.Ctx(ctx).Debug("Received SearchRequest",
zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()),
zap.Strings("vChannels", req.GetDmlChannels()),
zap.Int64s("segmentIDs", req.GetSegmentIDs()),
......@@ -613,7 +613,7 @@ func (node *QueryNode) Search(ctx context.Context, req *queryPb.SearchRequest) (
if err := runningGp.Wait(); err != nil {
return failRet, nil
}
ret, err := reduceSearchResults(toReduceResults, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType())
ret, err := reduceSearchResults(ctx, toReduceResults, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType())
if err != nil {
failRet.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
failRet.Status.Reason = err.Error()
......@@ -641,7 +641,7 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *queryPb.Se
}
msgID := req.GetReq().GetBase().GetMsgID()
log.Debug("Received SearchRequest",
log.Ctx(ctx).Debug("Received SearchRequest",
zap.Int64("msgID", msgID),
zap.Bool("fromShardLeader", req.GetFromShardLeader()),
zap.String("vChannel", dmlChannel),
......@@ -656,7 +656,7 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *queryPb.Se
qs, err := node.queryShardService.getQueryShard(dmlChannel)
if err != nil {
log.Warn("Search failed, failed to get query shard",
log.Ctx(ctx).Warn("Search failed, failed to get query shard",
zap.Int64("msgID", msgID),
zap.String("dml channel", dmlChannel),
zap.Error(err))
......@@ -665,7 +665,7 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *queryPb.Se
return failRet, nil
}
log.Debug("start do search",
log.Ctx(ctx).Debug("start do search",
zap.Int64("msgID", msgID),
zap.Bool("fromShardLeader", req.GetFromShardLeader()),
zap.String("vChannel", dmlChannel),
......@@ -692,7 +692,7 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *queryPb.Se
return failRet, nil
}
tr.Elapse(fmt.Sprintf("do search done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
tr.CtxElapse(ctx, fmt.Sprintf("do search done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs()))
failRet.Status.ErrorCode = commonpb.ErrorCode_Success
......@@ -747,22 +747,22 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *queryPb.Se
// shard leader dispatches request to its shard cluster
results, errCluster = cluster.Search(searchCtx, req, withStreaming)
if errCluster != nil {
log.Warn("search cluster failed", zap.Int64("msgID", msgID), zap.Int64("collectionID", req.Req.GetCollectionID()), zap.Error(errCluster))
log.Ctx(ctx).Warn("search cluster failed", zap.Int64("msgID", msgID), zap.Int64("collectionID", req.Req.GetCollectionID()), zap.Error(errCluster))
failRet.Status.Reason = errCluster.Error()
return failRet, nil
}
tr.Elapse(fmt.Sprintf("start reduce search result, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
tr.CtxElapse(ctx, fmt.Sprintf("start reduce search result, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs()))
results = append(results, streamingResult)
ret, err2 := reduceSearchResults(results, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType())
ret, err2 := reduceSearchResults(ctx, results, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType())
if err2 != nil {
failRet.Status.Reason = err2.Error()
return failRet, nil
}
tr.Elapse(fmt.Sprintf("do search done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
tr.CtxElapse(ctx, fmt.Sprintf("do search done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs()))
failRet.Status.ErrorCode = commonpb.ErrorCode_Success
......@@ -793,7 +793,7 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *queryPb.Que
}
msgID := req.GetReq().GetBase().GetMsgID()
log.Debug("Received QueryRequest",
log.Ctx(ctx).Debug("Received QueryRequest",
zap.Int64("msgID", msgID),
zap.Bool("fromShardLeader", req.GetFromShardLeader()),
zap.String("vChannel", dmlChannel),
......@@ -808,12 +808,12 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *queryPb.Que
qs, err := node.queryShardService.getQueryShard(dmlChannel)
if err != nil {
log.Warn("Query failed, failed to get query shard", zap.Int64("msgID", msgID), zap.String("dml channel", dmlChannel), zap.Error(err))
log.Ctx(ctx).Warn("Query failed, failed to get query shard", zap.Int64("msgID", msgID), zap.String("dml channel", dmlChannel), zap.Error(err))
failRet.Status.Reason = err.Error()
return failRet, nil
}
log.Debug("start do query",
log.Ctx(ctx).Debug("start do query",
zap.Int64("msgID", msgID),
zap.Bool("fromShardLeader", req.GetFromShardLeader()),
zap.String("vChannel", dmlChannel),
......@@ -837,7 +837,7 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *queryPb.Que
return failRet, nil
}
tr.Elapse(fmt.Sprintf("do query done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
tr.CtxElapse(ctx, fmt.Sprintf("do query done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs()))
failRet.Status.ErrorCode = commonpb.ErrorCode_Success
......@@ -890,22 +890,22 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *queryPb.Que
// shard leader dispatches request to its shard cluster
results, errCluster = cluster.Query(queryCtx, req, withStreaming)
if errCluster != nil {
log.Warn("failed to query cluster", zap.Int64("msgID", msgID), zap.Int64("collectionID", req.Req.GetCollectionID()), zap.Error(errCluster))
log.Ctx(ctx).Warn("failed to query cluster", zap.Int64("msgID", msgID), zap.Int64("collectionID", req.Req.GetCollectionID()), zap.Error(errCluster))
failRet.Status.Reason = errCluster.Error()
return failRet, nil
}
tr.Elapse(fmt.Sprintf("start reduce query result, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
tr.CtxElapse(ctx, fmt.Sprintf("start reduce query result, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs()))
results = append(results, streamingResult)
ret, err2 := mergeInternalRetrieveResults(results)
ret, err2 := mergeInternalRetrieveResults(ctx, results)
if err2 != nil {
failRet.Status.Reason = err2.Error()
return failRet, nil
}
tr.Elapse(fmt.Sprintf("do query done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
tr.CtxElapse(ctx, fmt.Sprintf("do query done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs()))
failRet.Status.ErrorCode = commonpb.ErrorCode_Success
......@@ -917,7 +917,7 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *queryPb.Que
// Query performs replica query tasks.
func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
log.Debug("Received QueryRequest", zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()),
log.Ctx(ctx).Debug("Received QueryRequest", zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()),
zap.Strings("vChannels", req.GetDmlChannels()),
zap.Int64s("segmentIDs", req.GetSegmentIDs()),
zap.Uint64("guaranteeTimestamp", req.Req.GetGuaranteeTimestamp()),
......@@ -963,7 +963,7 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
if err := runningGp.Wait(); err != nil {
return failRet, nil
}
ret, err := mergeInternalRetrieveResults(toMergeResults)
ret, err := mergeInternalRetrieveResults(ctx, toMergeResults)
if err != nil {
failRet.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
failRet.Status.Reason = err.Error()
......
......@@ -159,7 +159,7 @@ func TestReduceSearchResultData(t *testing.T) {
dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2)
res, err := reduceSearchResultData(dataArray, nq, topk)
res, err := reduceSearchResultData(context.TODO(), dataArray, nq, topk)
assert.Nil(t, err)
assert.Equal(t, ids, res.Ids.GetIntId().Data)
assert.Equal(t, scores, res.Scores)
......@@ -176,7 +176,7 @@ func TestReduceSearchResultData(t *testing.T) {
dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2)
res, err := reduceSearchResultData(dataArray, nq, topk)
res, err := reduceSearchResultData(context.TODO(), dataArray, nq, topk)
assert.Nil(t, err)
assert.ElementsMatch(t, []int64{1, 5, 2, 3}, res.Ids.GetIntId().Data)
})
......@@ -223,12 +223,13 @@ func TestMergeInternalRetrieveResults(t *testing.T) {
// Offset: []int64{0, 1},
FieldsData: fieldDataArray2,
}
ctx := context.TODO()
result, err := mergeInternalRetrieveResults([]*internalpb.RetrieveResults{result1, result2})
result, err := mergeInternalRetrieveResults(ctx, []*internalpb.RetrieveResults{result1, result2})
assert.NoError(t, err)
assert.Equal(t, 2, len(result.FieldsData[0].GetScalars().GetLongData().Data))
assert.Equal(t, 2*Dim, len(result.FieldsData[1].GetVectors().GetFloatVector().Data))
_, err = mergeInternalRetrieveResults(nil)
_, err = mergeInternalRetrieveResults(ctx, nil)
assert.NoError(t, err)
}
......@@ -17,6 +17,7 @@
package querynode
import (
"context"
"fmt"
"math"
"strconv"
......@@ -73,24 +74,24 @@ func reduceStatisticResponse(results []*internalpb.GetStatisticsResponse) (*inte
return ret, nil
}
func reduceSearchResults(results []*internalpb.SearchResults, nq int64, topk int64, metricType string) (*internalpb.SearchResults, error) {
func reduceSearchResults(ctx context.Context, results []*internalpb.SearchResults, nq int64, topk int64, metricType string) (*internalpb.SearchResults, error) {
searchResultData, err := decodeSearchResults(results)
if err != nil {
log.Warn("shard leader decode search results errors", zap.Error(err))
log.Ctx(ctx).Warn("shard leader decode search results errors", zap.Error(err))
return nil, err
}
log.Debug("shard leader get valid search results", zap.Int("numbers", len(searchResultData)))
log.Ctx(ctx).Debug("shard leader get valid search results", zap.Int("numbers", len(searchResultData)))
for i, sData := range searchResultData {
log.Debug("reduceSearchResultData",
log.Ctx(ctx).Debug("reduceSearchResultData",
zap.Int("result No.", i),
zap.Int64("nq", sData.NumQueries),
zap.Int64("topk", sData.TopK))
}
reducedResultData, err := reduceSearchResultData(searchResultData, nq, topk)
reducedResultData, err := reduceSearchResultData(ctx, searchResultData, nq, topk)
if err != nil {
log.Warn("shard leader reduce errors", zap.Error(err))
log.Ctx(ctx).Warn("shard leader reduce errors", zap.Error(err))
return nil, err
}
searchResults, err := encodeSearchResultData(reducedResultData, nq, topk, metricType)
......@@ -110,7 +111,7 @@ func reduceSearchResults(results []*internalpb.SearchResults, nq int64, topk int
return searchResults, nil
}
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq int64, topk int64) (*schemapb.SearchResultData, error) {
func reduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, nq int64, topk int64) (*schemapb.SearchResultData, error) {
if len(searchResultData) == 0 {
return &schemapb.SearchResultData{
NumQueries: nq,
......@@ -174,7 +175,7 @@ func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq in
// }
ret.Topks = append(ret.Topks, j)
}
log.Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
return ret, nil
}
......@@ -234,7 +235,7 @@ func encodeSearchResultData(searchResultData *schemapb.SearchResultData, nq int6
}
// TODO: largely based on function mergeSegcoreRetrieveResults, need rewriting
func mergeInternalRetrieveResults(retrieveResults []*internalpb.RetrieveResults) (*internalpb.RetrieveResults, error) {
func mergeInternalRetrieveResults(ctx context.Context, retrieveResults []*internalpb.RetrieveResults) (*internalpb.RetrieveResults, error) {
var ret *internalpb.RetrieveResults
var skipDupCnt int64
var idSet = make(map[interface{}]struct{})
......@@ -254,7 +255,7 @@ func mergeInternalRetrieveResults(retrieveResults []*internalpb.RetrieveResults)
}
if len(ret.FieldsData) != len(rr.FieldsData) {
log.Warn("mismatch FieldData in RetrieveResults")
log.Ctx(ctx).Warn("mismatch FieldData in RetrieveResults")
return nil, fmt.Errorf("mismatch FieldData in RetrieveResults")
}
......@@ -283,7 +284,7 @@ func mergeInternalRetrieveResults(retrieveResults []*internalpb.RetrieveResults)
return ret, nil
}
func mergeSegcoreRetrieveResults(retrieveResults []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) {
func mergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) {
var ret *segcorepb.RetrieveResults
var skipDupCnt int64
var idSet = make(map[interface{}]struct{})
......@@ -319,7 +320,7 @@ func mergeSegcoreRetrieveResults(retrieveResults []*segcorepb.RetrieveResults) (
}
}
}
log.Debug("skip duplicated query result", zap.Int64("count", skipDupCnt))
log.Ctx(ctx).Debug("skip duplicated query result", zap.Int64("count", skipDupCnt))
// not found, return default values indicating not result found
if ret == nil {
......
......@@ -17,6 +17,8 @@
package querynode
import (
"context"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/storage"
)
......@@ -44,12 +46,12 @@ func retrieveOnSegments(replica ReplicaInterface, segType segmentType, collID Un
}
// retrieveHistorical will retrieve all the target segments in historical
func retrieveHistorical(replica ReplicaInterface, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) {
func retrieveHistorical(ctx context.Context, replica ReplicaInterface, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) {
var err error
var retrieveResults []*segcorepb.RetrieveResults
var retrieveSegmentIDs []UniqueID
var retrievePartIDs []UniqueID
retrievePartIDs, retrieveSegmentIDs, err = validateOnHistoricalReplica(replica, collID, partIDs, segIDs)
retrievePartIDs, retrieveSegmentIDs, err = validateOnHistoricalReplica(ctx, replica, collID, partIDs, segIDs)
if err != nil {
return retrieveResults, retrieveSegmentIDs, retrievePartIDs, err
}
......@@ -59,13 +61,13 @@ func retrieveHistorical(replica ReplicaInterface, plan *RetrievePlan, collID Uni
}
// retrieveStreaming will retrieve all the target segments in streaming
func retrieveStreaming(replica ReplicaInterface, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, vChannel Channel, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) {
func retrieveStreaming(ctx context.Context, replica ReplicaInterface, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, vChannel Channel, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) {
var err error
var retrieveResults []*segcorepb.RetrieveResults
var retrievePartIDs []UniqueID
var retrieveSegmentIDs []UniqueID
retrievePartIDs, retrieveSegmentIDs, err = validateOnStreamReplica(replica, collID, partIDs, vChannel)
retrievePartIDs, retrieveSegmentIDs, err = validateOnStreamReplica(ctx, replica, collID, partIDs, vChannel)
if err != nil {
return retrieveResults, retrieveSegmentIDs, retrievePartIDs, err
}
......
......@@ -17,6 +17,7 @@
package querynode
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
......@@ -49,7 +50,7 @@ func TestStreaming_retrieve(t *testing.T) {
assert.NoError(t, err)
t.Run("test retrieve", func(t *testing.T) {
res, _, ids, err := retrieveStreaming(streaming, plan,
res, _, ids, err := retrieveStreaming(context.TODO(), streaming, plan,
defaultCollectionID,
[]UniqueID{defaultPartitionID},
defaultDMLChannel,
......@@ -61,7 +62,7 @@ func TestStreaming_retrieve(t *testing.T) {
t.Run("test empty partition", func(t *testing.T) {
res, _, ids, err := retrieveStreaming(streaming, plan,
res, _, ids, err := retrieveStreaming(context.TODO(), streaming, plan,
defaultCollectionID,
nil,
defaultDMLChannel,
......
......@@ -17,6 +17,7 @@
package querynode
import (
"context"
"fmt"
"sync"
......@@ -28,7 +29,7 @@ import (
// searchOnSegments performs search on listed segments
// all segment ids are validated before calling this function
func searchOnSegments(replica ReplicaInterface, segType segmentType, searchReq *searchRequest, segIDs []UniqueID) ([]*SearchResult, error) {
func searchOnSegments(ctx context.Context, replica ReplicaInterface, segType segmentType, searchReq *searchRequest, segIDs []UniqueID) ([]*SearchResult, error) {
// results variables
searchResults := make([]*SearchResult, len(segIDs))
errs := make([]error, len(segIDs))
......@@ -72,31 +73,31 @@ func searchOnSegments(replica ReplicaInterface, segType segmentType, searchReq *
// if segIDs is not specified, it will search on all the historical segments speficied by partIDs.
// if segIDs is specified, it will only search on the segments specified by the segIDs.
// if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded.
func searchHistorical(replica ReplicaInterface, searchReq *searchRequest, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]*SearchResult, []UniqueID, []UniqueID, error) {
func searchHistorical(ctx context.Context, replica ReplicaInterface, searchReq *searchRequest, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]*SearchResult, []UniqueID, []UniqueID, error) {
var err error
var searchResults []*SearchResult
var searchSegmentIDs []UniqueID
var searchPartIDs []UniqueID
searchPartIDs, searchSegmentIDs, err = validateOnHistoricalReplica(replica, collID, partIDs, segIDs)
searchPartIDs, searchSegmentIDs, err = validateOnHistoricalReplica(ctx, replica, collID, partIDs, segIDs)
if err != nil {
return searchResults, searchSegmentIDs, searchPartIDs, err
}
searchResults, err = searchOnSegments(replica, segmentTypeSealed, searchReq, searchSegmentIDs)
searchResults, err = searchOnSegments(ctx, replica, segmentTypeSealed, searchReq, searchSegmentIDs)
return searchResults, searchPartIDs, searchSegmentIDs, err
}
// searchStreaming will search all the target segments in streaming
// if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded.
func searchStreaming(replica ReplicaInterface, searchReq *searchRequest, collID UniqueID, partIDs []UniqueID, vChannel Channel) ([]*SearchResult, []UniqueID, []UniqueID, error) {
func searchStreaming(ctx context.Context, replica ReplicaInterface, searchReq *searchRequest, collID UniqueID, partIDs []UniqueID, vChannel Channel) ([]*SearchResult, []UniqueID, []UniqueID, error) {
var err error
var searchResults []*SearchResult
var searchPartIDs []UniqueID
var searchSegmentIDs []UniqueID
searchPartIDs, searchSegmentIDs, err = validateOnStreamReplica(replica, collID, partIDs, vChannel)
searchPartIDs, searchSegmentIDs, err = validateOnStreamReplica(ctx, replica, collID, partIDs, vChannel)
if err != nil {
return searchResults, searchSegmentIDs, searchPartIDs, err
}
searchResults, err = searchOnSegments(replica, segmentTypeGrowing, searchReq, searchSegmentIDs)
searchResults, err = searchOnSegments(ctx, replica, segmentTypeGrowing, searchReq, searchSegmentIDs)
return searchResults, searchPartIDs, searchSegmentIDs, err
}
......@@ -36,7 +36,7 @@ func TestHistorical_Search(t *testing.T) {
searchReq, err := genSearchPlanAndRequests(collection, IndexFaissIDMap, defaultNQ)
assert.NoError(t, err)
_, _, _, err = searchHistorical(his, searchReq, defaultCollectionID, nil, []UniqueID{defaultSegmentID})
_, _, _, err = searchHistorical(context.TODO(), his, searchReq, defaultCollectionID, nil, []UniqueID{defaultSegmentID})
assert.NoError(t, err)
})
......@@ -52,7 +52,7 @@ func TestHistorical_Search(t *testing.T) {
err = his.removeCollection(defaultCollectionID)
assert.NoError(t, err)
_, _, _, err = searchHistorical(his, searchReq, defaultCollectionID, nil, nil)
_, _, _, err = searchHistorical(context.TODO(), his, searchReq, defaultCollectionID, nil, nil)
assert.Error(t, err)
})
......@@ -68,7 +68,7 @@ func TestHistorical_Search(t *testing.T) {
err = his.removeCollection(defaultCollectionID)
assert.NoError(t, err)
_, _, _, err = searchHistorical(his, searchReq, defaultCollectionID, []UniqueID{defaultPartitionID}, nil)
_, _, _, err = searchHistorical(context.TODO(), his, searchReq, defaultCollectionID, []UniqueID{defaultPartitionID}, nil)
assert.Error(t, err)
})
......@@ -88,7 +88,7 @@ func TestHistorical_Search(t *testing.T) {
err = his.removePartition(defaultPartitionID)
assert.NoError(t, err)
_, _, _, err = searchHistorical(his, searchReq, defaultCollectionID, nil, nil)
_, _, _, err = searchHistorical(context.TODO(), his, searchReq, defaultCollectionID, nil, nil)
assert.Error(t, err)
})
......@@ -104,7 +104,7 @@ func TestHistorical_Search(t *testing.T) {
err = his.removePartition(defaultPartitionID)
assert.NoError(t, err)
res, _, ids, err := searchHistorical(his, searchReq, defaultCollectionID, nil, nil)
res, _, ids, err := searchHistorical(context.TODO(), his, searchReq, defaultCollectionID, nil, nil)
assert.Equal(t, 0, len(res))
assert.Equal(t, 0, len(ids))
assert.NoError(t, err)
......@@ -121,7 +121,7 @@ func TestStreaming_search(t *testing.T) {
searchReq, err := genSearchPlanAndRequests(collection, IndexFaissIDMap, defaultNQ)
assert.NoError(t, err)
res, _, _, err := searchStreaming(streaming, searchReq,
res, _, _, err := searchStreaming(context.TODO(), streaming, searchReq,
defaultCollectionID,
[]UniqueID{defaultPartitionID},
defaultDMLChannel)
......@@ -138,7 +138,7 @@ func TestStreaming_search(t *testing.T) {
searchReq, err := genSearchPlanAndRequests(collection, IndexFaissIDMap, defaultNQ)
assert.NoError(t, err)
res, _, _, err := searchStreaming(streaming, searchReq,
res, _, _, err := searchStreaming(context.TODO(), streaming, searchReq,
defaultCollectionID,
[]UniqueID{defaultPartitionID},
defaultDMLChannel)
......@@ -162,7 +162,7 @@ func TestStreaming_search(t *testing.T) {
err = streaming.removePartition(defaultPartitionID)
assert.NoError(t, err)
res, _, _, err := searchStreaming(streaming, searchReq,
res, _, _, err := searchStreaming(context.TODO(), streaming, searchReq,
defaultCollectionID,
[]UniqueID{defaultPartitionID},
defaultDMLChannel)
......@@ -187,7 +187,7 @@ func TestStreaming_search(t *testing.T) {
err = streaming.removePartition(defaultPartitionID)
assert.NoError(t, err)
_, _, _, err = searchStreaming(streaming, searchReq,
_, _, _, err = searchStreaming(context.TODO(), streaming, searchReq,
defaultCollectionID,
[]UniqueID{defaultPartitionID},
defaultDMLChannel)
......@@ -206,7 +206,7 @@ func TestStreaming_search(t *testing.T) {
err = streaming.removePartition(defaultPartitionID)
assert.NoError(t, err)
res, _, _, err := searchStreaming(streaming, searchReq,
res, _, _, err := searchStreaming(context.TODO(), streaming, searchReq,
defaultCollectionID,
[]UniqueID{},
defaultDMLChannel)
......@@ -228,7 +228,7 @@ func TestStreaming_search(t *testing.T) {
seg.segmentPtr = nil
_, _, _, err = searchStreaming(streaming, searchReq,
_, _, _, err = searchStreaming(context.TODO(), streaming, searchReq,
defaultCollectionID,
[]UniqueID{},
defaultDMLChannel)
......
......@@ -31,7 +31,7 @@ func TestHistorical_statistic(t *testing.T) {
his, err := genSimpleReplicaWithSealSegment(ctx)
assert.NoError(t, err)
_, _, _, err = statisticHistorical(his, defaultCollectionID, nil, []UniqueID{defaultSegmentID})
_, _, _, err = statisticHistorical(context.TODO(), his, defaultCollectionID, nil, []UniqueID{defaultSegmentID})
assert.NoError(t, err)
})
......@@ -42,7 +42,7 @@ func TestHistorical_statistic(t *testing.T) {
err = his.removeCollection(defaultCollectionID)
assert.NoError(t, err)
_, _, _, err = statisticHistorical(his, defaultCollectionID, nil, nil)
_, _, _, err = statisticHistorical(context.TODO(), his, defaultCollectionID, nil, nil)
assert.Error(t, err)
})
......@@ -53,7 +53,7 @@ func TestHistorical_statistic(t *testing.T) {
err = his.removeCollection(defaultCollectionID)
assert.NoError(t, err)
_, _, _, err = statisticHistorical(his, defaultCollectionID, []UniqueID{defaultPartitionID}, nil)
_, _, _, err = statisticHistorical(context.TODO(), his, defaultCollectionID, []UniqueID{defaultPartitionID}, nil)
assert.Error(t, err)
})
......@@ -68,7 +68,7 @@ func TestHistorical_statistic(t *testing.T) {
err = his.removePartition(defaultPartitionID)
assert.NoError(t, err)
_, _, _, err = statisticHistorical(his, defaultCollectionID, nil, nil)
_, _, _, err = statisticHistorical(context.TODO(), his, defaultCollectionID, nil, nil)
assert.Error(t, err)
})
......@@ -79,7 +79,7 @@ func TestHistorical_statistic(t *testing.T) {
err = his.removePartition(defaultPartitionID)
assert.NoError(t, err)
res, _, ids, err := statisticHistorical(his, defaultCollectionID, nil, nil)
res, _, ids, err := statisticHistorical(context.TODO(), his, defaultCollectionID, nil, nil)
assert.Equal(t, 0, len(res))
assert.Equal(t, 0, len(ids))
assert.NoError(t, err)
......@@ -91,7 +91,7 @@ func TestStreaming_statistics(t *testing.T) {
streaming, err := genSimpleReplicaWithGrowingSegment()
assert.NoError(t, err)
res, _, _, err := statisticStreaming(streaming,
res, _, _, err := statisticStreaming(context.TODO(), streaming,
defaultCollectionID,
[]UniqueID{defaultPartitionID},
defaultDMLChannel)
......@@ -103,7 +103,7 @@ func TestStreaming_statistics(t *testing.T) {
streaming, err := genSimpleReplicaWithGrowingSegment()
assert.NoError(t, err)
res, _, _, err := statisticStreaming(streaming,
res, _, _, err := statisticStreaming(context.TODO(), streaming,
defaultCollectionID,
[]UniqueID{defaultPartitionID},
defaultDMLChannel)
......@@ -122,7 +122,7 @@ func TestStreaming_statistics(t *testing.T) {
err = streaming.removePartition(defaultPartitionID)
assert.NoError(t, err)
res, _, _, err := statisticStreaming(streaming,
res, _, _, err := statisticStreaming(context.TODO(), streaming,
defaultCollectionID,
[]UniqueID{defaultPartitionID},
defaultDMLChannel)
......@@ -142,7 +142,7 @@ func TestStreaming_statistics(t *testing.T) {
err = streaming.removePartition(defaultPartitionID)
assert.NoError(t, err)
_, _, _, err = statisticStreaming(streaming,
_, _, _, err = statisticStreaming(context.TODO(), streaming,
defaultCollectionID,
[]UniqueID{defaultPartitionID},
defaultDMLChannel)
......@@ -156,7 +156,7 @@ func TestStreaming_statistics(t *testing.T) {
err = streaming.removePartition(defaultPartitionID)
assert.NoError(t, err)
res, _, _, err := statisticStreaming(streaming,
res, _, _, err := statisticStreaming(context.TODO(), streaming,
defaultCollectionID,
[]UniqueID{},
defaultDMLChannel)
......
package querynode
import (
"context"
"fmt"
"sync"
......@@ -55,8 +56,8 @@ func statisticOnSegments(replica ReplicaInterface, segType segmentType, segIDs [
// if segIDs is not specified, it will search on all the historical segments specified by partIDs.
// if segIDs is specified, it will only search on the segments specified by the segIDs.
// if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded.
func statisticHistorical(replica ReplicaInterface, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]map[string]interface{}, []UniqueID, []UniqueID, error) {
searchPartIDs, searchSegmentIDs, err := validateOnHistoricalReplica(replica, collID, partIDs, segIDs)
func statisticHistorical(ctx context.Context, replica ReplicaInterface, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]map[string]interface{}, []UniqueID, []UniqueID, error) {
searchPartIDs, searchSegmentIDs, err := validateOnHistoricalReplica(ctx, replica, collID, partIDs, segIDs)
if err != nil {
return nil, searchSegmentIDs, searchPartIDs, err
}
......@@ -66,8 +67,8 @@ func statisticHistorical(replica ReplicaInterface, collID UniqueID, partIDs []Un
// statisticStreaming will do statistics all the target segments in streaming
// if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded.
func statisticStreaming(replica ReplicaInterface, collID UniqueID, partIDs []UniqueID, vChannel Channel) ([]map[string]interface{}, []UniqueID, []UniqueID, error) {
searchPartIDs, searchSegmentIDs, err := validateOnStreamReplica(replica, collID, partIDs, vChannel)
func statisticStreaming(ctx context.Context, replica ReplicaInterface, collID UniqueID, partIDs []UniqueID, vChannel Channel) ([]map[string]interface{}, []UniqueID, []UniqueID, error) {
searchPartIDs, searchSegmentIDs, err := validateOnStreamReplica(ctx, replica, collID, partIDs, vChannel)
if err != nil {
return nil, searchSegmentIDs, searchPartIDs, err
}
......
......@@ -53,6 +53,7 @@ func (q *queryTask) PreExecute(ctx context.Context) error {
// TODO: merge queryOnStreaming and queryOnHistorical?
func (q *queryTask) queryOnStreaming() error {
// check ctx timeout
ctx := q.Ctx()
if !funcutil.CheckCtxValid(q.Ctx()) {
return errors.New("query context timeout")
}
......@@ -66,7 +67,7 @@ func (q *queryTask) queryOnStreaming() error {
q.QS.collection.RLock() // locks the collectionPtr
defer q.QS.collection.RUnlock()
if _, released := q.QS.collection.getReleaseTime(); released {
log.Debug("collection release before search", zap.Int64("msgID", q.ID()),
log.Ctx(ctx).Debug("collection release before search", zap.Int64("msgID", q.ID()),
zap.Int64("collectionID", q.CollectionID))
return fmt.Errorf("retrieve failed, collection has been released, collectionID = %d", q.CollectionID)
}
......@@ -78,13 +79,13 @@ func (q *queryTask) queryOnStreaming() error {
}
defer plan.delete()
sResults, _, _, sErr := retrieveStreaming(q.QS.metaReplica, plan, q.CollectionID, q.iReq.GetPartitionIDs(), q.QS.channel, q.QS.vectorChunkManager)
sResults, _, _, sErr := retrieveStreaming(ctx, q.QS.metaReplica, plan, q.CollectionID, q.iReq.GetPartitionIDs(), q.QS.channel, q.QS.vectorChunkManager)
if sErr != nil {
return sErr
}
q.tr.RecordSpan()
mergedResult, err := mergeSegcoreRetrieveResults(sResults)
mergedResult, err := mergeSegcoreRetrieveResults(ctx, sResults)
if err != nil {
return err
}
......@@ -100,7 +101,8 @@ func (q *queryTask) queryOnStreaming() error {
func (q *queryTask) queryOnHistorical() error {
// check ctx timeout
if !funcutil.CheckCtxValid(q.Ctx()) {
ctx := q.Ctx()
if !funcutil.CheckCtxValid(ctx) {
return errors.New("search context timeout3$")
}
......@@ -114,7 +116,7 @@ func (q *queryTask) queryOnHistorical() error {
defer q.QS.collection.RUnlock()
if _, released := q.QS.collection.getReleaseTime(); released {
log.Debug("collection release before search", zap.Int64("msgID", q.ID()),
log.Ctx(ctx).Debug("collection release before search", zap.Int64("msgID", q.ID()),
zap.Int64("collectionID", q.CollectionID))
return fmt.Errorf("retrieve failed, collection has been released, collectionID = %d", q.CollectionID)
}
......@@ -125,11 +127,11 @@ func (q *queryTask) queryOnHistorical() error {
return err
}
defer plan.delete()
retrieveResults, _, _, err := retrieveHistorical(q.QS.metaReplica, plan, q.CollectionID, nil, q.req.SegmentIDs, q.QS.vectorChunkManager)
retrieveResults, _, _, err := retrieveHistorical(ctx, q.QS.metaReplica, plan, q.CollectionID, nil, q.req.SegmentIDs, q.QS.vectorChunkManager)
if err != nil {
return err
}
mergedResult, err := mergeSegcoreRetrieveResults(retrieveResults)
mergedResult, err := mergeSegcoreRetrieveResults(ctx, retrieveResults)
if err != nil {
return err
}
......
......@@ -85,7 +85,8 @@ func (s *searchTask) init() error {
// TODO: merge searchOnStreaming and searchOnHistorical?
func (s *searchTask) searchOnStreaming() error {
// check ctx timeout
if !funcutil.CheckCtxValid(s.Ctx()) {
ctx := s.Ctx()
if !funcutil.CheckCtxValid(ctx) {
return errors.New("search context timeout")
}
......@@ -102,7 +103,7 @@ func (s *searchTask) searchOnStreaming() error {
s.QS.collection.RLock() // locks the collectionPtr
defer s.QS.collection.RUnlock()
if _, released := s.QS.collection.getReleaseTime(); released {
log.Debug("collection release before search", zap.Int64("msgID", s.ID()),
log.Ctx(ctx).Debug("collection release before search", zap.Int64("msgID", s.ID()),
zap.Int64("collectionID", s.CollectionID))
return fmt.Errorf("retrieve failed, collection has been released, collectionID = %d", s.CollectionID)
}
......@@ -113,20 +114,20 @@ func (s *searchTask) searchOnStreaming() error {
}
defer searchReq.delete()
// TODO add context
partResults, _, _, sErr := searchStreaming(s.QS.metaReplica, searchReq, s.CollectionID, s.iReq.GetPartitionIDs(), s.req.GetDmlChannels()[0])
partResults, _, _, sErr := searchStreaming(ctx, s.QS.metaReplica, searchReq, s.CollectionID, s.iReq.GetPartitionIDs(), s.req.GetDmlChannels()[0])
if sErr != nil {
log.Debug("failed to search streaming data", zap.Int64("msgID", s.ID()),
log.Ctx(ctx).Warn("failed to search streaming data", zap.Int64("msgID", s.ID()),
zap.Int64("collectionID", s.CollectionID), zap.Error(sErr))
return sErr
}
defer deleteSearchResults(partResults)
return s.reduceResults(searchReq, partResults)
return s.reduceResults(ctx, searchReq, partResults)
}
func (s *searchTask) searchOnHistorical() error {
// check ctx timeout
if !funcutil.CheckCtxValid(s.Ctx()) {
ctx := s.Ctx()
if !funcutil.CheckCtxValid(ctx) {
return errors.New("search context timeout")
}
......@@ -139,7 +140,7 @@ func (s *searchTask) searchOnHistorical() error {
s.QS.collection.RLock() // locks the collectionPtr
defer s.QS.collection.RUnlock()
if _, released := s.QS.collection.getReleaseTime(); released {
log.Debug("collection release before search", zap.Int64("msgID", s.ID()),
log.Ctx(ctx).Warn("collection release before search", zap.Int64("msgID", s.ID()),
zap.Int64("collectionID", s.CollectionID))
return fmt.Errorf("retrieve failed, collection has been released, collectionID = %d", s.CollectionID)
}
......@@ -151,12 +152,12 @@ func (s *searchTask) searchOnHistorical() error {
}
defer searchReq.delete()
partResults, _, _, err := searchHistorical(s.QS.metaReplica, searchReq, s.CollectionID, nil, segmentIDs)
partResults, _, _, err := searchHistorical(ctx, s.QS.metaReplica, searchReq, s.CollectionID, nil, segmentIDs)
if err != nil {
return err
}
defer deleteSearchResults(partResults)
return s.reduceResults(searchReq, partResults)
return s.reduceResults(ctx, searchReq, partResults)
}
func (s *searchTask) Execute(ctx context.Context) error {
......@@ -217,7 +218,7 @@ func (s *searchTask) CPUUsage() int32 {
}
// reduceResults reduce search results
func (s *searchTask) reduceResults(searchReq *searchRequest, results []*SearchResult) error {
func (s *searchTask) reduceResults(ctx context.Context, searchReq *searchRequest, results []*SearchResult) error {
isEmpty := len(results) == 0
cnt := 1 + len(s.otherTasks)
var t *searchTask
......@@ -227,7 +228,7 @@ func (s *searchTask) reduceResults(searchReq *searchRequest, results []*SearchRe
numSegment := int64(len(results))
blobs, err := reduceSearchResultsAndFillData(searchReq.plan, results, numSegment, sInfo.sliceNQs, sInfo.sliceTopKs)
if err != nil {
log.Debug("marshal for historical results error", zap.Int64("msgID", s.ID()), zap.Error(err))
log.Ctx(ctx).Warn("marshal for historical results error", zap.Int64("msgID", s.ID()), zap.Error(err))
return err
}
defer deleteSearchResultDataBlobs(blobs)
......@@ -235,7 +236,7 @@ func (s *searchTask) reduceResults(searchReq *searchRequest, results []*SearchRe
for i := 0; i < cnt; i++ {
blob, err := getSearchResultDataBlob(blobs, i)
if err != nil {
log.Debug("getSearchResultDataBlob for historical results error", zap.Int64("msgID", s.ID()),
log.Ctx(ctx).Warn("getSearchResultDataBlob for historical results error", zap.Int64("msgID", s.ID()),
zap.Error(err))
return err
}
......
......@@ -34,7 +34,8 @@ type statistics struct {
func (s *statistics) statisticOnStreaming() error {
// check ctx timeout
if !funcutil.CheckCtxValid(s.ctx) {
ctx := s.ctx
if !funcutil.CheckCtxValid(ctx) {
return errors.New("get statistics context timeout")
}
......@@ -47,14 +48,15 @@ func (s *statistics) statisticOnStreaming() error {
s.qs.collection.RLock() // locks the collectionPtr
defer s.qs.collection.RUnlock()
if _, released := s.qs.collection.getReleaseTime(); released {
log.Debug("collection release before do statistics", zap.Int64("msgID", s.id),
log.Ctx(ctx).Warn("collection release before do statistics", zap.Int64("msgID", s.id),
zap.Int64("collectionID", s.iReq.GetCollectionID()))
return fmt.Errorf("statistic failed, collection has been released, collectionID = %d", s.iReq.GetCollectionID())
}
results, _, _, err := statisticStreaming(s.qs.metaReplica, s.iReq.GetCollectionID(), s.iReq.GetPartitionIDs(), s.req.GetDmlChannels()[0])
results, _, _, err := statisticStreaming(ctx, s.qs.metaReplica, s.iReq.GetCollectionID(),
s.iReq.GetPartitionIDs(), s.req.GetDmlChannels()[0])
if err != nil {
log.Debug("failed to statistic on streaming data", zap.Int64("msgID", s.id),
log.Ctx(ctx).Warn("failed to statistic on streaming data", zap.Int64("msgID", s.id),
zap.Int64("collectionID", s.iReq.GetCollectionID()), zap.Error(err))
return err
}
......@@ -63,7 +65,8 @@ func (s *statistics) statisticOnStreaming() error {
func (s *statistics) statisticOnHistorical() error {
// check ctx timeout
if !funcutil.CheckCtxValid(s.ctx) {
ctx := s.ctx
if !funcutil.CheckCtxValid(ctx) {
return errors.New("get statistics context timeout")
}
......@@ -76,13 +79,13 @@ func (s *statistics) statisticOnHistorical() error {
s.qs.collection.RLock() // locks the collectionPtr
defer s.qs.collection.RUnlock()
if _, released := s.qs.collection.getReleaseTime(); released {
log.Debug("collection release before do statistics", zap.Int64("msgID", s.id),
log.Ctx(ctx).Debug("collection release before do statistics", zap.Int64("msgID", s.id),
zap.Int64("collectionID", s.iReq.GetCollectionID()))
return fmt.Errorf("statistic failed, collection has been released, collectionID = %d", s.iReq.GetCollectionID())
}
segmentIDs := s.req.GetSegmentIDs()
results, _, _, err := statisticHistorical(s.qs.metaReplica, s.iReq.GetCollectionID(), s.iReq.GetPartitionIDs(), segmentIDs)
results, _, _, err := statisticHistorical(ctx, s.qs.metaReplica, s.iReq.GetCollectionID(), s.iReq.GetPartitionIDs(), segmentIDs)
if err != nil {
return err
}
......
......@@ -17,6 +17,7 @@
package querynode
import (
"context"
"errors"
"fmt"
......@@ -26,7 +27,7 @@ import (
)
// TODO: merge validate?
func validateOnHistoricalReplica(replica ReplicaInterface, collectionID UniqueID, partitionIDs []UniqueID, segmentIDs []UniqueID) ([]UniqueID, []UniqueID, error) {
func validateOnHistoricalReplica(ctx context.Context, replica ReplicaInterface, collectionID UniqueID, partitionIDs []UniqueID, segmentIDs []UniqueID) ([]UniqueID, []UniqueID, error) {
var err error
var searchPartIDs []UniqueID
......@@ -46,7 +47,7 @@ func validateOnHistoricalReplica(replica ReplicaInterface, collectionID UniqueID
}
}
log.Debug("read target partitions", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", searchPartIDs))
log.Ctx(ctx).Debug("read target partitions", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", searchPartIDs))
col, err2 := replica.getCollectionByID(collectionID)
if err2 != nil {
return searchPartIDs, segmentIDs, err2
......@@ -86,7 +87,7 @@ func validateOnHistoricalReplica(replica ReplicaInterface, collectionID UniqueID
return searchPartIDs, newSegmentIDs, nil
}
func validateOnStreamReplica(replica ReplicaInterface, collectionID UniqueID, partitionIDs []UniqueID, vChannel Channel) ([]UniqueID, []UniqueID, error) {
func validateOnStreamReplica(ctx context.Context, replica ReplicaInterface, collectionID UniqueID, partitionIDs []UniqueID, vChannel Channel) ([]UniqueID, []UniqueID, error) {
var err error
var searchPartIDs []UniqueID
var segmentIDs []UniqueID
......@@ -107,7 +108,7 @@ func validateOnStreamReplica(replica ReplicaInterface, collectionID UniqueID, pa
}
}
log.Debug("read target partitions", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", searchPartIDs))
log.Ctx(ctx).Debug("read target partitions", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", searchPartIDs))
col, err2 := replica.getCollectionByID(collectionID)
if err2 != nil {
return searchPartIDs, segmentIDs, err2
......@@ -123,7 +124,7 @@ func validateOnStreamReplica(replica ReplicaInterface, collectionID UniqueID, pa
}
segmentIDs, err = replica.getSegmentIDsByVChannel(searchPartIDs, vChannel, segmentTypeGrowing)
log.Debug("validateOnStreamReplica getSegmentIDsByVChannel",
log.Ctx(ctx).Debug("validateOnStreamReplica getSegmentIDsByVChannel",
zap.Any("collectionID", collectionID),
zap.Any("vChannel", vChannel),
zap.Any("partitionIDs", searchPartIDs),
......
......@@ -30,35 +30,35 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) {
t.Run("test normal validate", func(t *testing.T) {
his, err := genSimpleReplicaWithSealSegment(ctx)
assert.NoError(t, err)
_, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID})
_, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID})
assert.NoError(t, err)
})
t.Run("test normal validate2", func(t *testing.T) {
his, err := genSimpleReplicaWithSealSegment(ctx)
assert.NoError(t, err)
_, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID})
_, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID})
assert.NoError(t, err)
})
t.Run("test validate non-existent collection", func(t *testing.T) {
his, err := genSimpleReplicaWithSealSegment(ctx)
assert.NoError(t, err)
_, _, err = validateOnHistoricalReplica(his, defaultCollectionID+1, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID})
_, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID+1, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID})
assert.Error(t, err)
})
t.Run("test validate non-existent partition", func(t *testing.T) {
his, err := genSimpleReplicaWithSealSegment(ctx)
assert.NoError(t, err)
_, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{defaultPartitionID + 1}, []UniqueID{defaultSegmentID})
_, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{defaultPartitionID + 1}, []UniqueID{defaultSegmentID})
assert.Error(t, err)
})
t.Run("test validate non-existent segment", func(t *testing.T) {
his, err := genSimpleReplicaWithSealSegment(ctx)
assert.NoError(t, err)
_, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID + 1})
_, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID + 1})
assert.Error(t, err)
})
......@@ -79,7 +79,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) {
assert.NoError(t, err)
// Scenario: search for a segment (segmentID = defaultSegmentID + 1, partitionID = defaultPartitionID+1)
// that does not belong to defaultPartition
_, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID + 1})
_, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID + 1})
assert.Error(t, err)
})
......@@ -88,7 +88,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) {
assert.NoError(t, err)
err = his.removePartition(defaultPartitionID)
assert.NoError(t, err)
_, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID})
_, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID})
assert.Error(t, err)
})
......@@ -100,7 +100,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) {
col.setLoadType(loadTypePartition)
err = his.removePartition(defaultPartitionID)
assert.NoError(t, err)
_, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID})
_, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID})
assert.Error(t, err)
})
......@@ -112,7 +112,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) {
col.setLoadType(loadTypeCollection)
err = his.removePartition(defaultPartitionID)
assert.NoError(t, err)
_, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID})
_, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID})
assert.NoError(t, err)
})
}
......@@ -1580,7 +1580,7 @@ func (c *Core) DescribeCollection(ctx context.Context, in *milvuspb.DescribeColl
}
tr := timerecord.NewTimeRecorder("DescribeCollection")
log.Debug("DescribeCollection", zap.String("role", typeutil.RootCoordRole),
log.Ctx(ctx).Debug("DescribeCollection", zap.String("role", typeutil.RootCoordRole),
zap.String("collection name", in.CollectionName), zap.Int64("id", in.CollectionID), zap.Int64("msgID", in.Base.MsgID))
t := &DescribeCollectionReqTask{
baseReqTask: baseReqTask{
......@@ -1592,14 +1592,14 @@ func (c *Core) DescribeCollection(ctx context.Context, in *milvuspb.DescribeColl
}
err := executeTask(t)
if err != nil {
log.Error("DescribeCollection failed", zap.String("role", typeutil.RootCoordRole),
log.Ctx(ctx).Warn("DescribeCollection failed", zap.String("role", typeutil.RootCoordRole),
zap.String("collection name", in.CollectionName), zap.Int64("id", in.CollectionID), zap.Int64("msgID", in.Base.MsgID), zap.Error(err))
metrics.RootCoordDDLReqCounter.WithLabelValues("DescribeCollection", metrics.FailLabel).Inc()
return &milvuspb.DescribeCollectionResponse{
Status: failStatus(commonpb.ErrorCode_UnexpectedError, "DescribeCollection failed: "+err.Error()),
}, nil
}
log.Debug("DescribeCollection success", zap.String("role", typeutil.RootCoordRole),
log.Ctx(ctx).Debug("DescribeCollection success", zap.String("role", typeutil.RootCoordRole),
zap.String("collection name", in.CollectionName), zap.Int64("id", in.CollectionID), zap.Int64("msgID", in.Base.MsgID))
metrics.RootCoordDDLReqCounter.WithLabelValues("DescribeCollection", metrics.SuccessLabel).Inc()
......
......@@ -274,6 +274,7 @@ func (t *CreateCollectionReqTask) Execute(ctx context.Context) error {
return err
}
log.NewMetaLogger().WithCollectionMeta(&collInfo).WithOperation(log.CreateCollection).WithTSO(ts).Info()
return nil
}
......@@ -391,6 +392,8 @@ func (t *DropCollectionReqTask) Execute(ctx context.Context) error {
return err
}
log.NewMetaLogger().WithCollectionID(collMeta.CollectionID).
WithCollectionName(collMeta.Name).WithTSO(ts).WithOperation(log.DropCollection).Info()
return nil
}
......@@ -594,6 +597,8 @@ func (t *CreatePartitionReqTask) Execute(ctx context.Context) error {
return err
}
log.NewMetaLogger().WithCollectionName(collMeta.Name).WithCollectionID(collMeta.CollectionID).
WithPartitionID(partID).WithPartitionName(t.Req.PartitionName).WithTSO(ts).WithOperation(log.CreatePartition).Info()
return nil
}
......@@ -691,6 +696,8 @@ func (t *DropPartitionReqTask) Execute(ctx context.Context) error {
// return err
//}
log.NewMetaLogger().WithCollectionID(collInfo.CollectionID).WithCollectionName(collInfo.Name).
WithPartitionName(t.Req.PartitionName).WithTSO(ts).WithOperation(log.DropCollection).Info()
return nil
}
......@@ -1038,6 +1045,11 @@ func (t *CreateIndexReqTask) Execute(ctx context.Context) error {
}
}
idxMeta, err := t.core.MetaTable.GetIndexByID(indexID)
if err == nil {
log.NewMetaLogger().WithIndexMeta(idxMeta).WithOperation(log.CreateIndex).WithTSO(createTS).Info()
}
return nil
}
......@@ -1098,6 +1110,15 @@ func (t *DropIndexReqTask) Execute(ctx context.Context) error {
if err := t.core.MetaTable.MarkIndexDeleted(t.Req.CollectionName, t.Req.FieldName, t.Req.IndexName); err != nil {
return err
}
deleteTS, err := t.core.TSOAllocator(1)
if err != nil {
return err
}
log.NewMetaLogger().WithCollectionName(t.Req.CollectionName).
WithFieldName(t.Req.FieldName).
WithIndexName(t.Req.IndexName).
WithOperation(log.DropIndex).WithTSO(deleteTS).Info()
return nil
}
......@@ -1127,6 +1148,7 @@ func (t *CreateAliasReqTask) Execute(ctx context.Context) error {
return fmt.Errorf("meta table add alias failed, error = %w", err)
}
log.NewMetaLogger().WithCollectionName(t.Req.CollectionName).WithAlias(t.Req.Alias).WithTSO(ts).WithOperation(log.CreateCollectionAlias).Info()
return nil
}
......@@ -1156,7 +1178,12 @@ func (t *DropAliasReqTask) Execute(ctx context.Context) error {
return fmt.Errorf("meta table drop alias failed, error = %w", err)
}
return t.core.ExpireMetaCache(ctx, []string{t.Req.Alias}, InvalidCollectionID, ts)
if err := t.core.ExpireMetaCache(ctx, []string{t.Req.Alias}, InvalidCollectionID, ts); err != nil {
return err
}
log.NewMetaLogger().WithAlias(t.Req.Alias).WithOperation(log.DropCollectionAlias).WithTSO(ts).Info()
return nil
}
// AlterAliasReqTask alter alias request task
......@@ -1185,5 +1212,11 @@ func (t *AlterAliasReqTask) Execute(ctx context.Context) error {
return fmt.Errorf("meta table alter alias failed, error = %w", err)
}
return t.core.ExpireMetaCache(ctx, []string{t.Req.Alias}, InvalidCollectionID, ts)
if err := t.core.ExpireMetaCache(ctx, []string{t.Req.Alias}, InvalidCollectionID, ts); err != nil {
return nil
}
log.NewMetaLogger().WithCollectionName(t.Req.CollectionName).
WithAlias(t.Req.Alias).WithOperation(log.AlterCollectionAlias).WithTSO(ts).Info()
return nil
}
package logutil
import (
"context"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/util/trace"
"go.uber.org/zap/zapcore"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
const (
logLevelRPCMetaKey = "log_level"
clientRequestIDKey = "client_request_id"
)
// UnaryTraceLoggerInterceptor adds a traced logger in unary rpc call ctx
func UnaryTraceLoggerInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
newctx := withLevelAndTrace(ctx)
return handler(newctx, req)
}
// StreamTraceLoggerInterceptor add a traced logger in stream rpc call ctx
func StreamTraceLoggerInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
ctx := ss.Context()
newctx := withLevelAndTrace(ctx)
wrappedStream := grpc_middleware.WrapServerStream(ss)
wrappedStream.WrappedContext = newctx
return handler(srv, wrappedStream)
}
func withLevelAndTrace(ctx context.Context) context.Context {
newctx := ctx
var traceID string
if md, ok := metadata.FromIncomingContext(ctx); ok {
levels := md.Get(logLevelRPCMetaKey)
// get log level
if len(levels) >= 1 {
level := zapcore.DebugLevel
if err := level.UnmarshalText([]byte(levels[0])); err != nil {
newctx = ctx
} else {
switch level {
case zapcore.DebugLevel:
newctx = log.WithDebugLevel(ctx)
case zapcore.InfoLevel:
newctx = log.WithInfoLevel(ctx)
case zapcore.WarnLevel:
newctx = log.WithWarnLevel(ctx)
case zapcore.ErrorLevel:
newctx = log.WithErrorLevel(ctx)
case zapcore.FatalLevel:
newctx = log.WithFatalLevel(ctx)
default:
newctx = ctx
}
}
// inject log level to outgoing meta
newctx = metadata.AppendToOutgoingContext(newctx, logLevelRPCMetaKey, level.String())
}
// client request id
requestID := md.Get(clientRequestIDKey)
if len(requestID) >= 1 {
traceID = requestID[0]
// inject traceid in order to pass client request id
newctx = metadata.AppendToOutgoingContext(newctx, clientRequestIDKey, traceID)
}
}
if traceID == "" {
traceID, _, _ = trace.InfoFromContext(newctx)
}
if traceID != "" {
newctx = log.WithTraceID(newctx, traceID)
}
return newctx
}
package logutil
import (
"context"
"testing"
"github.com/milvus-io/milvus/internal/log"
"github.com/stretchr/testify/assert"
"go.uber.org/zap/zapcore"
"google.golang.org/grpc/metadata"
)
func TestCtxWithLevelAndTrace(t *testing.T) {
t.Run("debug level", func(t *testing.T) {
ctx := withMetaData(context.TODO(), zapcore.DebugLevel)
newctx := withLevelAndTrace(ctx)
assert.Equal(t, log.Ctx(log.WithDebugLevel(context.TODO())), log.Ctx(newctx))
})
t.Run("info level", func(t *testing.T) {
ctx := context.TODO()
newctx := withLevelAndTrace(withMetaData(ctx, zapcore.InfoLevel))
assert.Equal(t, log.Ctx(log.WithInfoLevel(ctx)), log.Ctx(newctx))
})
t.Run("warn level", func(t *testing.T) {
ctx := context.TODO()
newctx := withLevelAndTrace(withMetaData(ctx, zapcore.WarnLevel))
assert.Equal(t, log.Ctx(log.WithWarnLevel(ctx)), log.Ctx(newctx))
})
t.Run("error level", func(t *testing.T) {
ctx := context.TODO()
newctx := withLevelAndTrace(withMetaData(ctx, zapcore.ErrorLevel))
assert.Equal(t, log.Ctx(log.WithErrorLevel(ctx)), log.Ctx(newctx))
})
t.Run("fatal level", func(t *testing.T) {
ctx := context.TODO()
newctx := withLevelAndTrace(withMetaData(ctx, zapcore.FatalLevel))
assert.Equal(t, log.Ctx(log.WithFatalLevel(ctx)), log.Ctx(newctx))
})
t.Run(("pass through variables"), func(t *testing.T) {
md := metadata.New(map[string]string{
logLevelRPCMetaKey: zapcore.ErrorLevel.String(),
clientRequestIDKey: "client-req-id",
})
ctx := metadata.NewIncomingContext(context.TODO(), md)
newctx := withLevelAndTrace(ctx)
md, ok := metadata.FromOutgoingContext(newctx)
assert.True(t, ok)
assert.Equal(t, "client-req-id", md.Get(clientRequestIDKey)[0])
assert.Equal(t, zapcore.ErrorLevel.String(), md.Get(logLevelRPCMetaKey)[0])
})
}
func withMetaData(ctx context.Context, level zapcore.Level) context.Context {
md := metadata.New(map[string]string{
logLevelRPCMetaKey: level.String(),
})
return metadata.NewIncomingContext(context.TODO(), md)
}
......@@ -141,7 +141,7 @@ var once sync.Once
func SetupLogger(cfg *log.Config) {
once.Do(func() {
// Initialize logger.
logger, p, err := log.InitLogger(cfg, zap.AddStacktrace(zap.ErrorLevel), zap.AddCallerSkip(1))
logger, p, err := log.InitLogger(cfg, zap.AddStacktrace(zap.ErrorLevel))
if err == nil {
log.ReplaceGlobals(logger, p)
} else {
......@@ -167,54 +167,10 @@ func SetupLogger(cfg *log.Config) {
})
}
type logKey int
const logCtxKey logKey = iota
// WithField adds given kv field to the logger in ctx
func WithField(ctx context.Context, key string, value string) context.Context {
logger := log.L()
if ctxLogger, ok := ctx.Value(logCtxKey).(*zap.Logger); ok {
logger = ctxLogger
}
return context.WithValue(ctx, logCtxKey, logger.With(zap.String(key, value)))
}
// WithReqID adds given reqID field to the logger in ctx
func WithReqID(ctx context.Context, reqID int64) context.Context {
logger := log.L()
if ctxLogger, ok := ctx.Value(logCtxKey).(*zap.Logger); ok {
logger = ctxLogger
}
return context.WithValue(ctx, logCtxKey, logger.With(zap.Int64("reqID", reqID)))
}
// WithModule adds given module field to the logger in ctx
func WithModule(ctx context.Context, module string) context.Context {
logger := log.L()
if ctxLogger, ok := ctx.Value(logCtxKey).(*zap.Logger); ok {
logger = ctxLogger
}
return context.WithValue(ctx, logCtxKey, logger.With(zap.String("module", module)))
}
func WithLogger(ctx context.Context, logger *zap.Logger) context.Context {
if logger == nil {
logger = log.L()
}
return context.WithValue(ctx, logCtxKey, logger)
}
func Logger(ctx context.Context) *zap.Logger {
if ctxLogger, ok := ctx.Value(logCtxKey).(*zap.Logger); ok {
return ctxLogger
}
return log.L()
return log.Ctx(ctx).Logger
}
func BgLogger() *zap.Logger {
return log.L()
func WithModule(ctx context.Context, module string) context.Context {
return log.WithModule(ctx, module)
}
......@@ -25,6 +25,7 @@ type TimeRecorder struct {
header string
start time.Time
last time.Time
ctx context.Context
}
// NewTimeRecorder creates a new TimeRecorder
......@@ -55,18 +56,30 @@ func (tr *TimeRecorder) ElapseSpan() time.Duration {
// Record calculates the time span from previous Record call
func (tr *TimeRecorder) Record(msg string) time.Duration {
span := tr.RecordSpan()
tr.printTimeRecord(msg, span)
tr.printTimeRecord(context.TODO(), msg, span)
return span
}
func (tr *TimeRecorder) CtxRecord(ctx context.Context, msg string) time.Duration {
span := tr.RecordSpan()
tr.printTimeRecord(ctx, msg, span)
return span
}
// Elapse calculates the time span from the beginning of this TimeRecorder
func (tr *TimeRecorder) Elapse(msg string) time.Duration {
span := tr.ElapseSpan()
tr.printTimeRecord(msg, span)
tr.printTimeRecord(context.TODO(), msg, span)
return span
}
func (tr *TimeRecorder) CtxElapse(ctx context.Context, msg string) time.Duration {
span := tr.ElapseSpan()
tr.printTimeRecord(ctx, msg, span)
return span
}
func (tr *TimeRecorder) printTimeRecord(msg string, span time.Duration) {
func (tr *TimeRecorder) printTimeRecord(ctx context.Context, msg string, span time.Duration) {
str := ""
if tr.header != "" {
str += tr.header + ": "
......@@ -75,7 +88,7 @@ func (tr *TimeRecorder) printTimeRecord(msg string, span time.Duration) {
str += " ("
str += strconv.Itoa(int(span.Milliseconds()))
str += "ms)"
log.Debug(str)
log.Ctx(ctx).Debug(str)
}
// LongTermChecker checks we receive at least one msg in d duration. If not, checker
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册