提交 f4e25550 编写于 作者: F fengjiayi

Fix compile error

Replace `OperatorPtr` with `std::shared_ptr<OperatorBase>`
上级 5f3bc2a4
......@@ -9,8 +9,9 @@ namespace paddle {
namespace framework {
TEST(GradOpCreator, AddTwo) {
OperatorPtr add_op(OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {}));
OperatorPtr grad_add_op = OpRegistry::CreateGradOp(add_op);
std::shared_ptr<OperatorBase> add_op(
OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {}));
std::shared_ptr<OperatorBase> grad_add_op = OpRegistry::CreateGradOp(add_op);
EXPECT_EQ(static_cast<int>(grad_add_op->inputs_.size()), 4);
EXPECT_EQ(static_cast<int>(grad_add_op->outputs_.size()), 2);
EXPECT_EQ(grad_add_op->Input("X"), "x");
......
......@@ -298,9 +298,10 @@ class OpRegistry {
return CreateOp(op_desc.type(), inputs, outputs, attrs);
}
static OperatorPtr CreateGradOp(OperatorPtr op) {
static std::shared_ptr<OperatorBase> CreateGradOp(
std::shared_ptr<OperatorBase> op) {
GradOpCreator creator(op.get());
OperatorPtr grad_op(creator.Create());
std::shared_ptr<OperatorBase> grad_op(creator.Create());
grad_op->Init();
return grad_op;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册