diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index dac57c2e22c750122712c378dc553e8e74909057..25ebcefa03ff657b6fc41e3be05c710606add194 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -50,50 +50,72 @@ static std::shared_ptr EmptyOp() { return net_op; } +/** + * @brief Backward an operator, implementation + * @param forwardOp the forward operator + * @param no_grad_names variable names not calculate for gradient. Like X@GRAD + * is not needed. + * @param uniq_id a unique index used inside BackwardImpl, it will be shared + * through recursive invoke. + * @return The backward operator. For simple situation, it is a simple operator. + * For complex situation, it is a NetOp. + * + * See Backward.h for details + */ static std::shared_ptr BackwardImpl( const OperatorBase& forwardOp, std::unordered_set& no_grad_names, size_t& uniq_id) { + /** + * If all input gradients of forwarding operator do not need to calculate, + * just return an EmptyOp. Not return null ptr because EmptyOp does not take + * too much time for calculation, but it is useful for simplifying logic. + */ if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(), no_grad_names)) { return EmptyOp(); } + /** + * All output gradients of forwarding operator do not need to calculate. Then + * all input gradients cannot be computed at all, and we put them into + * `no_grad_names` set. Return an EmptyOp. + */ if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(), no_grad_names)) { for (auto& name : forwardOp.inputs_) { - // Mark all input is not need + /// Mark all input is not need no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); } return EmptyOp(); } + //! Returned gradient network auto net = std::make_shared(); if (forwardOp.IsNetOp()) { - //! TODO(dzh) - std::unordered_map /*op offset*/> - dup_output_ops; - size_t local_op_id = 0; - // Because it is a net op, it can static_cast. + /// Because forwardOp is a net op, it can static_cast. auto& forwardNet = static_cast(forwardOp); - // travesal subnet/op + //! Map from output gradient variable name to operator's indices in backward + //! net. That operator generates that variable. + std::unordered_map> dup_output_ops; + + size_t local_op_id = 0; + /// reversely travel forwardNet for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); - ++it) { + ++it, ++local_op_id) { auto fwd = *it; auto bwd = BackwardImpl(*fwd, no_grad_names, uniq_id); net->AddOp(bwd); - for (size_t i = 0; i < bwd->outputs_.size(); ++i) { - dup_output_ops[bwd->outputs_[i]].emplace_back(local_op_id); + for (auto& out : bwd->outputs_) { + dup_output_ops[out].emplace_back(local_op_id); } - local_op_id++; } - // unique the duplicate name + /// Get unique ID for this method. auto uid = uniq_id++; // TODO(dzh): more comment - typedef std::pair> Pos; - std::list insert_postion; + using Pos = std::pair>; + std::list insert_position; for (auto& dup_output_op : dup_output_ops) { const std::string& name = dup_output_op.first; auto& dup_op = dup_output_op.second; @@ -106,16 +128,18 @@ static std::shared_ptr BackwardImpl( std::to_string(i)); net->ops_[op_offset]->Rename(name, dup_outputs.back()); } - insert_postion.push_back( + insert_position.push_back( {dup_op.back(), OpRegistry::CreateOp( "add", {dup_outputs}, {name}, {{"input_format", std::vector{0, (int)dup_outputs.size()}}})}); } - insert_postion.sort( + + insert_position.sort( [](const Pos& l, const Pos& r) { return l.first > r.first; }); - for (auto& pos : insert_postion) { + + for (auto& pos : insert_position) { net->InsertOp(pos.first, pos.second); } @@ -148,6 +172,7 @@ static std::shared_ptr BackwardImpl( return net; } +//! See header for comments extern std::shared_ptr Backward( const OperatorBase& forwardOp, const std::unordered_set& no_grad_vars) {