From 5e54821845eaff90e106d32d116b242cb1613575 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 24 May 2021 15:01:27 +0800 Subject: [PATCH] docs(syncbn): complete syncbn document GitOrigin-RevId: 9b8d546f0b1431d502ee6f69fd3ccbd072a73e71 --- imperative/python/megengine/functional/nn.py | 4 +++ .../python/megengine/module/batchnorm.py | 29 +++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 3d46c400d..ce58ac654 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1125,6 +1125,10 @@ def sync_batch_norm( Default: 0.9 :param eps: a value added to the denominator for numerical stability. Default: 1e-5 + :param eps_mode: mode of calculation for eps, "max" or "additive". + Default: "additive" + :param group: communication group, caculate mean and variance between this group. + Default: :obj:`~megengine.distributed.WORLD` :return: output tensor. """ assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format( diff --git a/imperative/python/megengine/module/batchnorm.py b/imperative/python/megengine/module/batchnorm.py index b50b94ea7..7428d0c5b 100644 --- a/imperative/python/megengine/module/batchnorm.py +++ b/imperative/python/megengine/module/batchnorm.py @@ -141,6 +141,35 @@ class _BatchNorm(Module): class SyncBatchNorm(_BatchNorm): r""" Applies Synchronized Batch Normalization for distributed training. + + :type num_features: int + :param num_features: usually :math:`C` from an input of shape + :math:`(N, C, H, W)` or the highest ranked dimension of an input + less than 4D. + :type eps: float + :param eps: a value added to the denominator for numerical stability. + Default: 1e-5 + :type momentum: float + :param momentum: the value used for the ``running_mean`` and ``running_var`` computation. + Default: 0.9 + :type affine: bool + :param affine: a boolean value that when set to True, this module has + learnable affine parameters. Default: True + :type track_running_stats: bool + :param track_running_stats: when set to True, this module tracks the + running mean and variance. When set to False, this module does not + track such statistics and always uses batch statistics in both training + and eval modes. Default: True + :type freeze: bool + :param freeze: when set to True, this module does not update the + running mean and variance, and uses the running mean and variance instead of + the batch mean and batch variance to normalize the input. The parameter takes effect + only when the module is initilized with track_running_stats as True. + Default: False + :type group: :class:`~megengine.distributed.Group` + :param group: communication group, caculate mean and variance between this group. + Default: :obj:`~megengine.distributed.WORLD` + :return: output tensor. """ def __init__( -- GitLab