diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 6b729e6297bedb07ddb25dbb3a1cbf69ad168c33..ac3d0a3a7856290907bba8116c6e9b1c01e7f34d 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -1979,22 +1979,24 @@ class TestSoftsignAPI(unittest.TestCase): F.softsign(x_fp16) +def ref_thresholded_relu(x, threshold=1.0): + out = (x > threshold) * x + return out + + class TestThresholdedRelu(TestActivation): def setUp(self): self.op_type = "thresholded_relu" self.init_dtype() - threshold = 0.25 - self.delta = 0.005 - np.random.seed(1024) - X = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) - - # Same reason as TestAbs - X[np.abs(X - threshold) < self.delta] = threshold + 0.2 - out = (X > threshold) * X + threshold = 15 - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(X)} - self.attrs = {'threshold': threshold} + np.random.seed(1024) + x = np.random.uniform(-20, 20, [10, 12]).astype(self.dtype) + x[np.abs(x) < 0.005] = 0.02 + out = ref_thresholded_relu(x, threshold) + self.inputs = {'X': x} + self.attrs = {"threshold": threshold} self.outputs = {'Out': out} def test_check_grad(self): @@ -2003,17 +2005,61 @@ class TestThresholdedRelu(TestActivation): self.check_grad(['X'], 'Out') -class TestThresholdedReluOpError(unittest.TestCase): +class TestThresholdedReluAPI(unittest.TestCase): + # test paddle.nn.ThresholdedReLU, paddle.nn.functional.thresholded_relu + def setUp(self): + self.threshold = 15 + np.random.seed(1024) + self.x_np = np.random.uniform(-20, 20, [10, 12]).astype(np.float64) + self.x_np[np.abs(self.x_np) < 0.005] = 0.02 + self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', self.x_np.shape, self.x_np.dtype) + out1 = F.thresholded_relu(x, self.threshold) + thresholded_relu = paddle.nn.ThresholdedReLU(self.threshold) + out2 = thresholded_relu(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = ref_thresholded_relu(self.x_np, self.threshold) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + out1 = F.thresholded_relu(x, self.threshold) + thresholded_relu = paddle.nn.ThresholdedReLU(self.threshold) + out2 = thresholded_relu(x) + out_ref = ref_thresholded_relu(self.x_np, self.threshold) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + paddle.enable_static() + + def test_fluid_api(self): + paddle.enable_static() + with fluid.program_guard(fluid.Program()): + x = fluid.data('X', self.x_np.shape, self.x_np.dtype) + out = fluid.layers.thresholded_relu(x, self.threshold) + exe = fluid.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) + out_ref = ref_thresholded_relu(self.x_np, self.threshold) + self.assertEqual(np.allclose(out_ref, res[0]), True) + def test_errors(self): - with program_guard(Program()): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): # The input type must be Variable. - self.assertRaises(TypeError, fluid.layers.thresholded_relu, 1) + self.assertRaises(TypeError, F.thresholded_relu, 1) # The input dtype must be float16, float32, float64. - x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32') - self.assertRaises(TypeError, fluid.layers.thresholded_relu, x_int32) + x_int32 = paddle.data(name='x_int32', shape=[12, 10], dtype='int32') + self.assertRaises(TypeError, F.thresholded_relu, x_int32) # support the input dtype is float16 - x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16') - fluid.layers.thresholded_relu(x_fp16) + x_fp16 = paddle.data(name='x_fp16', shape=[12, 10], dtype='float16') + F.thresholded_relu(x_fp16) def ref_hardsigmoid(x, slope=0.166666666666667, offset=0.5): @@ -2115,37 +2161,82 @@ class TestHardsigmoidAPI(unittest.TestCase): F.hardsigmoid(x_fp16) +def ref_swish(x): + out = x * expit(x) + return out + + class TestSwish(TestActivation): def setUp(self): self.op_type = "swish" self.init_dtype() np.random.seed(1024) - X = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) - beta = 2.3 - out = X * expit(beta * X) - - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(X)} - self.attrs = {'beta': beta} + x = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype) + out = ref_swish(x) + self.inputs = {'X': x} + self.attrs = {'slope': 1.0} self.outputs = {'Out': out} def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', max_relative_error=0.008) + self.check_grad(['X'], 'Out') + +class TestSwishAPI(unittest.TestCase): + # test paddle.nn.Swish, paddle.nn.functional.swish + def setUp(self): + np.random.seed(1024) + self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(np.float64) + self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', self.x_np.shape, self.x_np.dtype) + out1 = F.swish(x) + swish = paddle.nn.Swish() + out2 = swish(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = ref_swish(self.x_np) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + out1 = F.swish(x) + swish = paddle.nn.Swish() + out2 = swish(x) + out_ref = ref_swish(self.x_np) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + paddle.enable_static() + + def test_fluid_api(self): + paddle.enable_static() + with fluid.program_guard(fluid.Program()): + x = fluid.data('X', self.x_np.shape, self.x_np.dtype) + out = fluid.layers.swish(x) + exe = fluid.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) + out_ref = ref_swish(self.x_np) + self.assertEqual(np.allclose(out_ref, res[0]), True) -class TestSwishOpError(unittest.TestCase): def test_errors(self): - with program_guard(Program()): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): # The input type must be Variable. - self.assertRaises(TypeError, fluid.layers.swish, 1) + self.assertRaises(TypeError, F.swish, 1) # The input dtype must be float16, float32, float64. - x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32') - self.assertRaises(TypeError, fluid.layers.swish, x_int32) + x_int32 = paddle.data(name='x_int32', shape=[12, 10], dtype='int32') + self.assertRaises(TypeError, F.swish, x_int32) # support the input dtype is float16 - x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16') - fluid.layers.swish(x_fp16) + x_fp16 = paddle.data(name='x_fp16', shape=[12, 10], dtype='float16') + F.swish(x_fp16) #------------------ Test Error Activation---------------------- diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index c788727ab97e541eff2fd1c474fe0ab1c10b70a6..b16e95b7130f9fea68ce275339fc2940c698d6ea 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -69,7 +69,9 @@ from .layer.activation import Softmax #DEFINE_ALIAS from .layer.activation import Softplus #DEFINE_ALIAS from .layer.activation import Softshrink #DEFINE_ALIAS from .layer.activation import Softsign #DEFINE_ALIAS +from .layer.activation import Swish #DEFINE_ALIAS from .layer.activation import Tanhshrink #DEFINE_ALIAS +from .layer.activation import ThresholdedReLU #DEFINE_ALIAS from .layer.activation import LogSoftmax #DEFINE_ALIAS from .layer.activation import HSigmoid #DEFINE_ALIAS from .layer.activation import Maxout #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 2c65acb6f05a44cd21445dcff0db32592e4cedf0..53fa9814e6ef0dddfb727f89197049ca0e1ec1ad 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -15,9 +15,7 @@ # TODO: define activation functions of neural network from ...fluid.layers import erf #DEFINE_ALIAS from ...fluid.layers import soft_relu #DEFINE_ALIAS -from ...fluid.layers import swish #DEFINE_ALIAS from ...fluid.layers import sigmoid #DEFINE_ALIAS -from ...fluid.layers import thresholded_relu #DEFINE_ALIAS from ...tensor.math import tanh #DEFINE_ALIAS __all__ = [ @@ -787,8 +785,6 @@ def relu6(x, name=None): import paddle.nn.functional as F import numpy as np - paddle.disable_static() - x = paddle.to_tensor(np.array([-1, 0.3, 6.5])) out = F.relu6(x) # [0, 0.3, 6] """ @@ -839,8 +835,6 @@ def selu(x, import paddle.nn.functional as F import numpy as np - paddle.disable_static() - x = paddle.to_tensor(np.array([[0.0, 1.0],[2.0, 3.0]])) out = F.selu(x) # [[0, 1.050701],[2.101402, 3.152103]] """ @@ -1054,8 +1048,6 @@ def softplus(x, beta=1, threshold=20, name=None): import paddle.nn.functional as F import numpy as np - paddle.disable_static() - x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3])) out = F.softplus(x) # [0.513015, 0.598139, 0.744397, 0.854355] """ @@ -1103,8 +1095,6 @@ def softshrink(x, threshold=0.5, name=None): import paddle.nn.functional as F import numpy as np - paddle.disable_static() - x = paddle.to_tensor(np.array([-0.9, -0.2, 0.1, 0.8])) out = F.softshrink(x) # [-0.4, 0, 0, 0.3] """ @@ -1151,8 +1141,6 @@ def softsign(x, name=None): import paddle.nn.functional as F import numpy as np - paddle.disable_static() - x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3])) out = F.softsign(x) # [-0.285714, -0.166667, 0.0909091, 0.230769] """ @@ -1167,6 +1155,47 @@ def softsign(x, name=None): return out +def swish(x, name=None): + """ + swish activation. + + .. math:: + + swish(x) = \\frac{x}{1 + e^{-x}} + + Parameters: + x (Tensor): The input Tensor with data type float32, float64. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor with the same data type and shape as ``x`` . + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + import numpy as np + + x = paddle.to_tensor(np.array([-2., 0., 1.])) + out = F.swish(x) # [-0.238406, 0., 0.731059] + """ + + if in_dygraph_mode(): + return core.ops.swish(x, 'slop', 1.0) + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'swish') + helper = LayerHelper('swish', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op( + type='swish', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'slope': 1.0}) + return out + + def tanhshrink(x, name=None): """ tanhshrink activation @@ -1190,8 +1219,6 @@ def tanhshrink(x, name=None): import paddle.nn.functional as F import numpy as np - paddle.disable_static() - x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3])) out = F.tanhshrink(x) # [-0.020051, -0.00262468, 0.000332005, 0.00868739] """ @@ -1206,6 +1233,52 @@ def tanhshrink(x, name=None): return out +def thresholded_relu(x, threshold=1.0, name=None): + """ + thresholded relu activation. + + .. math:: + + thresholded\\_relu(x) = \\begin{cases} + x, \\text{if } x > threshold \\\\ + 0, \\text{otherwise} + \\end{cases} + + Parameters: + x (Tensor): The input Tensor with data type float32, float64. + threshold (float, optional): The value of threshold for thresholded_relu. Default is 1.0 + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor with the same data type and shape as ``x`` . + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + import numpy as np + + x = paddle.to_tensor(np.array([2., 0., 1.])) + out = F.thresholded_relu(x) # [2., 0., 0.] + """ + + if in_dygraph_mode(): + return core.ops.thresholded_relu(x, 'threshold', threshold) + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'thresholded_relu') + helper = LayerHelper('thresholded_relu', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op( + type='thresholded_relu', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'threshold': threshold}) + return out + + def log_softmax(x, axis=-1, dtype=None, name=None): """ This operator implements the log_softmax layer. The calculation process is diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index b3b7bd259c74409714b1c44c7e37e2e7696f668e..cd17f26e09e37546dee753c591d02ef9327661cd 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -32,7 +32,9 @@ __all__ = [ 'Softplus', 'Softshrink', 'Softsign', + 'Swish', 'Tanhshrink', + 'ThresholdedReLU', 'LogSigmoid', 'LogSoftmax', 'Maxout', @@ -580,8 +582,6 @@ class ReLU6(layers.Layer): import paddle import numpy as np - paddle.disable_static() - x = paddle.to_tensor(np.array([-1, 0.3, 6.5])) m = paddle.nn.ReLU6() out = m(x) # [0, 0.3, 6] @@ -623,8 +623,6 @@ class SELU(layers.Layer): import paddle import numpy as np - paddle.disable_static() - x = paddle.to_tensor(np.array([[0.0, 1.0],[2.0, 3.0]])) m = paddle.nn.SELU() out = m(x) # [[0, 1.050701],[2.101402, 3.152103]] @@ -801,8 +799,6 @@ class Softplus(layers.Layer): import paddle import numpy as np - paddle.disable_static() - x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3])) m = paddle.nn.Softplus() out = m(x) # [0.513015, 0.598139, 0.744397, 0.854355] @@ -845,8 +841,6 @@ class Softshrink(layers.Layer): import paddle import numpy as np - paddle.disable_static() - x = paddle.to_tensor(np.array([-0.9, -0.2, 0.1, 0.8])) m = paddle.nn.Softshrink() out = m(x) # [-0.4, 0, 0, 0.3] @@ -883,8 +877,6 @@ class Softsign(layers.Layer): import paddle import numpy as np - paddle.disable_static() - x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3])) m = paddle.nn.Softsign() out = m(x) # [-0.285714, -0.166667, 0.0909091, 0.230769] @@ -898,6 +890,41 @@ class Softsign(layers.Layer): return F.softsign(x, self._name) +class Swish(layers.Layer): + """ + Swish Activation. + + .. math:: + + Swish(x) = \\frac{x}{1 + e^{-x}} + + Parameters: + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + x = paddle.to_tensor(np.array([-2., 0., 1.])) + m = paddle.nn.Swish() + out = m(x) # [-0.238406, 0., 0.731059] + """ + + def __init__(self, name=None): + super(Swish, self).__init__() + self._name = name + + def forward(self, x): + return F.swish(x, self._name) + + class Tanhshrink(layers.Layer): """ Tanhshrink Activation @@ -920,8 +947,6 @@ class Tanhshrink(layers.Layer): import paddle import numpy as np - paddle.disable_static() - x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3])) m = paddle.nn.Tanhshrink() out = m(x) # [-0.020051, -0.00262468, 0.000332005, 0.00868739] @@ -935,6 +960,46 @@ class Tanhshrink(layers.Layer): return F.tanhshrink(x, self._name) +class ThresholdedReLU(layers.Layer): + """ + Thresholded ReLU Activation + + .. math:: + + ThresholdedReLU(x) = \\begin{cases} + x, \\text{if } x > threshold \\\\ + 0, \\text{otherwise} + \\end{cases} + + Parameters: + threshold (float, optional): The value of threshold for ThresholdedReLU. Default is 1.0 + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + x = paddle.to_tensor(np.array([2., 0., 1.])) + m = paddle.nn.ThresholdedReLU() + out = m(x) # [2., 0., 0.] + """ + + def __init__(self, threshold=1.0, name=None): + super(ThresholdedReLU, self).__init__() + self._threshold = threshold + self._name = name + + def forward(self, x): + return F.thresholded_relu(x, self._threshold, self._name) + + class LogSigmoid(layers.Layer): """ LogSigmoid Activation.