未验证 提交 672bb07e 编写于 作者: X xiaoyuanzi914 提交者: GitHub

add autogen code support for auc_op (#52437)

* add autogen code support for auc_op

* update

---------
Co-authored-by: Nwqgo <1552367872@qq.com>
上级 bb48b596
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
class AucOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return phi::KernelKey(
OperatorWithKernel::IndicateVarDataType(ctx, "Predict"),
ctx.GetPlace());
}
};
class AucOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Predict",
"A floating point 2D tensor with shape [batch_size, 2], values "
"are in the range [0, 1]."
"Typically, this tensor indicates the probability of each label");
AddInput("Label",
"A 2D int tensor indicating the label of the training data. "
"shape: [batch_size, 1]");
// TODO(typhoonzero): support weight input
AddInput("StatPos", "Statistic value when label = 1");
AddInput("StatNeg", "Statistic value when label = 0");
AddInput("InsTagWeight",
"(Tensor, optional) If provided, auc Op will use this "
"1 means real data, 0 means false data")
.AsDispensable();
AddOutput("AUC",
"A scalar representing the "
"current area-under-the-curve.");
AddOutput("StatPosOut", "Statistic value when label = 1");
AddOutput("StatNegOut", "Statistic value when label = 0");
AddAttr<std::string>("curve", "Curve type, can be 'ROC' or 'PR'.")
.SetDefault("ROC");
AddAttr<int>(
"num_thresholds",
"The number of thresholds to use when discretizing the roc curve.")
.SetDefault((2 << 12) - 1);
AddAttr<int>("slide_steps", "Use slide steps to calc batch auc.")
.SetDefault(1);
AddComment(R"DOC(
Area Under The Curve (AUC) Operator.
This implementation computes the AUC according to forward output and label.
It is used very widely in binary classification evaluation. As a note:
If input label contains values other than 0 and 1, it will be cast
to bool. You can find the relevant definitions here:
https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve
There are two types of possible curves:
1. ROC: Receiver operating characteristic
2. PR: Precision Recall
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(auc,
AucInferShapeFunctor,
PD_INFER_META(phi::AucInferMeta));
REGISTER_OP_WITHOUT_GRADIENT(auc,
ops::AucOp,
ops::AucOpMaker,
AucInferShapeFunctor);
REGISTER_OP_VERSION(auc).AddCheckpoint(
R"ROC(
Upgrade auc, add a new input [InsTagWeight].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewInput(
"ValueTensor", "In order to support multi-tag task"));
......@@ -202,15 +202,6 @@
data_type : dtype
backend : place > output
- op : auc
args : (Tensor x, Tensor label, Tensor stat_pos, Tensor stat_neg, Tensor ins_tag_weight, str curve, int num_thresholds, int slide_steps)
output : Tensor(auc), Tensor(stat_pos_out), Tensor(stat_neg_out)
infer_meta :
func : AucInferMeta
kernel :
func : auc
optional : ins_tag_weight
- op : average_accumulates_
args : (Tensor param, Tensor in_sum_1, Tensor in_sum_2, Tensor in_sum_3, Tensor in_num_accumulates, Tensor in_old_num_accumulates, Tensor in_num_updates, float average_window, int64_t max_average_window, int64_t min_average_window)
output : Tensor(out_sum_1), Tensor(out_sum_2), Tensor(out_sum_3), Tensor(out_num_accumulates), Tensor(out_old_num_accumulates), Tensor(out_num_updates)
......
......@@ -168,6 +168,12 @@
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- op : auc
inputs :
{x : Predict, label : Label, stat_pos : StatPos, stat_neg : StatNeg, ins_tag_weight : InsTagWeight}
outputs :
{auc : AUC, stat_pos_out : StatPosOut, stat_neg_out : StatNegOut}
- op : batch_norm
backward : batch_norm_grad
inputs:
......
......@@ -27,6 +27,13 @@
comment : The attribute 'atol' is deleted. The reason why it is deleted is that
attributes do not support a float64 value and it is changed to a tensor.
- op : auc
version :
- checkpoint : Upgrade auc, add a new input [InsTagWeight].
action :
- add_input : ValueTensor
comment : In order to support multi-tag task.
- op : clip
version :
- checkpoint : Upgrade clip add a new input [Min]
......
......@@ -124,6 +124,16 @@
func : atanh
backward : atanh_grad
- op : auc
args : (Tensor x, Tensor label, Tensor stat_pos, Tensor stat_neg, Tensor ins_tag_weight, str curve = "ROC", int num_thresholds = (2 << 12) - 1, int slide_steps = 1)
output : Tensor(auc), Tensor(stat_pos_out), Tensor(stat_neg_out)
infer_meta :
func : AucInferMeta
kernel :
func : auc
data_type : x
optional : ins_tag_weight
- op : bernoulli
args : (Tensor x)
output : Tensor(out)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
// we have to return every specific KernelSignature for infrt now
KernelSignature AucOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"auc",
{"Predict", "Label", "StatPos", "StatNeg", "InsTagWeight"},
{"curve", "num_thresholds", "slide_steps"},
{"AUC", "StatPosOut", "StatNegOut"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(auc, phi::AucOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册