diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index c966f97c2d5b553f6ab67bb2f7aac27108b80409..1e20789a1f1b520e33c99b0f8740fbbcf2e792fa 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -28,15 +28,15 @@ namespace paddle { namespace framework { static inline std::unique_ptr<OperatorBase> CreateGradOp( - const OperatorBase& op, - const std::unordered_set<std::string>& no_grad_set) { + const OperatorBase& op, const std::unordered_set<std::string>& no_grad_set, + std::unordered_map<std::string, std::string>* grad_to_var) { OpDescBind op_desc; op_desc.SetInputMap(op.Inputs()); op_desc.SetOutputMap(op.Outputs()); op_desc.SetType(op.Type()); op_desc.SetAttrMap(op.Attrs()); auto& info = OpInfoMap::Instance().Get(op.Type()); - auto grad_descs = info.GradOpMaker()(op_desc, no_grad_set); + auto grad_descs = info.GradOpMaker()(op_desc, no_grad_set, grad_to_var); std::vector<std::unique_ptr<OperatorBase>> grad_ops; grad_ops.reserve(grad_descs.size()); std::transform(grad_descs.begin(), grad_descs.end(), @@ -99,7 +99,9 @@ static std::unique_ptr<OperatorBase> NOP() { // See Backward.h for details static std::unique_ptr<OperatorBase> BackwardRecursive( const OperatorBase& forwardOp, - std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) { + std::unordered_set<std::string>& no_grad_names, + std::unordered_map<std::string, std::string>* grad_to_var, + size_t& uniq_id) { // If all input gradients of forwarding operator do not need to calculate, // just return an NOP. Not return null ptr because NOP does not take // too much time for calculation, but it is useful for simplifying logic. @@ -137,7 +139,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); ++it, ++local_op_id) { auto& fwd = *it; - auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id); + auto bwd = BackwardRecursive(*fwd, no_grad_names, grad_to_var, uniq_id); ForEachVarName(bwd->Outputs(), [&dup_output_ops, local_op_id](const std::string& out) { dup_output_ops[out].emplace_back(local_op_id); @@ -189,7 +191,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( } } else { std::unique_ptr<OperatorBase> grad_op( - CreateGradOp(forwardOp, no_grad_names)); + CreateGradOp(forwardOp, no_grad_names, grad_to_var)); ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op]( const std::string& grad_input) { @@ -228,7 +230,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( *static_cast<const OperatorBase*>(&rnnop.stepnet()); // create stepnet's gradient op rnn_grad_op->set_stepnet( - BackwardRecursive(stepnet_op, no_grad_names, uniq_id)); + BackwardRecursive(stepnet_op, no_grad_names, grad_to_var, uniq_id)); } if (net->ops_.empty()) { // Current no aux op is added to network @@ -255,7 +257,8 @@ std::unique_ptr<OperatorBase> Backward( no_grad_names.insert(name + kGradVarSuffix); } size_t uid = 0; - return BackwardRecursive(forwardOp, no_grad_names, uid); + std::unordered_map<std::string, std::string> grad_to_var; + return BackwardRecursive(forwardOp, no_grad_names, &grad_to_var, uid); } // ==================================== // @@ -272,30 +275,31 @@ static bool AllGradInSet(const std::vector<std::string>& names, std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad( const std::unique_ptr<OpDescBind>& op_desc, - std::unordered_set<std::string>& no_grad_vars) { + std::unordered_set<std::string>* no_grad_vars, + std::unordered_map<std::string, std::string>* grad_to_var) { std::vector<std::unique_ptr<OpDescBind>> grad_op_descs; // All input gradients of forwarding operator do not need to calculate. const std::vector<std::string>& inputs = op_desc->InputArgumentNames(); - if (AllGradInSet(inputs, no_grad_vars)) { + if (AllGradInSet(inputs, *no_grad_vars)) { return grad_op_descs; // empty vector } // All output gradients of forwarding operator do not need to calculate. const std::vector<std::string>& outputs = op_desc->OutputArgumentNames(); - if (AllGradInSet(outputs, no_grad_vars)) { + if (AllGradInSet(outputs, *no_grad_vars)) { for (const std::string& name : inputs) { - no_grad_vars.insert(GradVarName(name)); + no_grad_vars->insert(GradVarName(name)); } return grad_op_descs; // empty vector } grad_op_descs = OpInfoMap::Instance() .Get(op_desc->Type()) - .GradOpMaker()(*op_desc, no_grad_vars); + .GradOpMaker()(*op_desc, *no_grad_vars, grad_to_var); std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops; for (auto& desc : grad_op_descs) { for (const std::string& in_name : desc->InputArgumentNames()) { - if (no_grad_vars.count(in_name)) { + if (no_grad_vars->count(in_name)) { std::string prefix = in_name.substr( 0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1); std::string new_name = prefix + kZeroVarSuffix; @@ -315,7 +319,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad( std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( ProgramDescBind& program_desc, int block_idx, - std::unordered_set<std::string>& no_grad_vars) { + std::unordered_set<std::string>* no_grad_vars, + std::unordered_map<std::string, std::string>* grad_to_var) { BlockDescBind* cur_block = program_desc.Block(block_idx); std::deque<std::unique_ptr<OpDescBind>>& op_descs = cur_block->ops_; std::unordered_map<std::string, std::vector<size_t>> dup_out_ops; @@ -323,15 +328,15 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( std::vector<std::unique_ptr<OpDescBind>> backward_descs; for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) { std::vector<std::unique_ptr<OpDescBind>> op_grads = - MakeOpGrad(*it, no_grad_vars); + MakeOpGrad(*it, no_grad_vars, grad_to_var); if ((*it)->Type() == "recurrent") { PADDLE_ENFORCE_EQ( op_grads.size(), size_t(1), "rnn_op's gradient process should contain only one op."); int step_block_idx = (*it)->GetBlockAttr("stop_block"); - auto backward_block_op_descs = - MakeBlockBackward(program_desc, step_block_idx, no_grad_vars); + auto backward_block_op_descs = MakeBlockBackward( + program_desc, step_block_idx, no_grad_vars, grad_to_var); BlockDescBind* backward_block = program_desc.AppendBlock(*cur_block); for (auto& ptr : backward_block_op_descs) { backward_block->ops_.push_back(std::move(ptr)); @@ -387,8 +392,9 @@ void AppendBackward(ProgramDescBind& program_desc, no_grad_var_names.insert(GradVarName(name)); } const int root_block_idx = 0; - auto backward_op_descs = - MakeBlockBackward(program_desc, root_block_idx, no_grad_var_names); + std::unordered_map<std::string, std::string> grad_to_var; + auto backward_op_descs = MakeBlockBackward(program_desc, root_block_idx, + &no_grad_var_names, &grad_to_var); auto& forw_op_descs = program_desc.Block(root_block_idx)->ops_; for (auto& ptr : backward_op_descs) { forw_op_descs.push_back(std::move(ptr)); diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h index 3437e89923da8de79eeaa88d0466cf7eb0b5926d..9d453e1d6f42077df3886d9645e1ab59eaf1aa1d 100644 --- a/paddle/framework/block_desc.h +++ b/paddle/framework/block_desc.h @@ -35,7 +35,8 @@ class BlockDescBind { public: friend std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( ProgramDescBind &program_desc, int block_idx, - std::unordered_set<std::string> &no_grad_vars); + std::unordered_set<std::string> *no_grad_vars, + std::unordered_map<std::string, std::string> *grad_to_var); friend void AppendBackward( ProgramDescBind &program_desc, diff --git a/paddle/framework/details/op_registry.h b/paddle/framework/details/op_registry.h index ca8584b78ab081138e0d73b8a71ae4cc111a1b4c..ed7c5f17b0854809bde923276f36440cce193a88 100644 --- a/paddle/framework/details/op_registry.h +++ b/paddle/framework/details/op_registry.h @@ -99,8 +99,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> { void operator()(const char* op_type, OpInfo* info) const { info->grad_op_maker_ = []( const OpDescBind& fwd_op, - const std::unordered_set<std::string>& no_grad_set) { - T maker(fwd_op, no_grad_set); + const std::unordered_set<std::string>& no_grad_set, + std::unordered_map<std::string, std::string>* grad_to_var) { + T maker(fwd_op, no_grad_set, grad_to_var); return maker(); }; } diff --git a/paddle/framework/grad_op_desc_maker.h b/paddle/framework/grad_op_desc_maker.h index d7366b11ec94403e0d8d5d8a3485896f0dc691c0..1219e0487531b19b00adde5a9aa2bde51bfc0aa8 100644 --- a/paddle/framework/grad_op_desc_maker.h +++ b/paddle/framework/grad_op_desc_maker.h @@ -25,8 +25,9 @@ class GradOpDescMakerBase { public: explicit GradOpDescMakerBase( const OpDescBind& fwd_op, - const std::unordered_set<std::string>& no_grad_set) - : fwd_op_(fwd_op), no_grad_set_(no_grad_set) {} + const std::unordered_set<std::string>& no_grad_set, + std::unordered_map<std::string, std::string>* grad_to_var) + : fwd_op_(fwd_op), no_grad_set_(no_grad_set), grad_to_var_(grad_to_var) {} virtual ~GradOpDescMakerBase() = default; virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0; @@ -37,12 +38,17 @@ class GradOpDescMakerBase { std::vector<std::string> ret_val; auto var_names = this->Input(name); ret_val.reserve(var_names.size()); - std::transform( - var_names.begin(), var_names.end(), std::back_inserter(ret_val), - [this](const std::string& fwd_var_name) -> std::string { - auto g_name = GradVarName(fwd_var_name); - return no_grad_set_.count(g_name) == 0 ? g_name : kEmptyVarName; - }); + std::transform(var_names.begin(), var_names.end(), + std::back_inserter(ret_val), + [this](const std::string& fwd_var_name) -> std::string { + auto g_name = GradVarName(fwd_var_name); + if (no_grad_set_.count(g_name)) { + return kEmptyVarName; + } else { + (*this->grad_to_var_)[g_name] = fwd_var_name; + return g_name; + } + }); if (!drop_empty_grad) { return ret_val; } @@ -95,6 +101,7 @@ class GradOpDescMakerBase { private: const OpDescBind& fwd_op_; const std::unordered_set<std::string>& no_grad_set_; + std::unordered_map<std::string, std::string>* grad_to_var_; }; class SingleGradOpDescMaker : public GradOpDescMakerBase { diff --git a/paddle/framework/type_defs.h b/paddle/framework/type_defs.h index 7e1b79c97b5b4c3f0292fbfa205a8ca541702fbc..0d1564a7510ddf0106ff417fb0b487ddbde1ac2e 100644 --- a/paddle/framework/type_defs.h +++ b/paddle/framework/type_defs.h @@ -37,7 +37,8 @@ using OpCreator = std::function<OperatorBase*( const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>; using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDescBind>>( - const OpDescBind&, const std::unordered_set<std::string>& /*no_grad_set*/)>; + const OpDescBind&, const std::unordered_set<std::string>& /*no_grad_set*/, + std::unordered_map<std::string, std::string>* /*grad_to_var*/)>; } // namespace framework } // namespace paddle