未验证 提交 5c9fce0e 编写于 作者: S ShenLiang 提交者: GitHub

add p2p (#33864)

上级 c522530a
......@@ -164,9 +164,27 @@ class HybridCommunicateGroup(object):
self._dp_group, self._check_group)
logger.info(debug_str)
# create p2p_groups and no new group
self._p2p_groups = self._build_p2p_lists()
global _HYBRID_PARALLEL_GROUP
_HYBRID_PARALLEL_GROUP = self
def _build_p2p_lists(self):
comm_lists = self._topo.get_comm_list('pipe')
p2p_lists = []
for rank in range(self.nranks):
for comm_ranks in comm_lists:
assert len(comm_ranks) == self._pp_degree
if rank in comm_ranks:
idx = comm_ranks.index(rank)
next_rank = comm_ranks[(idx + 1) % self._pp_degree]
p2p_lists.append([rank, next_rank])
break
assert len(
p2p_lists) == self.nranks, "len(p2p_lists) should be equal nranks"
return p2p_lists
def get_parallel_mode(self):
# there are four modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel
# NOTE when sharding conjugates with other parallel, sharding should act like a optimizer and
......@@ -286,6 +304,9 @@ class HybridCommunicateGroup(object):
# TODO should the src rank related to the shard rank for each parameter ?
return self._sharding_comm_group.ranks[0]
def get_p2p_groups(self):
return self._p2p_groups
# check parallel group
def get_check_parallel_group(self):
return self._check_comm_group
......
......@@ -24,6 +24,7 @@ from ..utils.hybrid_parallel_util import broadcast_mp_parameters
from ..utils.hybrid_parallel_util import broadcast_dp_parameters
from ..utils.log_util import logger
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
from .pp_utils import p2p_communication as p2p
__all__ = []
......@@ -63,6 +64,7 @@ class PipelineParallel(MetaParallelBase):
self.prev_stage_id = self.stage_id - 1
self.next_stage_id = self.stage_id + 1
self.pp_group = self._hcg.get_pipe_parallel_group()
p2p.initialize_p2p_groups(hcg)
self.is_first_stage = self.stage_id == 0
self.is_last_stage = (self.stage_id == (self.num_stages - 1))
......@@ -275,97 +277,86 @@ class PipelineParallel(MetaParallelBase):
if isinstance(data, paddle.Tensor):
tensor_type = paddle.to_tensor([0])
# send tensor type
paddle.distributed.send(
tensor_type, peer, use_calc_stream=True, group=self.pp_group)
p2p.send(tensor_type, self.next_stage_id)
# send len(shape)
dims = paddle.to_tensor(len(data.shape))
paddle.distributed.send(
dims, peer, use_calc_stream=True, group=self.pp_group)
p2p.send(dims, self.next_stage_id)
# send shape
shape = paddle.to_tensor(data.shape)
paddle.distributed.send(
shape, peer, use_calc_stream=True, group=self.pp_group)
p2p.send(shape, self.next_stage_id)
# send dtype
dtype = paddle.to_tensor(paddle_2_number(data.dtype))
paddle.distributed.send(
dtype, peer, use_calc_stream=True, group=self.pp_group)
p2p.send(dtype, self.next_stage_id)
elif isinstance(data, tuple):
tensor_type = paddle.to_tensor([1])
paddle.distributed.send(
tensor_type, peer, use_calc_stream=True, group=self.pp_group)
p2p.send(tensor_type, self.next_stage_id)
nums = paddle.to_tensor(len(data))
paddle.distributed.send(
nums, peer, use_calc_stream=True, group=self.pp_group)
p2p.send(nums, self.next_stage_id)
for idx, d in enumerate(data):
assert isinstance(d, paddle.Tensor)
# send len(shape)
dims = paddle.to_tensor(len(d.shape))
paddle.distributed.send(
dims, peer, use_calc_stream=True, group=self.pp_group)
p2p.send(dims, self.next_stage_id)
# send shape
shape = paddle.to_tensor(d.shape)
paddle.distributed.send(
shape, peer, use_calc_stream=True, group=self.pp_group)
p2p.send(shape, self.next_stage_id)
# send dtype
dtype = paddle.to_tensor(paddle_2_number(d.dtype))
paddle.distributed.send(
dtype, peer, use_calc_stream=True, group=self.pp_group)
p2p.send(dtype, self.next_stage_id)
def _recv_meta(self, peer):
tensor_type = paddle.to_tensor([0])
paddle.distributed.recv(
tensor_type, peer, use_calc_stream=True, group=self.pp_group)
p2p.recv(tensor_type, self.prev_stage_id)
tensor_type = tensor_type.item()
if tensor_type == 0:
# recv len(shape)
dims = paddle.to_tensor([0])
paddle.distributed.recv(
dims, peer, use_calc_stream=True, group=self.pp_group)
p2p.recv(dims, self.prev_stage_id)
dims = dims.item()
# recv shape
shape = paddle.to_tensor([0] * dims)
paddle.distributed.recv(
shape, peer, use_calc_stream=True, group=self.pp_group)
p2p.recv(shape, self.prev_stage_id)
shape = shape.numpy().tolist()
# recv dtype
dtype = paddle.to_tensor([0])
paddle.distributed.recv(
dtype, peer, use_calc_stream=True, group=self.pp_group)
p2p.recv(dtype, self.prev_stage_id)
return self._allocate_cache(
shape, dtype=number_2_dtype(dtype.item()), num_caches=1)[0]
elif tensor_type == 1:
num = paddle.to_tensor([0])
paddle.distributed.recv(
num, peer, use_calc_stream=True, group=self.pp_group)
p2p.recv(num, self.prev_stage_id)
num = num.item()
shapes = []
dtypes = []
for i in range(num):
# recv len(shape)
dims = paddle.to_tensor([0])
paddle.distributed.recv(
dims, peer, use_calc_stream=True, group=self.pp_group)
p2p.recv(dims, self.prev_stage_id)
# recv shape
dims = dims.item()
shape = paddle.to_tensor([0] * dims)
paddle.distributed.recv(
shape, peer, use_calc_stream=True, group=self.pp_group)
p2p.recv(shape, self.prev_stage_id)
shapes.append(shape.numpy().tolist())
# recv dtype
dtype = paddle.to_tensor([0])
paddle.distributed.recv(
dtype, peer, use_calc_stream=True, group=self.pp_group)
p2p.recv(dtype, self.prev_stage_id)
dtypes.append(number_2_dtype(dtype.item()))
caches = self._allocate_caches(shapes, dtypes, num_caches=1)[0]
......@@ -380,39 +371,25 @@ class PipelineParallel(MetaParallelBase):
self._send_meta(outputs, self.next_stage_id)
if isinstance(outputs, paddle.Tensor):
paddle.distributed.send(
outputs,
self.next_stage_id,
use_calc_stream=True,
group=self.pp_group)
p2p.send(outputs, self.next_stage_id)
elif isinstance(outputs, tuple):
for output in outputs:
paddle.distributed.send(
output,
self.next_stage_id,
use_calc_stream=True,
group=self.pp_group)
p2p.send(output, self.next_stage_id)
def _send_gradients(self, cache_id):
inputs = self.caches['inputs'][cache_id]
if isinstance(inputs, paddle.Tensor):
assert inputs.grad is not None
paddle.distributed.send(
paddle.to_tensor(inputs.grad),
self.prev_stage_id,
use_calc_stream=True,
group=self.pp_group)
p2p.send(inputs.grad, self.prev_stage_id)
else:
for idx, d in enumerate(inputs):
# Skip tensors that will not produce a grad
if not is_float_tensor(d):
assert d.grad is None
continue
paddle.distributed.send(
d.grad,
self.prev_stage_id,
use_calc_stream=True,
group=self.pp_group)
p2p.send(d.grad, self.prev_stage_id)
self.caches['inputs'][cache_id] = None
def _recv_activations(self, cache_id):
......@@ -421,11 +398,7 @@ class PipelineParallel(MetaParallelBase):
self.recv_cache = self._recv_meta(self.prev_stage_id)
if isinstance(self.recv_cache, paddle.Tensor):
paddle.distributed.recv(
self.recv_cache,
self.prev_stage_id,
use_calc_stream=True,
group=self.pp_group)
p2p.recv(self.recv_cache, self.prev_stage_id)
inputs = self.recv_cache.clone().detach()
inputs.stop_gradient = not is_float_tensor(inputs)
else:
......@@ -433,12 +406,7 @@ class PipelineParallel(MetaParallelBase):
inputs = [None] * len(self.recv_cache)
for idx, d in enumerate(self.recv_cache):
assert isinstance(d, paddle.Tensor)
paddle.distributed.recv(
d,
self.prev_stage_id,
use_calc_stream=True,
group=self.pp_group)
p2p.recv(d, self.prev_stage_id)
inputs[idx] = d.clone().detach()
inputs = tuple(inputs)
......@@ -466,19 +434,11 @@ class PipelineParallel(MetaParallelBase):
sizes, dtypes, num_caches=1)[0]
if isinstance(self.grad_tensors, paddle.Tensor):
paddle.distributed.recv(
self.grad_tensors,
self.next_stage_id,
use_calc_stream=True,
group=self.pp_group)
p2p.recv(self.grad_tensors, self.next_stage_id)
else:
assert isinstance(outputs, tuple)
for d in self.grad_tensors:
paddle.distributed.recv(
d,
self.next_stage_id,
use_calc_stream=True,
group=self.pp_group)
p2p.recv(d, self.next_stage_id)
def _step(self):
self.optimizer.step()
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.distributed as dist
_groups = None
_hcg = None
def initialize_p2p_groups(hcg):
global _groups, _hcg
_groups = [dist.new_group(ranks=group) for group in hcg.get_p2p_groups()]
_hcg = hcg
def send(tensor, dest_stage):
global _groups, _hcg
src_stage = _hcg.get_stage_id()
src_rank = _hcg.get_rank_from_stage(stage_id=src_stage)
_is_valid_communciate(src_stage, dest_stage)
group = _get_send_recv_group(src_stage, dest_stage)
dst_rank = _hcg.get_rank_from_stage(stage_id=dest_stage)
return dist.broadcast(tensor, src_rank, group=group)
def recv(tensor, src_stage):
global _groups, _hcg
dest_stage = _hcg.get_stage_id()
_is_valid_communciate(src_stage, dest_stage)
group = _get_send_recv_group(src_stage, dest_stage)
src_rank = _hcg.get_rank_from_stage(stage_id=src_stage)
return dist.broadcast(tensor, src_rank, group=group)
def _is_valid_communciate(src_stage, dest_stage):
first_stage = 0
last_stage = _hcg.get_pipe_parallel_world_size() - 1
assert abs(src_stage-dest_stage) == 1 or \
(src_stage == first_stage and dest_stage == last_stage) or \
(src_stage == last_stage and dest_stage == first_stage)
def _get_send_recv_group(src_stage, dest_stage):
global _groups, _hcg
stage_id = None
first_stage = 0
last_stage = _hcg.get_pipe_parallel_world_size() - 1
if (src_stage == first_stage and dest_stage == last_stage) or \
(dest_stage == first_stage and src_stage == last_stage):
stage_id = last_stage
elif src_stage > dest_stage:
stage_id = dest_stage
else:
stage_id = src_stage
group_id = _hcg.get_rank_from_stage(stage_id=stage_id)
return _groups[group_id]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册