未验证 提交 4c780311 编写于 作者: W wuhuachaocoding 提交者: GitHub

update some input for pp and moe about recompute. (#45628)

上级 7c542c29
......@@ -195,6 +195,83 @@ class PipelineLayerChunk(Layer):
class PipelineLayer(Layer):
"""PipelineLayer
Args:
layers(Iterable): A sequence of layers description to define the structure for pipeline.
num_stages(int, optional): pp degree, if not specified, 'topology' parameter must be given.
topology(CommunicateTopology, optional): topo of hybrid parallel, if it is None, 'num_stages' parameters must be given.
loss_fn(callable, optional): Loss function.
seg_method(str, optional): the method of splitting pp layer, default 'uniform', or use specific layer to split, method's name must be start with 'layer:'.
recompute_interval(int, optional): the number of layers to be used recompute, the value of 0 represents no recompute. default 0.
recompute_ctx(dict,optional): the context of recompute, when 'recompute_interval' > 0, the context must be given.
num_virtual_pipeline_stages(int, optional): the num of virtual pipeline stages for interleave pp.
Examples:
.. code-block:: python
import paddle.nn as nn
from paddle.distributed import fleet
from paddle.fluid.dygraph.layers import Layer
import paddle.nn.functional as F
from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer
pipeline_parallel_size = 2
strategy = fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": 1,
"pp_degree": pipeline_parallel_size
}
strategy.pipeline_configs = {
"accumulate_steps": 4,
"micro_batch_size": 2
}
fleet.init(is_collective=True, strategy=strategy)
hcg = fleet.get_hybrid_communicate_group()
class ReshapeHelp(Layer):
def __init__(self, shape):
super(ReshapeHelp, self).__init__()
self.shape = shape
def forward(self, x):
return x.reshape(shape=self.shape)
class AlexNetPipeDesc(PipelineLayer):
def __init__(self, num_classes=10, **kwargs):
self.num_classes = num_classes
decs = [
LayerDesc(
nn.Conv2D, 1, 64, kernel_size=11, stride=4, padding=5),
LayerDesc(nn.ReLU),
LayerDesc(
nn.MaxPool2D, kernel_size=2, stride=2),
LayerDesc(
nn.Conv2D, 64, 192, kernel_size=5, padding=2),
F.relu,
LayerDesc(
nn.MaxPool2D, kernel_size=2, stride=2),
LayerDesc(
nn.Conv2D, 192, 384, kernel_size=3, padding=1),
F.relu,
LayerDesc(
nn.Conv2D, 384, 256, kernel_size=3, padding=1),
F.relu,
LayerDesc(
nn.Conv2D, 256, 256, kernel_size=3, padding=1),
F.relu,
LayerDesc(
nn.MaxPool2D, kernel_size=2, stride=2),
LayerDesc(
ReshapeHelp, shape=[-1, 256]),
LayerDesc(nn.Linear, 256, self.num_classes), # classifier
]
super(AlexNetPipeDesc, self).__init__(
layers=decs, loss_fn=nn.CrossEntropyLoss(), **kwargs)
model = AlexNetPipeDesc(num_stages=pipeline_parallel_size, topology=hcg._topo)
"""
def __init__(self,
layers,
......@@ -203,8 +280,7 @@ class PipelineLayer(Layer):
loss_fn=None,
seg_method="uniform",
recompute_interval=0,
recompute_offload=False,
recompute_partition=False,
recompute_ctx=None,
num_virtual_pipeline_stages=None):
super(PipelineLayer, self).__init__()
if num_stages is None and topology is None:
......@@ -233,14 +309,16 @@ class PipelineLayer(Layer):
self._loss_fn = loss_fn
self._topo = topology
self._recompute_interval = recompute_interval
self._recompute_offload = recompute_offload
self._recompute_partition = recompute_partition
if recompute_interval > 0:
assert recompute_ctx is not None, "recompute_ctx must be not None for recompute."
offload = recompute_ctx.get('offload', False)
partition = recompute_ctx.get('partition', False)
_initialize_recompute_setting(offload, partition)
logger.info(
"Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}"
.format(recompute_offload, recompute_partition))
_initialize_recompute_setting(recompute_offload, recompute_partition)
.format(offload, partition))
world_size = dist.get_world_size()
self.global_rank = dist.get_rank()
......
......@@ -113,20 +113,24 @@ class CriterionPipe(Layer):
class ModelPipe(PipelineLayer):
def __init__(self, topology):
def __init__(self, hcg):
self.descs = []
self.descs.append(LayerDesc(EmbeddingPipe))
self.hcg = hcg
for x in range(2):
self.descs.append(LayerDesc(TransformerNetPipe))
super().__init__(layers=self.descs,
loss_fn=CriterionPipe(),
topology=topology,
topology=self.hcg.topology(),
seg_method="layer:TransformerNetPipe",
recompute_interval=1,
recompute_partition=False,
recompute_offload=False)
recompute_ctx={
"mp_group": self.hcg.get_model_parallel_group(),
"offload": False,
"partition": False
})
class TestDistPPTraning(unittest.TestCase):
......@@ -156,7 +160,7 @@ class TestDistPPTraning(unittest.TestCase):
topology = hcg.topology()
set_random_seed(1024, dp_id, rank_id)
model = ModelPipe(topology)
model = ModelPipe(hcg)
scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[2],
values=[0.001, 0.002],
verbose=True)
......
......@@ -72,7 +72,12 @@ class TestPipeLayerAPI(unittest.TestCase):
seg_method="layer:Linear",
num_stages=self.pipeline_parallel_size,
num_virtual_pipeline_stages=2,
recompute_interval=1)
recompute_interval=1,
recompute_ctx={
"mp_group": self.hcg.get_model_parallel_group(),
"offload": False,
"partition": False
})
assert len(pipe_model.parameters()) > 0
model_chunks = pipe_model.get_model_chunks()
assert model_chunks is not None
......
......@@ -255,7 +255,8 @@ class MoELayer(nn.Layer):
moe_group: moe group for experts communication
mp_group: mp group for mp commutication
kwargs: other parameters
recompute_interval(int, optional): whether to use recompute, default 0, means to disable recompute.
recompute_ctx(dict, optional): the context for recompute, if recompute_interval > 1, recompute_ctx must be given.
Examples:
.. code-block:: python
from paddle.nn import layer, LayerList
......@@ -310,10 +311,11 @@ class MoELayer(nn.Layer):
gate=None,
moe_group=None,
mp_group=None,
**kwargs):
recompute_interval=0,
recompute_ctx=None):
super(MoELayer, self).__init__()
recompute_interval = kwargs.get("recompute_interval", 0)
self.recompute_ctx = recompute_ctx
if gate is None:
gate = dict()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册