未验证 提交 d1e2efae 编写于 作者: T tangwei12 提交者: GitHub

reimplement auc in fluid (#13167)

* reimplement auc in pyton

* reimplement auc in fluid

* add auc unittest


* replace new auc in layers


* add batch Auc in Fluid

* name formated
上级 f94fdeaa
......@@ -312,7 +312,7 @@ paddle.fluid.layers.iou_similarity ArgSpec(args=[], varargs='args', keywords='kw
paddle.fluid.layers.box_coder ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.polygon_box_transform ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None))
paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk'], varargs=None, keywords=None, defaults=('ROC', 200, 1))
paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk'], varargs=None, keywords=None, defaults=('ROC', 4095, 1))
paddle.fluid.layers.exponential_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.natural_exp_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.inverse_time_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,))
......
......@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/auc_op.h"
#include <string>
namespace paddle {
namespace operators {
......@@ -36,15 +35,12 @@ class AucOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(predict_height, label_height,
"Out and Label should have same height.");
int num_thres = ctx->Attrs().Get<int>("num_thresholds");
int num_pred_buckets = ctx->Attrs().Get<int>("num_thresholds") + 1;
ctx->SetOutputDim("AUC", {1});
ctx->SetOutputDim("TPOut", {num_thres});
ctx->SetOutputDim("TNOut", {num_thres});
ctx->SetOutputDim("FPOut", {num_thres});
ctx->SetOutputDim("FNOut", {num_thres});
ctx->ShareLoD("Predict", /*->*/ "AUC");
ctx->SetOutputDim("BatchAUC", {1});
ctx->SetOutputDim("StatPosOut", {num_pred_buckets});
ctx->SetOutputDim("StatNegOut", {num_pred_buckets});
}
protected:
......@@ -66,25 +62,24 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Label",
"A 2D int tensor indicating the label of the training data. "
"shape: [batch_size, 1]");
AddInput("TP", "True-Positive value.");
AddInput("FP", "False-Positive value.");
AddInput("TN", "True-Negative value.");
AddInput("FN", "False-Negative value.");
// TODO(typhoonzero): support weight input
AddInput("StatPos", "Statistic value when label = 1");
AddInput("StatNeg", "Statistic value when label = 0");
AddOutput("AUC",
"A scalar representing the "
"current area-under-the-curve.");
AddOutput("TPOut", "True-Positive value.");
AddOutput("FPOut", "False-Positive value.");
AddOutput("TNOut", "True-Negative value.");
AddOutput("FNOut", "False-Negative value.");
AddOutput("BatchAUC", "The AUC for current batch");
AddOutput("StatPosOut", "Statistic value when label = 1");
AddOutput("StatNegOut", "Statistic value when label = 0");
AddAttr<std::string>("curve", "Curve type, can be 'ROC' or 'PR'.")
.SetDefault("ROC");
AddAttr<int>("num_thresholds",
"The number of thresholds to use when discretizing the"
" roc curve.")
.SetDefault(200);
.SetDefault((2 << 12) - 1);
AddComment(R"DOC(
Area Under The Curve (AUC) Operator.
......
......@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
......@@ -23,106 +23,85 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename DeviceContext, typename T>
class AucKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* predict = ctx.Input<Tensor>("Predict");
auto* label = ctx.Input<Tensor>("Label");
auto* auc = ctx.Output<Tensor>("AUC");
void Compute(const framework::ExecutionContext &ctx) const override {
auto *predict = ctx.Input<Tensor>("Predict");
auto *label = ctx.Input<Tensor>("Label");
std::string curve = ctx.Attr<std::string>("curve");
int num_thresholds = ctx.Attr<int>("num_thresholds");
int num_pred_buckets = num_thresholds + 1;
// Only use output var for now, make sure it's persistable and
// not cleaned up for each batch.
auto* true_positive = ctx.Output<Tensor>("TPOut");
auto* false_positive = ctx.Output<Tensor>("FPOut");
auto* true_negative = ctx.Output<Tensor>("TNOut");
auto* false_negative = ctx.Output<Tensor>("FNOut");
auto *auc = ctx.Output<Tensor>("AUC");
auto *stat_pos = ctx.Output<Tensor>("StatPosOut");
auto *stat_neg = ctx.Output<Tensor>("StatNegOut");
auto* auc_data = auc->mutable_data<double>(ctx.GetPlace());
auto *stat_pos_data = stat_pos->mutable_data<int64_t>(ctx.GetPlace());
auto *stat_neg_data = stat_neg->mutable_data<int64_t>(ctx.GetPlace());
calcAuc(ctx, label, predict, stat_pos_data, stat_neg_data, num_thresholds,
auc);
std::string curve = ctx.Attr<std::string>("curve");
int num_thresholds = ctx.Attr<int>("num_thresholds");
std::vector<double> thresholds_list;
thresholds_list.reserve(num_thresholds);
for (int i = 1; i < num_thresholds - 1; i++) {
thresholds_list[i] = static_cast<double>(i) / (num_thresholds - 1);
}
const double kEpsilon = 1e-7;
thresholds_list[0] = 0.0f - kEpsilon;
thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon;
auto *batch_auc = ctx.Output<Tensor>("BatchAUC");
std::vector<int64_t> stat_pos_batch(num_pred_buckets, 0);
std::vector<int64_t> stat_neg_batch(num_pred_buckets, 0);
calcAuc(ctx, label, predict, stat_pos_batch.data(), stat_neg_batch.data(),
num_thresholds, batch_auc);
}
private:
inline static double trapezoidArea(double X1, double X2, double Y1,
double Y2) {
return (X1 > X2 ? (X1 - X2) : (X2 - X1)) * (Y1 + Y2) / 2.0;
}
inline static void calcAuc(const framework::ExecutionContext &ctx,
const framework::Tensor *label,
const framework::Tensor *predict,
int64_t *stat_pos, int64_t *stat_neg,
int num_thresholds,
framework::Tensor *auc_tensor) {
size_t batch_size = predict->dims()[0];
size_t inference_width = predict->dims()[1];
const T *inference_data = predict->data<T>();
const auto *label_data = label->data<int64_t>();
auto *auc = auc_tensor->mutable_data<double>(ctx.GetPlace());
const T* inference_data = predict->data<T>();
const auto* label_data = label->data<int64_t>();
auto* tp_data = true_positive->mutable_data<int64_t>(ctx.GetPlace());
auto* fn_data = false_negative->mutable_data<int64_t>(ctx.GetPlace());
auto* tn_data = true_negative->mutable_data<int64_t>(ctx.GetPlace());
auto* fp_data = false_positive->mutable_data<int64_t>(ctx.GetPlace());
for (int idx_thresh = 0; idx_thresh < num_thresholds; idx_thresh++) {
// calculate TP, FN, TN, FP for current thresh
int64_t tp = 0, fn = 0, tn = 0, fp = 0;
for (size_t i = 0; i < batch_size; i++) {
// NOTE: label_data used as bool, labels > 0 will be treated as true.
if (label_data[i]) {
if (inference_data[i * inference_width + 1] >=
(thresholds_list[idx_thresh])) {
tp++;
} else {
fn++;
}
} else {
if (inference_data[i * inference_width + 1] >=
(thresholds_list[idx_thresh])) {
fp++;
} else {
tn++;
}
}
for (size_t i = 0; i < batch_size; i++) {
uint32_t binIdx = static_cast<uint32_t>(
inference_data[i * inference_width + 1] * num_thresholds);
if (label_data[i]) {
stat_pos[binIdx] += 1.0;
} else {
stat_neg[binIdx] += 1.0;
}
// store rates
tp_data[idx_thresh] += tp;
fn_data[idx_thresh] += fn;
tn_data[idx_thresh] += tn;
fp_data[idx_thresh] += fp;
}
// epsilon to avoid divide by zero.
double epsilon = 1e-6;
// Riemann sum to caculate auc.
Tensor tp_rate, fp_rate, rec_rate;
tp_rate.Resize({num_thresholds});
fp_rate.Resize({num_thresholds});
rec_rate.Resize({num_thresholds});
auto* tp_rate_data = tp_rate.mutable_data<double>(ctx.GetPlace());
auto* fp_rate_data = fp_rate.mutable_data<double>(ctx.GetPlace());
auto* rec_rate_data = rec_rate.mutable_data<double>(ctx.GetPlace());
for (int i = 0; i < num_thresholds; i++) {
tp_rate_data[i] = (static_cast<double>(tp_data[i]) + epsilon) /
(tp_data[i] + fn_data[i] + epsilon);
fp_rate_data[i] =
static_cast<double>(fp_data[i]) / (fp_data[i] + tn_data[i] + epsilon);
rec_rate_data[i] = (static_cast<double>(tp_data[i]) + epsilon) /
(tp_data[i] + fp_data[i] + epsilon);
*auc = 0.0f;
double totPos = 0.0;
double totNeg = 0.0;
double totPosPrev = 0.0;
double totNegPrev = 0.0;
int idx = num_thresholds;
while (idx >= 0) {
totPosPrev = totPos;
totNegPrev = totNeg;
totPos += stat_pos[idx];
totNeg += stat_neg[idx];
*auc += trapezoidArea(totNeg, totNegPrev, totPos, totPosPrev);
--idx;
}
*auc_data = 0.0f;
if (curve == "ROC") {
for (int i = 0; i < num_thresholds - 1; i++) {
auto dx = fp_rate_data[i] - fp_rate_data[i + 1];
auto y = (tp_rate_data[i] + tp_rate_data[i + 1]) / 2.0f;
*auc_data = *auc_data + dx * y;
}
} else if (curve == "PR") {
for (int i = 1; i < num_thresholds; i++) {
auto dx = tp_rate_data[i] - tp_rate_data[i - 1];
auto y = (rec_rate_data[i] + rec_rate_data[i - 1]) / 2.0f;
*auc_data = *auc_data + dx * y;
}
if (totPos > 0.0 && totNeg > 0.0) {
*auc = *auc / totPos / totNeg;
}
}
};
......
......@@ -78,7 +78,7 @@ def accuracy(input, label, k=1, correct=None, total=None):
return acc_out
def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1):
"""
**Area Under the Curve (AUC) Layer**
......@@ -118,16 +118,14 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
"""
helper = LayerHelper("auc", **locals())
auc_out = helper.create_tmp_variable(dtype="float64")
batch_auc_out = helper.create_tmp_variable(dtype="float64")
# make tp, tn, fp, fn persistable, so that can accumulate all batches.
tp = helper.create_global_variable(
persistable=True, dtype='int64', shape=[num_thresholds])
tn = helper.create_global_variable(
persistable=True, dtype='int64', shape=[num_thresholds])
fp = helper.create_global_variable(
persistable=True, dtype='int64', shape=[num_thresholds])
fn = helper.create_global_variable(
persistable=True, dtype='int64', shape=[num_thresholds])
for var in [tp, tn, fp, fn]:
stat_pos = helper.create_global_variable(
persistable=True, dtype='int64', shape=[num_thresholds + 1])
stat_neg = helper.create_global_variable(
persistable=True, dtype='int64', shape=[num_thresholds + 1])
for var in [stat_pos, stat_neg]:
helper.set_variable_initializer(
var, Constant(
value=0.0, force_cpu=True))
......@@ -137,18 +135,15 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
inputs={
"Predict": [input],
"Label": [label],
"TP": [tp],
"TN": [tn],
"FP": [fp],
"FN": [fn]
"StatPos": [stat_pos],
"StatNeg": [stat_neg]
},
attrs={"curve": curve,
"num_thresholds": num_thresholds},
outputs={
"AUC": [auc_out],
"TPOut": [tp],
"TNOut": [tn],
"FPOut": [fp],
"FNOut": [fn]
"BatchAUC": [batch_auc_out],
"StatPosOut": [stat_pos],
"StatNegOut": [stat_neg]
})
return auc_out, [tp, tn, fp, fn]
return auc_out, batch_auc_out, [stat_pos, stat_neg]
......@@ -558,8 +558,6 @@ class Auc(MetricBase):
name: metric name
curve: Specifies the name of the curve to be computed, 'ROC' [default] or
'PR' for the Precision-Recall-curve.
num_thresholds: The number of thresholds to use when discretizing the roc
curve.
"NOTE: only implement the ROC curve type via Python now."
......@@ -574,15 +572,14 @@ class Auc(MetricBase):
numpy_auc = metric.eval()
"""
def __init__(self, name, curve='ROC', num_thresholds=200):
def __init__(self, name, curve='ROC', num_thresholds=4095):
super(Auc, self).__init__(name=name)
self._curve = curve
self._num_thresholds = num_thresholds
self._epsilon = 1e-6
self.tp_list = np.zeros((num_thresholds, ))
self.fn_list = np.zeros((num_thresholds, ))
self.tn_list = np.zeros((num_thresholds, ))
self.fp_list = np.zeros((num_thresholds, ))
_num_pred_buckets = num_thresholds + 1
self._stat_pos = [0] * _num_pred_buckets
self._stat_neg = [0] * _num_pred_buckets
def update(self, preds, labels):
if not _is_numpy_(labels):
......@@ -590,41 +587,32 @@ class Auc(MetricBase):
if not _is_numpy_(preds):
raise ValueError("The 'predictions' must be a numpy ndarray.")
kepsilon = 1e-7 # to account for floating point imprecisions
thresholds = [(i + 1) * 1.0 / (self._num_thresholds - 1)
for i in range(self._num_thresholds - 2)]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
# calculate TP, FN, TN, FP count
for idx_thresh, thresh in enumerate(thresholds):
tp, fn, tn, fp = 0, 0, 0, 0
for i, lbl in enumerate(labels):
if lbl:
if preds[i, 1] >= thresh:
tp += 1
else:
fn += 1
else:
if preds[i, 1] >= thresh:
fp += 1
else:
tn += 1
self.tp_list[idx_thresh] += tp
self.fn_list[idx_thresh] += fn
self.tn_list[idx_thresh] += tn
self.fp_list[idx_thresh] += fp
for i, lbl in enumerate(labels):
value = preds[i, 1]
bin_idx = int(value * self._num_thresholds)
assert bin_idx <= self._num_thresholds
if lbl:
self._stat_pos[bin_idx] += 1.0
else:
self._stat_neg[bin_idx] += 1.0
@staticmethod
def trapezoid_area(x1, x2, y1, y2):
return abs(x1 - x2) * (y1 + y2) / 2.0
def eval(self):
epsilon = self._epsilon
num_thresholds = self._num_thresholds
tpr = (self.tp_list.astype("float32") + epsilon) / (
self.tp_list + self.fn_list + epsilon)
fpr = self.fp_list.astype("float32") / (
self.fp_list + self.tn_list + epsilon)
rec = (self.tp_list.astype("float32") + epsilon) / (
self.tp_list + self.fp_list + epsilon)
x = fpr[:num_thresholds - 1] - fpr[1:]
y = (tpr[:num_thresholds - 1] + tpr[1:]) / 2.0
auc_value = np.sum(x * y)
return auc_value
tot_pos = 0.0
tot_neg = 0.0
auc = 0.0
idx = self._num_thresholds
while idx >= 0:
tot_pos_prev = tot_pos
tot_neg_prev = tot_neg
tot_pos += self._stat_pos[idx]
tot_neg += self._stat_neg[idx]
auc += self.trapezoid_area(tot_neg, tot_neg_prev, tot_pos,
tot_pos_prev)
idx -= 1
return auc / tot_pos / tot_neg if tot_pos > 0.0 and tot_neg > 0.0 else 0.0
......@@ -26,18 +26,15 @@ class TestAucOp(OpTest):
pred = np.random.random((128, 2)).astype("float32")
labels = np.random.randint(0, 2, (128, 1))
num_thresholds = 200
tp = np.zeros((num_thresholds, )).astype("int64")
tn = np.zeros((num_thresholds, )).astype("int64")
fp = np.zeros((num_thresholds, )).astype("int64")
fn = np.zeros((num_thresholds, )).astype("int64")
stat_pos = np.zeros((num_thresholds + 1, )).astype("int64")
stat_neg = np.zeros((num_thresholds + 1, )).astype("int64")
self.inputs = {
'Predict': pred,
'Label': labels,
'TP': tp,
'TN': tn,
'FP': fp,
'FN': fn
"StatPos": stat_pos,
"StatNeg": stat_neg
}
self.attrs = {'curve': 'ROC', 'num_thresholds': num_thresholds}
......@@ -47,11 +44,10 @@ class TestAucOp(OpTest):
python_auc.update(pred, labels)
self.outputs = {
'AUC': python_auc.eval(),
'TPOut': python_auc.tp_list,
'FNOut': python_auc.fn_list,
'TNOut': python_auc.tn_list,
'FPOut': python_auc.fp_list
'AUC': np.array(python_auc.eval()),
'BatchAUC': np.array(python_auc.eval()),
'StatPosOut': np.array(python_auc._stat_pos),
'StatNegOut': np.array(python_auc._stat_neg)
}
def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册