diff --git a/paddle/operators/modified_huber_loss_op.cc b/paddle/operators/modified_huber_loss_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6fe018f9a8fd74479a2feed07379c3179b7c72bd --- /dev/null +++ b/paddle/operators/modified_huber_loss_op.cc @@ -0,0 +1,114 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/modified_huber_loss_op.h" + +namespace paddle { +namespace operators { + +class ModifiedHuberLossOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext& context) const override { + PADDLE_ENFORCE_NOT_NULL(context.InputVar("X"), "X must be initialized."); + PADDLE_ENFORCE_NOT_NULL(context.InputVar("Y"), "Y must be initialized."); + + auto* x = context.Input("X"); + auto* y = context.Input("Y"); + + PADDLE_ENFORCE_EQ(x->dims(), y->dims(), + "The shape of X and Y must be the same."); + PADDLE_ENFORCE_EQ(x->dims().size(), 2, "The tensor rank of X must be 2."); + PADDLE_ENFORCE_EQ(x->dims()[1], 1, "The 2nd dimension of X must be 1."); + + context.Output("IntermediateVal")->Resize(x->dims()); + context.Output("Out")->Resize({x->dims()[0], 1}); + } +}; + +class ModifiedHuberLossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ModifiedHuberLossOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "The input tensor of modified huber loss op." + "X is 2-D tensor with shape [batch_size, 1]."); + AddInput("Y", + "The target labels of modified huber loss op." + "The shape of Y is same as X. Values of Y must be 0 or 1."); + AddOutput("IntermediateVal", + "Variable to save intermediate result which will be reused in " + "backward processing.") + .AsIntermediate(); + AddOutput("Out", "Classification loss for X."); + AddComment(R"DOC( +Modified huber loss is used in binary classification problem. The shape of +input X and target Y are both [N, 1] and so is the shape of output loss. +Since target Y is not differentiable, cacluating gradient for Y is illegal. +The formulation of modified huber loss is: + +L(y, f(x)) = max(0, 1 - yf(x))^2 for yf(x) >= -1, + -4yf(x) otherwise. + +Make sure the values of target label Y are in {0, 1} here. The operator will +scale values of Y to {-1, +1} when computing losses and gradients. +)DOC"); + } +}; + +class ModifiedHuberLossGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext& context) const override { + auto* x = context.Input("X"); + auto* y = context.Input("Y"); + auto* intermediate_val = context.Input("IntermediateVal"); + auto* out_grad = context.Input(framework::GradVarName("Out")); + auto* x_grad = + context.Output(framework::GradVarName("X")); + + PADDLE_ENFORCE_NOT_NULL(x, "X must be initialized."); + PADDLE_ENFORCE_NOT_NULL(y, "Y must be initialized."); + PADDLE_ENFORCE_NOT_NULL(intermediate_val, + "Intermediate value must not be null."); + PADDLE_ENFORCE_NOT_NULL(out_grad, "Input(Out@Grad) must not be null."); + + PADDLE_ENFORCE_EQ( + intermediate_val->dims(), x->dims(), + "The shape of X and intermediate value must be the same."); + PADDLE_ENFORCE_EQ(out_grad->dims(), x->dims(), + "The shape of Input(Out@Grad) and X must be the same."); + + if (x_grad) x_grad->Resize(x->dims()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(modified_huber_loss, ops::ModifiedHuberLossOp, + ops::ModifiedHuberLossOpMaker, modified_huber_loss_grad, + ops::ModifiedHuberLossGradOp); + +REGISTER_OP_CPU_KERNEL( + modified_huber_loss, + ops::ModifiedHuberLossKernel); +REGISTER_OP_CPU_KERNEL(modified_huber_loss_grad, + ops::ModifiedHuberLossGradCPUKernel); diff --git a/paddle/operators/modified_huber_loss_op.cu b/paddle/operators/modified_huber_loss_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..bce760f95e72cfec05b07591e0fa1250168b112f --- /dev/null +++ b/paddle/operators/modified_huber_loss_op.cu @@ -0,0 +1,78 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include +#include +#include +#include "paddle/framework/op_registry.h" +#include "paddle/operators/modified_huber_loss_op.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +struct ModifiedHuberLossBackward { + template + HOSTDEVICE void operator()(Tuple t) const { + auto inter_val = thrust::get<1>(t); + auto y_val = thrust::get<2>(t); + auto out_grad = thrust::get<3>(t); + if (inter_val < -1) { + thrust::get<0>(t) = -4 * (2 * y_val - 1) * out_grad; + } else if (inter_val < 1) { + thrust::get<0>(t) = -2 * (1 - inter_val) * (2 * y_val - 1) * out_grad; + } else { + thrust::get<0>(t) = 0; + } + } +}; + +template +class ModifiedHuberLossGradGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in0 = context.Input("Y"); + auto* in1 = context.Input("IntermediateVal"); + auto* in2 = context.Input(framework::GradVarName("Out")); + auto* out0 = context.Output(framework::GradVarName("X")); + + if (out0) { + auto counts = framework::product(in1->dims()); + auto y_ptr = thrust::device_pointer_cast(in0->data()); + auto inter_val_ptr = thrust::device_pointer_cast(in1->data()); + auto out_grad_ptr = thrust::device_pointer_cast(in2->data()); + thrust::device_ptr x_grad_ptr( + out0->mutable_data(context.GetPlace())); + + auto iter_begin = thrust::make_zip_iterator( + thrust::make_tuple(x_grad_ptr, inter_val_ptr, y_ptr, out_grad_ptr)); + + auto iter_end = thrust::make_zip_iterator( + thrust::make_tuple(x_grad_ptr + counts, inter_val_ptr + counts, + y_ptr + counts, out_grad_ptr + counts)); + + thrust::for_each(iter_begin, iter_end, ModifiedHuberLossBackward()); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + modified_huber_loss, + ops::ModifiedHuberLossKernel); +REGISTER_OP_GPU_KERNEL(modified_huber_loss_grad, + ops::ModifiedHuberLossGradGPUKernel); diff --git a/paddle/operators/modified_huber_loss_op.h b/paddle/operators/modified_huber_loss_op.h new file mode 100644 index 0000000000000000000000000000000000000000..2b2aae17084992c4935a697763ff902e455dfcbd --- /dev/null +++ b/paddle/operators/modified_huber_loss_op.h @@ -0,0 +1,107 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + +template +struct CheckLabelValue { + HOSTDEVICE T operator()(const T& val) const { + PADDLE_ASSERT(val == static_cast(0) || val == static_cast(1)); + } +}; + +template +struct ModifiedHuberLossForward { + HOSTDEVICE T operator()(const T& val) const { + if (val < -1) { + return -4 * val; + } else if (val < 1) { + return (1 - val) * (1 - val); + } else { + return static_cast(0); + } + } +}; + +template +class ModifiedHuberLossKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in0 = context.Input("X"); + auto* in1 = context.Input("Y"); + auto* out0 = context.Output("IntermediateVal"); + auto* out1 = context.Output("Out"); + + out0->mutable_data(context.GetPlace()); + out1->mutable_data(context.GetPlace()); + auto place = context.GetEigenDevice(); + + auto x = EigenVector::Flatten(*in0); + auto y = EigenVector::Flatten(*in1); + // make sure value's of Y in {0, 1} + y.unaryExpr(CheckLabelValue()); + auto inter_val = EigenVector::Flatten(*out0); + // scale y to {-1, +1} and compute x * y + inter_val.device(place) = x * (2 * y - static_cast(1)); + auto loss = EigenVector::Flatten(*out1); + loss.device(place) = inter_val.unaryExpr(ModifiedHuberLossForward()); + } +}; + +// CPU backward kernel +template +class ModifiedHuberLossGradCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in0 = context.Input("Y"); + auto* in1 = context.Input("IntermediateVal"); + auto* in2 = + context.Input(framework::GradVarName("Out")); + auto* out0 = + context.Output(framework::GradVarName("X")); + + if (out0) { + const T* y_ptr = in0->data(); + const T* inter_val_ptr = in1->data(); + const T* out_grad_ptr = in2->data(); + size_t counts = static_cast(framework::product(in1->dims())); + T* x_grad_ptr = out0->mutable_data(context.GetPlace()); + for (size_t i = 0; i < counts; ++i) { + if (inter_val_ptr[i] < -1) { + x_grad_ptr[i] = -4 * (2 * y_ptr[i] - 1) * out_grad_ptr[i]; + } else if (inter_val_ptr[i] < 1) { + x_grad_ptr[i] = -2 * (1 - inter_val_ptr[i]) * (2 * y_ptr[i] - 1) * + out_grad_ptr[i]; + } else { + x_grad_ptr[i] = 0; + } + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_modified_huber_loss_op.py b/python/paddle/v2/framework/tests/test_modified_huber_loss_op.py new file mode 100644 index 0000000000000000000000000000000000000000..a7e2b57529b0723b4ab18b73801cd2816d8025dd --- /dev/null +++ b/python/paddle/v2/framework/tests/test_modified_huber_loss_op.py @@ -0,0 +1,39 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def modified_huber_loss_forward(val): + if val < -1: + return -4 * val + elif val < 1: + return (1 - val) * (1 - val) + else: + return 0 + + +class TestModifiedHuberLossOp(OpTest): + def setUp(self): + self.op_type = 'modified_huber_loss' + samples_num = 32 + self.inputs = { + 'X': np.random.uniform(-1, 1., (samples_num, 1)).astype('float32'), + 'Y': np.random.choice([0, 1], samples_num).reshape((samples_num, 1)) + } + product_res = self.inputs['X'] * (2 * self.inputs['Y'] - 1) + loss = np.vectorize(modified_huber_loss_forward)(product_res) + + self.outputs = { + 'IntermediateVal': product_res, + 'Out': loss.reshape((samples_num, 1)) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out', max_relative_error=0.005) + + +if __name__ == '__main__': + unittest.main()