未验证 提交 543c4891 编写于 作者: Z zhenshan.cao 提交者: GitHub

Prevent the client from closing grpc connection by mistake (#11918)

Signed-off-by: Nzhenshan.cao <zhenshan.cao@zilliz.com>
上级 6d652263
......@@ -27,6 +27,7 @@ import (
grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/trace"
......@@ -192,14 +193,15 @@ func (c *Client) recall(caller func() (interface{}, error)) (interface{}, error)
if err == nil {
return ret, nil
}
if err == context.Canceled || err == context.DeadlineExceeded {
return nil, err
}
log.Debug("DataCoord Client grpc error", zap.Error(err))
c.resetConnection()
ret, err = caller()
if err == nil {
return ret, nil
}
return ret, err
}
......@@ -229,7 +231,9 @@ func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentS
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{})
})
if err != nil || ret == nil {
......@@ -245,7 +249,9 @@ func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRespon
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{})
})
if err != nil || ret == nil {
......@@ -261,7 +267,9 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{})
})
if err != nil || ret == nil {
......@@ -276,7 +284,9 @@ func (c *Client) Flush(ctx context.Context, req *datapb.FlushRequest) (*datapb.F
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.Flush(ctx, req)
})
if err != nil || ret == nil {
......@@ -304,7 +314,9 @@ func (c *Client) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentI
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.AssignSegmentID(ctx, req)
})
if err != nil || ret == nil {
......@@ -328,7 +340,9 @@ func (c *Client) GetSegmentStates(ctx context.Context, req *datapb.GetSegmentSta
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetSegmentStates(ctx, req)
})
if err != nil || ret == nil {
......@@ -351,7 +365,9 @@ func (c *Client) GetInsertBinlogPaths(ctx context.Context, req *datapb.GetInsert
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetInsertBinlogPaths(ctx, req)
})
if err != nil || ret == nil {
......@@ -374,7 +390,9 @@ func (c *Client) GetCollectionStatistics(ctx context.Context, req *datapb.GetCol
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetCollectionStatistics(ctx, req)
})
if err != nil || ret == nil {
......@@ -397,7 +415,9 @@ func (c *Client) GetPartitionStatistics(ctx context.Context, req *datapb.GetPart
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetPartitionStatistics(ctx, req)
})
if err != nil || ret == nil {
......@@ -414,7 +434,9 @@ func (c *Client) GetSegmentInfoChannel(ctx context.Context) (*milvuspb.StringRes
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetSegmentInfoChannel(ctx, &datapb.GetSegmentInfoChannelRequest{})
})
if err != nil || ret == nil {
......@@ -436,7 +458,9 @@ func (c *Client) GetSegmentInfo(ctx context.Context, req *datapb.GetSegmentInfoR
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetSegmentInfo(ctx, req)
})
if err != nil || ret == nil {
......@@ -483,7 +507,9 @@ func (c *Client) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetRecoveryInfo(ctx, req)
})
if err != nil || ret == nil {
......@@ -506,7 +532,9 @@ func (c *Client) GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedS
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetFlushedSegments(ctx, req)
})
if err != nil || ret == nil {
......@@ -522,7 +550,9 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetMetrics(ctx, req)
})
if err != nil || ret == nil {
......@@ -537,7 +567,9 @@ func (c *Client) CompleteCompaction(ctx context.Context, req *datapb.CompactionR
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.CompleteCompaction(ctx, req)
})
if err != nil || ret == nil {
......@@ -552,7 +584,9 @@ func (c *Client) ManualCompaction(ctx context.Context, req *milvuspb.ManualCompa
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ManualCompaction(ctx, req)
})
if err != nil || ret == nil {
......@@ -567,7 +601,9 @@ func (c *Client) GetCompactionState(ctx context.Context, req *milvuspb.GetCompac
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetCompactionState(ctx, req)
})
if err != nil || ret == nil {
......@@ -582,7 +618,9 @@ func (c *Client) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb.
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetCompactionStateWithPlans(ctx, req)
})
if err != nil || ret == nil {
......@@ -597,7 +635,9 @@ func (c *Client) WatchChannels(ctx context.Context, req *datapb.WatchChannelsReq
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.WatchChannels(ctx, req)
})
if err != nil || ret == nil {
......
......@@ -22,16 +22,16 @@ import (
"sync"
"time"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/util/retry"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/trace"
"google.golang.org/grpc/codes"
......@@ -174,6 +174,10 @@ func (c *Client) recall(caller func() (interface{}, error)) (interface{}, error)
if err == nil {
return ret, nil
}
if err == context.Canceled || err == context.DeadlineExceeded {
return nil, err
}
log.Debug("DataNode Client grpc error", zap.Error(err))
c.resetConnection()
......@@ -214,7 +218,9 @@ func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentS
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{})
})
if err != nil || ret == nil {
......@@ -229,7 +235,9 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{})
})
if err != nil || ret == nil {
......@@ -244,7 +252,9 @@ func (c *Client) WatchDmChannels(ctx context.Context, req *datapb.WatchDmChannel
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.WatchDmChannels(ctx, req)
})
if err != nil || ret == nil {
......@@ -267,7 +277,9 @@ func (c *Client) FlushSegments(ctx context.Context, req *datapb.FlushSegmentsReq
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.FlushSegments(ctx, req)
})
if err != nil || ret == nil {
......@@ -282,7 +294,9 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetMetrics(ctx, req)
})
if err != nil || ret == nil {
......@@ -297,7 +311,9 @@ func (c *Client) Compaction(ctx context.Context, req *datapb.CompactionPlan) (*c
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.Compaction(ctx, req)
})
if err != nil || ret == nil {
......
......@@ -22,17 +22,17 @@ import (
"sync"
"time"
"google.golang.org/grpc"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/trace"
"github.com/milvus-io/milvus/internal/util/typeutil"
"go.uber.org/zap"
"google.golang.org/grpc"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
......@@ -181,6 +181,10 @@ func (c *Client) recall(caller func() (interface{}, error)) (interface{}, error)
if err == nil {
return ret, nil
}
if err == context.Canceled || err == context.DeadlineExceeded {
return nil, err
}
log.Debug("IndexCoord Client grpc error", zap.Error(err))
c.resetConnection()
......@@ -220,7 +224,9 @@ func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentS
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{})
})
if err != nil || ret == nil {
......@@ -236,7 +242,9 @@ func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRespon
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{})
})
if err != nil || ret == nil {
......@@ -252,7 +260,9 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{})
})
if err != nil || ret == nil {
......@@ -268,7 +278,9 @@ func (c *Client) BuildIndex(ctx context.Context, req *indexpb.BuildIndexRequest)
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.BuildIndex(ctx, req)
})
if err != nil || ret == nil {
......@@ -284,7 +296,9 @@ func (c *Client) DropIndex(ctx context.Context, req *indexpb.DropIndexRequest) (
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.DropIndex(ctx, req)
})
if err != nil || ret == nil {
......@@ -300,7 +314,9 @@ func (c *Client) GetIndexStates(ctx context.Context, req *indexpb.GetIndexStates
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetIndexStates(ctx, req)
})
if err != nil || ret == nil {
......@@ -316,7 +332,9 @@ func (c *Client) GetIndexFilePaths(ctx context.Context, req *indexpb.GetIndexFil
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetIndexFilePaths(ctx, req)
})
if err != nil || ret == nil {
......@@ -332,7 +350,9 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetMetrics(ctx, req)
})
if err != nil || ret == nil {
......
......@@ -27,6 +27,7 @@ import (
grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/trace"
"go.uber.org/zap"
......@@ -169,6 +170,9 @@ func (c *Client) recall(caller func() (interface{}, error)) (interface{}, error)
if err == nil {
return ret, nil
}
if err == context.Canceled || err == context.DeadlineExceeded {
return nil, err
}
log.Debug("IndexNode Client grpc error", zap.Error(err))
c.resetConnection()
......@@ -208,7 +212,9 @@ func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentS
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{})
})
if err != nil || ret == nil {
......@@ -224,7 +230,9 @@ func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRespon
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{})
})
if err != nil || ret == nil {
......@@ -240,7 +248,9 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{})
})
if err != nil || ret == nil {
......@@ -256,7 +266,9 @@ func (c *Client) CreateIndex(ctx context.Context, req *indexpb.CreateIndexReques
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.CreateIndex(ctx, req)
})
if err != nil || ret == nil {
......@@ -272,7 +284,9 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetMetrics(ctx, req)
})
if err != nil || ret == nil {
......
......@@ -30,6 +30,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/trace"
"go.uber.org/zap"
......@@ -171,6 +172,9 @@ func (c *Client) recall(caller func() (interface{}, error)) (interface{}, error)
if err == nil {
return ret, nil
}
if err == context.Canceled || err == context.DeadlineExceeded {
return nil, err
}
log.Debug("Proxy Client grpc error", zap.Error(err))
c.resetConnection()
......@@ -208,7 +212,9 @@ func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentS
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{})
})
if err != nil || ret == nil {
......@@ -223,7 +229,9 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{})
})
if err != nil || ret == nil {
......@@ -238,7 +246,9 @@ func (c *Client) InvalidateCollectionMetaCache(ctx context.Context, req *proxypb
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.InvalidateCollectionMetaCache(ctx, req)
})
if err != nil || ret == nil {
......@@ -253,7 +263,9 @@ func (c *Client) ReleaseDQLMessageStream(ctx context.Context, req *proxypb.Relea
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ReleaseDQLMessageStream(ctx, req)
})
if err != nil || ret == nil {
......
......@@ -25,6 +25,7 @@ import (
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/trace"
......@@ -195,6 +196,9 @@ func (c *Client) recall(caller func() (interface{}, error)) (interface{}, error)
if err == nil {
return ret, nil
}
if err == context.Canceled || err == context.DeadlineExceeded {
return nil, err
}
log.Debug("QueryCoord Client grpc error", zap.Error(err))
c.resetConnection()
......@@ -234,7 +238,9 @@ func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentS
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{})
})
if err != nil || ret == nil {
......@@ -250,7 +256,9 @@ func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRespon
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{})
})
if err != nil || ret == nil {
......@@ -266,7 +274,9 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{})
})
if err != nil || ret == nil {
......@@ -282,7 +292,9 @@ func (c *Client) ShowCollections(ctx context.Context, req *querypb.ShowCollectio
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ShowCollections(ctx, req)
})
if err != nil || ret == nil {
......@@ -298,7 +310,9 @@ func (c *Client) LoadCollection(ctx context.Context, req *querypb.LoadCollection
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.LoadCollection(ctx, req)
})
if err != nil || ret == nil {
......@@ -314,7 +328,9 @@ func (c *Client) ReleaseCollection(ctx context.Context, req *querypb.ReleaseColl
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ReleaseCollection(ctx, req)
})
if err != nil || ret == nil {
......@@ -330,7 +346,9 @@ func (c *Client) ShowPartitions(ctx context.Context, req *querypb.ShowPartitions
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ShowPartitions(ctx, req)
})
if err != nil || ret == nil {
......@@ -346,7 +364,9 @@ func (c *Client) LoadPartitions(ctx context.Context, req *querypb.LoadPartitions
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.LoadPartitions(ctx, req)
})
if err != nil || ret == nil {
......@@ -362,7 +382,9 @@ func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePart
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ReleasePartitions(ctx, req)
})
if err != nil || ret == nil {
......@@ -378,7 +400,9 @@ func (c *Client) CreateQueryChannel(ctx context.Context, req *querypb.CreateQuer
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.CreateQueryChannel(ctx, req)
})
if err != nil || ret == nil {
......@@ -394,7 +418,9 @@ func (c *Client) GetPartitionStates(ctx context.Context, req *querypb.GetPartiti
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetPartitionStates(ctx, req)
})
if err != nil || ret == nil {
......@@ -410,7 +436,9 @@ func (c *Client) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfo
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetSegmentInfo(ctx, req)
})
if err != nil || ret == nil {
......@@ -426,7 +454,9 @@ func (c *Client) LoadBalance(ctx context.Context, req *querypb.LoadBalanceReques
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.LoadBalance(ctx, req)
})
if err != nil || ret == nil {
......@@ -442,7 +472,9 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetMetrics(ctx, req)
})
if err != nil || ret == nil {
......
......@@ -34,6 +34,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/trace"
)
......@@ -161,6 +162,9 @@ func (c *Client) recall(caller func() (interface{}, error)) (interface{}, error)
if err == nil {
return ret, nil
}
if err == context.Canceled || err == context.DeadlineExceeded {
return nil, err
}
log.Debug("QueryNode Client grpc error", zap.Error(err))
c.resetConnection()
......@@ -200,7 +204,9 @@ func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentS
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{})
})
if err != nil || ret == nil {
......@@ -216,7 +222,9 @@ func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRespon
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{})
})
if err != nil || ret == nil {
......@@ -232,7 +240,9 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{})
})
if err != nil || ret == nil {
......@@ -248,7 +258,9 @@ func (c *Client) AddQueryChannel(ctx context.Context, req *querypb.AddQueryChann
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.AddQueryChannel(ctx, req)
})
if err != nil || ret == nil {
......@@ -264,7 +276,9 @@ func (c *Client) RemoveQueryChannel(ctx context.Context, req *querypb.RemoveQuer
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.RemoveQueryChannel(ctx, req)
})
if err != nil || ret == nil {
......@@ -280,7 +294,9 @@ func (c *Client) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChanne
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.WatchDmChannels(ctx, req)
})
if err != nil || ret == nil {
......@@ -296,7 +312,9 @@ func (c *Client) WatchDeltaChannels(ctx context.Context, req *querypb.WatchDelta
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.WatchDeltaChannels(ctx, req)
})
if err != nil || ret == nil {
......@@ -312,7 +330,9 @@ func (c *Client) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequ
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.LoadSegments(ctx, req)
})
if err != nil || ret == nil {
......@@ -328,7 +348,9 @@ func (c *Client) ReleaseCollection(ctx context.Context, req *querypb.ReleaseColl
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ReleaseCollection(ctx, req)
})
if err != nil || ret == nil {
......@@ -344,7 +366,9 @@ func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePart
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ReleasePartitions(ctx, req)
})
if err != nil || ret == nil {
......@@ -360,7 +384,9 @@ func (c *Client) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmen
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ReleaseSegments(ctx, req)
})
if err != nil || ret == nil {
......@@ -376,7 +402,9 @@ func (c *Client) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfo
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetSegmentInfo(ctx, req)
})
if err != nil || ret == nil {
......@@ -392,7 +420,9 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetMetrics(ctx, req)
})
if err != nil || ret == nil {
......
......@@ -32,6 +32,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/trace"
......@@ -223,6 +224,9 @@ func (c *GrpcClient) recall(caller func() (interface{}, error)) (interface{}, er
if err == nil {
return ret, nil
}
if err == context.Canceled || err == context.DeadlineExceeded {
return nil, err
}
log.Debug("RootCoord Client grpc error", zap.Error(err))
c.resetConnection()
......@@ -241,7 +245,9 @@ func (c *GrpcClient) GetComponentStates(ctx context.Context) (*internalpb.Compon
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{})
})
if err != nil || ret == nil {
......@@ -257,7 +263,9 @@ func (c *GrpcClient) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRe
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{})
})
if err != nil || ret == nil {
......@@ -273,7 +281,9 @@ func (c *GrpcClient) GetStatisticsChannel(ctx context.Context) (*milvuspb.String
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{})
})
if err != nil || ret == nil {
......@@ -289,7 +299,9 @@ func (c *GrpcClient) CreateCollection(ctx context.Context, in *milvuspb.CreateCo
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.CreateCollection(ctx, in)
})
if err != nil || ret == nil {
......@@ -305,7 +317,9 @@ func (c *GrpcClient) DropCollection(ctx context.Context, in *milvuspb.DropCollec
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.DropCollection(ctx, in)
})
if err != nil || ret == nil {
......@@ -321,7 +335,9 @@ func (c *GrpcClient) HasCollection(ctx context.Context, in *milvuspb.HasCollecti
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.HasCollection(ctx, in)
})
if err != nil || ret == nil {
......@@ -337,7 +353,9 @@ func (c *GrpcClient) DescribeCollection(ctx context.Context, in *milvuspb.Descri
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.DescribeCollection(ctx, in)
})
if err != nil || ret == nil {
......@@ -353,7 +371,9 @@ func (c *GrpcClient) ShowCollections(ctx context.Context, in *milvuspb.ShowColle
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ShowCollections(ctx, in)
})
if err != nil || ret == nil {
......@@ -369,7 +389,9 @@ func (c *GrpcClient) CreatePartition(ctx context.Context, in *milvuspb.CreatePar
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.CreatePartition(ctx, in)
})
if err != nil || ret == nil {
......@@ -385,7 +407,9 @@ func (c *GrpcClient) DropPartition(ctx context.Context, in *milvuspb.DropPartiti
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.DropPartition(ctx, in)
})
if err != nil || ret == nil {
......@@ -401,7 +425,9 @@ func (c *GrpcClient) HasPartition(ctx context.Context, in *milvuspb.HasPartition
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.HasPartition(ctx, in)
})
if err != nil || ret == nil {
......@@ -417,7 +443,9 @@ func (c *GrpcClient) ShowPartitions(ctx context.Context, in *milvuspb.ShowPartit
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ShowPartitions(ctx, in)
})
if err != nil || ret == nil {
......@@ -433,7 +461,9 @@ func (c *GrpcClient) CreateIndex(ctx context.Context, in *milvuspb.CreateIndexRe
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.CreateIndex(ctx, in)
})
if err != nil || ret == nil {
......@@ -449,7 +479,9 @@ func (c *GrpcClient) DropIndex(ctx context.Context, in *milvuspb.DropIndexReques
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.DropIndex(ctx, in)
})
if err != nil || ret == nil {
......@@ -465,7 +497,9 @@ func (c *GrpcClient) DescribeIndex(ctx context.Context, in *milvuspb.DescribeInd
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.DescribeIndex(ctx, in)
})
if err != nil || ret == nil {
......@@ -481,7 +515,9 @@ func (c *GrpcClient) AllocTimestamp(ctx context.Context, in *rootcoordpb.AllocTi
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.AllocTimestamp(ctx, in)
})
if err != nil || ret == nil {
......@@ -497,7 +533,9 @@ func (c *GrpcClient) AllocID(ctx context.Context, in *rootcoordpb.AllocIDRequest
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.AllocID(ctx, in)
})
if err != nil || ret == nil {
......@@ -513,7 +551,9 @@ func (c *GrpcClient) UpdateChannelTimeTick(ctx context.Context, in *internalpb.C
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.UpdateChannelTimeTick(ctx, in)
})
if err != nil || ret == nil {
......@@ -529,7 +569,9 @@ func (c *GrpcClient) DescribeSegment(ctx context.Context, in *milvuspb.DescribeS
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.DescribeSegment(ctx, in)
})
if err != nil || ret == nil {
......@@ -545,7 +587,9 @@ func (c *GrpcClient) ShowSegments(ctx context.Context, in *milvuspb.ShowSegments
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ShowSegments(ctx, in)
})
if err != nil || ret == nil {
......@@ -561,7 +605,9 @@ func (c *GrpcClient) ReleaseDQLMessageStream(ctx context.Context, in *proxypb.Re
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ReleaseDQLMessageStream(ctx, in)
})
if err != nil || ret == nil {
......@@ -577,7 +623,9 @@ func (c *GrpcClient) SegmentFlushCompleted(ctx context.Context, in *datapb.Segme
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.SegmentFlushCompleted(ctx, in)
})
if err != nil || ret == nil {
......@@ -593,7 +641,9 @@ func (c *GrpcClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequ
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetMetrics(ctx, in)
})
if err != nil || ret == nil {
......@@ -609,7 +659,9 @@ func (c *GrpcClient) CreateAlias(ctx context.Context, req *milvuspb.CreateAliasR
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.CreateAlias(ctx, req)
})
if err != nil || ret == nil {
......@@ -625,7 +677,9 @@ func (c *GrpcClient) DropAlias(ctx context.Context, req *milvuspb.DropAliasReque
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.DropAlias(ctx, req)
})
if err != nil || ret == nil {
......@@ -641,7 +695,9 @@ func (c *GrpcClient) AlterAlias(ctx context.Context, req *milvuspb.AlterAliasReq
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.AlterAlias(ctx, req)
})
if err != nil || ret == nil {
......
......@@ -184,3 +184,7 @@ func GetAttrByKeyFromRepeatedKV(key string, kvs []*commonpb.KeyValuePair) (strin
return "", errors.New("key " + key + " not found")
}
func CheckCtxValid(ctx context.Context) bool {
return ctx.Err() != context.DeadlineExceeded && ctx.Err() != context.Canceled
}
......@@ -264,3 +264,26 @@ func TestGetAttrByKeyFromRepeatedKV(t *testing.T) {
assert.Equal(t, test.errIsNil, err == nil)
}
}
func TestCheckCtxValid(t *testing.T) {
bgCtx := context.Background()
timeout := 20 * time.Millisecond
deltaTime := 5 * time.Millisecond
ctx1, cancel1 := context.WithTimeout(bgCtx, timeout)
defer cancel1()
assert.True(t, CheckCtxValid(ctx1))
time.Sleep(timeout + deltaTime)
assert.False(t, CheckCtxValid(ctx1))
ctx2, cancel2 := context.WithTimeout(bgCtx, timeout)
assert.True(t, CheckCtxValid(ctx2))
cancel2()
assert.False(t, CheckCtxValid(ctx2))
futureTime := time.Now().Add(timeout)
ctx3, cancel3 := context.WithDeadline(bgCtx, futureTime)
defer cancel3()
assert.True(t, CheckCtxValid(ctx3))
time.Sleep(timeout + deltaTime)
assert.False(t, CheckCtxValid(ctx3))
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册