提交 b9b056f4 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!399 Add Global Batch Normalization

Merge pull request !399 from JichenZhao/syncbn
......@@ -18,7 +18,7 @@ Layer.
The high-level components(Cells) used to construct the neural network.
"""
from .activation import Softmax, LogSoftmax, ReLU, ReLU6, Tanh, GELU, ELU, Sigmoid, PReLU, get_activation, LeakyReLU, HSigmoid, HSwish
from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm, GroupNorm
from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm, GroupNorm, GlobalBatchNorm
from .container import SequentialCell, CellList
from .conv import Conv2d, Conv2dTranspose
from .lstm import LSTM
......@@ -29,7 +29,7 @@ from .image import ImageGradients, SSIM, PSNR
__all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
'PReLU', 'get_activation', 'LeakyReLU', 'HSigmoid', 'HSwish', 'ELU',
'BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm',
'BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', 'GlobalBatchNorm',
'SequentialCell', 'CellList',
'Conv2d', 'Conv2dTranspose',
'LSTM',
......
......@@ -20,8 +20,11 @@ from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
import mindspore.context as context
from mindspore._checkparam import check_int_positive, check_bool, check_typename
from mindspore._checkparam import check_bool, check_typename
from mindspore._extends import cell_attr_register
from mindspore.communication.management import get_group_size, get_rank
from mindspore.communication import management
from mindspore._checkparam import check_int_positive
from ..cell import Cell
......@@ -30,6 +33,7 @@ class _BatchNorm(Cell):
@cell_attr_register
def __init__(self,
num_features,
group=1,
eps=1e-5,
momentum=0.9,
affine=True,
......@@ -56,6 +60,21 @@ class _BatchNorm(Cell):
gamma_init, num_features), name="gamma", requires_grad=affine)
self.beta = Parameter(initializer(
beta_init, num_features), name="beta", requires_grad=affine)
self.group = check_int_positive(group)
if self.group != 1:
self.rank_id = get_rank()
self.rank_size = get_group_size()
self.device_list = [i for i in range(0, self.rank_size)]
self.rank_list = self.list_group(self.device_list, self.group)
self.rank_list_idx = len(self.rank_list)
for i in range(self.rank_list_idx):
if self.rank_id in self.rank_list[i] and self.group != 1:
self.is_global = True
management.create_group('group' + str(i), self.rank_list[i])
self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1)
self.shape = P.Shape()
self.reduce_mean = P.ReduceMean()
self.square = P.Square()
if context.get_context("enable_ge"):
self.is_ge_backend = True
......@@ -82,22 +101,53 @@ class _BatchNorm(Cell):
def _check_data_dim(self, x):
raise NotImplementedError
def list_group(self, world_rank, group_size):
if group_size > get_group_size():
raise ValueError("group size can not be greater than local rank size, group size is {}, "
"local_rank_size is {}".format(group_size, get_group_size()))
if len(world_rank) % group_size != 0:
raise ValueError("please make your group size correct.")
world_rank_list = zip(*(iter(world_rank),) *group_size)
group_list = [list(i) for i in world_rank_list]
return group_list
def construct(self, x):
if self.training and self.use_batch_statistics:
if self.is_ge_backend:
y, batch_mean, batch_var, _, _ = \
self.bn_train(x,
self.gamma,
self.beta,
None,
None)
mean_sub = self.sub_mean(self.moving_mean, batch_mean)
temp_mean = self.mul_mean(mean_sub, self.momentum)
mean_sub2 = self.sub_var(self.moving_variance, batch_var)
temp_variance = self.mul_var(mean_sub2, self.momentum)
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
if self.is_global:
x_mean = self.reduce_mean(x)
x_mean_square = self.reduce_mean(self.square(x))
global_batch_mean = self.all_reduce(x_mean) / self.group
global_batch_mean_square = self.all_reduce(x_mean_square) / self.group
global_mean = global_batch_mean
global_var = global_batch_mean_square - self.square(global_batch_mean)
y, batch_mean, batch_var, _, _ = \
self.bn_train(x,
self.gamma,
self.beta,
None,
None)
mean_sub = self.sub_mean(self.moving_mean, global_mean)
temp_mean = self.mul_mean(mean_sub, self.momentum)
mean_sub2 = self.sub_var(self.moving_variance, global_var)
temp_variance = self.mul_var(mean_sub2, self.momentum)
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
else:
y, batch_mean, batch_var, _, _ = \
self.bn_train(x,
self.gamma,
self.beta,
None,
None)
mean_sub = self.sub_mean(self.moving_mean, batch_mean)
temp_mean = self.mul_mean(mean_sub, self.momentum)
mean_sub2 = self.sub_var(self.moving_variance, batch_var)
temp_variance = self.mul_var(mean_sub2, self.momentum)
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
else:
y = self.bn_train(x,
self.gamma,
......@@ -221,6 +271,55 @@ class BatchNorm2d(_BatchNorm):
pass
class GlobalBatchNorm(_BatchNorm):
r"""
Global normalization layer over a N-dimension input.
Global Normalization is cross device synchronized batch normalization. Batch Normalization implementation
only normalize the data within each device. Global normalization will normalize the input within the group.
It has been described in the paper `Batch Normalization: Accelerating Deep Network Training by
Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
feature using a mini-batch of data and the learned parameters which can be described in the following formula.
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Args:
num_features (int): `C` from an expected input of size (N, C, H, W).
group (int): The number of device in each group.
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
momentum (float): A floating hyperparameter of the momentum for the
running_mean and running_var computation. Default: 0.9.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'zeros'.
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'zeros'.
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
the mean value and variance value of specified value. Default: True.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples:
>>> global_bn_op = nn.GlobalBatchNorm(num_features=3, group=4)
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
>>> global_bn_op(input)
"""
def _check_data_dim(self, x):
if x.dim == 0:
pass
class LayerNorm(Cell):
r"""
Applies Layer Normalization over a mini-batch of inputs.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册