diff --git a/paddle/operators/smooth_l1_loss_op.cc b/paddle/operators/smooth_l1_loss_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..9ee6fff8db6a285a0314431e4e13b284c78c8a70 --- /dev/null +++ b/paddle/operators/smooth_l1_loss_op.cc @@ -0,0 +1,135 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/smooth_l1_loss_op.h" + +namespace paddle { +namespace operators { + +class SmoothL1LossOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext& ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X must be initialized."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Y must be initialized."); + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + PADDLE_ENFORCE_EQ(x->dims(), y->dims(), + "The shape of X and Y must be the same."); + PADDLE_ENFORCE_GE(x->dims().size(), 2, + "The tensor rank of X must be at least 2."); + auto* inside_weight = ctx.Input("InsideWeight"); + if (inside_weight) { + auto* outside_weight = ctx.Input("OutsideWeight"); + PADDLE_ENFORCE_NOT_NULL(outside_weight, + "If weights are provided, must specify both " + "inside and outside weights."); + PADDLE_ENFORCE_EQ(inside_weight->dims(), x->dims(), + "The shape of InsideWeight must be same as X."); + PADDLE_ENFORCE_EQ(outside_weight->dims(), x->dims(), + "The shape of OutsideWeight must be same as X."); + } + + auto* diff = ctx.Output("Diff"); + auto* out = ctx.Output("Out"); + diff->Resize(x->dims()); + // loss is a two-rank tensor + out->Resize({x->dims()[0], 1}); + } +}; + +template +class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SmoothL1LossOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "The input tensor of smooth l1 loss op." + "The rank should be greater or equal to 2 with shape " + "[batch_size, value_dim1, value_dim2, ..., value_dimN]"); + AddInput("Y", + "The target tensor of smooth l1 loss op " + "with the same shape as X."); + AddInput("InsideWeight", + "Optional input tensor of smooth l1 loss op with the same shape " + "as X. If provided, the result of (X - Y) will be multiplied " + "by this tensor element by element."); + AddInput("OutsideWeight", + "Optinal input of smooth l1 loss op with the same shape as X." + "If provided, the output smooth l1 loss will be multiplied by " + "this tensor element by element."); + AddOutput("Diff", "Intermediate variable to cache InsideWeight*(X-Y).") + .AsIntermediate(); + AddOutput("Out", "Smooth l1 loss."); + AddAttr("sigma", + "Hyper parameter of smooth l1 loss op." + "A float scalar with default value 3.0.") + .SetDefault(3.0); + AddComment(R"DOC( +Compute smooth l1 loss for input and target. The operator take the 1st +dimension of input as batch size. For each instance, it will compute +smooth l1 loss element by element first and sum all losses to one value. +So the output shape is [batch_size, 1]. + +The equation is: +loss = 0.5 * (sigma * (x-y))^2 if abs(x - y) < 1 / sigma^2 + abs(x - y) - 0.5 / sigma^2 otherwise + +)DOC"); + } +}; + +class SmoothL1LossGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext& ctx) const override { + auto in_dims = ctx.Input("X")->dims(); + auto out_dims = + ctx.Input(framework::GradVarName("Out"))->dims(); + auto* x_grad = + ctx.Output(framework::GradVarName("X")); + auto* y_grad = + ctx.Output(framework::GradVarName("Y")); + + PADDLE_ENFORCE_GE(out_dims.size(), 2, + "The tensor rank of Input(Out@Grad) should be 2."); + PADDLE_ENFORCE_EQ(out_dims[0], in_dims[0], + "The 1st dimension of Input(Out@Grad) must be " + "same as input."); + PADDLE_ENFORCE_EQ(out_dims[1], 1, + "The 2nd dimension of Input(Out@Grad) must be 1."); + + if (x_grad) x_grad->Resize(in_dims); + if (y_grad) y_grad->Resize(in_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(smooth_l1_loss, ops::SmoothL1LossOp, + ops::SmoothL1LossOpMaker, smooth_l1_loss_grad, + ops::SmoothL1LossGradOp); +REGISTER_OP_CPU_KERNEL( + smooth_l1_loss, ops::SmoothL1LossKernel); +REGISTER_OP_CPU_KERNEL( + smooth_l1_loss_grad, + ops::SmoothL1LossGradKernel); diff --git a/paddle/operators/smooth_l1_loss_op.cu b/paddle/operators/smooth_l1_loss_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..1c3172f43867741cd1f26979a366b2425f326321 --- /dev/null +++ b/paddle/operators/smooth_l1_loss_op.cu @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU + +#include "paddle/operators/smooth_l1_loss_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + smooth_l1_loss, ops::SmoothL1LossKernel); +REGISTER_OP_GPU_KERNEL( + smooth_l1_loss_grad, + ops::SmoothL1LossGradKernel); diff --git a/paddle/operators/smooth_l1_loss_op.h b/paddle/operators/smooth_l1_loss_op.h new file mode 100644 index 0000000000000000000000000000000000000000..0604fb5e1c2f17c702208520a1d23bd5c3c65b5d --- /dev/null +++ b/paddle/operators/smooth_l1_loss_op.h @@ -0,0 +1,182 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; +template +using EigenMatrix = framework::EigenMatrix; + +template +struct SmoothL1LossForward { + HOSTDEVICE SmoothL1LossForward(const T& sigma2) : sigma2(sigma2) {} + + HOSTDEVICE T operator()(const T& val) const { + T abs_val = std::abs(val); + if (abs_val < 1.0 / sigma2) { + return 0.5 * val * val * sigma2; + } else { + return abs_val - 0.5 / sigma2; + } + } + + T sigma2; +}; + +template +class SmoothL1LossKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in0 = context.Input("X"); + auto* in1 = context.Input("Y"); + auto* in2 = context.Input("InsideWeight"); + auto* in3 = context.Input("OutsideWeight"); + auto* out0 = context.Output("Diff"); + auto* out1 = context.Output("Out"); + + out0->mutable_data(context.GetPlace()); + out1->mutable_data(context.GetPlace()); + auto place = context.GetEigenDevice(); + + auto sigma = static_cast(context.Attr("sigma")); + T sigma2 = sigma * sigma; + bool has_weight = (in2 != nullptr) && (in3 != nullptr); + + auto x = EigenVector::Flatten(*in0); + auto y = EigenVector::Flatten(*in1); + auto diff = EigenVector::Flatten(*out0); + + diff.device(place) = x - y; + // multiply inside weight + if (has_weight) { + auto inside_weight = EigenVector::Flatten(*in2); + // cache diff, reused in bp + diff.device(place) = diff * inside_weight; + } + + auto in_counts = in0->numel(); + Tensor ptensor_errors; + ptensor_errors.mutable_data({static_cast(in_counts)}, + context.GetPlace()); + auto errors = EigenVector::Flatten(ptensor_errors); + // apply smooth l1 forward + errors.device(place) = diff.unaryExpr(SmoothL1LossForward(sigma2)); + + // multiply outside weight + if (has_weight) { + auto outside_weight = EigenVector::Flatten(*in3); + errors.device(place) = errors * outside_weight; + } + auto loss = EigenVector::Flatten(*out1); + // first dimension of 'X' is the number of samples + auto mat_dims = + framework::make_ddim({static_cast(in0->dims()[0]), + static_cast(in_counts / in0->dims()[0])}); + auto errors_mat_view = EigenMatrix::From(ptensor_errors, mat_dims); + loss.device(place) = errors_mat_view.sum(Eigen::array({{1}})); + } +}; + +template +struct SmoothL1LossBackward { + HOSTDEVICE SmoothL1LossBackward(const T& sigma2) : sigma2(sigma2) {} + + HOSTDEVICE T operator()(const T& val) const { + T abs_val = std::abs(val); + if (abs_val < 1.0 / sigma2) { + return sigma2 * val; + } else { + return (0 < val) - (val < 0); + } + } + + T sigma2; +}; + +template +class SmoothL1LossGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in0 = context.Input("InsideWeight"); + auto* in1 = context.Input("OutsideWeight"); + auto* in2 = context.Input("Diff"); + auto* og = context.Input(framework::GradVarName("Out")); + auto sigma = static_cast(context.Attr("sigma")); + T sigma2 = sigma * sigma; + bool has_weight = (in0 != nullptr) && (in1 != nullptr); + + auto place = context.GetEigenDevice(); + + auto in_dims = in2->dims(); + auto counts = in2->numel(); + auto cols = counts / in_dims[0]; + auto mat_dims = framework::make_ddim( + {static_cast(in_dims[0]), static_cast(cols)}); + + Tensor ptensor_diff; + ptensor_diff.mutable_data({static_cast(counts)}, + context.GetPlace()); + auto diff = EigenVector::Flatten(ptensor_diff); + // apply smooth l1 backwoard + diff.device(place) = EigenVector::Flatten(*in2).unaryExpr( + SmoothL1LossBackward(sigma2)); + + // compute weights + Tensor ptensor_weights; + ptensor_weights.mutable_data(mat_dims, context.GetPlace()); + auto weights = EigenMatrix::From(ptensor_weights); + // initialize to 1.0 + weights.device(place) = weights.constant(static_cast(1.0)); + if (has_weight) { + auto inside_weight = EigenMatrix::From(*in0, mat_dims); + auto outside_weight = EigenMatrix::From(*in1, mat_dims); + weights.device(place) = inside_weight * outside_weight; + } + + // compute gradients + auto out_grad = EigenMatrix::From(*og); + auto diff_mat_view = EigenMatrix::From(ptensor_diff, mat_dims); + auto gradients = out_grad.broadcast( + Eigen::array({{1, static_cast(cols)}})) * + weights * diff_mat_view; + + auto* out0 = context.Output(framework::GradVarName("X")); + auto* out1 = context.Output(framework::GradVarName("Y")); + + if (out0) { + out0->mutable_data(context.GetPlace()); + auto x_grad = EigenMatrix::From(*out0, mat_dims); + x_grad.device(place) = gradients; + } + + if (out1) { + out1->mutable_data(context.GetPlace()); + auto y_grad = EigenMatrix::From(*out1, mat_dims); + y_grad.device(place) = -1 * gradients; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_smooth_l1_loss_op.py b/python/paddle/v2/framework/tests/test_smooth_l1_loss_op.py new file mode 100644 index 0000000000000000000000000000000000000000..be940327ec910ccb9de59d45029513ff4779443b --- /dev/null +++ b/python/paddle/v2/framework/tests/test_smooth_l1_loss_op.py @@ -0,0 +1,87 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def smooth_l1_loss_forward(val, sigma2): + abs_val = abs(val) + if abs_val < 1.0 / sigma2: + return 0.5 * val * val * sigma2 + else: + return abs_val - 0.5 / sigma2 + + +class TestSmoothL1LossOp1(OpTest): + def setUp(self): + self.op_type = "smooth_l1_loss" + dims = (5, 10) + self.inputs = { + 'X': np.random.random(dims).astype("float32"), + 'Y': np.random.random(dims).astype("float32") + } + sigma = 3.0 + self.attrs = {'sigma': sigma} + sigma2 = sigma * sigma + diff = self.inputs['X'] - self.inputs['Y'] + loss = np.vectorize(smooth_l1_loss_forward)(diff, sigma2).sum(1) + loss = loss.reshape((dims[0], 1)) + self.outputs = {'Diff': diff, 'Out': loss} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.02) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], 'Out', max_relative_error=0.03, no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], 'Out', max_relative_error=0.03, no_grad_set=set('Y')) + + +class TestSmoothL1LossOp2(OpTest): + def setUp(self): + self.op_type = "smooth_l1_loss" + dims = (5, 10) + self.inputs = { + 'X': np.random.random(dims).astype("float32"), + 'Y': np.random.random(dims).astype("float32"), + 'InsideWeight': np.random.random(dims).astype("float32"), + 'OutsideWeight': np.random.random(dims).astype("float32") + } + sigma = 3.0 + self.attrs = {'sigma': sigma} + sigma2 = sigma * sigma + diff = self.inputs['X'] - self.inputs['Y'] + diff = diff * self.inputs['InsideWeight'] + loss = np.vectorize(smooth_l1_loss_forward)(diff, sigma2) + loss = loss * self.inputs['OutsideWeight'] + loss = loss.sum(1).reshape((dims[0], 1)) + self.outputs = {'Diff': diff, 'Out': loss} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.03) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], + 'Out', + max_relative_error=0.03, + no_grad_set=set(['X', 'InsideWeight', 'OutsideWeight'])) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], + 'Out', + max_relative_error=0.03, + no_grad_set=set(['Y', 'InsideWeight', 'OutsideWeight'])) + + +if __name__ == '__main__': + unittest.main()