未验证 提交 c20eb7a6 编写于 作者: W wuhuachaocoding 提交者: GitHub

support stage2 for gradient merge. (#47711)

上级 460d5040
......@@ -418,17 +418,6 @@ class GroupShardedStage2(nn.Layer):
)
)
if self._dp_group and self._dp_group.nranks > 1:
assert (
not self._reduce_overlap
), 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.'
# TODO(wuhuachao):after the new communication lib upgrading, overlapping the comm of dp + stage2.
dist.all_reduce(
tensor=param.grad,
group=self._dp_group,
sync_op=True,
)
# Clear the task flow and trigger callback to clear the redundant gradient
# self._clear_task_flow()
......@@ -485,17 +474,6 @@ class GroupShardedStage2(nn.Layer):
)
)
if self._dp_group and self._dp_group.nranks > 1:
assert (
not self._reduce_overlap
), 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.'
# TODO(wuhuachao):after the new communication lib upgrading, overlapping the comm of dp + stage2.
dist.all_reduce(
tensor=grad_storage.buffer,
group=self._dp_group,
sync_op=True,
)
cleanup()
# Clear the task flow and trigger callback to clear the redundant gradient
......@@ -648,8 +626,34 @@ class GroupShardedStage2(nn.Layer):
)
return rank_buffer_size
def _dp_allreduce(self):
# do dp allreduce here for gradient merge.
if self._dp_group and self._dp_group.nranks > 1:
for dtype in self._grad_storages.keys():
for rank, g in sorted(
self._grad_storages[dtype].items(), key=lambda x: x[0]
):
if g.destination == self._rank:
assert g.buffer._is_initialized()
dist.all_reduce(
tensor=g.buffer,
group=self._dp_group,
sync_op=True,
)
for param in self._trainable_params:
if param.name in self._param_grads and param.grad is not None:
dst_rank = self._trainable_param2rank[param.name]
if dst_rank == self._rank:
dist.all_reduce(
tensor=param.grad,
group=self._dp_group,
sync_op=True,
)
def _redefine_opt_step(self):
grad_func = self._grad_scale
dp_allreduce_func = self._dp_allreduce
for opt in self._sharding_optimizers:
opt_step = opt.step
......@@ -658,7 +662,9 @@ class GroupShardedStage2(nn.Layer):
# Wait for the last reduce task. This wait must before grad scale function.
assert self._comm_task is not None
self._comm_task.wait()
grad_func()
dp_allreduce_func()
opt_step()
opt.step = MethodType(_opt_step, opt)
......
......@@ -23,7 +23,6 @@ from paddle.fluid import core
from paddle.fluid import layers
from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import dygraph_only
from paddle.distributed import fleet, ParallelMode
class Taskflow:
......@@ -245,18 +244,8 @@ def GroupShardedScaler(scaler):
self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")
hcg = fleet.fleet._hcg if hasattr(fleet.fleet, "_hcg") else None
hybrid_parallel = (
hcg is not None
and hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL
)
paddle.distributed.all_reduce(
is_found_inf,
op=paddle.distributed.ReduceOp.MAX,
group=hcg.get_check_parallel_group()
if hybrid_parallel
else optimizer._group,
is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None
)
self._found_inf = is_found_inf.numpy()[0]
......
......@@ -148,11 +148,6 @@ def test_sharding_api():
output_dir = tempfile.mkdtemp()
# test sharding + dp, just for test
dp_group = paddle.distributed.new_group(
list(range(paddle.distributed.get_world_size()))
)
# fp16
stage2_params = train_mlp(
mlp1,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册