From 55bd7e03627c07f3240474b49625ef65225ce523 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Mon, 14 Sep 2020 19:38:11 +0800 Subject: [PATCH] Add convert_syncbn (#2526) (#2629) * add docs --- doc/fluid/api_cn/nn_cn/SyncBatchNorm_cn.rst | 24 +++++++++++++++++++ .../paddle/nn/layer/norm/SyncBatchNorm_cn.rst | 23 ++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/doc/fluid/api_cn/nn_cn/SyncBatchNorm_cn.rst b/doc/fluid/api_cn/nn_cn/SyncBatchNorm_cn.rst index c4d2b742d..6e6b4615c 100644 --- a/doc/fluid/api_cn/nn_cn/SyncBatchNorm_cn.rst +++ b/doc/fluid/api_cn/nn_cn/SyncBatchNorm_cn.rst @@ -62,3 +62,27 @@ SyncBatchNorm print(hidden1.numpy()) # [[[[0.26824948, 1.0936325],[0.26824948, -1.6301316]],[[ 0.8095662, -0.665287],[-1.2744656, 1.1301866 ]]]] + +方法 +::::::::: +convert_sync_batchnorm(layer) +''''''''''''''''''''''''''''' + +该接口用于把 ``BatchNorm*d`` 层转换为 ``SyncBatchNorm`` 层。 + +参数: + - **layer** (paddle.nn.Layer) - 包含一个或多个 ``BatchNorm*d`` 层的模型。 + +返回: + 如果原始模型中有 ``BatchNorm*d`` 层, 则把 ``BatchNorm*d`` 层转换为 ``SyncBatchNorm`` 层的原始模型。 + +**代码示例** + +.. code-block:: python + + import paddle + import paddle.nn as nn + paddle.disable_static() + model = nn.Sequential(nn.Conv2d(3, 5, 3), nn.BatchNorm2d(5)) + sync_model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + diff --git a/doc/paddle/api/paddle/nn/layer/norm/SyncBatchNorm_cn.rst b/doc/paddle/api/paddle/nn/layer/norm/SyncBatchNorm_cn.rst index c4d2b742d..84618185a 100644 --- a/doc/paddle/api/paddle/nn/layer/norm/SyncBatchNorm_cn.rst +++ b/doc/paddle/api/paddle/nn/layer/norm/SyncBatchNorm_cn.rst @@ -62,3 +62,26 @@ SyncBatchNorm print(hidden1.numpy()) # [[[[0.26824948, 1.0936325],[0.26824948, -1.6301316]],[[ 0.8095662, -0.665287],[-1.2744656, 1.1301866 ]]]] +方法 +::::::::: +convert_sync_batchnorm(layer) +''''''''''''''''''''''''''''' + +该接口用于把 ``BatchNorm*d`` 层转换为 ``SyncBatchNorm`` 层。 + +参数: + - **layer** (paddle.nn.Layer) - 包含一个或多个 ``BatchNorm*d`` 层的模型。 + +返回: + 如果原始模型中有 ``BatchNorm*d`` 层, 则把 ``BatchNorm*d`` 层转换为 ``SyncBatchNorm`` 层的原始模型。 + +**代码示例** + +.. code-block:: python + + import paddle + import paddle.nn as nn + paddle.disable_static() + model = nn.Sequential(nn.Conv2d(3, 5, 3), nn.BatchNorm2d(5)) + sync_model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + -- GitLab