metrics_util.py 7.1 KB
Newer Older
D
dengkaipeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#  Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

15 16 17 18 19 20 21 22 23 24
from __future__ import absolute_import
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division

import logging

import numpy as np
from metrics.youtube8m import eval_util as youtube8m_metrics
from metrics.kinetics import accuracy_metrics as kinetics_metrics
25
from metrics.multicrop_test import multicrop_test_metrics as multicrop_test_metrics
26 27 28 29 30

logger = logging.getLogger(__name__)


class Metrics(object):
31
    def __init__(self, name, mode, metrics_args):
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
        """Not implemented"""
        pass

    def calculate_and_log_out(self, loss, pred, label, info=''):
        """Not implemented"""
        pass

    def accumulate(self, loss, pred, label, info=''):
        """Not implemented"""
        pass

    def finalize_and_log_out(self, info=''):
        """Not implemented"""
        pass

    def reset(self):
        """Not implemented"""
        pass


class Youtube8mMetrics(Metrics):
53
    def __init__(self, name, mode, metrics_args):
54
        self.name = name
D
dengkaipeng 已提交
55
        self.mode = mode
56 57
        self.num_classes = metrics_args['MODEL']['num_classes']
        self.topk = metrics_args['MODEL']['topk']
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
        self.calculator = youtube8m_metrics.EvaluationMetrics(self.num_classes,
                                                              self.topk)

    def calculate_and_log_out(self, loss, pred, label, info=''):
        loss = np.mean(np.array(loss))
        hit_at_one = youtube8m_metrics.calculate_hit_at_one(pred, label)
        perr = youtube8m_metrics.calculate_precision_at_equal_recall_rate(pred,
                                                                          label)
        gap = youtube8m_metrics.calculate_gap(pred, label)
        logger.info(info + ' , loss = {0}, Hit@1 = {1}, PERR = {2}, GAP = {3}'.format(\
                     '%.6f' % loss, '%.2f' % hit_at_one, '%.2f' % perr, '%.2f' % gap))

    def accumulate(self, loss, pred, label, info=''):
        self.calculator.accumulate(loss, pred, label)

    def finalize_and_log_out(self, info=''):
        epoch_info_dict = self.calculator.get()
        logger.info(info + '\tavg_hit_at_one: {0},\tavg_perr: {1},\tavg_loss :{2},\taps: {3},\tgap:{4}'\
                     .format(epoch_info_dict['avg_hit_at_one'], epoch_info_dict['avg_perr'], \
                             epoch_info_dict['avg_loss'], epoch_info_dict['aps'], epoch_info_dict['gap']))

    def reset(self):
        self.calculator.clear()


class Kinetics400Metrics(Metrics):
84
    def __init__(self, name, mode, metrics_args):
85
        self.name = name
D
dengkaipeng 已提交
86
        self.mode = mode
87
        self.calculator = kinetics_metrics.MetricsCalculator(name, mode.lower())
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113

    def calculate_and_log_out(self, loss, pred, label, info=''):
        if loss is not None:
            loss = np.mean(np.array(loss))
        else:
            loss = 0.
        acc1, acc5 = self.calculator.calculate_metrics(loss, pred, label)
        logger.info(info + '\tLoss: {},\ttop1_acc: {}, \ttop5_acc: {}'.format('%.6f' % loss, \
                       '%.2f' % acc1, '%.2f' % acc5))

    def accumulate(self, loss, pred, label, info=''):
        self.calculator.accumulate(loss, pred, label)

    def finalize_and_log_out(self, info=''):
        self.calculator.finalize_metrics()
        metrics_dict = self.calculator.get_computed_metrics()
        loss = metrics_dict['avg_loss']
        acc1 = metrics_dict['avg_acc1']
        acc5 = metrics_dict['avg_acc5']
        logger.info(info + '\tLoss: {},\ttop1_acc: {}, \ttop5_acc: {}'.format('%.6f' % loss, \
                       '%.2f' % acc1, '%.2f' % acc5))

    def reset(self):
        self.calculator.reset()


114 115
class MulticropMetrics(Metrics):
    def __init__(self, name, mode, metrics_args):
116
        self.name = name
D
dengkaipeng 已提交
117 118
        self.mode = mode
        if mode == 'test':
119 120 121 122 123 124 125 126
            args = {}
            args['num_test_clips'] = metrics_args.TEST.num_test_clips
            args['dataset_size'] = metrics_args.TEST.dataset_size
            args['filename_gt'] = metrics_args.TEST.filename_gt
            args['checkpoint_dir'] = metrics_args.TEST.checkpoint_dir
            args['num_classes'] = metrics_args.MODEL.num_classes
            self.calculator = multicrop_test_metrics.MetricsCalculator(
                name, mode.lower(), **args)
127 128
        else:
            self.calculator = kinetics_metrics.MetricsCalculator(name,
D
dengkaipeng 已提交
129
                                                                 mode.lower())
130 131

    def calculate_and_log_out(self, loss, pred, label, info=''):
D
dengkaipeng 已提交
132
        if self.mode == 'test':
133 134 135 136 137 138 139 140 141 142 143 144 145 146
            pass
        else:
            if loss is not None:
                loss = np.mean(np.array(loss))
            else:
                loss = 0.
            acc1, acc5 = self.calculator.calculate_metrics(loss, pred, label)
            logger.info(info + '\tLoss: {},\ttop1_acc: {}, \ttop5_acc: {}'.format('%.6f' % loss, \
                                   '%.2f' % acc1, '%.2f' % acc5))

    def accumulate(self, loss, pred, label):
        self.calculator.accumulate(loss, pred, label)

    def finalize_and_log_out(self, info=''):
D
dengkaipeng 已提交
147
        if self.mode == 'test':
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
            self.calculator.finalize_metrics()
        else:
            self.calculator.finalize_metrics()
            metrics_dict = self.calculator.get_computed_metrics()
            loss = metrics_dict['avg_loss']
            acc1 = metrics_dict['avg_acc1']
            acc5 = metrics_dict['avg_acc5']
            logger.info(info + '\tLoss: {},\ttop1_acc: {}, \ttop5_acc: {}'.format('%.6f' % loss, \
                           '%.2f' % acc1, '%.2f' % acc5))

    def reset(self):
        self.calculator.reset()


class MetricsZoo(object):
    def __init__(self):
        self.metrics_zoo = {}

    def regist(self, name, metrics):
        assert metrics.__base__ == Metrics, "Unknow model type {}".format(
            type(metrics))
        self.metrics_zoo[name] = metrics

171
    def get(self, name, mode, cfg):
172 173
        for k, v in self.metrics_zoo.items():
            if k == name:
174
                return v(name, mode, cfg)
175 176 177 178 179 180 181 182 183 184 185
        raise MetricsNotFoundError(name, self.metrics_zoo.keys())


# singleton metrics_zoo
metrics_zoo = MetricsZoo()


def regist_metrics(name, metrics):
    metrics_zoo.regist(name, metrics)


186 187
def get_metrics(name, mode, cfg):
    return metrics_zoo.get(name, mode, cfg)
188 189 190 191 192 193 194 195


regist_metrics("NEXTVLAD", Youtube8mMetrics)
regist_metrics("ATTENTIONLSTM", Youtube8mMetrics)
regist_metrics("ATTENTIONCLUSTER", Youtube8mMetrics)
regist_metrics("TSN", Kinetics400Metrics)
regist_metrics("TSM", Kinetics400Metrics)
regist_metrics("STNET", Kinetics400Metrics)
196
regist_metrics("NONLOCAL", MulticropMetrics)