classification.py 7.4 KB
Newer Older
D
dongshuilong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import platform
import paddle

21 22 23 24 25 26 27 28 29 30 31
from ...utils.misc import AverageMeter
from ...utils import logger
from ...data import build_dataloader
from ...loss import build_loss
from ...metric import build_metrics


class ClassEval(object):
    def __init__(self, config, mode, model):
        self.config = config
        self.model = model
G
gaotingquan 已提交
32
        self.print_batch_step = self.config["Global"]["print_batch_step"]
33
        self.use_dali = self.config["Global"].get("use_dali", False)
G
gaotingquan 已提交
34
        self.eval_metric_func = build_metrics(self.config, "Eval")
G
gaotingquan 已提交
35 36
        self.eval_dataloader = build_dataloader(self.config, "Eval")
        self.eval_loss_func = build_loss(self.config, "Eval")
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
        self.output_info = dict()

    @paddle.no_grad()
    def __call__(self, epoch_id=0):
        self.model.eval()

        if hasattr(self.eval_metric_func, "reset"):
            self.eval_metric_func.reset()

        time_info = {
            "batch_cost": AverageMeter(
                "batch_cost", '.5f', postfix=" s,"),
            "reader_cost": AverageMeter(
                "reader_cost", ".5f", postfix=" s,"),
        }

        tic = time.time()
G
gaotingquan 已提交
54
        total_samples = self.eval_dataloader.total_samples
55
        accum_samples = 0
G
gaotingquan 已提交
56 57
        max_iter = self.eval_dataloader.max_iter
        for iter_id, batch in enumerate(self.eval_dataloader):
58 59 60 61 62 63 64 65 66 67 68 69
            if iter_id >= max_iter:
                break
            if iter_id == 5:
                for key in time_info:
                    time_info[key].reset()

            time_info["reader_cost"].update(time.time() - tic)
            batch_size = batch[0].shape[0]
            batch[0] = paddle.to_tensor(batch[0])
            if not self.config["Global"].get("use_multilabel", False):
                batch[1] = batch[1].reshape([-1, 1]).astype("int64")

T
Tingquan Gao 已提交
70
            out = self.model(batch)
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96

            # just for DistributedBatchSampler issue: repeat sampling
            current_samples = batch_size * paddle.distributed.get_world_size()
            accum_samples += current_samples

            if isinstance(out, dict) and "Student" in out:
                out = out["Student"]
            if isinstance(out, dict) and "logits" in out:
                out = out["logits"]

            # gather Tensor when distributed
            if paddle.distributed.get_world_size() > 1:
                label_list = []
                device_id = paddle.distributed.ParallelEnv().device_id
                label = batch[1].cuda(device_id) if self.config["Global"][
                    "device"] == "gpu" else batch[1]
                paddle.distributed.all_gather(label_list, label)
                labels = paddle.concat(label_list, 0)

                if isinstance(out, list):
                    preds = []
                    for x in out:
                        pred_list = []
                        paddle.distributed.all_gather(pred_list, x)
                        pred_x = paddle.concat(pred_list, 0)
                        preds.append(pred_x)
H
HydrogenSulfate 已提交
97
                else:
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
                    pred_list = []
                    paddle.distributed.all_gather(pred_list, out)
                    preds = paddle.concat(pred_list, 0)

                if accum_samples > total_samples and not self.use_dali:
                    if isinstance(preds, list):
                        preds = [
                            pred[:total_samples + current_samples -
                                 accum_samples] for pred in preds
                        ]
                    else:
                        preds = preds[:total_samples + current_samples -
                                      accum_samples]
                    labels = labels[:total_samples + current_samples -
                                    accum_samples]
                    current_samples = total_samples + current_samples - accum_samples
D
dongshuilong 已提交
114
            else:
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
                labels = batch[1]
                preds = out

            # calc loss
            if self.eval_loss_func is not None:
                loss_dict = self.eval_loss_func(preds, labels)

                for key in loss_dict:
                    if key not in self.output_info:
                        self.output_info[key] = AverageMeter(key, '7.5f')
                    self.output_info[key].update(
                        float(loss_dict[key]), current_samples)

            #  calc metric
            if self.eval_metric_func is not None:
                self.eval_metric_func(preds, labels)
            time_info["batch_cost"].update(time.time() - tic)

G
gaotingquan 已提交
133
            if iter_id % self.print_batch_step == 0:
134 135 136 137
                time_msg = "s, ".join([
                    "{}: {:.5f}".format(key, time_info[key].avg)
                    for key in time_info
                ])
D
dongshuilong 已提交
138

139 140
                ips_msg = "ips: {:.5f} images/sec".format(
                    batch_size / time_info["batch_cost"].avg)
D
dongshuilong 已提交
141

142 143 144 145 146 147 148 149 150 151 152 153 154 155
                if "ATTRMetric" in self.config["Metric"]["Eval"][0]:
                    metric_msg = ""
                else:
                    metric_msg = ", ".join([
                        "{}: {:.5f}".format(key, self.output_info[key].val)
                        for key in self.output_info
                    ])
                    metric_msg += ", {}".format(self.eval_metric_func.avg_info)
                logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format(
                    epoch_id, iter_id, max_iter, metric_msg, time_msg,
                    ips_msg))

            tic = time.time()
        if self.use_dali:
G
gaotingquan 已提交
156
            self.eval_dataloader.reset()
157 158 159 160 161 162 163

        if "ATTRMetric" in self.config["Metric"]["Eval"][0]:
            metric_msg = ", ".join([
                "evalres: ma: {:.5f} label_f1: {:.5f} label_pos_recall: {:.5f} label_neg_recall: {:.5f} instance_f1: {:.5f} instance_acc: {:.5f} instance_prec: {:.5f} instance_recall: {:.5f}".
                format(*self.eval_metric_func.attr_res())
            ])
            logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
D
dongshuilong 已提交
164

165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
            # do not try to save best eval.model
            if self.eval_metric_func is None:
                return -1
            # return 1st metric in the dict
            return self.eval_metric_func.attr_res()[0]
        else:
            metric_msg = ", ".join([
                "{}: {:.5f}".format(key, self.output_info[key].avg)
                for key in self.output_info
            ])
            metric_msg += ", {}".format(self.eval_metric_func.avg_info)
            logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))

            # do not try to save best eval.model
            if self.eval_metric_func is None:
                return -1
            # return 1st metric in the dict
            return self.eval_metric_func.avg
        self.model.train()
        return eval_result