未验证 提交 5e9845b8 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] Completion Dist Attribute for Backward & Update stage (#36744)

* revise completion for backward

* revise completion for update

* revise completion for update

* update unitest
上级 e92e6b06
......@@ -623,24 +623,35 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
if dist_context is None:
dist_context = get_default_distributed_context()
grad_start_idx = -1
first_backward_op_idx = -1
for idx, op in enumerate(auto_parallel_main_prog.global_block().ops):
if int(op.attr('op_role')) == int(
int(core.op_proto_and_checker_maker.OpRole.Backward) | int(
core.op_proto_and_checker_maker.OpRole.Loss)):
assert op.type == "fill_constant"
grad_start_idx = idx
first_backward_op_idx = idx
break
assert grad_start_idx >= 0, "No backward procedure found in this program."
assert first_backward_op_idx >= 0, "No backward procedure found in this program."
ops = list(auto_parallel_main_prog.global_block().ops)
vars = auto_parallel_main_prog.global_block().vars
dist_op_helper = dist_context.get_dist_op_helper()
for idx in range(grad_start_idx, len(ops)):
for idx in range(first_backward_op_idx, len(ops)):
# complete the initial grad loss op
if idx == grad_start_idx:
if idx == first_backward_op_idx:
assert ops[idx].type == "fill_constant"
assert len(
ops[idx].input_arg_names
) == 0, "first backward op should has only ONE output, but got [{}]".format(
len(ops[idx].input_arg_names))
assert len(
ops[idx].output_arg_names
) == 1, "first backward op should has only ONE output, but got [{}]".format(
len(ops[idx].output_arg_names))
grad_var = vars[ops[idx].output_arg_names[0]]
forward_var_name = _get_forward_varname_from_grad_varname(
grad_var.name)
......@@ -659,90 +670,80 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
op_attr = OperatorDistributedAttribute(ops[idx], dist_context)
op_attr.set_process_mesh(process_mesh)
dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr)
continue
# TODO remove this when dist op handle its own grad scale
# in the data parallel mode, the loss op followed by scale op.
if ops[idx].type == "scale" and idx == grad_start_idx + 1:
assert grad_var.name in ops[
idx].input_arg_names and grad_var.name in ops[
idx].output_arg_names
grad_var = vars[ops[idx].output_arg_names[0]]
forward_var_name = _get_forward_varname_from_grad_varname(
grad_var.name)
forward_var = vars[forward_var_name]
process_mesh = dist_context.get_tensor_distributed_attr_for_program(
forward_var).get_process_mesh()
op_attr = OperatorDistributedAttribute(ops[idx], dist_context)
op_attr.set_process_mesh(process_mesh)
dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr)
continue
# TODO remove this when dist op handle its own communication
# TODO should distinguish the dp allreduce and mp allreduce
# complete the c_allreduce_sum op for gradient in the data parallel mode.
if ops[idx].type == "c_allreduce_sum" and ops[
idx].input_arg_names == ops[idx].output_arg_names:
grad_var = vars[ops[idx].output_arg_names[0]]
op_attr = OperatorDistributedAttribute(ops[idx], dist_context)
process_mesh = dist_context.get_tensor_distributed_attr_for_program(
grad_var).get_process_mesh()
op_attr.set_process_mesh(process_mesh)
op_attr.set_output_dims_mapping(grad_var.name, dims_mapping)
dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr)
continue
# complete the annotation of grad op (xxx_grad op or sum op)
grad_op = ops[idx]
# xxx_grad op will have a corresponding forward op in gradopidx2opidx
dist_op_helper = dist_context.get_dist_op_helper()
grad_op = ops[idx]
if grad_op.desc.id() in dist_op_helper.gradopidx2opidx:
# TODO support the case where one forward op corresponding to multiple xxx_grad op
forward_op = _get_op_by_id(
ops[:grad_start_idx],
ops[:first_backward_op_idx],
dist_op_helper.gradopidx2opidx[grad_op.desc.id()])
assert forward_op is not None
# op dist attr
forward_op_attr = dist_context.get_op_distributed_attr_for_program(
forward_op)
forward_op_process_mesh = forward_op_attr.get_process_mesh()
grad_op_attr = OperatorDistributedAttribute(grad_op, dist_context)
grad_op_attr.set_process_mesh(forward_op_attr.get_process_mesh())
for var_name in grad_op.input_arg_names:
if "@GRAD" in var_name:
dims_mapping = dist_context.get_tensor_distributed_attr_for_program(
vars[var_name]).get_dims_mapping()
grad_op_attr.set_input_dims_mapping(var_name, dims_mapping)
grad_op_attr.set_process_mesh(forward_op_process_mesh)
# var
for output_name in grad_op.desc.output_names():
assert len(grad_op.desc.output(output_name)) in [0, 1]
# if grad_op.type == "cast":
# input_name = "X"
# else:
if _is_grad_var_name(output_name):
input_name = _get_forward_varname_from_grad_varname(
output_name)
else:
dims_mapping = forward_op_attr.get_input_dims_mapping(
var_name)
# TODO fixed here
if dims_mapping == None:
dims_mapping = forward_op_attr.get_output_dims_mapping(
var_name)
assert dims_mapping is not None, "[{}]'s dims_mapping is None".format(
var_name)
grad_op_attr.set_input_dims_mapping(var_name, dims_mapping)
assert grad_op.type in [
"cast", "c_identity", "c_allreduce_sum"
]
input_name = "X"
assert input_name in forward_op.desc.input_names(
), "var [{}] in op [{}]'s output but coulf not find [{}] in its forward op".format(
output_name, grad_op.type, input_name)
if len(grad_op.desc.output(output_name)) == 1:
assert len(forward_op.desc.input(input_name)) == 1
input_var = vars[forward_op.desc.input(input_name)[0]]
input_var_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
input_var)
assert input_var_dist_attr is not None, "[{}] has not dist attribute".format(
input_var.name)
ref_dims_mapping = input_var_dist_attr.get_dims_mapping()
# tensor dist attr
output_var = vars[grad_op.desc.output(output_name)[0]]
output_var_attr = TensorDistributedAttribute(output_var,
dist_context)
output_var_attr.set_dims_mapping(ref_dims_mapping)
output_var_attr.set_process_mesh(forward_op_process_mesh)
dist_context.set_tensor_distributed_attr_for_program(
output_var, output_var_attr)
# op dist attr
grad_op_attr.set_output_dims_mapping(output_var.name,
ref_dims_mapping)
for input_name in grad_op.input_arg_names:
input_var = vars[input_name]
input_var_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
input_var)
assert input_var_dist_attr is not None, "[{}] has not dist attribute".format(
input_var.name)
ref_dims_mapping = input_var_dist_attr.get_dims_mapping()
assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format(
input_var.name)
grad_op_attr.set_input_dims_mapping(input_name,
ref_dims_mapping)
dist_context.set_op_distributed_attr_for_program(grad_op,
grad_op_attr)
# var dist attr
for var_name in grad_op.output_arg_names:
if _is_grad_var_name(var_name):
forward_var_name = _get_forward_varname_from_grad_varname(
var_name)
forward_var = vars[forward_var_name]
tensor_attr = TensorDistributedAttribute(vars[var_name],
dist_context)
process_mesh = grad_op_attr.get_process_mesh()
dims_mapping = grad_op_attr.get_input_dims_mapping(
forward_var_name)
tensor_attr.set_process_mesh(process_mesh)
tensor_attr.set_dims_mapping(dims_mapping)
dist_context.set_tensor_distributed_attr_for_program(
vars[var_name], tensor_attr)
# only sum op for merge mutiple version grad has no a corresponding mapping in gradopidx2opidx
else:
......@@ -775,6 +776,9 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
var_name) == ref_forward_var_name
grad_op_attr.set_input_dims_mapping(
var_name, ref_forward_var_dims_mapping)
grad_op_attr.set_output_dims_mapping(grad_op.output_arg_names[0],
ref_forward_var_dims_mapping)
dist_context.set_op_distributed_attr_for_program(grad_op,
grad_op_attr)
......@@ -787,28 +791,86 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context):
ops = list(auto_parallel_main_prog.global_block().ops)
vars = auto_parallel_main_prog.global_block().vars
learning_rate_completed = False
for idx in range(len(ops)):
# complete the annotation of the optimizer op.
# TODO to add attribute for moment var
if int(ops[idx].attr('op_role')) == int(OpRole.Optimize):
if "Grad" in ops[idx].input_names and "Param" in ops[
idx].input_names:
assert len(ops[idx].input(
op = ops[idx]
if int(op.attr('op_role')) == int(OpRole.Optimize):
if "Grad" in op.input_names and "Param" in ops[idx].input_names:
assert len(op.input(
"Param")) == 1, "Only support one-to-one now."
assert len(ops[idx].input(
assert len(op.input(
"Grad")) == 1, "Only support one-to-one now."
param = vars[ops[idx].input("Param")[0]]
grad_var = vars[ops[idx].input("Grad")[0]]
process_mesh = dist_context.get_tensor_distributed_attr_for_program(
param = vars[op.input("Param")[0]]
grad_var = vars[op.input("Grad")[0]]
param_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
param)
grad_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
grad_var)
assert param_dist_attr is not None
assert grad_dist_attr is not None
assert param_dist_attr.get_dims_mapping(
) == grad_dist_attr.get_dims_mapping()
ref_process_mesh = dist_context.get_tensor_distributed_attr_for_program(
param).get_process_mesh()
dims_mapping = dist_context.get_tensor_distributed_attr_for_program(
assert ref_process_mesh is not None
ref_dims_mapping = dist_context.get_tensor_distributed_attr_for_program(
param).get_dims_mapping()
op_attr = OperatorDistributedAttribute(ops[idx], dist_context)
op_attr.set_process_mesh(process_mesh)
op_attr.set_input_dims_mapping(grad_var.name, dims_mapping)
op_attr.set_input_dims_mapping(param.name, dims_mapping)
dist_context.set_op_distributed_attr_for_program(ops[idx],
op_attr)
assert ref_dims_mapping is not None
op_attr = OperatorDistributedAttribute(op, dist_context)
op_attr.set_process_mesh(ref_process_mesh)
op_attr.set_input_dims_mapping(grad_var.name, ref_dims_mapping)
op_attr.set_input_dims_mapping(param.name, ref_dims_mapping)
op_attr.set_output_dims_mapping(param.name, ref_dims_mapping)
learning_var = vars[op.input("LearningRate")[0]]
op_attr.set_input_dims_mapping(learning_var.name, [-1])
op_attr.set_output_dims_mapping(learning_var.name, [-1])
if not learning_rate_completed:
learning_rate_completed = True
var_dist_attr = TensorDistributedAttribute(learning_var,
dist_context)
var_dist_attr.set_process_mesh(ref_process_mesh)
var_dist_attr.set_dims_mapping([-1])
dist_context.set_tensor_distributed_attr_for_program(
learning_var, var_dist_attr)
for input_name in op.desc.input_names():
if input_name in [
'Param', 'Grad', 'LearningRate', "SkipUpdate",
"Beta1Tensor", "Beta2Tensor", "EpsilonTensor",
"MasterParam"
]:
continue
assert len(op.desc.input(input_name)) == 1
input_var = vars[op.desc.input(input_name)[0]]
input_var_attr = TensorDistributedAttribute(input_var,
dist_context)
if "Beta1Pow" in input_name or "Beta2Pow" in input_name:
input_var_attr.set_dims_mapping([-1])
op_attr.set_input_dims_mapping(input_var.name, [-1])
op_attr.set_output_dims_mapping(input_var.name, [-1])
else:
assert "Moment" in input_name
input_var_attr.set_dims_mapping(ref_dims_mapping)
op_attr.set_input_dims_mapping(input_var.name,
ref_dims_mapping)
op_attr.set_output_dims_mapping(input_var.name,
ref_dims_mapping)
input_var_attr.set_process_mesh(ref_process_mesh)
dist_context.set_tensor_distributed_attr_for_program(
input_var, input_var_attr)
dist_context.set_op_distributed_attr_for_program(op, op_attr)
continue
......@@ -55,6 +55,35 @@ def check_tensor_split(prog1, varnames1, prog2, varnames2, axis, nsplit):
return True
def is_valid_completed_program(dist_context, program):
# TODO (ZJ-LIANG) should check all block
ops = program.global_block().ops
vars_ = program.list_vars()
for op in ops:
op_dist_attrs = dist_context.get_op_distributed_attr_for_program(op)
if op_dist_attrs == None:
return False
if op_dist_attrs.get_process_mesh == None:
return False
if None in op_dist_attrs._dims_mapping.values():
return False
for var in vars_:
var_dist_attrs = dist_context.get_tensor_distributed_attr_for_program(
var)
if var_dist_attrs == None:
return False
elif var_dist_attrs.get_process_mesh == None:
return False
elif var_dist_attrs.get_dims_mapping == None:
return False
return True
class MultiHeadAttention(nn.Layer):
"""
Attention mapps queries and a set of key-value pairs to outputs, and
......@@ -874,6 +903,9 @@ class TestGPTPartitioner(unittest.TestCase):
self.assertTrue(all_params == data_parallel_allreduce_vars)
self.assertTrue(allreduce_grads == tensor_parallel_allreduce_vars)
self.assertTrue(
is_valid_completed_program(dist_context, auto_parallel_main_prog))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册