diff --git a/python_module/megengine/distributed/functional.py b/python_module/megengine/distributed/functional.py index dd353bf5f10bba15fcf1619d9164c763282b757b..ec8d1c55819064215478436f9008c8c0fa91350e 100644 --- a/python_module/megengine/distributed/functional.py +++ b/python_module/megengine/distributed/functional.py @@ -12,56 +12,14 @@ import megengine._internal as mgb from megengine._internal.opr_param_defs import CollectiveComm as CollParam from ..core import Buffer, Parameter, Tensor, wrap_io_tensor -from ..core.graph import get_default_graph from ..functional import add_update -from .util import ( - get_backend, - get_master_ip, - get_master_port, - get_rank, - get_world_size, - is_distributed, -) +from .helper import collective_comm_symvar +from .util import get_rank, is_distributed @wrap_io_tensor -def _collective_comm( - inp: Union[Tensor, mgb.CompGraph], - key: str, - op: CollParam.Mode, - nr_ranks: Optional[int] = None, - rank: Optional[int] = None, - root: Optional[int] = 0, - dtype: Optional[type] = None, - device: Optional[mgb.CompNode] = None, - comp_graph: Optional[mgb.CompGraph] = None, -) -> Tensor: - """Helper function for creating collective_comm operators - - :param inp: tensor or comp_graph - :param key: unique identifier for collective communication - :param op: mode of collective communication - :param nr_ranks: number of ranks, use util.get_world_size() as default - :param rank: rank of the current process, use util.get_rank() as default - :param root: rank of root node, use 0 as default - :param dtype: output data type, use dtype of inp as default - :param device: output comp node, use comp node of inp as default - :param comp_graph: output comp graph, use comp graph of inp as default - """ - return mgb.opr.collective_comm( - inp, - key=str(key), - nr_devices=nr_ranks if nr_ranks is not None else get_world_size(), - rank=rank if rank is not None else get_rank(), - root=root, - server_addr=get_master_ip(), - port=get_master_port(), - param=CollParam(mode=op), - dtype=dtype, - backend=get_backend(), - comp_node=device, - comp_graph=comp_graph, - ) +def _collective_comm(*args, **kargs): + return collective_comm_symvar(*args, **kargs) def reduce_sum( diff --git a/python_module/megengine/distributed/helper.py b/python_module/megengine/distributed/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..4514228ebd48c11987456a7c3bef8cec88cbbfdb --- /dev/null +++ b/python_module/megengine/distributed/helper.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from typing import Optional, Union + +import megengine._internal as mgb +from megengine._internal.opr_param_defs import CollectiveComm as CollParam + +from .util import get_backend, get_master_ip, get_master_port, get_rank, get_world_size + + +def collective_comm_symvar( + inp: Union[mgb.SymbolVar, mgb.CompGraph], + key: str, + op: CollParam.Mode, + nr_ranks: Optional[int] = None, + rank: Optional[int] = None, + root: Optional[int] = 0, + dtype: Optional[type] = None, + device: Optional[mgb.CompNode] = None, + comp_graph: Optional[mgb.CompGraph] = None, +) -> mgb.SymbolVar: + """Helper function for creating collective_comm operators + + :param inp: tensor or comp_graph + :param key: unique identifier for collective communication + :param op: mode of collective communication + :param nr_ranks: number of ranks, use util.get_world_size() as default + :param rank: rank of the current process, use util.get_rank() as default + :param root: rank of root node, use 0 as default + :param dtype: output data type, use dtype of inp as default + :param device: output comp node, use comp node of inp as default + :param comp_graph: output comp graph, use comp graph of inp as default + """ + return mgb.opr.collective_comm( + inp, + key=str(key), + nr_devices=nr_ranks if nr_ranks is not None else get_world_size(), + rank=rank if rank is not None else get_rank(), + root=root, + server_addr=get_master_ip(), + port=get_master_port(), + param=CollParam(mode=op), + dtype=dtype, + backend=get_backend(), + comp_node=device, + comp_graph=comp_graph, + ) diff --git a/python_module/megengine/distributed/util.py b/python_module/megengine/distributed/util.py index ddeab73e38e4c3db7e2b4d42261579c70d47e159..5166a9fcedd1bc44a6d119d5f9159e51b29c04b4 100644 --- a/python_module/megengine/distributed/util.py +++ b/python_module/megengine/distributed/util.py @@ -19,6 +19,7 @@ _master_port = 0 _world_size = 0 _rank = 0 _backend = None +_group_id = 0 def init_process_group( @@ -43,6 +44,7 @@ def init_process_group( global _world_size # pylint: disable=global-statement global _rank # pylint: disable=global-statement global _backend # pylint: disable=global-statement + global _group_id # pylint: disable=global-statement if not isinstance(master_ip, str): raise TypeError("Expect type str but got {}".format(type(master_ip))) @@ -60,6 +62,7 @@ def init_process_group( _world_size = world_size _rank = rank _backend = backend + _group_id = 0 set_default_device(mgb.comp_node("gpu" + str(dev))) @@ -101,6 +104,13 @@ def get_backend() -> str: return str(_backend) +def get_group_id() -> int: + """Get group id for collective communication""" + global _group_id + _group_id += 1 + return _group_id + + def group_barrier() -> None: """Block until all ranks in the group reach this barrier""" mgb.config.group_barrier(_master_ip, _master_port, _world_size, _rank) diff --git a/python_module/megengine/functional/__init__.py b/python_module/megengine/functional/__init__.py index 756dfc986db493bd49c3c08b855e223b4d80667a..82cbc171aa455615476c047d7381b6be5822a9fa 100644 --- a/python_module/megengine/functional/__init__.py +++ b/python_module/megengine/functional/__init__.py @@ -76,6 +76,7 @@ from .nn import ( roi_pooling, softmax, softplus, + sync_batch_norm, warp_perspective, ) from .quantized import conv_bias_activation diff --git a/python_module/megengine/functional/nn.py b/python_module/megengine/functional/nn.py index eea340239f167d5a1051ca2d277d7fc9d4f55ff8..5dc27f0dbb0aa801987cb3e38b3ae5a73e151aa4 100644 --- a/python_module/megengine/functional/nn.py +++ b/python_module/megengine/functional/nn.py @@ -11,15 +11,20 @@ from typing import Optional, Tuple, Union import megengine._internal as mgb from megengine._internal import CompGraph, CompNode +from megengine._internal.config import add_extra_vardep +from megengine._internal.opr import add_update +from megengine._internal.opr_param_defs import CollectiveComm as CollParam +from .. import distributed as dist from ..core import Tensor, wrap_io_tensor from ..core.graph import _use_default_if_none +from ..distributed.util import get_group_id from ..jit import barrier, mark_impure from ..random import uniform from ..utils.types import _pair, _pair_nonzero from .debug_param import get_conv_execution_strategy from .elemwise import exp, log -from .tensor import concat, where +from .tensor import where from .utils import _decide_comp_node_and_comp_graph @@ -474,6 +479,125 @@ def batch_norm2d( return output +@wrap_io_tensor +def sync_batch_norm( + input: Tensor, + running_mean: Tensor, + running_var: Tensor, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + training: bool = False, + momentum: Union[float, Tensor] = 0.9, + eps: float = 1e-5, + eps_mode="ADDITIVE", +) -> Tensor: + """ Applies synchronized batch normalization to the input. + + Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information. + + :param inp: input tensor. + :param running_mean: tensor to store running mean. + :param running_var: tensor to store running variance. + :param weight: scaling tensor in the learnable affine parameters. + See :math:`\gamma` in :class:`~.BatchNorm2d` + :param bias: bias tensor in the learnable affine parameters. + See :math:`\beta` in :class:`~.BatchNorm2d` + :param training: a boolean value to indicate whether batch norm is performed + in traning mode. Default: ``False`` + :param momentum: the value used for the ``running_mean`` and ``running_var`` + computation. + Default: 0.9 + :param eps: a value added to the denominator for numerical stability. + Default: 1e-5. + """ + + assert eps_mode in {"MAX", "ADDITIVE"}, "unknown eps_mode: {}".format(eps_mode) + input = mgb.opr.mark_no_broadcast_elemwise(input) + _channels = input.imm_shape[1] + _ndim = len(input.imm_shape) + _param_shape = (1, _channels) + (1,) * (_ndim - 2) + + if training: + + def _sum_on_channel(input): + return mgb.opr.reduce_general([input, _param_shape], mode="sum") + + def _allreduce(stat, key): + return dist.helper.collective_comm_symvar( + stat, key, CollParam.Mode.ALL_REDUCE_SUM + ) + + reduce_size = input.shape[0] + for i in range(2, _ndim): + reduce_size = reduce_size * input.shape[i] + channel_x1s = _sum_on_channel(input) + channel_x2s = _sum_on_channel(input ** 2) + + if dist.is_distributed(): + # reduce all nodes' data to calculate mean and variance + reduce_size = reduce_size.reshape(*(1,) * _ndim) + stat = mgb.opr.concat([reduce_size, channel_x1s, channel_x2s], axis=1) + stat = _allreduce(stat, key="sync_bn_" + str(get_group_id())) + reduce_size = stat[:, :1].reshape(1) + channel_x1s = stat[:, 1 : 1 + _channels] + channel_x2s = stat[:, 1 + _channels :] + + channel_mean = channel_x1s / reduce_size + channel_variance = ( + channel_x1s ** 2 / (-reduce_size * reduce_size) + channel_x2s / reduce_size + ) + else: + assert running_var is not None and running_mean is not None + channel_variance = running_var.reshape(*_param_shape) + channel_mean = running_mean.reshape(*_param_shape) + + invsqrt_channel_variance = ( + mgb.opr.elem.max(channel_variance, eps) + if eps_mode == "MAX" + else mgb.opr.elem.add(channel_variance, eps) + ) ** -0.5 + + if weight is not None: + weight = weight.reshape(*_param_shape) + if bias is not None: + bias = bias.reshape(*_param_shape) + + # outvar = output * weight + bias + # where output = input * invsqrt_channel_variance + ( + # -channel_mean * invsqrt_channel_variance + # ) + # Manually expand output for gopt + + if weight is not None: + inv_var_wt = invsqrt_channel_variance * weight + neg_channel_mean = -channel_mean + if bias is not None: + outvar = input * inv_var_wt + (neg_channel_mean * inv_var_wt + bias) + else: + outvar = input * inv_var_wt + neg_channel_mean * inv_var_wt + else: + outvar = input * invsqrt_channel_variance + ( + -channel_mean * invsqrt_channel_variance + ) + if bias is not None: + outvar = outvar + bias + + if training and running_var is not None and running_mean is not None: + _mean_update = add_update( + running_mean, channel_mean, alpha=momentum, beta=1 - momentum, + ) + channel_variance_unbiased = channel_x1s ** 2 / ( + -reduce_size * (reduce_size - 1) + ) + channel_x2s / (reduce_size - 1) + _variance_update = add_update( + running_var, channel_variance_unbiased, alpha=momentum, beta=1 - momentum + ) + for dep in (_mean_update, _variance_update): + add_extra_vardep(outvar, dep) + + return outvar + + def one_hot(inp: Tensor, num_classes: int) -> Tensor: r""" Perform one-hot encoding for the input tensor. diff --git a/python_module/megengine/module/__init__.py b/python_module/megengine/module/__init__.py index 0391f29d1c788378645045f23e42b4eb47ac97a1..7fe65951cc63cb30427c58b5b347e2f9a7ab9590 100644 --- a/python_module/megengine/module/__init__.py +++ b/python_module/megengine/module/__init__.py @@ -7,7 +7,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax -from .batchnorm import BatchNorm1d, BatchNorm2d +from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm from .concat import Concat from .conv import Conv2d, ConvTranspose2d, LocalConv2d from .conv_bn_relu import ConvBn2d, ConvBnRelu2d diff --git a/python_module/megengine/module/batchnorm.py b/python_module/megengine/module/batchnorm.py index 127dd77a0d45ea378a828c14d5ba6d8cb01766fe..cc33d67511fbfe730ffa4b2fe76d0f1ce9a09a81 100644 --- a/python_module/megengine/module/batchnorm.py +++ b/python_module/megengine/module/batchnorm.py @@ -9,7 +9,7 @@ import numpy as np from ..core import Buffer, Parameter -from ..functional import batch_norm2d +from ..functional import batch_norm2d, sync_batch_norm from . import init from .module import Module @@ -74,7 +74,6 @@ class _BatchNorm(Module): inp = inp.reshape(new_shape) - _iter_update = None if self.training and self.track_running_stats: exponential_average_factor = self.momentum else: @@ -97,6 +96,54 @@ class _BatchNorm(Module): return output +class SyncBatchNorm(_BatchNorm): + r""" + Applies Synchronization Batch Normalization. + """ + + def _check_input_ndim(self, inp): + if len(inp.shape) not in {2, 3, 4}: + raise ValueError( + "expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape)) + ) + + def forward(self, inp): + self._check_input_ndim(inp) + + _ndims = len(inp.shape) + if _ndims != 4: + origin_shape = inp.shapeof() + if _ndims == 2: + n, c = inp.shapeof(0), inp.shapeof(1) + new_shape = (n, c, 1, 1) + elif _ndims == 3: + n, c, h = inp.shapeof(0), inp.shapeof(1), inp.shapeof(2) + new_shape = (n, c, h, 1) + + inp = inp.reshape(new_shape) + + if self.training and self.track_running_stats: + exponential_average_factor = self.momentum + else: + exponential_average_factor = 0.0 # useless + + output = sync_batch_norm( + inp, + self.running_mean, + self.running_var, + self.weight, + self.bias, + self.training or not self.track_running_stats, + exponential_average_factor, + self.eps, + ) + + if _ndims != 4: + output = output.reshape(origin_shape) + + return output + + class BatchNorm1d(_BatchNorm): r""" Applies Batch Normalization over a 2D/3D tensor. diff --git a/python_module/megengine/optimizer/optimizer.py b/python_module/megengine/optimizer/optimizer.py index cfbbad978d2d5b1ebbfb7a2a3b194f5ee4a7aff4..295d1b6a09df21376631ebee122566b0bf36f649 100644 --- a/python_module/megengine/optimizer/optimizer.py +++ b/python_module/megengine/optimizer/optimizer.py @@ -18,6 +18,7 @@ from .._internal.config import opr_priority_scope from ..core import Buffer, Parameter, Tensor, TensorDict from ..core.graph import get_default_graph from ..distributed import all_reduce_sum, bcast_param, get_world_size, is_distributed +from ..distributed.util import get_group_id from ..functional import add_update from ..functional import grad as grad_func from ..jit import sideeffect @@ -152,7 +153,7 @@ class Optimizer(metaclass=ABCMeta): :param loss: The obtained loss tensor """ rst = [] - key = 0 + priority = 0 params = [] for group in self.param_groups: for param in group["params"]: @@ -173,11 +174,14 @@ class Optimizer(metaclass=ABCMeta): for param, grad in zip(params, grads): if is_distributed(): - key += 1 - with opr_priority_scope(cg, -key): + priority += 1 + with opr_priority_scope(cg, -priority): # all_reduce_mean - grad = all_reduce_sum(grad, key) / get_world_size() - with opr_priority_scope(cg, (1 << 30) - key): + grad = ( + all_reduce_sum(grad, "grad_" + str(get_group_id())) + / get_world_size() + ) + with opr_priority_scope(cg, (1 << 30) - priority): grad_update = add_update(param.grad, grad) else: grad_update = add_update(param.grad, grad) @@ -216,11 +220,9 @@ class Optimizer(metaclass=ABCMeta): param.grad.reset_zero() def bcast_param(self): - key = 0 for group in self.param_groups: for param in group["params"]: - bcast_param(param, key) - key += 1 + bcast_param(param, "bcast_param_" + str(get_group_id())) def state_dict(self) -> Dict: r"""Export the optimizer state. diff --git a/python_module/test/unit/module/test_batchnorm.py b/python_module/test/unit/module/test_batchnorm.py index 7f1b1b04a2a257bb51d9d399a8be48fb04e8f52c..d23a10f164ae307552aac91c0614fc50709596d3 100644 --- a/python_module/test/unit/module/test_batchnorm.py +++ b/python_module/test/unit/module/test_batchnorm.py @@ -6,15 +6,86 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import multiprocessing as mp + import numpy as np import pytest import megengine as mge +import megengine.distributed as dist from megengine.core import tensor -from megengine.module import BatchNorm1d, BatchNorm2d +from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm from megengine.test import assertTensorClose +@pytest.mark.isolated_distributed +def test_syncbn(): + nr_chan = 8 + data_shape = (3, nr_chan, 4, 16) + momentum = 0.9 + eps = 1e-5 + running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32) + running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) + steps = 4 + + def worker(rank, data, yv_expect, running_mean, running_var): + if not mge.is_cuda_available(): + return + dist.init_process_group("localhost", 2333, 4, rank, rank) + bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps) + data_tensor = tensor() + for i in range(steps): + data_tensor.set_value(data[i]) + yv = bn(data_tensor) + + assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) + assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6) + assertTensorClose(running_var, bn.running_var.numpy(), max_err=5e-6) + + xv = [] + for i in range(steps): + xv.append(np.random.normal(loc=2.3, size=data_shape).astype(np.float32)) + xv_transposed = np.transpose(xv[i], [0, 2, 3, 1]).reshape( + (data_shape[0] * data_shape[2] * data_shape[3], nr_chan) + ) + + mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) + + var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) + sd = np.sqrt(var_biased + eps) + + var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1)) + running_mean = running_mean * momentum + mean * (1 - momentum) + running_var = running_var * momentum + var_unbiased * (1 - momentum) + + yv_expect = (xv[i] - mean) / sd + + data = [] + for i in range(4): + data.append([]) + for j in range(steps): + data[i].append(xv[j][:, :, :, i * 4 : i * 4 + 4]) + + procs = [] + for rank in range(4): + p = mp.Process( + target=worker, + args=( + rank, + data[rank], + yv_expect[:, :, :, rank * 4 : rank * 4 + 4], + running_mean, + running_var, + ), + ) + p.start() + procs.append(p) + + for p in procs: + p.join() + assert p.exitcode == 0 + + def test_batchnorm(): nr_chan = 8 data_shape = (3, nr_chan, 4) @@ -64,6 +135,55 @@ def test_batchnorm(): assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6) +def test_syncbn1d(): + nr_chan = 8 + data_shape = (3, nr_chan, 4) + momentum = 0.9 + bn = SyncBatchNorm(nr_chan, momentum=momentum) + running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32) + running_var = np.ones((1, nr_chan, 1), dtype=np.float32) + data = tensor() + for i in range(3): + xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) + mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True) + xv_transposed = np.transpose(xv, [0, 2, 1]).reshape( + (data_shape[0] * data_shape[2], nr_chan) + ) + + var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1)) + sd = np.sqrt(var_biased + bn.eps) + + var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1)) + running_mean = running_mean * momentum + mean * (1 - momentum) + running_var = running_var * momentum + var_unbiased * (1 - momentum) + + data.set_value(xv) + yv = bn(data) + yv_expect = (xv - mean) / sd + + assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) + assertTensorClose( + running_mean.reshape(-1), bn.running_mean.numpy().reshape(-1), max_err=5e-6 + ) + assertTensorClose( + running_var.reshape(-1), bn.running_var.numpy().reshape(-1), max_err=5e-6 + ) + + # test set 'training' flag to False + mean_backup = bn.running_mean.numpy() + var_backup = bn.running_var.numpy() + bn.training = False + xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) + data.set_value(xv) + yv1 = bn(data) + yv2 = bn(data) + assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) + assertTensorClose(mean_backup, bn.running_mean.numpy(), max_err=0) + assertTensorClose(var_backup, bn.running_var.numpy(), max_err=0) + yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps) + assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6) + + def test_batchnorm2d(): nr_chan = 8 data_shape = (3, nr_chan, 16, 16) @@ -110,6 +230,52 @@ def test_batchnorm2d(): assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6) +def test_syncbn2d(): + nr_chan = 8 + data_shape = (3, nr_chan, 16, 16) + momentum = 0.9 + bn = SyncBatchNorm(nr_chan, momentum=momentum) + running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32) + running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) + data = tensor() + for i in range(3): + xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) + xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape( + (data_shape[0] * data_shape[2] * data_shape[3], nr_chan) + ) + + mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) + + var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) + sd = np.sqrt(var_biased + bn.eps) + + var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1)) + running_mean = running_mean * momentum + mean * (1 - momentum) + running_var = running_var * momentum + var_unbiased * (1 - momentum) + + data.set_value(xv) + yv = bn(data) + yv_expect = (xv - mean) / sd + + assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) + assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6) + assertTensorClose(running_var, bn.running_var.numpy(), max_err=5e-6) + + # test set 'training' flag to False + mean_backup = bn.running_mean.numpy() + var_backup = bn.running_var.numpy() + bn.training = False + xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) + data.set_value(xv) + yv1 = bn(data) + yv2 = bn(data) + assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) + assertTensorClose(mean_backup, bn.running_mean.numpy(), max_err=0) + assertTensorClose(var_backup, bn.running_var.numpy(), max_err=0) + yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps) + assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6) + + def test_batchnorm_no_stats(): nr_chan = 8 data_shape = (3, nr_chan, 4) @@ -135,6 +301,31 @@ def test_batchnorm_no_stats(): assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) +def test_syncbn_no_stats(): + nr_chan = 8 + data_shape = (3, nr_chan, 4) + bn = SyncBatchNorm(8, track_running_stats=False) + data = tensor() + for i in range(4): + if i == 2: + bn.training = False + xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) + mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True) + var = np.var( + np.transpose(xv, [0, 2, 1]).reshape( + (data_shape[0] * data_shape[2], nr_chan) + ), + axis=0, + ).reshape((1, nr_chan, 1)) + sd = np.sqrt(var + bn.eps) + + data.set_value(xv) + yv = bn(data) + yv_expect = (xv - mean) / sd + + assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) + + def test_batchnorm2d_no_stats(): nr_chan = 8 data_shape = (3, nr_chan, 16, 16) @@ -157,3 +348,27 @@ def test_batchnorm2d_no_stats(): yv_expect = (xv - mean) / sd assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) + + +def test_syncbn2d_no_stats(): + nr_chan = 8 + data_shape = (3, nr_chan, 16, 16) + bn = SyncBatchNorm(8, track_running_stats=False) + data = tensor() + for i in range(4): + if i == 2: + bn.training = False + xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) + xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape( + (data_shape[0] * data_shape[2] * data_shape[3], nr_chan) + ) + + mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) + var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) + sd = np.sqrt(var + bn.eps) + + data.set_value(xv) + yv = bn(data) + yv_expect = (xv - mean) / sd + + assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)