diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index dac57c2e22c750122712c378dc553e8e74909057..25ebcefa03ff657b6fc41e3be05c710606add194 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -50,50 +50,72 @@ static std::shared_ptr EmptyOp() { return net_op; } +/** + * @brief Backward an operator, implementation + * @param forwardOp the forward operator + * @param no_grad_names variable names not calculate for gradient. Like X@GRAD + * is not needed. + * @param uniq_id a unique index used inside BackwardImpl, it will be shared + * through recursive invoke. + * @return The backward operator. For simple situation, it is a simple operator. + * For complex situation, it is a NetOp. + * + * See Backward.h for details + */ static std::shared_ptr BackwardImpl( const OperatorBase& forwardOp, std::unordered_set& no_grad_names, size_t& uniq_id) { + /** + * If all input gradients of forwarding operator do not need to calculate, + * just return an EmptyOp. Not return null ptr because EmptyOp does not take + * too much time for calculation, but it is useful for simplifying logic. + */ if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(), no_grad_names)) { return EmptyOp(); } + /** + * All output gradients of forwarding operator do not need to calculate. Then + * all input gradients cannot be computed at all, and we put them into + * `no_grad_names` set. Return an EmptyOp. + */ if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(), no_grad_names)) { for (auto& name : forwardOp.inputs_) { - // Mark all input is not need + /// Mark all input is not need no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); } return EmptyOp(); } + //! Returned gradient network auto net = std::make_shared(); if (forwardOp.IsNetOp()) { - //! TODO(dzh) - std::unordered_map /*op offset*/> - dup_output_ops; - size_t local_op_id = 0; - // Because it is a net op, it can static_cast. + /// Because forwardOp is a net op, it can static_cast. auto& forwardNet = static_cast(forwardOp); - // travesal subnet/op + //! Map from output gradient variable name to operator's indices in backward + //! net. That operator generates that variable. + std::unordered_map> dup_output_ops; + + size_t local_op_id = 0; + /// reversely travel forwardNet for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); - ++it) { + ++it, ++local_op_id) { auto fwd = *it; auto bwd = BackwardImpl(*fwd, no_grad_names, uniq_id); net->AddOp(bwd); - for (size_t i = 0; i < bwd->outputs_.size(); ++i) { - dup_output_ops[bwd->outputs_[i]].emplace_back(local_op_id); + for (auto& out : bwd->outputs_) { + dup_output_ops[out].emplace_back(local_op_id); } - local_op_id++; } - // unique the duplicate name + /// Get unique ID for this method. auto uid = uniq_id++; // TODO(dzh): more comment - typedef std::pair> Pos; - std::list insert_postion; + using Pos = std::pair>; + std::list insert_position; for (auto& dup_output_op : dup_output_ops) { const std::string& name = dup_output_op.first; auto& dup_op = dup_output_op.second; @@ -106,16 +128,18 @@ static std::shared_ptr BackwardImpl( std::to_string(i)); net->ops_[op_offset]->Rename(name, dup_outputs.back()); } - insert_postion.push_back( + insert_position.push_back( {dup_op.back(), OpRegistry::CreateOp( "add", {dup_outputs}, {name}, {{"input_format", std::vector{0, (int)dup_outputs.size()}}})}); } - insert_postion.sort( + + insert_position.sort( [](const Pos& l, const Pos& r) { return l.first > r.first; }); - for (auto& pos : insert_postion) { + + for (auto& pos : insert_position) { net->InsertOp(pos.first, pos.second); } @@ -148,6 +172,7 @@ static std::shared_ptr BackwardImpl( return net; } +//! See header for comments extern std::shared_ptr Backward( const OperatorBase& forwardOp, const std::unordered_set& no_grad_vars) { diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index ae85e6201beaf0ccb096673c5bcbfbe276c9ac0f..1167fbc5b623351387a2669c3f036b9eeec07b6b 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/framework/backward.h" + #include #include "paddle/framework/net.h" #include "paddle/framework/op_registry.h" @@ -142,6 +143,7 @@ REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker); REGISTER_OP(add, f::EmptyOp, f::AddOpMaker); REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp); REGISTER_OP(fc, f::FcOp, f::FcOpMaker); +REGISTER_GRADIENT_OP(fc, fc_grad, f::EmptyOp); REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker); REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp); @@ -160,6 +162,18 @@ TEST(Backward, simple_op_grad) { // LOG(INFO) << gop->Output("X" + "@GRAD"); } +TEST(Backward, simple_op_not_need_grad) { + auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); + ASSERT_NE(fwd, nullptr); + auto gop = f::Backward(*fwd, {"X"}); + ASSERT_EQ(std::find(gop->outputs_.begin(), gop->outputs_.end(), + "X" + f::OperatorBase::GRAD_VAR_SUFFIX()), + gop->outputs_.end()); + + auto no_input_gop = f::Backward(*fwd, {"X", "b"}); + ASSERT_NE(no_input_gop, nullptr); +} + TEST(Backward, net_fc_backward_normal) { std::shared_ptr fwd = f::OpRegistry::CreateOp( "fc", {"X", "w", "b"}, {"mul_result", "add_result", "out"}, {}); @@ -217,6 +231,8 @@ TEST(Backward, net_input_of_network_not_need_grad) { bwd_net->outputs_.begin(), bwd_net->outputs_.end()); all_output.erase(f::OperatorBase::EMPTY_VAR_NAME()); + LOG(INFO) << bwd_net->DebugString(); + LOG(INFO) << bwd_net->ops_.size(); for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) { ASSERT_NE(all_output.find(out + f::OperatorBase::GRAD_VAR_SUFFIX()), all_output.end()); @@ -230,9 +246,9 @@ TEST(Backward, net_input_of_network_not_need_grad) { ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); auto first_fc_grad = static_cast(bwd_net->ops_[1].get()); ASSERT_EQ(3UL, first_fc_grad->ops_.size()); - ASSERT_EQ( - f::OperatorBase::EMPTY_VAR_NAME(), - first_fc_grad->ops_[2]->Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX())); + LOG(INFO) << first_fc_grad->DebugString(); + ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(), + first_fc_grad[2].Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX())); } TEST(Backward, net_shared_weight) { @@ -245,13 +261,14 @@ TEST(Backward, net_shared_weight) { ASSERT_TRUE(bwd->IsNetOp()); auto bwd_net = static_cast(bwd.get()); ASSERT_EQ(3UL, bwd_net->ops_.size()); + LOG(INFO) << bwd_net->DebugString(); ASSERT_EQ("add_grad", bwd_net->ops_[2]->type_); } TEST(Backward, op_register_grad_not_for_network) { - auto fwd = f::OpRegistry::CreateOp( - "fc", {"X", "W", "b"}, {"mul_result", "add_result", "Out"}, - {{"temporary_index", std::vector{1}}}); + auto fwd = + f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"}, + {{"temporary_index", std::vector{1}}}); ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet); } @@ -299,9 +316,11 @@ TEST(Backward, op_part_of_output_are_not_need) { TEST(Backward, op_part_of_input_are_not_need) { auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {}); auto backward = f::Backward(*fwd, {"a"}); - ASSERT_TRUE(!backward->IsNetOp()); + ASSERT_False(backward->IsNetOp()); + auto net = static_cast(backward.get()); + ASSERT_EQ(net->ops_.size(), 1UL); - auto &grad_mul = *backward; + auto &grad_mul = *net->ops_[0]; ASSERT_EQ(grad_mul.type_, "mul_grad"); ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); ASSERT_EQ(grad_mul.outputs_.size(), 2UL); @@ -324,11 +343,11 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { {"mul_out2", "tmp_out2", "out2"}, {})); net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"}, {"mul_out3", "tmp_out3", "out3"}, {})); - net.CompleteAddOp(); + net.CompleteAddOp(false); auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"}); ASSERT_TRUE(backward->IsNetOp()); auto bwd_net = static_cast(backward.get()); - ASSERT_EQ(bwd_net->ops_.size(), 1UL); + ASSERT_EQ(bwd_net->ops_.size(), 3UL); auto &grad_fc = *bwd_net->ops_[0]; ASSERT_EQ(grad_fc.type_, "fc_grad");