未验证 提交 351ccb63 编写于 作者: L lzydev 提交者: GitHub

Autogen softmax_with_cross_entropy (#52515)

* autogen softmax_with_cross_entropy

* fix error in softmax_with_cross_entropy version
上级 f9aaa1e4
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace paddle {
namespace operators {
class SoftmaxWithCrossEntropyOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Logits",
"(Tensor, default: Tensor<float>), The input tensor of unscaled "
"log probabilities, whose dimension :attr:`axis` should be scaled "
"by softmax.");
AddInput(
"Label",
"(Tensor) The input tensor of groud truth label. If :attr:`soft_label` "
"is set to false, Label is a Tensor<int64> in same shape with "
"Input(Logits) except the shape in dimension :attr:`axis` as 1. If "
"soft_label is set to true, Label is a Tensor<float/double> in same "
"shape with Input(Logits).");
AddOutput(
"Softmax",
"(Tensor, default: Tensor<float>), A tensor in same shape with "
"Input(Logits). "
"The outputs value of softmax activation by given the input batch, "
"which will be used in backward calculation.")
.AsIntermediate();
AddOutput("Loss",
"(Tensor, default: Tensor<float>), A tensor in same shape with "
"Input(Logits) "
"except the shape in dimension :attr:`axis` as 1. The cross "
"entropy loss.");
AddAttr<bool>(
"soft_label",
"(bool, default: false), A flag to indicate whether to interpretant "
"the given labels as soft labels.")
.SetDefault(false);
AddAttr<bool>(
"use_softmax",
"(bool, default: true), A flag to indicate whether to do softmax ")
.SetDefault(true);
AddAttr<bool>(
"numeric_stable_mode",
"(bool, default: true), A flag to indicate whether to use more "
"numerically stable algorithm. This flag is only valid when "
"soft_label is false and GPU is used.")
.SetDefault(true);
AddAttr<int>(
"ignore_index",
"(int, default -100), Specifies a target value that is ignored and"
"does not contribute to the input gradient. Only valid if soft_label"
"is set to False")
.SetDefault(-100);
AddAttr<int>("axis",
"The dimension index of Input(Logits) to perform softmax,"
"default -1 for last dimension")
.SetDefault(-1);
AddComment(R"DOC(
Softmax With Cross Entropy Operator.
Cross entropy loss with softmax is used as the output layer extensively. This
operator computes the softmax normalized values for each row of the input
tensor, after which cross-entropy loss is computed. This provides a more
numerically stable gradient.
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
When the attribute soft_label is set false, this operators expects mutually
exclusive hard labels, each sample in a batch is in exactly one class with a
probability of 1.0. Each sample in the batch will have a single label.
The equation is as follows:
1) Hard label (one-hot label, so every sample has exactly one class)
$$Loss_j = -\text{Logit}_{Label_j} +
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right),
j = 1,..., K$$
2) Soft label (each sample can have a distribution over all classes)
$$Loss_j = -\sum_{i=0}^{K}\text{Label}_i \left(\text{Logit}_i -
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right),
j = 1,...,K$$
)DOC");
}
};
class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Logits"),
true,
platform::errors::InvalidArgument("Input(Logits) should be not null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Label"),
true,
platform::errors::InvalidArgument("Input(Label) should be not null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Softmax"),
true,
platform::errors::InvalidArgument(
"Output(Softmax) should be not null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Loss"),
true,
platform::errors::InvalidArgument("Output(Loss) should be not null."));
auto axis = ctx->Attrs().Get<int>("axis");
auto logits_dims = ctx->GetInputDim("Logits");
auto labels_dims = ctx->GetInputDim("Label");
auto logits_rank = logits_dims.size();
PADDLE_ENFORCE_GE(axis,
-logits_rank,
platform::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(Logits)."));
PADDLE_ENFORCE_LT(axis,
logits_rank,
platform::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(Logits)."));
axis = phi::funcs::CanonicalAxis(axis, logits_rank);
for (int i = 0; i < logits_rank; i++) {
if (i != axis) {
if (ctx->IsRuntime() || (logits_dims[i] > 0 && labels_dims[i] > 0)) {
PADDLE_ENFORCE_EQ(logits_dims[i],
labels_dims[i],
platform::errors::InvalidArgument(
"Input(Logits) and Input(Label) should in "
"same shape in dimensions except axis."));
}
}
}
auto numeric_stable_mode = ctx->Attrs().Get<bool>("numeric_stable_mode");
if (axis != logits_rank - 1) {
PADDLE_ENFORCE_EQ(numeric_stable_mode,
true,
platform::errors::InvalidArgument(
"Attr(axis) can only be -1 "
"when not in numeric_stable_mode."));
}
bool soft_label = ctx->Attrs().Get<bool>("soft_label");
if (soft_label) {
if (ctx->IsRuntime() ||
(logits_dims[axis] > 0 && labels_dims[axis] > 0)) {
PADDLE_ENFORCE_EQ(logits_dims[axis],
labels_dims[axis],
platform::errors::InvalidArgument(
"If Attr(soft_label) == true, "
"the axis dimension of "
"Input(X) and Input(Label) should be equal."));
}
} else {
if (ctx->IsRuntime() || labels_dims[axis] > 0) {
PADDLE_ENFORCE_EQ(
labels_dims[axis],
1UL,
platform::errors::InvalidArgument("If Attr(soft_label) == false, "
"the axis dimension of "
"Input(Label) should be 1."));
}
}
ctx->SetOutputDim("Softmax", logits_dims);
logits_dims[axis] = 1;
ctx->SetOutputDim("Loss", logits_dims);
ctx->ShareLoD("Logits", /*->*/ "Softmax");
ctx->ShareLoD("Logits", /*->*/ "Loss");
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(
OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), ctx.GetPlace());
}
};
class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Loss")),
true,
platform::errors::InvalidArgument(
"Input(Loss@Grad) should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Softmax"),
true,
platform::errors::InvalidArgument(
"Input(Softmax) should be not null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Label"),
true,
platform::errors::InvalidArgument("Input(Label) should be not null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Logits")),
true,
platform::errors::InvalidArgument(
"Output(Logits@Grad) should be not null."));
auto axis = ctx->Attrs().Get<int>("axis");
auto softmax_dims = ctx->GetInputDim("Softmax");
auto labels_dims = ctx->GetInputDim("Label");
auto softmax_rank = softmax_dims.size();
PADDLE_ENFORCE_GE(axis,
-softmax_rank,
platform::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(Logits)."));
PADDLE_ENFORCE_LT(axis,
softmax_rank,
platform::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(Logits)."));
axis = phi::funcs::CanonicalAxis(axis, softmax_rank);
for (int i = 0; i < softmax_rank; i++) {
if (i != axis) {
if (ctx->IsRuntime() || (softmax_dims[i] > 0 && labels_dims[i] > 0)) {
PADDLE_ENFORCE_EQ(
softmax_dims[i],
labels_dims[i],
platform::errors::InvalidArgument(
"Input(Logits) and Input(Label) should in same shape in "
"dimensions except axis."));
}
}
}
bool soft_label = ctx->Attrs().Get<bool>("soft_label");
if (soft_label) {
if (ctx->IsRuntime() ||
(softmax_dims[axis] > 0 && labels_dims[axis] > 0)) {
PADDLE_ENFORCE_EQ(softmax_dims[axis],
labels_dims[axis],
platform::errors::InvalidArgument(
"If Attr(soft_label) == true, "
"the axis dimension of "
"Input(X) and Input(Label) should be equal."));
}
} else {
if (ctx->IsRuntime() || labels_dims[axis] > 0) {
PADDLE_ENFORCE_EQ(
labels_dims[axis],
1UL,
platform::errors::InvalidArgument("If Attr(soft_label) == false, "
"the axis dimension of "
"Input(Label) should be 1."));
}
}
ctx->SetOutputDim(framework::GradVarName("Logits"),
ctx->GetInputDim("Softmax"));
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Loss")),
ctx.GetPlace());
}
};
template <typename T>
class SoftmaxGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("softmax_with_cross_entropy_grad");
grad_op->SetInput("Label", this->Input("Label"));
grad_op->SetInput("Softmax", this->Output("Softmax"));
grad_op->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss"));
grad_op->SetOutput(framework::GradVarName("Logits"),
this->InputGrad("Logits"));
grad_op->SetAttrMap(this->Attrs());
}
};
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyInplaceInferer,
{"Logits", "Softmax"});
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyGradInplaceInferer,
{"Softmax", framework::GradVarName("Logits")});
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(softmax_with_cross_entropy,
ops::SoftmaxWithCrossEntropyOp,
ops::SoftmaxWithCrossEntropyOpMaker,
ops::SoftmaxGradMaker<paddle::framework::OpDesc>,
ops::SoftmaxGradMaker<paddle::imperative::OpBase>,
ops::SoftmaxWithCrossEntropyInplaceInferer);
REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyOpGrad,
ops::SoftmaxWithCrossEntropyGradInplaceInferer);
REGISTER_OP_VERSION(softmax_with_cross_entropy)
.AddCheckpoint(
R"ROC(
Add a new attribute [use_softmax] )ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"use_softmax", "A flag to indicate whether to do softmax", true));
......@@ -348,6 +348,17 @@
func : crop_grad
data_type : x
- backward_op : cross_entropy_with_softmax_grad
forward : cross_entropy_with_softmax (Tensor input, Tensor label, bool soft_label=false, bool use_softmax=true, bool numeric_stable_mode=true, int ignore_index=-100, int axis=-1) -> Tensor(softmax), Tensor(loss)
args : (Tensor label, Tensor softmax, Tensor loss_grad, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis)
output : Tensor(input_grad)
infer_meta :
func : CrossEntropyWithSoftmaxGradInferMeta
kernel :
func : cross_entropy_with_softmax_grad
data_type : loss_grad
inplace : (softmax -> input_grad)
- backward_op : cross_grad
forward : cross (Tensor x, Tensor y, int axis = 9) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad, int axis)
......
......@@ -239,17 +239,6 @@
kernel :
func : conv3d_transpose_grad
- backward_op : cross_entropy_with_softmax_grad
forward : cross_entropy_with_softmax (Tensor input, Tensor label, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis) -> Tensor(softmax), Tensor(loss)
args : (Tensor label, Tensor softmax, Tensor loss_grad, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis)
output : Tensor(input_grad)
infer_meta :
func : CrossEntropyWithSoftmaxGradInferMeta
kernel :
func : cross_entropy_with_softmax_grad
data_type : softmax
inplace : (softmax -> input_grad)
- backward_op : cumsum_grad
forward : cumsum(Tensor x, Scalar axis, bool flatten, bool exclusive, bool reverse) -> Tensor(out)
args : (Tensor x, Tensor out_grad, Scalar axis, bool flatten, bool exclusive, bool reverse)
......
......@@ -322,17 +322,6 @@
output : Tensor(out)
invoke : copy_to_impl(x, place, blocking)
# Part of python API paddle.nn.functional.cross_entropy
- op : cross_entropy_with_softmax
args : (Tensor input, Tensor label, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis)
output : Tensor(softmax), Tensor(loss)
infer_meta :
func : CrossEntropyWithSoftmaxInferMeta
kernel :
func : cross_entropy_with_softmax
data_type : input
backward : cross_entropy_with_softmax_grad
- op : cumsum
args : (Tensor x, Scalar axis, bool flatten, bool exclusive, bool reverse)
output : Tensor(out)
......
......@@ -456,6 +456,13 @@
outputs :
out : Out
- op : cross_entropy_with_softmax (softmax_with_cross_entropy)
backward : cross_entropy_with_softmax_grad (softmax_with_cross_entropy_grad)
inputs :
{input : Logits, label : Label}
outputs :
{softmax : Softmax, loss : Loss}
- op : cumprod
backward : cumprod_grad
inputs :
......
......@@ -191,6 +191,14 @@
- add_input : ShiftsTensor
comment : The number of places by which the elements of the tensor are shifted.
- op : softmax_with_cross_entropy
version :
- checkpoint : Add a new attribute [use_softmax]
action :
- add_attr : use_softmax
comment : A flag to indicate whether to do softmax
default : "true"
- op : trace
version :
- checkpoint : Upgrade trace add a new attribute [axis2]
......
......@@ -380,6 +380,18 @@
data_type : x
backward : cross_grad
# Part of python API paddle.nn.functional.cross_entropy
- op : cross_entropy_with_softmax
args : (Tensor input, Tensor label, bool soft_label=false, bool use_softmax=true, bool numeric_stable_mode=true, int ignore_index=-100, int axis=-1)
output : Tensor(softmax), Tensor(loss)
inplace : (input -> softmax)
infer_meta :
func : CrossEntropyWithSoftmaxInferMeta
kernel :
func : cross_entropy_with_softmax
data_type : input
backward : cross_entropy_with_softmax_grad
- op : cumprod
args : (Tensor x, int dim)
output : Tensor(out)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature SoftmaxWithCrossEntropyOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("cross_entropy_with_softmax",
{"Logits", "Label"},
{"soft_label",
"use_softmax",
"numeric_stable_mode",
"ignore_index",
"axis"},
{"Softmax", "Loss"});
}
KernelSignature SoftmaxWithCrossEntropyGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("cross_entropy_with_softmax_grad",
{"Label", "Softmax", "Loss@GRAD"},
{"soft_label",
"use_softmax",
"numeric_stable_mode",
"ignore_index",
"axis"},
{"Logits@GRAD"});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(softmax_with_cross_entropy,
cross_entropy_with_softmax);
PD_REGISTER_BASE_KERNEL_NAME(softmax_with_cross_entropy_grad,
cross_entropy_with_softmax_grad);
PD_REGISTER_ARG_MAPPING_FN(softmax_with_cross_entropy,
phi::SoftmaxWithCrossEntropyOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(softmax_with_cross_entropy_grad,
phi::SoftmaxWithCrossEntropyGradOpArgumentMapping);
......@@ -313,13 +313,6 @@ def fluid_softmax_with_cross_entropy(
loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
outputs = {'Softmax': softmax, 'Loss': loss}
if core.is_compiled_with_custom_device(
"npu"
) or core.is_compiled_with_custom_device("mlu"):
backprop = helper.create_variable_for_type_inference(
dtype=logits.dtype
)
outputs['Backprop'] = backprop
helper.append_op(
type='softmax_with_cross_entropy',
inputs={'Logits': logits, 'Label': label},
......@@ -2768,13 +2761,6 @@ def cross_entropy(
out = helper.create_variable_for_type_inference(dtype=input.dtype)
outputs = {'Softmax': softmax, 'Loss': out}
if core.is_compiled_with_custom_device(
"npu"
) or core.is_compiled_with_custom_device("mlu"):
backprop = helper.create_variable_for_type_inference(
dtype=input.dtype
)
outputs['Backprop'] = backprop
helper.append_op(
type='softmax_with_cross_entropy',
inputs={'Logits': input, 'Label': label},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册