未验证 提交 213f8038 编写于 作者: Z zyfncg 提交者: GitHub

Revert the change of remove_training_info (#45582)

* revert the change of remove_training_info

* update

* update
上级 413d6e1b
......@@ -5867,26 +5867,23 @@ class Program(object):
# Note: The op_role and op_role_var cann't be deleted currently,
# and we will try to remove them in the future.
common_clipped_attrs_list = ['op_namescope', 'op_callstack']
common_clipped_attrs_list = [
'op_namescope', 'op_callstack', 'op_device', 'with_quant_attr'
]
for i in six.moves.range(res.desc.num_blocks()):
block = res.desc.block(i)
for var in block.all_vars():
var.clear_is_parameter()
var.clear_stop_gradient()
if not clip_extra:
continue
for op_idx in range(0, block.op_size()):
op = block.op(op_idx)
if op.type() not in OpProtoHolder.instance().op_proto_map:
continue
if not clip_extra:
continue
extra_attrs_map = core.get_op_extra_attrs(op.type())
for name in op.attr_names():
if name in extra_attrs_map:
op.remove_attr(name)
continue
proto = OpProtoHolder.instance().get_op_proto(op.type())
remove_input_list = []
......@@ -5901,9 +5898,8 @@ class Program(object):
break
if not find:
remove_input_list.append(name)
# The extra input of op will be removed in the future
# for name in remove_input_list:
# op.remove_input(name)
for name in remove_input_list:
op.remove_input(name)
remove_output_list = []
for name in op.output_names():
......@@ -5917,10 +5913,10 @@ class Program(object):
break
if not find:
remove_output_list.append(name)
# The extra input of op will be removed in the future
# for name in remove_output_list:
# op.remove_output(name)
for name in remove_output_list:
op.remove_output(name)
remove_attr_list = []
op_quant_name = core.op_proto_and_checker_maker.kOpWithQuantAttrName(
)
quant = bool(op.attr(op_quant_name)
......@@ -5930,21 +5926,22 @@ class Program(object):
"activation_bits", "bit_length", "quantize_weight_bits",
"weight_quant_scale"
]
remove_attr_list = []
for name in op.attr_names():
if quant:
if name in quant_attrs:
continue
if name.endswith("_threshold"):
continue
if name in common_clipped_attrs_list:
remove_attr_list.append(name)
if len(extra_attrs_map) > 0:
if name in extra_attrs_map or name in common_clipped_attrs_list:
op.remove_attr(name)
continue
find = False
for attr_proto in proto.attrs:
if attr_proto.name != name:
continue
if attr_proto.extra:
remove_attr_list.append(name)
find = True
break
if not find:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册