From d7064f0435ce1c35c2b57bf6fcbef6b2597c5f4f Mon Sep 17 00:00:00 2001 From: yujun <50394665+JunnYu@users.noreply.github.com> Date: Wed, 13 Oct 2021 18:43:56 +0800 Subject: [PATCH] [PaddlePaddle hackathon] + ADD CELU (#36088) * update * update * update * try make CI pass * doc typo * update doc string --- paddle/fluid/operators/activation_op.cc | 74 ++++++++++++ paddle/fluid/operators/activation_op.cu | 66 +++++++++++ paddle/fluid/operators/activation_op.h | 111 ++++++++++++++++++ .../unittests/test_activation_nn_grad.py | 27 +++++ .../tests/unittests/test_activation_op.py | 89 ++++++++++++++ .../tests/unittests/test_imperative_layers.py | 3 + python/paddle/nn/__init__.py | 2 + python/paddle/nn/functional/__init__.py | 2 + python/paddle/nn/functional/activation.py | 44 +++++++ python/paddle/nn/layer/__init__.py | 1 + python/paddle/nn/layer/activation.py | 42 +++++++ 11 files changed, 461 insertions(+) diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index ac98e49b1c..3cdcfd7923 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -560,6 +560,28 @@ $$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$$ } }; +class CELUOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input is a multi-dimensional Tensor. The data type is " + "float32 or float64."); + AddOutput("Out", + "The output is a multi-dimensional Tensor which has same " + "dimension and data type as the ``x``."); + AddAttr("alpha", "The alpha value of CELU").SetDefault(1.0f); + AddComment(R"DOC( +CELU Activation Operator. + +Applies the following element-wise computation on the input according to +https://arxiv.org/abs/1704.07483. + +$$out = \max(0, x) + \min(0, \alpha * (e^(x/\alpha) - 1))$$ + +)DOC"); + } +}; + class Relu6OpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -982,6 +1004,29 @@ class ELUDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker { } }; +// celu grad: dx=dy if y>0 else dy*(x/alpha).exp() +// celu gradgrad: ddx=ddy if y>0 else ddy*(x/alpha).exp()/alpha +template +class CELUDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker { + public: + using ::paddle::framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("celu_grad_grad"); + + op->SetInput("X", this->Input("X")); + op->SetInput("DOut", this->Input(framework::GradVarName("Out"))); + // X@GRAD@GRAD: ddx + op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X"))); + op->SetAttrMap(this->Attrs()); + + // Out@GRAD@GRAD: ddy + op->SetOutput("DX", this->InputGrad("X")); + op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out"))); + } +}; + // sqrt Grad: dx = 0.5 * dy / y // sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx template @@ -1353,6 +1398,35 @@ REGISTER_OP_CPU_KERNEL( /* ========================================================================== */ +/* ======================== celu register ============================ + */ +REGISTER_OPERATOR( + celu, ops::ActivationOp, ops::CELUOpMaker, ops::ActivationOpInferVarType, + ops::ActivationGradOpMaker::FwdDeps(), + paddle::framework::OpDesc>, + ops::ActivationGradOpMaker::FwdDeps(), + paddle::imperative::OpBase>, + ops::ActFwdInplaceInferer); +REGISTER_OPERATOR(celu_grad, ops::ActivationOpGrad, + ops::ActivationGradOpInplaceInferer, + ops::CELUDoubleGradMaker, + ops::CELUDoubleGradMaker); +REGISTER_OPERATOR( + celu_grad_grad, + ops::ActivationOpDoubleGrad::FwdDeps()>, + ops::ActivationDoubleGradOpInplaceInferer); + +REGISTER_ACTIVATION_CPU_KERNEL(celu, CELU, CELUFunctor, CELUGradFunctor); +REGISTER_OP_CPU_KERNEL( + celu_grad_grad, ops::CELUDoubleGradKernel>, + ops::CELUDoubleGradKernel>, + ops::CELUDoubleGradKernel>); + +/* ========================================================================== */ + /* =========================== sqrt register ============================= */ REGISTER_OPERATOR( sqrt, ops::ActivationOp, ops::SqrtOpMaker, ops::ActivationOpInferVarType, diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index f330f2d7e8..d83a63015c 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -1202,6 +1202,59 @@ struct CudaELUGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +template +struct CudaCELUFunctor : public BaseActivationFunctor { + using CT = typename details::MPTypeTrait::Type; + CT zero = static_cast(0.0f); + CT one = static_cast(1.0f); + float alpha; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + + // celu(x) = max(0, x) + min(0, alpha * (exp(x/alpha) - 1)) + __device__ __forceinline__ T operator()(const T& arg_x) const { + CT x = static_cast(arg_x); + CT temp = static_cast(alpha) * (exp(x / static_cast(alpha)) - one); + CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp); + return static_cast(res); + } +}; + +template +struct CudaCELUGradFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + MPType zero = static_cast(0.0f); + MPType one = static_cast(1.0f); + float alpha; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + + // dx = dout, if alpha > 0 and x > 0 + // dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0 + // dx = dout , if alpha < 0 and x > 0 + // dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0 + __device__ __forceinline__ T operator()(const T& arg_dout, + const T& arg_x) const { + MPType dout = static_cast(arg_dout); + MPType x = static_cast(arg_x); + MPType a = static_cast(alpha); + MPType temp_a_pos = static_cast(alpha > 0.0f); + MPType temp_a_neg = static_cast(alpha <= 0.0f); + MPType temp_x_pos = static_cast(x > zero); + MPType temp_x_neg = static_cast(x <= zero); + return static_cast( + dout * + (temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * exp(x / a) + + temp_a_neg * temp_x_pos + exp(x / a) * temp_a_neg * temp_x_neg)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + template class ActivationCudaKernel : public framework::OpKernel { @@ -1341,6 +1394,19 @@ REGISTER_OP_CUDA_KERNEL( ops::ELUGradGradFunctor>); /* ========================================================================== */ +/* ======================== celu register ============================ */ +REGISTER_ACTIVATION_CUDA_KERNEL(celu, CELU, CudaCELUFunctor, + CudaCELUGradFunctor); + +REGISTER_OP_CUDA_KERNEL( + celu_grad_grad, ops::CELUDoubleGradKernel>, + ops::CELUDoubleGradKernel>, + ops::CELUDoubleGradKernel>); +/* ========================================================================== */ + /* =========================== relu register ============================ */ #ifdef PADDLE_WITH_HIP REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, CudaReluFunctor, diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 4f26cb095c..a6240c038b 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -1389,6 +1389,51 @@ struct ELUGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +template +struct CELUFunctor : public BaseActivationFunctor { + float alpha; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + + template + void operator()(Device d, X x, Out out) const { + out.device(d) = + (x < static_cast(0)) + .select(static_cast(alpha) * + ((x / static_cast(alpha)).exp() - static_cast(1)), + x); + } +}; + +template +struct CELUGradFunctor : public BaseActivationFunctor { + float alpha; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + auto temp_a_pos = static_cast(alpha > 0); + auto temp_a_neg = static_cast(alpha <= 0); + auto temp_x_pos = (x > static_cast(0)).template cast(); + auto temp_x_neg = (x <= static_cast(0)).template cast(); + + // dx = dout, if alpha > 0 and x > 0 + // dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0 + // dx = dout , if alpha < 0 and x > 0 + // dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0 + dx.device(d) = + dout * temp_a_pos * temp_x_pos + + dout * (x / static_cast(alpha)).exp() * temp_a_pos * temp_x_neg + + dout * temp_a_neg * temp_x_pos + + dout * (x / static_cast(alpha)).exp() * temp_a_neg * temp_x_neg; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + // FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198 template struct PowFunctor : public BaseActivationFunctor { @@ -1775,6 +1820,45 @@ struct ELUGradGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +template +struct CELUGradGradFunctor : public BaseActivationFunctor { + float alpha; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + template + void operator()(const Device& dev, const framework::Tensor* X, + const framework::Tensor* ddX, framework::Tensor* ddOut, + const framework::Tensor* dOut, framework::Tensor* dX) const { + auto* d = dev.eigen_device(); + auto ddx = framework::EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "DDX", "CELUGradGrad")); + auto x = framework::EigenVector::Flatten( + GET_DATA_SAFELY(X, "Input", "X", "CELUGradGrad")); + + if (dX) { + auto dx = framework::EigenVector::Flatten( + GET_DATA_SAFELY(dX, "Output", "DX", "CELUGradGrad")); + auto dout = framework::EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Output", "DOut", "CELUGradGrad")); + dx.device(*d) = ddx * dout / static_cast(alpha) * + (x / static_cast(alpha)).exp() * + (x <= static_cast(0)).template cast(); + } + + if (ddOut) { + auto ddout = framework::EigenVector::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "DDOut", "CELUGradGrad")); + ddout.device(*d) = ddx * + ((x > static_cast(0)).template cast() + + (x / static_cast(alpha)).exp() * + (x <= static_cast(0)).template cast()) + .template cast(); + } + } + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + template struct SqrtGradGradFunctor : public BaseActivationFunctor { template @@ -2107,6 +2191,33 @@ class ELUDoubleGradKernel } }; +template +class CELUDoubleGradKernel + : public framework::OpKernel { + public: + using T = typename Functor::ELEMENT_TYPE; + void Compute(const framework::ExecutionContext& ctx) const override { + const framework::Tensor *X, *ddX, *dOut; + X = ddX = dOut = nullptr; + framework::Tensor *dX, *ddOut; + dX = ddOut = nullptr; + + ExtractDoubleGradTensorWithInputDOut(ctx, &X, &ddX, &dX, &dOut, &ddOut); + + if (dX) dX->mutable_data(X->dims(), ctx.GetPlace()); + if (ddOut) ddOut->mutable_data(ctx.GetPlace()); + + auto& place = ctx.template device_context(); + + Functor functor; + auto attrs = functor.GetAttrs(); + for (auto& attr : attrs) { + *attr.second = ctx.Attr(attr.first); + } + functor(place, X, ddX, ddOut, dOut, dX); + } +}; + template class SqrtDoubleGradKernel : public framework::OpKernel { diff --git a/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py b/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py index 8f3353d115..c54f711c7c 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py @@ -22,6 +22,7 @@ import paddle import paddle.fluid.layers as layers import paddle.fluid.core as core import gradient_checker +import paddle.nn.functional as F from decorator_helper import prog_scope @@ -168,6 +169,32 @@ class TestELUDoubleGradCheck(unittest.TestCase): self.func(p) +class TestCELUDoubleGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + shape = [2, 4, 4, 4] + eps = 1e-6 + alpha = 0.2 + dtype = np.float64 + SEED = 0 + + x = layers.data('x', shape, False, dtype) + x.persistable = True + + y = F.celu(x, alpha=alpha) + np.random.RandomState(SEED) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + gradient_checker.double_grad_check( + [x], y, x_init=x_arr, place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + class TestSqrtDoubleGradCheck(unittest.TestCase): @prog_scope() def func(self, place): diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 346accac01..b82dd631c6 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -1827,6 +1827,94 @@ class TestELUAPI(unittest.TestCase): self.elu(x_fp16) +def celu(x, alpha): + out_ref = np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x / alpha) - 1)) + return out_ref.astype(x.dtype) + + +class TestCELU(TestActivation): + def setUp(self): + self.op_type = "celu" + self.init_dtype() + + np.random.seed(1024) + x = np.random.uniform(-3, 3, [10, 12]).astype(self.dtype) + alpha = 1.5 + out = celu(x, alpha) + self.inputs = {'X': x} + self.attrs = {'alpha': alpha} + self.outputs = {'Out': out} + + def test_check_grad(self): + if self.dtype == np.float16: + return + self.check_grad(['X'], 'Out') + + +class TestCELUAPI(unittest.TestCase): + # test paddle.nn.CELU, paddle.nn.functional.celu + def setUp(self): + np.random.seed(1024) + self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float32') + self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ + else paddle.CPUPlace() + self.executed_api() + + def executed_api(self): + self.celu = F.celu + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', [10, 12]) + out1 = self.celu(x, 1.5) + m = paddle.nn.CELU(1.5) + out2 = m(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = celu(self.x_np, 1.5) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + out1 = self.celu(x, 1.5) + x = paddle.to_tensor(self.x_np) + m = paddle.nn.CELU(1.5) + out2 = m(x) + out_ref = celu(self.x_np, 1.5) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + + out1 = self.celu(x, 0.2) + x = paddle.to_tensor(self.x_np) + m = paddle.nn.CELU(0.2) + out2 = m(x) + out_ref = celu(self.x_np, 0.2) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + paddle.enable_static() + + def test_errors(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + # The input type must be Variable. + self.assertRaises(TypeError, self.celu, 1) + # The input dtype must be float16, float32, float64. + x_int32 = paddle.fluid.data( + name='x_int32', shape=[10, 12], dtype='int32') + self.assertRaises(TypeError, self.celu, x_int32) + # The alpha must be not equal 0 + x_fp32 = paddle.fluid.data( + name='x_fp32', shape=[10, 12], dtype='float32') + self.assertRaises(ZeroDivisionError, F.celu, x_fp32, 0) + # support the input dtype is float16 + x_fp16 = paddle.fluid.data( + name='x_fp16', shape=[10, 12], dtype='float16') + self.celu(x_fp16) + + class TestELUInplaceAPI(TestELUAPI): # test paddle.nn.functional.elu_ def executed_api(self): @@ -2791,6 +2879,7 @@ create_test_act_fp16_class(TestBRelu) create_test_act_fp16_class(TestRelu6) create_test_act_fp16_class(TestSoftRelu, grad_atol=0.85) create_test_act_fp16_class(TestELU) +create_test_act_fp16_class(TestCELU) create_test_act_fp16_class(TestReciprocal) create_test_act_fp16_class(TestLog) if core.is_compiled_with_rocm(): diff --git a/python/paddle/fluid/tests/unittests/test_imperative_layers.py b/python/paddle/fluid/tests/unittests/test_imperative_layers.py index dc15566f85..3561405ae0 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_layers.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_layers.py @@ -22,6 +22,9 @@ class TestLayerPrint(unittest.TestCase): module = nn.ELU(0.2) self.assertEqual(str(module), 'ELU(alpha=0.2)') + module = nn.CELU(0.2) + self.assertEqual(str(module), 'CELU(alpha=0.2)') + module = nn.GELU(True) self.assertEqual(str(module), 'GELU(approximate=True)') diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 98444e69d0..064052c076 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -25,6 +25,7 @@ from .clip import ClipGradByNorm # noqa: F401 from .clip import ClipGradByValue # noqa: F401 from .decode import BeamSearchDecoder # noqa: F401 from .decode import dynamic_decode # noqa: F401 +from .layer.activation import CELU # noqa: F401 from .layer.activation import ELU # noqa: F401 from .layer.activation import GELU # noqa: F401 from .layer.activation import Tanh # noqa: F401 @@ -185,6 +186,7 @@ def weight_norm(*args): __all__ = [ #noqa 'BatchNorm', + 'CELU', 'GroupNorm', 'LayerNorm', 'SpectralNorm', diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 4151f25b94..1af53e0826 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -15,6 +15,7 @@ # TODO: import all neural network related api under this directory, # including layers, linear, conv, rnn etc. +from .activation import celu # noqa: F401 from .activation import elu # noqa: F401 from .activation import elu_ # noqa: F401 from .activation import gelu # noqa: F401 @@ -115,6 +116,7 @@ from ...fluid.layers import temporal_shift # noqa: F401 from .sparse_attention import sparse_attention __all__ = [ #noqa + 'celu', 'conv1d', 'conv1d_transpose', 'conv2d', diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 67be64c01c..a39c00075a 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -31,6 +31,50 @@ from paddle import _C_ops __all__ = [] +def celu(x, alpha=1.0, name=None): + r""" + celu activation. + + .. math:: + + celu(x) = max(0, x) + min(0, \alpha * (e^{x/\alpha}-1)) + + Parameters: + x (Tensor): The input Tensor with data type float32, float64. + alpha (float, optional): The 'alpha' value of the CELU formulation. Default is 1.0. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor with the same data type and shape as ``x`` . + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + x = paddle.to_tensor([[-1., 6.], [1., 15.6]]) + out = F.celu(x, alpha=0.2) + # [[-0.19865242, 6. ], + # [ 1. , 15.60000038]] + """ + if alpha == 0: + raise ZeroDivisionError("alpha cannot be 0 for celu") + + if in_dygraph_mode(): + return _C_ops.celu(x, 'alpha', alpha) + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'celu') + helper = LayerHelper("celu", **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op( + type='celu', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'alpha': alpha}) + return out + + def elu(x, alpha=1.0, name=None): r""" elu activation. diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 074dfac510..eb7535b16c 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -18,6 +18,7 @@ from . import rnn # noqa: F401 from . import transformer # noqa: F401 from . import container # noqa: F401 +from .activation import CELU # noqa: F401 from .activation import PReLU # noqa: F401 from .activation import ReLU # noqa: F401 from .activation import ReLU6 # noqa: F401 diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index abfeff0641..cf0ac79ca8 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -25,6 +25,48 @@ from paddle.nn import Layer __all__ = [] +class CELU(Layer): + r""" + CELU Activation. + + .. math:: + + CELU(x) = max(0, x) + min(0, \alpha * (e^{x/\alpha}-1)) + + Parameters: + alpha (float, optional): The 'alpha' value of the CELU formulation. Default is 1.0. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. + + Examples: + .. code-block:: python + + import paddle + + x = paddle.to_tensor([[-1. ,6.], [1., 15.6]]) + m = paddle.nn.CELU(0.2) + out = m(x) + # [[-0.19865242, 6. ], + # [ 1. , 15.60000038]] + """ + + def __init__(self, alpha=1.0, name=None): + super(CELU, self).__init__() + self._alpha = alpha + self._name = name + + def forward(self, x): + return F.celu(x, self._alpha, self._name) + + def extra_repr(self): + name_str = ', name={}'.format(self._name) if self._name else '' + return 'alpha={}{}'.format(self._alpha, name_str) + + class ELU(Layer): r""" ELU Activation. -- GitLab