From a5e422c85dc01f2a4084dca89495120c80cc8660 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Tue, 5 Jan 2021 16:56:32 +0800 Subject: [PATCH] add trace op_register_version and fix version bug; test=op_version (#30000) * add trace op_register_version and fix defaulf bug; test=op_version * add trace op_register_version; test=op_version * add trace op_register_version; test=op_version * add trace op_register_version; test=op_version * fix missing the template bug of vector; test=op_version --- paddle/fluid/operators/trace_op.cc | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/trace_op.cc b/paddle/fluid/operators/trace_op.cc index e90cf2054f7..50f9f0b9f4d 100644 --- a/paddle/fluid/operators/trace_op.cc +++ b/paddle/fluid/operators/trace_op.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/operators/trace_op.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace operators { @@ -88,13 +89,13 @@ class TraceOpMaker : public framework::OpProtoAndCheckerMaker { R"DOC((int, default 0), the first axis of the 2-D planes from which the diagonals should be taken. Can be either positive or negative. Default: 0. )DOC") - .SetDefault(-2); + .SetDefault(0); AddAttr( "axis2", R"DOC((int, default 1), the second axis of the 2-D planes from which the diagonals should be taken. Can be either positive or negative. Default: 1. )DOC") - .SetDefault(-1); + .SetDefault(1); AddComment(R"DOC( Trace Operator. Return the sum along diagonals of the input tensor. @@ -177,3 +178,21 @@ REGISTER_OP_CPU_KERNEL( paddle::platform::complex64>, ops::TraceGradKernel); + +/* ========================== register checkpoint ===========================*/ +REGISTER_OP_VERSION(trace) + .AddCheckpoint( + R"ROC(Upgrade trace add a new attribute [axis2])ROC", + paddle::framework::compatible::OpVersionDesc() + .NewAttr("axis1", + "The added attribute 'axis1' is not yet registered.", + std::vector{0.0f}) + .NewAttr("axis2", + "The added attribute 'axis2' is not yet registered.", + std::vector{1.0f}) + .DeleteAttr("dim1", + "The attribute 'dim1' is not recommend according to " + "the specification 2.0.") + .DeleteAttr("dim2", + "The attribute 'dim2' is not recommend according to " + "the specification 2.0.")); -- GitLab