backward.cc 30.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/eager/backward.h"
#include <queue>

#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/grad_tensor_holder.h"
#include "paddle/fluid/eager/utils.h"
22 23
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
24

J
Jiabin Yang 已提交
25
#include "glog/logging.h"
26 27
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
J
Jiabin Yang 已提交
28
#include "paddle/phi/kernels/autotune/switch_autotune.h"
29 30 31

namespace egr {

32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
/*
* GeneralGrad is Helpper class to implement custom grad operation between
* outputs and inputs.
*
* **/
class GeneralGrad {
 public:
  static GeneralGrad& Instance() { return *general_grad_; }

  // Get inputs's / no_grad_vars's GradNodes and InputMeta Info
  void GetTargetNodesInfo(
      const std::vector<paddle::experimental::Tensor>& inputs,
      bool is_no_grad_vars) {
    std::string msg = is_no_grad_vars ? "no_grad_vars" : "inputs";
    VLOG(6) << "Running in GetTargetNodesInfo.";
    if (!inputs.empty()) {
      VLOG(6) << msg << " are not empty.";
      size_t num_inputs = inputs.size();
      for (size_t i = 0; i < num_inputs; i++) {
        AutogradMeta* auto_grad_meta =
            EagerUtils::unsafe_autograd_meta(inputs[i]);
53 54 55
        auto* target_node = auto_grad_meta->GetMutableGradNode().get();

        if (orig_to_copied_node_mapping_.count(target_node)) {
56
          target_node = orig_to_copied_node_mapping_[target_node].get();
57 58 59 60 61 62
        } else {
          VLOG(6) << "Unable to find target node in "
                     "orig_to_copied_node_mapping_, likely indicating an "
                     "unused input";
        }

63 64 65 66 67 68 69 70 71 72 73 74 75
        PADDLE_ENFORCE_NOT_NULL(target_node,
                                paddle::platform::errors::Fatal(
                                    "There is no grad op for %s:[%d] or it's"
                                    "stop_gradient=True.",
                                    msg, i));
        if (is_no_grad_vars) {
          (no_grad_var_nodes_inputmeta_map)[target_node] = auto_grad_meta;
        } else {  // normal input
          (input_target_nodes_inputmeta_map)[target_node] = auto_grad_meta;
        }
      }
    }
  }
76

77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
  // Purify potential_startup_nodes, remove nodes those are the same as
  // input_target_nodes
  void PurifyPotentialStartUpNodes() {
    VLOG(6) << "Running in PurifyPotentialStartUpNodes";
    if (input_target_nodes_inputmeta_map.empty()) return;
    std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
    for (auto startup_op : potential_startup_nodes) {
      auto iter = input_target_nodes_inputmeta_map.find(startup_op);
      if (iter != input_target_nodes_inputmeta_map.end()) {
        potential_startup_nodes_to_be_erased.emplace(iter->first);
      }
    }
    if (!potential_startup_nodes_to_be_erased.empty()) {
      for (auto nodes : potential_startup_nodes_to_be_erased) {
        potential_startup_nodes.erase(nodes);
      }
    }
  }
95

96 97 98 99 100 101 102 103 104 105 106
  // Remove some nodes those doesn't need to be
  // stored in potential_stop_nodes、potential_startup_nodes
  void UpdateGraphInfo() {
    // Updated potential_sotp_nodes by depending_nodes,
    // make sure the path from root to target_node is ok
    std::unordered_set<GradNodeBase*> _startup_ops;
    VLOG(6) << "Running in UpdateGraphInfo";
    std::queue<GradNodeBase*> queue;
    for (auto& target_nodes_inputmeta_pair : input_target_nodes_inputmeta_map) {
      queue.emplace(target_nodes_inputmeta_pair.first);
    }
107

108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
    while (!queue.empty()) {
      auto* target_node = queue.front();
      queue.pop();
      if (!(depending_nodes)[target_node].empty()) {
        auto precedding_nodes = (depending_nodes)[target_node];
        for (auto pre_nodes : precedding_nodes) {
          queue.emplace(pre_nodes);
          if (potential_stop_nodes.find(pre_nodes) !=
              potential_stop_nodes.end()) {
            potential_stop_nodes.erase(pre_nodes);
          }
        }
      } else {  // startup_ops have no precedding nodes
        VLOG(6) << "Emplace _startup_ops";
        _startup_ops.emplace(target_node);
123 124
      }
    }
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
    // Purify potential_startup_nodes again, remove some
    // potential startup_nodes that unreach to input target nodes
    if (!_startup_ops.empty()) {
      std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
      for (auto node : potential_startup_nodes) {
        if (_startup_ops.count(node) == 0) {
          VLOG(6) << "Set up potential_startup_nodes_to_be_erased";
          potential_startup_nodes_to_be_erased.emplace(node);
        }
      }
      if (!potential_startup_nodes_to_be_erased.empty()) {
        for (auto node : potential_startup_nodes_to_be_erased) {
          VLOG(6) << "Erase nodes in potential_startup_nodes_to_be_erased";
          potential_startup_nodes.erase(node);
        }
      }
141
    }
142
  }
143

144 145 146 147
  // Get Graph Info Betweent input target GradNode and outputs,
  // record depending_nodes、potential_stop_nodes、potential_startup_nodes
  void GetGraphInfoBetweenTargets(const std::queue<GradNodeBase*>& init_queue) {
    VLOG(6) << "Runing In GetGraphInfoBetweenTargets";
148

149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
    // Calculate in_degree for each node
    std::unordered_map<GradNodeBase*, int> node_in_degree_map;

    // Copy nodes
    std::queue<GradNodeBase*> queue = init_queue;
    std::unordered_set<GradNodeBase*> visited;

    // Visit each node exactly once in any order
    while (!queue.empty()) {
      GradNodeBase* node = queue.front();
      queue.pop();

      if (visited.count(node)) {
        continue;
      }
      visited.insert(node);

      // Check node is target_nodes or not, if node is not target_node,
      // all the next_node will be marked in potential_stop_nodes
      bool is_potential_stop_nodes =
          input_target_nodes_inputmeta_map.count(node);

      // Find and append next nodes
      const std::vector<std::vector<Edge>>& edges = node->GetEdges();
      for (const auto& edge_list : edges) {
        for (const Edge& edge : edge_list) {
          GradNodeBase* next_node = edge.GetMutableGradNode().get();

          // Next node could be nullptr if it is leaf tensor with no
          // AccumulationNode attached
          // Or it could also originated from dispensable inputs
          if (!next_node) continue;

          // if node not in input_target_nodes,
          // all the next_nodes of current node will be inserted to
          // potential_stop_node
          if (is_potential_stop_nodes) {
            potential_stop_nodes.emplace(next_node);
          }

          // Update in_degree
          if (!node_in_degree_map.count(next_node))
            node_in_degree_map[next_node] = 0;
          node_in_degree_map[next_node]++;

          // Record depending relationship
          (depending_nodes)[next_node].emplace(node);
          queue.push(next_node);
        }
198 199
      }
    }
200 201 202
    // Update Graph Info, remove some nodes in
    // potential_stop_nodes、potential_startup_nodes、
    UpdateGraphInfo();
203 204
  }

205 206 207 208 209 210
  void ModifyReadyQueue(std::queue<GradNodeBase*>* queue) {
    std::queue<GradNodeBase*> tmp_queue;
    for (auto nodes : potential_startup_nodes) {
      tmp_queue.emplace(nodes);
    }
    tmp_queue.swap(*queue);
211 212
  }

213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
  // Set result for input target grad_var when potential_startup_nodes is empty
  void SetResultForInputTargetVar(
      const std::unordered_map<GradNodeBase*,
                               std::unique_ptr<GradTensorHolder>>&
          node_input_buffers_dict) {
    if (potential_startup_nodes.size() == 0) {
      for (auto input_target_node : *GetInPutTargetNodesInputMetaMap()) {
        // out rank_info of forward op
        auto rank_info = input_target_node.second->OutRankInfo();
        auto iter = node_input_buffers_dict.find(input_target_node.first);
        if (iter != node_input_buffers_dict.end()) {
          auto& target_result =
              (iter->second)->Buffers()[rank_info.first][rank_info.second];
          // save the target result
          results_map[input_target_node.first] = target_result;
228 229 230 231
        }
      }
    }
  }
232 233 234 235 236 237 238 239 240 241 242 243 244 245

  // Set input target grad_var from node_input_buffer by inputmeta
  void SetResultForInputTargetVar(GradTensorHolder input_buffers,
                                  GradNodeBase* node) {
    auto iter = GetInPutTargetNodesInputMetaMap()->find(node);
    if (iter != GetInPutTargetNodesInputMetaMap()->end()) {
      VLOG(6) << "Get target result by by inputmeta";
      // out rank_info of forward op
      auto rank_info = (iter->second)->OutRankInfo();
      // rank_info is a pair, first means slot_id, second means rank.
      auto& target_result =
          input_buffers.Buffers()[rank_info.first][rank_info.second];
      // save the target result
      results_map[node] = target_result;
246
    }
247 248 249 250 251 252 253 254 255 256 257 258 259 260
  }

  std::vector<paddle::experimental::Tensor> GetResults(
      const std::vector<paddle::experimental::Tensor>& inputs,
      bool allow_unused, bool create_graph) {
    VLOG(6) << "Running in GetResults";
    if (inputs.empty()) return {};

    std::vector<paddle::experimental::Tensor> results;
    results.reserve(inputs.size());

    for (size_t i = 0; i < inputs.size(); ++i) {
      auto& input = inputs[i];
      AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input);
261 262 263

      auto* target_node = auto_grad_meta->GetMutableGradNode().get();
      if (orig_to_copied_node_mapping_.count(target_node)) {
264
        target_node = orig_to_copied_node_mapping_[target_node].get();
265 266 267 268 269
      } else {
        VLOG(6) << "Unable to find target node in "
                   "orig_to_copied_node_mapping_, likely indicating an unused "
                   "input";
      }
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285

      auto iter = results_map.find(target_node);
      if (iter != results_map.end()) {
        // set StopGradient = !create_graph
        AutogradMeta* tensor_auto_grad_meta =
            EagerUtils::autograd_meta(&(iter->second));
        tensor_auto_grad_meta->SetStopGradient(!create_graph);
        results.emplace_back(iter->second);
      } else {
        PADDLE_ENFORCE_EQ(allow_unused, true,
                          paddle::platform::errors::InvalidArgument(
                              "The %d-th input does not appear in the backward "
                              "graph. Please check the input tensor or set "
                              "allow_unused=True to get None result.",
                              i));
        results.emplace_back();
286 287
      }
    }
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345
    Clear();
    return results;
  }

  void PreparedForGeneralGrad(
      const std::vector<paddle::experimental::Tensor>& inputs,
      const std::vector<paddle::experimental::Tensor>& no_grad_vars,
      std::queue<GradNodeBase*>* queue,
      const std::unordered_map<GradNodeBase*,
                               std::unique_ptr<GradTensorHolder>>&
          node_input_buffers_dict) {
    // Get no_grad_vars's GradNodes and InputMeta Info
    GetTargetNodesInfo(no_grad_vars, true /* is_no_grad_vars */);
    // Get inputs's GradNodes and InputMeta Info
    GetTargetNodesInfo(inputs, false /* is_no_grad_vars */);
    // Purify potential_startup_ops, remove those nodes that are the same as
    // input_target_nodes
    PurifyPotentialStartUpNodes();
    // Get Graph Info Betweent input target gradnode and outputs
    // Record the depending_nodes and
    // potential_stop_nodes、potential_startup_nodes
    GetGraphInfoBetweenTargets(*queue);
    // Reset queue. Queue is empty only when
    // 1.input equals to output. 2.input can not reach to output.
    ModifyReadyQueue(queue);
    // Set result for input target grad_var when queue is empty
    if (queue->empty()) SetResultForInputTargetVar(node_input_buffers_dict);
  }

  bool IsPotentialStopNodes(GradNodeBase* node) {
    return potential_stop_nodes.count(node);
  }

  std::unordered_map<GradNodeBase*, AutogradMeta*>*
  GetNoGradVarNodesInputMetaMap() {
    return &no_grad_var_nodes_inputmeta_map;
  }

  std::unordered_map<GradNodeBase*, AutogradMeta*>*
  GetInPutTargetNodesInputMetaMap() {
    return &input_target_nodes_inputmeta_map;
  }

  std::unordered_set<GradNodeBase*>* GetPotentialStopNodes() {
    return &potential_stop_nodes;
  }

  std::unordered_set<GradNodeBase*>* GetPotentialStartupNodes() {
    return &potential_startup_nodes;
  }

  void Clear() {
    no_grad_var_nodes_inputmeta_map.clear();
    input_target_nodes_inputmeta_map.clear();
    potential_startup_nodes.clear();
    potential_stop_nodes.clear();
    depending_nodes.clear();
    results_map.clear();
346 347 348 349 350 351
    copied_grad_nodes_.clear();
    orig_to_copied_node_mapping_.clear();
  }

  GradNodeBase* CopyGradNode(const std::shared_ptr<GradNodeBase>& orig_node) {
    if (orig_to_copied_node_mapping_.count(orig_node.get())) {
352
      return orig_to_copied_node_mapping_[orig_node.get()].get();
353 354 355 356
    }
    std::shared_ptr<GradNodeBase> copied_node = orig_node->Copy();

    // Save node and update mapping
357
    orig_to_copied_node_mapping_[orig_node.get()] = copied_node;
358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381
    copied_grad_nodes_.push_back(copied_node);

    return copied_node.get();
  }

  void ReconstructBackwardGraph(
      const std::queue<GradNodeBase*>& orig_init_queue) {
    std::queue<GradNodeBase*> queue = orig_init_queue;
    std::unordered_set<GradNodeBase*> visited;

    // BFS and recursively copy the grad nodes
    while (!queue.empty()) {
      GradNodeBase* orig_node = queue.front();
      queue.pop();
      if (visited.count(orig_node)) {
        continue;
      }
      visited.insert(orig_node);

      PADDLE_ENFORCE(
          orig_to_copied_node_mapping_.count(orig_node),
          paddle::platform::errors::Fatal(
              "Cannot reconstruct backward graph,"
              "unable to find copied target for certain grad node."));
382
      GradNodeBase* copied_node = orig_to_copied_node_mapping_[orig_node].get();
383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399

      const std::vector<std::vector<Edge>>& orig_edges = orig_node->GetEdges();
      std::vector<std::vector<Edge>>& copied_edges =
          copied_node->GetMutableEdges();
      for (size_t i = 0; i < orig_edges.size(); i++) {
        for (size_t j = 0; j < orig_edges[i].size(); j++) {
          const Edge& orig_edge = orig_edges[i][j];
          Edge& copied_edge = copied_edges[i][j];

          std::shared_ptr<GradNodeBase> orig_next_node =
              orig_edge.GetMutableGradNode();
          if (!orig_next_node) continue;

          // Copy Next Node
          std::shared_ptr<GradNodeBase> copied_next_node;
          if (orig_to_copied_node_mapping_.count(orig_next_node.get())) {
            copied_next_node =
400
                orig_to_copied_node_mapping_[orig_next_node.get()];
401 402 403 404

          } else {
            copied_next_node = orig_next_node->Copy();
            orig_to_copied_node_mapping_[orig_next_node.get()] =
405
                copied_next_node;
406 407 408 409 410 411 412 413 414 415 416
            copied_grad_nodes_.push_back(copied_next_node);
          }

          // Update Edge's Grad Node
          copied_edge.SetGradNode(copied_next_node);

          // Update BFS queue
          queue.push(orig_next_node.get());
        }
      }
    }
417 418
  }

419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435
 private:
  GeneralGrad() = default;
  static GeneralGrad* general_grad_;
  // no_grad_vars's GradNode and GradNode's InputMeta.
  std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
      no_grad_var_nodes_inputmeta_map;
  // inputs's GradNode and GradNode's InputMeta.
  std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
      input_target_nodes_inputmeta_map;
  // Record all the potential startup_nodes, will be changed.
  std::unordered_set<GradNodeBase*> potential_startup_nodes;
  // Record all the potential stop nodes, will be changed.
  std::unordered_set<GradNodeBase*> potential_stop_nodes;
  std::unordered_map<GradNodeBase* /* next node */,
                     std::unordered_set<GradNodeBase*> /* pre nodes */>
      depending_nodes;
  std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map;
436 437

  std::vector<std::shared_ptr<GradNodeBase>> copied_grad_nodes_;
438 439
  std::unordered_map<GradNodeBase*, std::shared_ptr<GradNodeBase>>
      orig_to_copied_node_mapping_;
440

441 442
  DISABLE_COPY_AND_ASSIGN(GeneralGrad);
};
443

444 445
std::unordered_map<GradNodeBase*, int> getInDegreeMap(
    const std::queue<GradNodeBase*>& init_queue) {
446
  // Calculate in_degree for each node
447 448
  // We can completely remove this pass, if in_degree were set during forward
  // pass
449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464
  std::unordered_map<GradNodeBase*, int> node_in_degree_map;

  // Copy nodes
  std::queue<GradNodeBase*> queue = init_queue;
  std::unordered_set<GradNodeBase*> visited;

  // Visit each node exactly once in any order
  while (!queue.empty()) {
    GradNodeBase* node = queue.front();
    queue.pop();

    if (visited.count(node)) {
      continue;
    }
    visited.insert(node);

465 466 467 468 469
    PADDLE_ENFORCE_NOT_NULL(
        node,
        paddle::platform::errors::Fatal(
            "We got null node when we traverse the backward graph, and this "
            "should not happened please check your code and contact us."));
470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487
    // Find and append next nodes
    const std::vector<std::vector<Edge>>& edges = node->GetEdges();
    for (const auto& edge_list : edges) {
      for (const Edge& edge : edge_list) {
        GradNodeBase* next_node = edge.GetMutableGradNode().get();
        // Next node could be nullptr if it is leaf tensor with no
        // AccumulationNode attached
        // Or it could also originated from dispensable inputs
        if (!next_node) continue;

        // Update in_degree
        if (!node_in_degree_map.count(next_node))
          node_in_degree_map[next_node] = 0;
        node_in_degree_map[next_node]++;
        queue.push(next_node);
      }
    }
  }
488

489
  return node_in_degree_map;
490 491 492 493 494 495 496 497 498 499 500 501 502 503 504
}

// Enforce GradNode has TensorWrappers as Input
void EnforceGradNodeHasInput(GradNodeBase* node) {
  VLOG(6) << "Running in EnforceGradNodeHasInput";
  PADDLE_ENFORCE_NE(
      node->IsTensorWrappersCleared(), true,
      paddle::platform::errors::Fatal(
          "The TensorWrappers of %s do not exist. This may be because:\n"
          "You calculate backward twice for the same subgraph without "
          "setting retain_graph=True. Please set retain_graph=True in the "
          "first backward/grad call.\n",
          node->name()));
}

505 506 507 508 509 510 511 512 513 514 515 516
void DuplicateCheck(const std::vector<paddle::experimental::Tensor>& inputs,
                    bool is_input) {
  std::unordered_set<AutogradMeta*> visisted_ins;
  std::string msg = is_input ? "inputs" : "outputs";
  for (auto in : inputs) {
    AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(in);
    PADDLE_ENFORCE_EQ(
        visisted_ins.count(auto_grad_meta), 0,
        paddle::platform::errors::AlreadyExists(
            "%s contain duplicate tensor %s, please check %s carefully.", msg,
            in.name(), msg));
    visisted_ins.insert(auto_grad_meta);
517 518 519
  }
}

520 521
GeneralGrad* GeneralGrad::general_grad_ = new GeneralGrad();

522 523 524 525 526 527 528
std::vector<paddle::experimental::Tensor> RunBackward(
    const std::vector<paddle::experimental::Tensor>& tensors,  // output
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
    bool retain_graph, bool create_graph = false,
    const std::vector<paddle::experimental::Tensor>& inputs = {},
    bool allow_unused = false,
    const std::vector<paddle::experimental::Tensor>& no_grad_vars = {}) {
529
  VLOG(6) << "Start Backward";
530

531 532 533 534
  // *Gradient Hook should happen at node-level
  // *Inplace version check should perform at node-level
  // *Cross-batch accumulation happens at forward pass

535 536
  // GeneralGrad
  bool is_general_grad = !inputs.empty();
537
  if (is_general_grad) GeneralGrad::Instance().Clear();
538

539 540 541 542
  /* --- Initialization --- */
  // 1. Init queue with starting nodes
  // 2. Prepare initial input buffers
  std::queue<GradNodeBase*> queue;
543
  std::queue<GradNodeBase*> orig_queue;
544 545 546
  std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>>
      node_input_buffers_dict;
  for (size_t i = 0; i < tensors.size(); i++) {
547
    const paddle::experimental::Tensor& tensor = tensors[i];
548

549 550 551 552 553 554 555
    AutogradMeta* auto_grad_meta = EagerUtils::nullable_autograd_meta(tensor);
    if (auto_grad_meta == nullptr) {
      VLOG(3) << "Skip auto grad since there is no grad op for var or loss is "
                 "stop_gradient=True: "
              << tensor.name();
      continue;
    }
556 557 558 559 560 561
    // Get grad input info from target tensors
    auto input_info = auto_grad_meta->OutRankInfo();

    VLOG(2) << "Out Rank of Tensor is slot: " << input_info.first
            << ", rank: " << input_info.second;
    // Get target GradNodeBase from target tensors
562 563 564 565 566 567 568 569 570 571
    auto shared_grad_node = auto_grad_meta->GetMutableGradNode();

    if (shared_grad_node == nullptr || shared_grad_node.get() == nullptr ||
        auto_grad_meta->StopGradient()) {
      VLOG(3) << "Skip auto grad since there is no grad op for var or loss is "
                 "stop_gradient=True: "
              << tensor.name();
      continue;
    }

572
    // TODO(zhanlve): Copy and Modify GradNode if is_general_grad
573
    GradNodeBase* grad_node = shared_grad_node.get();
574 575 576 577 578 579 580 581 582 583
    if (is_general_grad) {
      // Save orig grad node
      orig_queue.push(grad_node);

      // Replace grad_node with copied grad_node
      grad_node = GeneralGrad::Instance().CopyGradNode(shared_grad_node);

      // Record potential startup grad node
      GeneralGrad::Instance().GetPotentialStartupNodes()->insert(grad_node);
    }
584 585 586

    // Prepare GradTensorHolder
    if (!node_input_buffers_dict.count(grad_node)) {
587 588
      VLOG(6) << "Create Value for grad input tensor " << i
              << " of grad node: " << grad_node->name();
589 590 591
      node_input_buffers_dict[grad_node] =
          std::make_unique<GradTensorHolder>(grad_node->InputMeta());
    }
592 593 594
    bool copy_from_grad_t =
        grad_tensors.size() > 0 && grad_tensors[i].initialized();
    if (copy_from_grad_t) {
595 596 597 598 599
      PADDLE_ENFORCE(
          grad_tensors.size() == tensors.size(),
          paddle::platform::errors::Fatal(
              "Detected size mismatch between tensors and grad_tensors"
              "grad_tensors should either have "
600
              "size = 0 or same size as tensors."));
601 602
      // Feed given tensor if it's provided
      VLOG(6) << "Fill grad input tensor " << i << "with give grad tensor";
603

604 605 606
      // Deep copy
      node_input_buffers_dict[grad_node]->CopyValueFromTensor(
          input_info.first, input_info.second, grad_tensors[i]);
607 608 609 610 611 612 613
    } else {
      VLOG(6) << "Fill grad input tensor " << i << " with 1.0";
      // Initialize tensor with 1.0
      // Forward Tensor "tensor" is passed to indicate tensortype, datatype and
      // dims
      // GradTensorHolder will initialize another tensor with same tensortype,
      // datatype and dims but filled with 1.0
614
      node_input_buffers_dict[grad_node]->CopyValueFromTensor(
615 616 617
          input_info.first, input_info.second, tensor, true /*fill_one=true*/);
    }

618
    // Prepare queue, potential startup_nodes
619
    queue.push(grad_node);
620 621 622 623 624
  }

  if (is_general_grad) {
    // Copy Backward Graph
    GeneralGrad::Instance().ReconstructBackwardGraph(orig_queue);
625 626 627 628 629 630 631
  }

  VLOG(6) << "Update In degree Map for backward";
  // 3. Compute in_degree for each node
  std::unordered_map<GradNodeBase*, int> node_in_degree_map =
      getInDegreeMap(queue);

632 633 634 635
  if (is_general_grad) {
    // Prepare several vital preprocess for GeneralGrad
    GeneralGrad::Instance().PreparedForGeneralGrad(inputs, no_grad_vars, &queue,
                                                   node_input_buffers_dict);
636 637
  }

638
  VLOG(6) << " startup_ops' size is :" << queue.size();
639

640 641 642
  /* --- Topological Visit --- */
  // 1. Pop queue
  // 2. Run node
643
  //    |- Check and capture target result
644 645 646 647
  //    |- node(grads)
  //    |- Prepare for next node
  // 3. Update queue
  VLOG(6) << "Run Backward";
648 649
  while (!queue.empty()) {
    GradNodeBase* node = queue.front();
650
    VLOG(6) << "Running GradNode:" << node->name();
651

652
    paddle::platform::RecordEvent node_record_event(
653
        std::string((*node).name()) + " grad_node",
654 655
        paddle::platform::TracerEventType::Operator, 1);

656 657 658 659 660 661
    if (queue.size() > 1 && node_in_degree_map[node] != 0) {
      queue.pop();
      continue;
    }
    queue.pop();

662 663 664 665
    // Run node: This is where Hook happens
    PADDLE_ENFORCE(
        node_input_buffers_dict.count(node),
        paddle::platform::errors::Fatal(
666
            "Unable to find next node in the GradTensorHolder \n"
667
            "Trying to run Node without configuring its GradTensorHolder."));
668 669 670

    std::unique_ptr<GradTensorHolder> node_input_buffer =
        std::move(node_input_buffers_dict[node]);
671

672 673 674 675
    // Set input target grad_var from node_input_buffer by inputmeta
    if (!inputs.empty() && is_general_grad) {
      GeneralGrad::Instance().SetResultForInputTargetVar(*node_input_buffer,
                                                         node);
676 677 678
    }

    // no_grad_vars
679 680 681 682 683 684 685 686 687 688
    if (!no_grad_vars.empty() && is_general_grad) {
      auto iter =
          GeneralGrad::Instance().GetNoGradVarNodesInputMetaMap()->find(node);
      if (iter !=
          GeneralGrad::Instance().GetNoGradVarNodesInputMetaMap()->end()) {
        VLOG(6) << "Change the input buffer[slot][rank] by Zeros";
        auto rank_info = (iter->second)->OutRankInfo();
        node_input_buffer->SetBufferSlotRankZeros(rank_info.first,
                                                  rank_info.second);
      }
689 690 691 692
    }

    VLOG(6) << "Running GradNode:" << node->name();

693
    // Check input
694 695
    EnforceGradNodeHasInput(node);

696
    VLOG(6) << "Run Backward Kernel with GradTensorHolder.";
697
    // Run Pre Backward Node and get outputs
698
    std::vector<std::vector<paddle::experimental::Tensor>> grad_output_tensors =
699
        (*node)(node_input_buffer->Buffers(), create_graph, is_general_grad);
700 701 702 703 704 705 706 707

    // retain_grad or not
    if (!retain_graph) {
      VLOG(6)
          << "retain_graph is false, need to clear the TensorWrapper of nodes.";
      node->ClearTensorWrappers();
    }

708
    // TODO(jiabin): Should we erase it or find a more efficient way.
709

710 711 712 713 714 715 716
    node_input_buffers_dict.erase(node);

    // Prepare GradTensorHolder for next node
    const std::vector<std::vector<Edge>>& edges = node->GetEdges();
    PADDLE_ENFORCE(edges.size() == grad_output_tensors.size() || edges.empty(),
                   paddle::platform::errors::Fatal(
                       "Number of edges should be either empty ( for leaf node "
717 718 719
                       ") or the same as number of output grad tensors, but we "
                       "got edges size is: %d, grad_output size is: %d",
                       edges.size(), grad_output_tensors.size()));
720 721 722 723

    for (size_t i = 0; i < edges.size(); i++) {
      for (size_t j = 0; j < edges[i].size(); j++) {
        const Edge& edge = edges[i][j];
J
Jiabin Yang 已提交
724 725 726
        if (!edge.IsInitialized()) {
          continue;
        }
727 728 729 730
        auto edge_rank = edge.GetEdgeRankInfo();
        // Since we make edge has as same rank as bwd outputs, we indexing them
        // with
        // the same rank(i, j)
731
        auto next_node_shared = edge.GetMutableGradNode();
732

733 734 735
        // Next node could be nullptr if it is leaf tensor with no
        // AccumulationNode attached
        // Or it could also originated from dispensable inputs
736 737 738 739
        if (!next_node_shared || !next_node_shared.get() ||
            grad_output_tensors[i].empty()) {
          continue;
        }
740

741 742 743 744 745 746 747
        PADDLE_ENFORCE_LT(
            j, grad_output_tensors[i].size(),
            paddle::platform::errors::Fatal(
                "Rank of grad_output_tensors should be less than "
                "grad_output_tensors[i].size(), which is: %d. This error may "
                "indicate autoprune or autograd api error. ",
                grad_output_tensors.size()));
748 749
        paddle::experimental::Tensor& grad_output_tensor =
            grad_output_tensors[i][j];
750 751 752

        if ((!grad_output_tensor.defined() ||
             !grad_output_tensor.initialized())) {
753 754
          VLOG(6) << "We get grad_output_tensor with slot: " << i
                  << ", rank: " << j << " as uninitialized or undefined tensor";
755
        }
756

757 758 759 760
        VLOG(6) << "Get Edge and grad_output_tensor with slot: " << i
                << ", rank: " << j
                << " 's name is: " << grad_output_tensor.name();

761 762 763 764 765 766 767 768 769 770
        auto* next_node = next_node_shared.get();
        if (!node_input_buffers_dict.count(next_node)) {
          const auto& input_meta = next_node->InputMeta();
          auto grad_tensor_holder =
              std::make_unique<GradTensorHolder>(input_meta);
          VLOG(6) << "Construct GradTensorHolder for grad node: "
                  << next_node->name();
          node_input_buffers_dict[next_node] = std::move(grad_tensor_holder);
        }

771 772
        VLOG(6) << "Sum grad inputs for edge slot: " << edge_rank.first
                << ", rank: " << edge_rank.second;
773

774
        node_input_buffers_dict[next_node]->add(
775 776
            edge_rank.first, edge_rank.second, grad_output_tensor,
            create_graph);
777 778 779

        // Update queue
        node_in_degree_map[next_node]--;
780

781 782 783 784
        PADDLE_ENFORCE(
            node_in_degree_map[next_node] >= 0,
            paddle::platform::errors::Fatal(
                "Detected in-degree value smaller than zero. For Node: %s"
785
                "Node's in-degree cannot be negative.",
786
                next_node->name()));
787

788 789 790 791 792 793 794 795 796 797
        if (is_general_grad) {
          bool is_potential_stop_node =
              GeneralGrad::Instance().GetPotentialStopNodes()->count(next_node);
          if (node_in_degree_map[next_node] == 0 && !is_potential_stop_node) {
            queue.emplace(std::move(next_node));
          }
        } else {
          if (node_in_degree_map[next_node] == 0) {
            queue.emplace(std::move(next_node));
          }
798 799 800 801
        }
      }
    }
  }
802

803 804
  if (!is_general_grad) return {};
  return GeneralGrad::Instance().GetResults(inputs, allow_unused, create_graph);
805 806
}

807
void Backward(
808
    const std::vector<paddle::experimental::Tensor>& tensors,  // outputs
809 810 811 812 813 814
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
    bool retain_graph) {
  VLOG(6) << "Run in Backward";
  paddle::platform::RecordEvent backward_record_event(
      "backward", paddle::platform::TracerEventType::Operator, 1);
  RunBackward(tensors, grad_tensors, retain_graph);
J
Jiabin Yang 已提交
815
  phi::autotune::AutoTuneStatus::Instance().Update();
816 817 818
}

std::vector<paddle::experimental::Tensor> Grad(
819
    const std::vector<paddle::experimental::Tensor>& tensors,  // outputs
820 821 822 823 824
    const std::vector<paddle::experimental::Tensor>& inputs,
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
    bool retain_graph, bool create_graph, bool only_inputs, bool allow_unused,
    const std::vector<paddle::experimental::Tensor>& no_grad_vars) {
  VLOG(6) << "Run in Grad";
825 826 827 828

  DuplicateCheck(inputs, true /* is_input */);
  DuplicateCheck(tensors, false /* is_input */);

829 830 831
  return RunBackward(tensors, grad_tensors, retain_graph, create_graph, inputs,
                     allow_unused, no_grad_vars);
}
832
}  // namespace egr