diff --git a/python/paddle/fluid/layers/ops.py b/python/paddle/fluid/layers/ops.py index 829b27a405970ee9bbf0a348246ae05b8c453925..e6cf36055f6aab77fca63065437d9095f812628a 100644 --- a/python/paddle/fluid/layers/ops.py +++ b/python/paddle/fluid/layers/ops.py @@ -474,6 +474,7 @@ __all__ += ['hard_shrink'] _hard_shrink_ = generate_layer_fn('hard_shrink') +@deprecated(since="2.0.0", update_to="paddle.nn.functional.hardshrink") def hard_shrink(x, threshold=None): check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'hard_shrink') @@ -487,10 +488,6 @@ def hard_shrink(x, threshold=None): hard_shrink.__doc__ = _hard_shrink_.__doc__ + """ - :alias_main: paddle.nn.functional.hard_shrink - :alias: paddle.nn.functional.hard_shrink,paddle.nn.functional.activation.hard_shrink - :old_api: paddle.fluid.layers.hard_shrink - Examples: >>> import paddle.fluid as fluid diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 124767a3364b078ea2c74795c03497f3dc24ba8c..fc5f1f26d8ffb9682c600a23ce592badea34571b 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -22,7 +22,7 @@ from scipy.special import expit, erf import paddle import paddle.fluid as fluid import paddle.nn as nn -import paddle.nn.functional as functional +import paddle.nn.functional as F from paddle.fluid import compiler, Program, program_guard @@ -344,6 +344,12 @@ class TestTanhShrink(TestActivation): self.check_grad(['X'], 'Out') +def ref_hardshrink(x, threshold): + out = np.copy(x) + out[(out >= -threshold) & (out <= threshold)] = 0 + return out + + class TestHardShrink(TestActivation): def setUp(self): self.op_type = "hard_shrink" @@ -351,11 +357,10 @@ class TestHardShrink(TestActivation): threshold = 0.5 x = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype) * 10 - out = np.copy(x) - out[(out >= -threshold) & (out <= threshold)] = 0 + out = ref_hardshrink(x, threshold) - self.attrs = {'lambda': threshold} - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {'threshold': threshold} + self.inputs = {'X': x} self.outputs = {'Out': out} def test_check_grad(self): @@ -364,17 +369,62 @@ class TestHardShrink(TestActivation): self.check_grad(['X'], 'Out') -class TestHardShrinkOpError(unittest.TestCase): +class TestHardShrinkAPI(unittest.TestCase): + # test paddle.nn.Hardshrink, paddle.nn.functional.hardshrink + def setUp(self): + self.x_np = np.random.uniform(-1, 1, [10, 12]).astype('float32') + self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_api(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', [10, 12]) + out1 = F.hardshrink(x) + hd = paddle.nn.Hardshrink() + out2 = hd(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = ref_hardshrink(self.x_np, 0.5) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_variable(self.x_np) + out1 = F.hardshrink(x) + hd = paddle.nn.Hardshrink() + out2 = hd(x) + out_ref = ref_hardshrink(self.x_np, 0.5) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + + out1 = F.hardshrink(x, 0.6) + hd = paddle.nn.Hardshrink(0.6) + out2 = hd(x) + out_ref = ref_hardshrink(self.x_np, 0.6) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + paddle.enable_static() + + def test_fluid_api(self): + with fluid.program_guard(fluid.Program()): + x = fluid.data('X', [10, 12]) + out = fluid.layers.hard_shrink(x) + exe = fluid.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) + out_ref = ref_hardshrink(self.x_np, 0.5) + self.assertEqual(np.allclose(out_ref, res[0]), True) + def test_errors(self): - with program_guard(Program()): + with paddle.static.program_guard(paddle.static.Program()): # The input type must be Variable. - self.assertRaises(TypeError, fluid.layers.hard_shrink, 1) + self.assertRaises(TypeError, F.hardshrink, 1) # The input dtype must be float16, float32, float64. - x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32') - self.assertRaises(TypeError, fluid.layers.hard_shrink, x_int32) + x_int32 = paddle.data(name='x_int32', shape=[12, 10], dtype='int32') + self.assertRaises(TypeError, F.hardshrink, x_int32) # support the input dtype is float16 - x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16') - fluid.layers.hard_shrink(x_fp16) + x_fp16 = paddle.data(name='x_fp16', shape=[12, 10], dtype='float16') + F.hardshrink(x_fp16) class TestSoftShrink(TestActivation): @@ -1435,7 +1485,7 @@ class TestNNFunctionalReluAPI(unittest.TestCase): main_program = Program() with fluid.program_guard(main_program): x = fluid.data(name='x', shape=self.x_shape) - y = functional.relu(x) + y = F.relu(x) exe = fluid.Executor(fluid.CPUPlace()) out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y]) self.assertTrue(np.allclose(out[0], self.y)) @@ -1501,7 +1551,7 @@ class TestNNFunctionalSigmoidAPI(unittest.TestCase): main_program = Program() with fluid.program_guard(main_program): x = fluid.data(name='x', shape=self.x_shape) - y = functional.sigmoid(x) + y = F.sigmoid(x) exe = fluid.Executor(fluid.CPUPlace()) out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y]) self.assertTrue(np.allclose(out[0], self.y)) diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 9583d9a0a39b362ce4bda2c11cb976fbe705cbe3..9188c47eca7274713723b61e54cb8522c870b4af 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -51,6 +51,7 @@ from .decode import beam_search_decode #DEFINE_ALIAS from .decode import gather_tree #DEFINE_ALIAS from .input import data #DEFINE_ALIAS # from .input import Input #DEFINE_ALIAS +from .layer.activation import Hardshrink # from .layer.activation import PReLU #DEFINE_ALIAS from .layer.activation import ReLU #DEFINE_ALIAS from .layer.activation import LeakyReLU #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index e3426b22484e4cea764f92cc44cc641386b7f6e4..ded5cb462efcb898c2404c51101916a109286264 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -29,7 +29,7 @@ from .activation import brelu #DEFINE_ALIAS from .activation import elu #DEFINE_ALIAS from .activation import erf #DEFINE_ALIAS from .activation import gelu #DEFINE_ALIAS -from .activation import hard_shrink #DEFINE_ALIAS +from .activation import hardshrink #DEFINE_ALIAS from .activation import hard_sigmoid #DEFINE_ALIAS from .activation import hard_swish #DEFINE_ALIAS from .activation import hsigmoid #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index f524d74f408c033d6a7b2816aebf42a2525247cf..75ba7d2114a2b11c664f2062616c168369acf6bd 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -17,7 +17,6 @@ from ...fluid.layers import brelu #DEFINE_ALIAS from ...fluid.layers import elu #DEFINE_ALIAS from ...fluid.layers import erf #DEFINE_ALIAS from ...fluid.layers import gelu #DEFINE_ALIAS -from ...fluid.layers import hard_shrink #DEFINE_ALIAS from ...fluid.layers import hard_sigmoid #DEFINE_ALIAS from ...fluid.layers import hard_swish #DEFINE_ALIAS from ...fluid.layers import leaky_relu #DEFINE_ALIAS @@ -38,7 +37,7 @@ __all__ = [ 'elu', 'erf', 'gelu', - 'hard_shrink', + 'hardshrink', 'hard_sigmoid', 'hard_swish', 'hsigmoid', @@ -69,6 +68,59 @@ from ...fluid.data_feeder import check_variable_and_dtype import paddle +def hardshrink(x, threshold=0.5, name=None): + """ + hard shrinkage activation + + .. math:: + + hardshrink(x)= + \left\{ + \begin{aligned} + &x, & & if \ x > threshold \\ + &x, & & if \ x < -threshold \\ + &0, & & if \ others + \end{aligned} + \right. + + Args: + x (Tensor): The input Tensor with data type float32, float64. + threshold (float, optional): The value of threshold for hardthrink. Default is 0.5 + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor with the same data type and shape as ``x`` . + + Examples: + + .. code-block:: python + + import paddle + import paddle.nn.functional as F + import numpy as np + + paddle.disable_static() + + x = paddle.to_variable(np.array([-1, 0.3, 2.5])) + out = F.hardshrink(x) # [-1., 0., 2.5] + + """ + if in_dygraph_mode(): + return core.ops.hard_shrink(x, 'threshold', threshold) + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'hardshrink') + helper = LayerHelper('hardshrink', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op( + type='hard_shrink', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'threshold': threshold}) + return out + + def hsigmoid(input, label, weight, diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index d13f36a31854acbd990e4f9f26d71c046fc8848d..fd418300fa3451d7f7d540be88f76a07f0cc0f7a 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -15,6 +15,7 @@ # TODO: define activation functions of neural network __all__ = [ + 'Hardshrink', # 'PReLU', 'ReLU', 'LeakyReLU', @@ -30,6 +31,53 @@ from ...fluid.framework import in_dygraph_mode from .. import functional +class Hardshrink(layers.Layer): + """ + Hardshrink Activation + + .. math:: + + hardshrink(x)= + \left\{ + \begin{aligned} + &x, & & if \ x > threshold \\ + &x, & & if \ x < -threshold \\ + &0, & & if \ others + \end{aligned} + \right. + + Parameters: + threshold (float, optional): The value of threshold for hardthrink. Default is 0.5 + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + + x = paddle.to_variable(np.array([-1, 0.3, 2.5])) + m = paddle.nn.Hardshrink() + out = m(x) # [-1., 0., 2.5] + """ + + def __init__(self, threshold=0.5, name=None): + super(Hardshrink, self).__init__() + self._threshold = threshold + self._name = name + + def forward(self, x): + return functional.hardshrink(x, self._threshold, self._name) + + class HSigmoid(layers.Layer): """ :alias_main: paddle.nn.HSigmoid