diff --git a/paddle/framework/grad_op_creator_test.cc b/paddle/framework/grad_op_creator_test.cc index ad836727c355f2223c3a3e2802fff0cae94f789c..27ac65813120a2a682535a02bcecb882c4a7640d 100644 --- a/paddle/framework/grad_op_creator_test.cc +++ b/paddle/framework/grad_op_creator_test.cc @@ -9,8 +9,9 @@ namespace paddle { namespace framework { TEST(GradOpCreator, AddTwo) { - OperatorPtr add_op(OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {})); - OperatorPtr grad_add_op = OpRegistry::CreateGradOp(add_op); + std::shared_ptr add_op( + OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {})); + std::shared_ptr grad_add_op = OpRegistry::CreateGradOp(add_op); EXPECT_EQ(static_cast(grad_add_op->inputs_.size()), 4); EXPECT_EQ(static_cast(grad_add_op->outputs_.size()), 2); EXPECT_EQ(grad_add_op->Input("X"), "x"); diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 5c8b86c8a941f71bc741a0d2c64e2a1ebd28fc57..41c78309327342ff47982fc105eadf777c7e59c7 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -298,9 +298,10 @@ class OpRegistry { return CreateOp(op_desc.type(), inputs, outputs, attrs); } - static OperatorPtr CreateGradOp(OperatorPtr op) { + static std::shared_ptr CreateGradOp( + std::shared_ptr op) { GradOpCreator creator(op.get()); - OperatorPtr grad_op(creator.Create()); + std::shared_ptr grad_op(creator.Create()); grad_op->Init(); return grad_op; }