未验证 提交 5a4e42ca 编写于 作者: J Jack Zhou 提交者: GitHub

add gru op_register_version; test=op_version; (#29931)

* add gru op_register_version; test=op_version;

* Update fc,mul version;test=op_version;
上级 2b1d796c
......@@ -203,11 +203,11 @@ REGISTER_PASS_CAPABILITY(mul_gru_fuse_pass)
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("mul", 0)
.EQ("gru", 0)
.EQ("fusion_gru", 0));
.LE("fusion_gru", 1));
REGISTER_PASS_CAPABILITY(fc_gru_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("mul", 0)
.EQ("elementwise_add", 0)
.EQ("gru", 0)
.EQ("fusion_gru", 0));
.LE("fusion_gru", 1));
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <cstring> // for memcpy
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc.h"
......@@ -479,3 +480,13 @@ REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker);
REGISTER_OP_CPU_KERNEL(fusion_gru, ops::FusionGRUKernel<float>,
ops::FusionGRUKernel<double>);
/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(fusion_gru)
.AddCheckpoint(
R"ROC(Upgrade fusion_gru add a new attribute [Scale_weights])ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_weights",
"The added attribute 'Scale_weights' is not yet "
"registered.",
{1.0f}));
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册