From 1c81d57938c55001c58336f29ed07ea4f1247cb9 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Sat, 9 Sep 2017 19:01:24 +0800 Subject: [PATCH] Add huber loss operator. --- paddle/operators/huber_loss_op.cc | 108 ++++++++++++++++ paddle/operators/huber_loss_op.cu | 23 ++++ paddle/operators/huber_loss_op.h | 120 ++++++++++++++++++ paddle/pybind/pybind.cc | 1 + .../paddle/v2/framework/tests/CMakeLists.txt | 1 + .../v2/framework/tests/test_huber_loss_op.py | 56 ++++++++ 6 files changed, 309 insertions(+) create mode 100644 paddle/operators/huber_loss_op.cc create mode 100644 paddle/operators/huber_loss_op.cu create mode 100644 paddle/operators/huber_loss_op.h create mode 100644 python/paddle/v2/framework/tests/test_huber_loss_op.py diff --git a/paddle/operators/huber_loss_op.cc b/paddle/operators/huber_loss_op.cc new file mode 100644 index 00000000000..461409b0323 --- /dev/null +++ b/paddle/operators/huber_loss_op.cc @@ -0,0 +1,108 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/huber_loss_op.h" + +namespace paddle { +namespace operators { + +class HuberLossOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext& ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X must be initialized."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Y must be initialized."); + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + + PADDLE_ENFORCE_EQ(x->dims(), y->dims(), + "Dimensions of X and Y must be the same."); + // we constraint shape of X to (N, 1), may expand to (N, x, ...) if needed + PADDLE_ENFORCE_EQ(framework::arity(x->dims()), 2, + "Tensor rank of X must be 2."); + PADDLE_ENFORCE_EQ(x->dims()[1], 1, "Second dimension of X must be 1."); + + ctx.Output("residual")->Resize(x->dims()); + ctx.Output("Out")->Resize({x->dims()[0], 1}); + } +}; + +template +class HuberLossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + HuberLossOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input value of HuberLossOp."); + AddInput("Y", "Target value of HuberLossOp."); + AddOutput("residual", + "Save residual value between Y and X. " + "Will be reused in backward.") + .AsIntermediate(); + AddOutput("Out", "Huber loss between input and target."); + AddAttr("delta", "Hyper parameter in huber loss."); + AddComment(R"DOC( +Huber loss is a loss function used in robust regression. We constrain shape of +input to (N, 1). The formulation is: + +L_delta(y, f(x)) = 0.5 * (y - f(x))^2 for |y - f(x)| <= delta, + delta * (|y - f(x)| - 0.5 * delta) otherwise. + +)DOC"); + } +}; + +class HuberLossGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* residual = ctx.Input("residual"); + auto* out_grad = ctx.Input(framework::GradVarName("Out")); + auto* x_grad = ctx.Output(framework::GradVarName("X")); + auto* y_grad = ctx.Output(framework::GradVarName("Y")); + + PADDLE_ENFORCE_NOT_NULL(x, "Input X must not be null."); + PADDLE_ENFORCE_NOT_NULL(y, "Target Y must not be null."); + PADDLE_ENFORCE_NOT_NULL(residual, "Residual value must not be null."); + PADDLE_ENFORCE_NOT_NULL(out_grad, "Out gradient must not be null."); + + PADDLE_ENFORCE_EQ(residual->dims(), x->dims(), + "Dimension of X and residual value must be the same."); + PADDLE_ENFORCE_EQ( + out_grad->dims(), x->dims(), + "Dimension of Out gradient and X must be the same (N*1)."); + + if (x_grad) x_grad->Resize(x->dims()); + if (y_grad) y_grad->Resize(y->dims()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(huber_loss, ops::HuberLossOp, ops::HuberLossOpMaker, + huber_loss_grad, ops::HuberLossGradOp); +REGISTER_OP_CPU_KERNEL(huber_loss, + ops::HuberLossKernel); +REGISTER_OP_CPU_KERNEL( + huber_loss_grad, + ops::HuberLossGradKernel); diff --git a/paddle/operators/huber_loss_op.cu b/paddle/operators/huber_loss_op.cu new file mode 100644 index 00000000000..317321dc6c4 --- /dev/null +++ b/paddle/operators/huber_loss_op.cu @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/huber_loss_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(huber_loss, + ops::HuberLossKernel); +REGISTER_OP_GPU_KERNEL( + huber_loss_grad, + ops::HuberLossGradKernel); diff --git a/paddle/operators/huber_loss_op.h b/paddle/operators/huber_loss_op.h new file mode 100644 index 00000000000..61c64ea3572 --- /dev/null +++ b/paddle/operators/huber_loss_op.h @@ -0,0 +1,120 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + +template +struct HuberLossForward { + HOSTDEVICE HuberLossForward(const T& delta) : delta(delta) {} + + HOSTDEVICE T operator()(const T& val) const { + T abs_val = std::abs(val); + if (abs_val <= delta) { + return 0.5 * val * val; + } else { + return delta * (abs_val - 0.5 * delta); + } + } + + T delta; +}; + +template +class HuberLossKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in0 = context.Input("X"); + auto* in1 = context.Input("Y"); + auto* out0 = context.Output("residual"); + auto* out1 = context.Output("Out"); + auto delta = static_cast(context.op().Attr("delta")); + auto place = context.GetEigenDevice(); + + auto x = EigenVector::Flatten(*in0); + auto y = EigenVector::Flatten(*in1); + out0->mutable_data(context.GetPlace()); + auto residual = EigenVector::Flatten(*out0); + residual.device(place) = y - x; + out1->mutable_data(context.GetPlace()); + auto loss = EigenVector::Flatten(*out1); + loss.device(place) = residual.unaryExpr(HuberLossForward(delta)); + } +}; + +template +struct HuberLossBackward { + HOSTDEVICE HuberLossBackward(const T& delta, bool is_x) + : is_x(is_x), delta(delta) {} + + HOSTDEVICE T operator()(const T& val) const { + T sign = is_x ? -1.0 : 1.0; + T abs_val = std::abs(val); + if (abs_val <= delta) { + return sign * val; + } else { + if (val > 0) { + return sign * delta; + } else { + return -1 * sign * delta; + } + } + } + + bool is_x; + T delta; +}; + +template +class HuberLossGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in0 = context.Input("residual"); + auto* in1 = context.Input(framework::GradVarName("Out")); + auto* out0 = context.Output(framework::GradVarName("X")); + auto* out1 = context.Output(framework::GradVarName("Y")); + auto delta = static_cast(context.op().Attr("delta")); + auto place = context.GetEigenDevice(); + + auto residual = EigenVector::Flatten(*in0); + auto out_grad = EigenVector::Flatten(*in1); + + if (out0) { + out0->mutable_data(context.GetPlace()); + auto x_grad = EigenVector::Flatten(*out0); + x_grad.device(place) = + out_grad * residual.unaryExpr(HuberLossBackward(delta, true)); + } + + if (out1) { + out1->mutable_data(context.GetPlace()); + auto y_grad = EigenVector::Flatten(*out1); + y_grad.device(place) = + out_grad * residual.unaryExpr(HuberLossBackward(delta, false)); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 53985933ed1..130cf140aa8 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -51,6 +51,7 @@ USE_CPU_ONLY_OP(gather); USE_CPU_ONLY_OP(scatter); USE_OP(top_k); USE_OP(squared_l2_distance); +USE_OP(huber_loss); namespace paddle { namespace framework { diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index ef910f939be..5b9f4084ec0 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -35,3 +35,4 @@ py_test(test_lookup_table SRCS test_lookup_table.py) py_test(test_scale_and_identity_op SRCS test_scale_and_identity_op.py) py_test(mnist SRCS mnist.py) py_test(test_squared_l2_distance_op SRCS test_squared_l2_distance_op.py) +py_test(test_huber_loss_op SRCS test_huber_loss_op.py) diff --git a/python/paddle/v2/framework/tests/test_huber_loss_op.py b/python/paddle/v2/framework/tests/test_huber_loss_op.py new file mode 100644 index 00000000000..540dedc3577 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_huber_loss_op.py @@ -0,0 +1,56 @@ +import unittest +from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op +from paddle.v2.framework.op import Operator +import numpy as np + + +def huber_loss_forward(val, delta): + abs_val = abs(val) + if abs_val <= delta: + return 0.5 * val * val + else: + return delta * (abs_val - 0.5 * delta) + + +class TestHuberLossOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = 'huber_loss' + samples_num = 64 + delta = 1.0 + self.inputs = { + 'X': np.random.uniform(0, 1., (samples_num, 1)).astype('float32'), + 'Y': np.random.uniform(0, 1., (samples_num, 1)).astype('float32'), + } + residual = self.inputs['Y'] - self.inputs['X'] + loss = np.vectorize(huber_loss_forward)(residual, delta) + self.attrs = {'delta': delta} + self.outputs = { + 'residual': residual, + 'Out': loss.reshape((samples_num, 1)) + } + + +class TestHuberLossGradOp(GradientChecker): + def test_huber_loss(self): + samples_num = 10 + delta = 1.0 + inputs = { + 'X': np.random.uniform(-1, 1, (samples_num, 1)).astype('float32'), + 'Y': np.random.uniform(-1, 1, (samples_num, 1)).astype('float32') + } + op = Operator( + "huber_loss", + X='X', + Y='Y', + residual='residual', + delta=delta, + Out='Out') + self.compare_grad(op, inputs, no_grad_set=set(['residual'])) + self.check_grad(op, inputs, set(["X", "Y"]), "Out") + + +if __name__ == '__main__': + unittest.main() -- GitLab