未验证 提交 17b8446d 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] use original id in grad_op_id_to_op_id (#42992)

* use original id in dist_op_context.grad_op_id_to_op_id

* del assert

* remove redundant map
上级 f87fa3c0
......@@ -771,7 +771,7 @@ class Completer:
def _get_op_by_id(ops, id):
for op in ops:
if op.desc.id() == id:
if op.desc.original_id() == id:
return op
return None
......@@ -796,10 +796,12 @@ class Completer:
# complete the annotation of grad op (xxx_grad op or sum op)
# xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id
grad_op = ops[idx]
if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id:
if grad_op.desc.original_id(
) in dist_op_context.grad_op_id_to_op_id:
# TODO support the case where one forward op corresponding to multiple xxx_grad op
forward_op = _get_op_by_id(
ops, dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()])
forward_op = _get_op_by_id(ops,
dist_op_context.grad_op_id_to_op_id[
grad_op.desc.original_id()])
assert forward_op is not None
fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
......@@ -935,7 +937,7 @@ class Completer:
def _get_op_by_id(ops, id):
for op in ops:
if op.desc.id() == id:
if op.desc.original_id() == id:
return op
return None
......@@ -997,11 +999,12 @@ class Completer:
# complete the annotation of grad op (xxx_grad op or sum op)
# xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id
grad_op = ops[idx]
if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id:
if grad_op.desc.original_id(
) in dist_op_context.grad_op_id_to_op_id:
# TODO support the case where one forward op corresponding to multiple xxx_grad op
forward_op = _get_op_by_id(
ops[:first_backward_op_idx],
dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()])
forward_op = _get_op_by_id(ops[:first_backward_op_idx],
dist_op_context.grad_op_id_to_op_id[
grad_op.desc.original_id()])
assert forward_op is not None
if grad_op.type == "concat" and forward_op.type == "split":
......
......@@ -204,8 +204,12 @@ class DistributedContext:
)
self._serial_startup_program = self._original_serial_startup_program.clone(
)
self._serial_main_program = self._original_serial_main_program
self._serial_startup_program = self._original_serial_startup_program
# self._serial_main_program = self._original_serial_main_program
# self._serial_startup_program = self._original_serial_startup_program
if self._original_serial_loss:
self._serial_loss = self._serial_main_program.global_block(
).vars[self._original_serial_loss[0].name]
else:
self._serial_loss = self._original_serial_loss
self._serial_optimizer = self._original_serial_optimizer
self._init_dist_attr_for_program()
......
......@@ -51,7 +51,7 @@ class Parallelizer:
serial_optimizer = self._dist_context.serial_optimizer
if self._mode == "train" and serial_optimizer:
# Generate backward
serial_loss = self._dist_context.serial_fetch_vars["loss"][0]
serial_loss = self._dist_context.serial_loss
params_grads = self._generate_backward(
serial_main_program, serial_startup_program, serial_loss)
# Apply pre optimization passes
......
......@@ -211,7 +211,7 @@ class Partitioner(object):
forward_op_id2forward_op = {}
for idx in range(len(serial_ops)):
if idx <= last_fwd_op_idx:
forward_op_id2forward_op[serial_ops[idx].desc.id(
forward_op_id2forward_op[serial_ops[idx].desc.original_id(
)] = serial_ops[idx]
appended_grad_times = 0
......@@ -408,9 +408,9 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
def _get_dist_op_backward_implement(backward_op, dist_context,
forward_op_id2forward_op):
dist_op_context = dist_context.dist_op_context
if backward_op.desc.id() in dist_op_context.grad_op_id_to_op_id:
forward_op_id = dist_op_context.grad_op_id_to_op_id[backward_op.desc.id(
)]
if backward_op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
forward_op_id = dist_op_context.grad_op_id_to_op_id[
backward_op.desc.original_id()]
forward_op = forward_op_id2forward_op[forward_op_id]
forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_op)
......
......@@ -46,13 +46,13 @@ class AMPState(object):
if int(op.attr('op_role')) == int(OpRole.Forward):
self._mark_black_white_ops(amp_lists)
elif int(op.attr('op_role')) == int(OpRole.Backward):
if op.desc.id() in dist_op_context.grad_op_id_to_op_id:
fwd_op_id = dist_op_context.grad_op_id_to_op_id[op.desc.id(
)]
if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
fwd_op_id = dist_op_context.grad_op_id_to_op_id[
op.desc.original_id()]
if self._is_fp16_op(fwd_op_id) == True:
self._op_fp16_dict[op.desc.id()] = True
self._op_fp16_dict[op.desc.original_id()] = True
elif self._is_fp16_op(fwd_op_id) == False:
self._op_fp16_dict[op.desc.id()] = False
self._op_fp16_dict[op.desc.original_id()] = False
elif int(op.attr('op_role')) == int(OpRole.Optimize):
break
......@@ -70,12 +70,12 @@ class AMPState(object):
continue
if amp_lists.black_varnames is not None and _is_in_black_varnames(
op, amp_lists):
self._op_fp16_dict[op.desc.id()] = False
self._op_fp16_dict[op.desc.original_id()] = False
continue
if op.type in amp_lists.black_list:
self._op_fp16_dict[op.desc.id()] = False
self._op_fp16_dict[op.desc.original_id()] = False
elif op.type in amp_lists.white_list:
self._op_fp16_dict[op.desc.id()] = True
self._op_fp16_dict[op.desc.original_id()] = True
elif op.type in amp_lists.gray_list:
is_black_op = False
is_white_op = False
......@@ -95,22 +95,22 @@ class AMPState(object):
else:
prev_op = in_var.op
# if it's one of inputs
if self._is_fp16_op(prev_op.desc.id()) == False or \
if self._is_fp16_op(prev_op.desc.original_id()) == False or \
prev_op.type in amp_lists.black_list:
is_black_op = True
elif self._is_fp16_op(prev_op.desc.id()) == True or \
elif self._is_fp16_op(prev_op.desc.original_id()) == True or \
prev_op.type in amp_lists.white_list:
is_white_op = True
if is_black_op:
self._op_fp16_dict[op.desc.id()] = False
self._op_fp16_dict[op.desc.original_id()] = False
elif is_white_op:
self._op_fp16_dict[op.desc.id()] = True
self._op_fp16_dict[op.desc.original_id()] = True
else:
pass
else:
# For numerical safe, we apply fp32 computation on ops that
# are not determined which list they should stay.
self._op_fp16_dict[op.desc.id()] = False
self._op_fp16_dict[op.desc.original_id()] = False
def cast_forward_program(self, dist_context):
ops = self._block.ops
......@@ -120,11 +120,11 @@ class AMPState(object):
num_cast_ops = 0
if int(op.attr('op_role')) == int(OpRole.Backward):
break
if self._is_fp16_op(op.desc.id()) == False:
if self._is_fp16_op(op.desc.original_id()) == False:
num_cast_ops = self._insert_cast_op_forward(
op, idx, core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32, dist_context)
elif self._is_fp16_op(op.desc.id()) == True:
elif self._is_fp16_op(op.desc.original_id()) == True:
num_cast_ops = self._insert_cast_op_forward(
op, idx, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, dist_context)
......@@ -198,7 +198,7 @@ class AMPState(object):
else:
if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dst_dtype)
self._var_name_dict[op.desc.id()] = var_name_dict
self._var_name_dict[op.desc.original_id()] = var_name_dict
if src_dtype == core.VarDesc.VarType.FP32 and dst_dtype == core.VarDesc.VarType.FP16:
for out_name in op.output_names:
......@@ -225,13 +225,14 @@ class AMPState(object):
while idx < len(ops):
num_cast_ops = 0
grad_op = ops[idx]
grad_op_orig_id = grad_op.desc.original_id()
dist_op_context = dist_context.dist_op_context
if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id:
if self._is_fp16_op(grad_op.desc.id()) == False: # fp32
if grad_op_orig_id in dist_op_context.grad_op_id_to_op_id:
if self._is_fp16_op(grad_op_orig_id) == False: # fp32
num_cast_ops = self._insert_cast_op_backward(
grad_op, idx, core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32, dist_context)
elif self._is_fp16_op(grad_op.desc.id()) == True: # fp16
elif self._is_fp16_op(grad_op_orig_id) == True: # fp16
num_cast_ops = self._insert_cast_op_backward(
grad_op, idx, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, dist_context)
......@@ -272,8 +273,9 @@ class AMPState(object):
return False
num_cast_ops = 0
original_id = grad_op.desc.original_id()
dist_op_context = dist_context.dist_op_context
fwd_op_id = dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()]
fwd_op_id = dist_op_context.grad_op_id_to_op_id[original_id]
for in_name in grad_op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
......
......@@ -153,23 +153,24 @@ class FP16State(object):
# ernie inference trick
if op.type == "assign" and "array_" in op.input_arg_names[0]:
self._op_fp16_dict[op.desc.id()] = False
self._op_fp16_dict[op.desc.original_id()] = False
return
if _need_keep_fp32(op, self.amp_list.unsupported_list,
self.use_fp16_guard):
self._op_fp16_dict[op.desc.id()] = False
self._op_fp16_dict[op.desc.original_id()] = False
else:
self._op_fp16_dict[op.desc.id()] = True
self._op_fp16_dict[op.desc.original_id()] = True
for var_name in op.output_arg_names:
# assert var_name not in self.forward_non_leaf_tensors, "{}".format(var_name)
self.forward_non_leaf_tensors[var_name] = op.desc.id()
elif is_backward_op(op) == int(OpRole.Backward):
if op.desc.id() in self.grad_op_to_op_map:
fwd_op_id = self.grad_op_to_op_map[op.desc.id()]
if op.desc.original_id() in self.grad_op_to_op_map:
fwd_op_id = self.grad_op_to_op_map[op.desc.original_id()]
assert fwd_op_id in self._op_fp16_dict, "{}".format(str(op))
self._op_fp16_dict[op.desc.id()] = self._op_fp16_dict[fwd_op_id]
self._op_fp16_dict[op.desc.original_id()] = self._op_fp16_dict[
fwd_op_id]
if int(op.attr('op_role')) == 257:
self.is_train = True
......@@ -192,10 +193,10 @@ class FP16State(object):
def resolute_tensor_dtype(self, block):
for op in block.ops:
op_id = op.desc.id()
if is_forward_op(op):
# NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
if self._is_fp16_op(op_id) == True or op.type == "cast":
if self._is_fp16_op(op.desc.original_id()) == True \
or op.type == "cast":
for in_name in op.input_names:
if _keep_fp32_input(op, in_name):
continue
......@@ -209,7 +210,7 @@ class FP16State(object):
self.set_var_to_fp16(out_var_name, block)
set_op_dtype_to_fp16(op)
# NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
elif self._is_fp16_op(op_id) == False:
elif self._is_fp16_op(op.desc.original_id()) == False:
for out_var_name in op.output_arg_names:
out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types:
......@@ -217,7 +218,7 @@ class FP16State(object):
if out_var.dtype == core.VarDesc.VarType.FP16:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
elif is_backward_op(op):
if self._is_fp16_op(op_id) == True:
if self._is_fp16_op(op.desc.original_id()) == True:
for out_name in op.output_names:
if _keep_fp32_output(op, out_name):
continue
......@@ -225,7 +226,7 @@ class FP16State(object):
self.set_var_to_fp16(out_var_name, block)
set_op_dtype_to_fp16(op)
# NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
elif self._is_fp16_op(op_id) == False:
elif self._is_fp16_op(op.desc.original_id()) == False:
for out_var_name in op.output_arg_names:
out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types:
......@@ -238,28 +239,27 @@ class FP16State(object):
idx = 0
while idx < len(block.ops):
op = block.ops[idx]
op_id = op.desc.id()
num_cast_ops = 0
if op.type in __amp_skip_ops__:
idx += 1
continue
elif is_forward_op(op):
if self._is_fp16_op(op_id) == False:
if self._is_fp16_op(op.desc.original_id()) == False:
num_cast_ops = self._insert_forward_cast_ops(
op, idx, block, core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32, self.dist_context)
elif self._is_fp16_op(op_id) == True:
elif self._is_fp16_op(op.desc.original_id()) == True:
num_cast_ops = self._insert_forward_cast_ops(
op, idx, block, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, self.dist_context)
elif is_backward_op(op):
if op_id in dist_op_context.grad_op_id_to_op_id:
if self._is_fp16_op(op_id) == False:
if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
if self._is_fp16_op(op.desc.original_id()) == False:
num_cast_ops = self._insert_backward_cast_ops(
op, idx, block, core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32, self.dist_context)
elif self._is_fp16_op(op_id) == True:
elif self._is_fp16_op(op.desc.original_id()) == True:
num_cast_ops = self._insert_backward_cast_ops(
op, idx, block, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, self.dist_context)
......@@ -282,7 +282,6 @@ class FP16State(object):
dist_context):
num_cast_ops = 0
op_id = op.desc.id()
for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
......@@ -300,7 +299,7 @@ class FP16State(object):
cast_name = in_var.name + '.cast_' + _dtype_to_str(
dst_dtype)
cast_var = block.vars.get(cast_name)
self.forward_input_cast_ops[op_id] += [(
self.forward_input_cast_ops[op.desc.original_id()] += [(
cast_name, in_var.name, dst_dtype, src_dtype, in_name)]
in_var_dist_attr = consume_op_attr.get_input_dist_attr(
......@@ -349,8 +348,9 @@ class FP16State(object):
num_cast_ops = 0
op_id = op.desc.id()
original_id = op.desc.original_id()
dist_op_context = dist_context.dist_op_context
forward_op_id = dist_op_context.grad_op_id_to_op_id[op_id]
forward_op_id = dist_op_context.grad_op_id_to_op_id[original_id]
grad_op_attr = dist_context.get_op_dist_attr_for_program(op)
assert grad_op_attr is not None
......
......@@ -315,7 +315,7 @@ class RecomputePass(PassBase):
# When traversing all grad_ops in reverse, need to set a flag to indicate
# whether the ckpt and its segment_descs can be used.
ckpt_op = op_path[segment[1] - 1]
ckpt_ops_dict[ckpt_op.desc.id()] = [True, segment_descs]
ckpt_ops_dict[ckpt_op.desc.original_id()] = [True, segment_descs]
# step 4: insert recomputed fwd ops
ops = main_block.ops
......@@ -339,9 +339,9 @@ class RecomputePass(PassBase):
_rename_arg_([grad_op.desc], key, var_name_dict[key])
# insert recomputed ops
if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id:
fwd_op_id = dist_op_context.grad_op_id_to_op_id[grad_op.desc.id(
)]
original_id = grad_op.desc.original_id()
if original_id in dist_op_context.grad_op_id_to_op_id:
fwd_op_id = dist_op_context.grad_op_id_to_op_id[original_id]
if fwd_op_id in ckpt_ops_dict and ckpt_ops_dict[fwd_op_id][0]:
idx = grad_op.idx
while idx - 1 >= 0 and ops[idx - 1].type == "sum":
......
......@@ -1107,8 +1107,10 @@ def _append_backward_ops_(block,
distop_context.grad_var_to_var[appending_grad_times].update(
op_grad_to_var)
for op_desc in grad_op_desc:
assert op_desc.id() not in distop_context.grad_op_id_to_op_id
distop_context.grad_op_id_to_op_id[op_desc.id()] = op.desc.id()
assert op_desc.original_id(
) not in distop_context.grad_op_id_to_op_id
distop_context.grad_op_id_to_op_id[op_desc.original_id(
)] = op.desc.original_id()
if callbacks is not None:
assert (isinstance(callbacks, (list, tuple)))
......@@ -1255,12 +1257,6 @@ def _append_backward_ops_(block,
for op_desc in grad_op_descs:
new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op_desc)
# Rebuild the mapping because new_op_desc has a differnt id (Only for auto parallel)
if distop_context is not None:
if op_desc.id() in distop_context.grad_op_id_to_op_id:
distop_context.grad_op_id_to_op_id[new_op_desc.id(
)] = distop_context.grad_op_id_to_op_id[op_desc.id()]
distop_context.grad_op_id_to_op_id.pop(op_desc.id())
new_op_desc._set_attr(op_role_attr_name, backward)
grad_to_var["__current_op_desc__"] = new_op_desc
if callbacks is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册