From ed102ea1b5f2f2822bb7883e7639c587d9a911be Mon Sep 17 00:00:00 2001 From: WangXi Date: Sat, 22 Aug 2020 14:12:01 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90API=E3=80=91Add=20sign=20and=20tanh=20?= =?UTF-8?q?api=20(#26357)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/fluid/layers/nn.py | 5 +- python/paddle/fluid/layers/ops.py | 2 +- .../tests/unittests/test_activation_op.py | 53 +++++++++++++ .../fluid/tests/unittests/test_sign_op.py | 28 +++++++ python/paddle/nn/__init__.py | 1 + python/paddle/nn/functional/__init__.py | 1 + python/paddle/nn/functional/activation.py | 2 + python/paddle/nn/layer/activation.py | 40 ++++++++++ python/paddle/tensor/math.py | 77 ++++++++++++++++++- 9 files changed, 202 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 34e0850991e..be3988b1849 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -14028,12 +14028,9 @@ def where(condition): return out +@deprecated(since="2.0.0", update_to="paddle.sign") def sign(x): """ - :alias_main: paddle.sign - :alias: paddle.sign,paddle.tensor.sign,paddle.tensor.math.sign - :old_api: paddle.fluid.layers.sign - This OP returns sign of every element in `x`: 1 for positive, -1 for negative and 0 for zero. Args: diff --git a/python/paddle/fluid/layers/ops.py b/python/paddle/fluid/layers/ops.py index dbeffcd2803..84cacea6ba5 100644 --- a/python/paddle/fluid/layers/ops.py +++ b/python/paddle/fluid/layers/ops.py @@ -28,11 +28,11 @@ __activations_noattr__ = [ 'tanh_shrink', 'softplus', 'softsign', + 'tanh', ] __unary_func__ = [ 'exp', - 'tanh', 'atan', 'sqrt', 'rsqrt', diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 174fff4acde..533f1081cd5 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -191,6 +191,59 @@ class TestTanh(TestActivation, TestParameter): self.dtype = np.float32 +class TestTanhAPI(unittest.TestCase): + # test paddle.tanh, paddle.nn.tanh, paddle.nn.functional.tanh + def setUp(self): + self.dtype = 'float32' + self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype) + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_api(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', [10, 12], self.dtype) + out1 = F.tanh(x) + th = paddle.nn.Tanh() + out2 = th(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = np.tanh(self.x_np) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_variable(self.x_np) + out1 = F.tanh(x) + out2 = paddle.tanh(x) + th = paddle.nn.Tanh() + out3 = th(x) + out_ref = np.tanh(self.x_np) + for r in [out1, out2, out3]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + paddle.enable_static() + + def test_fluid_api(self): + with fluid.program_guard(fluid.Program()): + x = fluid.data('X', [10, 12], self.dtype) + out = fluid.layers.tanh(x) + exe = fluid.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) + out_ref = np.tanh(self.x_np) + self.assertEqual(np.allclose(out_ref, res[0]), True) + + def test_errors(self): + with paddle.static.program_guard(paddle.static.Program()): + # The input type must be Variable. + self.assertRaises(TypeError, F.tanh, 1) + # The input dtype must be float16, float32. + x_int32 = paddle.data(name='x_int32', shape=[12, 10], dtype='int32') + self.assertRaises(TypeError, F.tanh, x_int32) + # support the input dtype is float16 + x_fp16 = paddle.data(name='x_fp16', shape=[12, 10], dtype='float16') + F.tanh(x_fp16) + + class TestAtan(TestActivation, TestParameter): def setUp(self): self.op_type = "atan" diff --git a/python/paddle/fluid/tests/unittests/test_sign_op.py b/python/paddle/fluid/tests/unittests/test_sign_op.py index b84e3b5377f..da5080eabdd 100644 --- a/python/paddle/fluid/tests/unittests/test_sign_op.py +++ b/python/paddle/fluid/tests/unittests/test_sign_op.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle import paddle.fluid as fluid from paddle.fluid import Program, program_guard @@ -54,5 +55,32 @@ class TestSignOpError(unittest.TestCase): fluid.layers.sign(input4) +class TestSignAPI(unittest.TestCase): + def test_dygraph(self): + with fluid.dygraph.guard(): + np_x = np.array([-1., 0., -0., 1.2, 1.5], dtype='float64') + x = paddle.to_tensor(np_x) + z = paddle.sign(x) + np_z = z.numpy() + z_expected = np.sign(np_x) + self.assertEqual((np_z == z_expected).all(), True) + + def test_static(self): + with program_guard(Program(), Program()): + # The input type of sign_op must be Variable or numpy.ndarray. + input1 = 12 + self.assertRaises(TypeError, paddle.tensor.math.sign, input1) + # The input dtype of sign_op must be float16, float32, float64. + input2 = fluid.layers.data( + name='input2', shape=[12, 10], dtype="int32") + input3 = fluid.layers.data( + name='input3', shape=[12, 10], dtype="int64") + self.assertRaises(TypeError, paddle.tensor.math.sign, input2) + self.assertRaises(TypeError, paddle.tensor.math.sign, input3) + input4 = fluid.layers.data( + name='input4', shape=[4], dtype="float16") + paddle.sign(input4) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 0e09eeb6a0d..8c7c677366c 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -54,6 +54,7 @@ from .decode import gather_tree #DEFINE_ALIAS # from .input import Input #DEFINE_ALIAS from .layer.activation import ELU from .layer.activation import GELU +from .layer.activation import Tanh from .layer.activation import Hardshrink from .layer.activation import Hardtanh from .layer.activation import PReLU diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index f91caade8f9..9d790ae8883 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -50,6 +50,7 @@ from .activation import softplus #DEFINE_ALIAS from .activation import softshrink #DEFINE_ALIAS from .activation import softsign #DEFINE_ALIAS from .activation import swish #DEFINE_ALIAS +from .activation import tanh #DEFINE_ALIAS from .activation import tanhshrink #DEFINE_ALIAS from .activation import thresholded_relu #DEFINE_ALIAS from .activation import log_softmax #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 44e322c6d4b..5c6f8139ca9 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -22,6 +22,7 @@ from ...fluid.layers import soft_relu #DEFINE_ALIAS from ...fluid.layers import swish #DEFINE_ALIAS from ...fluid.layers import sigmoid #DEFINE_ALIAS from ...fluid.layers import thresholded_relu #DEFINE_ALIAS +from ...tensor.math import tanh #DEFINE_ALIAS __all__ = [ 'brelu', @@ -47,6 +48,7 @@ __all__ = [ 'softsign', 'sigmoid', 'swish', + 'tanh', 'tanhshrink', 'thresholded_relu', 'log_softmax', diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index b9cc13fa85b..ed5913565e9 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -18,6 +18,7 @@ __all__ = [ 'ELU', 'GELU', 'Hardshrink', + 'Tanh', 'Hardtanh', 'PReLU', 'ReLU', @@ -182,6 +183,45 @@ class Hardshrink(layers.Layer): return F.hardshrink(x, self._threshold, self._name) +class Tanh(layers.Layer): + """ + Tanh Activation. + + .. math:: + Tanh(x) = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}} + + Parameters: + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + + x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3])) + m = paddle.nn.Tanh() + out = m(x) + print(out.numpy()) + # [-0.37994896 -0.19737532 0.09966799 0.29131261] + """ + + def __init__(self, name=None): + super(Tanh, self).__init__() + self._name = name + + def forward(self, x): + return F.tanh(x, self._name) + + class Hardtanh(layers.Layer): """ Hardtanh Activation diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index bb08333c2b9..cacfb0e8e75 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -51,14 +51,12 @@ from ..fluid.layers import reduce_sum #DEFINE_ALIAS from ..fluid.layers import round #DEFINE_ALIAS from ..fluid.layers import rsqrt #DEFINE_ALIAS from ..fluid.layers import scale #DEFINE_ALIAS -from ..fluid.layers import sign #DEFINE_ALIAS from ..fluid.layers import square #DEFINE_ALIAS from ..fluid.layers import stanh #DEFINE_ALIAS from ..fluid.layers import atan #DEFINE_ALIAS from ..fluid.layers import erf #DEFINE_ALIAS from ..fluid.layers import sqrt #DEFINE_ALIAS from ..fluid.layers import sin #DEFINE_ALIAS -from ..fluid.layers import tanh #DEFINE_ALIAS from ..fluid.layers import increment #DEFINE_ALIAS from ..fluid.layers import multiplex #DEFINE_ALIAS @@ -1747,3 +1745,78 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None): x = layers.cast(x, dtype) return layers.reduce_prod(input=x, dim=axis, keep_dim=keepdim, name=name) + + +def sign(x, name=None): + """ + This OP returns sign of every element in `x`: 1 for positive, -1 for negative and 0 for zero. + + Args: + x(Tensor): The input tensor. The data type can be float16, float32 or float64. + name (str, optional): The default value is None. Normally there is no need for user to + set this property. For more information, please refer to :ref:`api_guide_Name` + + Returns: + Tensor: The output sign tensor with identical shape and data type to the input :attr:`x`. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + + data = np.array([3.0, 0.0, -2.0, 1.7], dtype='float32') + paddle.disable_static() + x = paddle.to_tensor(data) + out = paddle.sign(x=x) + print(out) # [1.0, 0.0, -1.0, 1.0] + """ + if in_dygraph_mode(): + return core.ops.sign(x) + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'sign') + helper = LayerHelper("sign", **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op(type='sign', inputs={'X': [x]}, outputs={'Out': [out]}) + + return out + + +def tanh(x, name=None): + """ + Tanh Activation Operator. + + .. math:: + out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}} + + Args: + x (Tensor): Input of Tanh operator, an N-D Tensor, with data type float32, float64 or float16. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Output of Tanh operator, a Tensor with same data type and shape as input. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + + x_data = np.array([-0.4, -0.2, 0.1, 0.3]) + x = paddle.to_tensor(x_data) + out = paddle.tanh(x) + print(out.numpy()) + # [-0.37994896 -0.19737532 0.09966799 0.29131261] + """ + if in_dygraph_mode(): + return core.ops.tanh(x) + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'tanh') + helper = LayerHelper('tanh', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op(type='tanh', inputs={'X': x}, outputs={'Out': out}) + return out -- GitLab