未验证 提交 213f8038 编写于 作者: Z zyfncg 提交者: GitHub

Revert the change of remove_training_info (#45582)

* revert the change of remove_training_info

* update

* update
上级 413d6e1b
...@@ -5867,26 +5867,23 @@ class Program(object): ...@@ -5867,26 +5867,23 @@ class Program(object):
# Note: The op_role and op_role_var cann't be deleted currently, # Note: The op_role and op_role_var cann't be deleted currently,
# and we will try to remove them in the future. # and we will try to remove them in the future.
common_clipped_attrs_list = ['op_namescope', 'op_callstack'] common_clipped_attrs_list = [
'op_namescope', 'op_callstack', 'op_device', 'with_quant_attr'
]
for i in six.moves.range(res.desc.num_blocks()): for i in six.moves.range(res.desc.num_blocks()):
block = res.desc.block(i) block = res.desc.block(i)
for var in block.all_vars(): for var in block.all_vars():
var.clear_is_parameter() var.clear_is_parameter()
var.clear_stop_gradient() var.clear_stop_gradient()
if not clip_extra:
continue
for op_idx in range(0, block.op_size()): for op_idx in range(0, block.op_size()):
op = block.op(op_idx) op = block.op(op_idx)
if op.type() not in OpProtoHolder.instance().op_proto_map: if op.type() not in OpProtoHolder.instance().op_proto_map:
continue continue
if not clip_extra:
continue
extra_attrs_map = core.get_op_extra_attrs(op.type()) extra_attrs_map = core.get_op_extra_attrs(op.type())
for name in op.attr_names():
if name in extra_attrs_map:
op.remove_attr(name)
continue
proto = OpProtoHolder.instance().get_op_proto(op.type()) proto = OpProtoHolder.instance().get_op_proto(op.type())
remove_input_list = [] remove_input_list = []
...@@ -5901,9 +5898,8 @@ class Program(object): ...@@ -5901,9 +5898,8 @@ class Program(object):
break break
if not find: if not find:
remove_input_list.append(name) remove_input_list.append(name)
# The extra input of op will be removed in the future for name in remove_input_list:
# for name in remove_input_list: op.remove_input(name)
# op.remove_input(name)
remove_output_list = [] remove_output_list = []
for name in op.output_names(): for name in op.output_names():
...@@ -5917,10 +5913,10 @@ class Program(object): ...@@ -5917,10 +5913,10 @@ class Program(object):
break break
if not find: if not find:
remove_output_list.append(name) remove_output_list.append(name)
# The extra input of op will be removed in the future for name in remove_output_list:
# for name in remove_output_list: op.remove_output(name)
# op.remove_output(name)
remove_attr_list = []
op_quant_name = core.op_proto_and_checker_maker.kOpWithQuantAttrName( op_quant_name = core.op_proto_and_checker_maker.kOpWithQuantAttrName(
) )
quant = bool(op.attr(op_quant_name) quant = bool(op.attr(op_quant_name)
...@@ -5930,21 +5926,22 @@ class Program(object): ...@@ -5930,21 +5926,22 @@ class Program(object):
"activation_bits", "bit_length", "quantize_weight_bits", "activation_bits", "bit_length", "quantize_weight_bits",
"weight_quant_scale" "weight_quant_scale"
] ]
remove_attr_list = []
for name in op.attr_names(): for name in op.attr_names():
if quant: if quant:
if name in quant_attrs: if name in quant_attrs:
continue continue
if name.endswith("_threshold"): if name.endswith("_threshold"):
continue continue
if name in common_clipped_attrs_list: if len(extra_attrs_map) > 0:
remove_attr_list.append(name) if name in extra_attrs_map or name in common_clipped_attrs_list:
op.remove_attr(name)
continue continue
find = False find = False
for attr_proto in proto.attrs: for attr_proto in proto.attrs:
if attr_proto.name != name: if attr_proto.name != name:
continue continue
if attr_proto.extra:
remove_attr_list.append(name)
find = True find = True
break break
if not find: if not find:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册