From 316eb3e968b8310f38a9308e813e15902f90f771 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Fri, 15 Jun 2018 03:03:28 -0700 Subject: [PATCH] Add doc for layers.auc --- python/paddle/fluid/layers/metric.py | 37 ++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/python/paddle/fluid/layers/metric.py b/python/paddle/fluid/layers/metric.py index a1c64ce2771..15d7c50bf45 100644 --- a/python/paddle/fluid/layers/metric.py +++ b/python/paddle/fluid/layers/metric.py @@ -53,6 +53,43 @@ def accuracy(input, label, k=1, correct=None, total=None): def auc(input, label, curve='ROC', num_thresholds=200): + """ + **Area Under The Curve (AUC) Layer** + + This implementation computes the AUC according to forward output and label. + It is used very widely in binary classification evaluation. + + As a note: If input label contains values other than 0 and 1, it will be + cast to bool. You can find the relevant definitions `here + `_. + + There are two types of possible curves: + 1. ROC: Receiver operating characteristic + 2. PR: Precision Recall + + Args: + input(Variable): A floating-point 2D Variable, values are in the range + [0, 1]. Each row is sorted in descending order. This + input should be the output of topk. Typically, this + Variable indicates the probability of each label. + label(Variable): A 2D int Variable indicating the label of the training + data. The height is batch size and width is always 1. + curve(str): Curve type, can be 'ROC' or 'PR'. Default 'ROC'. + num_thresholds(int): The number of thresholds to use when discretizing + the roc curve. Default 200. + + Returns: + Variable: A scalar representing the current AUC. + + Examples: + .. code-block:: python + + # network is a binary classification model and label the ground truth + prediction = network(image, is_infer=True) + auc_out=fluid.layers.auc(input=prediction, label=label) + """ + warnings.warn( "This interface not recommended, fluid.layers.auc compute the auc at every minibatch, \ but can not aggregate them and get the pass AUC, because pass \ -- GitLab