未验证 提交 10ca86ef 编写于 作者: W wt 提交者: GitHub

[skip ci] Add comment for the accuracy test (#8092)

Signed-off-by: Nwangting0128 <ting.wang@zilliz.com>
上级 5a008c72
......@@ -120,6 +120,7 @@ class AccAccuracyRunner(AccuracyRunner):
def extract_cases(self, collection):
collection_name = collection["collection_name"] if "collection_name" in collection else None
(data_type, dimension, metric_type) = parser.parse_ann_collection_name(collection_name)
# hdf5_source_file: The path of the source data file saved on the NAS
hdf5_source_file = collection["source_file"]
index_types = collection["index_types"]
index_params = collection["index_params"]
......@@ -136,11 +137,14 @@ class AccAccuracyRunner(AccuracyRunner):
}
filters = collection["filters"] if "filters" in collection else []
filter_query = []
# Convert list data into a set of dictionary data
search_params = utils.generate_combinations(search_params)
index_params = utils.generate_combinations(index_params)
cases = list()
case_metrics = list()
self.init_metric(self.name, collection_info, {}, search_info=None)
# true_ids: The data set used to verify the results returned by query
true_ids = np.array(dataset["neighbors"])
for index_type in index_types:
for index_param in index_params:
......@@ -192,11 +196,14 @@ class AccAccuracyRunner(AccuracyRunner):
"vector_query": vector_query,
"true_ids": true_ids
}
# Obtain the parameters of the use case to be tested
cases.append(case)
case_metrics.append(case_metric)
return cases, case_metrics
def prepare(self, **case_param):
""" According to the test case parameters, initialize the test """
collection_name = case_param["collection_name"]
metric_type = case_param["metric_type"]
dimension = case_param["dimension"]
......@@ -211,6 +218,7 @@ class AccAccuracyRunner(AccuracyRunner):
self.milvus.drop()
dataset = case_param["dataset"]
self.milvus.create_collection(dimension, data_type=vector_type)
# Get the data set train for inserting into the collection
insert_vectors = utils.normalize(metric_type, np.array(dataset["train"]))
if len(insert_vectors) != dataset["train"].shape[0]:
raise Exception("Row count of insert vectors: %d is not equal to dataset size: %d" % (
......@@ -224,6 +232,7 @@ class AccAccuracyRunner(AccuracyRunner):
start = i * INSERT_INTERVAL
end = min((i + 1) * INSERT_INTERVAL, len(insert_vectors))
if start < end:
# Insert up to INSERT_INTERVAL=50000 at a time
tmp_vectors = insert_vectors[start:end]
ids = [i for i in range(start, end)]
if not isinstance(tmp_vectors, list):
......@@ -256,7 +265,9 @@ class AccAccuracyRunner(AccuracyRunner):
top_k = case_metric.search["topk"]
query_res = self.milvus.query(case_param["vector_query"], filter_query=case_param["filter_query"])
result_ids = self.milvus.get_ids(query_res)
# Calculate the accuracy of the result of query
acc_value = utils.get_recall_value(true_ids[:nq, :top_k].tolist(), result_ids)
tmp_result = {"acc": acc_value}
# Return accuracy results for reporting
return tmp_result
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册