metric.py 4.2 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
All layers just related to metric.
"""

D
dzhwinter 已提交
18
import warnings
F
fengjiayi 已提交
19 20 21 22
from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant
from ..framework import Variable
from ..param_attr import ParamAttr
Q
qingqing01 已提交
23
import nn
F
fengjiayi 已提交
24

D
dzhwinter 已提交
25
__all__ = ['accuracy', 'auc']
F
fengjiayi 已提交
26 27 28 29 30


def accuracy(input, label, k=1, correct=None, total=None):
    """
    This function computes the accuracy using the input and label.
Q
qingqing01 已提交
31
    The output is the top k inputs and their indices.
F
fengjiayi 已提交
32 33
    """
    helper = LayerHelper("accuracy", **locals())
Q
qingqing01 已提交
34
    topk_out, topk_indices = nn.topk(input, k=k)
F
fengjiayi 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
    acc_out = helper.create_tmp_variable(dtype="float32")
    if correct is None:
        correct = helper.create_tmp_variable(dtype="int64")
    if total is None:
        total = helper.create_tmp_variable(dtype="int64")
    helper.append_op(
        type="accuracy",
        inputs={
            "Out": [topk_out],
            "Indices": [topk_indices],
            "Label": [label]
        },
        outputs={
            "Accuracy": [acc_out],
            "Correct": [correct],
            "Total": [total],
        })
    return acc_out
D
dzhwinter 已提交
53 54 55


def auc(input, label, curve='ROC', num_thresholds=200):
Y
Yibing Liu 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
    """
    **Area Under The Curve (AUC) Layer**

    This implementation computes the AUC according to forward output and label.
    It is used very widely in binary classification evaluation. 

    As a note: If input label contains values other than 0 and 1, it will be 
    cast to bool. You can find the relevant definitions `here 
     <https://en.wikipedia.org/wiki/Receiver_operating_characteristic
      #Area_under_the_curve>`_.

    There are two types of possible curves:
    1. ROC: Receiver operating characteristic
    2. PR: Precision Recall

    Args:
        input(Variable): A floating-point 2D Variable, values are in the range 
                         [0, 1]. Each row is sorted in descending order. This 
                         input should be the output of topk. Typically, this 
                         Variable indicates the probability of each label.
        label(Variable): A 2D int Variable indicating the label of the training 
                         data. The height is batch size and width is always 1.
        curve(str): Curve type, can be 'ROC' or 'PR'. Default 'ROC'.
        num_thresholds(int): The number of thresholds to use when discretizing 
                             the roc curve. Default 200.

    Returns:
        Variable: A scalar representing the current AUC.

    Examples:
        .. code-block:: python
        
        # network is a binary classification model and label the ground truth
        prediction = network(image, is_infer=True)
        auc_out=fluid.layers.auc(input=prediction, label=label)
    """

D
dzhwinter 已提交
93 94 95 96 97 98 99 100 101
    warnings.warn(
        "This interface not recommended, fluid.layers.auc compute the auc at every minibatch, \
        but can not aggregate them and get the pass AUC, because pass \
        auc can not be averaged with weighted from the minibatch auc value. \
        Please use fluid.metrics.Auc, it can compute the auc value via Python natively, \
        which can get every minibatch and every pass auc value.", Warning)
    helper = LayerHelper("auc", **locals())
    topk_out = helper.create_tmp_variable(dtype=input.dtype)
    topk_indices = helper.create_tmp_variable(dtype="int64")
Q
qingqing01 已提交
102
    topk_out, topk_indices = nn.topk(input, k=k)
D
dzhwinter 已提交
103 104 105 106 107 108 109 110 111 112 113 114
    auc_out = helper.create_tmp_variable(dtype="float32")
    helper.append_op(
        type="accuracy",
        inputs={
            "Out": [topk_out],
            "Indices": [topk_indices],
            "Label": [label]
        },
        attrs={"curve": curve,
               "num_thresholds": num_thresholds},
        outputs={"AUC": [auc_out], })
    return auc_out