backward.cc 33.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/eager/backward.h"
16

17
#include <deque>
18

19 20
#include "glog/logging.h"
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
21 22 23 24 25 26
#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/grad_tensor_holder.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
27 28
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
J
Jiabin Yang 已提交
29
#include "paddle/phi/kernels/autotune/switch_autotune.h"
30 31 32

namespace egr {

33
/*
34 35 36 37
 * GeneralGrad is Helpper class to implement custom grad operation between
 * outputs and inputs.
 *
 * **/
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
class GeneralGrad {
 public:
  static GeneralGrad& Instance() { return *general_grad_; }

  // Get inputs's / no_grad_vars's GradNodes and InputMeta Info
  void GetTargetNodesInfo(
      const std::vector<paddle::experimental::Tensor>& inputs,
      bool is_no_grad_vars) {
    std::string msg = is_no_grad_vars ? "no_grad_vars" : "inputs";
    VLOG(6) << "Running in GetTargetNodesInfo.";
    if (!inputs.empty()) {
      VLOG(6) << msg << " are not empty.";
      size_t num_inputs = inputs.size();
      for (size_t i = 0; i < num_inputs; i++) {
        AutogradMeta* auto_grad_meta =
            EagerUtils::unsafe_autograd_meta(inputs[i]);
54
        auto* target_node = auto_grad_meta->GetMutableGradNode().get();
55 56 57 58 59 60 61 62
        VLOG(8) << "Get no grad vars' grad_node: " << target_node->name()
                << ", " << target_node << " with output rank info: "
                << auto_grad_meta->OutRankInfo().first << ", "
                << auto_grad_meta->OutRankInfo().second;
        if (is_no_grad_vars) {
          (no_grad_var_nodes_inputmeta_map_)[target_node] = auto_grad_meta;
          continue;
        }
63
        if (orig_to_copied_node_mapping_.count(target_node)) {
64
          target_node = orig_to_copied_node_mapping_[target_node].get();
65 66 67 68 69 70
        } else {
          VLOG(6) << "Unable to find target node in "
                     "orig_to_copied_node_mapping_, likely indicating an "
                     "unused input";
        }

71 72 73 74
        PADDLE_ENFORCE_NOT_NULL(target_node,
                                paddle::platform::errors::Fatal(
                                    "There is no grad op for %s:[%d] or it's"
                                    "stop_gradient=True.",
75 76
                                    msg,
                                    i));
77 78
        // normal input
        (input_target_nodes_inputmeta_map_)[target_node] = auto_grad_meta;
79 80 81
      }
    }
  }
82

83
  // Purify potential_startup_nodes_, remove nodes those are the same as
84 85 86
  // input_target_nodes
  void PurifyPotentialStartUpNodes() {
    VLOG(6) << "Running in PurifyPotentialStartUpNodes";
87
    if (input_target_nodes_inputmeta_map_.empty()) return;
88
    std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
89 90 91
    for (auto startup_op : potential_startup_nodes_) {
      auto iter = input_target_nodes_inputmeta_map_.find(startup_op);
      if (iter != input_target_nodes_inputmeta_map_.end()) {
92 93 94 95 96
        potential_startup_nodes_to_be_erased.emplace(iter->first);
      }
    }
    if (!potential_startup_nodes_to_be_erased.empty()) {
      for (auto nodes : potential_startup_nodes_to_be_erased) {
97
        potential_startup_nodes_.erase(nodes);
98 99 100
      }
    }
  }
101

102
  // Remove some nodes those doesn't need to be
103
  // stored in potential_stop_nodes_、potential_startup_nodes_
104
  void UpdateGraphInfo() {
105
    // Updated potential_sotp_nodes by depending_nodes_,
106
    // make sure the path from root to target_node is ok
107
    std::unordered_set<GradNodeBase*> startup_ops;
108
    VLOG(6) << "Running in UpdateGraphInfo";
109
    std::deque<GradNodeBase*> queue;
110 111
    for (auto& target_nodes_inputmeta_pair :
         input_target_nodes_inputmeta_map_) {
112
      queue.push_back(target_nodes_inputmeta_pair.first);
113
    }
114

115 116
    while (!queue.empty()) {
      auto* target_node = queue.front();
117
      queue.pop_front();
118 119
      if (!(depending_nodes_)[target_node].empty()) {
        auto precedding_nodes = (depending_nodes_)[target_node];
120
        for (auto pre_nodes : precedding_nodes) {
121
          queue.push_back(pre_nodes);
122 123 124
          if (potential_stop_nodes_.find(pre_nodes) !=
              potential_stop_nodes_.end()) {
            potential_stop_nodes_.erase(pre_nodes);
125 126 127
          }
        }
      } else {  // startup_ops have no precedding nodes
128 129
        VLOG(6) << "Emplace startup_ops";
        startup_ops.emplace(target_node);
130 131
      }
    }
132
    // Purify potential_startup_nodes_ again, remove some
133
    // potential startup_nodes that unreach to input target nodes
134
    if (!startup_ops.empty()) {
135
      std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
136 137
      for (auto node : potential_startup_nodes_) {
        if (startup_ops.count(node) == 0) {
138 139 140 141 142 143 144
          VLOG(6) << "Set up potential_startup_nodes_to_be_erased";
          potential_startup_nodes_to_be_erased.emplace(node);
        }
      }
      if (!potential_startup_nodes_to_be_erased.empty()) {
        for (auto node : potential_startup_nodes_to_be_erased) {
          VLOG(6) << "Erase nodes in potential_startup_nodes_to_be_erased";
145
          potential_startup_nodes_.erase(node);
146 147
        }
      }
148
    }
149
  }
150

151
  // Get Graph Info Betweent input target GradNode and outputs,
152
  // record depending_nodes_、potential_stop_nodes_、potential_startup_nodes_
153
  void GetGraphInfoBetweenTargets(const std::deque<GradNodeBase*>& init_queue) {
154
    VLOG(6) << "Runing In GetGraphInfoBetweenTargets";
155

156 157 158 159
    // Calculate in_degree for each node
    std::unordered_map<GradNodeBase*, int> node_in_degree_map;

    // Copy nodes
160
    std::deque<GradNodeBase*> queue = init_queue;
161 162 163 164 165
    std::unordered_set<GradNodeBase*> visited;

    // Visit each node exactly once in any order
    while (!queue.empty()) {
      GradNodeBase* node = queue.front();
166
      queue.pop_front();
167 168 169 170 171 172 173

      if (visited.count(node)) {
        continue;
      }
      visited.insert(node);

      // Check node is target_nodes or not, if node is not target_node,
174
      // all the next_node will be marked in potential_stop_nodes_
175
      bool is_potential_stop_nodes =
176
          input_target_nodes_inputmeta_map_.count(node);
177 178

      // Find and append next nodes
179 180 181 182 183 184
      const paddle::small_vector<std::vector<GradSlotMeta>,
                                 kSlotSmallVectorSize>& metas =
          node->OutputMeta();
      for (const auto& meta_list : metas) {
        for (const GradSlotMeta& meta : meta_list) {
          const auto& edge = meta.GetEdge();
185 186 187 188 189 190 191 192 193 194 195
          GradNodeBase* next_node = edge.GetMutableGradNode().get();

          // Next node could be nullptr if it is leaf tensor with no
          // AccumulationNode attached
          // Or it could also originated from dispensable inputs
          if (!next_node) continue;

          // if node not in input_target_nodes,
          // all the next_nodes of current node will be inserted to
          // potential_stop_node
          if (is_potential_stop_nodes) {
196
            potential_stop_nodes_.emplace(next_node);
197 198 199
          }

          // Update in_degree
200
          if (!node_in_degree_map.count(next_node)) {
201
            node_in_degree_map[next_node] = 0;
202
          }
203 204 205
          node_in_degree_map[next_node]++;

          // Record depending relationship
206
          (depending_nodes_)[next_node].emplace(node);
207
          queue.push_back(next_node);
208
        }
209 210
      }
    }
211
    // Update Graph Info, remove some nodes in
212
    // potential_stop_nodes_、potential_startup_nodes_、
213
    UpdateGraphInfo();
214 215
  }

216 217
  void ModifyReadyQueue(std::deque<GradNodeBase*>* queue) {
    std::deque<GradNodeBase*> tmp_queue;
218
    for (auto nodes : potential_startup_nodes_) {
219
      tmp_queue.push_back(nodes);
220 221
    }
    tmp_queue.swap(*queue);
222 223
  }

224
  // Set result for input target grad_var when potential_startup_nodes_ is empty
225 226 227 228
  void SetResultForInputTargetVar(
      const std::unordered_map<GradNodeBase*,
                               std::unique_ptr<GradTensorHolder>>&
          node_input_buffers_dict) {
229 230
    if (potential_startup_nodes_.size() == 0) {
      for (auto input_target_node : *GetInputTargetNodesInputMetaMap()) {
231 232 233 234 235 236 237
        // out rank_info of forward op
        auto rank_info = input_target_node.second->OutRankInfo();
        auto iter = node_input_buffers_dict.find(input_target_node.first);
        if (iter != node_input_buffers_dict.end()) {
          auto& target_result =
              (iter->second)->Buffers()[rank_info.first][rank_info.second];
          // save the target result
238
          results_map_[input_target_node.first] = target_result;
239 240 241 242
        }
      }
    }
  }
243 244 245 246

  // Set input target grad_var from node_input_buffer by inputmeta
  void SetResultForInputTargetVar(GradTensorHolder input_buffers,
                                  GradNodeBase* node) {
247 248
    auto iter = GetInputTargetNodesInputMetaMap()->find(node);
    if (iter != GetInputTargetNodesInputMetaMap()->end()) {
249 250 251 252 253 254 255
      VLOG(6) << "Get target result by by inputmeta";
      // out rank_info of forward op
      auto rank_info = (iter->second)->OutRankInfo();
      // rank_info is a pair, first means slot_id, second means rank.
      auto& target_result =
          input_buffers.Buffers()[rank_info.first][rank_info.second];
      // save the target result
256
      results_map_[node] = target_result;
257
    }
258 259 260 261
  }

  std::vector<paddle::experimental::Tensor> GetResults(
      const std::vector<paddle::experimental::Tensor>& inputs,
262 263
      bool allow_unused,
      bool create_graph) {
264 265 266 267 268 269 270 271 272
    VLOG(6) << "Running in GetResults";
    if (inputs.empty()) return {};

    std::vector<paddle::experimental::Tensor> results;
    results.reserve(inputs.size());

    for (size_t i = 0; i < inputs.size(); ++i) {
      auto& input = inputs[i];
      AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input);
273 274 275

      auto* target_node = auto_grad_meta->GetMutableGradNode().get();
      if (orig_to_copied_node_mapping_.count(target_node)) {
276
        target_node = orig_to_copied_node_mapping_[target_node].get();
277 278 279 280 281
      } else {
        VLOG(6) << "Unable to find target node in "
                   "orig_to_copied_node_mapping_, likely indicating an unused "
                   "input";
      }
282

283 284
      auto iter = results_map_.find(target_node);
      if (iter != results_map_.end()) {
285 286 287 288 289 290
        // set StopGradient = !create_graph
        AutogradMeta* tensor_auto_grad_meta =
            EagerUtils::autograd_meta(&(iter->second));
        tensor_auto_grad_meta->SetStopGradient(!create_graph);
        results.emplace_back(iter->second);
      } else {
291 292
        PADDLE_ENFORCE_EQ(allow_unused,
                          true,
293 294 295 296 297 298
                          paddle::platform::errors::InvalidArgument(
                              "The %d-th input does not appear in the backward "
                              "graph. Please check the input tensor or set "
                              "allow_unused=True to get None result.",
                              i));
        results.emplace_back();
299 300
      }
    }
301 302 303 304 305 306 307
    Clear();
    return results;
  }

  void PreparedForGeneralGrad(
      const std::vector<paddle::experimental::Tensor>& inputs,
      const std::vector<paddle::experimental::Tensor>& no_grad_vars,
308
      std::deque<GradNodeBase*>* queue,
309 310 311 312 313
      const std::unordered_map<GradNodeBase*,
                               std::unique_ptr<GradTensorHolder>>&
          node_input_buffers_dict) {
    // Get inputs's GradNodes and InputMeta Info
    GetTargetNodesInfo(inputs, false /* is_no_grad_vars */);
314
    // Purify potentialstartup_ops, remove those nodes that are the same as
315 316 317
    // input_target_nodes
    PurifyPotentialStartUpNodes();
    // Get Graph Info Betweent input target gradnode and outputs
318 319
    // Record the depending_nodes_ and
    // potential_stop_nodes_、potential_startup_nodes_
320 321 322 323 324 325 326 327 328
    GetGraphInfoBetweenTargets(*queue);
    // Reset queue. Queue is empty only when
    // 1.input equals to output. 2.input can not reach to output.
    ModifyReadyQueue(queue);
    // Set result for input target grad_var when queue is empty
    if (queue->empty()) SetResultForInputTargetVar(node_input_buffers_dict);
  }

  bool IsPotentialStopNodes(GradNodeBase* node) {
329
    return potential_stop_nodes_.count(node);
330 331 332 333
  }

  std::unordered_map<GradNodeBase*, AutogradMeta*>*
  GetNoGradVarNodesInputMetaMap() {
334
    return &no_grad_var_nodes_inputmeta_map_;
335 336 337
  }

  std::unordered_map<GradNodeBase*, AutogradMeta*>*
338 339
  GetInputTargetNodesInputMetaMap() {
    return &input_target_nodes_inputmeta_map_;
340 341 342
  }

  std::unordered_set<GradNodeBase*>* GetPotentialStopNodes() {
343
    return &potential_stop_nodes_;
344 345 346
  }

  std::unordered_set<GradNodeBase*>* GetPotentialStartupNodes() {
347
    return &potential_startup_nodes_;
348 349 350
  }

  void Clear() {
351 352 353 354 355 356
    no_grad_var_nodes_inputmeta_map_.clear();
    input_target_nodes_inputmeta_map_.clear();
    potential_startup_nodes_.clear();
    potential_stop_nodes_.clear();
    depending_nodes_.clear();
    results_map_.clear();
357 358 359 360 361 362
    copied_grad_nodes_.clear();
    orig_to_copied_node_mapping_.clear();
  }

  GradNodeBase* CopyGradNode(const std::shared_ptr<GradNodeBase>& orig_node) {
    if (orig_to_copied_node_mapping_.count(orig_node.get())) {
363
      return orig_to_copied_node_mapping_[orig_node.get()].get();
364 365 366 367
    }
    std::shared_ptr<GradNodeBase> copied_node = orig_node->Copy();

    // Save node and update mapping
368
    orig_to_copied_node_mapping_[orig_node.get()] = copied_node;
369 370 371 372 373 374
    copied_grad_nodes_.push_back(copied_node);

    return copied_node.get();
  }

  void ReconstructBackwardGraph(
375 376
      const std::deque<GradNodeBase*>& orig_init_queue) {
    std::deque<GradNodeBase*> queue = orig_init_queue;
377 378 379 380 381
    std::unordered_set<GradNodeBase*> visited;

    // BFS and recursively copy the grad nodes
    while (!queue.empty()) {
      GradNodeBase* orig_node = queue.front();
382
      queue.pop_front();
383 384 385 386 387 388 389 390 391 392
      if (visited.count(orig_node)) {
        continue;
      }
      visited.insert(orig_node);

      PADDLE_ENFORCE(
          orig_to_copied_node_mapping_.count(orig_node),
          paddle::platform::errors::Fatal(
              "Cannot reconstruct backward graph,"
              "unable to find copied target for certain grad node."));
393
      GradNodeBase* copied_node = orig_to_copied_node_mapping_[orig_node].get();
394

395 396 397 398 399 400 401 402 403
      const paddle::small_vector<std::vector<GradSlotMeta>,
                                 kSlotSmallVectorSize>& orig_meta =
          orig_node->OutputMeta();
      paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
          copied_edges = copied_node->MutableOutputMeta();
      for (size_t i = 0; i < orig_meta.size(); i++) {
        for (size_t j = 0; j < orig_meta[i].size(); j++) {
          const Edge& orig_edge = orig_meta[i][j].GetEdge();
          Edge& copied_edge = copied_edges[i][j].GetMutableEdge();
404 405 406

          std::shared_ptr<GradNodeBase> orig_next_node =
              orig_edge.GetMutableGradNode();
407 408 409 410 411 412 413 414 415 416 417 418 419 420 421

          if (no_grad_var_nodes_inputmeta_map_.count(orig_next_node.get()) &&
              (no_grad_var_nodes_inputmeta_map_[orig_next_node.get()]
                   ->OutRankInfo() == orig_edge.GetEdgeRankInfo())) {
            VLOG(3) << "Get no grad edge from grad_node: " << orig_node->name()
                    << " : " << orig_node << " to:" << orig_next_node->name()
                    << ", " << orig_next_node.get()
                    << " with output rank info: "
                    << orig_edge.GetEdgeRankInfo().first << ", "
                    << orig_edge.GetEdgeRankInfo().second;
            // Stop no grad var's preceding node
            copied_node->MutableOutputMeta()[i][j].SetStopGradient(true);
            copied_edge.Clear();
            continue;
          }
422 423 424 425 426 427
          if (!orig_next_node) continue;

          // Copy Next Node
          std::shared_ptr<GradNodeBase> copied_next_node;
          if (orig_to_copied_node_mapping_.count(orig_next_node.get())) {
            copied_next_node =
428
                orig_to_copied_node_mapping_[orig_next_node.get()];
429 430 431 432

          } else {
            copied_next_node = orig_next_node->Copy();
            orig_to_copied_node_mapping_[orig_next_node.get()] =
433
                copied_next_node;
434 435 436 437 438 439 440
            copied_grad_nodes_.push_back(copied_next_node);
          }

          // Update Edge's Grad Node
          copied_edge.SetGradNode(copied_next_node);

          // Update BFS queue
441
          queue.push_back(orig_next_node.get());
442 443 444
        }
      }
    }
445 446
  }

447 448 449 450 451
 private:
  GeneralGrad() = default;
  static GeneralGrad* general_grad_;
  // no_grad_vars's GradNode and GradNode's InputMeta.
  std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
452
      no_grad_var_nodes_inputmeta_map_;
453 454
  // inputs's GradNode and GradNode's InputMeta.
  std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
455
      input_target_nodes_inputmeta_map_;
456
  // Record all the potential startup_nodes, will be changed.
457
  std::unordered_set<GradNodeBase*> potential_startup_nodes_;
458
  // Record all the potential stop nodes, will be changed.
459
  std::unordered_set<GradNodeBase*> potential_stop_nodes_;
460 461
  std::unordered_map<GradNodeBase* /* next node */,
                     std::unordered_set<GradNodeBase*> /* pre nodes */>
462 463
      depending_nodes_;
  std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map_;
464 465

  std::vector<std::shared_ptr<GradNodeBase>> copied_grad_nodes_;
466 467
  std::unordered_map<GradNodeBase*, std::shared_ptr<GradNodeBase>>
      orig_to_copied_node_mapping_;
468

469 470
  DISABLE_COPY_AND_ASSIGN(GeneralGrad);
};
471

472
std::unordered_map<GradNodeBase*, int> getInDegreeMap(
473
    const std::deque<GradNodeBase*>& init_queue) {
474
  // Calculate in_degree for each node
475 476
  // We can completely remove this pass, if in_degree were set during forward
  // pass
477 478 479
  std::unordered_map<GradNodeBase*, int> node_in_degree_map;

  // Copy nodes
480
  std::deque<GradNodeBase*> queue = init_queue;
481 482 483 484 485
  std::unordered_set<GradNodeBase*> visited;

  // Visit each node exactly once in any order
  while (!queue.empty()) {
    GradNodeBase* node = queue.front();
486
    queue.pop_front();
487 488 489 490 491 492

    if (visited.count(node)) {
      continue;
    }
    visited.insert(node);

493 494 495 496 497
    PADDLE_ENFORCE_NOT_NULL(
        node,
        paddle::platform::errors::Fatal(
            "We got null node when we traverse the backward graph, and this "
            "should not happened please check your code and contact us."));
498
    // Find and append next nodes
499 500 501 502 503
    const paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
        metas = node->OutputMeta();
    for (const auto& meta_list : metas) {
      for (const GradSlotMeta& meta : meta_list) {
        const auto& edge = meta.GetEdge();
504 505 506 507 508 509 510 511 512 513
        GradNodeBase* next_node = edge.GetMutableGradNode().get();
        // Next node could be nullptr if it is leaf tensor with no
        // AccumulationNode attached
        // Or it could also originated from dispensable inputs
        if (!next_node) continue;

        // Update in_degree
        if (!node_in_degree_map.count(next_node))
          node_in_degree_map[next_node] = 0;
        node_in_degree_map[next_node]++;
514
        queue.push_back(next_node);
515 516 517
      }
    }
  }
518

519
  return node_in_degree_map;
520 521 522 523 524 525
}

// Enforce GradNode has TensorWrappers as Input
void EnforceGradNodeHasInput(GradNodeBase* node) {
  VLOG(6) << "Running in EnforceGradNodeHasInput";
  PADDLE_ENFORCE_NE(
526 527
      node->IsTensorWrappersCleared(),
      true,
528 529 530 531 532 533 534 535
      paddle::platform::errors::Fatal(
          "The TensorWrappers of %s do not exist. This may be because:\n"
          "You calculate backward twice for the same subgraph without "
          "setting retain_graph=True. Please set retain_graph=True in the "
          "first backward/grad call.\n",
          node->name()));
}

536 537 538 539 540 541 542
void DuplicateCheck(const std::vector<paddle::experimental::Tensor>& inputs,
                    bool is_input) {
  std::unordered_set<AutogradMeta*> visisted_ins;
  std::string msg = is_input ? "inputs" : "outputs";
  for (auto in : inputs) {
    AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(in);
    PADDLE_ENFORCE_EQ(
543 544
        visisted_ins.count(auto_grad_meta),
        0,
545
        paddle::platform::errors::AlreadyExists(
546 547 548 549
            "%s contain duplicate tensor %s, please check %s carefully.",
            msg,
            in.name(),
            msg));
550
    visisted_ins.insert(auto_grad_meta);
551 552 553
  }
}

554 555
GeneralGrad* GeneralGrad::general_grad_ = new GeneralGrad();

556 557 558
std::vector<paddle::experimental::Tensor> RunBackward(
    const std::vector<paddle::experimental::Tensor>& tensors,  // output
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
559 560
    bool retain_graph,
    bool create_graph = false,
561 562 563
    const std::vector<paddle::experimental::Tensor>& inputs = {},
    bool allow_unused = false,
    const std::vector<paddle::experimental::Tensor>& no_grad_vars = {}) {
564
  VLOG(3) << "Start Backward";
565

566 567 568 569
  // *Gradient Hook should happen at node-level
  // *Inplace version check should perform at node-level
  // *Cross-batch accumulation happens at forward pass

570 571
  // GeneralGrad
  bool is_general_grad = !inputs.empty();
572
  if (is_general_grad) GeneralGrad::Instance().Clear();
573

574 575 576
  /* --- Initialization --- */
  // 1. Init queue with starting nodes
  // 2. Prepare initial input buffers
577 578
  std::deque<GradNodeBase*> queue;
  std::deque<GradNodeBase*> orig_queue;
579 580 581
  std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>>
      node_input_buffers_dict;
  for (size_t i = 0; i < tensors.size(); i++) {
582
    const paddle::experimental::Tensor& tensor = tensors[i];
583

584 585 586 587 588 589 590
    AutogradMeta* auto_grad_meta = EagerUtils::nullable_autograd_meta(tensor);
    if (auto_grad_meta == nullptr) {
      VLOG(3) << "Skip auto grad since there is no grad op for var or loss is "
                 "stop_gradient=True: "
              << tensor.name();
      continue;
    }
591 592 593 594 595 596
    // Get grad input info from target tensors
    auto input_info = auto_grad_meta->OutRankInfo();

    VLOG(2) << "Out Rank of Tensor is slot: " << input_info.first
            << ", rank: " << input_info.second;
    // Get target GradNodeBase from target tensors
597 598 599 600 601 602 603 604 605 606
    auto shared_grad_node = auto_grad_meta->GetMutableGradNode();

    if (shared_grad_node == nullptr || shared_grad_node.get() == nullptr ||
        auto_grad_meta->StopGradient()) {
      VLOG(3) << "Skip auto grad since there is no grad op for var or loss is "
                 "stop_gradient=True: "
              << tensor.name();
      continue;
    }

607
    // TODO(zhanlve): Copy and Modify GradNode if is_general_grad
608
    GradNodeBase* grad_node = shared_grad_node.get();
609 610
    if (is_general_grad) {
      // Save orig grad node
611
      orig_queue.push_back(grad_node);
612 613 614 615 616 617 618

      // Replace grad_node with copied grad_node
      grad_node = GeneralGrad::Instance().CopyGradNode(shared_grad_node);

      // Record potential startup grad node
      GeneralGrad::Instance().GetPotentialStartupNodes()->insert(grad_node);
    }
619 620 621

    // Prepare GradTensorHolder
    if (!node_input_buffers_dict.count(grad_node)) {
622 623
      VLOG(6) << "Create Value for grad input tensor " << i
              << " of grad node: " << grad_node->name();
624 625 626
      node_input_buffers_dict[grad_node] =
          std::make_unique<GradTensorHolder>(grad_node->InputMeta());
    }
627 628 629
    bool copy_from_grad_t =
        grad_tensors.size() > 0 && grad_tensors[i].initialized();
    if (copy_from_grad_t) {
630 631 632 633 634
      PADDLE_ENFORCE(
          grad_tensors.size() == tensors.size(),
          paddle::platform::errors::Fatal(
              "Detected size mismatch between tensors and grad_tensors"
              "grad_tensors should either have "
635
              "size = 0 or same size as tensors."));
636 637
      // Feed given tensor if it's provided
      VLOG(6) << "Fill grad input tensor " << i << "with give grad tensor";
638

639 640 641
      // Deep copy
      node_input_buffers_dict[grad_node]->CopyValueFromTensor(
          input_info.first, input_info.second, grad_tensors[i]);
642 643 644 645 646 647 648
    } else {
      VLOG(6) << "Fill grad input tensor " << i << " with 1.0";
      // Initialize tensor with 1.0
      // Forward Tensor "tensor" is passed to indicate tensortype, datatype and
      // dims
      // GradTensorHolder will initialize another tensor with same tensortype,
      // datatype and dims but filled with 1.0
649
      node_input_buffers_dict[grad_node]->CopyValueFromTensor(
650
          input_info.first, input_info.second, tensor, /*fill_one=*/true);
651 652
    }

653
    // Prepare queue, potential startup_nodes
654
    queue.push_back(grad_node);
655 656 657
  }

  if (is_general_grad) {
658 659 660
    // Get no_grad_vars's GradNodes and InputMeta Info
    GeneralGrad::Instance().GetTargetNodesInfo(no_grad_vars,
                                               true /* is_no_grad_vars */);
661 662
    // Copy Backward Graph
    GeneralGrad::Instance().ReconstructBackwardGraph(orig_queue);
663 664
  }

665
  VLOG(3) << "Update In degree Map for backward";
666 667 668 669
  // 3. Compute in_degree for each node
  std::unordered_map<GradNodeBase*, int> node_in_degree_map =
      getInDegreeMap(queue);

670 671
  if (is_general_grad) {
    // Prepare several vital preprocess for GeneralGrad
672 673
    GeneralGrad::Instance().PreparedForGeneralGrad(
        inputs, no_grad_vars, &queue, node_input_buffers_dict);
674 675
  }

676
  VLOG(6) << " startup_ops' size is :" << queue.size();
677

678 679 680
  /* --- Topological Visit --- */
  // 1. Pop queue
  // 2. Run node
681
  //    |- Check and capture target result
682 683 684
  //    |- node(grads)
  //    |- Prepare for next node
  // 3. Update queue
685
  VLOG(3) << "Run Backward";
686 687
  while (!queue.empty()) {
    GradNodeBase* node = queue.front();
688
    VLOG(6) << "Running GradNode:" << node->name();
689

690
    paddle::platform::RecordEvent node_record_event(
691
        std::string((*node).name()),
692 693
        paddle::platform::TracerEventType::Operator,
        1);
694

695
    if (queue.size() > 1 && node_in_degree_map[node] != 0) {
696
      queue.pop_front();
697 698
      continue;
    }
699
    queue.pop_front();
700

701
    // Run node: This is where Hook happens
702 703
    auto node_input_buffer_iter = node_input_buffers_dict.find(node);
    PADDLE_ENFORCE_NE(
704 705
        node_input_buffer_iter,
        node_input_buffers_dict.end(),
706
        paddle::platform::errors::Fatal(
707
            "Unable to find next node in the GradTensorHolder \n"
708
            "Trying to run Node without configuring its GradTensorHolder."));
709 710

    std::unique_ptr<GradTensorHolder> node_input_buffer =
711
        std::move(node_input_buffer_iter->second);
712

713 714 715 716
    // Set input target grad_var from node_input_buffer by inputmeta
    if (!inputs.empty() && is_general_grad) {
      GeneralGrad::Instance().SetResultForInputTargetVar(*node_input_buffer,
                                                         node);
717 718
    }

719
    // Check input
720 721
    EnforceGradNodeHasInput(node);

722
    VLOG(6) << "Run Backward Kernel with GradTensorHolder.";
723
    // Run Pre Backward Node and get outputs
724 725
    paddle::small_vector<std::vector<paddle::experimental::Tensor>,
                         kSlotSmallVectorSize>
726 727
        grad_output_tensors = (*node)(
            node_input_buffer->Buffers(), create_graph, is_general_grad);
728 729 730 731 732 733 734 735

    // retain_grad or not
    if (!retain_graph) {
      VLOG(6)
          << "retain_graph is false, need to clear the TensorWrapper of nodes.";
      node->ClearTensorWrappers();
    }

736
    // TODO(jiabin): Should we erase it or find a more efficient way.
737
    node_input_buffers_dict.erase(node_input_buffer_iter);
738 739

    // Prepare GradTensorHolder for next node
740 741 742
    const paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
        metas = node->OutputMeta();
    PADDLE_ENFORCE(metas.size() == grad_output_tensors.size() || metas.empty(),
743 744
                   paddle::platform::errors::Fatal(
                       "Number of edges should be either empty ( for leaf node "
745 746
                       ") or the same as number of output grad tensors, but we "
                       "got edges size is: %d, grad_output size is: %d",
747 748
                       metas.size(),
                       grad_output_tensors.size()));
749

750 751 752
    for (size_t i = 0; i < metas.size(); i++) {
      for (size_t j = 0; j < metas[i].size(); j++) {
        const Edge& edge = metas[i][j].GetEdge();
J
Jiabin Yang 已提交
753 754 755
        if (!edge.IsInitialized()) {
          continue;
        }
756 757
        auto edge_rank = edge.GetEdgeRankInfo();
        // Since we make edge has as same rank as bwd outputs, we indexing them
758
        // with the same rank(i, j)
759
        auto next_node_shared = edge.GetMutableGradNode();
760 761
        VLOG(3) << "Found pending node: " << next_node_shared->name() << ": "
                << next_node_shared.get();
762 763 764
        // Next node could be nullptr if it is leaf tensor with no
        // AccumulationNode attached
        // Or it could also originated from dispensable inputs
765 766 767 768
        if (!next_node_shared || !next_node_shared.get() ||
            grad_output_tensors[i].empty()) {
          continue;
        }
769

770
        PADDLE_ENFORCE_LT(
771 772
            j,
            grad_output_tensors[i].size(),
773 774 775 776 777
            paddle::platform::errors::Fatal(
                "Rank of grad_output_tensors should be less than "
                "grad_output_tensors[i].size(), which is: %d. This error may "
                "indicate autoprune or autograd api error. ",
                grad_output_tensors.size()));
778 779
        paddle::experimental::Tensor& grad_output_tensor =
            grad_output_tensors[i][j];
780 781 782

        if ((!grad_output_tensor.defined() ||
             !grad_output_tensor.initialized())) {
783 784
          VLOG(6) << "We get grad_output_tensor with slot: " << i
                  << ", rank: " << j << " as uninitialized or undefined tensor";
785
        }
786

787 788 789 790
        VLOG(6) << "Get Edge and grad_output_tensor with slot: " << i
                << ", rank: " << j
                << " 's name is: " << grad_output_tensor.name();

791 792 793 794 795 796 797 798 799 800
        auto* next_node = next_node_shared.get();
        if (!node_input_buffers_dict.count(next_node)) {
          const auto& input_meta = next_node->InputMeta();
          auto grad_tensor_holder =
              std::make_unique<GradTensorHolder>(input_meta);
          VLOG(6) << "Construct GradTensorHolder for grad node: "
                  << next_node->name();
          node_input_buffers_dict[next_node] = std::move(grad_tensor_holder);
        }

801 802
        VLOG(6) << "Sum grad inputs for edge slot: " << edge_rank.first
                << ", rank: " << edge_rank.second;
803

804 805 806 807
        node_input_buffers_dict[next_node]->add(edge_rank.first,
                                                edge_rank.second,
                                                grad_output_tensor,
                                                create_graph);
808 809 810

        // Update queue
        node_in_degree_map[next_node]--;
811 812
        VLOG(6) << next_node->name()
                << " ref_cnt is: " << node_in_degree_map[next_node];
813

814 815 816 817
        PADDLE_ENFORCE(
            node_in_degree_map[next_node] >= 0,
            paddle::platform::errors::Fatal(
                "Detected in-degree value smaller than zero. For Node: %s"
818
                "Node's in-degree cannot be negative.",
819
                next_node->name()));
820

821 822 823 824
        if (is_general_grad) {
          bool is_potential_stop_node =
              GeneralGrad::Instance().GetPotentialStopNodes()->count(next_node);
          if (node_in_degree_map[next_node] == 0 && !is_potential_stop_node) {
825 826 827 828 829
            if (dynamic_cast<egr::GradNodeAccumulation*>(next_node)) {
              queue.push_front(std::move(next_node));
            } else {
              queue.push_back(std::move(next_node));
            }
830 831 832
          }
        } else {
          if (node_in_degree_map[next_node] == 0) {
833 834 835 836 837
            if (dynamic_cast<egr::GradNodeAccumulation*>(next_node)) {
              queue.push_front(std::move(next_node));
            } else {
              queue.push_back(std::move(next_node));
            }
838
          }
839 840 841 842
        }
      }
    }
  }
843

844 845
  if (!is_general_grad) return {};
  return GeneralGrad::Instance().GetResults(inputs, allow_unused, create_graph);
846 847
}

848
void Backward(
849
    const std::vector<paddle::experimental::Tensor>& tensors,  // outputs
850 851
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
    bool retain_graph) {
852
  VLOG(3) << "Run in Backward";
853
  paddle::platform::RecordEvent backward_record_event(
854
      "backward", paddle::platform::TracerEventType::UserDefined, 1);
855
  RunBackward(tensors, grad_tensors, retain_graph);
J
Jiabin Yang 已提交
856
  phi::autotune::AutoTuneStatus::Instance().Update();
857 858 859
}

std::vector<paddle::experimental::Tensor> Grad(
860
    const std::vector<paddle::experimental::Tensor>& tensors,  // outputs
861 862
    const std::vector<paddle::experimental::Tensor>& inputs,
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
863 864 865 866
    bool retain_graph,
    bool create_graph,
    bool only_inputs,
    bool allow_unused,
867
    const std::vector<paddle::experimental::Tensor>& no_grad_vars) {
868
  VLOG(3) << "Run in Grad";
869 870 871 872

  DuplicateCheck(inputs, true /* is_input */);
  DuplicateCheck(tensors, false /* is_input */);

873 874 875 876 877 878 879
  return RunBackward(tensors,
                     grad_tensors,
                     retain_graph,
                     create_graph,
                     inputs,
                     allow_unused,
                     no_grad_vars);
880
}
881
}  // namespace egr