From 157ff0943e40fbf341f28bd7c9eaa85ab73fc909 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Thu, 7 Jan 2021 08:56:08 +0800 Subject: [PATCH] Cherry pick bn (#30136) * fix bn docs (#30096) * add attribute for batch_norm (#29950) * add attribute for batch_norm --- .../tests/unittests/test_batch_norm_op_v2.py | 55 +++++++++++++++++++ python/paddle/nn/functional/norm.py | 14 ++++- python/paddle/nn/layer/norm.py | 20 ++++--- 3 files changed, 79 insertions(+), 10 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py index 8118961919..b1f751f5ac 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py @@ -222,5 +222,60 @@ class TestBatchNormChannelLast(unittest.TestCase): self.assertEqual(np.allclose(y1.numpy(), y2.numpy()), True) +class TestBatchNormUseGlobalStats(unittest.TestCase): + def setUp(self): + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"): + self.places.append(fluid.CUDAPlace(0)) + self.init_test() + + ### train mode + def init_test(self): + self.use_global_stats = True + self.trainable_statistics = False + + def test_global_stats(self): + for p in self.places: + with fluid.dygraph.guard(p): + x = paddle.randn([2, 6, 6, 4]) + net1 = paddle.fluid.dygraph.BatchNorm( + 6, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(1.0)), + use_global_stats=self.use_global_stats, + trainable_statistics=self.trainable_statistics) + net2 = paddle.nn.BatchNorm2D( + 6, use_global_stats=self.use_global_stats) + net2.weight = net1.weight + net2.bias = net1.bias + if self.trainable_statistics == True: + net1.training = False + net2.training = False + y1 = net1(x) + y2 = net2(x) + self.assertEqual(np.allclose(y1.numpy(), y2.numpy()), True) + + +class TestBatchNormUseGlobalStatsCase1(TestBatchNormUseGlobalStats): + ### test mode + def init_test(self): + self.use_global_stats = False + self.trainable_statistics = True + + +class TestBatchNormUseGlobalStatsCase2(TestBatchNormUseGlobalStats): + ### train mode + def init_test(self): + self.use_global_stats = False + self.trainable_statistics = False + + +class TestBatchNormUseGlobalStatsCase3(TestBatchNormUseGlobalStats): + ### test mode + def init_test(self): + self.use_global_stats = True + self.trainable_statistics = True + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index 56b5068bfb..8d62535a25 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -123,6 +123,7 @@ def batch_norm(x, momentum=0.9, epsilon=1e-05, data_format="NCHW", + use_global_stats=None, name=None): """ Applies Batch Normalization as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift . @@ -139,6 +140,7 @@ def batch_norm(x, momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. training(bool, optional): True means train mode which compute by batch data and track global mean and var during train period. False means inference mode which compute by global mean and var which calculated by train period. Defalut False. data_format(str, optional): Specify the input data format, may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Defalut "NCHW". + use_global_stats(bool|None, optional): Whether to use global mean and variance. If set to False, use the statistics of one mini-batch, if set to True, use the global statistics, if set to None, use global statistics in the test phase and use the statistics of one mini-batch in the training phase. Default: None. name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. Returns: @@ -167,8 +169,6 @@ def batch_norm(x, assert len(x.shape) >= 2, "input dim must be larger than 1" - # we use not training means use_global_status, more details see nn._BatchNormBase - use_global_stats = not training # input ad out must share the memory mean_out = running_mean variance_out = running_var @@ -181,11 +181,18 @@ def batch_norm(x, data_format = 'NCHW' if data_format[1] == 'C' else 'NHWC' + if use_global_stats == None: + use_global_stats = not training + trainable_statistics = False + else: + trainable_statistics = not use_global_stats + if in_dygraph_mode(): # for dygraph need tuple attrs = ("momentum", momentum, "epsilon", epsilon, "data_layout", data_format, "use_mkldnn", False, "fuse_with_relu", False, - "use_global_stats", use_global_stats) + "use_global_stats", use_global_stats, "trainable_statistics", + trainable_statistics) batch_norm_out, _, _, _, _, _ = core.ops.batch_norm( x, weight, bias, running_mean, running_var, mean_out, variance_out, *attrs) @@ -204,6 +211,7 @@ def batch_norm(x, "use_mkldnn": False, "fuse_with_relu": False, "use_global_stats": use_global_stats, + "trainable_statistics": trainable_statistics, } inputs = { diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index b1f6906386..d8a4066cf0 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -550,11 +550,13 @@ class _BatchNormBase(layers.Layer): weight_attr=None, bias_attr=None, data_format='NCHW', + use_global_stats=None, name=None): super(_BatchNormBase, self).__init__() self._num_features = num_features self._weight_attr = weight_attr self._bias_attr = bias_attr + self._use_global_stats = use_global_stats if get_default_dtype() == 'float16': set_default_dtype('float32') @@ -642,14 +644,15 @@ class _BatchNormBase(layers.Layer): training=self.training, momentum=self._momentum, epsilon=self._epsilon, - data_format=self._data_format) + data_format=self._data_format, + use_global_stats=self._use_global_stats) class BatchNorm1D(_BatchNormBase): r""" Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D inputswith additional channel dimension) as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift . - When track_running_stats = False, the :math:`\\mu_{\\beta}` + When use_global_stats = False, the :math:`\\mu_{\\beta}` and :math:`\\sigma_{\\beta}^{2}` are the statistics of one mini-batch. Calculated as follows: @@ -660,7 +663,7 @@ class BatchNorm1D(_BatchNormBase): \\sigma_{\\beta}^{2} &\\gets \\frac{1}{m} \\sum_{i=1}^{m}(x_i - \\ \\mu_{\\beta})^2 \\qquad &//\ mini-batch\ variance \\\\ - When track_running_stats = True, the :math:`\\mu_{\\beta}` + When use_global_stats = True, the :math:`\\mu_{\\beta}` and :math:`\\sigma_{\\beta}^{2}` are not the statistics of one mini-batch. They are global or running statistics (moving_mean and moving_variance). It usually got from the pre-trained model. Calculated as follows: @@ -694,6 +697,7 @@ class BatchNorm1D(_BatchNormBase): will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. data_format(str, optional): Specify the input data format, may be "NC", "NCL" or "NLC". Defalut "NCL". + use_global_stats(bool|None, optional): Whether to use global mean and variance. If set to False, use the statistics of one mini-batch, if set to True, use the global statistics, if set to None, use global statistics in the test phase and use the statistics of one mini-batch in the training phase. Default: None. name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. Shape: @@ -739,7 +743,7 @@ class BatchNorm2D(_BatchNormBase): r""" Applies Batch Normalization over a 4D input (a mini-batch of 2D inputswith additional channel dimension) as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift . - When track_running_stats = False, the :math:`\\mu_{\\beta}` + When use_global_stats = False, the :math:`\\mu_{\\beta}` and :math:`\\sigma_{\\beta}^{2}` are the statistics of one mini-batch. Calculated as follows: @@ -750,7 +754,7 @@ class BatchNorm2D(_BatchNormBase): \\sigma_{\\beta}^{2} &\\gets \\frac{1}{m} \\sum_{i=1}^{m}(x_i - \\ \\mu_{\\beta})^2 \\qquad &//\ mini-batch\ variance \\\\ - When track_running_stats = True, the :math:`\\mu_{\\beta}` + When use_global_stats = True, the :math:`\\mu_{\\beta}` and :math:`\\sigma_{\\beta}^{2}` are not the statistics of one mini-batch. They are global or running statistics (moving_mean and moving_variance). It usually got from the pre-trained model. Calculated as follows: @@ -784,6 +788,7 @@ class BatchNorm2D(_BatchNormBase): will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. data_format(str, optional): Specify the input data format, the data format can be "NCHW" or "NHWC". Default: NCHW. + use_global_stats(bool|None, optional): Whether to use global mean and variance. If set to False, use the statistics of one mini-batch, if set to True, use the global statistics, if set to None, use global statistics in the test phase and use the statistics of one mini-batch in the training phase. Default: None. name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. Shape: @@ -827,7 +832,7 @@ class BatchNorm3D(_BatchNormBase): r""" Applies Batch Normalization over a 5D input (a mini-batch of 3D inputswith additional channel dimension) as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift . - When track_running_stats = False, the :math:`\\mu_{\\beta}` + When use_global_stats = False, the :math:`\\mu_{\\beta}` and :math:`\\sigma_{\\beta}^{2}` are the statistics of one mini-batch. Calculated as follows: @@ -838,7 +843,7 @@ class BatchNorm3D(_BatchNormBase): \\sigma_{\\beta}^{2} &\\gets \\frac{1}{m} \\sum_{i=1}^{m}(x_i - \\ \\mu_{\\beta})^2 \\qquad &//\ mini-batch\ variance \\\\ - When track_running_stats = True, the :math:`\\mu_{\\beta}` + When use_global_stats = True, the :math:`\\mu_{\\beta}` and :math:`\\sigma_{\\beta}^{2}` are not the statistics of one mini-batch. They are global or running statistics (moving_mean and moving_variance). It usually got from the pre-trained model. Calculated as follows: @@ -872,6 +877,7 @@ class BatchNorm3D(_BatchNormBase): will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. data_format(str, optional): Specify the input data format, the data format can be "NCDHW" or "NDHWC. Default: NCDHW. + use_global_stats(bool|None, optional): Whether to use global mean and variance. If set to False, use the statistics of one mini-batch, if set to True, use the global statistics, if set to None, use global statistics in the test phase and use the statistics of one mini-batch in the training phase. Default: None. name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. Shape: -- GitLab