提交 b1b43645 编写于 作者: Y Yu Yang

Rename PlainNet --> NetOp

上级 ef7e76fc
......@@ -29,7 +29,5 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init)
proto_library(net_proto SRCS net_proto.proto DEPS op_proto)
# cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op)
cc_library(net SRCS net.cc DEPS operator net_proto op_registry)
cc_library(net SRCS net.cc DEPS op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op)
......@@ -20,17 +20,7 @@
namespace paddle {
namespace framework {
std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps) {
auto grad_ops = std::make_shared<PlainNet>();
for (auto& op : ForwardOps->ops_) {
auto op_grad = OpRegistry::CreateGradOp(op);
grad_ops->AddOp(op_grad);
}
grad_ops->CompleteAddOp();
return grad_ops;
}
void PlainNet::CompleteAddOp(bool calc) {
void NetOp::CompleteAddOp(bool calc) {
add_op_done_ = true;
if (!calc) return;
std::unordered_set<std::string> input_set;
......@@ -70,7 +60,7 @@ void PlainNet::CompleteAddOp(bool calc) {
attrs_["temporary_index"] = tmp_index;
}
std::string PlainNet::DebugString() const {
std::string NetOp::DebugString() const {
std::ostringstream os;
os << OperatorBase::DebugString() << std::endl;
for (auto& op : ops_) {
......@@ -82,5 +72,7 @@ std::string PlainNet::DebugString() const {
return os.str();
}
bool NetOp::IsNetOp() const { return true; }
} // namespace framework
} // namespace paddle
......@@ -37,21 +37,7 @@ namespace framework {
* This is the base class of network, all the networks should implement the APIs
* it defines.
*/
class Net : public OperatorBase {
public:
virtual void AddOp(const std::shared_ptr<OperatorBase>& op) = 0;
virtual void CompleteAddOp(bool calc) = 0;
};
using NetPtr = std::shared_ptr<Net>;
/**
* @brief a basic implementation of Net.
*
* PlainNet is a very simple Net, it create a list of operators, and run them
* sequentially following the order they added.
*/
class PlainNet : public Net {
class NetOp : public OperatorBase {
public:
/**
* Infer all the operators' input and output variables' shapes, will be called
......@@ -80,15 +66,17 @@ class PlainNet : public Net {
/**
* @brief Add an operator by ptr
*/
void AddOp(const std::shared_ptr<OperatorBase>& op) override {
void AddOp(const std::shared_ptr<OperatorBase>& op) {
PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed");
ops_.push_back(op);
}
void CompleteAddOp(bool calculate = true) override;
void CompleteAddOp(bool calculate = true);
std::string DebugString() const override;
bool IsNetOp() const override;
std::vector<std::shared_ptr<OperatorBase>> ops_;
private:
......@@ -100,7 +88,5 @@ class PlainNet : public Net {
}
};
std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps);
} // namespace framework
} // namespace paddle
......@@ -40,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
}
TEST(OpKernel, all) {
auto net = std::make_shared<PlainNet>();
auto net = std::make_shared<NetOp>();
ASSERT_NE(net, nullptr);
auto op1 = std::make_shared<TestOp>();
......@@ -71,28 +71,21 @@ TEST(OpKernel, all) {
ASSERT_EQ(2, run_cnt);
ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet);
}
TEST(AddBackwardOp, TestGradOp) {
auto net = std::make_shared<PlainNet>();
ASSERT_NE(net, nullptr);
net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {}));
net->AddOp(
framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {}));
net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""}, {}));
auto grad_ops = AddBackwardOp(net);
for (auto& op : grad_ops->ops_) {
op->DebugString();
}
}
// TODO(zhihong): add fc grad without registering.
// TEST(AddBackwardOp, TestNoGradOp) {
// auto net = std::make_shared<PlainNet>();
// ASSERT_NE(net, nullptr);
// net->AddOp(framework::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Y"},
// {})); auto grad_ops = AddBackwardOp(net); for (auto& op : grad_ops->ops_) {
// op->DebugString();
// }
// }
//! TODO(yuyang18): Refine Backward Op.
// TEST(AddBackwardOp, TestGradOp) {
// auto net = std::make_shared<NetOp>();
// ASSERT_NE(net, nullptr);
// net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {}));
// net->AddOp(
// framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {}));
// net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""},
// {}));
// auto grad_ops = AddBackwardOp(net);
// for (auto& op : grad_ops->ops_) {
// op->DebugString();
// }
//}
} // namespace framework
} // namespace paddle
syntax="proto2";
package paddle.framework;
import "op_proto.proto";
message NetDesc {
// network identification
optional string name = 1;
// operator contains in network
repeated OpProto operators = 2;
// network type to run with. e.g "plainNet", "DAG"
optional string net_type = 3;
// num worker always
optional int32 num_workers = 4;
}
......@@ -90,15 +90,17 @@ class OperatorBase {
virtual void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const = 0;
// Get a input with argument's name described in `op_proto`
virtual bool IsNetOp() const { return false; }
//! Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const;
// Get a input which has multiple variables.
// TODO add a vector_view to prevent memory copy.
//! Get a input which has multiple variables.
//! TODO add a vector_view to prevent memory copy.
std::vector<std::string> Inputs(const std::string& name) const;
// Get a output with argument's name described in `op_proto`
//! Get a output with argument's name described in `op_proto`
const std::string& Output(const std::string& name) const;
// Get an output which has multiple variables.
// TODO add a vector_view to prevent memory copy.
//! Get an output which has multiple variables.
//! TODO add a vector_view to prevent memory copy.
std::vector<std::string> Outputs(const std::string& name) const;
public:
......
......@@ -17,7 +17,7 @@
namespace paddle {
namespace operators {
class FullyConnectedOp : public PlainNet {
class FullyConnectedOp : public NetOp {
public:
void Init() override {
AddOp(OpRegistry::CreateOp("mul",
......
......@@ -43,7 +43,7 @@ using OpProto = framework::OpProto;
using OpAttrChecker = framework::OpAttrChecker;
using CPUPlace = platform::CPUPlace;
using GPUPlace = platform::GPUPlace;
using PlainNet = framework::PlainNet;
using NetOp = framework::NetOp;
using OpRegistry = framework::OpRegistry;
} // namespace operators
} // namespace paddle
......
......@@ -146,22 +146,22 @@ All parameter, weight, gradient are variables in Paddle.
});
ExposeOperator(operator_base);
using PlainNetPtr = std::shared_ptr<pd::PlainNet>;
py::class_<pd::PlainNet, PlainNetPtr> net(m, "Net");
py::class_<pd::NetOp, std::shared_ptr<pd::NetOp>> net(m, "Net");
net.def_static("create",
[]() -> std::shared_ptr<pd::PlainNet> {
auto retv = std::make_shared<pd::PlainNet>();
[]() -> std::shared_ptr<pd::NetOp> {
auto retv = std::make_shared<pd::NetOp>();
retv->type_ = "plain_net";
return retv;
})
.def("add_op", &pd::PlainNet::AddOp)
.def("add_op", &pd::NetOp::AddOp)
.def("add_op",
[](PlainNetPtr& self, const PlainNetPtr& net) -> void {
self->AddOp(std::static_pointer_cast<pd::OperatorBase>(net));
[](pd::NetOp& self, const std::shared_ptr<pd::NetOp>& net) -> void {
self.AddOp(std::static_pointer_cast<pd::OperatorBase>(net));
})
.def("complete_add_op", &pd::PlainNet::CompleteAddOp)
.def("complete_add_op", [](PlainNetPtr& self) { self->CompleteAddOp(); });
.def("complete_add_op", &pd::NetOp::CompleteAddOp)
.def("complete_add_op",
[](std::shared_ptr<pd::NetOp>& self) { self->CompleteAddOp(); });
ExposeOperator(net);
m.def("unique_integer", UniqueIntegerGenerator);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册