diff --git a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc index c4515bbc45538ca211382aa119bbec5721c56c5a..fe347d6a45d0f499c4ec0d673a42b210a5a27769 100644 --- a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc @@ -203,11 +203,11 @@ REGISTER_PASS_CAPABILITY(mul_gru_fuse_pass) paddle::framework::compatible::OpVersionComparatorCombination() .EQ("mul", 0) .EQ("gru", 0) - .EQ("fusion_gru", 0)); + .LE("fusion_gru", 1)); REGISTER_PASS_CAPABILITY(fc_gru_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() .EQ("mul", 0) .EQ("elementwise_add", 0) .EQ("gru", 0) - .EQ("fusion_gru", 0)); + .LE("fusion_gru", 1)); diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index f5904039d4b6ef9794991687c535a0989864e9f6..9578cc247daaa691aa7f101fd404c5dbef39dc8f 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include // for memcpy #include #include +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/fc.h" @@ -479,3 +480,13 @@ REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker); REGISTER_OP_CPU_KERNEL(fusion_gru, ops::FusionGRUKernel, ops::FusionGRUKernel); + +/* ========================== register checkpoint ===========================*/ +REGISTER_OP_VERSION(fusion_gru) + .AddCheckpoint( + R"ROC(Upgrade fusion_gru add a new attribute [Scale_weights])ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "Scale_weights", + "The added attribute 'Scale_weights' is not yet " + "registered.", + {1.0f}));