提交 0a8b1773 编写于 作者: W weishengyu

dbg

上级 2a033328
......@@ -97,16 +97,12 @@ class Trainer(object):
# build train loss and metric info
if self.train_loss_func is None:
loss_info = self.config.get("Loss", None)
if loss_info is None:
loss_info = [{"CELoss": {"weight": 1.0}}]
else:
if loss_info is not None:
loss_info = loss_info["Train"]
self.train_loss_func = build_loss(loss_info)
if self.train_metric_func is None:
metric_config = self.config.get("Metric", None)
if metric_config is None:
metric_config = [{"name": "TopkAcc", "topk": (1, 5)}]
else:
if metric_config is not None:
metric_config = metric_config["Train"]
self.train_metric_func = build_metrics(metric_config)
......@@ -228,11 +224,9 @@ class Trainer(object):
self.model.eval()
if self.eval_loss_func is None:
loss_info = self.config.get("Loss", None)
if loss_info is None:
loss_info = [{"CELoss": {"weight": 1.0}}]
else:
if loss_info is not None:
loss_info = loss_info["Eval"]
self.eval_loss_func = build_loss(loss_info)
self.eval_loss_func = build_loss(loss_info)
if self.eval_mode == "classification":
if self.eval_dataloader is None:
self.eval_dataloader = build_dataloader(
......@@ -240,11 +234,9 @@ class Trainer(object):
if self.eval_metric_func is None:
metric_config = self.config.get("Metric", None)
if metric_config is None:
metric_config = [{"name": "TopkAcc", "topk": (1, 5)}]
else:
if metric_config is not None:
metric_config = metric_config["Eval"]
self.eval_metric_func = build_metrics(metric_config)
self.eval_metric_func = build_metrics(metric_config)
eval_result = self.eval_cls(epoch_id)
......@@ -358,14 +350,11 @@ class Trainer(object):
for block_idx, block_fea in enumerate(fea_blocks):
similarity_matrix = paddle.matmul(
block_fea, gallery_feas, transpose_y=True)
# image_id_block = image_id_blocks[block_idx]
# image_id_mask = (image_id_block != gallery_img_id)
# similarity_matrix = similarity_matrix.masked_select(image_id_mask)
# if query_camera_id is not None:
# camera_id_block = camera_id_blocks[block_idx]
# camera_id_mask = (camera_id_block != gallery_camera_id)
# similarity_matrix = similarity_matrix.masked_select(
# camera_id_mask)
if query_camera_id is not None:
camera_id_block = camera_id_blocks[block_idx]
camera_id_mask = (camera_id_block != gallery_camera_id)
similarity_matrix = similarity_matrix.masked_select(
camera_id_mask)
if cum_similarity_matrix is None:
cum_similarity_matrix = similarity_matrix
else:
......
......@@ -26,7 +26,6 @@ class CombinedMetrics(nn.Layer):
assert isinstance(config_list, list), (
'operator config should be a list')
for config in config_list:
print(config)
assert isinstance(config,
dict) and len(config) == 1, "yaml format error"
metric_name = list(config)[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册