metrics.py 2.6 KB
Newer Older
D
dengkaipeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import

import six
import abc
19
import numpy as np
D
dengkaipeng 已提交
20

D
dengkaipeng 已提交
21 22 23 24 25
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)

26
__all__ = ['Metric', 'Accuracy']
D
dengkaipeng 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40


@six.add_metaclass(abc.ABCMeta)
class Metric(object):
    """
    Base class for metric, encapsulates metric logic and APIs

    Usage:
    m = SomeMetric()
    for prediction, label in ...:
        m.update(prediction, label)
    m.accumulate()
    """

D
dengkaipeng 已提交
41
    @abc.abstractmethod
D
dengkaipeng 已提交
42 43 44 45
    def reset(self):
        """
        Reset states and result
        """
D
dengkaipeng 已提交
46
        raise NotImplementedError("function 'reset' not implemented in {}.".format(self.__class__.__name__))
D
dengkaipeng 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61

    @abc.abstractmethod
    def update(self, *args, **kwargs):
        """
        Update states for metric
        """
        raise NotImplementedError("function 'update' not implemented in {}.".format(self.__class__.__name__))

    @abc.abstractmethod
    def accumulate(self):
        """
        Accumulates statistics, computes and returns the metric value
        """
        raise NotImplementedError("function 'accumulate' not implemented in {}.".format(self.__class__.__name__))

62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89

class Accuracy(Metric):
    """
    Encapsulates accuracy metric logic
    """

    def __init__(self, topk=(1, ), *args, **kwargs):
       super(Accuracy, self).__init__(*args, **kwargs) 
       self.topk = topk
       self.maxk = max(topk)
       self.reset()

    def update(self, pred, label, *args, **kwargs):
        pred = np.argsort(pred[0])[:, ::-1][:, :self.maxk]
        corr = (pred == np.repeat(label[0], self.maxk, 1))
        self.correct = np.append(self.correct, corr, axis=0)

    def reset(self):
       self.correct = np.empty((0, self.maxk), dtype="int32")

    def accumulate(self):
        res = []
        num_samples = self.correct.shape[0]
        for k in self.topk:
            correct_k = self.correct[:, :k].sum()
            res.append(round(100.0 * correct_k / num_samples, 2))
        return res