From bbf31a4ef6ba9734c8ef6384e77d3906ffec1fb6 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Fri, 18 Feb 2022 18:09:15 +0800 Subject: [PATCH] bug fix (#39630) --- .../distributed/auto_parallel/completion.py | 30 ++++++++++++++++--- .../distributed/auto_parallel/cost_model.py | 3 +- .../distributed/auto_parallel/dist_op.py | 2 +- .../dist_check_finite_and_unscale.py | 4 +-- .../distributed/auto_parallel/parallelizer.py | 2 ++ .../distributed/auto_parallel/partitioner.py | 2 +- .../distributed/auto_parallel/reshard.py | 13 ++++---- .../distributed/passes/auto_parallel_amp.py | 24 ++++++++------- .../passes/auto_parallel_recompute.py | 1 + .../test_auto_parallel_recompute_pass.py | 2 +- 10 files changed, 56 insertions(+), 27 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 45ea9a3c9dd..ae2d9163435 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -442,7 +442,7 @@ class Completer: assert forward_op is not None if grad_op.type == "concat" and forward_op.type == "split": - forward_op_dist_attr = dist_context.get_op_dist_attr_for_program( + forward_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( forward_op) output_var = vars[grad_op.desc.output('Out')[0]] split_input_var_name = forward_op.input("X")[0] @@ -458,14 +458,14 @@ class Completer: output_var_dist_attr = TensorDistributedAttribute() output_var_dist_attr.dims_mapping = ref_dims_mapping output_var_dist_attr.process_mesh = ref_mesh - dist_context.set_tensor_dist_attr_for_program( + self._dist_context.set_tensor_dist_attr_for_program( output_var, output_var_dist_attr) grad_op_dist_attr.set_output_dims_mapping(output_var.name, ref_dims_mapping) grad_op_dist_attr.process_mesh = ref_mesh - dist_context.set_op_dist_attr_for_program(grad_op, - grad_op_dist_attr) + self._dist_context.set_op_dist_attr_for_program( + grad_op, grad_op_dist_attr) continue # op dist attr @@ -579,6 +579,28 @@ class Completer: # TODO to add attribute for moment var op = ops[idx] if int(op.attr('op_role')) == int(OpRole.Optimize): + if op.type == "clip_by_norm": + param_grad = vars[op.input("X")[0]] + param_grad_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( + param_grad) + assert param_grad_dist_attr is not None + ref_process_mesh = param_grad_dist_attr.process_mesh + ref_dims_mapping = param_grad_dist_attr.dims_mapping + + out = vars[op.output("Out")[0]] + out_dist_attr = TensorDistributedAttribute() + out_dist_attr.process_mesh = ref_process_mesh + out_dist_attr.dims_mapping = ref_dims_mapping + self._dist_context.set_tensor_dist_attr_for_program( + out, out_dist_attr) + + op_dist_attr = OperatorDistributedAttribute() + op_dist_attr.process_mesh = ref_process_mesh + op_dist_attr.set_input_dist_attr(param_grad.name, + param_grad_dist_attr) + op_dist_attr.set_output_dist_attr(out.name, out_dist_attr) + self._dist_context.set_op_dist_attr_for_program( + op, op_dist_attr) if "Grad" in op.input_names and "Param" in ops[idx].input_names: assert len(op.input( diff --git a/python/paddle/distributed/auto_parallel/cost_model.py b/python/paddle/distributed/auto_parallel/cost_model.py index 9252f8de905..1155c2817a2 100644 --- a/python/paddle/distributed/auto_parallel/cost_model.py +++ b/python/paddle/distributed/auto_parallel/cost_model.py @@ -142,7 +142,8 @@ class TensorCostNode(CostNode): elif node.dtype == paddle.uint8: self.dtype_factor = 1 else: - raise NotImplementedError("{} not counted".format(node.dtype)) + self.dtype_factor = 2 + # raise NotImplementedError("{} not counted".format(node.dtype)) self.batch_size = None if batch_size is not None: self.batch_size = batch_size diff --git a/python/paddle/distributed/auto_parallel/dist_op.py b/python/paddle/distributed/auto_parallel/dist_op.py index a7cc2a9600c..67de298564a 100644 --- a/python/paddle/distributed/auto_parallel/dist_op.py +++ b/python/paddle/distributed/auto_parallel/dist_op.py @@ -86,7 +86,7 @@ class DistributedOperator: tensor_dims_mapping) for tensor_name in self._serial_op.output_arg_names: tensor = self._serial_op.block._var_recursive(tensor_name) - if tensor.type == core.VarDesc.VarType.READER: + if tensor.type == core.VarDesc.VarType.READER or tensor.type == core.VarDesc.VarType.STEP_SCOPES: tensor_shape = [] else: tensor_shape = tensor.shape diff --git a/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py b/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py index 52d5e85c962..2870acfd367 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py @@ -26,7 +26,7 @@ from ..process_group import new_process_group from ..dist_attribute import OperatorDistributedAttribute from paddle.distributed.auto_parallel.process_group import get_world_process_group -global_process_mesh = get_world_process_group().ranks +world_process_group = get_world_process_group() class DistributedCheckFiniteAndUnscale(DistributedOperatorImplContainer): @@ -119,7 +119,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): main_block._sync_with_cpp() # sync result - group = new_process_group(global_process_mesh) + group = new_process_group(world_process_group.ranks) inf_var = main_block.var(kwargs['FoundInfinite'][0]) inf_var_int32 = main_block.create_var( diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index 2f557ad3e9f..6278f0a2424 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -222,6 +222,8 @@ class AutoParallelizer: HAS_ALLGATHER.clear() _g_process_group_map.clear() _g_process_group_map[0] = ProcessGroup(0, []) + for process_mesh in dist_context._process_meshes: + _g_process_group_map[0].add_ranks(process_mesh.processes) return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, g_process_group_map def parallelize(self, diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index a0a68efae3c..e789d82632e 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -381,7 +381,7 @@ def _get_dist_op_backward_implement(backward_op, dist_context, op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op) assert op_dist_attr.impl_idx >= 0 dist_op_impl = get_distributed_operator_impl_container( - backward_op.type).get_impl(op_dist_attr.impl_idx) + op_dist_attr.impl_type).get_impl(op_dist_attr.impl_idx) return dist_op_impl dist_op = get_distributed_operator_impl_container("default") diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index c28a48da838..4cc710b226d 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -1013,18 +1013,18 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, assert isinstance(dist_context, DistributedContext), "The type of dist_context should be DistributedContext, " \ "but got {}.".format(type(dist_context)) + def _is_special_op(op): + global _g_special_ops + if op.type in _g_special_ops: + return True + return False + block = auto_parallel_main_prog.global_block() idx = 0 while idx < len(block.ops): pre_op_count = len(block.ops) op = block.ops[idx] - def _is_special_op(op): - global _g_special_ops - if op.type in _g_special_ops: - return True - return False - if _is_special_op(op): idx += 1 continue @@ -1053,6 +1053,7 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, # insert send and recv op if output process mesh is different from tensor process mesh idx = 0 skip_ops = ["create_py_reader", "create_double_buffer_reader", "read"] + skip_ops += _g_special_ops while idx < len(block.ops): pre_op_count = len(block.ops) op = block.ops[idx] diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index d2af422bac0..d69d6d4ab32 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -26,7 +26,7 @@ from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_fp32_input, _k from paddle.fluid.contrib.mixed_precision.fp16_utils import _valid_types, find_true_post_op, find_true_prev_op from paddle.fluid.contrib.mixed_precision.fp16_utils import _is_in_black_varnames, _dtype_to_str, _rename_arg from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute -global_process_mesh = get_world_process_group().ranks +world_process_group = get_world_process_group() class AMPState(object): @@ -445,7 +445,7 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context): type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, stop_gradient=False) - set_var_dist_attr(dist_context, found_inf, [-1], global_process_mesh) + set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks) inputs = {'X': grads, 'Scale': loss_scaling} outputs = {'Out': grads, 'FoundInfinite': found_inf} @@ -457,9 +457,10 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context): attrs=attrs) new_op_dist_attr = OperatorDistributedAttribute() - new_op_dist_attr.process_mesh = global_process_mesh - if len(global_process_mesh) > 1: - new_op_dist_attr.impl_idx = 0 + new_op_dist_attr.process_mesh = world_process_group.ranks + new_op_dist_attr.impl_idx = 0 + if len(world_process_group.ranks) > 1: + new_op_dist_attr.impl_type = "check_finite_and_unscale" for g in grads: g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g) assert g_dist_attr is not None @@ -550,7 +551,7 @@ class AMPPass(PassBase): dtype='float32', persistable=True) set_var_dist_attr(self.dist_context, self._loss_scaling, [-1], - global_process_mesh) + world_process_group.ranks) if self.get_attr("use_dynamic_loss_scaling"): self._num_good_steps = paddle.static.create_global_var( @@ -560,7 +561,7 @@ class AMPPass(PassBase): dtype='int32', persistable=True) set_var_dist_attr(self.dist_context, self._num_good_steps, [-1], - global_process_mesh) + world_process_group.ranks) self._num_bad_steps = paddle.static.create_global_var( name=unique_name.generate("num_bad_steps"), @@ -569,7 +570,7 @@ class AMPPass(PassBase): dtype='int32', persistable=True) set_var_dist_attr(self.dist_context, self._num_bad_steps, [-1], - global_process_mesh) + world_process_group.ranks) def _scale_loss(self): @@ -700,9 +701,10 @@ class AMPPass(PassBase): attrs=attrs) new_op_dist_attr = OperatorDistributedAttribute() - new_op_dist_attr.process_mesh = global_process_mesh - if len(global_process_mesh) > 1: - new_op_dist_attr.impl_idx = 0 + new_op_dist_attr.process_mesh = world_process_group.ranks + new_op_dist_attr.impl_idx = 0 + if len(world_process_group.ranks) > 1: + new_op_dist_attr.impl_type = "update_loss_scaling" for g in grads: g_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(g) assert g_dist_attr is not None diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index 4039f3ed746..185fb453412 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -382,6 +382,7 @@ class RecomputePass(PassBase): new_dist_attr = OperatorDistributedAttribute() new_dist_attr.is_recompute = True new_dist_attr.impl_idx = old_dist_attr.impl_idx + new_dist_attr.impl_type = old_dist_attr.impl_type new_dist_attr.process_mesh = old_dist_attr.process_mesh for input in old_dist_attr.inputs_dist_attrs.keys(): if input in var_name_dict.keys(): diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_recompute_pass.py b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_recompute_pass.py index 1875c8b1da9..74a751881dd 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_recompute_pass.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_recompute_pass.py @@ -40,7 +40,7 @@ class TestRecomputePass(AutoPallelPassTestBase): def apply_passes(self): dist_strategy = fleet.DistributedStrategy() dist_strategy.recompute = True - dist_strategy.recompute_configs = {"checkpoints": ["tmp3", "tmp6"]} + dist_strategy.recompute_configs = {"checkpoints": ["tmp_3", "tmp_6"]} dist_strategy.semi_auto = True fleet.init(is_collective=True, strategy=dist_strategy) -- GitLab