metric_op.py 5.4 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
All layers just related to metric.
"""

D
dzhwinter 已提交
18
import warnings
F
fengjiayi 已提交
19 20 21 22
from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant
from ..framework import Variable
from ..param_attr import ParamAttr
Q
qingqing01 已提交
23
import nn
F
fengjiayi 已提交
24

D
dzhwinter 已提交
25
__all__ = ['accuracy', 'auc']
F
fengjiayi 已提交
26 27 28 29


def accuracy(input, label, k=1, correct=None, total=None):
    """
D
dzhwinter 已提交
30 31 32
    accuracy layer.
    Refer to the https://en.wikipedia.org/wiki/Precision_and_recall

F
fengjiayi 已提交
33
    This function computes the accuracy using the input and label.
D
dzhwinter 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
    If the correct label occurs in top k predictions, then correct will increment by one.
    Note: the dtype of accuracy is determined by input. the input and label dtype can be different.

    Args:
        input(Variable): The input of accuracy layer, which is the predictions of network.
          Carry LoD information is supported.
        label(Variable): The label of dataset.
        k(int): The top k predictions for each class will be checked.
        correct(Variable): The correct predictions count.
        total(Variable): The total entries count.

    Returns:
        Variable: The correct rate.

    Examples:
        .. code-block:: python

           data = fluid.layers.data(name="data", shape=[-1, 32, 32], dtype="float32")
           label = fluid.layers.data(name="data", shape=[-1,1], dtype="int32")
           predict = fluid.layers.fc(input=data, size=10)
           acc = fluid.layers.accuracy(input=predict, label=label, k=5)

F
fengjiayi 已提交
56 57
    """
    helper = LayerHelper("accuracy", **locals())
Q
qingqing01 已提交
58
    topk_out, topk_indices = nn.topk(input, k=k)
F
fengjiayi 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
    acc_out = helper.create_tmp_variable(dtype="float32")
    if correct is None:
        correct = helper.create_tmp_variable(dtype="int64")
    if total is None:
        total = helper.create_tmp_variable(dtype="int64")
    helper.append_op(
        type="accuracy",
        inputs={
            "Out": [topk_out],
            "Indices": [topk_indices],
            "Label": [label]
        },
        outputs={
            "Accuracy": [acc_out],
            "Correct": [correct],
            "Total": [total],
        })
    return acc_out
D
dzhwinter 已提交
77 78


W
Wu Yi 已提交
79
def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
Y
Yibing Liu 已提交
80
    """
Y
Yibing Liu 已提交
81
    **Area Under the Curve (AUC) Layer**
Y
Yibing Liu 已提交
82 83 84 85

    This implementation computes the AUC according to forward output and label.
    It is used very widely in binary classification evaluation. 

Y
Yibing Liu 已提交
86 87 88
    Note: If input label contains values other than 0 and 1, it will be cast 
    to `bool`. Find the relevant definitions `here <https://en.wikipedia.org\
    /wiki/Receiver_operating_characteristic#Area_under_the_curve>`_.
Y
Yibing Liu 已提交
89 90

    There are two types of possible curves:
Y
Yibing Liu 已提交
91 92 93

        1. ROC: Receiver operating characteristic;
        2. PR: Precision Recall
Y
Yibing Liu 已提交
94 95 96 97 98 99 100 101 102 103 104

    Args:
        input(Variable): A floating-point 2D Variable, values are in the range 
                         [0, 1]. Each row is sorted in descending order. This 
                         input should be the output of topk. Typically, this 
                         Variable indicates the probability of each label.
        label(Variable): A 2D int Variable indicating the label of the training 
                         data. The height is batch size and width is always 1.
        curve(str): Curve type, can be 'ROC' or 'PR'. Default 'ROC'.
        num_thresholds(int): The number of thresholds to use when discretizing 
                             the roc curve. Default 200.
W
Wu Yi 已提交
105
        topk(int): only topk number of prediction output will be used for auc.
Y
Yibing Liu 已提交
106 107 108 109 110 111 112

    Returns:
        Variable: A scalar representing the current AUC.

    Examples:
        .. code-block:: python
        
Y
Yibing Liu 已提交
113 114 115
            # network is a binary classification model and label the ground truth
            prediction = network(image, is_infer=True)
            auc_out=fluid.layers.auc(input=prediction, label=label)
Y
Yibing Liu 已提交
116
    """
D
dzhwinter 已提交
117
    helper = LayerHelper("auc", **locals())
Q
Qiao Longfei 已提交
118
    auc_out = helper.create_tmp_variable(dtype="float64")
W
Wu Yi 已提交
119
    # make tp, tn, fp, fn persistable, so that can accumulate all batches.
Q
Qiao Longfei 已提交
120 121 122 123
    tp = helper.create_global_variable(persistable=True, dtype='int64')
    tn = helper.create_global_variable(persistable=True, dtype='int64')
    fp = helper.create_global_variable(persistable=True, dtype='int64')
    fn = helper.create_global_variable(persistable=True, dtype='int64')
W
Wu Yi 已提交
124 125 126 127 128
    for var in [tp, tn, fp, fn]:
        helper.set_variable_initializer(
            var, Constant(
                value=0.0, force_cpu=True))

D
dzhwinter 已提交
129
    helper.append_op(
130
        type="auc",
D
dzhwinter 已提交
131
        inputs={
Q
Qiao Longfei 已提交
132
            "Predict": [input],
W
Wu Yi 已提交
133 134 135 136 137
            "Label": [label],
            "TP": [tp],
            "TN": [tn],
            "FP": [fp],
            "FN": [fn]
D
dzhwinter 已提交
138 139 140
        },
        attrs={"curve": curve,
               "num_thresholds": num_thresholds},
W
Wu Yi 已提交
141 142 143 144 145 146 147
        outputs={
            "AUC": [auc_out],
            "TPOut": [tp],
            "TNOut": [tn],
            "FPOut": [fp],
            "FNOut": [fn]
        })
Q
Qiao Longfei 已提交
148
    return auc_out, [tp, tn, fp, fn]