提交 bd817f3a 编写于 作者: M Megvii Engine Team

perf(syncbn): fallback to bn when sync is not required

GitOrigin-RevId: e89fdd87b0f333f369717f30ec3f0b550faac842
上级 ffcb4dac
...@@ -1269,15 +1269,27 @@ def sync_batch_norm( ...@@ -1269,15 +1269,27 @@ def sync_batch_norm(
group: communication group, caculate mean and variance between this group. group: communication group, caculate mean and variance between this group.
Default: :obj:`~megengine.distributed.WORLD` Default: :obj:`~megengine.distributed.WORLD`
""" """
assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format( _eps_mode = eps_mode.lower()
eps_mode assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode)
) if _eps_mode == "additive" and not (is_distributed() and training):
# TODO: cudnnBn fastpath return batch_norm(
inp,
running_mean,
running_var,
weight,
bias,
training=training,
momentum=momentum,
eps=eps,
)
_channels = make_shape_tuple(inp.shape)[1] _channels = make_shape_tuple(inp.shape)[1]
_ndim = inp.ndim _ndim = inp.ndim
_device = inp.device _device = inp.device
_dtype = inp.dtype _dtype = inp.dtype
if _ndim != 4:
raise NotImplementedError("sync_batch_norm for ndim != 4")
def _make_full_if_none(x, value): def _make_full_if_none(x, value):
if x is None: if x is None:
(x,) = Const(value, dtype=inp.dtype, device=_device)() (x,) = Const(value, dtype=inp.dtype, device=_device)()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册