diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 433edbfda742d3be9915eade7b0a455398a501dc..a29a81c994c8168afe4b2088d0f7cc6ad3807c5b 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -29,7 +29,5 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc. add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_dependencies(framework_py_proto framework_py_proto_init) -proto_library(net_proto SRCS net_proto.proto DEPS op_proto) -# cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op) -cc_library(net SRCS net.cc DEPS operator net_proto op_registry) +cc_library(net SRCS net.cc DEPS op_registry) cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op) diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index bc23b63b35d37eea01ae6b9b8891e9cd94615898..2cd378c6b21303d1a24206ba3010b0d035aaa766 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -20,17 +20,7 @@ namespace paddle { namespace framework { -std::shared_ptr AddBackwardOp(std::shared_ptr ForwardOps) { - auto grad_ops = std::make_shared(); - for (auto& op : ForwardOps->ops_) { - auto op_grad = OpRegistry::CreateGradOp(op); - grad_ops->AddOp(op_grad); - } - grad_ops->CompleteAddOp(); - return grad_ops; -} - -void PlainNet::CompleteAddOp(bool calc) { +void NetOp::CompleteAddOp(bool calc) { add_op_done_ = true; if (!calc) return; std::unordered_set input_set; @@ -70,7 +60,7 @@ void PlainNet::CompleteAddOp(bool calc) { attrs_["temporary_index"] = tmp_index; } -std::string PlainNet::DebugString() const { +std::string NetOp::DebugString() const { std::ostringstream os; os << OperatorBase::DebugString() << std::endl; for (auto& op : ops_) { @@ -82,5 +72,7 @@ std::string PlainNet::DebugString() const { return os.str(); } +bool NetOp::IsNetOp() const { return true; } + } // namespace framework } // namespace paddle diff --git a/paddle/framework/net.h b/paddle/framework/net.h index 3264f1f565e3efc188e7835cb9b44e5741e1eea8..089c1355951f59d51db16d4b4bdce4282d6e5c25 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -37,21 +37,7 @@ namespace framework { * This is the base class of network, all the networks should implement the APIs * it defines. */ -class Net : public OperatorBase { - public: - virtual void AddOp(const std::shared_ptr& op) = 0; - virtual void CompleteAddOp(bool calc) = 0; -}; - -using NetPtr = std::shared_ptr; - -/** - * @brief a basic implementation of Net. - * - * PlainNet is a very simple Net, it create a list of operators, and run them - * sequentially following the order they added. - */ -class PlainNet : public Net { +class NetOp : public OperatorBase { public: /** * Infer all the operators' input and output variables' shapes, will be called @@ -80,15 +66,17 @@ class PlainNet : public Net { /** * @brief Add an operator by ptr */ - void AddOp(const std::shared_ptr& op) override { + void AddOp(const std::shared_ptr& op) { PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); ops_.push_back(op); } - void CompleteAddOp(bool calculate = true) override; + void CompleteAddOp(bool calculate = true); std::string DebugString() const override; + bool IsNetOp() const override; + std::vector> ops_; private: @@ -100,7 +88,5 @@ class PlainNet : public Net { } }; -std::shared_ptr AddBackwardOp(std::shared_ptr ForwardOps); - } // namespace framework } // namespace paddle diff --git a/paddle/framework/net_op_test.cc b/paddle/framework/net_op_test.cc index d924058624bf334b015797c4e4f882db10203049..8048311fe54ee1827fb5b91577478a1d30803e43 100644 --- a/paddle/framework/net_op_test.cc +++ b/paddle/framework/net_op_test.cc @@ -40,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector& expected, } TEST(OpKernel, all) { - auto net = std::make_shared(); + auto net = std::make_shared(); ASSERT_NE(net, nullptr); auto op1 = std::make_shared(); @@ -71,28 +71,21 @@ TEST(OpKernel, all) { ASSERT_EQ(2, run_cnt); ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet); } -TEST(AddBackwardOp, TestGradOp) { - auto net = std::make_shared(); - ASSERT_NE(net, nullptr); - net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {})); - net->AddOp( - framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {})); - net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""}, {})); - auto grad_ops = AddBackwardOp(net); - for (auto& op : grad_ops->ops_) { - op->DebugString(); - } -} -// TODO(zhihong): add fc grad without registering. -// TEST(AddBackwardOp, TestNoGradOp) { -// auto net = std::make_shared(); -// ASSERT_NE(net, nullptr); -// net->AddOp(framework::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Y"}, -// {})); auto grad_ops = AddBackwardOp(net); for (auto& op : grad_ops->ops_) { -// op->DebugString(); -// } -// } +//! TODO(yuyang18): Refine Backward Op. +// TEST(AddBackwardOp, TestGradOp) { +// auto net = std::make_shared(); +// ASSERT_NE(net, nullptr); +// net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {})); +// net->AddOp( +// framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {})); +// net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""}, +// {})); +// auto grad_ops = AddBackwardOp(net); +// for (auto& op : grad_ops->ops_) { +// op->DebugString(); +// } +//} } // namespace framework } // namespace paddle diff --git a/paddle/framework/net_proto.proto b/paddle/framework/net_proto.proto deleted file mode 100644 index 0779f49fe2a9a6d0d1ea5ec11ba3befeb0a67fa1..0000000000000000000000000000000000000000 --- a/paddle/framework/net_proto.proto +++ /dev/null @@ -1,15 +0,0 @@ -syntax="proto2"; -package paddle.framework; - -import "op_proto.proto"; - -message NetDesc { - // network identification - optional string name = 1; - // operator contains in network - repeated OpProto operators = 2; - // network type to run with. e.g "plainNet", "DAG" - optional string net_type = 3; - // num worker always - optional int32 num_workers = 4; -} diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index f59314f8288d37f0c645b99811b1355f9a496c00..65fddb68112d70bda6d1462fd83561cfae657b6d 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -90,15 +90,17 @@ class OperatorBase { virtual void Run(const std::shared_ptr& scope, const platform::DeviceContext& dev_ctx) const = 0; - // Get a input with argument's name described in `op_proto` + virtual bool IsNetOp() const { return false; } + + //! Get a input with argument's name described in `op_proto` const std::string& Input(const std::string& name) const; - // Get a input which has multiple variables. - // TODO add a vector_view to prevent memory copy. + //! Get a input which has multiple variables. + //! TODO add a vector_view to prevent memory copy. std::vector Inputs(const std::string& name) const; - // Get a output with argument's name described in `op_proto` + //! Get a output with argument's name described in `op_proto` const std::string& Output(const std::string& name) const; - // Get an output which has multiple variables. - // TODO add a vector_view to prevent memory copy. + //! Get an output which has multiple variables. + //! TODO add a vector_view to prevent memory copy. std::vector Outputs(const std::string& name) const; public: diff --git a/paddle/operators/fc_op.cc b/paddle/operators/fc_op.cc index 40ff2f41dda0f2d092b4edd4b3d4665a99109395..c4a9f5937f4fa8c60989bea1726cedbb73330156 100644 --- a/paddle/operators/fc_op.cc +++ b/paddle/operators/fc_op.cc @@ -17,7 +17,7 @@ namespace paddle { namespace operators { -class FullyConnectedOp : public PlainNet { +class FullyConnectedOp : public NetOp { public: void Init() override { AddOp(OpRegistry::CreateOp("mul", diff --git a/paddle/operators/type_alias.h b/paddle/operators/type_alias.h index 44ffefb2993ff4bd110e003990dfe04fb752db5f..b712e457ff60e8b30b87c0d549693d53e9f05d59 100644 --- a/paddle/operators/type_alias.h +++ b/paddle/operators/type_alias.h @@ -43,7 +43,7 @@ using OpProto = framework::OpProto; using OpAttrChecker = framework::OpAttrChecker; using CPUPlace = platform::CPUPlace; using GPUPlace = platform::GPUPlace; -using PlainNet = framework::PlainNet; +using NetOp = framework::NetOp; using OpRegistry = framework::OpRegistry; } // namespace operators } // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 0b152d03c0641113370fd634aed05ce6b3fb6cf2..ccefcd2511ca0f132d127463166f9f40779a1d85 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -146,22 +146,22 @@ All parameter, weight, gradient are variables in Paddle. }); ExposeOperator(operator_base); - using PlainNetPtr = std::shared_ptr; - py::class_ net(m, "Net"); + py::class_> net(m, "Net"); net.def_static("create", - []() -> std::shared_ptr { - auto retv = std::make_shared(); + []() -> std::shared_ptr { + auto retv = std::make_shared(); retv->type_ = "plain_net"; return retv; }) - .def("add_op", &pd::PlainNet::AddOp) + .def("add_op", &pd::NetOp::AddOp) .def("add_op", - [](PlainNetPtr& self, const PlainNetPtr& net) -> void { - self->AddOp(std::static_pointer_cast(net)); + [](pd::NetOp& self, const std::shared_ptr& net) -> void { + self.AddOp(std::static_pointer_cast(net)); }) - .def("complete_add_op", &pd::PlainNet::CompleteAddOp) - .def("complete_add_op", [](PlainNetPtr& self) { self->CompleteAddOp(); }); + .def("complete_add_op", &pd::NetOp::CompleteAddOp) + .def("complete_add_op", + [](std::shared_ptr& self) { self->CompleteAddOp(); }); ExposeOperator(net); m.def("unique_integer", UniqueIntegerGenerator);