backward.cc 32.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/eager/backward.h"
16

17
#include <deque>
18

19 20
#include "glog/logging.h"
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
21 22 23 24 25 26
#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/grad_tensor_holder.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
27 28
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
J
Jiabin Yang 已提交
29
#include "paddle/phi/kernels/autotune/switch_autotune.h"
30 31 32

namespace egr {

33
/*
34 35 36 37
 * GeneralGrad is Helpper class to implement custom grad operation between
 * outputs and inputs.
 *
 * **/
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
class GeneralGrad {
 public:
  static GeneralGrad& Instance() { return *general_grad_; }

  // Get inputs's / no_grad_vars's GradNodes and InputMeta Info
  void GetTargetNodesInfo(
      const std::vector<paddle::experimental::Tensor>& inputs,
      bool is_no_grad_vars) {
    std::string msg = is_no_grad_vars ? "no_grad_vars" : "inputs";
    VLOG(6) << "Running in GetTargetNodesInfo.";
    if (!inputs.empty()) {
      VLOG(6) << msg << " are not empty.";
      size_t num_inputs = inputs.size();
      for (size_t i = 0; i < num_inputs; i++) {
        AutogradMeta* auto_grad_meta =
            EagerUtils::unsafe_autograd_meta(inputs[i]);
54 55 56
        auto* target_node = auto_grad_meta->GetMutableGradNode().get();

        if (orig_to_copied_node_mapping_.count(target_node)) {
57
          target_node = orig_to_copied_node_mapping_[target_node].get();
58 59 60 61 62 63
        } else {
          VLOG(6) << "Unable to find target node in "
                     "orig_to_copied_node_mapping_, likely indicating an "
                     "unused input";
        }

64 65 66 67 68 69
        PADDLE_ENFORCE_NOT_NULL(target_node,
                                paddle::platform::errors::Fatal(
                                    "There is no grad op for %s:[%d] or it's"
                                    "stop_gradient=True.",
                                    msg, i));
        if (is_no_grad_vars) {
70
          (no_grad_var_nodes_inputmeta_map_)[target_node] = auto_grad_meta;
71
        } else {  // normal input
72
          (input_target_nodes_inputmeta_map_)[target_node] = auto_grad_meta;
73 74 75 76
        }
      }
    }
  }
77

78
  // Purify potential_startup_nodes_, remove nodes those are the same as
79 80 81
  // input_target_nodes
  void PurifyPotentialStartUpNodes() {
    VLOG(6) << "Running in PurifyPotentialStartUpNodes";
82
    if (input_target_nodes_inputmeta_map_.empty()) return;
83
    std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
84 85 86
    for (auto startup_op : potential_startup_nodes_) {
      auto iter = input_target_nodes_inputmeta_map_.find(startup_op);
      if (iter != input_target_nodes_inputmeta_map_.end()) {
87 88 89 90 91
        potential_startup_nodes_to_be_erased.emplace(iter->first);
      }
    }
    if (!potential_startup_nodes_to_be_erased.empty()) {
      for (auto nodes : potential_startup_nodes_to_be_erased) {
92
        potential_startup_nodes_.erase(nodes);
93 94 95
      }
    }
  }
96

97
  // Remove some nodes those doesn't need to be
98
  // stored in potential_stop_nodes_、potential_startup_nodes_
99
  void UpdateGraphInfo() {
100
    // Updated potential_sotp_nodes by depending_nodes_,
101
    // make sure the path from root to target_node is ok
102
    std::unordered_set<GradNodeBase*> startup_ops;
103
    VLOG(6) << "Running in UpdateGraphInfo";
104
    std::deque<GradNodeBase*> queue;
105 106
    for (auto& target_nodes_inputmeta_pair :
         input_target_nodes_inputmeta_map_) {
107
      queue.push_back(target_nodes_inputmeta_pair.first);
108
    }
109

110 111
    while (!queue.empty()) {
      auto* target_node = queue.front();
112
      queue.pop_front();
113 114
      if (!(depending_nodes_)[target_node].empty()) {
        auto precedding_nodes = (depending_nodes_)[target_node];
115
        for (auto pre_nodes : precedding_nodes) {
116
          queue.push_back(pre_nodes);
117 118 119
          if (potential_stop_nodes_.find(pre_nodes) !=
              potential_stop_nodes_.end()) {
            potential_stop_nodes_.erase(pre_nodes);
120 121 122
          }
        }
      } else {  // startup_ops have no precedding nodes
123 124
        VLOG(6) << "Emplace startup_ops";
        startup_ops.emplace(target_node);
125 126
      }
    }
127
    // Purify potential_startup_nodes_ again, remove some
128
    // potential startup_nodes that unreach to input target nodes
129
    if (!startup_ops.empty()) {
130
      std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
131 132
      for (auto node : potential_startup_nodes_) {
        if (startup_ops.count(node) == 0) {
133 134 135 136 137 138 139
          VLOG(6) << "Set up potential_startup_nodes_to_be_erased";
          potential_startup_nodes_to_be_erased.emplace(node);
        }
      }
      if (!potential_startup_nodes_to_be_erased.empty()) {
        for (auto node : potential_startup_nodes_to_be_erased) {
          VLOG(6) << "Erase nodes in potential_startup_nodes_to_be_erased";
140
          potential_startup_nodes_.erase(node);
141 142
        }
      }
143
    }
144
  }
145

146
  // Get Graph Info Betweent input target GradNode and outputs,
147
  // record depending_nodes_、potential_stop_nodes_、potential_startup_nodes_
148
  void GetGraphInfoBetweenTargets(const std::deque<GradNodeBase*>& init_queue) {
149
    VLOG(6) << "Runing In GetGraphInfoBetweenTargets";
150

151 152 153 154
    // Calculate in_degree for each node
    std::unordered_map<GradNodeBase*, int> node_in_degree_map;

    // Copy nodes
155
    std::deque<GradNodeBase*> queue = init_queue;
156 157 158 159 160
    std::unordered_set<GradNodeBase*> visited;

    // Visit each node exactly once in any order
    while (!queue.empty()) {
      GradNodeBase* node = queue.front();
161
      queue.pop_front();
162 163 164 165 166 167 168

      if (visited.count(node)) {
        continue;
      }
      visited.insert(node);

      // Check node is target_nodes or not, if node is not target_node,
169
      // all the next_node will be marked in potential_stop_nodes_
170
      bool is_potential_stop_nodes =
171
          input_target_nodes_inputmeta_map_.count(node);
172 173

      // Find and append next nodes
174 175 176 177 178 179
      const paddle::small_vector<std::vector<GradSlotMeta>,
                                 kSlotSmallVectorSize>& metas =
          node->OutputMeta();
      for (const auto& meta_list : metas) {
        for (const GradSlotMeta& meta : meta_list) {
          const auto& edge = meta.GetEdge();
180 181 182 183 184 185 186 187 188 189 190
          GradNodeBase* next_node = edge.GetMutableGradNode().get();

          // Next node could be nullptr if it is leaf tensor with no
          // AccumulationNode attached
          // Or it could also originated from dispensable inputs
          if (!next_node) continue;

          // if node not in input_target_nodes,
          // all the next_nodes of current node will be inserted to
          // potential_stop_node
          if (is_potential_stop_nodes) {
191
            potential_stop_nodes_.emplace(next_node);
192 193 194
          }

          // Update in_degree
195
          if (!node_in_degree_map.count(next_node)) {
196
            node_in_degree_map[next_node] = 0;
197
          }
198 199 200
          node_in_degree_map[next_node]++;

          // Record depending relationship
201
          (depending_nodes_)[next_node].emplace(node);
202
          queue.push_back(next_node);
203
        }
204 205
      }
    }
206
    // Update Graph Info, remove some nodes in
207
    // potential_stop_nodes_、potential_startup_nodes_、
208
    UpdateGraphInfo();
209 210
  }

211 212
  void ModifyReadyQueue(std::deque<GradNodeBase*>* queue) {
    std::deque<GradNodeBase*> tmp_queue;
213
    for (auto nodes : potential_startup_nodes_) {
214
      tmp_queue.push_back(nodes);
215 216
    }
    tmp_queue.swap(*queue);
217 218
  }

219
  // Set result for input target grad_var when potential_startup_nodes_ is empty
220 221 222 223
  void SetResultForInputTargetVar(
      const std::unordered_map<GradNodeBase*,
                               std::unique_ptr<GradTensorHolder>>&
          node_input_buffers_dict) {
224 225
    if (potential_startup_nodes_.size() == 0) {
      for (auto input_target_node : *GetInputTargetNodesInputMetaMap()) {
226 227 228 229 230 231 232
        // out rank_info of forward op
        auto rank_info = input_target_node.second->OutRankInfo();
        auto iter = node_input_buffers_dict.find(input_target_node.first);
        if (iter != node_input_buffers_dict.end()) {
          auto& target_result =
              (iter->second)->Buffers()[rank_info.first][rank_info.second];
          // save the target result
233
          results_map_[input_target_node.first] = target_result;
234 235 236 237
        }
      }
    }
  }
238 239 240 241

  // Set input target grad_var from node_input_buffer by inputmeta
  void SetResultForInputTargetVar(GradTensorHolder input_buffers,
                                  GradNodeBase* node) {
242 243
    auto iter = GetInputTargetNodesInputMetaMap()->find(node);
    if (iter != GetInputTargetNodesInputMetaMap()->end()) {
244 245 246 247 248 249 250
      VLOG(6) << "Get target result by by inputmeta";
      // out rank_info of forward op
      auto rank_info = (iter->second)->OutRankInfo();
      // rank_info is a pair, first means slot_id, second means rank.
      auto& target_result =
          input_buffers.Buffers()[rank_info.first][rank_info.second];
      // save the target result
251
      results_map_[node] = target_result;
252
    }
253 254 255 256 257 258 259 260 261 262 263 264 265 266
  }

  std::vector<paddle::experimental::Tensor> GetResults(
      const std::vector<paddle::experimental::Tensor>& inputs,
      bool allow_unused, bool create_graph) {
    VLOG(6) << "Running in GetResults";
    if (inputs.empty()) return {};

    std::vector<paddle::experimental::Tensor> results;
    results.reserve(inputs.size());

    for (size_t i = 0; i < inputs.size(); ++i) {
      auto& input = inputs[i];
      AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input);
267 268 269

      auto* target_node = auto_grad_meta->GetMutableGradNode().get();
      if (orig_to_copied_node_mapping_.count(target_node)) {
270
        target_node = orig_to_copied_node_mapping_[target_node].get();
271 272 273 274 275
      } else {
        VLOG(6) << "Unable to find target node in "
                   "orig_to_copied_node_mapping_, likely indicating an unused "
                   "input";
      }
276

277 278
      auto iter = results_map_.find(target_node);
      if (iter != results_map_.end()) {
279 280 281 282 283 284 285 286 287 288 289 290 291
        // set StopGradient = !create_graph
        AutogradMeta* tensor_auto_grad_meta =
            EagerUtils::autograd_meta(&(iter->second));
        tensor_auto_grad_meta->SetStopGradient(!create_graph);
        results.emplace_back(iter->second);
      } else {
        PADDLE_ENFORCE_EQ(allow_unused, true,
                          paddle::platform::errors::InvalidArgument(
                              "The %d-th input does not appear in the backward "
                              "graph. Please check the input tensor or set "
                              "allow_unused=True to get None result.",
                              i));
        results.emplace_back();
292 293
      }
    }
294 295 296 297 298 299 300
    Clear();
    return results;
  }

  void PreparedForGeneralGrad(
      const std::vector<paddle::experimental::Tensor>& inputs,
      const std::vector<paddle::experimental::Tensor>& no_grad_vars,
301
      std::deque<GradNodeBase*>* queue,
302 303 304 305 306 307 308
      const std::unordered_map<GradNodeBase*,
                               std::unique_ptr<GradTensorHolder>>&
          node_input_buffers_dict) {
    // Get no_grad_vars's GradNodes and InputMeta Info
    GetTargetNodesInfo(no_grad_vars, true /* is_no_grad_vars */);
    // Get inputs's GradNodes and InputMeta Info
    GetTargetNodesInfo(inputs, false /* is_no_grad_vars */);
309
    // Purify potentialstartup_ops, remove those nodes that are the same as
310 311 312
    // input_target_nodes
    PurifyPotentialStartUpNodes();
    // Get Graph Info Betweent input target gradnode and outputs
313 314
    // Record the depending_nodes_ and
    // potential_stop_nodes_、potential_startup_nodes_
315 316 317 318 319 320 321 322 323
    GetGraphInfoBetweenTargets(*queue);
    // Reset queue. Queue is empty only when
    // 1.input equals to output. 2.input can not reach to output.
    ModifyReadyQueue(queue);
    // Set result for input target grad_var when queue is empty
    if (queue->empty()) SetResultForInputTargetVar(node_input_buffers_dict);
  }

  bool IsPotentialStopNodes(GradNodeBase* node) {
324
    return potential_stop_nodes_.count(node);
325 326 327 328
  }

  std::unordered_map<GradNodeBase*, AutogradMeta*>*
  GetNoGradVarNodesInputMetaMap() {
329
    return &no_grad_var_nodes_inputmeta_map_;
330 331 332
  }

  std::unordered_map<GradNodeBase*, AutogradMeta*>*
333 334
  GetInputTargetNodesInputMetaMap() {
    return &input_target_nodes_inputmeta_map_;
335 336 337
  }

  std::unordered_set<GradNodeBase*>* GetPotentialStopNodes() {
338
    return &potential_stop_nodes_;
339 340 341
  }

  std::unordered_set<GradNodeBase*>* GetPotentialStartupNodes() {
342
    return &potential_startup_nodes_;
343 344 345
  }

  void Clear() {
346 347 348 349 350 351
    no_grad_var_nodes_inputmeta_map_.clear();
    input_target_nodes_inputmeta_map_.clear();
    potential_startup_nodes_.clear();
    potential_stop_nodes_.clear();
    depending_nodes_.clear();
    results_map_.clear();
352 353 354 355 356 357
    copied_grad_nodes_.clear();
    orig_to_copied_node_mapping_.clear();
  }

  GradNodeBase* CopyGradNode(const std::shared_ptr<GradNodeBase>& orig_node) {
    if (orig_to_copied_node_mapping_.count(orig_node.get())) {
358
      return orig_to_copied_node_mapping_[orig_node.get()].get();
359 360 361 362
    }
    std::shared_ptr<GradNodeBase> copied_node = orig_node->Copy();

    // Save node and update mapping
363
    orig_to_copied_node_mapping_[orig_node.get()] = copied_node;
364 365 366 367 368 369
    copied_grad_nodes_.push_back(copied_node);

    return copied_node.get();
  }

  void ReconstructBackwardGraph(
370 371
      const std::deque<GradNodeBase*>& orig_init_queue) {
    std::deque<GradNodeBase*> queue = orig_init_queue;
372 373 374 375 376
    std::unordered_set<GradNodeBase*> visited;

    // BFS and recursively copy the grad nodes
    while (!queue.empty()) {
      GradNodeBase* orig_node = queue.front();
377
      queue.pop_front();
378 379 380 381 382 383 384 385 386 387
      if (visited.count(orig_node)) {
        continue;
      }
      visited.insert(orig_node);

      PADDLE_ENFORCE(
          orig_to_copied_node_mapping_.count(orig_node),
          paddle::platform::errors::Fatal(
              "Cannot reconstruct backward graph,"
              "unable to find copied target for certain grad node."));
388
      GradNodeBase* copied_node = orig_to_copied_node_mapping_[orig_node].get();
389

390 391 392 393 394 395 396 397 398
      const paddle::small_vector<std::vector<GradSlotMeta>,
                                 kSlotSmallVectorSize>& orig_meta =
          orig_node->OutputMeta();
      paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
          copied_edges = copied_node->MutableOutputMeta();
      for (size_t i = 0; i < orig_meta.size(); i++) {
        for (size_t j = 0; j < orig_meta[i].size(); j++) {
          const Edge& orig_edge = orig_meta[i][j].GetEdge();
          Edge& copied_edge = copied_edges[i][j].GetMutableEdge();
399 400 401 402 403 404 405 406 407

          std::shared_ptr<GradNodeBase> orig_next_node =
              orig_edge.GetMutableGradNode();
          if (!orig_next_node) continue;

          // Copy Next Node
          std::shared_ptr<GradNodeBase> copied_next_node;
          if (orig_to_copied_node_mapping_.count(orig_next_node.get())) {
            copied_next_node =
408
                orig_to_copied_node_mapping_[orig_next_node.get()];
409 410 411 412

          } else {
            copied_next_node = orig_next_node->Copy();
            orig_to_copied_node_mapping_[orig_next_node.get()] =
413
                copied_next_node;
414 415 416 417 418 419 420
            copied_grad_nodes_.push_back(copied_next_node);
          }

          // Update Edge's Grad Node
          copied_edge.SetGradNode(copied_next_node);

          // Update BFS queue
421
          queue.push_back(orig_next_node.get());
422 423 424
        }
      }
    }
425 426
  }

427 428 429 430 431
 private:
  GeneralGrad() = default;
  static GeneralGrad* general_grad_;
  // no_grad_vars's GradNode and GradNode's InputMeta.
  std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
432
      no_grad_var_nodes_inputmeta_map_;
433 434
  // inputs's GradNode and GradNode's InputMeta.
  std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
435
      input_target_nodes_inputmeta_map_;
436
  // Record all the potential startup_nodes, will be changed.
437
  std::unordered_set<GradNodeBase*> potential_startup_nodes_;
438
  // Record all the potential stop nodes, will be changed.
439
  std::unordered_set<GradNodeBase*> potential_stop_nodes_;
440 441
  std::unordered_map<GradNodeBase* /* next node */,
                     std::unordered_set<GradNodeBase*> /* pre nodes */>
442 443
      depending_nodes_;
  std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map_;
444 445

  std::vector<std::shared_ptr<GradNodeBase>> copied_grad_nodes_;
446 447
  std::unordered_map<GradNodeBase*, std::shared_ptr<GradNodeBase>>
      orig_to_copied_node_mapping_;
448

449 450
  DISABLE_COPY_AND_ASSIGN(GeneralGrad);
};
451

452
std::unordered_map<GradNodeBase*, int> getInDegreeMap(
453
    const std::deque<GradNodeBase*>& init_queue) {
454
  // Calculate in_degree for each node
455 456
  // We can completely remove this pass, if in_degree were set during forward
  // pass
457 458 459
  std::unordered_map<GradNodeBase*, int> node_in_degree_map;

  // Copy nodes
460
  std::deque<GradNodeBase*> queue = init_queue;
461 462 463 464 465
  std::unordered_set<GradNodeBase*> visited;

  // Visit each node exactly once in any order
  while (!queue.empty()) {
    GradNodeBase* node = queue.front();
466
    queue.pop_front();
467 468 469 470 471 472

    if (visited.count(node)) {
      continue;
    }
    visited.insert(node);

473 474 475 476 477
    PADDLE_ENFORCE_NOT_NULL(
        node,
        paddle::platform::errors::Fatal(
            "We got null node when we traverse the backward graph, and this "
            "should not happened please check your code and contact us."));
478
    // Find and append next nodes
479 480 481 482 483
    const paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
        metas = node->OutputMeta();
    for (const auto& meta_list : metas) {
      for (const GradSlotMeta& meta : meta_list) {
        const auto& edge = meta.GetEdge();
484 485 486 487 488 489 490 491 492 493
        GradNodeBase* next_node = edge.GetMutableGradNode().get();
        // Next node could be nullptr if it is leaf tensor with no
        // AccumulationNode attached
        // Or it could also originated from dispensable inputs
        if (!next_node) continue;

        // Update in_degree
        if (!node_in_degree_map.count(next_node))
          node_in_degree_map[next_node] = 0;
        node_in_degree_map[next_node]++;
494
        queue.push_back(next_node);
495 496 497
      }
    }
  }
498

499
  return node_in_degree_map;
500 501 502 503 504 505 506 507 508 509 510 511 512 513 514
}

// Enforce GradNode has TensorWrappers as Input
void EnforceGradNodeHasInput(GradNodeBase* node) {
  VLOG(6) << "Running in EnforceGradNodeHasInput";
  PADDLE_ENFORCE_NE(
      node->IsTensorWrappersCleared(), true,
      paddle::platform::errors::Fatal(
          "The TensorWrappers of %s do not exist. This may be because:\n"
          "You calculate backward twice for the same subgraph without "
          "setting retain_graph=True. Please set retain_graph=True in the "
          "first backward/grad call.\n",
          node->name()));
}

515 516 517 518 519 520 521 522 523 524 525 526
void DuplicateCheck(const std::vector<paddle::experimental::Tensor>& inputs,
                    bool is_input) {
  std::unordered_set<AutogradMeta*> visisted_ins;
  std::string msg = is_input ? "inputs" : "outputs";
  for (auto in : inputs) {
    AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(in);
    PADDLE_ENFORCE_EQ(
        visisted_ins.count(auto_grad_meta), 0,
        paddle::platform::errors::AlreadyExists(
            "%s contain duplicate tensor %s, please check %s carefully.", msg,
            in.name(), msg));
    visisted_ins.insert(auto_grad_meta);
527 528 529
  }
}

530 531
GeneralGrad* GeneralGrad::general_grad_ = new GeneralGrad();

532 533 534 535 536 537 538
std::vector<paddle::experimental::Tensor> RunBackward(
    const std::vector<paddle::experimental::Tensor>& tensors,  // output
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
    bool retain_graph, bool create_graph = false,
    const std::vector<paddle::experimental::Tensor>& inputs = {},
    bool allow_unused = false,
    const std::vector<paddle::experimental::Tensor>& no_grad_vars = {}) {
539
  VLOG(3) << "Start Backward";
540

541 542 543 544
  // *Gradient Hook should happen at node-level
  // *Inplace version check should perform at node-level
  // *Cross-batch accumulation happens at forward pass

545 546
  // GeneralGrad
  bool is_general_grad = !inputs.empty();
547
  if (is_general_grad) GeneralGrad::Instance().Clear();
548

549 550 551
  /* --- Initialization --- */
  // 1. Init queue with starting nodes
  // 2. Prepare initial input buffers
552 553
  std::deque<GradNodeBase*> queue;
  std::deque<GradNodeBase*> orig_queue;
554 555 556
  std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>>
      node_input_buffers_dict;
  for (size_t i = 0; i < tensors.size(); i++) {
557
    const paddle::experimental::Tensor& tensor = tensors[i];
558

559 560 561 562 563 564 565
    AutogradMeta* auto_grad_meta = EagerUtils::nullable_autograd_meta(tensor);
    if (auto_grad_meta == nullptr) {
      VLOG(3) << "Skip auto grad since there is no grad op for var or loss is "
                 "stop_gradient=True: "
              << tensor.name();
      continue;
    }
566 567 568 569 570 571
    // Get grad input info from target tensors
    auto input_info = auto_grad_meta->OutRankInfo();

    VLOG(2) << "Out Rank of Tensor is slot: " << input_info.first
            << ", rank: " << input_info.second;
    // Get target GradNodeBase from target tensors
572 573 574 575 576 577 578 579 580 581
    auto shared_grad_node = auto_grad_meta->GetMutableGradNode();

    if (shared_grad_node == nullptr || shared_grad_node.get() == nullptr ||
        auto_grad_meta->StopGradient()) {
      VLOG(3) << "Skip auto grad since there is no grad op for var or loss is "
                 "stop_gradient=True: "
              << tensor.name();
      continue;
    }

582
    // TODO(zhanlve): Copy and Modify GradNode if is_general_grad
583
    GradNodeBase* grad_node = shared_grad_node.get();
584 585
    if (is_general_grad) {
      // Save orig grad node
586
      orig_queue.push_back(grad_node);
587 588 589 590 591 592 593

      // Replace grad_node with copied grad_node
      grad_node = GeneralGrad::Instance().CopyGradNode(shared_grad_node);

      // Record potential startup grad node
      GeneralGrad::Instance().GetPotentialStartupNodes()->insert(grad_node);
    }
594 595 596

    // Prepare GradTensorHolder
    if (!node_input_buffers_dict.count(grad_node)) {
597 598
      VLOG(6) << "Create Value for grad input tensor " << i
              << " of grad node: " << grad_node->name();
599 600 601
      node_input_buffers_dict[grad_node] =
          std::make_unique<GradTensorHolder>(grad_node->InputMeta());
    }
602 603 604
    bool copy_from_grad_t =
        grad_tensors.size() > 0 && grad_tensors[i].initialized();
    if (copy_from_grad_t) {
605 606 607 608 609
      PADDLE_ENFORCE(
          grad_tensors.size() == tensors.size(),
          paddle::platform::errors::Fatal(
              "Detected size mismatch between tensors and grad_tensors"
              "grad_tensors should either have "
610
              "size = 0 or same size as tensors."));
611 612
      // Feed given tensor if it's provided
      VLOG(6) << "Fill grad input tensor " << i << "with give grad tensor";
613

614 615 616
      // Deep copy
      node_input_buffers_dict[grad_node]->CopyValueFromTensor(
          input_info.first, input_info.second, grad_tensors[i]);
617 618 619 620 621 622 623
    } else {
      VLOG(6) << "Fill grad input tensor " << i << " with 1.0";
      // Initialize tensor with 1.0
      // Forward Tensor "tensor" is passed to indicate tensortype, datatype and
      // dims
      // GradTensorHolder will initialize another tensor with same tensortype,
      // datatype and dims but filled with 1.0
624
      node_input_buffers_dict[grad_node]->CopyValueFromTensor(
625
          input_info.first, input_info.second, tensor, /*fill_one=*/true);
626 627
    }

628
    // Prepare queue, potential startup_nodes
629
    queue.push_back(grad_node);
630 631 632 633 634
  }

  if (is_general_grad) {
    // Copy Backward Graph
    GeneralGrad::Instance().ReconstructBackwardGraph(orig_queue);
635 636
  }

637
  VLOG(3) << "Update In degree Map for backward";
638 639 640 641
  // 3. Compute in_degree for each node
  std::unordered_map<GradNodeBase*, int> node_in_degree_map =
      getInDegreeMap(queue);

642 643 644 645
  if (is_general_grad) {
    // Prepare several vital preprocess for GeneralGrad
    GeneralGrad::Instance().PreparedForGeneralGrad(inputs, no_grad_vars, &queue,
                                                   node_input_buffers_dict);
646 647
  }

648
  VLOG(6) << " startup_ops' size is :" << queue.size();
649

650 651 652
  /* --- Topological Visit --- */
  // 1. Pop queue
  // 2. Run node
653
  //    |- Check and capture target result
654 655 656
  //    |- node(grads)
  //    |- Prepare for next node
  // 3. Update queue
657
  VLOG(3) << "Run Backward";
658 659
  while (!queue.empty()) {
    GradNodeBase* node = queue.front();
660
    VLOG(6) << "Running GradNode:" << node->name();
661

662
    paddle::platform::RecordEvent node_record_event(
663
        std::string((*node).name()),
664 665
        paddle::platform::TracerEventType::Operator, 1);

666
    if (queue.size() > 1 && node_in_degree_map[node] != 0) {
667
      queue.pop_front();
668 669
      continue;
    }
670
    queue.pop_front();
671

672
    // Run node: This is where Hook happens
673 674 675
    auto node_input_buffer_iter = node_input_buffers_dict.find(node);
    PADDLE_ENFORCE_NE(
        node_input_buffer_iter, node_input_buffers_dict.end(),
676
        paddle::platform::errors::Fatal(
677
            "Unable to find next node in the GradTensorHolder \n"
678
            "Trying to run Node without configuring its GradTensorHolder."));
679 680

    std::unique_ptr<GradTensorHolder> node_input_buffer =
681
        std::move(node_input_buffer_iter->second);
682

683 684 685 686
    // Set input target grad_var from node_input_buffer by inputmeta
    if (!inputs.empty() && is_general_grad) {
      GeneralGrad::Instance().SetResultForInputTargetVar(*node_input_buffer,
                                                         node);
687 688 689
    }

    // no_grad_vars
690 691 692 693 694 695 696 697 698 699
    if (!no_grad_vars.empty() && is_general_grad) {
      auto iter =
          GeneralGrad::Instance().GetNoGradVarNodesInputMetaMap()->find(node);
      if (iter !=
          GeneralGrad::Instance().GetNoGradVarNodesInputMetaMap()->end()) {
        VLOG(6) << "Change the input buffer[slot][rank] by Zeros";
        auto rank_info = (iter->second)->OutRankInfo();
        node_input_buffer->SetBufferSlotRankZeros(rank_info.first,
                                                  rank_info.second);
      }
700 701
    }

702
    // Check input
703 704
    EnforceGradNodeHasInput(node);

705
    VLOG(6) << "Run Backward Kernel with GradTensorHolder.";
706
    // Run Pre Backward Node and get outputs
707 708 709 710
    paddle::small_vector<std::vector<paddle::experimental::Tensor>,
                         kSlotSmallVectorSize>
        grad_output_tensors = (*node)(node_input_buffer->Buffers(),
                                      create_graph, is_general_grad);
711 712 713 714 715 716 717 718

    // retain_grad or not
    if (!retain_graph) {
      VLOG(6)
          << "retain_graph is false, need to clear the TensorWrapper of nodes.";
      node->ClearTensorWrappers();
    }

719
    // TODO(jiabin): Should we erase it or find a more efficient way.
720
    node_input_buffers_dict.erase(node_input_buffer_iter);
721 722

    // Prepare GradTensorHolder for next node
723 724 725
    const paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
        metas = node->OutputMeta();
    PADDLE_ENFORCE(metas.size() == grad_output_tensors.size() || metas.empty(),
726 727
                   paddle::platform::errors::Fatal(
                       "Number of edges should be either empty ( for leaf node "
728 729
                       ") or the same as number of output grad tensors, but we "
                       "got edges size is: %d, grad_output size is: %d",
730
                       metas.size(), grad_output_tensors.size()));
731

732 733 734
    for (size_t i = 0; i < metas.size(); i++) {
      for (size_t j = 0; j < metas[i].size(); j++) {
        const Edge& edge = metas[i][j].GetEdge();
J
Jiabin Yang 已提交
735 736 737
        if (!edge.IsInitialized()) {
          continue;
        }
738 739
        auto edge_rank = edge.GetEdgeRankInfo();
        // Since we make edge has as same rank as bwd outputs, we indexing them
740
        // with the same rank(i, j)
741
        auto next_node_shared = edge.GetMutableGradNode();
742
        VLOG(3) << "Found pending node: " << next_node_shared->name();
743 744 745
        // Next node could be nullptr if it is leaf tensor with no
        // AccumulationNode attached
        // Or it could also originated from dispensable inputs
746 747 748 749
        if (!next_node_shared || !next_node_shared.get() ||
            grad_output_tensors[i].empty()) {
          continue;
        }
750

751 752 753 754 755 756 757
        PADDLE_ENFORCE_LT(
            j, grad_output_tensors[i].size(),
            paddle::platform::errors::Fatal(
                "Rank of grad_output_tensors should be less than "
                "grad_output_tensors[i].size(), which is: %d. This error may "
                "indicate autoprune or autograd api error. ",
                grad_output_tensors.size()));
758 759
        paddle::experimental::Tensor& grad_output_tensor =
            grad_output_tensors[i][j];
760 761 762

        if ((!grad_output_tensor.defined() ||
             !grad_output_tensor.initialized())) {
763 764
          VLOG(6) << "We get grad_output_tensor with slot: " << i
                  << ", rank: " << j << " as uninitialized or undefined tensor";
765
        }
766

767 768 769 770
        VLOG(6) << "Get Edge and grad_output_tensor with slot: " << i
                << ", rank: " << j
                << " 's name is: " << grad_output_tensor.name();

771 772 773 774 775 776 777 778 779 780
        auto* next_node = next_node_shared.get();
        if (!node_input_buffers_dict.count(next_node)) {
          const auto& input_meta = next_node->InputMeta();
          auto grad_tensor_holder =
              std::make_unique<GradTensorHolder>(input_meta);
          VLOG(6) << "Construct GradTensorHolder for grad node: "
                  << next_node->name();
          node_input_buffers_dict[next_node] = std::move(grad_tensor_holder);
        }

781 782
        VLOG(6) << "Sum grad inputs for edge slot: " << edge_rank.first
                << ", rank: " << edge_rank.second;
783

784
        node_input_buffers_dict[next_node]->add(
785 786
            edge_rank.first, edge_rank.second, grad_output_tensor,
            create_graph);
787 788 789

        // Update queue
        node_in_degree_map[next_node]--;
790

791 792 793 794
        PADDLE_ENFORCE(
            node_in_degree_map[next_node] >= 0,
            paddle::platform::errors::Fatal(
                "Detected in-degree value smaller than zero. For Node: %s"
795
                "Node's in-degree cannot be negative.",
796
                next_node->name()));
797

798 799 800 801
        if (is_general_grad) {
          bool is_potential_stop_node =
              GeneralGrad::Instance().GetPotentialStopNodes()->count(next_node);
          if (node_in_degree_map[next_node] == 0 && !is_potential_stop_node) {
802 803 804 805 806
            if (dynamic_cast<egr::GradNodeAccumulation*>(next_node)) {
              queue.push_front(std::move(next_node));
            } else {
              queue.push_back(std::move(next_node));
            }
807 808 809
          }
        } else {
          if (node_in_degree_map[next_node] == 0) {
810 811 812 813 814
            if (dynamic_cast<egr::GradNodeAccumulation*>(next_node)) {
              queue.push_front(std::move(next_node));
            } else {
              queue.push_back(std::move(next_node));
            }
815
          }
816 817 818 819
        }
      }
    }
  }
820

821 822
  if (!is_general_grad) return {};
  return GeneralGrad::Instance().GetResults(inputs, allow_unused, create_graph);
823 824
}

825
void Backward(
826
    const std::vector<paddle::experimental::Tensor>& tensors,  // outputs
827 828
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
    bool retain_graph) {
829
  VLOG(3) << "Run in Backward";
830 831 832
  paddle::platform::RecordEvent backward_record_event(
      "backward", paddle::platform::TracerEventType::Operator, 1);
  RunBackward(tensors, grad_tensors, retain_graph);
J
Jiabin Yang 已提交
833
  phi::autotune::AutoTuneStatus::Instance().Update();
834 835 836
}

std::vector<paddle::experimental::Tensor> Grad(
837
    const std::vector<paddle::experimental::Tensor>& tensors,  // outputs
838 839 840 841
    const std::vector<paddle::experimental::Tensor>& inputs,
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
    bool retain_graph, bool create_graph, bool only_inputs, bool allow_unused,
    const std::vector<paddle::experimental::Tensor>& no_grad_vars) {
842
  VLOG(3) << "Run in Grad";
843 844 845 846

  DuplicateCheck(inputs, true /* is_input */);
  DuplicateCheck(tensors, false /* is_input */);

847 848 849
  return RunBackward(tensors, grad_tensors, retain_graph, create_graph, inputs,
                     allow_unused, no_grad_vars);
}
850
}  // namespace egr