未验证 提交 17b8446d 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] use original id in grad_op_id_to_op_id (#42992)

* use original id in dist_op_context.grad_op_id_to_op_id

* del assert

* remove redundant map
上级 f87fa3c0
...@@ -771,7 +771,7 @@ class Completer: ...@@ -771,7 +771,7 @@ class Completer:
def _get_op_by_id(ops, id): def _get_op_by_id(ops, id):
for op in ops: for op in ops:
if op.desc.id() == id: if op.desc.original_id() == id:
return op return op
return None return None
...@@ -796,10 +796,12 @@ class Completer: ...@@ -796,10 +796,12 @@ class Completer:
# complete the annotation of grad op (xxx_grad op or sum op) # complete the annotation of grad op (xxx_grad op or sum op)
# xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id # xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id
grad_op = ops[idx] grad_op = ops[idx]
if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id: if grad_op.desc.original_id(
) in dist_op_context.grad_op_id_to_op_id:
# TODO support the case where one forward op corresponding to multiple xxx_grad op # TODO support the case where one forward op corresponding to multiple xxx_grad op
forward_op = _get_op_by_id( forward_op = _get_op_by_id(ops,
ops, dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()]) dist_op_context.grad_op_id_to_op_id[
grad_op.desc.original_id()])
assert forward_op is not None assert forward_op is not None
fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
...@@ -935,7 +937,7 @@ class Completer: ...@@ -935,7 +937,7 @@ class Completer:
def _get_op_by_id(ops, id): def _get_op_by_id(ops, id):
for op in ops: for op in ops:
if op.desc.id() == id: if op.desc.original_id() == id:
return op return op
return None return None
...@@ -997,11 +999,12 @@ class Completer: ...@@ -997,11 +999,12 @@ class Completer:
# complete the annotation of grad op (xxx_grad op or sum op) # complete the annotation of grad op (xxx_grad op or sum op)
# xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id # xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id
grad_op = ops[idx] grad_op = ops[idx]
if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id: if grad_op.desc.original_id(
) in dist_op_context.grad_op_id_to_op_id:
# TODO support the case where one forward op corresponding to multiple xxx_grad op # TODO support the case where one forward op corresponding to multiple xxx_grad op
forward_op = _get_op_by_id( forward_op = _get_op_by_id(ops[:first_backward_op_idx],
ops[:first_backward_op_idx], dist_op_context.grad_op_id_to_op_id[
dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()]) grad_op.desc.original_id()])
assert forward_op is not None assert forward_op is not None
if grad_op.type == "concat" and forward_op.type == "split": if grad_op.type == "concat" and forward_op.type == "split":
......
...@@ -204,9 +204,13 @@ class DistributedContext: ...@@ -204,9 +204,13 @@ class DistributedContext:
) )
self._serial_startup_program = self._original_serial_startup_program.clone( self._serial_startup_program = self._original_serial_startup_program.clone(
) )
self._serial_main_program = self._original_serial_main_program # self._serial_main_program = self._original_serial_main_program
self._serial_startup_program = self._original_serial_startup_program # self._serial_startup_program = self._original_serial_startup_program
self._serial_loss = self._original_serial_loss if self._original_serial_loss:
self._serial_loss = self._serial_main_program.global_block(
).vars[self._original_serial_loss[0].name]
else:
self._serial_loss = self._original_serial_loss
self._serial_optimizer = self._original_serial_optimizer self._serial_optimizer = self._original_serial_optimizer
self._init_dist_attr_for_program() self._init_dist_attr_for_program()
self._tensors_ids = list(self._dist_tensors_for_program.keys()) self._tensors_ids = list(self._dist_tensors_for_program.keys())
......
...@@ -51,7 +51,7 @@ class Parallelizer: ...@@ -51,7 +51,7 @@ class Parallelizer:
serial_optimizer = self._dist_context.serial_optimizer serial_optimizer = self._dist_context.serial_optimizer
if self._mode == "train" and serial_optimizer: if self._mode == "train" and serial_optimizer:
# Generate backward # Generate backward
serial_loss = self._dist_context.serial_fetch_vars["loss"][0] serial_loss = self._dist_context.serial_loss
params_grads = self._generate_backward( params_grads = self._generate_backward(
serial_main_program, serial_startup_program, serial_loss) serial_main_program, serial_startup_program, serial_loss)
# Apply pre optimization passes # Apply pre optimization passes
......
...@@ -211,7 +211,7 @@ class Partitioner(object): ...@@ -211,7 +211,7 @@ class Partitioner(object):
forward_op_id2forward_op = {} forward_op_id2forward_op = {}
for idx in range(len(serial_ops)): for idx in range(len(serial_ops)):
if idx <= last_fwd_op_idx: if idx <= last_fwd_op_idx:
forward_op_id2forward_op[serial_ops[idx].desc.id( forward_op_id2forward_op[serial_ops[idx].desc.original_id(
)] = serial_ops[idx] )] = serial_ops[idx]
appended_grad_times = 0 appended_grad_times = 0
...@@ -408,9 +408,9 @@ def _partition_var(dist_context, src_block, dst_block, src_varname, ...@@ -408,9 +408,9 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
def _get_dist_op_backward_implement(backward_op, dist_context, def _get_dist_op_backward_implement(backward_op, dist_context,
forward_op_id2forward_op): forward_op_id2forward_op):
dist_op_context = dist_context.dist_op_context dist_op_context = dist_context.dist_op_context
if backward_op.desc.id() in dist_op_context.grad_op_id_to_op_id: if backward_op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
forward_op_id = dist_op_context.grad_op_id_to_op_id[backward_op.desc.id( forward_op_id = dist_op_context.grad_op_id_to_op_id[
)] backward_op.desc.original_id()]
forward_op = forward_op_id2forward_op[forward_op_id] forward_op = forward_op_id2forward_op[forward_op_id]
forward_op_dist_attr = dist_context.get_op_dist_attr_for_program( forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_op) forward_op)
......
...@@ -46,13 +46,13 @@ class AMPState(object): ...@@ -46,13 +46,13 @@ class AMPState(object):
if int(op.attr('op_role')) == int(OpRole.Forward): if int(op.attr('op_role')) == int(OpRole.Forward):
self._mark_black_white_ops(amp_lists) self._mark_black_white_ops(amp_lists)
elif int(op.attr('op_role')) == int(OpRole.Backward): elif int(op.attr('op_role')) == int(OpRole.Backward):
if op.desc.id() in dist_op_context.grad_op_id_to_op_id: if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
fwd_op_id = dist_op_context.grad_op_id_to_op_id[op.desc.id( fwd_op_id = dist_op_context.grad_op_id_to_op_id[
)] op.desc.original_id()]
if self._is_fp16_op(fwd_op_id) == True: if self._is_fp16_op(fwd_op_id) == True:
self._op_fp16_dict[op.desc.id()] = True self._op_fp16_dict[op.desc.original_id()] = True
elif self._is_fp16_op(fwd_op_id) == False: elif self._is_fp16_op(fwd_op_id) == False:
self._op_fp16_dict[op.desc.id()] = False self._op_fp16_dict[op.desc.original_id()] = False
elif int(op.attr('op_role')) == int(OpRole.Optimize): elif int(op.attr('op_role')) == int(OpRole.Optimize):
break break
...@@ -70,12 +70,12 @@ class AMPState(object): ...@@ -70,12 +70,12 @@ class AMPState(object):
continue continue
if amp_lists.black_varnames is not None and _is_in_black_varnames( if amp_lists.black_varnames is not None and _is_in_black_varnames(
op, amp_lists): op, amp_lists):
self._op_fp16_dict[op.desc.id()] = False self._op_fp16_dict[op.desc.original_id()] = False
continue continue
if op.type in amp_lists.black_list: if op.type in amp_lists.black_list:
self._op_fp16_dict[op.desc.id()] = False self._op_fp16_dict[op.desc.original_id()] = False
elif op.type in amp_lists.white_list: elif op.type in amp_lists.white_list:
self._op_fp16_dict[op.desc.id()] = True self._op_fp16_dict[op.desc.original_id()] = True
elif op.type in amp_lists.gray_list: elif op.type in amp_lists.gray_list:
is_black_op = False is_black_op = False
is_white_op = False is_white_op = False
...@@ -95,22 +95,22 @@ class AMPState(object): ...@@ -95,22 +95,22 @@ class AMPState(object):
else: else:
prev_op = in_var.op prev_op = in_var.op
# if it's one of inputs # if it's one of inputs
if self._is_fp16_op(prev_op.desc.id()) == False or \ if self._is_fp16_op(prev_op.desc.original_id()) == False or \
prev_op.type in amp_lists.black_list: prev_op.type in amp_lists.black_list:
is_black_op = True is_black_op = True
elif self._is_fp16_op(prev_op.desc.id()) == True or \ elif self._is_fp16_op(prev_op.desc.original_id()) == True or \
prev_op.type in amp_lists.white_list: prev_op.type in amp_lists.white_list:
is_white_op = True is_white_op = True
if is_black_op: if is_black_op:
self._op_fp16_dict[op.desc.id()] = False self._op_fp16_dict[op.desc.original_id()] = False
elif is_white_op: elif is_white_op:
self._op_fp16_dict[op.desc.id()] = True self._op_fp16_dict[op.desc.original_id()] = True
else: else:
pass pass
else: else:
# For numerical safe, we apply fp32 computation on ops that # For numerical safe, we apply fp32 computation on ops that
# are not determined which list they should stay. # are not determined which list they should stay.
self._op_fp16_dict[op.desc.id()] = False self._op_fp16_dict[op.desc.original_id()] = False
def cast_forward_program(self, dist_context): def cast_forward_program(self, dist_context):
ops = self._block.ops ops = self._block.ops
...@@ -120,11 +120,11 @@ class AMPState(object): ...@@ -120,11 +120,11 @@ class AMPState(object):
num_cast_ops = 0 num_cast_ops = 0
if int(op.attr('op_role')) == int(OpRole.Backward): if int(op.attr('op_role')) == int(OpRole.Backward):
break break
if self._is_fp16_op(op.desc.id()) == False: if self._is_fp16_op(op.desc.original_id()) == False:
num_cast_ops = self._insert_cast_op_forward( num_cast_ops = self._insert_cast_op_forward(
op, idx, core.VarDesc.VarType.FP16, op, idx, core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32, dist_context) core.VarDesc.VarType.FP32, dist_context)
elif self._is_fp16_op(op.desc.id()) == True: elif self._is_fp16_op(op.desc.original_id()) == True:
num_cast_ops = self._insert_cast_op_forward( num_cast_ops = self._insert_cast_op_forward(
op, idx, core.VarDesc.VarType.FP32, op, idx, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, dist_context) core.VarDesc.VarType.FP16, dist_context)
...@@ -198,7 +198,7 @@ class AMPState(object): ...@@ -198,7 +198,7 @@ class AMPState(object):
else: else:
if op.has_attr('in_dtype'): if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dst_dtype) op._set_attr('in_dtype', dst_dtype)
self._var_name_dict[op.desc.id()] = var_name_dict self._var_name_dict[op.desc.original_id()] = var_name_dict
if src_dtype == core.VarDesc.VarType.FP32 and dst_dtype == core.VarDesc.VarType.FP16: if src_dtype == core.VarDesc.VarType.FP32 and dst_dtype == core.VarDesc.VarType.FP16:
for out_name in op.output_names: for out_name in op.output_names:
...@@ -225,13 +225,14 @@ class AMPState(object): ...@@ -225,13 +225,14 @@ class AMPState(object):
while idx < len(ops): while idx < len(ops):
num_cast_ops = 0 num_cast_ops = 0
grad_op = ops[idx] grad_op = ops[idx]
grad_op_orig_id = grad_op.desc.original_id()
dist_op_context = dist_context.dist_op_context dist_op_context = dist_context.dist_op_context
if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id: if grad_op_orig_id in dist_op_context.grad_op_id_to_op_id:
if self._is_fp16_op(grad_op.desc.id()) == False: # fp32 if self._is_fp16_op(grad_op_orig_id) == False: # fp32
num_cast_ops = self._insert_cast_op_backward( num_cast_ops = self._insert_cast_op_backward(
grad_op, idx, core.VarDesc.VarType.FP16, grad_op, idx, core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32, dist_context) core.VarDesc.VarType.FP32, dist_context)
elif self._is_fp16_op(grad_op.desc.id()) == True: # fp16 elif self._is_fp16_op(grad_op_orig_id) == True: # fp16
num_cast_ops = self._insert_cast_op_backward( num_cast_ops = self._insert_cast_op_backward(
grad_op, idx, core.VarDesc.VarType.FP32, grad_op, idx, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, dist_context) core.VarDesc.VarType.FP16, dist_context)
...@@ -272,8 +273,9 @@ class AMPState(object): ...@@ -272,8 +273,9 @@ class AMPState(object):
return False return False
num_cast_ops = 0 num_cast_ops = 0
original_id = grad_op.desc.original_id()
dist_op_context = dist_context.dist_op_context dist_op_context = dist_context.dist_op_context
fwd_op_id = dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()] fwd_op_id = dist_op_context.grad_op_id_to_op_id[original_id]
for in_name in grad_op.input_names: for in_name in grad_op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input( if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
......
...@@ -153,23 +153,24 @@ class FP16State(object): ...@@ -153,23 +153,24 @@ class FP16State(object):
# ernie inference trick # ernie inference trick
if op.type == "assign" and "array_" in op.input_arg_names[0]: if op.type == "assign" and "array_" in op.input_arg_names[0]:
self._op_fp16_dict[op.desc.id()] = False self._op_fp16_dict[op.desc.original_id()] = False
return return
if _need_keep_fp32(op, self.amp_list.unsupported_list, if _need_keep_fp32(op, self.amp_list.unsupported_list,
self.use_fp16_guard): self.use_fp16_guard):
self._op_fp16_dict[op.desc.id()] = False self._op_fp16_dict[op.desc.original_id()] = False
else: else:
self._op_fp16_dict[op.desc.id()] = True self._op_fp16_dict[op.desc.original_id()] = True
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
# assert var_name not in self.forward_non_leaf_tensors, "{}".format(var_name) # assert var_name not in self.forward_non_leaf_tensors, "{}".format(var_name)
self.forward_non_leaf_tensors[var_name] = op.desc.id() self.forward_non_leaf_tensors[var_name] = op.desc.id()
elif is_backward_op(op) == int(OpRole.Backward): elif is_backward_op(op) == int(OpRole.Backward):
if op.desc.id() in self.grad_op_to_op_map: if op.desc.original_id() in self.grad_op_to_op_map:
fwd_op_id = self.grad_op_to_op_map[op.desc.id()] fwd_op_id = self.grad_op_to_op_map[op.desc.original_id()]
assert fwd_op_id in self._op_fp16_dict, "{}".format(str(op)) assert fwd_op_id in self._op_fp16_dict, "{}".format(str(op))
self._op_fp16_dict[op.desc.id()] = self._op_fp16_dict[fwd_op_id] self._op_fp16_dict[op.desc.original_id()] = self._op_fp16_dict[
fwd_op_id]
if int(op.attr('op_role')) == 257: if int(op.attr('op_role')) == 257:
self.is_train = True self.is_train = True
...@@ -192,10 +193,10 @@ class FP16State(object): ...@@ -192,10 +193,10 @@ class FP16State(object):
def resolute_tensor_dtype(self, block): def resolute_tensor_dtype(self, block):
for op in block.ops: for op in block.ops:
op_id = op.desc.id()
if is_forward_op(op): if is_forward_op(op):
# NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
if self._is_fp16_op(op_id) == True or op.type == "cast": if self._is_fp16_op(op.desc.original_id()) == True \
or op.type == "cast":
for in_name in op.input_names: for in_name in op.input_names:
if _keep_fp32_input(op, in_name): if _keep_fp32_input(op, in_name):
continue continue
...@@ -209,7 +210,7 @@ class FP16State(object): ...@@ -209,7 +210,7 @@ class FP16State(object):
self.set_var_to_fp16(out_var_name, block) self.set_var_to_fp16(out_var_name, block)
set_op_dtype_to_fp16(op) set_op_dtype_to_fp16(op)
# NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
elif self._is_fp16_op(op_id) == False: elif self._is_fp16_op(op.desc.original_id()) == False:
for out_var_name in op.output_arg_names: for out_var_name in op.output_arg_names:
out_var = block.vars.get(out_var_name) out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types: if out_var is None or out_var.type not in _valid_types:
...@@ -217,7 +218,7 @@ class FP16State(object): ...@@ -217,7 +218,7 @@ class FP16State(object):
if out_var.dtype == core.VarDesc.VarType.FP16: if out_var.dtype == core.VarDesc.VarType.FP16:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32) out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
elif is_backward_op(op): elif is_backward_op(op):
if self._is_fp16_op(op_id) == True: if self._is_fp16_op(op.desc.original_id()) == True:
for out_name in op.output_names: for out_name in op.output_names:
if _keep_fp32_output(op, out_name): if _keep_fp32_output(op, out_name):
continue continue
...@@ -225,7 +226,7 @@ class FP16State(object): ...@@ -225,7 +226,7 @@ class FP16State(object):
self.set_var_to_fp16(out_var_name, block) self.set_var_to_fp16(out_var_name, block)
set_op_dtype_to_fp16(op) set_op_dtype_to_fp16(op)
# NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
elif self._is_fp16_op(op_id) == False: elif self._is_fp16_op(op.desc.original_id()) == False:
for out_var_name in op.output_arg_names: for out_var_name in op.output_arg_names:
out_var = block.vars.get(out_var_name) out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types: if out_var is None or out_var.type not in _valid_types:
...@@ -238,28 +239,27 @@ class FP16State(object): ...@@ -238,28 +239,27 @@ class FP16State(object):
idx = 0 idx = 0
while idx < len(block.ops): while idx < len(block.ops):
op = block.ops[idx] op = block.ops[idx]
op_id = op.desc.id()
num_cast_ops = 0 num_cast_ops = 0
if op.type in __amp_skip_ops__: if op.type in __amp_skip_ops__:
idx += 1 idx += 1
continue continue
elif is_forward_op(op): elif is_forward_op(op):
if self._is_fp16_op(op_id) == False: if self._is_fp16_op(op.desc.original_id()) == False:
num_cast_ops = self._insert_forward_cast_ops( num_cast_ops = self._insert_forward_cast_ops(
op, idx, block, core.VarDesc.VarType.FP16, op, idx, block, core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32, self.dist_context) core.VarDesc.VarType.FP32, self.dist_context)
elif self._is_fp16_op(op_id) == True: elif self._is_fp16_op(op.desc.original_id()) == True:
num_cast_ops = self._insert_forward_cast_ops( num_cast_ops = self._insert_forward_cast_ops(
op, idx, block, core.VarDesc.VarType.FP32, op, idx, block, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, self.dist_context) core.VarDesc.VarType.FP16, self.dist_context)
elif is_backward_op(op): elif is_backward_op(op):
if op_id in dist_op_context.grad_op_id_to_op_id: if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
if self._is_fp16_op(op_id) == False: if self._is_fp16_op(op.desc.original_id()) == False:
num_cast_ops = self._insert_backward_cast_ops( num_cast_ops = self._insert_backward_cast_ops(
op, idx, block, core.VarDesc.VarType.FP16, op, idx, block, core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32, self.dist_context) core.VarDesc.VarType.FP32, self.dist_context)
elif self._is_fp16_op(op_id) == True: elif self._is_fp16_op(op.desc.original_id()) == True:
num_cast_ops = self._insert_backward_cast_ops( num_cast_ops = self._insert_backward_cast_ops(
op, idx, block, core.VarDesc.VarType.FP32, op, idx, block, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, self.dist_context) core.VarDesc.VarType.FP16, self.dist_context)
...@@ -282,7 +282,6 @@ class FP16State(object): ...@@ -282,7 +282,6 @@ class FP16State(object):
dist_context): dist_context):
num_cast_ops = 0 num_cast_ops = 0
op_id = op.desc.id()
for in_name in op.input_names: for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input( if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
...@@ -300,7 +299,7 @@ class FP16State(object): ...@@ -300,7 +299,7 @@ class FP16State(object):
cast_name = in_var.name + '.cast_' + _dtype_to_str( cast_name = in_var.name + '.cast_' + _dtype_to_str(
dst_dtype) dst_dtype)
cast_var = block.vars.get(cast_name) cast_var = block.vars.get(cast_name)
self.forward_input_cast_ops[op_id] += [( self.forward_input_cast_ops[op.desc.original_id()] += [(
cast_name, in_var.name, dst_dtype, src_dtype, in_name)] cast_name, in_var.name, dst_dtype, src_dtype, in_name)]
in_var_dist_attr = consume_op_attr.get_input_dist_attr( in_var_dist_attr = consume_op_attr.get_input_dist_attr(
...@@ -349,8 +348,9 @@ class FP16State(object): ...@@ -349,8 +348,9 @@ class FP16State(object):
num_cast_ops = 0 num_cast_ops = 0
op_id = op.desc.id() op_id = op.desc.id()
original_id = op.desc.original_id()
dist_op_context = dist_context.dist_op_context dist_op_context = dist_context.dist_op_context
forward_op_id = dist_op_context.grad_op_id_to_op_id[op_id] forward_op_id = dist_op_context.grad_op_id_to_op_id[original_id]
grad_op_attr = dist_context.get_op_dist_attr_for_program(op) grad_op_attr = dist_context.get_op_dist_attr_for_program(op)
assert grad_op_attr is not None assert grad_op_attr is not None
......
...@@ -315,7 +315,7 @@ class RecomputePass(PassBase): ...@@ -315,7 +315,7 @@ class RecomputePass(PassBase):
# When traversing all grad_ops in reverse, need to set a flag to indicate # When traversing all grad_ops in reverse, need to set a flag to indicate
# whether the ckpt and its segment_descs can be used. # whether the ckpt and its segment_descs can be used.
ckpt_op = op_path[segment[1] - 1] ckpt_op = op_path[segment[1] - 1]
ckpt_ops_dict[ckpt_op.desc.id()] = [True, segment_descs] ckpt_ops_dict[ckpt_op.desc.original_id()] = [True, segment_descs]
# step 4: insert recomputed fwd ops # step 4: insert recomputed fwd ops
ops = main_block.ops ops = main_block.ops
...@@ -339,9 +339,9 @@ class RecomputePass(PassBase): ...@@ -339,9 +339,9 @@ class RecomputePass(PassBase):
_rename_arg_([grad_op.desc], key, var_name_dict[key]) _rename_arg_([grad_op.desc], key, var_name_dict[key])
# insert recomputed ops # insert recomputed ops
if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id: original_id = grad_op.desc.original_id()
fwd_op_id = dist_op_context.grad_op_id_to_op_id[grad_op.desc.id( if original_id in dist_op_context.grad_op_id_to_op_id:
)] fwd_op_id = dist_op_context.grad_op_id_to_op_id[original_id]
if fwd_op_id in ckpt_ops_dict and ckpt_ops_dict[fwd_op_id][0]: if fwd_op_id in ckpt_ops_dict and ckpt_ops_dict[fwd_op_id][0]:
idx = grad_op.idx idx = grad_op.idx
while idx - 1 >= 0 and ops[idx - 1].type == "sum": while idx - 1 >= 0 and ops[idx - 1].type == "sum":
......
...@@ -1107,8 +1107,10 @@ def _append_backward_ops_(block, ...@@ -1107,8 +1107,10 @@ def _append_backward_ops_(block,
distop_context.grad_var_to_var[appending_grad_times].update( distop_context.grad_var_to_var[appending_grad_times].update(
op_grad_to_var) op_grad_to_var)
for op_desc in grad_op_desc: for op_desc in grad_op_desc:
assert op_desc.id() not in distop_context.grad_op_id_to_op_id assert op_desc.original_id(
distop_context.grad_op_id_to_op_id[op_desc.id()] = op.desc.id() ) not in distop_context.grad_op_id_to_op_id
distop_context.grad_op_id_to_op_id[op_desc.original_id(
)] = op.desc.original_id()
if callbacks is not None: if callbacks is not None:
assert (isinstance(callbacks, (list, tuple))) assert (isinstance(callbacks, (list, tuple)))
...@@ -1255,12 +1257,6 @@ def _append_backward_ops_(block, ...@@ -1255,12 +1257,6 @@ def _append_backward_ops_(block,
for op_desc in grad_op_descs: for op_desc in grad_op_descs:
new_op_desc = target_block.desc.append_op() new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op_desc) new_op_desc.copy_from(op_desc)
# Rebuild the mapping because new_op_desc has a differnt id (Only for auto parallel)
if distop_context is not None:
if op_desc.id() in distop_context.grad_op_id_to_op_id:
distop_context.grad_op_id_to_op_id[new_op_desc.id(
)] = distop_context.grad_op_id_to_op_id[op_desc.id()]
distop_context.grad_op_id_to_op_id.pop(op_desc.id())
new_op_desc._set_attr(op_role_attr_name, backward) new_op_desc._set_attr(op_role_attr_name, backward)
grad_to_var["__current_op_desc__"] = new_op_desc grad_to_var["__current_op_desc__"] = new_op_desc
if callbacks is not None: if callbacks is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册