From dcf3ffd98033ffa492932ed9ffb7880d0bf010a0 Mon Sep 17 00:00:00 2001 From: kavyasrinet Date: Tue, 28 Nov 2017 18:02:28 -0800 Subject: [PATCH] Adding log loss operator (#5854) * Adding log loss operator * Removing comments --- paddle/operators/log_loss_op.cc | 115 ++++++++++++++++++ paddle/operators/log_loss_op.cu | 22 ++++ paddle/operators/log_loss_op.h | 75 ++++++++++++ .../paddle/v2/fluid/tests/test_log_loss_op.py | 33 +++++ 4 files changed, 245 insertions(+) create mode 100644 paddle/operators/log_loss_op.cc create mode 100644 paddle/operators/log_loss_op.cu create mode 100644 paddle/operators/log_loss_op.h create mode 100644 python/paddle/v2/fluid/tests/test_log_loss_op.py diff --git a/paddle/operators/log_loss_op.cc b/paddle/operators/log_loss_op.cc new file mode 100644 index 00000000000..257e5c8a49e --- /dev/null +++ b/paddle/operators/log_loss_op.cc @@ -0,0 +1,115 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/log_loss_op.h" + +namespace paddle { +namespace operators { + +class LogLossOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Predicted"), + "Input(Predicted) must be initialized."); + PADDLE_ENFORCE(ctx->HasInput("Labels"), + "Input(Labels) must be initialized."); + + auto pred_dims = ctx->GetInputDim("Predicted"); + auto label_dims = ctx->GetInputDim("Labels"); + + PADDLE_ENFORCE_EQ(pred_dims, label_dims); + PADDLE_ENFORCE_EQ(pred_dims.size(), 2, + "The rank of Input(Predicted) must be 2 and the shape is " + "[batch_size, 1]."); + PADDLE_ENFORCE_EQ(pred_dims[1], 1, + "Each row of Input(Predicted) contains a real value, " + "so the 2nd dimension of Input(X) must be 1."); + + ctx->SetOutputDim("Loss", {pred_dims[0], 1}); + ctx->ShareLoD("Predicted", "Loss"); + } +}; + +template +class LogLossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + LogLossOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Predicted", + "The input value (Predicted) of Log loss op." + "Predicted is a 2-D tensor with shape [batch_size, 1]."); + AddInput("Labels", + "The target value (Labels) of Log loss op." + "Labels is a 2-D tensor with shape [batch_size, 1]."); + AddOutput("Loss", + "The output tensor with shape [batch_size, 1] " + "which represents the log loss."); + AddAttr("epsilon", "Epsilon in log loss."); + AddComment(R"DOC( +LogLoss Operator. + +Log loss is a loss function used for binary classification. Log Loss quantifies +the accuracy of a classifier by penalising false classifications. Minimising the +Log Loss is equivalent to maximising the accuracy of the classifier. We define +Predicted as the values predicted by our model and Labels as the target ground +truth value. Log loss can evaluate how close the predicted values are to the +target. The shapes of Predicted and Labels are both [batch_size, 1]. +The equation is: + +$$ +Loss = - Labels * log(Predicted + \epsilon) - + (1 - Labels) * log(1 - Predicted + \epsilon) +$$ + +)DOC"); + } +}; + +class LogLossGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Predicted"), + "Input(Predicted) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Labels"), + "Input(Labels) should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")), + "Input(Loss@GRAD) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Predicted")), + "Output(Predicted@GRAD) should not be null."); + + auto pred_dims = ctx->GetInputDim("Predicted"); + auto label_dims = ctx->GetInputDim("Labels"); + auto loss_grad_dims = ctx->GetInputDim(framework::GradVarName("Loss")); + PADDLE_ENFORCE_EQ(loss_grad_dims, pred_dims); + + auto pred_grad_name = framework::GradVarName("Predicted"); + ctx->SetOutputDim(pred_grad_name, pred_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(log_loss, ops::LogLossOp, ops::LogLossOpMaker, log_loss_grad, + ops::LogLossGradOp); +REGISTER_OP_CPU_KERNEL(log_loss, + ops::LogLossKernel); +REGISTER_OP_CPU_KERNEL( + log_loss_grad, ops::LogLossGradKernel); diff --git a/paddle/operators/log_loss_op.cu b/paddle/operators/log_loss_op.cu new file mode 100644 index 00000000000..6c189ef3412 --- /dev/null +++ b/paddle/operators/log_loss_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/log_loss_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(log_loss, + ops::LogLossKernel); +REGISTER_OP_GPU_KERNEL( + log_loss_grad, ops::LogLossGradKernel); diff --git a/paddle/operators/log_loss_op.h b/paddle/operators/log_loss_op.h new file mode 100644 index 00000000000..73404fce915 --- /dev/null +++ b/paddle/operators/log_loss_op.h @@ -0,0 +1,75 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + +template +class LogLossKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* loss_out = ctx.Output("Loss"); + + loss_out->mutable_data(ctx.GetPlace()); + + auto epsilon = static_cast(ctx.Attr("epsilon")); + + auto prediction = EigenVector::Flatten(*ctx.Input("Predicted")); + auto label = EigenVector::Flatten(*ctx.Input("Labels")); + + auto loss = EigenVector::Flatten(*loss_out); + auto place = ctx.GetEigenDevice(); + + loss.device(place) = (-(label * (prediction + epsilon).log()) - + ((static_cast(1) - label) * + (static_cast(1) - prediction + epsilon).log())); + } +}; + +template +class LogLossGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto epsilon = static_cast(ctx.Attr("epsilon")); + + auto prediction = EigenVector::Flatten(*ctx.Input("Predicted")); + auto label = EigenVector::Flatten(*ctx.Input("Labels")); + + auto* dloss = ctx.Input(framework::GradVarName("Loss")); + auto* dpred = ctx.Output(framework::GradVarName("Predicted")); + + auto dl = EigenVector::Flatten(*dloss); + auto place = ctx.GetEigenDevice(); + + if (dpred) { + dpred->mutable_data(ctx.GetPlace()); + auto dx = framework::EigenVector::Flatten(*dpred); + dx.device(place) = dl * (-(label / (prediction + epsilon)) + + ((static_cast(1) - label) / + (static_cast(1) - prediction + epsilon))); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_log_loss_op.py b/python/paddle/v2/fluid/tests/test_log_loss_op.py new file mode 100644 index 00000000000..2eeaa90758c --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_log_loss_op.py @@ -0,0 +1,33 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestLogLossOp(OpTest): + def setUp(self): + self.op_type = 'log_loss' + samples_num = 32 + + predicted = np.random.uniform(0.1, 1.0, + (samples_num, 1)).astype("float32") + labels = np.random.randint(0, 2, (samples_num, 1)).astype("float32") + epsilon = 1e-4 + self.inputs = { + 'Predicted': predicted, + 'Labels': labels, + } + + self.attrs = {'epsilon': epsilon} + loss = -labels * np.log(predicted + epsilon) - ( + 1 - labels) * np.log(1 - predicted + epsilon) + self.outputs = {'Loss': loss} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['Predicted'], 'Loss', max_relative_error=0.03) + + +if __name__ == '__main__': + unittest.main() -- GitLab