metric_op.py 5.4 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
All layers just related to metric.
"""

18 19
from __future__ import print_function

D
dzhwinter 已提交
20
import warnings
F
fengjiayi 已提交
21 22 23 24
from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant
from ..framework import Variable
from ..param_attr import ParamAttr
25
from . import nn
F
fengjiayi 已提交
26

D
dzhwinter 已提交
27
__all__ = ['accuracy', 'auc']
F
fengjiayi 已提交
28 29 30 31


def accuracy(input, label, k=1, correct=None, total=None):
    """
D
dzhwinter 已提交
32 33 34
    accuracy layer.
    Refer to the https://en.wikipedia.org/wiki/Precision_and_recall

F
fengjiayi 已提交
35
    This function computes the accuracy using the input and label.
D
dzhwinter 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
    If the correct label occurs in top k predictions, then correct will increment by one.
    Note: the dtype of accuracy is determined by input. the input and label dtype can be different.

    Args:
        input(Variable): The input of accuracy layer, which is the predictions of network.
          Carry LoD information is supported.
        label(Variable): The label of dataset.
        k(int): The top k predictions for each class will be checked.
        correct(Variable): The correct predictions count.
        total(Variable): The total entries count.

    Returns:
        Variable: The correct rate.

    Examples:
        .. code-block:: python

           data = fluid.layers.data(name="data", shape=[-1, 32, 32], dtype="float32")
           label = fluid.layers.data(name="data", shape=[-1,1], dtype="int32")
           predict = fluid.layers.fc(input=data, size=10)
           acc = fluid.layers.accuracy(input=predict, label=label, k=5)

F
fengjiayi 已提交
58 59
    """
    helper = LayerHelper("accuracy", **locals())
Q
qingqing01 已提交
60
    topk_out, topk_indices = nn.topk(input, k=k)
F
fengjiayi 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
    acc_out = helper.create_tmp_variable(dtype="float32")
    if correct is None:
        correct = helper.create_tmp_variable(dtype="int64")
    if total is None:
        total = helper.create_tmp_variable(dtype="int64")
    helper.append_op(
        type="accuracy",
        inputs={
            "Out": [topk_out],
            "Indices": [topk_indices],
            "Label": [label]
        },
        outputs={
            "Accuracy": [acc_out],
            "Correct": [correct],
            "Total": [total],
        })
    return acc_out
D
dzhwinter 已提交
79 80


W
Wu Yi 已提交
81
def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
Y
Yibing Liu 已提交
82
    """
Y
Yibing Liu 已提交
83
    **Area Under the Curve (AUC) Layer**
Y
Yibing Liu 已提交
84 85

    This implementation computes the AUC according to forward output and label.
86
    It is used very widely in binary classification evaluation.
Y
Yibing Liu 已提交
87

88
    Note: If input label contains values other than 0 and 1, it will be cast
Y
Yibing Liu 已提交
89 90
    to `bool`. Find the relevant definitions `here <https://en.wikipedia.org\
    /wiki/Receiver_operating_characteristic#Area_under_the_curve>`_.
Y
Yibing Liu 已提交
91 92

    There are two types of possible curves:
Y
Yibing Liu 已提交
93 94 95

        1. ROC: Receiver operating characteristic;
        2. PR: Precision Recall
Y
Yibing Liu 已提交
96 97

    Args:
98 99 100
        input(Variable): A floating-point 2D Variable, values are in the range
                         [0, 1]. Each row is sorted in descending order. This
                         input should be the output of topk. Typically, this
Y
Yibing Liu 已提交
101
                         Variable indicates the probability of each label.
102
        label(Variable): A 2D int Variable indicating the label of the training
Y
Yibing Liu 已提交
103 104
                         data. The height is batch size and width is always 1.
        curve(str): Curve type, can be 'ROC' or 'PR'. Default 'ROC'.
105
        num_thresholds(int): The number of thresholds to use when discretizing
Y
Yibing Liu 已提交
106
                             the roc curve. Default 200.
W
Wu Yi 已提交
107
        topk(int): only topk number of prediction output will be used for auc.
Y
Yibing Liu 已提交
108 109 110 111 112 113

    Returns:
        Variable: A scalar representing the current AUC.

    Examples:
        .. code-block:: python
114

Y
Yibing Liu 已提交
115 116 117
            # network is a binary classification model and label the ground truth
            prediction = network(image, is_infer=True)
            auc_out=fluid.layers.auc(input=prediction, label=label)
Y
Yibing Liu 已提交
118
    """
D
dzhwinter 已提交
119
    helper = LayerHelper("auc", **locals())
Q
Qiao Longfei 已提交
120
    auc_out = helper.create_tmp_variable(dtype="float64")
W
Wu Yi 已提交
121
    # make tp, tn, fp, fn persistable, so that can accumulate all batches.
Q
Qiao Longfei 已提交
122 123 124 125
    tp = helper.create_global_variable(persistable=True, dtype='int64')
    tn = helper.create_global_variable(persistable=True, dtype='int64')
    fp = helper.create_global_variable(persistable=True, dtype='int64')
    fn = helper.create_global_variable(persistable=True, dtype='int64')
W
Wu Yi 已提交
126 127 128 129 130
    for var in [tp, tn, fp, fn]:
        helper.set_variable_initializer(
            var, Constant(
                value=0.0, force_cpu=True))

D
dzhwinter 已提交
131
    helper.append_op(
132
        type="auc",
D
dzhwinter 已提交
133
        inputs={
Q
Qiao Longfei 已提交
134
            "Predict": [input],
W
Wu Yi 已提交
135 136 137 138 139
            "Label": [label],
            "TP": [tp],
            "TN": [tn],
            "FP": [fp],
            "FN": [fn]
D
dzhwinter 已提交
140 141 142
        },
        attrs={"curve": curve,
               "num_thresholds": num_thresholds},
W
Wu Yi 已提交
143 144 145 146 147 148 149
        outputs={
            "AUC": [auc_out],
            "TPOut": [tp],
            "TNOut": [tn],
            "FPOut": [fp],
            "FNOut": [fn]
        })
Q
Qiao Longfei 已提交
150
    return auc_out, [tp, tn, fp, fn]