提交 0a8b1773 编写于 作者: W weishengyu

dbg

上级 2a033328
...@@ -97,16 +97,12 @@ class Trainer(object): ...@@ -97,16 +97,12 @@ class Trainer(object):
# build train loss and metric info # build train loss and metric info
if self.train_loss_func is None: if self.train_loss_func is None:
loss_info = self.config.get("Loss", None) loss_info = self.config.get("Loss", None)
if loss_info is None: if loss_info is not None:
loss_info = [{"CELoss": {"weight": 1.0}}]
else:
loss_info = loss_info["Train"] loss_info = loss_info["Train"]
self.train_loss_func = build_loss(loss_info) self.train_loss_func = build_loss(loss_info)
if self.train_metric_func is None: if self.train_metric_func is None:
metric_config = self.config.get("Metric", None) metric_config = self.config.get("Metric", None)
if metric_config is None: if metric_config is not None:
metric_config = [{"name": "TopkAcc", "topk": (1, 5)}]
else:
metric_config = metric_config["Train"] metric_config = metric_config["Train"]
self.train_metric_func = build_metrics(metric_config) self.train_metric_func = build_metrics(metric_config)
...@@ -228,11 +224,9 @@ class Trainer(object): ...@@ -228,11 +224,9 @@ class Trainer(object):
self.model.eval() self.model.eval()
if self.eval_loss_func is None: if self.eval_loss_func is None:
loss_info = self.config.get("Loss", None) loss_info = self.config.get("Loss", None)
if loss_info is None: if loss_info is not None:
loss_info = [{"CELoss": {"weight": 1.0}}]
else:
loss_info = loss_info["Eval"] loss_info = loss_info["Eval"]
self.eval_loss_func = build_loss(loss_info) self.eval_loss_func = build_loss(loss_info)
if self.eval_mode == "classification": if self.eval_mode == "classification":
if self.eval_dataloader is None: if self.eval_dataloader is None:
self.eval_dataloader = build_dataloader( self.eval_dataloader = build_dataloader(
...@@ -240,11 +234,9 @@ class Trainer(object): ...@@ -240,11 +234,9 @@ class Trainer(object):
if self.eval_metric_func is None: if self.eval_metric_func is None:
metric_config = self.config.get("Metric", None) metric_config = self.config.get("Metric", None)
if metric_config is None: if metric_config is not None:
metric_config = [{"name": "TopkAcc", "topk": (1, 5)}]
else:
metric_config = metric_config["Eval"] metric_config = metric_config["Eval"]
self.eval_metric_func = build_metrics(metric_config) self.eval_metric_func = build_metrics(metric_config)
eval_result = self.eval_cls(epoch_id) eval_result = self.eval_cls(epoch_id)
...@@ -358,14 +350,11 @@ class Trainer(object): ...@@ -358,14 +350,11 @@ class Trainer(object):
for block_idx, block_fea in enumerate(fea_blocks): for block_idx, block_fea in enumerate(fea_blocks):
similarity_matrix = paddle.matmul( similarity_matrix = paddle.matmul(
block_fea, gallery_feas, transpose_y=True) block_fea, gallery_feas, transpose_y=True)
# image_id_block = image_id_blocks[block_idx] if query_camera_id is not None:
# image_id_mask = (image_id_block != gallery_img_id) camera_id_block = camera_id_blocks[block_idx]
# similarity_matrix = similarity_matrix.masked_select(image_id_mask) camera_id_mask = (camera_id_block != gallery_camera_id)
# if query_camera_id is not None: similarity_matrix = similarity_matrix.masked_select(
# camera_id_block = camera_id_blocks[block_idx] camera_id_mask)
# camera_id_mask = (camera_id_block != gallery_camera_id)
# similarity_matrix = similarity_matrix.masked_select(
# camera_id_mask)
if cum_similarity_matrix is None: if cum_similarity_matrix is None:
cum_similarity_matrix = similarity_matrix cum_similarity_matrix = similarity_matrix
else: else:
......
...@@ -26,7 +26,6 @@ class CombinedMetrics(nn.Layer): ...@@ -26,7 +26,6 @@ class CombinedMetrics(nn.Layer):
assert isinstance(config_list, list), ( assert isinstance(config_list, list), (
'operator config should be a list') 'operator config should be a list')
for config in config_list: for config in config_list:
print(config)
assert isinstance(config, assert isinstance(config,
dict) and len(config) == 1, "yaml format error" dict) and len(config) == 1, "yaml format error"
metric_name = list(config)[0] metric_name = list(config)[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册