未验证 提交 8db24e0a 编写于 作者: C congqixia 提交者: GitHub

Support queryHook in querynodev2 (#23140)

Signed-off-by: NCongqi Xia <congqi.xia@zilliz.com>
上级 38bb3599
......@@ -75,7 +75,12 @@ const (
SegmentIndexPath = `index_files`
)
// Search, Index parameter keys
const (
TopKKey = "topk"
SearchParamKey = "search_param"
SegmentNumKey = "segment_num"
IndexParamsKey = "params"
IndexTypeKey = "index_type"
MetricTypeKey = "metric_type"
......
......@@ -21,10 +21,13 @@ import (
"fmt"
"strconv"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
......@@ -33,8 +36,10 @@ import (
"github.com/milvus-io/milvus/internal/util"
"github.com/milvus-io/milvus/internal/util/commonpbutil"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/merr"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/timerecord"
"github.com/samber/lo"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
)
......@@ -249,6 +254,69 @@ func (node *QueryNode) querySegments(ctx context.Context, req *querypb.QueryRequ
}, nil
}
func (node *QueryNode) optimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, deleg delegator.ShardDelegator) (*querypb.SearchRequest, error) {
// no hook applied, just return
if node.queryHook == nil {
return req, nil
}
log := log.Ctx(ctx)
serializedPlan := req.GetReq().GetSerializedExprPlan()
// plan not found
if serializedPlan == nil {
log.Warn("serialized plan not found")
return req, merr.WrapErrParameterInvalid("serialized search plan", "nil")
}
channelNum := req.GetTotalChannelNum()
// not set, change to conservative channel num 1
if channelNum <= 0 {
channelNum = 1
}
plan := planpb.PlanNode{}
err := proto.Unmarshal(serializedPlan, &plan)
if err != nil {
log.Warn("failed to unmarshal plan", zap.Error(err))
return nil, merr.WrapErrParameterInvalid("valid serialized search plan", "no unmarshalable one", err.Error())
}
switch plan.GetNode().(type) {
case *planpb.PlanNode_VectorAnns:
// ignore growing ones for now since they will always be brute force
sealed, _ := deleg.GetSegmentInfo()
sealedNum := lo.Reduce(sealed, func(sum int, item delegator.SnapshotItem, _ int) int {
return sum + len(item.Segments)
}, 0)
// use shardNum * segments num in shard to estimate total segment number
estSegmentNum := sealedNum * int(channelNum)
queryInfo := plan.GetVectorAnns().GetQueryInfo()
params := map[string]any{
common.TopKKey: queryInfo.GetTopk(),
common.SearchParamKey: queryInfo.GetSearchParams(),
common.SegmentNumKey: estSegmentNum,
}
err := node.queryHook.Run(params)
if err != nil {
log.Warn("failed to execute queryHook", zap.Error(err))
return nil, merr.WrapErrServiceUnavailable(err.Error(), "queryHook execution failed")
}
queryInfo.Topk = params[common.TopKKey].(int64)
queryInfo.SearchParams = params[common.SearchParamKey].(string)
serializedExprPlan, err := proto.Marshal(&plan)
if err != nil {
log.Warn("failed to marshal optimized plan", zap.Error(err))
return nil, merr.WrapErrParameterInvalid("marshalable search plan", "plan with marshal error", err.Error())
}
req.Req.SerializedExprPlan = serializedExprPlan
default:
log.Warn("not supported node type", zap.String("nodeType", fmt.Sprintf("%T", plan.GetNode())))
}
return req, nil
}
func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchRequest, channel string) (*internalpb.SearchResults, error) {
log := log.Ctx(ctx).With(
zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()),
......@@ -270,7 +338,6 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq
searchCtx, cancel := context.WithCancel(ctx)
defer cancel()
// TODO From Shard Delegator
if req.GetFromShardLeader() {
tr := timerecord.NewTimeRecorder("searchChannel")
log.Debug("search channel...")
......@@ -313,6 +380,11 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq
log.Warn("Query failed, failed to get query shard delegator", zap.Error(ErrGetDelegatorFailed))
return nil, ErrGetDelegatorFailed
}
req, err := node.optimizeSearchParams(ctx, req, sd)
if err != nil {
log.Warn("failed to optimize search params", zap.Error(err))
return nil, err
}
// do search
results, err := sd.Search(searchCtx, req)
if err != nil {
......
......@@ -21,12 +21,17 @@ import (
"os"
"testing"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/merr"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
......@@ -136,3 +141,184 @@ func (suite *HandlersSuite) TestLoadGrowingSegments() {
func TestHandlersSuite(t *testing.T) {
suite.Run(t, new(HandlersSuite))
}
type OptimizeSearchParamSuite struct {
suite.Suite
// Data
collectionID int64
collectionName string
segmentID int64
channel string
node *QueryNode
delegator *delegator.MockShardDelegator
// Mock
factory *dependency.MockFactory
}
func (suite *OptimizeSearchParamSuite) SetupSuite() {
suite.collectionID = 111
suite.collectionName = "test-collection"
suite.segmentID = 1
suite.channel = "test-channel"
suite.delegator = &delegator.MockShardDelegator{}
suite.delegator.EXPECT().GetSegmentInfo().Return([]delegator.SnapshotItem{{NodeID: 1, Segments: []delegator.SegmentEntry{{SegmentID: 100}}}}, []delegator.SegmentEntry{})
}
func (suite *OptimizeSearchParamSuite) SetupTest() {
suite.factory = dependency.NewMockFactory(suite.T())
suite.node = NewQueryNode(context.Background(), suite.factory)
}
func (suite *OptimizeSearchParamSuite) TearDownTest() {
}
func (suite *OptimizeSearchParamSuite) TestOptimizeSearchParam() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
suite.Run("normal_run", func() {
mockHook := &MockQueryHook{}
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
params[common.TopKKey] = int64(50)
params[common.SearchParamKey] = `{"param": 2}`
}).Return(nil)
suite.node.queryHook = mockHook
defer func() { suite.node.queryHook = nil }()
plan := &planpb.PlanNode{
Node: &planpb.PlanNode_VectorAnns{
VectorAnns: &planpb.VectorANNS{
QueryInfo: &planpb.QueryInfo{
Topk: 100,
SearchParams: `{"param": 1}`,
},
},
},
}
bs, err := proto.Marshal(plan)
suite.Require().NoError(err)
req, err := suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{
SerializedExprPlan: bs,
},
TotalChannelNum: 2,
}, suite.delegator)
suite.NoError(err)
suite.verifyQueryInfo(req, 50, `{"param": 2}`)
})
suite.Run("no_hook", func() {
suite.node.queryHook = nil
plan := &planpb.PlanNode{
Node: &planpb.PlanNode_VectorAnns{
VectorAnns: &planpb.VectorANNS{
QueryInfo: &planpb.QueryInfo{
Topk: 100,
SearchParams: `{"param": 1}`,
},
},
},
}
bs, err := proto.Marshal(plan)
suite.Require().NoError(err)
req, err := suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{
SerializedExprPlan: bs,
},
TotalChannelNum: 2,
}, suite.delegator)
suite.NoError(err)
suite.verifyQueryInfo(req, 100, `{"param": 1}`)
})
suite.Run("other_plannode", func() {
mockHook := &MockQueryHook{}
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
params[common.TopKKey] = int64(50)
params[common.SearchParamKey] = `{"param": 2}`
}).Return(nil).Maybe()
suite.node.queryHook = mockHook
defer func() { suite.node.queryHook = nil }()
plan := &planpb.PlanNode{
Node: &planpb.PlanNode_Query{},
}
bs, err := proto.Marshal(plan)
suite.Require().NoError(err)
req, err := suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{
SerializedExprPlan: bs,
},
TotalChannelNum: 2,
}, suite.delegator)
suite.NoError(err)
suite.Equal(bs, req.GetReq().GetSerializedExprPlan())
})
suite.Run("no_serialized_plan", func() {
mockHook := &MockQueryHook{}
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
params[common.TopKKey] = int64(50)
params[common.SearchParamKey] = `{"param": 2}`
}).Return(nil)
suite.node.queryHook = mockHook
defer func() { suite.node.queryHook = nil }()
_, err := suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{},
TotalChannelNum: 2,
}, suite.delegator)
suite.Error(err)
})
suite.Run("hook_run_error", func() {
mockHook := &MockQueryHook{}
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
params[common.TopKKey] = int64(50)
params[common.SearchParamKey] = `{"param": 2}`
}).Return(merr.WrapErrServiceInternal("mocked"))
suite.node.queryHook = mockHook
defer func() { suite.node.queryHook = nil }()
plan := &planpb.PlanNode{
Node: &planpb.PlanNode_VectorAnns{
VectorAnns: &planpb.VectorANNS{
QueryInfo: &planpb.QueryInfo{
Topk: 100,
SearchParams: `{"param": 1}`,
},
},
},
}
bs, err := proto.Marshal(plan)
suite.Require().NoError(err)
_, err = suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{
SerializedExprPlan: bs,
},
}, suite.delegator)
suite.Error(err)
})
}
func (suite *OptimizeSearchParamSuite) verifyQueryInfo(req *querypb.SearchRequest, topK int64, param string) {
planBytes := req.GetReq().GetSerializedExprPlan()
plan := planpb.PlanNode{}
err := proto.Unmarshal(planBytes, &plan)
suite.Require().NoError(err)
queryInfo := plan.GetVectorAnns().GetQueryInfo()
suite.Equal(topK, queryInfo.GetTopk())
suite.Equal(param, queryInfo.GetSearchParams())
}
func TestOptimizeSearchParam(t *testing.T) {
suite.Run(t, new(OptimizeSearchParamSuite))
}
package querynodev2
import "github.com/stretchr/testify/mock"
type MockQueryHook struct {
mock.Mock
}
type MockQueryHookExpecter struct {
mock *mock.Mock
}
func (_m *MockQueryHook) EXPECT() *MockQueryHookExpecter {
return &MockQueryHookExpecter{mock: &_m.Mock}
}
func (_m *MockQueryHook) Run(params map[string]any) error {
ret := _m.Called(params)
var r0 error
if rf, ok := ret.Get(0).(func(params map[string]any) error); ok {
r0 = rf(params)
} else {
r0 = ret.Error(0)
}
return r0
}
// Run is a helper method to define mock.On call
func (_e *MockQueryHookExpecter) Run(params any) *MockQueryHookRunCall {
return &MockQueryHookRunCall{Call: _e.mock.On("Run", params)}
}
// MockQueryHook_Run_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Run'
type MockQueryHookRunCall struct {
*mock.Call
}
func (_c *MockQueryHookRunCall) Run(run func(params map[string]any)) *MockQueryHookRunCall {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(map[string]any))
})
return _c
}
func (_c *MockQueryHookRunCall) Return(_a0 error) *MockQueryHookRunCall {
_c.Call.Return(_a0)
return _c
}
func (_m *MockQueryHook) Init(param string) error {
ret := _m.Called(param)
var r0 error
if rf, ok := ret.Get(0).(func(string) error); ok {
r0 = rf(param)
} else {
r0 = ret.Error(0)
}
return r0
}
// Init is a helper method to define mock.On call
func (_e *MockQueryHookExpecter) Init(params any) *MockQueryHookRunCall {
return &MockQueryHookRunCall{Call: _e.mock.On("Init", params)}
}
// MockQueryHook_Init_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Run'
type MockQueryHookInitCall struct {
*mock.Call
}
func (_c *MockQueryHookInitCall) Run(run func(params string)) *MockQueryHookInitCall {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *MockQueryHookInitCall) Return(_a0 error) *MockQueryHookInitCall {
_c.Call.Return(_a0)
return _c
}
......@@ -32,6 +32,7 @@ import (
"fmt"
"os"
"path"
"plugin"
"runtime/debug"
"sync"
"syscall"
......@@ -43,6 +44,7 @@ import (
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/config"
grpcquerynodeclient "github.com/milvus-io/milvus/internal/distributed/querynode/client"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log"
......@@ -121,13 +123,16 @@ type QueryNode struct {
loadPool *conc.Pool
// Pool for search/query
taskPool *conc.Pool
// parameter turning hook
queryHook queryHook
}
// NewQueryNode will return a QueryNode with abnormal state.
func NewQueryNode(ctx context.Context, factory dependency.Factory) *QueryNode {
ctx1, cancel := context.WithCancel(ctx)
ctx, cancel := context.WithCancel(ctx)
node := &QueryNode{
ctx: ctx1,
ctx: ctx,
cancel: cancel,
factory: factory,
lifetime: lifetime.NewLifetime(commonpb.StateCode_Abnormal),
......@@ -224,6 +229,16 @@ func (node *QueryNode) Init() error {
return
}
err = node.initHook()
if err != nil {
log.Error("QueryNode init hook failed", zap.Error(err))
// auto index cannot work if hook init failed
if paramtable.Get().AutoIndexConfig.Enable.GetAsBool() {
initError = err
return
}
}
node.factory.Init(paramtable.Get())
localChunkManager := storage.NewLocalChunkManager(storage.RootPath(paramtable.Get().LocalStorageCfg.Path.GetValue()))
......@@ -356,3 +371,48 @@ func (node *QueryNode) GetAddress() string {
func (node *QueryNode) SetAddress(address string) {
node.address = address
}
type queryHook interface {
Run(map[string]any) error
Init(string) error
}
// initHook initializes parameter tuning hook.
func (node *QueryNode) initHook() error {
path := paramtable.Get().QueryNodeCfg.SoPath.GetValue()
if path == "" {
return fmt.Errorf("fail to set the plugin path")
}
log.Debug("start to load plugin", zap.String("path", path))
p, err := plugin.Open(path)
if err != nil {
return fmt.Errorf("fail to open the plugin, error: %s", err.Error())
}
log.Debug("plugin open")
h, err := p.Lookup("QueryNodePlugin")
if err != nil {
return fmt.Errorf("fail to find the 'QueryNodePlugin' object in the plugin, error: %s", err.Error())
}
hoo, ok := h.(queryHook)
if !ok {
return fmt.Errorf("fail to convert the `Hook` interface")
}
if err = hoo.Init(paramtable.Get().HookCfg.QueryNodePluginConfig.GetValue()); err != nil {
return fmt.Errorf("fail to init configs for the hook, error: %s", err.Error())
}
node.queryHook = hoo
onEvent := func(event *config.Event) {
if node.queryHook != nil {
if err := node.queryHook.Init(event.Value); err != nil {
log.Error("failed to refresh hook config", zap.Error(err))
}
}
}
paramtable.Get().Watch(paramtable.Get().HookCfg.QueryNodePluginConfig.Key, config.NewHandler("queryHook", onEvent))
return nil
}
......@@ -622,7 +622,9 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
FromShardLeader: req.FromShardLeader,
Scope: req.Scope,
}
runningGp.Go(func() error {
ret, err := node.searchChannel(runningCtx, req, ch)
mu.Lock()
defer mu.Unlock()
......@@ -632,9 +634,7 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
return err
}
if ret.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
failRet.Status.Reason = ret.Status.Reason
failRet.Status.ErrorCode = ret.Status.ErrorCode
return fmt.Errorf("%s", ret.Status.Reason)
return merr.Error(failRet.GetStatus())
}
toReduceResults = append(toReduceResults, ret)
return nil
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册