提交 74cd9a75 编写于 作者: D dongzhihong

"fix unittest"

上级 72839a76
...@@ -79,11 +79,11 @@ static std::shared_ptr<OperatorBase> BackwardImpl( ...@@ -79,11 +79,11 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
std::unordered_map<std::string /*var name*/, std::unordered_map<std::string /*var name*/,
std::vector<size_t> /*op offset*/> std::vector<size_t> /*op offset*/>
dup_output_ops; dup_output_ops;
size_t local_op_id = 0;
// Because it is a net op, it can static_cast. // Because it is a net op, it can static_cast.
auto& forwardNet = static_cast<const NetOp&>(forwardOp); auto& forwardNet = static_cast<const NetOp&>(forwardOp);
// travesal subnet/op // travesal subnet/op
size_t local_op_id = 0;
for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
++it) { ++it) {
auto fwd = *it; auto fwd = *it;
......
...@@ -149,7 +149,6 @@ TEST(Backward, simple_op_grad) { ...@@ -149,7 +149,6 @@ TEST(Backward, simple_op_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
auto gop = f::OpRegistry::CreateGradOp(*fwd); auto gop = f::OpRegistry::CreateGradOp(*fwd);
LOG(INFO) << gop->DebugString();
ASSERT_EQ(1UL, gop->inputs_.size()); ASSERT_EQ(1UL, gop->inputs_.size());
ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]); ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]);
ASSERT_EQ("rowwise_add_grad", gop->type_); ASSERT_EQ("rowwise_add_grad", gop->type_);
...@@ -161,18 +160,19 @@ TEST(Backward, simple_op_grad) { ...@@ -161,18 +160,19 @@ TEST(Backward, simple_op_grad) {
// LOG(INFO) << gop->Output("X" + "@GRAD"); // LOG(INFO) << gop->Output("X" + "@GRAD");
} }
TEST(Backward, simple_net_grad) { TEST(Backward, simple_op_not_need_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"x", "b"}, {"out"}, {});
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
auto gop = f::Backward(*fwd, {}); auto gop = f::Backward(*fwd, {"x"});
LOG(INFO) << gop->DebugString(); LOG(INFO) << gop->DebugString();
ASSERT_NE(gop->outputs_.find("x" + f::OperatorBase::GRAD_VAR_SUFFIX()),
gop->outputs_.end());
} }
TEST(Backward, net_fc_backward_normal) { TEST(Backward, net_fc_backward_normal) {
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp( std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
"fc", {"X", "w", "b"}, {"mul_result", "add_result", "out"}, {}); "fc", {"X", "w", "b"}, {"mul_result", "add_result", "out"}, {});
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
LOG(INFO) << fwd->DebugString();
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {}); std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
ASSERT_TRUE(gop->IsNetOp()); ASSERT_TRUE(gop->IsNetOp());
auto net = static_cast<f::NetOp *>(gop.get()); auto net = static_cast<f::NetOp *>(gop.get());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册