提交 bd817f3a 编写于 作者: M Megvii Engine Team

perf(syncbn): fallback to bn when sync is not required

GitOrigin-RevId: e89fdd87b0f333f369717f30ec3f0b550faac842
上级 ffcb4dac
......@@ -1269,15 +1269,27 @@ def sync_batch_norm(
group: communication group, caculate mean and variance between this group.
Default: :obj:`~megengine.distributed.WORLD`
"""
assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format(
eps_mode
)
# TODO: cudnnBn fastpath
_eps_mode = eps_mode.lower()
assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode)
if _eps_mode == "additive" and not (is_distributed() and training):
return batch_norm(
inp,
running_mean,
running_var,
weight,
bias,
training=training,
momentum=momentum,
eps=eps,
)
_channels = make_shape_tuple(inp.shape)[1]
_ndim = inp.ndim
_device = inp.device
_dtype = inp.dtype
if _ndim != 4:
raise NotImplementedError("sync_batch_norm for ndim != 4")
def _make_full_if_none(x, value):
if x is None:
(x,) = Const(value, dtype=inp.dtype, device=_device)()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册