average.py 2.8 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
D
dzhwinter 已提交
16
import warnings
F
fengjiayi 已提交
17 18 19
"""
    Class of all kinds of Average.

20
    All Averages are accomplished via Python totally.
F
fengjiayi 已提交
21
    They do not change Paddle's Program, nor do anything to
22
    modify NN model's configuration. They are completely
F
fengjiayi 已提交
23 24 25
    wrappers of Python functions.
"""

D
dzhwinter 已提交
26 27
__all__ = ["WeightedAverage"]

F
fengjiayi 已提交
28 29

def _is_number_(var):
30 31
    return isinstance(var, int) or isinstance(
        var, float) or (isinstance(var, np.ndarray) and var.shape == (1, ))
F
fengjiayi 已提交
32 33 34 35 36 37 38


def _is_number_or_matrix_(var):
    return _is_number_(var) or isinstance(var, np.ndarray)


class WeightedAverage(object):
F
fengjiayi 已提交
39 40 41
    """
    Calculate weighted average.

42
    The average calculating is accomplished via Python totally.
F
fengjiayi 已提交
43
    They do not change Paddle's Program, nor do anything to
44
    modify NN model's configuration. They are completely
F
fengjiayi 已提交
45 46 47 48
    wrappers of Python functions.

    Examples:
        .. code-block:: python
T
tink2123 已提交
49

50
            import paddle.fluid as fluid
F
fengjiayi 已提交
51 52 53 54 55 56 57 58 59
            avg = fluid.average.WeightedAverage()
            avg.add(value=2.0, weight=1)
            avg.add(value=4.0, weight=2)
            avg.eval()

            # The result is 3.333333333.
            # For (2.0 * 1 + 4.0 * 2) / (1 + 2) = 3.333333333
    """

F
fengjiayi 已提交
60
    def __init__(self):
D
dzhwinter 已提交
61 62 63
        warnings.warn(
            "The %s is deprecated, please use fluid.metrics.Accuracy instead." %
            (self.__class__.__name__), Warning)
F
fengjiayi 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
        self.reset()

    def reset(self):
        self.numerator = None
        self.denominator = None

    def add(self, value, weight):
        if not _is_number_or_matrix_(value):
            raise ValueError(
                "The 'value' must be a number(int, float) or a numpy ndarray.")
        if not _is_number_(weight):
            raise ValueError("The 'weight' must be a number(int, float).")

        if self.numerator is None or self.denominator is None:
            self.numerator = value * weight
            self.denominator = weight
        else:
            self.numerator += value * weight
            self.denominator += weight

    def eval(self):
        if self.numerator is None or self.denominator is None:
            raise ValueError(
                "There is no data to be averaged in WeightedAverage.")
        return self.numerator / self.denominator