未验证 提交 28cb6531 编写于 作者: C Chen Weihang 提交者: GitHub

Remove backend argument of init_parallel_env (#26773)

* remove backend argument of init_parallel_env

* remove keep name table in transformer

* add cpu version check

* add skip unittest for init_parallel_env

* polish doc: remove func use & update example
上级 9ded7565
...@@ -29,13 +29,13 @@ __all__ = ["init_parallel_env"] ...@@ -29,13 +29,13 @@ __all__ = ["init_parallel_env"]
ParallelStrategy = core.ParallelStrategy ParallelStrategy = core.ParallelStrategy
def init_parallel_env(backend='nccl'): def init_parallel_env():
""" """
Initialize parallel training environments in dynamic mode. Initialize parallel training environment in dynamic graph mode.
Args: .. note::
backend(str, optional): The backend to communication between multiple devices. Now only supports initializing the GPU parallel training
Now only support ``nccl`` . Default value is ``nccl`` . environment and using NCCL for communication.
Returns: Returns:
None None
...@@ -89,14 +89,12 @@ def init_parallel_env(backend='nccl'): ...@@ -89,14 +89,12 @@ def init_parallel_env(backend='nccl'):
dist.spawn(train) dist.spawn(train)
""" """
# 1. input check # 1. gpu check
if not isinstance(backend, six.string_types): if not core.is_compiled_with_cuda():
raise TypeError("input `backend` type error, expected type is str, " raise NotImplementedError(
"but received type is %s." % type(backend)) "Cannot initialize parallel environment in CPU-only version, now only "
if cpt.to_text(backend) != 'nccl': "supports initializing the GPU parallel environment. Please recompile "
raise ValueError( "or reinstall paddle with GPU support.")
"backend `%s` is not supported, now only supports `nccl` backend." %
backend)
# 2. check env # 2. check env
def _check_var_exists(var_name): def _check_var_exists(var_name):
...@@ -112,9 +110,8 @@ def init_parallel_env(backend='nccl'): ...@@ -112,9 +110,8 @@ def init_parallel_env(backend='nccl'):
_check_var_exists("PADDLE_TRAINERS_NUM") _check_var_exists("PADDLE_TRAINERS_NUM")
_check_var_exists("PADDLE_TRAINER_ENDPOINTS") _check_var_exists("PADDLE_TRAINER_ENDPOINTS")
# 3. init ParallelStrategy # 3. init NCCL ParallelStrategy
strategy = ParallelStrategy() strategy = ParallelStrategy()
if cpt.to_text(backend) == 'nccl':
if parallel_helper._is_parallel_ctx_initialized(): if parallel_helper._is_parallel_ctx_initialized():
warnings.warn("The parallel environment has been initialized.") warnings.warn("The parallel environment has been initialized.")
strategy.nranks = ParallelEnv().world_size strategy.nranks = ParallelEnv().world_size
...@@ -133,8 +130,7 @@ def init_parallel_env(backend='nccl'): ...@@ -133,8 +130,7 @@ def init_parallel_env(backend='nccl'):
_set_expected_place(place) _set_expected_place(place)
# init nccl context # init nccl context
parallel_helper._set_parallel_ctx( parallel_helper._set_parallel_ctx(core.NCCLParallelContext(strategy, place))
core.NCCLParallelContext(strategy, place))
parallel_helper._init_parallel_ctx() parallel_helper._init_parallel_ctx()
...@@ -163,7 +159,7 @@ def get_rank(): ...@@ -163,7 +159,7 @@ def get_rank():
def get_world_size(): def get_world_size():
""" """
The number of trainers (number of processes participating in current job). Returns the number of trainers (number of processes participating in current job).
Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` . Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` .
The default value is 1. The default value is 1.
......
...@@ -236,8 +236,6 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options): ...@@ -236,8 +236,6 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options):
func (function): The target function is called by spawned process. func (function): The target function is called by spawned process.
This function need to be able to pickled, so it must be defined This function need to be able to pickled, so it must be defined
at the top level of a module. at the top level of a module.
This function should be called as ``func(i, *args)``, ``i`` is
the process index and ``args`` contains other arguments as tuple.
args (tuple, optional): Arguments passed to ``func``. args (tuple, optional): Arguments passed to ``func``.
nprocs (int, optional): Number of processed to start. Default: -1. nprocs (int, optional): Number of processed to start. Default: -1.
when nprocs is -1, the available device will be obtained from when nprocs is -1, the available device will be obtained from
...@@ -246,8 +244,8 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options): ...@@ -246,8 +244,8 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options):
variable CUDA_VISIBLE_DEVICES; If use CPU, the currently available variable CUDA_VISIBLE_DEVICES; If use CPU, the currently available
CPU number is obtained from the environment variable CPU_NUM. CPU number is obtained from the environment variable CPU_NUM.
For example, export CPU_NUM=4, if the environment variable is not set, For example, export CPU_NUM=4, if the environment variable is not set,
the executor will add the variable to the environment variable and the spawn method will add default value to the environment variable
set its value to 1. and set its value to 1.
join (bool, optional): Perform a blocking join on all spawned processes. join (bool, optional): Perform a blocking join on all spawned processes.
Default: True. Default: True.
daemon (bool, optional): The spawned processes' daemon flag. Default: False. daemon (bool, optional): The spawned processes' daemon flag. Default: False.
...@@ -266,8 +264,8 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options): ...@@ -266,8 +264,8 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options):
such as 6170. Default: None; such as 6170. Default: None;
(5) selected_gpus (string): The training process will run on the (5) selected_gpus (string): The training process will run on the
selected_gpus, such as "0,1,2,3". Default: None; selected_gpus, such as "0,1,2,3". Default: None;
(6) print_config: Print current parallel training config. Default: False; (6) print_config (bool): Print current parallel training config. Default: False;
(7) use_paddlecloud: Whether to use paddlecloud platform to run your (7) use_paddlecloud (bool): Whether to use paddlecloud platform to run your
multi-process job. Default: False. multi-process job. Default: False.
Returns: Returns:
......
...@@ -349,38 +349,53 @@ class DataParallel(layers.Layer): ...@@ -349,38 +349,53 @@ class DataParallel(layers.Layer):
Examples: Examples:
.. code-block:: python .. code-block:: python
import numpy as np import paddle
import paddle.fluid as fluid import paddle.nn as nn
import paddle.optimizer as opt
import paddle.distributed as dist
place = fluid.CUDAPlace(fluid.dygraph.ParallelEnv().dev_id) class LinearNet(nn.Layer):
with fluid.dygraph.guard(place): def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
# prepare the data parallel context def forward(self, x):
strategy = fluid.dygraph.prepare_context() return self._linear2(self._linear1(x))
linear = fluid.dygraph.Linear(1, 10, act="softmax") def train():
adam = fluid.optimizer.AdamOptimizer( # 1. enable dynamic mode
learning_rate=0.001, parameter_list=linear.parameters()) paddle.disable_static()
# make the module become the data parallelism module # 2. initialize parallel environment
linear = fluid.dygraph.DataParallel(linear, strategy) dist.init_parallel_env()
x_data = np.random.random(size=[10, 1]).astype(np.float32) # 3. create data parallel layer & optimizer
data = fluid.dygraph.to_variable(x_data) layer = LinearNet()
dp_layer = paddle.DataParallel(layer)
hidden = linear(data) loss_fn = nn.MSELoss()
avg_loss = fluid.layers.mean(hidden) adam = opt.Adam(
learning_rate=0.001, parameters=dp_layer.parameters())
# scale the loss according to the number of trainers. # 4. run layer
avg_loss = linear.scale_loss(avg_loss) inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
avg_loss.backward() loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
# collect the gradients of trainers. adam.step()
linear.apply_collective_grads() adam.clear_grad()
adam.minimize(avg_loss) if __name__ == '__main__':
linear.clear_gradients() # 1. start by ``paddle.distributed.spawn`` (default)
dist.spawn(train, nprocs=2)
# 2. start by ``paddle.distributed.launch``
# train()
""" """
if not self._is_data_parallel_mode(): if not self._is_data_parallel_mode():
return loss return loss
...@@ -438,38 +453,53 @@ class DataParallel(layers.Layer): ...@@ -438,38 +453,53 @@ class DataParallel(layers.Layer):
Examples: Examples:
.. code-block:: python .. code-block:: python
import numpy as np import paddle
import paddle.fluid as fluid import paddle.nn as nn
import paddle.optimizer as opt
import paddle.distributed as dist
place = fluid.CUDAPlace(fluid.dygraph.ParallelEnv().dev_id) class LinearNet(nn.Layer):
with fluid.dygraph.guard(place): def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
# prepare the data parallel context def forward(self, x):
strategy = fluid.dygraph.prepare_context() return self._linear2(self._linear1(x))
linear = fluid.dygraph.Linear(1, 10, act="softmax") def train():
adam = fluid.optimizer.AdamOptimizer( # 1. enable dynamic mode
learning_rate=0.001, parameter_list=linear.parameters()) paddle.disable_static()
# make the module become the data parallelism module # 2. initialize parallel environment
linear = fluid.dygraph.DataParallel(linear, strategy) dist.init_parallel_env()
x_data = np.random.random(size=[10, 1]).astype(np.float32) # 3. create data parallel layer & optimizer
data = fluid.dygraph.to_variable(x_data) layer = LinearNet()
dp_layer = paddle.DataParallel(layer)
hidden = linear(data) loss_fn = nn.MSELoss()
avg_loss = fluid.layers.mean(hidden) adam = opt.Adam(
learning_rate=0.001, parameters=dp_layer.parameters())
# scale the loss according to the number of trainers. # 4. run layer
avg_loss = linear.scale_loss(avg_loss) inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
avg_loss.backward() loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
# collect the gradients of trainers. adam.step()
linear.apply_collective_grads() adam.clear_grad()
adam.minimize(avg_loss) if __name__ == '__main__':
linear.clear_gradients() # 1. start by ``paddle.distributed.spawn`` (default)
dist.spawn(train, nprocs=2)
# 2. start by ``paddle.distributed.launch``
# train()
""" """
if not self._is_data_parallel_mode(): if not self._is_data_parallel_mode():
return return
......
...@@ -30,15 +30,9 @@ from paddle.fluid.dygraph import parallel_helper ...@@ -30,15 +30,9 @@ from paddle.fluid.dygraph import parallel_helper
# executed in the python3 sub-process. # executed in the python3 sub-process.
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestInitParallelEnv(unittest.TestCase): class TestInitParallelEnv(unittest.TestCase):
def test_beckend_type_error(self):
with self.assertRaises(TypeError):
dist.init_parallel_env(backend=1)
def test_backend_value_error(self):
with self.assertRaises(ValueError):
dist.init_parallel_env(backend="mpi")
def test_check_env_failed(self): def test_check_env_failed(self):
os.environ['FLAGS_selected_gpus'] = '0' os.environ['FLAGS_selected_gpus'] = '0'
os.environ['PADDLE_TRAINER_ID'] = '0' os.environ['PADDLE_TRAINER_ID'] = '0'
......
...@@ -20,8 +20,8 @@ __all__ = [ ...@@ -20,8 +20,8 @@ __all__ = [
] ]
__all__ += [ __all__ += [
'grad', 'LayerList', 'load', 'save', 'prepare_context', 'to_variable', 'grad', 'LayerList', 'load', 'save', 'to_variable', 'no_grad',
'no_grad', 'ParallelEnv', 'DataParallel' 'DataParallel'
] ]
__all__ += [ __all__ += [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册