From 3474f96c01385ca25c999ec5d0beb1860ffa2c21 Mon Sep 17 00:00:00 2001 From: JepsonWong <2013000149@qq.com> Date: Thu, 12 Mar 2020 12:20:28 +0800 Subject: [PATCH] polish,test=develop --- doc/fluid/api_cn/dygraph_cn/DataParallel_cn.rst | 6 +++--- doc/fluid/api_cn/dygraph_cn/Env_cn.rst | 2 +- scripts/api_white_list.txt | 1 + 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/doc/fluid/api_cn/dygraph_cn/DataParallel_cn.rst b/doc/fluid/api_cn/dygraph_cn/DataParallel_cn.rst index d1ba63391..a16b0be5c 100644 --- a/doc/fluid/api_cn/dygraph_cn/DataParallel_cn.rst +++ b/doc/fluid/api_cn/dygraph_cn/DataParallel_cn.rst @@ -45,7 +45,7 @@ DataParallel hidden = linear(data) avg_loss = fluid.layers.mean(hidden) - # 根据trainers的数量来损失值进行缩放 + # 根据trainers的数量来损失值进行缩放,其中trainers为参与训练GPU卡的数量。 avg_loss = linear.scale_loss(avg_loss) avg_loss.backward() @@ -58,7 +58,7 @@ DataParallel .. py:method:: scale_loss(loss) -对损失值进行缩放。在数据并行模式下,损失值根据 ``trainers`` 的数量缩放一定的比例;反之,返回原始的损失值。在 ``backward`` 前调用,示例如上。 +对损失值进行缩放。在数据并行模式下,损失值根据 ``trainers`` 的数量缩放一定的比例;反之,返回原始的损失值。在 ``backward`` 前调用,示例如上。其中 ``trainers`` 为参与训练GPU卡的数量。 参数: - **loss** (Variable) - 当前模型的损失值 @@ -69,5 +69,5 @@ DataParallel .. py:method:: apply_collective_grads() -使用AllReduce模式来计算数据并行模式下多个 ``trainers`` 模型之间参数梯度的均值。在 ``backward`` 之后调用,示例如上。 +使用AllReduce模式来计算数据并行模式下多个 ``trainers`` 模型之间参数梯度的均值。在 ``backward`` 之后调用,示例如上。其中 ``trainers`` 为参与训练GPU卡的数量。 diff --git a/doc/fluid/api_cn/dygraph_cn/Env_cn.rst b/doc/fluid/api_cn/dygraph_cn/Env_cn.rst index 47c67f4b5..9ed878cc7 100644 --- a/doc/fluid/api_cn/dygraph_cn/Env_cn.rst +++ b/doc/fluid/api_cn/dygraph_cn/Env_cn.rst @@ -27,7 +27,7 @@ Env通常需要和 `fluid.dygraph.parallel.DataParallel` 一起使用,用于 # 准备数据并行的环境 strategy = dygraph.parallel.prepare_context() linear = Linear(1, 10, act="softmax") - adam = fluid.optimizer.AdamOptimizer() + adam = fluid.optimizer.AdamOptimizer(parameter_list=linear.parameters()) # 配置模型为并行模型 linear = dygraph.parallel.DataParallel(linear, strategy) x_data = np.random.random(size=[10, 1]).astype(np.float32) diff --git a/scripts/api_white_list.txt b/scripts/api_white_list.txt index e0ba3fbc5..fb3a21d7c 100644 --- a/scripts/api_white_list.txt +++ b/scripts/api_white_list.txt @@ -8,3 +8,4 @@ transpiler_cn/RoundRobin_cn.rst optimizer_cn/Dpsgd_cn.rst io_cn/ComposeNotAligned_cn.rst dygraph_cn/DataParallel_cn.rst +dygraph_cn/Env_cn.rst -- GitLab