From 2c46666e7b24078b54a495166a489c9d5c71ba53 Mon Sep 17 00:00:00 2001
From: fengjiayi <fengjiayi@baidu.com>
Date: Fri, 13 Oct 2017 15:24:24 -0700
Subject: [PATCH] Add grad_name_map to record correspondences between vars and
 grad_vars (#4794)

* Add grad_name_map

* Fix bug

* Fix bug

* Follow comments
---
 paddle/framework/backward.cc           | 46 +++++++++++++++-----------
 paddle/framework/block_desc.h          |  3 +-
 paddle/framework/details/op_registry.h |  5 +--
 paddle/framework/grad_op_desc_maker.h  | 23 ++++++++-----
 paddle/framework/type_defs.h           |  3 +-
 5 files changed, 48 insertions(+), 32 deletions(-)

diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc
index c966f97c2d..1e20789a1f 100644
--- a/paddle/framework/backward.cc
+++ b/paddle/framework/backward.cc
@@ -28,15 +28,15 @@ namespace paddle {
 namespace framework {
 
 static inline std::unique_ptr<OperatorBase> CreateGradOp(
-    const OperatorBase& op,
-    const std::unordered_set<std::string>& no_grad_set) {
+    const OperatorBase& op, const std::unordered_set<std::string>& no_grad_set,
+    std::unordered_map<std::string, std::string>* grad_to_var) {
   OpDescBind op_desc;
   op_desc.SetInputMap(op.Inputs());
   op_desc.SetOutputMap(op.Outputs());
   op_desc.SetType(op.Type());
   op_desc.SetAttrMap(op.Attrs());
   auto& info = OpInfoMap::Instance().Get(op.Type());
-  auto grad_descs = info.GradOpMaker()(op_desc, no_grad_set);
+  auto grad_descs = info.GradOpMaker()(op_desc, no_grad_set, grad_to_var);
   std::vector<std::unique_ptr<OperatorBase>> grad_ops;
   grad_ops.reserve(grad_descs.size());
   std::transform(grad_descs.begin(), grad_descs.end(),
@@ -99,7 +99,9 @@ static std::unique_ptr<OperatorBase> NOP() {
 //  See Backward.h for details
 static std::unique_ptr<OperatorBase> BackwardRecursive(
     const OperatorBase& forwardOp,
-    std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
+    std::unordered_set<std::string>& no_grad_names,
+    std::unordered_map<std::string, std::string>* grad_to_var,
+    size_t& uniq_id) {
   //  If all input gradients of forwarding operator do not need to calculate,
   //  just return an NOP. Not return null ptr because NOP does not take
   //  too much time for calculation, but it is useful for simplifying logic.
@@ -137,7 +139,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
     for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
          ++it, ++local_op_id) {
       auto& fwd = *it;
-      auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id);
+      auto bwd = BackwardRecursive(*fwd, no_grad_names, grad_to_var, uniq_id);
       ForEachVarName(bwd->Outputs(),
                      [&dup_output_ops, local_op_id](const std::string& out) {
                        dup_output_ops[out].emplace_back(local_op_id);
@@ -189,7 +191,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
     }
   } else {
     std::unique_ptr<OperatorBase> grad_op(
-        CreateGradOp(forwardOp, no_grad_names));
+        CreateGradOp(forwardOp, no_grad_names, grad_to_var));
 
     ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op](
                                           const std::string& grad_input) {
@@ -228,7 +230,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
           *static_cast<const OperatorBase*>(&rnnop.stepnet());
       // create stepnet's gradient op
       rnn_grad_op->set_stepnet(
-          BackwardRecursive(stepnet_op, no_grad_names, uniq_id));
+          BackwardRecursive(stepnet_op, no_grad_names, grad_to_var, uniq_id));
     }
 
     if (net->ops_.empty()) {  // Current no aux op is added to network
@@ -255,7 +257,8 @@ std::unique_ptr<OperatorBase> Backward(
     no_grad_names.insert(name + kGradVarSuffix);
   }
   size_t uid = 0;
-  return BackwardRecursive(forwardOp, no_grad_names, uid);
+  std::unordered_map<std::string, std::string> grad_to_var;
+  return BackwardRecursive(forwardOp, no_grad_names, &grad_to_var, uid);
 }
 
 // ====================================  //
@@ -272,30 +275,31 @@ static bool AllGradInSet(const std::vector<std::string>& names,
 
 std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
     const std::unique_ptr<OpDescBind>& op_desc,
-    std::unordered_set<std::string>& no_grad_vars) {
+    std::unordered_set<std::string>* no_grad_vars,
+    std::unordered_map<std::string, std::string>* grad_to_var) {
   std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
   // All input gradients of forwarding operator do not need to calculate.
   const std::vector<std::string>& inputs = op_desc->InputArgumentNames();
-  if (AllGradInSet(inputs, no_grad_vars)) {
+  if (AllGradInSet(inputs, *no_grad_vars)) {
     return grad_op_descs;  // empty vector
   }
   // All output gradients of forwarding operator do not need to calculate.
   const std::vector<std::string>& outputs = op_desc->OutputArgumentNames();
-  if (AllGradInSet(outputs, no_grad_vars)) {
+  if (AllGradInSet(outputs, *no_grad_vars)) {
     for (const std::string& name : inputs) {
-      no_grad_vars.insert(GradVarName(name));
+      no_grad_vars->insert(GradVarName(name));
     }
     return grad_op_descs;  // empty vector
   }
 
   grad_op_descs = OpInfoMap::Instance()
                       .Get(op_desc->Type())
-                      .GradOpMaker()(*op_desc, no_grad_vars);
+                      .GradOpMaker()(*op_desc, *no_grad_vars, grad_to_var);
 
   std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops;
   for (auto& desc : grad_op_descs) {
     for (const std::string& in_name : desc->InputArgumentNames()) {
-      if (no_grad_vars.count(in_name)) {
+      if (no_grad_vars->count(in_name)) {
         std::string prefix = in_name.substr(
             0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
         std::string new_name = prefix + kZeroVarSuffix;
@@ -315,7 +319,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
 
 std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
     ProgramDescBind& program_desc, int block_idx,
-    std::unordered_set<std::string>& no_grad_vars) {
+    std::unordered_set<std::string>* no_grad_vars,
+    std::unordered_map<std::string, std::string>* grad_to_var) {
   BlockDescBind* cur_block = program_desc.Block(block_idx);
   std::deque<std::unique_ptr<OpDescBind>>& op_descs = cur_block->ops_;
   std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
@@ -323,15 +328,15 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
   std::vector<std::unique_ptr<OpDescBind>> backward_descs;
   for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) {
     std::vector<std::unique_ptr<OpDescBind>> op_grads =
-        MakeOpGrad(*it, no_grad_vars);
+        MakeOpGrad(*it, no_grad_vars, grad_to_var);
 
     if ((*it)->Type() == "recurrent") {
       PADDLE_ENFORCE_EQ(
           op_grads.size(), size_t(1),
           "rnn_op's gradient process should contain only one op.");
       int step_block_idx = (*it)->GetBlockAttr("stop_block");
-      auto backward_block_op_descs =
-          MakeBlockBackward(program_desc, step_block_idx, no_grad_vars);
+      auto backward_block_op_descs = MakeBlockBackward(
+          program_desc, step_block_idx, no_grad_vars, grad_to_var);
       BlockDescBind* backward_block = program_desc.AppendBlock(*cur_block);
       for (auto& ptr : backward_block_op_descs) {
         backward_block->ops_.push_back(std::move(ptr));
@@ -387,8 +392,9 @@ void AppendBackward(ProgramDescBind& program_desc,
     no_grad_var_names.insert(GradVarName(name));
   }
   const int root_block_idx = 0;
-  auto backward_op_descs =
-      MakeBlockBackward(program_desc, root_block_idx, no_grad_var_names);
+  std::unordered_map<std::string, std::string> grad_to_var;
+  auto backward_op_descs = MakeBlockBackward(program_desc, root_block_idx,
+                                             &no_grad_var_names, &grad_to_var);
   auto& forw_op_descs = program_desc.Block(root_block_idx)->ops_;
   for (auto& ptr : backward_op_descs) {
     forw_op_descs.push_back(std::move(ptr));
diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h
index 3437e89923..9d453e1d6f 100644
--- a/paddle/framework/block_desc.h
+++ b/paddle/framework/block_desc.h
@@ -35,7 +35,8 @@ class BlockDescBind {
  public:
   friend std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
       ProgramDescBind &program_desc, int block_idx,
-      std::unordered_set<std::string> &no_grad_vars);
+      std::unordered_set<std::string> *no_grad_vars,
+      std::unordered_map<std::string, std::string> *grad_to_var);
 
   friend void AppendBackward(
       ProgramDescBind &program_desc,
diff --git a/paddle/framework/details/op_registry.h b/paddle/framework/details/op_registry.h
index ca8584b78a..ed7c5f17b0 100644
--- a/paddle/framework/details/op_registry.h
+++ b/paddle/framework/details/op_registry.h
@@ -99,8 +99,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
   void operator()(const char* op_type, OpInfo* info) const {
     info->grad_op_maker_ = [](
         const OpDescBind& fwd_op,
-        const std::unordered_set<std::string>& no_grad_set) {
-      T maker(fwd_op, no_grad_set);
+        const std::unordered_set<std::string>& no_grad_set,
+        std::unordered_map<std::string, std::string>* grad_to_var) {
+      T maker(fwd_op, no_grad_set, grad_to_var);
       return maker();
     };
   }
diff --git a/paddle/framework/grad_op_desc_maker.h b/paddle/framework/grad_op_desc_maker.h
index d7366b11ec..1219e04875 100644
--- a/paddle/framework/grad_op_desc_maker.h
+++ b/paddle/framework/grad_op_desc_maker.h
@@ -25,8 +25,9 @@ class GradOpDescMakerBase {
  public:
   explicit GradOpDescMakerBase(
       const OpDescBind& fwd_op,
-      const std::unordered_set<std::string>& no_grad_set)
-      : fwd_op_(fwd_op), no_grad_set_(no_grad_set) {}
+      const std::unordered_set<std::string>& no_grad_set,
+      std::unordered_map<std::string, std::string>* grad_to_var)
+      : fwd_op_(fwd_op), no_grad_set_(no_grad_set), grad_to_var_(grad_to_var) {}
 
   virtual ~GradOpDescMakerBase() = default;
   virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0;
@@ -37,12 +38,17 @@ class GradOpDescMakerBase {
     std::vector<std::string> ret_val;
     auto var_names = this->Input(name);
     ret_val.reserve(var_names.size());
-    std::transform(
-        var_names.begin(), var_names.end(), std::back_inserter(ret_val),
-        [this](const std::string& fwd_var_name) -> std::string {
-          auto g_name = GradVarName(fwd_var_name);
-          return no_grad_set_.count(g_name) == 0 ? g_name : kEmptyVarName;
-        });
+    std::transform(var_names.begin(), var_names.end(),
+                   std::back_inserter(ret_val),
+                   [this](const std::string& fwd_var_name) -> std::string {
+                     auto g_name = GradVarName(fwd_var_name);
+                     if (no_grad_set_.count(g_name)) {
+                       return kEmptyVarName;
+                     } else {
+                       (*this->grad_to_var_)[g_name] = fwd_var_name;
+                       return g_name;
+                     }
+                   });
     if (!drop_empty_grad) {
       return ret_val;
     }
@@ -95,6 +101,7 @@ class GradOpDescMakerBase {
  private:
   const OpDescBind& fwd_op_;
   const std::unordered_set<std::string>& no_grad_set_;
+  std::unordered_map<std::string, std::string>* grad_to_var_;
 };
 
 class SingleGradOpDescMaker : public GradOpDescMakerBase {
diff --git a/paddle/framework/type_defs.h b/paddle/framework/type_defs.h
index 7e1b79c97b..0d1564a751 100644
--- a/paddle/framework/type_defs.h
+++ b/paddle/framework/type_defs.h
@@ -37,7 +37,8 @@ using OpCreator = std::function<OperatorBase*(
     const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
 
 using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDescBind>>(
-    const OpDescBind&, const std::unordered_set<std::string>& /*no_grad_set*/)>;
+    const OpDescBind&, const std::unordered_set<std::string>& /*no_grad_set*/,
+    std::unordered_map<std::string, std::string>* /*grad_to_var*/)>;
 
 }  // namespace framework
 }  // namespace paddle
-- 
GitLab