未验证 提交 72b5b5bf 编写于 作者: Y Yuang Liu 提交者: GitHub

[dygraph hybrid pp for interleave] The interleave scheduler for pipeline parallel (#45497)

上级 fd86a938
......@@ -2384,7 +2384,7 @@ def isend(tensor, dst, group=None):
assert group_dst_rank >= 0, ("dst rank out of group, need global rank")
return group.process_group.send(tensor, group_dst_rank)
else:
raise RuntimeError("Don't support static graph mode currently.")
raise RuntimeError("Only support eager dygraph mode.")
def irecv(tensor, src=None, group=None):
......@@ -2433,7 +2433,7 @@ def irecv(tensor, src=None, group=None):
assert group_src_rank >= 0, ("src rank out of group, need global rank")
return group.process_group.recv(tensor, group_src_rank)
else:
raise RuntimeError("Don't support static graph mode currently.")
raise RuntimeError("Only support eager dygraph mode.")
class P2POp(object):
......
......@@ -240,6 +240,14 @@ class HybridCommunicateGroup(object):
return parallel_group, parallel_comm_group
def _get_p2p_next_rank(self):
assert hasattr(self, 'next_rank'), "next_rank has not been inited"
return self.next_rank
def _get_p2p_prev_rank(self):
assert hasattr(self, 'prev_rank'), "prev_rank has not been inited"
return self.prev_rank
def _set_p2p_group(self):
comm_lists = self._topo.get_comm_list('pipe')
......@@ -255,6 +263,10 @@ class HybridCommunicateGroup(object):
next_rank = comm_ranks[(idx + 1) % self._pp_degree]
prev_rank = comm_ranks[(idx - 1) % self._pp_degree]
if self.global_rank == curr_rank:
self.next_rank = next_rank
self.prev_rank = prev_rank
next_group = paddle.distributed.new_group(
ranks=[curr_rank, next_rank])
if self.global_rank == curr_rank:
......
......@@ -24,6 +24,7 @@ from .parallel_layers import model_parallel_random_seed # noqa: F401
from .parallel_layers import get_rng_state_tracker # noqa: F401
from .tensor_parallel import TensorParallel # noqa: F401
from .pipeline_parallel import PipelineParallel # noqa: F401
from .pipeline_parallel import PipelineParallelWithInterleave # noqa: F401
from .sharding_parallel import ShardingParallel # noqa: F401
__all__ = []
......@@ -189,7 +189,7 @@ class PipelineLayerChunk(Layer):
# Users shouldn't call PipelineLayerChunk directly, since all logics relating with recompute
# are in the forward function of PipelineLayer. Any directly call will bring unexpected
# behavior under recompute circumstance.
raise NotImplementedError(
raise PermissionError(
"The forward function of PipelineLayerChunk cannot be called directly. "
"Please call forward function of PipelineLayer.")
......@@ -385,6 +385,9 @@ class PipelineLayer(Layer):
start_idx + stage + 1]:
return stage
def get_num_virtual_stages(self):
return self._num_virtual_pipeline_stages
def get_model_chunks(self):
return None if self._num_virtual_pipeline_stages == 1 else self._model_chunks
......
......@@ -54,7 +54,7 @@ class SendRecvMeta:
def _recv_shape_dtype(self, group):
# recv len(shape)
dims = paddle.to_tensor([0])
src_rank = group.ranks[0]
src_rank = _hcg._get_p2p_prev_rank()
paddle.distributed.recv(dims, src=src_rank, group=group)
dims = dims.item()
......@@ -74,7 +74,7 @@ class SendRecvMeta:
def recv_meta(self, group):
tensor_type = paddle.to_tensor([0])
src_rank = group.ranks[0]
src_rank = _hcg._get_p2p_prev_rank()
paddle.distributed.recv(tensor_type, src=src_rank, group=group)
tensor_type = tensor_type.item()
......@@ -105,7 +105,7 @@ class SendRecvMeta:
def _send_dims_shape_dtype(self, tensor, group):
# send len(shape)
dims = paddle.to_tensor(len(tensor.shape))
dst_rank = group.ranks[1]
dst_rank = _hcg._get_p2p_next_rank()
paddle.distributed.send(dims, dst=dst_rank, group=group)
......@@ -122,7 +122,7 @@ class SendRecvMeta:
paddle.distributed.send(stop_grad, dst=dst_rank, group=group)
def send_meta(self, tensor, group):
dst_rank = group.ranks[1]
dst_rank = _hcg._get_p2p_next_rank()
if isinstance(tensor, (paddle.Tensor, core.eager.Tensor)):
tensor_type = paddle.to_tensor([0])
......@@ -165,20 +165,17 @@ def _is_valid_send_recv_partial(tensor, mp_degree):
def _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, nranks,
rank_id):
dst_rank_in_group = dst if group is None else group.get_group_rank(dst)
if _in_legacy_dygraph():
return _legacy_C_ops.partial_send(tensor.detach(), 'use_calc_stream',
use_calc_stream, 'ring_id', ring_id,
'peer', dst, 'num', nranks, 'id',
rank_id)
'peer', dst_rank_in_group, 'num',
nranks, 'id', rank_id)
elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group(
) if group is None else group
task = group.process_group.send_partial(tensor, dst, nranks, rank_id)
if use_calc_stream:
task.wait()
return None
else:
return task
return group.process_group.send_partial(tensor, dst_rank_in_group,
nranks, rank_id)
def send_partial(tensor,
......@@ -192,33 +189,35 @@ def send_partial(tensor,
return
ring_id = 0 if group is None else group.id
dst_rank = _hcg._get_p2p_next_rank(
) if dst == 1 else _hcg._get_p2p_prev_rank()
if _is_valid_send_recv_partial(tensor, nranks):
return _partial_send_op(tensor, group, use_calc_stream, ring_id, dst,
nranks, rank_id)
return _partial_send_op(tensor, group, use_calc_stream, ring_id,
dst_rank, nranks, rank_id)
else:
return paddle.distributed.send(tensor.detach(),
dst=group.ranks[dst],
group=group,
use_calc_stream=use_calc_stream)
if _in_legacy_dygraph():
send_op = paddle.distributed.send
elif in_dygraph_mode():
send_op = paddle.distributed.isend
return send_op(tensor.detach(), dst=dst_rank, group=group)
def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks,
rank_id):
src_rank_in_group = src if group is None else group.get_group_rank(src)
if _in_legacy_dygraph():
return _legacy_C_ops.partial_recv(tensor.detach(), 'use_calc_stream',
use_calc_stream, 'ring_id', ring_id,
'peer', src, 'num', nranks, 'id',
rank_id, 'dtype', tensor.dtype,
'out_shape', tensor.shape)
'peer', src_rank_in_group, 'num',
nranks, 'id', rank_id, 'dtype',
tensor.dtype, 'out_shape',
tensor.shape)
elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group(
) if group is None else group
task = group.process_group.recv_partial(tensor, src, nranks, rank_id)
if use_calc_stream:
task.wait()
return None
else:
return task
return group.process_group.recv_partial(tensor, src_rank_in_group,
nranks, rank_id)
def recv_partial(tensor,
......@@ -232,14 +231,18 @@ def recv_partial(tensor,
return
ring_id = 0 if group is None else group.id
src_rank = _hcg._get_p2p_prev_rank(
) if src == 0 else _hcg._get_p2p_next_rank()
if _is_valid_send_recv_partial(tensor, nranks):
return _partial_recv_op(tensor, group, use_calc_stream, ring_id, src,
nranks, rank_id)
return _partial_recv_op(tensor, group, use_calc_stream, ring_id,
src_rank, nranks, rank_id)
else:
return paddle.distributed.recv(tensor.detach(),
src=group.ranks[src],
group=group,
use_calc_stream=use_calc_stream)
if _in_legacy_dygraph():
recv_op = paddle.distributed.recv
elif in_dygraph_mode():
recv_op = paddle.distributed.irecv
return recv_op(tensor.detach(), src=src_rank, group=group)
def _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks,
......@@ -253,13 +256,8 @@ def _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks,
elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group(
) if group is None else group
task = group.process_group.all_gather_partial(tensor, tensor, nranks,
return group.process_group.all_gather_partial(tensor, tensor, nranks,
rank_id)
if use_calc_stream:
task.wait()
return None
else:
return task
def allgather_partial(tensor,
......@@ -268,9 +266,9 @@ def allgather_partial(tensor,
group=None,
use_calc_stream=True):
if not _is_valid_send_recv_partial(tensor, nranks):
return tensor
return None
if group is not None and not group.is_member():
return
return None
ring_id = 0 if group is None else group.id
return _partial_allgather_op(tensor, group, use_calc_stream, ring_id,
......@@ -323,105 +321,124 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
tensor_recv_next = paddle.empty(
shape=send_shape_msg, dtype=number_2_dtype(send_dtype_msg))
# TODO(Yuang Liu): use batch_isend_irecv replace all these comm ops
tasks = []
# start to p2p communicate
if tensor_send_prev is not None:
if isinstance(tensor_send_prev, tuple):
for d in tensor_send_prev:
paddle.distributed.wait(d, use_calc_stream=True)
send_partial(d,
tasks.append(
send_partial(d,
dst=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.send_prev_group,
use_calc_stream=False))
else:
paddle.distributed.wait(tensor_send_prev, use_calc_stream=True)
tasks.append(
send_partial(tensor_send_prev,
dst=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.send_prev_group,
use_calc_stream=False)
else:
paddle.distributed.wait(tensor_send_prev, use_calc_stream=True)
send_partial(tensor_send_prev,
dst=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.send_prev_group,
use_calc_stream=False)
use_calc_stream=False))
if tensor_recv_prev is not None:
if isinstance(tensor_recv_prev, tuple):
for d in tensor_recv_prev:
recv_partial(d,
tasks.append(
recv_partial(d,
src=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_prev_group,
use_calc_stream=True))
tasks.append(
allgather_partial(d,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True))
else:
tasks.append(
recv_partial(tensor_recv_prev,
src=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_prev_group,
use_calc_stream=True)
allgather_partial(d,
use_calc_stream=True))
tasks.append(
allgather_partial(tensor_recv_prev,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
else:
recv_partial(tensor_recv_prev,
src=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_prev_group,
use_calc_stream=True)
allgather_partial(tensor_recv_prev,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
use_calc_stream=True))
if tensor_send_next is not None:
if isinstance(tensor_send_next, tuple):
for d in tensor_send_next:
paddle.distributed.wait(d, use_calc_stream=True)
send_partial(d,
tasks.append(
send_partial(d,
dst=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.send_next_group,
use_calc_stream=False))
else:
paddle.distributed.wait(tensor_send_next, use_calc_stream=True)
tasks.append(
send_partial(tensor_send_next,
dst=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.send_next_group,
use_calc_stream=False)
else:
paddle.distributed.wait(tensor_send_next, use_calc_stream=True)
send_partial(tensor_send_next,
dst=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.send_next_group,
use_calc_stream=False)
use_calc_stream=False))
if tensor_recv_next is not None:
if isinstance(tensor_recv_next, tuple):
for d in tensor_recv_next:
recv_partial(d,
tasks.append(
recv_partial(d,
src=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_next_group,
use_calc_stream=True))
tasks.append(
allgather_partial(d,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True))
else:
tasks.append(
recv_partial(tensor_recv_next,
src=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_next_group,
use_calc_stream=True)
allgather_partial(d,
use_calc_stream=True))
tasks.append(
allgather_partial(tensor_recv_next,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
else:
recv_partial(tensor_recv_next,
src=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_next_group,
use_calc_stream=True)
allgather_partial(tensor_recv_next,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
use_calc_stream=True))
if in_dygraph_mode():
# wait tasks in new dygraph mode with new comm library
for task in tasks:
if task is not None:
task.wait()
return tensor_recv_prev, tensor_recv_next
def recv_forward():
if _hcg.is_first_stage:
def recv_forward(pp_first_stage):
if pp_first_stage:
input_tensor = None
else:
if not _send_recv_meta.has_recv_meta:
......@@ -435,8 +452,8 @@ def recv_forward():
return input_tensor
def recv_backward():
if _hcg.is_last_stage:
def recv_backward(pp_last_stage):
if pp_last_stage:
output_tensor_grad = None
else:
_, output_tensor_grad = _p2p_helper(tensor_send_next=None,
......@@ -446,8 +463,8 @@ def recv_backward():
return output_tensor_grad
def send_forward(output_tensor):
if not _hcg.is_last_stage:
def send_forward(output_tensor, pp_last_stage):
if not pp_last_stage:
if not _send_recv_meta.has_send_meta:
_send_recv_meta.set_send_message(output_tensor)
_send_recv_meta.send_meta(output_tensor, _hcg.send_next_group)
......@@ -459,16 +476,16 @@ def send_forward(output_tensor):
recv_next=False)
def send_backward(input_tensor_grad):
if not _hcg.is_first_stage:
def send_backward(input_tensor_grad, pp_first_stage):
if not pp_first_stage:
_p2p_helper(tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False)
def send_forward_recv_backward(output_tensor):
if _hcg.is_last_stage:
def send_forward_recv_backward(output_tensor, pp_last_stage):
if pp_last_stage:
output_tensor_grad = None
else:
_, output_tensor_grad = _p2p_helper(tensor_send_next=output_tensor,
......@@ -478,8 +495,8 @@ def send_forward_recv_backward(output_tensor):
return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad):
if _hcg.is_first_stage:
def send_backward_recv_forward(input_tensor_grad, pp_first_stage):
if pp_first_stage:
input_tensor = None
else:
input_tensor, _ = _p2p_helper(tensor_send_next=None,
......@@ -487,3 +504,48 @@ def send_backward_recv_forward(input_tensor_grad):
recv_prev=True,
recv_next=False)
return input_tensor
def send_forward_backward_recv_forward_backward(output_tensor,
input_tensor_grad, recv_prev,
recv_next):
# always have to send dytpe info to downstream
if not _send_recv_meta.has_send_meta:
_send_recv_meta.set_send_message(output_tensor)
_send_recv_meta.send_meta(output_tensor, _hcg.send_next_group)
_send_recv_meta.has_send_meta = _use_cache
if recv_prev and not _send_recv_meta.has_recv_meta:
_send_recv_meta.recv_meta(_hcg.recv_prev_group)
_send_recv_meta.has_recv_meta = _use_cache
input_tensor, output_tensor_grad = _p2p_helper(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next)
return input_tensor, output_tensor_grad
def send_forward_recv_forward(output_tensor, recv_prev):
# always have to send dytpe info to downstream
if not _send_recv_meta.has_send_meta:
_send_recv_meta.set_send_message(output_tensor)
_send_recv_meta.send_meta(output_tensor, _hcg.send_next_group)
_send_recv_meta.has_send_meta = _use_cache
if recv_prev and not _send_recv_meta.has_recv_meta:
_send_recv_meta.recv_meta(_hcg.recv_prev_group)
_send_recv_meta.has_recv_meta = _use_cache
input_tensor, _ = _p2p_helper(tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=False)
return input_tensor
def send_backward_recv_backward(input_tensor_grad, recv_next):
_, output_tensor_grad = _p2p_helper(tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=recv_next)
return output_tensor_grad
......@@ -18,7 +18,7 @@ import numpy as np
from .base import topology as tp
from .base.topology import ParallelMode
from .meta_parallel import TensorParallel, model_parallel_random_seed
from .meta_parallel import PipelineParallel, ShardingParallel
from .meta_parallel import PipelineParallel, ShardingParallel, PipelineParallelWithInterleave, PipelineLayer
from paddle.fluid import core
from paddle.distributed.fleet.utils.recompute import LegacyRecomputeFunction
from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar
......@@ -185,6 +185,16 @@ def distributed_model(model):
elif fleet_env._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL:
model = TensorParallel(model, fleet_env._hcg, strategy=strategy)
elif fleet_env._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL:
model = PipelineParallel(model, fleet_env._hcg, strategy=strategy)
assert isinstance(
model, PipelineLayer
), "For pipeline parallel, the model should an instance of PipelineLayer"
if model.get_num_virtual_stages() == 1:
# 1f1b pipeline
model = PipelineParallel(model, fleet_env._hcg, strategy=strategy)
else:
# interleave pipeline
model = PipelineParallelWithInterleave(model,
fleet_env._hcg,
strategy=strategy)
return model
......@@ -27,8 +27,6 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel)
list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_feedforward)
list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_attention)
list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_multi_transformer)
list(APPEND DIST_TEST_OPS
test_parallel_dygraph_pipeline_parallel_with_virtual_stage)
list(APPEND DIST_TEST_OPS test_auto_parallel_data_unshard)
list(APPEND DIST_TEST_OPS test_auto_parallel_save_load)
list(APPEND DIST_TEST_OPS test_auto_parallel_autoconvert)
......@@ -178,8 +176,6 @@ if((NOT WITH_GPU) AND (NOT WITH_ROCM))
# TODO(shenliang03): batch_fc_op support CPU device in future
# TODO(Yancey1989): parallel dygraph support CPU device in future
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel)
list(REMOVE_ITEM TEST_OPS
test_parallel_dygraph_pipeline_parallel_with_virtual_stage)
list(REMOVE_ITEM TEST_OPS test_fleet_base_single)
list(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner)
list(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner_gpt)
......@@ -1178,9 +1174,6 @@ set_tests_properties(test_graph_send_uv_op PROPERTIES TIMEOUT 60)
if(WITH_DISTRIBUTE
AND WITH_GPU
AND WITH_NCCL)
set_tests_properties(
test_parallel_dygraph_pipeline_parallel_with_virtual_stage
PROPERTIES TIMEOUT 500)
set_tests_properties(test_auto_parallel_data_unshard PROPERTIES TIMEOUT 120)
set_tests_properties(test_auto_parallel_save_load PROPERTIES TIMEOUT 120)
set_tests_properties(test_auto_parallel_autoconvert PROPERTIES TIMEOUT 120)
......
......@@ -204,6 +204,20 @@ if((WITH_GPU) AND LOCAL_ALL_PLAT)
set_tests_properties(test_parallel_dygraph_pipeline_parallel
PROPERTIES TIMEOUT "500")
endif()
if((WITH_GPU) AND LOCAL_ALL_PLAT)
bash_test_modules(
test_parallel_dygraph_pipeline_parallel_with_virtual_stage
START_BASH
../../dist_test.sh
LABELS
"RUN_TYPE=DIST"
ENVS
"PADDLE_DIST_UT_PORT=21282;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python"
)
set_tests_properties(
test_parallel_dygraph_pipeline_parallel_with_virtual_stage
PROPERTIES TIMEOUT "500" RUN_SERIAL 1)
endif()
if((WITH_GPU
OR WITH_XPU
OR WITH_ASCEND
......
......@@ -19,7 +19,7 @@ import paddle
from paddle.distributed import fleet
import paddle.nn as nn
from paddle.fluid.dygraph.layers import Layer
from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer
from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer, PipelineParallelWithInterleave
import paddle.nn.functional as F
......@@ -87,7 +87,8 @@ class TestPipeLayerAPI(unittest.TestCase):
try:
model_chunks[0](paddle.to_tensor([1., 2.]))
except NotImplementedError:
raise NotImplementedError
except PermissionError:
pass
# fake call for the forward function of virtual pipeline layer
......@@ -102,6 +103,7 @@ class TestPipeLayerAPI(unittest.TestCase):
# just make sure the model can be wrapped with distributed model
dist_model = fleet.distributed_model(pipe_model)
assert isinstance(dist_model, PipelineParallelWithInterleave)
if __name__ == '__main__':
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import unittest
import paddle
import numpy as np
import random
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from paddle.fluid import layers
import paddle.nn.functional as F
from paddle.distributed.fleet.meta_parallel import PipelineLayer, LayerDesc
from paddle.fluid.dygraph.layers import Layer
import paddle.nn as nn
def set_random_seed(seed, dp_id, rank_id):
"""Set random seed for reproducability."""
random.seed(seed)
np.random.seed(seed + dp_id)
paddle.seed(seed + dp_id)
batch_size = 8
length = 8
micro_batch_size = 2
num_virtual_pipeline_stages = 2
vocab_size = 128
hidden_size = 16
d_model = hidden_size
dim_feedforward = 4 * d_model
class EmbeddingNet(Layer):
def __init__(self):
super(EmbeddingNet, self).__init__()
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
self.position_embeddings = nn.Embedding(vocab_size, hidden_size)
def forward(self, x):
attention_mask = paddle.tensor.triu((paddle.ones(
(length, length), dtype="float32") * -1e9), 1)
no_used = paddle.ones((3, 3), dtype="int32")
w_emb = self.word_embeddings(x)
p_emb = self.position_embeddings(x)
w_emb = w_emb + p_emb
attention_mask.stop_gradient = True
no_used.stop_gradient = True
# need to fix bug of backward()
return w_emb, attention_mask, no_used, p_emb
class TransformerNet(Layer):
def __init__(self):
super(TransformerNet, self).__init__()
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
def forward(self, x, mask):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
product = layers.matmul(x=q, y=k, transpose_y=True, alpha=d_model**-0.5)
weights = F.softmax(product + mask)
tgt = layers.matmul(weights, v)
residual = tgt
tgt = self.norm1(tgt)
tgt = residual + tgt
out = self.linear2(F.gelu(self.linear1(tgt), approximate=True))
return out
class EmbeddingPipe(EmbeddingNet):
def forward(self, x):
return super().forward(x)
class TransformerNetPipe(TransformerNet):
def forward(self, args):
x, mask, no_used, p_emb = args[0], args[1], args[2], args[3]
output = super().forward(x, mask)
output = output + p_emb
mask.stop_gradient = True
return output, mask, no_used, p_emb
class CriterionPipe(Layer):
def __init__(self):
super(CriterionPipe, self).__init__()
def forward(self, out, label):
loss = out.mean()
return loss
class ModelPipe(PipelineLayer):
def __init__(self, topology):
self.descs = []
self.descs.append(LayerDesc(EmbeddingPipe))
for x in range(8):
self.descs.append(LayerDesc(TransformerNetPipe))
self.descs.append(lambda x: x[0])
super().__init__(
layers=self.descs,
loss_fn=CriterionPipe(),
topology=topology,
num_virtual_pipeline_stages=num_virtual_pipeline_stages,
seg_method="layer:TransformerNetPipe")
class TestDistPPTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 1
self.data_parallel_size = 1
self.pipeline_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": self.pipeline_parallel_size,
}
strategy.pipeline_configs = {
"accumulate_steps": batch_size // micro_batch_size,
"micro_batch_size": micro_batch_size
}
fleet.init(is_collective=True, strategy=strategy)
def test_pp_model(self):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
dp_id = hcg.get_data_parallel_rank()
pp_id = hcg.get_stage_id()
rank_id = dist.get_rank()
topology = hcg.topology()
set_random_seed(1024, dp_id, rank_id)
model = ModelPipe(topology)
scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[2],
values=[0.001, 0.002],
verbose=True)
optimizer = paddle.optimizer.SGD(learning_rate=scheduler,
parameters=model.parameters())
model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
for step_id in range(5):
x_data = np.random.randint(0, vocab_size, size=[batch_size, length])
x = paddle.to_tensor(x_data)
x.stop_gradient = True
e_loss = model.eval_batch([x, x], True)
loss = model.train_batch([x, x], optimizer, scheduler)
np.testing.assert_allclose(loss.numpy(), e_loss.numpy())
if __name__ == "__main__":
unittest.main()
......@@ -25,8 +25,10 @@ class TestHybridPipeParallelWithVirtualStage(TestMultipleGpus):
def test_hybrid_parallel_pp_layer_with_virtual_stage(self):
self.run_mnist_2gpu('hybrid_parallel_pp_layer_with_virtual_stage.py')
self.run_mnist_2gpu('hybrid_parallel_pp_layer_with_virtual_stage.py',
eager_mode=False)
def test_hybrid_parallel_pp_transformer_with_virtual_stage(self):
self.run_mnist_2gpu(
'hybrid_parallel_pp_transformer_with_virtual_stage.py')
if __name__ == "__main__":
......
......@@ -16,6 +16,7 @@ test_fleet_graph_execution_meta_optimizer,,GPU;XPU;ASCEND;ASCEND_CL,,DIST,../../
test_communicator_half_async,,,120,DIST,test_runner.py,2,,FLAGS_communicator_send_queue_size=1;FLAGS_communicator_max_merge_var_num=1;http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL
test_fleet_graph_executor,,GPU;XPU;ASCEND;ASCEND_CL,,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_pipeline_parallel,,GPU,500,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_pipeline_parallel_with_virtual_stage,,GPU,500,DIST,../../dist_test.sh,2,1,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_localsgd_meta_optimizer,LINUX,GPU;XPU;ASCEND;ASCEND_CL,,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_class_center_sample,,GPU,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL
test_pipeline,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册