metric_op.py 6.8 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
All layers just related to metric.
"""

18 19
from __future__ import print_function

D
dzhwinter 已提交
20
import warnings
F
fengjiayi 已提交
21 22 23 24
from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant
from ..framework import Variable
from ..param_attr import ParamAttr
25
from . import nn
F
fengjiayi 已提交
26

D
dzhwinter 已提交
27
__all__ = ['accuracy', 'auc']
F
fengjiayi 已提交
28 29 30 31


def accuracy(input, label, k=1, correct=None, total=None):
    """
D
dzhwinter 已提交
32 33 34
    accuracy layer.
    Refer to the https://en.wikipedia.org/wiki/Precision_and_recall

F
fengjiayi 已提交
35
    This function computes the accuracy using the input and label.
D
dzhwinter 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
    If the correct label occurs in top k predictions, then correct will increment by one.
    Note: the dtype of accuracy is determined by input. the input and label dtype can be different.

    Args:
        input(Variable): The input of accuracy layer, which is the predictions of network.
          Carry LoD information is supported.
        label(Variable): The label of dataset.
        k(int): The top k predictions for each class will be checked.
        correct(Variable): The correct predictions count.
        total(Variable): The total entries count.

    Returns:
        Variable: The correct rate.

    Examples:
        .. code-block:: python

           data = fluid.layers.data(name="data", shape=[-1, 32, 32], dtype="float32")
           label = fluid.layers.data(name="data", shape=[-1,1], dtype="int32")
           predict = fluid.layers.fc(input=data, size=10)
           acc = fluid.layers.accuracy(input=predict, label=label, k=5)

F
fengjiayi 已提交
58 59
    """
    helper = LayerHelper("accuracy", **locals())
Q
qingqing01 已提交
60
    topk_out, topk_indices = nn.topk(input, k=k)
X
Xin Pan 已提交
61
    acc_out = helper.create_variable_for_type_inference(dtype="float32")
F
fengjiayi 已提交
62
    if correct is None:
X
Xin Pan 已提交
63
        correct = helper.create_variable_for_type_inference(dtype="int64")
F
fengjiayi 已提交
64
    if total is None:
X
Xin Pan 已提交
65
        total = helper.create_variable_for_type_inference(dtype="int64")
F
fengjiayi 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78
    helper.append_op(
        type="accuracy",
        inputs={
            "Out": [topk_out],
            "Indices": [topk_indices],
            "Label": [label]
        },
        outputs={
            "Accuracy": [acc_out],
            "Correct": [correct],
            "Total": [total],
        })
    return acc_out
D
dzhwinter 已提交
79 80


T
tangwei12 已提交
81 82 83 84 85 86
def auc(input,
        label,
        curve='ROC',
        num_thresholds=2**12 - 1,
        topk=1,
        slide_steps=1):
Y
Yibing Liu 已提交
87
    """
Y
Yibing Liu 已提交
88
    **Area Under the Curve (AUC) Layer**
Y
Yibing Liu 已提交
89 90

    This implementation computes the AUC according to forward output and label.
91
    It is used very widely in binary classification evaluation.
Y
Yibing Liu 已提交
92

93
    Note: If input label contains values other than 0 and 1, it will be cast
Y
Yibing Liu 已提交
94 95
    to `bool`. Find the relevant definitions `here <https://en.wikipedia.org\
    /wiki/Receiver_operating_characteristic#Area_under_the_curve>`_.
Y
Yibing Liu 已提交
96 97

    There are two types of possible curves:
Y
Yibing Liu 已提交
98 99 100

        1. ROC: Receiver operating characteristic;
        2. PR: Precision Recall
Y
Yibing Liu 已提交
101 102

    Args:
103 104 105
        input(Variable): A floating-point 2D Variable, values are in the range
                         [0, 1]. Each row is sorted in descending order. This
                         input should be the output of topk. Typically, this
Y
Yibing Liu 已提交
106
                         Variable indicates the probability of each label.
107
        label(Variable): A 2D int Variable indicating the label of the training
Y
Yibing Liu 已提交
108 109
                         data. The height is batch size and width is always 1.
        curve(str): Curve type, can be 'ROC' or 'PR'. Default 'ROC'.
110
        num_thresholds(int): The number of thresholds to use when discretizing
Y
Yibing Liu 已提交
111
                             the roc curve. Default 200.
W
Wu Yi 已提交
112
        topk(int): only topk number of prediction output will be used for auc.
T
tangwei12 已提交
113 114
        slide_steps: when calc batch auc, we can not only use step currently but the previous steps can be used. slide_steps=1 means use the current step, slide_steps=3 means use current step and the previous second steps, slide_steps=0 use all of the steps.

Y
Yibing Liu 已提交
115 116 117 118 119 120

    Returns:
        Variable: A scalar representing the current AUC.

    Examples:
        .. code-block:: python
121

Y
Yibing Liu 已提交
122 123 124
            # network is a binary classification model and label the ground truth
            prediction = network(image, is_infer=True)
            auc_out=fluid.layers.auc(input=prediction, label=label)
Y
Yibing Liu 已提交
125
    """
D
dzhwinter 已提交
126
    helper = LayerHelper("auc", **locals())
X
Xin Pan 已提交
127 128
    auc_out = helper.create_variable_for_type_inference(dtype="float64")
    batch_auc_out = helper.create_variable_for_type_inference(dtype="float64")
W
Wu Yi 已提交
129
    # make tp, tn, fp, fn persistable, so that can accumulate all batches.
T
tangwei12 已提交
130 131 132 133 134 135 136 137 138 139 140 141

    # for batch auc
    batch_stat_pos = helper.create_global_variable(
        persistable=True,
        dtype='int64',
        shape=[slide_steps, num_thresholds + 1])
    batch_stat_neg = helper.create_global_variable(
        persistable=True,
        dtype='int64',
        shape=[slide_steps, num_thresholds + 1])

    # for global auc
T
tangwei12 已提交
142
    stat_pos = helper.create_global_variable(
T
tangwei12 已提交
143
        persistable=True, dtype='int64', shape=[1, num_thresholds + 1])
T
tangwei12 已提交
144
    stat_neg = helper.create_global_variable(
T
tangwei12 已提交
145
        persistable=True, dtype='int64', shape=[1, num_thresholds + 1])
T
tangwei12 已提交
146

T
tangwei12 已提交
147
    for var in [batch_stat_pos, batch_stat_neg, stat_pos, stat_neg]:
W
Wu Yi 已提交
148 149 150 151
        helper.set_variable_initializer(
            var, Constant(
                value=0.0, force_cpu=True))

T
tangwei12 已提交
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
    # Batch AUC
    helper.append_op(
        type="auc",
        inputs={
            "Predict": [input],
            "Label": [label],
            "StatPos": [batch_stat_pos],
            "StatNeg": [batch_stat_neg]
        },
        attrs={
            "curve": curve,
            "num_thresholds": num_thresholds,
            "slide_steps": slide_steps
        },
        outputs={
            "AUC": [batch_auc_out],
            "StatPosOut": [batch_stat_pos],
            "StatNegOut": [batch_stat_neg]
        })
    # Global AUC
D
dzhwinter 已提交
172
    helper.append_op(
173
        type="auc",
D
dzhwinter 已提交
174
        inputs={
Q
Qiao Longfei 已提交
175
            "Predict": [input],
W
Wu Yi 已提交
176
            "Label": [label],
T
tangwei12 已提交
177 178
            "StatPos": [stat_pos],
            "StatNeg": [stat_neg]
D
dzhwinter 已提交
179
        },
T
tangwei12 已提交
180 181 182 183 184
        attrs={
            "curve": curve,
            "num_thresholds": num_thresholds,
            "slide_steps": 0
        },
W
Wu Yi 已提交
185 186
        outputs={
            "AUC": [auc_out],
T
tangwei12 已提交
187 188
            "StatPosOut": [stat_pos],
            "StatNegOut": [stat_neg]
W
Wu Yi 已提交
189
        })
T
tangwei12 已提交
190 191 192
    return auc_out, batch_auc_out, [
        batch_stat_pos, batch_stat_neg, stat_pos, stat_neg
    ]