diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index 43f5fa264790f9e53c60069831257ef21d7d8c81..2f557ad3e9fe38e5eb6807e2abae35e9c996bd39 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -209,7 +209,8 @@ class AutoParallelizer: make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context) - reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context) + reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context, + dist_params_grads) self._apply_post_optimization_passes(dist_main_prog, dist_startup_prog, rank, dist_params_grads) diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index da0f2ebcba89ef1ffddf1870eeba75ca07c4a6bb..c28a48da838fec88ffa1703251705564add993c2 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -20,6 +20,7 @@ import paddle.fluid.core as core from paddle.utils import unique_name from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.framework import Program, OpProtoHolder +from paddle.distributed.fleet.meta_optimizers.common import OpRole import paddle.fluid.layers.utils as utils from ..collective import _get_global_env from .dist_context import DistributedContext @@ -862,7 +863,7 @@ def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id): block._remove_op(idx) -def _remove_no_need_vars(auto_parallel_main_prog): +def _remove_no_need_vars(auto_parallel_main_prog, dist_params_grads): """Remove no need vars in the main program""" remove_vars = set() block = auto_parallel_main_prog.global_block() @@ -879,14 +880,42 @@ def _remove_no_need_vars(auto_parallel_main_prog): for var in vars: if var not in need_vars: remove_vars.add(var) + + # change dist_params_grads + param_grad_map = {} + for op in ops: + if int(op.attr('op_role')) == int(OpRole.Optimize): + if "Param" in op.input_names and "Grad" in op.input_names: + param_name = op.input("Param")[0] + grad_name = op.input("Grad")[0] + param_grad_map[param_name] = grad_name + + need_remove_idx = [] + for idx, item in enumerate(dist_params_grads): + if item[0].name not in param_grad_map.keys(): + need_remove_idx.append(idx) + + for idx in need_remove_idx[::-1]: + dist_params_grads.pop(idx) + + idx = 0 + while idx < len(dist_params_grads): + param_name = dist_params_grads[idx][0].name + grad_name = dist_params_grads[idx][1].name + if grad_name != param_grad_map[param_name]: + dist_params_grads[idx] = (vars[param_name], + vars[param_grad_map[param_name]]) + idx += 1 + for var in remove_vars: block._remove_var(var) -def remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id): +def remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id, + dist_params_grads): """Remove no need vars and ops in the main program.""" _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id) - _remove_no_need_vars(auto_parallel_main_prog) + _remove_no_need_vars(auto_parallel_main_prog, dist_params_grads) def remove_no_need_in_startup(auto_parallel_main_prog, @@ -964,7 +993,7 @@ def remove_no_need_in_startup(auto_parallel_main_prog, def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, - dist_context): + dist_context, dist_params_grads): """ Reshard tensor in the program according to its distributed attribute and corresponding op distributed attribute. @@ -973,6 +1002,7 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, auto_parallel_startup_prog (Program): An auto parallel startup program. rank_id (int): The process id. dist_context (DistributedContext): The distributed context of this rank. + dist_params_grads (list): The list contains the tuple of param and grad. """ assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_main_prog should be Program, " \ "but got {}.".format(type(auto_parallel_main_prog)) @@ -1049,7 +1079,8 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, idx += 1 # remove no need vars and ops in the main program - remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id) + remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id, + dist_params_grads) # remove no need vars and ops in the startip program remove_no_need_in_startup(auto_parallel_main_prog, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py index fd19a5bd8b866960c922923dcbb68521c3e4c8c0..52397f51321f585784b52c4a39bd707cf97f7dc4 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py @@ -175,7 +175,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): partitioned_optimize_ops = parallelizer._apply_optimize( auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) - return auto_parallel_main_prog, auto_parallel_startup_prog + return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads def check_runtime_estimation(cost): @@ -229,10 +229,10 @@ class TestCostModel(unittest.TestCase): train_program = paddle.static.Program() startup_program = paddle.static.Program() dist_context = DistributedContext() - distributed_program, dist_startup_prog = get_dist_prog( + distributed_program, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id) reshard(distributed_program, dist_startup_prog, rank_id, - dist_context) + dist_context, dist_params_grads) dist_program.append(distributed_program) cluster = None cost = estimate_cost( diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py index 9d4de771076cd85e136c8058642f9774ee02461c..8869fd6a59e3772507aa6413afd7c872bab7a533 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py @@ -502,7 +502,8 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): partitioned_optimize_ops = parallelizer._apply_optimize( dist_train_program, dist_startup_prog, dist_params_grads) - reshard(dist_train_program, dist_startup_prog, rank_id, dist_context) + reshard(dist_train_program, dist_startup_prog, rank_id, dist_context, + dist_params_grads) return dist_train_program, dist_startup_prog diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py index a93abd3c1277681234209c27f54f0d019bf4e9df..1d8938785924cfadfdb232aeeb42b7af045af09a 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py @@ -183,7 +183,7 @@ def get_dist_prog(train_program, partitioned_optimize_ops = parallelizer._apply_optimize( auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) - return auto_parallel_main_prog, auto_parallel_startup_prog + return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads def check_backward_dist_attr(dist_context, dist_main_prog, op_need_check): @@ -277,7 +277,7 @@ class TestMLPReshard(unittest.TestCase): startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 0 - dist_main_prog, dist_startup_prog = get_dist_prog( + dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, 0) op_need_check = None @@ -306,11 +306,12 @@ class TestMLPReshard(unittest.TestCase): startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 1 - dist_main_prog, dist_startup_prog = get_dist_prog( + dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id) for key in list(_g_process_group_map.keys()): del _g_process_group_map[key] - reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) + reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, + dist_params_grads) # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) @@ -326,11 +327,12 @@ class TestMLPReshard(unittest.TestCase): startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 1 - dist_main_prog, dist_startup_prog = get_dist_prog( + dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id, True) for key in list(_g_process_group_map.keys()): del _g_process_group_map[key] - reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) + reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, + dist_params_grads) print_program_with_dist_attr(dist_main_prog, dist_context) # check send and recv result @@ -347,9 +349,10 @@ class TestMLPReshard(unittest.TestCase): startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 0 - dist_main_prog, dist_startup_prog = get_dist_prog( + dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id) - reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) + reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, + dist_params_grads) # send and recv should not exist in dp scene. self.assertFalse(check_send_recv_result(dist_main_prog, rank_id)) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py index 40847a769033ab5ae15ee30e83d9db960bbf7781..5a79d1f9514ab2c8ce1f6de7956653df463a1f9d 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py @@ -137,7 +137,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): partitioned_optimize_ops = parallelizer._apply_optimize( auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) - return auto_parallel_main_prog, auto_parallel_startup_prog + return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads def check_send_recv_result(dist_main_prog, rank_id): @@ -177,9 +177,10 @@ class TestMLPReshard(unittest.TestCase): startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 2 - dist_main_prog, dist_startup_prog = get_dist_prog( + dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id) - reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) + reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, + dist_params_grads) # print_program_with_dist_attr(dist_main_prog, dist_context) # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py index 869bcd4c7ab32723c12478dba5a5732ed0f0d537..6696a9d3006d2bdec61b14fc49a639060d5fa4cd 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py @@ -152,7 +152,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): partitioned_optimize_ops = parallelizer._apply_optimize( auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) - return auto_parallel_main_prog, auto_parallel_startup_prog + return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads def check_send_recv_result(dist_main_prog, rank_id): @@ -211,9 +211,10 @@ class TestMLPReshard(unittest.TestCase): startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 2 - dist_main_prog, dist_startup_prog = get_dist_prog( + dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id) - reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) + reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, + dist_params_grads) # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) @@ -271,7 +272,7 @@ class TestMLPReshard(unittest.TestCase): partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition( complete_train_program, startup_program, []) reshard(partitioned_main_prog, partitioned_startup_prog, rank_id, - dist_context) + dist_context, partitioned_params_grads) # the x should not be slice self.assertTrue(check_allgather(partitioned_main_prog))