提交 27a9326c 编写于 作者: M Megvii Engine Team

fix(mge/module): fix frozen batch norm

GitOrigin-RevId: 143d468a37694522591971a191a549e3f1dd2d05
上级 c3ba0280
...@@ -35,6 +35,10 @@ class _BatchNorm(Module): ...@@ -35,6 +35,10 @@ class _BatchNorm(Module):
self.track_running_stats = track_running_stats self.track_running_stats = track_running_stats
self._track_running_stats_saved = track_running_stats self._track_running_stats_saved = track_running_stats
self.freeze = freeze self.freeze = freeze
if self.freeze:
assert (
self._track_running_stats_saved
), "track_running_stats must be initilized to True if freeze is True"
tshape = (1, self.num_features, 1, 1) tshape = (1, self.num_features, 1, 1)
if self.affine: if self.affine:
self.weight = Parameter(np.ones(tshape, dtype=np.float32)) self.weight = Parameter(np.ones(tshape, dtype=np.float32))
...@@ -84,10 +88,24 @@ class _BatchNorm(Module): ...@@ -84,10 +88,24 @@ class _BatchNorm(Module):
inp = inp.reshape(new_shape) inp = inp.reshape(new_shape)
if self.freeze and self.training and self._track_running_stats_saved: _weight = self.weight
scale = self.weight * (self.running_var + self.eps) ** (-0.5) _bias = self.bias
bias = self.bias - self.running_mean * scale
return inp * scale.detach() + bias.detach() if self.freeze:
if _weight is not None:
_weight = _weight.detach()
if _bias is not None:
_bias = _bias.detach()
# Need to expand to elementwise operations here
# see MGB_IMPL_OPR_GRAD(BatchNormForward) in src/opr/impl/dnn/batch_norm.cpp
scale = (self.running_var + self.eps) ** (-0.5)
if _weight is not None:
scale *= _weight
bias = -self.running_mean * scale
if _bias is not None:
bias += _bias
return inp * scale + bias
if self.training and self.track_running_stats: if self.training and self.track_running_stats:
exponential_average_factor = self.momentum exponential_average_factor = self.momentum
...@@ -98,8 +116,8 @@ class _BatchNorm(Module): ...@@ -98,8 +116,8 @@ class _BatchNorm(Module):
inp, inp,
self.running_mean if self.track_running_stats else None, self.running_mean if self.track_running_stats else None,
self.running_var if self.track_running_stats else None, self.running_var if self.track_running_stats else None,
self.weight, _weight,
self.bias, _bias,
training=self.training training=self.training
or ((self.running_mean is None) and (self.running_var is None)), or ((self.running_mean is None) and (self.running_var is None)),
momentum=exponential_average_factor, momentum=exponential_average_factor,
...@@ -121,7 +139,7 @@ class _BatchNorm(Module): ...@@ -121,7 +139,7 @@ class _BatchNorm(Module):
class SyncBatchNorm(_BatchNorm): class SyncBatchNorm(_BatchNorm):
r""" r"""
Applies Synchronization Batch Normalization. Applies Synchronized Batch Normalization for distributed training.
""" """
def __init__( def __init__(
...@@ -169,15 +187,25 @@ class SyncBatchNorm(_BatchNorm): ...@@ -169,15 +187,25 @@ class SyncBatchNorm(_BatchNorm):
else: else:
exponential_average_factor = 0.0 # useless exponential_average_factor = 0.0 # useless
_weight = self.weight
_bias = self.bias
if self.freeze:
if _weight is not None:
_weight = _weight.detach()
if _bias is not None:
_bias = _bias.detach()
output = sync_batch_norm( output = sync_batch_norm(
inp, inp,
self.running_mean, self.running_mean,
self.running_var, self.running_var,
self.weight, _weight,
self.bias, _bias,
self.training or not self.track_running_stats, training=(self.training and not self.freeze)
exponential_average_factor, or ((self.running_mean is None) and (self.running_var is None)),
self.eps, momentum=exponential_average_factor,
eps=self.eps,
group=self.group, group=self.group,
) )
...@@ -257,8 +285,7 @@ class BatchNorm2d(_BatchNorm): ...@@ -257,8 +285,7 @@ class BatchNorm2d(_BatchNorm):
:param freeze: when set to True, this module does not update the :param freeze: when set to True, this module does not update the
running mean and variance, and uses the running mean and variance instead of running mean and variance, and uses the running mean and variance instead of
the batch mean and batch variance to normalize the input. The parameter takes effect the batch mean and batch variance to normalize the input. The parameter takes effect
only when the module is initilized with track_running_stats as True and only when the module is initilized with track_running_stats as True.
the module is in training mode.
Default: False Default: False
Examples: Examples:
......
...@@ -11,15 +11,23 @@ import pytest ...@@ -11,15 +11,23 @@ import pytest
import megengine import megengine
import megengine.autodiff as ad import megengine.autodiff as ad
import megengine.distributed as dist
import megengine.functional as F
import megengine.optimizer as optimizer import megengine.optimizer as optimizer
from megengine import Parameter, tensor from megengine import Parameter, tensor
from megengine.distributed.helper import get_device_count_by_fork
from megengine.jit import trace from megengine.jit import trace
from megengine.module import BatchNorm2d, Module from megengine.module import BatchNorm2d, Module, SyncBatchNorm
def test_frozen_bn(): def run_frozen_bn(BNModule, use_trace=False, use_symbolic=False):
nchannel = 3 nchannel = 3
m = BatchNorm2d(nchannel, freeze=True) m = BNModule(nchannel, freeze=True)
var = 4.0
bias = 1.0
shape = (1, nchannel, 1, 1)
m.running_var[...] = var * F.ones(shape)
m.running_mean[...] = bias * F.ones(shape)
saved_var = m.running_var.numpy() saved_var = m.running_var.numpy()
saved_mean = m.running_mean.numpy() saved_mean = m.running_mean.numpy()
...@@ -31,16 +39,45 @@ def test_frozen_bn(): ...@@ -31,16 +39,45 @@ def test_frozen_bn():
optim.clear_grad() optim.clear_grad()
data = np.random.random((6, nchannel, 2, 2)).astype("float32") data = np.random.random((6, nchannel, 2, 2)).astype("float32")
with gm:
loss = m(data).mean()
gm.backward(loss)
optim.step()
np.testing.assert_equal(m.running_var.numpy(), saved_var) def train_fn(d):
np.testing.assert_equal(m.running_mean.numpy(), saved_mean) for _ in range(3):
np.testing.assert_equal(m.weight.numpy(), saved_wt) with gm:
np.testing.assert_equal(m.bias.numpy(), saved_bias) loss = m(d).mean()
np.testing.assert_almost_equal(loss.numpy(), data.mean(), 5) gm.backward(loss)
optim.step()
return loss
if use_trace:
train_fn = trace(train_fn, symbolic=use_symbolic)
for _ in range(3):
loss = train_fn(megengine.Tensor(data))
np.testing.assert_equal(m.running_var.numpy(), saved_var)
np.testing.assert_equal(m.running_mean.numpy(), saved_mean)
np.testing.assert_equal(m.weight.numpy(), saved_wt)
np.testing.assert_equal(m.bias.numpy(), saved_bias)
np.testing.assert_almost_equal(
loss.numpy(), ((data - bias) / np.sqrt(var)).mean(), 5
)
def test_frozen_bn():
run_frozen_bn(BatchNorm2d)
run_frozen_bn(BatchNorm2d, True, False)
run_frozen_bn(BatchNorm2d, True, True)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_frozen_synced_bn():
@dist.launcher(n_gpus=2)
def worker():
run_frozen_bn(SyncBatchNorm)
run_frozen_bn(SyncBatchNorm, True, False)
run_frozen_bn(SyncBatchNorm, True, True)
worker()
def test_bn_no_track_stat(): def test_bn_no_track_stat():
...@@ -112,3 +149,11 @@ def test_trace_bn_forward_twice(): ...@@ -112,3 +149,11 @@ def test_trace_bn_forward_twice():
x = np.ones((1, 1, 32, 32), dtype=np.float32) x = np.ones((1, 1, 32, 32), dtype=np.float32)
y = train_bn(x, net=Simple()) y = train_bn(x, net=Simple())
np.testing.assert_equal(y.numpy(), 0) np.testing.assert_equal(y.numpy(), 0)
# https://github.com/MegEngine/MegEngine/issues/145
def test_frozen_bn_no_affine():
nchannel = 3
m = BatchNorm2d(nchannel, freeze=True, affine=False)
data = megengine.Tensor(np.random.random((6, nchannel, 2, 2)).astype("float32"))
m(data).numpy()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册