From d736fc0e00108384853a996aef9d51dbe81f1564 Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 13 Sep 2017 17:33:36 +0800 Subject: [PATCH] add activation macro --- paddle/framework/operator.h | 6 +- paddle/operators/activation_op.cc | 115 ++++++++++++++++++ .../{sigmoid_op.cu => activation_op.cu} | 11 +- paddle/operators/activation_op.h | 71 +++++++++++ paddle/operators/math/activation.h | 20 --- paddle/operators/math/activation_functor.h | 96 +++++++++++++++ paddle/operators/sigmoid_op.cc | 61 ---------- paddle/operators/sigmoid_op.h | 62 ---------- paddle/pybind/pybind.cc | 4 +- .../paddle/v2/framework/tests/test_exp_op.py | 22 ++++ .../paddle/v2/framework/tests/test_relu_op.py | 22 ++++ 11 files changed, 342 insertions(+), 148 deletions(-) create mode 100644 paddle/operators/activation_op.cc rename paddle/operators/{sigmoid_op.cu => activation_op.cu} (66%) create mode 100644 paddle/operators/activation_op.h delete mode 100644 paddle/operators/math/activation.h create mode 100644 paddle/operators/math/activation_functor.h delete mode 100644 paddle/operators/sigmoid_op.cc delete mode 100644 paddle/operators/sigmoid_op.h create mode 100644 python/paddle/v2/framework/tests/test_exp_op.py create mode 100644 python/paddle/v2/framework/tests/test_relu_op.py diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index bfa2190557e..0970797e02b 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -139,9 +139,9 @@ class OperatorBase { // Macro for define a clone method. // If you are writing an kernel operator, `Clone` will be defined when you // register it. i.e. `Clone` method is not needed to define by yourself. -#define DEFINE_OP_CLONE_METHOD(cls) \ - std::unique_ptr Clone() const final { \ - return std::unique_ptr(new cls(*this)); \ +#define DEFINE_OP_CLONE_METHOD(cls) \ + std::unique_ptr<::paddle::framework::OperatorBase> Clone() const final { \ + return std::unique_ptr<::paddle::framework::OperatorBase>(new cls(*this)); \ } // Macro for define a default constructor for Operator. diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc new file mode 100644 index 00000000000..d2c2378feff --- /dev/null +++ b/paddle/operators/activation_op.cc @@ -0,0 +1,115 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/activation_op.h" + +#define FILL_ACTIVATION_OP \ + public: \ + using framework::OperatorWithKernel::OperatorWithKernel; \ + \ + protected: \ + void InferShape(const framework::InferShapeContext &ctx) const override { \ + ctx.Output("Y")->Resize( \ + ctx.Input("X")->dims()); \ + } + +#define FILL_ACTIVATION_GRAD_OP \ + public: \ + using framework::OperatorWithKernel::OperatorWithKernel; \ + \ + protected: \ + void InferShape(const framework::InferShapeContext &ctx) const override { \ + ctx.Output(framework::GradVarName("X")) \ + ->Resize(ctx.Input("Y")->dims()); \ + } + +namespace paddle { +namespace operators { + +class SigmoidOp : public framework::OperatorWithKernel { + FILL_ACTIVATION_OP +}; + +class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SigmoidOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Sigmoid operator"); + AddOutput("Y", "Output of Sigmoid operator"); + AddComment("Sigmoid activation operator"); + } +}; + +class SigmoidOpGrad : public framework::OperatorWithKernel { + FILL_ACTIVATION_GRAD_OP +}; + +class ExpOp : public framework::OperatorWithKernel { + FILL_ACTIVATION_OP +}; + +class ExpOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ExpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Exp operator"); + AddOutput("Y", "Output of Exp operator"); + AddComment("Exp activation operator"); + } +}; + +class ExpOpGrad : public framework::OperatorWithKernel { + FILL_ACTIVATION_GRAD_OP +}; + +class ReluOp : public framework::OperatorWithKernel { + FILL_ACTIVATION_OP +}; + +class ReluOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Relu operator"); + AddOutput("Y", "Output of Relu operator"); + AddComment("Relu activation operator"); + } +}; + +class ReluOpGrad : public framework::OperatorWithKernel { + FILL_ACTIVATION_GRAD_OP +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad, + ops::SigmoidOpGrad); +REGISTER_OP_CPU_KERNEL(sigmoid, + ops::SigmoidKernel); +REGISTER_OP_CPU_KERNEL( + sigmoid_grad, ops::SigmoidGradKernel); + +REGISTER_OP(exp, ops::ExpOp, ops::ExpOpMaker, exp_grad, ops::ExpOpGrad); +REGISTER_OP_CPU_KERNEL(exp, ops::ExpKernel); +REGISTER_OP_CPU_KERNEL(exp_grad, + ops::ExpGradKernel); + +REGISTER_OP(relu, ops::ReluOp, ops::ReluOpMaker, relu_grad, ops::ReluOpGrad); +REGISTER_OP_CPU_KERNEL(relu, + ops::ReluKernel); +REGISTER_OP_CPU_KERNEL(relu_grad, + ops::ReluGradKernel); diff --git a/paddle/operators/sigmoid_op.cu b/paddle/operators/activation_op.cu similarity index 66% rename from paddle/operators/sigmoid_op.cu rename to paddle/operators/activation_op.cu index 1a50dfe14a7..55d9f52124d 100644 --- a/paddle/operators/sigmoid_op.cu +++ b/paddle/operators/activation_op.cu @@ -13,7 +13,7 @@ limitations under the License. */ #define EIGEN_USE_GPU -#include "paddle/operators/sigmoid_op.h" +#include "paddle/operators/activation_op.h" namespace ops = paddle::operators; @@ -21,3 +21,12 @@ REGISTER_OP_GPU_KERNEL(sigmoid, ops::SigmoidKernel); REGISTER_OP_GPU_KERNEL( sigmoid_grad, ops::SigmoidGradKernel); + +REGISTER_OP_GPU_KERNEL(exp, ops::ExpKernel); +REGISTER_OP_GPU_KERNEL(exp_grad, + ops::ExpGradKernel); + +REGISTER_OP_GPU_KERNEL(relu, + ops::ReluKernel); +REGISTER_OP_GPU_KERNEL(relu_grad, + ops::ReluGradKernel); diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h new file mode 100644 index 00000000000..9e4101805e8 --- /dev/null +++ b/paddle/operators/activation_op.h @@ -0,0 +1,71 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/activation_functor.h" + +#define ACTIVATION_KERNEL_NAME(ACTIVATION_NAME) ACTIVATION_NAME##Kernel + +#define DEFINE_ACTIVATION_KERNEL(ACTIVATION_NAME) \ + template \ + class ACTIVATION_KERNEL_NAME(ACTIVATION_NAME) : public framework::OpKernel { \ + public: \ + void Compute(const framework::ExecutionContext& context) const override { \ + auto* X = context.Input("X"); \ + auto* Y = context.Output("Y"); \ + Y->mutable_data(context.GetPlace()); \ + math::ACTIVATION_NAME functor; \ + auto* device_context = context.device_context(); \ + functor(*device_context, *X, Y); \ + } \ + }; + +#define DEFINE_ACTIVATION_GRAD_KERNEL(ACTIVATION_GRAD_NAME) \ + template \ + class ACTIVATION_KERNEL_NAME(ACTIVATION_GRAD_NAME) \ + : public framework::OpKernel { \ + public: \ + void Compute(const framework::ExecutionContext& context) const override { \ + auto* X = context.Input("X"); \ + auto* Y = context.Input("Y"); \ + auto* dY = \ + context.Input(framework::GradVarName("Y")); \ + auto* dX = \ + context.Output(framework::GradVarName("X")); \ + dX->mutable_data(context.GetPlace()); \ + math::ACTIVATION_GRAD_NAME functor; \ + auto* device_context = context.device_context(); \ + functor(*device_context, *X, *Y, *dY, dX); \ + } \ + }; + +namespace paddle { +namespace operators { + +DEFINE_ACTIVATION_KERNEL(Sigmoid); + +DEFINE_ACTIVATION_GRAD_KERNEL(SigmoidGrad); + +DEFINE_ACTIVATION_KERNEL(Exp); + +DEFINE_ACTIVATION_GRAD_KERNEL(ExpGrad); + +DEFINE_ACTIVATION_KERNEL(Relu); + +DEFINE_ACTIVATION_GRAD_KERNEL(ReluGrad); + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/activation.h b/paddle/operators/math/activation.h deleted file mode 100644 index b6af478d82d..00000000000 --- a/paddle/operators/math/activation.h +++ /dev/null @@ -1,20 +0,0 @@ -#include "paddle/framework/eigen.h" -#include "paddle/framework/tensor.h" - -namespace paddle { -namespace operators { -namespace math { - -template -struct sigmoid { - void operator()(const platform::DeviceContext& deice_context, - const framework::Tensor& input, framework::Tensor* output) { - auto x = framework::EigenVector::Flatten(*output); - auto y = framework::EigenVector::Flatten(input); - auto* place = device_context.get_eigen_device(); - y.device(*place) = 1. / (1. + (-x).exp()); - } -}; -} -} -} diff --git a/paddle/operators/math/activation_functor.h b/paddle/operators/math/activation_functor.h new file mode 100644 index 00000000000..7e15607f462 --- /dev/null +++ b/paddle/operators/math/activation_functor.h @@ -0,0 +1,96 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct Sigmoid { + void operator()(const platform::DeviceContext& device_context, + const framework::Tensor& X, framework::Tensor* Y) { + auto x = framework::EigenVector::Flatten(X); + auto y = framework::EigenVector::Flatten(*Y); + auto* place = device_context.template get_eigen_device(); + y.device(*place) = 1. / (1. + (-x).exp()); + } +}; + +template +struct SigmoidGrad { + void operator()(const platform::DeviceContext& device_context, + const framework::Tensor& X, const framework::Tensor& Y, + const framework::Tensor& dY, framework::Tensor* dX) { + auto dx = framework::EigenVector::Flatten(*dX); + auto y = framework::EigenVector::Flatten(Y); + auto dy = framework::EigenVector::Flatten(dY); + auto* place = device_context.template get_eigen_device(); + dx.device(*place) = dy * y * (1. - y); + } +}; + +template +struct Exp { + void operator()(const platform::DeviceContext& device_context, + const framework::Tensor& input, framework::Tensor* output) { + auto x = framework::EigenVector::Flatten(input); + auto y = framework::EigenVector::Flatten(*output); + auto* place = device_context.template get_eigen_device(); + y.device(*place) = x.exp(); + } +}; + +template +struct ExpGrad { + void operator()(const platform::DeviceContext& device_context, + const framework::Tensor& X, const framework::Tensor& Y, + const framework::Tensor& dY, framework::Tensor* dX) { + auto dx = framework::EigenVector::Flatten(*dX); + auto dy = framework::EigenVector::Flatten(dY); + auto* place = device_context.template get_eigen_device(); + dx.device(*place) = dy.exp(); + } +}; + +template +struct Relu { + void operator()(const platform::DeviceContext& device_context, + const framework::Tensor& input, framework::Tensor* output) { + auto x = framework::EigenVector::Flatten(input); + auto y = framework::EigenVector::Flatten(*output); + auto* place = device_context.template get_eigen_device(); + y.device(*place) = x.cwiseMax(static_cast(0)); + } +}; + +template +struct ReluGrad { + void operator()(const platform::DeviceContext& device_context, + const framework::Tensor& X, const framework::Tensor& Y, + const framework::Tensor& dY, framework::Tensor* dX) { + auto dx = framework::EigenVector::Flatten(*dX); + auto dy = framework::EigenVector::Flatten(dY); + auto x = framework::EigenVector::Flatten(X); + auto* place = device_context.template get_eigen_device(); + dx.device(*place) = dy * (x > static_cast(0)).template cast(); + } +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc deleted file mode 100644 index 761c6de8d4d..00000000000 --- a/paddle/operators/sigmoid_op.cc +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/operators/sigmoid_op.h" - -namespace paddle { -namespace operators { - -class SigmoidOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - ctx.Output("Y")->Resize(ctx.Input("X")->dims()); - } -}; - -class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { - public: - SigmoidOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "sigmoid input"); - AddOutput("Y", "sigmoid output"); - AddComment("Sigmoid function"); - } -}; - -class SigmoidOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - ctx.Output(framework::GradVarName("X")) - ->Resize(ctx.Input("Y")->dims()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad, - ops::SigmoidOpGrad); -REGISTER_OP_CPU_KERNEL(sigmoid, - ops::SigmoidKernel); -REGISTER_OP_CPU_KERNEL( - sigmoid_grad, ops::SigmoidGradKernel); diff --git a/paddle/operators/sigmoid_op.h b/paddle/operators/sigmoid_op.h deleted file mode 100644 index b01a9b3f232..00000000000 --- a/paddle/operators/sigmoid_op.h +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#pragma once -#include "paddle/framework/eigen.h" -#include "paddle/framework/op_registry.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -template -using EigenVector = framework::EigenVector; - -template -class SigmoidKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto input = context.Input("X"); - auto output = context.Output("Y"); - output->mutable_data(context.GetPlace()); - - // The clipping is used in Paddle's raw implenmention - auto X = EigenVector::Flatten(*input); - auto Y = EigenVector::Flatten(*output); - auto place = context.GetEigenDevice(); - - Y.device(place) = 1. / (1. + (-X).exp()); - } -}; - -template -class SigmoidGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto Y_t = context.Input("Y"); - auto dY_t = context.Input(framework::GradVarName("Y")); - auto dX_t = context.Output(framework::GradVarName("X")); - - dX_t->mutable_data(context.GetPlace()); - - auto dX = EigenVector::Flatten(*dX_t); - auto Y = EigenVector::Flatten(*Y_t); - auto dY = EigenVector::Flatten(*dY_t); - dX.device(context.GetEigenDevice()) = dY * Y * (1. - Y); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 16a2368aae5..bd964c5d079 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -36,7 +36,6 @@ USE_OP(onehot_cross_entropy); USE_OP(sgd); USE_OP(mul); USE_OP(mean); -USE_OP(sigmoid); USE_OP(softmax); USE_OP(rowwise_add); USE_OP(fill_zeros_like); @@ -55,6 +54,9 @@ USE_OP(top_k); USE_OP(squared_l2_distance); USE_OP(sum); USE_OP(reshape); +USE_OP(sigmoid); +USE_OP(exp); +USE_OP(relu); namespace paddle { namespace framework { diff --git a/python/paddle/v2/framework/tests/test_exp_op.py b/python/paddle/v2/framework/tests/test_exp_op.py new file mode 100644 index 00000000000..5a004f6fe2f --- /dev/null +++ b/python/paddle/v2/framework/tests/test_exp_op.py @@ -0,0 +1,22 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestExp(OpTest): + def setUp(self): + self.op_type = "exp" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") + } + self.outputs = {'Y': np.exp(self.inputs['X'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Y", max_relative_error=0.007) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_relu_op.py b/python/paddle/v2/framework/tests/test_relu_op.py new file mode 100644 index 00000000000..07b7113d791 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_relu_op.py @@ -0,0 +1,22 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestExp(OpTest): + def setUp(self): + self.op_type = "exp" + self.inputs = { + 'X': np.random.uniform(-1, 1, [11, 17]).astype("float32") + } + self.outputs = {'Y': np.maximum(self.inputs['X'], 0)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Y", max_relative_error=0.007) + + +if __name__ == '__main__': + unittest.main() -- GitLab