未验证 提交 c5619bbc 编写于 作者: W Wu Yi 提交者: GitHub

fix auc op (#12087)

* fix auc

* update

* update

* fix compile

* fix param name

* add doc string

* fix test
上级 f02a4da6
...@@ -35,7 +35,14 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -35,7 +35,14 @@ class AucOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(inference_height, label_height, PADDLE_ENFORCE_EQ(inference_height, label_height,
"Out and Label should have same height."); "Out and Label should have same height.");
int num_thres = ctx->Attrs().Get<int>("num_thresholds");
ctx->SetOutputDim("AUC", {1}); ctx->SetOutputDim("AUC", {1});
ctx->SetOutputDim("TPOut", {num_thres});
ctx->SetOutputDim("TNOut", {num_thres});
ctx->SetOutputDim("FPOut", {num_thres});
ctx->SetOutputDim("FNOut", {num_thres});
ctx->ShareLoD("Out", /*->*/ "AUC"); ctx->ShareLoD("Out", /*->*/ "AUC");
} }
...@@ -63,10 +70,18 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -63,10 +70,18 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Label", AddInput("Label",
"A 2D int tensor indicating the label of the training data." "A 2D int tensor indicating the label of the training data."
"The height is batch size and width is always 1."); "The height is batch size and width is always 1.");
AddInput("TP", "True-Positive value.");
AddInput("FP", "False-Positive value.");
AddInput("TN", "True-Negative value.");
AddInput("FN", "False-Negative value.");
// TODO(typhoonzero): support weight input // TODO(typhoonzero): support weight input
AddOutput("AUC", AddOutput("AUC",
"A scalar representing the " "A scalar representing the "
"current area-under-the-curve."); "current area-under-the-curve.");
AddOutput("TPOut", "True-Positive value.");
AddOutput("FPOut", "False-Positive value.");
AddOutput("TNOut", "True-Negative value.");
AddOutput("FNOut", "False-Negative value.");
AddAttr<std::string>("curve", "Curve type, can be 'ROC' or 'PR'.") AddAttr<std::string>("curve", "Curve type, can be 'ROC' or 'PR'.")
.SetDefault("ROC"); .SetDefault("ROC");
......
...@@ -34,6 +34,12 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -34,6 +34,12 @@ class AucKernel : public framework::OpKernel<T> {
auto* inference = ctx.Input<Tensor>("Out"); auto* inference = ctx.Input<Tensor>("Out");
auto* label = ctx.Input<Tensor>("Label"); auto* label = ctx.Input<Tensor>("Label");
auto* auc = ctx.Output<Tensor>("AUC"); auto* auc = ctx.Output<Tensor>("AUC");
// Only use output var for now, make sure it's persistable and
// not cleaned up for each batch.
auto* true_positive = ctx.Output<Tensor>("TPOut");
auto* false_positive = ctx.Output<Tensor>("FPOut");
auto* true_negative = ctx.Output<Tensor>("TNOut");
auto* false_negative = ctx.Output<Tensor>("FNOut");
float* auc_data = auc->mutable_data<float>(ctx.GetPlace()); float* auc_data = auc->mutable_data<float>(ctx.GetPlace());
...@@ -54,19 +60,10 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -54,19 +60,10 @@ class AucKernel : public framework::OpKernel<T> {
const T* inference_data = inference->data<T>(); const T* inference_data = inference->data<T>();
const int64_t* label_data = label->data<int64_t>(); const int64_t* label_data = label->data<int64_t>();
// Create local tensor for storing the curve: TP, FN, TN, FP auto* tp_data = true_positive->mutable_data<int64_t>(ctx.GetPlace());
// TODO(typhoonzero): use eigen op to caculate these values. auto* fn_data = false_negative->mutable_data<int64_t>(ctx.GetPlace());
Tensor true_positive, false_positive, true_negative, false_negative; auto* tn_data = true_negative->mutable_data<int64_t>(ctx.GetPlace());
auto* fp_data = false_positive->mutable_data<int64_t>(ctx.GetPlace());
true_positive.Resize({num_thresholds});
false_negative.Resize({num_thresholds});
true_negative.Resize({num_thresholds});
false_positive.Resize({num_thresholds});
int64_t* tp_data = true_positive.mutable_data<int64_t>(ctx.GetPlace());
int64_t* fn_data = false_negative.mutable_data<int64_t>(ctx.GetPlace());
int64_t* tn_data = true_negative.mutable_data<int64_t>(ctx.GetPlace());
int64_t* fp_data = false_positive.mutable_data<int64_t>(ctx.GetPlace());
for (int idx_thresh = 0; idx_thresh < num_thresholds; idx_thresh++) { for (int idx_thresh = 0; idx_thresh < num_thresholds; idx_thresh++) {
// caculate TP, FN, TN, FP for current thresh // caculate TP, FN, TN, FP for current thresh
...@@ -91,10 +88,10 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -91,10 +88,10 @@ class AucKernel : public framework::OpKernel<T> {
} }
} }
// store rates // store rates
tp_data[idx_thresh] = tp; tp_data[idx_thresh] += tp;
fn_data[idx_thresh] = fn; fn_data[idx_thresh] += fn;
tn_data[idx_thresh] = tn; tn_data[idx_thresh] += tn;
fp_data[idx_thresh] = fp; fp_data[idx_thresh] += fp;
} }
// epsilon to avoid divide by zero. // epsilon to avoid divide by zero.
float epsilon = 1e-6; float epsilon = 1e-6;
......
...@@ -76,7 +76,7 @@ def accuracy(input, label, k=1, correct=None, total=None): ...@@ -76,7 +76,7 @@ def accuracy(input, label, k=1, correct=None, total=None):
return acc_out return acc_out
def auc(input, label, curve='ROC', num_thresholds=200): def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
""" """
**Area Under the Curve (AUC) Layer** **Area Under the Curve (AUC) Layer**
...@@ -102,6 +102,7 @@ def auc(input, label, curve='ROC', num_thresholds=200): ...@@ -102,6 +102,7 @@ def auc(input, label, curve='ROC', num_thresholds=200):
curve(str): Curve type, can be 'ROC' or 'PR'. Default 'ROC'. curve(str): Curve type, can be 'ROC' or 'PR'. Default 'ROC'.
num_thresholds(int): The number of thresholds to use when discretizing num_thresholds(int): The number of thresholds to use when discretizing
the roc curve. Default 200. the roc curve. Default 200.
topk(int): only topk number of prediction output will be used for auc.
Returns: Returns:
Variable: A scalar representing the current AUC. Variable: A scalar representing the current AUC.
...@@ -115,7 +116,7 @@ def auc(input, label, curve='ROC', num_thresholds=200): ...@@ -115,7 +116,7 @@ def auc(input, label, curve='ROC', num_thresholds=200):
""" """
warnings.warn( warnings.warn(
"This interface not recommended, fluid.layers.auc compute the auc at every minibatch, \ "This interface is not recommended, fluid.layers.auc compute the auc at every minibatch, \
but can not aggregate them and get the pass AUC, because pass \ but can not aggregate them and get the pass AUC, because pass \
auc can not be averaged with weighted from the minibatch auc value. \ auc can not be averaged with weighted from the minibatch auc value. \
Please use fluid.metrics.Auc, it can compute the auc value via Python natively, \ Please use fluid.metrics.Auc, it can compute the auc value via Python natively, \
...@@ -125,14 +126,34 @@ def auc(input, label, curve='ROC', num_thresholds=200): ...@@ -125,14 +126,34 @@ def auc(input, label, curve='ROC', num_thresholds=200):
topk_indices = helper.create_tmp_variable(dtype="int64") topk_indices = helper.create_tmp_variable(dtype="int64")
topk_out, topk_indices = nn.topk(input, k=k) topk_out, topk_indices = nn.topk(input, k=k)
auc_out = helper.create_tmp_variable(dtype="float32") auc_out = helper.create_tmp_variable(dtype="float32")
# make tp, tn, fp, fn persistable, so that can accumulate all batches.
tp = helper.create_global_variable(persistable=True)
tn = helper.create_global_variable(persistable=True)
fp = helper.create_global_variable(persistable=True)
fn = helper.create_global_variable(persistable=True)
for var in [tp, tn, fp, fn]:
helper.set_variable_initializer(
var, Constant(
value=0.0, force_cpu=True))
helper.append_op( helper.append_op(
type="auc", type="auc",
inputs={ inputs={
"Out": [topk_out], "Out": [topk_out],
"Indices": [topk_indices], "Indices": [topk_indices],
"Label": [label] "Label": [label],
"TP": [tp],
"TN": [tn],
"FP": [fp],
"FN": [fn]
}, },
attrs={"curve": curve, attrs={"curve": curve,
"num_thresholds": num_thresholds}, "num_thresholds": num_thresholds},
outputs={"AUC": [auc_out], }) outputs={
"AUC": [auc_out],
"TPOut": [tp],
"TNOut": [tn],
"FPOut": [fp],
"FNOut": [fn]
})
return auc_out return auc_out
...@@ -24,7 +24,20 @@ class TestAucOp(OpTest): ...@@ -24,7 +24,20 @@ class TestAucOp(OpTest):
indices = np.random.randint(0, 2, (128, 2)) indices = np.random.randint(0, 2, (128, 2))
labels = np.random.randint(0, 2, (128, 1)) labels = np.random.randint(0, 2, (128, 1))
num_thresholds = 200 num_thresholds = 200
self.inputs = {'Out': pred, 'Indices': indices, 'Label': labels} tp = np.zeros((num_thresholds, )).astype("int64")
tn = np.zeros((num_thresholds, )).astype("int64")
fp = np.zeros((num_thresholds, )).astype("int64")
fn = np.zeros((num_thresholds, )).astype("int64")
self.inputs = {
'Out': pred,
'Indices': indices,
'Label': labels,
'TP': tp,
'TN': tn,
'FP': fp,
'FN': fn
}
self.attrs = {'curve': 'ROC', 'num_thresholds': num_thresholds} self.attrs = {'curve': 'ROC', 'num_thresholds': num_thresholds}
# NOTE: sklearn use a different way to generate thresholds # NOTE: sklearn use a different way to generate thresholds
# which will cause the result differs slightly: # which will cause the result differs slightly:
...@@ -71,7 +84,13 @@ class TestAucOp(OpTest): ...@@ -71,7 +84,13 @@ class TestAucOp(OpTest):
y = (tpr[:num_thresholds - 1] + tpr[1:]) / 2.0 y = (tpr[:num_thresholds - 1] + tpr[1:]) / 2.0
auc_value = np.sum(x * y) auc_value = np.sum(x * y)
self.outputs = {'AUC': auc_value} self.outputs = {
'AUC': auc_value,
'TPOut': tp_list,
'FNOut': fn_list,
'TNOut': tn_list,
'FPOut': fp_list
}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册