未验证 提交 7bd7b188 编写于 作者: H huangjun12 提交者: GitHub

add alpha_dropout in nn.functional and nn.layer, test=develop (#26365)

* add alpha_dropout in nn.functional and nn.layer, test=develop

* refine Interface and assertion, test=develop

* fix ci import error, test=develop

* fix alias and use layers.scale, test=develop

* fix doc and training params, test=develop

* refine details in doc, test=develop

* fix doc details, test=develop
上级 9b14117c
......@@ -635,5 +635,103 @@ class TestDropout3DCAPI(unittest.TestCase):
self.assertTrue(np.allclose(result.numpy(), result_np))
class TestAlphaDropoutFAPI(unittest.TestCase):
def setUp(self):
np.random.seed(123)
self.places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
self.places.append(fluid.CUDAPlace(0))
def check_static_result(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input = fluid.data(name="input", shape=[40, 40], dtype="float32")
res1 = paddle.nn.functional.alpha_dropout(x=input, p=0.)
res2 = paddle.nn.functional.alpha_dropout(
x=input, p=0., training=False)
in_np = np.random.random([40, 40]).astype("float32")
res_np = in_np
exe = fluid.Executor(place)
res_list = [res1, res2]
for res in res_list:
fetches = exe.run(fluid.default_main_program(),
feed={"input": in_np},
fetch_list=[res])
self.assertTrue(np.allclose(fetches[0], res_np))
def test_static(self):
for place in self.places:
self.check_static_result(place=place)
def test_dygraph(self):
for place in self.places:
with fluid.dygraph.guard(place):
in_np = np.random.random([40, 40]).astype("float32")
res_np = in_np
input = fluid.dygraph.to_variable(in_np)
res1 = paddle.nn.functional.alpha_dropout(x=input, p=0.)
res2 = paddle.nn.functional.alpha_dropout(
x=input, p=0., training=False)
res_list = [res1, res2]
for res in res_list:
self.assertTrue(np.allclose(res.numpy(), res_np))
class TestAlphaDropoutFAPIError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
def test_Variable():
# the input of dropout must be Variable.
x1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace())
paddle.nn.functional.alpha_dropout(x1, p=0.5)
self.assertRaises(TypeError, test_Variable)
def test_dtype():
# the input dtype of dropout must be float32 or float64
xr = fluid.data(name='xr', shape=[3, 4, 5, 6], dtype="int32")
paddle.nn.functional.alpha_dropout(xr)
self.assertRaises(TypeError, test_dtype)
def test_pdtype():
# p should be int or float
x2 = fluid.data(name='x2', shape=[3, 4, 5, 6], dtype="float32")
paddle.nn.functional.alpha_dropout(x2, p='0.5')
self.assertRaises(TypeError, test_pdtype)
def test_pvalue():
# p should be 0.<=p<=1.
x2 = fluid.data(name='x2', shape=[3, 4, 5, 6], dtype="float32")
paddle.nn.functional.alpha_dropout(x2, p=1.2)
self.assertRaises(ValueError, test_pvalue)
class TestAlphaDropoutCAPI(unittest.TestCase):
def setUp(self):
np.random.seed(123)
self.places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
self.places.append(fluid.CUDAPlace(0))
def test_dygraph(self):
for place in self.places:
with fluid.dygraph.guard(place):
input_np = np.random.random([40, 40]).astype("float32")
result_np = input_np
input = fluid.dygraph.to_variable(input_np)
m = paddle.nn.AlphaDropout(p=0.)
m.eval()
result = m(input)
self.assertTrue(np.allclose(result.numpy(), result_np))
if __name__ == '__main__':
unittest.main()
......@@ -91,6 +91,7 @@ from .layer.common import UpSample #DEFINE_ALIAS
from .layer.common import Dropout #DEFINE_ALIAS
from .layer.common import Dropout2D #DEFINE_ALIAS
from .layer.common import Dropout3D #DEFINE_ALIAS
from .layer.common import AlphaDropout #DEFINE_ALIAS
from .layer.pooling import AdaptiveAvgPool2d #DEFINE_ALIAS
from .layer.pooling import AdaptiveAvgPool3d #DEFINE_ALIAS
from .layer.conv import Conv1d #DEFINE_ALIAS
......
......@@ -57,6 +57,7 @@ from .activation import log_softmax #DEFINE_ALIAS
from .common import dropout #DEFINE_ALIAS
from .common import dropout2d #DEFINE_ALIAS
from .common import dropout3d #DEFINE_ALIAS
from .common import alpha_dropout #DEFINE_ALIAS
# from .common import embedding #DEFINE_ALIAS
# from .common import fc #DEFINE_ALIAS
from .common import label_smooth #DEFINE_ALIAS
......
......@@ -40,6 +40,7 @@ __all__ = [
'dropout',
'dropout2d',
'dropout3d',
'alpha_dropout',
# 'embedding',
# 'fc',
'label_smooth',
......@@ -476,7 +477,6 @@ def dropout(x,
p (float | int): Probability of setting units to zero. Default 0.5.
axis (int | list): The axis along which the dropout is performed. Default None.
training (bool): A flag indicating whether it is in train phrase or not. Default True.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
mode(str): ['upscale_in_train'(default) | 'downscale_in_infer']
1. upscale_in_train(default), upscale the output at training time
......@@ -488,6 +488,7 @@ def dropout(x,
- train: out = input * mask
- inference: out = input * (1.0 - dropout_prob)
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor representing the dropout, has same shape and data type as `x` .
......@@ -549,7 +550,7 @@ def dropout(x,
[4 0 6]]
(3) What about ``axis=[0, 1]`` ? This means the dropout is performed in all axes of x,
which is the same case as default setting ``axis=None`` .
(4) You may note that logically `axis=None` means the dropout is performed in no axis of x,
(4) You may note that logically `axis=None` means the dropout is performed in none axis of x,
We generate mask with the shape 1*1. Whole input is randomly selected or dropped.
For example, we may get such mask:
[[0]]
......@@ -563,8 +564,7 @@ def dropout(x,
When x is a 4d tensor with shape `NCHW`, we can set ``axis=[0,1]`` and the dropout will be performed
in channel `N` and `C`, `H` and `W` is tied, i.e.
paddle.nn.dropout(x, p, axis=[0,1])
This is something we called dropout2d. Please refer to ``paddle.nn.functional.dropout2d``
for more details.
Please refer to ``paddle.nn.functional.dropout2d`` for more details.
Similarly, when x is a 5d tensor with shape `NCDHW`, we can set ``axis=[0,1]`` to perform
dropout3d. Please refer to ``paddle.nn.functional.dropout3d`` for more details.
......@@ -795,6 +795,80 @@ def dropout3d(x, p=0.5, training=True, data_format='NCDHW', name=None):
name=name)
def alpha_dropout(x, p=0.5, training=True, name=None):
"""
Alpha Dropout is a type of Dropout that maintains the self-normalizing property.
For an input with zero mean and unit standard deviation, the output of Alpha Dropout
maintains the original mean and standard deviation of the input.
Alpha Dropout fits well to SELU activate function by randomly setting activations to the negative saturation value.
Args:
x (Tensor): The input tensor. The data type is float32 or float64.
p (float | int): Probability of setting units to zero. Default 0.5.
training (bool): A flag indicating whether it is in train phrase or not. Default True.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor representing the dropout, has same shape and data type as `x`.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
x = np.array([[-1, 1], [-1, 1]]).astype('float32')
x = paddle.to_tensor(x)
y_train = paddle.nn.functional.alpha_dropout(x, 0.5)
y_test = paddle.nn.functional.alpha_dropout(x, 0.5, training=False)
print(x.numpy())
print(y_train.numpy())
# [[-0.10721093, 1.6655989 ], [-0.7791938, -0.7791938]] (randomly)
print(y_test.numpy())
"""
if not isinstance(p, (float, int)):
raise TypeError("p argument should be a float or int")
if p < 0 or p > 1:
raise ValueError("p argument should between 0 and 1")
if not in_dygraph_mode():
check_variable_and_dtype(x, 'x', ['float32', 'float64'],
'alpha_dropout')
if training:
#get transformation params
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
alpha_p = -alpha * scale
a = ((1 - p) * (1 + p * alpha_p**2))**-0.5
b = -a * alpha_p * p
dtype = x.dtype
input_shape = x.shape
#get mask
random_tensor = layers.uniform_random(
input_shape, dtype='float32', min=0., max=1.0)
p = layers.fill_constant(shape=[1], dtype='float32', value=p)
keep_mask = layers.greater_equal(random_tensor, p)
keep_mask = layers.cast(keep_mask, dtype)
drop_mask = layers.elementwise_sub(
layers.fill_constant(
shape=input_shape, dtype=dtype, value=1.),
keep_mask)
#apply mask
b = layers.fill_constant(shape=[1], dtype=dtype, value=b)
y = layers.elementwise_add(
paddle.multiply(x, keep_mask),
layers.scale(
drop_mask, scale=alpha_p))
res = layers.elementwise_add(layers.scale(y, scale=a), b, name=name)
return res
else: # test
return x
def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None):
"""
Pad tensor according to 'pad' and 'mode'.
......
......@@ -55,6 +55,7 @@ from .common import UpSample #DEFINE_ALIAS
from .common import Dropout #DEFINE_ALIAS
from .common import Dropout2D #DEFINE_ALIAS
from .common import Dropout3D #DEFINE_ALIAS
from .common import AlphaDropout #DEFINE_ALIAS
from .pooling import AdaptiveAvgPool2d #DEFINE_ALIAS
from .pooling import AdaptiveAvgPool3d #DEFINE_ALIAS
from .conv import Conv1d #DEFINE_ALIAS
......
......@@ -23,25 +23,11 @@ from .. import functional as F
from ...fluid.framework import _dygraph_tracer
__all__ = [
'BilinearTensorProduct',
'Pool2D',
'Embedding',
'Linear',
'UpSample',
'Pad2D',
'ReflectionPad1d',
'ReplicationPad1d',
'ConstantPad1d',
'ReflectionPad2d',
'ReplicationPad2d',
'ConstantPad2d',
'ZeroPad2d',
'ConstantPad3d',
'ReplicationPad3d',
'CosineSimilarity',
'Dropout',
'Dropout2D',
'Dropout3D',
'BilinearTensorProduct', 'Pool2D', 'Embedding', 'Linear', 'UpSample',
'Pad2D', 'ReflectionPad1d', 'ReplicationPad1d', 'ConstantPad1d',
'ReflectionPad2d', 'ReplicationPad2d', 'ConstantPad2d', 'ZeroPad2d',
'ConstantPad3d', 'ReplicationPad3d', 'CosineSimilarity', 'Dropout',
'Dropout2D', 'Dropout3D', 'AlphaDropout'
]
......@@ -361,12 +347,12 @@ class Dropout(layers.Layer):
according to the given dropout probability.
See ``paddle.nn.functional.dropout`` for more details.
In dygraph mode, please use ``eval()`` to indicate whether it is in test phrase or not.
In dygraph mode, please use ``eval()`` to switch to evaluation mode, where dropout is disabled.
Parameters:
p (float | int): Probability of setting units to zero. Default: 0.5
axis (int | list): The axis along which the dropout is performed. Default None.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
mode(str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']
1. upscale_in_train(default), upscale the output at training time
......@@ -378,6 +364,7 @@ class Dropout(layers.Layer):
- train: out = input * mask
- inference: out = input * (1.0 - p)
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: N-D tensor.
......@@ -404,7 +391,6 @@ class Dropout(layers.Layer):
super(Dropout, self).__init__()
self.p = p
self.training = _dygraph_tracer()._train_mode
self.axis = axis
self.mode = mode
self.name = name
......@@ -430,7 +416,8 @@ class Dropout2D(layers.Layer):
See ``paddle.nn.functional.dropout2d`` for more details.
Please use ``eval()`` to indicate whether it is in test phrase or not.
In dygraph mode, please use ``eval()`` to switch to evaluation mode, where dropout is disabled.
Parameters:
p (float, optional): Probability of setting units to zero. Default: 0.5
data_format (str, optional): Specify the data format of the input, and the data format of the output
......@@ -487,7 +474,8 @@ class Dropout3D(layers.Layer):
See ``paddle.nn.functional.dropout3d`` for more details.
Please use ``eval()`` to indicate whether it is in test phrase or not.
In dygraph mode, please use ``eval()`` to switch to evaluation mode, where dropout is disabled.
Parameters:
p (float | int): Probability of setting units to zero. Default: 0.5
data_format (str, optional): Specify the data format of the input, and the data format of the output
......@@ -521,7 +509,6 @@ class Dropout3D(layers.Layer):
super(Dropout3D, self).__init__()
self.p = p
self.training = _dygraph_tracer()._train_mode
self.data_format = data_format
self.name = name
......@@ -535,6 +522,55 @@ class Dropout3D(layers.Layer):
return out
class AlphaDropout(layers.Layer):
"""
Alpha Dropout is a type of Dropout that maintains the self-normalizing property. For an input with
zero mean and unit standard deviation, the output of Alpha Dropout maintains the original mean and
standard deviation of the input. Alpha Dropout fits well to SELU activate function by randomly setting
activations to the negative saturation value.
For more information, please refer to:
`Self-Normalizing Neural Networks <https://arxiv.org/abs/1706.02515>`_
In dygraph mode, please use ``eval()`` to switch to evaluation mode, where dropout is disabled.
Parameters:
p (float | int): Probability of setting units to zero. Default: 0.5
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: N-D tensor.
- output: N-D tensor, the same shape as input.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
x = np.array([[-1, 1], [-1, 1]]).astype('float32')
x = paddle.to_tensor(x)
m = paddle.nn.AlphaDropout(p=0.5)
y_train = m(x)
m.eval() # switch the model to test phase
y_test = m(x)
print(x.numpy())
print(y_train.numpy())
# [[-0.10721093, 1.6655989 ], [-0.7791938, -0.7791938]] (randomly)
print(y_test.numpy())
"""
def __init__(self, p=0.5, name=None):
super(AlphaDropout, self).__init__()
self.p = p
self.name = name
def forward(self, input):
out = F.alpha_dropout(
input, p=self.p, training=self.training, name=self.name)
return out
class ReflectionPad1d(layers.Layer):
"""
This interface is used to construct a callable object of the ``ReflectionPad1d`` class.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册