未验证 提交 2fe8677c 编写于 作者: J Jiquan Long 提交者: GitHub

Enable dimension check in Proxy when create index request received (#16718)

Signed-off-by: Ndragondriver <jiquan.long@zilliz.com>
上级 bb9ccbb7
package proxy
import (
"context"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
type getCollectionIDFunc func(ctx context.Context, collectionName string) (typeutil.UniqueID, error)
type getCollectionSchemaFunc func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error)
type mockCache struct {
Cache
getIDFunc getCollectionIDFunc
getSchemaFunc getCollectionSchemaFunc
}
func (m *mockCache) GetCollectionID(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
if m.getIDFunc != nil {
return m.getIDFunc(ctx, collectionName)
}
return 0, nil
}
func (m *mockCache) GetCollectionSchema(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) {
if m.getSchemaFunc != nil {
return m.getSchemaFunc(ctx, collectionName)
}
return nil, nil
}
func (m *mockCache) setGetIDFunc(f getCollectionIDFunc) {
m.getIDFunc = f
}
func (m *mockCache) setGetSchemaFunc(f getCollectionSchemaFunc) {
m.getSchemaFunc = f
}
func newMockCache() *mockCache {
return &mockCache{}
}
......@@ -1795,36 +1795,13 @@ func (cit *createIndexTask) OnEnqueue() error {
return nil
}
func (cit *createIndexTask) PreExecute(ctx context.Context) error {
cit.Base.MsgType = commonpb.MsgType_CreateIndex
cit.Base.SourceID = Params.ProxyCfg.GetNodeID()
collName, fieldName := cit.CollectionName, cit.FieldName
col, err := globalMetaCache.GetCollectionInfo(ctx, collName)
if err != nil {
return err
}
cit.collectionID = col.collID
if err := validateCollectionName(collName); err != nil {
return err
}
if err := validateFieldName(fieldName); err != nil {
return err
}
// check index param, not accurate, only some static rules
func parseIndexParams(m []*commonpb.KeyValuePair) (map[string]string, error) {
indexParams := make(map[string]string)
for _, kv := range cit.CreateIndexRequest.ExtraParams {
for _, kv := range m {
if kv.Key == "params" { // TODO(dragondriver): change `params` to const variable
params, err := funcutil.ParseIndexParamsMap(kv.Value)
if err != nil {
log.Warn("Failed to parse index params",
zap.String("params", kv.Value),
zap.Error(err))
continue
return nil, err
}
for k, v := range params {
indexParams[k] = v
......@@ -1833,23 +1810,68 @@ func (cit *createIndexTask) PreExecute(ctx context.Context) error {
indexParams[kv.Key] = kv.Value
}
}
indexType, exist := indexParams["index_type"] // TODO(dragondriver): change `index_type` to const variable
_, exist := indexParams["index_type"] // TODO(dragondriver): change `index_type` to const variable
if !exist {
indexType = indexparamcheck.IndexFaissIvfPQ // IVF_PQ is the default index type
indexParams["index_type"] = indexparamcheck.IndexFaissIvfPQ // IVF_PQ is the default index type
}
return indexParams, nil
}
//TODO:: add default index type for VarChar type field
func (cit *createIndexTask) getIndexedField(ctx context.Context) (*schemapb.FieldSchema, error) {
schema, err := globalMetaCache.GetCollectionSchema(ctx, cit.GetCollectionName())
if err != nil {
log.Error("failed to get collection schema", zap.Error(err))
return nil, fmt.Errorf("failed to get collection schema: %s", err)
}
schemaHelper, err := typeutil.CreateSchemaHelper(schema)
if err != nil {
log.Error("failed to parse collection schema", zap.Error(err))
return nil, fmt.Errorf("failed to parse collection schema: %s", err)
}
field, err := schemaHelper.GetFieldFromName(cit.GetFieldName())
if err != nil {
log.Error("create index on non-exist field", zap.Error(err))
return nil, fmt.Errorf("cannot create index on non-exist field: %s", cit.GetFieldName())
}
return field, nil
}
// skip params check of non-vector field.
func fillDimension(field *schemapb.FieldSchema, indexParams map[string]string) error {
vecDataTypes := []schemapb.DataType{
schemapb.DataType_FloatVector,
schemapb.DataType_BinaryVector,
}
for _, f := range col.schema.GetFields() {
if f.GetName() == fieldName && !funcutil.SliceContain(vecDataTypes, f.GetDataType()) {
return indexparamcheck.CheckIndexValid(f.GetDataType(), indexType, indexParams)
if !funcutil.SliceContain(vecDataTypes, field.GetDataType()) {
return nil
}
params := make([]*commonpb.KeyValuePair, 0, len(field.GetTypeParams())+len(field.GetIndexParams()))
params = append(params, field.GetTypeParams()...)
params = append(params, field.GetIndexParams()...)
dimensionInSchema, err := funcutil.GetAttrByKeyFromRepeatedKV("dim", params)
if err != nil {
return fmt.Errorf("dimension not found in schema")
}
dimension, exist := indexParams["dim"]
if exist {
if dimensionInSchema != dimension {
return fmt.Errorf("dimension mismatch, dimension in schema: %s, dimension: %s", dimensionInSchema, dimension)
}
} else {
indexParams["dim"] = dimensionInSchema
}
return nil
}
func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) error {
indexType := indexParams["index_type"]
// skip params check of non-vector field.
vecDataTypes := []schemapb.DataType{
schemapb.DataType_FloatVector,
schemapb.DataType_BinaryVector,
}
if !funcutil.SliceContain(vecDataTypes, field.GetDataType()) {
return indexparamcheck.CheckIndexValid(field.GetDataType(), indexType, indexParams)
}
adapter, err := indexparamcheck.GetConfAdapterMgrInstance().GetAdapter(indexType)
......@@ -1858,15 +1880,46 @@ func (cit *createIndexTask) PreExecute(ctx context.Context) error {
return fmt.Errorf("invalid index type: %s", indexType)
}
if err := fillDimension(field, indexParams); err != nil {
return err
}
ok := adapter.CheckTrain(indexParams)
if !ok {
log.Warn("Create index with invalid params", zap.Any("index_params", indexParams))
return fmt.Errorf("invalid index params: %v", cit.CreateIndexRequest.ExtraParams)
return fmt.Errorf("invalid index params: %v", indexParams)
}
return nil
}
func (cit *createIndexTask) PreExecute(ctx context.Context) error {
cit.Base.MsgType = commonpb.MsgType_CreateIndex
cit.Base.SourceID = Params.ProxyCfg.GetNodeID()
collName := cit.CollectionName
collID, err := globalMetaCache.GetCollectionID(ctx, collName)
if err != nil {
return err
}
cit.collectionID = collID
field, err := cit.getIndexedField(ctx)
if err != nil {
return err
}
// check index param, not accurate, only some static rules
indexParams, err := parseIndexParams(cit.GetExtraParams())
if err != nil {
log.Error("failed to parse index params", zap.Error(err))
return fmt.Errorf("failed to parse index params: %s", err)
}
return checkTrain(field, indexParams)
}
func (cit *createIndexTask) Execute(ctx context.Context) error {
var err error
cit.result, err = cit.rootCoord.CreateIndex(ctx, cit.CreateIndexRequest)
......
......@@ -21,11 +21,14 @@ import (
"context"
"encoding/binary"
"encoding/json"
"errors"
"math/rand"
"strconv"
"testing"
"time"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
......@@ -2169,3 +2172,265 @@ func TestAlterAlias_all(t *testing.T) {
assert.NoError(t, task.Execute(ctx))
assert.NoError(t, task.PostExecute(ctx))
}
func Test_createIndexTask_getIndexedField(t *testing.T) {
collectionName := "test"
fieldName := "test"
cit := &createIndexTask{
CreateIndexRequest: &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: fieldName,
},
}
t.Run("normal", func(t *testing.T) {
cache := newMockCache()
cache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) {
return &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
Name: fieldName,
IsPrimaryKey: false,
DataType: schemapb.DataType_FloatVector,
TypeParams: nil,
IndexParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "128",
},
},
AutoID: false,
},
},
}, nil
})
globalMetaCache = cache
field, err := cit.getIndexedField(context.Background())
assert.NoError(t, err)
assert.Equal(t, fieldName, field.GetName())
})
t.Run("schema not found", func(t *testing.T) {
cache := newMockCache()
cache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) {
return nil, errors.New("mock")
})
globalMetaCache = cache
_, err := cit.getIndexedField(context.Background())
assert.Error(t, err)
})
t.Run("invalid schema", func(t *testing.T) {
cache := newMockCache()
cache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) {
return &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
Name: fieldName,
},
{
Name: fieldName, // duplicate
},
},
}, nil
})
globalMetaCache = cache
_, err := cit.getIndexedField(context.Background())
assert.Error(t, err)
})
t.Run("field not found", func(t *testing.T) {
cache := newMockCache()
cache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) {
return &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
Name: fieldName + fieldName,
},
},
}, nil
})
globalMetaCache = cache
_, err := cit.getIndexedField(context.Background())
assert.Error(t, err)
})
}
func Test_fillDimension(t *testing.T) {
t.Run("scalar", func(t *testing.T) {
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_Int64,
}
assert.NoError(t, fillDimension(f, nil))
})
t.Run("no dim in schema", func(t *testing.T) {
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_FloatVector,
}
assert.Error(t, fillDimension(f, nil))
})
t.Run("dimension mismatch", func(t *testing.T) {
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_FloatVector,
IndexParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "128",
},
},
}
assert.Error(t, fillDimension(f, map[string]string{"dim": "8"}))
})
t.Run("normal", func(t *testing.T) {
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_FloatVector,
IndexParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "128",
},
},
}
m := map[string]string{}
assert.NoError(t, fillDimension(f, m))
assert.Equal(t, "128", m["dim"])
})
}
func Test_checkTrain(t *testing.T) {
t.Run("normal", func(t *testing.T) {
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_FloatVector,
IndexParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "128",
},
},
}
m := map[string]string{
"index_type": "IVF_FLAT",
"nlist": "1024",
"metric_type": "L2",
}
assert.NoError(t, checkTrain(f, m))
})
t.Run("scalar", func(t *testing.T) {
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_Int64,
}
m := map[string]string{
"index_type": "scalar",
}
assert.NoError(t, checkTrain(f, m))
})
t.Run("dimension mismatch", func(t *testing.T) {
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_FloatVector,
IndexParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "128",
},
},
}
m := map[string]string{
"index_type": "IVF_FLAT",
"nlist": "1024",
"metric_type": "L2",
"dim": "8",
}
assert.Error(t, checkTrain(f, m))
})
t.Run("invalid params", func(t *testing.T) {
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_FloatVector,
IndexParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "128",
},
},
}
m := map[string]string{
"index_type": "IVF_FLAT",
"metric_type": "L2",
}
assert.Error(t, checkTrain(f, m))
})
}
func Test_createIndexTask_PreExecute(t *testing.T) {
collectionName := "test"
fieldName := "test"
cit := &createIndexTask{
CreateIndexRequest: &milvuspb.CreateIndexRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_CreateIndex,
},
CollectionName: collectionName,
FieldName: fieldName,
},
}
t.Run("normal", func(t *testing.T) {
cache := newMockCache()
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
return 100, nil
})
cache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) {
return &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
Name: fieldName,
IsPrimaryKey: false,
DataType: schemapb.DataType_FloatVector,
TypeParams: nil,
IndexParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "128",
},
},
AutoID: false,
},
},
}, nil
})
globalMetaCache = cache
cit.CreateIndexRequest.ExtraParams = []*commonpb.KeyValuePair{
{
Key: "index_type",
Value: "IVF_FLAT",
},
{
Key: "nlist",
Value: "1024",
},
{
Key: "metric_type",
Value: "L2",
},
}
assert.NoError(t, cit.PreExecute(context.Background()))
})
t.Run("collection not found", func(t *testing.T) {
cache := newMockCache()
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
return 0, errors.New("mock")
})
globalMetaCache = cache
assert.Error(t, cit.PreExecute(context.Background()))
})
}
......@@ -131,10 +131,9 @@ type BaseConfAdapter struct {
// CheckTrain check whether the params contains supported metrics types
func (adapter *BaseConfAdapter) CheckTrain(params map[string]string) bool {
// dimension is specified when create collection
//if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
// return false
//}
if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
return false
}
return CheckStrByValues(params, Metric, METRICS)
}
......@@ -179,8 +178,8 @@ func (adapter *IVFPQConfAdapter) CheckTrain(params map[string]string) bool {
func (adapter *IVFPQConfAdapter) checkPQParams(params map[string]string) bool {
dimStr, dimensionExist := params[DIM]
if !dimensionExist { // dimension is specified when creating collection
return true
if !dimensionExist {
return false
}
dimension, err := strconv.Atoi(dimStr)
......@@ -260,10 +259,9 @@ type BinIDMAPConfAdapter struct {
// CheckTrain checks if a binary flat index can be built with the specific parameters.
func (adapter *BinIDMAPConfAdapter) CheckTrain(params map[string]string) bool {
// dimension is specified when create collection
//if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
// return false
//}
if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
return false
}
return CheckStrByValues(params, Metric, BinIDMapMetrics)
}
......@@ -278,10 +276,9 @@ type BinIVFConfAdapter struct {
// CheckTrain checks if a binary ivf index can be built with specific parameters.
func (adapter *BinIVFConfAdapter) CheckTrain(params map[string]string) bool {
// dimension is specified when create collection
//if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
// return false
//}
if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
return false
}
if !CheckIntByRange(params, NLIST, MinNList, MaxNList) {
return false
......
......@@ -12,6 +12,7 @@
package indexparamcheck
import (
"fmt"
"strconv"
"testing"
)
......@@ -50,11 +51,15 @@ func TestBaseConfAdapter_CheckTrain(t *testing.T) {
DIM: strconv.Itoa(128),
Metric: L2,
}
paramsWithoutDim := map[string]string{
Metric: L2,
}
cases := []struct {
params map[string]string
want bool
}{
{validParams, true},
{paramsWithoutDim, false},
}
adapter := newBaseConfAdapter()
......@@ -141,7 +146,7 @@ func TestIVFPQConfAdapter_CheckTrain(t *testing.T) {
{validParamsWithoutNbits, true},
{invalidIVFParamsMin(), false},
{invalidIVFParamsMax(), false},
{validParamsWithoutDim, true},
{validParamsWithoutDim, false},
{invalidParamsDim, false},
{invalidParamsNbits, false},
{invalidParamsWithoutIVF, false},
......@@ -150,8 +155,9 @@ func TestIVFPQConfAdapter_CheckTrain(t *testing.T) {
}
adapter := newIVFPQConfAdapter()
for _, test := range cases {
for i, test := range cases {
if got := adapter.CheckTrain(test.params); got != test.want {
fmt.Printf("i: %d, params: %v\n", i, test.params)
t.Errorf("IVFPQConfAdapter.CheckTrain(%v) = %v", test.params, test.want)
}
}
......@@ -187,11 +193,15 @@ func TestBinIDMAPConfAdapter_CheckTrain(t *testing.T) {
DIM: strconv.Itoa(128),
Metric: JACCARD,
}
paramsWithoutDim := map[string]string{
Metric: JACCARD,
}
cases := []struct {
params map[string]string
want bool
}{
{validParams, true},
{paramsWithoutDim, false},
}
adapter := newBinIDMAPConfAdapter()
......@@ -211,6 +221,12 @@ func TestBinIVFConfAdapter_CheckTrain(t *testing.T) {
NBITS: strconv.Itoa(8),
Metric: JACCARD,
}
paramsWithoutDim := map[string]string{
NLIST: strconv.Itoa(100),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: JACCARD,
}
invalidParams := copyParams(validParams)
invalidParams[Metric] = L2
......@@ -220,6 +236,7 @@ func TestBinIVFConfAdapter_CheckTrain(t *testing.T) {
want bool
}{
{validParams, true},
{paramsWithoutDim, false},
{invalidIVFParamsMin(), false},
{invalidIVFParamsMax(), false},
{invalidParams, false},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册