未验证 提交 9561ccae 编写于 作者: B Bin Lu 提交者: GitHub

Merge pull request #809 from FredHuang16/patch-11

fix eval out-of-memory problem 
......@@ -408,6 +408,10 @@ class Trainer(object):
query_img_id, num_or_sections=sections)
metric_key = None
if self.eval_metric_func is None:
metric_dict = {metric_key: 0.}
else:
metric_dict = dict()
for block_idx, block_fea in enumerate(fea_blocks):
similarity_matrix = paddle.matmul(
block_fea, gallery_feas, transpose_y=True)
......@@ -419,22 +423,21 @@ class Trainer(object):
image_id_mask = (image_id_block != gallery_img_id.t())
keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
similarity_matrix = similarity_matrix * keep_mask.astype(
"float32")
if cum_similarity_matrix is None:
cum_similarity_matrix = similarity_matrix
else:
cum_similarity_matrix = paddle.concat(
[cum_similarity_matrix, similarity_matrix], axis=0)
similarity_matrix = similarity_matrix * keep_mask.astype("float32")
# calc metric
if self.eval_metric_func is not None:
metric_dict = self.eval_metric_func(cum_similarity_matrix,
query_img_id, gallery_img_id)
metric_tmp = self.eval_metric_func(similarity_matrix,image_id_blocks[block_idx], gallery_img_id)
for key in metric_tmp:
if key not in metric_dict:
metric_dict[key] = metric_tmp[key]
else:
metric_dict = {metric_key: 0.}
metric_info_list = []
metric_dict[key] += metric_tmp[key]
num_sections = len(fea_blocks)
for key in metric_dict:
metric_dict[key] = metric_dict[key]/num_sections
metric_info_list = []
for key in metric_dict:
if metric_key is None:
metric_key = key
......@@ -444,6 +447,7 @@ class Trainer(object):
return metric_dict[metric_key]
def _cal_feature(self, name='gallery'):
all_feas = None
all_image_id = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册