提交 40412e26 编写于 作者: M Megvii Engine Team

feat(mge/module): add sync bn

GitOrigin-RevId: ae71a540d1ee044a5879ad029479ed19bc99cfb8
上级 3c32ad6d
......@@ -12,56 +12,14 @@ import megengine._internal as mgb
from megengine._internal.opr_param_defs import CollectiveComm as CollParam
from ..core import Buffer, Parameter, Tensor, wrap_io_tensor
from ..core.graph import get_default_graph
from ..functional import add_update
from .util import (
get_backend,
get_master_ip,
get_master_port,
get_rank,
get_world_size,
is_distributed,
)
from .helper import collective_comm_symvar
from .util import get_rank, is_distributed
@wrap_io_tensor
def _collective_comm(
inp: Union[Tensor, mgb.CompGraph],
key: str,
op: CollParam.Mode,
nr_ranks: Optional[int] = None,
rank: Optional[int] = None,
root: Optional[int] = 0,
dtype: Optional[type] = None,
device: Optional[mgb.CompNode] = None,
comp_graph: Optional[mgb.CompGraph] = None,
) -> Tensor:
"""Helper function for creating collective_comm operators
:param inp: tensor or comp_graph
:param key: unique identifier for collective communication
:param op: mode of collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of the current process, use util.get_rank() as default
:param root: rank of root node, use 0 as default
:param dtype: output data type, use dtype of inp as default
:param device: output comp node, use comp node of inp as default
:param comp_graph: output comp graph, use comp graph of inp as default
"""
return mgb.opr.collective_comm(
inp,
key=str(key),
nr_devices=nr_ranks if nr_ranks is not None else get_world_size(),
rank=rank if rank is not None else get_rank(),
root=root,
server_addr=get_master_ip(),
port=get_master_port(),
param=CollParam(mode=op),
dtype=dtype,
backend=get_backend(),
comp_node=device,
comp_graph=comp_graph,
)
def _collective_comm(*args, **kargs):
return collective_comm_symvar(*args, **kargs)
def reduce_sum(
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Optional, Union
import megengine._internal as mgb
from megengine._internal.opr_param_defs import CollectiveComm as CollParam
from .util import get_backend, get_master_ip, get_master_port, get_rank, get_world_size
def collective_comm_symvar(
inp: Union[mgb.SymbolVar, mgb.CompGraph],
key: str,
op: CollParam.Mode,
nr_ranks: Optional[int] = None,
rank: Optional[int] = None,
root: Optional[int] = 0,
dtype: Optional[type] = None,
device: Optional[mgb.CompNode] = None,
comp_graph: Optional[mgb.CompGraph] = None,
) -> mgb.SymbolVar:
"""Helper function for creating collective_comm operators
:param inp: tensor or comp_graph
:param key: unique identifier for collective communication
:param op: mode of collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of the current process, use util.get_rank() as default
:param root: rank of root node, use 0 as default
:param dtype: output data type, use dtype of inp as default
:param device: output comp node, use comp node of inp as default
:param comp_graph: output comp graph, use comp graph of inp as default
"""
return mgb.opr.collective_comm(
inp,
key=str(key),
nr_devices=nr_ranks if nr_ranks is not None else get_world_size(),
rank=rank if rank is not None else get_rank(),
root=root,
server_addr=get_master_ip(),
port=get_master_port(),
param=CollParam(mode=op),
dtype=dtype,
backend=get_backend(),
comp_node=device,
comp_graph=comp_graph,
)
......@@ -19,6 +19,7 @@ _master_port = 0
_world_size = 0
_rank = 0
_backend = None
_group_id = 0
def init_process_group(
......@@ -43,6 +44,7 @@ def init_process_group(
global _world_size # pylint: disable=global-statement
global _rank # pylint: disable=global-statement
global _backend # pylint: disable=global-statement
global _group_id # pylint: disable=global-statement
if not isinstance(master_ip, str):
raise TypeError("Expect type str but got {}".format(type(master_ip)))
......@@ -60,6 +62,7 @@ def init_process_group(
_world_size = world_size
_rank = rank
_backend = backend
_group_id = 0
set_default_device(mgb.comp_node("gpu" + str(dev)))
......@@ -101,6 +104,13 @@ def get_backend() -> str:
return str(_backend)
def get_group_id() -> int:
"""Get group id for collective communication"""
global _group_id
_group_id += 1
return _group_id
def group_barrier() -> None:
"""Block until all ranks in the group reach this barrier"""
mgb.config.group_barrier(_master_ip, _master_port, _world_size, _rank)
......
......@@ -76,6 +76,7 @@ from .nn import (
roi_pooling,
softmax,
softplus,
sync_batch_norm,
warp_perspective,
)
from .quantized import conv_bias_activation
......
......@@ -11,15 +11,20 @@ from typing import Optional, Tuple, Union
import megengine._internal as mgb
from megengine._internal import CompGraph, CompNode
from megengine._internal.config import add_extra_vardep
from megengine._internal.opr import add_update
from megengine._internal.opr_param_defs import CollectiveComm as CollParam
from .. import distributed as dist
from ..core import Tensor, wrap_io_tensor
from ..core.graph import _use_default_if_none
from ..distributed.util import get_group_id
from ..jit import barrier, mark_impure
from ..random import uniform
from ..utils.types import _pair, _pair_nonzero
from .debug_param import get_conv_execution_strategy
from .elemwise import exp, log
from .tensor import concat, where
from .tensor import where
from .utils import _decide_comp_node_and_comp_graph
......@@ -474,6 +479,125 @@ def batch_norm2d(
return output
@wrap_io_tensor
def sync_batch_norm(
input: Tensor,
running_mean: Tensor,
running_var: Tensor,
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
training: bool = False,
momentum: Union[float, Tensor] = 0.9,
eps: float = 1e-5,
eps_mode="ADDITIVE",
) -> Tensor:
""" Applies synchronized batch normalization to the input.
Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information.
:param inp: input tensor.
:param running_mean: tensor to store running mean.
:param running_var: tensor to store running variance.
:param weight: scaling tensor in the learnable affine parameters.
See :math:`\gamma` in :class:`~.BatchNorm2d`
:param bias: bias tensor in the learnable affine parameters.
See :math:`\beta` in :class:`~.BatchNorm2d`
:param training: a boolean value to indicate whether batch norm is performed
in traning mode. Default: ``False``
:param momentum: the value used for the ``running_mean`` and ``running_var``
computation.
Default: 0.9
:param eps: a value added to the denominator for numerical stability.
Default: 1e-5.
"""
assert eps_mode in {"MAX", "ADDITIVE"}, "unknown eps_mode: {}".format(eps_mode)
input = mgb.opr.mark_no_broadcast_elemwise(input)
_channels = input.imm_shape[1]
_ndim = len(input.imm_shape)
_param_shape = (1, _channels) + (1,) * (_ndim - 2)
if training:
def _sum_on_channel(input):
return mgb.opr.reduce_general([input, _param_shape], mode="sum")
def _allreduce(stat, key):
return dist.helper.collective_comm_symvar(
stat, key, CollParam.Mode.ALL_REDUCE_SUM
)
reduce_size = input.shape[0]
for i in range(2, _ndim):
reduce_size = reduce_size * input.shape[i]
channel_x1s = _sum_on_channel(input)
channel_x2s = _sum_on_channel(input ** 2)
if dist.is_distributed():
# reduce all nodes' data to calculate mean and variance
reduce_size = reduce_size.reshape(*(1,) * _ndim)
stat = mgb.opr.concat([reduce_size, channel_x1s, channel_x2s], axis=1)
stat = _allreduce(stat, key="sync_bn_" + str(get_group_id()))
reduce_size = stat[:, :1].reshape(1)
channel_x1s = stat[:, 1 : 1 + _channels]
channel_x2s = stat[:, 1 + _channels :]
channel_mean = channel_x1s / reduce_size
channel_variance = (
channel_x1s ** 2 / (-reduce_size * reduce_size) + channel_x2s / reduce_size
)
else:
assert running_var is not None and running_mean is not None
channel_variance = running_var.reshape(*_param_shape)
channel_mean = running_mean.reshape(*_param_shape)
invsqrt_channel_variance = (
mgb.opr.elem.max(channel_variance, eps)
if eps_mode == "MAX"
else mgb.opr.elem.add(channel_variance, eps)
) ** -0.5
if weight is not None:
weight = weight.reshape(*_param_shape)
if bias is not None:
bias = bias.reshape(*_param_shape)
# outvar = output * weight + bias
# where output = input * invsqrt_channel_variance + (
# -channel_mean * invsqrt_channel_variance
# )
# Manually expand output for gopt
if weight is not None:
inv_var_wt = invsqrt_channel_variance * weight
neg_channel_mean = -channel_mean
if bias is not None:
outvar = input * inv_var_wt + (neg_channel_mean * inv_var_wt + bias)
else:
outvar = input * inv_var_wt + neg_channel_mean * inv_var_wt
else:
outvar = input * invsqrt_channel_variance + (
-channel_mean * invsqrt_channel_variance
)
if bias is not None:
outvar = outvar + bias
if training and running_var is not None and running_mean is not None:
_mean_update = add_update(
running_mean, channel_mean, alpha=momentum, beta=1 - momentum,
)
channel_variance_unbiased = channel_x1s ** 2 / (
-reduce_size * (reduce_size - 1)
) + channel_x2s / (reduce_size - 1)
_variance_update = add_update(
running_var, channel_variance_unbiased, alpha=momentum, beta=1 - momentum
)
for dep in (_mean_update, _variance_update):
add_extra_vardep(outvar, dep)
return outvar
def one_hot(inp: Tensor, num_classes: int) -> Tensor:
r"""
Perform one-hot encoding for the input tensor.
......
......@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax
from .batchnorm import BatchNorm1d, BatchNorm2d
from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm
from .concat import Concat
from .conv import Conv2d, ConvTranspose2d, LocalConv2d
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d
......
......@@ -9,7 +9,7 @@
import numpy as np
from ..core import Buffer, Parameter
from ..functional import batch_norm2d
from ..functional import batch_norm2d, sync_batch_norm
from . import init
from .module import Module
......@@ -74,7 +74,6 @@ class _BatchNorm(Module):
inp = inp.reshape(new_shape)
_iter_update = None
if self.training and self.track_running_stats:
exponential_average_factor = self.momentum
else:
......@@ -97,6 +96,54 @@ class _BatchNorm(Module):
return output
class SyncBatchNorm(_BatchNorm):
r"""
Applies Synchronization Batch Normalization.
"""
def _check_input_ndim(self, inp):
if len(inp.shape) not in {2, 3, 4}:
raise ValueError(
"expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape))
)
def forward(self, inp):
self._check_input_ndim(inp)
_ndims = len(inp.shape)
if _ndims != 4:
origin_shape = inp.shapeof()
if _ndims == 2:
n, c = inp.shapeof(0), inp.shapeof(1)
new_shape = (n, c, 1, 1)
elif _ndims == 3:
n, c, h = inp.shapeof(0), inp.shapeof(1), inp.shapeof(2)
new_shape = (n, c, h, 1)
inp = inp.reshape(new_shape)
if self.training and self.track_running_stats:
exponential_average_factor = self.momentum
else:
exponential_average_factor = 0.0 # useless
output = sync_batch_norm(
inp,
self.running_mean,
self.running_var,
self.weight,
self.bias,
self.training or not self.track_running_stats,
exponential_average_factor,
self.eps,
)
if _ndims != 4:
output = output.reshape(origin_shape)
return output
class BatchNorm1d(_BatchNorm):
r"""
Applies Batch Normalization over a 2D/3D tensor.
......
......@@ -18,6 +18,7 @@ from .._internal.config import opr_priority_scope
from ..core import Buffer, Parameter, Tensor, TensorDict
from ..core.graph import get_default_graph
from ..distributed import all_reduce_sum, bcast_param, get_world_size, is_distributed
from ..distributed.util import get_group_id
from ..functional import add_update
from ..functional import grad as grad_func
from ..jit import sideeffect
......@@ -152,7 +153,7 @@ class Optimizer(metaclass=ABCMeta):
:param loss: The obtained loss tensor
"""
rst = []
key = 0
priority = 0
params = []
for group in self.param_groups:
for param in group["params"]:
......@@ -173,11 +174,14 @@ class Optimizer(metaclass=ABCMeta):
for param, grad in zip(params, grads):
if is_distributed():
key += 1
with opr_priority_scope(cg, -key):
priority += 1
with opr_priority_scope(cg, -priority):
# all_reduce_mean
grad = all_reduce_sum(grad, key) / get_world_size()
with opr_priority_scope(cg, (1 << 30) - key):
grad = (
all_reduce_sum(grad, "grad_" + str(get_group_id()))
/ get_world_size()
)
with opr_priority_scope(cg, (1 << 30) - priority):
grad_update = add_update(param.grad, grad)
else:
grad_update = add_update(param.grad, grad)
......@@ -216,11 +220,9 @@ class Optimizer(metaclass=ABCMeta):
param.grad.reset_zero()
def bcast_param(self):
key = 0
for group in self.param_groups:
for param in group["params"]:
bcast_param(param, key)
key += 1
bcast_param(param, "bcast_param_" + str(get_group_id()))
def state_dict(self) -> Dict:
r"""Export the optimizer state.
......
......@@ -6,15 +6,86 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import multiprocessing as mp
import numpy as np
import pytest
import megengine as mge
import megengine.distributed as dist
from megengine.core import tensor
from megengine.module import BatchNorm1d, BatchNorm2d
from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm
from megengine.test import assertTensorClose
@pytest.mark.isolated_distributed
def test_syncbn():
nr_chan = 8
data_shape = (3, nr_chan, 4, 16)
momentum = 0.9
eps = 1e-5
running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32)
running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32)
steps = 4
def worker(rank, data, yv_expect, running_mean, running_var):
if not mge.is_cuda_available():
return
dist.init_process_group("localhost", 2333, 4, rank, rank)
bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps)
data_tensor = tensor()
for i in range(steps):
data_tensor.set_value(data[i])
yv = bn(data_tensor)
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)
assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6)
assertTensorClose(running_var, bn.running_var.numpy(), max_err=5e-6)
xv = []
for i in range(steps):
xv.append(np.random.normal(loc=2.3, size=data_shape).astype(np.float32))
xv_transposed = np.transpose(xv[i], [0, 2, 3, 1]).reshape(
(data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
)
mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
sd = np.sqrt(var_biased + eps)
var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1))
running_mean = running_mean * momentum + mean * (1 - momentum)
running_var = running_var * momentum + var_unbiased * (1 - momentum)
yv_expect = (xv[i] - mean) / sd
data = []
for i in range(4):
data.append([])
for j in range(steps):
data[i].append(xv[j][:, :, :, i * 4 : i * 4 + 4])
procs = []
for rank in range(4):
p = mp.Process(
target=worker,
args=(
rank,
data[rank],
yv_expect[:, :, :, rank * 4 : rank * 4 + 4],
running_mean,
running_var,
),
)
p.start()
procs.append(p)
for p in procs:
p.join()
assert p.exitcode == 0
def test_batchnorm():
nr_chan = 8
data_shape = (3, nr_chan, 4)
......@@ -64,6 +135,55 @@ def test_batchnorm():
assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6)
def test_syncbn1d():
nr_chan = 8
data_shape = (3, nr_chan, 4)
momentum = 0.9
bn = SyncBatchNorm(nr_chan, momentum=momentum)
running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32)
running_var = np.ones((1, nr_chan, 1), dtype=np.float32)
data = tensor()
for i in range(3):
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
xv_transposed = np.transpose(xv, [0, 2, 1]).reshape(
(data_shape[0] * data_shape[2], nr_chan)
)
var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1))
sd = np.sqrt(var_biased + bn.eps)
var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1))
running_mean = running_mean * momentum + mean * (1 - momentum)
running_var = running_var * momentum + var_unbiased * (1 - momentum)
data.set_value(xv)
yv = bn(data)
yv_expect = (xv - mean) / sd
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)
assertTensorClose(
running_mean.reshape(-1), bn.running_mean.numpy().reshape(-1), max_err=5e-6
)
assertTensorClose(
running_var.reshape(-1), bn.running_var.numpy().reshape(-1), max_err=5e-6
)
# test set 'training' flag to False
mean_backup = bn.running_mean.numpy()
var_backup = bn.running_var.numpy()
bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
data.set_value(xv)
yv1 = bn(data)
yv2 = bn(data)
assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0)
assertTensorClose(mean_backup, bn.running_mean.numpy(), max_err=0)
assertTensorClose(var_backup, bn.running_var.numpy(), max_err=0)
yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps)
assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6)
def test_batchnorm2d():
nr_chan = 8
data_shape = (3, nr_chan, 16, 16)
......@@ -110,6 +230,52 @@ def test_batchnorm2d():
assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6)
def test_syncbn2d():
nr_chan = 8
data_shape = (3, nr_chan, 16, 16)
momentum = 0.9
bn = SyncBatchNorm(nr_chan, momentum=momentum)
running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32)
running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32)
data = tensor()
for i in range(3):
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
(data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
)
mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
sd = np.sqrt(var_biased + bn.eps)
var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1))
running_mean = running_mean * momentum + mean * (1 - momentum)
running_var = running_var * momentum + var_unbiased * (1 - momentum)
data.set_value(xv)
yv = bn(data)
yv_expect = (xv - mean) / sd
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)
assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6)
assertTensorClose(running_var, bn.running_var.numpy(), max_err=5e-6)
# test set 'training' flag to False
mean_backup = bn.running_mean.numpy()
var_backup = bn.running_var.numpy()
bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
data.set_value(xv)
yv1 = bn(data)
yv2 = bn(data)
assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0)
assertTensorClose(mean_backup, bn.running_mean.numpy(), max_err=0)
assertTensorClose(var_backup, bn.running_var.numpy(), max_err=0)
yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps)
assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6)
def test_batchnorm_no_stats():
nr_chan = 8
data_shape = (3, nr_chan, 4)
......@@ -135,6 +301,31 @@ def test_batchnorm_no_stats():
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)
def test_syncbn_no_stats():
nr_chan = 8
data_shape = (3, nr_chan, 4)
bn = SyncBatchNorm(8, track_running_stats=False)
data = tensor()
for i in range(4):
if i == 2:
bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
var = np.var(
np.transpose(xv, [0, 2, 1]).reshape(
(data_shape[0] * data_shape[2], nr_chan)
),
axis=0,
).reshape((1, nr_chan, 1))
sd = np.sqrt(var + bn.eps)
data.set_value(xv)
yv = bn(data)
yv_expect = (xv - mean) / sd
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)
def test_batchnorm2d_no_stats():
nr_chan = 8
data_shape = (3, nr_chan, 16, 16)
......@@ -157,3 +348,27 @@ def test_batchnorm2d_no_stats():
yv_expect = (xv - mean) / sd
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)
def test_syncbn2d_no_stats():
nr_chan = 8
data_shape = (3, nr_chan, 16, 16)
bn = SyncBatchNorm(8, track_running_stats=False)
data = tensor()
for i in range(4):
if i == 2:
bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
(data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
)
mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
sd = np.sqrt(var + bn.eps)
data.set_value(xv)
yv = bn(data)
yv_expect = (xv - mean) / sd
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册