diff --git a/paddle/fluid/operators/cross_entropy2_op.cc b/paddle/fluid/operators/cross_entropy2_op.cc deleted file mode 100644 index 181d373cfc3d2853b1c111667c2eb4789bbe5104..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/cross_entropy2_op.cc +++ /dev/null @@ -1,137 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/cross_entropy2_op.h" -#include -#include -#include -#include "paddle/fluid/operators/cross_entropy_op_base.h" - -namespace paddle { -namespace operators { - -class CrossEntropyOp2 : public CrossEntropyOpBase { - public: - using CrossEntropyOpBase::CrossEntropyOpBase; - - void InferShape(framework::InferShapeContext* ctx) const override { - CrossEntropyOpBase::InferShape(ctx); - - PADDLE_ENFORCE(ctx->HasOutput("XShape"), - "Output(XShape) should be not null."); - - auto x_dims = ctx->GetInputDim("X"); - auto x_dims_vec = framework::vectorize(x_dims); - x_dims_vec.push_back(0); - ctx->SetOutputDim("XShape", framework::make_ddim(x_dims_vec)); - ctx->ShareLoD("X", /*->*/ "XShape"); - } - - protected: - bool IsSoftLabel(framework::InferShapeContext* ctx) const override { - return false; - } -}; - -class CrossEntropyGradientOp2 : public CrossEntropyGradientOpBase { - public: - using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase; - - protected: - virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const { - auto x_shape = ctx->GetInputDim("XShape"); - return framework::DDim(x_shape.Get(), x_shape.size() - 1); - } - - virtual const char* VarNameWithXLoD() const { return "XShape"; } - - virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const { - return false; - } -}; - -class CrossEntropyOpMaker2 : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", - "(Tensor, default Tensor), a tensor whose last dimension " - "size is equal to the number of classes. This input is a " - "probability computed by the previous operator, which is almost " - "always the result of a softmax operator."); - AddInput( - "Label", - "(Tensor), the tensor which represents the ground truth. It has the " - "same shape with 'X' except the last dimension. One hot Tensor."); - AddOutput("Y", - "(Tensor, default Tensor), a tensor whose shape is same " - "with 'X' except that the last dimension size is 1. It " - "represents the cross entropy loss."); - AddOutput("XShape", "Temporaily variable to save shape and LoD of X."); - AddAttr("ignore_index", - "(int, default -100), Specifies a target value that is" - "ignored and does not contribute to the input gradient." - "Only valid if soft_label is set to False") - .SetDefault(-100); - AddComment(R"DOC( -Hard-label CrossEntropy Operator. - -The input 'X' and 'Label' will first be logically flattened to 2-D matrixs. -The matrix's second dimension(row length) is as same as the original last -dimension, and the first dimension(column length) is the product of all other -original dimensions. Then the softmax computation will take palce on each raw -of flattened matrixs. - -Only support hard label. - -Both the input X and Label can carry the LoD (Level of Details) information, -or not. But the output only shares the LoD information with input X. - -)DOC"); - } -}; - -class CrossEntropyGradOpMaker2 : public framework::SingleGradOpDescMaker { - public: - using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; - - protected: - std::unique_ptr Apply() const override { - std::unique_ptr op(new framework::OpDesc()); - op->SetType("cross_entropy_grad2"); - op->SetInput("Label", Input("Label")); - op->SetInput("Y", Output("Y")); - op->SetInput("XShape", Output("XShape")); - op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); - op->SetOutput(framework::GradVarName("X"), InputGrad("X")); - op->SetAttrMap(Attrs()); - return op; - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -using CPUCtx = paddle::platform::CPUDeviceContext; - -REGISTER_OPERATOR(cross_entropy2, ops::CrossEntropyOp2, - ops::CrossEntropyOpMaker2, ops::CrossEntropyOpInferVarType, - ops::CrossEntropyGradOpMaker2); -REGISTER_OPERATOR(cross_entropy_grad2, ops::CrossEntropyGradientOp2); -REGISTER_OP_CPU_KERNEL(cross_entropy2, - ops::CrossEntropyOpKernel2, - ops::CrossEntropyOpKernel2); -REGISTER_OP_CPU_KERNEL(cross_entropy_grad2, - ops::CrossEntropyGradientOpKernel2, - ops::CrossEntropyGradientOpKernel2); diff --git a/paddle/fluid/operators/cross_entropy2_op.cu b/paddle/fluid/operators/cross_entropy2_op.cu deleted file mode 100644 index 1868c1b866016d1ea51e28339847b6c890c5ec74..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/cross_entropy2_op.cu +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/cross_entropy2_op.h" -#include "paddle/fluid/platform/float16.h" - -namespace plat = paddle::platform; -namespace ops = paddle::operators; -using CUDACtx = paddle::platform::CUDADeviceContext; -REGISTER_OP_CUDA_KERNEL(cross_entropy2, - ops::CrossEntropyOpKernel2, - ops::CrossEntropyOpKernel2, - ops::CrossEntropyOpKernel2); - -REGISTER_OP_CUDA_KERNEL( - cross_entropy_grad2, ops::CrossEntropyGradientOpKernel2, - ops::CrossEntropyGradientOpKernel2, - ops::CrossEntropyGradientOpKernel2); diff --git a/paddle/fluid/operators/cross_entropy2_op.h b/paddle/fluid/operators/cross_entropy2_op.h deleted file mode 100644 index 3e9dc7ebce263d5c22fa42a8b529a411a76f1ad7..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/cross_entropy2_op.h +++ /dev/null @@ -1,110 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math.h" -#include "paddle/fluid/operators/math/cross_entropy.h" -#include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/platform/for_range.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -struct CrossEntropyBackwardFunctor { - CrossEntropyBackwardFunctor(T *dx, const T *y, const T *dy, - const int64_t *label, int64_t ignore_index, - int64_t feature_size) - : dx_(dx), - y_(y), - dy_(dy), - label_(label), - ignore_index_(ignore_index), - feature_size_(feature_size) {} - - HOSTDEVICE void operator()(int64_t idx) const { - auto row_idx = idx / feature_size_; - auto col_idx = idx % feature_size_; - auto label = label_[row_idx]; - if (label == col_idx && label != ignore_index_) { - dx_[idx] = -dy_[row_idx] * real_exp(y_[row_idx]); - } else { - dx_[idx] = 0; - } - } - - T *dx_; - const T *y_; - const T *dy_; - const int64_t *label_; - int64_t ignore_index_; - int64_t feature_size_; -}; - -template -class CrossEntropyOpKernel2 : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto *x_original = ctx.Input("X"); - int rank = x_original->dims().size(); - - auto x = framework::ReshapeToMatrix(*x_original, rank - 1); - auto label = - framework::ReshapeToMatrix(*ctx.Input("Label"), rank - 1); - auto *y = ctx.Output("Y"); - y->mutable_data(ctx.GetPlace()); - - auto ignore_index = ctx.Attr("ignore_index"); - - math::CrossEntropyFunctor()( - ctx.template device_context(), y, &x, &label, false, - ignore_index); - } -}; - -template -class CrossEntropyGradientOpKernel2 : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto *dx = ctx.Output(framework::GradVarName("X")); - auto *y = ctx.Input("Y"); - auto *dy = ctx.Input(framework::GradVarName("Y")); - auto *label = ctx.Input("Label"); - - auto *p_dx = dx->mutable_data(ctx.GetPlace()); - auto *p_y = y->data(); - auto *p_dy = dy->data(); - auto *p_label = label->data(); - - int64_t ignore_index = ctx.Attr("ignore_index"); - int rank = dx->dims().size(); - int64_t feature_size = dx->dims()[rank - 1]; - int64_t batch_size = framework::product(dx->dims()) / feature_size; - - platform::ForRange for_range( - ctx.template device_context(), - batch_size * feature_size); - for_range(CrossEntropyBackwardFunctor(p_dx, p_y, p_dy, p_label, - ignore_index, feature_size)); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index 1707f7078cad82f0ee5d739b4df6e3aa433ecb3e..dd1b48cecfdc5b7a1244d943eb0ef82418fdde56 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -14,11 +14,154 @@ limitations under the License. */ #include "paddle/fluid/operators/cross_entropy_op.h" #include -#include "paddle/fluid/operators/cross_entropy_op_base.h" namespace paddle { namespace operators { +class CrossEntropyOpBase : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + + PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto label_dims = ctx->GetInputDim("Label"); + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(rank, label_dims.size(), + "Input(X) and Input(Label) shall have the same rank."); + bool check = true; + if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || + framework::product(label_dims) <= 0)) { + check = false; + } + if (check) { + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_dims, 0, rank - 1), + "Input(X) and Input(Label) shall have the same shape " + "except the last dimension."); + } + + if (IsSoftLabel(ctx)) { + if (check) { + PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1], + "If Attr(soft_label) == true, the last dimension of " + "Input(X) and Input(Label) should be equal."); + } + } else { + PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL, + "If Attr(softLabel) == false, the last dimension of " + "Input(Label) should be 1."); + } + + auto y_dims = x_dims; + y_dims[rank - 1] = 1; + ctx->SetOutputDim("Y", y_dims); + ctx->ShareLoD("X", /*->*/ "Y"); + } + + protected: + // Explicitly set that the data type of computation kernel of cross_entropy + // is determined by its input "X". + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } + + virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const { + return ctx->Attrs().Get("soft_label"); + } +}; + +class CrossEntropyGradientOpBase : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), + "Input(Y@GRAD) shoudl be not null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@GRAD) should be not null."); + + auto x_dims = GetXDim(ctx); + auto label_dims = ctx->GetInputDim("Label"); + auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(dy_dims.size(), rank, + "Input(Y@Grad) and Input(X) should have the same rank."); + PADDLE_ENFORCE_EQ(label_dims.size(), rank, + "Input(Label) and Input(X) should have the same rank."); + + bool check = true; + if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || + framework::product(label_dims) <= 0)) { + check = false; + } + + if (check) { + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_dims, 0, rank - 1), + "The Input(X) and Input(Label) should have the same " + "shape except the last dimension."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(dy_dims, 0, rank - 1), + "The Input(X) and Input(Y@Grad) should have the same " + "shape except the last dimension."); + } + if (IsSoftLabel(ctx)) { + if (check) { + PADDLE_ENFORCE_EQ( + x_dims[rank - 1], label_dims[rank - 1], + "When Attr(soft_label) == true, the last dimension of " + "Input(X) and Input(Label) should be equal."); + } + } else { + PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1, + "When Attr(soft_label) == false, the last dimension of " + "Input(Label) should be 1."); + } + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1, + "The last dimension of Input(Y@Grad) should be 1."); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + ctx->ShareLoD(VarNameWithXLoD(), framework::GradVarName("X")); + } + + protected: + // Explicitly set that the data type of computation kernel of cross_entropy + // is determined by its input "X". + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Y"))->type(), + ctx.device_context()); + } + + virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const { + return ctx->GetInputDim("X"); + } + + virtual const char* VarNameWithXLoD() const { return "X"; } + + virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const { + return ctx->Attrs().Get("soft_label"); + } +}; + +class CrossEntropyOpInferVarType + : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map GetInputOutputWithSameType() + const override { + return std::unordered_map{{"X", /*->*/ "Y"}}; + } +}; + class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -87,12 +230,110 @@ class CrossEntropyGradientOp : public CrossEntropyGradientOpBase { public: using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase; - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); CrossEntropyGradientOpBase::InferShape(ctx); } }; +class CrossEntropyOp2 : public CrossEntropyOpBase { + public: + using CrossEntropyOpBase::CrossEntropyOpBase; + + void InferShape(framework::InferShapeContext* ctx) const override { + CrossEntropyOpBase::InferShape(ctx); + + PADDLE_ENFORCE(ctx->HasOutput("XShape"), + "Output(XShape) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto x_dims_vec = framework::vectorize(x_dims); + x_dims_vec.push_back(0); + ctx->SetOutputDim("XShape", framework::make_ddim(x_dims_vec)); + ctx->ShareLoD("X", /*->*/ "XShape"); + } + + protected: + bool IsSoftLabel(framework::InferShapeContext* ctx) const override { + return false; + } +}; + +class CrossEntropyGradientOp2 : public CrossEntropyGradientOpBase { + public: + using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase; + + protected: + virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const { + auto x_shape = ctx->GetInputDim("XShape"); + return framework::DDim(x_shape.Get(), x_shape.size() - 1); + } + + virtual const char* VarNameWithXLoD() const { return "XShape"; } + + virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const { + return false; + } +}; + +class CrossEntropyOpMaker2 : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor), a tensor whose last dimension " + "size is equal to the number of classes. This input is a " + "probability computed by the previous operator, which is almost " + "always the result of a softmax operator."); + AddInput( + "Label", + "(Tensor), the tensor which represents the ground truth. It has the " + "same shape with 'X' except the last dimension. One hot Tensor."); + AddOutput("Y", + "(Tensor, default Tensor), a tensor whose shape is same " + "with 'X' except that the last dimension size is 1. It " + "represents the cross entropy loss."); + AddOutput("XShape", "Temporaily variable to save shape and LoD of X."); + AddAttr("ignore_index", + "(int, default -100), Specifies a target value that is" + "ignored and does not contribute to the input gradient." + "Only valid if soft_label is set to False") + .SetDefault(-100); + AddComment(R"DOC( +Hard-label CrossEntropy Operator. + +The input 'X' and 'Label' will first be logically flattened to 2-D matrixs. +The matrix's second dimension(row length) is as same as the original last +dimension, and the first dimension(column length) is the product of all other +original dimensions. Then the softmax computation will take palce on each raw +of flattened matrixs. + +Only support hard label. + +Both the input X and Label can carry the LoD (Level of Details) information, +or not. But the output only shares the LoD information with input X. + +)DOC"); + } +}; + +class CrossEntropyGradOpDescMaker2 : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("cross_entropy_grad2"); + op->SetInput("Label", Input("Label")); + op->SetInput("Y", Output("Y")); + op->SetInput("XShape", Output("XShape")); + op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetAttrMap(Attrs()); + return op; + } +}; + } // namespace operators } // namespace paddle @@ -108,3 +349,14 @@ REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel, REGISTER_OP_CPU_KERNEL(cross_entropy_grad, ops::CrossEntropyGradientOpKernel, ops::CrossEntropyGradientOpKernel); + +REGISTER_OPERATOR(cross_entropy2, ops::CrossEntropyOp2, + ops::CrossEntropyOpMaker2, ops::CrossEntropyOpInferVarType, + ops::CrossEntropyGradOpDescMaker2); +REGISTER_OPERATOR(cross_entropy_grad2, ops::CrossEntropyGradientOp2); +REGISTER_OP_CPU_KERNEL(cross_entropy2, + ops::CrossEntropyOpKernel2, + ops::CrossEntropyOpKernel2); +REGISTER_OP_CPU_KERNEL(cross_entropy_grad2, + ops::CrossEntropyGradientOpKernel2, + ops::CrossEntropyGradientOpKernel2); diff --git a/paddle/fluid/operators/cross_entropy_op.cu b/paddle/fluid/operators/cross_entropy_op.cu index fcd34383a85f6984a8f27ce0625364f8fd5e31d6..243e7f52c1e3c4c210e91f708ae5d6de97e4afbc 100644 --- a/paddle/fluid/operators/cross_entropy_op.cu +++ b/paddle/fluid/operators/cross_entropy_op.cu @@ -27,3 +27,13 @@ REGISTER_OP_CUDA_KERNEL( cross_entropy_grad, ops::CrossEntropyGradientOpKernel, ops::CrossEntropyGradientOpKernel, ops::CrossEntropyGradientOpKernel); + +REGISTER_OP_CUDA_KERNEL(cross_entropy2, + ops::CrossEntropyOpKernel2, + ops::CrossEntropyOpKernel2, + ops::CrossEntropyOpKernel2); + +REGISTER_OP_CUDA_KERNEL( + cross_entropy_grad2, ops::CrossEntropyGradientOpKernel2, + ops::CrossEntropyGradientOpKernel2, + ops::CrossEntropyGradientOpKernel2); diff --git a/paddle/fluid/operators/cross_entropy_op.h b/paddle/fluid/operators/cross_entropy_op.h index f123e11542d85c904a81fe2a87f59ab52511cc15..05609e4bc20b1c75872be38e057de221a0188b88 100644 --- a/paddle/fluid/operators/cross_entropy_op.h +++ b/paddle/fluid/operators/cross_entropy_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/for_range.h" @@ -137,5 +138,85 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { } }; +template +struct HardLabelCrossEntropyBackwardFunctor { + HardLabelCrossEntropyBackwardFunctor(T* dx, const T* y, const T* dy, + const int64_t* label, + int64_t ignore_index, + int64_t feature_size) + : dx_(dx), + y_(y), + dy_(dy), + label_(label), + ignore_index_(ignore_index), + feature_size_(feature_size) {} + + HOSTDEVICE void operator()(int64_t idx) const { + auto row_idx = idx / feature_size_; + auto col_idx = idx % feature_size_; + auto label = label_[row_idx]; + if (label == col_idx && label != ignore_index_) { + dx_[idx] = -dy_[row_idx] * real_exp(y_[row_idx]); + } else { + dx_[idx] = 0; + } + } + + T* dx_; + const T* y_; + const T* dy_; + const int64_t* label_; + int64_t ignore_index_; + int64_t feature_size_; +}; + +template +class CrossEntropyOpKernel2 : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x_original = ctx.Input("X"); + int rank = x_original->dims().size(); + + auto x = framework::ReshapeToMatrix(*x_original, rank - 1); + auto label = + framework::ReshapeToMatrix(*ctx.Input("Label"), rank - 1); + auto* y = ctx.Output("Y"); + y->mutable_data(ctx.GetPlace()); + + auto ignore_index = ctx.Attr("ignore_index"); + + math::CrossEntropyFunctor()( + ctx.template device_context(), y, &x, &label, false, + ignore_index); + } +}; + +template +class CrossEntropyGradientOpKernel2 : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* y = ctx.Input("Y"); + auto* dy = ctx.Input(framework::GradVarName("Y")); + auto* label = ctx.Input("Label"); + + auto* p_dx = dx->mutable_data(ctx.GetPlace()); + auto* p_y = y->data(); + auto* p_dy = dy->data(); + auto* p_label = label->data(); + + int64_t ignore_index = ctx.Attr("ignore_index"); + int rank = dx->dims().size(); + int64_t feature_size = dx->dims()[rank - 1]; + int64_t batch_size = framework::product(dx->dims()) / feature_size; + + platform::ForRange for_range( + ctx.template device_context(), + batch_size * feature_size); + for_range(HardLabelCrossEntropyBackwardFunctor( + p_dx, p_y, p_dy, p_label, ignore_index, feature_size)); + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/cross_entropy_op_base.h b/paddle/fluid/operators/cross_entropy_op_base.h deleted file mode 100644 index c3e5254c37e0293072a67982004aa57b91de4c36..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/cross_entropy_op_base.h +++ /dev/null @@ -1,169 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -class CrossEntropyOpBase : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); - PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); - - PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null."); - - auto x_dims = ctx->GetInputDim("X"); - auto label_dims = ctx->GetInputDim("Label"); - int rank = x_dims.size(); - PADDLE_ENFORCE_EQ(rank, label_dims.size(), - "Input(X) and Input(Label) shall have the same rank."); - bool check = true; - if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || - framework::product(label_dims) <= 0)) { - check = false; - } - if (check) { - PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), - framework::slice_ddim(label_dims, 0, rank - 1), - "Input(X) and Input(Label) shall have the same shape " - "except the last dimension."); - } - - if (IsSoftLabel(ctx)) { - if (check) { - PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1], - "If Attr(soft_label) == true, the last dimension of " - "Input(X) and Input(Label) should be equal."); - } - } else { - PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL, - "If Attr(softLabel) == false, the last dimension of " - "Input(Label) should be 1."); - } - - auto y_dims = x_dims; - y_dims[rank - 1] = 1; - ctx->SetOutputDim("Y", y_dims); - ctx->ShareLoD("X", /*->*/ "Y"); - } - - protected: - // Explicitly set that the data type of computation kernel of cross_entropy - // is determined by its input "X". - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); - } - - virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const { - return ctx->Attrs().Get("soft_label"); - } -}; - -class CrossEntropyOpInferVarType - : public framework::PassInDtypeAndVarTypeToOutput { - protected: - std::unordered_map GetInputOutputWithSameType() - const override { - return std::unordered_map{{"X", /*->*/ "Y"}}; - } -}; - -class CrossEntropyGradientOpBase : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), - "Input(Y@GRAD) shoudl be not null."); - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), - "Output(X@GRAD) should be not null."); - - auto x_dims = GetXDim(ctx); - auto label_dims = ctx->GetInputDim("Label"); - auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); - int rank = x_dims.size(); - PADDLE_ENFORCE_EQ(dy_dims.size(), rank, - "Input(Y@Grad) and Input(X) should have the same rank."); - PADDLE_ENFORCE_EQ(label_dims.size(), rank, - "Input(Label) and Input(X) should have the same rank."); - - bool check = true; - if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || - framework::product(label_dims) <= 0)) { - check = false; - } - - if (check) { - PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), - framework::slice_ddim(label_dims, 0, rank - 1), - "The Input(X) and Input(Label) should have the same " - "shape except the last dimension."); - PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), - framework::slice_ddim(dy_dims, 0, rank - 1), - "The Input(X) and Input(Y@Grad) should have the same " - "shape except the last dimension."); - } - if (IsSoftLabel(ctx)) { - if (check) { - PADDLE_ENFORCE_EQ( - x_dims[rank - 1], label_dims[rank - 1], - "When Attr(soft_label) == true, the last dimension of " - "Input(X) and Input(Label) should be equal."); - } - } else { - PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1, - "When Attr(soft_label) == false, the last dimension of " - "Input(Label) should be 1."); - } - ctx->SetOutputDim(framework::GradVarName("X"), x_dims); - PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1, - "The last dimension of Input(Y@Grad) should be 1."); - ctx->SetOutputDim(framework::GradVarName("X"), x_dims); - ctx->ShareLoD(VarNameWithXLoD(), framework::GradVarName("X")); - } - - protected: - // Explicitly set that the data type of computation kernel of cross_entropy - // is determined by its input "X". - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Y"))->type(), - ctx.device_context()); - } - - virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const { - return ctx->GetInputDim("X"); - } - - virtual const char* VarNameWithXLoD() const { return "X"; } - - virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const { - return ctx->Attrs().Get("soft_label"); - } -}; - -} // namespace operators -} // namespace paddle