From 40d193ed1731b05260df402f7fc98af7b921587a Mon Sep 17 00:00:00 2001 From: hong19860320 <9973393+hong19860320@users.noreply.github.com> Date: Thu, 20 Aug 2020 19:28:16 +0800 Subject: [PATCH] Add the ReLU6, Tanhshrink, SELU, Softplus, Softshrink and Softsign for the api 2.0 (#26376) --- paddle/fluid/operators/activation_op.cc | 38 +- paddle/fluid/operators/activation_op.h | 40 +- python/paddle/fluid/layers/nn.py | 9 +- python/paddle/fluid/layers/ops.py | 42 ++- .../tests/unittests/test_activation_op.py | 344 +++++++++++++++--- .../fluid/tests/unittests/test_selu_op.py | 89 ++++- python/paddle/nn/__init__.py | 6 + python/paddle/nn/functional/__init__.py | 2 +- python/paddle/nn/functional/activation.py | 287 ++++++++++++++- python/paddle/nn/layer/activation.py | 252 +++++++++++++ 10 files changed, 993 insertions(+), 116 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 3d604fd335d..e061fd0ab93 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -317,13 +317,6 @@ $$out = x^2$$ )DOC"; -UNUSED constexpr char SoftplusDoc[] = R"DOC( -Softplus Activation Operator. - -$$out = \ln(1 + e^{x})$$ - -)DOC"; - UNUSED constexpr char SoftsignDoc[] = R"DOC( Softsign Activation Operator. @@ -396,6 +389,36 @@ $$out = \max(x, \alpha * x)$$ } }; +class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "Input of Softplus operator, an N-D Tensor, with data type " + "float32, float64 or float16."); + AddOutput( + "Out", + "Output of Softplus operator, a Tensor with shape same as input."); + AddAttr("beta", "The value of beta for Softplus.").SetDefault(1.0f); + AddAttr("threshold", "The value of threshold for Softplus.") + .SetDefault(20.0f); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel.") + .SetDefault(false); + AddAttr( + "use_cudnn", + "(bool, default false) Only used in cudnn kernel, need install cudnn.") + .SetDefault(false); + AddComment(R"DOC( +:strong:`Softplus Activation Operator` + +.. math:: + out = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) \\ + \text{For numerical stability, the implementation reverts to the linear function when :}\,x \times \beta > threshold. + +)DOC"); + } +}; + class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -672,7 +695,6 @@ REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc); REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc); REGISTER_ACTIVATION_OP_MAKER(Log1p, Log1pDoc); REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc); -REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc); REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc); template diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 3aac7ae8a5e..94cbaf76e94 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -975,32 +975,46 @@ struct HardSwishGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; -// softplus(x) = log(1 + exp(x)) -// When x is a very large positive number, exp(x) may explode to inf, -// Using trick below for numerical stability -// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/ -// Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0))) +// For numerical stability, using the following formula instead of softplus(x) = +// log(1 + exp(x)) +// softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <= threshold(beta = +// 1, threshold = 20 by default), otherwise x template struct SoftplusFunctor : public BaseActivationFunctor { + float beta; + float threshold; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}, {"threshold", &threshold}}; + } + template void operator()(Device d, X x, Out out) { - auto temp = x.cwiseMax(static_cast(0)); // temp = max(x, 0) - out.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log()); + auto x_beta = static_cast(beta) * x; + out.device(d) = (x_beta > static_cast(threshold)) + .select(x, (static_cast(1) + x_beta.exp()).log() / + static_cast(beta)); } }; -// d(softplus(x))/dx = exp(x) / (1 + exp(x)) -// For numerical stability: -// d(softplus(x))/dx = exp(x - max(x, 0)) / (exp(-max(x, 0)) + -// exp(x - max(x, 0))) +// For numerical stability, using the following formula instead of +// d(softplus(x))/dx = 1 / (1 + exp(-x)) +// d(softplus(x))/dx = 1 / (1 + exp(-beta * x)) when beta * x <= threshold(beta +// = 1, threshold = 20 by default), otherwise x template struct SoftplusGradFunctor : public BaseActivationFunctor { + float beta; + float threshold; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}, {"threshold", &threshold}}; + } + template void operator()(Device d, X x, Out out, dOut dout, dX dx) { - auto temp = x.cwiseMax(static_cast(0)); // temp = max(x, 0) + auto x_beta = static_cast(beta) * x; dx.device(d) = - dout * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp())); + (x_beta > static_cast(threshold)) + .select(dout, dout / (static_cast(1) + (-x_beta).exp())); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 9298664f188..efa60b70001 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -8643,11 +8643,9 @@ def relu(x, name=None): return out +@deprecated(since="2.0.0", update_to="paddle.nn.functional.selu") def selu(x, scale=None, alpha=None, name=None): """ - :alias_main: paddle.nn.functional.selu - :alias: paddle.nn.functional.selu,paddle.nn.functional.activation.selu - :old_api: paddle.fluid.layers.selu Selu Operator. @@ -9304,12 +9302,9 @@ def elu(x, alpha=1.0, name=None): return out -@templatedoc() +@deprecated(since="2.0.0", update_to="paddle.nn.functional.relu6") def relu6(x, threshold=6.0, name=None): """ - :alias_main: paddle.nn.functional.relu6 - :alias: paddle.nn.functional.relu6,paddle.nn.functional.activation.relu6 - :old_api: paddle.fluid.layers.relu6 ${comment} diff --git a/python/paddle/fluid/layers/ops.py b/python/paddle/fluid/layers/ops.py index 1c9aa97e29e..dbeffcd2803 100644 --- a/python/paddle/fluid/layers/ops.py +++ b/python/paddle/fluid/layers/ops.py @@ -20,6 +20,8 @@ from ..framework import convert_np_dtype_to_dtype_, Variable from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype from paddle.utils import deprecated +__deprecated_func_name__ = {'tanh_shrink': 'tanhshrink', } + __activations_noattr__ = [ 'sigmoid', 'logsigmoid', @@ -64,14 +66,20 @@ __all__ += __activations_noattr__ __all__ += __unary_func__ for _OP in set(__activations_noattr__): + _new_OP = _OP + if _OP in __deprecated_func_name__: + _new_OP = __deprecated_func_name__[_OP] func = generate_activation_fn(_OP) func = deprecated( - since="2.0.0", update_to="paddle.nn.functional.%s" % (_OP))(func) + since="2.0.0", update_to="paddle.nn.functional.%s" % (_new_OP))(func) globals()[_OP] = func for _OP in set(__unary_func__): + _new_OP = _OP + if _OP in __deprecated_func_name__: + _new_OP = __deprecated_func_name__[_OP] func = generate_activation_fn(_OP) - func = deprecated(since="2.0.0", update_to="paddle.%s" % (_OP))(func) + func = deprecated(since="2.0.0", update_to="paddle.%s" % (_new_OP))(func) globals()[_OP] = func add_sample_code(globals()["sigmoid"], r""" @@ -160,16 +168,14 @@ add_sample_code(globals()["tanh_shrink"], r""" Examples: .. code-block:: python - import numpy as np import paddle import paddle.nn.functional as F + import numpy as np + paddle.disable_static() - x_data = np.array([-0.4, -0.2, 0.1, 0.3]) - x = paddle.to_variable(x_data) - out = F.tanh_shrink(x) - print(out.numpy()) - # [-0.02005104 -0.00262468 0.00033201 0.00868739] + x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3])) + out = F.tanhshrink(x) # [-0.020051, -0.00262468, 0.000332005, 0.00868739] """) @@ -401,16 +407,14 @@ add_sample_code(globals()["softplus"], r""" Examples: .. code-block:: python - import numpy as np import paddle import paddle.nn.functional as F + import numpy as np + paddle.disable_static() - x_data = np.array([-0.4, -0.2, 0.1, 0.3]) - x = paddle.to_variable(x_data) - out = F.softplus(x) - print(out.numpy()) - # [0.51301525 0.59813887 0.74439666 0.85435524] + x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3])) + out = F.softplus(x) # [0.513015, 0.598139, 0.744397, 0.854355] """) @@ -418,16 +422,14 @@ add_sample_code(globals()["softsign"], r""" Examples: .. code-block:: python - import numpy as np import paddle import paddle.nn.functional as F + import numpy as np + paddle.disable_static() - x_data = np.array([-0.4, -0.2, 0.1, 0.3]) - x = paddle.to_variable(x_data) - out = F.softsign(x) - print(out.numpy()) - # [-0.28571429 -0.16666667 0.09090909 0.23076923] + x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3])) + out = F.softsign(x) # [-0.285714, -0.166667, 0.0909091, 0.230769] """) diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index aef49df577f..db32976f046 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -369,15 +369,20 @@ class TestCoshOpError(unittest.TestCase): fluid.layers.cosh(x_fp16) -class TestTanhShrink(TestActivation): +def ref_tanhshrink(x): + out = x - np.tanh(x) + return out + + +class TestTanhshrink(TestActivation): def setUp(self): self.op_type = "tanh_shrink" self.init_dtype() - x = np.random.uniform(0.1, 1, [10, 17]).astype(self.dtype) - out = x - np.tanh(x) + x = np.random.uniform(10, 20, [10, 17]).astype(self.dtype) + out = ref_tanhshrink(x) - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.inputs = {'X': x} self.outputs = {'Out': out} def test_check_grad(self): @@ -386,6 +391,57 @@ class TestTanhShrink(TestActivation): self.check_grad(['X'], 'Out') +class TestTanhshrinkAPI(unittest.TestCase): + # test paddle.nn.Tanhshrink, paddle.nn.functional.tanhshrink + def setUp(self): + self.x_np = np.random.uniform(10, 20, [10, 17]).astype(np.float64) + self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_api(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', self.x_np.shape, self.x_np.dtype) + out1 = F.tanhshrink(x) + tanhshrink = paddle.nn.Tanhshrink() + out2 = tanhshrink(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = ref_tanhshrink(self.x_np) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + out1 = F.tanhshrink(x) + tanhshrink = paddle.nn.Tanhshrink() + out2 = tanhshrink(x) + out_ref = ref_tanhshrink(self.x_np) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + paddle.enable_static() + + def test_fluid_api(self): + with fluid.program_guard(fluid.Program()): + x = fluid.data('X', self.x_np.shape, self.x_np.dtype) + out = fluid.layers.tanh_shrink(x) + exe = fluid.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) + out_ref = ref_tanhshrink(self.x_np) + self.assertEqual(np.allclose(out_ref, res[0]), True) + + def test_errors(self): + with paddle.static.program_guard(paddle.static.Program()): + # The input type must be Variable. + self.assertRaises(TypeError, F.tanhshrink, 1) + # The input dtype must be float16, float32, float64. + x_int32 = paddle.data(name='x_int32', shape=[12, 10], dtype='int32') + self.assertRaises(TypeError, F.tanhshrink, x_int32) + # support the input dtype is float16 + x_fp16 = paddle.data(name='x_fp16', shape=[12, 10], dtype='float16') + F.tanhshrink(x_fp16) + + def ref_hardshrink(x, threshold): out = np.copy(x) out[(out >= -threshold) & (out <= threshold)] = 0 @@ -469,19 +525,24 @@ class TestHardShrinkAPI(unittest.TestCase): F.hardshrink(x_fp16) -class TestSoftShrink(TestActivation): +def ref_softshrink(x, threshold=0.5): + out = np.copy(x) + out = (out < -threshold) * (out + threshold) + (out > threshold) * ( + out - threshold) + return out + + +class TestSoftshrink(TestActivation): def setUp(self): self.op_type = "softshrink" self.init_dtype() - lambda_val = 0.1 - x = np.random.uniform(0.25, 10, [10, 12]).astype(self.dtype) - out = np.copy(x) - out = (out < -lambda_val) * (out + lambda_val) + (out > lambda_val) * ( - out - lambda_val) + threshold = 0.8 - self.attrs = {'lambda': lambda_val} - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + x = np.random.uniform(0.25, 10, [10, 12]).astype(self.dtype) + out = ref_softshrink(x, threshold) + self.inputs = {'X': x} + self.attrs = {"lambda": threshold} self.outputs = {'Out': out} def test_check_grad(self): @@ -490,17 +551,56 @@ class TestSoftShrink(TestActivation): self.check_grad(['X'], 'Out') -class TestSoftShrinkOpError(unittest.TestCase): +class TestSoftshrinkAPI(unittest.TestCase): + # test paddle.nn.Softshrink, paddle.nn.functional.softshrink + def setUp(self): + self.threshold = 0.8 + self.x_np = np.random.uniform(0.25, 10, [10, 12]).astype(np.float64) + self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_api(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', self.x_np.shape, self.x_np.dtype) + out1 = F.softshrink(x, self.threshold) + softshrink = paddle.nn.Softshrink(self.threshold) + out2 = softshrink(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = ref_softshrink(self.x_np, self.threshold) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + out1 = F.softshrink(x, self.threshold) + softshrink = paddle.nn.Softshrink(self.threshold) + out2 = softshrink(x) + out_ref = ref_softshrink(self.x_np, self.threshold) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + paddle.enable_static() + + def test_fluid_api(self): + with fluid.program_guard(fluid.Program()): + x = fluid.data('X', self.x_np.shape, self.x_np.dtype) + out = fluid.layers.softshrink(x, self.threshold) + exe = fluid.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) + out_ref = ref_softshrink(self.x_np, self.threshold) + self.assertEqual(np.allclose(out_ref, res[0]), True) + def test_errors(self): - with program_guard(Program()): + with paddle.static.program_guard(paddle.static.Program()): # The input type must be Variable. - self.assertRaises(TypeError, fluid.layers.softshrink, 1) + self.assertRaises(TypeError, F.softshrink, 1) # The input dtype must be float16, float32, float64. - x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32') - self.assertRaises(TypeError, fluid.layers.softshrink, x_int32) + x_int32 = paddle.data(name='x_int32', shape=[12, 10], dtype='int32') + self.assertRaises(TypeError, F.softshrink, x_int32) # support the input dtype is float16 - x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16') - fluid.layers.softshrink(x_fp16) + x_fp16 = paddle.data(name='x_fp16', shape=[12, 10], dtype='float16') + F.softshrink(x_fp16) class TestSqrt(TestActivation, TestParameter): @@ -903,20 +1003,24 @@ class TestBReluOpError(unittest.TestCase): fluid.layers.brelu(x_fp16) +def ref_relu6(x, threshold=6.0): + out = np.copy(x) + out[np.abs(x - threshold) < 0.005] = threshold + 0.02 + out = np.minimum(np.maximum(x, 0), threshold) + return out + + class TestRelu6(TestActivation): def setUp(self): self.op_type = "relu6" self.init_dtype() x = np.random.uniform(-1, 10, [10, 12]).astype(self.dtype) - threshold = 6.0 - # The same with TestAbs x[np.abs(x) < 0.005] = 0.02 - x[np.abs(x - threshold) < 0.005] = threshold + 0.02 - out = np.minimum(np.maximum(x, 0), threshold) + out = ref_relu6(x) - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} - self.attrs = {'threshold': threshold} + self.inputs = {'X': x} + self.attrs = {'threshold': 6.0} self.outputs = {'Out': out} def test_check_grad(self): @@ -925,17 +1029,56 @@ class TestRelu6(TestActivation): self.check_grad(['X'], 'Out') -class TestRelu6OpError(unittest.TestCase): +class TestRelu6API(unittest.TestCase): + # test paddle.nn.ReLU6, paddle.nn.functional.relu6 + def setUp(self): + self.x_np = np.random.uniform(-1, 10, [10, 12]).astype(np.float64) + self.x_np[np.abs(self.x_np) < 0.005] = 0.02 + self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_api(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', self.x_np.shape, self.x_np.dtype) + out1 = F.relu6(x) + relu6 = paddle.nn.ReLU6() + out2 = relu6(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = ref_relu6(self.x_np) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + out1 = F.relu6(x) + relu6 = paddle.nn.ReLU6() + out2 = relu6(x) + out_ref = ref_relu6(self.x_np) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + paddle.enable_static() + + def test_fluid_api(self): + with fluid.program_guard(fluid.Program()): + x = fluid.data('X', self.x_np.shape, self.x_np.dtype) + out = fluid.layers.relu6(x) + exe = fluid.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) + out_ref = ref_relu6(self.x_np) + self.assertEqual(np.allclose(out_ref, res[0]), True) + def test_errors(self): - with program_guard(Program()): + with paddle.static.program_guard(paddle.static.Program()): # The input type must be Variable. - self.assertRaises(TypeError, fluid.layers.relu6, 1) + self.assertRaises(TypeError, F.relu6, 1) # The input dtype must be float16, float32, float64. - x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32') - self.assertRaises(TypeError, fluid.layers.relu6, x_int32) + x_int32 = paddle.data(name='x_int32', shape=[12, 10], dtype='int32') + self.assertRaises(TypeError, F.relu6, x_int32) # support the input dtype is float16 - x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16') - fluid.layers.relu6(x_fp16) + x_fp16 = paddle.data(name='x_fp16', shape=[12, 10], dtype='float16') + F.relu6(x_fp16) class TestHardSwish(TestActivation): @@ -1318,16 +1461,25 @@ class TestSTanhOpError(unittest.TestCase): fluid.layers.stanh(x_fp16) +def ref_softplus(x, beta=1, threshold=20): + x_beta = beta * x + out = np.select([x_beta <= threshold, x_beta > threshold], + [np.log(1 + np.exp(x_beta)) / beta, x]) + return out + + class TestSoftplus(TestActivation): def setUp(self): self.op_type = "softplus" self.init_dtype() - self.dtype = np.float64 - x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) - out = np.log(1 + np.exp(x)) + beta = 2 + threshold = 15 - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + x = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype) + out = ref_softplus(x, beta, threshold) + self.inputs = {'X': x} + self.attrs = {'beta': beta, "threshold": threshold} self.outputs = {'Out': out} def test_check_grad(self): @@ -1336,15 +1488,72 @@ class TestSoftplus(TestActivation): self.check_grad(['X'], 'Out') +class TestSoftplusAPI(unittest.TestCase): + # test paddle.nn.Softplus, paddle.nn.functional.softplus + def setUp(self): + self.beta = 2 + self.threshold = 15 + self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(np.float64) + self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_api(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', self.x_np.shape, self.x_np.dtype) + out1 = F.softplus(x, self.beta, self.threshold) + softplus = paddle.nn.Softplus(self.beta, self.threshold) + out2 = softplus(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = ref_softplus(self.x_np, self.beta, self.threshold) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + out1 = F.softplus(x, self.beta, self.threshold) + softplus = paddle.nn.Softplus(self.beta, self.threshold) + out2 = softplus(x) + out_ref = ref_softplus(self.x_np, self.beta, self.threshold) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + paddle.enable_static() + + def test_fluid_api(self): + with fluid.program_guard(fluid.Program()): + x = fluid.data('X', self.x_np.shape, self.x_np.dtype) + out = fluid.layers.softplus(x) + exe = fluid.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) + out_ref = ref_softplus(self.x_np) + self.assertEqual(np.allclose(out_ref, res[0]), True) + + def test_errors(self): + with paddle.static.program_guard(paddle.static.Program()): + # The input type must be Variable. + self.assertRaises(TypeError, F.softplus, 1) + # The input dtype must be float16, float32, float64. + x_int32 = paddle.data(name='x_int32', shape=[12, 10], dtype='int32') + self.assertRaises(TypeError, F.softplus, x_int32) + # support the input dtype is float16 + x_fp16 = paddle.data(name='x_fp16', shape=[12, 10], dtype='float16') + F.softplus(x_fp16) + + +def ref_softsign(x): + out = np.divide(x, 1 + np.abs(x)) + return out + + class TestSoftsign(TestActivation): def setUp(self): self.op_type = "softsign" self.init_dtype() - x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) - out = np.divide(x, 1 + np.abs(x)) - - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + x = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype) + out = ref_softsign(x) + self.inputs = {'X': x} self.outputs = {'Out': out} def test_check_grad(self): @@ -1353,6 +1562,57 @@ class TestSoftsign(TestActivation): self.check_grad(['X'], 'Out') +class TestSoftsignAPI(unittest.TestCase): + # test paddle.nn.Softsign, paddle.nn.functional.softsign + def setUp(self): + self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(np.float64) + self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_api(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', self.x_np.shape, self.x_np.dtype) + out1 = F.softsign(x) + softsign = paddle.nn.Softsign() + out2 = softsign(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = ref_softsign(self.x_np) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + out1 = F.softsign(x) + softsign = paddle.nn.Softsign() + out2 = softsign(x) + out_ref = ref_softsign(self.x_np) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + paddle.enable_static() + + def test_fluid_api(self): + with fluid.program_guard(fluid.Program()): + x = fluid.data('X', self.x_np.shape, self.x_np.dtype) + out = fluid.layers.softsign(x) + exe = fluid.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) + out_ref = ref_softsign(self.x_np) + self.assertEqual(np.allclose(out_ref, res[0]), True) + + def test_errors(self): + with paddle.static.program_guard(paddle.static.Program()): + # The input type must be Variable. + self.assertRaises(TypeError, F.softsign, 1) + # The input dtype must be float16, float32, float64. + x_int32 = paddle.data(name='x_int32', shape=[12, 10], dtype='int32') + self.assertRaises(TypeError, F.softsign, x_int32) + # support the input dtype is float16 + x_fp16 = paddle.data(name='x_fp16', shape=[12, 10], dtype='float16') + F.softsign(x_fp16) + + class TestThresholdedRelu(TestActivation): def setUp(self): self.op_type = "thresholded_relu" @@ -1548,9 +1808,9 @@ create_test_act_fp16_class(TestActivation) create_test_act_fp16_class(TestSigmoid) create_test_act_fp16_class(TestLogSigmoid) create_test_act_fp16_class(TestTanh) -create_test_act_fp16_class(TestTanhShrink) +create_test_act_fp16_class(TestTanhshrink) create_test_act_fp16_class(TestHardShrink) -create_test_act_fp16_class(TestSoftShrink) +create_test_act_fp16_class(TestSoftshrink) create_test_act_fp16_class(TestSqrt) create_test_act_fp16_class(TestAbs) create_test_act_fp16_class(TestCeil, grad_check=False) diff --git a/python/paddle/fluid/tests/unittests/test_selu_op.py b/python/paddle/fluid/tests/unittests/test_selu_op.py index 6070c84ff23..590ef11e9cb 100644 --- a/python/paddle/fluid/tests/unittests/test_selu_op.py +++ b/python/paddle/fluid/tests/unittests/test_selu_op.py @@ -17,9 +17,26 @@ from __future__ import print_function import unittest import numpy as np import six +import paddle.fluid.core as core from op_test import OpTest +import paddle import paddle.fluid as fluid -from paddle.fluid import Program, program_guard +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.fluid import compiler, Program, program_guard + + +def ref_selu(x, + scale=1.0507009873554804934193349852946, + alpha=1.6732632423543772848170429916717): + out = np.copy(x) + out_flat = out.flatten() + for i in range(out_flat.size): + if out_flat[i] < 0: + out_flat[i] = alpha * np.exp(out_flat[i]) - alpha + out_flat[i] = scale * out_flat[i] + out = out_flat.reshape(x.shape) + return out class SeluTest(OpTest): @@ -39,17 +56,10 @@ class SeluTest(OpTest): # zero. x[np.abs(x) < 0.005] = 0.02 - x_flat = x.flatten() - - for i in range(x_flat.size): - if x_flat[i] < 0: - x_flat[i] = alpha * np.exp(x_flat[i]) - alpha - x_flat[i] = scale * x_flat[i] - - out_np = x_flat.reshape(self.x_shape) + out = ref_selu(x, scale, alpha) self.inputs = {'X': x} - self.outputs = {'Out': out_np} + self.outputs = {'Out': out} self.attrs = { 'alpha': alpha, @@ -69,17 +79,60 @@ class SeluTest(OpTest): self.check_grad(['X'], 'Out') -class TestSeluOpError(unittest.TestCase): +class TestSeluAPI(unittest.TestCase): + # test paddle.nn.SELU, paddle.nn.functional.selu + def setUp(self): + self.scale = 1.5 + self.alpha = 2.0 + self.x_np = np.random.normal(size=[3, 5, 5, 10]).astype(np.float64) + # Since zero point in selu is not differentiable, avoid randomize + # zero. + self.x_np[np.abs(self.x_np) < 0.005] = 0.02 + self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_api(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', self.x_np.shape, self.x_np.dtype) + out1 = F.selu(x, self.scale, self.alpha) + selu = paddle.nn.SELU(self.scale, self.alpha) + out2 = selu(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = ref_selu(self.x_np, self.scale, self.alpha) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + out1 = F.selu(x, self.scale, self.alpha) + selu = paddle.nn.SELU(self.scale, self.alpha) + out2 = selu(x) + out_ref = ref_selu(self.x_np, self.scale, self.alpha) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + paddle.enable_static() + + def test_fluid_api(self): + with fluid.program_guard(fluid.Program()): + x = fluid.data('X', self.x_np.shape, self.x_np.dtype) + out = fluid.layers.selu(x, self.scale, self.alpha) + exe = fluid.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) + out_ref = ref_selu(self.x_np, self.scale, self.alpha) + self.assertEqual(np.allclose(out_ref, res[0]), True) + def test_errors(self): - with program_guard(Program()): + with paddle.static.program_guard(paddle.static.Program()): # The input type must be Variable. - self.assertRaises(TypeError, fluid.layers.selu, 1) + self.assertRaises(TypeError, F.selu, 1) # The input dtype must be float16, float32, float64. - x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32') - self.assertRaises(TypeError, fluid.layers.selu, x_int32) - # support the input dtype is float32 - x_fp32 = fluid.data(name='x_fp32', shape=[12, 10], dtype='float32') - fluid.layers.selu(x_fp32) + x_int32 = paddle.data(name='x_int32', shape=[12, 10], dtype='int32') + self.assertRaises(TypeError, F.selu, x_int32) + # support the input dtype is float16 + x_fp16 = paddle.data(name='x_fp16', shape=[12, 10], dtype='float16') + F.selu(x_fp16) if __name__ == "__main__": diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 3b75629ede9..3dd1c1d94fb 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -57,10 +57,16 @@ from .layer.activation import GELU from .layer.activation import Hardshrink # from .layer.activation import PReLU #DEFINE_ALIAS from .layer.activation import ReLU +from .layer.activation import ReLU6 #DEFINE_ALIAS +from .layer.activation import SELU #DEFINE_ALIAS from .layer.activation import LeakyReLU #DEFINE_ALIAS from .layer.activation import Sigmoid #DEFINE_ALIAS from .layer.activation import LogSigmoid # from .layer.activation import Softmax #DEFINE_ALIAS +from .layer.activation import Softplus #DEFINE_ALIAS +from .layer.activation import Softshrink #DEFINE_ALIAS +from .layer.activation import Softsign #DEFINE_ALIAS +from .layer.activation import Tanhshrink #DEFINE_ALIAS from .layer.activation import LogSoftmax #DEFINE_ALIAS from .layer.activation import HSigmoid #DEFINE_ALIAS from .layer.common import BilinearTensorProduct #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 855935f6620..ff2b1edf672 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -47,7 +47,7 @@ from .activation import softplus #DEFINE_ALIAS from .activation import softshrink #DEFINE_ALIAS from .activation import softsign #DEFINE_ALIAS from .activation import swish #DEFINE_ALIAS -from .activation import tanh_shrink #DEFINE_ALIAS +from .activation import tanhshrink #DEFINE_ALIAS from .activation import thresholded_relu #DEFINE_ALIAS from .activation import log_softmax #DEFINE_ALIAS from .common import dropout #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 61476a61729..16a86ce2e8c 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -19,15 +19,9 @@ from ...fluid.layers import hard_sigmoid #DEFINE_ALIAS from ...fluid.layers import hard_swish #DEFINE_ALIAS from ...fluid.layers import leaky_relu #DEFINE_ALIAS from ...fluid.layers import maxout #DEFINE_ALIAS -from ...fluid.layers import relu6 #DEFINE_ALIAS -from ...fluid.layers import selu #DEFINE_ALIAS from ...fluid.layers import soft_relu #DEFINE_ALIAS -from ...fluid.layers import softplus #DEFINE_ALIAS -from ...fluid.layers import softshrink #DEFINE_ALIAS -from ...fluid.layers import softsign #DEFINE_ALIAS from ...fluid.layers import swish #DEFINE_ALIAS from ...fluid.layers import sigmoid #DEFINE_ALIAS -from ...fluid.layers import tanh_shrink #DEFINE_ALIAS from ...fluid.layers import thresholded_relu #DEFINE_ALIAS __all__ = [ @@ -53,7 +47,7 @@ __all__ = [ 'softsign', 'sigmoid', 'swish', - 'tanh_shrink', + 'tanhshrink', 'thresholded_relu', 'log_softmax' ] @@ -423,6 +417,103 @@ def logsigmoid(x, name=None): return out +def relu6(x, name=None): + """ + relu6 activation + + .. math:: + + \text{relu6}(x) = \min(\max(0,x), 6) + + Args: + x (Tensor): The input Tensor with data type float32, float64. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor with the same data type and shape as ``x`` . + + Examples: + + .. code-block:: python + + import paddle + import paddle.nn.functional as F + import numpy as np + + paddle.disable_static() + + x = paddle.to_tensor(np.array([-1, 0.3, 6.5])) + out = F.relu6(x) # [0, 0.3, 6] + + """ + threshold = 6.0 + if in_dygraph_mode(): + return core.ops.relu6(x, 'threshold', threshold) + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'relu6') + helper = LayerHelper('relu6', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op( + type='relu6', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'threshold': threshold}) + return out + + +def selu(x, + scale=1.0507009873554804934193349852946, + alpha=1.6732632423543772848170429916717, + name=None): + """ + selu activation + + .. math:: + + \text{selu}(x) = scale * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1))), \\ + with\,alpha=1.6732632423543772848170429916717 and \\ + scale=1.0507009873554804934193349852946 + + Args: + x (Tensor): The input Tensor with data type float32, float64. + scale (float, optional): The value of scale for selu. Default is 1.0507009873554804934193349852946 + alpha (float, optional): The value of alpha for selu. Default is 1.6732632423543772848170429916717 + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor with the same data type and shape as ``x`` . + + Examples: + + .. code-block:: python + + import paddle + import paddle.nn.functional as F + import numpy as np + + paddle.disable_static() + + x = paddle.to_tensor(np.array([[0, 1],[2, 3]])) + out = F.selu(x) # [[0, 1.050701],[2.101402, 3.152103]] + + """ + if in_dygraph_mode(): + return core.ops.selu(x, 'scale', scale, 'alpha', alpha) + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'selu') + helper = LayerHelper('selu', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op( + type='selu', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'scale': scale, + 'alpha': alpha}) + return out + + def softmax(x, axis=-1, name=None): """ This operator implements the softmax layer. The calculation process is as follows: @@ -539,6 +630,188 @@ def softmax(x, axis=-1, name=None): return paddle.fluid.layers.softmax(input=x, axis=axis, name=name) +def softplus(x, beta=1, threshold=20, name=None): + """ + softplus activation + + .. math:: + + \text{softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) \\ + \text{For numerical stability, the implementation reverts to the linear function when :}\,x \times \beta > threshold. + + Args: + x (Tensor): The input Tensor with data type float32, float64. + beta (float, optional): The value of beta for softplus. Default is 1 + threshold (float, optional): The value of threshold for softplus. Default is 20 + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor with the same data type and shape as ``x`` . + + Examples: + + .. code-block:: python + + import paddle + import paddle.nn.functional as F + import numpy as np + + paddle.disable_static() + + x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3])) + out = F.softplus(x) # [0.513015, 0.598139, 0.744397, 0.854355] + + """ + if in_dygraph_mode(): + return core.ops.softplus(x, 'beta', beta, 'threshold', threshold) + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'softplus') + helper = LayerHelper('softplus', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op( + type='softplus', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'beta': beta, + 'threshold': threshold}) + return out + + +def softshrink(x, threshold=0.5, name=None): + """ + softshrink activation + + .. math:: + + \text{softshrink}(x) = + \begin{cases} + x - threshold, & \text{ if } x > threshold \\ + x + threshold, & \text{ if } x < -threshold \\ + 0, & \text{ otherwise } + \end{cases} + + Args: + x (Tensor): The input Tensor with data type float32, float64. + threshold (float, optional): The value of threshold(must be no less than zero) for softplus. Default is 0.5 + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor with the same data type and shape as ``x`` . + + Examples: + + .. code-block:: python + + import paddle + import paddle.nn.functional as F + import numpy as np + + paddle.disable_static() + + x = paddle.to_tensor(np.array([-0.9, -0.2, 0.1, 0.8])) + out = F.softshrink(x) # [-0.4, 0, 0, 0.3] + + """ + if in_dygraph_mode(): + return core.ops.softshrink(x, 'lambda', threshold) + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'softshrink') + helper = LayerHelper('softshrink', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op( + type='softshrink', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'lambda': threshold}) + return out + + +def softsign(x, name=None): + """ + softsign activation + + .. math:: + + \text{softsign}(x) = \frac{x}{1 + |x|} + + Args: + x (Tensor): The input Tensor with data type float32, float64. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor with the same data type and shape as ``x`` . + + Examples: + + .. code-block:: python + + import paddle + import paddle.nn.functional as F + import numpy as np + + paddle.disable_static() + + x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3])) + out = F.softsign(x) # [-0.285714, -0.166667, 0.0909091, 0.230769] + + """ + if in_dygraph_mode(): + return core.ops.softsign(x) + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'softsign') + helper = LayerHelper('softsign', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op(type='softsign', inputs={'X': x}, outputs={'Out': out}) + return out + + +def tanhshrink(x, name=None): + """ + tanhshrink activation + + .. math:: + + \text{tanhshrink}(x) = x - \text{tanh}(x) + + Args: + x (Tensor): The input Tensor with data type float32, float64. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor with the same data type and shape as ``x`` . + + Examples: + + .. code-block:: python + + import paddle + import paddle.nn.functional as F + import numpy as np + + paddle.disable_static() + + x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3])) + out = F.tanhshrink(x) # [-0.020051, -0.00262468, 0.000332005, 0.00868739] + + """ + if in_dygraph_mode(): + return core.ops.tanh_shrink(x) + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'tanhshrink') + helper = LayerHelper('tanh_shrink', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op(type='tanh_shrink', inputs={'X': x}, outputs={'Out': out}) + return out + + def log_softmax(x, axis=-1, dtype=None, name=None): """ This operator implements the log_softmax layer. The calculation process is diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index 6b965564fc9..bb294467bb3 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -20,9 +20,15 @@ __all__ = [ 'Hardshrink', # 'PReLU', 'ReLU', + 'ReLU6', + 'SELU', 'LeakyReLU', 'Sigmoid', # 'Softmax', + 'Softplus', + 'Softshrink', + 'Softsign', + 'Tanhshrink', 'LogSigmoid', 'LogSoftmax', 'HSigmoid' @@ -351,6 +357,91 @@ class ReLU(layers.Layer): return F.relu(x, self._name) +class ReLU6(layers.Layer): + """ + ReLU6 Activation + + .. math:: + + \text{ReLU6}(x) = \min(\max(0,x), 6) + + Parameters: + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + + x = paddle.to_tensor(np.array([-1, 0.3, 6.5])) + m = paddle.nn.ReLU6() + out = m(x) # [0, 0.3, 6] + """ + + def __init__(self, name=None): + super(ReLU6, self).__init__() + self._name = name + + def forward(self, x): + return F.relu6(x, self._name) + + +class SELU(layers.Layer): + """ + SELU Activation + + .. math:: + + \text{SELU}(x) = scale * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1))), \\ + with\,alpha=1.6732632423543772848170429916717 and \\ + scale=1.0507009873554804934193349852946 + + Parameters: + scale (float, optional): The value of scale for SELU. Default is 1.0507009873554804934193349852946 + alpha (float, optional): The value of alpha for SELU. Default is 1.6732632423543772848170429916717 + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + + x = paddle.to_tensor(np.array([[0, 1],[2, 3]])) + m = paddle.nn.SELU() + out = m(x) # [[0, 1.050701],[2.101402, 3.152103]] + """ + + def __init__(self, + scale=1.0507009873554804934193349852946, + alpha=1.6732632423543772848170429916717, + name=None): + super(SELU, self).__init__() + self._scale = scale + self._alpha = alpha + self._name = name + + def forward(self, x): + return F.selu(x, self._scale, self._alpha, self._name) + + class LeakyReLU(layers.Layer): """ Leaky ReLU Activation. @@ -431,6 +522,167 @@ class Sigmoid(layers.Layer): return F.sigmoid(x, self.name) +class Softplus(layers.Layer): + """ + Softplus Activation + + .. math:: + + \text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) \\ + \text{For numerical stability, the implementation reverts to the linear function when :}\,x \times \beta > threshold. + + Parameters: + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + + x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3])) + m = paddle.nn.Softplus() + out = m(x) # [0.513015, 0.598139, 0.744397, 0.854355] + + """ + + def __init__(self, beta=1, threshold=20, name=None): + super(Softplus, self).__init__() + self._beta = beta + self._threshold = threshold + self._name = name + + def forward(self, x): + return F.softplus(x, self._beta, self._threshold, self._name) + + +class Softshrink(layers.Layer): + """ + Softshrink Activation + + .. math:: + + \text{Softshrink}(x) = + \begin{cases} + x - threshold, & \text{ if } x > threshold \\ + x + threshold, & \text{ if } x < -threshold \\ + 0, & \text{ otherwise } + \end{cases} + + Parameters: + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + + x = paddle.to_tensor(np.array([-0.9, -0.2, 0.1, 0.8])) + m = paddle.nn.Softshrink() + out = m(x) # [-0.4, 0, 0, 0.3] + """ + + def __init__(self, threshold=0.5, name=None): + super(Softshrink, self).__init__() + self._threshold = threshold + self._name = name + + def forward(self, x): + return F.softshrink(x, self._threshold, self._name) + + +class Softsign(layers.Layer): + """ + Softsign Activation + + .. math:: + + \text{Softsign}(x) = \frac{x}{1 + |x|} + + Parameters: + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + + x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3])) + m = paddle.nn.Softsign() + out = m(x) # [-0.285714, -0.166667, 0.0909091, 0.230769] + """ + + def __init__(self, name=None): + super(Softsign, self).__init__() + self._name = name + + def forward(self, x): + return F.softsign(x, self._name) + + +class Tanhshrink(layers.Layer): + """ + Tanhshrink Activation + + .. math:: + + \text{Tanhshrink}(x) = x - \text{Tanh}(x) + + Parameters: + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + + x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3])) + m = paddle.nn.Tanhshrink() + out = m(x) # [-0.020051, -0.00262468, 0.000332005, 0.00868739] + """ + + def __init__(self, name=None): + super(Tanhshrink, self).__init__() + self._name = name + + def forward(self, x): + return F.tanhshrink(x, self._name) + + class LogSigmoid(layers.Layer): """ LogSigmoid Activation. -- GitLab