From 3dfaf44adc989c8798b811e784f48bf278bbec29 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Thu, 12 Mar 2020 21:37:48 +0800 Subject: [PATCH] Rename dygraph parallel env & add doc (#22925) * add dygraph parallel env doc, test=develop * polish details, test=develop, test=document_fix * fix examples error in other apis, test=develop * fix more example error in other api, test=develop * add white list for gpu examples, test=develop, test=document_fix --- python/paddle/fluid/dygraph/parallel.py | 251 +++++++++++++++++++++--- tools/wlist.json | 6 +- 2 files changed, 231 insertions(+), 26 deletions(-) diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index ac75a5dded..bb201d41ec 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -22,7 +22,7 @@ from .. import framework from ..layers import collective from . import to_variable, no_grad -__all__ = ["prepare_context"] +__all__ = ["prepare_context", "ParallelEnv", "DataParallel"] ParallelStrategy = core.ParallelStrategy @@ -37,10 +37,10 @@ def prepare_context(strategy=None): if strategy.nranks < 2: return assert framework.in_dygraph_mode() is True, \ - "dygraph.parallel.prepare_context should be used with dygrahp mode." + "dygraph.prepare_context should be used with dygrahp mode." place = framework._current_expected_place() assert place is not None, \ - "dygraph.parallel.prepare_context should be used in fluid.dygraph.guard(place) guard." + "dygraph.prepare_context should be used in fluid.dygraph.guard(place) guard." if isinstance(place, core.CUDAPlace): parallel_helper._set_parallel_ctx( core.NCCLParallelContext(strategy, place)) @@ -51,7 +51,64 @@ def prepare_context(strategy=None): return strategy -class Env(object): +class ParallelEnv(object): + """ + **Notes**: + **The old class name was Env and will be deprecated. Please use new class name ParallelEnv.** + + This class is used to obtain the environment variables required for + the parallel execution of dynamic graph model. + + The dynamic graph parallel mode needs to be started using paddle.distributed.launch. + By default, the related environment variable is automatically configured by this module. + + This class is generally used in with `fluid.dygraph.DataParallel` to configure dynamic graph models + to run in parallel. + + Examples: + .. code-block:: python + + # This example needs to run with paddle.distributed.launch, The usage is: + # python -m paddle.distributed.launch --selected_gpus=0,1 example.py + # And the content of `example.py` is the code of following example. + + import numpy as np + import paddle.fluid as fluid + import paddle.fluid.dygraph as dygraph + from paddle.fluid.optimizer import AdamOptimizer + from paddle.fluid.dygraph.nn import Linear + from paddle.fluid.dygraph.base import to_variable + + place = fluid.CUDAPlace(fluid.dygraph.ParallelEnv().dev_id) + with fluid.dygraph.guard(place=place): + + # prepare the data parallel context + strategy=dygraph.prepare_context() + + linear = Linear(1, 10, act="softmax") + adam = fluid.optimizer.AdamOptimizer() + + # make the module become the data parallelism module + linear = dygraph.DataParallel(linear, strategy) + + x_data = np.random.random(size=[10, 1]).astype(np.float32) + data = to_variable(x_data) + + hidden = linear(data) + avg_loss = fluid.layers.mean(hidden) + + # scale the loss according to the number of trainers. + avg_loss = linear.scale_loss(avg_loss) + + avg_loss.backward() + + # collect the gradients of trainers. + linear.apply_collective_grads() + + adam.minimize(avg_loss) + linear.clear_gradients() + """ + def __init__(self): self._nranks = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) self._local_rank = int(os.getenv("PADDLE_TRAINER_ID", "0")) @@ -62,34 +119,124 @@ class Env(object): @property def nranks(self): + """ + The number of trainers, generally refers to the number of GPU cards used in training. + + Its value is equal to the value of the environment variable PADDLE_TRAINERS_NUM. The default value is 1. + + Examples: + .. code-block:: python + + # execute this command in terminal: export PADDLE_TRAINERS_NUM=4 + import paddle.fluid as fluid + + env = fluid.dygraph.ParallelEnv() + print("The nranks is %d" % env.nranks) + # The nranks is 4 + """ return self._nranks @property def local_rank(self): + """ + The current trainer number. + + Its value is equal to the value of the environment variable PADDLE_TRAINER_ID. The default value is 0. + + Examples: + .. code-block:: python + + # execute this command in terminal: export PADDLE_TRAINER_ID=0 + import paddle.fluid as fluid + + env = fluid.dygraph.ParallelEnv() + print("The local rank is %d" % env.local_rank) + # The local rank is 0 + """ return self._local_rank @property def dev_id(self): + """ + The ID of selected GPU card for parallel training. + + Its value is equal to the value of the environment variable FLAGS_selected_gpus. The default value is 0. + + Examples: + .. code-block:: python + + # execute this command in terminal: export FLAGS_selected_gpus=1 + import paddle.fluid as fluid + + env = fluid.dygraph.ParallelEnv() + print("The device id are %d" % env.dev_id) + # The device id are 1 + """ return self._dev_id @property def current_endpoint(self): + """ + The endpoint of current trainer, it is in the form of (node IP + port). + + Its value is equal to the value of the environment variable PADDLE_CURRENT_ENDPOINT. The default value is "". + + Examples: + .. code-block:: python + + # execute this command in terminal: export PADDLE_CURRENT_ENDPOINT=127.0.0.1:6170 + import paddle.fluid as fluid + + env = fluid.dygraph.ParallelEnv() + print("The current endpoint are %s" % env.current_endpoint) + # The current endpoint are 127.0.0.1:6170 + """ return self._current_endpoint @property def trainer_endpoints(self): + """ + The endpoints of all trainer nodes in the task, + which are used to broadcast the NCCL ID when NCCL2 is initialized. + + Its value is equal to the value of the environment variable PADDLE_TRAINER_ENDPOINTS. The default value is "". + + Examples: + .. code-block:: python + + # execute this command in terminal: export PADDLE_TRAINER_ENDPOINTS=127.0.0.1:6170,127.0.0.1:6171 + import paddle.fluid as fluid + + env = fluid.dygraph.ParallelEnv() + print("The trainer endpoints are %s" % env.trainer_endpoints) + # The trainer endpoints are ['127.0.0.1:6170', '127.0.0.1:6171'] + """ return self._trainer_endpoints +# NOTE: [ Compatible ] Originally this class name is `Env`. The semantics of the old class names +# are inaccurate and may confuse users, so replace it with `ParallelEnv`, but to be compatible +# with the old examples, here still need to keep this name. +Env = ParallelEnv + + class DataParallel(layers.Layer): """ - Runs the module with data parallelism. + Run the dygraph module with data parallelism. - Currently, DataParallel only supports to run the dynamic graph + Currently, DataParallel class only supports to run the dynamic graph with multi-process. The usage is: - `python -m paddle.distributed.launch --gpus 2 dynamic_graph_test.py`. + `python -m paddle.distributed.launch --selected_gpus=0,1 dynamic_graph_test.py`. And the content of `dynamic_graph_test.py` is the code of examples. + Args: + layers(Layer): The module that should be executed by data parallel. + strategy(ParallelStrategy): The strategy of data parallelism, contains + environment configuration related to parallel execution. + + Returns: + Layer: The data paralleled module. + Examples: .. code-block:: python @@ -100,17 +247,17 @@ class DataParallel(layers.Layer): from paddle.fluid.dygraph.nn import Linear from paddle.fluid.dygraph.base import to_variable - place = fluid.CUDAPlace(0) + place = place = fluid.CUDAPlace(fluid.dygraph.ParallelEnv().dev_id) with fluid.dygraph.guard(place=place): # prepare the data parallel context - strategy=dygraph.parallel.prepare_context() + strategy=dygraph.prepare_context() linear = Linear(1, 10, act="softmax") adam = fluid.optimizer.AdamOptimizer() # make the module become the data parallelism module - linear = dygraph.parallel.DataParallel(linear, strategy) + linear = dygraph.DataParallel(linear, strategy) x_data = np.random.random(size=[10, 1]).astype(np.float32) data = to_variable(x_data) @@ -128,13 +275,6 @@ class DataParallel(layers.Layer): adam.minimize(avg_loss) linear.clear_gradients() - - Args: - layers(Layer): The module that should be executed by data parallel. - strategy(ParallelStrategy): The strategy of data parallelism. - - Returns: - Layer: The data paralleled module. """ def __init__(self, layers, strategy): @@ -154,10 +294,41 @@ class DataParallel(layers.Layer): directly. Args: - loss(Layer): The loss of the current Model. + loss(Variable): The loss of the current Model. Returns: - Layer: the scaled loss. + Variable: the scaled loss. + + Examples: + .. code-block:: python + + import numpy as np + import paddle.fluid as fluid + import paddle.fluid.dygraph as dygraph + from paddle.fluid.optimizer import AdamOptimizer + from paddle.fluid.dygraph.nn import Linear + from paddle.fluid.dygraph.base import to_variable + + place = place = fluid.CUDAPlace(fluid.dygraph.ParallelEnv().dev_id) + with fluid.dygraph.guard(place=place): + strategy=dygraph.prepare_context() + linear = Linear(1, 10, act="softmax") + adam = fluid.optimizer.AdamOptimizer() + linear = dygraph.DataParallel(linear, strategy) + + x_data = np.random.random(size=[10, 1]).astype(np.float32) + data = to_variable(x_data) + hidden = linear(data) + avg_loss = fluid.layers.mean(hidden) + + # scale the loss according to the number of trainers. + avg_loss = linear.scale_loss(avg_loss) + + avg_loss.backward() + linear.apply_collective_grads() + + adam.minimize(avg_loss) + linear.clear_gradients() """ if not self._is_data_parallel_mode(): return loss @@ -211,6 +382,36 @@ class DataParallel(layers.Layer): def apply_collective_grads(self): """ AllReduce the Parameters' gradient. + + Examples: + .. code-block:: python + + import numpy as np + import paddle.fluid as fluid + import paddle.fluid.dygraph as dygraph + from paddle.fluid.optimizer import AdamOptimizer + from paddle.fluid.dygraph.nn import Linear + from paddle.fluid.dygraph.base import to_variable + + place = place = fluid.CUDAPlace(fluid.dygraph.ParallelEnv().dev_id) + with fluid.dygraph.guard(place=place): + strategy=dygraph.prepare_context() + linear = Linear(1, 10, act="softmax") + adam = fluid.optimizer.AdamOptimizer() + linear = dygraph.DataParallel(linear, strategy) + + x_data = np.random.random(size=[10, 1]).astype(np.float32) + data = to_variable(x_data) + hidden = linear(data) + avg_loss = fluid.layers.mean(hidden) + avg_loss = linear.scale_loss(avg_loss) + avg_loss.backward() + + # collect the gradients of trainers. + linear.apply_collective_grads() + + adam.minimize(avg_loss) + linear.clear_gradients() """ if not self._is_data_parallel_mode(): return @@ -276,9 +477,9 @@ class DataParallel(layers.Layer): import paddle.fluid as fluid with fluid.dygraph.guard(): - strategy=dygraph.parallel.prepare_context() + strategy=fluid.dygraph.prepare_context() emb = fluid.dygraph.Embedding([10, 10]) - emb = dygraph.parallel.DataParallel(emb, strategy) + emb = fluid.dygraph.DataParallel(emb, strategy) state_dict = emb.state_dict() fluid.save_dygraph( state_dict, "paddle_dy") @@ -310,9 +511,9 @@ class DataParallel(layers.Layer): import paddle.fluid as fluid with fluid.dygraph.guard(): - strategy=dygraph.parallel.prepare_context() + strategy=fluid.dygraph.prepare_context() emb = fluid.dygraph.Embedding([10, 10]) - emb = dygraph.parallel.DataParallel(emb, strategy) + emb = fluid.dygraph.DataParallel(emb, strategy) state_dict = emb.state_dict() fluid.save_dygraph( state_dict, "paddle_dy") @@ -350,9 +551,9 @@ class DataParallel(layers.Layer): import paddle.fluid as fluid with fluid.dygraph.guard(): - strategy=dygraph.parallel.prepare_context() + strategy=fluid.dygraph.prepare_context() emb = fluid.dygraph.Embedding([10, 10]) - emb = dygraph.parallel.DataParallel(emb, strategy) + emb = fluid.dygraph.DataParallel(emb, strategy) state_dict = emb.state_dict() fluid.save_dygraph( state_dict, "paddle_dy") diff --git a/tools/wlist.json b/tools/wlist.json index 70f0e4f6e1..cb6f9a6c9a 100644 --- a/tools/wlist.json +++ b/tools/wlist.json @@ -310,6 +310,10 @@ "Recall.eval", "FC.forward", "While.block", - "DGCMomentumOptimizer" + "DGCMomentumOptimizer", + "ParallelEnv", + "DataParallel", + "DataParallel.scale_loss", + "DataParallel.apply_collective_grads" ] } -- GitLab