From 404cc056b8f0de18ee3633c7c6ba28b773320e2e Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Thu, 27 Jul 2017 17:50:17 +0800 Subject: [PATCH] "reverse travesal" --- paddle/framework/backward.cc | 7 +++++-- paddle/framework/backward_test.cc | 8 ++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 716e78f342..2d9efdd511 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -77,14 +77,17 @@ static std::shared_ptr BackwardImpl( if (forwardOp.IsNetOp()) { //! TODO(dzh) std::unordered_map /*op offs et*/> + std::vector /*op offset*/> dup_output_ops; size_t local_op_id = 0; // Because it is a net op, it can static_cast. auto& forwardNet = static_cast(forwardOp); // travesal subnet/op - for (auto& fwd : forwardNet.ops_) { + for (auto it = forwardNet.ops_.end(); it != forwardNet.ops_.begin(); --it) { + auto fwd = *it; + // for (auto& fwd : forwardNet.ops_) { + // auto bwd = Backward(*fwd, no_grad_names); auto bwd = Backward(*fwd, no_grad_names); net->AddOp(bwd); for (size_t i = 0; i < bwd->outputs_.size(); ++i) { diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 0666bcc14c..54acc47599 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -129,12 +129,12 @@ REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker); REGISTER_GRADIENT_OP(mul, mul_grad, f::EmptyOp); REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker); REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, f::EmptyOp); -REGISTER_OP(fc, f::FcOp, f::FcOpMaker); -REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker); -REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp); REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker); REGISTER_OP(add, f::EmptyOp, f::AddOpMaker); REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp); +REGISTER_OP(fc, f::FcOp, f::FcOpMaker); +REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker); +REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp); TEST(Backward, simple_op_grad) { auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); @@ -218,7 +218,7 @@ TEST(Backward, net_input_of_network_not_need_grad) { ASSERT_EQ(2UL, bwd_net->ops_.size()); ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); auto first_fc_grad = static_cast(bwd_net->ops_[1].get()); - ASSERT_EQ(3, first_fc_grad->ops_.size()); + ASSERT_EQ(3UL, first_fc_grad->ops_.size()); ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(), first_fc_grad[2].Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX())); } -- GitLab