未验证 提交 46879ff5 编写于 作者: S ShenLiang 提交者: GitHub

[HybridParallel]Add scatter-gather for pipeline (#34130)

* add scatter-gather opt

* fix topo for pp

* rename function
上级 e1e3e3b4
......@@ -68,14 +68,19 @@ reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/us
}
};
DECLARE_INPLACE_OP_INFERER(PartialAllGatherOpInplaceInferer, {"X", "Out"});
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(partial_allgather, ops::PartialAllGatherOp,
ops::PartialAllGatherOpMaker);
REGISTER_OPERATOR(
partial_allgather, ops::PartialAllGatherOp, ops::PartialAllGatherOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::PartialAllGatherOpInplaceInferer)
REGISTER_OP_CPU_KERNEL(partial_allgather,
ops::PartialAllGatherOpCPUKernel<float>,
......
......@@ -126,6 +126,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"accuracy", {"Correct", "Total"}},
{"fill_constant", {"Out"}},
{"recv_v2", {"Out"}},
{"partial_recv", {"Out"}},
{"matmul", {"Out"}},
{"c_broadcast", {"Out"}},
{"c_sync_calc_stream", {"Out"}},
......
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from types import MethodType
import numpy as np
import paddle
import paddle.fluid as fluid
......@@ -39,6 +39,8 @@ class PipelineParallel(MetaParallelBase):
self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
self.is_pipe_partitioned = self.use_model_parallel
self.num_caches = 0
self.caches = {
'inputs': [],
......@@ -70,6 +72,9 @@ class PipelineParallel(MetaParallelBase):
self.is_last_stage = (self.stage_id == (self.num_stages - 1))
self.global_rank = self._hcg.get_global_rank()
self.mp_degree = self._hcg.get_model_parallel_world_size()
self.mp_rank = self._hcg.get_model_parallel_rank()
logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format(
self.num_stages, self.stage_id))
......@@ -159,8 +164,8 @@ class PipelineParallel(MetaParallelBase):
else:
inputs = self.caches['inputs'][cache_id]
outputs = self._layers.forward(inputs)
self._clear_grads(inputs)
outputs = self._layers.forward(inputs)
self.caches['outputs'][cache_id] = outputs
......@@ -369,6 +374,11 @@ class PipelineParallel(MetaParallelBase):
caches = tuple(caches)
return caches
def _is_valid_send_recv(self, tensor):
tensor_numel = np.prod(tensor.shape)
assert tensor_numel != 0, "can't send/recv zero element"
return tensor_numel % self.mp_degree == 0
def _send_activations(self, cache_id):
outputs = self.caches['outputs'][cache_id]
......@@ -377,16 +387,39 @@ class PipelineParallel(MetaParallelBase):
self._send_meta(outputs, self.next_stage_id)
if isinstance(outputs, paddle.Tensor):
p2p.send(outputs, self.next_stage_id)
if self.is_pipe_partitioned and self._is_valid_send_recv(outputs):
p2p.send_partial(
outputs.detach(),
self.next_stage_id,
mp_degree=self.mp_degree,
mp_rank=self.mp_rank)
else:
p2p.send(outputs.detach(), self.next_stage_id)
elif isinstance(outputs, tuple):
for output in outputs:
p2p.send(output, self.next_stage_id)
if self.is_pipe_partitioned and self._is_valid_send_recv(
output):
p2p.send_partial(
output.detach(),
self.next_stage_id,
mp_degree=self.mp_degree,
mp_rank=self.mp_rank)
else:
p2p.send(output.detach(), self.next_stage_id)
def _send_gradients(self, cache_id):
inputs = self.caches['inputs'][cache_id]
if isinstance(inputs, paddle.Tensor):
assert inputs.grad is not None
if self.is_pipe_partitioned and self._is_valid_send_recv(
inputs.grad):
grad = p2p.send_partial(
inputs.grad,
self.prev_stage_id,
mp_degree=self.mp_degree,
mp_rank=self.mp_rank)
else:
p2p.send(inputs.grad, self.prev_stage_id)
else:
for idx, d in enumerate(inputs):
......@@ -394,6 +427,15 @@ class PipelineParallel(MetaParallelBase):
if not is_float_tensor(d):
assert d.grad is None
continue
if self.is_pipe_partitioned and self._is_valid_send_recv(
d.grad):
grad = p2p.send_partial(
d.grad,
self.prev_stage_id,
mp_degree=self.mp_degree,
mp_rank=self.mp_rank)
else:
p2p.send(d.grad, self.prev_stage_id)
self.caches['inputs'][cache_id] = None
......@@ -404,13 +446,37 @@ class PipelineParallel(MetaParallelBase):
self.recv_cache = self._recv_meta(self.prev_stage_id)
if isinstance(self.recv_cache, paddle.Tensor):
if self.is_pipe_partitioned and self._is_valid_send_recv(
self.recv_cache):
p2p.recv_partial(self.recv_cache, self.prev_stage_id,
self.mp_degree, self.mp_rank)
p2p.partial_allgather_operator(
self.recv_cache,
mp_ranks=self.mp_degree,
mp_rank_id=self.mp_rank,
group=self._hcg.get_model_parallel_group(),
use_calc_stream=True)
else:
p2p.recv(self.recv_cache, self.prev_stage_id)
inputs = self.recv_cache.clone().detach()
inputs.stop_gradient = not is_float_tensor(inputs)
else:
assert isinstance(self.recv_cache, tuple)
inputs = [None] * len(self.recv_cache)
for idx, d in enumerate(self.recv_cache):
if self.is_pipe_partitioned and self._is_valid_send_recv(d):
assert isinstance(d, paddle.Tensor)
p2p.recv_partial(d, self.prev_stage_id, self.mp_degree,
self.mp_rank)
p2p.partial_allgather_operator(
d,
mp_ranks=self.mp_degree,
mp_rank_id=self.mp_rank,
group=self._hcg.get_model_parallel_group(),
use_calc_stream=True)
else:
assert isinstance(d, paddle.Tensor)
p2p.recv(d, self.prev_stage_id)
inputs[idx] = d.clone().detach()
......@@ -440,10 +506,32 @@ class PipelineParallel(MetaParallelBase):
sizes, dtypes, num_caches=1)[0]
if isinstance(self.grad_tensors, paddle.Tensor):
if self.is_pipe_partitioned and self._is_valid_send_recv(
self.grad_tensors):
p2p.recv_partial(self.grad_tensors, self.next_stage_id,
self.mp_degree, self.mp_rank)
p2p.partial_allgather_operator(
self.grad_tensors,
mp_ranks=self.mp_degree,
mp_rank_id=self.mp_rank,
group=self._hcg.get_model_parallel_group(),
use_calc_stream=True)
else:
p2p.recv(self.grad_tensors, self.next_stage_id)
else:
assert isinstance(outputs, tuple)
for d in self.grad_tensors:
if self.is_pipe_partitioned and self._is_valid_send_recv(d):
p2p.recv_partial(d, self.next_stage_id, self.mp_degree,
self.mp_rank)
p2p.partial_allgather_operator(
d,
mp_ranks=self.mp_degree,
mp_rank_id=self.mp_rank,
group=self._hcg.get_model_parallel_group(),
use_calc_stream=True)
else:
p2p.recv(d, self.next_stage_id)
def _step(self):
......
......@@ -27,15 +27,67 @@ def initialize_p2p_groups(hcg):
_hcg = hcg
def _is_valid_communciate(src_stage, dest_stage):
first_stage = 0
last_stage = _hcg.get_pipe_parallel_world_size() - 1
assert abs(src_stage-dest_stage) == 1 or \
(src_stage == first_stage and dest_stage == last_stage) or \
(src_stage == last_stage and dest_stage == first_stage)
def partial_send_operator(tensor,
dst=0,
mp_ranks=1,
mp_rank_id=0,
group=None,
use_calc_stream=True):
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
return paddle.fluid.core.ops.partial_send(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer',
dst, 'num', mp_ranks, 'id', mp_rank_id)
def partial_recv_operator(tensor,
src=0,
mp_ranks=1,
mp_rank_id=0,
group=None,
use_calc_stream=True):
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
return paddle.fluid.core.ops.partial_recv(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer',
src, 'num', mp_ranks, 'id', mp_rank_id, 'dtype', tensor.dtype,
'out_shape', tensor.shape)
def partial_allgather_operator(tensor,
mp_ranks=1,
mp_rank_id=0,
group=None,
use_calc_stream=True):
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
return paddle.fluid.core.ops.partial_allgather_(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id,
'nranks', mp_ranks, 'rank', mp_rank_id)
def send(tensor, dest_stage):
global _groups, _hcg
src_stage = _hcg.get_stage_id()
src_rank = _hcg.get_rank_from_stage(stage_id=src_stage)
_is_valid_communciate(src_stage, dest_stage)
group = _get_send_recv_group(src_stage, dest_stage)
dst_rank = _hcg.get_rank_from_stage(stage_id=dest_stage)
return paddle.distributed.broadcast(tensor, src_rank, group=group)
return paddle.distributed.send(
tensor, dst=1 if dest_stage > src_stage else 0, group=group)
def recv(tensor, src_stage):
......@@ -44,16 +96,35 @@ def recv(tensor, src_stage):
_is_valid_communciate(src_stage, dest_stage)
group = _get_send_recv_group(src_stage, dest_stage)
src_rank = _hcg.get_rank_from_stage(stage_id=src_stage)
return paddle.distributed.broadcast(tensor, src_rank, group=group)
return paddle.distributed.recv(
tensor, src=0 if dest_stage > src_stage else 1, group=group)
def _is_valid_communciate(src_stage, dest_stage):
first_stage = 0
last_stage = _hcg.get_pipe_parallel_world_size() - 1
assert abs(src_stage-dest_stage) == 1 or \
(src_stage == first_stage and dest_stage == last_stage) or \
(src_stage == last_stage and dest_stage == first_stage)
def send_partial(tensor, dest_stage, mp_degree, mp_rank):
global _groups, _hcg
src_stage = _hcg.get_stage_id()
_is_valid_communciate(src_stage, dest_stage)
group = _get_send_recv_group(src_stage, dest_stage)
return partial_send_operator(
tensor,
dst=1 if dest_stage > src_stage else 0,
mp_ranks=mp_degree,
mp_rank_id=mp_rank,
group=group)
def recv_partial(tensor, src_stage, mp_degree, mp_rank):
global _groups, _hcg
dest_stage = _hcg.get_stage_id()
_is_valid_communciate(src_stage, dest_stage)
group = _get_send_recv_group(src_stage, dest_stage)
return partial_recv_operator(
tensor,
src=0 if dest_stage > src_stage else 1,
mp_ranks=mp_degree,
mp_rank_id=mp_rank,
group=group)
def _get_send_recv_group(src_stage, dest_stage):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册