diff --git a/paddle/fluid/operators/cross_entropy2_op.cc b/paddle/fluid/operators/cross_entropy2_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..03b217a974c7bbc002865c98eb63ace9e88b1c3c --- /dev/null +++ b/paddle/fluid/operators/cross_entropy2_op.cc @@ -0,0 +1,218 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/cross_entropy2_op.h" +#include +#include +#include + +namespace paddle { +namespace operators { + +class CrossEntropyOp2 : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + + PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("XShape"), + "Output(XShape) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto label_dims = ctx->GetInputDim("Label"); + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(rank, label_dims.size(), + "Input(X) and Input(Label) shall have the same rank."); + bool check = true; + if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || + framework::product(label_dims) <= 0)) { + check = false; + } + if (check) { + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_dims, 0, rank - 1), + "Input(X) and Input(Label) shall have the same shape " + "except the last dimension."); + } + + PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL, + "Last dimension of Input(Label) should be 1."); + auto y_dims = x_dims; + y_dims[rank - 1] = 1; + ctx->SetOutputDim("Y", y_dims); + ctx->ShareLoD("X", /*->*/ "Y"); + + auto x_dims_vec = framework::vectorize(x_dims); + x_dims_vec.push_back(0); + ctx->SetOutputDim("XShape", framework::make_ddim(x_dims_vec)); + ctx->ShareLoD("X", /*->*/ "XShape"); + } + + protected: + // Explicitly set that the data type of computation kernel of cross_entropy + // is determined by its input "X". + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } +}; + +class CrossEntropyGradientOp2 : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("XShape"), + "Input(XShape) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null."); + + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), + "Input(Y@GRAD) shoudl be not null."); + + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@GRAD) should be not null."); + + auto x_shapes = ctx->GetInputDim("XShape"); + framework::DDim x_dims(x_shapes.Get(), x_shapes.size() - 1); + auto label_dims = ctx->GetInputDim("Label"); + auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(dy_dims.size(), rank, + "Input(Y@Grad) and Input(X) should have the same rank."); + PADDLE_ENFORCE_EQ(label_dims.size(), rank, + "Input(Label) and Input(X) should have the same rank."); + + bool check = true; + if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || + framework::product(label_dims) <= 0)) { + check = false; + } + + if (check) { + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_dims, 0, rank - 1), + "The Input(X) and Input(Label) should have the same " + "shape except the last dimension."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(dy_dims, 0, rank - 1), + "The Input(X) and Input(Y@Grad) should have the same " + "shape except the last dimension."); + } + PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1, + "The last dimension of Input(Y@Grad) should be 1."); + PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1, + "Last dimension of Input(Label) should be 1."); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + ctx->ShareLoD("XShape", framework::GradVarName("X")); + } + + protected: + // Explicitly set that the data type of computation kernel of cross_entropy + // is determined by its input "X". + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Y"))->type(), + ctx.device_context()); + } +}; + +class CrossEntropyOpMaker2 : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor), a tensor whose last dimension " + "size is equal to the number of classes. This input is a " + "probability computed by the previous operator, which is almost " + "always the result of a softmax operator."); + AddInput( + "Label", + "(Tensor), the tensor which represents the ground truth. It has the " + "same shape with 'X' except the last dimension. One hot Tensor."); + AddOutput("Y", + "(Tensor, default Tensor), a tensor whose shape is same " + "with 'X' except that the last dimension size is 1. It " + "represents the cross entropy loss."); + AddOutput("XShape", "Temporaily variable to save shape and LoD of X."); + AddAttr("ignore_index", + "(int, default -100), Specifies a target value that is" + "ignored and does not contribute to the input gradient." + "Only valid if soft_label is set to False") + .SetDefault(-100); + AddComment(R"DOC( +CrossEntropy Operator. + +The input 'X' and 'Label' will first be logically flattened to 2-D matrixs. +The matrix's second dimension(row length) is as same as the original last +dimension, and the first dimension(column length) is the product of all other +original dimensions. Then the softmax computation will take palce on each raw +of flattened matrixs. + +Only support hard label. + +Both the input X and Label can carry the LoD (Level of Details) information, +or not. But the output only shares the LoD information with input X. + +)DOC"); + } +}; + +class CrossEntropyOpInferVarType2 + : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map GetInputOutputWithSameType() + const override { + return std::unordered_map{{"X", /*->*/ "Y"}}; + } +}; + +class CrossEntropyGradOpMaker2 : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("cross_entropy_grad2"); + op->SetInput("Label", Input("Label")); + op->SetInput("Y", Output("Y")); + op->SetInput("XShape", Output("XShape")); + op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetAttrMap(Attrs()); + return op; + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CPUCtx = paddle::platform::CPUDeviceContext; + +REGISTER_OPERATOR(cross_entropy2, ops::CrossEntropyOp2, + ops::CrossEntropyOpMaker2, ops::CrossEntropyOpInferVarType2, + ops::CrossEntropyGradOpMaker2); +REGISTER_OPERATOR(cross_entropy_grad2, ops::CrossEntropyGradientOp2); +REGISTER_OP_CPU_KERNEL(cross_entropy2, + ops::CrossEntropyOpKernel2, + ops::CrossEntropyOpKernel2); +REGISTER_OP_CPU_KERNEL(cross_entropy_grad2, + ops::CrossEntropyGradientOpKernel2, + ops::CrossEntropyGradientOpKernel2); diff --git a/paddle/fluid/operators/cross_entropy2_op.cu b/paddle/fluid/operators/cross_entropy2_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..1868c1b866016d1ea51e28339847b6c890c5ec74 --- /dev/null +++ b/paddle/fluid/operators/cross_entropy2_op.cu @@ -0,0 +1,29 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/cross_entropy2_op.h" +#include "paddle/fluid/platform/float16.h" + +namespace plat = paddle::platform; +namespace ops = paddle::operators; +using CUDACtx = paddle::platform::CUDADeviceContext; +REGISTER_OP_CUDA_KERNEL(cross_entropy2, + ops::CrossEntropyOpKernel2, + ops::CrossEntropyOpKernel2, + ops::CrossEntropyOpKernel2); + +REGISTER_OP_CUDA_KERNEL( + cross_entropy_grad2, ops::CrossEntropyGradientOpKernel2, + ops::CrossEntropyGradientOpKernel2, + ops::CrossEntropyGradientOpKernel2); diff --git a/paddle/fluid/operators/cross_entropy2_op.h b/paddle/fluid/operators/cross_entropy2_op.h new file mode 100644 index 0000000000000000000000000000000000000000..3d209f7c5c95c242bb5023a49f4841a718e4d1fb --- /dev/null +++ b/paddle/fluid/operators/cross_entropy2_op.h @@ -0,0 +1,188 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/cross_entropy.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +HOSTDEVICE inline platform::float16 RealLog(platform::float16 x) { +#ifdef __NVCC__ + return static_cast(logf(static_cast(x))); +#else + return static_cast(std::log(static_cast(x))); +#endif +} + +HOSTDEVICE inline float RealLog(float x) { +#ifdef __NVCC__ + return logf(x); +#else + return std::log(x); +#endif +} + +HOSTDEVICE inline double RealLog(double x) { +#ifdef __NVCC__ + return log(x); +#else + return std::log(x); +#endif +} + +HOSTDEVICE inline platform::float16 RealExp(platform::float16 x) { +#ifdef __NVCC__ + return static_cast(expf(static_cast(x))); +#else + return static_cast(std::exp(static_cast(x))); +#endif +} + +HOSTDEVICE inline float RealExp(float x) { +#ifdef __NVCC__ + return expf(x); +#else + return std::exp(x); +#endif +} + +HOSTDEVICE inline double RealExp(double x) { +#ifdef __NVCC__ + return exp(x); +#else + return std::exp(x); +#endif +} + +template +struct CrossEntropyForwardFunctor { + CrossEntropyForwardFunctor(const T *x, T *y, const int64_t *label, + int64_t ignore_index, int64_t feature_size) + : x_(x), + y_(y), + label_(label), + ignore_index_(ignore_index), + feature_size_(feature_size) {} + + HOSTDEVICE void operator()(int64_t row_idx) const { + auto col_idx = label_[row_idx]; + if (col_idx != ignore_index_) { + y_[row_idx] = -math::TolerableValue()( + RealLog(x_[row_idx * feature_size_ + col_idx])); + } else { + y_[row_idx] = 0; + } + } + + const T *x_; + T *y_; + const int64_t *label_; + int64_t ignore_index_; + int64_t feature_size_; +}; + +template +struct CrossEntropyBackwardFunctor { + CrossEntropyBackwardFunctor(T *dx, const T *y, const T *dy, + const int64_t *label, int64_t ignore_index, + int64_t feature_size) + : dx_(dx), + y_(y), + dy_(dy), + label_(label), + ignore_index_(ignore_index), + feature_size_(feature_size) {} + + HOSTDEVICE void operator()(int64_t idx) const { + auto row_idx = idx / feature_size_; + auto col_idx = idx % feature_size_; + auto label = label_[row_idx]; + if (label == col_idx && label != ignore_index_) { + dx_[idx] = -dy_[row_idx] * RealExp(y_[row_idx]); + } else { + dx_[idx] = 0; + } + } + + T *dx_; + const T *y_; + const T *dy_; + const int64_t *label_; + int64_t ignore_index_; + int64_t feature_size_; +}; + +template +class CrossEntropyOpKernel2 : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *x = ctx.Input("X"); + auto *label = ctx.Input("Label"); + auto *y = ctx.Output("Y"); + + auto *p_y = y->mutable_data(ctx.GetPlace()); + auto *p_x = x->data(); + auto *p_label = label->data(); + + int rank = x->dims().size(); + int64_t feature_size = x->dims()[rank - 1]; + int64_t batch_size = framework::product(x->dims()) / feature_size; + + int64_t ignore_index = ctx.Attr("ignore_index"); + + platform::ForRange for_range( + ctx.template device_context(), batch_size); + for_range(CrossEntropyForwardFunctor(p_x, p_y, p_label, ignore_index, + feature_size)); + } +}; + +template +class CrossEntropyGradientOpKernel2 : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *dx = ctx.Output(framework::GradVarName("X")); + auto *y = ctx.Input("Y"); + auto *dy = ctx.Input(framework::GradVarName("Y")); + auto *label = ctx.Input("Label"); + + auto *p_dx = dx->mutable_data(ctx.GetPlace()); + auto *p_y = y->data(); + auto *p_dy = dy->data(); + auto *p_label = label->data(); + + int64_t ignore_index = ctx.Attr("ignore_index"); + int rank = dx->dims().size(); + int64_t feature_size = dx->dims()[rank - 1]; + int64_t batch_size = framework::product(dx->dims()) / feature_size; + + platform::ForRange for_range( + ctx.template device_context(), + batch_size * feature_size); + for_range(CrossEntropyBackwardFunctor(p_dx, p_y, p_dy, p_label, + ignore_index, feature_size)); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 9d1d5fe0932ea8a53e28bc18a776a430a53e9ef4..4f384ce37d7c8826257dbff12d07b1e8f86e5aaa 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1432,6 +1432,8 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex): predict = fluid.layers.fc(input=net, size=classdim, act='softmax') cost = fluid.layers.cross_entropy(input=predict, label=label) """ + if not soft_label: + return cross_entropy2(input, label, ignore_index) helper = LayerHelper('cross_entropy', **locals()) out = helper.create_variable_for_type_inference(dtype=input.dtype) helper.append_op( @@ -1444,6 +1446,20 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex): return out +def cross_entropy2(input, label, ignore_index=kIgnoreIndex): + helper = LayerHelper('cross_entropy2', **locals()) + out = helper.create_variable_for_type_inference(dtype=input.dtype) + xshape = helper.create_variable_for_type_inference(dtype=input.dtype) + helper.append_op( + type='cross_entropy2', + inputs={'X': [input], + 'Label': [label]}, + outputs={'Y': [out], + 'XShape': [xshape]}, + attrs={'ignore_index': ignore_index}) + return out + + def bpr_loss(input, label, name=None): """ Bayesian Personalized Ranking Loss Operator.