提交 65d26787 编写于 作者: D dongzhihong

"add simple net test"

上级 404cc056
......@@ -86,8 +86,6 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
// travesal subnet/op
for (auto it = forwardNet.ops_.end(); it != forwardNet.ops_.begin(); --it) {
auto fwd = *it;
// for (auto& fwd : forwardNet.ops_) {
// auto bwd = Backward(*fwd, no_grad_names);
auto bwd = Backward(*fwd, no_grad_names);
net->AddOp(bwd);
for (size_t i = 0; i < bwd->outputs_.size(); ++i) {
......
......@@ -63,10 +63,10 @@ class FcOp : public NetOp {
public:
void Init() override {
AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")},
{Output("before_act")}, {}));
{Output("mul_out")}, {}));
auto b_name = Input("b");
if (b_name != EMPTY_VAR_NAME()) {
AddOp(OpRegistry::CreateOp("rowwise_add", {Output("before_act"), b_name},
AddOp(OpRegistry::CreateOp("rowwise_add", {Output("mul_out"), b_name},
{Output("before_act")}, {}));
}
AddOp(OpRegistry::CreateOp("sigmoid", {Output("before_act")},
......@@ -82,6 +82,7 @@ class FcOpMaker : public OpProtoAndCheckerMaker {
AddInput("X", "x");
AddInput("W", "w");
AddInput("b", "b");
AddOutput("mul_out", "mul output").SetTemporary();
AddOutput("before_act", "before act").SetTemporary();
AddOutput("Out", "");
AddComment("");
......@@ -140,6 +141,7 @@ TEST(Backward, simple_op_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
ASSERT_NE(fwd, nullptr);
auto gop = f::OpRegistry::CreateGradOp(*fwd);
LOG(INFO) << gop->DebugString();
ASSERT_EQ(1UL, gop->inputs_.size());
ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]);
ASSERT_EQ("rowwise_add_grad", gop->type_);
......@@ -151,10 +153,18 @@ TEST(Backward, simple_op_grad) {
// LOG(INFO) << gop->Output("X" + "@GRAD");
}
TEST(Backward, simple_net_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
ASSERT_NE(fwd, nullptr);
auto gop = f::Backward(*fwd, {});
LOG(INFO) << gop->DebugString();
}
TEST(Backward, net_fc_backward_normal) {
std::shared_ptr<f::OperatorBase> fwd =
f::OpRegistry::CreateOp("fc", {"X", "w", "b"}, {"out"}, {});
ASSERT_NE(fwd, nullptr);
LOG(INFO) << fwd->DebugString();
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
ASSERT_TRUE(gop->IsNetOp());
auto net = static_cast<f::NetOp *>(gop.get());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册