diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 5f62e5d3ea5ce484b2a40695a74feaa37d30cda4..16c4fc6acbe83ea945b7daf0c49c678ea78c5929 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -5867,26 +5867,23 @@ class Program(object): # Note: The op_role and op_role_var cann't be deleted currently, # and we will try to remove them in the future. - common_clipped_attrs_list = ['op_namescope', 'op_callstack'] + common_clipped_attrs_list = [ + 'op_namescope', 'op_callstack', 'op_device', 'with_quant_attr' + ] for i in six.moves.range(res.desc.num_blocks()): block = res.desc.block(i) for var in block.all_vars(): var.clear_is_parameter() var.clear_stop_gradient() + if not clip_extra: + continue for op_idx in range(0, block.op_size()): op = block.op(op_idx) if op.type() not in OpProtoHolder.instance().op_proto_map: continue - if not clip_extra: - continue - extra_attrs_map = core.get_op_extra_attrs(op.type()) - for name in op.attr_names(): - if name in extra_attrs_map: - op.remove_attr(name) - continue proto = OpProtoHolder.instance().get_op_proto(op.type()) remove_input_list = [] @@ -5901,9 +5898,8 @@ class Program(object): break if not find: remove_input_list.append(name) - # The extra input of op will be removed in the future - # for name in remove_input_list: - # op.remove_input(name) + for name in remove_input_list: + op.remove_input(name) remove_output_list = [] for name in op.output_names(): @@ -5917,10 +5913,10 @@ class Program(object): break if not find: remove_output_list.append(name) - # The extra input of op will be removed in the future - # for name in remove_output_list: - # op.remove_output(name) + for name in remove_output_list: + op.remove_output(name) + remove_attr_list = [] op_quant_name = core.op_proto_and_checker_maker.kOpWithQuantAttrName( ) quant = bool(op.attr(op_quant_name) @@ -5930,21 +5926,22 @@ class Program(object): "activation_bits", "bit_length", "quantize_weight_bits", "weight_quant_scale" ] - remove_attr_list = [] for name in op.attr_names(): if quant: if name in quant_attrs: continue if name.endswith("_threshold"): continue - if name in common_clipped_attrs_list: - remove_attr_list.append(name) + if len(extra_attrs_map) > 0: + if name in extra_attrs_map or name in common_clipped_attrs_list: + op.remove_attr(name) continue - find = False for attr_proto in proto.attrs: if attr_proto.name != name: continue + if attr_proto.extra: + remove_attr_list.append(name) find = True break if not find: