提交 d2583bd4 编写于 作者: Y Yu Yang

InsertOp for NetOp

上级 84198f75
...@@ -30,7 +30,7 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch ...@@ -30,7 +30,7 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch
add_dependencies(framework_py_proto framework_py_proto_init) add_dependencies(framework_py_proto framework_py_proto_init)
cc_library(net SRCS net.cc DEPS op_registry) cc_library(net SRCS net.cc DEPS op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op) cc_test(net_op_test SRCS net_op_test.cc DEPS net)
cc_library(backward SRCS backward.cc DEPS net) cc_library(backward SRCS backward.cc DEPS net)
cc_test(backward_test SRCS backward_test.cc DEPS backward) cc_test(backward_test SRCS backward_test.cc DEPS backward)
...@@ -68,9 +68,18 @@ class NetOp : public OperatorBase { ...@@ -68,9 +68,18 @@ class NetOp : public OperatorBase {
*/ */
void AddOp(const std::shared_ptr<OperatorBase>& op) { void AddOp(const std::shared_ptr<OperatorBase>& op) {
PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed");
PADDLE_ENFORCE(op != nullptr, "Cannot Insert Null op");
ops_.push_back(op); ops_.push_back(op);
} }
void InsertOp(size_t pos, const std::shared_ptr<OperatorBase>& op) {
PADDLE_ENFORCE(!add_op_done_,
"Cannot InsertOp when this network is sealed");
PADDLE_ENFORCE(op != nullptr, "Cannot Insert Null op");
PADDLE_ENFORCE(pos <= ops_.size(), "Out of range");
ops_.insert(ops_.begin() + pos, op);
}
void CompleteAddOp(bool calculate = true); void CompleteAddOp(bool calculate = true);
std::string DebugString() const override; std::string DebugString() const override;
......
...@@ -3,11 +3,6 @@ ...@@ -3,11 +3,6 @@
#include <paddle/framework/op_registry.h> #include <paddle/framework/op_registry.h>
#include <paddle/framework/operator.h> #include <paddle/framework/operator.h>
USE_OP(add_two);
USE_OP(mul);
USE_OP(sigmoid);
USE_OP(softmax);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -26,6 +21,13 @@ class TestOp : public OperatorBase { ...@@ -26,6 +21,13 @@ class TestOp : public OperatorBase {
} }
}; };
class EmptyOp : public OperatorBase {
public:
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {}
};
template <typename T> template <typename T>
void AssertSameVectorWithoutOrder(const std::vector<T>& expected, void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
const std::vector<T>& actual) { const std::vector<T>& actual) {
...@@ -72,20 +74,17 @@ TEST(OpKernel, all) { ...@@ -72,20 +74,17 @@ TEST(OpKernel, all) {
ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet); ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet);
} }
//! TODO(yuyang18): Refine Backward Op. TEST(Net, insert_op) {
// TEST(AddBackwardOp, TestGradOp) { NetOp net;
// auto net = std::make_shared<NetOp>(); auto op1 = std::make_shared<EmptyOp>();
// ASSERT_NE(net, nullptr); op1->inputs_ = {"x", "w1", "b1"};
// net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {})); op1->outputs_ = {"y"};
// net->AddOp( net.AddOp(op1);
// framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {})); net.InsertOp(0, op1);
// net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""}, ASSERT_EQ(2UL, net.ops_.size());
// {})); net.InsertOp(2, op1);
// auto grad_ops = AddBackwardOp(net); ASSERT_EQ(3UL, net.ops_.size());
// for (auto& op : grad_ops->ops_) { }
// op->DebugString();
// }
//}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册