From 412eca679fc213d64ef3772f20a54f03809abb87 Mon Sep 17 00:00:00 2001 From: huangjun12 <2399845970@qq.com> Date: Sat, 22 Aug 2020 10:30:16 +0800 Subject: [PATCH] [API2.0] add dropout, dropout2d and dropout3d in nn.functional and nn.layer (#26111) * [API2.0] add dropout, dropout2d and dropout3d, test=develop * refine Interface and assertion after review, test=develop * fix alias p=1 and use scale, test=develop * fix doc and training, test=develop * fix doc in Dropout2D, test=develop --- python/paddle/fluid/layers/nn.py | 4 +- .../fluid/tests/unittests/test_dropout_op.py | 399 ++++++++++++++++++ python/paddle/nn/__init__.py | 3 + python/paddle/nn/functional/__init__.py | 2 + python/paddle/nn/functional/common.py | 345 ++++++++++++++- python/paddle/nn/layer/__init__.py | 3 + python/paddle/nn/layer/common.py | 187 ++++++++ 7 files changed, 937 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 955217a46c..056e63cff8 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -932,6 +932,7 @@ def cos_sim(X, Y): return out +@deprecated(since="2.0.0", update_to="paddle.nn.functional.dropout") def dropout(x, dropout_prob, is_test=False, @@ -939,9 +940,6 @@ def dropout(x, name=None, dropout_implementation="downgrade_in_infer"): """ - :alias_main: paddle.nn.functional.dropout - :alias: paddle.nn.functional.dropout,paddle.nn.functional.common.dropout - :old_api: paddle.fluid.layers.dropout Computes dropout. diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index cc3910d1b0..816bb263ce 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -18,6 +18,7 @@ import unittest import numpy as np import paddle.fluid.core as core from op_test import OpTest, skip_check_grad_ci +import paddle import paddle.fluid as fluid from paddle.fluid import Program, program_guard @@ -236,5 +237,403 @@ class TestDropoutOpError(unittest.TestCase): self.assertRaises(TypeError, test_dtype) +class TestDropoutFAPI(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def check_static_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = fluid.data(name="input", shape=[40, 40], dtype="float32") + res1 = paddle.nn.functional.dropout(x=input, p=0., training=False) + res2 = paddle.nn.functional.dropout( + x=input, p=0., axis=0, training=True, mode='upscale_in_train') + res3 = paddle.nn.functional.dropout( + x=input, p=0., axis=0, training=True, mode='downscale_in_infer') + res4 = paddle.nn.functional.dropout( + x=input, p=0., axis=0, training=False, mode='upscale_in_train') + res5 = paddle.nn.functional.dropout( + x=input, + p=0., + axis=0, + training=False, + mode='downscale_in_infer') + res6 = paddle.nn.functional.dropout( + x=input, + p=0., + axis=[0, 1], + training=True, + mode='upscale_in_train') + res7 = paddle.nn.functional.dropout( + x=input, + p=0., + axis=[0, 1], + training=True, + mode='downscale_in_infer') + res8 = paddle.nn.functional.dropout( + x=input, + p=0., + axis=[0, 1], + training=False, + mode='upscale_in_train') + res9 = paddle.nn.functional.dropout( + x=input, + p=0., + axis=[0, 1], + training=False, + mode='downscale_in_infer') + res10 = paddle.nn.functional.dropout(x=input, p=1., training=True) + + in_np = np.random.random([40, 40]).astype("float32") + res_np = in_np + res_np2 = np.zeros_like(in_np) + + exe = fluid.Executor(place) + res_list = [res1, res2, res3, res4, res5, res6, res7, res8, res9] + for res in res_list: + fetches = exe.run(fluid.default_main_program(), + feed={"input": in_np}, + fetch_list=[res]) + self.assertTrue(np.allclose(fetches[0], res_np)) + fetches2 = exe.run(fluid.default_main_program(), + feed={"input": in_np}, + fetch_list=[res10]) + self.assertTrue(np.allclose(fetches2[0], res_np2)) + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + def test_dygraph(self): + for place in self.places: + with fluid.dygraph.guard(place): + in_np = np.random.random([40, 40]).astype("float32") + res_np = in_np + res_np2 = np.zeros_like(in_np) + input = fluid.dygraph.to_variable(in_np) + + res1 = paddle.nn.functional.dropout( + x=input, p=0., training=False) + res2 = paddle.nn.functional.dropout( + x=input, + p=0., + axis=0, + training=True, + mode='upscale_in_train') + res3 = paddle.nn.functional.dropout( + x=input, + p=0., + axis=0, + training=True, + mode='downscale_in_infer') + res4 = paddle.nn.functional.dropout( + x=input, + p=0., + axis=0, + training=False, + mode='upscale_in_train') + res5 = paddle.nn.functional.dropout( + x=input, + p=0., + axis=0, + training=False, + mode='downscale_in_infer') + res6 = paddle.nn.functional.dropout( + x=input, + p=0., + axis=[0, 1], + training=True, + mode='upscale_in_train') + res7 = paddle.nn.functional.dropout( + x=input, + p=0., + axis=[0, 1], + training=True, + mode='downscale_in_infer') + res8 = paddle.nn.functional.dropout( + x=input, + p=0., + axis=[0, 1], + training=False, + mode='upscale_in_train') + res9 = paddle.nn.functional.dropout( + x=input, + p=0., + axis=[0, 1], + training=False, + mode='downscale_in_infer') + res10 = paddle.nn.functional.dropout( + x=input, p=1., training=True) + + res_list = [res1, res2, res3, res4, res5, res6, res7, res8, res9] + for res in res_list: + self.assertTrue(np.allclose(res.numpy(), res_np)) + self.assertTrue(np.allclose(res10.numpy(), res_np2)) + + +class TestDropoutFAPIError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + + def test_Variable(): + # the input of dropout must be Variable. + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) + paddle.nn.functional.dropout(x1, p=0.5) + + self.assertRaises(TypeError, test_Variable) + + def test_Variable2(): + # the input of dropout must be Variable. + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) + paddle.nn.functional.dropout(x1, p=0.5, axis=0) + + self.assertRaises(TypeError, test_Variable2) + + def test_dtype(): + # the input dtype of dropout must be float32 or float64 + # float16 only can be set on GPU place + xr = fluid.data(name='xr', shape=[3, 4, 5, 6], dtype="int32") + paddle.nn.functional.dropout(xr, p=0.5) + + self.assertRaises(TypeError, test_dtype) + + def test_pdtype(): + # p should be int or float + x2 = fluid.data(name='x2', shape=[3, 4, 5, 6], dtype="float32") + paddle.nn.functional.dropout(x2, p='0.5') + + self.assertRaises(TypeError, test_pdtype) + + def test_pvalue(): + # p should be 0.<=p<=1. + x2 = fluid.data(name='x2', shape=[3, 4, 5, 6], dtype="float32") + paddle.nn.functional.dropout(x2, p=1.2) + + self.assertRaises(ValueError, test_pvalue) + + def test_mode(): + # mode should be 'downscale_in_infer' or 'upscale_in_train' + x2 = fluid.data(name='x2', shape=[3, 4, 5, 6], dtype="float32") + paddle.nn.functional.dropout(x2, mode='abc') + + self.assertRaises(ValueError, test_mode) + + def test_axis(): + # axis should be int or list + x2 = fluid.data(name='x2', shape=[3, 4, 5, 6], dtype="float32") + paddle.nn.functional.dropout(x2, axis=1.2) + + self.assertRaises(TypeError, test_axis) + + def test_axis_max(): + # maximum of axis should less than dimensions of x + x2 = fluid.data(name='x2', shape=[3, 4, 5, 6], dtype="float32") + paddle.nn.functional.dropout(x2, axis=[0, 5]) + + self.assertRaises(ValueError, test_axis_max) + + def test_axis_len(): + # length of axis should not greater than dimensions of x + x2 = fluid.data(name='x2', shape=[3, 4, 5, 6], dtype="float32") + paddle.nn.functional.dropout(x2, axis=[0, 1, 2, 3, 4]) + + self.assertRaises(ValueError, test_axis_len) + + +class TestDropoutCAPI(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def test_dygraph(self): + for place in self.places: + with fluid.dygraph.guard(place): + input_np = np.random.random([40, 40]).astype("float32") + result_np = input_np + input = fluid.dygraph.to_variable(input_np) + m = paddle.nn.Dropout(p=0.) + m.eval() + result = m(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + +class TestDropout2dFAPI(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def check_static_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = fluid.data( + name="input", shape=[2, 3, 4, 5], dtype="float32") + res1 = paddle.nn.functional.dropout2d( + x=input, p=0., training=False, data_format='NCHW') + res2 = paddle.nn.functional.dropout2d( + x=input, p=0., training=False, data_format='NHWC') + + in_np = np.random.random([2, 3, 4, 5]).astype("float32") + res_np = in_np + + exe = fluid.Executor(place) + res_list = [res1, res2] + for res in res_list: + fetches = exe.run(fluid.default_main_program(), + feed={"input": in_np}, + fetch_list=[res]) + self.assertTrue(np.allclose(fetches[0], res_np)) + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + def test_dygraph(self): + for place in self.places: + with fluid.dygraph.guard(place): + in_np = np.random.random([2, 3, 4, 5]).astype("float32") + res_np = in_np + input = fluid.dygraph.to_variable(in_np) + + res1 = paddle.nn.functional.dropout2d( + x=input, p=0., training=False, data_format='NCHW') + res2 = paddle.nn.functional.dropout2d( + x=input, p=0., training=False, data_format='NHWC') + + res_list = [res1, res2] + for res in res_list: + self.assertTrue(np.allclose(res.numpy(), res_np)) + + +class TestDropout2dFAPIError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + + def test_xdim(): + # dimentions of x should be 4 + x = fluid.data(name='x1', shape=[2, 3, 4, 5, 6], dtype="int32") + paddle.nn.functional.dropout2d(x) + + self.assertRaises(ValueError, test_xdim) + + def test_dataformat(): + # data_format should be 'NCHW' or 'NHWC' + x = fluid.data(name='x2', shape=[2, 3, 4, 5], dtype="int32") + paddle.nn.functional.dropout2d(x, data_format='CNHW') + + self.assertRaises(ValueError, test_dataformat) + + +class TestDropout2DCAPI(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def test_dygraph(self): + for place in self.places: + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 4, 5]).astype("float32") + result_np = input_np + input = fluid.dygraph.to_variable(input_np) + m = paddle.nn.Dropout2D(p=0.) + m.eval() + result = m(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + +class TestDropout3dFAPI(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def check_static_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = fluid.data( + name="input", shape=[2, 3, 4, 5, 6], dtype="float32") + res1 = paddle.nn.functional.dropout3d( + x=input, p=0., training=False, data_format='NCDHW') + res2 = paddle.nn.functional.dropout3d( + x=input, p=0., training=False, data_format='NDHWC') + + in_np = np.random.random([2, 3, 4, 5, 6]).astype("float32") + res_np = in_np + + exe = fluid.Executor(place) + res_list = [res1, res2] + for res in res_list: + fetches = exe.run(fluid.default_main_program(), + feed={"input": in_np}, + fetch_list=[res]) + self.assertTrue(np.allclose(fetches[0], res_np)) + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + def test_dygraph(self): + for place in self.places: + with fluid.dygraph.guard(place): + in_np = np.random.random([2, 3, 4, 5, 6]).astype("float32") + res_np = in_np + input = fluid.dygraph.to_variable(in_np) + + res1 = paddle.nn.functional.dropout3d( + x=input, p=0., training=False, data_format='NCDHW') + res2 = paddle.nn.functional.dropout3d( + x=input, p=0., training=False, data_format='NDHWC') + + res_list = [res1, res2] + for res in res_list: + self.assertTrue(np.allclose(res.numpy(), res_np)) + + +class TestDropout3dFAPIError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + + def test_xdim(): + # dimentions of x should be 5 + x = fluid.data(name='x1', shape=[2, 3, 4, 5], dtype="int32") + paddle.nn.functional.dropout3d(x) + + self.assertRaises(ValueError, test_xdim) + + def test_dataformat(): + # data_format should be 'NCDHW' or 'NDHWC' + x = fluid.data(name='x2', shape=[2, 3, 4, 5, 6], dtype="int32") + paddle.nn.functional.dropout3d(x, data_format='CNDHW') + + self.assertRaises(ValueError, test_dataformat) + + +class TestDropout3DCAPI(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def test_dygraph(self): + for place in self.places: + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 4, 5, 6]).astype("float32") + result_np = input_np + input = fluid.dygraph.to_variable(input_np) + m = paddle.nn.Dropout3D(p=0.) + m.eval() + result = m(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 861260fedd..7fbf26df96 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -87,6 +87,9 @@ from .layer.common import Embedding #DEFINE_ALIAS from .layer.common import Linear #DEFINE_ALIAS from .layer.common import Flatten #DEFINE_ALIAS from .layer.common import UpSample #DEFINE_ALIAS +from .layer.common import Dropout #DEFINE_ALIAS +from .layer.common import Dropout2D #DEFINE_ALIAS +from .layer.common import Dropout3D #DEFINE_ALIAS from .layer.pooling import AdaptiveAvgPool2d #DEFINE_ALIAS from .layer.pooling import AdaptiveAvgPool3d #DEFINE_ALIAS from .layer.conv import Conv2D #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 744452c17c..a135aea98c 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -54,6 +54,8 @@ from .activation import tanhshrink #DEFINE_ALIAS from .activation import thresholded_relu #DEFINE_ALIAS from .activation import log_softmax #DEFINE_ALIAS from .common import dropout #DEFINE_ALIAS +from .common import dropout2d #DEFINE_ALIAS +from .common import dropout3d #DEFINE_ALIAS # from .common import embedding #DEFINE_ALIAS # from .common import fc #DEFINE_ALIAS from .common import label_smooth #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 9e54c62ab6..8091549c08 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -13,13 +13,12 @@ # limitations under the License. import warnings -import paddle.fluid.core as core -from ...fluid.framework import in_dygraph_mode, core +import paddle +from ...fluid.framework import in_dygraph_mode, default_main_program from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layers.tensor import Variable, fill_constant, zeros, concat # TODO: define the common functions to build a neural network -from ...fluid.layers import dropout #DEFINE_ALIAS from ...fluid.layers import label_smooth #DEFINE_ALIAS from ...fluid import one_hot #DEFINE_ALIAS from ...fluid.layers import pad2d #DEFINE_ALIAS @@ -34,9 +33,13 @@ from ...tensor import sqrt #from ...fluid.layers import fc #DEFINE_ALIAS from ...fluid.layers import pad_constant_like #DEFINE_ALIAS +from ...fluid import core, layers +from ...fluid.data_feeder import check_variable_and_dtype __all__ = [ 'dropout', + 'dropout2d', + 'dropout3d', # 'embedding', # 'fc', 'label_smooth', @@ -456,6 +459,342 @@ def interpolate(input, return out +def dropout(x, + p=0.5, + axis=None, + training=True, + mode="upscale_in_train", + name=None): + """ + Dropout is a regularization technique for reducing overfitting by preventing + neuron co-adaption during training. The dropout operator randomly sets the + outputs of some units to zero, while upscale others according to the given + dropout probability. + + Args: + x (Tensor): The input tensor. The data type is float32 or float64. + p (float | int): Probability of setting units to zero. Default 0.5. + axis (int | list): The axis along which the dropout is performed. Default None. + training (bool): A flag indicating whether it is in train phrase or not. Default True. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + mode(str): ['upscale_in_train'(default) | 'downscale_in_infer'] + + 1. upscale_in_train(default), upscale the output at training time + + - train: out = input * mask / ( 1.0 - dropout_prob ) + - inference: out = input + + 2. downscale_in_infer, downscale the output at inference + + - train: out = input * mask + - inference: out = input * (1.0 - dropout_prob) + + Returns: + A Tensor representing the dropout, has same shape and data type as `x` . + + Examples: + We use ``p=0.5`` in the following description for simplicity. + 1. When ``axis=None`` , this is commonly used dropout, which dropout each element of x randomly. + Let's see a simple case when x is a 2d tensor with shape 2*3: + [[1 2 3] + [4 5 6]] + we generate mask with the same shape as x, which is 2*3. The value of mask is + sampled from a Bernoulli distribution randomly. For example, we may get such mask: + [[0 1 0] + [1 0 1]] + So the output is obtained from elementwise multiply of x and mask: + [[0 2 0] + [4 0 6]] + Using default setting, i.e. ``mode='upscale_in_train'`` , + if in training phase, the final upscale output is: + [[0 4 0 ] + [8 0 12]] + if in test phase, the output is the same as input: + [[1 2 3] + [4 5 6]] + we can also set ``mode='downscale_in_infer'`` , then + if in training phase, the final output is: + [[0 2 0] + [4 0 6]] + if in test phase, the scale output is: + [[0.5 1. 1.5] + [2. 2.5 3. ]] + + 2. When ``axis!=None`` , this is useful for dropping whole channels from an image or sequence. + Let's see the simple case when x is a 2d tensor with shape 2*3 again: + [[1 2 3] + [4 5 6]] + (1) If ``axis=0`` , this means the dropout is only performed in axis `0` . + we generate mask with the shape 2*1. Only in axis `0` the value is randomly selected. + For example, we may get such mask: + [[1] + [0]] + The output is obtained from elementwise multiply of x and mask. Doing that the mask will be + broadcast from 2*1 to 2*3: + [[1 1 1] + [0 0 0]] + and the result after elementwise multiply is: + [[1 2 3] + [0 0 0]] + then we can do upscale or downscale according to the setting of other arguments. + (2) If ``axis=1`` , this means the dropout is only performed in axis `1` . + we generate mask with the shape 1*3. Only in axis `1` the value is randomly selected. + For example, we may get such mask: + [[1 0 1]] + Doing elementwise multiply the mask will be broadcast from 1*3 to 2*3: + [[1 0 1] + [1 0 1]] + and the result after elementwise multiply is: + [[1 0 3] + [4 0 6]] + (3) What about ``axis=[0, 1]`` ? This means the dropout is performed in all axes of x, + which is the same case as default setting ``axis=None`` . + (4) You may note that logically `axis=None` means the dropout is performed in no axis of x, + We generate mask with the shape 1*1. Whole input is randomly selected or dropped. + For example, we may get such mask: + [[0]] + Doing elementwise multiply the mask will be broadcast from 1*1 to 2*3: + [[0 0 0] + [0 0 0]] + and the result after elementwise multiply is: + [[0 0 0] + [0 0 0]] + Actually this is not what we want because all elements may set to zero~ + When x is a 4d tensor with shape `NCHW`, we can set ``axis=[0,1]`` and the dropout will be performed + in channel `N` and `C`, `H` and `W` is tied, i.e. + paddle.nn.dropout(x, p, axis=[0,1]) + This is something we called dropout2d. Please refer to ``paddle.nn.functional.dropout2d`` + for more details. + Similarly, when x is a 5d tensor with shape `NCDHW`, we can set ``axis=[0,1]`` to perform + dropout3d. Please refer to ``paddle.nn.functional.dropout3d`` for more details. + + .. code-block:: python + import paddle + import numpy as np + + paddle.disable_static() + x = np.array([[1,2,3], [4,5,6]]).astype('float32') + x = paddle.to_tensor(x) + y_train = paddle.nn.functional.dropout(x, 0.5) + y_test = paddle.nn.functional.dropout(x, 0.5, training=False) + y_0 = paddle.nn.functional.dropout(x, axis=0) + y_1 = paddle.nn.functional.dropout(x, axis=1) + y_01 = paddle.nn.functional.dropout(x, axis=[0,1]) + print(x.numpy()) + print(y_train.numpy()) + print(y_test.numpy()) + print(y_0.numpy()) + print(y_1.numpy()) + print(y_01.numpy()) + + """ + if not isinstance(p, (float, int)): + raise TypeError("p argument should be a number") + if p < 0 or p > 1: + raise ValueError("p argument should between 0 and 1") + if mode not in ('downscale_in_infer', 'upscale_in_train'): + raise ValueError( + "mode argument should be 'downscale_in_infer' or 'upscale_in_train'") + if axis and not isinstance(axis, (int, list)): + raise TypeError("datatype of axis argument should be int or list") + + if axis == None: # commonly used dropout + seed = None + mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer + + def get_attrs(prog, dropout_prob, is_test, seed): + if (seed is None or seed == 0) and prog.random_seed != 0: + seed = prog.random_seed + attrs = { + 'dropout_prob': dropout_prob, + 'is_test': is_test, + 'fix_seed': seed is not None, + 'seed': seed if seed is not None else 0, + 'dropout_implementation': mode, + } + return attrs + + if in_dygraph_mode(): + if default_main_program().random_seed != 0: + seed = default_main_program().random_seed + out, mask = core.ops.dropout( + x, 'dropout_prob', p, 'is_test', not training, 'fix_seed', + seed is not None, 'seed', seed + if seed is not None else 0, 'dropout_implementation', mode) + return out + + helper = LayerHelper('dropout', **locals()) + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'dropout') + + out = helper.create_variable_for_type_inference(dtype=x.dtype) + mask = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) + + attrs = get_attrs(helper.main_program, p, not training, seed) + + helper.append_op( + type='dropout', + inputs={'X': [x]}, + outputs={'Out': [out], + 'Mask': [mask]}, + attrs=attrs) + return out + else: #sometimes called dropout_nd #TODO: optimize with c++ + if not in_dygraph_mode(): + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'dropout') + dtype = x.dtype + keep_prob = 1 - p + if training: + if p == 1.: + return layers.scale(x, scale=0.) + + scale_input = layers.scale( + x, scale=1 / keep_prob) if mode == 'upscale_in_train' else x + + #get mask shape + input_shape = x.shape + drop_axes = [axis] if isinstance(axis, int) else axis + if max(drop_axes) > len(input_shape) - 1: + raise ValueError("axis value should less than dimensions of x:{}, but get drop_axes value:{} " \ + .format(len(input_shape), max(drop_axes))) + if len(drop_axes) > len(input_shape): + raise ValueError( + "length of axis should not greater than dimensions of x:{}, but get length of drop axes: {}". + format(len(input_shape), len(drop_axes))) + mask_shape = [1] * len(input_shape) + for i in drop_axes: + mask_shape[i] = input_shape[i] + + #get mask + random_tensor = layers.uniform_random( + mask_shape, dtype='float32', min=0., max=1.0) + p = layers.fill_constant(shape=[1], dtype='float32', value=p) + keep_mask = layers.greater_equal(random_tensor, p) + + scale_input = layers.cast(scale_input, dtype) + keep_mask = layers.cast(keep_mask, dtype) + ret = paddle.multiply(scale_input, keep_mask, name=name) + return ret + else: # test + ret = layers.scale( + x, scale=keep_prob) if mode == 'downscale_in_infer' else x + return ret + + +def dropout2d(x, p=0.5, training=True, data_format='NCHW', name=None): + """ + Randomly zero out entire channels (in the batched input 4d tensor with the shape `NCHW` , + a channel is a 2D feature map with the shape `HW` ). Each channel will be zeroed out independently + on every forward call with probability `p` using samples from a Bernoulli distribution. + + See ``paddle.nn.functional.dropout`` for more details. + + Args: + x (Tensor): The input is 4-D Tensor with shape [N, C, H, W] or [N, H, W, C]. + The data type is float32 or float64. + p (float): Probability of setting units to zero. Default 0.5. + training (bool): A flag indicating whether it is in train phrase or not. Default True. + data_format (str, optional): Specify the data format of the input, and the data format of the output + will be consistent with that of the input. An optional string from: + `NCHW` , `NHWC` . The default is `NCHW` . When it is `NCHW` , the data is + stored in the order of: [batch_size, input_channels, input_height, input_width]. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor representing the dropout2d, has same shape and data type as `x` . + + Examples: + .. code-block:: python + import paddle + import numpy as np + + paddle.disable_static() + x = np.random.random(size=(2, 3, 4, 5)).astype('float32') + x = paddle.to_tensor(x) + y_train = paddle.nn.functional.dropout2d(x) #train + y_test = paddle.nn.functional.dropout2d(x, training=False) #test + for i in range(2): + for j in range(3): + print(x.numpy()[i,j,:,:]) + print(y_train.numpy()[i,j,:,:]) # may all 0 + print(y_test.numpy()[i,j,:,:]) + """ + input_shape = x.shape + if len(input_shape) != 4: + raise ValueError("dimensions of x should be 4, but received {} != 4"\ + .format(len(input_shape))) + + if data_format not in ["NCHW", "NHWC"]: + raise ValueError( + "Attr(data_format) should be 'NCHW' or 'NHWC'. Received " + "Attr(data_format): %s." % str(data_format)) + + return dropout( + x, + p=p, + axis=[0, 1] if data_format == 'NCHW' else [0, 3], + training=training, + mode="upscale_in_train", + name=name) + + +def dropout3d(x, p=0.5, training=True, data_format='NCDHW', name=None): + """ + Randomly zero out entire channels (in the batched input 5d tensor with the shape `NCDHW` , + a channel is a 3D feature map with the shape `DHW` ). Each channel will be zeroed out independently + on every forward call with probability `p` using samples from a Bernoulli distribution. + + See ``paddle.nn.functional.dropout`` for more details. + + Args: + x (Tensor): The input is 5-D Tensor with shape [N, C, D, H, W] or [N, D, H, W, C]. + The data type is float32 or float64. + p (float): Probability of setting units to zero. Default 0.5. + training (bool): A flag indicating whether it is in train phrase or not. Default True. + data_format (str, optional): Specify the data format of the input, and the data format of the output + will be consistent with that of the input. An optional string from: + ``NCDHW``, ``NDHWC``. The default is ``NCDHW`` . When it is ``NCDHW`` , the data is + stored in the order of: [batch_size, input_channels, input_depth, input_height, input_width]. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor representing the dropout3d, has same shape and data type with `x` . + + Examples: + .. code-block:: python + import paddle + import numpy as np + + paddle.disable_static() + x = np.random.random(size=(2, 3, 4, 5, 6)).astype('float32') + x = paddle.to_tensor(x) + y_train = paddle.nn.functional.dropout3d(x) #train + y_test = paddle.nn.functional.dropout3d(x, training=False) #test + print(x.numpy()[0,0,:,:,:]) + print(y_train.numpy()[0,0,:,:,:]) # may all 0 + print(y_test.numpy()[0,0,:,:,:]) + """ + + input_shape = x.shape + if len(input_shape) != 5: + raise ValueError("dimensions of x should be 5, but received {} != 5" \ + .format(len(input_shape))) + + if data_format not in ["NCDHW", "NDHWC"]: + raise ValueError( + "Attr(data_format) should be 'NCDHW' or 'NDHWC'. Received " + "Attr(data_format): %s." % str(data_format)) + + return dropout( + x, + p=p, + axis=[0, 1] if data_format == 'NCDHW' else [0, 4], + training=training, + mode="upscale_in_train", + name=name) + + def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None): """ Pad tensor according to 'pad' and 'mode'. diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 4a84f57d6c..342a684c04 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -52,6 +52,9 @@ from .common import Embedding #DEFINE_ALIAS from .common import Linear #DEFINE_ALIAS from .common import Flatten #DEFINE_ALIAS from .common import UpSample #DEFINE_ALIAS +from .common import Dropout #DEFINE_ALIAS +from .common import Dropout2D #DEFINE_ALIAS +from .common import Dropout3D #DEFINE_ALIAS from .pooling import AdaptiveAvgPool2d #DEFINE_ALIAS from .pooling import AdaptiveAvgPool3d #DEFINE_ALIAS from .conv import Conv2D #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index abe6d57260..3034880533 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -20,6 +20,7 @@ from ...fluid.dygraph import Linear #DEFINE_ALIAS from ...fluid.dygraph import Flatten #DEFINE_ALIAS from ...fluid.dygraph import layers from .. import functional as F +from ...fluid.framework import _dygraph_tracer __all__ = [ 'BilinearTensorProduct', @@ -38,6 +39,9 @@ __all__ = [ 'ConstantPad3d', 'ReplicationPad3d', 'CosineSimilarity', + 'Dropout', + 'Dropout2D', + 'Dropout3D', ] @@ -348,6 +352,189 @@ class Pad2D(layers.Layer): data_format=self._data_format) +class Dropout(layers.Layer): + """ + Dropout is a regularization technique for reducing overfitting by preventing + neuron co-adaption during training as described in the paper: + `Improving neural networks by preventing co-adaptation of feature detectors `_ + The dropout operator randomly sets the outputs of some units to zero, while upscale others + according to the given dropout probability. + + See ``paddle.nn.functional.dropout`` for more details. + In dygraph mode, please use ``eval()`` to indicate whether it is in test phrase or not. + + Parameters: + p (float | int): Probability of setting units to zero. Default: 0.5 + axis (int | list): The axis along which the dropout is performed. Default None. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + mode(str, optional): ['upscale_in_train'(default) | 'downscale_in_infer'] + + 1. upscale_in_train(default), upscale the output at training time + + - train: out = input * mask / ( 1.0 - p ) + - inference: out = input + + 2. downscale_in_infer, downscale the output at inference + + - train: out = input * mask + - inference: out = input * (1.0 - p) + + Shape: + - input: N-D tensor. + - output: N-D tensor, the same shape as input. + + Examples: + .. code-block:: python + import paddle + import numpy as np + + paddle.disable_static() + x = np.array([[1,2,3], [4,5,6]]).astype('float32') + x = paddle.to_tensor(x) + m = paddle.nn.Dropout(p=0.5) + y_train = m(x) + m.eval() # switch the model to test phase + y_test = m(x) + print(x.numpy()) + print(y_train.numpy()) + print(y_test.numpy()) + """ + + def __init__(self, p=0.5, axis=None, mode="upscale_in_train", name=None): + super(Dropout, self).__init__() + + self.p = p + self.training = _dygraph_tracer()._train_mode + self.axis = axis + self.mode = mode + self.name = name + + def forward(self, input): + out = F.dropout( + input, + p=self.p, + axis=self.axis, + training=self.training, + mode=self.mode, + name=self.name) + return out + + +class Dropout2D(layers.Layer): + """ + Randomly zero out entire channels (in the batched input 4d tensor with the shape `NCHW` , + a channel is a 2D feature map with the shape `HW`). Each channel will be zeroed out independently + on every forward call with probability `p` using samples from a Bernoulli distribution. + Dropout2d will help promote independence between feature maps as described in the paper: + `Efficient Object Localization Using Convolutional Networks `_ + + See ``paddle.nn.functional.dropout2d`` for more details. + + Please use ``eval()`` to indicate whether it is in test phrase or not. + Parameters: + p (float, optional): Probability of setting units to zero. Default: 0.5 + data_format (str, optional): Specify the data format of the input, and the data format of the output + will be consistent with that of the input. An optional string from: + `NCHW`, `NHWC`. The default is `NCHW`. When it is `NCHW`, the data is + stored in the order of: [batch_size, input_channels, input_height, input_width]. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: 4-D tensor. + - output: 4-D tensor, the same shape as input. + + Examples: + .. code-block:: python + import paddle + import numpy as np + + paddle.disable_static() + x = np.random.random(size=(2, 3, 4, 5)).astype('float32') + x = paddle.to_tensor(x) + m = paddle.nn.Dropout2D(p=0.5) + y_train = m(x) + m.eval() # switch the model to test phase + y_test = m(x) + print(x.numpy()) + print(y_train.numpy()) + print(y_test.numpy()) + """ + + def __init__(self, p=0.5, data_format='NCHW', name=None): + super(Dropout2D, self).__init__() + + self.p = p + self.data_format = data_format + self.name = name + + def forward(self, input): + out = F.dropout2d( + input, + p=self.p, + training=self.training, + data_format=self.data_format, + name=self.name) + return out + + +class Dropout3D(layers.Layer): + """ + Randomly zero out entire channels (in the batched input 5d tensor with the shape `NCDHW` , + a channel is a 3D feature map with the shape `DHW` ). Each channel will be zeroed out independently + on every forward call with probability `p` using samples from a Bernoulli distribution. + Dropout3d will help promote independence between feature maps as described in the paper: + `Efficient Object Localization Using Convolutional Networks `_ + + See ``paddle.nn.functional.dropout3d`` for more details. + + Please use ``eval()`` to indicate whether it is in test phrase or not. + Parameters: + p (float | int): Probability of setting units to zero. Default: 0.5 + data_format (str, optional): Specify the data format of the input, and the data format of the output + will be consistent with that of the input. An optional string from: + `NCDHW`, `NDHWC`. The default is `NCDHW`. When it is `NCDHW`, the data is + stored in the order of: [batch_size, input_channels, input_depth, input_height, input_width]. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: 5-D tensor. + - output: 5-D tensor, the same shape as input. + + Examples: + .. code-block:: python + import paddle + import numpy as np + + paddle.disable_static() + x = np.random.random(size=(2, 3, 4, 5, 6)).astype('float32') + x = paddle.to_tensor(x) + m = paddle.nn.Dropout3D(p=0.5) + y_train = m(x) + m.eval() # switch the model to test phase + y_test = m(x) + print(x.numpy()) + print(y_train.numpy()) + print(y_test.numpy()) + """ + + def __init__(self, p=0.5, data_format='NCDHW', name=None): + super(Dropout3D, self).__init__() + + self.p = p + self.training = _dygraph_tracer()._train_mode + self.data_format = data_format + self.name = name + + def forward(self, input): + out = F.dropout3d( + input, + p=self.p, + training=self.training, + data_format=self.data_format, + name=self.name) + return out + + class ReflectionPad1d(layers.Layer): """ This interface is used to construct a callable object of the ``ReflectionPad1d`` class. -- GitLab