diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index be4c68d97d840c8803c2e172d45e0fa700eb75cf..c31642a9e2af315ff0ca3f2d4d6a624baae8733e 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1407,6 +1407,27 @@ def naive_set_dist_op_attr_for_program_by_mesh_and_mapping( ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr) +def naive_set_dist_op_attr_for_program_by_mesh( + new_op, process_mesh, ctx, is_recompute=False +): + assert process_mesh is not None + + new_op_dist_attr = OperatorDistributedAttribute() + + for input_varname in new_op.desc.input_arg_names(): + var = ctx.serial_main_program.global_block().var(input_varname) + mapping = ctx.get_tensor_dist_attr_for_program(var).dims_mapping + new_op_dist_attr.set_input_dims_mapping(input_varname, mapping) + for output_varname in new_op.desc.output_arg_names(): + var = ctx.serial_main_program.global_block().var(output_varname) + mapping = ctx.get_tensor_dist_attr_for_program(var).dims_mapping + new_op_dist_attr.set_output_dims_mapping(output_varname, mapping) + + new_op_dist_attr.process_mesh = process_mesh + new_op_dist_attr.is_recompute = is_recompute + ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr) + + def update_op_dims_mapping_by_default_dist_impl(dist_op): changed = False op_dist_attr = dist_op.dist_attr @@ -2102,3 +2123,73 @@ def _copy_dist_attr_from_cpp_for_graph(dist_context): py_dist_attr = dist_context.get_op_dist_attr_for_graph(node) cpp_dist_attr = node.op().dist_attr _copy_op_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr) + + +def insert_dependencies_for_two_ops( + block, + idx, + prior_op, + posterior, + dist_context, + is_recompute=False, + sync=False, +): + """ + dependency: prior_op should be run before posterior + """ + + assert ( + len(prior_op.output_arg_names) >= 1 + ), "first op of dependency should at least have one output. [{}]".format( + str(prior_op) + ) + assert ( + len(posterior.input_arg_names) >= 1 + ), "second op of dependency should at least have one input. [{}]".format( + str(posterior) + ) + prior_op_mesh = dist_context.get_op_dist_attr_for_program( + prior_op + ).process_mesh + posterior_mesh = dist_context.get_op_dist_attr_for_program( + posterior + ).process_mesh + assert ( + prior_op_mesh == posterior_mesh + ), "two ops of dependency should have same mesh but got [{}] and [{}]".format( + str(prior_op_mesh), str(posterior_mesh) + ) + + def _select_best_depend_var(vars): + + vars_with_numels = [(var, get_var_numel(var)) for var in vars] + vars_with_numels.sort(key=lambda x: x[1]) + + return vars_with_numels[-1][0] + + first_var = _select_best_depend_var( + [block.var(name) for name in prior_op.output_arg_names] + ) + second_var = _select_best_depend_var( + [block.var(name) for name in posterior.input_arg_names] + ) + + depend_op = block._insert_op_without_sync( + idx, + type='nop', + inputs={ + "X": first_var, + }, + outputs={"Out": second_var}, + ) + # depend_op.desc.set_type("depend") + depend_op._set_attr(OP_ROLE_KEY, OpRole.Backward) + # depend_op.desc.set_input("Dep", [first_var.name]) + # self.desc.set_output(out_proto.name, out_arg_names) + + naive_set_dist_op_attr_for_program_by_mesh( + depend_op, prior_op_mesh, dist_context, is_recompute + ) + + if sync: + block._sync_with_cpp() diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index 5bdbe9d2dd5d9d2cf1b424e2bdb25c3f68a5f76a..44e02fb3ffad8121738652e117af1b0ccc569948 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -29,6 +29,7 @@ from paddle.distributed.auto_parallel.utils import ( ) from paddle.distributed.auto_parallel.utils import ( naive_set_dist_op_attr_for_program_by_mesh_and_mapping, + insert_dependencies_for_two_ops, ) @@ -449,6 +450,7 @@ class RecomputePass(PassBase): while idx - 1 >= 0 and ops[idx - 1].type == "sum": idx -= 1 segment_descs = ckpt_ops_dict[fwd_op_id][1] + rc_op = None for _, op_desc in reversed(list(enumerate(segment_descs))): rc_op = main_block._insert_op_without_sync( idx, type='nop' @@ -466,7 +468,15 @@ class RecomputePass(PassBase): ) ckpt_ops_dict[fwd_op_id][0] = False - + if rc_op: + insert_dependencies_for_two_ops( + main_block, + idx, + main_block.ops[rc_op.idx - 1], + rc_op, + self._dist_context, + sync=False, + ) main_program._sync_with_cpp() def reset_op_dist_attr(self, op, var_name_dict):