未验证 提交 b042c226 编写于 作者: W wangzhen38 提交者: GitHub

fix auc instag (#40672)

* fix auc instag

* update ins_tag_w default

* update ins_tag_weight default style

* update ins_tag_weight default style

* optimize

* update by reviews of zwh

* optmize default value

* fix auc_op order
上级 f2bc1576
......@@ -25,6 +25,8 @@ class AucOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Predict"), "Input", "Predict", "Auc");
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "Auc");
OP_INOUT_CHECK(ctx->HasInput("InsTagWeight"), "Input", "InsTagWeight",
"Auc");
auto predict_width = ctx->GetInputDim("Predict")[1];
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_LE(predict_width, 2,
......@@ -83,10 +85,11 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Label",
"A 2D int tensor indicating the label of the training data. "
"shape: [batch_size, 1]");
// TODO(typhoonzero): support weight input
AddInput("StatPos", "Statistic value when label = 1");
AddInput("StatNeg", "Statistic value when label = 0");
AddInput("InsTagWeight",
"(Tensor) instag weight, 1 means real data, 0 means false data");
AddOutput("AUC",
"A scalar representing the "
......
......@@ -112,6 +112,12 @@ class AucCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &ctx) const override {
auto *predict = ctx.Input<Tensor>("Predict");
auto *label = ctx.Input<Tensor>("Label");
auto *ins_tag_weight = ctx.Input<Tensor>("InsTagWeight");
const auto *ins_tag_weight_value = ins_tag_weight->data<float>();
bool is_fake_data = 0;
if (ins_tag_weight_value[0] == 0) {
is_fake_data = 1;
}
int num_thresholds = ctx.Attr<int>("num_thresholds");
int slide_steps = ctx.Attr<int>("slide_steps");
......@@ -145,8 +151,12 @@ class AucCUDAKernel : public framework::OpKernel<T> {
cudaMemcpyDeviceToDevice);
}
if (slide_steps == 0 && is_fake_data) {
return;
}
statAuc(ctx, label, predict, num_thresholds, slide_steps, origin_stat_pos,
origin_stat_neg);
origin_stat_neg, is_fake_data);
int sum_offset = slide_steps * (num_thresholds + 1);
auto stream =
ctx.template device_context<platform::CUDADeviceContext>().stream();
......@@ -165,8 +175,8 @@ class AucCUDAKernel : public framework::OpKernel<T> {
const framework::Tensor *label,
const framework::Tensor *predict,
const int num_thresholds, const int slide_steps,
int64_t *origin_stat_pos,
int64_t *origin_stat_neg) {
int64_t *origin_stat_pos, int64_t *origin_stat_neg,
const bool is_fake_data) {
size_t batch_size = predict->dims()[0];
size_t inference_width = predict->dims()[1];
const T *inference_data = predict->data<T>();
......@@ -200,10 +210,12 @@ class AucCUDAKernel : public framework::OpKernel<T> {
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
label_data, inference_data, inference_width, num_thresholds,
origin_stat_pos, origin_stat_neg, batch_size, slide_steps);
UpdateSumDataKernel<<<(bucket_length + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
origin_stat_pos, origin_stat_neg, bucket_length, slide_steps);
if (!is_fake_data) {
UpdateSumDataKernel<<<(bucket_length + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
origin_stat_pos, origin_stat_neg, bucket_length, slide_steps);
}
}
};
......
......@@ -29,7 +29,12 @@ class AucKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &ctx) const override {
auto *predict = ctx.Input<Tensor>("Predict");
auto *label = ctx.Input<Tensor>("Label");
auto *ins_tag_weight = ctx.Input<Tensor>("InsTagWeight");
const auto *ins_tag_weight_value = ins_tag_weight->data<float>();
bool is_fake_data = 0;
if (ins_tag_weight_value[0] == 0) {
is_fake_data = 1;
}
int num_thresholds = ctx.Attr<int>("num_thresholds");
int slide_steps = ctx.Attr<int>("slide_steps");
......@@ -60,8 +65,13 @@ class AucKernel : public framework::OpKernel<T> {
(slide_steps > 0 ? 1 : 0)) *
sizeof(int64_t));
}
// when calculate global_auc && is fake data, just do nothing
if (slide_steps == 0 && is_fake_data) {
return;
}
statAuc(label, predict, num_thresholds, slide_steps, origin_stat_pos,
origin_stat_neg);
origin_stat_neg, is_fake_data);
int sum_offset = slide_steps * (num_thresholds + 1);
calcAuc(origin_stat_pos + sum_offset, origin_stat_neg + sum_offset,
......@@ -81,8 +91,8 @@ class AucKernel : public framework::OpKernel<T> {
inline static void statAuc(const framework::Tensor *label,
const framework::Tensor *predict,
const int num_thresholds, const int slide_steps,
int64_t *origin_stat_pos,
int64_t *origin_stat_neg) {
int64_t *origin_stat_pos, int64_t *origin_stat_neg,
const bool is_fake_data) {
size_t batch_size = predict->dims()[0];
size_t inference_width = predict->dims()[1];
const T *inference_data = predict->data<T>();
......@@ -148,11 +158,13 @@ class AucKernel : public framework::OpKernel<T> {
origin_stat_neg[cur_step_begin + binIdx] += 1;
}
}
for (int i = 0; i < bucket_length; ++i) {
origin_stat_pos[sum_step_begin + i] +=
origin_stat_pos[cur_step_begin + i];
origin_stat_neg[sum_step_begin + i] +=
origin_stat_neg[cur_step_begin + i];
if (!is_fake_data) {
for (int i = 0; i < bucket_length; ++i) {
origin_stat_pos[sum_step_begin + i] +=
origin_stat_pos[cur_step_begin + i];
origin_stat_neg[sum_step_begin + i] +=
origin_stat_neg[cur_step_begin + i];
}
}
}
......
......@@ -24,6 +24,7 @@ from ..framework import Variable, in_dygraph_mode, _varbase_creator
from .. import core
from ..param_attr import ParamAttr
from . import nn
from . import tensor
from ..data_feeder import check_variable_and_dtype
__all__ = ['accuracy', 'auc']
......@@ -113,7 +114,8 @@ def auc(input,
curve='ROC',
num_thresholds=2**12 - 1,
topk=1,
slide_steps=1):
slide_steps=1,
ins_tag_weight=None):
"""
**Area Under the Curve (AUC) Layer**
......@@ -143,7 +145,9 @@ def auc(input,
the roc curve. Default 200.
topk(int): only topk number of prediction output will be used for auc.
slide_steps: when calc batch auc, we can not only use step currently but the previous steps can be used. slide_steps=1 means use the current step, slide_steps=3 means use current step and the previous second steps, slide_steps=0 use all of the steps.
ins_tag_weight(Variable): A 2D int Variable indicating the ins_tag_weight of the training
data. 1 means real data, 0 means fake data.
A LoDTensor or Tensor with type float32,float64.
Returns:
Variable: A tuple representing the current AUC.
......@@ -159,6 +163,7 @@ def auc(input,
data = fluid.data(name="input", shape=[-1, 32,32], dtype="float32")
label = fluid.data(name="label", shape=[-1], dtype="int")
ins_tag_weight = fluid.data(name="ins_tag_weight", shape=[-1], dtype="float32")
fc_out = fluid.layers.fc(input=data, size=2)
predict = fluid.layers.softmax(input=fc_out)
result=fluid.layers.auc(input=predict, label=label)
......@@ -169,14 +174,22 @@ def auc(input,
exe.run(fluid.default_startup_program())
x = np.random.rand(3,32,32).astype("float32")
y = np.array([1,0,1])
output= exe.run(feed={"input": x,"label": y},
z = np.array([1,1,1]) #this means real data
output= exe.run(feed={"input": x,"label": y, "ins_tag_weight": z},
fetch_list=[result[0]])
print(output)
#[array([0.5])]
"""
helper = LayerHelper("auc", **locals())
if ins_tag_weight is None:
ins_tag_weight = tensor.fill_constant(
shape=[1, 1], dtype="float32", value=1.0)
check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'auc')
check_variable_and_dtype(label, 'label', ['int32', 'int64'], 'auc')
check_variable_and_dtype(ins_tag_weight, 'ins_tag_weight',
['float32', 'float64'], 'auc')
auc_out = helper.create_variable_for_type_inference(dtype="float64")
batch_auc_out = helper.create_variable_for_type_inference(dtype="float64")
# make tp, tn, fp, fn persistable, so that can accumulate all batches.
......@@ -215,7 +228,8 @@ def auc(input,
"Predict": [input],
"Label": [label],
"StatPos": [batch_stat_pos],
"StatNeg": [batch_stat_neg]
"StatNeg": [batch_stat_neg],
"InsTagWeight": [ins_tag_weight]
},
attrs={
"curve": curve,
......@@ -234,7 +248,8 @@ def auc(input,
"Predict": [input],
"Label": [label],
"StatPos": [stat_pos],
"StatNeg": [stat_neg]
"StatNeg": [stat_neg],
"InsTagWeight": [ins_tag_weight]
},
attrs={
"curve": curve,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册