未验证 提交 bb16c251 编写于 作者: C Chen Weihang 提交者: GitHub

Polish parallel api impl & doc details (#28980)

* polish parallel api impl & doc details

* add unittest for coverage

* remove spawn test in py2.7

* add parallel api into white list
上级 c91bb084
...@@ -32,6 +32,17 @@ __all__ = ["init_parallel_env"] ...@@ -32,6 +32,17 @@ __all__ = ["init_parallel_env"]
ParallelStrategy = core.ParallelStrategy ParallelStrategy = core.ParallelStrategy
# NOTE(chenweihang): Maintain a global parallel env to avoid
# initializing ParallelEnv every time and improve performance
_global_parallel_env = None
def _get_global_parallel_env():
global _global_parallel_env
if _global_parallel_env is None:
_global_parallel_env = ParallelEnv()
return _global_parallel_env
def _start_kv_server(port, http_server_d): def _start_kv_server(port, http_server_d):
from paddle.distributed.fleet.utils.http_server import KVServer from paddle.distributed.fleet.utils.http_server import KVServer
...@@ -48,8 +59,7 @@ def init_parallel_env(): ...@@ -48,8 +59,7 @@ def init_parallel_env():
Initialize parallel training environment in dynamic graph mode. Initialize parallel training environment in dynamic graph mode.
.. note:: .. note::
Now only supports initializing the GPU parallel training Now initialize both `NCCL` and `GLOO` contexts for communication.
environment and using NCCL for communication.
Returns: Returns:
None None
...@@ -72,13 +82,10 @@ def init_parallel_env(): ...@@ -72,13 +82,10 @@ def init_parallel_env():
return self._linear2(self._linear1(x)) return self._linear2(self._linear1(x))
def train(): def train():
# 1. enable dynamic mode # 1. initialize parallel environment
paddle.disable_static()
# 2. initialize parallel environment
dist.init_parallel_env() dist.init_parallel_env()
# 3. create data parallel layer & optimizer # 2. create data parallel layer & optimizer
layer = LinearNet() layer = LinearNet()
dp_layer = paddle.DataParallel(layer) dp_layer = paddle.DataParallel(layer)
...@@ -86,7 +93,7 @@ def init_parallel_env(): ...@@ -86,7 +93,7 @@ def init_parallel_env():
adam = opt.Adam( adam = opt.Adam(
learning_rate=0.001, parameters=dp_layer.parameters()) learning_rate=0.001, parameters=dp_layer.parameters())
# 4. run layer # 3. run layer
inputs = paddle.randn([10, 10], 'float32') inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs) outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32') labels = paddle.randn([10, 1], 'float32')
...@@ -101,6 +108,18 @@ def init_parallel_env(): ...@@ -101,6 +108,18 @@ def init_parallel_env():
dist.spawn(train) dist.spawn(train)
""" """
# 0. get env & check world size
global _global_parallel_env
# when call init_parallel_env, need update `_global_parallel_env`
_global_parallel_env = ParallelEnv()
parallel_env = _global_parallel_env
# if not parallel, `init_parallel_env` do nothing
if parallel_env.world_size < 2:
warnings.warn(
"Currently not a parallel execution environment, `paddle.distributed.init_parallel_env` will not do anything."
)
return
# 1. gpu check # 1. gpu check
if not core.is_compiled_with_cuda(): if not core.is_compiled_with_cuda():
raise NotImplementedError( raise NotImplementedError(
...@@ -122,17 +141,14 @@ def init_parallel_env(): ...@@ -122,17 +141,14 @@ def init_parallel_env():
_check_var_exists("PADDLE_TRAINERS_NUM") _check_var_exists("PADDLE_TRAINERS_NUM")
_check_var_exists("PADDLE_TRAINER_ENDPOINTS") _check_var_exists("PADDLE_TRAINER_ENDPOINTS")
if ParallelEnv().world_size < 2:
return
# 3: init gloo context (step 1: httpsever start) # 3: init gloo context (step 1: httpsever start)
ep_rank_0 = ParallelEnv().trainer_endpoints[0].split(":") ep_rank_0 = parallel_env.trainer_endpoints[0].split(":")
ep_rank = ParallelEnv().trainer_endpoints[ParallelEnv().rank].split(":") ep_rank = parallel_env.trainer_endpoints[parallel_env.rank].split(":")
manager = Manager() manager = Manager()
# glboal dict to store status # glboal dict to store status
http_server_d = manager.dict() http_server_d = manager.dict()
http_server_d["running"] = False http_server_d["running"] = False
if ParallelEnv().rank == 0: if parallel_env.rank == 0:
http_server = Process( http_server = Process(
target=_start_kv_server, args=(int(ep_rank_0[1]), http_server_d)) target=_start_kv_server, args=(int(ep_rank_0[1]), http_server_d))
http_server.daemon = True http_server.daemon = True
...@@ -143,10 +159,10 @@ def init_parallel_env(): ...@@ -143,10 +159,10 @@ def init_parallel_env():
strategy = ParallelStrategy() strategy = ParallelStrategy()
if parallel_helper._is_parallel_ctx_initialized(): if parallel_helper._is_parallel_ctx_initialized():
warnings.warn("The parallel environment has been initialized.") warnings.warn("The parallel environment has been initialized.")
strategy.nranks = ParallelEnv().world_size strategy.nranks = parallel_env.world_size
strategy.local_rank = ParallelEnv().rank strategy.local_rank = parallel_env.rank
strategy.trainer_endpoints = ParallelEnv().trainer_endpoints strategy.trainer_endpoints = parallel_env.trainer_endpoints
strategy.current_endpoint = ParallelEnv().current_endpoint strategy.current_endpoint = parallel_env.current_endpoint
# NOTE(chenweihang): [ why config global place here? ] # NOTE(chenweihang): [ why config global place here? ]
# the dygraph mode will be set to default mode, # the dygraph mode will be set to default mode,
...@@ -154,7 +170,7 @@ def init_parallel_env(): ...@@ -154,7 +170,7 @@ def init_parallel_env():
# directly, if they want to switch default place, # directly, if they want to switch default place,
# they need to call a function to change default place, # they need to call a function to change default place,
# here just set correctly place to users # here just set correctly place to users
place = core.CUDAPlace(ParallelEnv().device_id) place = core.CUDAPlace(parallel_env.device_id)
_set_expected_place(place) _set_expected_place(place)
# init nccl context # init nccl context
...@@ -165,11 +181,11 @@ def init_parallel_env(): ...@@ -165,11 +181,11 @@ def init_parallel_env():
# dividing init_gloo into two part beacause nccl and gloo # dividing init_gloo into two part beacause nccl and gloo
# are separately looking for free ports which sometimes # are separately looking for free ports which sometimes
# leads to port-conflict. # leads to port-conflict.
wait_server_ready([ParallelEnv().trainer_endpoints[0]]) wait_server_ready([parallel_env.trainer_endpoints[0]])
gloo_strategy = core.GlooParallelStrategy() gloo_strategy = core.GlooParallelStrategy()
gloo_strategy.rank = ParallelEnv().rank gloo_strategy.rank = parallel_env.rank
gloo_strategy.rank_num = ParallelEnv().world_size gloo_strategy.rank_num = parallel_env.world_size
gloo_strategy.ip_address = ep_rank_0[0] gloo_strategy.ip_address = ep_rank_0[0]
gloo_strategy.ip_port = int(ep_rank_0[1]) gloo_strategy.ip_port = int(ep_rank_0[1])
default_init_timeout_seconds = 3600 default_init_timeout_seconds = 3600
...@@ -178,7 +194,7 @@ def init_parallel_env(): ...@@ -178,7 +194,7 @@ def init_parallel_env():
gloo_strategy.run_seconds = default_run_timeout_seconds gloo_strategy.run_seconds = default_run_timeout_seconds
gloo = core.GlooParallelContext(gloo_strategy) gloo = core.GlooParallelContext(gloo_strategy)
gloo.init() gloo.init()
if ParallelEnv().rank == 0: if parallel_env.rank == 0:
http_server_d["running"] = False http_server_d["running"] = False
http_server.join() http_server.join()
...@@ -203,7 +219,7 @@ def get_rank(): ...@@ -203,7 +219,7 @@ def get_rank():
print("The rank is %d" % dist.get_rank()) print("The rank is %d" % dist.get_rank())
# The rank is 0 # The rank is 0
""" """
return ParallelEnv().rank return _get_global_parallel_env().rank
def get_world_size(): def get_world_size():
...@@ -226,4 +242,4 @@ def get_world_size(): ...@@ -226,4 +242,4 @@ def get_world_size():
print("The world_size is %d" % dist.get_world_size()) print("The world_size is %d" % dist.get_world_size())
# The world_size is 4 # The world_size is 4
""" """
return ParallelEnv().world_size return _get_global_parallel_env().world_size
...@@ -68,6 +68,18 @@ def _py_supported_check(): ...@@ -68,6 +68,18 @@ def _py_supported_check():
"`paddle.distributed.launch` instead.") "`paddle.distributed.launch` instead.")
def _options_valid_check(options):
supported_options = [
'start_method', 'cluster_node_ips', 'node_ip', 'started_port',
'selected_gpus', 'print_config', 'use_paddlecloud'
]
for key in options:
if key not in supported_options:
raise ValueError(
"The config option (%s) of `paddle.distributed.spawn` is not supported."
% key)
def _get_subprocess_env_list(nprocs, options): def _get_subprocess_env_list(nprocs, options):
# contruct processes env list # contruct processes env list
processes_env_list = [] processes_env_list = []
...@@ -290,14 +302,11 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options): ...@@ -290,14 +302,11 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options):
def forward(self, x): def forward(self, x):
return self._linear2(self._linear1(x)) return self._linear2(self._linear1(x))
def train(print_result=False): def train(print_result=False):
# 1. enable dynamic mode # 1. initialize parallel environment
paddle.disable_static()
# 2. initialize parallel environment
dist.init_parallel_env() dist.init_parallel_env()
# 3. create data parallel layer & optimizer # 2. create data parallel layer & optimizer
layer = LinearNet() layer = LinearNet()
dp_layer = paddle.DataParallel(layer) dp_layer = paddle.DataParallel(layer)
...@@ -305,7 +314,7 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options): ...@@ -305,7 +314,7 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options):
adam = opt.Adam( adam = opt.Adam(
learning_rate=0.001, parameters=dp_layer.parameters()) learning_rate=0.001, parameters=dp_layer.parameters())
# 4. run layer # 3. run layer
inputs = paddle.randn([10, 10], 'float32') inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs) outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32') labels = paddle.randn([10, 1], 'float32')
...@@ -344,13 +353,13 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options): ...@@ -344,13 +353,13 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options):
# Usage 4: pass function, arguments, nprocs and selected_gpus. # Usage 4: pass function, arguments, nprocs and selected_gpus.
# If your training method need some arguments, and # If your training method need some arguments, and
# only use part of visible devices for parallel training, # only use part of visible devices for parallel training,
# but you can't set your machine's environment varibale # but you can't set your machine's environment variable
# CUDA_VISIBLE_DEVICES, such as it is None or all cards # CUDA_VISIBLE_DEVICES, such as it is None or all cards
# {0,1,2,3,4,5,6,7}, you can pass `selelcted_gpus` to # {0,1,2,3,4,5,6,7}, you can pass `selected_gpus` to
# select the GPU cards you want to use. For example, # select the GPU cards you want to use. For example,
# this case will use cards {4,5} if your machine hold 8 cards. # this case will use cards {4,5} if your machine hold 8 cards.
if __name__ == '__main__': if __name__ == '__main__':
dist.spawn(train, args=(True,), nprocs=2, selelcted_gpus='4,5') dist.spawn(train, args=(True,), nprocs=2, selected_gpus='4,5')
""" """
# NOTE(chenweihang): [ why only supports python3.4+ ? ] # NOTE(chenweihang): [ why only supports python3.4+ ? ]
# Python supported setting the child process startup method # Python supported setting the child process startup method
...@@ -359,6 +368,10 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options): ...@@ -359,6 +368,10 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options):
# cannot support CUDA runtime multi-process # cannot support CUDA runtime multi-process
_py_supported_check() _py_supported_check()
# Give an error hint when the users enter a configuration option
# that does not exist
_options_valid_check(options)
# get default nprocs # get default nprocs
if nprocs == -1: if nprocs == -1:
device = get_device() device = get_device()
......
...@@ -377,13 +377,10 @@ class DataParallel(layers.Layer): ...@@ -377,13 +377,10 @@ class DataParallel(layers.Layer):
return self._linear2(self._linear1(x)) return self._linear2(self._linear1(x))
def train(): def train():
# 1. enable dynamic mode # 1. initialize parallel environment
paddle.disable_static()
# 2. initialize parallel environment
dist.init_parallel_env() dist.init_parallel_env()
# 3. create data parallel layer & optimizer # 2. create data parallel layer & optimizer
layer = LinearNet() layer = LinearNet()
dp_layer = paddle.DataParallel(layer) dp_layer = paddle.DataParallel(layer)
...@@ -391,7 +388,7 @@ class DataParallel(layers.Layer): ...@@ -391,7 +388,7 @@ class DataParallel(layers.Layer):
adam = opt.Adam( adam = opt.Adam(
learning_rate=0.001, parameters=dp_layer.parameters()) learning_rate=0.001, parameters=dp_layer.parameters())
# 4. run layer # 3. run layer
inputs = paddle.randn([10, 10], 'float32') inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs) outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32') labels = paddle.randn([10, 1], 'float32')
...@@ -450,28 +447,28 @@ class DataParallel(layers.Layer): ...@@ -450,28 +447,28 @@ class DataParallel(layers.Layer):
include_sublayers=True, include_sublayers=True,
structured_name_prefix=""): structured_name_prefix=""):
''' '''
Get all parameters of self._layers and its sub-layers. And set all the parameters into a dict Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
Parameters: Parameters:
destination(dict, optional) : If provide, all the parameters will set to this dict . Default: None destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
include_sublayers(bool, optional) : If true, also include the parameters from sublayers. Default: True include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
structured_name_prefix(str, optional): If not empty str, all the key in state dict will start
with structured_name_prefix
Retruns: Retruns:
dict: a dict contains all the parameters of self._layers dict: a dict contains all the parameters and persistable buffers.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle
with fluid.dygraph.guard(): import paddle.distributed as dist
strategy=fluid.dygraph.prepare_context()
emb = fluid.dygraph.Embedding([10, 10]) dist.init_parallel_env()
emb = fluid.dygraph.DataParallel(emb, strategy)
emb = fluid.dygraph.Embedding([10, 10])
emb = fluid.dygraph.DataParallel(emb)
state_dict = emb.state_dict() state_dict = emb.state_dict()
fluid.save_dygraph( state_dict, "paddle_dy") paddle.save(state_dict, "paddle_dy.pdparams")
''' '''
...@@ -486,12 +483,12 @@ class DataParallel(layers.Layer): ...@@ -486,12 +483,12 @@ class DataParallel(layers.Layer):
include_sublayers=True, include_sublayers=True,
use_structured_name=True): use_structured_name=True):
''' '''
Set parameters of self._layers from state_dict. All the parameters of self._layers will be reset by the tensor in the state_dict Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict
Parameters: Parameters:
state_dict(dict) : Dict contains all the parameters state_dict(dict) : Dict contains all the parameters and persistable buffers.
include_sublayers(bool, optional) : If true, also include the parameters from sublayers. Default: True include_sublayers(bool, optional) : If true, also include the parameters and peresistable buffers from sublayers. Default: True
use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter name as key. use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key.
Default: True Default: True
Returns: Returns:
None None
...@@ -499,18 +496,18 @@ class DataParallel(layers.Layer): ...@@ -499,18 +496,18 @@ class DataParallel(layers.Layer):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.distributed as dist
paddle.disable_static() dist.init_parallel_env()
emb = paddle.nn.Embedding(10, 10) emb = paddle.nn.Embedding(10, 10)
emb = fluid.dygraph.DataParallel(emb, strategy) emb = fluid.dygraph.DataParallel(emb)
state_dict = emb.state_dict() state_dict = emb.state_dict()
paddle.save(state_dict, "paddle_dy.pdparams") paddle.save(state_dict, "paddle_dy.pdparams")
para_state_dict = paddle.load("paddle_dy.pdparams") para_state_dict = paddle.load("paddle_dy.pdparams")
emb.set_state_dict(para_state_dict) emb.set_state_dict(para_state_dict)
''' '''
......
...@@ -37,7 +37,7 @@ class TestInitParallelEnv(unittest.TestCase): ...@@ -37,7 +37,7 @@ class TestInitParallelEnv(unittest.TestCase):
os.environ['FLAGS_selected_gpus'] = '0' os.environ['FLAGS_selected_gpus'] = '0'
os.environ['PADDLE_TRAINER_ID'] = '0' os.environ['PADDLE_TRAINER_ID'] = '0'
os.environ['PADDLE_CURRENT_ENDPOINT'] = '127.0.0.1:6170' os.environ['PADDLE_CURRENT_ENDPOINT'] = '127.0.0.1:6170'
os.environ['PADDLE_TRAINERS_NUM'] = '1' os.environ['PADDLE_TRAINERS_NUM'] = '2'
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
dist.init_parallel_env() dist.init_parallel_env()
......
...@@ -379,6 +379,8 @@ ...@@ -379,6 +379,8 @@
"While.block", "While.block",
"DGCMomentumOptimizer", "DGCMomentumOptimizer",
"ParallelEnv", "ParallelEnv",
"spawn",
"init_parallel_env",
"DataParallel", "DataParallel",
"DataParallel.scale_loss", "DataParallel.scale_loss",
"DataParallel.apply_collective_grads", "DataParallel.apply_collective_grads",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册