未验证 提交 46b73e6c 编写于 作者: S ShenLiang 提交者: GitHub

Change the api of DataParallel and Fleet (#29224)

上级 73e51a17
...@@ -135,6 +135,7 @@ message DistributedStrategy { ...@@ -135,6 +135,7 @@ message DistributedStrategy {
optional bool adaptive_localsgd = 24 [ default = false ]; optional bool adaptive_localsgd = 24 [ default = false ];
optional bool fp16_allreduce = 25 [ default = false ]; optional bool fp16_allreduce = 25 [ default = false ];
optional bool sharding = 26 [ default = false ]; optional bool sharding = 26 [ default = false ];
optional float last_comm_group_size_MB = 27 [ default = 1 ];
optional RecomputeConfig recompute_configs = 101; optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102; optional AMPConfig amp_configs = 102;
......
...@@ -18,6 +18,7 @@ from paddle.fluid.framework import Variable, set_flags, core ...@@ -18,6 +18,7 @@ from paddle.fluid.framework import Variable, set_flags, core
from paddle.fluid.wrapped_decorator import wrap_decorator from paddle.fluid.wrapped_decorator import wrap_decorator
import google.protobuf.text_format import google.protobuf.text_format
import google.protobuf import google.protobuf
from paddle.fluid.framework import dygraph_only
__all__ = ["DistributedStrategy"] __all__ = ["DistributedStrategy"]
...@@ -555,6 +556,32 @@ class DistributedStrategy(object): ...@@ -555,6 +556,32 @@ class DistributedStrategy(object):
else: else:
print("WARNING: fuse_grad_size_in_MB should have value of int type") print("WARNING: fuse_grad_size_in_MB should have value of int type")
@property
def last_comm_group_size_MB(self):
"""
Specifying the size of gradient to fuse in Mega-Bytes when
the last group of each batch communicates. Making the last group
small is useful to improve performance.
Default value: 1
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.last_comm_group_size_MB = 2
"""
return self.strategy.last_comm_group_size_MB
@last_comm_group_size_MB.setter
@is_strict_auto
def last_comm_group_size_MB(self, value):
if value > 0:
self.strategy.last_comm_group_size_MB = value
else:
raise ValueError("last_comm_group_size_MB should be greater than 0")
@property @property
def _fuse_grad_size_in_TFLOPS(self): def _fuse_grad_size_in_TFLOPS(self):
return self.strategy.fuse_grad_size_in_TFLOPS return self.strategy.fuse_grad_size_in_TFLOPS
......
...@@ -92,12 +92,11 @@ class Fleet(object): ...@@ -92,12 +92,11 @@ class Fleet(object):
import paddle import paddle
paddle.enable_static() paddle.enable_static()
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
fleet.init()
strategy = fleet.DistributedStrategy() strategy = fleet.DistributedStrategy()
fleet.init(strategy)
optimizer = paddle.optimizer.SGD(learning_rate=0.001) optimizer = paddle.optimizer.SGD(learning_rate=0.001)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) optimizer = fleet.distributed_optimizer(optimizer)
if fleet.is_first_worker(): if fleet.is_first_worker():
print("this is first worker") print("this is first worker")
...@@ -127,7 +126,7 @@ class Fleet(object): ...@@ -127,7 +126,7 @@ class Fleet(object):
self._util = None self._util = None
self._context = {} self._context = {}
def init(self, role_maker=None, is_collective=False): def init(self, role_maker=None, is_collective=False, strategy=None):
""" """
Initialize role_maker in Fleet. Initialize role_maker in Fleet.
...@@ -142,6 +141,10 @@ class Fleet(object): ...@@ -142,6 +141,10 @@ class Fleet(object):
is_collective (Boolean, optional): A ``Boolean`` variable determines whether the program is_collective (Boolean, optional): A ``Boolean`` variable determines whether the program
runs on the CPU or GPU. False means set distributed training using CPU, and True means runs on the CPU or GPU. False means set distributed training using CPU, and True means
GPU.The default value is False.The default value is False. GPU.The default value is False.The default value is False.
strategy (DistributedStrategy): Extra properties for distributed training.
For details, please refer to paddle.distributed.fleet.DistributedStrategy. Default: None.
Returns: Returns:
None None
...@@ -167,6 +170,14 @@ class Fleet(object): ...@@ -167,6 +170,14 @@ class Fleet(object):
role = fleet.PaddleCloudRoleMaker() role = fleet.PaddleCloudRoleMaker()
fleet.init(role) fleet.init(role)
Examples4:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
fleet.init(strategy)
""" """
if role_maker is None: if role_maker is None:
...@@ -209,6 +220,10 @@ class Fleet(object): ...@@ -209,6 +220,10 @@ class Fleet(object):
else: else:
paddle.distributed.init_parallel_env() paddle.distributed.init_parallel_env()
if strategy is None:
strategy = DistributedStrategy()
self._user_defined_strategy = copy.deepcopy(strategy)
def is_first_worker(self): def is_first_worker(self):
""" """
Check whether the node is the first instance of worker. Check whether the node is the first instance of worker.
...@@ -575,7 +590,11 @@ class Fleet(object): ...@@ -575,7 +590,11 @@ class Fleet(object):
Args: Args:
optimizer(Optimizer): The executor to run for init server. optimizer(Optimizer): The executor to run for init server.
strategy(DistributedStrategy): Extra properties for distributed optimizer. strategy(DistributedStrategy): Extra properties for distributed optimizer.
It is recommended to use DistributedStrategy in fleet.init(). The strategy
here is for compatibility. If the strategy in fleet.distributed_optimizer()
is not None, then it will overwrite the DistributedStrategy in fleet.init(),
which will take effect in distributed training.
Returns: Returns:
Fleet: instance of fleet. Fleet: instance of fleet.
...@@ -594,27 +613,25 @@ class Fleet(object): ...@@ -594,27 +613,25 @@ class Fleet(object):
""" """
self.user_defined_optimizer = optimizer self.user_defined_optimizer = optimizer
if strategy == None: if strategy is not None:
strategy = DistributedStrategy() warnings.warn(
"It is recommended to pass in DistributedStrategy"
"in fleet.init. The strategy here is for compatibility."
"If the `strategy` in fleet.distributed_optimizer() is"
"not None, then it will overwrite the DistributedStrategy in fleet.init(),"
"which will take effect in distributed training.")
self._user_defined_strategy = copy.deepcopy(strategy)
self._user_defined_strategy = copy.deepcopy(strategy)
self._context = {} self._context = {}
return self return self
@dygraph_only @dygraph_only
def distributed_model(self, model, group_size_limits=25, def distributed_model(self, model):
small_group_size=1):
""" """
Return distributed data parallel model (Only work in dygraph mode) Return distributed data parallel model (Only work in dygraph mode)
Args: Args:
model (Layer): the user-defind model which inherits Layer. model (Layer): the user-defind model which inherits Layer.
group_size_limits(int, optional): It is up limited memory size(MB) of one group
parameters' gradient which is the input of communication
calling(e.g NCCLAllReduce). Default: 25.
small_group_size(int, optional): It is up limited memory size(MB) of last group in communication
calling. Making the last group small is useful to
improve performance. Default: 1.
Returns: Returns:
distributed data parallel model which inherits Layer. distributed data parallel model which inherits Layer.
...@@ -667,8 +684,9 @@ class Fleet(object): ...@@ -667,8 +684,9 @@ class Fleet(object):
assert model is not None assert model is not None
self.model = paddle.DataParallel( self.model = paddle.DataParallel(
model, model,
group_size_limits=group_size_limits, comm_buffer_size=self._user_defined_strategy.fuse_grad_size_in_MB,
small_group_size=small_group_size) last_comm_buffer_size=self._user_defined_strategy.
last_comm_group_size_MB)
return self.model return self.model
@dygraph_only @dygraph_only
......
...@@ -309,11 +309,11 @@ class DataParallel(layers.Layer): ...@@ -309,11 +309,11 @@ class DataParallel(layers.Layer):
layers(Layer): The module that should be executed by data parallel. layers(Layer): The module that should be executed by data parallel.
strategy(ParallelStrategy, optional): (deprecated) The strategy of data parallelism, strategy(ParallelStrategy, optional): (deprecated) The strategy of data parallelism,
contains environment configuration related to parallel execution. Default: None. contains environment configuration related to parallel execution. Default: None.
group_size_limits(int, optional): It is up limited memory size(MB) of one group comm_buffer_size(int, optional): It limits the memory size(MB) of one buffer
parameters' gradient which is the input of communication parameters' gradient which is the input of communication
calling(e.g NCCLAllReduce). Default: 25. calling(e.g NCCLAllReduce). Default: 25.
small_group_size(int, optional): It is up limited memory size(MB) of last group in communication last_comm_buffer_size(float, optional): It limits memory size(MB) of last buffer in communication
calling. Making the last group small is useful to calling. Making the last communication buffer size small is useful to
improve performance. Default: 1. improve performance. Default: 1.
Returns: Returns:
...@@ -369,8 +369,8 @@ class DataParallel(layers.Layer): ...@@ -369,8 +369,8 @@ class DataParallel(layers.Layer):
def __init__(self, def __init__(self,
layers, layers,
strategy=None, strategy=None,
group_size_limits=25, comm_buffer_size=25,
small_group_size=1): last_comm_buffer_size=1):
super(DataParallel, super(DataParallel,
self).__init__(layers.full_name() + "_data_parallel") self).__init__(layers.full_name() + "_data_parallel")
...@@ -386,12 +386,13 @@ class DataParallel(layers.Layer): ...@@ -386,12 +386,13 @@ class DataParallel(layers.Layer):
self._strategy = _build_default_parallel_strategy() self._strategy = _build_default_parallel_strategy()
if self._strategy.nranks > 1: if self._strategy.nranks > 1:
self.group_size_limits = int(group_size_limits * 1024 * 1024) self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024)
# NOTE(shenliang03): We can set environment variables to control # NOTE(shenliang03): We can set environment variables to control
# the size of the group, Default: 1MB. The role of this small group is: # the size of the group, Default: 1MB. The role of this small group is:
# when the last group allreduce, the overlap cannot work. Making the # when the last group allreduce, the overlap cannot work. Making the
# the last group small is useful to improve performance. # the last group small is useful to improve performance.
self.small_group_size = int(small_group_size * 1024 * 1024) self.last_comm_buffer_size = int(last_comm_buffer_size * 1024 *
1024)
self.init_reducer() self.init_reducer()
else: else:
warnings.warn( warnings.warn(
...@@ -431,7 +432,7 @@ class DataParallel(layers.Layer): ...@@ -431,7 +432,7 @@ class DataParallel(layers.Layer):
self.group_indices = core.assign_group_by_size( self.group_indices = core.assign_group_by_size(
trainable_parameters, is_sparse_gradient, trainable_parameters, is_sparse_gradient,
[self.small_group_size, self.group_size_limits]) [self.last_comm_buffer_size, self.comm_buffer_size])
assert parallel_helper.__parallel_ctx__clz__ is not None, \ assert parallel_helper.__parallel_ctx__clz__ is not None, \
"ParallelContext must be initialized before. You should use init_parallel_env() before" \ "ParallelContext must be initialized before. You should use init_parallel_env() before" \
......
...@@ -169,6 +169,13 @@ class TestStrategyConfig(unittest.TestCase): ...@@ -169,6 +169,13 @@ class TestStrategyConfig(unittest.TestCase):
strategy.fuse_grad_size_in_MB = "40" strategy.fuse_grad_size_in_MB = "40"
self.assertEqual(strategy.fuse_grad_size_in_MB, 50) self.assertEqual(strategy.fuse_grad_size_in_MB, 50)
def test_last_comm_group_size_MB(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.last_comm_group_size_MB = 50
self.assertEqual(strategy.last_comm_group_size_MB, 50)
with self.assertRaises(ValueError):
strategy.last_comm_group_size_MB = -1
def test_fuse_grad_size_in_TFLOPS(self): def test_fuse_grad_size_in_TFLOPS(self):
strategy = paddle.distributed.fleet.DistributedStrategy() strategy = paddle.distributed.fleet.DistributedStrategy()
strategy._fuse_grad_size_in_TFLOPS = 0.1 strategy._fuse_grad_size_in_TFLOPS = 0.1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册