From 2f351ed563e2736a1c1444782eb3700cc270d2b5 Mon Sep 17 00:00:00 2001 From: minghaoBD <79566150+minghaoBD@users.noreply.github.com> Date: Sun, 25 Apr 2021 11:34:05 +0800 Subject: [PATCH] add silu op, test=develop (#32384) --- paddle/fluid/framework/ir/is_test_pass.cc | 2 +- paddle/fluid/operators/activation_op.cc | 7 ++ paddle/fluid/operators/activation_op.h | 26 +++++++ python/paddle/fluid/layers/ops.py | 15 +++++ .../tests/unittests/test_activation_op.py | 67 +++++++++++++++++++ python/paddle/nn/__init__.py | 1 + python/paddle/nn/functional/__init__.py | 1 + python/paddle/nn/functional/activation.py | 34 ++++++++++ python/paddle/nn/layer/activation.py | 39 +++++++++++ 9 files changed, 191 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/is_test_pass.cc b/paddle/fluid/framework/ir/is_test_pass.cc index 0a70440765d..25bf03f426a 100644 --- a/paddle/fluid/framework/ir/is_test_pass.cc +++ b/paddle/fluid/framework/ir/is_test_pass.cc @@ -35,7 +35,7 @@ void IsTestPass::ApplyImpl(ir::Graph* graph) const { "hard_shrink", "hard_sigmoid", "relu6", "soft_relu", "swish", "thresholded_relu", "log", "square", "softplus", - "softsign"}; + "softsign", "silu"}; for (const Node* n : graph->Nodes()) { if (n->IsOp()) { auto* op = n->Op(); diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 1cac9ed9f1d..055909ba6f4 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -162,6 +162,12 @@ $$out = \\frac{1}{1 + e^{-x}}$$ )DOC"; +UNUSED constexpr char SiluDoc[] = R"DOC( +Silu Activation Operator + +$$out = x * \\frac{1}{1 + e^{-x}}$$ +)DOC"; + UNUSED constexpr char LogSigmoidDoc[] = R"DOC( Logsigmoid Activation Operator @@ -697,6 +703,7 @@ It is recommended to use the defaults for this activation. }; REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc); +REGISTER_ACTIVATION_OP_MAKER(Silu, SiluDoc); REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc); REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc); REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index fb9f956f17c..7245dea9cf9 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -258,6 +258,31 @@ struct SigmoidGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; +// silu(x) = x / (1 + exp(-x)) +template +struct SiluFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + auto temp = static_cast(1) / (static_cast(1) + (-x).exp()); + out.device(d) = x * temp; + } +}; + +// silu'(x) = (1 / (1 + e^{-x})) * (1 + out * e^{-x})) +template +struct SiluGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + auto temp1 = static_cast(1) + (-x).exp(); // 1+e^(-x) + auto temp2 = x * (-x).exp(); // x*e^(-x) + dx.device(d) = dout * ((static_cast(1) / temp1) * + (static_cast(1) + (temp2 / temp1))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + // Originally: logsigmoid(x) = -log (1 + exp(-x)) // For numerical stability, we can use the log-sum-exp trick: // https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/ @@ -2129,6 +2154,7 @@ struct LogGradGradFunctor : public BaseActivationFunctor { #define FOR_EACH_ACTIVATION_OP(__macro) \ __macro(sigmoid, Sigmoid, SigmoidFunctor, SigmoidGradFunctor); \ + __macro(silu, Silu, SiluFunctor, SiluGradFunctor); \ __macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ __macro(atan, Atan, AtanFunctor, AtanGradFunctor); \ __macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \ diff --git a/python/paddle/fluid/layers/ops.py b/python/paddle/fluid/layers/ops.py index 841daf7a41d..67cdc6dce5a 100755 --- a/python/paddle/fluid/layers/ops.py +++ b/python/paddle/fluid/layers/ops.py @@ -27,6 +27,7 @@ __deprecated_func_name__ = { __activations_noattr__ = [ 'sigmoid', + 'silu', 'logsigmoid', 'tanh_shrink', 'softplus', @@ -100,6 +101,20 @@ Examples: """) +add_sample_code(globals()["silu"], r""" +Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + + x = paddle.to_tensor([1.0, 2.0, 3.0, 4.0]) + out = F.silu(x) + print(out) + # [ 0.7310586 1.7615942 2.8577224, 3.9280552 ] + +""") + add_sample_code(globals()["logsigmoid"], r""" Examples: .. code-block:: python diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index ea183e94448..92465c3e284 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -119,6 +119,72 @@ class TestSigmoid(TestActivation): self.check_grad(['X'], 'Out', max_relative_error=0.01) +class TestSilu(TestActivation): + def setUp(self): + self.op_type = "silu" + self.init_dtype() + + np.random.seed(1024) + x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) + out = x / (np.exp(-x) + 1) + + self.inputs = {'X': x} + self.outputs = {'Out': out} + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_grad(self): + if self.dtype == np.float16: + return + self.check_grad(['X'], 'Out') + + +class TestSiluAPI(unittest.TestCase): + # test paddle.nn.Silu, paddle.nn.functional.silu + def setUp(self): + self.x_np = np.random.uniform(-1, 1, [11, 17]).astype('float32') + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', [11, 17]) + out1 = F.silu(x) + m = paddle.nn.Silu() + out2 = m(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = self.x_np / (1 + np.exp(-self.x_np)) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + out1 = F.silu(x) + m = paddle.nn.Silu() + out2 = m(x) + out_ref = self.x_np / (1 + np.exp(-self.x_np)) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + paddle.enable_static() + + def test_errors(self): + with paddle.static.program_guard(paddle.static.Program()): + # The input type must be Variable. + self.assertRaises(TypeError, F.silu, 1) + # The input dtype must be float16, float32, float64. + x_int32 = paddle.fluid.data( + name='x_int32', shape=[11, 17], dtype='int32') + self.assertRaises(TypeError, F.silu, x_int32) + # support the input dtype is float16 + x_fp16 = paddle.fluid.data( + name='x_fp16', shape=[11, 17], dtype='float16') + F.silu(x_fp16) + + class TestLogSigmoid(TestActivation): def setUp(self): self.op_type = "logsigmoid" @@ -2629,6 +2695,7 @@ def create_test_act_fp16_class(parent, create_test_act_fp16_class(TestActivation) create_test_act_fp16_class(TestSigmoid) +create_test_act_fp16_class(TestSilu) create_test_act_fp16_class(TestLogSigmoid) create_test_act_fp16_class(TestTanh) create_test_act_fp16_class(TestTanhshrink) diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 27d8f35234a..836d4008f7d 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -55,6 +55,7 @@ from .layer.activation import PReLU #DEFINE_ALIAS from .layer.activation import ReLU #DEFINE_ALIAS from .layer.activation import ReLU6 #DEFINE_ALIAS from .layer.activation import SELU #DEFINE_ALIAS +from .layer.activation import Silu #DEFINE_ALIAS from .layer.activation import LeakyReLU #DEFINE_ALIAS from .layer.activation import Sigmoid #DEFINE_ALIAS from .layer.activation import Hardsigmoid #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 268bdffeb36..98124be7288 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -46,6 +46,7 @@ from .activation import relu_ #DEFINE_ALIAS from .activation import relu6 #DEFINE_ALIAS from .activation import selu #DEFINE_ALIAS from .activation import sigmoid #DEFINE_ALIAS +from .activation import silu #DEFINE_ALIAS # from .activation import soft_relu #DEFINE_ALIAS from .activation import softmax #DEFINE_ALIAS from .activation import softmax_ #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 8119b0f45d9..d74308dc9aa 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -49,6 +49,7 @@ __all__ = [ 'softshrink', 'softsign', 'sigmoid', + 'silu' 'swish', 'tanh', 'tanh_', @@ -761,6 +762,39 @@ def selu(x, return out +def silu(x, name=None): + """ + silu activation. + .. math: + silu(x) = \frac{x}{1 + e^{-x}} + + Parameters: + x (Tensor): The input Tensor with data type float32, float64. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor with the same data type and shape as ``x`` . + + Examples: + .. code-block:: python + import paddle + import paddle.nn.functional as F + + x = paddle.to_tensor([1.0, 2.0, 3.0, 4.0]) + out = F.silu(x) # [ 0.731059, 1.761594, 2.857722, 3.928055 ] + """ + + if in_dygraph_mode(): + return core.ops.silu(x) + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'silu') + helper = LayerHelper("silu", **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op(type='silu', inputs={'X': x}, outputs={'Out': out}) + return out + + def softmax(x, axis=-1, dtype=None, name=None): r""" This operator implements the softmax layer. The calculation process is as follows: diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index 69cdb738171..2a9ae310615 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -27,6 +27,7 @@ __all__ = [ 'SELU', 'LeakyReLU', 'Sigmoid', + 'Silu', 'Hardsigmoid', 'Softmax', 'Softplus', @@ -919,6 +920,44 @@ class ThresholdedReLU(layers.Layer): return 'threshold={}{}'.format(self._threshold, name_str) +class Silu(layers.Layer): + """ + Silu Activation. + .. math:: + + Silu(x) = \frac{x}{1 + e^{-x}} + + Parameters: + x (Tensor): The input Tensor with data type float32, or float64. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. + + Examples: + .. code-block:: python + + import paddle + + x = paddle.to_tensor([1.0, 2.0, 3.0, 4.0]) + m = paddle.nn.Silu() + out = m(x) # [ 0.731059, 1.761594, 2.857722, 3.928055 ] + """ + + def __init__(self, name=None): + super(Silu, self).__init__() + self._name = name + + def forward(self, x): + return F.silu(x, self._name) + + def extra_repr(self): + name_str = 'name={}'.format(self._name) if self._name else '' + return name_str + + class LogSigmoid(layers.Layer): r""" LogSigmoid Activation. -- GitLab