diff --git a/internal/proxy/collection_task_test.go b/internal/proxy/collection_task_test.go new file mode 100644 index 0000000000000000000000000000000000000000..12da8e82fa46529f94a1206f3a6b269d7f527e9f --- /dev/null +++ b/internal/proxy/collection_task_test.go @@ -0,0 +1,269 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +package proxy + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/milvus-io/milvus/internal/proto/schemapb" + + "github.com/milvus-io/milvus/internal/proto/commonpb" + + "github.com/milvus-io/milvus/internal/util/uniquegenerator" + + "github.com/stretchr/testify/assert" + + "github.com/golang/protobuf/proto" + + "github.com/milvus-io/milvus/internal/util/funcutil" + + "github.com/milvus-io/milvus/internal/proto/milvuspb" +) + +func TestCreateCollectionTask(t *testing.T) { + Params.Init() + + rc := NewRootCoordMock() + ctx := context.Background() + shardsNum := int32(2) + prefix := "TestCreateCollectionTask" + dbName := "" + collectionName := prefix + funcutil.GenRandomStr() + int64Field := "int64" + floatVecField := "fvec" + dim := 128 + + schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName) + var marshaledSchema []byte + marshaledSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + + task := &createCollectionTask{ + Condition: NewTaskCondition(ctx), + CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: shardsNum, + }, + ctx: ctx, + rootCoord: rc, + result: nil, + schema: nil, + } + + t.Run("on enqueue", func(t *testing.T) { + err := task.OnEnqueue() + assert.NoError(t, err) + assert.Equal(t, commonpb.MsgType_CreateCollection, task.Type()) + }) + + t.Run("ctx", func(t *testing.T) { + traceCtx := task.TraceCtx() + assert.NotNil(t, traceCtx) + }) + + t.Run("id", func(t *testing.T) { + id := UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()) + task.SetID(id) + assert.Equal(t, id, task.ID()) + }) + + t.Run("name", func(t *testing.T) { + assert.Equal(t, CreateCollectionTaskName, task.Name()) + }) + + t.Run("ts", func(t *testing.T) { + ts := Timestamp(time.Now().UnixNano()) + task.SetTs(ts) + assert.Equal(t, ts, task.BeginTs()) + assert.Equal(t, ts, task.EndTs()) + }) + + t.Run("process task", func(t *testing.T) { + var err error + + err = task.PreExecute(ctx) + assert.NoError(t, err) + + err = task.Execute(ctx) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, task.result.ErrorCode) + + // recreate -> fail + err = task.Execute(ctx) + assert.NoError(t, err) + assert.NotEqual(t, commonpb.ErrorCode_Success, task.result.ErrorCode) + + err = task.PostExecute(ctx) + assert.NoError(t, err) + }) + + t.Run("PreExecute", func(t *testing.T) { + var err error + + err = task.PreExecute(ctx) + assert.NoError(t, err) + + task.Schema = []byte{0x1, 0x2, 0x3, 0x4} + err = task.PreExecute(ctx) + assert.Error(t, err) + task.Schema = marshaledSchema + + task.ShardsNum = Params.MaxShardNum + 1 + err = task.PreExecute(ctx) + assert.Error(t, err) + task.ShardsNum = shardsNum + + reqBackup := proto.Clone(task.CreateCollectionRequest).(*milvuspb.CreateCollectionRequest) + schemaBackup := proto.Clone(schema).(*schemapb.CollectionSchema) + + schemaWithTooManyFields := &schemapb.CollectionSchema{ + Name: collectionName, + Description: "", + AutoID: false, + Fields: make([]*schemapb.FieldSchema, Params.MaxFieldNum+1), + } + marshaledSchemaWithTooManyFields, err := proto.Marshal(schemaWithTooManyFields) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = marshaledSchemaWithTooManyFields + err = task.PreExecute(ctx) + assert.Error(t, err) + + task.CreateCollectionRequest = reqBackup + + // ValidateCollectionName + + schema.Name = " " // empty + emptyNameSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = emptyNameSchema + err = task.PreExecute(ctx) + assert.Error(t, err) + + schema.Name = prefix + for i := 0; i < int(Params.MaxNameLength); i++ { + schema.Name += strconv.Itoa(i % 10) + } + tooLongNameSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = tooLongNameSchema + err = task.PreExecute(ctx) + assert.Error(t, err) + + schema.Name = "$" // invalid first char + invalidFirstCharSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = invalidFirstCharSchema + err = task.PreExecute(ctx) + assert.Error(t, err) + + // ValidateDuplicatedFieldName + schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) + schema.Fields = append(schema.Fields, schema.Fields[0]) + duplicatedFieldsSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = duplicatedFieldsSchema + err = task.PreExecute(ctx) + assert.Error(t, err) + + // ValidatePrimaryKey + schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) + for idx := range schema.Fields { + schema.Fields[idx].IsPrimaryKey = false + } + noPrimaryFieldsSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = noPrimaryFieldsSchema + err = task.PreExecute(ctx) + assert.Error(t, err) + + // ValidateFieldName + schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) + for idx := range schema.Fields { + schema.Fields[idx].Name = "$" + } + invalidFieldNameSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = invalidFieldNameSchema + err = task.PreExecute(ctx) + assert.Error(t, err) + + // ValidateVectorField + schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) + for idx := range schema.Fields { + if schema.Fields[idx].DataType == schemapb.DataType_FloatVector || + schema.Fields[idx].DataType == schemapb.DataType_BinaryVector { + schema.Fields[idx].TypeParams = nil + } + } + noDimSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = noDimSchema + err = task.PreExecute(ctx) + assert.Error(t, err) + + schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) + for idx := range schema.Fields { + if schema.Fields[idx].DataType == schemapb.DataType_FloatVector || + schema.Fields[idx].DataType == schemapb.DataType_BinaryVector { + schema.Fields[idx].TypeParams = []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "not int", + }, + } + } + } + dimNotIntSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = dimNotIntSchema + err = task.PreExecute(ctx) + assert.Error(t, err) + + schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) + for idx := range schema.Fields { + if schema.Fields[idx].DataType == schemapb.DataType_FloatVector || + schema.Fields[idx].DataType == schemapb.DataType_BinaryVector { + schema.Fields[idx].TypeParams = []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: strconv.Itoa(int(Params.MaxDimension) + 1), + }, + } + } + } + tooLargeDimSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = tooLargeDimSchema + err = task.PreExecute(ctx) + assert.Error(t, err) + + schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) + schema.Fields[1].DataType = schemapb.DataType_BinaryVector + schema.Fields[1].TypeParams = []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: strconv.Itoa(int(Params.MaxDimension) + 1), + }, + } + binaryTooLargeDimSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = binaryTooLargeDimSchema + err = task.PreExecute(ctx) + assert.Error(t, err) + }) +} diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 723d76f1b9d1c6c894a1de21b7fc8f1daf2291a5..1637785b886e717981928bfc506b6dcc66e3439e 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -48,12 +48,12 @@ import ( ) const ( - InsertTaskName = "insertTask" + InsertTaskName = "InsertTask" CreateCollectionTaskName = "CreateCollectionTask" DropCollectionTaskName = "DropCollectionTask" - SearchTaskName = "searchTask" + SearchTaskName = "SearchTask" RetrieveTaskName = "RetrieveTask" - QueryTaskName = "queryTask" + QueryTaskName = "QueryTask" AnnsFieldKey = "anns_field" TopKKey = "topk" MetricTypeKey = "metric_type" @@ -67,12 +67,12 @@ const ( DropPartitionTaskName = "DropPartitionTask" HasPartitionTaskName = "HasPartitionTask" ShowPartitionTaskName = "ShowPartitionTask" - CreateIndexTaskName = "createIndexTask" - DescribeIndexTaskName = "describeIndexTask" - DropIndexTaskName = "dropIndexTask" - GetIndexStateTaskName = "getIndexStateTask" - GetIndexBuildProgressTaskName = "getIndexBuildProgressTask" - FlushTaskName = "flushTask" + CreateIndexTaskName = "CreateIndexTask" + DescribeIndexTaskName = "DescribeIndexTask" + DropIndexTaskName = "DropIndexTask" + GetIndexStateTaskName = "GetIndexStateTask" + GetIndexBuildProgressTaskName = "GetIndexBuildProgressTask" + FlushTaskName = "FlushTask" LoadCollectionTaskName = "LoadCollectionTask" ReleaseCollectionTaskName = "ReleaseCollectionTask" LoadPartitionTaskName = "LoadPartitionsTask" @@ -1112,6 +1112,8 @@ func (cct *createCollectionTask) SetTs(ts Timestamp) { func (cct *createCollectionTask) OnEnqueue() error { cct.Base = &commonpb.MsgBase{} + cct.Base.MsgType = commonpb.MsgType_CreateCollection + cct.Base.SourceID = Params.ProxyID return nil } @@ -1121,8 +1123,11 @@ func (cct *createCollectionTask) PreExecute(ctx context.Context) error { cct.schema = &schemapb.CollectionSchema{} err := proto.Unmarshal(cct.Schema, cct.schema) + if err != nil { + return err + } cct.schema.AutoID = false - cct.CreateCollectionRequest.Schema, _ = proto.Marshal(cct.schema) + cct.CreateCollectionRequest.Schema, err = proto.Marshal(cct.schema) if err != nil { return err } diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index b3f86d701848285b1f5f9e387a148c9221a905a8..c3bf0c64f9a4697e3411792976104222feb998ae 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -1,8 +1,13 @@ package proxy import ( + "fmt" + "strconv" "testing" + "github.com/golang/protobuf/proto" + "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/schemapb" @@ -12,6 +17,70 @@ import ( // TODO(dragondriver): add more test cases +func constructCollectionSchema( + int64Field, floatVecField string, + dim int, + collectionName string, +) *schemapb.CollectionSchema { + + pk := &schemapb.FieldSchema{ + FieldID: 0, + Name: int64Field, + IsPrimaryKey: true, + Description: "", + DataType: schemapb.DataType_Int64, + TypeParams: nil, + IndexParams: nil, + AutoID: true, + } + fVec := &schemapb.FieldSchema{ + FieldID: 0, + Name: floatVecField, + IsPrimaryKey: false, + Description: "", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: strconv.Itoa(dim), + }, + }, + IndexParams: nil, + AutoID: false, + } + return &schemapb.CollectionSchema{ + Name: collectionName, + Description: "", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + pk, + fVec, + }, + } +} + +func constructCreateCollectionRequest( + schema *schemapb.CollectionSchema, + dbName, collectionName string, + shardsNum int32, +) *milvuspb.CreateCollectionRequest { + bs, err := proto.Marshal(schema) + if err != nil { + panic( + fmt.Sprintf( + "failed to marshal collection schema, schema: %v, error: %v", + schema, + err)) + } + return &milvuspb.CreateCollectionRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + Schema: bs, + ShardsNum: shardsNum, + } +} + func TestGetNumRowsOfScalarField(t *testing.T) { cases := []struct { datas interface{} @@ -508,7 +577,3 @@ func TestTranslateOutputFields(t *testing.T) { assert.Equal(t, nil, err) assert.ElementsMatch(t, []string{idFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields) } - -func TestCreateCollectionTask(t *testing.T) { - -}