提交 404cc056 编写于 作者: D dongzhihong

"reverse travesal"

上级 7088654a
...@@ -77,14 +77,17 @@ static std::shared_ptr<OperatorBase> BackwardImpl( ...@@ -77,14 +77,17 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
if (forwardOp.IsNetOp()) { if (forwardOp.IsNetOp()) {
//! TODO(dzh) //! TODO(dzh)
std::unordered_map<std::string /*var name*/, std::unordered_map<std::string /*var name*/,
std::vector<size_t> /*op offs et*/> std::vector<size_t> /*op offset*/>
dup_output_ops; dup_output_ops;
size_t local_op_id = 0; size_t local_op_id = 0;
// Because it is a net op, it can static_cast. // Because it is a net op, it can static_cast.
auto& forwardNet = static_cast<const NetOp&>(forwardOp); auto& forwardNet = static_cast<const NetOp&>(forwardOp);
// travesal subnet/op // travesal subnet/op
for (auto& fwd : forwardNet.ops_) { for (auto it = forwardNet.ops_.end(); it != forwardNet.ops_.begin(); --it) {
auto fwd = *it;
// for (auto& fwd : forwardNet.ops_) {
// auto bwd = Backward(*fwd, no_grad_names);
auto bwd = Backward(*fwd, no_grad_names); auto bwd = Backward(*fwd, no_grad_names);
net->AddOp(bwd); net->AddOp(bwd);
for (size_t i = 0; i < bwd->outputs_.size(); ++i) { for (size_t i = 0; i < bwd->outputs_.size(); ++i) {
......
...@@ -129,12 +129,12 @@ REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker); ...@@ -129,12 +129,12 @@ REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker);
REGISTER_GRADIENT_OP(mul, mul_grad, f::EmptyOp); REGISTER_GRADIENT_OP(mul, mul_grad, f::EmptyOp);
REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker); REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker);
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, f::EmptyOp); REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, f::EmptyOp);
REGISTER_OP(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker);
REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp);
REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker); REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker);
REGISTER_OP(add, f::EmptyOp, f::AddOpMaker); REGISTER_OP(add, f::EmptyOp, f::AddOpMaker);
REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp); REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp);
REGISTER_OP(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker);
REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp);
TEST(Backward, simple_op_grad) { TEST(Backward, simple_op_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
...@@ -218,7 +218,7 @@ TEST(Backward, net_input_of_network_not_need_grad) { ...@@ -218,7 +218,7 @@ TEST(Backward, net_input_of_network_not_need_grad) {
ASSERT_EQ(2UL, bwd_net->ops_.size()); ASSERT_EQ(2UL, bwd_net->ops_.size());
ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
auto first_fc_grad = static_cast<f::NetOp *>(bwd_net->ops_[1].get()); auto first_fc_grad = static_cast<f::NetOp *>(bwd_net->ops_[1].get());
ASSERT_EQ(3, first_fc_grad->ops_.size()); ASSERT_EQ(3UL, first_fc_grad->ops_.size());
ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(), ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(),
first_fc_grad[2].Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX())); first_fc_grad[2].Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册