提交 3e4af044 编写于 作者: W weishengyu

add default metrics

上级 5dc93ac0
...@@ -93,37 +93,17 @@ class Trainer(object): ...@@ -93,37 +93,17 @@ class Trainer(object):
self.train_metric_func = None self.train_metric_func = None
self.eval_metric_func = None self.eval_metric_func = None
def _build_metric_info(self, metric_config, mode="train"):
"""
_build_metric_info: build metrics according to current mode
Return:
metric: dict of the metrics info
"""
metric = None
mode = mode.capitalize()
if mode in metric_config and metric_config[mode] is not None:
metric = build_metrics(metric_config[mode])
return metric
def _build_loss_info(self, loss_config, mode="train"):
"""
_build_loss_info: build loss according to current mode
Return:
loss_dict: dict of the loss info
"""
loss = None
mode = mode.capitalize()
if mode in loss_config and loss_config[mode] is not None:
loss = build_loss(loss_config[mode])
return loss
def train(self): def train(self):
# build train loss and metric info # build train loss and metric info
if self.train_loss_func is None: if self.train_loss_func is None:
self.train_loss_func = self._build_loss_info(self.config["Loss"]) self.train_loss_func = build_loss(self.config["Loss"])
if "Metric" in self.config and self.train_metric_func is None: if self.train_metric_func is None:
self.train_metric_func = self._build_metric_info(self.config[ metric_config = self.config.get("Metric", None)
"Metric"]) if metric_config is None:
metric_config = [{"name": "TopkAcc", "topk": (1, 5)}]
else:
metric_config = metric_config["Train"]
self.train_metric_func = build_metrics(metric_config)
if self.train_dataloader is None: if self.train_dataloader is None:
self.train_dataloader = build_dataloader(self.config["DataLoader"], self.train_dataloader = build_dataloader(self.config["DataLoader"],
...@@ -241,10 +221,26 @@ class Trainer(object): ...@@ -241,10 +221,26 @@ class Trainer(object):
@paddle.no_grad() @paddle.no_grad()
def eval(self, epoch_id=0): def eval(self, epoch_id=0):
self.model.eval() self.model.eval()
if self.eval_loss_func is None:
loss_info = self.config.get("Loss", None)
if loss_info is None:
loss_info = [{"CELoss": {"weight": 1.0}}]
else:
loss_info = loss_info["Eval"]
self.eval_loss_func = build_loss(loss_info)
if self.eval_mode == "classification": if self.eval_mode == "classification":
if self.eval_dataloader is None: if self.eval_dataloader is None:
self.eval_dataloader = build_dataloader( self.eval_dataloader = build_dataloader(
self.config["DataLoader"], "Eval", self.device) self.config["DataLoader"], "Eval", self.device)
if self.eval_metric_func is None:
metric_config = self.config.get("Metric", None)
if metric_config is None:
metric_config = [{"name": "TopkAcc", "topk": (1, 5)}]
else:
metric_config = metric_config["Eval"]
self.eval_metric_func = build_metrics(metric_config)
eval_result = self.eval_cls(epoch_id) eval_result = self.eval_cls(epoch_id)
elif self.eval_mode == "retrieval": elif self.eval_mode == "retrieval":
...@@ -255,13 +251,14 @@ class Trainer(object): ...@@ -255,13 +251,14 @@ class Trainer(object):
if self.query_dataloader is None: if self.query_dataloader is None:
self.query_dataloader = build_dataloader( self.query_dataloader = build_dataloader(
self.config["DataLoader"], "Query", self.device) self.config["DataLoader"], "Query", self.device)
# build train loss and metric info # build metric info
if self.eval_loss_func is None:
self.eval_loss_func = self._build_loss_info(
self.config["Loss"], "eval")
if self.eval_metric_func is None: if self.eval_metric_func is None:
self.eval_metric_func = self._build_metric_info( metric_config = self.config.get("Metric", None)
self.config["Metric"], "eval") if metric_config is None:
metric_config = [{"name": "Recallk", "topk": (1, 5)}]
else:
metric_config = metric_config["Eval"]
self.eval_metric_func = build_metrics(metric_config)
eval_result = self.eval_retrieval(epoch_id) eval_result = self.eval_retrieval(epoch_id)
else: else:
logger.warning("Invalid eval mode: {}".format(self.eval_mode)) logger.warning("Invalid eval mode: {}".format(self.eval_mode))
......
...@@ -16,7 +16,7 @@ from paddle import nn ...@@ -16,7 +16,7 @@ from paddle import nn
import copy import copy
from collections import OrderedDict from collections import OrderedDict
from .metrics import Topk, mAP, mINP, Recallk from .metrics import TopkAcc, mAP, mINP, Recallk
class CombinedMetrics(nn.Layer): class CombinedMetrics(nn.Layer):
......
...@@ -18,7 +18,7 @@ import paddle.nn as nn ...@@ -18,7 +18,7 @@ import paddle.nn as nn
# TODO: fix the format # TODO: fix the format
class Topk(nn.Layer): class TopkAcc(nn.Layer):
def __init__(self, topk=(1, 5)): def __init__(self, topk=(1, 5)):
super().__init__() super().__init__()
assert isinstance(topk, (int, list, tuple)) assert isinstance(topk, (int, list, tuple))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册