提交 e0463acf 编写于 作者: Y Yu Yang

Rename PlainNet --> NetOp

上级 ef7e76fc
...@@ -29,7 +29,5 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc. ...@@ -29,7 +29,5 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init) add_dependencies(framework_py_proto framework_py_proto_init)
proto_library(net_proto SRCS net_proto.proto DEPS op_proto) cc_library(net SRCS net.cc DEPS op_registry)
# cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op)
cc_library(net SRCS net.cc DEPS operator net_proto op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op) cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op)
...@@ -20,17 +20,7 @@ ...@@ -20,17 +20,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps) { void NetOp::CompleteAddOp(bool calc) {
auto grad_ops = std::make_shared<PlainNet>();
for (auto& op : ForwardOps->ops_) {
auto op_grad = OpRegistry::CreateGradOp(op);
grad_ops->AddOp(op_grad);
}
grad_ops->CompleteAddOp();
return grad_ops;
}
void PlainNet::CompleteAddOp(bool calc) {
add_op_done_ = true; add_op_done_ = true;
if (!calc) return; if (!calc) return;
std::unordered_set<std::string> input_set; std::unordered_set<std::string> input_set;
...@@ -70,7 +60,7 @@ void PlainNet::CompleteAddOp(bool calc) { ...@@ -70,7 +60,7 @@ void PlainNet::CompleteAddOp(bool calc) {
attrs_["temporary_index"] = tmp_index; attrs_["temporary_index"] = tmp_index;
} }
std::string PlainNet::DebugString() const { std::string NetOp::DebugString() const {
std::ostringstream os; std::ostringstream os;
os << OperatorBase::DebugString() << std::endl; os << OperatorBase::DebugString() << std::endl;
for (auto& op : ops_) { for (auto& op : ops_) {
...@@ -82,5 +72,7 @@ std::string PlainNet::DebugString() const { ...@@ -82,5 +72,7 @@ std::string PlainNet::DebugString() const {
return os.str(); return os.str();
} }
bool NetOp::IsNetOp() const { return true; }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -37,21 +37,7 @@ namespace framework { ...@@ -37,21 +37,7 @@ namespace framework {
* This is the base class of network, all the networks should implement the APIs * This is the base class of network, all the networks should implement the APIs
* it defines. * it defines.
*/ */
class Net : public OperatorBase { class NetOp : public OperatorBase {
public:
virtual void AddOp(const std::shared_ptr<OperatorBase>& op) = 0;
virtual void CompleteAddOp(bool calc) = 0;
};
using NetPtr = std::shared_ptr<Net>;
/**
* @brief a basic implementation of Net.
*
* PlainNet is a very simple Net, it create a list of operators, and run them
* sequentially following the order they added.
*/
class PlainNet : public Net {
public: public:
/** /**
* Infer all the operators' input and output variables' shapes, will be called * Infer all the operators' input and output variables' shapes, will be called
...@@ -80,15 +66,17 @@ class PlainNet : public Net { ...@@ -80,15 +66,17 @@ class PlainNet : public Net {
/** /**
* @brief Add an operator by ptr * @brief Add an operator by ptr
*/ */
void AddOp(const std::shared_ptr<OperatorBase>& op) override { void AddOp(const std::shared_ptr<OperatorBase>& op) {
PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed");
ops_.push_back(op); ops_.push_back(op);
} }
void CompleteAddOp(bool calculate = true) override; void CompleteAddOp(bool calculate = true);
std::string DebugString() const override; std::string DebugString() const override;
bool IsNetOp() const override;
std::vector<std::shared_ptr<OperatorBase>> ops_; std::vector<std::shared_ptr<OperatorBase>> ops_;
private: private:
...@@ -100,7 +88,5 @@ class PlainNet : public Net { ...@@ -100,7 +88,5 @@ class PlainNet : public Net {
} }
}; };
std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -40,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected, ...@@ -40,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
} }
TEST(OpKernel, all) { TEST(OpKernel, all) {
auto net = std::make_shared<PlainNet>(); auto net = std::make_shared<NetOp>();
ASSERT_NE(net, nullptr); ASSERT_NE(net, nullptr);
auto op1 = std::make_shared<TestOp>(); auto op1 = std::make_shared<TestOp>();
...@@ -71,28 +71,21 @@ TEST(OpKernel, all) { ...@@ -71,28 +71,21 @@ TEST(OpKernel, all) {
ASSERT_EQ(2, run_cnt); ASSERT_EQ(2, run_cnt);
ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet); ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet);
} }
TEST(AddBackwardOp, TestGradOp) {
auto net = std::make_shared<PlainNet>();
ASSERT_NE(net, nullptr);
net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {}));
net->AddOp(
framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {}));
net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""}, {}));
auto grad_ops = AddBackwardOp(net);
for (auto& op : grad_ops->ops_) {
op->DebugString();
}
}
// TODO(zhihong): add fc grad without registering. //! TODO(yuyang18): Refine Backward Op.
// TEST(AddBackwardOp, TestNoGradOp) { // TEST(AddBackwardOp, TestGradOp) {
// auto net = std::make_shared<PlainNet>(); // auto net = std::make_shared<NetOp>();
// ASSERT_NE(net, nullptr); // ASSERT_NE(net, nullptr);
// net->AddOp(framework::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Y"}, // net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {}));
// {})); auto grad_ops = AddBackwardOp(net); for (auto& op : grad_ops->ops_) { // net->AddOp(
// op->DebugString(); // framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {}));
// } // net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""},
// } // {}));
// auto grad_ops = AddBackwardOp(net);
// for (auto& op : grad_ops->ops_) {
// op->DebugString();
// }
//}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
syntax="proto2";
package paddle.framework;
import "op_proto.proto";
message NetDesc {
// network identification
optional string name = 1;
// operator contains in network
repeated OpProto operators = 2;
// network type to run with. e.g "plainNet", "DAG"
optional string net_type = 3;
// num worker always
optional int32 num_workers = 4;
}
...@@ -90,15 +90,17 @@ class OperatorBase { ...@@ -90,15 +90,17 @@ class OperatorBase {
virtual void Run(const std::shared_ptr<Scope>& scope, virtual void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const = 0; const platform::DeviceContext& dev_ctx) const = 0;
// Get a input with argument's name described in `op_proto` virtual bool IsNetOp() const { return false; }
//! Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const; const std::string& Input(const std::string& name) const;
// Get a input which has multiple variables. //! Get a input which has multiple variables.
// TODO add a vector_view to prevent memory copy. //! TODO add a vector_view to prevent memory copy.
std::vector<std::string> Inputs(const std::string& name) const; std::vector<std::string> Inputs(const std::string& name) const;
// Get a output with argument's name described in `op_proto` //! Get a output with argument's name described in `op_proto`
const std::string& Output(const std::string& name) const; const std::string& Output(const std::string& name) const;
// Get an output which has multiple variables. //! Get an output which has multiple variables.
// TODO add a vector_view to prevent memory copy. //! TODO add a vector_view to prevent memory copy.
std::vector<std::string> Outputs(const std::string& name) const; std::vector<std::string> Outputs(const std::string& name) const;
public: public:
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class FullyConnectedOp : public PlainNet { class FullyConnectedOp : public NetOp {
public: public:
void Init() override { void Init() override {
AddOp(OpRegistry::CreateOp("mul", AddOp(OpRegistry::CreateOp("mul",
......
...@@ -43,7 +43,7 @@ using OpProto = framework::OpProto; ...@@ -43,7 +43,7 @@ using OpProto = framework::OpProto;
using OpAttrChecker = framework::OpAttrChecker; using OpAttrChecker = framework::OpAttrChecker;
using CPUPlace = platform::CPUPlace; using CPUPlace = platform::CPUPlace;
using GPUPlace = platform::GPUPlace; using GPUPlace = platform::GPUPlace;
using PlainNet = framework::PlainNet; using NetOp = framework::NetOp;
using OpRegistry = framework::OpRegistry; using OpRegistry = framework::OpRegistry;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -146,22 +146,22 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -146,22 +146,22 @@ All parameter, weight, gradient are variables in Paddle.
}); });
ExposeOperator(operator_base); ExposeOperator(operator_base);
using PlainNetPtr = std::shared_ptr<pd::PlainNet>; py::class_<pd::NetOp, std::shared_ptr<pd::NetOp>> net(m, "Net");
py::class_<pd::PlainNet, PlainNetPtr> net(m, "Net");
net.def_static("create", net.def_static("create",
[]() -> std::shared_ptr<pd::PlainNet> { []() -> std::shared_ptr<pd::NetOp> {
auto retv = std::make_shared<pd::PlainNet>(); auto retv = std::make_shared<pd::NetOp>();
retv->type_ = "plain_net"; retv->type_ = "plain_net";
return retv; return retv;
}) })
.def("add_op", &pd::PlainNet::AddOp) .def("add_op", &pd::NetOp::AddOp)
.def("add_op", .def("add_op",
[](PlainNetPtr& self, const PlainNetPtr& net) -> void { [](pd::NetOp& self, const std::shared_ptr<pd::NetOp>& net) -> void {
self->AddOp(std::static_pointer_cast<pd::OperatorBase>(net)); self.AddOp(std::static_pointer_cast<pd::OperatorBase>(net));
}) })
.def("complete_add_op", &pd::PlainNet::CompleteAddOp) .def("complete_add_op", &pd::NetOp::CompleteAddOp)
.def("complete_add_op", [](PlainNetPtr& self) { self->CompleteAddOp(); }); .def("complete_add_op",
[](std::shared_ptr<pd::NetOp>& self) { self->CompleteAddOp(); });
ExposeOperator(net); ExposeOperator(net);
m.def("unique_integer", UniqueIntegerGenerator); m.def("unique_integer", UniqueIntegerGenerator);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册