未验证 提交 b11340a4 编写于 作者: S songyouwei 提交者: GitHub

support Layer level train/eval mode (#22463)

* Layer add training attr, add Dropout Layer

* add unit test for Dropout
test=develop

* minor fix
test=develop

* add missing args
test=develop

* support global flag in dropout, add docs
test=develop

* minor fix
test=develop

* minor fix
test=develop

* refine code comments
test=develop

* refine Dropout
test=develop

* fix ut
test=develop

* arg expansion
test=develop

* sample code update
test=develop

* prop -> p
test=develop

* fix ut
test=develop

* minor check fix
test=develop
上级 7f3e0eaa
......@@ -76,6 +76,7 @@ class Layer(core.Layer):
"""
def __init__(self, name_scope=None, dtype=core.VarDesc.VarType.FP32):
self.training = True
if name_scope is None:
name_scope = _convert_camel_to_snake(self.__class__.__name__)
self._full_name = unique_name.generate(name_scope)
......@@ -91,10 +92,34 @@ class Layer(core.Layer):
self._forward_post_hooks = collections.OrderedDict()
def train(self):
"""
Sets this Layer and all its sublayers to training mode.
This only effects certain modules like `Dropout` and `BatchNorm`.
Returns:
None
"""
# global setting
framework._dygraph_tracer().train_mode()
# Layer-level setting
self.training = True
for layer in self.sublayers():
layer.train()
def eval(self):
"""
Sets this Layer and all its sublayers to evaluation mode.
This only effects certain modules like `Dropout` and `BatchNorm`.
Returns:
None
"""
# global setting
framework._dygraph_tracer().eval_mode()
# Layer-level setting
self.training = False
for layer in self.sublayers():
layer.eval()
def full_name(self):
"""Full name for this layer, composed by name_scope + "/" + MyLayer.__class__.__name__
......
......@@ -17,11 +17,11 @@ from __future__ import print_function
from six.moves import reduce
from .. import core
from ..layers import utils
from ..layers import nn
from ..layers import nn as F
from .. import dygraph_utils
from . import layers
from ..framework import Variable, in_dygraph_mode, OpProtoHolder, Parameter, _dygraph_tracer, _varbase_creator, default_main_program
from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
from ..framework import Variable, in_dygraph_mode, OpProtoHolder, Parameter, _dygraph_tracer, _varbase_creator
from ..param_attr import ParamAttr
from ..initializer import Normal, Constant, NumpyArrayInitializer
from .. import unique_name
......@@ -31,9 +31,10 @@ import numbers
import logging
__all__ = [
'Conv2D', 'Conv3D', 'Pool2D', 'Linear', 'BatchNorm', 'Embedding', 'GRUUnit',
'LayerNorm', 'NCE', 'PRelu', 'BilinearTensorProduct', 'Conv2DTranspose',
'Conv3DTranspose', 'GroupNorm', 'SpectralNorm', 'TreeConv'
'Conv2D', 'Conv3D', 'Pool2D', 'Linear', 'BatchNorm', 'Dropout', 'Embedding',
'GRUUnit', 'LayerNorm', 'NCE', 'PRelu', 'BilinearTensorProduct',
'Conv2DTranspose', 'Conv3DTranspose', 'GroupNorm', 'SpectralNorm',
'TreeConv'
]
......@@ -1007,7 +1008,9 @@ class BatchNorm(layers.Layer):
Parameters:
num_channels(int): Indicate the number of channels of the input ``Tensor``.
act(str, optional): Activation to be applied to the output of batch normalization. Default: None.
is_test (bool, optional): A flag indicating whether it is in test phrase or not. Default: False.
is_test (bool, optional): A flag indicating whether it is in test phrase or not.
This flag only has effect on static graph mode. For dygraph mode, please use ``eval()``.
Default: False.
momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9.
epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5.
param_attr(ParamAttr, optional): The parameter attribute for Parameter `scale`
......@@ -1134,8 +1137,7 @@ class BatchNorm(layers.Layer):
variance_out = self._variance
if in_dygraph_mode():
_is_test = (not _dygraph_tracer()._train_mode) and (
not self._trainable_statistics)
_is_test = not self.training and not self._trainable_statistics
attrs = ("momentum", self._momentum, "epsilon", self._epsilon,
"is_test", _is_test, "data_layout", self._data_layout,
"use_mkldnn", False, "fuse_with_relu",
......@@ -1157,8 +1159,7 @@ class BatchNorm(layers.Layer):
"data_layout": self._data_layout,
"use_mkldnn": False,
"fuse_with_relu": self._fuse_with_relu,
"use_global_stats": self._use_global_stats,
"trainable_statistics": self._trainable_statistics
"use_global_stats": self._use_global_stats
}
inputs = {
......@@ -1191,6 +1192,115 @@ class BatchNorm(layers.Layer):
return self._helper.append_activation(batch_norm_out, self._act)
class Dropout(layers.Layer):
"""
This interface is used to construct a callable object of the ``Dropout`` class.
For more details, refer to code examples.
Drop or keep each element of input independently. Dropout is a regularization
technique for reducing overfitting by preventing neuron co-adaption during
training. The dropout operator randomly sets (according to the given dropout
probability) the outputs of some units to zero, while others are remain
unchanged.
Dropout layer can be removed for efficiency concern.
Parameters:
p (float, optional): Probability of setting units to zero. Default: 0.5
seed (int, optional): A Python integer used to create random seeds. If this
parameter is set to None, a random seed is used.
NOTE: If an integer seed is given, always the same output
units will be dropped. DO NOT use a fixed seed in training. Default: None.
dropout_implementation(string, optional): ['downgrade_in_infer'(default)|'upscale_in_train']
1. downgrade_in_infer(default), downgrade the outcome at inference
- train: out = input * mask
- inference: out = input * (1.0 - p)
(mask is a tensor same shape with input, value is 0 or 1
ratio of 0 is dropout_prob)
2. upscale_in_train, upscale the outcome at training time
- train: out = input * mask / ( 1.0 - p )
- inference: out = input
(mask is a tensor same shape with input, value is 0 or 1
ratio of 0 is p)
is_test (bool, optional): A flag indicating whether it is in test phrase or not.
This flag only has effect on static graph mode. For dygraph mode, please use ``eval()``.
Default: False.
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
import numpy as np
x = np.random.random(size=(3, 10, 3, 7)).astype('float32')
with fluid.dygraph.guard():
x = to_variable(x)
m = fluid.dygraph.Dropout(p=0.5)
droped_train = m(x)
# switch to eval mode
m.eval()
droped_eval = m(x)
"""
def __init__(self,
p=0.5,
seed=None,
dropout_implementation="downgrade_in_infer",
is_test=False):
super(Dropout, self).__init__()
assert isinstance(p, (float, int)), "p argument should be a number"
assert 0 <= p <= 1, "p argument should between 0 and 1"
self._dropout_prob = p
assert seed is None or isinstance(
seed, int), "seed argument should be None or a integer"
self._seed = seed
assert dropout_implementation in (
'downgrade_in_infer', 'upscale_in_train'
), "dropout_implementation argument should be 'downgrade_in_infer' or 'upscale_in_train'"
self._dropout_implementation = dropout_implementation
self._is_test = is_test
def forward(self, input):
prog = default_main_program()
if (self._seed is None or self._seed == 0) and prog.random_seed != 0:
self._seed = prog.random_seed
attrs = {
'dropout_prob': self._dropout_prob,
'is_test': not self.training
if in_dygraph_mode() else self._is_test,
'fix_seed': self._seed is not None,
'seed': self._seed if self._seed is not None else 0,
'dropout_implementation': self._dropout_implementation,
}
if in_dygraph_mode():
attrs = sum(attrs.items(), ())
out, mask = core.ops.dropout(input, *attrs)
return out
out = self._helper.create_variable_for_type_inference(dtype=input.dtype)
mask = self._helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
self._helper.append_op(
type='dropout',
inputs={'X': [input]},
outputs={'Out': [out],
'Mask': [mask]},
attrs=attrs)
return out
class Embedding(layers.Layer):
"""
**Embedding Layer**
......
......@@ -106,6 +106,33 @@ class TestLayer(LayerTest):
ret = custom(x, do_linear2=True)
self.assertTrue(np.array_equal(ret.numpy().shape, [3, 1]))
def test_dropout(self):
inp = np.ones([3, 32, 32], dtype='float32')
with self.static_graph():
t = layers.data(
name='data',
shape=[3, 32, 32],
dtype='float32',
append_batch_size=False)
dropout = nn.Dropout(p=0.35, seed=1, is_test=False)
ret = dropout(t)
ret2 = fluid.layers.dropout(
t, dropout_prob=0.35, seed=1, is_test=False)
static_ret, static_ret2 = self.get_static_graph_result(
feed={'data': inp}, fetch_list=[ret, ret2])
with self.dynamic_graph():
t = base.to_variable(inp)
dropout = nn.Dropout(p=0.35, seed=1, is_test=False)
dy_ret = dropout(t)
dy_ret2 = fluid.layers.dropout(
t, dropout_prob=0.35, seed=1, is_test=False)
dy_ret_value = dy_ret.numpy()
dy_ret2_value = dy_ret2.numpy()
self.assertTrue(np.array_equal(static_ret, static_ret2))
self.assertTrue(np.array_equal(dy_ret_value, dy_ret2_value))
self.assertTrue(np.array_equal(static_ret, dy_ret_value))
def test_linear(self):
inp = np.ones([3, 32, 32], dtype='float32')
with self.static_graph():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册