提交 a5c6f40c 编写于 作者: X XuanYang-cn 提交者: yefu.chen

Enable tests in test_index.py

Signed-off-by: NXuanYang-cn <xuan.yang@zilliz.com>
上级 4ad0338d
......@@ -134,8 +134,8 @@ func (node *NodeImpl) HasCollection(ctx context.Context, request *milvuspb.HasCo
func (node *NodeImpl) LoadCollection(ctx context.Context, request *milvuspb.LoadCollectionRequest) (*commonpb.Status, error) {
log.Println("load collection: ", request)
ctx, cancel := context.WithTimeout(ctx, reqTimeoutInterval)
defer cancel()
//ctx, cancel := context.WithTimeout(ctx, reqTimeoutInterval)
//defer cancel()
lct := &LoadCollectionTask{
ctx: ctx,
......
package types
import (
"context"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/datapb"
"github.com/zilliztech/milvus-distributed/internal/proto/indexpb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb2"
"github.com/zilliztech/milvus-distributed/internal/proto/masterpb"
"github.com/zilliztech/milvus-distributed/internal/proto/milvuspb"
"github.com/zilliztech/milvus-distributed/internal/proto/proxypb"
"github.com/zilliztech/milvus-distributed/internal/proto/querypb"
)
type TimeTickProvider interface {
GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error)
}
type Component interface {
Init() error
Start() error
Stop() error
GetComponentStates(ctx context.Context) (*internalpb2.ComponentStates, error)
GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error)
}
type DataNodeService interface {
Component
WatchDmChannels(ctx context.Context, in *datapb.WatchDmChannelRequest) (*commonpb.Status, error)
FlushSegments(ctx context.Context, in *datapb.FlushSegRequest) (*commonpb.Status, error)
}
type DataService interface {
Component
TimeTickProvider
RegisterNode(ctx context.Context, req *datapb.RegisterNodeRequest) (*datapb.RegisterNodeResponse, error)
Flush(ctx context.Context, req *datapb.FlushRequest) (*commonpb.Status, error)
AssignSegmentID(ctx context.Context, req *datapb.AssignSegIDRequest) (*datapb.AssignSegIDResponse, error)
ShowSegments(ctx context.Context, req *datapb.ShowSegmentRequest) (*datapb.ShowSegmentResponse, error)
GetSegmentStates(ctx context.Context, req *datapb.SegmentStatesRequest) (*datapb.SegmentStatesResponse, error)
GetInsertBinlogPaths(ctx context.Context, req *datapb.InsertBinlogPathRequest) (*datapb.InsertBinlogPathsResponse, error)
GetSegmentInfoChannel(ctx context.Context) (*milvuspb.StringResponse, error)
GetInsertChannels(ctx context.Context, req *datapb.InsertChannelRequest) (*internalpb2.StringList, error)
GetCollectionStatistics(ctx context.Context, req *datapb.CollectionStatsRequest) (*datapb.CollectionStatsResponse, error)
GetPartitionStatistics(ctx context.Context, req *datapb.PartitionStatsRequest) (*datapb.PartitionStatsResponse, error)
GetCount(ctx context.Context, req *datapb.CollectionCountRequest) (*datapb.CollectionCountResponse, error)
GetSegmentInfo(ctx context.Context, req *datapb.SegmentInfoRequest) (*datapb.SegmentInfoResponse, error)
}
type IndexNodeService interface {
Component
TimeTickProvider
BuildIndex(ctx context.Context, req *indexpb.BuildIndexCmd) (*commonpb.Status, error)
DropIndex(ctx context.Context, req *indexpb.DropIndexRequest) (*commonpb.Status, error)
}
type IndexService interface {
Component
TimeTickProvider
RegisterNode(ctx context.Context, req *indexpb.RegisterNodeRequest) (*indexpb.RegisterNodeResponse, error)
BuildIndex(ctx context.Context, req *indexpb.BuildIndexRequest) (*indexpb.BuildIndexResponse, error)
DropIndex(ctx context.Context, req *indexpb.DropIndexRequest) (*commonpb.Status, error)
GetIndexStates(ctx context.Context, req *indexpb.IndexStatesRequest) (*indexpb.IndexStatesResponse, error)
GetIndexFilePaths(ctx context.Context, req *indexpb.IndexFilePathsRequest) (*indexpb.IndexFilePathsResponse, error)
NotifyBuildIndex(ctx context.Context, nty *indexpb.BuildIndexNotification) (*commonpb.Status, error)
}
type MasterService interface {
Component
//DDL request
CreateCollection(ctx context.Context, in *milvuspb.CreateCollectionRequest) (*commonpb.Status, error)
DropCollection(ctx context.Context, in *milvuspb.DropCollectionRequest) (*commonpb.Status, error)
HasCollection(ctx context.Context, in *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error)
DescribeCollection(ctx context.Context, in *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error)
ShowCollections(ctx context.Context, in *milvuspb.ShowCollectionRequest) (*milvuspb.ShowCollectionResponse, error)
CreatePartition(ctx context.Context, in *milvuspb.CreatePartitionRequest) (*commonpb.Status, error)
DropPartition(ctx context.Context, in *milvuspb.DropPartitionRequest) (*commonpb.Status, error)
HasPartition(ctx context.Context, in *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error)
ShowPartitions(ctx context.Context, in *milvuspb.ShowPartitionRequest) (*milvuspb.ShowPartitionResponse, error)
//index builder service
CreateIndex(ctx context.Context, in *milvuspb.CreateIndexRequest) (*commonpb.Status, error)
DescribeIndex(ctx context.Context, in *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error)
DropIndex(ctx context.Context, in *milvuspb.DropIndexRequest) (*commonpb.Status, error)
//global timestamp allocator
AllocTimestamp(ctx context.Context, in *masterpb.TsoRequest) (*masterpb.TsoResponse, error)
AllocID(ctx context.Context, in *masterpb.IDRequest) (*masterpb.IDResponse, error)
//TODO, master load these channel form config file ?
//receiver time tick from proxy service, and put it into this channel
GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error)
//receive ddl from rpc and time tick from proxy service, and put them into this channel
GetDdChannel(ctx context.Context) (*milvuspb.StringResponse, error)
//just define a channel, not used currently
GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error)
//segment
DescribeSegment(ctx context.Context, in *milvuspb.DescribeSegmentRequest) (*milvuspb.DescribeSegmentResponse, error)
ShowSegments(ctx context.Context, in *milvuspb.ShowSegmentRequest) (*milvuspb.ShowSegmentResponse, error)
}
type ProxyNodeService interface {
Component
InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error)
CreateCollection(ctx context.Context, request *milvuspb.CreateCollectionRequest) (*commonpb.Status, error)
DropCollection(ctx context.Context, request *milvuspb.DropCollectionRequest) (*commonpb.Status, error)
HasCollection(ctx context.Context, request *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error)
LoadCollection(ctx context.Context, request *milvuspb.LoadCollectionRequest) (*commonpb.Status, error)
ReleaseCollection(ctx context.Context, request *milvuspb.ReleaseCollectionRequest) (*commonpb.Status, error)
DescribeCollection(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error)
GetCollectionStatistics(ctx context.Context, request *milvuspb.CollectionStatsRequest) (*milvuspb.CollectionStatsResponse, error)
ShowCollections(ctx context.Context, request *milvuspb.ShowCollectionRequest) (*milvuspb.ShowCollectionResponse, error)
CreatePartition(ctx context.Context, request *milvuspb.CreatePartitionRequest) (*commonpb.Status, error)
DropPartition(ctx context.Context, request *milvuspb.DropPartitionRequest) (*commonpb.Status, error)
HasPartition(ctx context.Context, request *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error)
LoadPartitions(ctx context.Context, request *milvuspb.LoadPartitonRequest) (*commonpb.Status, error)
ReleasePartitions(ctx context.Context, request *milvuspb.ReleasePartitionRequest) (*commonpb.Status, error)
GetPartitionStatistics(ctx context.Context, request *milvuspb.PartitionStatsRequest) (*milvuspb.PartitionStatsResponse, error)
ShowPartitions(ctx context.Context, request *milvuspb.ShowPartitionRequest) (*milvuspb.ShowPartitionResponse, error)
CreateIndex(ctx context.Context, request *milvuspb.CreateIndexRequest) (*commonpb.Status, error)
DescribeIndex(ctx context.Context, request *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error)
GetIndexState(ctx context.Context, request *milvuspb.IndexStateRequest) (*milvuspb.IndexStateResponse, error)
DropIndex(ctx context.Context, request *milvuspb.DropIndexRequest) (*commonpb.Status, error)
Insert(ctx context.Context, request *milvuspb.InsertRequest) (*milvuspb.InsertResponse, error)
Search(ctx context.Context, request *milvuspb.SearchRequest) (*milvuspb.SearchResults, error)
Flush(ctx context.Context, request *milvuspb.FlushRequest) (*commonpb.Status, error)
GetDdChannel(ctx context.Context, request *commonpb.Empty) (*milvuspb.StringResponse, error)
GetQuerySegmentInfo(ctx context.Context, req *milvuspb.QuerySegmentInfoRequest) (*milvuspb.QuerySegmentInfoResponse, error)
GetPersistentSegmentInfo(ctx context.Context, req *milvuspb.PersistentSegmentInfoRequest) (*milvuspb.PersistentSegmentInfoResponse, error)
}
type ProxyService interface {
Component
TimeTickProvider
RegisterNode(ctx context.Context, request *proxypb.RegisterNodeRequest) (*proxypb.RegisterNodeResponse, error)
InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error)
}
type QueryNodeService interface {
Component
TimeTickProvider
AddQueryChannel(ctx context.Context, in *querypb.AddQueryChannelsRequest) (*commonpb.Status, error)
RemoveQueryChannel(ctx context.Context, in *querypb.RemoveQueryChannelsRequest) (*commonpb.Status, error)
WatchDmChannels(ctx context.Context, in *querypb.WatchDmChannelsRequest) (*commonpb.Status, error)
LoadSegments(ctx context.Context, in *querypb.LoadSegmentRequest) (*commonpb.Status, error)
ReleaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest) (*commonpb.Status, error)
ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionRequest) (*commonpb.Status, error)
ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentRequest) (*commonpb.Status, error)
GetSegmentInfo(ctx context.Context, in *querypb.SegmentInfoRequest) (*querypb.SegmentInfoResponse, error)
}
type QueryService interface {
Component
TimeTickProvider
ShowCollections(ctx context.Context, req *querypb.ShowCollectionRequest) (*querypb.ShowCollectionResponse, error)
LoadCollection(ctx context.Context, req *querypb.LoadCollectionRequest) (*commonpb.Status, error)
ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error)
ShowPartitions(ctx context.Context, req *querypb.ShowPartitionRequest) (*querypb.ShowPartitionResponse, error)
LoadPartitions(ctx context.Context, req *querypb.LoadPartitionRequest) (*commonpb.Status, error)
ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionRequest) (*commonpb.Status, error)
CreateQueryChannel(ctx context.Context) (*querypb.CreateQueryChannelResponse, error)
GetPartitionStates(ctx context.Context, req *querypb.PartitionStatesRequest) (*querypb.PartitionStatesResponse, error)
GetSegmentInfo(ctx context.Context, req *querypb.SegmentInfoRequest) (*querypb.SegmentInfoResponse, error)
}
......@@ -46,6 +46,7 @@ class TestIndexBase:
******************************************************************
"""
@pytest.mark.tags("0331")
@pytest.mark.timeout(BUILD_TIMEOUT)
def test_create_index(self, connect, collection, get_simple_index):
'''
......@@ -95,6 +96,7 @@ class TestIndexBase:
index = connect.describe_index(collection, field_name)
assert index == get_simple_index
@pytest.mark.tags("0331")
@pytest.mark.timeout(BUILD_TIMEOUT)
def test_create_index_partition(self, connect, collection, get_simple_index):
'''
......@@ -104,11 +106,11 @@ class TestIndexBase:
'''
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
connect.flush([collection])
connect.create_index(collection, field_name, get_simple_index)
index = connect.describe_index(collection, field_name)
assert index == get_simple_index
@pytest.mark.tags("0331")
@pytest.mark.timeout(BUILD_TIMEOUT)
def test_create_index_partition_flush(self, connect, collection, get_simple_index):
'''
......@@ -118,7 +120,7 @@ class TestIndexBase:
'''
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
connect.flush()
connect.flush([collection])
connect.create_index(collection, field_name, get_simple_index)
index = connect.describe_index(collection, field_name)
assert index == get_simple_index
......@@ -151,6 +153,7 @@ class TestIndexBase:
res = connect.search(collection, query)
assert len(res) == nq
@pytest.mark.tags("0331")
@pytest.mark.timeout(BUILD_TIMEOUT)
@pytest.mark.level(2)
def test_create_index_multithread(self, connect, collection, args):
......@@ -189,6 +192,7 @@ class TestIndexBase:
with pytest.raises(Exception) as e:
connect.create_index(collection_name, field_name, default_index)
@pytest.mark.tags("0331")
@pytest.mark.level(2)
@pytest.mark.timeout(BUILD_TIMEOUT)
def test_create_index_insert_flush(self, connect, collection, get_simple_index):
......@@ -232,9 +236,12 @@ class TestIndexBase:
indexs = [default_index, {"metric_type":"L2", "index_type": "FLAT", "params":{"nlist": 1024}}]
for index in indexs:
connect.create_index(collection, field_name, index)
connect.release_collection(collection)
connect.load_collection(collection)
index = connect.describe_index(collection, field_name)
assert index == indexs[-1]
@pytest.mark.tags("0331")
@pytest.mark.timeout(BUILD_TIMEOUT)
def test_create_index_ip(self, connect, collection, get_simple_index):
'''
......@@ -261,6 +268,7 @@ class TestIndexBase:
index = connect.describe_index(collection, field_name)
assert index == get_simple_index
@pytest.mark.tags("0331")
@pytest.mark.timeout(BUILD_TIMEOUT)
def test_create_index_partition_ip(self, connect, collection, get_simple_index):
'''
......@@ -270,12 +278,12 @@ class TestIndexBase:
'''
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
connect.flush([collection])
get_simple_index["metric_type"] = "IP"
connect.create_index(collection, field_name, get_simple_index)
index = connect.describe_index(collection, field_name)
assert index == get_simple_index
@pytest.mark.tags("0331")
@pytest.mark.timeout(BUILD_TIMEOUT)
def test_create_index_partition_flush_ip(self, connect, collection, get_simple_index):
'''
......@@ -285,7 +293,7 @@ class TestIndexBase:
'''
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
connect.flush()
connect.flush([collection])
get_simple_index["metric_type"] = "IP"
connect.create_index(collection, field_name, get_simple_index)
index = connect.describe_index(collection, field_name)
......@@ -302,7 +310,8 @@ class TestIndexBase:
ids = connect.insert(collection, default_entities)
get_simple_index["metric_type"] = metric_type
connect.create_index(collection, field_name, get_simple_index)
logging.getLogger().info(connect.describe_index(collection))
connect.load_collection(collection)
logging.getLogger().info(connect.describe_index(collection, field_name))
nq = get_nq
index_type = get_simple_index["index_type"]
search_param = get_search_param(index_type)
......@@ -310,6 +319,7 @@ class TestIndexBase:
res = connect.search(collection, query)
assert len(res) == nq
@pytest.mark.tags("0331")
@pytest.mark.timeout(BUILD_TIMEOUT)
@pytest.mark.level(2)
def test_create_index_multithread_ip(self, connect, collection, args):
......@@ -350,6 +360,7 @@ class TestIndexBase:
with pytest.raises(Exception) as e:
connect.create_index(collection_name, field_name, default_index)
@pytest.mark.tags("0331")
@pytest.mark.timeout(BUILD_TIMEOUT)
def test_create_index_no_vectors_insert_ip(self, connect, collection):
'''
......@@ -391,12 +402,15 @@ class TestIndexBase:
expected: return code 0, and describe index result equals with the second index params
'''
ids = connect.insert(collection, default_entities)
connect.load_collection(collection)
stats = connect.get_collection_stats(collection)
assert stats["row_count"] == default_nb
default_index["metric_type"] = "IP"
indexs = [default_index, {"index_type": "FLAT", "params": {"nlist": 1024}, "metric_type": "IP"}]
for index in indexs:
connect.create_index(collection, field_name, index)
connect.release_collection(collection)
connect.load_collection(collection)
index = connect.describe_index(collection, field_name)
assert index == indexs[-1]
......@@ -585,6 +599,7 @@ class TestIndexBinary:
******************************************************************
"""
@pytest.mark.tags("0331")
@pytest.mark.timeout(BUILD_TIMEOUT)
def test_create_index(self, connect, binary_collection, get_jaccard_index):
'''
......@@ -597,6 +612,7 @@ class TestIndexBinary:
binary_index = connect.describe_index(binary_collection, binary_field_name)
assert binary_index == get_jaccard_index
@pytest.mark.tags("0331")
@pytest.mark.timeout(BUILD_TIMEOUT)
def test_create_index_partition(self, connect, binary_collection, get_jaccard_index):
'''
......@@ -620,10 +636,10 @@ class TestIndexBinary:
nq = get_nq
ids = connect.insert(binary_collection, default_binary_entities)
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
connect.load_collection(binary_collection)
query, vecs = gen_query_vectors(binary_field_name, default_binary_entities, default_top_k, nq, metric_type="JACCARD")
search_param = get_search_param(get_jaccard_index["index_type"], metric_type="JACCARD")
logging.getLogger().info(search_param)
connect.load_collection(binary_collection)
res = connect.search(binary_collection, query, search_params=search_param)
assert len(res) == nq
......@@ -650,6 +666,7 @@ class TestIndexBinary:
The following cases are used to test `describe_index` function
***************************************************************
"""
@pytest.mark.skip("repeat with test_create_index binary")
def test_get_index_info(self, connect, binary_collection, get_jaccard_index):
'''
target: test describe index interface
......@@ -669,6 +686,7 @@ class TestIndexBinary:
if "index_type" in file:
assert file["index_type"] == get_jaccard_index["index_type"]
@pytest.mark.skip("repeat with test_create_index_partition binary")
def test_get_index_info_partition(self, connect, binary_collection, get_jaccard_index):
'''
target: test describe index interface
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册