未验证 提交 46879ff5 编写于 作者: S ShenLiang 提交者: GitHub

[HybridParallel]Add scatter-gather for pipeline (#34130)

* add scatter-gather opt

* fix topo for pp

* rename function
上级 e1e3e3b4
...@@ -68,14 +68,19 @@ reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/us ...@@ -68,14 +68,19 @@ reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/us
} }
}; };
DECLARE_INPLACE_OP_INFERER(PartialAllGatherOpInplaceInferer, {"X", "Out"});
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(partial_allgather, ops::PartialAllGatherOp, REGISTER_OPERATOR(
ops::PartialAllGatherOpMaker); partial_allgather, ops::PartialAllGatherOp, ops::PartialAllGatherOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::PartialAllGatherOpInplaceInferer)
REGISTER_OP_CPU_KERNEL(partial_allgather, REGISTER_OP_CPU_KERNEL(partial_allgather,
ops::PartialAllGatherOpCPUKernel<float>, ops::PartialAllGatherOpCPUKernel<float>,
......
...@@ -126,6 +126,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = { ...@@ -126,6 +126,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"accuracy", {"Correct", "Total"}}, {"accuracy", {"Correct", "Total"}},
{"fill_constant", {"Out"}}, {"fill_constant", {"Out"}},
{"recv_v2", {"Out"}}, {"recv_v2", {"Out"}},
{"partial_recv", {"Out"}},
{"matmul", {"Out"}}, {"matmul", {"Out"}},
{"c_broadcast", {"Out"}}, {"c_broadcast", {"Out"}},
{"c_sync_calc_stream", {"Out"}}, {"c_sync_calc_stream", {"Out"}},
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
from types import MethodType import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -39,6 +39,8 @@ class PipelineParallel(MetaParallelBase): ...@@ -39,6 +39,8 @@ class PipelineParallel(MetaParallelBase):
self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1 self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1 self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
self.is_pipe_partitioned = self.use_model_parallel
self.num_caches = 0 self.num_caches = 0
self.caches = { self.caches = {
'inputs': [], 'inputs': [],
...@@ -70,6 +72,9 @@ class PipelineParallel(MetaParallelBase): ...@@ -70,6 +72,9 @@ class PipelineParallel(MetaParallelBase):
self.is_last_stage = (self.stage_id == (self.num_stages - 1)) self.is_last_stage = (self.stage_id == (self.num_stages - 1))
self.global_rank = self._hcg.get_global_rank() self.global_rank = self._hcg.get_global_rank()
self.mp_degree = self._hcg.get_model_parallel_world_size()
self.mp_rank = self._hcg.get_model_parallel_rank()
logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format( logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format(
self.num_stages, self.stage_id)) self.num_stages, self.stage_id))
...@@ -159,8 +164,8 @@ class PipelineParallel(MetaParallelBase): ...@@ -159,8 +164,8 @@ class PipelineParallel(MetaParallelBase):
else: else:
inputs = self.caches['inputs'][cache_id] inputs = self.caches['inputs'][cache_id]
outputs = self._layers.forward(inputs)
self._clear_grads(inputs) self._clear_grads(inputs)
outputs = self._layers.forward(inputs)
self.caches['outputs'][cache_id] = outputs self.caches['outputs'][cache_id] = outputs
...@@ -369,6 +374,11 @@ class PipelineParallel(MetaParallelBase): ...@@ -369,6 +374,11 @@ class PipelineParallel(MetaParallelBase):
caches = tuple(caches) caches = tuple(caches)
return caches return caches
def _is_valid_send_recv(self, tensor):
tensor_numel = np.prod(tensor.shape)
assert tensor_numel != 0, "can't send/recv zero element"
return tensor_numel % self.mp_degree == 0
def _send_activations(self, cache_id): def _send_activations(self, cache_id):
outputs = self.caches['outputs'][cache_id] outputs = self.caches['outputs'][cache_id]
...@@ -377,24 +387,56 @@ class PipelineParallel(MetaParallelBase): ...@@ -377,24 +387,56 @@ class PipelineParallel(MetaParallelBase):
self._send_meta(outputs, self.next_stage_id) self._send_meta(outputs, self.next_stage_id)
if isinstance(outputs, paddle.Tensor): if isinstance(outputs, paddle.Tensor):
p2p.send(outputs, self.next_stage_id) if self.is_pipe_partitioned and self._is_valid_send_recv(outputs):
p2p.send_partial(
outputs.detach(),
self.next_stage_id,
mp_degree=self.mp_degree,
mp_rank=self.mp_rank)
else:
p2p.send(outputs.detach(), self.next_stage_id)
elif isinstance(outputs, tuple): elif isinstance(outputs, tuple):
for output in outputs: for output in outputs:
p2p.send(output, self.next_stage_id) if self.is_pipe_partitioned and self._is_valid_send_recv(
output):
p2p.send_partial(
output.detach(),
self.next_stage_id,
mp_degree=self.mp_degree,
mp_rank=self.mp_rank)
else:
p2p.send(output.detach(), self.next_stage_id)
def _send_gradients(self, cache_id): def _send_gradients(self, cache_id):
inputs = self.caches['inputs'][cache_id] inputs = self.caches['inputs'][cache_id]
if isinstance(inputs, paddle.Tensor): if isinstance(inputs, paddle.Tensor):
assert inputs.grad is not None assert inputs.grad is not None
p2p.send(inputs.grad, self.prev_stage_id) if self.is_pipe_partitioned and self._is_valid_send_recv(
inputs.grad):
grad = p2p.send_partial(
inputs.grad,
self.prev_stage_id,
mp_degree=self.mp_degree,
mp_rank=self.mp_rank)
else:
p2p.send(inputs.grad, self.prev_stage_id)
else: else:
for idx, d in enumerate(inputs): for idx, d in enumerate(inputs):
# Skip tensors that will not produce a grad # Skip tensors that will not produce a grad
if not is_float_tensor(d): if not is_float_tensor(d):
assert d.grad is None assert d.grad is None
continue continue
p2p.send(d.grad, self.prev_stage_id)
if self.is_pipe_partitioned and self._is_valid_send_recv(
d.grad):
grad = p2p.send_partial(
d.grad,
self.prev_stage_id,
mp_degree=self.mp_degree,
mp_rank=self.mp_rank)
else:
p2p.send(d.grad, self.prev_stage_id)
self.caches['inputs'][cache_id] = None self.caches['inputs'][cache_id] = None
...@@ -404,15 +446,39 @@ class PipelineParallel(MetaParallelBase): ...@@ -404,15 +446,39 @@ class PipelineParallel(MetaParallelBase):
self.recv_cache = self._recv_meta(self.prev_stage_id) self.recv_cache = self._recv_meta(self.prev_stage_id)
if isinstance(self.recv_cache, paddle.Tensor): if isinstance(self.recv_cache, paddle.Tensor):
p2p.recv(self.recv_cache, self.prev_stage_id) if self.is_pipe_partitioned and self._is_valid_send_recv(
self.recv_cache):
p2p.recv_partial(self.recv_cache, self.prev_stage_id,
self.mp_degree, self.mp_rank)
p2p.partial_allgather_operator(
self.recv_cache,
mp_ranks=self.mp_degree,
mp_rank_id=self.mp_rank,
group=self._hcg.get_model_parallel_group(),
use_calc_stream=True)
else:
p2p.recv(self.recv_cache, self.prev_stage_id)
inputs = self.recv_cache.clone().detach() inputs = self.recv_cache.clone().detach()
inputs.stop_gradient = not is_float_tensor(inputs) inputs.stop_gradient = not is_float_tensor(inputs)
else: else:
assert isinstance(self.recv_cache, tuple) assert isinstance(self.recv_cache, tuple)
inputs = [None] * len(self.recv_cache) inputs = [None] * len(self.recv_cache)
for idx, d in enumerate(self.recv_cache): for idx, d in enumerate(self.recv_cache):
assert isinstance(d, paddle.Tensor) if self.is_pipe_partitioned and self._is_valid_send_recv(d):
p2p.recv(d, self.prev_stage_id) assert isinstance(d, paddle.Tensor)
p2p.recv_partial(d, self.prev_stage_id, self.mp_degree,
self.mp_rank)
p2p.partial_allgather_operator(
d,
mp_ranks=self.mp_degree,
mp_rank_id=self.mp_rank,
group=self._hcg.get_model_parallel_group(),
use_calc_stream=True)
else:
assert isinstance(d, paddle.Tensor)
p2p.recv(d, self.prev_stage_id)
inputs[idx] = d.clone().detach() inputs[idx] = d.clone().detach()
inputs = tuple(inputs) inputs = tuple(inputs)
...@@ -440,11 +506,33 @@ class PipelineParallel(MetaParallelBase): ...@@ -440,11 +506,33 @@ class PipelineParallel(MetaParallelBase):
sizes, dtypes, num_caches=1)[0] sizes, dtypes, num_caches=1)[0]
if isinstance(self.grad_tensors, paddle.Tensor): if isinstance(self.grad_tensors, paddle.Tensor):
p2p.recv(self.grad_tensors, self.next_stage_id) if self.is_pipe_partitioned and self._is_valid_send_recv(
self.grad_tensors):
p2p.recv_partial(self.grad_tensors, self.next_stage_id,
self.mp_degree, self.mp_rank)
p2p.partial_allgather_operator(
self.grad_tensors,
mp_ranks=self.mp_degree,
mp_rank_id=self.mp_rank,
group=self._hcg.get_model_parallel_group(),
use_calc_stream=True)
else:
p2p.recv(self.grad_tensors, self.next_stage_id)
else: else:
assert isinstance(outputs, tuple) assert isinstance(outputs, tuple)
for d in self.grad_tensors: for d in self.grad_tensors:
p2p.recv(d, self.next_stage_id) if self.is_pipe_partitioned and self._is_valid_send_recv(d):
p2p.recv_partial(d, self.next_stage_id, self.mp_degree,
self.mp_rank)
p2p.partial_allgather_operator(
d,
mp_ranks=self.mp_degree,
mp_rank_id=self.mp_rank,
group=self._hcg.get_model_parallel_group(),
use_calc_stream=True)
else:
p2p.recv(d, self.next_stage_id)
def _step(self): def _step(self):
if self.scaler: if self.scaler:
......
...@@ -27,15 +27,67 @@ def initialize_p2p_groups(hcg): ...@@ -27,15 +27,67 @@ def initialize_p2p_groups(hcg):
_hcg = hcg _hcg = hcg
def _is_valid_communciate(src_stage, dest_stage):
first_stage = 0
last_stage = _hcg.get_pipe_parallel_world_size() - 1
assert abs(src_stage-dest_stage) == 1 or \
(src_stage == first_stage and dest_stage == last_stage) or \
(src_stage == last_stage and dest_stage == first_stage)
def partial_send_operator(tensor,
dst=0,
mp_ranks=1,
mp_rank_id=0,
group=None,
use_calc_stream=True):
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
return paddle.fluid.core.ops.partial_send(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer',
dst, 'num', mp_ranks, 'id', mp_rank_id)
def partial_recv_operator(tensor,
src=0,
mp_ranks=1,
mp_rank_id=0,
group=None,
use_calc_stream=True):
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
return paddle.fluid.core.ops.partial_recv(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer',
src, 'num', mp_ranks, 'id', mp_rank_id, 'dtype', tensor.dtype,
'out_shape', tensor.shape)
def partial_allgather_operator(tensor,
mp_ranks=1,
mp_rank_id=0,
group=None,
use_calc_stream=True):
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
return paddle.fluid.core.ops.partial_allgather_(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id,
'nranks', mp_ranks, 'rank', mp_rank_id)
def send(tensor, dest_stage): def send(tensor, dest_stage):
global _groups, _hcg global _groups, _hcg
src_stage = _hcg.get_stage_id() src_stage = _hcg.get_stage_id()
src_rank = _hcg.get_rank_from_stage(stage_id=src_stage)
_is_valid_communciate(src_stage, dest_stage) _is_valid_communciate(src_stage, dest_stage)
group = _get_send_recv_group(src_stage, dest_stage) group = _get_send_recv_group(src_stage, dest_stage)
dst_rank = _hcg.get_rank_from_stage(stage_id=dest_stage) return paddle.distributed.send(
return paddle.distributed.broadcast(tensor, src_rank, group=group) tensor, dst=1 if dest_stage > src_stage else 0, group=group)
def recv(tensor, src_stage): def recv(tensor, src_stage):
...@@ -44,16 +96,35 @@ def recv(tensor, src_stage): ...@@ -44,16 +96,35 @@ def recv(tensor, src_stage):
_is_valid_communciate(src_stage, dest_stage) _is_valid_communciate(src_stage, dest_stage)
group = _get_send_recv_group(src_stage, dest_stage) group = _get_send_recv_group(src_stage, dest_stage)
src_rank = _hcg.get_rank_from_stage(stage_id=src_stage) return paddle.distributed.recv(
return paddle.distributed.broadcast(tensor, src_rank, group=group) tensor, src=0 if dest_stage > src_stage else 1, group=group)
def _is_valid_communciate(src_stage, dest_stage): def send_partial(tensor, dest_stage, mp_degree, mp_rank):
first_stage = 0 global _groups, _hcg
last_stage = _hcg.get_pipe_parallel_world_size() - 1 src_stage = _hcg.get_stage_id()
assert abs(src_stage-dest_stage) == 1 or \ _is_valid_communciate(src_stage, dest_stage)
(src_stage == first_stage and dest_stage == last_stage) or \ group = _get_send_recv_group(src_stage, dest_stage)
(src_stage == last_stage and dest_stage == first_stage) return partial_send_operator(
tensor,
dst=1 if dest_stage > src_stage else 0,
mp_ranks=mp_degree,
mp_rank_id=mp_rank,
group=group)
def recv_partial(tensor, src_stage, mp_degree, mp_rank):
global _groups, _hcg
dest_stage = _hcg.get_stage_id()
_is_valid_communciate(src_stage, dest_stage)
group = _get_send_recv_group(src_stage, dest_stage)
return partial_recv_operator(
tensor,
src=0 if dest_stage > src_stage else 1,
mp_ranks=mp_degree,
mp_rank_id=mp_rank,
group=group)
def _get_send_recv_group(src_stage, dest_stage): def _get_send_recv_group(src_stage, dest_stage):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册