提交 65d26787 编写于 作者: D dongzhihong

"add simple net test"

上级 404cc056
...@@ -86,8 +86,6 @@ static std::shared_ptr<OperatorBase> BackwardImpl( ...@@ -86,8 +86,6 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
// travesal subnet/op // travesal subnet/op
for (auto it = forwardNet.ops_.end(); it != forwardNet.ops_.begin(); --it) { for (auto it = forwardNet.ops_.end(); it != forwardNet.ops_.begin(); --it) {
auto fwd = *it; auto fwd = *it;
// for (auto& fwd : forwardNet.ops_) {
// auto bwd = Backward(*fwd, no_grad_names);
auto bwd = Backward(*fwd, no_grad_names); auto bwd = Backward(*fwd, no_grad_names);
net->AddOp(bwd); net->AddOp(bwd);
for (size_t i = 0; i < bwd->outputs_.size(); ++i) { for (size_t i = 0; i < bwd->outputs_.size(); ++i) {
......
...@@ -63,10 +63,10 @@ class FcOp : public NetOp { ...@@ -63,10 +63,10 @@ class FcOp : public NetOp {
public: public:
void Init() override { void Init() override {
AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")}, AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")},
{Output("before_act")}, {})); {Output("mul_out")}, {}));
auto b_name = Input("b"); auto b_name = Input("b");
if (b_name != EMPTY_VAR_NAME()) { if (b_name != EMPTY_VAR_NAME()) {
AddOp(OpRegistry::CreateOp("rowwise_add", {Output("before_act"), b_name}, AddOp(OpRegistry::CreateOp("rowwise_add", {Output("mul_out"), b_name},
{Output("before_act")}, {})); {Output("before_act")}, {}));
} }
AddOp(OpRegistry::CreateOp("sigmoid", {Output("before_act")}, AddOp(OpRegistry::CreateOp("sigmoid", {Output("before_act")},
...@@ -82,6 +82,7 @@ class FcOpMaker : public OpProtoAndCheckerMaker { ...@@ -82,6 +82,7 @@ class FcOpMaker : public OpProtoAndCheckerMaker {
AddInput("X", "x"); AddInput("X", "x");
AddInput("W", "w"); AddInput("W", "w");
AddInput("b", "b"); AddInput("b", "b");
AddOutput("mul_out", "mul output").SetTemporary();
AddOutput("before_act", "before act").SetTemporary(); AddOutput("before_act", "before act").SetTemporary();
AddOutput("Out", ""); AddOutput("Out", "");
AddComment(""); AddComment("");
...@@ -140,6 +141,7 @@ TEST(Backward, simple_op_grad) { ...@@ -140,6 +141,7 @@ TEST(Backward, simple_op_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
auto gop = f::OpRegistry::CreateGradOp(*fwd); auto gop = f::OpRegistry::CreateGradOp(*fwd);
LOG(INFO) << gop->DebugString();
ASSERT_EQ(1UL, gop->inputs_.size()); ASSERT_EQ(1UL, gop->inputs_.size());
ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]); ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]);
ASSERT_EQ("rowwise_add_grad", gop->type_); ASSERT_EQ("rowwise_add_grad", gop->type_);
...@@ -151,10 +153,18 @@ TEST(Backward, simple_op_grad) { ...@@ -151,10 +153,18 @@ TEST(Backward, simple_op_grad) {
// LOG(INFO) << gop->Output("X" + "@GRAD"); // LOG(INFO) << gop->Output("X" + "@GRAD");
} }
TEST(Backward, simple_net_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
ASSERT_NE(fwd, nullptr);
auto gop = f::Backward(*fwd, {});
LOG(INFO) << gop->DebugString();
}
TEST(Backward, net_fc_backward_normal) { TEST(Backward, net_fc_backward_normal) {
std::shared_ptr<f::OperatorBase> fwd = std::shared_ptr<f::OperatorBase> fwd =
f::OpRegistry::CreateOp("fc", {"X", "w", "b"}, {"out"}, {}); f::OpRegistry::CreateOp("fc", {"X", "w", "b"}, {"out"}, {});
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
LOG(INFO) << fwd->DebugString();
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {}); std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
ASSERT_TRUE(gop->IsNetOp()); ASSERT_TRUE(gop->IsNetOp());
auto net = static_cast<f::NetOp *>(gop.get()); auto net = static_cast<f::NetOp *>(gop.get());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册