未验证 提交 19387754 编写于 作者: Z zhuwenxing 提交者: GitHub

[test]Add restful api test (#21336)

Signed-off-by: Nzhuwenxing <wenxing.zhu@zilliz.com>
上级 396a85c9
from decorest import GET, POST, DELETE
from decorest import HttpStatus, RestClient
from decorest import accept, body, content, endpoint, form
from decorest import header, multipart, on, query, stream, timeout
class Alias(RestClient):
def drop_alias():
pass
def alter_alias():
pass
def create_alias():
pass
\ No newline at end of file
import json
from decorest import GET, POST, DELETE
from decorest import HttpStatus, RestClient
from decorest import accept, body, content, endpoint, form
from decorest import header, multipart, on, query, stream, timeout
class Collection(RestClient):
@DELETE("collection")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def drop_collection(self, payload):
"""Drop a collection"""
@GET("collection")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def describe_collection(self, payload):
"""Describe a collection"""
@POST("collection")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def create_collection(self, payload):
"""Create a collection"""
@GET("collection/existence")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def has_collection(self, payload):
"""Check if a collection exists"""
@DELETE("collection/load")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def release_collection(self, payload):
"""Release a collection"""
@POST("collection/load")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def load_collection(self, payload):
"""Load a collection"""
@GET("collection/statistics")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def get_collection_statistics(self, payload):
"""Get collection statistics"""
@GET("collections")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def show_collections(self, payload):
"""Show collections"""
if __name__ == '__main__':
client = Collection("http://localhost:19121/api/v1")
print(client)
from decorest import GET, POST, DELETE
from decorest import HttpStatus, RestClient
from decorest import accept, body, content, endpoint, form
from decorest import header, multipart, on, query, stream, timeout
class Credential(RestClient):
def delete_credential():
pass
def update_credential():
pass
def create_credential():
pass
def list_credentials():
pass
import json
from decorest import GET, POST, DELETE
from decorest import HttpStatus, RestClient
from decorest import accept, body, content, endpoint, form
from decorest import header, multipart, on, query, stream, timeout
class Entity(RestClient):
@POST("distance")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def calc_distance(self, payload):
""" Calculate distance between two points """
@DELETE("entities")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def delete(self, payload):
"""delete entities"""
@POST("entities")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def insert(self, payload):
"""insert entities"""
@POST("persist")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def flush(self, payload):
"""flush entities"""
@POST("persist/segment-info")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def get_persistent_segment_info(self, payload):
"""get persistent segment info"""
@POST("persist/state")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def get_flush_state(self, payload):
"""get flush state"""
@POST("query")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def query(self, payload):
"""query entities"""
@POST("query-segment-info")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def get_query_segment_info(self, payload):
"""get query segment info"""
@POST("search")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def search(self, payload):
"""search entities"""
\ No newline at end of file
from decorest import GET, POST, DELETE
from decorest import HttpStatus, RestClient
from decorest import accept, body, content, endpoint, form
from decorest import header, multipart, on, query, stream, timeout
class Import(RestClient):
def list_import_tasks():
pass
def exec_import():
pass
def get_import_state():
pass
import json
from decorest import GET, POST, DELETE
from decorest import HttpStatus, RestClient
from decorest import accept, body, content, endpoint, form
from decorest import header, multipart, on, query, stream, timeout
class Index(RestClient):
@DELETE("/index")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def drop_index(self, payload):
"""Drop an index"""
@GET("/index")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def describe_index(self, payload):
"""Describe an index"""
@POST("index")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def create_index(self, payload):
"""create index"""
@GET("index/progress")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def get_index_build_progress(self, payload):
"""get index build progress"""
@GET("index/state")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def get_index_state(self, payload):
"""get index state"""
from decorest import GET, POST, DELETE
from decorest import HttpStatus, RestClient
from decorest import accept, body, content, endpoint, form
from decorest import header, multipart, on, query, stream, timeout
class Metrics(RestClient):
def get_metrics():
pass
\ No newline at end of file
from decorest import GET, POST, DELETE
from decorest import HttpStatus, RestClient
from decorest import accept, body, content, endpoint, form
from decorest import header, multipart, on, query, stream, timeout
class Ops(RestClient):
def manual_compaction():
pass
def get_compaction_plans():
pass
def get_compaction_state():
pass
def load_balance():
pass
def get_replicas():
pass
\ No newline at end of file
from decorest import GET, POST, DELETE
from decorest import HttpStatus, RestClient
from decorest import accept, body, content, endpoint, form
from decorest import header, multipart, on, query, stream, timeout
class Partition(RestClient):
def drop_partition():
pass
def create_partition():
pass
def has_partition():
pass
def get_partition_statistics():
pass
def show_partitions():
pass
def release_partition():
pass
def load_partition():
pass
from datetime import date, datetime
from typing import List, Union, Optional
from pydantic import BaseModel, UUID4, conlist
from pydantic_factories import ModelFactory
class Person(BaseModel):
def __init__(self, length):
super().__init__()
self.len = length
id: UUID4
name: str
hobbies: List[str]
age: Union[float, int]
birthday: Union[datetime, date]
class Pet(BaseModel):
name: str
age: int
class PetFactory(BaseModel):
name: str
pet: Pet
age: Optional[int] = None
sample = {
"name": "John",
"pet": {
"name": "Fido",
"age": 3
}
}
result = PetFactory(**sample)
print(result)
from time import sleep
from decorest import HttpStatus, RestClient
from models.schema import CollectionSchema
from base.collection_service import CollectionService
from base.index_service import IndexService
from base.entity_service import EntityService
from utils.util_log import test_log as log
from common import common_func as cf
from common import common_type as ct
class Base:
"""init base class"""
endpoint = None
collection_service = None
index_service = None
entity_service = None
collection_name = None
collection_object_list = []
def setup_class(self):
log.info("setup class")
def teardown_class(self):
log.info("teardown class")
def setup_method(self, method):
log.info(("*" * 35) + " setup " + ("*" * 35))
log.info("[setup_method] Start setup test case %s." % method.__name__)
host = cf.param_info.param_host
port = cf.param_info.param_port
self.endpoint = "http://" + host + ":" + str(port) + "/api/v1"
self.collection_service = CollectionService(self.endpoint)
self.index_service = IndexService(self.endpoint)
self.entity_service = EntityService(self.endpoint)
def teardown_method(self, method):
res = self.collection_service.has_collection(collection_name=self.collection_name)
log.info(f"collection {self.collection_name} exists: {res}")
if res["value"] is True:
res = self.collection_service.drop_collection(self.collection_name)
log.info(f"drop collection {self.collection_name} res: {res}")
res = self.collection_service.show_collections()
all_collections = res["collection_names"]
union_collections = set(all_collections) & set(self.collection_object_list)
for collection in union_collections:
res = self.collection_service.drop_collection(collection)
log.info(f"drop collection {collection} res: {res}")
log.info("[teardown_method] Start teardown test case %s." % method.__name__)
log.info(("*" * 35) + " teardown " + ("*" * 35))
class TestBase(Base):
"""init test base class"""
def init_collection(self, name=None, schema=None):
collection_name = cf.gen_unique_str("test") if name is None else name
self.collection_name = collection_name
self.collection_object_list.append(collection_name)
if schema is None:
schema = cf.gen_default_schema(collection_name=collection_name)
# create collection
res = self.collection_service.create_collection(collection_name=collection_name, schema=schema)
log.info(f"create collection name: {collection_name} with schema: {schema}")
return collection_name, schema
from api.collection import Collection
from utils.util_log import test_log as log
from models import milvus
TIMEOUT = 30
class CollectionService:
def __init__(self, endpoint=None, timeout=None):
if timeout is None:
timeout = TIMEOUT
if endpoint is None:
endpoint = "http://localhost:9091/api/v1"
self._collection = Collection(endpoint=endpoint)
def create_collection(self, collection_name, consistency_level=1, schema=None, shards_num=2):
payload = {
"collection_name": collection_name,
"consistency_level": consistency_level,
"schema": schema,
"shards_num": shards_num
}
log.info(f"payload: {payload}")
# payload = milvus.CreateCollectionRequest(collection_name=collection_name,
# consistency_level=consistency_level,
# schema=schema,
# shards_num=shards_num)
# payload = payload.dict()
rsp = self._collection.create_collection(payload)
return rsp
def has_collection(self, collection_name=None, time_stamp=0):
payload = {
"collection_name": collection_name,
"time_stamp": time_stamp
}
# payload = milvus.HasCollectionRequest(collection_name=collection_name, time_stamp=time_stamp)
# payload = payload.dict()
return self._collection.has_collection(payload)
def drop_collection(self, collection_name):
payload = {
"collection_name": collection_name
}
# payload = milvus.DropCollectionRequest(collection_name=collection_name)
# payload = payload.dict()
return self._collection.drop_collection(payload)
def describe_collection(self, collection_name, collection_id=None, time_stamp=0):
payload = {
"collection_name": collection_name,
"collection_id": collection_id,
"time_stamp": time_stamp
}
# payload = milvus.DescribeCollectionRequest(collection_name=collection_name,
# collectionID=collection_id,
# time_stamp=time_stamp)
# payload = payload.dict()
return self._collection.describe_collection(payload)
def load_collection(self, collection_name, replica_number=1):
payload = {
"collection_name": collection_name,
"replica_number": replica_number
}
# payload = milvus.LoadCollectionRequest(collection_name=collection_name, replica_number=replica_number)
# payload = payload.dict()
return self._collection.load_collection(payload)
def release_collection(self, collection_name):
payload = {
"collection_name": collection_name
}
# payload = milvus.ReleaseCollectionRequest(collection_name=collection_name)
# payload = payload.dict()
return self._collection.release_collection(payload)
def get_collection_statistics(self, collection_name):
payload = {
"collection_name": collection_name
}
# payload = milvus.GetCollectionStatisticsRequest(collection_name=collection_name)
# payload = payload.dict()
return self._collection.get_collection_statistics(payload)
def show_collections(self, collection_names=None, type=0):
payload = {
"collection_names": collection_names,
"type": type
}
# payload = milvus.ShowCollectionsRequest(collection_names=collection_names, type=type)
# payload = payload.dict()
return self._collection.show_collections(payload)
from api.entity import Entity
from common import common_type as ct
from utils.util_log import test_log as log
from models import common, schema, milvus, server
TIMEOUT = 30
class EntityService:
def __init__(self, endpoint=None, timeout=None):
if timeout is None:
timeout = TIMEOUT
if endpoint is None:
endpoint = "http://localhost:9091/api/v1"
self._entity = Entity(endpoint=endpoint)
def calc_distance(self, base=None, op_left=None, op_right=None, params=None):
payload = {
"base": base,
"op_left": op_left,
"op_right": op_right,
"params": params
}
# payload = milvus.CalcDistanceRequest(base=base, op_left=op_left, op_right=op_right, params=params)
# payload = payload.dict()
return self._entity.calc_distance(payload)
def delete(self, base=None, collection_name=None, db_name=None, expr=None, hash_keys=None, partition_name=None):
payload = {
"base": base,
"collection_name": collection_name,
"db_name": db_name,
"expr": expr,
"hash_keys": hash_keys,
"partition_name": partition_name
}
# payload = server.DeleteRequest(base=base,
# collection_name=collection_name,
# db_name=db_name,
# expr=expr,
# hash_keys=hash_keys,
# partition_name=partition_name)
# payload = payload.dict()
return self._entity.delete(payload)
def insert(self, base=None, collection_name=None, db_name=None, fields_data=None, hash_keys=None, num_rows=None,
partition_name=None, check_task=True):
payload = {
"base": base,
"collection_name": collection_name,
"db_name": db_name,
"fields_data": fields_data,
"hash_keys": hash_keys,
"num_rows": num_rows,
"partition_name": partition_name
}
# payload = milvus.InsertRequest(base=base,
# collection_name=collection_name,
# db_name=db_name,
# fields_data=fields_data,
# hash_keys=hash_keys,
# num_rows=num_rows,
# partition_name=partition_name)
# payload = payload.dict()
rsp = self._entity.insert(payload)
if check_task:
assert rsp["status"] == {}
assert rsp["insert_cnt"] == num_rows
return rsp
def flush(self, base=None, collection_names=None, db_name=None, check_task=True):
payload = {
"base": base,
"collection_names": collection_names,
"db_name": db_name
}
# payload = server.FlushRequest(base=base,
# collection_names=collection_names,
# db_name=db_name)
# payload = payload.dict()
rsp = self._entity.flush(payload)
if check_task:
assert rsp["status"] == {}
def get_persistent_segment_info(self, base=None, collection_name=None, db_name=None):
payload = {
"base": base,
"collection_name": collection_name,
"db_name": db_name
}
# payload = server.GetPersistentSegmentInfoRequest(base=base,
# collection_name=collection_name,
# db_name=db_name)
# payload = payload.dict()
return self._entity.get_persistent_segment_info(payload)
def get_flush_state(self, segment_ids=None):
payload = {
"segment_ids": segment_ids
}
# payload = server.GetFlushStateRequest(segment_ids=segment_ids)
# payload = payload.dict()
return self._entity.get_flush_state(payload)
def query(self, base=None, collection_name=None, db_name=None, expr=None,
guarantee_timestamp=None, output_fields=None, partition_names=None, travel_timestamp=None,
check_task=True):
payload = {
"base": base,
"collection_name": collection_name,
"db_name": db_name,
"expr": expr,
"guarantee_timestamp": guarantee_timestamp,
"output_fields": output_fields,
"partition_names": partition_names,
"travel_timestamp": travel_timestamp
}
#
# payload = server.QueryRequest(base=base, collection_name=collection_name, db_name=db_name, expr=expr,
# guarantee_timestamp=guarantee_timestamp, output_fields=output_fields,
# partition_names=partition_names, travel_timestamp=travel_timestamp)
# payload = payload.dict()
rsp = self._entity.query(payload)
if check_task:
fields_data = rsp["fields_data"]
for field_data in fields_data:
if field_data["field_name"] in expr:
data = field_data["Field"]["Scalars"]["Data"]["LongData"]["data"]
for d in data:
s = expr.replace(field_data["field_name"], str(d))
assert eval(s) is True
return rsp
def get_query_segment_info(self, base=None, collection_name=None, db_name=None):
payload = {
"base": base,
"collection_name": collection_name,
"db_name": db_name
}
# payload = server.GetQuerySegmentInfoRequest(base=base,
# collection_name=collection_name,
# db_name=db_name)
# payload = payload.dict()
return self._entity.get_query_segment_info(payload)
def search(self, base=None, collection_name=None, vectors=None, db_name=None, dsl=None,
output_fields=None, dsl_type=1,
guarantee_timestamp=None, partition_names=None, placeholder_group=None,
search_params=None, travel_timestamp=None, check_task=True):
payload = {
"base": base,
"collection_name": collection_name,
"output_fields": output_fields,
"vectors": vectors,
"db_name": db_name,
"dsl": dsl,
"dsl_type": dsl_type,
"guarantee_timestamp": guarantee_timestamp,
"partition_names": partition_names,
"placeholder_group": placeholder_group,
"search_params": search_params,
"travel_timestamp": travel_timestamp
}
# payload = server.SearchRequest(base=base, collection_name=collection_name, db_name=db_name, dsl=dsl,
# dsl_type=dsl_type, guarantee_timestamp=guarantee_timestamp,
# partition_names=partition_names, placeholder_group=placeholder_group,
# search_params=search_params, travel_timestamp=travel_timestamp)
# payload = payload.dict()
rsp = self._entity.search(payload)
if check_task:
assert rsp["status"] == {}
assert rsp["results"]["num_queries"] == len(vectors)
assert len(rsp["results"]["ids"]["IdField"]["IntId"]["data"]) == sum(rsp["results"]["topks"])
return rsp
from api.index import Index
from models import common, schema, milvus, server
TIMEOUT = 30
class IndexService:
def __init__(self, endpoint=None, timeout=None):
if timeout is None:
timeout = TIMEOUT
if endpoint is None:
endpoint = "http://localhost:9091/api/v1"
self._index = Index(endpoint=endpoint)
def drop_index(self, base, collection_name, db_name, field_name, index_name):
payload = server.DropIndexRequest(base=base, collection_name=collection_name,
db_name=db_name, field_name=field_name, index_name=index_name)
payload = payload.dict()
return self._index.drop_index(payload)
def describe_index(self, base, collection_name, db_name, field_name, index_name):
payload = server.DescribeIndexRequest(base=base, collection_name=collection_name,
db_name=db_name, field_name=field_name, index_name=index_name)
payload = payload.dict()
return self._index.describe_index(payload)
def create_index(self, base=None, collection_name=None, db_name=None, extra_params=None,
field_name=None, index_name=None):
payload = {
"base": base,
"collection_name": collection_name,
"db_name": db_name,
"extra_params": extra_params,
"field_name": field_name,
"index_name": index_name
}
# payload = server.CreateIndexRequest(base=base, collection_name=collection_name, db_name=db_name,
# extra_params=extra_params, field_name=field_name, index_name=index_name)
# payload = payload.dict()
return self._index.create_index(payload)
def get_index_build_progress(self, base, collection_name, db_name, field_name, index_name):
payload = server.GetIndexBuildProgressRequest(base=base, collection_name=collection_name,
db_name=db_name, field_name=field_name, index_name=index_name)
payload = payload.dict()
return self._index.get_index_build_progress(payload)
def get_index_state(self, base, collection_name, db_name, field_name, index_name):
payload = server.GetIndexStateRequest(base=base, collection_name=collection_name,
db_name=db_name, field_name=field_name, index_name=index_name)
payload = payload.dict()
return self._index.get_index_state(payload)
class CheckTasks:
""" The name of the method used to check the result """
check_nothing = "check_nothing"
err_res = "error_response"
ccr = "check_connection_result"
check_collection_property = "check_collection_property"
check_partition_property = "check_partition_property"
check_search_results = "check_search_results"
check_query_results = "check_query_results"
check_query_empty = "check_query_empty" # verify that query result is empty
check_query_not_empty = "check_query_not_empty"
check_distance = "check_distance"
check_delete_compact = "check_delete_compact"
check_merge_compact = "check_merge_compact"
check_role_property = "check_role_property"
check_permission_deny = "check_permission_deny"
check_value_equal = "check_value_equal"
class ResponseChecker:
def __init__(self, check_task, check_items):
self.check_task = check_task
self.check_items = check_items
from utils.util_log import test_log as log
def ip_check(ip):
if ip == "localhost":
return True
if not isinstance(ip, str):
log.error("[IP_CHECK] IP(%s) is not a string." % ip)
return False
return True
def number_check(num):
if str(num).isdigit():
return True
else:
log.error("[NUMBER_CHECK] Number(%s) is not a numbers." % num)
return False
import json
import os
import random
import string
import numpy as np
from enum import Enum
from common import common_type as ct
from utils.util_log import test_log as log
class ParamInfo:
def __init__(self):
self.param_host = ""
self.param_port = ""
def prepare_param_info(self, host, http_port):
self.param_host = host
self.param_port = http_port
param_info = ParamInfo()
class DataType(Enum):
Bool: 1
Int8: 2
Int16: 3
Int32: 4
Int64: 5
Float: 10
Double: 11
String: 20
VarChar: 21
BinaryVector: 100
FloatVector: 101
def gen_unique_str(str_value=None):
prefix = "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8))
return "test_" + prefix if str_value is None else str_value + "_" + prefix
def gen_field(name=ct.default_bool_field_name, description=ct.default_desc, type_params=None, index_params=None,
data_type="Int64", is_primary_key=False, auto_id=False, dim=128, max_length=256):
data_type_map = {
"Bool": 1,
"Int8": 2,
"Int16": 3,
"Int32": 4,
"Int64": 5,
"Float": 10,
"Double": 11,
"String": 20,
"VarChar": 21,
"BinaryVector": 100,
"FloatVector": 101,
}
if data_type == "Int64":
is_primary_key = True
auto_id = True
if type_params is None:
type_params = []
if index_params is None:
index_params = []
if data_type in ["FloatVector", "BinaryVector"]:
type_params = [{"key": "dim", "value": str(dim)}]
if data_type in ["String", "VarChar"]:
type_params = [{"key": "max_length", "value": str(dim)}]
return {
"name": name,
"description": description,
"data_type": data_type_map.get(data_type, 0),
"type_params": type_params,
"index_params": index_params,
"is_primary_key": is_primary_key,
"auto_id": auto_id,
}
def gen_schema(name, fields, description=ct.default_desc, auto_id=False):
return {
"name": name,
"description": description,
"auto_id": auto_id,
"fields": fields,
}
def gen_default_schema(data_types=None, dim=ct.default_dim, collection_name=None):
if data_types is None:
data_types = ["Int64", "Float", "VarChar", "FloatVector"]
fields = []
for data_type in data_types:
if data_type in ["FloatVector", "BinaryVector"]:
fields.append(gen_field(name=data_type, data_type=data_type, type_params=[{"key": "dim", "value": dim}]))
else:
fields.append(gen_field(name=data_type, data_type=data_type))
return {
"autoID": True,
"fields": fields,
"description": ct.default_desc,
"name": collection_name,
}
def gen_fields_data(schema=None, nb=ct.default_nb,):
if schema is None:
schema = gen_default_schema()
fields = schema["fields"]
fields_data = []
for field in fields:
if field["data_type"] == 1:
fields_data.append([random.choice([True, False]) for i in range(nb)])
elif field["data_type"] == 2:
fields_data.append([i for i in range(nb)])
elif field["data_type"] == 3:
fields_data.append([i for i in range(nb)])
elif field["data_type"] == 4:
fields_data.append([i for i in range(nb)])
elif field["data_type"] == 5:
fields_data.append([i for i in range(nb)])
elif field["data_type"] == 10:
fields_data.append([np.float64(i) for i in range(nb)]) # json not support float32
elif field["data_type"] == 11:
fields_data.append([np.float64(i) for i in range(nb)])
elif field["data_type"] == 20:
fields_data.append([gen_unique_str((str(i))) for i in range(nb)])
elif field["data_type"] == 21:
fields_data.append([gen_unique_str(str(i)) for i in range(nb)])
elif field["data_type"] == 100:
dim = ct.default_dim
for k, v in field["type_params"]:
if k == "dim":
dim = int(v)
break
fields_data.append(gen_binary_vectors(nb, dim))
elif field["data_type"] == 101:
dim = ct.default_dim
for k, v in field["type_params"]:
if k == "dim":
dim = int(v)
break
fields_data.append(gen_float_vectors(nb, dim))
else:
log.error("Unknown data type.")
fields_data_body = []
for i, field in enumerate(fields):
fields_data_body.append({
"field_name": field["name"],
"type": field["data_type"],
"field": fields_data[i],
})
return fields_data_body
def get_vector_field(schema):
for field in schema["fields"]:
if field["data_type"] in [100, 101]:
return field["name"]
return None
def get_varchar_field(schema):
for field in schema["fields"]:
if field["data_type"] == 21:
return field["name"]
return None
def gen_vectors(nq=None, schema=None):
if nq is None:
nq = ct.default_nq
dim = ct.default_dim
data_type = 101
for field in schema["fields"]:
if field["data_type"] in [100, 101]:
dim = ct.default_dim
data_type = field["data_type"]
for k, v in field["type_params"]:
if k == "dim":
dim = int(v)
break
if data_type == 100:
return gen_binary_vectors(nq, dim)
if data_type == 101:
return gen_float_vectors(nq, dim)
def gen_float_vectors(nb, dim):
return [[np.float64(random.uniform(-1.0, 1.0)) for _ in range(dim)] for _ in range(nb)] # json not support float32
def gen_binary_vectors(nb, dim):
raw_vectors = []
binary_vectors = []
for _ in range(nb):
raw_vector = [random.randint(0, 1) for _ in range(dim)]
raw_vectors.append(raw_vector)
# packs a binary-valued array into bits in a unit8 array, and bytes array_of_ints
binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist()))
return binary_vectors
def gen_index_params(index_type=None):
if index_type is None:
index_params = ct.default_index_params
else:
index_params = ct.all_index_params_map[index_type]
extra_params = []
for k, v in index_params.items():
item = {"key": k, "value": json.dumps(v) if isinstance(v, dict) else str(v)}
extra_params.append(item)
return extra_params
def gen_search_param_by_index_type(index_type, metric_type="L2"):
search_params = []
if index_type in ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ"]:
for nprobe in [10]:
ivf_search_params = {"metric_type": metric_type, "params": {"nprobe": nprobe}}
search_params.append(ivf_search_params)
elif index_type in ["BIN_FLAT", "BIN_IVF_FLAT"]:
for nprobe in [10]:
bin_search_params = {"metric_type": "HAMMING", "params": {"nprobe": nprobe}}
search_params.append(bin_search_params)
elif index_type in ["HNSW"]:
for ef in [64]:
hnsw_search_param = {"metric_type": metric_type, "params": {"ef": ef}}
search_params.append(hnsw_search_param)
elif index_type == "ANNOY":
for search_k in [1000]:
annoy_search_param = {"metric_type": metric_type, "params": {"search_k": search_k}}
search_params.append(annoy_search_param)
else:
log.info("Invalid index_type.")
raise Exception("Invalid index_type.")
return search_params
def gen_search_params(index_type=None, anns_field=ct.default_float_vec_field_name,
topk=ct.default_top_k):
if index_type is None:
search_params = gen_search_param_by_index_type(ct.default_index_type)[0]
else:
search_params = gen_search_param_by_index_type(index_type)[0]
extra_params = []
for k, v in search_params.items():
item = {"key": k, "value": json.dumps(v) if isinstance(v, dict) else str(v)}
extra_params.append(item)
extra_params.append({"key": "anns_field", "value": anns_field})
extra_params.append({"key": "topk", "value": str(topk)})
return extra_params
def gen_search_vectors(dim, nb, is_binary=False):
if is_binary:
return gen_binary_vectors(nb, dim)
return gen_float_vectors(nb, dim)
def modify_file(file_path_list, is_modify=False, input_content=""):
"""
file_path_list : file list -> list[<file_path>]
is_modify : does the file need to be reset
input_content :the content that need to insert to the file
"""
if not isinstance(file_path_list, list):
log.error("[modify_file] file is not a list.")
for file_path in file_path_list:
folder_path, file_name = os.path.split(file_path)
if not os.path.isdir(folder_path):
log.debug("[modify_file] folder(%s) is not exist." % folder_path)
os.makedirs(folder_path)
if not os.path.isfile(file_path):
log.error("[modify_file] file(%s) is not exist." % file_path)
else:
if is_modify is True:
log.debug("[modify_file] start modifying file(%s)..." % file_path)
with open(file_path, "r+") as f:
f.seek(0)
f.truncate()
f.write(input_content)
f.close()
log.info("[modify_file] file(%s) modification is complete." % file_path_list)
if __name__ == '__main__':
a = gen_binary_vectors(10, 128)
print(a)
""" Initialized parameters """
port = 19530
epsilon = 0.000001
namespace = "milvus"
default_flush_interval = 1
big_flush_interval = 1000
default_drop_interval = 3
default_dim = 128
default_nb = 3000
default_nb_medium = 5000
default_top_k = 10
default_nq = 2
default_limit = 10
default_search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
default_search_ip_params = {"metric_type": "IP", "params": {"nprobe": 10}}
default_search_binary_params = {"metric_type": "JACCARD", "params": {"nprobe": 10}}
default_index_type = "HNSW"
default_index_params = {"index_type": "HNSW", "params": {"M": 48, "efConstruction": 500}, "metric_type": "L2"}
default_varchar_index = {}
default_binary_index = {"index_type": "BIN_IVF_FLAT", "params": {"nlist": 128}, "metric_type": "JACCARD"}
default_diskann_index = {"index_type": "DISKANN", "metric_type": "L2", "params": {}}
default_diskann_search_params = {"metric_type": "L2", "params": {"search_list": 30}}
max_top_k = 16384
max_partition_num = 4096 # 256
default_segment_row_limit = 1000
default_server_segment_row_limit = 1024 * 512
default_alias = "default"
default_user = "root"
default_password = "Milvus"
default_bool_field_name = "Bool"
default_int8_field_name = "Int8"
default_int16_field_name = "Int16"
default_int32_field_name = "Int32"
default_int64_field_name = "Int64"
default_float_field_name = "Float"
default_double_field_name = "Double"
default_string_field_name = "Varchar"
default_float_vec_field_name = "FloatVector"
another_float_vec_field_name = "FloatVector1"
default_binary_vec_field_name = "BinaryVector"
default_partition_name = "_default"
default_tag = "1970_01_01"
row_count = "row_count"
default_length = 65535
default_desc = ""
default_collection_desc = "default collection"
default_index_name = "default_index_name"
default_binary_desc = "default binary collection"
collection_desc = "collection"
int_field_desc = "int64 type field"
float_field_desc = "float type field"
float_vec_field_desc = "float vector type field"
binary_vec_field_desc = "binary vector type field"
max_dim = 32768
min_dim = 1
gracefulTime = 1
default_nlist = 128
compact_segment_num_threshold = 4
compact_delta_ratio_reciprocal = 5 # compact_delta_binlog_ratio is 0.2
compact_retention_duration = 40 # compaction travel time retention range 20s
max_compaction_interval = 60 # the max time interval (s) from the last compaction
max_field_num = 256 # Maximum number of fields in a collection
default_dsl = f"{default_int64_field_name} in [2,4,6,8]"
default_expr = f"{default_int64_field_name} in [2,4,6,8]"
metric_types = []
all_index_types = ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ", "HNSW", "ANNOY", "DISKANN", "BIN_FLAT", "BIN_IVF_FLAT"]
all_index_params_map = {"FLAT": {"index_type": "FLAT", "params": {"nlist": 128}, "metric_type": "L2"},
"IVF_FLAT": {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"},
"IVF_SQ8": {"index_type": "IVF_SQ8", "params": {"nlist": 128}, "metric_type": "L2"},
"IVF_PQ": {"index_type": "IVF_PQ", "params": {"nlist": 128, "m": 16, "nbits": 8},
"metric_type": "L2"},
"HNSW": {"index_type": "HNSW", "params": {"M": 48, "efConstruction": 500}, "metric_type": "L2"},
"ANNOY": {"index_type": "ANNOY", "params": {"n_trees": 50}, "metric_type": "L2"},
"DISKANN": {"index_type": "DISKANN", "params": {}, "metric_type": "L2"},
"BIN_FLAT": {"index_type": "BIN_FLAT", "params": {"nlist": 128}, "metric_type": "JACCARD"},
"BIN_IVF_FLAT": {"index_type": "BIN_IVF_FLAT", "params": {"nlist": 128},
"metric_type": "JACCARD"}
}
import os
class LogConfig:
def __init__(self):
self.log_debug = ""
self.log_err = ""
self.log_info = ""
self.log_worker = ""
self.get_default_config()
@staticmethod
def get_env_variable(var="CI_LOG_PATH"):
""" get log path for testing """
try:
log_path = os.environ[var]
return str(log_path)
except Exception as e:
# now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
log_path = f"/tmp/ci_logs"
print("[get_env_variable] failed to get environment variables : %s, use default path : %s" % (str(e), log_path))
return log_path
@staticmethod
def create_path(log_path):
if not os.path.isdir(str(log_path)):
print("[create_path] folder(%s) is not exist." % log_path)
print("[create_path] create path now...")
os.makedirs(log_path)
def get_default_config(self):
""" Make sure the path exists """
log_dir = self.get_env_variable()
self.log_debug = "%s/ci_test_log.debug" % log_dir
self.log_info = "%s/ci_test_log.log" % log_dir
self.log_err = "%s/ci_test_log.err" % log_dir
work_log = os.environ.get('PYTEST_XDIST_WORKER')
if work_log is not None:
self.log_worker = f'{log_dir}/{work_log}.log'
self.create_path(log_dir)
log_config = LogConfig()
import pytest
import common.common_func as cf
from check.param_check import ip_check, number_check
from config.log_config import log_config
from utils.util_log import test_log as log
from common.common_func import param_info
def pytest_addoption(parser):
parser.addoption("--host", action="store", default="127.0.0.1", help="Milvus host")
parser.addoption("--port", action="store", default="9091", help="Milvus http port")
parser.addoption('--clean_log', action='store_true', default=False, help="clean log before testing")
@pytest.fixture
def host(request):
return request.config.getoption("--host")
@pytest.fixture
def port(request):
return request.config.getoption("--port")
@pytest.fixture
def clean_log(request):
return request.config.getoption("--clean_log")
@pytest.fixture(scope="session", autouse=True)
def initialize_env(request):
""" clean log before testing """
host = request.config.getoption("--host")
port = request.config.getoption("--port")
clean_log = request.config.getoption("--clean_log")
""" params check """
assert ip_check(host) and number_check(port)
""" modify log files """
file_path_list = [log_config.log_debug, log_config.log_info, log_config.log_err]
if log_config.log_worker != "":
file_path_list.append(log_config.log_worker)
cf.modify_file(file_path_list=file_path_list, is_modify=clean_log)
log.info("#" * 80)
log.info("[initialize_milvus] Log cleaned up, start testing...")
param_info.prepare_param_info(host, port)
# generated by datamodel-codegen:
# filename: openapi.json
# timestamp: 2022-12-08T02:46:08+00:00
# generated by datamodel-codegen:
# filename: openapi.json
# timestamp: 2022-12-08T02:46:08+00:00
from __future__ import annotations
from typing import List, Optional
from pydantic import BaseModel, Field
class KeyDataPair(BaseModel):
data: Optional[List[int]] = None
key: Optional[str] = None
class KeyValuePair(BaseModel):
key: Optional[str] = Field(None, example='dim')
value: Optional[str] = Field(None, example='128')
class MsgBase(BaseModel):
msg_type: Optional[int] = Field(None, description='Not useful for now')
class Status(BaseModel):
error_code: Optional[int] = None
reason: Optional[str] = None
# generated by datamodel-codegen:
# filename: openapi.json
# timestamp: 2022-12-08T02:46:08+00:00
from __future__ import annotations
from typing import List, Optional
from pydantic import BaseModel, Field
from models import common, schema
class DescribeCollectionRequest(BaseModel):
collection_name: Optional[str] = None
collectionID: Optional[int] = Field(
None, description='The collection ID you want to describe'
)
time_stamp: Optional[int] = Field(
None,
description='If time_stamp is not zero, will describe collection success when time_stamp >= created collection timestamp, otherwise will throw error.',
)
class DropCollectionRequest(BaseModel):
collection_name: Optional[str] = Field(
None, description='The unique collection name in milvus.(Required)'
)
class FieldData(BaseModel):
field: Optional[List] = None
field_id: Optional[int] = None
field_name: Optional[str] = None
type: Optional[int] = Field(
None,
description='0: "None",\n1: "Bool",\n2: "Int8",\n3: "Int16",\n4: "Int32",\n5: "Int64",\n10: "Float",\n11: "Double",\n20: "String",\n21: "VarChar",\n100: "BinaryVector",\n101: "FloatVector",',
)
class GetCollectionStatisticsRequest(BaseModel):
collection_name: Optional[str] = Field(
None, description='The collection name you want get statistics'
)
class HasCollectionRequest(BaseModel):
collection_name: Optional[str] = Field(
None, description='The unique collection name in milvus.(Required)'
)
time_stamp: Optional[int] = Field(
None,
description='If time_stamp is not zero, will return true when time_stamp >= created collection timestamp, otherwise will return false.',
)
class InsertRequest(BaseModel):
base: Optional[common.MsgBase] = None
collection_name: Optional[str] = None
db_name: Optional[str] = None
fields_data: Optional[List[FieldData]] = None
hash_keys: Optional[List[int]] = None
num_rows: Optional[int] = None
partition_name: Optional[str] = None
class LoadCollectionRequest(BaseModel):
collection_name: Optional[str] = Field(
None, description='The collection name you want to load'
)
replica_number: Optional[int] = Field(
None, description='The replica number to load, default by 1'
)
class ReleaseCollectionRequest(BaseModel):
collection_name: Optional[str] = Field(
None, description='The collection name you want to release'
)
class ShowCollectionsRequest(BaseModel):
collection_names: Optional[List[str]] = Field(
None,
description="When type is InMemory, will return these collection's inMemory_percentages.(Optional)",
)
type: Optional[int] = Field(
None,
description='Decide return Loaded collections or All collections(Optional)',
)
class VectorIDs(BaseModel):
collection_name: Optional[str] = None
field_name: Optional[str] = None
id_array: Optional[List[int]] = None
partition_names: Optional[List[str]] = None
class VectorsArray(BaseModel):
binary_vectors: Optional[List[int]] = Field(
None,
description='Vectors is an array of binary vector divided by given dim. Disabled when IDs is set',
)
dim: Optional[int] = Field(
None, description='Dim of vectors or binary_vectors, not needed when use ids'
)
ids: Optional[VectorIDs] = None
vectors: Optional[List[float]] = Field(
None,
description='Vectors is an array of vector divided by given dim. Disabled when ids or binary_vectors is set',
)
class CalcDistanceRequest(BaseModel):
base: Optional[common.MsgBase] = None
op_left: Optional[VectorsArray] = None
op_right: Optional[VectorsArray] = None
params: Optional[List[common.KeyValuePair]] = None
class CreateCollectionRequest(BaseModel):
collection_name: str = Field(
...,
description='The unique collection name in milvus.(Required)',
example='book',
)
consistency_level: int = Field(
...,
description='The consistency level that the collection used, modification is not supported now.\n"Strong": 0,\n"Session": 1,\n"Bounded": 2,\n"Eventually": 3,\n"Customized": 4,',
example=1,
)
schema_: schema.CollectionSchema = Field(..., alias='schema')
shards_num: Optional[int] = Field(
None,
description='Once set, no modification is allowed (Optional)\nhttps://github.com/milvus-io/milvus/issues/6690',
example=1,
)
# generated by datamodel-codegen:
# filename: openapi.json
# timestamp: 2022-12-08T02:46:08+00:00
from __future__ import annotations
from typing import Any, List, Optional
from pydantic import BaseModel, Field
from models import common
class FieldData(BaseModel):
field: Optional[Any] = Field(
None,
description='Types that are assignable to Field:\n\t*FieldData_Scalars\n\t*FieldData_Vectors',
)
field_id: Optional[int] = None
field_name: Optional[str] = None
type: Optional[int] = Field(
None,
description='0: "None",\n1: "Bool",\n2: "Int8",\n3: "Int16",\n4: "Int32",\n5: "Int64",\n10: "Float",\n11: "Double",\n20: "String",\n21: "VarChar",\n100: "BinaryVector",\n101: "FloatVector",',
)
class FieldSchema(BaseModel):
autoID: Optional[bool] = None
data_type: int = Field(
...,
description='0: "None",\n1: "Bool",\n2: "Int8",\n3: "Int16",\n4: "Int32",\n5: "Int64",\n10: "Float",\n11: "Double",\n20: "String",\n21: "VarChar",\n100: "BinaryVector",\n101: "FloatVector",',
example=101,
)
description: Optional[str] = Field(
None, example='embedded vector of book introduction'
)
fieldID: Optional[int] = None
index_params: Optional[List[common.KeyValuePair]] = None
is_primary_key: Optional[bool] = Field(None, example=False)
name: str = Field(..., example='book_intro')
type_params: Optional[List[common.KeyValuePair]] = None
class IDs(BaseModel):
idField: Optional[Any] = Field(
None,
description='Types that are assignable to IdField:\n\t*IDs_IntId\n\t*IDs_StrId',
)
class LongArray(BaseModel):
data: Optional[List[int]] = None
class SearchResultData(BaseModel):
fields_data: Optional[List[FieldData]] = None
ids: Optional[IDs] = None
num_queries: Optional[int] = None
scores: Optional[List[float]] = None
top_k: Optional[int] = None
topks: Optional[List[int]] = None
class CollectionSchema(BaseModel):
autoID: Optional[bool] = Field(
None,
description='deprecated later, keep compatible with c++ part now',
example=False,
)
description: Optional[str] = Field(None, example='Test book search')
fields: Optional[List[FieldSchema]] = None
name: str = Field(..., example='book')
此差异已折叠。
[pytest]
addopts = --host 10.101.178.131 --html=/tmp/ci_logs/report.html --self-contained-html -v
# python3 -W ignore -m pytest
log_format = [%(asctime)s - %(levelname)s - %(name)s]: %(message)s (%(filename)s:%(lineno)s)
log_date_format = %Y-%m-%d %H:%M:%S
filterwarnings =
ignore::DeprecationWarning
\ No newline at end of file
decorest~=0.1.0
pydantic~=1.10.2
\ No newline at end of file
from time import sleep
from common import common_type as ct
from common import common_func as cf
from base.client_base import TestBase
from utils.util_log import test_log as log
class TestDefault(TestBase):
def test_e2e(self):
collection_name, schema = self.init_collection()
nb = ct.default_nb
# insert
res = self.entity_service.insert(collection_name=collection_name, fields_data=cf.gen_fields_data(schema, nb=nb),
num_rows=nb)
log.info(f"insert {nb} rows into collection {collection_name}, response: {res}")
# flush
res = self.entity_service.flush(collection_names=[collection_name])
log.info(f"flush collection {collection_name}, response: {res}")
# create index for vector field
vector_field_name = cf.get_vector_field(schema)
vector_index_params = cf.gen_index_params(index_type="HNSW")
res = self.index_service.create_index(collection_name=collection_name, field_name=vector_field_name,
extra_params=vector_index_params)
log.info(f"create index for vector field {vector_field_name}, response: {res}")
# load
res = self.collection_service.load_collection(collection_name=collection_name)
log.info(f"load collection {collection_name}, response: {res}")
sleep(5)
# search
vectors = cf.gen_vectors(nq=ct.default_nq, schema=schema)
res = self.entity_service.search(collection_name=collection_name, vectors=vectors,
output_fields=[ct.default_int64_field_name],
search_params=cf.gen_search_params())
log.info(f"search collection {collection_name}, response: {res}")
# hybrid search
res = self.entity_service.search(collection_name=collection_name, vectors=vectors,
output_fields=[ct.default_int64_field_name],
search_params=cf.gen_search_params(),
dsl=ct.default_dsl)
log.info(f"hybrid search collection {collection_name}, response: {res}")
# query
res = self.entity_service.query(collection_name=collection_name, expr=ct.default_expr)
log.info(f"query collection {collection_name}, response: {res}")
import logging
import sys
from config.log_config import log_config
class TestLog:
def __init__(self, logger, log_debug, log_file, log_err, log_worker):
self.logger = logger
self.log_debug = log_debug
self.log_file = log_file
self.log_err = log_err
self.log_worker = log_worker
self.log = logging.getLogger(self.logger)
self.log.setLevel(logging.DEBUG)
try:
formatter = logging.Formatter("[%(asctime)s - %(levelname)s - %(name)s]: "
"%(message)s (%(filename)s:%(lineno)s)")
# [%(process)s] process NO.
dh = logging.FileHandler(self.log_debug)
dh.setLevel(logging.DEBUG)
dh.setFormatter(formatter)
self.log.addHandler(dh)
fh = logging.FileHandler(self.log_file)
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
self.log.addHandler(fh)
eh = logging.FileHandler(self.log_err)
eh.setLevel(logging.ERROR)
eh.setFormatter(formatter)
self.log.addHandler(eh)
if self.log_worker != "":
wh = logging.FileHandler(self.log_worker)
wh.setLevel(logging.DEBUG)
wh.setFormatter(formatter)
self.log.addHandler(wh)
ch = logging.StreamHandler(sys.stdout)
ch.setLevel(logging.DEBUG)
ch.setFormatter(formatter)
# self.log.addHandler(ch)
except Exception as e:
print("Can not use %s or %s or %s to log. error : %s" % (log_debug, log_file, log_err, str(e)))
"""All modules share this unified log"""
log_debug = log_config.log_debug
log_info = log_config.log_info
log_err = log_config.log_err
log_worker = log_config.log_worker
test_log = TestLog('ci_test', log_debug, log_info, log_err, log_worker).log
import time
from datetime import datetime
import functools
from utils.util_log import test_log as log
DEFAULT_FMT = '[{start_time}] [{elapsed:0.8f}s] {collection_name} {func_name} -> {res!r}'
def trace(fmt=DEFAULT_FMT, prefix='test', flag=True):
def decorate(func):
@functools.wraps(func)
def inner_wrapper(*args, **kwargs):
# args[0] is an instance of ApiCollectionWrapper class
flag = args[0].active_trace
if flag:
start_time = datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ')
t0 = time.perf_counter()
res, result = func(*args, **kwargs)
elapsed = time.perf_counter() - t0
end_time = datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ')
func_name = func.__name__
collection_name = args[0].collection.name
# arg_lst = [repr(arg) for arg in args[1:]][:100]
# arg_lst.extend(f'{k}={v!r}' for k, v in kwargs.items())
# arg_str = ', '.join(arg_lst)[:200]
log_str = f"[{prefix}]" + fmt.format(**locals())
# TODO: add report function in this place, like uploading to influxdb
# it is better a async way to do this, in case of blocking the request processing
log.info(log_str)
return res, result
else:
res, result = func(*args, **kwargs)
return res, result
return inner_wrapper
return decorate
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册