diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 89770ccc8cec1b0c5fcfc4c1033a691cc674566e..a770275bf08c05cbc2c5474f456df8de289e9689 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -57,6 +57,9 @@ std::map> op_outs_map = { {"batch_norm", {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", "ReserveSpace"}}, + {"sync_batch_norm", + {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", + "ReserveSpace"}}, }; // NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are @@ -76,6 +79,7 @@ std::map> op_passing_outs_map = { {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}}, {"momentum", {"ParamOut", "VelocityOut"}}, {"batch_norm", {"MeanOut", "VarianceOut"}}, + {"sync_batch_norm", {"MeanOut", "VarianceOut"}}, {"accuracy", {"Correct", "Total"}}, {"fill_constant", {"Out"}}, {"matmul", {"Out"}}, diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index a7ffc2dd63ae6542582ae777b0c584cc8e22cd58..45744841fc5be5d76585b2dc70841cea5c7ca764 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -35,7 +35,7 @@ __all__ = [ 'Conv2D', 'Conv3D', 'Pool2D', 'Linear', 'BatchNorm', 'Dropout', 'Embedding', 'GRUUnit', 'InstanceNorm', 'LayerNorm', 'NCE', 'PRelu', 'BilinearTensorProduct', 'Conv2DTranspose', 'Conv3DTranspose', 'GroupNorm', - 'SpectralNorm', 'TreeConv', 'Flatten' + 'SpectralNorm', 'TreeConv', 'Flatten', 'SyncBatchNorm' ] @@ -3202,6 +3202,220 @@ class TreeConv(layers.Layer): return self._helper.append_activation(pre_activation, act=self._act) +class SyncBatchNorm(layers.Layer): + """ + This interface is used to construct a callable object of the ``SyncBatchNorm`` class. + It implements the function of the Cross-GPU Synchronized Batch Normalization Layer, and can + be used as a normalizer function for other operations, such as conv2d and fully connected + operations. + The data is normalized by the mean and variance of the channel based on whole mini-batch + , which including data in all gpus. + Refer to `Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift `_ + for more details. + + When model in training mode, the :math:`\\mu_{\\beta}` + and :math:`\\sigma_{\\beta}^{2}` are the statistics of whole mini-batch data in all gpus. + Calculated as follows: + + .. math:: + + \\mu_{\\beta} &\\gets \\frac{1}{m} \\sum_{i=1}^{m} x_i \\qquad &//\\ + \ mini-batch\ mean \\\\ + \\sigma_{\\beta}^{2} &\\gets \\frac{1}{m} \\sum_{i=1}^{m}(x_i - \\ + \\mu_{\\beta})^2 \\qquad &//\ mini-batch\ variance \\\\ + + - :math:`x` : whole mini-batch data in all gpus + - :math:`m` : the size of the whole mini-batch data + + When model in evaluation mode, the :math:`\\mu_{\\beta}` + and :math:`\\sigma_{\\beta}^{2}` are global statistics (moving_mean and moving_variance, + which usually got from the pre-trained model). Global statistics calculated as follows: + + .. math:: + moving\_mean = moving\_mean * momentum + \mu_{\beta} * (1. - momentum) \quad &// global mean \\ + moving\_variance = moving\_variance * momentum + \sigma_{\beta}^{2} * (1. - momentum) \quad &// global variance \\ + + The formula of normalization is as follows: + + .. math:: + + \\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\ + \\sigma_{\\beta}^{2} + \\eps}} \\qquad &//\ normalize \\\\ + y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift + + - :math:`\\eps` : add a smaller value to the variance to prevent division by zero + - :math:`\\gamma` : trainable scale parameter vector + - :math:`\\beta` : trainable shift parameter vector + + Parameters: + num_features(int): Indicate the number of channels of the input ``Tensor``. + epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5. + momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. + weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale` + of this layer. If it is set to None or one attribute of ParamAttr, this layerr + will create ParamAttr as param_attr. If the Initializer of the param_attr + is not set, the parameter is initialized with Xavier. If it is set to False, + this layer will not have trainable scale parameter. Default: None. + bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of this layer. + If it is set to None or one attribute of ParamAttr, this layer + will create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized zero. If it is set to False, this layer will not + have trainable bias parameter. Default: None. + track_running_stats(bool, optional): Whether to compute global stats, which including running mean and + running variance. Default: True. + + Returns: + None + + Examples: + .. code-block:: python + + import paddle + import paddle.nn as nn + import numpy as np + + x = np.array([[[[0.3, 0.4], [0.3, 0.07]], [[0.83, 0.37], [0.18, 0.93]]]]).astype('float32') + paddle.disable_static() + x = paddle.to_tensor(x) + if paddle.fluid.is_compiled_with_cuda(): + sync_batch_norm = nn.SyncBatchNorm(2) + hidden1 = sync_batch_norm(x) + print(hidden1.numpy()) + # [[[[0.26824948, 1.0936325],[0.26824948, -1.6301316]],[[ 0.8095662, -0.665287],[-1.2744656, 1.1301866 ]]]] + """ + + def __init__(self, + num_features, + epsilon=1e-05, + momentum=0.9, + track_running_stats=True, + weight_attr=None, + bias_attr=None, + data_format='NCHW', + name=None): + super(SyncBatchNorm, self).__init__() + self._weight_attr = weight_attr + self._bias_attr = bias_attr + self._num_features = num_features + self._data_layout = data_format + self._momentum = momentum + self._epsilon = epsilon + self._track_running_stats = track_running_stats + + if self._track_running_stats == False: + logging.warn( + "moving mean and moving variance will be calculated whether `track_running_stats` is set to `True` or `False`, we will fix it in the next version." + ) + + param_shape = [self._num_features] + + # create parameter + if weight_attr == False: + self.weight = self.create_parameter( + attr=None, shape=param_shape, default_initializer=Constant(1.0)) + self.weight.stop_gradient = True + else: + self.weight = self.create_parameter( + attr=self._weight_attr, + shape=param_shape, + default_initializer=Constant(1.0)) + self.weight.stop_gradient = self._weight_attr != None and self._weight_attr.learning_rate == 0. + + if bias_attr == False: + self.bias = self.create_parameter( + attr=None, + shape=param_shape, + default_initializer=Constant(0.0), + is_bias=True) + self.bias.stop_gradient = True + else: + self.bias = self.create_parameter( + attr=self._bias_attr, shape=param_shape, is_bias=True) + self.bias.stop_gradient = self._weight_attr != None and self._weight_attr.learning_rate == 0. + + self._mean = self.create_parameter( + attr=ParamAttr( + name=None, + initializer=Constant(0.0), + trainable=False, + do_model_average=True), + shape=param_shape, + dtype=self._dtype) + self._mean.stop_gradient = True + + self._variance = self.create_parameter( + attr=ParamAttr( + name=None, + initializer=Constant(1.0), + trainable=False, + do_model_average=True), + shape=param_shape, + dtype=self._dtype) + self._variance.stop_gradient = True + + def forward(self, x): + # create output + # mean and mean_out share the same memory + mean_out = self._mean + # variance and variance out share the same memory + variance_out = self._variance + + ### train mode: use mini-batch stats, eval mode: use global stats + if in_dygraph_mode(): + attrs = ("momentum", self._momentum, "epsilon", self._epsilon, + "is_test", not self.training, "data_layout", + self._data_layout, "use_mkldnn", False, "fuse_with_relu", + False, "use_global_stats", not self.training, + 'trainable_statistics', False) + sync_batch_norm_out, _, _, _, _, _ = core.ops.sync_batch_norm( + x, self.weight, self.bias, self._mean, self._variance, mean_out, + variance_out, *attrs) + + return sync_batch_norm_out + + check_variable_and_dtype(x, 'input', ['float16', 'float32', 'float64'], + 'BatchNorm') + + attrs = { + "momentum": self._momentum, + "epsilon": self._epsilon, + "is_test": not self.training, + "data_layout": self._data_layout, + "use_mkldnn": False, + "fuse_with_relu": False, + "use_global_stats": not self.training, + "trainable_statistics": False, + } + + inputs = { + "X": [x], + "Scale": [self.weight], + "Bias": [self.bias], + "Mean": [self._mean], + "Variance": [self._variance] + } + + saved_mean = self._helper.create_variable_for_type_inference( + dtype=self._dtype, stop_gradient=True) + saved_variance = self._helper.create_variable_for_type_inference( + dtype=self._dtype, stop_gradient=True) + sync_batch_norm_out = self._helper.create_variable_for_type_inference( + self._dtype) + + outputs = { + "Y": [sync_batch_norm_out], + "MeanOut": [mean_out], + "VarianceOut": [variance_out], + "SavedMean": [saved_mean], + "SavedVariance": [saved_variance] + } + + self._helper.append_op( + type="sync_batch_norm", inputs=inputs, outputs=outputs, attrs=attrs) + return sync_batch_norm_out + + class Flatten(layers.Layer): """ :alias_main: paddle.nn.Flatten diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 05f5d174f8774b20089db6f383bbedc5ca8ab21f..126b4465eae480d1d012eb706667b125dce5f0ea 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -106,6 +106,7 @@ if (NOT ${WITH_GPU}) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_se_resnext) LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sparse_embedding) LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_transformer) + LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sync_batch_norm) LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision) elseif(${CUDNN_VERSION} VERSION_LESS 7100) LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op) diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_sync_batch_norm.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_sync_batch_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..5e2059592b5170fdd623f4a20b9fa47612ff2a6a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_sync_batch_norm.py @@ -0,0 +1,108 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import contextlib +import unittest +import numpy as np +import six +import pickle + +import paddle +import paddle.fluid as fluid +import paddle.fluid.dygraph as dygraph +from paddle.fluid import core +from paddle.fluid.optimizer import SGDOptimizer +from paddle.nn import Conv2D, Pool2D, Linear, SyncBatchNorm +from paddle.fluid.dygraph.base import to_variable + +from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase + + +class TestLayer(fluid.dygraph.Layer): + def __init__(self, + num_channels, + num_filters, + filter_size, + stride=1, + groups=1, + act=None): + super(TestLayer, self).__init__() + + self._conv = Conv2D( + num_channels=num_channels, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + bias_attr=False) + + self._sync_batch_norm = SyncBatchNorm(num_filters) + + self._conv2 = Conv2D( + num_channels=num_filters, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + bias_attr=False) + + self._sync_batch_norm2 = SyncBatchNorm( + num_filters, + weight_attr=False, + bias_attr=False, + track_running_stats=False) + + def forward(self, inputs): + y = self._conv(inputs) + y = self._sync_batch_norm(y) + y = self._conv2(y) + y = self._sync_batch_norm2(y) + + return y + + +class TestSyncBatchNorm(TestParallelDyGraphRunnerBase): + def get_model(self): + model = TestLayer(3, 64, 7) + train_reader = paddle.batch( + paddle.dataset.flowers.test(use_xmap=False), + batch_size=32, + drop_last=True) + opt = fluid.optimizer.Adam( + learning_rate=1e-3, parameter_list=model.parameters()) + return model, train_reader, opt + + def run_one_loop(self, model, opt, data): + batch_size = len(data) + dy_x_data = np.array([x[0].reshape(3, 224, 224) + for x in data]).astype('float32') + img = to_variable(dy_x_data) + img.stop_gradient = False + + out = model(img) + + out = fluid.layers.mean(out) + + return out + + +if __name__ == "__main__": + runtime_main(TestSyncBatchNorm) diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 9da70e85f01c0a13a87766a1befbda206c510cbe..91186b2e95ae00b15503dc29feb1c2b1039e744f 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -283,6 +283,24 @@ class TestLayer(LayerTest): with self.assertRaises(ValueError): lm(base.to_variable(inp)) + def test_SyncBatchNorm(self): + if core.is_compiled_with_cuda(): + with self.static_graph(): + t = layers.data(name='t', shape=[-1, 3, 5, 5], dtype='float32') + my_sync_bn = nn.SyncBatchNorm(3) + ret = my_sync_bn(t) + static_ret = self.get_static_graph_result( + feed={'t': np.ones( + [3, 3, 5, 5], dtype='float32')}, + fetch_list=[ret])[0] + + with self.dynamic_graph(): + t = np.ones([3, 3, 5, 5], dtype='float32') + my_syncbn = paddle.nn.SyncBatchNorm(3) + dy_ret = my_syncbn(base.to_variable(t)) + dy_ret_value = dy_ret.numpy() + self.assertTrue(np.array_equal(static_ret, static_ret)) + def test_relu(self): with self.static_graph(): t = layers.data(name='t', shape=[3, 3], dtype='float32') diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sync_batch_norm.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sync_batch_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..5c34b35fc83a3de6e1a33a51dad1e4e264afd52b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sync_batch_norm.py @@ -0,0 +1,40 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +from test_dist_base import TestDistBase +import paddle.fluid as fluid + +import os +flag_name = os.path.splitext(__file__)[0] + + +class TestParallelDygraphMnist(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._nccl2_mode = True + self._dygraph = False #True + + def test_mnist(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_sync_batch_norm.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py index 8fd118c0193035fce294aa6ac23951d57ba43f78..806b6b90e7e2d39a3d4e5f3792cd849022097777 100644 --- a/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py @@ -25,6 +25,7 @@ import six import paddle.fluid.core as core import paddle.fluid as fluid from paddle.fluid import compiler +from paddle.fluid import Program, program_guard from op_test import OpTest, _set_use_system_allocator @@ -202,5 +203,22 @@ class TestFP16SyncBatchNormOpTraining(TestSyncBatchNormOpTraining): self.atol = 1e-2 +class TestDygraphSyncBatchNormAPIError(unittest.TestCase): + def test_errors(self): + if not core.is_compiled_with_cuda(): + return + + with program_guard(Program(), Program()): + my_sync_batch_norm = fluid.dygraph.SyncBatchNorm(10) + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CUDAPlace(0)) + self.assertRaises(TypeError, my_sync_batch_norm, x1) + + # the input dtype of SyncBatchNorm must be float16 or float32 or float64 + # float16 only can be set on GPU place + x2 = fluid.layers.data(name='x2', shape=[3, 4, 5, 6], dtype="int32") + self.assertRaises(TypeError, my_sync_batch_norm, x2) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 6249ce59e21d3655b4dc4a0a1dea60052558fcdb..7b6dcdf7f67dece7dd4e8ebafdc9a1c679dd18ba 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -92,6 +92,7 @@ from .layer.loss import BCELoss #DEFINE_ALIAS from .layer.loss import KLDivLoss #DEFINE_ALIAS from .layer.loss import MarginRankingLoss #DEFINE_ALIAS from .layer.norm import BatchNorm #DEFINE_ALIAS +from .layer.norm import SyncBatchNorm #DEFINE_ALIAS from .layer.norm import GroupNorm #DEFINE_ALIAS from .layer.norm import LayerNorm #DEFINE_ALIAS from .layer.norm import SpectralNorm #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 9fb8ea78a16ab4872c80f04849e239e73d0cf28a..f64252da5428a09a5530eb1d2cec375c9141ea9a 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -65,6 +65,7 @@ from .loss import BCELoss #DEFINE_ALIAS from .loss import KLDivLoss #DEFINE_ALIAS from .loss import MarginRankingLoss #DEFINE_ALIAS from .norm import BatchNorm #DEFINE_ALIAS +from .norm import SyncBatchNorm #DEFINE_ALIAS from .norm import GroupNorm #DEFINE_ALIAS from .norm import LayerNorm #DEFINE_ALIAS from .norm import SpectralNorm #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 1beba62c1809ffd94a22712fb24ac43a0ec23ff1..1d00f9c7b8b0204affed690bcea2f0ff78a943d1 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -20,7 +20,9 @@ from ...fluid.dygraph import BatchNorm #DEFINE_ALIAS from ...fluid.dygraph import GroupNorm #DEFINE_ALIAS from ...fluid.dygraph import LayerNorm #DEFINE_ALIAS from ...fluid.dygraph import SpectralNorm #DEFINE_ALIAS +from ...fluid.dygraph import SyncBatchNorm #DEFINE_ALIAS __all__ = [ - 'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'InstanceNorm' + 'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'InstanceNorm', + 'SyncBatchNorm' ]