backward.cc 31.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/eager/backward.h"
#include <queue>

#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/grad_tensor_holder.h"
#include "paddle/fluid/eager/utils.h"
22 23
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
24

J
Jiabin Yang 已提交
25
#include "glog/logging.h"
26 27
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
J
Jiabin Yang 已提交
28
#include "paddle/phi/kernels/autotune/switch_autotune.h"
29 30 31

namespace egr {

32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
/*
* GeneralGrad is Helpper class to implement custom grad operation between
* outputs and inputs.
*
* **/
class GeneralGrad {
 public:
  static GeneralGrad& Instance() { return *general_grad_; }

  // Get inputs's / no_grad_vars's GradNodes and InputMeta Info
  void GetTargetNodesInfo(
      const std::vector<paddle::experimental::Tensor>& inputs,
      bool is_no_grad_vars) {
    std::string msg = is_no_grad_vars ? "no_grad_vars" : "inputs";
    VLOG(6) << "Running in GetTargetNodesInfo.";
    if (!inputs.empty()) {
      VLOG(6) << msg << " are not empty.";
      size_t num_inputs = inputs.size();
      for (size_t i = 0; i < num_inputs; i++) {
        AutogradMeta* auto_grad_meta =
            EagerUtils::unsafe_autograd_meta(inputs[i]);
53 54 55
        auto* target_node = auto_grad_meta->GetMutableGradNode().get();

        if (orig_to_copied_node_mapping_.count(target_node)) {
56
          target_node = orig_to_copied_node_mapping_[target_node].get();
57 58 59 60 61 62
        } else {
          VLOG(6) << "Unable to find target node in "
                     "orig_to_copied_node_mapping_, likely indicating an "
                     "unused input";
        }

63 64 65 66 67 68 69 70 71 72 73 74 75
        PADDLE_ENFORCE_NOT_NULL(target_node,
                                paddle::platform::errors::Fatal(
                                    "There is no grad op for %s:[%d] or it's"
                                    "stop_gradient=True.",
                                    msg, i));
        if (is_no_grad_vars) {
          (no_grad_var_nodes_inputmeta_map)[target_node] = auto_grad_meta;
        } else {  // normal input
          (input_target_nodes_inputmeta_map)[target_node] = auto_grad_meta;
        }
      }
    }
  }
76

77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
  // Purify potential_startup_nodes, remove nodes those are the same as
  // input_target_nodes
  void PurifyPotentialStartUpNodes() {
    VLOG(6) << "Running in PurifyPotentialStartUpNodes";
    if (input_target_nodes_inputmeta_map.empty()) return;
    std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
    for (auto startup_op : potential_startup_nodes) {
      auto iter = input_target_nodes_inputmeta_map.find(startup_op);
      if (iter != input_target_nodes_inputmeta_map.end()) {
        potential_startup_nodes_to_be_erased.emplace(iter->first);
      }
    }
    if (!potential_startup_nodes_to_be_erased.empty()) {
      for (auto nodes : potential_startup_nodes_to_be_erased) {
        potential_startup_nodes.erase(nodes);
      }
    }
  }
95

96 97 98 99 100 101 102 103 104 105 106
  // Remove some nodes those doesn't need to be
  // stored in potential_stop_nodes、potential_startup_nodes
  void UpdateGraphInfo() {
    // Updated potential_sotp_nodes by depending_nodes,
    // make sure the path from root to target_node is ok
    std::unordered_set<GradNodeBase*> _startup_ops;
    VLOG(6) << "Running in UpdateGraphInfo";
    std::queue<GradNodeBase*> queue;
    for (auto& target_nodes_inputmeta_pair : input_target_nodes_inputmeta_map) {
      queue.emplace(target_nodes_inputmeta_pair.first);
    }
107

108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
    while (!queue.empty()) {
      auto* target_node = queue.front();
      queue.pop();
      if (!(depending_nodes)[target_node].empty()) {
        auto precedding_nodes = (depending_nodes)[target_node];
        for (auto pre_nodes : precedding_nodes) {
          queue.emplace(pre_nodes);
          if (potential_stop_nodes.find(pre_nodes) !=
              potential_stop_nodes.end()) {
            potential_stop_nodes.erase(pre_nodes);
          }
        }
      } else {  // startup_ops have no precedding nodes
        VLOG(6) << "Emplace _startup_ops";
        _startup_ops.emplace(target_node);
123 124
      }
    }
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
    // Purify potential_startup_nodes again, remove some
    // potential startup_nodes that unreach to input target nodes
    if (!_startup_ops.empty()) {
      std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
      for (auto node : potential_startup_nodes) {
        if (_startup_ops.count(node) == 0) {
          VLOG(6) << "Set up potential_startup_nodes_to_be_erased";
          potential_startup_nodes_to_be_erased.emplace(node);
        }
      }
      if (!potential_startup_nodes_to_be_erased.empty()) {
        for (auto node : potential_startup_nodes_to_be_erased) {
          VLOG(6) << "Erase nodes in potential_startup_nodes_to_be_erased";
          potential_startup_nodes.erase(node);
        }
      }
141
    }
142
  }
143

144 145 146 147
  // Get Graph Info Betweent input target GradNode and outputs,
  // record depending_nodes、potential_stop_nodes、potential_startup_nodes
  void GetGraphInfoBetweenTargets(const std::queue<GradNodeBase*>& init_queue) {
    VLOG(6) << "Runing In GetGraphInfoBetweenTargets";
148

149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
    // Calculate in_degree for each node
    std::unordered_map<GradNodeBase*, int> node_in_degree_map;

    // Copy nodes
    std::queue<GradNodeBase*> queue = init_queue;
    std::unordered_set<GradNodeBase*> visited;

    // Visit each node exactly once in any order
    while (!queue.empty()) {
      GradNodeBase* node = queue.front();
      queue.pop();

      if (visited.count(node)) {
        continue;
      }
      visited.insert(node);

      // Check node is target_nodes or not, if node is not target_node,
      // all the next_node will be marked in potential_stop_nodes
      bool is_potential_stop_nodes =
          input_target_nodes_inputmeta_map.count(node);

      // Find and append next nodes
172 173 174 175 176 177
      const paddle::small_vector<std::vector<GradSlotMeta>,
                                 kSlotSmallVectorSize>& metas =
          node->OutputMeta();
      for (const auto& meta_list : metas) {
        for (const GradSlotMeta& meta : meta_list) {
          const auto& edge = meta.GetEdge();
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
          GradNodeBase* next_node = edge.GetMutableGradNode().get();

          // Next node could be nullptr if it is leaf tensor with no
          // AccumulationNode attached
          // Or it could also originated from dispensable inputs
          if (!next_node) continue;

          // if node not in input_target_nodes,
          // all the next_nodes of current node will be inserted to
          // potential_stop_node
          if (is_potential_stop_nodes) {
            potential_stop_nodes.emplace(next_node);
          }

          // Update in_degree
          if (!node_in_degree_map.count(next_node))
            node_in_degree_map[next_node] = 0;
          node_in_degree_map[next_node]++;

          // Record depending relationship
          (depending_nodes)[next_node].emplace(node);
          queue.push(next_node);
        }
201 202
      }
    }
203 204 205
    // Update Graph Info, remove some nodes in
    // potential_stop_nodes、potential_startup_nodes、
    UpdateGraphInfo();
206 207
  }

208 209 210 211 212 213
  void ModifyReadyQueue(std::queue<GradNodeBase*>* queue) {
    std::queue<GradNodeBase*> tmp_queue;
    for (auto nodes : potential_startup_nodes) {
      tmp_queue.emplace(nodes);
    }
    tmp_queue.swap(*queue);
214 215
  }

216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
  // Set result for input target grad_var when potential_startup_nodes is empty
  void SetResultForInputTargetVar(
      const std::unordered_map<GradNodeBase*,
                               std::unique_ptr<GradTensorHolder>>&
          node_input_buffers_dict) {
    if (potential_startup_nodes.size() == 0) {
      for (auto input_target_node : *GetInPutTargetNodesInputMetaMap()) {
        // out rank_info of forward op
        auto rank_info = input_target_node.second->OutRankInfo();
        auto iter = node_input_buffers_dict.find(input_target_node.first);
        if (iter != node_input_buffers_dict.end()) {
          auto& target_result =
              (iter->second)->Buffers()[rank_info.first][rank_info.second];
          // save the target result
          results_map[input_target_node.first] = target_result;
231 232 233 234
        }
      }
    }
  }
235 236 237 238 239 240 241 242 243 244 245 246 247 248

  // Set input target grad_var from node_input_buffer by inputmeta
  void SetResultForInputTargetVar(GradTensorHolder input_buffers,
                                  GradNodeBase* node) {
    auto iter = GetInPutTargetNodesInputMetaMap()->find(node);
    if (iter != GetInPutTargetNodesInputMetaMap()->end()) {
      VLOG(6) << "Get target result by by inputmeta";
      // out rank_info of forward op
      auto rank_info = (iter->second)->OutRankInfo();
      // rank_info is a pair, first means slot_id, second means rank.
      auto& target_result =
          input_buffers.Buffers()[rank_info.first][rank_info.second];
      // save the target result
      results_map[node] = target_result;
249
    }
250 251 252 253 254 255 256 257 258 259 260 261 262 263
  }

  std::vector<paddle::experimental::Tensor> GetResults(
      const std::vector<paddle::experimental::Tensor>& inputs,
      bool allow_unused, bool create_graph) {
    VLOG(6) << "Running in GetResults";
    if (inputs.empty()) return {};

    std::vector<paddle::experimental::Tensor> results;
    results.reserve(inputs.size());

    for (size_t i = 0; i < inputs.size(); ++i) {
      auto& input = inputs[i];
      AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input);
264 265 266

      auto* target_node = auto_grad_meta->GetMutableGradNode().get();
      if (orig_to_copied_node_mapping_.count(target_node)) {
267
        target_node = orig_to_copied_node_mapping_[target_node].get();
268 269 270 271 272
      } else {
        VLOG(6) << "Unable to find target node in "
                   "orig_to_copied_node_mapping_, likely indicating an unused "
                   "input";
      }
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288

      auto iter = results_map.find(target_node);
      if (iter != results_map.end()) {
        // set StopGradient = !create_graph
        AutogradMeta* tensor_auto_grad_meta =
            EagerUtils::autograd_meta(&(iter->second));
        tensor_auto_grad_meta->SetStopGradient(!create_graph);
        results.emplace_back(iter->second);
      } else {
        PADDLE_ENFORCE_EQ(allow_unused, true,
                          paddle::platform::errors::InvalidArgument(
                              "The %d-th input does not appear in the backward "
                              "graph. Please check the input tensor or set "
                              "allow_unused=True to get None result.",
                              i));
        results.emplace_back();
289 290
      }
    }
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348
    Clear();
    return results;
  }

  void PreparedForGeneralGrad(
      const std::vector<paddle::experimental::Tensor>& inputs,
      const std::vector<paddle::experimental::Tensor>& no_grad_vars,
      std::queue<GradNodeBase*>* queue,
      const std::unordered_map<GradNodeBase*,
                               std::unique_ptr<GradTensorHolder>>&
          node_input_buffers_dict) {
    // Get no_grad_vars's GradNodes and InputMeta Info
    GetTargetNodesInfo(no_grad_vars, true /* is_no_grad_vars */);
    // Get inputs's GradNodes and InputMeta Info
    GetTargetNodesInfo(inputs, false /* is_no_grad_vars */);
    // Purify potential_startup_ops, remove those nodes that are the same as
    // input_target_nodes
    PurifyPotentialStartUpNodes();
    // Get Graph Info Betweent input target gradnode and outputs
    // Record the depending_nodes and
    // potential_stop_nodes、potential_startup_nodes
    GetGraphInfoBetweenTargets(*queue);
    // Reset queue. Queue is empty only when
    // 1.input equals to output. 2.input can not reach to output.
    ModifyReadyQueue(queue);
    // Set result for input target grad_var when queue is empty
    if (queue->empty()) SetResultForInputTargetVar(node_input_buffers_dict);
  }

  bool IsPotentialStopNodes(GradNodeBase* node) {
    return potential_stop_nodes.count(node);
  }

  std::unordered_map<GradNodeBase*, AutogradMeta*>*
  GetNoGradVarNodesInputMetaMap() {
    return &no_grad_var_nodes_inputmeta_map;
  }

  std::unordered_map<GradNodeBase*, AutogradMeta*>*
  GetInPutTargetNodesInputMetaMap() {
    return &input_target_nodes_inputmeta_map;
  }

  std::unordered_set<GradNodeBase*>* GetPotentialStopNodes() {
    return &potential_stop_nodes;
  }

  std::unordered_set<GradNodeBase*>* GetPotentialStartupNodes() {
    return &potential_startup_nodes;
  }

  void Clear() {
    no_grad_var_nodes_inputmeta_map.clear();
    input_target_nodes_inputmeta_map.clear();
    potential_startup_nodes.clear();
    potential_stop_nodes.clear();
    depending_nodes.clear();
    results_map.clear();
349 350 351 352 353 354
    copied_grad_nodes_.clear();
    orig_to_copied_node_mapping_.clear();
  }

  GradNodeBase* CopyGradNode(const std::shared_ptr<GradNodeBase>& orig_node) {
    if (orig_to_copied_node_mapping_.count(orig_node.get())) {
355
      return orig_to_copied_node_mapping_[orig_node.get()].get();
356 357 358 359
    }
    std::shared_ptr<GradNodeBase> copied_node = orig_node->Copy();

    // Save node and update mapping
360
    orig_to_copied_node_mapping_[orig_node.get()] = copied_node;
361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384
    copied_grad_nodes_.push_back(copied_node);

    return copied_node.get();
  }

  void ReconstructBackwardGraph(
      const std::queue<GradNodeBase*>& orig_init_queue) {
    std::queue<GradNodeBase*> queue = orig_init_queue;
    std::unordered_set<GradNodeBase*> visited;

    // BFS and recursively copy the grad nodes
    while (!queue.empty()) {
      GradNodeBase* orig_node = queue.front();
      queue.pop();
      if (visited.count(orig_node)) {
        continue;
      }
      visited.insert(orig_node);

      PADDLE_ENFORCE(
          orig_to_copied_node_mapping_.count(orig_node),
          paddle::platform::errors::Fatal(
              "Cannot reconstruct backward graph,"
              "unable to find copied target for certain grad node."));
385
      GradNodeBase* copied_node = orig_to_copied_node_mapping_[orig_node].get();
386

387 388 389 390 391 392 393 394 395
      const paddle::small_vector<std::vector<GradSlotMeta>,
                                 kSlotSmallVectorSize>& orig_meta =
          orig_node->OutputMeta();
      paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
          copied_edges = copied_node->MutableOutputMeta();
      for (size_t i = 0; i < orig_meta.size(); i++) {
        for (size_t j = 0; j < orig_meta[i].size(); j++) {
          const Edge& orig_edge = orig_meta[i][j].GetEdge();
          Edge& copied_edge = copied_edges[i][j].GetMutableEdge();
396 397 398 399 400 401 402 403 404

          std::shared_ptr<GradNodeBase> orig_next_node =
              orig_edge.GetMutableGradNode();
          if (!orig_next_node) continue;

          // Copy Next Node
          std::shared_ptr<GradNodeBase> copied_next_node;
          if (orig_to_copied_node_mapping_.count(orig_next_node.get())) {
            copied_next_node =
405
                orig_to_copied_node_mapping_[orig_next_node.get()];
406 407 408 409

          } else {
            copied_next_node = orig_next_node->Copy();
            orig_to_copied_node_mapping_[orig_next_node.get()] =
410
                copied_next_node;
411 412 413 414 415 416 417 418 419 420 421
            copied_grad_nodes_.push_back(copied_next_node);
          }

          // Update Edge's Grad Node
          copied_edge.SetGradNode(copied_next_node);

          // Update BFS queue
          queue.push(orig_next_node.get());
        }
      }
    }
422 423
  }

424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440
 private:
  GeneralGrad() = default;
  static GeneralGrad* general_grad_;
  // no_grad_vars's GradNode and GradNode's InputMeta.
  std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
      no_grad_var_nodes_inputmeta_map;
  // inputs's GradNode and GradNode's InputMeta.
  std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
      input_target_nodes_inputmeta_map;
  // Record all the potential startup_nodes, will be changed.
  std::unordered_set<GradNodeBase*> potential_startup_nodes;
  // Record all the potential stop nodes, will be changed.
  std::unordered_set<GradNodeBase*> potential_stop_nodes;
  std::unordered_map<GradNodeBase* /* next node */,
                     std::unordered_set<GradNodeBase*> /* pre nodes */>
      depending_nodes;
  std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map;
441 442

  std::vector<std::shared_ptr<GradNodeBase>> copied_grad_nodes_;
443 444
  std::unordered_map<GradNodeBase*, std::shared_ptr<GradNodeBase>>
      orig_to_copied_node_mapping_;
445

446 447
  DISABLE_COPY_AND_ASSIGN(GeneralGrad);
};
448

449 450
std::unordered_map<GradNodeBase*, int> getInDegreeMap(
    const std::queue<GradNodeBase*>& init_queue) {
451
  // Calculate in_degree for each node
452 453
  // We can completely remove this pass, if in_degree were set during forward
  // pass
454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469
  std::unordered_map<GradNodeBase*, int> node_in_degree_map;

  // Copy nodes
  std::queue<GradNodeBase*> queue = init_queue;
  std::unordered_set<GradNodeBase*> visited;

  // Visit each node exactly once in any order
  while (!queue.empty()) {
    GradNodeBase* node = queue.front();
    queue.pop();

    if (visited.count(node)) {
      continue;
    }
    visited.insert(node);

470 471 472 473 474
    PADDLE_ENFORCE_NOT_NULL(
        node,
        paddle::platform::errors::Fatal(
            "We got null node when we traverse the backward graph, and this "
            "should not happened please check your code and contact us."));
475
    // Find and append next nodes
476 477 478 479 480
    const paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
        metas = node->OutputMeta();
    for (const auto& meta_list : metas) {
      for (const GradSlotMeta& meta : meta_list) {
        const auto& edge = meta.GetEdge();
481 482 483 484 485 486 487 488 489 490 491 492 493 494
        GradNodeBase* next_node = edge.GetMutableGradNode().get();
        // Next node could be nullptr if it is leaf tensor with no
        // AccumulationNode attached
        // Or it could also originated from dispensable inputs
        if (!next_node) continue;

        // Update in_degree
        if (!node_in_degree_map.count(next_node))
          node_in_degree_map[next_node] = 0;
        node_in_degree_map[next_node]++;
        queue.push(next_node);
      }
    }
  }
495

496
  return node_in_degree_map;
497 498 499 500 501 502 503 504 505 506 507 508 509 510 511
}

// Enforce GradNode has TensorWrappers as Input
void EnforceGradNodeHasInput(GradNodeBase* node) {
  VLOG(6) << "Running in EnforceGradNodeHasInput";
  PADDLE_ENFORCE_NE(
      node->IsTensorWrappersCleared(), true,
      paddle::platform::errors::Fatal(
          "The TensorWrappers of %s do not exist. This may be because:\n"
          "You calculate backward twice for the same subgraph without "
          "setting retain_graph=True. Please set retain_graph=True in the "
          "first backward/grad call.\n",
          node->name()));
}

512 513 514 515 516 517 518 519 520 521 522 523
void DuplicateCheck(const std::vector<paddle::experimental::Tensor>& inputs,
                    bool is_input) {
  std::unordered_set<AutogradMeta*> visisted_ins;
  std::string msg = is_input ? "inputs" : "outputs";
  for (auto in : inputs) {
    AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(in);
    PADDLE_ENFORCE_EQ(
        visisted_ins.count(auto_grad_meta), 0,
        paddle::platform::errors::AlreadyExists(
            "%s contain duplicate tensor %s, please check %s carefully.", msg,
            in.name(), msg));
    visisted_ins.insert(auto_grad_meta);
524 525 526
  }
}

527 528
GeneralGrad* GeneralGrad::general_grad_ = new GeneralGrad();

529 530 531 532 533 534 535
std::vector<paddle::experimental::Tensor> RunBackward(
    const std::vector<paddle::experimental::Tensor>& tensors,  // output
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
    bool retain_graph, bool create_graph = false,
    const std::vector<paddle::experimental::Tensor>& inputs = {},
    bool allow_unused = false,
    const std::vector<paddle::experimental::Tensor>& no_grad_vars = {}) {
536
  VLOG(6) << "Start Backward";
537

538 539 540 541
  // *Gradient Hook should happen at node-level
  // *Inplace version check should perform at node-level
  // *Cross-batch accumulation happens at forward pass

542 543
  // GeneralGrad
  bool is_general_grad = !inputs.empty();
544
  if (is_general_grad) GeneralGrad::Instance().Clear();
545

546 547 548 549
  /* --- Initialization --- */
  // 1. Init queue with starting nodes
  // 2. Prepare initial input buffers
  std::queue<GradNodeBase*> queue;
550
  std::queue<GradNodeBase*> orig_queue;
551 552 553
  std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>>
      node_input_buffers_dict;
  for (size_t i = 0; i < tensors.size(); i++) {
554
    const paddle::experimental::Tensor& tensor = tensors[i];
555 556 557 558 559 560 561 562

    AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(tensor);
    // Get grad input info from target tensors
    auto input_info = auto_grad_meta->OutRankInfo();

    VLOG(2) << "Out Rank of Tensor is slot: " << input_info.first
            << ", rank: " << input_info.second;
    // Get target GradNodeBase from target tensors
563 564 565 566 567 568 569 570 571 572
    auto shared_grad_node = auto_grad_meta->GetMutableGradNode();

    if (shared_grad_node == nullptr || shared_grad_node.get() == nullptr ||
        auto_grad_meta->StopGradient()) {
      VLOG(3) << "Skip auto grad since there is no grad op for var or loss is "
                 "stop_gradient=True: "
              << tensor.name();
      continue;
    }

573
    // TODO(zhanlve): Copy and Modify GradNode if is_general_grad
574
    GradNodeBase* grad_node = shared_grad_node.get();
575 576 577 578 579 580 581 582 583 584
    if (is_general_grad) {
      // Save orig grad node
      orig_queue.push(grad_node);

      // Replace grad_node with copied grad_node
      grad_node = GeneralGrad::Instance().CopyGradNode(shared_grad_node);

      // Record potential startup grad node
      GeneralGrad::Instance().GetPotentialStartupNodes()->insert(grad_node);
    }
585 586 587

    // Prepare GradTensorHolder
    if (!node_input_buffers_dict.count(grad_node)) {
588 589
      VLOG(6) << "Create Value for grad input tensor " << i
              << " of grad node: " << grad_node->name();
590 591 592
      node_input_buffers_dict[grad_node] =
          std::make_unique<GradTensorHolder>(grad_node->InputMeta());
    }
593 594 595
    bool copy_from_grad_t =
        grad_tensors.size() > 0 && grad_tensors[i].initialized();
    if (copy_from_grad_t) {
596 597 598 599 600
      PADDLE_ENFORCE(
          grad_tensors.size() == tensors.size(),
          paddle::platform::errors::Fatal(
              "Detected size mismatch between tensors and grad_tensors"
              "grad_tensors should either have "
601
              "size = 0 or same size as tensors."));
602 603
      // Feed given tensor if it's provided
      VLOG(6) << "Fill grad input tensor " << i << "with give grad tensor";
604

605 606 607
      // Deep copy
      node_input_buffers_dict[grad_node]->CopyValueFromTensor(
          input_info.first, input_info.second, grad_tensors[i]);
608 609 610 611 612 613 614
    } else {
      VLOG(6) << "Fill grad input tensor " << i << " with 1.0";
      // Initialize tensor with 1.0
      // Forward Tensor "tensor" is passed to indicate tensortype, datatype and
      // dims
      // GradTensorHolder will initialize another tensor with same tensortype,
      // datatype and dims but filled with 1.0
615
      node_input_buffers_dict[grad_node]->CopyValueFromTensor(
616 617 618
          input_info.first, input_info.second, tensor, true /*fill_one=true*/);
    }

619
    // Prepare queue, potential startup_nodes
620
    queue.push(grad_node);
621 622 623 624 625
  }

  if (is_general_grad) {
    // Copy Backward Graph
    GeneralGrad::Instance().ReconstructBackwardGraph(orig_queue);
626 627 628 629 630 631 632
  }

  VLOG(6) << "Update In degree Map for backward";
  // 3. Compute in_degree for each node
  std::unordered_map<GradNodeBase*, int> node_in_degree_map =
      getInDegreeMap(queue);

633 634 635 636
  if (is_general_grad) {
    // Prepare several vital preprocess for GeneralGrad
    GeneralGrad::Instance().PreparedForGeneralGrad(inputs, no_grad_vars, &queue,
                                                   node_input_buffers_dict);
637 638
  }

639
  VLOG(6) << " startup_ops' size is :" << queue.size();
640

641 642 643
  /* --- Topological Visit --- */
  // 1. Pop queue
  // 2. Run node
644
  //    |- Check and capture target result
645 646 647 648
  //    |- node(grads)
  //    |- Prepare for next node
  // 3. Update queue
  VLOG(6) << "Run Backward";
649 650
  while (!queue.empty()) {
    GradNodeBase* node = queue.front();
651
    VLOG(6) << "Running GradNode:" << node->name();
652

653
    paddle::platform::RecordEvent node_record_event(
C
chenjian 已提交
654
        std::string((*node).name()) + " grad_node",
655 656
        paddle::platform::TracerEventType::Operator, 1);

657 658 659 660 661 662
    if (queue.size() > 1 && node_in_degree_map[node] != 0) {
      queue.pop();
      continue;
    }
    queue.pop();

663 664 665 666
    // Run node: This is where Hook happens
    PADDLE_ENFORCE(
        node_input_buffers_dict.count(node),
        paddle::platform::errors::Fatal(
667
            "Unable to find next node in the GradTensorHolder \n"
668
            "Trying to run Node without configuring its GradTensorHolder."));
669 670 671

    std::unique_ptr<GradTensorHolder> node_input_buffer =
        std::move(node_input_buffers_dict[node]);
672

673 674 675 676
    // Set input target grad_var from node_input_buffer by inputmeta
    if (!inputs.empty() && is_general_grad) {
      GeneralGrad::Instance().SetResultForInputTargetVar(*node_input_buffer,
                                                         node);
677 678 679
    }

    // no_grad_vars
680 681 682 683 684 685 686 687 688 689
    if (!no_grad_vars.empty() && is_general_grad) {
      auto iter =
          GeneralGrad::Instance().GetNoGradVarNodesInputMetaMap()->find(node);
      if (iter !=
          GeneralGrad::Instance().GetNoGradVarNodesInputMetaMap()->end()) {
        VLOG(6) << "Change the input buffer[slot][rank] by Zeros";
        auto rank_info = (iter->second)->OutRankInfo();
        node_input_buffer->SetBufferSlotRankZeros(rank_info.first,
                                                  rank_info.second);
      }
690 691 692 693
    }

    VLOG(6) << "Running GradNode:" << node->name();

694
    // Check input
695 696
    EnforceGradNodeHasInput(node);

697
    VLOG(6) << "Run Backward Kernel with GradTensorHolder.";
698
    // Run Pre Backward Node and get outputs
699 700 701 702
    paddle::small_vector<std::vector<paddle::experimental::Tensor>,
                         kSlotSmallVectorSize>
        grad_output_tensors = (*node)(node_input_buffer->Buffers(),
                                      create_graph, is_general_grad);
703 704 705 706 707 708 709 710

    // retain_grad or not
    if (!retain_graph) {
      VLOG(6)
          << "retain_graph is false, need to clear the TensorWrapper of nodes.";
      node->ClearTensorWrappers();
    }

711
    // TODO(jiabin): Should we erase it or find a more efficient way.
712

713 714 715
    node_input_buffers_dict.erase(node);

    // Prepare GradTensorHolder for next node
716 717 718
    const paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
        metas = node->OutputMeta();
    PADDLE_ENFORCE(metas.size() == grad_output_tensors.size() || metas.empty(),
719 720
                   paddle::platform::errors::Fatal(
                       "Number of edges should be either empty ( for leaf node "
721 722
                       ") or the same as number of output grad tensors, but we "
                       "got edges size is: %d, grad_output size is: %d",
723
                       metas.size(), grad_output_tensors.size()));
724

725 726 727
    for (size_t i = 0; i < metas.size(); i++) {
      for (size_t j = 0; j < metas[i].size(); j++) {
        const Edge& edge = metas[i][j].GetEdge();
J
Jiabin Yang 已提交
728 729 730
        if (!edge.IsInitialized()) {
          continue;
        }
731 732 733 734
        auto edge_rank = edge.GetEdgeRankInfo();
        // Since we make edge has as same rank as bwd outputs, we indexing them
        // with
        // the same rank(i, j)
735
        auto next_node_shared = edge.GetMutableGradNode();
736

737 738 739
        // Next node could be nullptr if it is leaf tensor with no
        // AccumulationNode attached
        // Or it could also originated from dispensable inputs
740 741 742 743
        if (!next_node_shared || !next_node_shared.get() ||
            grad_output_tensors[i].empty()) {
          continue;
        }
744

745 746 747 748 749 750 751
        PADDLE_ENFORCE_LT(
            j, grad_output_tensors[i].size(),
            paddle::platform::errors::Fatal(
                "Rank of grad_output_tensors should be less than "
                "grad_output_tensors[i].size(), which is: %d. This error may "
                "indicate autoprune or autograd api error. ",
                grad_output_tensors.size()));
752 753
        paddle::experimental::Tensor& grad_output_tensor =
            grad_output_tensors[i][j];
754 755 756

        if ((!grad_output_tensor.defined() ||
             !grad_output_tensor.initialized())) {
757 758
          VLOG(6) << "We get grad_output_tensor with slot: " << i
                  << ", rank: " << j << " as uninitialized or undefined tensor";
759
        }
760

761 762 763 764
        VLOG(6) << "Get Edge and grad_output_tensor with slot: " << i
                << ", rank: " << j
                << " 's name is: " << grad_output_tensor.name();

765 766 767 768 769 770 771 772 773 774
        auto* next_node = next_node_shared.get();
        if (!node_input_buffers_dict.count(next_node)) {
          const auto& input_meta = next_node->InputMeta();
          auto grad_tensor_holder =
              std::make_unique<GradTensorHolder>(input_meta);
          VLOG(6) << "Construct GradTensorHolder for grad node: "
                  << next_node->name();
          node_input_buffers_dict[next_node] = std::move(grad_tensor_holder);
        }

775 776
        VLOG(6) << "Sum grad inputs for edge slot: " << edge_rank.first
                << ", rank: " << edge_rank.second;
777

778
        node_input_buffers_dict[next_node]->add(
779 780
            edge_rank.first, edge_rank.second, grad_output_tensor,
            create_graph);
781 782 783

        // Update queue
        node_in_degree_map[next_node]--;
784

785 786 787 788
        PADDLE_ENFORCE(
            node_in_degree_map[next_node] >= 0,
            paddle::platform::errors::Fatal(
                "Detected in-degree value smaller than zero. For Node: %s"
789
                "Node's in-degree cannot be negative.",
790
                next_node->name()));
791

792 793 794 795 796 797 798 799 800 801
        if (is_general_grad) {
          bool is_potential_stop_node =
              GeneralGrad::Instance().GetPotentialStopNodes()->count(next_node);
          if (node_in_degree_map[next_node] == 0 && !is_potential_stop_node) {
            queue.emplace(std::move(next_node));
          }
        } else {
          if (node_in_degree_map[next_node] == 0) {
            queue.emplace(std::move(next_node));
          }
802 803 804 805
        }
      }
    }
  }
806

807 808
  if (!is_general_grad) return {};
  return GeneralGrad::Instance().GetResults(inputs, allow_unused, create_graph);
809 810
}

811
void Backward(
812
    const std::vector<paddle::experimental::Tensor>& tensors,  // outputs
813 814 815 816 817 818
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
    bool retain_graph) {
  VLOG(6) << "Run in Backward";
  paddle::platform::RecordEvent backward_record_event(
      "backward", paddle::platform::TracerEventType::Operator, 1);
  RunBackward(tensors, grad_tensors, retain_graph);
J
Jiabin Yang 已提交
819
  phi::autotune::AutoTuneStatus::Instance().Update();
820 821 822
}

std::vector<paddle::experimental::Tensor> Grad(
823
    const std::vector<paddle::experimental::Tensor>& tensors,  // outputs
824 825 826 827 828
    const std::vector<paddle::experimental::Tensor>& inputs,
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
    bool retain_graph, bool create_graph, bool only_inputs, bool allow_unused,
    const std::vector<paddle::experimental::Tensor>& no_grad_vars) {
  VLOG(6) << "Run in Grad";
829 830 831 832

  DuplicateCheck(inputs, true /* is_input */);
  DuplicateCheck(tensors, false /* is_input */);

833 834 835
  return RunBackward(tensors, grad_tensors, retain_graph, create_graph, inputs,
                     allow_unused, no_grad_vars);
}
836
}  // namespace egr