未验证 提交 d67da3dc 编写于 作者: Z zyfncg 提交者: GitHub

[cherry-pick] Open the clip_extra flag in save_inference_model (#46577)

* set flag of clip_extra in save_inference_model to true (#46151)

* open the clip_extra flag in paddle.static.save_inference_model, test=allcase (#46456)

* Open the clip_extra flag in TracedLayer.save_inference_model (#46473)

* open the clip_extra flag in paddle.static.save_inference_model, test=allcase

* set the defalut value of clip_extra in TracedLayer from False to True, test=allcase

* update english doc of paddle.static.save_inference_model, test=document_fix (#46484)

* Fix clip_extra logic in remove_training_info (#46534)

* fix clip_extra code in remove_training_info

* revert rnn opmaker clear
上级 d90db9bd
...@@ -103,6 +103,9 @@ class RNNOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -103,6 +103,9 @@ class RNNOpMaker : public framework::OpProtoAndCheckerMaker {
"mode", "mode",
"(string) rnn types, including: LSTM, GRU, RNN_RELU, RNN_TANH."); "(string) rnn types, including: LSTM, GRU, RNN_RELU, RNN_TANH.");
AddAttr<int>("seed", "seed to used if fix_seed is True").SetDefault(0); AddAttr<int>("seed", "seed to used if fix_seed is True").SetDefault(0);
AddAttr<bool>("is_test", "True if in test phase.")
.SetDefault(false)
.AsExtra();
AddComment(R"DOC( AddComment(R"DOC(
)DOC"); )DOC");
} }
......
# - op : rnn
# backward : rnn_grad
# extra :
# attrs : [bool is_test = false]
- op : abs - op : abs
backward : abs_grad backward : abs_grad
extra : extra :
...@@ -609,11 +614,6 @@ ...@@ -609,11 +614,6 @@
extra : extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false] attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- op : rnn
backward : rnn_grad
extra :
attrs : [bool is_test = false]
- op : round - op : round
backward : round_grad backward : round_grad
extra : extra :
......
...@@ -397,7 +397,7 @@ def _parse_save_configs(configs): ...@@ -397,7 +397,7 @@ def _parse_save_configs(configs):
inner_config.output_spec = configs.get('output_spec', None) inner_config.output_spec = configs.get('output_spec', None)
inner_config.with_hook = configs.get('with_hook', False) inner_config.with_hook = configs.get('with_hook', False)
inner_config.combine_params = configs.get("combine_params", False) inner_config.combine_params = configs.get("combine_params", False)
inner_config.clip_extra = configs.get("clip_extra", False) inner_config.clip_extra = configs.get("clip_extra", True)
inner_config.skip_forward = configs.get("skip_forward", False) inner_config.skip_forward = configs.get("skip_forward", False)
return inner_config return inner_config
...@@ -1650,7 +1650,7 @@ class TracedLayer(object): ...@@ -1650,7 +1650,7 @@ class TracedLayer(object):
check_type( check_type(
f, "each element of fetch", int, f, "each element of fetch", int,
"fluid.dygraph.jit.TracedLayer.save_inference_model") "fluid.dygraph.jit.TracedLayer.save_inference_model")
clip_extra = kwargs.get('clip_extra', False) clip_extra = kwargs.get('clip_extra', True)
# path check # path check
file_prefix = os.path.basename(path) file_prefix = os.path.basename(path)
if file_prefix == "": if file_prefix == "":
......
...@@ -5867,9 +5867,7 @@ class Program(object): ...@@ -5867,9 +5867,7 @@ class Program(object):
# Note: The op_role and op_role_var cann't be deleted currently, # Note: The op_role and op_role_var cann't be deleted currently,
# and we will try to remove them in the future. # and we will try to remove them in the future.
common_clipped_attrs_list = [ common_clipped_attrs_list = ['op_callstack', 'with_quant_attr']
'op_namescope', 'op_callstack', 'op_device', 'with_quant_attr'
]
for i in six.moves.range(res.desc.num_blocks()): for i in six.moves.range(res.desc.num_blocks()):
block = res.desc.block(i) block = res.desc.block(i)
...@@ -5898,8 +5896,9 @@ class Program(object): ...@@ -5898,8 +5896,9 @@ class Program(object):
break break
if not find: if not find:
remove_input_list.append(name) remove_input_list.append(name)
for name in remove_input_list: # The extra input of op will be removed in the future
op.remove_input(name) # for name in remove_input_list:
# op.remove_input(name)
remove_output_list = [] remove_output_list = []
for name in op.output_names(): for name in op.output_names():
...@@ -5913,10 +5912,10 @@ class Program(object): ...@@ -5913,10 +5912,10 @@ class Program(object):
break break
if not find: if not find:
remove_output_list.append(name) remove_output_list.append(name)
for name in remove_output_list: # The extra output of op will be removed in the future
op.remove_output(name) # for name in remove_output_list:
# op.remove_output(name)
remove_attr_list = []
op_quant_name = core.op_proto_and_checker_maker.kOpWithQuantAttrName( op_quant_name = core.op_proto_and_checker_maker.kOpWithQuantAttrName(
) )
quant = bool(op.attr(op_quant_name) quant = bool(op.attr(op_quant_name)
...@@ -5926,6 +5925,9 @@ class Program(object): ...@@ -5926,6 +5925,9 @@ class Program(object):
"activation_bits", "bit_length", "quantize_weight_bits", "activation_bits", "bit_length", "quantize_weight_bits",
"weight_quant_scale" "weight_quant_scale"
] ]
for extra_attr_name in extra_attrs_map.keys():
op.remove_attr(extra_attr_name)
remove_attr_list = []
for name in op.attr_names(): for name in op.attr_names():
if quant: if quant:
if name in quant_attrs: if name in quant_attrs:
...@@ -5933,15 +5935,13 @@ class Program(object): ...@@ -5933,15 +5935,13 @@ class Program(object):
if name.endswith("_threshold"): if name.endswith("_threshold"):
continue continue
if len(extra_attrs_map) > 0: if len(extra_attrs_map) > 0:
if name in extra_attrs_map or name in common_clipped_attrs_list: if name in common_clipped_attrs_list:
op.remove_attr(name) op.remove_attr(name)
continue continue
find = False find = False
for attr_proto in proto.attrs: for attr_proto in proto.attrs:
if attr_proto.name != name: if attr_proto.name != name:
continue continue
if attr_proto.extra:
remove_attr_list.append(name)
find = True find = True
break break
if not find: if not find:
......
...@@ -1232,7 +1232,7 @@ def save_inference_model(dirname, ...@@ -1232,7 +1232,7 @@ def save_inference_model(dirname,
params_filename=None, params_filename=None,
export_for_deployment=True, export_for_deployment=True,
program_only=False, program_only=False,
clip_extra=False): clip_extra=True):
""" """
Prune the given `main_program` to build a new program especially for inference, Prune the given `main_program` to build a new program especially for inference,
and then save it and all related parameters to given `dirname` . and then save it and all related parameters to given `dirname` .
......
...@@ -454,8 +454,6 @@ def save_to_file(path, content): ...@@ -454,8 +454,6 @@ def save_to_file(path, content):
def save_inference_model(path_prefix, feed_vars, fetch_vars, executor, def save_inference_model(path_prefix, feed_vars, fetch_vars, executor,
**kwargs): **kwargs):
""" """
:api_attr: Static Graph
Save current model and its parameters to given path. i.e. Save current model and its parameters to given path. i.e.
Given path_prefix = "/path/to/modelname", after invoking Given path_prefix = "/path/to/modelname", after invoking
save_inference_model(path_prefix, feed_vars, fetch_vars, executor), save_inference_model(path_prefix, feed_vars, fetch_vars, executor),
...@@ -472,7 +470,7 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor, ...@@ -472,7 +470,7 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor,
- program(Program): specify a program if you don't want to use default main program. - program(Program): specify a program if you don't want to use default main program.
- clip_extra(bool): set to True if you want to clip extra information for every operator. - clip_extra(bool): the flag indicating whether to clip extra information for every operator. Default: True.
Returns: Returns:
None None
...@@ -534,7 +532,7 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor, ...@@ -534,7 +532,7 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor,
_check_vars('fetch_vars', fetch_vars) _check_vars('fetch_vars', fetch_vars)
program = _get_valid_program(kwargs.get('program', None)) program = _get_valid_program(kwargs.get('program', None))
clip_extra = kwargs.get('clip_extra', False) clip_extra = kwargs.get('clip_extra', True)
program = normalize_program(program, feed_vars, fetch_vars) program = normalize_program(program, feed_vars, fetch_vars)
# serialize and save program # serialize and save program
program_bytes = _serialize_program( program_bytes = _serialize_program(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册