From c22e1123091d7b6592b07a9f6acb1c8c108e271b Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Mon, 6 Jun 2022 19:06:42 +0800 Subject: [PATCH] [AutoParallel] fix gradient merge optimize parse (#43169) * fix gradient merge * bug fix * update annotation --- .../auto_parallel/parallelizer_v2.py | 6 +- .../passes/auto_parallel_gradient_merge.py | 70 ++++++++------- .../distributed_passes/CMakeLists.txt | 2 +- ...test_auto_parallel_gradient_merge_pass.py} | 88 +++++-------------- 4 files changed, 66 insertions(+), 100 deletions(-) rename python/paddle/fluid/tests/unittests/distributed_passes/{test_dist_gradient_merge_pass.py => test_auto_parallel_gradient_merge_pass.py} (72%) diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py index ce543988ea4..f02eb38f458 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -148,7 +148,7 @@ class Parallelizer: config) auto_parallel_recompute_pass.apply([main_program], [startup_program], - self._dist_context) + self._pass_context) def _apply_post_optimization(self, main_program, startup_program, rank, params_grads): @@ -162,7 +162,7 @@ class Parallelizer: auto_parallel_sharding_pass = new_pass("auto_parallel_sharding", config) auto_parallel_sharding_pass.apply([main_program], [startup_program], - self._dist_context) + self._pass_context) if self._strategy.gradient_merge: config = copy.deepcopy(self._strategy.gradient_merge_configs) @@ -172,4 +172,4 @@ class Parallelizer: "auto_parallel_gradient_merge_pass", config) auto_parallel_gradient_merge_pass.apply([main_program], [startup_program], - self._dist_context) + self._pass_context) diff --git a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py index bc40dad8ac0..394d71706c4 100644 --- a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py +++ b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py @@ -18,10 +18,10 @@ from typing import List, Tuple, Dict, Any import paddle from paddle.framework import core +from paddle.fluid import layers from paddle.fluid.framework import program_guard, device_guard -from paddle.fluid import unique_name, layers -from paddle.fluid.clip import append_gradient_clip_ops from .pass_base import PassBase, PassType, register_pass +from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.distributed.auto_parallel.utils import set_var_dist_attr from paddle.distributed.auto_parallel.utils import naive_set_dist_op_attr_for_program_by_mesh_and_mapping from paddle.distributed.auto_parallel.process_group import get_world_process_group @@ -29,16 +29,8 @@ from paddle.distributed.auto_parallel.process_group import get_world_process_gro world_process_group = get_world_process_group() -def _is_the_backward_op(op): - OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() - OpRole = core.op_proto_and_checker_maker.OpRole - return OP_ROLE_KEY in op.attr_names and \ - int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward) - - def _is_the_optimizer_op(op): OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() - OpRole = core.op_proto_and_checker_maker.OpRole return OP_ROLE_KEY in op.attr_names and \ int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize) @@ -47,13 +39,13 @@ def _remove_and_get_optimizer_op(main_program, dist_context): # 1 create tmp block # 2 mv optimizer op from global program to tmp block # 3 del the op from dist_context - from paddle.distributed.fleet.meta_optimizers.common import OpRole main_block = main_program.global_block() temp_block = main_program._create_block() removed_op_idx = [] optimize_ops_desc = [] + skip_ops = ["increment", "elementwise_mod", "equal"] for idx, op in enumerate(main_block.ops): - if _is_the_optimizer_op(op): + if _is_the_optimizer_op(op) and op.type not in skip_ops: # append optimizer op to tmp block new_op_desc = temp_block.desc.append_op() new_op_desc.copy_from(op.desc) @@ -111,8 +103,17 @@ def _get_gm_cond_var(main_program, k_steps, dist_context): set_var_dist_attr(dist_context, cond_var, [-1], world_process_group.ranks) with device_guard("cpu"): - # step_var = (step_var + 1) % k_step - layers.increment(x=step_var, value=1.0, in_place=True) + # step_var += 1 + increment_op = main_block.append_op(type='increment', + inputs={'X': [step_var]}, + outputs={'Out': [step_var]}, + attrs={ + 'step': float(1.0), + 'op_role': OpRole.Optimize + }) + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + increment_op, world_process_group.ranks, [-1], dist_context) + # step_var %= k_step elementwise_mod_op = main_block.append_op(type='elementwise_mod', inputs={ 'X': step_var, @@ -121,18 +122,19 @@ def _get_gm_cond_var(main_program, k_steps, dist_context): outputs={'Out': step_var}, attrs={ 'axis': -1, - 'use_mkldnn': False + 'use_mkldnn': False, + 'op_role': OpRole.Optimize }) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( elementwise_mod_op, world_process_group.ranks, [-1], dist_context) - # cond_var = (step_var == 0) equal_op = main_block.append_op(type='equal', inputs={ 'X': step_var, 'Y': zero_var }, - outputs={'Out': cond_var}) + outputs={'Out': cond_var}, + attrs={'op_role': OpRole.Optimize}) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( equal_op, world_process_group.ranks, [-1], dist_context) @@ -154,7 +156,9 @@ def _append_gradient_merge_backward_op( _remove_op_role_var(param, grad) - param_to_gradient_merge = {} + # {grad.name: gradient_merge_var.name} to rename opt inputs + grad_to_gradient_merge = {} + # {param: gradient_merge_var} to insert scale op and fill_constant op new_params_to_grads = [] # step2: create gradient_merge var and init with 0 for param, grad in params_grads: @@ -168,7 +172,6 @@ def _append_gradient_merge_backward_op( shape=param_var.shape, dtype=param_var.dtype, persistable=True) - param_to_gradient_merge[param_name] = gradient_merge_var ref_process_mesh = ref_dist_attr.process_mesh ref_dims_mapping = ref_dist_attr.dims_mapping @@ -197,17 +200,19 @@ def _append_gradient_merge_backward_op( outputs={'Out': gradient_merge_var}, attrs={ 'axis': -1, - 'use_mkldnn': False + 'use_mkldnn': False, + 'op_role': OpRole.Optimize }) new_params_to_grads.append([param, gradient_merge_var]) + grad_to_gradient_merge[grad.name] = gradient_merge_var.name naive_set_dist_op_attr_for_program_by_mesh_and_mapping( new_grad_op, ref_process_mesh, ref_dims_mapping, dist_context) - return new_params_to_grads, param_to_gradient_merge + return new_params_to_grads, grad_to_gradient_merge def _create_cond_block_and_update_optimizer( main_program, cond_var, new_params_to_grads: List[Tuple[Any, Any]], - param_to_gradient_merge: Dict[str, Any], optimize_ops_desc: List[Any], + grad_to_gradient_merge: Dict[str, str], optimize_ops_desc: List[Any], k_steps, avg): def true_apply_gradient(): @@ -229,7 +234,7 @@ def _create_cond_block_and_update_optimizer( 'bias_after_scale': False }) new_grad.op._set_attr(op_maker.kOpRoleAttrName(), - op_maker.OpRole.Optimize) + OpRole.Optimize) # append optimizer ops for op_desc in optimize_ops_desc: @@ -238,14 +243,14 @@ def _create_cond_block_and_update_optimizer( #update input/output for input_name in new_op_desc.input_arg_names(): - if input_name in new_params_to_grads: - new_op_desc._rename_input(input_name, - new_params_to_grads[input_name]) + if input_name in grad_to_gradient_merge: + new_op_desc._rename_input( + input_name, grad_to_gradient_merge[input_name]) for output_name in new_op_desc.output_arg_names(): - if output_name in new_params_to_grads: - new_op_desc._rename_output(output_name, - new_params_to_grads[output_name]) + if output_name in grad_to_gradient_merge: + new_op_desc._rename_output( + output_name, grad_to_gradient_merge[output_name]) # remove op_role_var if new_op_desc.has_attr(op_maker.kOpRoleVarAttrName()): @@ -271,6 +276,8 @@ def _create_cond_block_and_update_optimizer( op_maker.OpRole.Optimize) layers.cond(cond_var, true_fn=true_apply_gradient, false_fn=None) + cond_op = main_program.global_block().ops[-1] + cond_op._set_attr('op_role', OpRole.Optimize) def parse_program(main_program, startup_program, params_grads, k_steps, avg, @@ -285,14 +292,14 @@ def parse_program(main_program, startup_program, params_grads, k_steps, avg, main_program._rollback() # 3 append gradient merge backward op to main_program - new_params_to_grads, param_to_gradient_merge = _append_gradient_merge_backward_op( + new_params_to_grads, grad_to_gradient_merge = _append_gradient_merge_backward_op( main_program, startup_program, params_grads, cond_var.name, dist_context) # 4 create ConditionalBlock and append gradient merge optimizer ops _create_cond_block_and_update_optimizer(main_program, cond_var, new_params_to_grads, - param_to_gradient_merge, + grad_to_gradient_merge, optimize_ops_desc, k_steps, avg) @@ -303,7 +310,6 @@ class GradientMergePass(PassBase): super(GradientMergePass, self).__init__() self.set_attr("k_steps", -1) self.set_attr("avg", True) - self.set_attr("inner_optimizer", None) def _check_self(self): if self.get_attr("k_steps") < 1: diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt b/python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt index c68cebaa25b..29e528edce9 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt @@ -14,12 +14,12 @@ if((NOT WITH_GPU) list(REMOVE_ITEM TEST_OPS "test_dist_fuse_momentum_pass") list(REMOVE_ITEM TEST_OPS "test_dist_fuse_relu_depthwise_conv_pass") list(REMOVE_ITEM TEST_OPS "test_dist_fuse_sgd_pass") - list(REMOVE_ITEM TEST_OPS "test_dist_gradient_merge_pass") list(REMOVE_ITEM TEST_OPS "test_dist_inplace_addto_pass") list(REMOVE_ITEM TEST_OPS "test_auto_parallel_amp_pass") list(REMOVE_ITEM TEST_OPS "test_auto_parallel_recompute_pass") list(REMOVE_ITEM TEST_OPS "test_auto_parallel_sharding_pass") list(REMOVE_ITEM TEST_OPS "test_auto_parallel_fp16_pass") + list(REMOVE_ITEM TEST_OPS "test_auto_parallel_gradient_merge_pass") endif() foreach(TEST_OP ${TEST_OPS}) diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/test_dist_gradient_merge_pass.py b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_gradient_merge_pass.py similarity index 72% rename from python/paddle/fluid/tests/unittests/distributed_passes/test_dist_gradient_merge_pass.py rename to python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_gradient_merge_pass.py index f856059402e..50e18718201 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/test_dist_gradient_merge_pass.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_gradient_merge_pass.py @@ -25,20 +25,14 @@ import paddle.nn as nn import paddle.utils as utils import paddle.static as static import paddle.nn.functional as F +import paddle.distributed.fleet as fleet import paddle.distributed.auto_parallel as auto -from paddle.fluid.initializer import NumpyArrayInitializer -from paddle.distributed.passes import new_pass, PassManager, PassContext -import paddle.distributed.fleet as fleet -from dist_pass_test_base import DistPassTestBase -from paddle.distributed.auto_parallel.dist_context import DistributedContext +from paddle.fluid.initializer import NumpyArrayInitializer +from auto_parallel_pass_test_base import AutoPallelPassTestBase logging.getLogger().setLevel(logging.INFO) paddle.enable_static() -_global_parallel_strategy = None -_global_process_mesh = None - -#np.set_printoptions(suppress=True) class MLPLayer(nn.Layer): @@ -103,13 +97,11 @@ class MLPLayer(nn.Layer): def mlp_forward(input, label, hidden_size): - if _global_parallel_strategy == "dp": - auto.shard_tensor(input, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) - + auto.shard_tensor(input, + dist_attr={ + "process_mesh": [0], + "dims_mapping": [-1, -1] + }) mlp = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, initializer_range=0.02) @@ -119,40 +111,33 @@ def mlp_forward(input, label, hidden_size): return loss -class TestGradientMergePass(DistPassTestBase): +class TestGradientMergePass(AutoPallelPassTestBase): def init(self): - self._params_grads = None - self._config = {"k_steps": 4, "avg": True} - #self._config["dist_context"] = DistributedContext() - - def apply_passes(self, main_prog, startup_prog): - #self._config["params_grads"] = self._params_grads - #pass_context = PassContext() - #auto_parallel_gradient_merge_pass = new_pass( - # "auto_parallel_gradient_merge_pass", self._config) - #auto_parallel_gradient_merge_pass.apply([main_prog], [startup_prog], - # pass_context) + paddle.seed(2022) + random.seed(2022) + np.random.seed(2022) + + def apply_passes(self): dist_strategy = fleet.DistributedStrategy() + dist_strategy.semi_auto = True dist_strategy.gradient_merge = True dist_strategy.gradient_merge_configs = {"k_steps": 4, "avg": True} - dist_strategy.semi_auto = True fleet.init(is_collective=True, strategy=dist_strategy) def test_result(self): no_pass_rets = self._distributed_launch(model=None, apply_pass=False, gpus=[0], - gradient_merge=False, batch_size=32, + hidden_size=128, max_step=2) pass_rets = self._distributed_launch(model=None, apply_pass=True, gpus=[0], - gradient_merge=True, batch_size=8, + hidden_size=128, max_step=8) - """ # avg loss for gradient_merge pass avg_loss = 0 pass_avg_ret_list = [] @@ -167,40 +152,16 @@ class TestGradientMergePass(DistPassTestBase): for no_pass_ret, pass_ret in zip(no_pass_rets[0], pass_avg_ret_list): print(f"no_pass_ret={no_pass_ret}, pass_ret={pass_ret}") self.assertTrue( - np.isclose( - no_pass_ret, - pass_ret, - rtol=self.rtol, - atol=self.atol, - equal_nan=self.equal_nan)) - """ - - def get_model(self, place, gradient_merge, batch_size, max_step): - paddle.seed(2021) - random.seed(2021) - np.random.seed(2021) + np.isclose(no_pass_ret, + pass_ret, + rtol=self.rtol, + atol=self.atol, + equal_nan=self.equal_nan)) - hidden_size = 128 - - global _global_parallel_strategy - global _global_process_mesh - world_size = paddle.distributed.get_world_size() - if world_size == 1: - _global_parallel_strategy = "dp" - _global_process_mesh = auto.ProcessMesh([0]) - elif world_size == 2: - _global_parallel_strategy = "dp" - _global_process_mesh = auto.ProcessMesh([0, 1]) + def get_model(self, place, batch_size, hidden_size, max_step): train_program = static.Program() startup_program = static.Program() - dist_strategy = fleet.DistributedStrategy() - dist_strategy.semi_auto = True - #if gradient_merge: - # dist_strategy.gradient_merge = True - # dist_strategy.gradient_merge_configs = {"k_steps": 4, "avg": True} - fleet.init(is_collective=True, strategy=dist_strategy) - with static.program_guard(train_program, startup_program), \ utils.unique_name.guard(): input = static.data(name="input", @@ -212,8 +173,7 @@ class TestGradientMergePass(DistPassTestBase): input.stop_gradient = False loss = mlp_forward(input, label, hidden_size) - optimizer = paddle.fluid.optimizer.SGDOptimizer(learning_rate=0.01) - #optimizer = paddle.fluid.optimizer.Adam(learning_rate=0.01) + optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.01) optimizer = fleet.distributed_optimizer(optimizer) _, self._params_grads, dist_startup_prog, dist_main_prog = optimizer.minimize( loss, startup_program) -- GitLab