未验证 提交 31f422ae 编写于 作者: C Chen Weihang 提交者: GitHub

Add interface to launch parallel dygraph by multiprocessing (#26044)

* add dygraph parallel run interface

* polish implement & unified env property name

* add print config arg

* refactor init_parallel_env function

* Compatible with multiprocessing and launch modes

* set default trainer start port

* support run in python 2

* polish python2 support code

* remove python2 support

* refine launch import

* polish dome design details

* refactor api implemention & path

* use new method _set_expected_place

* add spawn unittest framework & mnist test

* add more unittests & doc

* fix unittest failed

* polish english doc

* self review and polish details

* refactor code by reviewer's comments

* fix unittest failed

* fix parallel_env unittest

* fix several typos

* fix error introduced when fixing typos

* add unpublic note for start_processes

* polish details by xiaoguang's comment

* verify correctly when spawn nprocs=-1

* refactor spawn & init_parallel_env design

* polish doc details

* open spawn unittests

* try to fix doc compile error

* try to fix unknown doc format error

* add skip unittest when not gpu
上级 3390c7e2
......@@ -230,8 +230,6 @@ from .framework import grad #DEFINE_ALIAS
from .framework import no_grad #DEFINE_ALIAS
from .framework import save #DEFINE_ALIAS
from .framework import load #DEFINE_ALIAS
from .framework import prepare_context #DEFINE_ALIAS
from .framework import ParallelEnv #DEFINE_ALIAS
from .framework import DataParallel #DEFINE_ALIAS
from .framework import NoamDecay #DEFINE_ALIAS
......
......@@ -12,4 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import spawn
from .spawn import spawn
from . import parallel
from .parallel import init_parallel_env
from .parallel import get_rank
from .parallel import get_world_size
from paddle.fluid.dygraph.parallel import prepare_context #DEFINE_ALIAS
from paddle.fluid.dygraph.parallel import ParallelEnv #DEFINE_ALIAS
from . import collective
from .collective import *
# start multiprocess apis
__all__ = ["spawn"]
# dygraph parallel apis
__all__ += [
"init_parallel_env", "get_rank", "get_world_size", "prepare_context",
"ParallelEnv"
]
# collective apis
__all__ += collective.__all__
......@@ -44,11 +44,9 @@ import time
import six
import copy
from argparse import ArgumentParser, REMAINDER
import paddle
import paddle.fluid as fluid
from paddle.distributed.utils import *
import paddle.distributed.cloud_utils as cloud_utils
from paddle.distributed import cloud_utils
def _print_arguments(args):
......@@ -167,7 +165,8 @@ def get_cluster_from_args(args, selected_gpus):
def get_gpus(selected_gpus):
if selected_gpus is None:
gpus_num = fluid.core.get_cuda_device_count()
from paddle.fluid import core
gpus_num = core.get_cuda_device_count()
selected_gpus = [str(x) for x in range(0, gpus_num)]
else:
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
......@@ -190,7 +189,7 @@ def get_gpus(selected_gpus):
return selected_gpus
def launch(args):
def get_cluster_and_pod(args):
# parse arguments, used for cloud-single-machine and local
selected_gpus = get_gpus(args.selected_gpus)
trainers_num = cloud_utils.get_trainers_num()
......@@ -209,6 +208,12 @@ def launch(args):
cluster, pod = get_cluster_from_args(args, selected_gpus)
logger.info("get cluster from args:{}".format(cluster))
return cluster, pod
def launch(args):
cluster, pod = get_cluster_and_pod(args)
procs = start_local_trainers(
cluster,
pod,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except jin compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import six
import warnings
from paddle import compat as cpt
# deprecated module import
from paddle.fluid import core
from paddle.fluid.framework import _set_expected_place
from paddle.fluid.dygraph import parallel_helper
from paddle.fluid.dygraph.parallel import ParallelEnv
__all__ = ["init_parallel_env"]
ParallelStrategy = core.ParallelStrategy
def init_parallel_env(backend='nccl'):
"""
Initialize parallel training environments in dynamic mode.
Args:
backend(str, optional): The backend to communication between multiple devices.
Now only support ``nccl`` . Default value is ``nccl`` .
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
import paddle.optimizer as opt
import paddle.distributed as dist
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
def forward(self, x):
return self._linear2(self._linear1(x))
def train():
# 1. enable dynamic mode
paddle.disable_static()
# 2. initialize parallel environment
dist.init_parallel_env()
# 3. create data parallel layer & optimizer
layer = LinearNet()
dp_layer = paddle.DataParallel(layer)
loss_fn = nn.MSELoss()
adam = opt.Adam(
learning_rate=0.001, parameters=dp_layer.parameters())
# 4. run layer
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
if __name__ == '__main__':
dist.spawn(train)
"""
# 1. input check
if not isinstance(backend, six.string_types):
raise TypeError("input `backend` type error, expected type is str, "
"but received type is %s." % type(backend))
if cpt.to_text(backend) != 'nccl':
raise ValueError(
"backend `%s` is not supported, now only supports `nccl` backend." %
backend)
# 2. check env
def _check_var_exists(var_name):
var = os.environ.get(var_name, None)
if var is None:
raise ValueError("paddle.distributed initialize error, "
"environment variable %s is needed, but not set." %
var_name)
_check_var_exists("FLAGS_selected_gpus")
_check_var_exists("PADDLE_TRAINER_ID")
_check_var_exists("PADDLE_CURRENT_ENDPOINT")
_check_var_exists("PADDLE_TRAINERS_NUM")
_check_var_exists("PADDLE_TRAINER_ENDPOINTS")
# 3. init ParallelStrategy
strategy = ParallelStrategy()
if cpt.to_text(backend) == 'nccl':
if parallel_helper._is_parallel_ctx_initialized():
warnings.warn("The parallel environment has been initialized.")
strategy.nranks = ParallelEnv().world_size
strategy.local_rank = ParallelEnv().rank
strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
strategy.current_endpoint = ParallelEnv().current_endpoint
if strategy.nranks < 2:
return
# NOTE(chenweihang): [ why config global place here? ]
# the dygraph mode will be set to default mode,
# users will not call `dygraph.guard` or `enable_dygraph`
# directly, if they want to switch default place,
# they need to call a function to change default place,
# here just set correctly place to users
place = core.CUDAPlace(ParallelEnv().device_id)
_set_expected_place(place)
# init nccl context
parallel_helper._set_parallel_ctx(
core.NCCLParallelContext(strategy, place))
parallel_helper._init_parallel_ctx()
def get_rank():
"""
Returns the rank of current trainer.
Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` .
The default value is 0.
Returns:
(int) The rank of current trainer.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
# execute this command in terminal: export PADDLE_TRAINER_ID=0
print("The rank is %d" % dist.get_rank())
# The rank is 0
"""
return ParallelEnv().rank
def get_world_size():
"""
The number of trainers (number of processes participating in current job).
Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` .
The default value is 1.
Returns:
(int) The number of trainers.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
# execute this command in terminal: export PADDLE_TRAINERS_NUM=4
print("The world_size is %d" % dist.get_world_size())
# The world_size is 4
"""
return ParallelEnv().world_size
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function, division
import multiprocessing
import os
import signal
import six
import sys
import warnings
from paddle.distributed.launch import get_cluster_and_pod, _print_arguments
from paddle.distributed.utils import _prepare_trainer_env
from paddle.device import get_device
# deprecated module import
from paddle.fluid import core
from paddle.fluid.framework import _cpu_num
# NOTE(chenweihang): The existence of this class leads to
# the maintenance of two arguments. When the launch.py arguments
# is updated, the arguments here also need to be updated,
# but I have not thought of a better way here
class ParallelEnvArgs(object):
def __init__(self):
# Paddle cluster nodes ips, such as 192.168.0.16,192.168.0.17..
self.cluster_node_ips = None
# The current node ip.
self.node_ip = None
# whether to use paddlecloud platform to run your multi-process job.
# If false, no need to set this argument.
self.use_paddlecloud = None
# The trainer's started port on a single node
self.started_port = None
# Print the config or not
self.print_config = True
# It's for gpu training and the training process will run
# on the selected_gpus, each process is bound to a single GPU.
# And if it's not set, this module will use all the gpu cards
# for training.
self.selected_gpus = None
def _py_supported_check():
if not sys.version_info >= (3, 4):
raise RuntimeError(
"Use `paddle.distributed.spawn` to start parallel training "
"requires python version greater than 3.4, if your python "
"is lower than this version, please use "
"`paddle.distributed.launch` instead.")
def _get_subprocess_env_list(nprocs, options):
# contruct processes env list
processes_env_list = []
# get args from kwargs
args = ParallelEnvArgs()
# set default `node_ip` and `cluster_node_ips`
args.cluster_node_ips = options.get('cluster_node_ips', None)
args.node_ip = options.get('node_ip', None)
if args.cluster_node_ips is not None and args.node_ip is None:
raise ValueError("please input current node ip, "
"cannot only give `cluster_node_ips`.")
default_node_ip = "127.0.0.1"
if args.node_ip is None:
args.node_ip = default_node_ip
if args.cluster_node_ips is None:
args.cluster_node_ips = default_node_ip
# set default selected gpus
# e.g. if the nprocs is 4, the selected gpus is "0,1,2,3"
# NOTE(chenweihang): [ why not use FLAGS_selected_gpus directly? ]
# because the FLAGS_selected_gpus may be used in other place,
# if we set FLAGS_selected_gpus to be `0,1,2,3`, it may cause error
# when using `ParallelEnv`
# NOTE(chenweihang): use absolute gpu card id
args.selected_gpus = options.get('selected_gpus', None)
env_devices = os.getenv("CUDA_VISIBLE_DEVICES", None)
if env_devices is None or env_devices == "":
env_devices_list = [
str(x) for x in six.moves.range(core.get_cuda_device_count())
]
else:
env_devices_list = env_devices.split(',')
if args.selected_gpus is None:
if len(env_devices_list) < nprocs:
raise RuntimeError(
"the number of visible devices(%d) is less than the number "
"of spawn processes(%d), please ensure that the correct "
"`nprocs` argument is passed or the environment variable "
"`CUDA_VISIBLE_DEVICES` is correctly configured." %
(len(env_devices_list), nprocs))
args.selected_gpus = ",".join(
[str(env_devices_list[x]) for x in range(0, nprocs)])
else:
for card_id in args.selected_gpus.split(','):
if card_id not in env_devices_list:
raise ValueError("The selected gpu card %s cannot found in "
"CUDA_VISIBLE_DEVICES (%s)." %
(card_id, ",".join(env_devices_list)))
# set other arguments
args.started_port = options.get('started_port', None)
args.use_paddlecloud = options.get('use_paddlecloud', False)
args.print_config = options.get('print_config', False)
# reuse code of launch.py
cluster, pod = get_cluster_and_pod(args)
# prepare subprocess env list
for trainer in pod.trainers:
processes_env_list.append(_prepare_trainer_env(cluster, trainer))
# print config
if args.print_config:
_print_arguments(args)
return processes_env_list
def _remove_risky_env():
# remove useless env vars, same as launch.py
# no copy, each process will hold env vars itself
os.environ.pop("http_proxy", None)
os.environ.pop("https_proxy", None)
def _set_trainer_env(env_dict):
for var_name in env_dict:
os.environ[var_name] = env_dict[var_name]
def _func_wrapper(func, args, error_queue, return_queue, env_dict):
try:
# config subprocess environment variables
_remove_risky_env()
_set_trainer_env(env_dict)
# execute function
result = func(*args)
# record function return value
return_queue.put(result)
except KeyboardInterrupt:
pass
except Exception:
import traceback
error_queue.put(traceback.format_exc())
sys.exit(1)
class MultiprocessContext(object):
def __init__(self, processes, error_queues, return_queues):
_py_supported_check()
self.error_queues = error_queues
# NOTE(chenweihang): The `spawn` method is mainly used
# to wrap the outermost execution function of the program for
# parallel execution. Generally, the return value is not concerned,
# but if the user needs to obtain the return value, users can get
# the return result of each process from context.return_queues
self.return_queues = return_queues
self.processes = processes
self.sentinels = {
process.sentinel: index
for index, process in enumerate(processes)
}
def join(self, timeout=None):
if len(self.sentinels) == 0:
return True
ready = multiprocessing.connection.wait(
self.sentinels.keys(), timeout=timeout)
error_index = None
for sentinel in ready:
index = self.sentinels.pop(sentinel)
process = self.processes[index]
process.join()
if process.exitcode != 0:
error_index = index
break
if error_index is None:
return len(self.sentinels) == 0
for process in self.processes:
if process.is_alive():
process.terminate()
process.join()
self._throw_exception(error_index)
def _throw_exception(self, error_index):
if self.error_queues[error_index].empty():
exitcode = self.processes[error_index].exitcode
if exitcode < 0:
name = signal.Signals(-exitcode).name
raise Exception("Process %d terminated with signal %s." %
(error_index, name))
else:
raise Exception("Process %d terminated with exit code %d." & (
error_index, exitcode))
original_trace = self.error_queues[error_index].get()
msg = "\n\n----------------------------------------------\n" \
"Process %d terminated with the following error:\n" \
"----------------------------------------------\n\n" % error_index
msg += original_trace
raise Exception(msg)
def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options):
"""
Start multiple processes with ``spawn`` method for parallel training.
Args:
func (function): The target function is called by spawned process.
This function need to be able to pickled, so it must be defined
at the top level of a module.
This function should be called as ``func(i, *args)``, ``i`` is
the process index and ``args`` contains other arguments as tuple.
args (tuple, optional): Arguments passed to ``func``.
nprocs (int, optional): Number of processed to start. Default: -1.
when nprocs is -1, the available device will be obtained from
the environment variable when the model is executed: If use GPU,
the currently available device ID is obtained from the environment
variable CUDA_VISIBLE_DEVICES; If use CPU, the currently available
CPU number is obtained from the environment variable CPU_NUM.
For example, export CPU_NUM=4, if the environment variable is not set,
the executor will add the variable to the environment variable and
set its value to 1.
join (bool, optional): Perform a blocking join on all spawned processes.
Default: True.
daemon (bool, optional): The spawned processes' daemon flag. Default: False.
**options(dict, optional): Other initial parallel execution environment
configuration options. The following options are currently supported:
(1) start_method (string): the way to start a process.
The start method can be ``spawn`` , ``fork`` , ``forkserver`` .
Because the CUDA runtime does not support the ``fork`` start method,
when use CUDA in subprocesses, we should start process by ``spawn``
or ``forkserver`` method. Default: "spawn" ;
(2) cluster_node_ips (string): Paddle cluster nodes ips, such as
"192.168.0.16,192.168.0.17". Default: "127.0.0.1";
(3) node_ip (string): The current node ip, such as "192.168.0.16".
Default: "127.0.0.1";
(4) started_port (int): The trainer's started port on a single node,
such as 6170. Default: None;
(5) selected_gpus (string): The training process will run on the
selected_gpus, such as "0,1,2,3". Default: None;
(6) print_config: Print current parallel training config. Default: False;
(7) use_paddlecloud: Whether to use paddlecloud platform to run your
multi-process job. Default: False.
Returns:
``MultiprocessContext`` object, it hold the spawned processes.
Examples:
.. code-block:: python
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.optimizer as opt
import paddle.distributed as dist
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
def forward(self, x):
return self._linear2(self._linear1(x))
def train(print_result=False):
# 1. enable dynamic mode
paddle.disable_static()
# 2. initialize parallel environment
dist.init_parallel_env()
# 3. create data parallel layer & optimizer
layer = LinearNet()
dp_layer = paddle.DataParallel(layer)
loss_fn = nn.MSELoss()
adam = opt.Adam(
learning_rate=0.001, parameters=dp_layer.parameters())
# 4. run layer
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
if print_result is True:
print("loss:", loss.numpy())
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
# Usage 1: only pass function.
# If your training method no need any argument, and
# use all visible devices for parallel training.
if __name__ == '__main__':
dist.spawn(train)
# Usage 2: pass function and arguments.
# If your training method need some arguments, and
# use all visible devices for parallel training.
if __name__ == '__main__':
dist.spawn(train, args=(True,))
# Usage 3: pass function, arguments and nprocs.
# If your training method need some arguments, and
# only use part of visible devices for parallel training.
# If your machine hold 8 cards {0,1,2,3,4,5,6,7},
# this case will use cards {0,1}; If you set
# CUDA_VISIBLE_DEVICES=4,5,6,7, this case will use
# cards {4,5}
if __name__ == '__main__':
dist.spawn(train, args=(True,), nprocs=2)
# Usage 4: pass function, arguments, nprocs and selected_gpus.
# If your training method need some arguments, and
# only use part of visible devices for parallel training,
# but you can't set your machine's environment varibale
# CUDA_VISIBLE_DEVICES, such as it is None or all cards
# {0,1,2,3,4,5,6,7}, you can pass `selelcted_gpus` to
# select the GPU cards you want to use. For example,
# this case will use cards {4,5} if your machine hold 8 cards.
if __name__ == '__main__':
dist.spawn(train, args=(True,), nprocs=2, selelcted_gpus='4,5')
"""
# NOTE(chenweihang): [ why only supports python3.4+ ? ]
# Python supported setting the child process startup method
# since 3.4. The previous version can only use the default startup
# method, while the default startup method of Unix is fork, which
# cannot support CUDA runtime multi-process
_py_supported_check()
# get default nprocs
if nprocs == -1:
device = get_device()
if device == 'cpu':
# TODO: not supports cpu parallel now
nprocs = _cpu_num
else:
nprocs = core.get_cuda_device_count()
# NOTE(chenweihang): [ why need get cluster info before run? ]
# when using `paddle.distributed.spawn` start parallel training,
# we should get cluster info before starting subprocess, and pass
# correct info to each subprocess
procs_env_list = _get_subprocess_env_list(nprocs, options)
# start processes
# NOTE(chenweihang): [ why default start method is spawn? ]
# The CUDA runtime does not support the fork start method,
# either the spawn or forkserver start method are required
# to use CUDA in subprocesses.
start_method = options.get('start_method', None)
if start_method is None:
start_method = 'spawn'
mp = multiprocessing.get_context(start_method)
error_queues = []
return_queues = []
processes = []
for i in range(nprocs):
error_queue = mp.SimpleQueue()
return_queue = mp.SimpleQueue()
process = mp.Process(
target=_func_wrapper,
args=(func, args, error_queue, return_queue, procs_env_list[i]))
process.daemon = daemon
process.start()
error_queues.append(error_queue)
return_queues.append(return_queue)
processes.append(process)
context = MultiprocessContext(processes, error_queues, return_queues)
if not join:
return context
# loop until all process end
while not context.join():
pass
# finally return context
return context
......@@ -327,6 +327,17 @@ def find_free_ports(num):
return None
def _prepare_trainer_env(cluster, trainer):
proc_env = {
"FLAGS_selected_gpus": "%s" % ",".join([str(g) for g in trainer.gpus]),
"PADDLE_TRAINER_ID": "%d" % trainer.rank,
"PADDLE_CURRENT_ENDPOINT": "%s" % trainer.endpoint,
"PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(),
"PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints())
}
return proc_env
class TrainerProc(object):
def __init__(self):
self.proc = None
......@@ -352,14 +363,7 @@ def start_local_trainers(cluster,
procs = []
for idx, t in enumerate(pod.trainers):
proc_env = {
"FLAGS_selected_gpus": "%s" % ",".join([str(g) for g in t.gpus]),
"PADDLE_TRAINER_ID": "%d" % t.rank,
"PADDLE_CURRENT_ENDPOINT": "%s" % t.endpoint,
"PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(),
"PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints())
}
proc_env = _prepare_trainer_env(cluster, t)
current_env.update(proc_env)
logger.debug("trainer proc env:{}".format(current_env))
......
......@@ -11,21 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import six
import numpy as np
import warnings
from collections import OrderedDict
from .. import core
from . import layers
from . import parallel_helper
from .. import framework
from . import to_variable, no_grad
from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph import parallel_helper
from paddle.fluid.dygraph import to_variable, no_grad
from paddle.utils import deprecated
__all__ = ["prepare_context", "ParallelEnv", "DataParallel"]
ParallelStrategy = core.ParallelStrategy
@deprecated(since="2.0.0", update_to="paddle.distributed.init_parallel_env")
def prepare_context(strategy=None):
'''
:api_attr: imperative
......@@ -39,17 +44,18 @@ def prepare_context(strategy=None):
if strategy.nranks < 2:
return
assert framework.in_dygraph_mode() is True, \
"dygraph.prepare_context should be used with dygrahp mode."
"dygraph.prepare_context should be used with dygraph mode."
place = framework._current_expected_place()
assert place is not None, \
"dygraph.prepare_context should be used in fluid.dygraph.guard(place) guard."
if isinstance(place, core.CUDAPlace):
parallel_helper._set_parallel_ctx(
core.NCCLParallelContext(strategy, place))
else:
# TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation
assert ("Only support CUDAPlace for now.")
parallel_helper._init_parallel_ctx()
if not parallel_helper._is_parallel_ctx_initialized():
if isinstance(place, core.CUDAPlace):
parallel_helper._set_parallel_ctx(
core.NCCLParallelContext(strategy, place))
else:
# TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation
assert ("Only support CUDAPlace for now.")
parallel_helper._init_parallel_ctx()
return strategy
......@@ -112,84 +118,84 @@ class ParallelEnv(object):
"""
def __init__(self):
self._nranks = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
self._local_rank = int(os.getenv("PADDLE_TRAINER_ID", "0"))
self._dev_id = int(os.getenv("FLAGS_selected_gpus", "0"))
self._rank = int(os.getenv("PADDLE_TRAINER_ID", "0"))
self._world_size = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
self._device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
self._trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS",
"").split(",")
self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "")
@property
def nranks(self):
def rank(self):
"""
The number of trainers, generally refers to the number of GPU cards used in training.
Rank of current trainer.
Its value is equal to the value of the environment variable PADDLE_TRAINERS_NUM. The default value is 1.
Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` . The default value is 0.
Examples:
.. code-block:: python
# execute this command in terminal: export PADDLE_TRAINERS_NUM=4
import paddle.fluid as fluid
# execute this command in terminal: export PADDLE_TRAINER_ID=0
import paddle.distributed as dist
env = fluid.dygraph.ParallelEnv()
print("The nranks is %d" % env.nranks)
# The nranks is 4
env = dist.ParallelEnv()
print("The rank is %d" % env.rank)
# The rank is 0
"""
return self._nranks
return self._rank
@property
def local_rank(self):
def world_size(self):
"""
The current trainer number.
The number of trainers (number of processes participating in current job).
Its value is equal to the value of the environment variable PADDLE_TRAINER_ID. The default value is 0.
Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` . The default value is 1.
Examples:
.. code-block:: python
# execute this command in terminal: export PADDLE_TRAINER_ID=0
import paddle.fluid as fluid
# execute this command in terminal: export PADDLE_TRAINERS_NUM=4
import paddle.distributed as dist
env = fluid.dygraph.ParallelEnv()
print("The local rank is %d" % env.local_rank)
# The local rank is 0
env = dist.ParallelEnv()
print("The world_size is %d" % env.world_size)
# The world_size is 4
"""
return self._local_rank
return self._world_size
@property
def dev_id(self):
def device_id(self):
"""
The ID of selected GPU card for parallel training.
Its value is equal to the value of the environment variable FLAGS_selected_gpus. The default value is 0.
Its value is equal to the value of the environment variable ``FLAGS_selected_gpus`` . The default value is 0.
Examples:
.. code-block:: python
# execute this command in terminal: export FLAGS_selected_gpus=1
import paddle.fluid as fluid
import paddle.distributed as dist
env = fluid.dygraph.ParallelEnv()
print("The device id are %d" % env.dev_id)
env = dist.ParallelEnv()
print("The device id are %d" % env.device_id)
# The device id are 1
"""
return self._dev_id
return self._device_id
@property
def current_endpoint(self):
"""
The endpoint of current trainer, it is in the form of (node IP + port).
Its value is equal to the value of the environment variable PADDLE_CURRENT_ENDPOINT. The default value is "".
Its value is equal to the value of the environment variable ``PADDLE_CURRENT_ENDPOINT`` . The default value is "".
Examples:
.. code-block:: python
# execute this command in terminal: export PADDLE_CURRENT_ENDPOINT=127.0.0.1:6170
import paddle.fluid as fluid
import paddle.distributed as dist
env = fluid.dygraph.ParallelEnv()
env = dist.ParallelEnv()
print("The current endpoint are %s" % env.current_endpoint)
# The current endpoint are 127.0.0.1:6170
"""
......@@ -201,20 +207,25 @@ class ParallelEnv(object):
The endpoints of all trainer nodes in the task,
which are used to broadcast the NCCL ID when NCCL2 is initialized.
Its value is equal to the value of the environment variable PADDLE_TRAINER_ENDPOINTS. The default value is "".
Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ENDPOINTS`` . The default value is "".
Examples:
.. code-block:: python
# execute this command in terminal: export PADDLE_TRAINER_ENDPOINTS=127.0.0.1:6170,127.0.0.1:6171
import paddle.fluid as fluid
import paddle.distributed as dist
env = fluid.dygraph.ParallelEnv()
env = dist.ParallelEnv()
print("The trainer endpoints are %s" % env.trainer_endpoints)
# The trainer endpoints are ['127.0.0.1:6170', '127.0.0.1:6171']
"""
return self._trainer_endpoints
# [aliases] Compatible with old method names
local_rank = rank
nranks = world_size
dev_id = device_id
# NOTE: [ Compatible ] Originally this class name is `Env`. The semantics of the old class names
# are inaccurate and may confuse users, so replace it with `ParallelEnv`, but to be compatible
......@@ -227,61 +238,98 @@ class DataParallel(layers.Layer):
Run the dygraph module with data parallelism.
Currently, DataParallel class only supports to run the dynamic graph
with multi-process. The usage is:
`python -m paddle.distributed.launch --selected_gpus=0,1 dynamic_graph_test.py`.
And the content of `dynamic_graph_test.py` is the code of examples.
with multi-process.
Now supports two ways to start training:
1. start by ``paddle.distributed.spawn`` method, for example:
``python demo.py`` (spawn need to be called in ``__main__`` method)
2. start by ``paddle.distributed.launch`` module, for example:
``python -m paddle.distributed.launch --selected_gpus=0,1 demo.py`` .
And the content of `demo.py` is the code of examples.
Args:
layers(Layer): The module that should be executed by data parallel.
strategy(ParallelStrategy): The strategy of data parallelism, contains
environment configuration related to parallel execution.
strategy(ParallelStrategy, optional): (deprecated) The strategy of data parallelism,
contains environment configuration related to parallel execution. Default: None.
Returns:
Layer: The data paralleled module.
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
import paddle
import paddle.nn as nn
import paddle.optimizer as opt
import paddle.distributed as dist
place = fluid.CUDAPlace(fluid.dygraph.ParallelEnv().dev_id)
with fluid.dygraph.guard(place):
# prepare the data parallel context
strategy = fluid.dygraph.prepare_context()
linear = fluid.dygraph.Linear(1, 10, act="softmax")
adam = fluid.optimizer.AdamOptimizer(
learning_rate=0.001, parameter_list=linear.parameters())
# make the module become the data parallelism module
linear = fluid.dygraph.DataParallel(linear, strategy)
x_data = np.random.random(size=[10, 1]).astype(np.float32)
data = fluid.dygraph.to_variable(x_data)
hidden = linear(data)
avg_loss = fluid.layers.mean(hidden)
# scale the loss according to the number of trainers.
avg_loss = linear.scale_loss(avg_loss)
avg_loss.backward()
# collect the gradients of trainers.
linear.apply_collective_grads()
adam.minimize(avg_loss)
linear.clear_gradients()
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
def forward(self, x):
return self._linear2(self._linear1(x))
def train():
# 1. enable dynamic mode
paddle.disable_static()
# 2. initialize parallel environment
dist.init_parallel_env()
# 3. create data parallel layer & optimizer
layer = LinearNet()
dp_layer = paddle.DataParallel(layer)
loss_fn = nn.MSELoss()
adam = opt.Adam(
learning_rate=0.001, parameters=dp_layer.parameters())
# 4. run layer
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
if __name__ == '__main__':
# 1. start by ``paddle.distributed.spawn`` (default)
dist.spawn(train, nprocs=2)
# 2. start by ``paddle.distributed.launch``
# train()
"""
def __init__(self, layers, strategy):
def __init__(self, layers, strategy=None):
super(DataParallel,
self).__init__(layers.full_name() + "_data_parallel")
self._layers = layers
self._strategy = strategy
# NOTE(chenweihang): The ParallelStrategy here is not strictly a strategy.
# It just stores some environment variables, which can be constructed by
# ParallelEnv. Here it is set as an optional argument.
# This parameter is not removed because of compatibility with 1.x writing.
if strategy is not None:
self._strategy = strategy
else:
self._strategy = ParallelStrategy()
self._strategy.nranks = ParallelEnv().nranks
self._strategy.local_rank = ParallelEnv().local_rank
self._strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
self._strategy.current_endpoint = ParallelEnv().current_endpoint
def forward(self, *inputs, **kwargs):
return self._layers(*inputs, **kwargs)
......
......@@ -23,6 +23,11 @@ def _is_data_parallel_mode():
os.getenv("PADDLE_TRAINERS_NUM", "1")) > 1
def _is_parallel_ctx_initialized():
global __parallel_ctx__clz__
return __parallel_ctx__clz__ is not None
def _set_parallel_ctx(nccl_parallel_context):
global __parallel_ctx__clz__
assert __parallel_ctx__clz__ is None, \
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function, division
import numpy as np
import unittest
import paddle
# used by model.run_trainer in test_dist_base
from test_dist_base import RUN_STEP
# NOTE: compatible TestParallelDyGraphRunnerBase args
class SpawnAssistTestArgs(object):
update_method = "local"
trainer_id = 0
class TestDistSpawnRunner(unittest.TestCase):
def setUp(self):
# NOTE(chenweihang): keep consistent with
# TestDistBase.check_with_place
self.nprocs = 2
def _run(self, model, args):
args.update_method = "local"
return model.run_trainer_with_spawn(args)
def _run_parallel(self, model, args):
args.update_method = "nccl2"
context = paddle.distributed.spawn(
func=model.run_trainer_with_spawn,
args=(args, ),
nprocs=self.nprocs,
join=True)
result_list = []
for res_queue in context.return_queues:
result_list.append(res_queue.get())
return result_list
def check_dist_result_with_spawn(self, test_class, delta=1e-3):
# 0. prepare model and args
model = test_class()
args = SpawnAssistTestArgs()
# 1. calc signal card loss
losses = self._run(model, args)
# 2. calc multi card loss (nccl mode)
dist_losses_list = self._run_parallel(model, args)
# 3. compare losses
for step_id in range(RUN_STEP):
loss = losses[step_id]
dist_loss_sum = None
for dist_losses in dist_losses_list:
if dist_loss_sum is None:
dist_loss_sum = np.array(dist_losses[step_id])
else:
dist_loss_sum += np.array(dist_losses[step_id])
dist_loss = dist_loss_sum / self.nprocs
self.assertAlmostEqual(
loss,
dist_loss,
delta=delta,
msg="The results of single-card execution and multi-card execution are inconsistent."
"signal-card loss is:\n{}\nmulti-card average loss is:\n{}\n".
format(loss, dist_loss))
......@@ -38,9 +38,10 @@ class TestDirectory(unittest.TestCase):
'paddle.enable_static', 'paddle.disable_static',
'paddle.in_dynamic_mode', 'paddle.to_variable', 'paddle.grad',
'paddle.no_grad', 'paddle.save', 'paddle.load',
'paddle.static.save', 'paddle.static.load', 'paddle.ParallelEnv',
'paddle.prepare_context', 'paddle.DataParallel', 'paddle.jit',
'paddle.jit.TracedLayer', 'paddle.jit.to_static',
'paddle.static.save', 'paddle.static.load',
'paddle.distributed.ParallelEnv',
'paddle.distributed.prepare_context', 'paddle.DataParallel',
'paddle.jit', 'paddle.jit.TracedLayer', 'paddle.jit.to_static',
'paddle.jit.ProgramTranslator', 'paddle.jit.TranslatedLayer',
'paddle.jit.save', 'paddle.jit.load', 'paddle.jit.SaveLoadConfig',
'paddle.NoamDecay', 'paddle.PiecewiseDecay',
......
......@@ -23,8 +23,11 @@ import subprocess
import six
import argparse
import pickle
import random
import numpy as np
import time
import paddle
import paddle.fluid as fluid
from paddle.fluid import compiler
import paddle.fluid.dygraph as dygraph
......@@ -382,22 +385,22 @@ class TestParallelDyGraphRunnerBase(object):
raise NotImplementedError(
"train_one_loop should be implemented by the child classes.")
def _get_data(self, batch, args):
if args.update_method != "local":
new_batch = []
for offset, item in enumerate(batch):
if offset % 2 == args.trainer_id:
new_batch.append(item)
return new_batch
else:
return batch
def run_trainer(self, args):
seed = 90
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(device_id)
def _get_data(batch):
if args.update_method != "local":
new_batch = []
for offset, item in enumerate(batch):
if offset % 2 == args.trainer_id:
new_batch.append(item)
return new_batch
else:
return batch
with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
......@@ -422,7 +425,7 @@ class TestParallelDyGraphRunnerBase(object):
out_losses = []
print_to_err(type(self).__name__, "begin to run dygraph training")
for step_id, data in enumerate(train_reader()):
data = _get_data(data)
data = self._get_data(data, args)
if step_id == RUN_STEP:
break
loss = self.run_one_loop(model, opt, data)
......@@ -444,6 +447,47 @@ class TestParallelDyGraphRunnerBase(object):
model.clear_gradients()
print_to_out(out_losses)
def run_trainer_with_spawn(self, args):
# 1. enable dygraph
paddle.disable_static()
# 2. init seed
seed = 90
paddle.static.default_startup_program().random_seed = seed
paddle.static.default_main_program().random_seed = seed
np.random.seed(seed)
random.seed = seed
# get trainer id
args.trainer_id = paddle.distributed.get_rank()
# 3. init parallel env
if args.update_method == "nccl2":
paddle.distributed.init_parallel_env()
# 4. train model
model, train_reader, opt = self.get_model()
if args.update_method == "nccl2":
model = paddle.DataParallel(model)
out_losses = []
for step_id, data in enumerate(train_reader()):
data = self._get_data(data, args)
if step_id == RUN_STEP:
break
loss = self.run_one_loop(model, opt, data)
out_losses.append(loss.numpy())
if args.update_method == "nccl2":
loss = model.scale_loss(loss)
loss.backward()
if args.update_method == "nccl2":
model.apply_collective_grads()
opt.minimize(loss)
model.clear_gradients()
return out_losses
def runtime_main(test_class):
parser = argparse.ArgumentParser(description='Run dist test.')
......
......@@ -43,7 +43,7 @@ class MLP(fluid.Layer):
class TestDataParallelStateDict(unittest.TestCase):
def test_data_parallel_state_dict(self):
with fluid.dygraph.guard():
strategy = paddle.prepare_context()
strategy = paddle.distributed.prepare_context()
mlp = MLP()
parallel_mlp = dygraph.parallel.DataParallel(mlp, strategy)
......
......@@ -13,11 +13,16 @@
# limitations under the License.
from __future__ import print_function
import os
import sys
import unittest
from test_dist_base import TestDistBase
import paddle.fluid as fluid
from test_dist_base import TestDistBase
from spawn_runner_base import TestDistSpawnRunner
from parallel_dygraph_mnist import TestMnist
import os
flag_name = os.path.splitext(__file__)[0]
......@@ -36,5 +41,11 @@ class TestParallelDygraphMnist(TestDistBase):
log_name=flag_name)
class TestParallelDygraphMnistSpawn(TestDistSpawnRunner):
def test_mnist_with_spawn(self):
if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4):
self.check_dist_result_with_spawn(test_class=TestMnist, delta=1e-5)
if __name__ == "__main__":
unittest.main()
......@@ -13,11 +13,16 @@
# limitations under the License.
from __future__ import print_function
import os
import sys
import unittest
from test_dist_base import TestDistBase
import paddle.fluid as fluid
from test_dist_base import TestDistBase
from spawn_runner_base import TestDistSpawnRunner
from parallel_dygraph_se_resnext import TestSeResNeXt
import os
flag_name = os.path.splitext(__file__)[0]
......@@ -36,5 +41,12 @@ class TestParallelDygraphSeResNeXt(TestDistBase):
log_name=flag_name)
class TestParallelDygraphSeResNeXtSpawn(TestDistSpawnRunner):
def test_se_resnext_with_spawn(self):
if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4):
self.check_dist_result_with_spawn(
test_class=TestSeResNeXt, delta=0.01)
if __name__ == "__main__":
unittest.main()
......@@ -15,10 +15,13 @@
from __future__ import print_function
import os
import sys
import unittest
import paddle.fluid as fluid
import paddle.fluid as fluid
from test_dist_base import TestDistBase
from spawn_runner_base import TestDistSpawnRunner
from parallel_dygraph_sparse_embedding import TestSparseEmbedding
flag_name = os.path.splitext(__file__)[0]
......@@ -38,5 +41,12 @@ class TestParallelDygraphSparseEmdedding(TestDistBase):
log_name=flag_name)
class TestParallelDygraphSparseEmdeddingSpawn(TestDistSpawnRunner):
def test_sparse_embedding_with_spawn(self):
if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4):
self.check_dist_result_with_spawn(
test_class=TestSparseEmbedding, delta=1e-5)
if __name__ == "__main__":
unittest.main()
......@@ -15,10 +15,13 @@
from __future__ import print_function
import os
import sys
import unittest
import paddle.fluid as fluid
import paddle.fluid as fluid
from test_dist_base import TestDistBase
from spawn_runner_base import TestDistSpawnRunner
from parallel_dygraph_transformer import TestTransformer
flag_name = os.path.splitext(__file__)[0]
......@@ -38,5 +41,12 @@ class TestParallelDygraphTransformer(TestDistBase):
log_name=flag_name)
class TestParallelDygraphTransformerSpawn(TestDistSpawnRunner):
def test_transformer_with_spawn(self):
if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4):
self.check_dist_result_with_spawn(
test_class=TestTransformer, delta=1e-5)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import numpy as np
import unittest
import paddle
import paddle.distributed as dist
from paddle.distributed.spawn import _get_subprocess_env_list
from paddle.fluid import core
from paddle.fluid.dygraph import parallel_helper
# NOTE(chenweihang): Coverage CI is currently not able to count python3
# unittest, so the unittests here covers some cases that will only be
# executed in the python3 sub-process.
class TestInitParallelEnv(unittest.TestCase):
def test_beckend_type_error(self):
with self.assertRaises(TypeError):
dist.init_parallel_env(backend=1)
def test_backend_value_error(self):
with self.assertRaises(ValueError):
dist.init_parallel_env(backend="mpi")
def test_check_env_failed(self):
os.environ['FLAGS_selected_gpus'] = '0'
os.environ['PADDLE_TRAINER_ID'] = '0'
os.environ['PADDLE_CURRENT_ENDPOINT'] = '127.0.0.1:6170'
os.environ['PADDLE_TRAINERS_NUM'] = '1'
with self.assertRaises(ValueError):
dist.init_parallel_env()
def test_init_parallel_env_break(self):
os.environ['FLAGS_selected_gpus'] = '0'
os.environ['PADDLE_TRAINER_ID'] = '0'
os.environ['PADDLE_CURRENT_ENDPOINT'] = '127.0.0.1:6170'
os.environ['PADDLE_TRAINERS_NUM'] = '1'
os.environ['PADDLE_TRAINER_ENDPOINTS'] = '127.0.0.1:6170'
# coverage success branch
dist.init_parallel_env()
self.assertFalse(parallel_helper._is_parallel_ctx_initialized())
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSpawnAssistMethod(unittest.TestCase):
def test_only_cluster_node_ips_error(self):
with self.assertRaises(ValueError):
options = dict()
options['cluster_node_ips'] = "127.0.0.1,127.0.0.2"
_get_subprocess_env_list(nprocs=1, options=options)
def test_nprocs_greater_than_device_num_error(self):
with self.assertRaises(RuntimeError):
_get_subprocess_env_list(nprocs=100, options=dict())
def test_selected_gpus_error(self):
with self.assertRaises(ValueError):
options = dict()
options['selected_gpus'] = "100,101"
_get_subprocess_env_list(nprocs=2, options=options)
def test_get_correct_env(self):
env_dict = _get_subprocess_env_list(nprocs=1, options=dict())[0]
self.assertEqual(env_dict['PADDLE_TRAINER_ID'], '0')
self.assertEqual(env_dict['PADDLE_TRAINERS_NUM'], '1')
if __name__ == "__main__":
unittest.main()
......@@ -50,8 +50,6 @@ from ..fluid.dygraph.base import to_variable #DEFINE_ALIAS
from ..fluid.dygraph.base import grad #DEFINE_ALIAS
from ..fluid.dygraph.checkpoint import load_dygraph as load #DEFINE_ALIAS
from ..fluid.dygraph.checkpoint import save_dygraph as save #DEFINE_ALIAS
from ..fluid.dygraph.parallel import prepare_context #DEFINE_ALIAS
from ..fluid.dygraph.parallel import ParallelEnv #DEFINE_ALIAS
from ..fluid.dygraph.parallel import DataParallel #DEFINE_ALIAS
from ..fluid.dygraph.learning_rate_scheduler import NoamDecay #DEFINE_ALIAS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册