提交 37042d5b 编写于 作者: Z zhaojichen

add global batch normalization

上级 549bfb97
......@@ -18,7 +18,7 @@ Layer.
The high-level components(Cells) used to construct the neural network.
"""
from .activation import Softmax, LogSoftmax, ReLU, ReLU6, Tanh, GELU, ELU, Sigmoid, PReLU, get_activation, LeakyReLU, HSigmoid, HSwish
from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm, GroupNorm
from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm, GroupNorm, GlobalBatchNorm
from .container import SequentialCell, CellList
from .conv import Conv2d, Conv2dTranspose
from .lstm import LSTM
......@@ -29,7 +29,7 @@ from .image import ImageGradients, SSIM
__all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
'PReLU', 'get_activation', 'LeakyReLU', 'HSigmoid', 'HSwish', 'ELU',
'BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm',
'BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', 'GlobalBatchNorm',
'SequentialCell', 'CellList',
'Conv2d', 'Conv2dTranspose',
'LSTM',
......
......@@ -20,16 +20,28 @@ from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
import mindspore.context as context
from mindspore._checkparam import check_int_positive, check_bool, check_typename
from mindspore._checkparam import check_bool, check_typename
from mindspore._extends import cell_attr_register
from mindspore.communication.management import get_local_rank_size, get_rank
from mindspore.communication import management
from mindspore._checkparam import check_int_positive
from ..cell import Cell
class _GlobalBNHelper(Cell):
def __init__(self, group):
super(_GlobalBNHelper, self).__init__()
self.group = group
self.reduce = P.AllReduce(P.ReduceOp.SUM, group=self.group).add_prim_attr('fusion', 1)
def construct(self, x):
x = self.reduce(x)
return x
class _BatchNorm(Cell):
"""Batch Normalization base class."""
@cell_attr_register
def __init__(self,
num_features,
group=1,
eps=1e-5,
momentum=0.9,
affine=True,
......@@ -56,6 +68,20 @@ class _BatchNorm(Cell):
gamma_init, num_features), name="gamma", requires_grad=affine)
self.beta = Parameter(initializer(
beta_init, num_features), name="beta", requires_grad=affine)
self.group = check_int_positive(group)
self.rank_id = get_rank()
self.rank_size = get_local_rank_size()
self.device_list = [i for i in range(0, self.rank_size)]
self.rank_list = self.list_group(self.device_list, self.group)
self.rank_list_idx = len(self.rank_list)
for i in range(self.rank_list_idx):
if self.rank_id in self.rank_list[i] and self.group != 1:
self.is_global = True
management.create_group('group' + str(i), self.rank_list[i])
self.all_reduce = _GlobalBNHelper('group' + str(i))
self.shape = P.Shape()
self.reduce_mean = P.ReduceMean()
self.square = P.Square()
if context.get_context("enable_ge"):
self.is_ge_backend = True
......@@ -82,22 +108,52 @@ class _BatchNorm(Cell):
def _check_data_dim(self, x):
raise NotImplementedError
def list_group(self, world_rank, group_size):
if group_size > get_local_rank_size():
raise ValueError("group size can not be greater than local rank size, group size is {}, local_rank_size is {}".format(group_size, get_local_rank_size()))
if len(world_rank) % group_size != 0:
raise ValueError("please make your group size correct.")
world_rank_list = zip(*(iter(world_rank),) *group_size)
group_list = [list(i) for i in world_rank_list]
return group_list
def construct(self, x):
if self.training and self.use_batch_statistics:
if self.is_ge_backend:
y, batch_mean, batch_var, _, _ = \
self.bn_train(x,
self.gamma,
self.beta,
None,
None)
mean_sub = self.sub_mean(self.moving_mean, batch_mean)
temp_mean = self.mul_mean(mean_sub, self.momentum)
mean_sub2 = self.sub_var(self.moving_variance, batch_var)
temp_variance = self.mul_var(mean_sub2, self.momentum)
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
if self.is_global:
x_mean = self.reduce_mean(x)
x_mean_square = self.reduce_mean(self.square(x))
global_batch_mean = self.all_reduce(x_mean) / self.group
global_batch_mean_square = self.all_reduce(x_mean_square) / self.group
global_mean = global_batch_mean
global_var = global_batch_mean_square - self.square(global_batch_mean)
y, batch_mean, batch_var, _, _ = \
self.bn_train(x,
self.gamma,
self.beta,
None,
None)
mean_sub = self.sub_mean(self.moving_mean, global_mean)
temp_mean = self.mul_mean(mean_sub, self.momentum)
mean_sub2 = self.sub_var(self.moving_variance, global_var)
temp_variance = self.mul_var(mean_sub2, self.momentum)
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
else:
y, batch_mean, batch_var, _, _ = \
self.bn_train(x,
self.gamma,
self.beta,
None,
None)
mean_sub = self.sub_mean(self.moving_mean, batch_mean)
temp_mean = self.mul_mean(mean_sub, self.momentum)
mean_sub2 = self.sub_var(self.moving_variance, batch_var)
temp_variance = self.mul_var(mean_sub2, self.momentum)
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
else:
y = self.bn_train(x,
self.gamma,
......@@ -221,6 +277,52 @@ class BatchNorm2d(_BatchNorm):
pass
class GlobalBatchNorm(_BatchNorm):
r"""
Global normalization layer over a N-dimension input.
Global Normalization is cross device synchronized batch normalization. Batch Normalization implementation
only normalize the data within each device. Global normalization will normalize the input within the group.
It has been described in the paper `Batch Normalization: Accelerating Deep Network Training by
Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
feature using a mini-batch of data and the learned parameters which can be described in the following formula.
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Args:
num_features (int): `C` from an expected input of size (N, C, H, W).
group (int): The number of device in each group.
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
momentum (float): A floating hyperparameter of the momentum for the
running_mean and running_var computation. Default: 0.9.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'zeros'.
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'zeros'.
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
the mean value and variance value of specified value. Default: True.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples:
>>> global_bn_op = nn.GlobalBatchNorm(num_features=3, group=4)
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
>>> global_bn_op(input)
"""
class LayerNorm(Cell):
r"""
Applies Layer Normalization over a mini-batch of inputs.
......
......@@ -19,6 +19,7 @@ import pytest
import mindspore.nn as nn
from mindspore.common.api import _executor
from mindspore import Tensor, Parameter
from mindspore.communication.management import init
def test_bn_pars_valid1():
......@@ -70,3 +71,17 @@ def test_compile_groupnorm():
net = nn.GroupNorm(16, 64)
input_data = Tensor(np.random.rand(1,64,256,256).astype(np.float32))
_executor.compile(net, input_data)
class GlobalBNNet(nn.Cell):
def __init__(self):
super(GlobalBNNet, self).__init__()
self.bn = nn.GlobalBatchNorm(num_features = 2, group = 4)
def construct(self, x):
return self.bn(x)
def test_gloabl_bn():
init("hccl")
net = GlobalBNNet()
input_data = Tensor(np.array([[2.4, 2.1], [3.2, 5.4]], dtype=np.float32))
net.set_train()
out = net(input_data)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册