未验证 提交 47039c82 编写于 作者: F Felix 提交者: GitHub

Update trainer.py

上级 38201d26
......@@ -404,37 +404,40 @@ class Trainer(object):
if query_query_id is not None:
query_id_blocks = paddle.split(
query_query_id, num_or_sections=sections)
image_id_blocks = paddle.split(
query_img_id, num_or_sections=sections)
image_id_blocks = paddle.split(
query_img_id, num_or_sections=sections)
metric_key = None
for block_idx, block_fea in enumerate(fea_blocks):
similarity_matrix = paddle.matmul(
block_fea, gallery_feas, transpose_y=True)
if query_query_id is not None:
query_id_block = query_id_blocks[block_idx]
query_id_mask = (query_id_block != gallery_unique_id.t())
image_id_block = image_id_blocks[block_idx]
image_id_mask = (image_id_block != gallery_img_id.t())
keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
similarity_matrix = similarity_matrix * keep_mask.astype(
"float32")
if cum_similarity_matrix is None:
cum_similarity_matrix = similarity_matrix
else:
cum_similarity_matrix = paddle.concat(
[cum_similarity_matrix, similarity_matrix], axis=0)
# calc metric
if self.eval_metric_func is not None:
metric_dict = self.eval_metric_func(cum_similarity_matrix,
query_img_id, gallery_img_id)
else:
if self.eval_metric_func is None:
metric_dict = {metric_key: 0.}
else:
metric_dict = dict()
for block_idx, block_fea in enumerate(fea_blocks):
similarity_matrix = paddle.matmul(
block_fea, gallery_feas, transpose_y=True)
if query_query_id is not None:
query_id_block = query_id_blocks[block_idx]
query_id_mask = (query_id_block != gallery_unique_id.t())
image_id_block = image_id_blocks[block_idx]
image_id_mask = (image_id_block != gallery_img_id.t())
keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
similarity_matrix = similarity_matrix * keep_mask.astype("float32")
metric_tmp = self.eval_metric_func(similarity_matrix,image_id_blocks[block_idx], gallery_img_id)
for key in metric_tmp:
if key not in metric_dict:
metric_dict[key] = metric_tmp[key]
else:
metric_dict[key] += metric_tmp[key]
num_sections = len(fea_blocks)
for key in metric_dict:
metric_dict[key] = metric_dict[key]/num_sections
metric_info_list = []
for key in metric_dict:
if metric_key is None:
metric_key = key
......@@ -442,7 +445,8 @@ class Trainer(object):
metric_msg = ", ".join(metric_info_list)
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
return metric_dict[metric_key]
return metric_dict[metric_key]
def _cal_feature(self, name='gallery'):
all_feas = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册