From d7cfee9b315bd5a54a07b388b2b3c2dedd55a63a Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Wed, 11 Nov 2020 11:29:15 +0800 Subject: [PATCH] Checkout point add (#28488) * upgrade pass capability --- paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc | 2 +- paddle/fluid/operators/fill_constant_op.cc | 10 ++++++++++ paddle/fluid/operators/gather_op.cc | 6 +++--- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc index d74843611cd..542aadbe53d 100644 --- a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc +++ b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc @@ -394,5 +394,5 @@ REGISTER_PASS_CAPABILITY(squared_mat_sub_fuse_pass) .EQ("square", 0) .EQ("elementwise_mul", 0) .EQ("elementwise_sub", 0) - .EQ("fill_constant", 0) + .EQ("fill_constant", 1) .EQ("fusion_squared_mat_sub", 0)); diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index 35d54577bfe..cc85c295965 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/fill_constant_op.h" #include +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace operators { @@ -143,3 +144,12 @@ REGISTER_OP_CPU_KERNEL(fill_constant, ops::FillConstantKernel, ops::FillConstantKernel, ops::FillConstantKernel, ops::FillConstantKernel); + +REGISTER_OP_VERSION(fill_constant) + .AddCheckpoint( + R"ROC( + Upgrade fill_constant, add a new input [ValueTensor]. + )ROC", + paddle::framework::compatible::OpVersionDesc().NewInput( + "ValueTensor", + "In order to support new feature tensor support of Value")); diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index 72b44b22f9c..34fd11e8c0d 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -171,6 +171,6 @@ REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel, ops::GatherGradientOpKernel, ops::GatherGradientOpKernel); REGISTER_OP_VERSION(gather) - .AddCheckpoint(R"ROC(upgrad gather, add attribut [axis])ROC", - paddle::framework::compatible::OpVersionDesc().NewAttr( - "axis", "Specify the axis of gather operation.", {})); + .AddCheckpoint(R"ROC(upgrad gather, add a new input [Axis])ROC", + paddle::framework::compatible::OpVersionDesc().NewInput( + "Axis", "Specify the axis of gather operation.")); -- GitLab