From 17b8446d459bc3ddde7eee71d04e5ed4c986fbc5 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Mon, 30 May 2022 17:46:12 +0800 Subject: [PATCH] [AutoParallel] use original id in grad_op_id_to_op_id (#42992) * use original id in dist_op_context.grad_op_id_to_op_id * del assert * remove redundant map --- .../distributed/auto_parallel/completion.py | 21 ++++++---- .../distributed/auto_parallel/dist_context.py | 10 +++-- .../auto_parallel/parallelizer_v2.py | 2 +- .../distributed/auto_parallel/partitioner.py | 8 ++-- .../distributed/passes/auto_parallel_amp.py | 42 ++++++++++--------- .../distributed/passes/auto_parallel_fp16.py | 40 +++++++++--------- .../passes/auto_parallel_recompute.py | 8 ++-- python/paddle/fluid/backward.py | 12 ++---- 8 files changed, 74 insertions(+), 69 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 31bdc4cc650..03996ec350d 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -771,7 +771,7 @@ class Completer: def _get_op_by_id(ops, id): for op in ops: - if op.desc.id() == id: + if op.desc.original_id() == id: return op return None @@ -796,10 +796,12 @@ class Completer: # complete the annotation of grad op (xxx_grad op or sum op) # xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id grad_op = ops[idx] - if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id: + if grad_op.desc.original_id( + ) in dist_op_context.grad_op_id_to_op_id: # TODO support the case where one forward op corresponding to multiple xxx_grad op - forward_op = _get_op_by_id( - ops, dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()]) + forward_op = _get_op_by_id(ops, + dist_op_context.grad_op_id_to_op_id[ + grad_op.desc.original_id()]) assert forward_op is not None fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( @@ -935,7 +937,7 @@ class Completer: def _get_op_by_id(ops, id): for op in ops: - if op.desc.id() == id: + if op.desc.original_id() == id: return op return None @@ -997,11 +999,12 @@ class Completer: # complete the annotation of grad op (xxx_grad op or sum op) # xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id grad_op = ops[idx] - if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id: + if grad_op.desc.original_id( + ) in dist_op_context.grad_op_id_to_op_id: # TODO support the case where one forward op corresponding to multiple xxx_grad op - forward_op = _get_op_by_id( - ops[:first_backward_op_idx], - dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()]) + forward_op = _get_op_by_id(ops[:first_backward_op_idx], + dist_op_context.grad_op_id_to_op_id[ + grad_op.desc.original_id()]) assert forward_op is not None if grad_op.type == "concat" and forward_op.type == "split": diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index 7299f84504b..a47ef66ee84 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -204,9 +204,13 @@ class DistributedContext: ) self._serial_startup_program = self._original_serial_startup_program.clone( ) - self._serial_main_program = self._original_serial_main_program - self._serial_startup_program = self._original_serial_startup_program - self._serial_loss = self._original_serial_loss + # self._serial_main_program = self._original_serial_main_program + # self._serial_startup_program = self._original_serial_startup_program + if self._original_serial_loss: + self._serial_loss = self._serial_main_program.global_block( + ).vars[self._original_serial_loss[0].name] + else: + self._serial_loss = self._original_serial_loss self._serial_optimizer = self._original_serial_optimizer self._init_dist_attr_for_program() self._tensors_ids = list(self._dist_tensors_for_program.keys()) diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py index 6a94bbd3130..4d736327610 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -51,7 +51,7 @@ class Parallelizer: serial_optimizer = self._dist_context.serial_optimizer if self._mode == "train" and serial_optimizer: # Generate backward - serial_loss = self._dist_context.serial_fetch_vars["loss"][0] + serial_loss = self._dist_context.serial_loss params_grads = self._generate_backward( serial_main_program, serial_startup_program, serial_loss) # Apply pre optimization passes diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index 91a31dd1b92..ce686fd6a56 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -211,7 +211,7 @@ class Partitioner(object): forward_op_id2forward_op = {} for idx in range(len(serial_ops)): if idx <= last_fwd_op_idx: - forward_op_id2forward_op[serial_ops[idx].desc.id( + forward_op_id2forward_op[serial_ops[idx].desc.original_id( )] = serial_ops[idx] appended_grad_times = 0 @@ -408,9 +408,9 @@ def _partition_var(dist_context, src_block, dst_block, src_varname, def _get_dist_op_backward_implement(backward_op, dist_context, forward_op_id2forward_op): dist_op_context = dist_context.dist_op_context - if backward_op.desc.id() in dist_op_context.grad_op_id_to_op_id: - forward_op_id = dist_op_context.grad_op_id_to_op_id[backward_op.desc.id( - )] + if backward_op.desc.original_id() in dist_op_context.grad_op_id_to_op_id: + forward_op_id = dist_op_context.grad_op_id_to_op_id[ + backward_op.desc.original_id()] forward_op = forward_op_id2forward_op[forward_op_id] forward_op_dist_attr = dist_context.get_op_dist_attr_for_program( forward_op) diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index fe94c25e12d..3cd04affa29 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -46,13 +46,13 @@ class AMPState(object): if int(op.attr('op_role')) == int(OpRole.Forward): self._mark_black_white_ops(amp_lists) elif int(op.attr('op_role')) == int(OpRole.Backward): - if op.desc.id() in dist_op_context.grad_op_id_to_op_id: - fwd_op_id = dist_op_context.grad_op_id_to_op_id[op.desc.id( - )] + if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id: + fwd_op_id = dist_op_context.grad_op_id_to_op_id[ + op.desc.original_id()] if self._is_fp16_op(fwd_op_id) == True: - self._op_fp16_dict[op.desc.id()] = True + self._op_fp16_dict[op.desc.original_id()] = True elif self._is_fp16_op(fwd_op_id) == False: - self._op_fp16_dict[op.desc.id()] = False + self._op_fp16_dict[op.desc.original_id()] = False elif int(op.attr('op_role')) == int(OpRole.Optimize): break @@ -70,12 +70,12 @@ class AMPState(object): continue if amp_lists.black_varnames is not None and _is_in_black_varnames( op, amp_lists): - self._op_fp16_dict[op.desc.id()] = False + self._op_fp16_dict[op.desc.original_id()] = False continue if op.type in amp_lists.black_list: - self._op_fp16_dict[op.desc.id()] = False + self._op_fp16_dict[op.desc.original_id()] = False elif op.type in amp_lists.white_list: - self._op_fp16_dict[op.desc.id()] = True + self._op_fp16_dict[op.desc.original_id()] = True elif op.type in amp_lists.gray_list: is_black_op = False is_white_op = False @@ -95,22 +95,22 @@ class AMPState(object): else: prev_op = in_var.op # if it's one of inputs - if self._is_fp16_op(prev_op.desc.id()) == False or \ + if self._is_fp16_op(prev_op.desc.original_id()) == False or \ prev_op.type in amp_lists.black_list: is_black_op = True - elif self._is_fp16_op(prev_op.desc.id()) == True or \ + elif self._is_fp16_op(prev_op.desc.original_id()) == True or \ prev_op.type in amp_lists.white_list: is_white_op = True if is_black_op: - self._op_fp16_dict[op.desc.id()] = False + self._op_fp16_dict[op.desc.original_id()] = False elif is_white_op: - self._op_fp16_dict[op.desc.id()] = True + self._op_fp16_dict[op.desc.original_id()] = True else: pass else: # For numerical safe, we apply fp32 computation on ops that # are not determined which list they should stay. - self._op_fp16_dict[op.desc.id()] = False + self._op_fp16_dict[op.desc.original_id()] = False def cast_forward_program(self, dist_context): ops = self._block.ops @@ -120,11 +120,11 @@ class AMPState(object): num_cast_ops = 0 if int(op.attr('op_role')) == int(OpRole.Backward): break - if self._is_fp16_op(op.desc.id()) == False: + if self._is_fp16_op(op.desc.original_id()) == False: num_cast_ops = self._insert_cast_op_forward( op, idx, core.VarDesc.VarType.FP16, core.VarDesc.VarType.FP32, dist_context) - elif self._is_fp16_op(op.desc.id()) == True: + elif self._is_fp16_op(op.desc.original_id()) == True: num_cast_ops = self._insert_cast_op_forward( op, idx, core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP16, dist_context) @@ -198,7 +198,7 @@ class AMPState(object): else: if op.has_attr('in_dtype'): op._set_attr('in_dtype', dst_dtype) - self._var_name_dict[op.desc.id()] = var_name_dict + self._var_name_dict[op.desc.original_id()] = var_name_dict if src_dtype == core.VarDesc.VarType.FP32 and dst_dtype == core.VarDesc.VarType.FP16: for out_name in op.output_names: @@ -225,13 +225,14 @@ class AMPState(object): while idx < len(ops): num_cast_ops = 0 grad_op = ops[idx] + grad_op_orig_id = grad_op.desc.original_id() dist_op_context = dist_context.dist_op_context - if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id: - if self._is_fp16_op(grad_op.desc.id()) == False: # fp32 + if grad_op_orig_id in dist_op_context.grad_op_id_to_op_id: + if self._is_fp16_op(grad_op_orig_id) == False: # fp32 num_cast_ops = self._insert_cast_op_backward( grad_op, idx, core.VarDesc.VarType.FP16, core.VarDesc.VarType.FP32, dist_context) - elif self._is_fp16_op(grad_op.desc.id()) == True: # fp16 + elif self._is_fp16_op(grad_op_orig_id) == True: # fp16 num_cast_ops = self._insert_cast_op_backward( grad_op, idx, core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP16, dist_context) @@ -272,8 +273,9 @@ class AMPState(object): return False num_cast_ops = 0 + original_id = grad_op.desc.original_id() dist_op_context = dist_context.dist_op_context - fwd_op_id = dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()] + fwd_op_id = dist_op_context.grad_op_id_to_op_id[original_id] for in_name in grad_op.input_names: if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input( diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 9dda310e5c0..b01f3975aef 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -153,23 +153,24 @@ class FP16State(object): # ernie inference trick if op.type == "assign" and "array_" in op.input_arg_names[0]: - self._op_fp16_dict[op.desc.id()] = False + self._op_fp16_dict[op.desc.original_id()] = False return if _need_keep_fp32(op, self.amp_list.unsupported_list, self.use_fp16_guard): - self._op_fp16_dict[op.desc.id()] = False + self._op_fp16_dict[op.desc.original_id()] = False else: - self._op_fp16_dict[op.desc.id()] = True + self._op_fp16_dict[op.desc.original_id()] = True for var_name in op.output_arg_names: # assert var_name not in self.forward_non_leaf_tensors, "{}".format(var_name) self.forward_non_leaf_tensors[var_name] = op.desc.id() elif is_backward_op(op) == int(OpRole.Backward): - if op.desc.id() in self.grad_op_to_op_map: - fwd_op_id = self.grad_op_to_op_map[op.desc.id()] + if op.desc.original_id() in self.grad_op_to_op_map: + fwd_op_id = self.grad_op_to_op_map[op.desc.original_id()] assert fwd_op_id in self._op_fp16_dict, "{}".format(str(op)) - self._op_fp16_dict[op.desc.id()] = self._op_fp16_dict[fwd_op_id] + self._op_fp16_dict[op.desc.original_id()] = self._op_fp16_dict[ + fwd_op_id] if int(op.attr('op_role')) == 257: self.is_train = True @@ -192,10 +193,10 @@ class FP16State(object): def resolute_tensor_dtype(self, block): for op in block.ops: - op_id = op.desc.id() if is_forward_op(op): # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python - if self._is_fp16_op(op_id) == True or op.type == "cast": + if self._is_fp16_op(op.desc.original_id()) == True \ + or op.type == "cast": for in_name in op.input_names: if _keep_fp32_input(op, in_name): continue @@ -209,7 +210,7 @@ class FP16State(object): self.set_var_to_fp16(out_var_name, block) set_op_dtype_to_fp16(op) # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python - elif self._is_fp16_op(op_id) == False: + elif self._is_fp16_op(op.desc.original_id()) == False: for out_var_name in op.output_arg_names: out_var = block.vars.get(out_var_name) if out_var is None or out_var.type not in _valid_types: @@ -217,7 +218,7 @@ class FP16State(object): if out_var.dtype == core.VarDesc.VarType.FP16: out_var.desc.set_dtype(core.VarDesc.VarType.FP32) elif is_backward_op(op): - if self._is_fp16_op(op_id) == True: + if self._is_fp16_op(op.desc.original_id()) == True: for out_name in op.output_names: if _keep_fp32_output(op, out_name): continue @@ -225,7 +226,7 @@ class FP16State(object): self.set_var_to_fp16(out_var_name, block) set_op_dtype_to_fp16(op) # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python - elif self._is_fp16_op(op_id) == False: + elif self._is_fp16_op(op.desc.original_id()) == False: for out_var_name in op.output_arg_names: out_var = block.vars.get(out_var_name) if out_var is None or out_var.type not in _valid_types: @@ -238,28 +239,27 @@ class FP16State(object): idx = 0 while idx < len(block.ops): op = block.ops[idx] - op_id = op.desc.id() num_cast_ops = 0 if op.type in __amp_skip_ops__: idx += 1 continue elif is_forward_op(op): - if self._is_fp16_op(op_id) == False: + if self._is_fp16_op(op.desc.original_id()) == False: num_cast_ops = self._insert_forward_cast_ops( op, idx, block, core.VarDesc.VarType.FP16, core.VarDesc.VarType.FP32, self.dist_context) - elif self._is_fp16_op(op_id) == True: + elif self._is_fp16_op(op.desc.original_id()) == True: num_cast_ops = self._insert_forward_cast_ops( op, idx, block, core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP16, self.dist_context) elif is_backward_op(op): - if op_id in dist_op_context.grad_op_id_to_op_id: - if self._is_fp16_op(op_id) == False: + if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id: + if self._is_fp16_op(op.desc.original_id()) == False: num_cast_ops = self._insert_backward_cast_ops( op, idx, block, core.VarDesc.VarType.FP16, core.VarDesc.VarType.FP32, self.dist_context) - elif self._is_fp16_op(op_id) == True: + elif self._is_fp16_op(op.desc.original_id()) == True: num_cast_ops = self._insert_backward_cast_ops( op, idx, block, core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP16, self.dist_context) @@ -282,7 +282,6 @@ class FP16State(object): dist_context): num_cast_ops = 0 - op_id = op.desc.id() for in_name in op.input_names: if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input( @@ -300,7 +299,7 @@ class FP16State(object): cast_name = in_var.name + '.cast_' + _dtype_to_str( dst_dtype) cast_var = block.vars.get(cast_name) - self.forward_input_cast_ops[op_id] += [( + self.forward_input_cast_ops[op.desc.original_id()] += [( cast_name, in_var.name, dst_dtype, src_dtype, in_name)] in_var_dist_attr = consume_op_attr.get_input_dist_attr( @@ -349,8 +348,9 @@ class FP16State(object): num_cast_ops = 0 op_id = op.desc.id() + original_id = op.desc.original_id() dist_op_context = dist_context.dist_op_context - forward_op_id = dist_op_context.grad_op_id_to_op_id[op_id] + forward_op_id = dist_op_context.grad_op_id_to_op_id[original_id] grad_op_attr = dist_context.get_op_dist_attr_for_program(op) assert grad_op_attr is not None diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index 258f46304d1..c6d16854462 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -315,7 +315,7 @@ class RecomputePass(PassBase): # When traversing all grad_ops in reverse, need to set a flag to indicate # whether the ckpt and its segment_descs can be used. ckpt_op = op_path[segment[1] - 1] - ckpt_ops_dict[ckpt_op.desc.id()] = [True, segment_descs] + ckpt_ops_dict[ckpt_op.desc.original_id()] = [True, segment_descs] # step 4: insert recomputed fwd ops ops = main_block.ops @@ -339,9 +339,9 @@ class RecomputePass(PassBase): _rename_arg_([grad_op.desc], key, var_name_dict[key]) # insert recomputed ops - if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id: - fwd_op_id = dist_op_context.grad_op_id_to_op_id[grad_op.desc.id( - )] + original_id = grad_op.desc.original_id() + if original_id in dist_op_context.grad_op_id_to_op_id: + fwd_op_id = dist_op_context.grad_op_id_to_op_id[original_id] if fwd_op_id in ckpt_ops_dict and ckpt_ops_dict[fwd_op_id][0]: idx = grad_op.idx while idx - 1 >= 0 and ops[idx - 1].type == "sum": diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 145ecc83cfc..ed3e0bc98ed 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -1107,8 +1107,10 @@ def _append_backward_ops_(block, distop_context.grad_var_to_var[appending_grad_times].update( op_grad_to_var) for op_desc in grad_op_desc: - assert op_desc.id() not in distop_context.grad_op_id_to_op_id - distop_context.grad_op_id_to_op_id[op_desc.id()] = op.desc.id() + assert op_desc.original_id( + ) not in distop_context.grad_op_id_to_op_id + distop_context.grad_op_id_to_op_id[op_desc.original_id( + )] = op.desc.original_id() if callbacks is not None: assert (isinstance(callbacks, (list, tuple))) @@ -1255,12 +1257,6 @@ def _append_backward_ops_(block, for op_desc in grad_op_descs: new_op_desc = target_block.desc.append_op() new_op_desc.copy_from(op_desc) - # Rebuild the mapping because new_op_desc has a differnt id (Only for auto parallel) - if distop_context is not None: - if op_desc.id() in distop_context.grad_op_id_to_op_id: - distop_context.grad_op_id_to_op_id[new_op_desc.id( - )] = distop_context.grad_op_id_to_op_id[op_desc.id()] - distop_context.grad_op_id_to_op_id.pop(op_desc.id()) new_op_desc._set_attr(op_role_attr_name, backward) grad_to_var["__current_op_desc__"] = new_op_desc if callbacks is not None: -- GitLab