diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index 6e6d2a672fd18631c4f0ac7073eaada488b37967..da0f2ebcba89ef1ffddf1870eeba75ca07c4a6bb 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -279,7 +279,7 @@ def _is_overlapped(shape_x, shape_y): return overlapped -def _need_reshard(dist_tensor, dist_op): +def _need_reshard(dist_tensor, dist_op, op_input=True): """Judge the tensor whether needs to be resharded.""" is_reshard = False tensor_dist_attr = dist_tensor.dist_attr @@ -289,13 +289,31 @@ def _need_reshard(dist_tensor, dist_op): op_dist_attr = dist_op.dist_attr op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name) op_process_mesh = op_dist_attr.process_mesh - if all( - map(lambda x: x is not None, [ - tensor_dims_mapping, tensor_process_mesh, op_input_dims_mapping, - op_process_mesh - ])): - if tensor_dims_mapping != op_input_dims_mapping or tensor_process_mesh != op_process_mesh: - is_reshard = True + if op_input: + op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name) + op_process_mesh = op_dist_attr.process_mesh + if all( + map(lambda x: x is not None, [ + tensor_dims_mapping, tensor_process_mesh, + op_input_dims_mapping, op_process_mesh + ])): + if tensor_dims_mapping != op_input_dims_mapping or tensor_process_mesh != op_process_mesh: + is_reshard = True + else: + op_output_dims_mapping = op_dist_attr.get_output_dims_mapping( + tensor_name) + op_process_mesh = op_dist_attr.process_mesh + if all( + map(lambda x: x is not None, [ + tensor_dims_mapping, tensor_process_mesh, + op_output_dims_mapping, op_process_mesh + ])): + if tensor_process_mesh != op_process_mesh: + is_reshard = True + if tensor_dims_mapping != op_output_dims_mapping: + raise ValueError( + "It is not supported that tensor dims mapping is different from op output dims mapping." + ) return is_reshard @@ -948,12 +966,13 @@ def remove_no_need_in_startup(auto_parallel_main_prog, def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, dist_context): """ - Reshard tensor in the program according to its dist attr and corresponding op dist attr. + Reshard tensor in the program according to its distributed attribute and corresponding op distributed attribute. Args: auto_parallel_main_prog (Program): An auto parallel main program. auto_parallel_startup_prog (Program): An auto parallel startup program. rank_id (int): The process id. + dist_context (DistributedContext): The distributed context of this rank. """ assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_main_prog should be Program, " \ "but got {}.".format(type(auto_parallel_main_prog)) @@ -1001,6 +1020,34 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, else: idx += 1 + # insert send and recv op if output process mesh is different from tensor process mesh + idx = 0 + skip_ops = ["create_py_reader", "create_double_buffer_reader", "read"] + while idx < len(block.ops): + pre_op_count = len(block.ops) + op = block.ops[idx] + dist_op = dist_context.get_dist_op_for_program(op) + if dist_op is not None and op.type not in skip_ops: + for var_name in op.output_arg_names: + var = block.vars[var_name] + dist_tensor = dist_context.get_dist_tensor_for_program(var) + if dist_tensor is not None and _need_reshard(dist_tensor, + dist_op, False): + for index, item in enumerate( + dist_op.dist_attr.process_mesh.processes): + recv_rank = dist_tensor.dist_attr.process_mesh.processes[ + index] + if rank_id == item: + _insert_send_op(block, idx + 1, var, recv_rank) + if rank_id == recv_rank: + _insert_recv_op(block, idx + 1, var, item) + cur_op_count = len(block.ops) + idx_offset = idx_offset + cur_op_count - pre_op_count + pre_op_count = cur_op_count + idx = idx + idx_offset + 1 + else: + idx += 1 + # remove no need vars and ops in the main program remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py index b234e25823f4b370b9a4150ee3f8b7d635468952..a93abd3c1277681234209c27f54f0d019bf4e9df 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py @@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed import fleet from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.partitioner import Partitioner -from paddle.distributed.auto_parallel.reshard import reshard +from paddle.distributed.auto_parallel.reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER from paddle.distributed.auto_parallel.process_group import _g_process_group_map from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr @@ -143,7 +143,11 @@ def mlp_forward(train_program, start_program): return loss, train_program, start_program -def get_dist_prog(train_program, startup_program, dist_context, rank_id): +def get_dist_prog(train_program, + startup_program, + dist_context, + rank_id, + change_process_mesh=False): loss, train_program, startup_program = mlp_forward(train_program, startup_program) @@ -157,6 +161,12 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): complete_train_program = completer.complete_forward_annotation( train_program) + if change_process_mesh: + global PP_MESH_1 + dist_context.get_tensor_dist_attr_for_program( + train_program.global_block().vars[ + "gelu_0.tmp_0"]).process_mesh = PP_MESH_1 + params_grads = parallelizer._generate_backward( complete_train_program, startup_program, @@ -308,6 +318,25 @@ class TestMLPReshard(unittest.TestCase): # parameter initialization of every rank should be different in the pipeline scene self.assertTrue(check_initialization(dist_startup_prog, rank_id)) + def test_mlp_pp_diff_process_mesh(self): + HAS_SENT.clear() + HAS_RECV.clear() + HAS_ALLGATHER.clear() + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + dist_context = DistributedContext() + rank_id = 1 + dist_main_prog, dist_startup_prog = get_dist_prog( + train_program, startup_program, dist_context, rank_id, True) + for key in list(_g_process_group_map.keys()): + del _g_process_group_map[key] + reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) + print_program_with_dist_attr(dist_main_prog, dist_context) + + # check send and recv result + self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) + self.assertTrue(check_initialization(dist_startup_prog, rank_id)) + def test_mlp_dp(self): global _global_parallel_strategy _global_parallel_strategy = "dp"