diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index 200bc825e030688923e0ce22552266b1b35f948f..a617b9fb1d948340d25853252be79fdd08fe0438 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -13,19 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/cross_entropy_op.h" +#include #include #include namespace paddle { namespace operators { -class CrossEntropyOp : public framework::OperatorWithKernel { +class CrossEntropyOpBase : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null."); auto x_dims = ctx->GetInputDim("X"); @@ -44,7 +46,8 @@ class CrossEntropyOp : public framework::OperatorWithKernel { "Input(X) and Input(Label) shall have the same shape " "except the last dimension."); } - if (ctx->Attrs().Get("soft_label")) { + + if (IsSoftLabel(ctx)) { if (check) { PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1], "If Attr(soft_label) == true, the last dimension of " @@ -70,21 +73,24 @@ class CrossEntropyOp : public framework::OperatorWithKernel { return framework::OpKernelType(ctx.Input("X")->type(), ctx.device_context()); } + + virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const { + return ctx->Attrs().Get("soft_label"); + } }; -class CrossEntropyGradientOp : public framework::OperatorWithKernel { +class CrossEntropyGradientOpBase : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + void InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), "Input(Y@GRAD) shoudl be not null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), "Output(X@GRAD) should be not null."); - auto x_dims = ctx->GetInputDim("X"); + auto x_dims = GetXDim(ctx); auto label_dims = ctx->GetInputDim("Label"); auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); int rank = x_dims.size(); @@ -109,9 +115,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { "The Input(X) and Input(Y@Grad) should have the same " "shape except the last dimension."); } - PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1, - "The last dimension of Input(Y@Grad) should be 1."); - if (ctx->Attrs().Get("soft_label")) { + if (IsSoftLabel(ctx)) { if (check) { PADDLE_ENFORCE_EQ( x_dims[rank - 1], label_dims[rank - 1], @@ -124,7 +128,10 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { "Input(Label) should be 1."); } ctx->SetOutputDim(framework::GradVarName("X"), x_dims); - ctx->ShareLoD("X", framework::GradVarName("X")); + PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1, + "The last dimension of Input(Y@Grad) should be 1."); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + ctx->ShareLoD(VarNameWithXLoD(), framework::GradVarName("X")); } protected: @@ -132,8 +139,28 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Y"))->type(), + ctx.device_context()); + } + + virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const { + return ctx->GetInputDim("X"); + } + + virtual const char* VarNameWithXLoD() const { return "X"; } + + virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const { + return ctx->Attrs().Get("soft_label"); + } +}; + +class CrossEntropyOpInferVarType + : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map GetInputOutputWithSameType() + const override { + return std::unordered_map{{"X", /*->*/ "Y"}}; } }; @@ -201,22 +228,132 @@ or not. But the output only shares the LoD information with input X. } }; -class CrossEntropyOpInferVarType - : public framework::PassInDtypeAndVarTypeToOutput { +class CrossEntropyGradientOp : public CrossEntropyGradientOpBase { + public: + using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + CrossEntropyGradientOpBase::InferShape(ctx); + } +}; + +class CrossEntropyOp2 : public CrossEntropyOpBase { + public: + using CrossEntropyOpBase::CrossEntropyOpBase; + + void InferShape(framework::InferShapeContext* ctx) const override { + CrossEntropyOpBase::InferShape(ctx); + + PADDLE_ENFORCE(ctx->HasOutput("XShape"), + "Output(XShape) should be not null."); + + PADDLE_ENFORCE(ctx->HasOutput("MatchX"), + "Output(MatchX) should be not null."); + auto x_dims = ctx->GetInputDim("X"); + auto x_dims_vec = framework::vectorize(x_dims); + x_dims_vec.push_back(0); + ctx->SetOutputDim("XShape", framework::make_ddim(x_dims_vec)); + x_dims[x_dims.size() - 1] = 1; + ctx->SetOutputDim("MatchX", x_dims); + ctx->ShareLoD("X", /*->*/ "XShape"); + } + protected: - std::unordered_map GetInputOutputWithSameType() - const override { - return std::unordered_map{{"X", /*->*/ "Y"}}; + bool IsSoftLabel(framework::InferShapeContext* ctx) const override { + return false; + } +}; + +class CrossEntropyGradientOp2 : public CrossEntropyGradientOpBase { + public: + using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("MatchX"), "Input(MatchX) must exist"); + CrossEntropyGradientOpBase::InferShape(ctx); + } + + protected: + virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const { + auto x_shape = ctx->GetInputDim("XShape"); + return framework::DDim(x_shape.Get(), x_shape.size() - 1); + } + + virtual const char* VarNameWithXLoD() const { return "XShape"; } + + virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const { + return false; } }; + +class CrossEntropyOpMaker2 : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor), a tensor whose last dimension " + "size is equal to the number of classes. This input is a " + "probability computed by the previous operator, which is almost " + "always the result of a softmax operator."); + AddInput( + "Label", + "(Tensor), the tensor which represents the ground truth. It has the " + "same shape with 'X' except the last dimension. One hot Tensor."); + AddOutput("Y", + "(Tensor, default Tensor), a tensor whose shape is same " + "with 'X' except that the last dimension size is 1. It " + "represents the cross entropy loss."); + AddOutput("XShape", "Temporaily variable to save shape and LoD of X."); + AddOutput("MatchX", + "X value that matches label, used for gradient computation."); + AddAttr("ignore_index", + "(int, default -100), Specifies a target value that is" + "ignored and does not contribute to the input gradient." + "Only valid if soft_label is set to False") + .SetDefault(-100); + AddComment(R"DOC( +Hard-label CrossEntropy Operator. + +The input 'X' and 'Label' will first be logically flattened to 2-D matrixs. +The matrix's second dimension(row length) is as same as the original last +dimension, and the first dimension(column length) is the product of all other +original dimensions. Then the softmax computation will take palce on each raw +of flattened matrixs. + +Only support hard label. + +Both the input X and Label can carry the LoD (Level of Details) information, +or not. But the output only shares the LoD information with input X. + +)DOC"); + } +}; + +class CrossEntropyGradOpDescMaker2 : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("cross_entropy_grad2"); + op->SetInput("Label", Input("Label")); + op->SetInput("MatchX", Output("MatchX")); + op->SetInput("XShape", Output("XShape")); + op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetAttrMap(Attrs()); + return op; + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; using CPUCtx = paddle::platform::CPUDeviceContext; -REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker, - ops::CrossEntropyOpInferVarType, +REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOpBase, + ops::CrossEntropyOpMaker, ops::CrossEntropyOpInferVarType, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(cross_entropy_grad, ops::CrossEntropyGradientOp); REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel, @@ -224,3 +361,14 @@ REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel, REGISTER_OP_CPU_KERNEL(cross_entropy_grad, ops::CrossEntropyGradientOpKernel, ops::CrossEntropyGradientOpKernel); + +REGISTER_OPERATOR(cross_entropy2, ops::CrossEntropyOp2, + ops::CrossEntropyOpMaker2, ops::CrossEntropyOpInferVarType, + ops::CrossEntropyGradOpDescMaker2); +REGISTER_OPERATOR(cross_entropy_grad2, ops::CrossEntropyGradientOp2); +REGISTER_OP_CPU_KERNEL(cross_entropy2, + ops::CrossEntropyOpKernel2, + ops::CrossEntropyOpKernel2); +REGISTER_OP_CPU_KERNEL(cross_entropy_grad2, + ops::CrossEntropyGradientOpKernel2, + ops::CrossEntropyGradientOpKernel2); diff --git a/paddle/fluid/operators/cross_entropy_op.cu b/paddle/fluid/operators/cross_entropy_op.cu index fcd34383a85f6984a8f27ce0625364f8fd5e31d6..243e7f52c1e3c4c210e91f708ae5d6de97e4afbc 100644 --- a/paddle/fluid/operators/cross_entropy_op.cu +++ b/paddle/fluid/operators/cross_entropy_op.cu @@ -27,3 +27,13 @@ REGISTER_OP_CUDA_KERNEL( cross_entropy_grad, ops::CrossEntropyGradientOpKernel, ops::CrossEntropyGradientOpKernel, ops::CrossEntropyGradientOpKernel); + +REGISTER_OP_CUDA_KERNEL(cross_entropy2, + ops::CrossEntropyOpKernel2, + ops::CrossEntropyOpKernel2, + ops::CrossEntropyOpKernel2); + +REGISTER_OP_CUDA_KERNEL( + cross_entropy_grad2, ops::CrossEntropyGradientOpKernel2, + ops::CrossEntropyGradientOpKernel2, + ops::CrossEntropyGradientOpKernel2); diff --git a/paddle/fluid/operators/cross_entropy_op.h b/paddle/fluid/operators/cross_entropy_op.h index f123e11542d85c904a81fe2a87f59ab52511cc15..7eb663773ed072760c47a2914377b5306ceeb7af 100644 --- a/paddle/fluid/operators/cross_entropy_op.h +++ b/paddle/fluid/operators/cross_entropy_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/for_range.h" @@ -137,5 +138,124 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { } }; +template +struct HardLabelCrossEntropyForwardFunctor { + HardLabelCrossEntropyForwardFunctor(const T* x, T* y, T* match_x, + const int64_t* label, + int64_t ignore_index, + int64_t feature_size) + : x_(x), + y_(y), + match_x_(match_x), + label_(label), + ignore_index_(ignore_index), + feature_size_(feature_size) {} + + HOSTDEVICE void operator()(int64_t idx) const { + auto label = label_[idx]; + if (label != ignore_index_) { + auto match_x = x_[idx * feature_size_ + label]; + y_[idx] = -math::TolerableValue()(real_log(match_x)); + match_x_[idx] = match_x; + } else { + y_[idx] = 0; + match_x_[idx] = 0; // any value is ok + } + } + + const T* x_; + T* y_; + T* match_x_; + const int64_t* label_; + int64_t ignore_index_; + int64_t feature_size_; +}; + +template +struct HardLabelCrossEntropyBackwardFunctor { + HardLabelCrossEntropyBackwardFunctor(T* dx, const T* dy, const T* match_x, + const int64_t* label, + int64_t ignore_index, + int64_t feature_size) + : dx_(dx), + dy_(dy), + match_x_(match_x), + label_(label), + ignore_index_(ignore_index), + feature_size_(feature_size) {} + + HOSTDEVICE void operator()(int64_t idx) const { + auto row_idx = idx / feature_size_; + auto col_idx = idx % feature_size_; + auto label = label_[row_idx]; + if (label == col_idx && label != ignore_index_) { + dx_[idx] = -dy_[row_idx] / match_x_[row_idx]; + } else { + dx_[idx] = 0; + } + } + + T* dx_; + const T* dy_; + const T* match_x_; + const int64_t* label_; + int64_t ignore_index_; + int64_t feature_size_; +}; + +template +class CrossEntropyOpKernel2 : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* label = ctx.Input("Label"); + auto* y = ctx.Output("Y"); + auto* match_x = ctx.Output("MatchX"); + + auto& x_dims = x->dims(); + auto feature_size = x_dims[x_dims.size() - 1]; + auto batch_size = framework::product(x->dims()) / feature_size; + + auto* p_x = x->data(); + auto* p_label = label->data(); + auto* p_y = y->mutable_data(ctx.GetPlace()); + auto* p_match_x = match_x->mutable_data(ctx.GetPlace()); + + auto ignore_index = ctx.Attr("ignore_index"); + + platform::ForRange for_range( + ctx.template device_context(), batch_size); + for_range(HardLabelCrossEntropyForwardFunctor( + p_x, p_y, p_match_x, p_label, ignore_index, feature_size)); + } +}; + +template +class CrossEntropyGradientOpKernel2 : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Input(framework::GradVarName("Y")); + auto* match_x = ctx.Input("MatchX"); + auto* label = ctx.Input("Label"); + + auto* p_dx = dx->mutable_data(ctx.GetPlace()); + auto* p_dy = dy->data(); + auto* p_match_x = match_x->data(); + auto* p_label = label->data(); + + int64_t ignore_index = ctx.Attr("ignore_index"); + int rank = dx->dims().size(); + int64_t feature_size = dx->dims()[rank - 1]; + int64_t batch_size = framework::product(dx->dims()) / feature_size; + + platform::ForRange for_range( + ctx.template device_context(), + batch_size * feature_size); + for_range(HardLabelCrossEntropyBackwardFunctor( + p_dx, p_dy, p_match_x, p_label, ignore_index, feature_size)); + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/expand_op.cc b/paddle/fluid/operators/expand_op.cc index 44a2f37b66772425a835c26e94c37b500e8a5d19..fcb2be93635eeaeaae25c3a845fd06aa1a73e2e7 100644 --- a/paddle/fluid/operators/expand_op.cc +++ b/paddle/fluid/operators/expand_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/expand_op.h" +#include #include namespace paddle { @@ -138,12 +139,28 @@ class ExpandGradOp : public framework::OperatorWithKernel { } }; +class ExpandGradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("expand_grad"); + op->SetInput("X", Input("X")); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetAttrMap(Attrs()); + return op; + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(expand, ops::ExpandOp, ops::ExpandOpMaker, - paddle::framework::DefaultGradOpDescMaker); + ops::ExpandGradOpDescMaker); REGISTER_OPERATOR(expand_grad, ops::ExpandGradOp); REGISTER_OP_CPU_KERNEL( expand, ops::ExpandKernel, diff --git a/paddle/fluid/operators/math.h b/paddle/fluid/operators/math.h new file mode 100644 index 0000000000000000000000000000000000000000..8cc24200d37dffaff1deda2f0181e5875141add0 --- /dev/null +++ b/paddle/fluid/operators/math.h @@ -0,0 +1,42 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/hostdevice.h" + +#include "math.h" // NOLINT + +namespace paddle { +namespace operators { + +inline HOSTDEVICE platform::float16 real_exp(platform::float16 x) { + return static_cast(::expf(static_cast(x))); +} + +inline HOSTDEVICE float real_exp(float x) { return ::expf(x); } + +inline HOSTDEVICE double real_exp(double x) { return ::exp(x); } + +inline HOSTDEVICE platform::float16 real_log(platform::float16 x) { + return static_cast(::logf(static_cast(x))); +} + +inline HOSTDEVICE float real_log(float x) { return ::logf(x); } + +inline HOSTDEVICE double real_log(double x) { return ::log(x); } + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/cross_entropy.cu b/paddle/fluid/operators/math/cross_entropy.cu index cb200ec8d6ea533d546f3e01a16a48c88b14f677..44cbdf2e9882195819bc3ca047dbac6e2fa4e631 100644 --- a/paddle/fluid/operators/math/cross_entropy.cu +++ b/paddle/fluid/operators/math/cross_entropy.cu @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -20,17 +21,6 @@ namespace paddle { namespace operators { namespace math { -namespace { - -__device__ __forceinline__ float real_log(float x) { return logf(x); } - -__device__ __forceinline__ double real_log(double x) { return log(x); } - -__device__ __forceinline__ platform::float16 real_log( - const platform::float16& val) { - return static_cast(logf(static_cast(val))); -} - template __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label, const int N, const int D, @@ -61,7 +51,6 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, Y[blockIdx.x] = -val; } } -} // namespace template class CrossEntropyFunctor { diff --git a/paddle/fluid/operators/selu_op.h b/paddle/fluid/operators/selu_op.h index bdb506885c932708803fe8d84ee705aee0fe02b4..b2fc834c42f65ff3521b6267ed2f32fabbab4e4d 100644 --- a/paddle/fluid/operators/selu_op.h +++ b/paddle/fluid/operators/selu_op.h @@ -15,13 +15,12 @@ limitations under the License. */ #pragma once #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math.h" #include "paddle/fluid/platform/for_range.h" + namespace paddle { namespace operators { -static HOSTDEVICE float real_exp(float x) { return expf(x); } -static HOSTDEVICE float real_exp(double x) { return exp(x); } - template struct SeluFunctor { SeluFunctor(const T* x_data_ptr, float alpha, float scale, T* y_data_ptr) diff --git a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu index cc5e9821903fb7a726f52177df1d17757f697411..a9dc0a4fda253db9bb0d33c4a25fbba36492f35b 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu @@ -14,6 +14,7 @@ limitations under the License. */ #include #include // NOLINT +#include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/sequence_ops/sequence_softmax_op.h" namespace paddle { @@ -21,9 +22,6 @@ namespace operators { using LoDTensor = framework::LoDTensor; -__device__ __forceinline__ float real_exp(float x) { return expf(x); } -__device__ __forceinline__ double real_exp(double x) { return exp(x); } - template using BlockReduce = cub::BlockReduce; diff --git a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu index 2a4570ef5cec0bee07efd69a2efd1a079ff33df5..aea69de6434a38aa834ff14f6d3d15ad5bbfc3e6 100644 --- a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu +++ b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "cub/cub.cuh" +#include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/hostdevice.h" @@ -21,11 +22,6 @@ namespace operators { using Tensor = framework::Tensor; -static HOSTDEVICE float real_exp(float x) { return expf(x); } -static HOSTDEVICE float real_exp(double x) { return exp(x); } -static HOSTDEVICE float real_log(float x) { return logf(x); } -static HOSTDEVICE float real_log(double x) { return log(x); } - static constexpr int kNumCUDAThreads = 512; static constexpr int kNumMaxinumNumBlocks = 4096; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index d0bff52e43470c399c86d4490931092142314501..9886f4e84c2ac9afba92ac7f98284b5a439c70b3 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1432,6 +1432,8 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex): predict = fluid.layers.fc(input=net, size=classdim, act='softmax') cost = fluid.layers.cross_entropy(input=predict, label=label) """ + if not soft_label: + return cross_entropy2(input, label, ignore_index) helper = LayerHelper('cross_entropy', **locals()) out = helper.create_variable_for_type_inference(dtype=input.dtype) helper.append_op( @@ -1444,6 +1446,22 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex): return out +def cross_entropy2(input, label, ignore_index=kIgnoreIndex): + helper = LayerHelper('cross_entropy2', **locals()) + out = helper.create_variable_for_type_inference(dtype=input.dtype) + xshape = helper.create_variable_for_type_inference(dtype=input.dtype) + match_x = helper.create_variable_for_type_inference(dtype=input.dtype) + helper.append_op( + type='cross_entropy2', + inputs={'X': [input], + 'Label': [label]}, + outputs={'Y': [out], + 'MatchX': [match_x], + 'XShape': [xshape]}, + attrs={'ignore_index': ignore_index}) + return out + + def bpr_loss(input, label, name=None): """ Bayesian Personalized Ranking Loss Operator. diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy2_op.py b/python/paddle/fluid/tests/unittests/test_cross_entropy2_op.py new file mode 100644 index 0000000000000000000000000000000000000000..55029c18d6966ea1d139a1987ff90d46c8e81270 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy2_op.py @@ -0,0 +1,82 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from op_test import OpTest +import unittest +import numpy as np +import six + + +class CrossEntropy2OpTestBase(OpTest): + def initParameters(self): + return [32, 64], 'float32', -100 + + def calc_output(self, logits, label, ignore_index): + ret = np.zeros(shape=label.shape, dtype=logits.dtype) + match_x = np.zeros(shape=label.shape, dtype=logits.dtype) + for idx in six.moves.range(label.shape[0]): + if label[idx] == ignore_index: + continue + match_x[idx] = logits[idx][label[idx]] + ret[idx] = -np.log(match_x[idx]) + return ret, match_x + + def setUp(self): + self.shape, self.dtype, self.ignore_index = self.initParameters() + self.op_type = 'cross_entropy2' + feature_size = int(self.shape[-1]) + batch_size = int(np.prod(self.shape) / feature_size) + logits = (np.random.random(size=self.shape) + 1).astype(self.dtype) + label = np.random.random_integers( + low=0, high=feature_size - 1, + size=self.shape[0:-1] + [1]).astype('int64') + outputs, match_x = self.calc_output( + np.reshape(logits, [batch_size, feature_size]), + np.reshape(label, [batch_size, 1]), self.ignore_index) + self.inputs = {'X': logits, 'Label': label} + self.outputs = { + 'Y': np.reshape(outputs, label.shape), + 'MatchX': np.reshape(match_x, label.shape), + 'XShape': np.zeros( + shape=logits.shape, dtype=logits.dtype) + } + self.attrs = {'ignore_index': self.ignore_index} + + def test_check_output(self): + self.check_output(no_check_set=['XShape']) + + def test_check_grad(self): + self.check_grad( + inputs_to_check=['X'], + output_names=['Y'], + no_grad_set=['XShape', 'MatchX', 'Label']) + + +class CrossEntropy2OpTest2(CrossEntropy2OpTestBase): + def initParameters(self): + return [32, 64], 'float64', 3 + + +class CrossEntropy2OpTest3(CrossEntropy2OpTestBase): + def initParameters(self): + return [4, 8, 16, 32], 'float32', -100 + + +class CrossEntropy2OpTest4(CrossEntropy2OpTestBase): + def initParameters(self): + return [4, 8, 16, 32], 'float32', 3 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 12132477d28c74c7da718321140a3ddef784fc30..f81d4fda50be195b239fcb149382d997275405fd 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -524,8 +524,8 @@ class TestLocalLookupTable(TestDistLookupTableBase): ops = [ 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add', - 'cross_entropy', 'mean', 'fill_constant', 'mean_grad', - 'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad', + 'cross_entropy2', 'mean', 'fill_constant', 'mean_grad', + 'cross_entropy_grad2', 'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad', 'split_selected_rows', 'send', 'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', @@ -564,8 +564,8 @@ class TestDistLookupTable(TestDistLookupTableBase): ops = [ 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'sequence_pool', 'lookup_table', 'sequence_pool', 'concat', 'mul', - 'elementwise_add', 'cross_entropy', 'mean', 'fill_constant', - 'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send', + 'elementwise_add', 'cross_entropy2', 'mean', 'fill_constant', + 'mean_grad', 'cross_entropy_grad2', 'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad', 'split_selected_rows', 'send', 'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad', @@ -612,8 +612,8 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase): ops = [ 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add', - 'cross_entropy', 'mean', 'fill_constant', 'mean_grad', - 'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad', + 'cross_entropy2', 'mean', 'fill_constant', 'mean_grad', + 'cross_entropy_grad2', 'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad', 'split_selected_rows', 'send', 'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', @@ -652,8 +652,8 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): ops = [ 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'sequence_pool', 'lookup_table', 'sequence_pool', 'concat', 'mul', - 'elementwise_add', 'cross_entropy', 'mean', 'fill_constant', - 'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send', + 'elementwise_add', 'cross_entropy2', 'mean', 'fill_constant', + 'mean_grad', 'cross_entropy_grad2', 'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad', 'split_selected_rows', 'send', 'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad', @@ -841,8 +841,8 @@ class TestRemoteLookupTable(TestDistLookupTableBase): ops = [ 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add', - 'cross_entropy', 'mean', 'fill_constant', 'mean_grad', - 'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad', + 'cross_entropy2', 'mean', 'fill_constant', 'mean_grad', + 'cross_entropy_grad2', 'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad', 'split_selected_rows', 'send', 'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',