backward.cc 31.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/eager/backward.h"
#include <queue>

#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/grad_tensor_holder.h"
#include "paddle/fluid/eager/utils.h"
22 23
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
24

J
Jiabin Yang 已提交
25
#include "glog/logging.h"
26 27
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
J
Jiabin Yang 已提交
28
#include "paddle/phi/kernels/autotune/switch_autotune.h"
29 30 31

namespace egr {

32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
/*
* GeneralGrad is Helpper class to implement custom grad operation between
* outputs and inputs.
*
* **/
class GeneralGrad {
 public:
  static GeneralGrad& Instance() { return *general_grad_; }

  // Get inputs's / no_grad_vars's GradNodes and InputMeta Info
  void GetTargetNodesInfo(
      const std::vector<paddle::experimental::Tensor>& inputs,
      bool is_no_grad_vars) {
    std::string msg = is_no_grad_vars ? "no_grad_vars" : "inputs";
    VLOG(6) << "Running in GetTargetNodesInfo.";
    if (!inputs.empty()) {
      VLOG(6) << msg << " are not empty.";
      size_t num_inputs = inputs.size();
      for (size_t i = 0; i < num_inputs; i++) {
        AutogradMeta* auto_grad_meta =
            EagerUtils::unsafe_autograd_meta(inputs[i]);
53 54 55
        auto* target_node = auto_grad_meta->GetMutableGradNode().get();

        if (orig_to_copied_node_mapping_.count(target_node)) {
56
          target_node = orig_to_copied_node_mapping_[target_node].get();
57 58 59 60 61 62
        } else {
          VLOG(6) << "Unable to find target node in "
                     "orig_to_copied_node_mapping_, likely indicating an "
                     "unused input";
        }

63 64 65 66 67 68 69 70 71 72 73 74 75
        PADDLE_ENFORCE_NOT_NULL(target_node,
                                paddle::platform::errors::Fatal(
                                    "There is no grad op for %s:[%d] or it's"
                                    "stop_gradient=True.",
                                    msg, i));
        if (is_no_grad_vars) {
          (no_grad_var_nodes_inputmeta_map)[target_node] = auto_grad_meta;
        } else {  // normal input
          (input_target_nodes_inputmeta_map)[target_node] = auto_grad_meta;
        }
      }
    }
  }
76

77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
  // Purify potential_startup_nodes, remove nodes those are the same as
  // input_target_nodes
  void PurifyPotentialStartUpNodes() {
    VLOG(6) << "Running in PurifyPotentialStartUpNodes";
    if (input_target_nodes_inputmeta_map.empty()) return;
    std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
    for (auto startup_op : potential_startup_nodes) {
      auto iter = input_target_nodes_inputmeta_map.find(startup_op);
      if (iter != input_target_nodes_inputmeta_map.end()) {
        potential_startup_nodes_to_be_erased.emplace(iter->first);
      }
    }
    if (!potential_startup_nodes_to_be_erased.empty()) {
      for (auto nodes : potential_startup_nodes_to_be_erased) {
        potential_startup_nodes.erase(nodes);
      }
    }
  }
95

96 97 98 99 100 101 102 103 104 105 106
  // Remove some nodes those doesn't need to be
  // stored in potential_stop_nodes、potential_startup_nodes
  void UpdateGraphInfo() {
    // Updated potential_sotp_nodes by depending_nodes,
    // make sure the path from root to target_node is ok
    std::unordered_set<GradNodeBase*> _startup_ops;
    VLOG(6) << "Running in UpdateGraphInfo";
    std::queue<GradNodeBase*> queue;
    for (auto& target_nodes_inputmeta_pair : input_target_nodes_inputmeta_map) {
      queue.emplace(target_nodes_inputmeta_pair.first);
    }
107

108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
    while (!queue.empty()) {
      auto* target_node = queue.front();
      queue.pop();
      if (!(depending_nodes)[target_node].empty()) {
        auto precedding_nodes = (depending_nodes)[target_node];
        for (auto pre_nodes : precedding_nodes) {
          queue.emplace(pre_nodes);
          if (potential_stop_nodes.find(pre_nodes) !=
              potential_stop_nodes.end()) {
            potential_stop_nodes.erase(pre_nodes);
          }
        }
      } else {  // startup_ops have no precedding nodes
        VLOG(6) << "Emplace _startup_ops";
        _startup_ops.emplace(target_node);
123 124
      }
    }
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
    // Purify potential_startup_nodes again, remove some
    // potential startup_nodes that unreach to input target nodes
    if (!_startup_ops.empty()) {
      std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
      for (auto node : potential_startup_nodes) {
        if (_startup_ops.count(node) == 0) {
          VLOG(6) << "Set up potential_startup_nodes_to_be_erased";
          potential_startup_nodes_to_be_erased.emplace(node);
        }
      }
      if (!potential_startup_nodes_to_be_erased.empty()) {
        for (auto node : potential_startup_nodes_to_be_erased) {
          VLOG(6) << "Erase nodes in potential_startup_nodes_to_be_erased";
          potential_startup_nodes.erase(node);
        }
      }
141
    }
142
  }
143

144 145 146 147
  // Get Graph Info Betweent input target GradNode and outputs,
  // record depending_nodes、potential_stop_nodes、potential_startup_nodes
  void GetGraphInfoBetweenTargets(const std::queue<GradNodeBase*>& init_queue) {
    VLOG(6) << "Runing In GetGraphInfoBetweenTargets";
148

149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
    // Calculate in_degree for each node
    std::unordered_map<GradNodeBase*, int> node_in_degree_map;

    // Copy nodes
    std::queue<GradNodeBase*> queue = init_queue;
    std::unordered_set<GradNodeBase*> visited;

    // Visit each node exactly once in any order
    while (!queue.empty()) {
      GradNodeBase* node = queue.front();
      queue.pop();

      if (visited.count(node)) {
        continue;
      }
      visited.insert(node);

      // Check node is target_nodes or not, if node is not target_node,
      // all the next_node will be marked in potential_stop_nodes
      bool is_potential_stop_nodes =
          input_target_nodes_inputmeta_map.count(node);

      // Find and append next nodes
172 173 174 175 176 177
      const paddle::small_vector<std::vector<GradSlotMeta>,
                                 kSlotSmallVectorSize>& metas =
          node->OutputMeta();
      for (const auto& meta_list : metas) {
        for (const GradSlotMeta& meta : meta_list) {
          const auto& edge = meta.GetEdge();
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
          GradNodeBase* next_node = edge.GetMutableGradNode().get();

          // Next node could be nullptr if it is leaf tensor with no
          // AccumulationNode attached
          // Or it could also originated from dispensable inputs
          if (!next_node) continue;

          // if node not in input_target_nodes,
          // all the next_nodes of current node will be inserted to
          // potential_stop_node
          if (is_potential_stop_nodes) {
            potential_stop_nodes.emplace(next_node);
          }

          // Update in_degree
          if (!node_in_degree_map.count(next_node))
            node_in_degree_map[next_node] = 0;
          node_in_degree_map[next_node]++;

          // Record depending relationship
          (depending_nodes)[next_node].emplace(node);
          queue.push(next_node);
        }
201 202
      }
    }
203 204 205
    // Update Graph Info, remove some nodes in
    // potential_stop_nodes、potential_startup_nodes、
    UpdateGraphInfo();
206 207
  }

208 209 210 211 212 213
  void ModifyReadyQueue(std::queue<GradNodeBase*>* queue) {
    std::queue<GradNodeBase*> tmp_queue;
    for (auto nodes : potential_startup_nodes) {
      tmp_queue.emplace(nodes);
    }
    tmp_queue.swap(*queue);
214 215
  }

216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
  // Set result for input target grad_var when potential_startup_nodes is empty
  void SetResultForInputTargetVar(
      const std::unordered_map<GradNodeBase*,
                               std::unique_ptr<GradTensorHolder>>&
          node_input_buffers_dict) {
    if (potential_startup_nodes.size() == 0) {
      for (auto input_target_node : *GetInPutTargetNodesInputMetaMap()) {
        // out rank_info of forward op
        auto rank_info = input_target_node.second->OutRankInfo();
        auto iter = node_input_buffers_dict.find(input_target_node.first);
        if (iter != node_input_buffers_dict.end()) {
          auto& target_result =
              (iter->second)->Buffers()[rank_info.first][rank_info.second];
          // save the target result
          results_map[input_target_node.first] = target_result;
231 232 233 234
        }
      }
    }
  }
235 236 237 238 239 240 241 242 243 244 245 246 247 248

  // Set input target grad_var from node_input_buffer by inputmeta
  void SetResultForInputTargetVar(GradTensorHolder input_buffers,
                                  GradNodeBase* node) {
    auto iter = GetInPutTargetNodesInputMetaMap()->find(node);
    if (iter != GetInPutTargetNodesInputMetaMap()->end()) {
      VLOG(6) << "Get target result by by inputmeta";
      // out rank_info of forward op
      auto rank_info = (iter->second)->OutRankInfo();
      // rank_info is a pair, first means slot_id, second means rank.
      auto& target_result =
          input_buffers.Buffers()[rank_info.first][rank_info.second];
      // save the target result
      results_map[node] = target_result;
249
    }
250 251 252 253 254 255 256 257 258 259 260 261 262 263
  }

  std::vector<paddle::experimental::Tensor> GetResults(
      const std::vector<paddle::experimental::Tensor>& inputs,
      bool allow_unused, bool create_graph) {
    VLOG(6) << "Running in GetResults";
    if (inputs.empty()) return {};

    std::vector<paddle::experimental::Tensor> results;
    results.reserve(inputs.size());

    for (size_t i = 0; i < inputs.size(); ++i) {
      auto& input = inputs[i];
      AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input);
264 265 266

      auto* target_node = auto_grad_meta->GetMutableGradNode().get();
      if (orig_to_copied_node_mapping_.count(target_node)) {
267
        target_node = orig_to_copied_node_mapping_[target_node].get();
268 269 270 271 272
      } else {
        VLOG(6) << "Unable to find target node in "
                   "orig_to_copied_node_mapping_, likely indicating an unused "
                   "input";
      }
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288

      auto iter = results_map.find(target_node);
      if (iter != results_map.end()) {
        // set StopGradient = !create_graph
        AutogradMeta* tensor_auto_grad_meta =
            EagerUtils::autograd_meta(&(iter->second));
        tensor_auto_grad_meta->SetStopGradient(!create_graph);
        results.emplace_back(iter->second);
      } else {
        PADDLE_ENFORCE_EQ(allow_unused, true,
                          paddle::platform::errors::InvalidArgument(
                              "The %d-th input does not appear in the backward "
                              "graph. Please check the input tensor or set "
                              "allow_unused=True to get None result.",
                              i));
        results.emplace_back();
289 290
      }
    }
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348
    Clear();
    return results;
  }

  void PreparedForGeneralGrad(
      const std::vector<paddle::experimental::Tensor>& inputs,
      const std::vector<paddle::experimental::Tensor>& no_grad_vars,
      std::queue<GradNodeBase*>* queue,
      const std::unordered_map<GradNodeBase*,
                               std::unique_ptr<GradTensorHolder>>&
          node_input_buffers_dict) {
    // Get no_grad_vars's GradNodes and InputMeta Info
    GetTargetNodesInfo(no_grad_vars, true /* is_no_grad_vars */);
    // Get inputs's GradNodes and InputMeta Info
    GetTargetNodesInfo(inputs, false /* is_no_grad_vars */);
    // Purify potential_startup_ops, remove those nodes that are the same as
    // input_target_nodes
    PurifyPotentialStartUpNodes();
    // Get Graph Info Betweent input target gradnode and outputs
    // Record the depending_nodes and
    // potential_stop_nodes、potential_startup_nodes
    GetGraphInfoBetweenTargets(*queue);
    // Reset queue. Queue is empty only when
    // 1.input equals to output. 2.input can not reach to output.
    ModifyReadyQueue(queue);
    // Set result for input target grad_var when queue is empty
    if (queue->empty()) SetResultForInputTargetVar(node_input_buffers_dict);
  }

  bool IsPotentialStopNodes(GradNodeBase* node) {
    return potential_stop_nodes.count(node);
  }

  std::unordered_map<GradNodeBase*, AutogradMeta*>*
  GetNoGradVarNodesInputMetaMap() {
    return &no_grad_var_nodes_inputmeta_map;
  }

  std::unordered_map<GradNodeBase*, AutogradMeta*>*
  GetInPutTargetNodesInputMetaMap() {
    return &input_target_nodes_inputmeta_map;
  }

  std::unordered_set<GradNodeBase*>* GetPotentialStopNodes() {
    return &potential_stop_nodes;
  }

  std::unordered_set<GradNodeBase*>* GetPotentialStartupNodes() {
    return &potential_startup_nodes;
  }

  void Clear() {
    no_grad_var_nodes_inputmeta_map.clear();
    input_target_nodes_inputmeta_map.clear();
    potential_startup_nodes.clear();
    potential_stop_nodes.clear();
    depending_nodes.clear();
    results_map.clear();
349 350 351 352 353 354
    copied_grad_nodes_.clear();
    orig_to_copied_node_mapping_.clear();
  }

  GradNodeBase* CopyGradNode(const std::shared_ptr<GradNodeBase>& orig_node) {
    if (orig_to_copied_node_mapping_.count(orig_node.get())) {
355
      return orig_to_copied_node_mapping_[orig_node.get()].get();
356 357 358 359
    }
    std::shared_ptr<GradNodeBase> copied_node = orig_node->Copy();

    // Save node and update mapping
360
    orig_to_copied_node_mapping_[orig_node.get()] = copied_node;
361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384
    copied_grad_nodes_.push_back(copied_node);

    return copied_node.get();
  }

  void ReconstructBackwardGraph(
      const std::queue<GradNodeBase*>& orig_init_queue) {
    std::queue<GradNodeBase*> queue = orig_init_queue;
    std::unordered_set<GradNodeBase*> visited;

    // BFS and recursively copy the grad nodes
    while (!queue.empty()) {
      GradNodeBase* orig_node = queue.front();
      queue.pop();
      if (visited.count(orig_node)) {
        continue;
      }
      visited.insert(orig_node);

      PADDLE_ENFORCE(
          orig_to_copied_node_mapping_.count(orig_node),
          paddle::platform::errors::Fatal(
              "Cannot reconstruct backward graph,"
              "unable to find copied target for certain grad node."));
385
      GradNodeBase* copied_node = orig_to_copied_node_mapping_[orig_node].get();
386

387 388 389 390 391 392 393 394 395
      const paddle::small_vector<std::vector<GradSlotMeta>,
                                 kSlotSmallVectorSize>& orig_meta =
          orig_node->OutputMeta();
      paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
          copied_edges = copied_node->MutableOutputMeta();
      for (size_t i = 0; i < orig_meta.size(); i++) {
        for (size_t j = 0; j < orig_meta[i].size(); j++) {
          const Edge& orig_edge = orig_meta[i][j].GetEdge();
          Edge& copied_edge = copied_edges[i][j].GetMutableEdge();
396 397 398 399 400 401 402 403 404

          std::shared_ptr<GradNodeBase> orig_next_node =
              orig_edge.GetMutableGradNode();
          if (!orig_next_node) continue;

          // Copy Next Node
          std::shared_ptr<GradNodeBase> copied_next_node;
          if (orig_to_copied_node_mapping_.count(orig_next_node.get())) {
            copied_next_node =
405
                orig_to_copied_node_mapping_[orig_next_node.get()];
406 407 408 409

          } else {
            copied_next_node = orig_next_node->Copy();
            orig_to_copied_node_mapping_[orig_next_node.get()] =
410
                copied_next_node;
411 412 413 414 415 416 417 418 419 420 421
            copied_grad_nodes_.push_back(copied_next_node);
          }

          // Update Edge's Grad Node
          copied_edge.SetGradNode(copied_next_node);

          // Update BFS queue
          queue.push(orig_next_node.get());
        }
      }
    }
422 423
  }

424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440
 private:
  GeneralGrad() = default;
  static GeneralGrad* general_grad_;
  // no_grad_vars's GradNode and GradNode's InputMeta.
  std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
      no_grad_var_nodes_inputmeta_map;
  // inputs's GradNode and GradNode's InputMeta.
  std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
      input_target_nodes_inputmeta_map;
  // Record all the potential startup_nodes, will be changed.
  std::unordered_set<GradNodeBase*> potential_startup_nodes;
  // Record all the potential stop nodes, will be changed.
  std::unordered_set<GradNodeBase*> potential_stop_nodes;
  std::unordered_map<GradNodeBase* /* next node */,
                     std::unordered_set<GradNodeBase*> /* pre nodes */>
      depending_nodes;
  std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map;
441 442

  std::vector<std::shared_ptr<GradNodeBase>> copied_grad_nodes_;
443 444
  std::unordered_map<GradNodeBase*, std::shared_ptr<GradNodeBase>>
      orig_to_copied_node_mapping_;
445

446 447
  DISABLE_COPY_AND_ASSIGN(GeneralGrad);
};
448

449 450
std::unordered_map<GradNodeBase*, int> getInDegreeMap(
    const std::queue<GradNodeBase*>& init_queue) {
451
  // Calculate in_degree for each node
452 453
  // We can completely remove this pass, if in_degree were set during forward
  // pass
454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469
  std::unordered_map<GradNodeBase*, int> node_in_degree_map;

  // Copy nodes
  std::queue<GradNodeBase*> queue = init_queue;
  std::unordered_set<GradNodeBase*> visited;

  // Visit each node exactly once in any order
  while (!queue.empty()) {
    GradNodeBase* node = queue.front();
    queue.pop();

    if (visited.count(node)) {
      continue;
    }
    visited.insert(node);

470 471 472 473 474
    PADDLE_ENFORCE_NOT_NULL(
        node,
        paddle::platform::errors::Fatal(
            "We got null node when we traverse the backward graph, and this "
            "should not happened please check your code and contact us."));
475
    // Find and append next nodes
476 477 478 479 480
    const paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
        metas = node->OutputMeta();
    for (const auto& meta_list : metas) {
      for (const GradSlotMeta& meta : meta_list) {
        const auto& edge = meta.GetEdge();
481 482 483 484 485 486 487 488 489 490 491 492 493 494
        GradNodeBase* next_node = edge.GetMutableGradNode().get();
        // Next node could be nullptr if it is leaf tensor with no
        // AccumulationNode attached
        // Or it could also originated from dispensable inputs
        if (!next_node) continue;

        // Update in_degree
        if (!node_in_degree_map.count(next_node))
          node_in_degree_map[next_node] = 0;
        node_in_degree_map[next_node]++;
        queue.push(next_node);
      }
    }
  }
495

496
  return node_in_degree_map;
497 498 499 500 501 502 503 504 505 506 507 508 509 510 511
}

// Enforce GradNode has TensorWrappers as Input
void EnforceGradNodeHasInput(GradNodeBase* node) {
  VLOG(6) << "Running in EnforceGradNodeHasInput";
  PADDLE_ENFORCE_NE(
      node->IsTensorWrappersCleared(), true,
      paddle::platform::errors::Fatal(
          "The TensorWrappers of %s do not exist. This may be because:\n"
          "You calculate backward twice for the same subgraph without "
          "setting retain_graph=True. Please set retain_graph=True in the "
          "first backward/grad call.\n",
          node->name()));
}

512 513 514 515 516 517 518 519 520 521 522 523
void DuplicateCheck(const std::vector<paddle::experimental::Tensor>& inputs,
                    bool is_input) {
  std::unordered_set<AutogradMeta*> visisted_ins;
  std::string msg = is_input ? "inputs" : "outputs";
  for (auto in : inputs) {
    AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(in);
    PADDLE_ENFORCE_EQ(
        visisted_ins.count(auto_grad_meta), 0,
        paddle::platform::errors::AlreadyExists(
            "%s contain duplicate tensor %s, please check %s carefully.", msg,
            in.name(), msg));
    visisted_ins.insert(auto_grad_meta);
524 525 526
  }
}

527 528
GeneralGrad* GeneralGrad::general_grad_ = new GeneralGrad();

529 530 531 532 533 534 535
std::vector<paddle::experimental::Tensor> RunBackward(
    const std::vector<paddle::experimental::Tensor>& tensors,  // output
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
    bool retain_graph, bool create_graph = false,
    const std::vector<paddle::experimental::Tensor>& inputs = {},
    bool allow_unused = false,
    const std::vector<paddle::experimental::Tensor>& no_grad_vars = {}) {
536
  VLOG(6) << "Start Backward";
537

538 539 540 541
  // *Gradient Hook should happen at node-level
  // *Inplace version check should perform at node-level
  // *Cross-batch accumulation happens at forward pass

542 543
  // GeneralGrad
  bool is_general_grad = !inputs.empty();
544
  if (is_general_grad) GeneralGrad::Instance().Clear();
545

546 547 548 549
  /* --- Initialization --- */
  // 1. Init queue with starting nodes
  // 2. Prepare initial input buffers
  std::queue<GradNodeBase*> queue;
550
  std::queue<GradNodeBase*> orig_queue;
551 552 553
  std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>>
      node_input_buffers_dict;
  for (size_t i = 0; i < tensors.size(); i++) {
554
    const paddle::experimental::Tensor& tensor = tensors[i];
555

556 557 558 559 560 561 562
    AutogradMeta* auto_grad_meta = EagerUtils::nullable_autograd_meta(tensor);
    if (auto_grad_meta == nullptr) {
      VLOG(3) << "Skip auto grad since there is no grad op for var or loss is "
                 "stop_gradient=True: "
              << tensor.name();
      continue;
    }
563 564 565 566 567 568
    // Get grad input info from target tensors
    auto input_info = auto_grad_meta->OutRankInfo();

    VLOG(2) << "Out Rank of Tensor is slot: " << input_info.first
            << ", rank: " << input_info.second;
    // Get target GradNodeBase from target tensors
569 570 571 572 573 574 575 576 577 578
    auto shared_grad_node = auto_grad_meta->GetMutableGradNode();

    if (shared_grad_node == nullptr || shared_grad_node.get() == nullptr ||
        auto_grad_meta->StopGradient()) {
      VLOG(3) << "Skip auto grad since there is no grad op for var or loss is "
                 "stop_gradient=True: "
              << tensor.name();
      continue;
    }

579
    // TODO(zhanlve): Copy and Modify GradNode if is_general_grad
580
    GradNodeBase* grad_node = shared_grad_node.get();
581 582 583 584 585 586 587 588 589 590
    if (is_general_grad) {
      // Save orig grad node
      orig_queue.push(grad_node);

      // Replace grad_node with copied grad_node
      grad_node = GeneralGrad::Instance().CopyGradNode(shared_grad_node);

      // Record potential startup grad node
      GeneralGrad::Instance().GetPotentialStartupNodes()->insert(grad_node);
    }
591 592 593

    // Prepare GradTensorHolder
    if (!node_input_buffers_dict.count(grad_node)) {
594 595
      VLOG(6) << "Create Value for grad input tensor " << i
              << " of grad node: " << grad_node->name();
596 597 598
      node_input_buffers_dict[grad_node] =
          std::make_unique<GradTensorHolder>(grad_node->InputMeta());
    }
599 600 601
    bool copy_from_grad_t =
        grad_tensors.size() > 0 && grad_tensors[i].initialized();
    if (copy_from_grad_t) {
602 603 604 605 606
      PADDLE_ENFORCE(
          grad_tensors.size() == tensors.size(),
          paddle::platform::errors::Fatal(
              "Detected size mismatch between tensors and grad_tensors"
              "grad_tensors should either have "
607
              "size = 0 or same size as tensors."));
608 609
      // Feed given tensor if it's provided
      VLOG(6) << "Fill grad input tensor " << i << "with give grad tensor";
610

611 612 613
      // Deep copy
      node_input_buffers_dict[grad_node]->CopyValueFromTensor(
          input_info.first, input_info.second, grad_tensors[i]);
614 615 616 617 618 619 620
    } else {
      VLOG(6) << "Fill grad input tensor " << i << " with 1.0";
      // Initialize tensor with 1.0
      // Forward Tensor "tensor" is passed to indicate tensortype, datatype and
      // dims
      // GradTensorHolder will initialize another tensor with same tensortype,
      // datatype and dims but filled with 1.0
621
      node_input_buffers_dict[grad_node]->CopyValueFromTensor(
622 623 624
          input_info.first, input_info.second, tensor, true /*fill_one=true*/);
    }

625
    // Prepare queue, potential startup_nodes
626
    queue.push(grad_node);
627 628 629 630 631
  }

  if (is_general_grad) {
    // Copy Backward Graph
    GeneralGrad::Instance().ReconstructBackwardGraph(orig_queue);
632 633 634 635 636 637 638
  }

  VLOG(6) << "Update In degree Map for backward";
  // 3. Compute in_degree for each node
  std::unordered_map<GradNodeBase*, int> node_in_degree_map =
      getInDegreeMap(queue);

639 640 641 642
  if (is_general_grad) {
    // Prepare several vital preprocess for GeneralGrad
    GeneralGrad::Instance().PreparedForGeneralGrad(inputs, no_grad_vars, &queue,
                                                   node_input_buffers_dict);
643 644
  }

645
  VLOG(6) << " startup_ops' size is :" << queue.size();
646

647 648 649
  /* --- Topological Visit --- */
  // 1. Pop queue
  // 2. Run node
650
  //    |- Check and capture target result
651 652 653 654
  //    |- node(grads)
  //    |- Prepare for next node
  // 3. Update queue
  VLOG(6) << "Run Backward";
655 656
  while (!queue.empty()) {
    GradNodeBase* node = queue.front();
657
    VLOG(6) << "Running GradNode:" << node->name();
658

659
    paddle::platform::RecordEvent node_record_event(
C
chenjian 已提交
660
        std::string((*node).name()) + " grad_node",
661 662
        paddle::platform::TracerEventType::Operator, 1);

663 664 665 666 667 668
    if (queue.size() > 1 && node_in_degree_map[node] != 0) {
      queue.pop();
      continue;
    }
    queue.pop();

669 670 671 672
    // Run node: This is where Hook happens
    PADDLE_ENFORCE(
        node_input_buffers_dict.count(node),
        paddle::platform::errors::Fatal(
673
            "Unable to find next node in the GradTensorHolder \n"
674
            "Trying to run Node without configuring its GradTensorHolder."));
675 676 677

    std::unique_ptr<GradTensorHolder> node_input_buffer =
        std::move(node_input_buffers_dict[node]);
678

679 680 681 682
    // Set input target grad_var from node_input_buffer by inputmeta
    if (!inputs.empty() && is_general_grad) {
      GeneralGrad::Instance().SetResultForInputTargetVar(*node_input_buffer,
                                                         node);
683 684 685
    }

    // no_grad_vars
686 687 688 689 690 691 692 693 694 695
    if (!no_grad_vars.empty() && is_general_grad) {
      auto iter =
          GeneralGrad::Instance().GetNoGradVarNodesInputMetaMap()->find(node);
      if (iter !=
          GeneralGrad::Instance().GetNoGradVarNodesInputMetaMap()->end()) {
        VLOG(6) << "Change the input buffer[slot][rank] by Zeros";
        auto rank_info = (iter->second)->OutRankInfo();
        node_input_buffer->SetBufferSlotRankZeros(rank_info.first,
                                                  rank_info.second);
      }
696 697 698 699
    }

    VLOG(6) << "Running GradNode:" << node->name();

700
    // Check input
701 702
    EnforceGradNodeHasInput(node);

703
    VLOG(6) << "Run Backward Kernel with GradTensorHolder.";
704
    // Run Pre Backward Node and get outputs
705 706 707 708
    paddle::small_vector<std::vector<paddle::experimental::Tensor>,
                         kSlotSmallVectorSize>
        grad_output_tensors = (*node)(node_input_buffer->Buffers(),
                                      create_graph, is_general_grad);
709 710 711 712 713 714 715 716

    // retain_grad or not
    if (!retain_graph) {
      VLOG(6)
          << "retain_graph is false, need to clear the TensorWrapper of nodes.";
      node->ClearTensorWrappers();
    }

717
    // TODO(jiabin): Should we erase it or find a more efficient way.
718

719 720 721
    node_input_buffers_dict.erase(node);

    // Prepare GradTensorHolder for next node
722 723 724
    const paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
        metas = node->OutputMeta();
    PADDLE_ENFORCE(metas.size() == grad_output_tensors.size() || metas.empty(),
725 726
                   paddle::platform::errors::Fatal(
                       "Number of edges should be either empty ( for leaf node "
727 728
                       ") or the same as number of output grad tensors, but we "
                       "got edges size is: %d, grad_output size is: %d",
729
                       metas.size(), grad_output_tensors.size()));
730

731 732 733
    for (size_t i = 0; i < metas.size(); i++) {
      for (size_t j = 0; j < metas[i].size(); j++) {
        const Edge& edge = metas[i][j].GetEdge();
J
Jiabin Yang 已提交
734 735 736
        if (!edge.IsInitialized()) {
          continue;
        }
737 738 739 740
        auto edge_rank = edge.GetEdgeRankInfo();
        // Since we make edge has as same rank as bwd outputs, we indexing them
        // with
        // the same rank(i, j)
741
        auto next_node_shared = edge.GetMutableGradNode();
742

743 744 745
        // Next node could be nullptr if it is leaf tensor with no
        // AccumulationNode attached
        // Or it could also originated from dispensable inputs
746 747 748 749
        if (!next_node_shared || !next_node_shared.get() ||
            grad_output_tensors[i].empty()) {
          continue;
        }
750

751 752 753 754 755 756 757
        PADDLE_ENFORCE_LT(
            j, grad_output_tensors[i].size(),
            paddle::platform::errors::Fatal(
                "Rank of grad_output_tensors should be less than "
                "grad_output_tensors[i].size(), which is: %d. This error may "
                "indicate autoprune or autograd api error. ",
                grad_output_tensors.size()));
758 759
        paddle::experimental::Tensor& grad_output_tensor =
            grad_output_tensors[i][j];
760 761 762

        if ((!grad_output_tensor.defined() ||
             !grad_output_tensor.initialized())) {
763 764
          VLOG(6) << "We get grad_output_tensor with slot: " << i
                  << ", rank: " << j << " as uninitialized or undefined tensor";
765
        }
766

767 768 769 770
        VLOG(6) << "Get Edge and grad_output_tensor with slot: " << i
                << ", rank: " << j
                << " 's name is: " << grad_output_tensor.name();

771 772 773 774 775 776 777 778 779 780
        auto* next_node = next_node_shared.get();
        if (!node_input_buffers_dict.count(next_node)) {
          const auto& input_meta = next_node->InputMeta();
          auto grad_tensor_holder =
              std::make_unique<GradTensorHolder>(input_meta);
          VLOG(6) << "Construct GradTensorHolder for grad node: "
                  << next_node->name();
          node_input_buffers_dict[next_node] = std::move(grad_tensor_holder);
        }

781 782
        VLOG(6) << "Sum grad inputs for edge slot: " << edge_rank.first
                << ", rank: " << edge_rank.second;
783

784
        node_input_buffers_dict[next_node]->add(
785 786
            edge_rank.first, edge_rank.second, grad_output_tensor,
            create_graph);
787 788 789

        // Update queue
        node_in_degree_map[next_node]--;
790

791 792 793 794
        PADDLE_ENFORCE(
            node_in_degree_map[next_node] >= 0,
            paddle::platform::errors::Fatal(
                "Detected in-degree value smaller than zero. For Node: %s"
795
                "Node's in-degree cannot be negative.",
796
                next_node->name()));
797

798 799 800 801 802 803 804 805 806 807
        if (is_general_grad) {
          bool is_potential_stop_node =
              GeneralGrad::Instance().GetPotentialStopNodes()->count(next_node);
          if (node_in_degree_map[next_node] == 0 && !is_potential_stop_node) {
            queue.emplace(std::move(next_node));
          }
        } else {
          if (node_in_degree_map[next_node] == 0) {
            queue.emplace(std::move(next_node));
          }
808 809 810 811
        }
      }
    }
  }
812

813 814
  if (!is_general_grad) return {};
  return GeneralGrad::Instance().GetResults(inputs, allow_unused, create_graph);
815 816
}

817
void Backward(
818
    const std::vector<paddle::experimental::Tensor>& tensors,  // outputs
819 820 821 822 823 824
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
    bool retain_graph) {
  VLOG(6) << "Run in Backward";
  paddle::platform::RecordEvent backward_record_event(
      "backward", paddle::platform::TracerEventType::Operator, 1);
  RunBackward(tensors, grad_tensors, retain_graph);
J
Jiabin Yang 已提交
825
  phi::autotune::AutoTuneStatus::Instance().Update();
826 827 828
}

std::vector<paddle::experimental::Tensor> Grad(
829
    const std::vector<paddle::experimental::Tensor>& tensors,  // outputs
830 831 832 833 834
    const std::vector<paddle::experimental::Tensor>& inputs,
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
    bool retain_graph, bool create_graph, bool only_inputs, bool allow_unused,
    const std::vector<paddle::experimental::Tensor>& no_grad_vars) {
  VLOG(6) << "Run in Grad";
835 836 837 838

  DuplicateCheck(inputs, true /* is_input */);
  DuplicateCheck(tensors, false /* is_input */);

839 840 841
  return RunBackward(tensors, grad_tensors, retain_graph, create_graph, inputs,
                     allow_unused, no_grad_vars);
}
842
}  // namespace egr