From 5e9845b807fd26fe9f3dd72d569efdda9dad4722 Mon Sep 17 00:00:00 2001 From: JZ-LIANG <38102074+JZ-LIANG@users.noreply.github.com> Date: Wed, 27 Oct 2021 16:42:28 +0800 Subject: [PATCH] [Auto Parallel] Completion Dist Attribute for Backward & Update stage (#36744) * revise completion for backward * revise completion for update * revise completion for update * update unitest --- .../distributed/auto_parallel/completion.py | 234 +++++++++++------- .../test_auto_parallel_partitioner_gpt.py | 32 +++ 2 files changed, 180 insertions(+), 86 deletions(-) mode change 100644 => 100755 python/paddle/distributed/auto_parallel/completion.py diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py old mode 100644 new mode 100755 index 855eb656bd..0097a38e23 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -623,24 +623,35 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): if dist_context is None: dist_context = get_default_distributed_context() - grad_start_idx = -1 + first_backward_op_idx = -1 for idx, op in enumerate(auto_parallel_main_prog.global_block().ops): if int(op.attr('op_role')) == int( int(core.op_proto_and_checker_maker.OpRole.Backward) | int( core.op_proto_and_checker_maker.OpRole.Loss)): assert op.type == "fill_constant" - grad_start_idx = idx + first_backward_op_idx = idx break - assert grad_start_idx >= 0, "No backward procedure found in this program." + assert first_backward_op_idx >= 0, "No backward procedure found in this program." ops = list(auto_parallel_main_prog.global_block().ops) vars = auto_parallel_main_prog.global_block().vars + dist_op_helper = dist_context.get_dist_op_helper() - for idx in range(grad_start_idx, len(ops)): + for idx in range(first_backward_op_idx, len(ops)): # complete the initial grad loss op - if idx == grad_start_idx: + if idx == first_backward_op_idx: + assert ops[idx].type == "fill_constant" + assert len( + ops[idx].input_arg_names + ) == 0, "first backward op should has only ONE output, but got [{}]".format( + len(ops[idx].input_arg_names)) + assert len( + ops[idx].output_arg_names + ) == 1, "first backward op should has only ONE output, but got [{}]".format( + len(ops[idx].output_arg_names)) + grad_var = vars[ops[idx].output_arg_names[0]] forward_var_name = _get_forward_varname_from_grad_varname( grad_var.name) @@ -659,90 +670,80 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): op_attr = OperatorDistributedAttribute(ops[idx], dist_context) op_attr.set_process_mesh(process_mesh) - dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr) - continue - - # TODO remove this when dist op handle its own grad scale - # in the data parallel mode, the loss op followed by scale op. - if ops[idx].type == "scale" and idx == grad_start_idx + 1: - assert grad_var.name in ops[ - idx].input_arg_names and grad_var.name in ops[ - idx].output_arg_names - grad_var = vars[ops[idx].output_arg_names[0]] - forward_var_name = _get_forward_varname_from_grad_varname( - grad_var.name) - forward_var = vars[forward_var_name] - process_mesh = dist_context.get_tensor_distributed_attr_for_program( - forward_var).get_process_mesh() - op_attr = OperatorDistributedAttribute(ops[idx], dist_context) - op_attr.set_process_mesh(process_mesh) - dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr) - continue - - # TODO remove this when dist op handle its own communication - # TODO should distinguish the dp allreduce and mp allreduce - # complete the c_allreduce_sum op for gradient in the data parallel mode. - if ops[idx].type == "c_allreduce_sum" and ops[ - idx].input_arg_names == ops[idx].output_arg_names: - grad_var = vars[ops[idx].output_arg_names[0]] - op_attr = OperatorDistributedAttribute(ops[idx], dist_context) - process_mesh = dist_context.get_tensor_distributed_attr_for_program( - grad_var).get_process_mesh() - op_attr.set_process_mesh(process_mesh) + op_attr.set_output_dims_mapping(grad_var.name, dims_mapping) dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr) continue # complete the annotation of grad op (xxx_grad op or sum op) - grad_op = ops[idx] - # xxx_grad op will have a corresponding forward op in gradopidx2opidx - dist_op_helper = dist_context.get_dist_op_helper() + grad_op = ops[idx] if grad_op.desc.id() in dist_op_helper.gradopidx2opidx: # TODO support the case where one forward op corresponding to multiple xxx_grad op forward_op = _get_op_by_id( - ops[:grad_start_idx], + ops[:first_backward_op_idx], dist_op_helper.gradopidx2opidx[grad_op.desc.id()]) assert forward_op is not None # op dist attr forward_op_attr = dist_context.get_op_distributed_attr_for_program( forward_op) + forward_op_process_mesh = forward_op_attr.get_process_mesh() grad_op_attr = OperatorDistributedAttribute(grad_op, dist_context) - grad_op_attr.set_process_mesh(forward_op_attr.get_process_mesh()) - - for var_name in grad_op.input_arg_names: - if "@GRAD" in var_name: - dims_mapping = dist_context.get_tensor_distributed_attr_for_program( - vars[var_name]).get_dims_mapping() - grad_op_attr.set_input_dims_mapping(var_name, dims_mapping) + grad_op_attr.set_process_mesh(forward_op_process_mesh) + + # var + for output_name in grad_op.desc.output_names(): + assert len(grad_op.desc.output(output_name)) in [0, 1] + # if grad_op.type == "cast": + # input_name = "X" + # else: + if _is_grad_var_name(output_name): + input_name = _get_forward_varname_from_grad_varname( + output_name) else: - dims_mapping = forward_op_attr.get_input_dims_mapping( - var_name) - # TODO fixed here - if dims_mapping == None: - dims_mapping = forward_op_attr.get_output_dims_mapping( - var_name) - assert dims_mapping is not None, "[{}]'s dims_mapping is None".format( - var_name) - grad_op_attr.set_input_dims_mapping(var_name, dims_mapping) + assert grad_op.type in [ + "cast", "c_identity", "c_allreduce_sum" + ] + input_name = "X" + assert input_name in forward_op.desc.input_names( + ), "var [{}] in op [{}]'s output but coulf not find [{}] in its forward op".format( + output_name, grad_op.type, input_name) + if len(grad_op.desc.output(output_name)) == 1: + assert len(forward_op.desc.input(input_name)) == 1 + input_var = vars[forward_op.desc.input(input_name)[0]] + input_var_dist_attr = dist_context.get_tensor_distributed_attr_for_program( + input_var) + assert input_var_dist_attr is not None, "[{}] has not dist attribute".format( + input_var.name) + ref_dims_mapping = input_var_dist_attr.get_dims_mapping() + + # tensor dist attr + output_var = vars[grad_op.desc.output(output_name)[0]] + output_var_attr = TensorDistributedAttribute(output_var, + dist_context) + output_var_attr.set_dims_mapping(ref_dims_mapping) + output_var_attr.set_process_mesh(forward_op_process_mesh) + dist_context.set_tensor_distributed_attr_for_program( + output_var, output_var_attr) + + # op dist attr + grad_op_attr.set_output_dims_mapping(output_var.name, + ref_dims_mapping) + + for input_name in grad_op.input_arg_names: + input_var = vars[input_name] + input_var_dist_attr = dist_context.get_tensor_distributed_attr_for_program( + input_var) + assert input_var_dist_attr is not None, "[{}] has not dist attribute".format( + input_var.name) + ref_dims_mapping = input_var_dist_attr.get_dims_mapping() + assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format( + input_var.name) + grad_op_attr.set_input_dims_mapping(input_name, + ref_dims_mapping) + dist_context.set_op_distributed_attr_for_program(grad_op, grad_op_attr) - # var dist attr - for var_name in grad_op.output_arg_names: - if _is_grad_var_name(var_name): - - forward_var_name = _get_forward_varname_from_grad_varname( - var_name) - forward_var = vars[forward_var_name] - tensor_attr = TensorDistributedAttribute(vars[var_name], - dist_context) - process_mesh = grad_op_attr.get_process_mesh() - dims_mapping = grad_op_attr.get_input_dims_mapping( - forward_var_name) - tensor_attr.set_process_mesh(process_mesh) - tensor_attr.set_dims_mapping(dims_mapping) - dist_context.set_tensor_distributed_attr_for_program( - vars[var_name], tensor_attr) # only sum op for merge mutiple version grad has no a corresponding mapping in gradopidx2opidx else: @@ -775,6 +776,9 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): var_name) == ref_forward_var_name grad_op_attr.set_input_dims_mapping( var_name, ref_forward_var_dims_mapping) + + grad_op_attr.set_output_dims_mapping(grad_op.output_arg_names[0], + ref_forward_var_dims_mapping) dist_context.set_op_distributed_attr_for_program(grad_op, grad_op_attr) @@ -787,28 +791,86 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context): ops = list(auto_parallel_main_prog.global_block().ops) vars = auto_parallel_main_prog.global_block().vars + learning_rate_completed = False for idx in range(len(ops)): # complete the annotation of the optimizer op. # TODO to add attribute for moment var - if int(ops[idx].attr('op_role')) == int(OpRole.Optimize): - if "Grad" in ops[idx].input_names and "Param" in ops[ - idx].input_names: - assert len(ops[idx].input( + op = ops[idx] + if int(op.attr('op_role')) == int(OpRole.Optimize): + + if "Grad" in op.input_names and "Param" in ops[idx].input_names: + assert len(op.input( "Param")) == 1, "Only support one-to-one now." - assert len(ops[idx].input( + assert len(op.input( "Grad")) == 1, "Only support one-to-one now." - param = vars[ops[idx].input("Param")[0]] - grad_var = vars[ops[idx].input("Grad")[0]] - process_mesh = dist_context.get_tensor_distributed_attr_for_program( + param = vars[op.input("Param")[0]] + grad_var = vars[op.input("Grad")[0]] + + param_dist_attr = dist_context.get_tensor_distributed_attr_for_program( + param) + grad_dist_attr = dist_context.get_tensor_distributed_attr_for_program( + grad_var) + + assert param_dist_attr is not None + assert grad_dist_attr is not None + assert param_dist_attr.get_dims_mapping( + ) == grad_dist_attr.get_dims_mapping() + + ref_process_mesh = dist_context.get_tensor_distributed_attr_for_program( param).get_process_mesh() - dims_mapping = dist_context.get_tensor_distributed_attr_for_program( + assert ref_process_mesh is not None + ref_dims_mapping = dist_context.get_tensor_distributed_attr_for_program( param).get_dims_mapping() - op_attr = OperatorDistributedAttribute(ops[idx], dist_context) - op_attr.set_process_mesh(process_mesh) - op_attr.set_input_dims_mapping(grad_var.name, dims_mapping) - op_attr.set_input_dims_mapping(param.name, dims_mapping) - dist_context.set_op_distributed_attr_for_program(ops[idx], - op_attr) + assert ref_dims_mapping is not None + op_attr = OperatorDistributedAttribute(op, dist_context) + op_attr.set_process_mesh(ref_process_mesh) + op_attr.set_input_dims_mapping(grad_var.name, ref_dims_mapping) + op_attr.set_input_dims_mapping(param.name, ref_dims_mapping) + op_attr.set_output_dims_mapping(param.name, ref_dims_mapping) + learning_var = vars[op.input("LearningRate")[0]] + op_attr.set_input_dims_mapping(learning_var.name, [-1]) + op_attr.set_output_dims_mapping(learning_var.name, [-1]) + + if not learning_rate_completed: + learning_rate_completed = True + var_dist_attr = TensorDistributedAttribute(learning_var, + dist_context) + var_dist_attr.set_process_mesh(ref_process_mesh) + var_dist_attr.set_dims_mapping([-1]) + dist_context.set_tensor_distributed_attr_for_program( + learning_var, var_dist_attr) + + for input_name in op.desc.input_names(): + + if input_name in [ + 'Param', 'Grad', 'LearningRate', "SkipUpdate", + "Beta1Tensor", "Beta2Tensor", "EpsilonTensor", + "MasterParam" + ]: + continue + + assert len(op.desc.input(input_name)) == 1 + input_var = vars[op.desc.input(input_name)[0]] + input_var_attr = TensorDistributedAttribute(input_var, + dist_context) + + if "Beta1Pow" in input_name or "Beta2Pow" in input_name: + input_var_attr.set_dims_mapping([-1]) + op_attr.set_input_dims_mapping(input_var.name, [-1]) + op_attr.set_output_dims_mapping(input_var.name, [-1]) + else: + assert "Moment" in input_name + input_var_attr.set_dims_mapping(ref_dims_mapping) + op_attr.set_input_dims_mapping(input_var.name, + ref_dims_mapping) + op_attr.set_output_dims_mapping(input_var.name, + ref_dims_mapping) + + input_var_attr.set_process_mesh(ref_process_mesh) + dist_context.set_tensor_distributed_attr_for_program( + input_var, input_var_attr) + + dist_context.set_op_distributed_attr_for_program(op, op_attr) continue diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py index 11b3338bc6..3c395fbdf7 100755 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py @@ -55,6 +55,35 @@ def check_tensor_split(prog1, varnames1, prog2, varnames2, axis, nsplit): return True +def is_valid_completed_program(dist_context, program): + + # TODO (ZJ-LIANG) should check all block + ops = program.global_block().ops + vars_ = program.list_vars() + for op in ops: + op_dist_attrs = dist_context.get_op_distributed_attr_for_program(op) + if op_dist_attrs == None: + return False + + if op_dist_attrs.get_process_mesh == None: + return False + + if None in op_dist_attrs._dims_mapping.values(): + return False + + for var in vars_: + var_dist_attrs = dist_context.get_tensor_distributed_attr_for_program( + var) + if var_dist_attrs == None: + return False + elif var_dist_attrs.get_process_mesh == None: + return False + elif var_dist_attrs.get_dims_mapping == None: + return False + + return True + + class MultiHeadAttention(nn.Layer): """ Attention mapps queries and a set of key-value pairs to outputs, and @@ -874,6 +903,9 @@ class TestGPTPartitioner(unittest.TestCase): self.assertTrue(all_params == data_parallel_allreduce_vars) self.assertTrue(allreduce_grads == tensor_parallel_allreduce_vars) + self.assertTrue( + is_valid_completed_program(dist_context, auto_parallel_main_prog)) + if __name__ == "__main__": unittest.main() -- GitLab