未验证 提交 7e2e2ee7 编写于 作者: Z zyfncg 提交者: GitHub

Fix clip_extra logic in remove_training_info (#46534)

* fix clip_extra code in remove_training_info

* revert rnn opmaker clear
上级 7467221b
...@@ -103,6 +103,9 @@ class RNNOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -103,6 +103,9 @@ class RNNOpMaker : public framework::OpProtoAndCheckerMaker {
"mode", "mode",
"(string) rnn types, including: LSTM, GRU, RNN_RELU, RNN_TANH."); "(string) rnn types, including: LSTM, GRU, RNN_RELU, RNN_TANH.");
AddAttr<int>("seed", "seed to used if fix_seed is True").SetDefault(0); AddAttr<int>("seed", "seed to used if fix_seed is True").SetDefault(0);
AddAttr<bool>("is_test", "True if in test phase.")
.SetDefault(false)
.AsExtra();
AddComment(R"DOC( AddComment(R"DOC(
)DOC"); )DOC");
} }
......
# - op : rnn
# backward : rnn_grad
# extra :
# attrs : [bool is_test = false]
- op : abs - op : abs
backward : abs_grad backward : abs_grad
extra : extra :
...@@ -609,11 +614,6 @@ ...@@ -609,11 +614,6 @@
extra : extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false] attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- op : rnn
backward : rnn_grad
extra :
attrs : [bool is_test = false]
- op : round - op : round
backward : round_grad backward : round_grad
extra : extra :
......
...@@ -5932,6 +5932,8 @@ class Program(object): ...@@ -5932,6 +5932,8 @@ class Program(object):
"activation_bits", "bit_length", "quantize_weight_bits", "activation_bits", "bit_length", "quantize_weight_bits",
"weight_quant_scale" "weight_quant_scale"
] ]
for extra_attr_name in extra_attrs_map.keys():
op.remove_attr(extra_attr_name)
remove_attr_list = [] remove_attr_list = []
for name in op.attr_names(): for name in op.attr_names():
if quant: if quant:
...@@ -5940,7 +5942,7 @@ class Program(object): ...@@ -5940,7 +5942,7 @@ class Program(object):
if name.endswith("_threshold"): if name.endswith("_threshold"):
continue continue
if len(extra_attrs_map) > 0: if len(extra_attrs_map) > 0:
if name in extra_attrs_map or name in common_clipped_attrs_list: if name in common_clipped_attrs_list:
op.remove_attr(name) op.remove_attr(name)
continue continue
find = False find = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册