提交 f4e25550 编写于 作者: F fengjiayi

Fix compile error

Replace `OperatorPtr` with `std::shared_ptr<OperatorBase>`
上级 5f3bc2a4
...@@ -9,8 +9,9 @@ namespace paddle { ...@@ -9,8 +9,9 @@ namespace paddle {
namespace framework { namespace framework {
TEST(GradOpCreator, AddTwo) { TEST(GradOpCreator, AddTwo) {
OperatorPtr add_op(OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {})); std::shared_ptr<OperatorBase> add_op(
OperatorPtr grad_add_op = OpRegistry::CreateGradOp(add_op); OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {}));
std::shared_ptr<OperatorBase> grad_add_op = OpRegistry::CreateGradOp(add_op);
EXPECT_EQ(static_cast<int>(grad_add_op->inputs_.size()), 4); EXPECT_EQ(static_cast<int>(grad_add_op->inputs_.size()), 4);
EXPECT_EQ(static_cast<int>(grad_add_op->outputs_.size()), 2); EXPECT_EQ(static_cast<int>(grad_add_op->outputs_.size()), 2);
EXPECT_EQ(grad_add_op->Input("X"), "x"); EXPECT_EQ(grad_add_op->Input("X"), "x");
......
...@@ -298,9 +298,10 @@ class OpRegistry { ...@@ -298,9 +298,10 @@ class OpRegistry {
return CreateOp(op_desc.type(), inputs, outputs, attrs); return CreateOp(op_desc.type(), inputs, outputs, attrs);
} }
static OperatorPtr CreateGradOp(OperatorPtr op) { static std::shared_ptr<OperatorBase> CreateGradOp(
std::shared_ptr<OperatorBase> op) {
GradOpCreator creator(op.get()); GradOpCreator creator(op.get());
OperatorPtr grad_op(creator.Create()); std::shared_ptr<OperatorBase> grad_op(creator.Create());
grad_op->Init(); grad_op->Init();
return grad_op; return grad_op;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册