diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 7febaaa52729bf8af808d10aab5d74f989f895cb..c9a50d8968e56885e100cf287b8a0417ce526eb0 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -30,7 +30,7 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch add_dependencies(framework_py_proto framework_py_proto_init) cc_library(net SRCS net.cc DEPS op_registry) -cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op) +cc_test(net_op_test SRCS net_op_test.cc DEPS net) cc_library(backward SRCS backward.cc DEPS net) cc_test(backward_test SRCS backward_test.cc DEPS backward) diff --git a/paddle/framework/net.h b/paddle/framework/net.h index 089c1355951f59d51db16d4b4bdce4282d6e5c25..b584dd578f3b39af55c1be215a23c9ac46424fcb 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -68,9 +68,18 @@ class NetOp : public OperatorBase { */ void AddOp(const std::shared_ptr& op) { PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); + PADDLE_ENFORCE(op != nullptr, "Cannot Insert Null op"); ops_.push_back(op); } + void InsertOp(size_t pos, const std::shared_ptr& op) { + PADDLE_ENFORCE(!add_op_done_, + "Cannot InsertOp when this network is sealed"); + PADDLE_ENFORCE(op != nullptr, "Cannot Insert Null op"); + PADDLE_ENFORCE(pos <= ops_.size(), "Out of range"); + ops_.insert(ops_.begin() + pos, op); + } + void CompleteAddOp(bool calculate = true); std::string DebugString() const override; diff --git a/paddle/framework/net_op_test.cc b/paddle/framework/net_op_test.cc index 8048311fe54ee1827fb5b91577478a1d30803e43..4b733e958e48a2e5f072d90f8e83c430bc8251d9 100644 --- a/paddle/framework/net_op_test.cc +++ b/paddle/framework/net_op_test.cc @@ -3,11 +3,6 @@ #include #include -USE_OP(add_two); -USE_OP(mul); -USE_OP(sigmoid); -USE_OP(softmax); - namespace paddle { namespace framework { @@ -26,6 +21,13 @@ class TestOp : public OperatorBase { } }; +class EmptyOp : public OperatorBase { + public: + void InferShape(const std::shared_ptr& scope) const override {} + void Run(const std::shared_ptr& scope, + const platform::DeviceContext& dev_ctx) const override {} +}; + template void AssertSameVectorWithoutOrder(const std::vector& expected, const std::vector& actual) { @@ -72,20 +74,17 @@ TEST(OpKernel, all) { ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet); } -//! TODO(yuyang18): Refine Backward Op. -// TEST(AddBackwardOp, TestGradOp) { -// auto net = std::make_shared(); -// ASSERT_NE(net, nullptr); -// net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {})); -// net->AddOp( -// framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {})); -// net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""}, -// {})); -// auto grad_ops = AddBackwardOp(net); -// for (auto& op : grad_ops->ops_) { -// op->DebugString(); -// } -//} +TEST(Net, insert_op) { + NetOp net; + auto op1 = std::make_shared(); + op1->inputs_ = {"x", "w1", "b1"}; + op1->outputs_ = {"y"}; + net.AddOp(op1); + net.InsertOp(0, op1); + ASSERT_EQ(2UL, net.ops_.size()); + net.InsertOp(2, op1); + ASSERT_EQ(3UL, net.ops_.size()); +} } // namespace framework } // namespace paddle