未验证 提交 9e0bb91c 编写于 作者: S ShenLiang 提交者: GitHub

[HybridParallel]Support 1f1b for PipelineParallel (#34483)

* support 1f1b for pipeline

* add utest

* add send_partial/recv_partial

* support amp for pp

* fix logger
上级 3b5fc2ad
......@@ -156,6 +156,10 @@ class HybridCommunicateGroup(object):
self.is_first_stage = (self.stage_id == 0)
self.is_last_stage = (self.stage_id == (self._pp_degree - 1))
# create p2p_groups
if self._pp_degree > 1:
self._set_p2p_group()
debug_str = "HybridParallelInfo: rank_id: %d, mp_degree: %d, " \
"sharding_degree: %d, pp_degree: %d, dp_degree: %d" % (self.global_rank, self._mp_degree,
self._sharding_degree, self._pp_degree, self._dp_degree)
......@@ -164,27 +168,9 @@ class HybridCommunicateGroup(object):
self._dp_group, self._check_group)
logger.info(debug_str)
# create p2p_groups and no new group
self._p2p_groups = self._build_p2p_lists()
global _HYBRID_PARALLEL_GROUP
_HYBRID_PARALLEL_GROUP = self
def _build_p2p_lists(self):
comm_lists = self._topo.get_comm_list('pipe')
p2p_lists = []
for rank in range(self.nranks):
for comm_ranks in comm_lists:
assert len(comm_ranks) == self._pp_degree
if rank in comm_ranks:
idx = comm_ranks.index(rank)
next_rank = comm_ranks[(idx + 1) % self._pp_degree]
p2p_lists.append([rank, next_rank])
break
assert len(
p2p_lists) == self.nranks, "len(p2p_lists) should be equal nranks"
return p2p_lists
def get_parallel_mode(self):
# there are four modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel
# NOTE when sharding conjugates with other parallel, sharding should act like a optimizer and
......@@ -236,6 +222,41 @@ class HybridCommunicateGroup(object):
return parallel_group, parallel_comm_group
def _set_p2p_group(self):
comm_lists = self._topo.get_comm_list('pipe')
self.send_next_group = None
self.send_prev_group = None
self.recv_next_group = None
self.recv_prev_group = None
for comm_ranks in comm_lists:
assert len(comm_ranks) == self._pp_degree
for idx, rank in enumerate(comm_ranks):
curr_rank = rank
next_rank = comm_ranks[(idx + 1) % self._pp_degree]
prev_rank = comm_ranks[(idx - 1) % self._pp_degree]
next_group = paddle.distributed.new_group(
ranks=[curr_rank, next_rank])
if self.global_rank == curr_rank:
self.send_next_group = next_group
elif self.global_rank == next_rank:
self.recv_prev_group = next_group
prev_group = paddle.distributed.new_group(
ranks=[prev_rank, curr_rank])
if self.global_rank == curr_rank:
self.send_prev_group = prev_group
elif self.global_rank == prev_rank:
self.recv_next_group = prev_group
assert self.send_next_group is not None
assert self.send_prev_group is not None
assert self.recv_next_group is not None
assert self.recv_prev_group is not None
def topology(self):
return self._topo
......@@ -287,6 +308,9 @@ class HybridCommunicateGroup(object):
def get_pipe_parallel_group(self):
return self._pp_comm_group
def get_p2p_groups(self):
return self.send_next_group, self.send_prev_group, self.recv_next_group, self.recv_prev_group
# sharding parallel message:
def _get_sharding_parallel_id(self):
return self._topo.get_coord(self.global_rank).sharding
......@@ -304,9 +328,6 @@ class HybridCommunicateGroup(object):
# TODO should the src rank related to the shard rank for each parameter ?
return self._sharding_comm_group.ranks[0]
def get_p2p_groups(self):
return self._p2p_groups
# check parallel group
def get_check_parallel_group(self):
return self._check_comm_group
......
......@@ -13,131 +13,388 @@
# limitations under the License.
import paddle
from .utils import paddle_2_number, number_2_dtype
from ...utils.log_util import logger
_groups = None
_hcg = None
def initialize_p2p_groups(hcg):
global _groups, _hcg
_groups = [
paddle.distributed.new_group(ranks=group)
for group in hcg.get_p2p_groups()
]
global _hcg
_hcg = hcg
send_next_group, send_prev_group, recv_next_group, recv_prev_group = _hcg.get_p2p_groups(
)
debug_str = "P2pInfo: send_next_group: %s, send_prev_group: %s, " \
"recv_next_group: %s, recv_prev_group: %s" % (repr(send_next_group),
repr(send_prev_group),repr(recv_next_group), repr(recv_prev_group))
logger.info(debug_str)
def _is_valid_communciate(src_stage, dest_stage):
first_stage = 0
last_stage = _hcg.get_pipe_parallel_world_size() - 1
assert abs(src_stage-dest_stage) == 1 or \
(src_stage == first_stage and dest_stage == last_stage) or \
(src_stage == last_stage and dest_stage == first_stage)
class SendRecvMeta:
"""Mainly used to help p2p communication context information"""
def partial_send_operator(tensor,
dst=0,
mp_ranks=1,
mp_rank_id=0,
group=None,
use_calc_stream=True):
def __init__(self):
self.send_shape_message = None
self.send_dtype_message = None
self.recv_shape_message = None
self.recv_dtype_message = None
self.has_send_meta = False
self.has_recv_meta = False
def _recv_shape_dtype(self, group):
# recv len(shape)
dims = paddle.to_tensor([0])
paddle.distributed.recv(dims, src=0, group=group)
dims = dims.item()
# recv shape
shape = paddle.to_tensor([0] * dims)
paddle.distributed.recv(shape, src=0, group=group)
# recv dtype
dtype = paddle.to_tensor([0])
paddle.distributed.recv(dtype, src=0, group=group)
return shape.numpy().tolist(), dtype.item()
def recv_meta(self, group):
tensor_type = paddle.to_tensor([0])
paddle.distributed.recv(tensor_type, src=0, group=group)
tensor_type = tensor_type.item()
if tensor_type == 0:
shape, dtype = self._recv_shape_dtype(group)
self.recv_shape_message = shape
self.recv_dtype_message = dtype
elif tensor_type == 1:
num = paddle.to_tensor([0])
paddle.distributed.recv(num, src=0, group=group)
num = num.item()
shapes = []
dtypes = []
for i in range(num):
shape, dtype = self._recv_shape_dtype(group)
shapes.append(shape)
dtypes.append(dtype)
self.recv_shape_message = tuple(shapes)
self.recv_dtype_message = tuple(dtypes)
def _send_dims_shape_dtype(self, tensor, group):
# send len(shape)
dims = paddle.to_tensor(len(tensor.shape))
paddle.distributed.send(dims, dst=1, group=group)
# send shape
shape = paddle.to_tensor(tensor.shape)
paddle.distributed.send(shape, dst=1, group=group)
# send dtype
dtype = paddle.to_tensor(paddle_2_number(tensor.dtype))
paddle.distributed.send(dtype, dst=1, group=group)
def send_meta(self, tensor, group):
if isinstance(tensor, paddle.Tensor):
tensor_type = paddle.to_tensor([0])
# send tensor type
paddle.distributed.send(tensor_type, dst=1, group=group)
self._send_dims_shape_dtype(tensor, group)
elif isinstance(tensor, tuple):
tensor_type = paddle.to_tensor([1])
# send tensor type
paddle.distributed.send(tensor_type, dst=1, group=group)
nums = paddle.to_tensor(len(tensor))
paddle.distributed.send(nums, dst=1, group=group)
for d in tensor:
assert isinstance(d, paddle.Tensor)
self._send_dims_shape_dtype(d, group=group)
def set_send_message(self, tensor):
if isinstance(tensor, paddle.Tensor):
self.send_shape_message = tensor.shape
self.send_dtype_message = paddle_2_number(tensor.dtype)
elif isinstance(tensor, tuple):
self.send_shape_message = tuple(
[d.shape for d in tensor if not d.stop_gradient])
self.send_dtype_message = tuple(
[paddle_2_number(d.dtype) for d in tensor])
_send_recv_meta = SendRecvMeta()
def send_partial(tensor,
dst=0,
nranks=1,
rank_id=0,
group=None,
use_calc_stream=True):
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
return paddle.fluid.core.ops.partial_send(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer',
dst, 'num', mp_ranks, 'id', mp_rank_id)
dst, 'num', nranks, 'id', rank_id)
def partial_recv_operator(tensor,
src=0,
mp_ranks=1,
mp_rank_id=0,
group=None,
use_calc_stream=True):
def recv_partial(tensor,
src=0,
nranks=1,
rank_id=0,
group=None,
use_calc_stream=True):
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
return paddle.fluid.core.ops.partial_recv(
paddle.fluid.core.ops.partial_recv(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer',
src, 'num', mp_ranks, 'id', mp_rank_id, 'dtype', tensor.dtype,
'out_shape', tensor.shape)
src, 'num', nranks, 'id', rank_id, 'dtype', tensor.dtype, 'out_shape',
tensor.shape)
def partial_allgather_operator(tensor,
mp_ranks=1,
mp_rank_id=0,
group=None,
use_calc_stream=True):
def allgather_partial(tensor,
nranks=1,
rank_id=0,
group=None,
use_calc_stream=True):
if nranks == 1:
return tensor
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
return paddle.fluid.core.ops.partial_allgather_(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id,
'nranks', mp_ranks, 'rank', mp_rank_id)
def send(tensor, dest_stage):
global _groups, _hcg
src_stage = _hcg.get_stage_id()
_is_valid_communciate(src_stage, dest_stage)
group = _get_send_recv_group(src_stage, dest_stage)
return paddle.distributed.send(
tensor, dst=1 if dest_stage > src_stage else 0, group=group)
def recv(tensor, src_stage):
global _groups, _hcg
dest_stage = _hcg.get_stage_id()
_is_valid_communciate(src_stage, dest_stage)
group = _get_send_recv_group(src_stage, dest_stage)
return paddle.distributed.recv(
tensor, src=0 if dest_stage > src_stage else 1, group=group)
def send_partial(tensor, dest_stage, mp_degree, mp_rank):
global _groups, _hcg
src_stage = _hcg.get_stage_id()
_is_valid_communciate(src_stage, dest_stage)
group = _get_send_recv_group(src_stage, dest_stage)
return partial_send_operator(
tensor,
dst=1 if dest_stage > src_stage else 0,
mp_ranks=mp_degree,
mp_rank_id=mp_rank,
group=group)
def recv_partial(tensor, src_stage, mp_degree, mp_rank):
global _groups, _hcg
dest_stage = _hcg.get_stage_id()
_is_valid_communciate(src_stage, dest_stage)
group = _get_send_recv_group(src_stage, dest_stage)
return partial_recv_operator(
tensor,
src=0 if dest_stage > src_stage else 1,
mp_ranks=mp_degree,
mp_rank_id=mp_rank,
group=group)
def _get_send_recv_group(src_stage, dest_stage):
global _groups, _hcg
stage_id = None
first_stage = 0
last_stage = _hcg.get_pipe_parallel_world_size() - 1
if (src_stage == first_stage and dest_stage == last_stage) or \
(dest_stage == first_stage and src_stage == last_stage):
stage_id = last_stage
elif src_stage > dest_stage:
stage_id = dest_stage
'nranks', nranks, 'rank', rank_id)
def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
global _hcg
tensor_recv_prev = None
tensor_recv_next = None
# send / recv message
recv_shape_msg = _send_recv_meta.recv_shape_message
recv_dtype_msg = _send_recv_meta.recv_dtype_message
send_shape_msg = _send_recv_meta.send_shape_message
send_dtype_msg = _send_recv_meta.send_dtype_message
# model parallel message
mp_group = _hcg.get_model_parallel_group()
mp_degree = _hcg.get_model_parallel_world_size()
mp_rank = _hcg.get_model_parallel_rank()
if recv_prev:
if isinstance(recv_shape_msg, tuple):
tensor_recv_prev = []
for idx, shape in enumerate(recv_shape_msg):
tensor_recv_prev.append(
paddle.empty(
shape=shape, dtype=number_2_dtype(recv_dtype_msg[idx])))
tensor_recv_prev = tuple(tensor_recv_prev)
else:
tensor_recv_prev = paddle.empty(
shape=recv_shape_msg, dtype=number_2_dtype(recv_dtype_msg))
if recv_next:
if isinstance(send_shape_msg, tuple):
tensor_recv_next = []
for idx, shape in enumerate(send_shape_msg):
tensor_recv_next.append(
paddle.empty(
shape=shape, dtype=number_2_dtype(send_dtype_msg[idx])))
tensor_recv_next = tuple(tensor_recv_next)
else:
tensor_recv_next = paddle.empty(
shape=send_shape_msg, dtype=number_2_dtype(send_dtype_msg))
# start to p2p communicate
if tensor_send_prev is not None:
if isinstance(tensor_send_prev, tuple):
for d in tensor_send_prev:
paddle.distributed.wait(d, use_calc_stream=True)
send_partial(
d,
dst=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.send_prev_group,
use_calc_stream=False)
else:
paddle.distributed.wait(tensor_send_prev, use_calc_stream=True)
send_partial(
tensor_send_prev,
dst=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.send_prev_group,
use_calc_stream=False)
if tensor_recv_prev is not None:
if isinstance(tensor_recv_prev, tuple):
for d in tensor_recv_prev:
recv_partial(
d,
src=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_prev_group,
use_calc_stream=True)
allgather_partial(
d,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
else:
recv_partial(
tensor_recv_prev,
src=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_prev_group,
use_calc_stream=True)
allgather_partial(
tensor_recv_prev,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
if tensor_send_next is not None:
if isinstance(tensor_send_next, tuple):
for d in tensor_send_next:
paddle.distributed.wait(d, use_calc_stream=True)
send_partial(
d,
dst=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.send_next_group,
use_calc_stream=False)
else:
paddle.distributed.wait(tensor_send_next, use_calc_stream=True)
send_partial(
tensor_send_next,
dst=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.send_next_group,
use_calc_stream=False)
if tensor_recv_next is not None:
if isinstance(tensor_recv_next, tuple):
for d in tensor_recv_next:
recv_partial(
d,
src=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_next_group,
use_calc_stream=True)
allgather_partial(
d,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
else:
recv_partial(
tensor_recv_next,
src=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_next_group,
use_calc_stream=True)
allgather_partial(
tensor_recv_next,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
return tensor_recv_prev, tensor_recv_next
def recv_forward():
if _hcg.is_first_stage:
input_tensor = None
else:
if not _send_recv_meta.has_recv_meta:
_send_recv_meta.recv_meta(_hcg.recv_prev_group)
_send_recv_meta.has_recv_meta = True
input_tensor, _ = _p2p_helper(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=True,
recv_next=False)
return input_tensor
def recv_backward():
if _hcg.is_last_stage:
output_tensor_grad = None
else:
_, output_tensor_grad = _p2p_helper(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=True)
return output_tensor_grad
def send_forward(output_tensor):
if not _hcg.is_last_stage:
if not _send_recv_meta.has_send_meta:
_send_recv_meta.set_send_message(output_tensor)
_send_recv_meta.send_meta(output_tensor, _hcg.send_next_group)
_send_recv_meta.has_send_meta = True
_p2p_helper(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=False)
def send_backward(input_tensor_grad):
if not _hcg.is_first_stage:
_p2p_helper(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False)
def send_forward_recv_backward(output_tensor):
if _hcg.is_last_stage:
output_tensor_grad = None
else:
_, output_tensor_grad = _p2p_helper(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=True)
return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad):
if _hcg.is_first_stage:
input_tensor = None
else:
stage_id = src_stage
group_id = _hcg.get_rank_from_stage(stage_id=stage_id)
return _groups[group_id]
input_tensor, _ = _p2p_helper(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=True,
recv_next=False)
return input_tensor
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import unittest
import paddle
import numpy as np
import random
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from paddle.fluid import layers
import paddle.nn.functional as F
from paddle.distributed.fleet.meta_parallel import PipelineLayer, LayerDesc
from paddle.fluid.dygraph.layers import Layer
import paddle.nn as nn
def set_random_seed(seed, dp_id, rank_id):
"""Set random seed for reproducability."""
random.seed(seed)
np.random.seed(seed + dp_id)
paddle.seed(seed + dp_id)
batch_size = 8
length = 8
micro_batch_size = 2
vocab_size = 128
hidden_size = 16
d_model = hidden_size
dim_feedforward = 4 * d_model
class EmbeddingNet(Layer):
def __init__(self):
super(EmbeddingNet, self).__init__()
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
self.position_embeddings = nn.Embedding(vocab_size, hidden_size)
def forward(self, x):
attention_mask = paddle.tensor.triu(
(paddle.ones(
(length, length), dtype="float32") * -1e9), 1)
attention_mask.stop_gradient = True
w_emb = self.word_embeddings(x)
p_emb = self.position_embeddings(x)
w_emb = w_emb + p_emb
# need to fix bug of backward()
return w_emb, attention_mask
class TransformerNet(Layer):
def __init__(self):
super(TransformerNet, self).__init__()
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
def forward(self, x, mask):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
product = layers.matmul(x=q, y=k, transpose_y=True, alpha=d_model**-0.5)
weights = F.softmax(product + mask)
weights = F.dropout(weights, 0.2)
tgt = layers.matmul(weights, v)
residual = tgt
tgt = self.norm1(tgt)
tgt = residual + tgt
out = self.linear2(F.gelu(self.linear1(tgt), approximate=True))
return out
class EmbeddingPipe(EmbeddingNet):
def forward(self, x):
return super().forward(x)
class TransformerNetPipe(TransformerNet):
def forward(self, args):
x, mask = args[0], args[1]
output = super().forward(x, mask)
output = output
mask.stop_gradient = True
return output, mask
class CriterionPipe(Layer):
def __init__(self):
super(CriterionPipe, self).__init__()
def forward(self, out, label):
loss = out.mean()
return loss
class ModelPipe(PipelineLayer):
def __init__(self, topology):
self.descs = []
self.descs.append(LayerDesc(EmbeddingPipe))
for x in range(5):
self.descs.append(LayerDesc(TransformerNetPipe))
self.descs.append(lambda x: x[0])
super().__init__(
layers=self.descs, loss_fn=CriterionPipe(), topology=topology)
class TestDistPPTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 1
self.data_parallel_size = 1
self.pipeline_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": self.pipeline_parallel_size,
}
strategy.pipeline_configs = {
"accumulate_steps": batch_size // micro_batch_size,
"micro_batch_size": micro_batch_size
}
fleet.init(is_collective=True, strategy=strategy)
def test_pp_model(self):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
dp_id = hcg.get_data_parallel_rank()
pp_id = hcg.get_stage_id()
rank_id = dist.get_rank()
topology = hcg.topology()
set_random_seed(1024, dp_id, rank_id)
model = ModelPipe(topology)
scheduler = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True)
optimizer = paddle.optimizer.SGD(learning_rate=scheduler,
parameters=model.parameters())
model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
for step_id in range(5):
x_data = np.random.randint(0, vocab_size, size=[batch_size, length])
x = paddle.to_tensor(x_data)
x.stop_gradient = True
loss = model.train_batch([x, x], optimizer, scheduler)
# TODO(shenliang03) add utest for loss
if __name__ == "__main__":
unittest.main()
......@@ -33,6 +33,9 @@ class TestHybridPipeParallel(TestMultipleGpus):
def test_pipeline_parallel(self):
self.run_mnist_2gpu('hybrid_parallel_pp_amp.py')
def test_hybrid_parallel_transformer(self):
self.run_mnist_2gpu('hybrid_parallel_pp_transformer.py')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册