未验证 提交 0a2db7c8 编写于 作者: Z zhangkaihuo 提交者: GitHub

Add sparse SyncBatchNorm (#43520)

* add sparse SyncBatchNorm
上级 ab2aaf8b
......@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle
from paddle.incubate.sparse import nn
import paddle.fluid as fluid
from paddle.fluid.framework import _test_eager_guard
import copy
......@@ -56,11 +57,10 @@ class TestSparseBatchNorm(unittest.TestCase):
# test backward
sparse_y.backward(sparse_y)
assert np.allclose(
dense_x.grad.flatten().numpy(),
sparse_x.grad.values().flatten().numpy(),
atol=1e-5,
rtol=1e-5)
assert np.allclose(dense_x.grad.flatten().numpy(),
sparse_x.grad.values().flatten().numpy(),
atol=1e-5,
rtol=1e-5)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})
def test_error_layout(self):
......@@ -86,5 +86,34 @@ class TestSparseBatchNorm(unittest.TestCase):
# [1, 6, 6, 6, 3]
class TestSyncBatchNorm(unittest.TestCase):
def test_sync_batch_norm(self):
with _test_eager_guard():
x = np.array([[[[0.3, 0.4], [0.3, 0.07]],
[[0.83, 0.37], [0.18, 0.93]]]]).astype('float32')
x = paddle.to_tensor(x)
x = x.to_sparse_coo(len(x.shape) - 1)
if paddle.is_compiled_with_cuda():
sync_batch_norm = nn.SyncBatchNorm(2)
hidden1 = sync_batch_norm(x)
print(hidden1)
def test_convert(self):
base_model = paddle.nn.Sequential(nn.Conv3D(3, 5, 3), nn.BatchNorm(5),
nn.BatchNorm(5))
model = paddle.nn.Sequential(
nn.Conv3D(3, 5, 3), nn.BatchNorm(5),
nn.BatchNorm(5,
weight_attr=fluid.ParamAttr(name='bn.scale'),
bias_attr=fluid.ParamAttr(name='bn.bias')))
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
for idx, sublayer in enumerate(base_model.sublayers()):
if isinstance(sublayer, nn.BatchNorm):
self.assertEqual(isinstance(model[idx], nn.SyncBatchNorm), True)
if __name__ == "__main__":
unittest.main()
......@@ -15,10 +15,10 @@
from . import functional
from .layer.activation import ReLU
from .layer.norm import BatchNorm, SyncBatchNorm
from .layer.activation import Softmax
from .layer.activation import ReLU6
from .layer.activation import LeakyReLU
from .layer.norm import BatchNorm
from .layer.conv import Conv3D
from .layer.conv import SubmConv3D
from .layer.pooling import MaxPool3D
......@@ -29,6 +29,7 @@ __all__ = [
'LeakyReLU',
'Softmax',
'BatchNorm',
'SyncBatchNorm',
'Conv3D',
'SubmConv3D',
'MaxPool3D',
......
......@@ -27,6 +27,8 @@
import paddle
import warnings
from paddle.nn.layer.norm import _BatchNormBase
from paddle.framework import no_grad
class BatchNorm(paddle.nn.BatchNorm1D):
......@@ -157,3 +159,177 @@ class BatchNorm(paddle.nn.BatchNorm1D):
batch_norm_out,
shape=input.shape,
stop_gradient=input.stop_gradient)
class SyncBatchNorm(paddle.nn.SyncBatchNorm):
r"""
This interface is used to construct a callable object of the ``SyncBatchNorm`` class.
It implements the function of the Cross-GPU Synchronized Batch Normalization Layer, and can
be used as a normalizer function for other operations, such as conv2d and fully connected
operations.
The data is normalized by the mean and variance of the channel based on whole mini-batch
, which including data in all gpus.
Refer to `Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift <https://arxiv.org/pdf/1502.03167.pdf>`_
for more details.
When model in training mode, the :math:`\\mu_{\\beta}`
and :math:`\\sigma_{\\beta}^{2}` are the statistics of whole mini-batch data in all gpus.
Calculated as follows:
.. math::
\mu_{\beta} &\gets \frac{1}{m} \sum_{i=1}^{m} x_i \qquad &//\
\ mini-batch\ mean \\
\sigma_{\beta}^{2} &\gets \frac{1}{m} \sum_{i=1}^{m}(x_i - \
\mu_{\beta})^2 \qquad &//\ mini-batch\ variance \\
- :math:`x` : whole mini-batch data in all gpus
- :math:`m` : the size of the whole mini-batch data
When model in evaluation mode, the :math:`\\mu_{\\beta}`
and :math:`\sigma_{\beta}^{2}` are global statistics (moving_mean and moving_variance,
which usually got from the pre-trained model). Global statistics calculated as follows:
.. math::
moving\_mean = moving\_mean * momentum + \mu_{\beta} * (1. - momentum) \quad &// global \ mean \\
moving\_variance = moving\_variance * momentum + \sigma_{\beta}^{2} * (1. - momentum) \quad &// global \ variance \\
The formula of normalization is as follows:
.. math::
\hat{x_i} &\gets \frac{x_i - \mu_\beta} {\sqrt{\
\sigma_{\beta}^{2} + \epsilon}} \qquad &//\ normalize \\
y_i &\gets \gamma \hat{x_i} + \beta \qquad &//\ scale\ and\ shift
- :math:`\epsilon` : add a smaller value to the variance to prevent division by zero
- :math:`\gamma` : trainable scale parameter vector
- :math:`\beta` : trainable shift parameter vector
Note:
If you want to use container to pack your model and has ``SyncBatchNorm`` in the
evaluation phase, please use ``nn.LayerList`` or ``nn.Sequential`` instead of
``list`` to pack the model.
Parameters:
num_features(int): Indicate the number of channels of the input ``Tensor``.
epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5.
momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9.
weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale`
of this layer. If it is set to None or one attribute of ParamAttr, this layerr
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. If it is set to False,
this layer will not have trainable scale parameter. Default: None.
bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of this layer.
If it is set to None or one attribute of ParamAttr, this layer
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. If it is set to False, this layer will not
have trainable bias parameter. Default: None.
Shapes:
input: Tensor that the dimension from 2 to 5.
output: Tensor with the same shape as input.
Examples:
.. code-block:: python
# required: gpu
import paddle
import paddle.incubate.sparse.nn as nn
import numpy as np
x = np.array([[[[0.3, 0.4], [0.3, 0.07]], [[0.83, 0.37], [0.18, 0.93]]]]).astype('float32')
x = paddle.to_tensor(x)
x = x.to_sparse_coo(len(x.shape)-1)
if paddle.is_compiled_with_cuda():
sync_batch_norm = nn.SyncBatchNorm(2)
hidden1 = sync_batch_norm(x)
print(hidden1)
# Tensor(shape=[1, 2, 2, 2], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True,
# indices=[[0, 0, 0, 0],
# [0, 0, 1, 1],
# [0, 1, 0, 1]],
# values=[[-0.40730840, -0.13725480],
# [-0.40730840, -1.20299828],
# [ 1.69877410, -0.23414057],
# [-0.88415730, 1.57439375]])
"""
def __init__(self,
num_features,
momentum=0.9,
epsilon=1e-05,
weight_attr=None,
bias_attr=None,
data_format='NCHW',
name=None):
super(SyncBatchNorm,
self).__init__(num_features, momentum, epsilon, weight_attr,
bias_attr, data_format, name)
def forward(self, x):
assert x.is_sparse_coo(
), "SyncBatchNorm only support SparseTensor in COO format."
out = super(SyncBatchNorm, self).forward(x.values())
return paddle.incubate.sparse.sparse_coo_tensor(
x.indices(), out, shape=x.shape, stop_gradient=x.stop_gradient)
@classmethod
def convert_sync_batchnorm(cls, layer):
"""
Helper function to convert :class: `paddle.incubate.sparse.nn.BatchNorm` layers in the model to :class: `paddle.incubate.sparse.nn.SyncBatchNorm` layers.
Parameters:
layer(paddle.nn.Layer): model containing one or more `BatchNorm` layers.
Returns:
The original model with converted SyncBatchNorm layers. If BatchNorm layer in the model, use SyncBatchNorm layer instead.
Examples:
.. code-block:: python
import paddle
import paddle.incubate.sparse.nn as nn
model = paddle.nn.Sequential(nn.Conv3D(3, 5, 3), nn.BatchNorm(5))
sync_model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
"""
layer_output = layer
if isinstance(layer, _BatchNormBase):
if layer._weight_attr != None and not isinstance(
layer._weight_attr,
bool) and layer._weight_attr.name != None:
layer._weight_attr.name = layer._weight_attr.name + '_sync'
if layer._bias_attr != None and not isinstance(
layer._bias_attr, bool) and layer._bias_attr.name != None:
layer._bias_attr.name = layer._bias_attr.name + '_sync'
#convert sparse BatchNorm
if isinstance(layer, BatchNorm):
layer_output = SyncBatchNorm(layer._num_features,
layer._momentum, layer._epsilon,
layer._weight_attr,
layer._bias_attr,
layer._data_format, layer._name)
#convert dense BatchNorm
else:
layer_output = paddle.nn.SyncBatchNorm(
layer._num_features, layer._momentum, layer._epsilon,
layer._weight_attr, layer._bias_attr, layer._data_format,
layer._name)
if layer._weight_attr != False and layer._bias_attr != False:
with no_grad():
layer_output.weight = layer.weight
layer_output.bias = layer.bias
layer_output._mean = layer._mean
layer_output._variance = layer._variance
for name, sublayer in layer.named_children():
layer_output.add_sublayer(name,
cls.convert_sync_batchnorm(sublayer))
del layer
return layer_output
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册