backward.cc 32.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/eager/backward.h"
16

17
#include <deque>
18

19 20
#include "glog/logging.h"
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
21 22 23 24 25 26
#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/grad_tensor_holder.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
27 28
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
J
Jiabin Yang 已提交
29
#include "paddle/phi/kernels/autotune/switch_autotune.h"
30 31 32

namespace egr {

33
/*
34 35 36 37
 * GeneralGrad is Helpper class to implement custom grad operation between
 * outputs and inputs.
 *
 * **/
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
class GeneralGrad {
 public:
  static GeneralGrad& Instance() { return *general_grad_; }

  // Get inputs's / no_grad_vars's GradNodes and InputMeta Info
  void GetTargetNodesInfo(
      const std::vector<paddle::experimental::Tensor>& inputs,
      bool is_no_grad_vars) {
    std::string msg = is_no_grad_vars ? "no_grad_vars" : "inputs";
    VLOG(6) << "Running in GetTargetNodesInfo.";
    if (!inputs.empty()) {
      VLOG(6) << msg << " are not empty.";
      size_t num_inputs = inputs.size();
      for (size_t i = 0; i < num_inputs; i++) {
        AutogradMeta* auto_grad_meta =
            EagerUtils::unsafe_autograd_meta(inputs[i]);
54 55 56
        auto* target_node = auto_grad_meta->GetMutableGradNode().get();

        if (orig_to_copied_node_mapping_.count(target_node)) {
57
          target_node = orig_to_copied_node_mapping_[target_node].get();
58 59 60 61 62 63
        } else {
          VLOG(6) << "Unable to find target node in "
                     "orig_to_copied_node_mapping_, likely indicating an "
                     "unused input";
        }

64 65 66 67
        PADDLE_ENFORCE_NOT_NULL(target_node,
                                paddle::platform::errors::Fatal(
                                    "There is no grad op for %s:[%d] or it's"
                                    "stop_gradient=True.",
68 69
                                    msg,
                                    i));
70
        if (is_no_grad_vars) {
71
          (no_grad_var_nodes_inputmeta_map_)[target_node] = auto_grad_meta;
72
        } else {  // normal input
73
          (input_target_nodes_inputmeta_map_)[target_node] = auto_grad_meta;
74 75 76 77
        }
      }
    }
  }
78

79
  // Purify potential_startup_nodes_, remove nodes those are the same as
80 81 82
  // input_target_nodes
  void PurifyPotentialStartUpNodes() {
    VLOG(6) << "Running in PurifyPotentialStartUpNodes";
83
    if (input_target_nodes_inputmeta_map_.empty()) return;
84
    std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
85 86 87
    for (auto startup_op : potential_startup_nodes_) {
      auto iter = input_target_nodes_inputmeta_map_.find(startup_op);
      if (iter != input_target_nodes_inputmeta_map_.end()) {
88 89 90 91 92
        potential_startup_nodes_to_be_erased.emplace(iter->first);
      }
    }
    if (!potential_startup_nodes_to_be_erased.empty()) {
      for (auto nodes : potential_startup_nodes_to_be_erased) {
93
        potential_startup_nodes_.erase(nodes);
94 95 96
      }
    }
  }
97

98
  // Remove some nodes those doesn't need to be
99
  // stored in potential_stop_nodes_、potential_startup_nodes_
100
  void UpdateGraphInfo() {
101
    // Updated potential_sotp_nodes by depending_nodes_,
102
    // make sure the path from root to target_node is ok
103
    std::unordered_set<GradNodeBase*> startup_ops;
104
    VLOG(6) << "Running in UpdateGraphInfo";
105
    std::deque<GradNodeBase*> queue;
106 107
    for (auto& target_nodes_inputmeta_pair :
         input_target_nodes_inputmeta_map_) {
108
      queue.push_back(target_nodes_inputmeta_pair.first);
109
    }
110

111 112
    while (!queue.empty()) {
      auto* target_node = queue.front();
113
      queue.pop_front();
114 115
      if (!(depending_nodes_)[target_node].empty()) {
        auto precedding_nodes = (depending_nodes_)[target_node];
116
        for (auto pre_nodes : precedding_nodes) {
117
          queue.push_back(pre_nodes);
118 119 120
          if (potential_stop_nodes_.find(pre_nodes) !=
              potential_stop_nodes_.end()) {
            potential_stop_nodes_.erase(pre_nodes);
121 122 123
          }
        }
      } else {  // startup_ops have no precedding nodes
124 125
        VLOG(6) << "Emplace startup_ops";
        startup_ops.emplace(target_node);
126 127
      }
    }
128
    // Purify potential_startup_nodes_ again, remove some
129
    // potential startup_nodes that unreach to input target nodes
130
    if (!startup_ops.empty()) {
131
      std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
132 133
      for (auto node : potential_startup_nodes_) {
        if (startup_ops.count(node) == 0) {
134 135 136 137 138 139 140
          VLOG(6) << "Set up potential_startup_nodes_to_be_erased";
          potential_startup_nodes_to_be_erased.emplace(node);
        }
      }
      if (!potential_startup_nodes_to_be_erased.empty()) {
        for (auto node : potential_startup_nodes_to_be_erased) {
          VLOG(6) << "Erase nodes in potential_startup_nodes_to_be_erased";
141
          potential_startup_nodes_.erase(node);
142 143
        }
      }
144
    }
145
  }
146

147
  // Get Graph Info Betweent input target GradNode and outputs,
148
  // record depending_nodes_、potential_stop_nodes_、potential_startup_nodes_
149
  void GetGraphInfoBetweenTargets(const std::deque<GradNodeBase*>& init_queue) {
150
    VLOG(6) << "Runing In GetGraphInfoBetweenTargets";
151

152 153 154 155
    // Calculate in_degree for each node
    std::unordered_map<GradNodeBase*, int> node_in_degree_map;

    // Copy nodes
156
    std::deque<GradNodeBase*> queue = init_queue;
157 158 159 160 161
    std::unordered_set<GradNodeBase*> visited;

    // Visit each node exactly once in any order
    while (!queue.empty()) {
      GradNodeBase* node = queue.front();
162
      queue.pop_front();
163 164 165 166 167 168 169

      if (visited.count(node)) {
        continue;
      }
      visited.insert(node);

      // Check node is target_nodes or not, if node is not target_node,
170
      // all the next_node will be marked in potential_stop_nodes_
171
      bool is_potential_stop_nodes =
172
          input_target_nodes_inputmeta_map_.count(node);
173 174

      // Find and append next nodes
175 176 177 178 179 180
      const paddle::small_vector<std::vector<GradSlotMeta>,
                                 kSlotSmallVectorSize>& metas =
          node->OutputMeta();
      for (const auto& meta_list : metas) {
        for (const GradSlotMeta& meta : meta_list) {
          const auto& edge = meta.GetEdge();
181 182 183 184 185 186 187 188 189 190 191
          GradNodeBase* next_node = edge.GetMutableGradNode().get();

          // Next node could be nullptr if it is leaf tensor with no
          // AccumulationNode attached
          // Or it could also originated from dispensable inputs
          if (!next_node) continue;

          // if node not in input_target_nodes,
          // all the next_nodes of current node will be inserted to
          // potential_stop_node
          if (is_potential_stop_nodes) {
192
            potential_stop_nodes_.emplace(next_node);
193 194 195
          }

          // Update in_degree
196
          if (!node_in_degree_map.count(next_node)) {
197
            node_in_degree_map[next_node] = 0;
198
          }
199 200 201
          node_in_degree_map[next_node]++;

          // Record depending relationship
202
          (depending_nodes_)[next_node].emplace(node);
203
          queue.push_back(next_node);
204
        }
205 206
      }
    }
207
    // Update Graph Info, remove some nodes in
208
    // potential_stop_nodes_、potential_startup_nodes_、
209
    UpdateGraphInfo();
210 211
  }

212 213
  void ModifyReadyQueue(std::deque<GradNodeBase*>* queue) {
    std::deque<GradNodeBase*> tmp_queue;
214
    for (auto nodes : potential_startup_nodes_) {
215
      tmp_queue.push_back(nodes);
216 217
    }
    tmp_queue.swap(*queue);
218 219
  }

220
  // Set result for input target grad_var when potential_startup_nodes_ is empty
221 222 223 224
  void SetResultForInputTargetVar(
      const std::unordered_map<GradNodeBase*,
                               std::unique_ptr<GradTensorHolder>>&
          node_input_buffers_dict) {
225 226
    if (potential_startup_nodes_.size() == 0) {
      for (auto input_target_node : *GetInputTargetNodesInputMetaMap()) {
227 228 229 230 231 232 233
        // out rank_info of forward op
        auto rank_info = input_target_node.second->OutRankInfo();
        auto iter = node_input_buffers_dict.find(input_target_node.first);
        if (iter != node_input_buffers_dict.end()) {
          auto& target_result =
              (iter->second)->Buffers()[rank_info.first][rank_info.second];
          // save the target result
234
          results_map_[input_target_node.first] = target_result;
235 236 237 238
        }
      }
    }
  }
239 240 241 242

  // Set input target grad_var from node_input_buffer by inputmeta
  void SetResultForInputTargetVar(GradTensorHolder input_buffers,
                                  GradNodeBase* node) {
243 244
    auto iter = GetInputTargetNodesInputMetaMap()->find(node);
    if (iter != GetInputTargetNodesInputMetaMap()->end()) {
245 246 247 248 249 250 251
      VLOG(6) << "Get target result by by inputmeta";
      // out rank_info of forward op
      auto rank_info = (iter->second)->OutRankInfo();
      // rank_info is a pair, first means slot_id, second means rank.
      auto& target_result =
          input_buffers.Buffers()[rank_info.first][rank_info.second];
      // save the target result
252
      results_map_[node] = target_result;
253
    }
254 255 256 257
  }

  std::vector<paddle::experimental::Tensor> GetResults(
      const std::vector<paddle::experimental::Tensor>& inputs,
258 259
      bool allow_unused,
      bool create_graph) {
260 261 262 263 264 265 266 267 268
    VLOG(6) << "Running in GetResults";
    if (inputs.empty()) return {};

    std::vector<paddle::experimental::Tensor> results;
    results.reserve(inputs.size());

    for (size_t i = 0; i < inputs.size(); ++i) {
      auto& input = inputs[i];
      AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input);
269 270 271

      auto* target_node = auto_grad_meta->GetMutableGradNode().get();
      if (orig_to_copied_node_mapping_.count(target_node)) {
272
        target_node = orig_to_copied_node_mapping_[target_node].get();
273 274 275 276 277
      } else {
        VLOG(6) << "Unable to find target node in "
                   "orig_to_copied_node_mapping_, likely indicating an unused "
                   "input";
      }
278

279 280
      auto iter = results_map_.find(target_node);
      if (iter != results_map_.end()) {
281 282 283 284 285 286
        // set StopGradient = !create_graph
        AutogradMeta* tensor_auto_grad_meta =
            EagerUtils::autograd_meta(&(iter->second));
        tensor_auto_grad_meta->SetStopGradient(!create_graph);
        results.emplace_back(iter->second);
      } else {
287 288
        PADDLE_ENFORCE_EQ(allow_unused,
                          true,
289 290 291 292 293 294
                          paddle::platform::errors::InvalidArgument(
                              "The %d-th input does not appear in the backward "
                              "graph. Please check the input tensor or set "
                              "allow_unused=True to get None result.",
                              i));
        results.emplace_back();
295 296
      }
    }
297 298 299 300 301 302 303
    Clear();
    return results;
  }

  void PreparedForGeneralGrad(
      const std::vector<paddle::experimental::Tensor>& inputs,
      const std::vector<paddle::experimental::Tensor>& no_grad_vars,
304
      std::deque<GradNodeBase*>* queue,
305 306 307 308 309 310 311
      const std::unordered_map<GradNodeBase*,
                               std::unique_ptr<GradTensorHolder>>&
          node_input_buffers_dict) {
    // Get no_grad_vars's GradNodes and InputMeta Info
    GetTargetNodesInfo(no_grad_vars, true /* is_no_grad_vars */);
    // Get inputs's GradNodes and InputMeta Info
    GetTargetNodesInfo(inputs, false /* is_no_grad_vars */);
312
    // Purify potentialstartup_ops, remove those nodes that are the same as
313 314 315
    // input_target_nodes
    PurifyPotentialStartUpNodes();
    // Get Graph Info Betweent input target gradnode and outputs
316 317
    // Record the depending_nodes_ and
    // potential_stop_nodes_、potential_startup_nodes_
318 319 320 321 322 323 324 325 326
    GetGraphInfoBetweenTargets(*queue);
    // Reset queue. Queue is empty only when
    // 1.input equals to output. 2.input can not reach to output.
    ModifyReadyQueue(queue);
    // Set result for input target grad_var when queue is empty
    if (queue->empty()) SetResultForInputTargetVar(node_input_buffers_dict);
  }

  bool IsPotentialStopNodes(GradNodeBase* node) {
327
    return potential_stop_nodes_.count(node);
328 329 330 331
  }

  std::unordered_map<GradNodeBase*, AutogradMeta*>*
  GetNoGradVarNodesInputMetaMap() {
332
    return &no_grad_var_nodes_inputmeta_map_;
333 334 335
  }

  std::unordered_map<GradNodeBase*, AutogradMeta*>*
336 337
  GetInputTargetNodesInputMetaMap() {
    return &input_target_nodes_inputmeta_map_;
338 339 340
  }

  std::unordered_set<GradNodeBase*>* GetPotentialStopNodes() {
341
    return &potential_stop_nodes_;
342 343 344
  }

  std::unordered_set<GradNodeBase*>* GetPotentialStartupNodes() {
345
    return &potential_startup_nodes_;
346 347 348
  }

  void Clear() {
349 350 351 352 353 354
    no_grad_var_nodes_inputmeta_map_.clear();
    input_target_nodes_inputmeta_map_.clear();
    potential_startup_nodes_.clear();
    potential_stop_nodes_.clear();
    depending_nodes_.clear();
    results_map_.clear();
355 356 357 358 359 360
    copied_grad_nodes_.clear();
    orig_to_copied_node_mapping_.clear();
  }

  GradNodeBase* CopyGradNode(const std::shared_ptr<GradNodeBase>& orig_node) {
    if (orig_to_copied_node_mapping_.count(orig_node.get())) {
361
      return orig_to_copied_node_mapping_[orig_node.get()].get();
362 363 364 365
    }
    std::shared_ptr<GradNodeBase> copied_node = orig_node->Copy();

    // Save node and update mapping
366
    orig_to_copied_node_mapping_[orig_node.get()] = copied_node;
367 368 369 370 371 372
    copied_grad_nodes_.push_back(copied_node);

    return copied_node.get();
  }

  void ReconstructBackwardGraph(
373 374
      const std::deque<GradNodeBase*>& orig_init_queue) {
    std::deque<GradNodeBase*> queue = orig_init_queue;
375 376 377 378 379
    std::unordered_set<GradNodeBase*> visited;

    // BFS and recursively copy the grad nodes
    while (!queue.empty()) {
      GradNodeBase* orig_node = queue.front();
380
      queue.pop_front();
381 382 383 384 385 386 387 388 389 390
      if (visited.count(orig_node)) {
        continue;
      }
      visited.insert(orig_node);

      PADDLE_ENFORCE(
          orig_to_copied_node_mapping_.count(orig_node),
          paddle::platform::errors::Fatal(
              "Cannot reconstruct backward graph,"
              "unable to find copied target for certain grad node."));
391
      GradNodeBase* copied_node = orig_to_copied_node_mapping_[orig_node].get();
392

393 394 395 396 397 398 399 400 401
      const paddle::small_vector<std::vector<GradSlotMeta>,
                                 kSlotSmallVectorSize>& orig_meta =
          orig_node->OutputMeta();
      paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
          copied_edges = copied_node->MutableOutputMeta();
      for (size_t i = 0; i < orig_meta.size(); i++) {
        for (size_t j = 0; j < orig_meta[i].size(); j++) {
          const Edge& orig_edge = orig_meta[i][j].GetEdge();
          Edge& copied_edge = copied_edges[i][j].GetMutableEdge();
402 403 404 405 406 407 408 409 410

          std::shared_ptr<GradNodeBase> orig_next_node =
              orig_edge.GetMutableGradNode();
          if (!orig_next_node) continue;

          // Copy Next Node
          std::shared_ptr<GradNodeBase> copied_next_node;
          if (orig_to_copied_node_mapping_.count(orig_next_node.get())) {
            copied_next_node =
411
                orig_to_copied_node_mapping_[orig_next_node.get()];
412 413 414 415

          } else {
            copied_next_node = orig_next_node->Copy();
            orig_to_copied_node_mapping_[orig_next_node.get()] =
416
                copied_next_node;
417 418 419 420 421 422 423
            copied_grad_nodes_.push_back(copied_next_node);
          }

          // Update Edge's Grad Node
          copied_edge.SetGradNode(copied_next_node);

          // Update BFS queue
424
          queue.push_back(orig_next_node.get());
425 426 427
        }
      }
    }
428 429
  }

430 431 432 433 434
 private:
  GeneralGrad() = default;
  static GeneralGrad* general_grad_;
  // no_grad_vars's GradNode and GradNode's InputMeta.
  std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
435
      no_grad_var_nodes_inputmeta_map_;
436 437
  // inputs's GradNode and GradNode's InputMeta.
  std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
438
      input_target_nodes_inputmeta_map_;
439
  // Record all the potential startup_nodes, will be changed.
440
  std::unordered_set<GradNodeBase*> potential_startup_nodes_;
441
  // Record all the potential stop nodes, will be changed.
442
  std::unordered_set<GradNodeBase*> potential_stop_nodes_;
443 444
  std::unordered_map<GradNodeBase* /* next node */,
                     std::unordered_set<GradNodeBase*> /* pre nodes */>
445 446
      depending_nodes_;
  std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map_;
447 448

  std::vector<std::shared_ptr<GradNodeBase>> copied_grad_nodes_;
449 450
  std::unordered_map<GradNodeBase*, std::shared_ptr<GradNodeBase>>
      orig_to_copied_node_mapping_;
451

452 453
  DISABLE_COPY_AND_ASSIGN(GeneralGrad);
};
454

455
std::unordered_map<GradNodeBase*, int> getInDegreeMap(
456
    const std::deque<GradNodeBase*>& init_queue) {
457
  // Calculate in_degree for each node
458 459
  // We can completely remove this pass, if in_degree were set during forward
  // pass
460 461 462
  std::unordered_map<GradNodeBase*, int> node_in_degree_map;

  // Copy nodes
463
  std::deque<GradNodeBase*> queue = init_queue;
464 465 466 467 468
  std::unordered_set<GradNodeBase*> visited;

  // Visit each node exactly once in any order
  while (!queue.empty()) {
    GradNodeBase* node = queue.front();
469
    queue.pop_front();
470 471 472 473 474 475

    if (visited.count(node)) {
      continue;
    }
    visited.insert(node);

476 477 478 479 480
    PADDLE_ENFORCE_NOT_NULL(
        node,
        paddle::platform::errors::Fatal(
            "We got null node when we traverse the backward graph, and this "
            "should not happened please check your code and contact us."));
481
    // Find and append next nodes
482 483 484 485 486
    const paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
        metas = node->OutputMeta();
    for (const auto& meta_list : metas) {
      for (const GradSlotMeta& meta : meta_list) {
        const auto& edge = meta.GetEdge();
487 488 489 490 491 492 493 494 495 496
        GradNodeBase* next_node = edge.GetMutableGradNode().get();
        // Next node could be nullptr if it is leaf tensor with no
        // AccumulationNode attached
        // Or it could also originated from dispensable inputs
        if (!next_node) continue;

        // Update in_degree
        if (!node_in_degree_map.count(next_node))
          node_in_degree_map[next_node] = 0;
        node_in_degree_map[next_node]++;
497
        queue.push_back(next_node);
498 499 500
      }
    }
  }
501

502
  return node_in_degree_map;
503 504 505 506 507 508
}

// Enforce GradNode has TensorWrappers as Input
void EnforceGradNodeHasInput(GradNodeBase* node) {
  VLOG(6) << "Running in EnforceGradNodeHasInput";
  PADDLE_ENFORCE_NE(
509 510
      node->IsTensorWrappersCleared(),
      true,
511 512 513 514 515 516 517 518
      paddle::platform::errors::Fatal(
          "The TensorWrappers of %s do not exist. This may be because:\n"
          "You calculate backward twice for the same subgraph without "
          "setting retain_graph=True. Please set retain_graph=True in the "
          "first backward/grad call.\n",
          node->name()));
}

519 520 521 522 523 524 525
void DuplicateCheck(const std::vector<paddle::experimental::Tensor>& inputs,
                    bool is_input) {
  std::unordered_set<AutogradMeta*> visisted_ins;
  std::string msg = is_input ? "inputs" : "outputs";
  for (auto in : inputs) {
    AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(in);
    PADDLE_ENFORCE_EQ(
526 527
        visisted_ins.count(auto_grad_meta),
        0,
528
        paddle::platform::errors::AlreadyExists(
529 530 531 532
            "%s contain duplicate tensor %s, please check %s carefully.",
            msg,
            in.name(),
            msg));
533
    visisted_ins.insert(auto_grad_meta);
534 535 536
  }
}

537 538
GeneralGrad* GeneralGrad::general_grad_ = new GeneralGrad();

539 540 541
std::vector<paddle::experimental::Tensor> RunBackward(
    const std::vector<paddle::experimental::Tensor>& tensors,  // output
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
542 543
    bool retain_graph,
    bool create_graph = false,
544 545 546
    const std::vector<paddle::experimental::Tensor>& inputs = {},
    bool allow_unused = false,
    const std::vector<paddle::experimental::Tensor>& no_grad_vars = {}) {
547
  VLOG(3) << "Start Backward";
548

549 550 551 552
  // *Gradient Hook should happen at node-level
  // *Inplace version check should perform at node-level
  // *Cross-batch accumulation happens at forward pass

553 554
  // GeneralGrad
  bool is_general_grad = !inputs.empty();
555
  if (is_general_grad) GeneralGrad::Instance().Clear();
556

557 558 559
  /* --- Initialization --- */
  // 1. Init queue with starting nodes
  // 2. Prepare initial input buffers
560 561
  std::deque<GradNodeBase*> queue;
  std::deque<GradNodeBase*> orig_queue;
562 563 564
  std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>>
      node_input_buffers_dict;
  for (size_t i = 0; i < tensors.size(); i++) {
565
    const paddle::experimental::Tensor& tensor = tensors[i];
566

567 568 569 570 571 572 573
    AutogradMeta* auto_grad_meta = EagerUtils::nullable_autograd_meta(tensor);
    if (auto_grad_meta == nullptr) {
      VLOG(3) << "Skip auto grad since there is no grad op for var or loss is "
                 "stop_gradient=True: "
              << tensor.name();
      continue;
    }
574 575 576 577 578 579
    // Get grad input info from target tensors
    auto input_info = auto_grad_meta->OutRankInfo();

    VLOG(2) << "Out Rank of Tensor is slot: " << input_info.first
            << ", rank: " << input_info.second;
    // Get target GradNodeBase from target tensors
580 581 582 583 584 585 586 587 588 589
    auto shared_grad_node = auto_grad_meta->GetMutableGradNode();

    if (shared_grad_node == nullptr || shared_grad_node.get() == nullptr ||
        auto_grad_meta->StopGradient()) {
      VLOG(3) << "Skip auto grad since there is no grad op for var or loss is "
                 "stop_gradient=True: "
              << tensor.name();
      continue;
    }

590
    // TODO(zhanlve): Copy and Modify GradNode if is_general_grad
591
    GradNodeBase* grad_node = shared_grad_node.get();
592 593
    if (is_general_grad) {
      // Save orig grad node
594
      orig_queue.push_back(grad_node);
595 596 597 598 599 600 601

      // Replace grad_node with copied grad_node
      grad_node = GeneralGrad::Instance().CopyGradNode(shared_grad_node);

      // Record potential startup grad node
      GeneralGrad::Instance().GetPotentialStartupNodes()->insert(grad_node);
    }
602 603 604

    // Prepare GradTensorHolder
    if (!node_input_buffers_dict.count(grad_node)) {
605 606
      VLOG(6) << "Create Value for grad input tensor " << i
              << " of grad node: " << grad_node->name();
607 608 609
      node_input_buffers_dict[grad_node] =
          std::make_unique<GradTensorHolder>(grad_node->InputMeta());
    }
610 611 612
    bool copy_from_grad_t =
        grad_tensors.size() > 0 && grad_tensors[i].initialized();
    if (copy_from_grad_t) {
613 614 615 616 617
      PADDLE_ENFORCE(
          grad_tensors.size() == tensors.size(),
          paddle::platform::errors::Fatal(
              "Detected size mismatch between tensors and grad_tensors"
              "grad_tensors should either have "
618
              "size = 0 or same size as tensors."));
619 620
      // Feed given tensor if it's provided
      VLOG(6) << "Fill grad input tensor " << i << "with give grad tensor";
621

622 623 624
      // Deep copy
      node_input_buffers_dict[grad_node]->CopyValueFromTensor(
          input_info.first, input_info.second, grad_tensors[i]);
625 626 627 628 629 630 631
    } else {
      VLOG(6) << "Fill grad input tensor " << i << " with 1.0";
      // Initialize tensor with 1.0
      // Forward Tensor "tensor" is passed to indicate tensortype, datatype and
      // dims
      // GradTensorHolder will initialize another tensor with same tensortype,
      // datatype and dims but filled with 1.0
632
      node_input_buffers_dict[grad_node]->CopyValueFromTensor(
633
          input_info.first, input_info.second, tensor, /*fill_one=*/true);
634 635
    }

636
    // Prepare queue, potential startup_nodes
637
    queue.push_back(grad_node);
638 639 640 641 642
  }

  if (is_general_grad) {
    // Copy Backward Graph
    GeneralGrad::Instance().ReconstructBackwardGraph(orig_queue);
643 644
  }

645
  VLOG(3) << "Update In degree Map for backward";
646 647 648 649
  // 3. Compute in_degree for each node
  std::unordered_map<GradNodeBase*, int> node_in_degree_map =
      getInDegreeMap(queue);

650 651
  if (is_general_grad) {
    // Prepare several vital preprocess for GeneralGrad
652 653
    GeneralGrad::Instance().PreparedForGeneralGrad(
        inputs, no_grad_vars, &queue, node_input_buffers_dict);
654 655
  }

656
  VLOG(6) << " startup_ops' size is :" << queue.size();
657

658 659 660
  /* --- Topological Visit --- */
  // 1. Pop queue
  // 2. Run node
661
  //    |- Check and capture target result
662 663 664
  //    |- node(grads)
  //    |- Prepare for next node
  // 3. Update queue
665
  VLOG(3) << "Run Backward";
666 667
  while (!queue.empty()) {
    GradNodeBase* node = queue.front();
668
    VLOG(6) << "Running GradNode:" << node->name();
669

670
    paddle::platform::RecordEvent node_record_event(
671
        std::string((*node).name()),
672 673
        paddle::platform::TracerEventType::Operator,
        1);
674

675
    if (queue.size() > 1 && node_in_degree_map[node] != 0) {
676
      queue.pop_front();
677 678
      continue;
    }
679
    queue.pop_front();
680

681
    // Run node: This is where Hook happens
682 683
    auto node_input_buffer_iter = node_input_buffers_dict.find(node);
    PADDLE_ENFORCE_NE(
684 685
        node_input_buffer_iter,
        node_input_buffers_dict.end(),
686
        paddle::platform::errors::Fatal(
687
            "Unable to find next node in the GradTensorHolder \n"
688
            "Trying to run Node without configuring its GradTensorHolder."));
689 690

    std::unique_ptr<GradTensorHolder> node_input_buffer =
691
        std::move(node_input_buffer_iter->second);
692

693 694 695 696
    // Set input target grad_var from node_input_buffer by inputmeta
    if (!inputs.empty() && is_general_grad) {
      GeneralGrad::Instance().SetResultForInputTargetVar(*node_input_buffer,
                                                         node);
697 698 699
    }

    // no_grad_vars
700 701 702 703 704 705 706 707 708 709
    if (!no_grad_vars.empty() && is_general_grad) {
      auto iter =
          GeneralGrad::Instance().GetNoGradVarNodesInputMetaMap()->find(node);
      if (iter !=
          GeneralGrad::Instance().GetNoGradVarNodesInputMetaMap()->end()) {
        VLOG(6) << "Change the input buffer[slot][rank] by Zeros";
        auto rank_info = (iter->second)->OutRankInfo();
        node_input_buffer->SetBufferSlotRankZeros(rank_info.first,
                                                  rank_info.second);
      }
710 711
    }

712
    // Check input
713 714
    EnforceGradNodeHasInput(node);

715
    VLOG(6) << "Run Backward Kernel with GradTensorHolder.";
716
    // Run Pre Backward Node and get outputs
717 718
    paddle::small_vector<std::vector<paddle::experimental::Tensor>,
                         kSlotSmallVectorSize>
719 720
        grad_output_tensors = (*node)(
            node_input_buffer->Buffers(), create_graph, is_general_grad);
721 722 723 724 725 726 727 728

    // retain_grad or not
    if (!retain_graph) {
      VLOG(6)
          << "retain_graph is false, need to clear the TensorWrapper of nodes.";
      node->ClearTensorWrappers();
    }

729
    // TODO(jiabin): Should we erase it or find a more efficient way.
730
    node_input_buffers_dict.erase(node_input_buffer_iter);
731 732

    // Prepare GradTensorHolder for next node
733 734 735
    const paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
        metas = node->OutputMeta();
    PADDLE_ENFORCE(metas.size() == grad_output_tensors.size() || metas.empty(),
736 737
                   paddle::platform::errors::Fatal(
                       "Number of edges should be either empty ( for leaf node "
738 739
                       ") or the same as number of output grad tensors, but we "
                       "got edges size is: %d, grad_output size is: %d",
740 741
                       metas.size(),
                       grad_output_tensors.size()));
742

743 744 745
    for (size_t i = 0; i < metas.size(); i++) {
      for (size_t j = 0; j < metas[i].size(); j++) {
        const Edge& edge = metas[i][j].GetEdge();
J
Jiabin Yang 已提交
746 747 748
        if (!edge.IsInitialized()) {
          continue;
        }
749 750
        auto edge_rank = edge.GetEdgeRankInfo();
        // Since we make edge has as same rank as bwd outputs, we indexing them
751
        // with the same rank(i, j)
752
        auto next_node_shared = edge.GetMutableGradNode();
753
        VLOG(3) << "Found pending node: " << next_node_shared->name();
754 755 756
        // Next node could be nullptr if it is leaf tensor with no
        // AccumulationNode attached
        // Or it could also originated from dispensable inputs
757 758 759 760
        if (!next_node_shared || !next_node_shared.get() ||
            grad_output_tensors[i].empty()) {
          continue;
        }
761

762
        PADDLE_ENFORCE_LT(
763 764
            j,
            grad_output_tensors[i].size(),
765 766 767 768 769
            paddle::platform::errors::Fatal(
                "Rank of grad_output_tensors should be less than "
                "grad_output_tensors[i].size(), which is: %d. This error may "
                "indicate autoprune or autograd api error. ",
                grad_output_tensors.size()));
770 771
        paddle::experimental::Tensor& grad_output_tensor =
            grad_output_tensors[i][j];
772 773 774

        if ((!grad_output_tensor.defined() ||
             !grad_output_tensor.initialized())) {
775 776
          VLOG(6) << "We get grad_output_tensor with slot: " << i
                  << ", rank: " << j << " as uninitialized or undefined tensor";
777
        }
778

779 780 781 782
        VLOG(6) << "Get Edge and grad_output_tensor with slot: " << i
                << ", rank: " << j
                << " 's name is: " << grad_output_tensor.name();

783 784 785 786 787 788 789 790 791 792
        auto* next_node = next_node_shared.get();
        if (!node_input_buffers_dict.count(next_node)) {
          const auto& input_meta = next_node->InputMeta();
          auto grad_tensor_holder =
              std::make_unique<GradTensorHolder>(input_meta);
          VLOG(6) << "Construct GradTensorHolder for grad node: "
                  << next_node->name();
          node_input_buffers_dict[next_node] = std::move(grad_tensor_holder);
        }

793 794
        VLOG(6) << "Sum grad inputs for edge slot: " << edge_rank.first
                << ", rank: " << edge_rank.second;
795

796 797 798 799
        node_input_buffers_dict[next_node]->add(edge_rank.first,
                                                edge_rank.second,
                                                grad_output_tensor,
                                                create_graph);
800 801 802

        // Update queue
        node_in_degree_map[next_node]--;
803

804 805 806 807
        PADDLE_ENFORCE(
            node_in_degree_map[next_node] >= 0,
            paddle::platform::errors::Fatal(
                "Detected in-degree value smaller than zero. For Node: %s"
808
                "Node's in-degree cannot be negative.",
809
                next_node->name()));
810

811 812 813 814
        if (is_general_grad) {
          bool is_potential_stop_node =
              GeneralGrad::Instance().GetPotentialStopNodes()->count(next_node);
          if (node_in_degree_map[next_node] == 0 && !is_potential_stop_node) {
815 816 817 818 819
            if (dynamic_cast<egr::GradNodeAccumulation*>(next_node)) {
              queue.push_front(std::move(next_node));
            } else {
              queue.push_back(std::move(next_node));
            }
820 821 822
          }
        } else {
          if (node_in_degree_map[next_node] == 0) {
823 824 825 826 827
            if (dynamic_cast<egr::GradNodeAccumulation*>(next_node)) {
              queue.push_front(std::move(next_node));
            } else {
              queue.push_back(std::move(next_node));
            }
828
          }
829 830 831 832
        }
      }
    }
  }
833

834 835
  if (!is_general_grad) return {};
  return GeneralGrad::Instance().GetResults(inputs, allow_unused, create_graph);
836 837
}

838
void Backward(
839
    const std::vector<paddle::experimental::Tensor>& tensors,  // outputs
840 841
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
    bool retain_graph) {
842
  VLOG(3) << "Run in Backward";
843 844 845
  paddle::platform::RecordEvent backward_record_event(
      "backward", paddle::platform::TracerEventType::Operator, 1);
  RunBackward(tensors, grad_tensors, retain_graph);
J
Jiabin Yang 已提交
846
  phi::autotune::AutoTuneStatus::Instance().Update();
847 848 849
}

std::vector<paddle::experimental::Tensor> Grad(
850
    const std::vector<paddle::experimental::Tensor>& tensors,  // outputs
851 852
    const std::vector<paddle::experimental::Tensor>& inputs,
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
853 854 855 856
    bool retain_graph,
    bool create_graph,
    bool only_inputs,
    bool allow_unused,
857
    const std::vector<paddle::experimental::Tensor>& no_grad_vars) {
858
  VLOG(3) << "Run in Grad";
859 860 861 862

  DuplicateCheck(inputs, true /* is_input */);
  DuplicateCheck(tensors, false /* is_input */);

863 864 865 866 867 868 869
  return RunBackward(tensors,
                     grad_tensors,
                     retain_graph,
                     create_graph,
                     inputs,
                     allow_unused,
                     no_grad_vars);
870
}
871
}  // namespace egr