提交 b56237dc 编写于 作者: W weishengyu

update trainer

上级 ed098b3c
......@@ -158,7 +158,6 @@ class Trainer(object):
for epoch_id in range(best_metric["epoch"] + 1,
self.config["Global"]["epochs"] + 1):
acc = 0.0
self.model.train()
for iter_id, batch in enumerate(self.train_dataloader()):
batch_size = batch[0].shape[0]
batch[1] = paddle.to_tensor(batch[1].numpy().astype("int64")
......@@ -241,34 +240,34 @@ class Trainer(object):
@paddle.no_grad()
def eval(self, epoch_id=0):
self.model.eval()
if self.eval_mode == "classification":
if self.eval_dataloader is None:
self.eval_dataloader = build_dataloader(self.config["DataLoader"],
"Eval", self.device)
self.eval_dataloader = build_dataloader(
self.config["DataLoader"], "Eval", self.device)
eval_result = self.eval_cls(epoch_id)
elif self.eval_mode == "retrieval":
if self.gallery_dataloader is None:
self.gallery_dataloader = build_dataloader(
self.config["DataLoader"], "Gallery", self.device)
if self.query_dataloader is None:
self.query_dataloader = build_dataloader(self.config["DataLoader"],
"Query", self.device)
self.query_dataloader = build_dataloader(
self.config["DataLoader"], "Query", self.device)
# build train loss and metric info
if self.eval_loss_func is None:
self.eval_loss_func = self._build_loss_info(self.config["Loss"],
"eval")
self.eval_loss_func = self._build_loss_info(
self.config["Loss"], "eval")
if self.eval_metric_func is None:
self.eval_metric_func = self._build_metric_info(
self.config["Metric"], "eval")
self.model.eval()
if self.eval_mode == "classification":
self.eval_cls(epoch_id)
elif self.eval_mode == "retrieval":
self.eval_retrieval(epoch_id)
eval_result = self.eval_retrieval(epoch_id)
else:
logger.warning("Invalid eval mode: {}".format(self.eval_mode))
eval_result = None
self.model.train()
return eval_result
def eval_cls(self, epoch_id=0):
output_info = dict()
......@@ -332,9 +331,8 @@ class Trainer(object):
return output_info[metric_key].avg
def eval_retrieval(self, epoch_id=0):
output_info = dict()
self.model.eval()
cum_similarity_matrix = None
# step1. build gallery
gallery_feas, gallery_img_id, gallery_camera_id = self._cal_feature(
name='gallery')
......@@ -342,7 +340,7 @@ class Trainer(object):
name='query')
gallery_img_id = paddle.to_tensor([gallery_img_id]).t()
if gallery_camera_id is not None:
gallery_camera_id = paddle.to_tensor(gallery_camera_id).t()
gallery_camera_id = paddle.to_tensor([gallery_camera_id]).t()
query_img_id = paddle.to_tensor(query_img_id)
if query_camera_id is not None:
query_camera_id = paddle.to_tensor(query_camera_id)
......@@ -352,35 +350,37 @@ class Trainer(object):
if not len(query_feas) % sim_block_size:
sections.append(len(query_feas) % sim_block_size)
fea_blocks = paddle.split(query_feas, num_or_sections=sections)
if query_camera_id is not None:
camera_id_blocks = paddle.split(
query_camera_id, num_or_sections=sections)
image_id_blocks = paddle.split(query_img_id, num_or_sections=sections)
metric_key = None
for block_idx, block_fea in enumerate(fea_blocks):
similarities_matrix = paddle.matmul(
similarity_matrix = paddle.matmul(
block_fea, gallery_feas, transpose_y=True)
image_id_block = image_id_blocks[block_idx]
image_id_mask = (image_id_block == gallery_img_id)
similarities_matrix = similarities_matrix.masked_select(
image_id_mask)
image_id_mask = (image_id_block != gallery_img_id)
similarity_matrix = similarity_matrix.masked_select(image_id_mask)
if query_camera_id is not None:
camera_id_block = camera_id_blocks[block_idx]
camera_id_mask = (camera_id_block == gallery_camera_id)
similarities_matrix = similarities_matrix.masked_select(
camera_id_mask = (camera_id_block != gallery_camera_id)
similarity_matrix = similarity_matrix.masked_select(
camera_id_mask)
if similarity_matrix is None:
cum_similarity_matrix = similarity_matrix
else:
cum_similarity_matrix = paddle.concat(cum_similarity_matrix,
similarity_matrix)
# calc metric
if self.eval_metric_func is not None:
metric_dict = self.eval_metric_func(similarities_matrix,
image_id_block)
for key in metric_dict:
if metric_key is None:
metric_key = key
if not key in output_info:
output_info[key] = AverageMeter(key, '7.5f')
metric_dict = self.eval_metric_func(cum_similarity_matrix,
query_img_id, gallery_img_id)
else:
metric_dict = {metric_key: 0.}
output_info[key].update(metric_dict[key].numpy()[0],
len(image_id_block))
return metric_dict[metric_key]
def _cal_feature(self, name='gallery'):
all_feas = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册