未验证 提交 1d04e021 编写于 作者: Z zyfncg 提交者: GitHub

set flag of clip_extra in save_inference_model to true (#46151)

上级 60a3929a
......@@ -397,7 +397,7 @@ def _parse_save_configs(configs):
inner_config.output_spec = configs.get('output_spec', None)
inner_config.with_hook = configs.get('with_hook', False)
inner_config.combine_params = configs.get("combine_params", False)
inner_config.clip_extra = configs.get("clip_extra", False)
inner_config.clip_extra = configs.get("clip_extra", True)
inner_config.skip_forward = configs.get("skip_forward", False)
return inner_config
......
......@@ -5873,9 +5873,7 @@ class Program(object):
# Note: The op_role and op_role_var cann't be deleted currently,
# and we will try to remove them in the future.
common_clipped_attrs_list = [
'op_namescope', 'op_callstack', 'op_device', 'with_quant_attr'
]
common_clipped_attrs_list = ['op_callstack', 'with_quant_attr']
for i in six.moves.range(res.desc.num_blocks()):
block = res.desc.block(i)
......@@ -5904,8 +5902,9 @@ class Program(object):
break
if not find:
remove_input_list.append(name)
for name in remove_input_list:
op.remove_input(name)
# The extra input of op will be removed in the future
# for name in remove_input_list:
# op.remove_input(name)
remove_output_list = []
for name in op.output_names():
......@@ -5919,10 +5918,10 @@ class Program(object):
break
if not find:
remove_output_list.append(name)
for name in remove_output_list:
op.remove_output(name)
# The extra output of op will be removed in the future
# for name in remove_output_list:
# op.remove_output(name)
remove_attr_list = []
op_quant_name = core.op_proto_and_checker_maker.kOpWithQuantAttrName(
)
quant = bool(op.attr(op_quant_name)
......@@ -5932,6 +5931,7 @@ class Program(object):
"activation_bits", "bit_length", "quantize_weight_bits",
"weight_quant_scale"
]
remove_attr_list = []
for name in op.attr_names():
if quant:
if name in quant_attrs:
......@@ -5946,8 +5946,6 @@ class Program(object):
for attr_proto in proto.attrs:
if attr_proto.name != name:
continue
if attr_proto.extra:
remove_attr_list.append(name)
find = True
break
if not find:
......
......@@ -1232,7 +1232,7 @@ def save_inference_model(dirname,
params_filename=None,
export_for_deployment=True,
program_only=False,
clip_extra=False):
clip_extra=True):
"""
Prune the given `main_program` to build a new program especially for inference,
and then save it and all related parameters to given `dirname` .
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册