diff --git a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py index 310358436ae32384b6ac651034f7c020aa4fe6ef..7668dff36207ed700f5aa6378cb0f5532cfedd3f 100644 --- a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py +++ b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py @@ -22,6 +22,10 @@ from paddle.fluid.framework import program_guard, device_guard from paddle.fluid import unique_name, layers from paddle.fluid.clip import append_gradient_clip_ops from .pass_base import PassBase, PassType, register_pass +from paddle.distributed.auto_parallel.utils import set_var_dist_attr +from paddle.distributed.auto_parallel.utils import naive_set_dist_op_attr_for_program_by_mesh_and_mapping +from paddle.distributed.auto_parallel.process_group import get_world_process_group +world_process_group = get_world_process_group() def _is_the_backward_op(op): @@ -68,15 +72,11 @@ def _remove_and_get_optimizer_op(main_program, dist_context): def _remove_op_role_var(param, grad): op_maker = core.op_proto_and_checker_maker op = grad.op - assert _is_the_backward_op(op), \ - 'grad.op={} is not the backward op which produces the grad={}' \ - .format(op, grad.name) - if op.has_attr(op_maker.kOpRoleVarAttrName()): op._remove_attr(op_maker.kOpRoleVarAttrName()) -def _get_gm_cond_var(main_program, k_steps): +def _get_gm_cond_var(main_program, k_steps, dist_context): main_block = main_program.global_block() # Add const var k_step_var = layers.create_global_var( @@ -86,6 +86,7 @@ def _get_gm_cond_var(main_program, k_steps): dtype='int32', persistable=True, force_cpu=True) + set_var_dist_attr(dist_context, k_step_var, [-1], world_process_group.ranks) zero_var = layers.create_global_var( name="gradient_merge_zero", @@ -94,6 +95,7 @@ def _get_gm_cond_var(main_program, k_steps): dtype='int32', persistable=True, force_cpu=True) + set_var_dist_attr(dist_context, zero_var, [-1], world_process_group.ranks) # Add step var & cond var step_var = layers.create_global_var( @@ -103,6 +105,7 @@ def _get_gm_cond_var(main_program, k_steps): dtype='int32', persistable=True, force_cpu=True) + set_var_dist_attr(dist_context, step_var, [-1], world_process_group.ranks) cond_var = layers.create_global_var( name="gradient_merge_cond", @@ -111,24 +114,29 @@ def _get_gm_cond_var(main_program, k_steps): dtype='bool', persistable=False, force_cpu=True) + set_var_dist_attr(dist_context, cond_var, [-1], world_process_group.ranks) with device_guard("cpu"): # step_var = (step_var + 1) % k_step layers.increment(x=step_var, value=1.0, in_place=True) - main_block.append_op( + elementwise_mod_op = main_block.append_op( type='elementwise_mod', inputs={'X': step_var, 'Y': k_step_var}, outputs={'Out': step_var}, attrs={'axis': -1, 'use_mkldnn': False}) + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + elementwise_mod_op, world_process_group.ranks, [-1], dist_context) # cond_var = (step_var == 0) - main_block.append_op( + equal_op = main_block.append_op( type='equal', inputs={'X': step_var, 'Y': zero_var}, outputs={'Out': cond_var}) + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + equal_op, world_process_group.ranks, [-1], dist_context) return cond_var @@ -137,7 +145,8 @@ def _append_gradient_merge_backward_op( main_program, startup_program, params_grads: List[Tuple[Any, Any]], - cond_var_name: str) -> Tuple[List[Tuple[Any, Any]], Dict[str, Any]]: + cond_var_name: str, + dist_context) -> Tuple[List[Tuple[Any, Any]], Dict[str, Any]]: main_block = main_program.global_block() startup_block = startup_program.global_block() @@ -156,12 +165,19 @@ def _append_gradient_merge_backward_op( param_name = param.name param_var = main_block.var(param_name) assert (param_var is not None) + ref_dist_attr = dist_context.get_tensor_dist_attr_for_program(param_var) + assert ref_dist_attr is not None gradient_merge_var = main_block.create_var( name=param_name + "@GRAD@GradientMerge", shape=param_var.shape, dtype=param_var.dtype, persistable=True) param_to_gradient_merge[param_name] = gradient_merge_var + ref_process_mesh = ref_dist_attr.process_mesh + ref_dims_mapping = ref_dist_attr.dims_mapping + + set_var_dist_attr(dist_context, gradient_merge_var, ref_dims_mapping, + ref_process_mesh) startup_gradient_merge_var = startup_block.create_var( name=param_name + "@GRAD@GradientMerge", @@ -186,6 +202,8 @@ def _append_gradient_merge_backward_op( attrs={'axis': -1, 'use_mkldnn': False}) new_params_to_grads.append([param, gradient_merge_var]) + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + new_grad_op, ref_process_mesh, ref_dims_mapping, dist_context) return new_params_to_grads, param_to_gradient_merge @@ -240,7 +258,7 @@ def _create_cond_block_and_update_optimizer( new_op_desc.remove_attr(op_maker.kOpRoleVarAttrName()) # op's update Grad - if new_op_desc.input("Grad"): + if core.grad_var_suffix() in new_op_desc.input_arg_names(): grad_value = new_op_desc.input("Grad")[0] # TODO FIXME(xym) support fp16 grad_merge_value = grad_value + '@GradientMerge' @@ -265,7 +283,7 @@ def _create_cond_block_and_update_optimizer( def parse_program(main_program, startup_program, params_grads, k_steps, avg, dist_context): # 1 create gradient_merge_cond - cond_var = _get_gm_cond_var(main_program, k_steps) + cond_var = _get_gm_cond_var(main_program, k_steps, dist_context) # 2 remove optimizer_op from main_program optimize_ops_desc = _remove_and_get_optimizer_op(main_program, dist_context) @@ -275,7 +293,8 @@ def parse_program(main_program, startup_program, params_grads, k_steps, avg, # 3 append gradient merge backward op to main_program new_params_to_grads, param_to_gradient_merge = _append_gradient_merge_backward_op( - main_program, startup_program, params_grads, cond_var.name) + main_program, startup_program, params_grads, cond_var.name, + dist_context) # 4 create ConditionalBlock and append gradient merge optimizer ops _create_cond_block_and_update_optimizer( diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/test_dist_gradient_merge_pass.py b/python/paddle/fluid/tests/unittests/distributed_passes/test_dist_gradient_merge_pass.py index acb67e8a20c8c0cb81b473fecc442f3044a6a0b3..0c324ba8ee9aa2d16fadfd68e8b19e9e9a3a9abf 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/test_dist_gradient_merge_pass.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/test_dist_gradient_merge_pass.py @@ -31,6 +31,7 @@ from paddle.fluid.initializer import NumpyArrayInitializer from paddle.distributed.passes import new_pass, PassManager, PassContext import paddle.distributed.fleet as fleet from dist_pass_test_base import DistPassTestBase +from paddle.distributed.auto_parallel.dist_context import DistributedContext logging.getLogger().setLevel(logging.INFO) paddle.enable_static() @@ -111,14 +112,20 @@ class TestGradientMergePass(DistPassTestBase): def init(self): self._params_grads = None self._config = {"k_steps": 4, "avg": True} + #self._config["dist_context"] = DistributedContext() def apply_passes(self, main_prog, startup_prog): - self._config["params_grads"] = self._params_grads - pass_context = PassContext() - auto_parallel_gradient_merge_pass = new_pass( - "auto_parallel_gradient_merge_pass", self._config) - auto_parallel_gradient_merge_pass.apply([main_prog], [startup_prog], - pass_context) + #self._config["params_grads"] = self._params_grads + #pass_context = PassContext() + #auto_parallel_gradient_merge_pass = new_pass( + # "auto_parallel_gradient_merge_pass", self._config) + #auto_parallel_gradient_merge_pass.apply([main_prog], [startup_prog], + # pass_context) + dist_strategy = fleet.DistributedStrategy() + dist_strategy.gradient_merge = True + dist_strategy.gradient_merge_configs = {"k_steps": 4, "avg": True} + dist_strategy.semi_auto = True + fleet.init(is_collective=True, strategy=dist_strategy) def test_result(self): no_pass_rets = self._distributed_launch( @@ -135,7 +142,7 @@ class TestGradientMergePass(DistPassTestBase): gradient_merge=True, batch_size=8, max_step=8) - + """ # avg loss for gradient_merge pass avg_loss = 0 pass_avg_ret_list = [] @@ -156,6 +163,7 @@ class TestGradientMergePass(DistPassTestBase): rtol=self.rtol, atol=self.atol, equal_nan=self.equal_nan)) + """ def get_model(self, place, gradient_merge, batch_size, max_step): paddle.seed(2021)