未验证 提交 74d3a550 编写于 作者: H hong19860320 提交者: GitHub

Add Swish and ThresholdedReLU for API 2.0 (#27758)

上级 a2d08aa9
......@@ -1979,22 +1979,24 @@ class TestSoftsignAPI(unittest.TestCase):
F.softsign(x_fp16)
def ref_thresholded_relu(x, threshold=1.0):
out = (x > threshold) * x
return out
class TestThresholdedRelu(TestActivation):
def setUp(self):
self.op_type = "thresholded_relu"
self.init_dtype()
threshold = 0.25
self.delta = 0.005
np.random.seed(1024)
X = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
# Same reason as TestAbs
X[np.abs(X - threshold) < self.delta] = threshold + 0.2
out = (X > threshold) * X
threshold = 15
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(X)}
self.attrs = {'threshold': threshold}
np.random.seed(1024)
x = np.random.uniform(-20, 20, [10, 12]).astype(self.dtype)
x[np.abs(x) < 0.005] = 0.02
out = ref_thresholded_relu(x, threshold)
self.inputs = {'X': x}
self.attrs = {"threshold": threshold}
self.outputs = {'Out': out}
def test_check_grad(self):
......@@ -2003,17 +2005,61 @@ class TestThresholdedRelu(TestActivation):
self.check_grad(['X'], 'Out')
class TestThresholdedReluOpError(unittest.TestCase):
class TestThresholdedReluAPI(unittest.TestCase):
# test paddle.nn.ThresholdedReLU, paddle.nn.functional.thresholded_relu
def setUp(self):
self.threshold = 15
np.random.seed(1024)
self.x_np = np.random.uniform(-20, 20, [10, 12]).astype(np.float64)
self.x_np[np.abs(self.x_np) < 0.005] = 0.02
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()
def test_static_api(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', self.x_np.shape, self.x_np.dtype)
out1 = F.thresholded_relu(x, self.threshold)
thresholded_relu = paddle.nn.ThresholdedReLU(self.threshold)
out2 = thresholded_relu(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = ref_thresholded_relu(self.x_np, self.threshold)
for r in res:
self.assertEqual(np.allclose(out_ref, r), True)
def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out1 = F.thresholded_relu(x, self.threshold)
thresholded_relu = paddle.nn.ThresholdedReLU(self.threshold)
out2 = thresholded_relu(x)
out_ref = ref_thresholded_relu(self.x_np, self.threshold)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
paddle.enable_static()
def test_fluid_api(self):
paddle.enable_static()
with fluid.program_guard(fluid.Program()):
x = fluid.data('X', self.x_np.shape, self.x_np.dtype)
out = fluid.layers.thresholded_relu(x, self.threshold)
exe = fluid.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out])
out_ref = ref_thresholded_relu(self.x_np, self.threshold)
self.assertEqual(np.allclose(out_ref, res[0]), True)
def test_errors(self):
with program_guard(Program()):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.thresholded_relu, 1)
self.assertRaises(TypeError, F.thresholded_relu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.thresholded_relu, x_int32)
x_int32 = paddle.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, F.thresholded_relu, x_int32)
# support the input dtype is float16
x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.thresholded_relu(x_fp16)
x_fp16 = paddle.data(name='x_fp16', shape=[12, 10], dtype='float16')
F.thresholded_relu(x_fp16)
def ref_hardsigmoid(x, slope=0.166666666666667, offset=0.5):
......@@ -2115,37 +2161,82 @@ class TestHardsigmoidAPI(unittest.TestCase):
F.hardsigmoid(x_fp16)
def ref_swish(x):
out = x * expit(x)
return out
class TestSwish(TestActivation):
def setUp(self):
self.op_type = "swish"
self.init_dtype()
np.random.seed(1024)
X = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
beta = 2.3
out = X * expit(beta * X)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(X)}
self.attrs = {'beta': beta}
x = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype)
out = ref_swish(x)
self.inputs = {'X': x}
self.attrs = {'slope': 1.0}
self.outputs = {'Out': out}
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.008)
self.check_grad(['X'], 'Out')
class TestSwishAPI(unittest.TestCase):
# test paddle.nn.Swish, paddle.nn.functional.swish
def setUp(self):
np.random.seed(1024)
self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(np.float64)
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()
def test_static_api(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', self.x_np.shape, self.x_np.dtype)
out1 = F.swish(x)
swish = paddle.nn.Swish()
out2 = swish(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = ref_swish(self.x_np)
for r in res:
self.assertEqual(np.allclose(out_ref, r), True)
def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out1 = F.swish(x)
swish = paddle.nn.Swish()
out2 = swish(x)
out_ref = ref_swish(self.x_np)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
paddle.enable_static()
def test_fluid_api(self):
paddle.enable_static()
with fluid.program_guard(fluid.Program()):
x = fluid.data('X', self.x_np.shape, self.x_np.dtype)
out = fluid.layers.swish(x)
exe = fluid.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out])
out_ref = ref_swish(self.x_np)
self.assertEqual(np.allclose(out_ref, res[0]), True)
class TestSwishOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.swish, 1)
self.assertRaises(TypeError, F.swish, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.swish, x_int32)
x_int32 = paddle.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, F.swish, x_int32)
# support the input dtype is float16
x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.swish(x_fp16)
x_fp16 = paddle.data(name='x_fp16', shape=[12, 10], dtype='float16')
F.swish(x_fp16)
#------------------ Test Error Activation----------------------
......
......@@ -69,7 +69,9 @@ from .layer.activation import Softmax #DEFINE_ALIAS
from .layer.activation import Softplus #DEFINE_ALIAS
from .layer.activation import Softshrink #DEFINE_ALIAS
from .layer.activation import Softsign #DEFINE_ALIAS
from .layer.activation import Swish #DEFINE_ALIAS
from .layer.activation import Tanhshrink #DEFINE_ALIAS
from .layer.activation import ThresholdedReLU #DEFINE_ALIAS
from .layer.activation import LogSoftmax #DEFINE_ALIAS
from .layer.activation import HSigmoid #DEFINE_ALIAS
from .layer.activation import Maxout #DEFINE_ALIAS
......
......@@ -15,9 +15,7 @@
# TODO: define activation functions of neural network
from ...fluid.layers import erf #DEFINE_ALIAS
from ...fluid.layers import soft_relu #DEFINE_ALIAS
from ...fluid.layers import swish #DEFINE_ALIAS
from ...fluid.layers import sigmoid #DEFINE_ALIAS
from ...fluid.layers import thresholded_relu #DEFINE_ALIAS
from ...tensor.math import tanh #DEFINE_ALIAS
__all__ = [
......@@ -787,8 +785,6 @@ def relu6(x, name=None):
import paddle.nn.functional as F
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-1, 0.3, 6.5]))
out = F.relu6(x) # [0, 0.3, 6]
"""
......@@ -839,8 +835,6 @@ def selu(x,
import paddle.nn.functional as F
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([[0.0, 1.0],[2.0, 3.0]]))
out = F.selu(x) # [[0, 1.050701],[2.101402, 3.152103]]
"""
......@@ -1054,8 +1048,6 @@ def softplus(x, beta=1, threshold=20, name=None):
import paddle.nn.functional as F
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
out = F.softplus(x) # [0.513015, 0.598139, 0.744397, 0.854355]
"""
......@@ -1103,8 +1095,6 @@ def softshrink(x, threshold=0.5, name=None):
import paddle.nn.functional as F
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-0.9, -0.2, 0.1, 0.8]))
out = F.softshrink(x) # [-0.4, 0, 0, 0.3]
"""
......@@ -1151,8 +1141,6 @@ def softsign(x, name=None):
import paddle.nn.functional as F
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
out = F.softsign(x) # [-0.285714, -0.166667, 0.0909091, 0.230769]
"""
......@@ -1167,6 +1155,47 @@ def softsign(x, name=None):
return out
def swish(x, name=None):
"""
swish activation.
.. math::
swish(x) = \\frac{x}{1 + e^{-x}}
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
import numpy as np
x = paddle.to_tensor(np.array([-2., 0., 1.]))
out = F.swish(x) # [-0.238406, 0., 0.731059]
"""
if in_dygraph_mode():
return core.ops.swish(x, 'slop', 1.0)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'swish')
helper = LayerHelper('swish', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type='swish',
inputs={'X': x},
outputs={'Out': out},
attrs={'slope': 1.0})
return out
def tanhshrink(x, name=None):
"""
tanhshrink activation
......@@ -1190,8 +1219,6 @@ def tanhshrink(x, name=None):
import paddle.nn.functional as F
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
out = F.tanhshrink(x) # [-0.020051, -0.00262468, 0.000332005, 0.00868739]
"""
......@@ -1206,6 +1233,52 @@ def tanhshrink(x, name=None):
return out
def thresholded_relu(x, threshold=1.0, name=None):
"""
thresholded relu activation.
.. math::
thresholded\\_relu(x) = \\begin{cases}
x, \\text{if } x > threshold \\\\
0, \\text{otherwise}
\\end{cases}
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
threshold (float, optional): The value of threshold for thresholded_relu. Default is 1.0
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
import numpy as np
x = paddle.to_tensor(np.array([2., 0., 1.]))
out = F.thresholded_relu(x) # [2., 0., 0.]
"""
if in_dygraph_mode():
return core.ops.thresholded_relu(x, 'threshold', threshold)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'thresholded_relu')
helper = LayerHelper('thresholded_relu', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type='thresholded_relu',
inputs={'X': x},
outputs={'Out': out},
attrs={'threshold': threshold})
return out
def log_softmax(x, axis=-1, dtype=None, name=None):
"""
This operator implements the log_softmax layer. The calculation process is
......
......@@ -32,7 +32,9 @@ __all__ = [
'Softplus',
'Softshrink',
'Softsign',
'Swish',
'Tanhshrink',
'ThresholdedReLU',
'LogSigmoid',
'LogSoftmax',
'Maxout',
......@@ -580,8 +582,6 @@ class ReLU6(layers.Layer):
import paddle
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-1, 0.3, 6.5]))
m = paddle.nn.ReLU6()
out = m(x) # [0, 0.3, 6]
......@@ -623,8 +623,6 @@ class SELU(layers.Layer):
import paddle
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([[0.0, 1.0],[2.0, 3.0]]))
m = paddle.nn.SELU()
out = m(x) # [[0, 1.050701],[2.101402, 3.152103]]
......@@ -801,8 +799,6 @@ class Softplus(layers.Layer):
import paddle
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
m = paddle.nn.Softplus()
out = m(x) # [0.513015, 0.598139, 0.744397, 0.854355]
......@@ -845,8 +841,6 @@ class Softshrink(layers.Layer):
import paddle
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-0.9, -0.2, 0.1, 0.8]))
m = paddle.nn.Softshrink()
out = m(x) # [-0.4, 0, 0, 0.3]
......@@ -883,8 +877,6 @@ class Softsign(layers.Layer):
import paddle
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
m = paddle.nn.Softsign()
out = m(x) # [-0.285714, -0.166667, 0.0909091, 0.230769]
......@@ -898,6 +890,41 @@ class Softsign(layers.Layer):
return F.softsign(x, self._name)
class Swish(layers.Layer):
"""
Swish Activation.
.. math::
Swish(x) = \\frac{x}{1 + e^{-x}}
Parameters:
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.
Examples:
.. code-block:: python
import paddle
import numpy as np
x = paddle.to_tensor(np.array([-2., 0., 1.]))
m = paddle.nn.Swish()
out = m(x) # [-0.238406, 0., 0.731059]
"""
def __init__(self, name=None):
super(Swish, self).__init__()
self._name = name
def forward(self, x):
return F.swish(x, self._name)
class Tanhshrink(layers.Layer):
"""
Tanhshrink Activation
......@@ -920,8 +947,6 @@ class Tanhshrink(layers.Layer):
import paddle
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
m = paddle.nn.Tanhshrink()
out = m(x) # [-0.020051, -0.00262468, 0.000332005, 0.00868739]
......@@ -935,6 +960,46 @@ class Tanhshrink(layers.Layer):
return F.tanhshrink(x, self._name)
class ThresholdedReLU(layers.Layer):
"""
Thresholded ReLU Activation
.. math::
ThresholdedReLU(x) = \\begin{cases}
x, \\text{if } x > threshold \\\\
0, \\text{otherwise}
\\end{cases}
Parameters:
threshold (float, optional): The value of threshold for ThresholdedReLU. Default is 1.0
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.
Examples:
.. code-block:: python
import paddle
import numpy as np
x = paddle.to_tensor(np.array([2., 0., 1.]))
m = paddle.nn.ThresholdedReLU()
out = m(x) # [2., 0., 0.]
"""
def __init__(self, threshold=1.0, name=None):
super(ThresholdedReLU, self).__init__()
self._threshold = threshold
self._name = name
def forward(self, x):
return F.thresholded_relu(x, self._threshold, self._name)
class LogSigmoid(layers.Layer):
"""
LogSigmoid Activation.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册