diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index fdc9eb7eda828291962100418da6058d246cfb37..305c30f7e9a2d1c8678db248cb96428c2139a214 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1269,15 +1269,27 @@ def sync_batch_norm( group: communication group, caculate mean and variance between this group. Default: :obj:`~megengine.distributed.WORLD` """ - assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format( - eps_mode - ) - # TODO: cudnnBn fastpath + _eps_mode = eps_mode.lower() + assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode) + if _eps_mode == "additive" and not (is_distributed() and training): + return batch_norm( + inp, + running_mean, + running_var, + weight, + bias, + training=training, + momentum=momentum, + eps=eps, + ) _channels = make_shape_tuple(inp.shape)[1] _ndim = inp.ndim _device = inp.device _dtype = inp.dtype + if _ndim != 4: + raise NotImplementedError("sync_batch_norm for ndim != 4") + def _make_full_if_none(x, value): if x is None: (x,) = Const(value, dtype=inp.dtype, device=_device)()