未验证 提交 b042c226 编写于 作者: W wangzhen38 提交者: GitHub

fix auc instag (#40672)

* fix auc instag

* update ins_tag_w default

* update ins_tag_weight default style

* update ins_tag_weight default style

* optimize

* update by reviews of zwh

* optmize default value

* fix auc_op order
上级 f2bc1576
...@@ -25,6 +25,8 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -25,6 +25,8 @@ class AucOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Predict"), "Input", "Predict", "Auc"); OP_INOUT_CHECK(ctx->HasInput("Predict"), "Input", "Predict", "Auc");
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "Auc"); OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "Auc");
OP_INOUT_CHECK(ctx->HasInput("InsTagWeight"), "Input", "InsTagWeight",
"Auc");
auto predict_width = ctx->GetInputDim("Predict")[1]; auto predict_width = ctx->GetInputDim("Predict")[1];
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
PADDLE_ENFORCE_LE(predict_width, 2, PADDLE_ENFORCE_LE(predict_width, 2,
...@@ -83,10 +85,11 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -83,10 +85,11 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Label", AddInput("Label",
"A 2D int tensor indicating the label of the training data. " "A 2D int tensor indicating the label of the training data. "
"shape: [batch_size, 1]"); "shape: [batch_size, 1]");
// TODO(typhoonzero): support weight input // TODO(typhoonzero): support weight input
AddInput("StatPos", "Statistic value when label = 1"); AddInput("StatPos", "Statistic value when label = 1");
AddInput("StatNeg", "Statistic value when label = 0"); AddInput("StatNeg", "Statistic value when label = 0");
AddInput("InsTagWeight",
"(Tensor) instag weight, 1 means real data, 0 means false data");
AddOutput("AUC", AddOutput("AUC",
"A scalar representing the " "A scalar representing the "
......
...@@ -112,6 +112,12 @@ class AucCUDAKernel : public framework::OpKernel<T> { ...@@ -112,6 +112,12 @@ class AucCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto *predict = ctx.Input<Tensor>("Predict"); auto *predict = ctx.Input<Tensor>("Predict");
auto *label = ctx.Input<Tensor>("Label"); auto *label = ctx.Input<Tensor>("Label");
auto *ins_tag_weight = ctx.Input<Tensor>("InsTagWeight");
const auto *ins_tag_weight_value = ins_tag_weight->data<float>();
bool is_fake_data = 0;
if (ins_tag_weight_value[0] == 0) {
is_fake_data = 1;
}
int num_thresholds = ctx.Attr<int>("num_thresholds"); int num_thresholds = ctx.Attr<int>("num_thresholds");
int slide_steps = ctx.Attr<int>("slide_steps"); int slide_steps = ctx.Attr<int>("slide_steps");
...@@ -145,8 +151,12 @@ class AucCUDAKernel : public framework::OpKernel<T> { ...@@ -145,8 +151,12 @@ class AucCUDAKernel : public framework::OpKernel<T> {
cudaMemcpyDeviceToDevice); cudaMemcpyDeviceToDevice);
} }
if (slide_steps == 0 && is_fake_data) {
return;
}
statAuc(ctx, label, predict, num_thresholds, slide_steps, origin_stat_pos, statAuc(ctx, label, predict, num_thresholds, slide_steps, origin_stat_pos,
origin_stat_neg); origin_stat_neg, is_fake_data);
int sum_offset = slide_steps * (num_thresholds + 1); int sum_offset = slide_steps * (num_thresholds + 1);
auto stream = auto stream =
ctx.template device_context<platform::CUDADeviceContext>().stream(); ctx.template device_context<platform::CUDADeviceContext>().stream();
...@@ -165,8 +175,8 @@ class AucCUDAKernel : public framework::OpKernel<T> { ...@@ -165,8 +175,8 @@ class AucCUDAKernel : public framework::OpKernel<T> {
const framework::Tensor *label, const framework::Tensor *label,
const framework::Tensor *predict, const framework::Tensor *predict,
const int num_thresholds, const int slide_steps, const int num_thresholds, const int slide_steps,
int64_t *origin_stat_pos, int64_t *origin_stat_pos, int64_t *origin_stat_neg,
int64_t *origin_stat_neg) { const bool is_fake_data) {
size_t batch_size = predict->dims()[0]; size_t batch_size = predict->dims()[0];
size_t inference_width = predict->dims()[1]; size_t inference_width = predict->dims()[1];
const T *inference_data = predict->data<T>(); const T *inference_data = predict->data<T>();
...@@ -200,10 +210,12 @@ class AucCUDAKernel : public framework::OpKernel<T> { ...@@ -200,10 +210,12 @@ class AucCUDAKernel : public framework::OpKernel<T> {
PADDLE_CUDA_NUM_THREADS, 0, stream>>>( PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
label_data, inference_data, inference_width, num_thresholds, label_data, inference_data, inference_width, num_thresholds,
origin_stat_pos, origin_stat_neg, batch_size, slide_steps); origin_stat_pos, origin_stat_neg, batch_size, slide_steps);
UpdateSumDataKernel<<<(bucket_length + PADDLE_CUDA_NUM_THREADS - 1) / if (!is_fake_data) {
PADDLE_CUDA_NUM_THREADS, UpdateSumDataKernel<<<(bucket_length + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS, 0, stream>>>( PADDLE_CUDA_NUM_THREADS,
origin_stat_pos, origin_stat_neg, bucket_length, slide_steps); PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
origin_stat_pos, origin_stat_neg, bucket_length, slide_steps);
}
} }
}; };
......
...@@ -29,7 +29,12 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -29,7 +29,12 @@ class AucKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto *predict = ctx.Input<Tensor>("Predict"); auto *predict = ctx.Input<Tensor>("Predict");
auto *label = ctx.Input<Tensor>("Label"); auto *label = ctx.Input<Tensor>("Label");
auto *ins_tag_weight = ctx.Input<Tensor>("InsTagWeight");
const auto *ins_tag_weight_value = ins_tag_weight->data<float>();
bool is_fake_data = 0;
if (ins_tag_weight_value[0] == 0) {
is_fake_data = 1;
}
int num_thresholds = ctx.Attr<int>("num_thresholds"); int num_thresholds = ctx.Attr<int>("num_thresholds");
int slide_steps = ctx.Attr<int>("slide_steps"); int slide_steps = ctx.Attr<int>("slide_steps");
...@@ -60,8 +65,13 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -60,8 +65,13 @@ class AucKernel : public framework::OpKernel<T> {
(slide_steps > 0 ? 1 : 0)) * (slide_steps > 0 ? 1 : 0)) *
sizeof(int64_t)); sizeof(int64_t));
} }
// when calculate global_auc && is fake data, just do nothing
if (slide_steps == 0 && is_fake_data) {
return;
}
statAuc(label, predict, num_thresholds, slide_steps, origin_stat_pos, statAuc(label, predict, num_thresholds, slide_steps, origin_stat_pos,
origin_stat_neg); origin_stat_neg, is_fake_data);
int sum_offset = slide_steps * (num_thresholds + 1); int sum_offset = slide_steps * (num_thresholds + 1);
calcAuc(origin_stat_pos + sum_offset, origin_stat_neg + sum_offset, calcAuc(origin_stat_pos + sum_offset, origin_stat_neg + sum_offset,
...@@ -81,8 +91,8 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -81,8 +91,8 @@ class AucKernel : public framework::OpKernel<T> {
inline static void statAuc(const framework::Tensor *label, inline static void statAuc(const framework::Tensor *label,
const framework::Tensor *predict, const framework::Tensor *predict,
const int num_thresholds, const int slide_steps, const int num_thresholds, const int slide_steps,
int64_t *origin_stat_pos, int64_t *origin_stat_pos, int64_t *origin_stat_neg,
int64_t *origin_stat_neg) { const bool is_fake_data) {
size_t batch_size = predict->dims()[0]; size_t batch_size = predict->dims()[0];
size_t inference_width = predict->dims()[1]; size_t inference_width = predict->dims()[1];
const T *inference_data = predict->data<T>(); const T *inference_data = predict->data<T>();
...@@ -148,11 +158,13 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -148,11 +158,13 @@ class AucKernel : public framework::OpKernel<T> {
origin_stat_neg[cur_step_begin + binIdx] += 1; origin_stat_neg[cur_step_begin + binIdx] += 1;
} }
} }
for (int i = 0; i < bucket_length; ++i) { if (!is_fake_data) {
origin_stat_pos[sum_step_begin + i] += for (int i = 0; i < bucket_length; ++i) {
origin_stat_pos[cur_step_begin + i]; origin_stat_pos[sum_step_begin + i] +=
origin_stat_neg[sum_step_begin + i] += origin_stat_pos[cur_step_begin + i];
origin_stat_neg[cur_step_begin + i]; origin_stat_neg[sum_step_begin + i] +=
origin_stat_neg[cur_step_begin + i];
}
} }
} }
......
...@@ -24,6 +24,7 @@ from ..framework import Variable, in_dygraph_mode, _varbase_creator ...@@ -24,6 +24,7 @@ from ..framework import Variable, in_dygraph_mode, _varbase_creator
from .. import core from .. import core
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from . import nn from . import nn
from . import tensor
from ..data_feeder import check_variable_and_dtype from ..data_feeder import check_variable_and_dtype
__all__ = ['accuracy', 'auc'] __all__ = ['accuracy', 'auc']
...@@ -113,7 +114,8 @@ def auc(input, ...@@ -113,7 +114,8 @@ def auc(input,
curve='ROC', curve='ROC',
num_thresholds=2**12 - 1, num_thresholds=2**12 - 1,
topk=1, topk=1,
slide_steps=1): slide_steps=1,
ins_tag_weight=None):
""" """
**Area Under the Curve (AUC) Layer** **Area Under the Curve (AUC) Layer**
...@@ -143,7 +145,9 @@ def auc(input, ...@@ -143,7 +145,9 @@ def auc(input,
the roc curve. Default 200. the roc curve. Default 200.
topk(int): only topk number of prediction output will be used for auc. topk(int): only topk number of prediction output will be used for auc.
slide_steps: when calc batch auc, we can not only use step currently but the previous steps can be used. slide_steps=1 means use the current step, slide_steps=3 means use current step and the previous second steps, slide_steps=0 use all of the steps. slide_steps: when calc batch auc, we can not only use step currently but the previous steps can be used. slide_steps=1 means use the current step, slide_steps=3 means use current step and the previous second steps, slide_steps=0 use all of the steps.
ins_tag_weight(Variable): A 2D int Variable indicating the ins_tag_weight of the training
data. 1 means real data, 0 means fake data.
A LoDTensor or Tensor with type float32,float64.
Returns: Returns:
Variable: A tuple representing the current AUC. Variable: A tuple representing the current AUC.
...@@ -159,6 +163,7 @@ def auc(input, ...@@ -159,6 +163,7 @@ def auc(input,
data = fluid.data(name="input", shape=[-1, 32,32], dtype="float32") data = fluid.data(name="input", shape=[-1, 32,32], dtype="float32")
label = fluid.data(name="label", shape=[-1], dtype="int") label = fluid.data(name="label", shape=[-1], dtype="int")
ins_tag_weight = fluid.data(name="ins_tag_weight", shape=[-1], dtype="float32")
fc_out = fluid.layers.fc(input=data, size=2) fc_out = fluid.layers.fc(input=data, size=2)
predict = fluid.layers.softmax(input=fc_out) predict = fluid.layers.softmax(input=fc_out)
result=fluid.layers.auc(input=predict, label=label) result=fluid.layers.auc(input=predict, label=label)
...@@ -169,14 +174,22 @@ def auc(input, ...@@ -169,14 +174,22 @@ def auc(input,
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
x = np.random.rand(3,32,32).astype("float32") x = np.random.rand(3,32,32).astype("float32")
y = np.array([1,0,1]) y = np.array([1,0,1])
output= exe.run(feed={"input": x,"label": y}, z = np.array([1,1,1]) #this means real data
output= exe.run(feed={"input": x,"label": y, "ins_tag_weight": z},
fetch_list=[result[0]]) fetch_list=[result[0]])
print(output) print(output)
#[array([0.5])] #[array([0.5])]
""" """
helper = LayerHelper("auc", **locals()) helper = LayerHelper("auc", **locals())
if ins_tag_weight is None:
ins_tag_weight = tensor.fill_constant(
shape=[1, 1], dtype="float32", value=1.0)
check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'auc') check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'auc')
check_variable_and_dtype(label, 'label', ['int32', 'int64'], 'auc') check_variable_and_dtype(label, 'label', ['int32', 'int64'], 'auc')
check_variable_and_dtype(ins_tag_weight, 'ins_tag_weight',
['float32', 'float64'], 'auc')
auc_out = helper.create_variable_for_type_inference(dtype="float64") auc_out = helper.create_variable_for_type_inference(dtype="float64")
batch_auc_out = helper.create_variable_for_type_inference(dtype="float64") batch_auc_out = helper.create_variable_for_type_inference(dtype="float64")
# make tp, tn, fp, fn persistable, so that can accumulate all batches. # make tp, tn, fp, fn persistable, so that can accumulate all batches.
...@@ -215,7 +228,8 @@ def auc(input, ...@@ -215,7 +228,8 @@ def auc(input,
"Predict": [input], "Predict": [input],
"Label": [label], "Label": [label],
"StatPos": [batch_stat_pos], "StatPos": [batch_stat_pos],
"StatNeg": [batch_stat_neg] "StatNeg": [batch_stat_neg],
"InsTagWeight": [ins_tag_weight]
}, },
attrs={ attrs={
"curve": curve, "curve": curve,
...@@ -234,7 +248,8 @@ def auc(input, ...@@ -234,7 +248,8 @@ def auc(input,
"Predict": [input], "Predict": [input],
"Label": [label], "Label": [label],
"StatPos": [stat_pos], "StatPos": [stat_pos],
"StatNeg": [stat_neg] "StatNeg": [stat_neg],
"InsTagWeight": [ins_tag_weight]
}, },
attrs={ attrs={
"curve": curve, "curve": curve,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册