From 01ab8a061936aee264c8e9aeae360cf3de195e47 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Tue, 17 Mar 2020 23:25:28 +0800 Subject: [PATCH] add approximation for gelu, test=develop (#22961) add approximation for gelu, default value is False (only kernel with eigen is added, remove code for computing gelu with MKLDNN temporarily) --- paddle/fluid/operators/activation_op.cc | 8 - paddle/fluid/operators/activation_op.h | 85 ----------- paddle/fluid/operators/erf_op.cc | 1 + paddle/fluid/operators/gelu_op.cc | 144 ++++++++++++++++++ paddle/fluid/operators/gelu_op.cu | 28 ++++ paddle/fluid/operators/gelu_op.h | 119 +++++++++++++++ python/paddle/fluid/layers/ops.py | 7 +- .../tests/unittests/test_activation_op.py | 32 +++- .../fluid/tests/unittests/test_gelu_op.py | 28 ++-- 9 files changed, 347 insertions(+), 105 deletions(-) create mode 100644 paddle/fluid/operators/gelu_op.cc create mode 100644 paddle/fluid/operators/gelu_op.cu create mode 100644 paddle/fluid/operators/gelu_op.h diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 18951b8827..71f67466cb 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -185,13 +185,6 @@ $out = \max(x, 0)$ )DOC"; -UNUSED constexpr char GeluDoc[] = R"DOC( -Gelu Activation Operator. - -$out = \\frac{1 + erf(\\frac{x}{\\sqrt{2}})}{2} x$ - -)DOC"; - UNUSED constexpr char TanhDoc[] = R"DOC( Tanh Activation Operator. @@ -635,7 +628,6 @@ REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc); REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc); REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc); REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc); -REGISTER_ACTIVATION_OP_MAKER(Gelu, GeluDoc); REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc); REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc); REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index fa6ec23ce8..8194b1ef44 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -304,90 +304,6 @@ struct ReluGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; -// gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) -template -struct GeluFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { -// Because the execute or device context can not be deliver here, it keep the -// marco for NVCC. -#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \ - !defined(__OSX__) && !defined(PADDLE_WITH_CUDA) - auto x_data = x.data(); - auto out_data = out.data(); - int n = std::min(x.size(), out.size()); - - std::memset(out_data, 0, n * sizeof(T)); - math::CBlas::AXPY(n, static_cast(M_SQRT1_2), x_data, 1, out_data, 1); - math::CBlas::VMERF(n, out_data, out_data, VML_LA); - for (int i = 0; i < n; i++) { - out_data[i] += static_cast(1); - } - math::CBlas::VMUL(n, x_data, out_data, out_data); - for (int i = 0; i < n; i++) { - out_data[i] *= static_cast(0.5); - } -#else - auto temp = (x * static_cast(M_SQRT1_2)).erf(); - out.device(d) = x * static_cast(0.5) * (static_cast(1) + temp); -#endif - } -}; - -// gelu_grad(x) = dout * (0.5 * (1 + erf(x / sqrt(2))) + 0.5 * 2 / sqrt(pi) / -// sqrt(2) * x * exp (-0.5 * x^2)) -template -struct GeluGradFunctor : BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { -#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \ - !defined(__OSX__) && !defined(PADDLE_WITH_CUDA) - auto x_data = x.data(); - auto dx_data = dx.data(); - auto dout_data = dout.data(); - int n = std::min(x.size(), dx.size()); - - auto first = static_cast(std::malloc(n * sizeof(T))); - std::memset(first, 0, n * sizeof(T)); - auto second = static_cast(std::malloc(n * sizeof(T))); - std::memset(second, 0, n * sizeof(T)); - - // first = (0.5 * (1 + erf(x / sqrt(2)))) - math::CBlas::AXPY(n, static_cast(M_SQRT1_2), x_data, 1, first, 1); - math::CBlas::VMERF(n, first, first, VML_LA); - for (int i = 0; i < n; i++) { - first[i] += static_cast(1); - } - math::CBlas::SCAL(n, static_cast(0.5), first, 1); - - // second = (0.5 * 2/sqrt(pi) * 1/sqrt(2) * x * exp(-0.5 * x^2)) - math::CBlas::VSQUARE(n, x_data, second); - math::CBlas::SCAL(n, -static_cast(0.5), second, 1); - math::CBlas::VEXP(n, second, second); - math::CBlas::VMUL(n, x_data, second, second); - math::CBlas::SCAL(n, static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2), - second, 1); - - // dx = dout * (first + second); - math::CBlas::VADD(n, first, second, first); - math::CBlas::VMUL(n, dout_data, first, dx_data); - - std::free(first); - std::free(second); -#else - auto first = static_cast(0.5) * - (static_cast(1) + ((x * static_cast(M_SQRT1_2)).erf())); - - auto second = static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2) * x * - (-static_cast(0.5) * x.square()).exp(); - dx.device(d) = dout * (first + second); -#endif - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - // tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct TanhFunctor : public BaseActivationFunctor { @@ -1727,7 +1643,6 @@ class PowGradKernel #define FOR_EACH_ACTIVATION_OP(__macro) \ __macro(sigmoid, Sigmoid, SigmoidFunctor, SigmoidGradFunctor); \ __macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ - __macro(gelu, Gelu, GeluFunctor, GeluGradFunctor); \ __macro(tanh, Tanh, TanhFunctor, TanhGradFunctor); \ __macro(atan, Atan, AtanFunctor, AtanGradFunctor); \ __macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \ diff --git a/paddle/fluid/operators/erf_op.cc b/paddle/fluid/operators/erf_op.cc index 4caef66d4c..09cdf4d8b2 100644 --- a/paddle/fluid/operators/erf_op.cc +++ b/paddle/fluid/operators/erf_op.cc @@ -101,6 +101,7 @@ class ErfGradOpMaker : public framework::SingleGradOpMaker { public: using framework::SingleGradOpMaker::SingleGradOpMaker; + protected: void Apply(GradOpPtr grad_op) const override { grad_op->SetType("erf_grad"); grad_op->SetInput("X", this->Input("X")); diff --git a/paddle/fluid/operators/gelu_op.cc b/paddle/fluid/operators/gelu_op.cc new file mode 100644 index 0000000000..07fa09acee --- /dev/null +++ b/paddle/fluid/operators/gelu_op.cc @@ -0,0 +1,144 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/operators/gelu_op.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +class GeluOp : public framework::OperatorWithKernel { + public: + GeluOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + platform::errors::InvalidArgument( + "Input(%s) of GeluOp should not be null.", "X")); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + platform::errors::InvalidArgument( + "Output(%s) of GeluOp should not be null.", "Out")); + + ctx->ShareDim("X", /*->*/ "Out"); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class GeluGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput(framework::GradVarName("Out")), true, + platform::errors::InvalidArgument( + "Input(%s) of GeluGradOp should not be null.", "DOut")); + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + platform::errors::InvalidArgument( + "Input(%s) of GeluGradOp should not be null.", "X")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput(framework::GradVarName("X")), true, + platform::errors::InvalidArgument( + "Output(%s) of GeluGradOp should not be null.", "DX")); + auto x_grad_name = framework::GradVarName("X"); + ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ x_grad_name); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class GeluOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "Input of Gelu operator"); + AddOutput("Out", "Output of Gelu operator"); + AddAttr("approximate", + "(bool, default false) use approximation of gelu") + .SetDefault(false); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr("use_cudnn", + "(bool, default false) Only used in cudnn kernel, need " + "install cudnn") + .SetDefault(false); + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + AddComment(R"DOC( +Gelu Activation Operator. + +For more details, please refer to [Gaussian Error Linear Units](https://arxiv.org/pdf/1606.08415.pdf). + +when using approximation +$out = \\frac{1}{2}x(1+tanh(\\sqrt{\\frac{2}{\\pi}}(x+0.044715x^{3}))$ + +or else +$out = \\frac{1 + erf(\\frac{x}{\\sqrt{2}})}{2} x$ + +)DOC"); + } +}; + +template +class GeluGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("gelu_grad"); + grad_op->SetInput("X", this->Input("X")); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + grad_op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(gelu, ops::GeluOp, ops::GeluOpMaker, + ops::GeluGradOpMaker, + ops::GeluGradOpMaker); +REGISTER_OPERATOR(gelu_grad, ops::GeluGradOp); +REGISTER_OP_CPU_KERNEL( + gelu, ops::GeluKernel, + ops::GeluKernel); +REGISTER_OP_CPU_KERNEL( + gelu_grad, ops::GeluGradKernel, + ops::GeluGradKernel); diff --git a/paddle/fluid/operators/gelu_op.cu b/paddle/fluid/operators/gelu_op.cu new file mode 100644 index 0000000000..5bb2fd2479 --- /dev/null +++ b/paddle/fluid/operators/gelu_op.cu @@ -0,0 +1,28 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/gelu_op.h" +#include "paddle/fluid/platform/float16.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + gelu, ops::GeluKernel, + ops::GeluKernel, + ops::GeluKernel); +REGISTER_OP_CUDA_KERNEL( + gelu_grad, ops::GeluGradKernel, + ops::GeluGradKernel, + ops::GeluGradKernel); diff --git a/paddle/fluid/operators/gelu_op.h b/paddle/fluid/operators/gelu_op.h new file mode 100644 index 0000000000..ad38ec1cc5 --- /dev/null +++ b/paddle/fluid/operators/gelu_op.h @@ -0,0 +1,119 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#ifndef _USE_MATH_DEFINES +#define _USE_MATH_DEFINES +#endif +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/float16.h" + +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +struct GeluFunctor { + template + void operator()(Device d, X x, Out out, bool approximate) const { + if (approximate) { + // gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3}))) + auto temp = (static_cast(M_2_SQRTPI * M_SQRT1_2) * + (x + static_cast(0.044715) * x.cube())) + .tanh(); + out.device(d) = x * static_cast(0.5) * (static_cast(1) + temp); + } else { + // gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) + auto temp = (x * static_cast(M_SQRT1_2)).erf(); + out.device(d) = x * static_cast(0.5) * (static_cast(1) + temp); + } + } +}; + +template +struct GeluGradFunctor { + template + void operator()(Device d, X x, dOut dout, dX dx, bool approximate) const { + if (approximate) { + const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); + const T kBeta = kAlpha * static_cast(0.044715) * static_cast(3); + const auto y = + (kAlpha * ((static_cast(0.044715) * x.cube()) + x)).tanh(); + dx.device(d) = static_cast(0.5) * dout * + (static_cast(1) + y + + (x - x * y.square()) * (kAlpha + kBeta * x.square())); + } else { + // gelu_grad(x) = dout * 0.5 * (1 + erf(x / sqrt(2)) + x * sqrt(2 / pi) * + // exp(- x^2 / 2) + auto first = + static_cast(0.5) * + (static_cast(1) + ((x * static_cast(M_SQRT1_2)).erf())); + + auto second = static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2) * x * + (-static_cast(0.5) * x.square()).exp(); + dx.device(d) = dout * (first + second); + } + } +}; + +template +class GeluKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* out = context.Output("Out"); + auto* in = context.Input("X"); + auto approximate = context.Attr("approximate"); + out->mutable_data(in->place()); + + auto eigen_out = framework::EigenVector::Flatten(*out); + auto eigen_in = framework::EigenVector::Flatten(*in); + auto& place = + *context.template device_context().eigen_device(); + + GeluFunctor functor; + functor(place, eigen_in, eigen_out, approximate); + } +}; + +template +class GeluGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* dout = + context.Input(framework::GradVarName("Out")); + auto* dx = context.Output(framework::GradVarName("X")); + auto approximate = context.Attr("approximate"); + dx->mutable_data(dout->place()); + + auto eigen_x = framework::EigenVector::Flatten(*x); + auto eigen_dout = framework::EigenVector::Flatten(*dout); + auto eigen_dx = framework::EigenVector::Flatten(*dx); + auto& place = + *context.template device_context().eigen_device(); + + GeluGradFunctor functor; + functor(place, eigen_x, eigen_dout, eigen_dx, approximate); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/ops.py b/python/paddle/fluid/layers/ops.py index 5951d86998..54d3952c9c 100644 --- a/python/paddle/fluid/layers/ops.py +++ b/python/paddle/fluid/layers/ops.py @@ -245,7 +245,7 @@ __all__ += ['gelu'] _gelu_ = generate_layer_fn('gelu') -def gelu(x): +def gelu(x, approximate=False): locals_var = locals().copy() kwargs = dict() for name, val in locals_var.items(): @@ -259,6 +259,11 @@ gelu.__doc__ = """ For more details, see [Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415). Equation: + if approximate is True + .. math:: + out = 0.5 * x * (1 + tanh(\\sqrt{\\frac{2}{\\pi}} * (x + 0.044715x^{3}))) + + else .. math:: out = 0.5 * x * (1 + erf(\\frac{x}{\\sqrt{2}})) diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index a4a8b76d07..48aec26fea 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -411,16 +411,44 @@ class TestLeakyRelu(TestActivation): self.check_grad(['X'], 'Out') -class TestGelu(TestActivation): +def gelu(x, approximate): + if approximate: + y_ref = 0.5 * x * (1.0 + np.tanh( + np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) + else: + y_ref = 0.5 * x * (1 + erf(x / np.sqrt(2))) + return y_ref.astype(x.dtype) + + +class TestGeluApproximate(TestActivation): def setUp(self): self.op_type = "gelu" self.init_dtype() + approximate = True + x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) + out = gelu(x, approximate) + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + self.attrs = {"approximate": approximate} + + def test_check_grad(self): + if self.dtype == np.float16: + return + self.check_grad(['X'], 'Out') + + +class TestGelu(TestActivation): + def setUp(self): + self.op_type = "gelu" + self.init_dtype() + approximate = False x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) - out = 0.5 * x * (1.0 + erf(x / np.sqrt(2.0))) + out = gelu(x, approximate) self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.attrs = {"approximate": approximate} def test_check_grad(self): if self.dtype == np.float16: diff --git a/python/paddle/fluid/tests/unittests/test_gelu_op.py b/python/paddle/fluid/tests/unittests/test_gelu_op.py index 5f722ab8e0..13174edad8 100644 --- a/python/paddle/fluid/tests/unittests/test_gelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_gelu_op.py @@ -21,33 +21,43 @@ import paddle.fluid as fluid import paddle.fluid.dygraph as dg +def gelu(x, approximate): + if approximate: + y_ref = 0.5 * x * (1.0 + np.tanh( + np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) + else: + y_ref = 0.5 * x * (1 + erf(x / np.sqrt(2))) + return y_ref.astype(x.dtype) + + class TestGeluOp(unittest.TestCase): - def _test_case1_cpu(self): + def _test_case1_cpu(self, approximate): x = np.random.uniform(-1, 1, size=(11, 17)).astype(np.float32) - y_ref = 0.5 * x * (1 + erf(x / np.sqrt(2))) + y_ref = gelu(x, approximate) place = fluid.CPUPlace() with dg.guard(place) as g: x_var = dg.to_variable(x) - y_var = fluid.layers.gelu(x_var) + y_var = fluid.layers.gelu(x_var, approximate) y_test = y_var.numpy() self.assertTrue(np.allclose(y_ref, y_test, rtol=1e-05, atol=1e-08)) - def _test_case1_gpu(self): + def _test_case1_gpu(self, approximate): x = np.random.uniform(-1, 1, size=(11, 17)).astype(np.float32) - y_ref = 0.5 * x * (1 + erf(x / np.sqrt(2))) + y_ref = gelu(x, approximate) place = fluid.CUDAPlace(0) with dg.guard(place) as g: x_var = dg.to_variable(x) - y_var = fluid.layers.gelu(x_var) + y_var = fluid.layers.gelu(x_var, approximate) y_test = y_var.numpy() self.assertTrue(np.allclose(y_ref, y_test, rtol=1e-05, atol=1e-08)) def test_cases(self): - self._test_case1_cpu() - if fluid.is_compiled_with_cuda(): - self._test_case1_gpu() + for approximate in [True, False]: + self._test_case1_cpu(approximate) + if fluid.is_compiled_with_cuda(): + self._test_case1_gpu(approximate) if __name__ == '__main__': -- GitLab