未验证 提交 5a4e42ca 编写于 作者: J Jack Zhou 提交者: GitHub

add gru op_register_version; test=op_version; (#29931)

* add gru op_register_version; test=op_version;

* Update fc,mul version;test=op_version;
上级 2b1d796c
...@@ -203,11 +203,11 @@ REGISTER_PASS_CAPABILITY(mul_gru_fuse_pass) ...@@ -203,11 +203,11 @@ REGISTER_PASS_CAPABILITY(mul_gru_fuse_pass)
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("mul", 0) .EQ("mul", 0)
.EQ("gru", 0) .EQ("gru", 0)
.EQ("fusion_gru", 0)); .LE("fusion_gru", 1));
REGISTER_PASS_CAPABILITY(fc_gru_fuse_pass) REGISTER_PASS_CAPABILITY(fc_gru_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("mul", 0) .EQ("mul", 0)
.EQ("elementwise_add", 0) .EQ("elementwise_add", 0)
.EQ("gru", 0) .EQ("gru", 0)
.EQ("fusion_gru", 0)); .LE("fusion_gru", 1));
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <cstring> // for memcpy #include <cstring> // for memcpy
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/operators/math/fc.h"
...@@ -479,3 +480,13 @@ REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker); ...@@ -479,3 +480,13 @@ REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker);
REGISTER_OP_CPU_KERNEL(fusion_gru, ops::FusionGRUKernel<float>, REGISTER_OP_CPU_KERNEL(fusion_gru, ops::FusionGRUKernel<float>,
ops::FusionGRUKernel<double>); ops::FusionGRUKernel<double>);
/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(fusion_gru)
.AddCheckpoint(
R"ROC(Upgrade fusion_gru add a new attribute [Scale_weights])ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_weights",
"The added attribute 'Scale_weights' is not yet "
"registered.",
{1.0f}));
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册