From 3e4af0448e6193d8a2c4c55963ca9e98c6558f4b Mon Sep 17 00:00:00 2001 From: weishengyu Date: Fri, 4 Jun 2021 22:56:12 +0800 Subject: [PATCH] add default metrics --- ppcls/engine/trainer.py | 65 +++++++++++++++++++--------------------- ppcls/metric/__init__.py | 2 +- ppcls/metric/metrics.py | 2 +- 3 files changed, 33 insertions(+), 36 deletions(-) diff --git a/ppcls/engine/trainer.py b/ppcls/engine/trainer.py index a5a941ec..81fe515d 100644 --- a/ppcls/engine/trainer.py +++ b/ppcls/engine/trainer.py @@ -93,37 +93,17 @@ class Trainer(object): self.train_metric_func = None self.eval_metric_func = None - def _build_metric_info(self, metric_config, mode="train"): - """ - _build_metric_info: build metrics according to current mode - Return: - metric: dict of the metrics info - """ - metric = None - mode = mode.capitalize() - if mode in metric_config and metric_config[mode] is not None: - metric = build_metrics(metric_config[mode]) - return metric - - def _build_loss_info(self, loss_config, mode="train"): - """ - _build_loss_info: build loss according to current mode - Return: - loss_dict: dict of the loss info - """ - loss = None - mode = mode.capitalize() - if mode in loss_config and loss_config[mode] is not None: - loss = build_loss(loss_config[mode]) - return loss - def train(self): # build train loss and metric info if self.train_loss_func is None: - self.train_loss_func = self._build_loss_info(self.config["Loss"]) - if "Metric" in self.config and self.train_metric_func is None: - self.train_metric_func = self._build_metric_info(self.config[ - "Metric"]) + self.train_loss_func = build_loss(self.config["Loss"]) + if self.train_metric_func is None: + metric_config = self.config.get("Metric", None) + if metric_config is None: + metric_config = [{"name": "TopkAcc", "topk": (1, 5)}] + else: + metric_config = metric_config["Train"] + self.train_metric_func = build_metrics(metric_config) if self.train_dataloader is None: self.train_dataloader = build_dataloader(self.config["DataLoader"], @@ -241,10 +221,26 @@ class Trainer(object): @paddle.no_grad() def eval(self, epoch_id=0): self.model.eval() + if self.eval_loss_func is None: + loss_info = self.config.get("Loss", None) + if loss_info is None: + loss_info = [{"CELoss": {"weight": 1.0}}] + else: + loss_info = loss_info["Eval"] + self.eval_loss_func = build_loss(loss_info) if self.eval_mode == "classification": if self.eval_dataloader is None: self.eval_dataloader = build_dataloader( self.config["DataLoader"], "Eval", self.device) + + if self.eval_metric_func is None: + metric_config = self.config.get("Metric", None) + if metric_config is None: + metric_config = [{"name": "TopkAcc", "topk": (1, 5)}] + else: + metric_config = metric_config["Eval"] + self.eval_metric_func = build_metrics(metric_config) + eval_result = self.eval_cls(epoch_id) elif self.eval_mode == "retrieval": @@ -255,13 +251,14 @@ class Trainer(object): if self.query_dataloader is None: self.query_dataloader = build_dataloader( self.config["DataLoader"], "Query", self.device) - # build train loss and metric info - if self.eval_loss_func is None: - self.eval_loss_func = self._build_loss_info( - self.config["Loss"], "eval") + # build metric info if self.eval_metric_func is None: - self.eval_metric_func = self._build_metric_info( - self.config["Metric"], "eval") + metric_config = self.config.get("Metric", None) + if metric_config is None: + metric_config = [{"name": "Recallk", "topk": (1, 5)}] + else: + metric_config = metric_config["Eval"] + self.eval_metric_func = build_metrics(metric_config) eval_result = self.eval_retrieval(epoch_id) else: logger.warning("Invalid eval mode: {}".format(self.eval_mode)) diff --git a/ppcls/metric/__init__.py b/ppcls/metric/__init__.py index 40e2e17f..a1643efe 100644 --- a/ppcls/metric/__init__.py +++ b/ppcls/metric/__init__.py @@ -16,7 +16,7 @@ from paddle import nn import copy from collections import OrderedDict -from .metrics import Topk, mAP, mINP, Recallk +from .metrics import TopkAcc, mAP, mINP, Recallk class CombinedMetrics(nn.Layer): diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index fca2f9fc..0301ac05 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -18,7 +18,7 @@ import paddle.nn as nn # TODO: fix the format -class Topk(nn.Layer): +class TopkAcc(nn.Layer): def __init__(self, topk=(1, 5)): super().__init__() assert isinstance(topk, (int, list, tuple)) -- GitLab