未验证 提交 b074f530 编写于 作者: B bigsheeper 提交者: GitHub

Forbid createIndex if collection loaded before (#20100)

Signed-off-by: Nbigsheeper <yihao.dai@zilliz.com>
Signed-off-by: Nbigsheeper <yihao.dai@zilliz.com>
上级 3ff0112e
......@@ -1924,6 +1924,7 @@ func (node *Proxy) CreateIndex(ctx context.Context, request *milvuspb.CreateInde
req: request,
rootCoord: node.rootCoord,
indexCoord: node.indexCoord,
queryCoord: node.queryCoord,
}
method := "CreateIndex"
......
......@@ -869,7 +869,7 @@ func (dpt *dropPartitionTask) PreExecute(ctx context.Context) error {
return err
}
collLoaded, err := isCollectionLoaded(ctx, dpt.queryCoord, []int64{collID})
collLoaded, err := isCollectionLoaded(ctx, dpt.queryCoord, collID)
if err != nil {
return err
}
......
......@@ -55,6 +55,7 @@ type createIndexTask struct {
ctx context.Context
rootCoord types.RootCoord
indexCoord types.IndexCoord
queryCoord types.QueryCoord
result *commonpb.Status
isAutoIndex bool
......@@ -283,7 +284,20 @@ func (cit *createIndexTask) PreExecute(ctx context.Context) error {
}
cit.fieldSchema = field
// check index param, not accurate, only some static rules
return cit.parseIndexParams()
err = cit.parseIndexParams()
if err != nil {
return err
}
loaded, err := isCollectionLoaded(ctx, cit.queryCoord, collID)
if err != nil {
return err
}
if loaded {
return fmt.Errorf("create index failed, collection is loaded, please release it first")
}
return nil
}
func (cit *createIndexTask) Execute(ctx context.Context) error {
......@@ -504,7 +518,7 @@ func (dit *dropIndexTask) PreExecute(ctx context.Context) error {
}
dit.collectionID = collID
loaded, err := isCollectionLoaded(ctx, dit.queryCoord, []int64{collID})
loaded, err := isCollectionLoaded(ctx, dit.queryCoord, collID)
if err != nil {
return err
}
......
......@@ -21,14 +21,15 @@ import (
"errors"
"testing"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
func TestGetIndexStateTask_Execute(t *testing.T) {
......@@ -203,3 +204,87 @@ func TestDropIndexTask_PreExecute(t *testing.T) {
assert.Error(t, err)
})
}
func TestCreateIndexTask_PreExecute(t *testing.T) {
collectionName := "collection1"
collectionID := UniqueID(1)
fieldName := newTestSchema().Fields[0].Name
Params.Init()
ic := newMockIndexCoord()
ctx := context.Background()
mockCache := newMockCache()
mockCache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
return collectionID, nil
})
mockCache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) {
return newTestSchema(), nil
})
globalMetaCache = mockCache
cit := createIndexTask{
ctx: ctx,
req: &milvuspb.CreateIndexRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_CreateIndex,
},
CollectionName: collectionName,
FieldName: fieldName,
},
indexCoord: ic,
queryCoord: nil,
result: nil,
collectionID: collectionID,
}
t.Run("normal", func(t *testing.T) {
showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: []int64{},
}, nil
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
qc.updateState(commonpb.StateCode_Healthy)
cit.queryCoord = qc
err := cit.PreExecute(ctx)
assert.NoError(t, err)
})
t.Run("coll has been loaded", func(t *testing.T) {
showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: []int64{collectionID},
}, nil
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
qc.updateState(commonpb.StateCode_Healthy)
cit.queryCoord = qc
err := cit.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("check load error", func(t *testing.T) {
showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "fail reason",
},
CollectionIDs: nil,
}, errors.New("error")
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
qc.updateState(commonpb.StateCode_Healthy)
cit.queryCoord = qc
err := cit.PreExecute(ctx)
assert.Error(t, err)
})
}
......@@ -2136,6 +2136,17 @@ func Test_createIndexTask_PreExecute(t *testing.T) {
FieldName: fieldName,
},
}
showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: []int64{},
}, nil
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
qc.updateState(commonpb.StateCode_Healthy)
cit.queryCoord = qc
t.Run("normal", func(t *testing.T) {
cache := newMockCache()
......
......@@ -836,7 +836,7 @@ func validateIndexName(indexName string) error {
return nil
}
func isCollectionLoaded(ctx context.Context, qc types.QueryCoord, collIDs []int64) (bool, error) {
func isCollectionLoaded(ctx context.Context, qc types.QueryCoord, collID int64) (bool, error) {
// get all loading collections
resp, err := qc.ShowCollections(ctx, &querypb.ShowCollectionsRequest{
CollectionIDs: nil,
......@@ -848,23 +848,18 @@ func isCollectionLoaded(ctx context.Context, qc types.QueryCoord, collIDs []int6
return false, errors.New(resp.Status.Reason)
}
loaded := false
LOOP:
for _, loadedCollID := range resp.GetCollectionIDs() {
for _, collID := range collIDs {
if collID == loadedCollID {
loaded = true
break LOOP
}
if collID == loadedCollID {
return true, nil
}
}
return loaded, nil
return false, nil
}
func isPartitionLoaded(ctx context.Context, qc types.QueryCoord, collIDs int64, partIDs []int64) (bool, error) {
func isPartitionLoaded(ctx context.Context, qc types.QueryCoord, collID int64, partIDs []int64) (bool, error) {
// get all loading collections
resp, err := qc.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{
CollectionID: collIDs,
CollectionID: collID,
PartitionIDs: nil,
})
if err != nil {
......@@ -874,15 +869,12 @@ func isPartitionLoaded(ctx context.Context, qc types.QueryCoord, collIDs int64,
return false, errors.New(resp.Status.Reason)
}
loaded := false
LOOP:
for _, loadedPartID := range resp.GetPartitionIDs() {
for _, partID := range partIDs {
if partID == loadedPartID {
loaded = true
break LOOP
return true, nil
}
}
}
return loaded, nil
return false, nil
}
......@@ -825,7 +825,7 @@ func Test_isCollectionIsLoaded(t *testing.T) {
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
qc.updateState(commonpb.StateCode_Healthy)
loaded, err := isCollectionLoaded(ctx, qc, []int64{collID})
loaded, err := isCollectionLoaded(ctx, qc, collID)
assert.NoError(t, err)
assert.True(t, loaded)
})
......@@ -843,7 +843,7 @@ func Test_isCollectionIsLoaded(t *testing.T) {
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
qc.updateState(commonpb.StateCode_Healthy)
loaded, err := isCollectionLoaded(ctx, qc, []int64{collID})
loaded, err := isCollectionLoaded(ctx, qc, collID)
assert.Error(t, err)
assert.False(t, loaded)
})
......@@ -861,7 +861,7 @@ func Test_isCollectionIsLoaded(t *testing.T) {
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
qc.updateState(commonpb.StateCode_Healthy)
loaded, err := isCollectionLoaded(ctx, qc, []int64{collID})
loaded, err := isCollectionLoaded(ctx, qc, collID)
assert.Error(t, err)
assert.False(t, loaded)
})
......
......@@ -1398,9 +1398,9 @@ class TestIndexString(TestcaseBase):
data = cf.gen_default_list_data(ct.default_nb)
collection_w.insert(data=data)
collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index, index_name="vector_flat")
collection_w.load()
index, _ = self.index_wrap.init_index(collection_w.collection, default_string_field_name,
default_string_index_params)
collection_w.load()
cf.assert_equal_index(index, collection_w.indexes[0])
assert collection_w.num_entities == default_nb
......
......@@ -748,7 +748,7 @@ class TestQueryParams(TestcaseBase):
expected: verify query result
"""
# init collection with fields: int64, float, float_vec
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
collection_w, vectors = self.init_collection_general(prefix, insert_data=True, is_index=True)[0:2]
df = vectors[0]
# query with output_fields=["*", float_vector)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册