average.py 2.9 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
D
dzhwinter 已提交
16
import warnings
17

F
fengjiayi 已提交
18 19 20
"""
    Class of all kinds of Average.

21
    All Averages are accomplished via Python totally.
F
fengjiayi 已提交
22
    They do not change Paddle's Program, nor do anything to
23
    modify NN model's configuration. They are completely
F
fengjiayi 已提交
24 25 26
    wrappers of Python functions.
"""

D
dzhwinter 已提交
27 28
__all__ = ["WeightedAverage"]

F
fengjiayi 已提交
29 30

def _is_number_(var):
31 32 33 34 35
    return (
        isinstance(var, int)
        or isinstance(var, float)
        or (isinstance(var, np.ndarray) and var.shape == (1,))
    )
F
fengjiayi 已提交
36 37 38 39 40 41


def _is_number_or_matrix_(var):
    return _is_number_(var) or isinstance(var, np.ndarray)


42
class WeightedAverage:
F
fengjiayi 已提交
43 44 45
    """
    Calculate weighted average.

46
    The average calculating is accomplished via Python totally.
F
fengjiayi 已提交
47
    They do not change Paddle's Program, nor do anything to
48
    modify NN model's configuration. They are completely
F
fengjiayi 已提交
49 50 51 52
    wrappers of Python functions.

    Examples:
        .. code-block:: python
T
tink2123 已提交
53

54
            import paddle.fluid as fluid
F
fengjiayi 已提交
55 56 57 58 59 60 61 62 63
            avg = fluid.average.WeightedAverage()
            avg.add(value=2.0, weight=1)
            avg.add(value=4.0, weight=2)
            avg.eval()

            # The result is 3.333333333.
            # For (2.0 * 1 + 4.0 * 2) / (1 + 2) = 3.333333333
    """

F
fengjiayi 已提交
64
    def __init__(self):
D
dzhwinter 已提交
65
        warnings.warn(
66 67 68 69
            "The %s is deprecated, please use fluid.metrics.Accuracy instead."
            % (self.__class__.__name__),
            Warning,
        )
F
fengjiayi 已提交
70 71 72 73 74 75 76 77 78
        self.reset()

    def reset(self):
        self.numerator = None
        self.denominator = None

    def add(self, value, weight):
        if not _is_number_or_matrix_(value):
            raise ValueError(
79 80
                "The 'value' must be a number(int, float) or a numpy ndarray."
            )
F
fengjiayi 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93
        if not _is_number_(weight):
            raise ValueError("The 'weight' must be a number(int, float).")

        if self.numerator is None or self.denominator is None:
            self.numerator = value * weight
            self.denominator = weight
        else:
            self.numerator += value * weight
            self.denominator += weight

    def eval(self):
        if self.numerator is None or self.denominator is None:
            raise ValueError(
94 95
                "There is no data to be averaged in WeightedAverage."
            )
F
fengjiayi 已提交
96
        return self.numerator / self.denominator