From cac6f408e7796e0b4d5153277d36be546c7e1535 Mon Sep 17 00:00:00 2001 From: caozhou <48191911+Caozhou1995@users.noreply.github.com> Date: Thu, 27 Jan 2022 11:57:44 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Auto=20Parallel=E3=80=91update=20dist?= =?UTF-8?q?=20param=20grad=20for=20pass=20(#38941)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update dist param grad for pass * update unitest * update unitests * fix conflict --- .../distributed/auto_parallel/parallelizer.py | 3 +- .../distributed/auto_parallel/reshard.py | 41 ++++++++++++++++--- .../test_auto_parallel_cost_model.py | 6 +-- .../unittests/test_auto_parallel_mapper.py | 3 +- .../unittests/test_auto_parallel_reshard.py | 19 +++++---- .../test_auto_parallel_reshard_dpmppp.py | 7 ++-- .../test_auto_parallel_reshard_mppp.py | 9 ++-- 7 files changed, 63 insertions(+), 25 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index 43f5fa26479..2f557ad3e9f 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -209,7 +209,8 @@ class AutoParallelizer: make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context) - reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context) + reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context, + dist_params_grads) self._apply_post_optimization_passes(dist_main_prog, dist_startup_prog, rank, dist_params_grads) diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index da0f2ebcba8..c28a48da838 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -20,6 +20,7 @@ import paddle.fluid.core as core from paddle.utils import unique_name from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.framework import Program, OpProtoHolder +from paddle.distributed.fleet.meta_optimizers.common import OpRole import paddle.fluid.layers.utils as utils from ..collective import _get_global_env from .dist_context import DistributedContext @@ -862,7 +863,7 @@ def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id): block._remove_op(idx) -def _remove_no_need_vars(auto_parallel_main_prog): +def _remove_no_need_vars(auto_parallel_main_prog, dist_params_grads): """Remove no need vars in the main program""" remove_vars = set() block = auto_parallel_main_prog.global_block() @@ -879,14 +880,42 @@ def _remove_no_need_vars(auto_parallel_main_prog): for var in vars: if var not in need_vars: remove_vars.add(var) + + # change dist_params_grads + param_grad_map = {} + for op in ops: + if int(op.attr('op_role')) == int(OpRole.Optimize): + if "Param" in op.input_names and "Grad" in op.input_names: + param_name = op.input("Param")[0] + grad_name = op.input("Grad")[0] + param_grad_map[param_name] = grad_name + + need_remove_idx = [] + for idx, item in enumerate(dist_params_grads): + if item[0].name not in param_grad_map.keys(): + need_remove_idx.append(idx) + + for idx in need_remove_idx[::-1]: + dist_params_grads.pop(idx) + + idx = 0 + while idx < len(dist_params_grads): + param_name = dist_params_grads[idx][0].name + grad_name = dist_params_grads[idx][1].name + if grad_name != param_grad_map[param_name]: + dist_params_grads[idx] = (vars[param_name], + vars[param_grad_map[param_name]]) + idx += 1 + for var in remove_vars: block._remove_var(var) -def remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id): +def remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id, + dist_params_grads): """Remove no need vars and ops in the main program.""" _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id) - _remove_no_need_vars(auto_parallel_main_prog) + _remove_no_need_vars(auto_parallel_main_prog, dist_params_grads) def remove_no_need_in_startup(auto_parallel_main_prog, @@ -964,7 +993,7 @@ def remove_no_need_in_startup(auto_parallel_main_prog, def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, - dist_context): + dist_context, dist_params_grads): """ Reshard tensor in the program according to its distributed attribute and corresponding op distributed attribute. @@ -973,6 +1002,7 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, auto_parallel_startup_prog (Program): An auto parallel startup program. rank_id (int): The process id. dist_context (DistributedContext): The distributed context of this rank. + dist_params_grads (list): The list contains the tuple of param and grad. """ assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_main_prog should be Program, " \ "but got {}.".format(type(auto_parallel_main_prog)) @@ -1049,7 +1079,8 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, idx += 1 # remove no need vars and ops in the main program - remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id) + remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id, + dist_params_grads) # remove no need vars and ops in the startip program remove_no_need_in_startup(auto_parallel_main_prog, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py index fd19a5bd8b8..52397f51321 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py @@ -175,7 +175,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): partitioned_optimize_ops = parallelizer._apply_optimize( auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) - return auto_parallel_main_prog, auto_parallel_startup_prog + return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads def check_runtime_estimation(cost): @@ -229,10 +229,10 @@ class TestCostModel(unittest.TestCase): train_program = paddle.static.Program() startup_program = paddle.static.Program() dist_context = DistributedContext() - distributed_program, dist_startup_prog = get_dist_prog( + distributed_program, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id) reshard(distributed_program, dist_startup_prog, rank_id, - dist_context) + dist_context, dist_params_grads) dist_program.append(distributed_program) cluster = None cost = estimate_cost( diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py index 9d4de771076..8869fd6a59e 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py @@ -502,7 +502,8 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): partitioned_optimize_ops = parallelizer._apply_optimize( dist_train_program, dist_startup_prog, dist_params_grads) - reshard(dist_train_program, dist_startup_prog, rank_id, dist_context) + reshard(dist_train_program, dist_startup_prog, rank_id, dist_context, + dist_params_grads) return dist_train_program, dist_startup_prog diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py index a93abd3c127..1d893878592 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py @@ -183,7 +183,7 @@ def get_dist_prog(train_program, partitioned_optimize_ops = parallelizer._apply_optimize( auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) - return auto_parallel_main_prog, auto_parallel_startup_prog + return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads def check_backward_dist_attr(dist_context, dist_main_prog, op_need_check): @@ -277,7 +277,7 @@ class TestMLPReshard(unittest.TestCase): startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 0 - dist_main_prog, dist_startup_prog = get_dist_prog( + dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, 0) op_need_check = None @@ -306,11 +306,12 @@ class TestMLPReshard(unittest.TestCase): startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 1 - dist_main_prog, dist_startup_prog = get_dist_prog( + dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id) for key in list(_g_process_group_map.keys()): del _g_process_group_map[key] - reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) + reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, + dist_params_grads) # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) @@ -326,11 +327,12 @@ class TestMLPReshard(unittest.TestCase): startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 1 - dist_main_prog, dist_startup_prog = get_dist_prog( + dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id, True) for key in list(_g_process_group_map.keys()): del _g_process_group_map[key] - reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) + reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, + dist_params_grads) print_program_with_dist_attr(dist_main_prog, dist_context) # check send and recv result @@ -347,9 +349,10 @@ class TestMLPReshard(unittest.TestCase): startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 0 - dist_main_prog, dist_startup_prog = get_dist_prog( + dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id) - reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) + reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, + dist_params_grads) # send and recv should not exist in dp scene. self.assertFalse(check_send_recv_result(dist_main_prog, rank_id)) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py index 40847a76903..5a79d1f9514 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py @@ -137,7 +137,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): partitioned_optimize_ops = parallelizer._apply_optimize( auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) - return auto_parallel_main_prog, auto_parallel_startup_prog + return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads def check_send_recv_result(dist_main_prog, rank_id): @@ -177,9 +177,10 @@ class TestMLPReshard(unittest.TestCase): startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 2 - dist_main_prog, dist_startup_prog = get_dist_prog( + dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id) - reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) + reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, + dist_params_grads) # print_program_with_dist_attr(dist_main_prog, dist_context) # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py index 869bcd4c7ab..6696a9d3006 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py @@ -152,7 +152,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): partitioned_optimize_ops = parallelizer._apply_optimize( auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) - return auto_parallel_main_prog, auto_parallel_startup_prog + return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads def check_send_recv_result(dist_main_prog, rank_id): @@ -211,9 +211,10 @@ class TestMLPReshard(unittest.TestCase): startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 2 - dist_main_prog, dist_startup_prog = get_dist_prog( + dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id) - reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) + reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, + dist_params_grads) # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) @@ -271,7 +272,7 @@ class TestMLPReshard(unittest.TestCase): partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition( complete_train_program, startup_program, []) reshard(partitioned_main_prog, partitioned_startup_prog, rank_id, - dist_context) + dist_context, partitioned_params_grads) # the x should not be slice self.assertTrue(check_allgather(partitioned_main_prog)) -- GitLab