From 5a4e42ca9a78e6f8eb3daab97f41dcbf59780955 Mon Sep 17 00:00:00 2001 From: Jack Zhou Date: Mon, 28 Dec 2020 19:59:38 +0800 Subject: [PATCH] add gru op_register_version; test=op_version; (#29931) * add gru op_register_version; test=op_version; * Update fc,mul version;test=op_version; --- paddle/fluid/framework/ir/fc_gru_fuse_pass.cc | 4 ++-- paddle/fluid/operators/fused/fusion_gru_op.cc | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc index c4515bbc455..fe347d6a45d 100644 --- a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc @@ -203,11 +203,11 @@ REGISTER_PASS_CAPABILITY(mul_gru_fuse_pass) paddle::framework::compatible::OpVersionComparatorCombination() .EQ("mul", 0) .EQ("gru", 0) - .EQ("fusion_gru", 0)); + .LE("fusion_gru", 1)); REGISTER_PASS_CAPABILITY(fc_gru_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() .EQ("mul", 0) .EQ("elementwise_add", 0) .EQ("gru", 0) - .EQ("fusion_gru", 0)); + .LE("fusion_gru", 1)); diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index f5904039d4b..9578cc247da 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include // for memcpy #include #include +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/fc.h" @@ -479,3 +480,13 @@ REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker); REGISTER_OP_CPU_KERNEL(fusion_gru, ops::FusionGRUKernel, ops::FusionGRUKernel); + +/* ========================== register checkpoint ===========================*/ +REGISTER_OP_VERSION(fusion_gru) + .AddCheckpoint( + R"ROC(Upgrade fusion_gru add a new attribute [Scale_weights])ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "Scale_weights", + "The added attribute 'Scale_weights' is not yet " + "registered.", + {1.0f})); -- GitLab