diff --git a/paddle/operators/sigmoid_cross_entropy_with_logits_op.cc b/paddle/operators/sigmoid_cross_entropy_with_logits_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..ede458e01147ab70444c4d158973cfb0818b9bdd --- /dev/null +++ b/paddle/operators/sigmoid_cross_entropy_with_logits_op.cc @@ -0,0 +1,150 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/sigmoid_cross_entropy_with_logits_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Labels"), + "Input(Labels) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto labels_dims = ctx->GetInputDim("Labels"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(labels_dims.size(), 2, + "Input(Labels)'s rank should be 2."); + PADDLE_ENFORCE_EQ(x_dims[0], labels_dims[0], + "The 1st dimension of Input(X) and Input(Labels) should " + "be equal."); + PADDLE_ENFORCE_EQ(x_dims[1], labels_dims[1], + "The 2nd dimension of Input(X) and Input(Labels) should " + "be equal."); + + ctx->SetOutputDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class SigmoidCrossEntropyWithLogitsGradOp + : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Labels"), + "Input(Labels) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) shoudl be not null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@GRAD) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto labels_dims = ctx->GetInputDim("Labels"); + auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(labels_dims.size(), 2, + "Input(Labels)'s rank should be 2."); + PADDLE_ENFORCE_EQ(dout_dims.size(), 2, + "Input(Out@Grad)'s rank should be 2."); + PADDLE_ENFORCE_EQ(x_dims[0], labels_dims[0], + "The 1st dimension of Input(X) and Input(Labels) should " + "be equal."); + PADDLE_ENFORCE_EQ(x_dims[1], labels_dims[1], + "The 2nd dimension of Input(X) and Input(Labels) should " + "be equal."); + PADDLE_ENFORCE_EQ(x_dims[0], dout_dims[0], + "The 1st dimension of Input(X) and Input(Out@Grad) " + "should be equal."); + PADDLE_ENFORCE_EQ(x_dims[1], dout_dims[1], + "The 2nd dimension of Input(X) and Input(Out@Grad) " + "should be equal."); + + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + } +}; + +class SigmoidCrossEntropyWithLogitsOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + SigmoidCrossEntropyWithLogitsOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : framework::OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(Tensor, default Tensor), a 2-D tensor with shape N x D, " + "where N is the batch size and D is the number of classes. " + "This input is a tensor of logits computed by the previous " + " operator. Logits are unscaled log probabilities given as " + "log(p/(1-p))."); + AddInput("Labels", + "(Tensor, default Tensor), a 2-D tensor of the same type " + "and shape as X. This input is a tensor of probabalistic labels " + "for each logit"); + AddOutput("Out", + "(Tensor, default Tensor), a 2-D tensor with shape N x D " + " of elementwise logistic losses."); + AddComment(R"DOC( +SigmoidCrossEntropyWithLogits Operator. + +This measures the elementwise probability error in discrete classification tasks +in which each class is independent. This can be thought of as predicting labels +for a data-point that are not mutually exclusive. For example, a news article +can be about politics, technology or sports at the same time or none of these. + +The logistic loss is given as follows: + + loss = -Labels * log(sigmoid(X)) - (1 - Labels) * log(1 - sigmoid(X)) + +We know that sigmoid(X) = (1 / (1 + exp(-X))). By substituting this we get + + loss = X - X * Labels + log(1 + exp(-X)) + +For stability and to prevent overflow of exp(-X) when X < 0, +we can reformulate the loss as follows: + + loss = max(X, 0) - X * Labels + log(1 + exp(-abs(X))) + +Both the input `X` and `Labels` can carry the LoD (Level of Details) information. +However the output only shares the LoD with input `X`. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(sigmoid_cross_entropy_with_logits, + ops::SigmoidCrossEntropyWithLogitsOp, + ops::SigmoidCrossEntropyWithLogitsOpMaker, + sigmoid_cross_entropy_with_logits_grad, + ops::SigmoidCrossEntropyWithLogitsGradOp); +REGISTER_OP_CPU_KERNEL(sigmoid_cross_entropy_with_logits, + ops::SigmoidCrossEntropyWithLogitsKernel< + paddle::platform::CPUPlace, float>); +REGISTER_OP_CPU_KERNEL(sigmoid_cross_entropy_with_logits_grad, + ops::SigmoidCrossEntropyWithLogitsGradKernel< + paddle::platform::CPUPlace, float>); diff --git a/paddle/operators/sigmoid_cross_entropy_with_logits_op.cu b/paddle/operators/sigmoid_cross_entropy_with_logits_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..32a39956a14a206373b7b4c141dad19577d171f0 --- /dev/null +++ b/paddle/operators/sigmoid_cross_entropy_with_logits_op.cu @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/sigmoid_cross_entropy_with_logits_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(sigmoid_cross_entropy_with_logits, + ops::SigmoidCrossEntropyWithLogitsKernel< + paddle::platform::GPUPlace, float>); +REGISTER_OP_GPU_KERNEL(sigmoid_cross_entropy_with_logits_grad, + ops::SigmoidCrossEntropyWithLogitsGradKernel< + paddle::platform::GPUPlace, float>); diff --git a/paddle/operators/sigmoid_cross_entropy_with_logits_op.h b/paddle/operators/sigmoid_cross_entropy_with_logits_op.h new file mode 100644 index 0000000000000000000000000000000000000000..a6de9043fdbcdcca47407aac0b4892cbad3a9a42 --- /dev/null +++ b/paddle/operators/sigmoid_cross_entropy_with_logits_op.h @@ -0,0 +1,75 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +// Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X))) +template +class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + const framework::Tensor *X = context.Input("X"); + const framework::Tensor *Labels = + context.Input("Labels"); + framework::Tensor *Out = context.Output("Out"); + Out->mutable_data(context.GetPlace()); + + auto x = framework::EigenVector::Flatten(*X); + auto labels = framework::EigenVector::Flatten(*Labels); + auto out = framework::EigenVector::Flatten(*Out); + auto place = context.GetEigenDevice(); + + // term1 = max(x, 0) + auto term1 = x.cwiseMax(static_cast(0)); + // term2 = x * labels + auto term2 = x * labels; + // term3 = log(1 + exp(-abs(x))) + auto term3 = (static_cast(1) + (-(x.abs())).exp()).log(); + + out.device(place) = term1 - term2 + term3; + } +}; + +// dX = sigmoid(X) - labels +template +class SigmoidCrossEntropyWithLogitsGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + const framework::Tensor *X = context.Input("X"); + const framework::Tensor *Labels = + context.Input("Labels"); + const framework::Tensor *dOut = + context.Input(framework::GradVarName("Out")); + framework::Tensor *dX = + context.Output(framework::GradVarName("X")); + dX->mutable_data(context.GetPlace()); + + auto x = framework::EigenVector::Flatten(*X); + auto labels = framework::EigenVector::Flatten(*Labels); + auto dout = framework::EigenVector::Flatten(*dOut); + auto dx = framework::EigenVector::Flatten(*dX); + auto place = context.GetEigenDevice(); + + auto sigmoid_x = static_cast(1) / (static_cast(1) + (-x).exp()); + dx.device(place) = dout * (sigmoid_x - labels); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_sigmoid_cross_entropy_with_logits_op.py b/python/paddle/v2/framework/tests/test_sigmoid_cross_entropy_with_logits_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e53856b38aa5ddd6061b350a66e9fe86bc23923c --- /dev/null +++ b/python/paddle/v2/framework/tests/test_sigmoid_cross_entropy_with_logits_op.py @@ -0,0 +1,66 @@ +import numpy as np +from op_test import OpTest +from scipy.special import logit +from scipy.special import expit + + +class TestSigmoidCrossEntropyWithLogitsOp1(OpTest): + '''Test sigmoid_cross_entropy_with_logit_op with binary labels + ''' + + def setUp(self): + self.op_type = "sigmoid_cross_entropy_with_logits" + batch_size = 64 + num_classes = 20 + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, (batch_size, num_classes)) + .astype("float32")), + 'Labels': np.random.randint(0, 2, (batch_size, num_classes)) + .astype("float32") + } + + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Labels * -log(sigmoid(X)) + (1 - labels) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Labels'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Labels']) * np.log(1 - sigmoid_X) + self.outputs = {'Out': -term1 - term2} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestSigmoidCrossEntropyWithLogitsOp2(OpTest): + '''Test sigmoid_cross_entropy_with_logit_op with probabalistic labels + ''' + + def setUp(self): + self.op_type = "sigmoid_cross_entropy_with_logits" + batch_size = 64 + num_classes = 20 + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, (batch_size, num_classes)) + .astype("float32")), + 'Labels': np.random.uniform(0, 1, (batch_size, num_classes)) + .astype("float32") + } + + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Labels * -log(sigmoid(X)) + (1 - labels) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Labels'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Labels']) * np.log(1 - sigmoid_X) + self.outputs = {'Out': -term1 - term2} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out')