From 213f80388903688fd595653b9a62771a3ad08586 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 31 Aug 2022 19:39:25 +0800 Subject: [PATCH] Revert the change of remove_training_info (#45582) * revert the change of remove_training_info * update * update --- python/paddle/fluid/framework.py | 33 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 5f62e5d3ea..16c4fc6acb 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -5867,26 +5867,23 @@ class Program(object): # Note: The op_role and op_role_var cann't be deleted currently, # and we will try to remove them in the future. - common_clipped_attrs_list = ['op_namescope', 'op_callstack'] + common_clipped_attrs_list = [ + 'op_namescope', 'op_callstack', 'op_device', 'with_quant_attr' + ] for i in six.moves.range(res.desc.num_blocks()): block = res.desc.block(i) for var in block.all_vars(): var.clear_is_parameter() var.clear_stop_gradient() + if not clip_extra: + continue for op_idx in range(0, block.op_size()): op = block.op(op_idx) if op.type() not in OpProtoHolder.instance().op_proto_map: continue - if not clip_extra: - continue - extra_attrs_map = core.get_op_extra_attrs(op.type()) - for name in op.attr_names(): - if name in extra_attrs_map: - op.remove_attr(name) - continue proto = OpProtoHolder.instance().get_op_proto(op.type()) remove_input_list = [] @@ -5901,9 +5898,8 @@ class Program(object): break if not find: remove_input_list.append(name) - # The extra input of op will be removed in the future - # for name in remove_input_list: - # op.remove_input(name) + for name in remove_input_list: + op.remove_input(name) remove_output_list = [] for name in op.output_names(): @@ -5917,10 +5913,10 @@ class Program(object): break if not find: remove_output_list.append(name) - # The extra input of op will be removed in the future - # for name in remove_output_list: - # op.remove_output(name) + for name in remove_output_list: + op.remove_output(name) + remove_attr_list = [] op_quant_name = core.op_proto_and_checker_maker.kOpWithQuantAttrName( ) quant = bool(op.attr(op_quant_name) @@ -5930,21 +5926,22 @@ class Program(object): "activation_bits", "bit_length", "quantize_weight_bits", "weight_quant_scale" ] - remove_attr_list = [] for name in op.attr_names(): if quant: if name in quant_attrs: continue if name.endswith("_threshold"): continue - if name in common_clipped_attrs_list: - remove_attr_list.append(name) + if len(extra_attrs_map) > 0: + if name in extra_attrs_map or name in common_clipped_attrs_list: + op.remove_attr(name) continue - find = False for attr_proto in proto.attrs: if attr_proto.name != name: continue + if attr_proto.extra: + remove_attr_list.append(name) find = True break if not find: -- GitLab