From 27a9326c66b9643b32f57386fc525fc534a5c0e6 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 22 Jan 2021 13:39:56 +0800 Subject: [PATCH] fix(mge/module): fix frozen batch norm GitOrigin-RevId: 143d468a37694522591971a191a549e3f1dd2d05 --- .../python/megengine/module/batchnorm.py | 55 +++++++++++---- imperative/python/test/integration/test_bn.py | 69 +++++++++++++++---- 2 files changed, 98 insertions(+), 26 deletions(-) diff --git a/imperative/python/megengine/module/batchnorm.py b/imperative/python/megengine/module/batchnorm.py index f4d49895..f6d313d2 100644 --- a/imperative/python/megengine/module/batchnorm.py +++ b/imperative/python/megengine/module/batchnorm.py @@ -35,6 +35,10 @@ class _BatchNorm(Module): self.track_running_stats = track_running_stats self._track_running_stats_saved = track_running_stats self.freeze = freeze + if self.freeze: + assert ( + self._track_running_stats_saved + ), "track_running_stats must be initilized to True if freeze is True" tshape = (1, self.num_features, 1, 1) if self.affine: self.weight = Parameter(np.ones(tshape, dtype=np.float32)) @@ -84,10 +88,24 @@ class _BatchNorm(Module): inp = inp.reshape(new_shape) - if self.freeze and self.training and self._track_running_stats_saved: - scale = self.weight * (self.running_var + self.eps) ** (-0.5) - bias = self.bias - self.running_mean * scale - return inp * scale.detach() + bias.detach() + _weight = self.weight + _bias = self.bias + + if self.freeze: + if _weight is not None: + _weight = _weight.detach() + if _bias is not None: + _bias = _bias.detach() + + # Need to expand to elementwise operations here + # see MGB_IMPL_OPR_GRAD(BatchNormForward) in src/opr/impl/dnn/batch_norm.cpp + scale = (self.running_var + self.eps) ** (-0.5) + if _weight is not None: + scale *= _weight + bias = -self.running_mean * scale + if _bias is not None: + bias += _bias + return inp * scale + bias if self.training and self.track_running_stats: exponential_average_factor = self.momentum @@ -98,8 +116,8 @@ class _BatchNorm(Module): inp, self.running_mean if self.track_running_stats else None, self.running_var if self.track_running_stats else None, - self.weight, - self.bias, + _weight, + _bias, training=self.training or ((self.running_mean is None) and (self.running_var is None)), momentum=exponential_average_factor, @@ -121,7 +139,7 @@ class _BatchNorm(Module): class SyncBatchNorm(_BatchNorm): r""" - Applies Synchronization Batch Normalization. + Applies Synchronized Batch Normalization for distributed training. """ def __init__( @@ -169,15 +187,25 @@ class SyncBatchNorm(_BatchNorm): else: exponential_average_factor = 0.0 # useless + _weight = self.weight + _bias = self.bias + + if self.freeze: + if _weight is not None: + _weight = _weight.detach() + if _bias is not None: + _bias = _bias.detach() + output = sync_batch_norm( inp, self.running_mean, self.running_var, - self.weight, - self.bias, - self.training or not self.track_running_stats, - exponential_average_factor, - self.eps, + _weight, + _bias, + training=(self.training and not self.freeze) + or ((self.running_mean is None) and (self.running_var is None)), + momentum=exponential_average_factor, + eps=self.eps, group=self.group, ) @@ -257,8 +285,7 @@ class BatchNorm2d(_BatchNorm): :param freeze: when set to True, this module does not update the running mean and variance, and uses the running mean and variance instead of the batch mean and batch variance to normalize the input. The parameter takes effect - only when the module is initilized with track_running_stats as True and - the module is in training mode. + only when the module is initilized with track_running_stats as True. Default: False Examples: diff --git a/imperative/python/test/integration/test_bn.py b/imperative/python/test/integration/test_bn.py index 84d92a35..6d351408 100644 --- a/imperative/python/test/integration/test_bn.py +++ b/imperative/python/test/integration/test_bn.py @@ -11,15 +11,23 @@ import pytest import megengine import megengine.autodiff as ad +import megengine.distributed as dist +import megengine.functional as F import megengine.optimizer as optimizer from megengine import Parameter, tensor +from megengine.distributed.helper import get_device_count_by_fork from megengine.jit import trace -from megengine.module import BatchNorm2d, Module +from megengine.module import BatchNorm2d, Module, SyncBatchNorm -def test_frozen_bn(): +def run_frozen_bn(BNModule, use_trace=False, use_symbolic=False): nchannel = 3 - m = BatchNorm2d(nchannel, freeze=True) + m = BNModule(nchannel, freeze=True) + var = 4.0 + bias = 1.0 + shape = (1, nchannel, 1, 1) + m.running_var[...] = var * F.ones(shape) + m.running_mean[...] = bias * F.ones(shape) saved_var = m.running_var.numpy() saved_mean = m.running_mean.numpy() @@ -31,16 +39,45 @@ def test_frozen_bn(): optim.clear_grad() data = np.random.random((6, nchannel, 2, 2)).astype("float32") - with gm: - loss = m(data).mean() - gm.backward(loss) - optim.step() - np.testing.assert_equal(m.running_var.numpy(), saved_var) - np.testing.assert_equal(m.running_mean.numpy(), saved_mean) - np.testing.assert_equal(m.weight.numpy(), saved_wt) - np.testing.assert_equal(m.bias.numpy(), saved_bias) - np.testing.assert_almost_equal(loss.numpy(), data.mean(), 5) + def train_fn(d): + for _ in range(3): + with gm: + loss = m(d).mean() + gm.backward(loss) + optim.step() + return loss + + if use_trace: + train_fn = trace(train_fn, symbolic=use_symbolic) + + for _ in range(3): + loss = train_fn(megengine.Tensor(data)) + np.testing.assert_equal(m.running_var.numpy(), saved_var) + np.testing.assert_equal(m.running_mean.numpy(), saved_mean) + np.testing.assert_equal(m.weight.numpy(), saved_wt) + np.testing.assert_equal(m.bias.numpy(), saved_bias) + np.testing.assert_almost_equal( + loss.numpy(), ((data - bias) / np.sqrt(var)).mean(), 5 + ) + + +def test_frozen_bn(): + run_frozen_bn(BatchNorm2d) + run_frozen_bn(BatchNorm2d, True, False) + run_frozen_bn(BatchNorm2d, True, True) + + +@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") +@pytest.mark.isolated_distributed +def test_frozen_synced_bn(): + @dist.launcher(n_gpus=2) + def worker(): + run_frozen_bn(SyncBatchNorm) + run_frozen_bn(SyncBatchNorm, True, False) + run_frozen_bn(SyncBatchNorm, True, True) + + worker() def test_bn_no_track_stat(): @@ -112,3 +149,11 @@ def test_trace_bn_forward_twice(): x = np.ones((1, 1, 32, 32), dtype=np.float32) y = train_bn(x, net=Simple()) np.testing.assert_equal(y.numpy(), 0) + + +# https://github.com/MegEngine/MegEngine/issues/145 +def test_frozen_bn_no_affine(): + nchannel = 3 + m = BatchNorm2d(nchannel, freeze=True, affine=False) + data = megengine.Tensor(np.random.random((6, nchannel, 2, 2)).astype("float32")) + m(data).numpy() -- GitLab