From 2763f3e32f7f37d52cbd6379b036958ad3d34ad1 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Wed, 6 Sep 2017 15:32:54 +0800 Subject: [PATCH] Complete smooth_l1_loss_op. --- paddle/operators/smooth_l1_loss_op.cc | 119 +++++++++++ paddle/operators/smooth_l1_loss_op.cu | 24 +++ paddle/operators/smooth_l1_loss_op.h | 184 ++++++++++++++++++ paddle/pybind/pybind.cc | 1 + .../paddle/v2/framework/tests/CMakeLists.txt | 1 + .../framework/tests/test_smooth_l1_loss_op.py | 106 ++++++++++ 6 files changed, 435 insertions(+) create mode 100644 paddle/operators/smooth_l1_loss_op.cc create mode 100644 paddle/operators/smooth_l1_loss_op.cu create mode 100644 paddle/operators/smooth_l1_loss_op.h create mode 100644 python/paddle/v2/framework/tests/test_smooth_l1_loss_op.py diff --git a/paddle/operators/smooth_l1_loss_op.cc b/paddle/operators/smooth_l1_loss_op.cc new file mode 100644 index 0000000000..e9a3847417 --- /dev/null +++ b/paddle/operators/smooth_l1_loss_op.cc @@ -0,0 +1,119 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/smooth_l1_loss_op.h" + +namespace paddle { +namespace operators { + +class SmoothL1LossOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext& ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input of SmoothL1LossOp must be initialized."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), + "Target of SmoothL1LossOp must be initialized."); + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + PADDLE_ENFORCE_EQ(x->dims(), y->dims(), + "Dimensions of SmoothL1LossOp's input and target " + "must be same."); + PADDLE_ENFORCE_GE(framework::arity(x->dims()), 2, + "Tensor rank of SmoothL1LossOp's input must be " + "at least 2."); + auto* inside_weight = ctx.Input("InsideWeight"); + if (inside_weight) { + auto* outside_weight = ctx.Input("OutsideWeight"); + PADDLE_ENFORCE_NOT_NULL(outside_weight, + "If weights are provided, must specify both " + "inside and outside weights."); + PADDLE_ENFORCE_EQ(inside_weight->dims(), x->dims(), + "Dimensions of inside weight must be same with input."); + PADDLE_ENFORCE_EQ( + outside_weight->dims(), x->dims(), + "Dimensions of outside weight must be same with input."); + } + + auto* diff = ctx.Output("diff"); + auto* out = ctx.Output("Out"); + diff->Resize(x->dims()); + // loss is a two-rank tensor + out->Resize({x->dims()[0], 1}); + } +}; + +template +class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SmoothL1LossOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of SmoothL1LossOp."); + AddInput("Y", "Target of SmoothL1LossOp."); + AddInput("InsideWeight", "Optional input to scale (X-Y)."); + AddInput("OutsideWeight", "Optinal input to scale smooth l1 loss."); + AddOutput("diff", "Intermediate variable to cache Win*(X-Y).") + .AsIntermediate(); + AddOutput("Out", "Final smooth l1 loss of inputs."); + AddComment(R"DOC( +Compute SmoothL1Loss for input and target. + +The equation is: Out = 0.5 * (sigma * (X - Y)) ^ 2 if abs(X - Y) < 1 / sigma^2 + abs(X - Y) - 0.5 / sigma^2 otherwise +)DOC"); + AddAttr("sigma", "Hyper parameter, default value is 3.0 .") + .SetDefault(3.0); + } +}; + +class SmoothL1LossGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext& ctx) const override { + auto in_dims = ctx.Input("X")->dims(); + auto out_dims = + ctx.Input(framework::GradVarName("Out"))->dims(); + auto* x_grad = ctx.Output(framework::GradVarName("X")); + auto* y_grad = ctx.Output(framework::GradVarName("Y")); + + PADDLE_ENFORCE_GE(framework::arity(out_dims), 2, + "Tensor rank of output gradient should be 2."); + PADDLE_ENFORCE_EQ(out_dims[0], in_dims[0], + "First dimension of ouptut gradient must be " + "same with input."); + PADDLE_ENFORCE_EQ(out_dims[1], 1, + "Second dimension of output gradient must be 1."); + + if (x_grad) x_grad->Resize(in_dims); + if (y_grad) y_grad->Resize(in_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(smooth_l1_loss, ops::SmoothL1LossOp, + ops::SmoothL1LossOpMaker, ops::SmoothL1LossGradOp); +REGISTER_OP_CPU_KERNEL( + smooth_l1_loss, ops::SmoothL1LossKernel); +REGISTER_OP_CPU_KERNEL( + smooth_l1_loss_grad, + ops::SmoothL1LossGradKernel); diff --git a/paddle/operators/smooth_l1_loss_op.cu b/paddle/operators/smooth_l1_loss_op.cu new file mode 100644 index 0000000000..1c3172f438 --- /dev/null +++ b/paddle/operators/smooth_l1_loss_op.cu @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU + +#include "paddle/operators/smooth_l1_loss_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + smooth_l1_loss, ops::SmoothL1LossKernel); +REGISTER_OP_GPU_KERNEL( + smooth_l1_loss_grad, + ops::SmoothL1LossGradKernel); diff --git a/paddle/operators/smooth_l1_loss_op.h b/paddle/operators/smooth_l1_loss_op.h new file mode 100644 index 0000000000..ae91b9c893 --- /dev/null +++ b/paddle/operators/smooth_l1_loss_op.h @@ -0,0 +1,184 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; +template +using EigenMatrix = framework::EigenMatrix; + +template +struct SmoothL1LossFoward { + __host__ __device__ SmoothL1LossFoward(const T& sigma2) : sigma2(sigma2) {} + + __host__ __device__ T operator()(const T& val) const { + T abs_val = std::abs(val); + if (abs_val < 1.0 / sigma2) { + return 0.5 * val * val * sigma2; + } else { + return abs_val - 0.5 / sigma2; + } + } + + T sigma2; +}; + +template +class SmoothL1LossKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in0 = context.Input("X"); + auto* in1 = context.Input("Y"); + auto* in2 = context.Input("InsideWeight"); + auto* in3 = context.Input("OutsideWeight"); + auto* out0 = context.Output("diff"); + auto* out1 = context.Output("Out"); + + out0->mutable_data(context.GetPlace()); + out1->mutable_data(context.GetPlace()); + auto place = context.GetEigenDevice(); + + auto sigma = static_cast(context.op_.GetAttr("sigma")); + T sigma2 = sigma * sigma; + bool has_weight = (in2 != nullptr) && (in3 != nullptr); + + auto x = EigenVector::Flatten(*in0); + auto y = EigenVector::Flatten(*in1); + auto diff = EigenVector::Flatten(*out0); + + diff.device(place) = x - y; + // multiply inside weight + if (has_weight) { + auto inside_weight = EigenVector::Flatten(*in2); + // cache diff, reused in bp + diff.device(place) = diff * inside_weight; + } + + auto in_counts = framework::product(in0->dims()); + Tensor paddle_errors; + paddle_errors.mutable_data({static_cast(in_counts)}, + context.GetPlace()); + auto errors = EigenVector::Flatten(paddle_errors); + // apply smooth l1 forward + errors.device(place) = diff.unaryExpr(SmoothL1LossFoward(sigma2)); + + // multiply outside weight + if (has_weight) { + auto outside_weight = EigenVector::Flatten(*in3); + errors.device(place) = errors * outside_weight; + } + auto loss = EigenMatrix::From(*out1, {in0->dims()[0], 1}); + // first dimension of 'X' is the number of samples + auto errors_mat_view = EigenMatrix::From(paddle_errors, in0->dims()); + loss.device(place) = errors_mat_view.sum(Eigen::array({1})); + } +}; + +template +struct SmoothL1LossBackward { + __host__ __device__ SmoothL1LossBackward(const T& sigma2) : sigma2(sigma2) {} + + __host__ __device__ T operator()(const T& val) const { + T abs_val = std::abs(val); + if (abs_val < 1.0 / sigma2) { + return sigma2 * val; + } else { + return (0 < val) - (val < 0); + } + } + + T sigma2; +}; + +template +class SmoothL1LossGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in0 = context.Input("InsideWeight"); + auto* in1 = context.Input("OutsideWeight"); + auto* in2 = context.Input("diff"); + auto* og = context.Input(framework::GradVarName("Out")); + auto sigma = static_cast(context.op_.GetAttr("sigma")); + T sigma2 = sigma * sigma; + bool has_weight = (in0 != nullptr) && (in1 != nullptr); + + auto place = context.GetEigenDevice(); + + auto in_dims = in2->dims(); + auto counts = framework::product(in_dims); + auto cols = counts / in_dims[0]; + auto mat_dims = framework::make_ddim( + {static_cast(in_dims[0]), static_cast(cols)}); + + Tensor paddle_diff; + paddle_diff.mutable_data({static_cast(counts)}, context.GetPlace()); + auto diff = EigenVector::Flatten(paddle_diff); + // apply smooth l1 backwoard + diff.device(place) = EigenVector::Flatten(*in2).unaryExpr( + SmoothL1LossBackward(sigma2)); + + auto* out0 = context.Output(framework::GradVarName("X")); + auto* out1 = context.Output(framework::GradVarName("Y")); + + // compute weights + Tensor paddle_weights; + paddle_weights.mutable_data(mat_dims, context.GetPlace()); + auto weights = EigenMatrix::From(paddle_weights); + // initialize to 1.0 + if (platform::is_cpu_place(context.GetPlace())) { + weights.setConstant(static_cast(1.0)); + } else { + Tensor paddle_cpu_weights; + paddle_cpu_weights.mutable_data(mat_dims, platform::CPUPlace()); + EigenMatrix::From(paddle_cpu_weights).setConstant(static_cast(1.0)); + paddle_weights.CopyFrom(paddle_cpu_weights, context.GetPlace()); + } + if (has_weight) { + auto inside_weight = EigenMatrix::From(*in0, mat_dims); + auto outside_weight = EigenMatrix::From(*in1, mat_dims); + weights.device(place) = inside_weight * outside_weight; + } + + // compute gradients + auto out_grad = EigenMatrix::From(*og); + auto diff_mat_view = EigenMatrix::From(paddle_diff, mat_dims); + auto gradients = + out_grad.broadcast(Eigen::array({1, static_cast(cols)})) * + weights * diff_mat_view; + + if (out0) { + out0->mutable_data(context.GetPlace()); + auto x_grad = EigenMatrix::From(*out0, mat_dims); + x_grad.device(place) = gradients; + } + + if (out1) { + out1->mutable_data(context.GetPlace()); + auto y_grad = EigenMatrix::From(*out1, mat_dims); + y_grad.device(place) = -1 * gradients; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 3bc150ccb7..5aaa372664 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -48,6 +48,7 @@ USE_OP_ITSELF(identity); USE_OP(minus); USE_CPU_ONLY_OP(gather); USE_CPU_ONLY_OP(scatter); +USE_OP(smooth_l1_loss); namespace paddle { namespace framework { diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 661ebd8964..763f3a9f95 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -32,3 +32,4 @@ py_test(test_gradient_checker SRCS test_gradient_checker.py) py_test(test_lookup_table SRCS test_lookup_table.py) py_test(test_scale_and_identity_op SRCS test_scale_and_identity_op.py) py_test(mnist SRCS mnist.py) +py_test(test_smooth_l1_loss_op SRCS test_smooth_l1_loss_op.py) diff --git a/python/paddle/v2/framework/tests/test_smooth_l1_loss_op.py b/python/paddle/v2/framework/tests/test_smooth_l1_loss_op.py new file mode 100644 index 0000000000..b3432e703e --- /dev/null +++ b/python/paddle/v2/framework/tests/test_smooth_l1_loss_op.py @@ -0,0 +1,106 @@ +import unittest +from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op +import functools +import numpy as np +from paddle.v2.framework.op import Operator + + +def smooth_l1_loss_forward(val, sigma2): + abs_val = abs(val) + if abs_val < 1.0 / sigma2: + return 0.5 * val * val * sigma2 + else: + return abs_val - 0.5 / sigma2 + + +class TestSmoothL1LossOp_f0(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "smooth_l1_loss" + dims = (32, 64) + self.inputs = { + 'X': np.random.random(dims).astype("float32"), + 'Y': np.random.random(dims).astype("float32") + } + sigma = 3.0 + self.attrs = {'sigma': sigma} + sigma2 = sigma * sigma + diff = self.inputs['X'] - self.inputs['Y'] + loss = np.vectorize(smooth_l1_loss_forward)(diff, sigma2).sum(1) + loss = loss.reshape((dims[0], 1)) + self.outputs = {'diff': diff, 'Out': loss} + + +class TestSmoothL1LossOp_f1(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "smooth_l1_loss" + dims = (32, 64) + self.inputs = { + 'X': np.random.random(dims).astype("float32"), + 'Y': np.random.random(dims).astype("float32"), + 'InsideWeight': np.random.random(dims).astype("float32"), + 'OutsideWeight': np.random.random(dims).astype("float32") + } + sigma = 3.0 + self.attrs = {'sigma': sigma} + sigma2 = sigma * sigma + diff = self.inputs['X'] - self.inputs['Y'] + diff = diff * self.inputs['InsideWeight'] + loss = np.vectorize(smooth_l1_loss_forward)(diff, sigma2) + loss = loss * self.inputs['OutsideWeight'] + loss = loss.sum(1).reshape((dims[0], 1)) + self.outputs = {'diff': diff, 'Out': loss} + + +class SmoothL1LossGradOpTest(GradientChecker): + def test_smooth_l1_loss_b0(self): + dims = (5, 7) + X = np.random.random(dims).astype("float32") + Y = np.random.random(dims).astype("float32") + InsideWeight = np.random.random(dims).astype("float32") + OutsideWeight = np.random.random(dims).astype("float32") + inputs = { + 'X': X, + 'Y': Y, + 'InsideWeight': InsideWeight, + 'OutsideWeight': OutsideWeight + } + op = Operator( + "smooth_l1_loss", + X='X', + Y='Y', + InsideWeight='InsideWeight', + OutsideWeight='OutsideWeight', + diff="diff", + Out="Out", + sigma=3.0) + self.compare_grad( + op, inputs, no_grad_set=set(['InsideWeight', 'OutsideWeight'])) + self.check_grad( + op, inputs, set(["X", "Y"]), "Out", max_relative_error=0.08) + + def test_smooth_l1_loss_b1(self): + dims = (5, 7) + X = np.random.random(dims).astype("float32") + Y = np.random.random(dims).astype("float32") + inputs = {'X': X, 'Y': Y} + op = Operator( + "smooth_l1_loss", + X='X', + Y='Y', + InsideWeight='InsideWeight', + OutsideWeight='OutsideWeight', + diff="diff", + Out="Out", + sigma=3.0) + self.compare_grad( + op, inputs, no_grad_set=set(['InsideWeight', 'OutsideWeight'])) + self.check_grad(op, inputs, set(["X", "Y"]), "Out") + + +if __name__ == '__main__': + unittest.main() -- GitLab