From b26e9bd2326029de54901031ba93458f32a0db5b Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Tue, 12 Mar 2019 03:48:33 +0000 Subject: [PATCH] refine code test=develop --- paddle/fluid/operators/cross_entropy2_op.cc | 117 ++---------- paddle/fluid/operators/cross_entropy2_op.h | 104 ++--------- paddle/fluid/operators/cross_entropy_op.cc | 137 ++------------ .../fluid/operators/cross_entropy_op_base.h | 169 ++++++++++++++++++ paddle/fluid/operators/expand_op.cc | 1 + paddle/fluid/operators/math.h | 42 +++++ paddle/fluid/operators/math/cross_entropy.cu | 13 +- paddle/fluid/operators/selu_op.h | 5 +- .../sequence_ops/sequence_softmax_op.cu | 4 +- .../sigmoid_cross_entropy_with_logits_op.cu | 6 +- 10 files changed, 259 insertions(+), 339 deletions(-) create mode 100644 paddle/fluid/operators/cross_entropy_op_base.h create mode 100644 paddle/fluid/operators/math.h diff --git a/paddle/fluid/operators/cross_entropy2_op.cc b/paddle/fluid/operators/cross_entropy2_op.cc index 03b217a974..181d373cfc 100644 --- a/paddle/fluid/operators/cross_entropy2_op.cc +++ b/paddle/fluid/operators/cross_entropy2_op.cc @@ -16,46 +16,22 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/operators/cross_entropy_op_base.h" namespace paddle { namespace operators { -class CrossEntropyOp2 : public framework::OperatorWithKernel { +class CrossEntropyOp2 : public CrossEntropyOpBase { public: - using framework::OperatorWithKernel::OperatorWithKernel; + using CrossEntropyOpBase::CrossEntropyOpBase; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); - PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + CrossEntropyOpBase::InferShape(ctx); - PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null."); PADDLE_ENFORCE(ctx->HasOutput("XShape"), "Output(XShape) should be not null."); auto x_dims = ctx->GetInputDim("X"); - auto label_dims = ctx->GetInputDim("Label"); - int rank = x_dims.size(); - PADDLE_ENFORCE_EQ(rank, label_dims.size(), - "Input(X) and Input(Label) shall have the same rank."); - bool check = true; - if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || - framework::product(label_dims) <= 0)) { - check = false; - } - if (check) { - PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), - framework::slice_ddim(label_dims, 0, rank - 1), - "Input(X) and Input(Label) shall have the same shape " - "except the last dimension."); - } - - PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL, - "Last dimension of Input(Label) should be 1."); - auto y_dims = x_dims; - y_dims[rank - 1] = 1; - ctx->SetOutputDim("Y", y_dims); - ctx->ShareLoD("X", /*->*/ "Y"); - auto x_dims_vec = framework::vectorize(x_dims); x_dims_vec.push_back(0); ctx->SetOutputDim("XShape", framework::make_ddim(x_dims_vec)); @@ -63,73 +39,25 @@ class CrossEntropyOp2 : public framework::OperatorWithKernel { } protected: - // Explicitly set that the data type of computation kernel of cross_entropy - // is determined by its input "X". - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + bool IsSoftLabel(framework::InferShapeContext* ctx) const override { + return false; } }; -class CrossEntropyGradientOp2 : public framework::OperatorWithKernel { +class CrossEntropyGradientOp2 : public CrossEntropyGradientOpBase { public: - using framework::OperatorWithKernel::OperatorWithKernel; + using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase; - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); - PADDLE_ENFORCE(ctx->HasInput("XShape"), - "Input(XShape) should be not null."); - PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null."); - - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), - "Input(Y@GRAD) shoudl be not null."); - - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), - "Output(X@GRAD) should be not null."); - - auto x_shapes = ctx->GetInputDim("XShape"); - framework::DDim x_dims(x_shapes.Get(), x_shapes.size() - 1); - auto label_dims = ctx->GetInputDim("Label"); - auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); - int rank = x_dims.size(); - PADDLE_ENFORCE_EQ(dy_dims.size(), rank, - "Input(Y@Grad) and Input(X) should have the same rank."); - PADDLE_ENFORCE_EQ(label_dims.size(), rank, - "Input(Label) and Input(X) should have the same rank."); - - bool check = true; - if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || - framework::product(label_dims) <= 0)) { - check = false; - } - - if (check) { - PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), - framework::slice_ddim(label_dims, 0, rank - 1), - "The Input(X) and Input(Label) should have the same " - "shape except the last dimension."); - PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), - framework::slice_ddim(dy_dims, 0, rank - 1), - "The Input(X) and Input(Y@Grad) should have the same " - "shape except the last dimension."); - } - PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1, - "The last dimension of Input(Y@Grad) should be 1."); - PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1, - "Last dimension of Input(Label) should be 1."); - ctx->SetOutputDim(framework::GradVarName("X"), x_dims); - ctx->ShareLoD("XShape", framework::GradVarName("X")); + protected: + virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const { + auto x_shape = ctx->GetInputDim("XShape"); + return framework::DDim(x_shape.Get(), x_shape.size() - 1); } - protected: - // Explicitly set that the data type of computation kernel of cross_entropy - // is determined by its input "X". - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Y"))->type(), - ctx.device_context()); + virtual const char* VarNameWithXLoD() const { return "XShape"; } + + virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const { + return false; } }; @@ -156,7 +84,7 @@ class CrossEntropyOpMaker2 : public framework::OpProtoAndCheckerMaker { "Only valid if soft_label is set to False") .SetDefault(-100); AddComment(R"DOC( -CrossEntropy Operator. +Hard-label CrossEntropy Operator. The input 'X' and 'Label' will first be logically flattened to 2-D matrixs. The matrix's second dimension(row length) is as same as the original last @@ -173,15 +101,6 @@ or not. But the output only shares the LoD information with input X. } }; -class CrossEntropyOpInferVarType2 - : public framework::PassInDtypeAndVarTypeToOutput { - protected: - std::unordered_map GetInputOutputWithSameType() - const override { - return std::unordered_map{{"X", /*->*/ "Y"}}; - } -}; - class CrossEntropyGradOpMaker2 : public framework::SingleGradOpDescMaker { public: using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; @@ -207,7 +126,7 @@ namespace ops = paddle::operators; using CPUCtx = paddle::platform::CPUDeviceContext; REGISTER_OPERATOR(cross_entropy2, ops::CrossEntropyOp2, - ops::CrossEntropyOpMaker2, ops::CrossEntropyOpInferVarType2, + ops::CrossEntropyOpMaker2, ops::CrossEntropyOpInferVarType, ops::CrossEntropyGradOpMaker2); REGISTER_OPERATOR(cross_entropy_grad2, ops::CrossEntropyGradientOp2); REGISTER_OP_CPU_KERNEL(cross_entropy2, diff --git a/paddle/fluid/operators/cross_entropy2_op.h b/paddle/fluid/operators/cross_entropy2_op.h index 3d209f7c5c..3e9dc7ebce 100644 --- a/paddle/fluid/operators/cross_entropy2_op.h +++ b/paddle/fluid/operators/cross_entropy2_op.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/for_range.h" @@ -26,81 +27,6 @@ namespace operators { using Tensor = framework::Tensor; -HOSTDEVICE inline platform::float16 RealLog(platform::float16 x) { -#ifdef __NVCC__ - return static_cast(logf(static_cast(x))); -#else - return static_cast(std::log(static_cast(x))); -#endif -} - -HOSTDEVICE inline float RealLog(float x) { -#ifdef __NVCC__ - return logf(x); -#else - return std::log(x); -#endif -} - -HOSTDEVICE inline double RealLog(double x) { -#ifdef __NVCC__ - return log(x); -#else - return std::log(x); -#endif -} - -HOSTDEVICE inline platform::float16 RealExp(platform::float16 x) { -#ifdef __NVCC__ - return static_cast(expf(static_cast(x))); -#else - return static_cast(std::exp(static_cast(x))); -#endif -} - -HOSTDEVICE inline float RealExp(float x) { -#ifdef __NVCC__ - return expf(x); -#else - return std::exp(x); -#endif -} - -HOSTDEVICE inline double RealExp(double x) { -#ifdef __NVCC__ - return exp(x); -#else - return std::exp(x); -#endif -} - -template -struct CrossEntropyForwardFunctor { - CrossEntropyForwardFunctor(const T *x, T *y, const int64_t *label, - int64_t ignore_index, int64_t feature_size) - : x_(x), - y_(y), - label_(label), - ignore_index_(ignore_index), - feature_size_(feature_size) {} - - HOSTDEVICE void operator()(int64_t row_idx) const { - auto col_idx = label_[row_idx]; - if (col_idx != ignore_index_) { - y_[row_idx] = -math::TolerableValue()( - RealLog(x_[row_idx * feature_size_ + col_idx])); - } else { - y_[row_idx] = 0; - } - } - - const T *x_; - T *y_; - const int64_t *label_; - int64_t ignore_index_; - int64_t feature_size_; -}; - template struct CrossEntropyBackwardFunctor { CrossEntropyBackwardFunctor(T *dx, const T *y, const T *dy, @@ -118,7 +44,7 @@ struct CrossEntropyBackwardFunctor { auto col_idx = idx % feature_size_; auto label = label_[row_idx]; if (label == col_idx && label != ignore_index_) { - dx_[idx] = -dy_[row_idx] * RealExp(y_[row_idx]); + dx_[idx] = -dy_[row_idx] * real_exp(y_[row_idx]); } else { dx_[idx] = 0; } @@ -136,24 +62,20 @@ template class CrossEntropyOpKernel2 : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { - auto *x = ctx.Input("X"); - auto *label = ctx.Input("Label"); - auto *y = ctx.Output("Y"); + auto *x_original = ctx.Input("X"); + int rank = x_original->dims().size(); - auto *p_y = y->mutable_data(ctx.GetPlace()); - auto *p_x = x->data(); - auto *p_label = label->data(); + auto x = framework::ReshapeToMatrix(*x_original, rank - 1); + auto label = + framework::ReshapeToMatrix(*ctx.Input("Label"), rank - 1); + auto *y = ctx.Output("Y"); + y->mutable_data(ctx.GetPlace()); - int rank = x->dims().size(); - int64_t feature_size = x->dims()[rank - 1]; - int64_t batch_size = framework::product(x->dims()) / feature_size; + auto ignore_index = ctx.Attr("ignore_index"); - int64_t ignore_index = ctx.Attr("ignore_index"); - - platform::ForRange for_range( - ctx.template device_context(), batch_size); - for_range(CrossEntropyForwardFunctor(p_x, p_y, p_label, ignore_index, - feature_size)); + math::CrossEntropyFunctor()( + ctx.template device_context(), y, &x, &label, false, + ignore_index); } }; diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index 3adc7baebd..1707f7078c 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -14,128 +14,11 @@ limitations under the License. */ #include "paddle/fluid/operators/cross_entropy_op.h" #include +#include "paddle/fluid/operators/cross_entropy_op_base.h" namespace paddle { namespace operators { -class CrossEntropyOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); - PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); - PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null."); - - auto x_dims = ctx->GetInputDim("X"); - auto label_dims = ctx->GetInputDim("Label"); - int rank = x_dims.size(); - PADDLE_ENFORCE_EQ(rank, label_dims.size(), - "Input(X) and Input(Label) shall have the same rank."); - bool check = true; - if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || - framework::product(label_dims) <= 0)) { - check = false; - } - if (check) { - PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), - framework::slice_ddim(label_dims, 0, rank - 1), - "Input(X) and Input(Label) shall have the same shape " - "except the last dimension."); - } - if (ctx->Attrs().Get("soft_label")) { - if (check) { - PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1], - "If Attr(soft_label) == true, the last dimension of " - "Input(X) and Input(Label) should be equal."); - } - } else { - PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL, - "If Attr(softLabel) == false, the last dimension of " - "Input(Label) should be 1."); - } - - auto y_dims = x_dims; - y_dims[rank - 1] = 1; - ctx->SetOutputDim("Y", y_dims); - ctx->ShareLoD("X", /*->*/ "Y"); - } - - protected: - // Explicitly set that the data type of computation kernel of cross_entropy - // is determined by its input "X". - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); - } -}; - -class CrossEntropyGradientOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); - PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), - "Input(Y@GRAD) shoudl be not null."); - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), - "Output(X@GRAD) should be not null."); - - auto x_dims = ctx->GetInputDim("X"); - auto label_dims = ctx->GetInputDim("Label"); - auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); - int rank = x_dims.size(); - PADDLE_ENFORCE_EQ(dy_dims.size(), rank, - "Input(Y@Grad) and Input(X) should have the same rank."); - PADDLE_ENFORCE_EQ(label_dims.size(), rank, - "Input(Label) and Input(X) should have the same rank."); - - bool check = true; - if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || - framework::product(label_dims) <= 0)) { - check = false; - } - - if (check) { - PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), - framework::slice_ddim(label_dims, 0, rank - 1), - "The Input(X) and Input(Label) should have the same " - "shape except the last dimension."); - PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), - framework::slice_ddim(dy_dims, 0, rank - 1), - "The Input(X) and Input(Y@Grad) should have the same " - "shape except the last dimension."); - } - PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1, - "The last dimension of Input(Y@Grad) should be 1."); - if (ctx->Attrs().Get("soft_label")) { - if (check) { - PADDLE_ENFORCE_EQ( - x_dims[rank - 1], label_dims[rank - 1], - "When Attr(soft_label) == true, the last dimension of " - "Input(X) and Input(Label) should be equal."); - } - } else { - PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1, - "When Attr(soft_label) == false, the last dimension of " - "Input(Label) should be 1."); - } - ctx->SetOutputDim(framework::GradVarName("X"), x_dims); - ctx->ShareLoD("X", framework::GradVarName("X")); - } - - protected: - // Explicitly set that the data type of computation kernel of cross_entropy - // is determined by its input "X". - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); - } -}; - class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -200,22 +83,24 @@ or not. But the output only shares the LoD information with input X. } }; -class CrossEntropyOpInferVarType - : public framework::PassInDtypeAndVarTypeToOutput { - protected: - std::unordered_map GetInputOutputWithSameType() - const override { - return std::unordered_map{{"X", /*->*/ "Y"}}; +class CrossEntropyGradientOp : public CrossEntropyGradientOpBase { + public: + using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + CrossEntropyGradientOpBase::InferShape(ctx); } }; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; using CPUCtx = paddle::platform::CPUDeviceContext; -REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker, - ops::CrossEntropyOpInferVarType, +REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOpBase, + ops::CrossEntropyOpMaker, ops::CrossEntropyOpInferVarType, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(cross_entropy_grad, ops::CrossEntropyGradientOp); REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel, diff --git a/paddle/fluid/operators/cross_entropy_op_base.h b/paddle/fluid/operators/cross_entropy_op_base.h new file mode 100644 index 0000000000..c3e5254c37 --- /dev/null +++ b/paddle/fluid/operators/cross_entropy_op_base.h @@ -0,0 +1,169 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class CrossEntropyOpBase : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + + PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto label_dims = ctx->GetInputDim("Label"); + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(rank, label_dims.size(), + "Input(X) and Input(Label) shall have the same rank."); + bool check = true; + if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || + framework::product(label_dims) <= 0)) { + check = false; + } + if (check) { + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_dims, 0, rank - 1), + "Input(X) and Input(Label) shall have the same shape " + "except the last dimension."); + } + + if (IsSoftLabel(ctx)) { + if (check) { + PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1], + "If Attr(soft_label) == true, the last dimension of " + "Input(X) and Input(Label) should be equal."); + } + } else { + PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL, + "If Attr(softLabel) == false, the last dimension of " + "Input(Label) should be 1."); + } + + auto y_dims = x_dims; + y_dims[rank - 1] = 1; + ctx->SetOutputDim("Y", y_dims); + ctx->ShareLoD("X", /*->*/ "Y"); + } + + protected: + // Explicitly set that the data type of computation kernel of cross_entropy + // is determined by its input "X". + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } + + virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const { + return ctx->Attrs().Get("soft_label"); + } +}; + +class CrossEntropyOpInferVarType + : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map GetInputOutputWithSameType() + const override { + return std::unordered_map{{"X", /*->*/ "Y"}}; + } +}; + +class CrossEntropyGradientOpBase : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), + "Input(Y@GRAD) shoudl be not null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@GRAD) should be not null."); + + auto x_dims = GetXDim(ctx); + auto label_dims = ctx->GetInputDim("Label"); + auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(dy_dims.size(), rank, + "Input(Y@Grad) and Input(X) should have the same rank."); + PADDLE_ENFORCE_EQ(label_dims.size(), rank, + "Input(Label) and Input(X) should have the same rank."); + + bool check = true; + if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || + framework::product(label_dims) <= 0)) { + check = false; + } + + if (check) { + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_dims, 0, rank - 1), + "The Input(X) and Input(Label) should have the same " + "shape except the last dimension."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(dy_dims, 0, rank - 1), + "The Input(X) and Input(Y@Grad) should have the same " + "shape except the last dimension."); + } + if (IsSoftLabel(ctx)) { + if (check) { + PADDLE_ENFORCE_EQ( + x_dims[rank - 1], label_dims[rank - 1], + "When Attr(soft_label) == true, the last dimension of " + "Input(X) and Input(Label) should be equal."); + } + } else { + PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1, + "When Attr(soft_label) == false, the last dimension of " + "Input(Label) should be 1."); + } + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1, + "The last dimension of Input(Y@Grad) should be 1."); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + ctx->ShareLoD(VarNameWithXLoD(), framework::GradVarName("X")); + } + + protected: + // Explicitly set that the data type of computation kernel of cross_entropy + // is determined by its input "X". + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Y"))->type(), + ctx.device_context()); + } + + virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const { + return ctx->GetInputDim("X"); + } + + virtual const char* VarNameWithXLoD() const { return "X"; } + + virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const { + return ctx->Attrs().Get("soft_label"); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/expand_op.cc b/paddle/fluid/operators/expand_op.cc index ce3d9a7aac..fcb2be9363 100644 --- a/paddle/fluid/operators/expand_op.cc +++ b/paddle/fluid/operators/expand_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/expand_op.h" +#include #include namespace paddle { diff --git a/paddle/fluid/operators/math.h b/paddle/fluid/operators/math.h new file mode 100644 index 0000000000..8cc24200d3 --- /dev/null +++ b/paddle/fluid/operators/math.h @@ -0,0 +1,42 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/hostdevice.h" + +#include "math.h" // NOLINT + +namespace paddle { +namespace operators { + +inline HOSTDEVICE platform::float16 real_exp(platform::float16 x) { + return static_cast(::expf(static_cast(x))); +} + +inline HOSTDEVICE float real_exp(float x) { return ::expf(x); } + +inline HOSTDEVICE double real_exp(double x) { return ::exp(x); } + +inline HOSTDEVICE platform::float16 real_log(platform::float16 x) { + return static_cast(::logf(static_cast(x))); +} + +inline HOSTDEVICE float real_log(float x) { return ::logf(x); } + +inline HOSTDEVICE double real_log(double x) { return ::log(x); } + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/cross_entropy.cu b/paddle/fluid/operators/math/cross_entropy.cu index cb200ec8d6..44cbdf2e98 100644 --- a/paddle/fluid/operators/math/cross_entropy.cu +++ b/paddle/fluid/operators/math/cross_entropy.cu @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -20,17 +21,6 @@ namespace paddle { namespace operators { namespace math { -namespace { - -__device__ __forceinline__ float real_log(float x) { return logf(x); } - -__device__ __forceinline__ double real_log(double x) { return log(x); } - -__device__ __forceinline__ platform::float16 real_log( - const platform::float16& val) { - return static_cast(logf(static_cast(val))); -} - template __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label, const int N, const int D, @@ -61,7 +51,6 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, Y[blockIdx.x] = -val; } } -} // namespace template class CrossEntropyFunctor { diff --git a/paddle/fluid/operators/selu_op.h b/paddle/fluid/operators/selu_op.h index bdb506885c..b2fc834c42 100644 --- a/paddle/fluid/operators/selu_op.h +++ b/paddle/fluid/operators/selu_op.h @@ -15,13 +15,12 @@ limitations under the License. */ #pragma once #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math.h" #include "paddle/fluid/platform/for_range.h" + namespace paddle { namespace operators { -static HOSTDEVICE float real_exp(float x) { return expf(x); } -static HOSTDEVICE float real_exp(double x) { return exp(x); } - template struct SeluFunctor { SeluFunctor(const T* x_data_ptr, float alpha, float scale, T* y_data_ptr) diff --git a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu index cc5e982190..a9dc0a4fda 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu @@ -14,6 +14,7 @@ limitations under the License. */ #include #include // NOLINT +#include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/sequence_ops/sequence_softmax_op.h" namespace paddle { @@ -21,9 +22,6 @@ namespace operators { using LoDTensor = framework::LoDTensor; -__device__ __forceinline__ float real_exp(float x) { return expf(x); } -__device__ __forceinline__ double real_exp(double x) { return exp(x); } - template using BlockReduce = cub::BlockReduce; diff --git a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu index 2a4570ef5c..aea69de643 100644 --- a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu +++ b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "cub/cub.cuh" +#include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/hostdevice.h" @@ -21,11 +22,6 @@ namespace operators { using Tensor = framework::Tensor; -static HOSTDEVICE float real_exp(float x) { return expf(x); } -static HOSTDEVICE float real_exp(double x) { return exp(x); } -static HOSTDEVICE float real_log(float x) { return logf(x); } -static HOSTDEVICE float real_log(double x) { return log(x); } - static constexpr int kNumCUDAThreads = 512; static constexpr int kNumMaxinumNumBlocks = 4096; -- GitLab