diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 45ea9a3c9dde4bd92e4d2db69d05e6e4eac7a424..ae2d9163435b906f17e9b28a680302d2bd305bbc 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -442,7 +442,7 @@ class Completer: assert forward_op is not None if grad_op.type == "concat" and forward_op.type == "split": - forward_op_dist_attr = dist_context.get_op_dist_attr_for_program( + forward_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( forward_op) output_var = vars[grad_op.desc.output('Out')[0]] split_input_var_name = forward_op.input("X")[0] @@ -458,14 +458,14 @@ class Completer: output_var_dist_attr = TensorDistributedAttribute() output_var_dist_attr.dims_mapping = ref_dims_mapping output_var_dist_attr.process_mesh = ref_mesh - dist_context.set_tensor_dist_attr_for_program( + self._dist_context.set_tensor_dist_attr_for_program( output_var, output_var_dist_attr) grad_op_dist_attr.set_output_dims_mapping(output_var.name, ref_dims_mapping) grad_op_dist_attr.process_mesh = ref_mesh - dist_context.set_op_dist_attr_for_program(grad_op, - grad_op_dist_attr) + self._dist_context.set_op_dist_attr_for_program( + grad_op, grad_op_dist_attr) continue # op dist attr @@ -579,6 +579,28 @@ class Completer: # TODO to add attribute for moment var op = ops[idx] if int(op.attr('op_role')) == int(OpRole.Optimize): + if op.type == "clip_by_norm": + param_grad = vars[op.input("X")[0]] + param_grad_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( + param_grad) + assert param_grad_dist_attr is not None + ref_process_mesh = param_grad_dist_attr.process_mesh + ref_dims_mapping = param_grad_dist_attr.dims_mapping + + out = vars[op.output("Out")[0]] + out_dist_attr = TensorDistributedAttribute() + out_dist_attr.process_mesh = ref_process_mesh + out_dist_attr.dims_mapping = ref_dims_mapping + self._dist_context.set_tensor_dist_attr_for_program( + out, out_dist_attr) + + op_dist_attr = OperatorDistributedAttribute() + op_dist_attr.process_mesh = ref_process_mesh + op_dist_attr.set_input_dist_attr(param_grad.name, + param_grad_dist_attr) + op_dist_attr.set_output_dist_attr(out.name, out_dist_attr) + self._dist_context.set_op_dist_attr_for_program( + op, op_dist_attr) if "Grad" in op.input_names and "Param" in ops[idx].input_names: assert len(op.input( diff --git a/python/paddle/distributed/auto_parallel/cost_model.py b/python/paddle/distributed/auto_parallel/cost_model.py index 9252f8de905b5f6cedd834b520d8a7b9ad2e125a..1155c2817a21cd147ee1012fbaf11376a5183717 100644 --- a/python/paddle/distributed/auto_parallel/cost_model.py +++ b/python/paddle/distributed/auto_parallel/cost_model.py @@ -142,7 +142,8 @@ class TensorCostNode(CostNode): elif node.dtype == paddle.uint8: self.dtype_factor = 1 else: - raise NotImplementedError("{} not counted".format(node.dtype)) + self.dtype_factor = 2 + # raise NotImplementedError("{} not counted".format(node.dtype)) self.batch_size = None if batch_size is not None: self.batch_size = batch_size diff --git a/python/paddle/distributed/auto_parallel/dist_op.py b/python/paddle/distributed/auto_parallel/dist_op.py index a7cc2a9600c05f8e528860ffe4ed28a729ac0bbb..67de298564afc8caddad90d228131f1795f5707e 100644 --- a/python/paddle/distributed/auto_parallel/dist_op.py +++ b/python/paddle/distributed/auto_parallel/dist_op.py @@ -86,7 +86,7 @@ class DistributedOperator: tensor_dims_mapping) for tensor_name in self._serial_op.output_arg_names: tensor = self._serial_op.block._var_recursive(tensor_name) - if tensor.type == core.VarDesc.VarType.READER: + if tensor.type == core.VarDesc.VarType.READER or tensor.type == core.VarDesc.VarType.STEP_SCOPES: tensor_shape = [] else: tensor_shape = tensor.shape diff --git a/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py b/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py index 52d5e85c962eb2cb28578c43abf4dd7c6c5cce82..2870acfd367cab5236f8544c447bdd269b8e654b 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py @@ -26,7 +26,7 @@ from ..process_group import new_process_group from ..dist_attribute import OperatorDistributedAttribute from paddle.distributed.auto_parallel.process_group import get_world_process_group -global_process_mesh = get_world_process_group().ranks +world_process_group = get_world_process_group() class DistributedCheckFiniteAndUnscale(DistributedOperatorImplContainer): @@ -119,7 +119,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): main_block._sync_with_cpp() # sync result - group = new_process_group(global_process_mesh) + group = new_process_group(world_process_group.ranks) inf_var = main_block.var(kwargs['FoundInfinite'][0]) inf_var_int32 = main_block.create_var( diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index 2f557ad3e9fe38e5eb6807e2abae35e9c996bd39..6278f0a2424a0fa89b5ae7ab2350aeec63a600a7 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -222,6 +222,8 @@ class AutoParallelizer: HAS_ALLGATHER.clear() _g_process_group_map.clear() _g_process_group_map[0] = ProcessGroup(0, []) + for process_mesh in dist_context._process_meshes: + _g_process_group_map[0].add_ranks(process_mesh.processes) return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, g_process_group_map def parallelize(self, diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index a0a68efae3c3c1f4ca04144b3243d144651df817..e789d82632e073544b7efaac96f397bb9df9276c 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -381,7 +381,7 @@ def _get_dist_op_backward_implement(backward_op, dist_context, op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op) assert op_dist_attr.impl_idx >= 0 dist_op_impl = get_distributed_operator_impl_container( - backward_op.type).get_impl(op_dist_attr.impl_idx) + op_dist_attr.impl_type).get_impl(op_dist_attr.impl_idx) return dist_op_impl dist_op = get_distributed_operator_impl_container("default") diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index c28a48da838fec88ffa1703251705564add993c2..4cc710b226d8f84fadd249a148e754d5330fb564 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -1013,18 +1013,18 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, assert isinstance(dist_context, DistributedContext), "The type of dist_context should be DistributedContext, " \ "but got {}.".format(type(dist_context)) + def _is_special_op(op): + global _g_special_ops + if op.type in _g_special_ops: + return True + return False + block = auto_parallel_main_prog.global_block() idx = 0 while idx < len(block.ops): pre_op_count = len(block.ops) op = block.ops[idx] - def _is_special_op(op): - global _g_special_ops - if op.type in _g_special_ops: - return True - return False - if _is_special_op(op): idx += 1 continue @@ -1053,6 +1053,7 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, # insert send and recv op if output process mesh is different from tensor process mesh idx = 0 skip_ops = ["create_py_reader", "create_double_buffer_reader", "read"] + skip_ops += _g_special_ops while idx < len(block.ops): pre_op_count = len(block.ops) op = block.ops[idx] diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index d2af422bac023e21b342209bce54f6f72277f17a..d69d6d4ab3286368d242651652b34c4d11c853fb 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -26,7 +26,7 @@ from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_fp32_input, _k from paddle.fluid.contrib.mixed_precision.fp16_utils import _valid_types, find_true_post_op, find_true_prev_op from paddle.fluid.contrib.mixed_precision.fp16_utils import _is_in_black_varnames, _dtype_to_str, _rename_arg from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute -global_process_mesh = get_world_process_group().ranks +world_process_group = get_world_process_group() class AMPState(object): @@ -445,7 +445,7 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context): type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, stop_gradient=False) - set_var_dist_attr(dist_context, found_inf, [-1], global_process_mesh) + set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks) inputs = {'X': grads, 'Scale': loss_scaling} outputs = {'Out': grads, 'FoundInfinite': found_inf} @@ -457,9 +457,10 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context): attrs=attrs) new_op_dist_attr = OperatorDistributedAttribute() - new_op_dist_attr.process_mesh = global_process_mesh - if len(global_process_mesh) > 1: - new_op_dist_attr.impl_idx = 0 + new_op_dist_attr.process_mesh = world_process_group.ranks + new_op_dist_attr.impl_idx = 0 + if len(world_process_group.ranks) > 1: + new_op_dist_attr.impl_type = "check_finite_and_unscale" for g in grads: g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g) assert g_dist_attr is not None @@ -550,7 +551,7 @@ class AMPPass(PassBase): dtype='float32', persistable=True) set_var_dist_attr(self.dist_context, self._loss_scaling, [-1], - global_process_mesh) + world_process_group.ranks) if self.get_attr("use_dynamic_loss_scaling"): self._num_good_steps = paddle.static.create_global_var( @@ -560,7 +561,7 @@ class AMPPass(PassBase): dtype='int32', persistable=True) set_var_dist_attr(self.dist_context, self._num_good_steps, [-1], - global_process_mesh) + world_process_group.ranks) self._num_bad_steps = paddle.static.create_global_var( name=unique_name.generate("num_bad_steps"), @@ -569,7 +570,7 @@ class AMPPass(PassBase): dtype='int32', persistable=True) set_var_dist_attr(self.dist_context, self._num_bad_steps, [-1], - global_process_mesh) + world_process_group.ranks) def _scale_loss(self): @@ -700,9 +701,10 @@ class AMPPass(PassBase): attrs=attrs) new_op_dist_attr = OperatorDistributedAttribute() - new_op_dist_attr.process_mesh = global_process_mesh - if len(global_process_mesh) > 1: - new_op_dist_attr.impl_idx = 0 + new_op_dist_attr.process_mesh = world_process_group.ranks + new_op_dist_attr.impl_idx = 0 + if len(world_process_group.ranks) > 1: + new_op_dist_attr.impl_type = "update_loss_scaling" for g in grads: g_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(g) assert g_dist_attr is not None diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index 4039f3ed746772a49fd022a99d5ff281aa177a7e..185fb453412eab69579197e10c43d5c9f8d7fc1a 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -382,6 +382,7 @@ class RecomputePass(PassBase): new_dist_attr = OperatorDistributedAttribute() new_dist_attr.is_recompute = True new_dist_attr.impl_idx = old_dist_attr.impl_idx + new_dist_attr.impl_type = old_dist_attr.impl_type new_dist_attr.process_mesh = old_dist_attr.process_mesh for input in old_dist_attr.inputs_dist_attrs.keys(): if input in var_name_dict.keys(): diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_recompute_pass.py b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_recompute_pass.py index 1875c8b1da983f965b3477d78b4a28768ef91efe..74a751881ddf24df41ead820ce41447f4c61210b 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_recompute_pass.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_recompute_pass.py @@ -40,7 +40,7 @@ class TestRecomputePass(AutoPallelPassTestBase): def apply_passes(self): dist_strategy = fleet.DistributedStrategy() dist_strategy.recompute = True - dist_strategy.recompute_configs = {"checkpoints": ["tmp3", "tmp6"]} + dist_strategy.recompute_configs = {"checkpoints": ["tmp_3", "tmp_6"]} dist_strategy.semi_auto = True fleet.init(is_collective=True, strategy=dist_strategy)