提交 5e548218 编写于 作者: M Megvii Engine Team

docs(syncbn): complete syncbn document

GitOrigin-RevId: 9b8d546f0b1431d502ee6f69fd3ccbd072a73e71
上级 3591ef1f
...@@ -1125,6 +1125,10 @@ def sync_batch_norm( ...@@ -1125,6 +1125,10 @@ def sync_batch_norm(
Default: 0.9 Default: 0.9
:param eps: a value added to the denominator for numerical stability. :param eps: a value added to the denominator for numerical stability.
Default: 1e-5 Default: 1e-5
:param eps_mode: mode of calculation for eps, "max" or "additive".
Default: "additive"
:param group: communication group, caculate mean and variance between this group.
Default: :obj:`~megengine.distributed.WORLD`
:return: output tensor. :return: output tensor.
""" """
assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format( assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format(
......
...@@ -141,6 +141,35 @@ class _BatchNorm(Module): ...@@ -141,6 +141,35 @@ class _BatchNorm(Module):
class SyncBatchNorm(_BatchNorm): class SyncBatchNorm(_BatchNorm):
r""" r"""
Applies Synchronized Batch Normalization for distributed training. Applies Synchronized Batch Normalization for distributed training.
:type num_features: int
:param num_features: usually :math:`C` from an input of shape
:math:`(N, C, H, W)` or the highest ranked dimension of an input
less than 4D.
:type eps: float
:param eps: a value added to the denominator for numerical stability.
Default: 1e-5
:type momentum: float
:param momentum: the value used for the ``running_mean`` and ``running_var`` computation.
Default: 0.9
:type affine: bool
:param affine: a boolean value that when set to True, this module has
learnable affine parameters. Default: True
:type track_running_stats: bool
:param track_running_stats: when set to True, this module tracks the
running mean and variance. When set to False, this module does not
track such statistics and always uses batch statistics in both training
and eval modes. Default: True
:type freeze: bool
:param freeze: when set to True, this module does not update the
running mean and variance, and uses the running mean and variance instead of
the batch mean and batch variance to normalize the input. The parameter takes effect
only when the module is initilized with track_running_stats as True.
Default: False
:type group: :class:`~megengine.distributed.Group`
:param group: communication group, caculate mean and variance between this group.
Default: :obj:`~megengine.distributed.WORLD`
:return: output tensor.
""" """
def __init__( def __init__(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册