From 436808c6981be3fb808bb22794ee2885d7cd257e Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Tue, 23 Nov 2021 18:14:58 +0800 Subject: [PATCH] elu support alpha < 0 (#37316) (#37437) --- paddle/fluid/operators/activation_op.cc | 37 ++++++-- paddle/fluid/operators/activation_op.cu | 91 ++++++++++++++++--- paddle/fluid/operators/activation_op.h | 73 ++++++++++++--- paddle/fluid/operators/inplace_abn_op.h | 2 +- .../tests/unittests/test_activation_op.py | 18 +++- python/paddle/nn/functional/activation.py | 9 +- python/paddle/nn/layer/activation.py | 8 +- 7 files changed, 196 insertions(+), 42 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 5a498e617a..cacb6dd8fe 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -560,6 +560,22 @@ $$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$$ } }; +template +class ELUGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("elu_grad"); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetInput("Out", this->Output("Out")); + op->SetInput("X", this->Input("X")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + class Relu6OpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -1233,13 +1249,11 @@ REGISTER_OP_CPU_KERNEL( /* ========================================================================== */ /* ======================== elu register ============================ */ -REGISTER_OPERATOR( - elu, ops::ActivationOp, ops::ELUOpMaker, ops::ActivationOpInferVarType, - ops::ActivationGradOpMaker::FwdDeps(), - paddle::framework::OpDesc>, - ops::ActivationGradOpMaker::FwdDeps(), - paddle::imperative::OpBase>, - ops::ActFwdInplaceInferer); +REGISTER_OPERATOR(elu, ops::ActivationOp, ops::ELUOpMaker, + ops::ActivationOpInferVarType, + ops::ELUGradOpMaker, + ops::ELUGradOpMaker, + ops::ActFwdInplaceInferer); REGISTER_OPERATOR(elu_grad, ops::ActivationOpGrad, ops::ActivationGradOpInplaceInferer, ops::ELUDoubleGradMaker, @@ -1249,7 +1263,14 @@ REGISTER_OPERATOR( ops::ActivationOpDoubleGrad::FwdDeps()>, ops::ActivationDoubleGradOpInplaceInferer); -REGISTER_ACTIVATION_CPU_KERNEL(elu, ELU, ELUFunctor, ELUGradFunctor); +REGISTER_OP_CPU_KERNEL(elu, + ops::ActivationKernel>, + ops::ActivationKernel>); +REGISTER_OP_CPU_KERNEL( + elu_grad, ops::ELUGradKernel, + ops::ELUGradKernel); REGISTER_OP_CPU_KERNEL( elu_grad_grad, ops::ELUDoubleGradKernel>, diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 72f10bf19e..3ff164a869 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -1161,11 +1161,12 @@ struct CudaELUFunctor : public BaseActivationFunctor { return {{"alpha", &alpha}}; } - // elu(x) = max(0, x) + min(0, alpha * (exp(x) - 1)) + // elu(x) = x, if x > 0 + // elu(x) = alpha * (e^x - 1), if x <= 0 __device__ __forceinline__ T operator()(const T& arg_x) const { CT x = static_cast(arg_x); CT temp = static_cast(alpha) * (exp(x) - one); - CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp); + CT res = x > zero ? x : temp; return static_cast(res); } }; @@ -1174,34 +1175,84 @@ template struct CudaELUGradFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; MPType zero = static_cast(0.0f); - MPType one = static_cast(1.0f); float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } - // dx = dout, if alpha > 0 and x > 0 - // dx = dout * alpha * x.exp(), if alpha > 0 and x <= 0 - // dx = dout * (1 + alpha * x.exp()), if alpha <= 0 and x > 0 - // dx = 0, if alpha <= 0 and x <=0 + // case 1: alpha >= 0 + // dx = dout, if out > 0 + // dx = dout * (out + alpha), if out <= 0 __device__ __forceinline__ T operator()(const T& arg_dout, + const T& arg_out) const { + MPType dout = static_cast(arg_dout); + MPType out = static_cast(arg_out); + MPType a = static_cast(alpha); + MPType out_pos = static_cast(out > zero); + MPType out_neg = static_cast(out <= zero); + return static_cast(dout * (out_pos + out_neg * (out + a))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; + +template +struct CudaELUGradNegativeAlphaFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + MPType zero = static_cast(0.0f); + float alpha; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + + // case 2: alpha < 0 + // dx = dout, if x > 0 + // dx = dout * (out + alpha), if x <=0 + __device__ __forceinline__ T operator()(const T& arg_dout, const T& arg_out, const T& arg_x) const { MPType dout = static_cast(arg_dout); + MPType out = static_cast(arg_out); MPType x = static_cast(arg_x); MPType a = static_cast(alpha); - MPType temp_a_pos = static_cast(alpha > 0.0f); - MPType temp_a_neg = static_cast(alpha <= 0.0f); - MPType temp_x_pos = static_cast(x > zero); - MPType temp_x_neg = static_cast(x <= zero); - return static_cast( - dout * (temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * a * exp(x) + - temp_a_neg * temp_x_pos * (one + a * exp(x)))); + MPType x_pos = static_cast(x > zero); + MPType x_neg = static_cast(x <= zero); + return static_cast(dout * (x_pos + x_neg * (out + a))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +template +class ELUGradCudaKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto* out = ctx.Input("Out"); + auto* x = ctx.Input("X"); + auto* d_x = ctx.Output(framework::GradVarName("X")); + d_x->mutable_data(ctx.GetPlace()); + const float alpha = ctx.Attr("alpha"); + + auto& dev_ctx = ctx.device_context(); + std::vector ins = {d_out, out}; + std::vector outs = {d_x}; + if (alpha > 0) { + CudaELUGradFunctor functor; + functor.alpha = alpha; + LaunchSameDimsElementwiseCudaKernel( + dev_ctx, ins, &outs, functor); + } else { + CudaELUGradNegativeAlphaFunctor functor; + functor.alpha = alpha; + ins.push_back(x); + LaunchSameDimsElementwiseCudaKernel( + dev_ctx, ins, &outs, functor); + } + } +}; + template class ActivationCudaKernel : public framework::OpKernel { @@ -1330,7 +1381,17 @@ REGISTER_OP_CUDA_KERNEL( /* ========================================================================== */ /* ======================== elu register ============================ */ -REGISTER_ACTIVATION_CUDA_KERNEL(elu, ELU, CudaELUFunctor, CudaELUGradFunctor); +REGISTER_OP_CUDA_KERNEL( + elu, ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>); +REGISTER_OP_CUDA_KERNEL( + elu_grad, ops::ELUGradCudaKernel, + ops::ELUGradCudaKernel, + ops::ELUGradCudaKernel); REGISTER_OP_CUDA_KERNEL( elu_grad_grad, ops::ELUDoubleGradKernel { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - auto temp_a_pos = static_cast(alpha > 0); - auto temp_a_neg = static_cast(alpha <= 0); - auto temp_x_pos = (x > static_cast(0)).template cast(); - auto temp_x_neg = (x <= static_cast(0)).template cast(); - - // dx = dout, if alpha > 0 and x > 0 - // dx = dout * alpha * x.exp(), if alpha > 0 and x <= 0 - // dx = dout * (1 + alpha * x.exp()), if alpha <= 0 and x > 0 - // dx = 0, if alpha <= 0 and x <=0 - dx.device(d) = - dout * temp_a_pos * temp_x_pos + - dout * static_cast(alpha) * x.exp() * temp_a_pos * temp_x_neg + - dout * (static_cast(1) + static_cast(alpha) * x.exp()) * - temp_a_neg * temp_x_pos; + // case 1: alpha >= 0 + // dx = dout, if out > 0 + // dx = dout * (out + alpha), if out <= 0 + dx.device(d) = (out > static_cast(0)) + .select(dout, dout * (out + static_cast(alpha))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct ELUGradNegativeAlphaFunctor : public BaseActivationFunctor { + float alpha; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + // case 2: alpha < 0 + // dx = dout, if x > 0 + // dx = dout * (out + alpha), if x <=0 + dx.device(d) = (x > static_cast(0)) + .select(dout, dout * static_cast(alpha) * x.exp()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +template +class ELUGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Out = context.Input("Out"); + auto* dOut = + context.Input(framework::GradVarName("Out")); + auto* dX = context.Output(framework::GradVarName("X")); + const float alpha = context.Attr("alpha"); + dX->mutable_data(context.GetPlace()); + + auto x = framework::EigenVector::Flatten( + GET_DATA_SAFELY(X, "Input", "X", "elu_grad")); + auto out = framework::EigenVector::Flatten( + GET_DATA_SAFELY(Out, "Input", "Out", "elu_grad")); + auto dout = framework::EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Input", "dOut", "elu_grad")); + auto dx = framework::EigenVector::Flatten( + GET_DATA_SAFELY(dX, "Output", "dX", "elu_grad")); + auto* place = + context.template device_context().eigen_device(); + + if (alpha > 0) { + ELUGradFunctor functor; + functor.alpha = alpha; + functor(*place, x, out, dout, dx); + } else { + ELUGradNegativeAlphaFunctor functor; + functor.alpha = alpha; + functor(*place, x, out, dout, dx); + } + } +}; + // FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198 template struct PowFunctor : public BaseActivationFunctor { diff --git a/paddle/fluid/operators/inplace_abn_op.h b/paddle/fluid/operators/inplace_abn_op.h index 1c90a645bf..9c3727ab90 100644 --- a/paddle/fluid/operators/inplace_abn_op.h +++ b/paddle/fluid/operators/inplace_abn_op.h @@ -104,7 +104,7 @@ class InplaceABNActivation { auto temp2 = (y * temp / static_cast(alpha) + static_cast(1)).log(); x.device(d) = (y * temp1 + temp2).template cast(); - ELUGradFunctor functor; + ELUGradNegativeAlphaFunctor functor; compute(ctx, &functor, d, x, y, dy, dx); } else { PADDLE_THROW( diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 346accac01..23eb42c5a9 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -1742,7 +1742,7 @@ class TestSoftReluOpError(unittest.TestCase): def elu(x, alpha): - out_ref = np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1)) + out_ref = np.where(x > 0, x, alpha * (np.exp(x) - 1)) return out_ref.astype(x.dtype) @@ -1753,7 +1753,7 @@ class TestELU(TestActivation): np.random.seed(1024) x = np.random.uniform(-3, 3, [10, 12]).astype(self.dtype) - alpha = 1. + alpha = self.get_alpha() out = elu(x, alpha) # Note: unlike other Relu extensions, point 0 on standard ELU function (i.e. alpha = 1) # is differentiable, so we can skip modifications like x[np.abs(x) < 0.005] = 0.02 here @@ -1766,6 +1766,14 @@ class TestELU(TestActivation): return self.check_grad(['X'], 'Out') + def get_alpha(self): + return 1. + + +class TestELUAlpha(TestELU): + def get_alpha(self): + return -0.2 + class TestELUAPI(unittest.TestCase): # test paddle.nn.ELU, paddle.nn.functional.elu @@ -1832,6 +1840,12 @@ class TestELUInplaceAPI(TestELUAPI): def executed_api(self): self.elu = F.elu_ + def test_alpha_error(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + self.assertRaises(Exception, F.elu_, x, -0.2) + paddle.enable_static() + class TestReciprocal(TestActivation): def setUp(self): diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 67be64c01c..830a6c9363 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -37,7 +37,13 @@ def elu(x, alpha=1.0, name=None): .. math:: - elu(x) = max(0, x) + min(0, \alpha * (e^{x}-1)) + elu(x)= + \left\{ + \begin{array}{lcl} + x,& &\text{if } \ x > 0 \\ + alpha * (e^{x} - 1),& &\text{if } \ x <= 0 + \end{array} + \right. Parameters: x (Tensor): The input Tensor with data type float32, float64. @@ -80,6 +86,7 @@ def elu_(x, alpha=1.0, name=None): Inplace version of ``elu`` API, the output Tensor will be inplaced with input ``x``. Please refer to :ref:`api_nn_cn_elu`. """ + assert alpha >= 0., "elu_ only support alpha >= 0, please use elu instead." return _C_ops.elu_(x, 'alpha', alpha) diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index abfeff0641..7d98125f86 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -31,7 +31,13 @@ class ELU(Layer): .. math:: - ELU(x) = max(0, x) + min(0, \alpha * (e^{x}-1)) + ELU(x)= + \left\{ + \begin{array}{lcl} + x,& &\text{if } \ x > 0 \\ + alpha * (e^{x} - 1),& &\text{if } \ x <= 0 + \end{array} + \right. Parameters: alpha (float, optional): The 'alpha' value of the ELU formulation. Default is 1.0. -- GitLab