提交 3e4af044 编写于 作者: W weishengyu

add default metrics

上级 5dc93ac0
......@@ -93,37 +93,17 @@ class Trainer(object):
self.train_metric_func = None
self.eval_metric_func = None
def _build_metric_info(self, metric_config, mode="train"):
"""
_build_metric_info: build metrics according to current mode
Return:
metric: dict of the metrics info
"""
metric = None
mode = mode.capitalize()
if mode in metric_config and metric_config[mode] is not None:
metric = build_metrics(metric_config[mode])
return metric
def _build_loss_info(self, loss_config, mode="train"):
"""
_build_loss_info: build loss according to current mode
Return:
loss_dict: dict of the loss info
"""
loss = None
mode = mode.capitalize()
if mode in loss_config and loss_config[mode] is not None:
loss = build_loss(loss_config[mode])
return loss
def train(self):
# build train loss and metric info
if self.train_loss_func is None:
self.train_loss_func = self._build_loss_info(self.config["Loss"])
if "Metric" in self.config and self.train_metric_func is None:
self.train_metric_func = self._build_metric_info(self.config[
"Metric"])
self.train_loss_func = build_loss(self.config["Loss"])
if self.train_metric_func is None:
metric_config = self.config.get("Metric", None)
if metric_config is None:
metric_config = [{"name": "TopkAcc", "topk": (1, 5)}]
else:
metric_config = metric_config["Train"]
self.train_metric_func = build_metrics(metric_config)
if self.train_dataloader is None:
self.train_dataloader = build_dataloader(self.config["DataLoader"],
......@@ -241,10 +221,26 @@ class Trainer(object):
@paddle.no_grad()
def eval(self, epoch_id=0):
self.model.eval()
if self.eval_loss_func is None:
loss_info = self.config.get("Loss", None)
if loss_info is None:
loss_info = [{"CELoss": {"weight": 1.0}}]
else:
loss_info = loss_info["Eval"]
self.eval_loss_func = build_loss(loss_info)
if self.eval_mode == "classification":
if self.eval_dataloader is None:
self.eval_dataloader = build_dataloader(
self.config["DataLoader"], "Eval", self.device)
if self.eval_metric_func is None:
metric_config = self.config.get("Metric", None)
if metric_config is None:
metric_config = [{"name": "TopkAcc", "topk": (1, 5)}]
else:
metric_config = metric_config["Eval"]
self.eval_metric_func = build_metrics(metric_config)
eval_result = self.eval_cls(epoch_id)
elif self.eval_mode == "retrieval":
......@@ -255,13 +251,14 @@ class Trainer(object):
if self.query_dataloader is None:
self.query_dataloader = build_dataloader(
self.config["DataLoader"], "Query", self.device)
# build train loss and metric info
if self.eval_loss_func is None:
self.eval_loss_func = self._build_loss_info(
self.config["Loss"], "eval")
# build metric info
if self.eval_metric_func is None:
self.eval_metric_func = self._build_metric_info(
self.config["Metric"], "eval")
metric_config = self.config.get("Metric", None)
if metric_config is None:
metric_config = [{"name": "Recallk", "topk": (1, 5)}]
else:
metric_config = metric_config["Eval"]
self.eval_metric_func = build_metrics(metric_config)
eval_result = self.eval_retrieval(epoch_id)
else:
logger.warning("Invalid eval mode: {}".format(self.eval_mode))
......
......@@ -16,7 +16,7 @@ from paddle import nn
import copy
from collections import OrderedDict
from .metrics import Topk, mAP, mINP, Recallk
from .metrics import TopkAcc, mAP, mINP, Recallk
class CombinedMetrics(nn.Layer):
......
......@@ -18,7 +18,7 @@ import paddle.nn as nn
# TODO: fix the format
class Topk(nn.Layer):
class TopkAcc(nn.Layer):
def __init__(self, topk=(1, 5)):
super().__init__()
assert isinstance(topk, (int, list, tuple))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册