From 7e2e2ee71b3b21a92f3f54cc4708b741bf29e0e7 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 28 Sep 2022 09:14:49 +0800 Subject: [PATCH] Fix clip_extra logic in remove_training_info (#46534) * fix clip_extra code in remove_training_info * revert rnn opmaker clear --- paddle/fluid/operators/rnn_op.cc | 3 +++ paddle/phi/api/yaml/op_compat.yaml | 10 +++++----- python/paddle/fluid/framework.py | 4 +++- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/rnn_op.cc b/paddle/fluid/operators/rnn_op.cc index 4a97afdfc4a..aba720a99ba 100644 --- a/paddle/fluid/operators/rnn_op.cc +++ b/paddle/fluid/operators/rnn_op.cc @@ -103,6 +103,9 @@ class RNNOpMaker : public framework::OpProtoAndCheckerMaker { "mode", "(string) rnn types, including: LSTM, GRU, RNN_RELU, RNN_TANH."); AddAttr("seed", "seed to used if fix_seed is True").SetDefault(0); + AddAttr("is_test", "True if in test phase.") + .SetDefault(false) + .AsExtra(); AddComment(R"DOC( )DOC"); } diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 613f2d38aee..bbaa9425201 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1,3 +1,8 @@ +# - op : rnn +# backward : rnn_grad +# extra : +# attrs : [bool is_test = false] + - op : abs backward : abs_grad extra : @@ -609,11 +614,6 @@ extra : attrs : [bool use_mkldnn = false, bool use_cudnn = false] -- op : rnn - backward : rnn_grad - extra : - attrs : [bool is_test = false] - - op : round backward : round_grad extra : diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 08ee9893553..e62cb956d98 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -5932,6 +5932,8 @@ class Program(object): "activation_bits", "bit_length", "quantize_weight_bits", "weight_quant_scale" ] + for extra_attr_name in extra_attrs_map.keys(): + op.remove_attr(extra_attr_name) remove_attr_list = [] for name in op.attr_names(): if quant: @@ -5940,7 +5942,7 @@ class Program(object): if name.endswith("_threshold"): continue if len(extra_attrs_map) > 0: - if name in extra_attrs_map or name in common_clipped_attrs_list: + if name in common_clipped_attrs_list: op.remove_attr(name) continue find = False -- GitLab