backward.cc 31.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/eager/backward.h"
#include <queue>

#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/grad_tensor_holder.h"
#include "paddle/fluid/eager/utils.h"
22 23
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
24

J
Jiabin Yang 已提交
25
#include "glog/logging.h"
26 27
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
J
Jiabin Yang 已提交
28
#include "paddle/phi/kernels/autotune/switch_autotune.h"
29 30 31

namespace egr {

32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
/*
* GeneralGrad is Helpper class to implement custom grad operation between
* outputs and inputs.
*
* **/
class GeneralGrad {
 public:
  static GeneralGrad& Instance() { return *general_grad_; }

  // Get inputs's / no_grad_vars's GradNodes and InputMeta Info
  void GetTargetNodesInfo(
      const std::vector<paddle::experimental::Tensor>& inputs,
      bool is_no_grad_vars) {
    std::string msg = is_no_grad_vars ? "no_grad_vars" : "inputs";
    VLOG(6) << "Running in GetTargetNodesInfo.";
    if (!inputs.empty()) {
      VLOG(6) << msg << " are not empty.";
      size_t num_inputs = inputs.size();
      for (size_t i = 0; i < num_inputs; i++) {
        AutogradMeta* auto_grad_meta =
            EagerUtils::unsafe_autograd_meta(inputs[i]);
53 54 55
        auto* target_node = auto_grad_meta->GetMutableGradNode().get();

        if (orig_to_copied_node_mapping_.count(target_node)) {
56
          target_node = orig_to_copied_node_mapping_[target_node].get();
57 58 59 60 61 62
        } else {
          VLOG(6) << "Unable to find target node in "
                     "orig_to_copied_node_mapping_, likely indicating an "
                     "unused input";
        }

63 64 65 66 67 68
        PADDLE_ENFORCE_NOT_NULL(target_node,
                                paddle::platform::errors::Fatal(
                                    "There is no grad op for %s:[%d] or it's"
                                    "stop_gradient=True.",
                                    msg, i));
        if (is_no_grad_vars) {
69
          (no_grad_var_nodes_inputmeta_map_)[target_node] = auto_grad_meta;
70
        } else {  // normal input
71
          (input_target_nodes_inputmeta_map_)[target_node] = auto_grad_meta;
72 73 74 75
        }
      }
    }
  }
76

77
  // Purify potential_startup_nodes_, remove nodes those are the same as
78 79 80
  // input_target_nodes
  void PurifyPotentialStartUpNodes() {
    VLOG(6) << "Running in PurifyPotentialStartUpNodes";
81
    if (input_target_nodes_inputmeta_map_.empty()) return;
82
    std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
83 84 85
    for (auto startup_op : potential_startup_nodes_) {
      auto iter = input_target_nodes_inputmeta_map_.find(startup_op);
      if (iter != input_target_nodes_inputmeta_map_.end()) {
86 87 88 89 90
        potential_startup_nodes_to_be_erased.emplace(iter->first);
      }
    }
    if (!potential_startup_nodes_to_be_erased.empty()) {
      for (auto nodes : potential_startup_nodes_to_be_erased) {
91
        potential_startup_nodes_.erase(nodes);
92 93 94
      }
    }
  }
95

96
  // Remove some nodes those doesn't need to be
97
  // stored in potential_stop_nodes_、potential_startup_nodes_
98
  void UpdateGraphInfo() {
99
    // Updated potential_sotp_nodes by depending_nodes_,
100
    // make sure the path from root to target_node is ok
101
    std::unordered_set<GradNodeBase*> startup_ops;
102 103
    VLOG(6) << "Running in UpdateGraphInfo";
    std::queue<GradNodeBase*> queue;
104 105
    for (auto& target_nodes_inputmeta_pair :
         input_target_nodes_inputmeta_map_) {
106 107
      queue.emplace(target_nodes_inputmeta_pair.first);
    }
108

109 110 111
    while (!queue.empty()) {
      auto* target_node = queue.front();
      queue.pop();
112 113
      if (!(depending_nodes_)[target_node].empty()) {
        auto precedding_nodes = (depending_nodes_)[target_node];
114 115
        for (auto pre_nodes : precedding_nodes) {
          queue.emplace(pre_nodes);
116 117 118
          if (potential_stop_nodes_.find(pre_nodes) !=
              potential_stop_nodes_.end()) {
            potential_stop_nodes_.erase(pre_nodes);
119 120 121
          }
        }
      } else {  // startup_ops have no precedding nodes
122 123
        VLOG(6) << "Emplace startup_ops";
        startup_ops.emplace(target_node);
124 125
      }
    }
126
    // Purify potential_startup_nodes_ again, remove some
127
    // potential startup_nodes that unreach to input target nodes
128
    if (!startup_ops.empty()) {
129
      std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
130 131
      for (auto node : potential_startup_nodes_) {
        if (startup_ops.count(node) == 0) {
132 133 134 135 136 137 138
          VLOG(6) << "Set up potential_startup_nodes_to_be_erased";
          potential_startup_nodes_to_be_erased.emplace(node);
        }
      }
      if (!potential_startup_nodes_to_be_erased.empty()) {
        for (auto node : potential_startup_nodes_to_be_erased) {
          VLOG(6) << "Erase nodes in potential_startup_nodes_to_be_erased";
139
          potential_startup_nodes_.erase(node);
140 141
        }
      }
142
    }
143
  }
144

145
  // Get Graph Info Betweent input target GradNode and outputs,
146
  // record depending_nodes_、potential_stop_nodes_、potential_startup_nodes_
147 148
  void GetGraphInfoBetweenTargets(const std::queue<GradNodeBase*>& init_queue) {
    VLOG(6) << "Runing In GetGraphInfoBetweenTargets";
149

150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
    // Calculate in_degree for each node
    std::unordered_map<GradNodeBase*, int> node_in_degree_map;

    // Copy nodes
    std::queue<GradNodeBase*> queue = init_queue;
    std::unordered_set<GradNodeBase*> visited;

    // Visit each node exactly once in any order
    while (!queue.empty()) {
      GradNodeBase* node = queue.front();
      queue.pop();

      if (visited.count(node)) {
        continue;
      }
      visited.insert(node);

      // Check node is target_nodes or not, if node is not target_node,
168
      // all the next_node will be marked in potential_stop_nodes_
169
      bool is_potential_stop_nodes =
170
          input_target_nodes_inputmeta_map_.count(node);
171 172

      // Find and append next nodes
173 174 175 176 177 178
      const paddle::small_vector<std::vector<GradSlotMeta>,
                                 kSlotSmallVectorSize>& metas =
          node->OutputMeta();
      for (const auto& meta_list : metas) {
        for (const GradSlotMeta& meta : meta_list) {
          const auto& edge = meta.GetEdge();
179 180 181 182 183 184 185 186 187 188 189
          GradNodeBase* next_node = edge.GetMutableGradNode().get();

          // Next node could be nullptr if it is leaf tensor with no
          // AccumulationNode attached
          // Or it could also originated from dispensable inputs
          if (!next_node) continue;

          // if node not in input_target_nodes,
          // all the next_nodes of current node will be inserted to
          // potential_stop_node
          if (is_potential_stop_nodes) {
190
            potential_stop_nodes_.emplace(next_node);
191 192 193
          }

          // Update in_degree
194
          if (!node_in_degree_map.count(next_node)) {
195
            node_in_degree_map[next_node] = 0;
196
          }
197 198 199
          node_in_degree_map[next_node]++;

          // Record depending relationship
200
          (depending_nodes_)[next_node].emplace(node);
201 202
          queue.push(next_node);
        }
203 204
      }
    }
205
    // Update Graph Info, remove some nodes in
206
    // potential_stop_nodes_、potential_startup_nodes_、
207
    UpdateGraphInfo();
208 209
  }

210 211
  void ModifyReadyQueue(std::queue<GradNodeBase*>* queue) {
    std::queue<GradNodeBase*> tmp_queue;
212
    for (auto nodes : potential_startup_nodes_) {
213 214 215
      tmp_queue.emplace(nodes);
    }
    tmp_queue.swap(*queue);
216 217
  }

218
  // Set result for input target grad_var when potential_startup_nodes_ is empty
219 220 221 222
  void SetResultForInputTargetVar(
      const std::unordered_map<GradNodeBase*,
                               std::unique_ptr<GradTensorHolder>>&
          node_input_buffers_dict) {
223 224
    if (potential_startup_nodes_.size() == 0) {
      for (auto input_target_node : *GetInputTargetNodesInputMetaMap()) {
225 226 227 228 229 230 231
        // out rank_info of forward op
        auto rank_info = input_target_node.second->OutRankInfo();
        auto iter = node_input_buffers_dict.find(input_target_node.first);
        if (iter != node_input_buffers_dict.end()) {
          auto& target_result =
              (iter->second)->Buffers()[rank_info.first][rank_info.second];
          // save the target result
232
          results_map_[input_target_node.first] = target_result;
233 234 235 236
        }
      }
    }
  }
237 238 239 240

  // Set input target grad_var from node_input_buffer by inputmeta
  void SetResultForInputTargetVar(GradTensorHolder input_buffers,
                                  GradNodeBase* node) {
241 242
    auto iter = GetInputTargetNodesInputMetaMap()->find(node);
    if (iter != GetInputTargetNodesInputMetaMap()->end()) {
243 244 245 246 247 248 249
      VLOG(6) << "Get target result by by inputmeta";
      // out rank_info of forward op
      auto rank_info = (iter->second)->OutRankInfo();
      // rank_info is a pair, first means slot_id, second means rank.
      auto& target_result =
          input_buffers.Buffers()[rank_info.first][rank_info.second];
      // save the target result
250
      results_map_[node] = target_result;
251
    }
252 253 254 255 256 257 258 259 260 261 262 263 264 265
  }

  std::vector<paddle::experimental::Tensor> GetResults(
      const std::vector<paddle::experimental::Tensor>& inputs,
      bool allow_unused, bool create_graph) {
    VLOG(6) << "Running in GetResults";
    if (inputs.empty()) return {};

    std::vector<paddle::experimental::Tensor> results;
    results.reserve(inputs.size());

    for (size_t i = 0; i < inputs.size(); ++i) {
      auto& input = inputs[i];
      AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input);
266 267 268

      auto* target_node = auto_grad_meta->GetMutableGradNode().get();
      if (orig_to_copied_node_mapping_.count(target_node)) {
269
        target_node = orig_to_copied_node_mapping_[target_node].get();
270 271 272 273 274
      } else {
        VLOG(6) << "Unable to find target node in "
                   "orig_to_copied_node_mapping_, likely indicating an unused "
                   "input";
      }
275

276 277
      auto iter = results_map_.find(target_node);
      if (iter != results_map_.end()) {
278 279 280 281 282 283 284 285 286 287 288 289 290
        // set StopGradient = !create_graph
        AutogradMeta* tensor_auto_grad_meta =
            EagerUtils::autograd_meta(&(iter->second));
        tensor_auto_grad_meta->SetStopGradient(!create_graph);
        results.emplace_back(iter->second);
      } else {
        PADDLE_ENFORCE_EQ(allow_unused, true,
                          paddle::platform::errors::InvalidArgument(
                              "The %d-th input does not appear in the backward "
                              "graph. Please check the input tensor or set "
                              "allow_unused=True to get None result.",
                              i));
        results.emplace_back();
291 292
      }
    }
293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
    Clear();
    return results;
  }

  void PreparedForGeneralGrad(
      const std::vector<paddle::experimental::Tensor>& inputs,
      const std::vector<paddle::experimental::Tensor>& no_grad_vars,
      std::queue<GradNodeBase*>* queue,
      const std::unordered_map<GradNodeBase*,
                               std::unique_ptr<GradTensorHolder>>&
          node_input_buffers_dict) {
    // Get no_grad_vars's GradNodes and InputMeta Info
    GetTargetNodesInfo(no_grad_vars, true /* is_no_grad_vars */);
    // Get inputs's GradNodes and InputMeta Info
    GetTargetNodesInfo(inputs, false /* is_no_grad_vars */);
308
    // Purify potentialstartup_ops, remove those nodes that are the same as
309 310 311
    // input_target_nodes
    PurifyPotentialStartUpNodes();
    // Get Graph Info Betweent input target gradnode and outputs
312 313
    // Record the depending_nodes_ and
    // potential_stop_nodes_、potential_startup_nodes_
314 315 316 317 318 319 320 321 322
    GetGraphInfoBetweenTargets(*queue);
    // Reset queue. Queue is empty only when
    // 1.input equals to output. 2.input can not reach to output.
    ModifyReadyQueue(queue);
    // Set result for input target grad_var when queue is empty
    if (queue->empty()) SetResultForInputTargetVar(node_input_buffers_dict);
  }

  bool IsPotentialStopNodes(GradNodeBase* node) {
323
    return potential_stop_nodes_.count(node);
324 325 326 327
  }

  std::unordered_map<GradNodeBase*, AutogradMeta*>*
  GetNoGradVarNodesInputMetaMap() {
328
    return &no_grad_var_nodes_inputmeta_map_;
329 330 331
  }

  std::unordered_map<GradNodeBase*, AutogradMeta*>*
332 333
  GetInputTargetNodesInputMetaMap() {
    return &input_target_nodes_inputmeta_map_;
334 335 336
  }

  std::unordered_set<GradNodeBase*>* GetPotentialStopNodes() {
337
    return &potential_stop_nodes_;
338 339 340
  }

  std::unordered_set<GradNodeBase*>* GetPotentialStartupNodes() {
341
    return &potential_startup_nodes_;
342 343 344
  }

  void Clear() {
345 346 347 348 349 350
    no_grad_var_nodes_inputmeta_map_.clear();
    input_target_nodes_inputmeta_map_.clear();
    potential_startup_nodes_.clear();
    potential_stop_nodes_.clear();
    depending_nodes_.clear();
    results_map_.clear();
351 352 353 354 355 356
    copied_grad_nodes_.clear();
    orig_to_copied_node_mapping_.clear();
  }

  GradNodeBase* CopyGradNode(const std::shared_ptr<GradNodeBase>& orig_node) {
    if (orig_to_copied_node_mapping_.count(orig_node.get())) {
357
      return orig_to_copied_node_mapping_[orig_node.get()].get();
358 359 360 361
    }
    std::shared_ptr<GradNodeBase> copied_node = orig_node->Copy();

    // Save node and update mapping
362
    orig_to_copied_node_mapping_[orig_node.get()] = copied_node;
363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386
    copied_grad_nodes_.push_back(copied_node);

    return copied_node.get();
  }

  void ReconstructBackwardGraph(
      const std::queue<GradNodeBase*>& orig_init_queue) {
    std::queue<GradNodeBase*> queue = orig_init_queue;
    std::unordered_set<GradNodeBase*> visited;

    // BFS and recursively copy the grad nodes
    while (!queue.empty()) {
      GradNodeBase* orig_node = queue.front();
      queue.pop();
      if (visited.count(orig_node)) {
        continue;
      }
      visited.insert(orig_node);

      PADDLE_ENFORCE(
          orig_to_copied_node_mapping_.count(orig_node),
          paddle::platform::errors::Fatal(
              "Cannot reconstruct backward graph,"
              "unable to find copied target for certain grad node."));
387
      GradNodeBase* copied_node = orig_to_copied_node_mapping_[orig_node].get();
388

389 390 391 392 393 394 395 396 397
      const paddle::small_vector<std::vector<GradSlotMeta>,
                                 kSlotSmallVectorSize>& orig_meta =
          orig_node->OutputMeta();
      paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
          copied_edges = copied_node->MutableOutputMeta();
      for (size_t i = 0; i < orig_meta.size(); i++) {
        for (size_t j = 0; j < orig_meta[i].size(); j++) {
          const Edge& orig_edge = orig_meta[i][j].GetEdge();
          Edge& copied_edge = copied_edges[i][j].GetMutableEdge();
398 399 400 401 402 403 404 405 406

          std::shared_ptr<GradNodeBase> orig_next_node =
              orig_edge.GetMutableGradNode();
          if (!orig_next_node) continue;

          // Copy Next Node
          std::shared_ptr<GradNodeBase> copied_next_node;
          if (orig_to_copied_node_mapping_.count(orig_next_node.get())) {
            copied_next_node =
407
                orig_to_copied_node_mapping_[orig_next_node.get()];
408 409 410 411

          } else {
            copied_next_node = orig_next_node->Copy();
            orig_to_copied_node_mapping_[orig_next_node.get()] =
412
                copied_next_node;
413 414 415 416 417 418 419 420 421 422 423
            copied_grad_nodes_.push_back(copied_next_node);
          }

          // Update Edge's Grad Node
          copied_edge.SetGradNode(copied_next_node);

          // Update BFS queue
          queue.push(orig_next_node.get());
        }
      }
    }
424 425
  }

426 427 428 429 430
 private:
  GeneralGrad() = default;
  static GeneralGrad* general_grad_;
  // no_grad_vars's GradNode and GradNode's InputMeta.
  std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
431
      no_grad_var_nodes_inputmeta_map_;
432 433
  // inputs's GradNode and GradNode's InputMeta.
  std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
434
      input_target_nodes_inputmeta_map_;
435
  // Record all the potential startup_nodes, will be changed.
436
  std::unordered_set<GradNodeBase*> potential_startup_nodes_;
437
  // Record all the potential stop nodes, will be changed.
438
  std::unordered_set<GradNodeBase*> potential_stop_nodes_;
439 440
  std::unordered_map<GradNodeBase* /* next node */,
                     std::unordered_set<GradNodeBase*> /* pre nodes */>
441 442
      depending_nodes_;
  std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map_;
443 444

  std::vector<std::shared_ptr<GradNodeBase>> copied_grad_nodes_;
445 446
  std::unordered_map<GradNodeBase*, std::shared_ptr<GradNodeBase>>
      orig_to_copied_node_mapping_;
447

448 449
  DISABLE_COPY_AND_ASSIGN(GeneralGrad);
};
450

451 452
std::unordered_map<GradNodeBase*, int> getInDegreeMap(
    const std::queue<GradNodeBase*>& init_queue) {
453
  // Calculate in_degree for each node
454 455
  // We can completely remove this pass, if in_degree were set during forward
  // pass
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471
  std::unordered_map<GradNodeBase*, int> node_in_degree_map;

  // Copy nodes
  std::queue<GradNodeBase*> queue = init_queue;
  std::unordered_set<GradNodeBase*> visited;

  // Visit each node exactly once in any order
  while (!queue.empty()) {
    GradNodeBase* node = queue.front();
    queue.pop();

    if (visited.count(node)) {
      continue;
    }
    visited.insert(node);

472 473 474 475 476
    PADDLE_ENFORCE_NOT_NULL(
        node,
        paddle::platform::errors::Fatal(
            "We got null node when we traverse the backward graph, and this "
            "should not happened please check your code and contact us."));
477
    // Find and append next nodes
478 479 480 481 482
    const paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
        metas = node->OutputMeta();
    for (const auto& meta_list : metas) {
      for (const GradSlotMeta& meta : meta_list) {
        const auto& edge = meta.GetEdge();
483 484 485 486 487 488 489 490 491 492 493 494 495 496
        GradNodeBase* next_node = edge.GetMutableGradNode().get();
        // Next node could be nullptr if it is leaf tensor with no
        // AccumulationNode attached
        // Or it could also originated from dispensable inputs
        if (!next_node) continue;

        // Update in_degree
        if (!node_in_degree_map.count(next_node))
          node_in_degree_map[next_node] = 0;
        node_in_degree_map[next_node]++;
        queue.push(next_node);
      }
    }
  }
497

498
  return node_in_degree_map;
499 500 501 502 503 504 505 506 507 508 509 510 511 512 513
}

// Enforce GradNode has TensorWrappers as Input
void EnforceGradNodeHasInput(GradNodeBase* node) {
  VLOG(6) << "Running in EnforceGradNodeHasInput";
  PADDLE_ENFORCE_NE(
      node->IsTensorWrappersCleared(), true,
      paddle::platform::errors::Fatal(
          "The TensorWrappers of %s do not exist. This may be because:\n"
          "You calculate backward twice for the same subgraph without "
          "setting retain_graph=True. Please set retain_graph=True in the "
          "first backward/grad call.\n",
          node->name()));
}

514 515 516 517 518 519 520 521 522 523 524 525
void DuplicateCheck(const std::vector<paddle::experimental::Tensor>& inputs,
                    bool is_input) {
  std::unordered_set<AutogradMeta*> visisted_ins;
  std::string msg = is_input ? "inputs" : "outputs";
  for (auto in : inputs) {
    AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(in);
    PADDLE_ENFORCE_EQ(
        visisted_ins.count(auto_grad_meta), 0,
        paddle::platform::errors::AlreadyExists(
            "%s contain duplicate tensor %s, please check %s carefully.", msg,
            in.name(), msg));
    visisted_ins.insert(auto_grad_meta);
526 527 528
  }
}

529 530
GeneralGrad* GeneralGrad::general_grad_ = new GeneralGrad();

531 532 533 534 535 536 537
std::vector<paddle::experimental::Tensor> RunBackward(
    const std::vector<paddle::experimental::Tensor>& tensors,  // output
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
    bool retain_graph, bool create_graph = false,
    const std::vector<paddle::experimental::Tensor>& inputs = {},
    bool allow_unused = false,
    const std::vector<paddle::experimental::Tensor>& no_grad_vars = {}) {
538
  VLOG(6) << "Start Backward";
539

540 541 542 543
  // *Gradient Hook should happen at node-level
  // *Inplace version check should perform at node-level
  // *Cross-batch accumulation happens at forward pass

544 545
  // GeneralGrad
  bool is_general_grad = !inputs.empty();
546
  if (is_general_grad) GeneralGrad::Instance().Clear();
547

548 549 550 551
  /* --- Initialization --- */
  // 1. Init queue with starting nodes
  // 2. Prepare initial input buffers
  std::queue<GradNodeBase*> queue;
552
  std::queue<GradNodeBase*> orig_queue;
553 554 555
  std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>>
      node_input_buffers_dict;
  for (size_t i = 0; i < tensors.size(); i++) {
556
    const paddle::experimental::Tensor& tensor = tensors[i];
557

558 559 560 561 562 563 564
    AutogradMeta* auto_grad_meta = EagerUtils::nullable_autograd_meta(tensor);
    if (auto_grad_meta == nullptr) {
      VLOG(3) << "Skip auto grad since there is no grad op for var or loss is "
                 "stop_gradient=True: "
              << tensor.name();
      continue;
    }
565 566 567 568 569 570
    // Get grad input info from target tensors
    auto input_info = auto_grad_meta->OutRankInfo();

    VLOG(2) << "Out Rank of Tensor is slot: " << input_info.first
            << ", rank: " << input_info.second;
    // Get target GradNodeBase from target tensors
571 572 573 574 575 576 577 578 579 580
    auto shared_grad_node = auto_grad_meta->GetMutableGradNode();

    if (shared_grad_node == nullptr || shared_grad_node.get() == nullptr ||
        auto_grad_meta->StopGradient()) {
      VLOG(3) << "Skip auto grad since there is no grad op for var or loss is "
                 "stop_gradient=True: "
              << tensor.name();
      continue;
    }

581
    // TODO(zhanlve): Copy and Modify GradNode if is_general_grad
582
    GradNodeBase* grad_node = shared_grad_node.get();
583 584 585 586 587 588 589 590 591 592
    if (is_general_grad) {
      // Save orig grad node
      orig_queue.push(grad_node);

      // Replace grad_node with copied grad_node
      grad_node = GeneralGrad::Instance().CopyGradNode(shared_grad_node);

      // Record potential startup grad node
      GeneralGrad::Instance().GetPotentialStartupNodes()->insert(grad_node);
    }
593 594 595

    // Prepare GradTensorHolder
    if (!node_input_buffers_dict.count(grad_node)) {
596 597
      VLOG(6) << "Create Value for grad input tensor " << i
              << " of grad node: " << grad_node->name();
598 599 600
      node_input_buffers_dict[grad_node] =
          std::make_unique<GradTensorHolder>(grad_node->InputMeta());
    }
601 602 603
    bool copy_from_grad_t =
        grad_tensors.size() > 0 && grad_tensors[i].initialized();
    if (copy_from_grad_t) {
604 605 606 607 608
      PADDLE_ENFORCE(
          grad_tensors.size() == tensors.size(),
          paddle::platform::errors::Fatal(
              "Detected size mismatch between tensors and grad_tensors"
              "grad_tensors should either have "
609
              "size = 0 or same size as tensors."));
610 611
      // Feed given tensor if it's provided
      VLOG(6) << "Fill grad input tensor " << i << "with give grad tensor";
612

613 614 615
      // Deep copy
      node_input_buffers_dict[grad_node]->CopyValueFromTensor(
          input_info.first, input_info.second, grad_tensors[i]);
616 617 618 619 620 621 622
    } else {
      VLOG(6) << "Fill grad input tensor " << i << " with 1.0";
      // Initialize tensor with 1.0
      // Forward Tensor "tensor" is passed to indicate tensortype, datatype and
      // dims
      // GradTensorHolder will initialize another tensor with same tensortype,
      // datatype and dims but filled with 1.0
623
      node_input_buffers_dict[grad_node]->CopyValueFromTensor(
624
          input_info.first, input_info.second, tensor, /*fill_one=*/true);
625 626
    }

627
    // Prepare queue, potential startup_nodes
628
    queue.push(grad_node);
629 630 631 632 633
  }

  if (is_general_grad) {
    // Copy Backward Graph
    GeneralGrad::Instance().ReconstructBackwardGraph(orig_queue);
634 635 636 637 638 639 640
  }

  VLOG(6) << "Update In degree Map for backward";
  // 3. Compute in_degree for each node
  std::unordered_map<GradNodeBase*, int> node_in_degree_map =
      getInDegreeMap(queue);

641 642 643 644
  if (is_general_grad) {
    // Prepare several vital preprocess for GeneralGrad
    GeneralGrad::Instance().PreparedForGeneralGrad(inputs, no_grad_vars, &queue,
                                                   node_input_buffers_dict);
645 646
  }

647
  VLOG(6) << " startup_ops' size is :" << queue.size();
648

649 650 651
  /* --- Topological Visit --- */
  // 1. Pop queue
  // 2. Run node
652
  //    |- Check and capture target result
653 654 655 656
  //    |- node(grads)
  //    |- Prepare for next node
  // 3. Update queue
  VLOG(6) << "Run Backward";
657 658
  while (!queue.empty()) {
    GradNodeBase* node = queue.front();
659
    VLOG(6) << "Running GradNode:" << node->name();
660

661
    paddle::platform::RecordEvent node_record_event(
662
        std::string((*node).name()),
663 664
        paddle::platform::TracerEventType::Operator, 1);

665 666 667 668 669 670
    if (queue.size() > 1 && node_in_degree_map[node] != 0) {
      queue.pop();
      continue;
    }
    queue.pop();

671
    // Run node: This is where Hook happens
672 673 674
    auto node_input_buffer_iter = node_input_buffers_dict.find(node);
    PADDLE_ENFORCE_NE(
        node_input_buffer_iter, node_input_buffers_dict.end(),
675
        paddle::platform::errors::Fatal(
676
            "Unable to find next node in the GradTensorHolder \n"
677
            "Trying to run Node without configuring its GradTensorHolder."));
678 679

    std::unique_ptr<GradTensorHolder> node_input_buffer =
680
        std::move(node_input_buffer_iter->second);
681

682 683 684 685
    // Set input target grad_var from node_input_buffer by inputmeta
    if (!inputs.empty() && is_general_grad) {
      GeneralGrad::Instance().SetResultForInputTargetVar(*node_input_buffer,
                                                         node);
686 687 688
    }

    // no_grad_vars
689 690 691 692 693 694 695 696 697 698
    if (!no_grad_vars.empty() && is_general_grad) {
      auto iter =
          GeneralGrad::Instance().GetNoGradVarNodesInputMetaMap()->find(node);
      if (iter !=
          GeneralGrad::Instance().GetNoGradVarNodesInputMetaMap()->end()) {
        VLOG(6) << "Change the input buffer[slot][rank] by Zeros";
        auto rank_info = (iter->second)->OutRankInfo();
        node_input_buffer->SetBufferSlotRankZeros(rank_info.first,
                                                  rank_info.second);
      }
699 700 701 702
    }

    VLOG(6) << "Running GradNode:" << node->name();

703
    // Check input
704 705
    EnforceGradNodeHasInput(node);

706
    VLOG(6) << "Run Backward Kernel with GradTensorHolder.";
707
    // Run Pre Backward Node and get outputs
708 709 710 711
    paddle::small_vector<std::vector<paddle::experimental::Tensor>,
                         kSlotSmallVectorSize>
        grad_output_tensors = (*node)(node_input_buffer->Buffers(),
                                      create_graph, is_general_grad);
712 713 714 715 716 717 718 719

    // retain_grad or not
    if (!retain_graph) {
      VLOG(6)
          << "retain_graph is false, need to clear the TensorWrapper of nodes.";
      node->ClearTensorWrappers();
    }

720
    // TODO(jiabin): Should we erase it or find a more efficient way.
721
    node_input_buffers_dict.erase(node_input_buffer_iter);
722 723

    // Prepare GradTensorHolder for next node
724 725 726
    const paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
        metas = node->OutputMeta();
    PADDLE_ENFORCE(metas.size() == grad_output_tensors.size() || metas.empty(),
727 728
                   paddle::platform::errors::Fatal(
                       "Number of edges should be either empty ( for leaf node "
729 730
                       ") or the same as number of output grad tensors, but we "
                       "got edges size is: %d, grad_output size is: %d",
731
                       metas.size(), grad_output_tensors.size()));
732

733 734 735
    for (size_t i = 0; i < metas.size(); i++) {
      for (size_t j = 0; j < metas[i].size(); j++) {
        const Edge& edge = metas[i][j].GetEdge();
J
Jiabin Yang 已提交
736 737 738
        if (!edge.IsInitialized()) {
          continue;
        }
739 740
        auto edge_rank = edge.GetEdgeRankInfo();
        // Since we make edge has as same rank as bwd outputs, we indexing them
741
        // with the same rank(i, j)
742
        auto next_node_shared = edge.GetMutableGradNode();
743

744 745 746
        // Next node could be nullptr if it is leaf tensor with no
        // AccumulationNode attached
        // Or it could also originated from dispensable inputs
747 748 749 750
        if (!next_node_shared || !next_node_shared.get() ||
            grad_output_tensors[i].empty()) {
          continue;
        }
751

752 753 754 755 756 757 758
        PADDLE_ENFORCE_LT(
            j, grad_output_tensors[i].size(),
            paddle::platform::errors::Fatal(
                "Rank of grad_output_tensors should be less than "
                "grad_output_tensors[i].size(), which is: %d. This error may "
                "indicate autoprune or autograd api error. ",
                grad_output_tensors.size()));
759 760
        paddle::experimental::Tensor& grad_output_tensor =
            grad_output_tensors[i][j];
761 762 763

        if ((!grad_output_tensor.defined() ||
             !grad_output_tensor.initialized())) {
764 765
          VLOG(6) << "We get grad_output_tensor with slot: " << i
                  << ", rank: " << j << " as uninitialized or undefined tensor";
766
        }
767

768 769 770 771
        VLOG(6) << "Get Edge and grad_output_tensor with slot: " << i
                << ", rank: " << j
                << " 's name is: " << grad_output_tensor.name();

772 773 774 775 776 777 778 779 780 781
        auto* next_node = next_node_shared.get();
        if (!node_input_buffers_dict.count(next_node)) {
          const auto& input_meta = next_node->InputMeta();
          auto grad_tensor_holder =
              std::make_unique<GradTensorHolder>(input_meta);
          VLOG(6) << "Construct GradTensorHolder for grad node: "
                  << next_node->name();
          node_input_buffers_dict[next_node] = std::move(grad_tensor_holder);
        }

782 783
        VLOG(6) << "Sum grad inputs for edge slot: " << edge_rank.first
                << ", rank: " << edge_rank.second;
784

785
        node_input_buffers_dict[next_node]->add(
786 787
            edge_rank.first, edge_rank.second, grad_output_tensor,
            create_graph);
788 789 790

        // Update queue
        node_in_degree_map[next_node]--;
791

792 793 794 795
        PADDLE_ENFORCE(
            node_in_degree_map[next_node] >= 0,
            paddle::platform::errors::Fatal(
                "Detected in-degree value smaller than zero. For Node: %s"
796
                "Node's in-degree cannot be negative.",
797
                next_node->name()));
798

799 800 801 802 803 804 805 806 807 808
        if (is_general_grad) {
          bool is_potential_stop_node =
              GeneralGrad::Instance().GetPotentialStopNodes()->count(next_node);
          if (node_in_degree_map[next_node] == 0 && !is_potential_stop_node) {
            queue.emplace(std::move(next_node));
          }
        } else {
          if (node_in_degree_map[next_node] == 0) {
            queue.emplace(std::move(next_node));
          }
809 810 811 812
        }
      }
    }
  }
813

814 815
  if (!is_general_grad) return {};
  return GeneralGrad::Instance().GetResults(inputs, allow_unused, create_graph);
816 817
}

818
void Backward(
819
    const std::vector<paddle::experimental::Tensor>& tensors,  // outputs
820 821 822 823 824 825
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
    bool retain_graph) {
  VLOG(6) << "Run in Backward";
  paddle::platform::RecordEvent backward_record_event(
      "backward", paddle::platform::TracerEventType::Operator, 1);
  RunBackward(tensors, grad_tensors, retain_graph);
J
Jiabin Yang 已提交
826
  phi::autotune::AutoTuneStatus::Instance().Update();
827 828 829
}

std::vector<paddle::experimental::Tensor> Grad(
830
    const std::vector<paddle::experimental::Tensor>& tensors,  // outputs
831 832 833 834 835
    const std::vector<paddle::experimental::Tensor>& inputs,
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
    bool retain_graph, bool create_graph, bool only_inputs, bool allow_unused,
    const std::vector<paddle::experimental::Tensor>& no_grad_vars) {
  VLOG(6) << "Run in Grad";
836 837 838 839

  DuplicateCheck(inputs, true /* is_input */);
  DuplicateCheck(tensors, false /* is_input */);

840 841 842
  return RunBackward(tensors, grad_tensors, retain_graph, create_graph, inputs,
                     allow_unused, no_grad_vars);
}
843
}  // namespace egr