未验证 提交 bdd0b0f1 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] Support High Order Differential with Data Parallel Calc-Comm Overlaping (#45388)

* support high order differential with data parallel overlap

* update unitest
上级 256bf6ff
......@@ -24,6 +24,7 @@ from .dist_attribute import OperatorDistributedAttribute
from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperator
from .process_mesh import ProcessMesh
from .utils import is_loss_grad_op, is_loss_op
# There always exists a default context for user. And user can set it to another one.
_g_default_distributed_context = None
......@@ -895,6 +896,11 @@ class DistributedOperatorContext:
self.already_init_sync_vars = set()
self.varname_mapping = None
self.rank_id = None
# NOTE Support correct parallelism for high-order differential model.
# by default exceed_backward_init_op is False and it means we are in Forward phase; After exceed_backward_init_op = True,
# it means we are in Backward phase.
# And the final sulotion should be revise high-order differential logic for these two phases in future.
self._exceed_backward_init_op = False
def __deepcopy__(self, memo):
cls = self.__class__
......@@ -951,10 +957,16 @@ class DistributedOperatorContext:
assert self._cur_src_op is not None
return self._cur_src_op
def in_backward_phase(self):
return self._exceed_backward_init_op
def prepare_context(self, src_op):
self._cur_src_op = src_op
if is_loss_grad_op(src_op):
self._exceed_backward_init_op = True
# build input varname mapping
kinputs = {}
for input_name in src_op.desc.input_names():
......
......@@ -428,6 +428,9 @@ def gradient_synchronization(dist_ctx, op, act_grad_names, out_grad_names,
rank (int): global ranks index for current process.
"""
if not is_in_backward_phase(dist_ctx):
return
if is_optimize_op(op) or len(act_grad_names) == 0 or len(
out_grad_names) == 0:
return
......@@ -448,3 +451,12 @@ def is_data_parallel_scale_op(op):
def is_data_parallel_reduce_op(op):
return op.type in ["c_reduce_sum", "c_allreduce_sum"] and op.desc.has_attr("op_namescope") \
and ParallelMode.DataParallel in op.desc.attr("op_namescope")
def is_in_backward_phase(dist_ctx):
# NOTE currently high-order differential in Paddle dose NOT distinguish gradient computation operators
# in Forward phase and operators in Backward phase (both with op_role=1), which will mislead
# auto parallel to add gradient synchronization for gradient computation operators in Forward phase.
# we use this FLAG to distinguish these two phases temporarily.
return dist_ctx.dist_op_context.in_backward_phase()
......@@ -42,29 +42,6 @@ class DistributedPNorm(DistributedOperatorImplContainer):
register_distributed_operator_impl_container(DistributedPNorm("p_norm"))
def _insert_fill_constant_op(block, op_role):
"""Insert fill constant op into block at the given index."""
helper = LayerHelper("fill_constant", **locals())
with paddle.static.program_guard(block.program):
out = helper.create_variable_for_type_inference(dtype="int32")
inputs = {}
attrs = {'force_cpu': False}
attrs['str_value'] = str(int("1"))
attrs['value'] = int("1")
attrs['dtype'] = out.dtype
attrs['op_role'] = op_role
utils.get_shape_tensor_inputs(inputs=inputs,
attrs=attrs,
shape=[0],
op_type='fill_constant')
fill_constant_op = block.append_op(type='fill_constant',
inputs=inputs,
outputs={'Out': [out]},
attrs=attrs)
out.stop_gradient = True
return out, fill_constant_op
# Row Parallel
class DistributedPNormImpl(DistributedOperatorImpl):
......@@ -182,32 +159,6 @@ class DistributedPNormImpl(DistributedOperatorImpl):
check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'],
'norm')
# 1. insert barrier op
ref_process_mesh = op_dist_attr.process_mesh
constant_out_dims_mapping = [-1]
fill_constant_out, fill_constant_op = _insert_fill_constant_op(
main_block, src_op.attr('op_role'))
# set fill_constant_out tensor dist_attr
constant_out_dist_attr = TensorDistributedAttribute()
constant_out_dist_attr.process_mesh = ref_process_mesh
constant_out_dist_attr.dims_mapping = constant_out_dims_mapping
ctx.set_tensor_dist_attr_for_program(fill_constant_out,
constant_out_dist_attr)
# set fill_constant op dist_attr
constant_op_dist_attr = OperatorDistributedAttribute()
constant_op_dist_attr.process_mesh = ref_process_mesh
constant_op_dist_attr.set_output_dims_mapping(
fill_constant_out.name, constant_out_dims_mapping)
ctx.set_op_dist_attr_for_program(fill_constant_op,
constant_op_dist_attr)
barrier_op = main_block.append_op(type='barrier',
inputs={'X': [fill_constant_out]},
outputs={'Out': [fill_constant_out]},
attrs={'ring_id': group.id})
# set barrier op dist attr
set_comm_op_dist_attr_for_program(barrier_op, ref_process_mesh,
constant_out_dist_attr, ctx)
# 2. insert c_allgather op
# create c_allgather output var
allgather_out = main_block.create_var(
......
......@@ -111,9 +111,9 @@ class DataParallelOptimizationPass(PassBase):
scaled_grads = []
for op in ops:
grad_name = op.output_arg_names[0]
if is_data_parallel_reduce_op(op):
grad_name = op.output_arg_names[0]
if grad_name in self._grad_name_to_group_map:
continue
assert op.has_attr(
......@@ -132,6 +132,7 @@ class DataParallelOptimizationPass(PassBase):
self._group_to_grad_name_map[group].append(grad_name)
elif is_data_parallel_scale_op(op):
grad_name = op.output_arg_names[0]
scaled_grads.append(grad_name)
# TODO support multiple optimizers in on network in future.
......
......@@ -108,10 +108,8 @@ class TestDistPNorm(unittest.TestCase):
for output_attr in op_dist_attr.outputs_dist_attrs.values():
assert output_attr.dims_mapping[0] == 0
assert set(output_attr.dims_mapping[1:]) == set([-1])
assert op_types == [
"fill_constant", "barrier", "c_allgather", "p_norm",
"fill_constant", "p_norm_grad", "slice"
"c_allgather", "p_norm", "fill_constant", "p_norm_grad", "slice"
]
def test_dist_pnorm_serial(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册