diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 371a8b3e04121851f3922132724f1e9091f32819..4d40d0e7dedff19b733aed3011472c585877daa5 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -195,6 +195,83 @@ class PipelineLayerChunk(Layer): class PipelineLayer(Layer): + """PipelineLayer + Args: + layers(Iterable): A sequence of layers description to define the structure for pipeline. + num_stages(int, optional): pp degree, if not specified, 'topology' parameter must be given. + topology(CommunicateTopology, optional): topo of hybrid parallel, if it is None, 'num_stages' parameters must be given. + loss_fn(callable, optional): Loss function. + seg_method(str, optional): the method of splitting pp layer, default 'uniform', or use specific layer to split, method's name must be start with 'layer:'. + recompute_interval(int, optional): the number of layers to be used recompute, the value of 0 represents no recompute. default 0. + recompute_ctx(dict,optional): the context of recompute, when 'recompute_interval' > 0, the context must be given. + num_virtual_pipeline_stages(int, optional): the num of virtual pipeline stages for interleave pp. + Examples: + .. code-block:: python + import paddle.nn as nn + from paddle.distributed import fleet + from paddle.fluid.dygraph.layers import Layer + import paddle.nn.functional as F + from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer + + pipeline_parallel_size = 2 + strategy = fleet.DistributedStrategy() + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": 1, + "pp_degree": pipeline_parallel_size + } + strategy.pipeline_configs = { + "accumulate_steps": 4, + "micro_batch_size": 2 + } + + fleet.init(is_collective=True, strategy=strategy) + + hcg = fleet.get_hybrid_communicate_group() + + class ReshapeHelp(Layer): + def __init__(self, shape): + super(ReshapeHelp, self).__init__() + self.shape = shape + + def forward(self, x): + return x.reshape(shape=self.shape) + + class AlexNetPipeDesc(PipelineLayer): + def __init__(self, num_classes=10, **kwargs): + self.num_classes = num_classes + decs = [ + LayerDesc( + nn.Conv2D, 1, 64, kernel_size=11, stride=4, padding=5), + LayerDesc(nn.ReLU), + LayerDesc( + nn.MaxPool2D, kernel_size=2, stride=2), + LayerDesc( + nn.Conv2D, 64, 192, kernel_size=5, padding=2), + F.relu, + LayerDesc( + nn.MaxPool2D, kernel_size=2, stride=2), + LayerDesc( + nn.Conv2D, 192, 384, kernel_size=3, padding=1), + F.relu, + LayerDesc( + nn.Conv2D, 384, 256, kernel_size=3, padding=1), + F.relu, + LayerDesc( + nn.Conv2D, 256, 256, kernel_size=3, padding=1), + F.relu, + LayerDesc( + nn.MaxPool2D, kernel_size=2, stride=2), + LayerDesc( + ReshapeHelp, shape=[-1, 256]), + LayerDesc(nn.Linear, 256, self.num_classes), # classifier + ] + super(AlexNetPipeDesc, self).__init__( + layers=decs, loss_fn=nn.CrossEntropyLoss(), **kwargs) + + model = AlexNetPipeDesc(num_stages=pipeline_parallel_size, topology=hcg._topo) + + """ def __init__(self, layers, @@ -203,8 +280,7 @@ class PipelineLayer(Layer): loss_fn=None, seg_method="uniform", recompute_interval=0, - recompute_offload=False, - recompute_partition=False, + recompute_ctx=None, num_virtual_pipeline_stages=None): super(PipelineLayer, self).__init__() if num_stages is None and topology is None: @@ -233,14 +309,16 @@ class PipelineLayer(Layer): self._loss_fn = loss_fn self._topo = topology self._recompute_interval = recompute_interval - self._recompute_offload = recompute_offload - self._recompute_partition = recompute_partition if recompute_interval > 0: + assert recompute_ctx is not None, "recompute_ctx must be not None for recompute." + + offload = recompute_ctx.get('offload', False) + partition = recompute_ctx.get('partition', False) + _initialize_recompute_setting(offload, partition) logger.info( "Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}" - .format(recompute_offload, recompute_partition)) - _initialize_recompute_setting(recompute_offload, recompute_partition) + .format(offload, partition)) world_size = dist.get_world_size() self.global_rank = dist.get_rank() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_recompute.py b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_recompute.py index 8e364290bae6733619cc10af008157c281da49e7..e5be0f52fe3098858a568fee87c4737a1dabbe58 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_recompute.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_recompute.py @@ -113,20 +113,24 @@ class CriterionPipe(Layer): class ModelPipe(PipelineLayer): - def __init__(self, topology): + def __init__(self, hcg): self.descs = [] self.descs.append(LayerDesc(EmbeddingPipe)) + self.hcg = hcg for x in range(2): self.descs.append(LayerDesc(TransformerNetPipe)) super().__init__(layers=self.descs, loss_fn=CriterionPipe(), - topology=topology, + topology=self.hcg.topology(), seg_method="layer:TransformerNetPipe", recompute_interval=1, - recompute_partition=False, - recompute_offload=False) + recompute_ctx={ + "mp_group": self.hcg.get_model_parallel_group(), + "offload": False, + "partition": False + }) class TestDistPPTraning(unittest.TestCase): @@ -156,7 +160,7 @@ class TestDistPPTraning(unittest.TestCase): topology = hcg.topology() set_random_seed(1024, dp_id, rank_id) - model = ModelPipe(topology) + model = ModelPipe(hcg) scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[2], values=[0.001, 0.002], verbose=True) diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_layer_with_virtual_stage.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_layer_with_virtual_stage.py index 0ff14ad5f545234c8ad46cb484f69556a16414eb..1bd8e9348080ea12a651cdd4f72f0060888d1e1e 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_layer_with_virtual_stage.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_layer_with_virtual_stage.py @@ -72,7 +72,12 @@ class TestPipeLayerAPI(unittest.TestCase): seg_method="layer:Linear", num_stages=self.pipeline_parallel_size, num_virtual_pipeline_stages=2, - recompute_interval=1) + recompute_interval=1, + recompute_ctx={ + "mp_group": self.hcg.get_model_parallel_group(), + "offload": False, + "partition": False + }) assert len(pipe_model.parameters()) > 0 model_chunks = pipe_model.get_model_chunks() assert model_chunks is not None diff --git a/python/paddle/incubate/distributed/models/moe/moe_layer.py b/python/paddle/incubate/distributed/models/moe/moe_layer.py index bb308444a5822df0ed354a80d9036f6da4255234..ebf300abf95457d69c8437c4c3abd93835b88969 100644 --- a/python/paddle/incubate/distributed/models/moe/moe_layer.py +++ b/python/paddle/incubate/distributed/models/moe/moe_layer.py @@ -255,7 +255,8 @@ class MoELayer(nn.Layer): moe_group: moe group for experts communication mp_group: mp group for mp commutication - kwargs: other parameters + recompute_interval(int, optional): whether to use recompute, default 0, means to disable recompute. + recompute_ctx(dict, optional): the context for recompute, if recompute_interval > 1, recompute_ctx must be given. Examples: .. code-block:: python from paddle.nn import layer, LayerList @@ -310,10 +311,11 @@ class MoELayer(nn.Layer): gate=None, moe_group=None, mp_group=None, - **kwargs): + recompute_interval=0, + recompute_ctx=None): super(MoELayer, self).__init__() - recompute_interval = kwargs.get("recompute_interval", 0) + self.recompute_ctx = recompute_ctx if gate is None: gate = dict()