backward.cc 26.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/eager/backward.h"
#include <queue>

#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/grad_tensor_holder.h"
#include "paddle/fluid/eager/utils.h"
22 23
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
24 25 26 27 28 29 30 31

#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"

#include "glog/logging.h"

namespace egr {

32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
/*
* GeneralGrad is Helpper class to implement custom grad operation between
* outputs and inputs.
*
* **/
class GeneralGrad {
 public:
  static GeneralGrad& Instance() { return *general_grad_; }

  // Get inputs's / no_grad_vars's GradNodes and InputMeta Info
  void GetTargetNodesInfo(
      const std::vector<paddle::experimental::Tensor>& inputs,
      bool is_no_grad_vars) {
    std::string msg = is_no_grad_vars ? "no_grad_vars" : "inputs";
    VLOG(6) << "Running in GetTargetNodesInfo.";
    if (!inputs.empty()) {
      VLOG(6) << msg << " are not empty.";
      size_t num_inputs = inputs.size();
      for (size_t i = 0; i < num_inputs; i++) {
        AutogradMeta* auto_grad_meta =
            EagerUtils::unsafe_autograd_meta(inputs[i]);
        auto target_node = auto_grad_meta->GetMutableGradNode().get();
        PADDLE_ENFORCE_NOT_NULL(target_node,
                                paddle::platform::errors::Fatal(
                                    "There is no grad op for %s:[%d] or it's"
                                    "stop_gradient=True.",
                                    msg, i));
        if (is_no_grad_vars) {
          (no_grad_var_nodes_inputmeta_map)[target_node] = auto_grad_meta;
        } else {  // normal input
          (input_target_nodes_inputmeta_map)[target_node] = auto_grad_meta;
        }
      }
    }
  }
67

68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
  // Purify potential_startup_nodes, remove nodes those are the same as
  // input_target_nodes
  void PurifyPotentialStartUpNodes() {
    VLOG(6) << "Running in PurifyPotentialStartUpNodes";
    if (input_target_nodes_inputmeta_map.empty()) return;
    std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
    for (auto startup_op : potential_startup_nodes) {
      auto iter = input_target_nodes_inputmeta_map.find(startup_op);
      if (iter != input_target_nodes_inputmeta_map.end()) {
        potential_startup_nodes_to_be_erased.emplace(iter->first);
      }
    }
    if (!potential_startup_nodes_to_be_erased.empty()) {
      for (auto nodes : potential_startup_nodes_to_be_erased) {
        potential_startup_nodes.erase(nodes);
      }
    }
  }
86

87 88 89 90 91 92 93 94 95 96 97
  // Remove some nodes those doesn't need to be
  // stored in potential_stop_nodes、potential_startup_nodes
  void UpdateGraphInfo() {
    // Updated potential_sotp_nodes by depending_nodes,
    // make sure the path from root to target_node is ok
    std::unordered_set<GradNodeBase*> _startup_ops;
    VLOG(6) << "Running in UpdateGraphInfo";
    std::queue<GradNodeBase*> queue;
    for (auto& target_nodes_inputmeta_pair : input_target_nodes_inputmeta_map) {
      queue.emplace(target_nodes_inputmeta_pair.first);
    }
98

99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
    while (!queue.empty()) {
      auto* target_node = queue.front();
      queue.pop();
      if (!(depending_nodes)[target_node].empty()) {
        auto precedding_nodes = (depending_nodes)[target_node];
        for (auto pre_nodes : precedding_nodes) {
          queue.emplace(pre_nodes);
          if (potential_stop_nodes.find(pre_nodes) !=
              potential_stop_nodes.end()) {
            potential_stop_nodes.erase(pre_nodes);
          }
        }
      } else {  // startup_ops have no precedding nodes
        VLOG(6) << "Emplace _startup_ops";
        _startup_ops.emplace(target_node);
114 115
      }
    }
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
    // Purify potential_startup_nodes again, remove some
    // potential startup_nodes that unreach to input target nodes
    if (!_startup_ops.empty()) {
      std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
      for (auto node : potential_startup_nodes) {
        if (_startup_ops.count(node) == 0) {
          VLOG(6) << "Set up potential_startup_nodes_to_be_erased";
          potential_startup_nodes_to_be_erased.emplace(node);
        }
      }
      if (!potential_startup_nodes_to_be_erased.empty()) {
        for (auto node : potential_startup_nodes_to_be_erased) {
          VLOG(6) << "Erase nodes in potential_startup_nodes_to_be_erased";
          potential_startup_nodes.erase(node);
        }
      }
132
    }
133
  }
134

135 136 137 138
  // Get Graph Info Betweent input target GradNode and outputs,
  // record depending_nodes、potential_stop_nodes、potential_startup_nodes
  void GetGraphInfoBetweenTargets(const std::queue<GradNodeBase*>& init_queue) {
    VLOG(6) << "Runing In GetGraphInfoBetweenTargets";
139

140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
    // Calculate in_degree for each node
    std::unordered_map<GradNodeBase*, int> node_in_degree_map;

    // Copy nodes
    std::queue<GradNodeBase*> queue = init_queue;
    std::unordered_set<GradNodeBase*> visited;

    // Visit each node exactly once in any order
    while (!queue.empty()) {
      GradNodeBase* node = queue.front();
      queue.pop();

      if (visited.count(node)) {
        continue;
      }
      visited.insert(node);

      // Check node is target_nodes or not, if node is not target_node,
      // all the next_node will be marked in potential_stop_nodes
      bool is_potential_stop_nodes =
          input_target_nodes_inputmeta_map.count(node);

      // Find and append next nodes
      const std::vector<std::vector<Edge>>& edges = node->GetEdges();
      for (const auto& edge_list : edges) {
        for (const Edge& edge : edge_list) {
          GradNodeBase* next_node = edge.GetMutableGradNode().get();

          // Next node could be nullptr if it is leaf tensor with no
          // AccumulationNode attached
          // Or it could also originated from dispensable inputs
          if (!next_node) continue;

          // if node not in input_target_nodes,
          // all the next_nodes of current node will be inserted to
          // potential_stop_node
          if (is_potential_stop_nodes) {
            potential_stop_nodes.emplace(next_node);
          }

          // Update in_degree
          if (!node_in_degree_map.count(next_node))
            node_in_degree_map[next_node] = 0;
          node_in_degree_map[next_node]++;

          // Record depending relationship
          (depending_nodes)[next_node].emplace(node);
          queue.push(next_node);
        }
189 190
      }
    }
191 192 193
    // Update Graph Info, remove some nodes in
    // potential_stop_nodes、potential_startup_nodes、
    UpdateGraphInfo();
194 195
  }

196 197 198 199 200 201
  void ModifyReadyQueue(std::queue<GradNodeBase*>* queue) {
    std::queue<GradNodeBase*> tmp_queue;
    for (auto nodes : potential_startup_nodes) {
      tmp_queue.emplace(nodes);
    }
    tmp_queue.swap(*queue);
202 203
  }

204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
  // Set result for input target grad_var when potential_startup_nodes is empty
  void SetResultForInputTargetVar(
      const std::unordered_map<GradNodeBase*,
                               std::unique_ptr<GradTensorHolder>>&
          node_input_buffers_dict) {
    if (potential_startup_nodes.size() == 0) {
      for (auto input_target_node : *GetInPutTargetNodesInputMetaMap()) {
        // out rank_info of forward op
        auto rank_info = input_target_node.second->OutRankInfo();
        auto iter = node_input_buffers_dict.find(input_target_node.first);
        if (iter != node_input_buffers_dict.end()) {
          auto& target_result =
              (iter->second)->Buffers()[rank_info.first][rank_info.second];
          // save the target result
          results_map[input_target_node.first] = target_result;
219 220 221 222
        }
      }
    }
  }
223 224 225 226 227 228 229 230 231 232 233 234 235 236

  // Set input target grad_var from node_input_buffer by inputmeta
  void SetResultForInputTargetVar(GradTensorHolder input_buffers,
                                  GradNodeBase* node) {
    auto iter = GetInPutTargetNodesInputMetaMap()->find(node);
    if (iter != GetInPutTargetNodesInputMetaMap()->end()) {
      VLOG(6) << "Get target result by by inputmeta";
      // out rank_info of forward op
      auto rank_info = (iter->second)->OutRankInfo();
      // rank_info is a pair, first means slot_id, second means rank.
      auto& target_result =
          input_buffers.Buffers()[rank_info.first][rank_info.second];
      // save the target result
      results_map[node] = target_result;
237
    }
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
  }

  std::vector<paddle::experimental::Tensor> GetResults(
      const std::vector<paddle::experimental::Tensor>& inputs,
      bool allow_unused, bool create_graph) {
    VLOG(6) << "Running in GetResults";
    if (inputs.empty()) return {};

    std::vector<paddle::experimental::Tensor> results;
    results.reserve(inputs.size());

    for (size_t i = 0; i < inputs.size(); ++i) {
      auto& input = inputs[i];
      AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input);
      auto target_node = auto_grad_meta->GetMutableGradNode().get();

      auto iter = results_map.find(target_node);
      if (iter != results_map.end()) {
        // set StopGradient = !create_graph
        AutogradMeta* tensor_auto_grad_meta =
            EagerUtils::autograd_meta(&(iter->second));
        tensor_auto_grad_meta->SetStopGradient(!create_graph);
        results.emplace_back(iter->second);
      } else {
        PADDLE_ENFORCE_EQ(allow_unused, true,
                          paddle::platform::errors::InvalidArgument(
                              "The %d-th input does not appear in the backward "
                              "graph. Please check the input tensor or set "
                              "allow_unused=True to get None result.",
                              i));
        results.emplace_back();
269 270
      }
    }
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328
    Clear();
    return results;
  }

  void PreparedForGeneralGrad(
      const std::vector<paddle::experimental::Tensor>& inputs,
      const std::vector<paddle::experimental::Tensor>& no_grad_vars,
      std::queue<GradNodeBase*>* queue,
      const std::unordered_map<GradNodeBase*,
                               std::unique_ptr<GradTensorHolder>>&
          node_input_buffers_dict) {
    // Get no_grad_vars's GradNodes and InputMeta Info
    GetTargetNodesInfo(no_grad_vars, true /* is_no_grad_vars */);
    // Get inputs's GradNodes and InputMeta Info
    GetTargetNodesInfo(inputs, false /* is_no_grad_vars */);
    // Purify potential_startup_ops, remove those nodes that are the same as
    // input_target_nodes
    PurifyPotentialStartUpNodes();
    // Get Graph Info Betweent input target gradnode and outputs
    // Record the depending_nodes and
    // potential_stop_nodes、potential_startup_nodes
    GetGraphInfoBetweenTargets(*queue);
    // Reset queue. Queue is empty only when
    // 1.input equals to output. 2.input can not reach to output.
    ModifyReadyQueue(queue);
    // Set result for input target grad_var when queue is empty
    if (queue->empty()) SetResultForInputTargetVar(node_input_buffers_dict);
  }

  bool IsPotentialStopNodes(GradNodeBase* node) {
    return potential_stop_nodes.count(node);
  }

  std::unordered_map<GradNodeBase*, AutogradMeta*>*
  GetNoGradVarNodesInputMetaMap() {
    return &no_grad_var_nodes_inputmeta_map;
  }

  std::unordered_map<GradNodeBase*, AutogradMeta*>*
  GetInPutTargetNodesInputMetaMap() {
    return &input_target_nodes_inputmeta_map;
  }

  std::unordered_set<GradNodeBase*>* GetPotentialStopNodes() {
    return &potential_stop_nodes;
  }

  std::unordered_set<GradNodeBase*>* GetPotentialStartupNodes() {
    return &potential_startup_nodes;
  }

  void Clear() {
    no_grad_var_nodes_inputmeta_map.clear();
    input_target_nodes_inputmeta_map.clear();
    potential_startup_nodes.clear();
    potential_stop_nodes.clear();
    depending_nodes.clear();
    results_map.clear();
329 330
  }

331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349
 private:
  GeneralGrad() = default;
  static GeneralGrad* general_grad_;
  // no_grad_vars's GradNode and GradNode's InputMeta.
  std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
      no_grad_var_nodes_inputmeta_map;
  // inputs's GradNode and GradNode's InputMeta.
  std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
      input_target_nodes_inputmeta_map;
  // Record all the potential startup_nodes, will be changed.
  std::unordered_set<GradNodeBase*> potential_startup_nodes;
  // Record all the potential stop nodes, will be changed.
  std::unordered_set<GradNodeBase*> potential_stop_nodes;
  std::unordered_map<GradNodeBase* /* next node */,
                     std::unordered_set<GradNodeBase*> /* pre nodes */>
      depending_nodes;
  std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map;
  DISABLE_COPY_AND_ASSIGN(GeneralGrad);
};
350

351 352
std::unordered_map<GradNodeBase*, int> getInDegreeMap(
    const std::queue<GradNodeBase*>& init_queue) {
353
  // Calculate in_degree for each node
354 355
  // We can completely remove this pass, if in_degree were set during forward
  // pass
356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371
  std::unordered_map<GradNodeBase*, int> node_in_degree_map;

  // Copy nodes
  std::queue<GradNodeBase*> queue = init_queue;
  std::unordered_set<GradNodeBase*> visited;

  // Visit each node exactly once in any order
  while (!queue.empty()) {
    GradNodeBase* node = queue.front();
    queue.pop();

    if (visited.count(node)) {
      continue;
    }
    visited.insert(node);

372 373 374 375 376
    PADDLE_ENFORCE_NOT_NULL(
        node,
        paddle::platform::errors::Fatal(
            "We got null node when we traverse the backward graph, and this "
            "should not happened please check your code and contact us."));
377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394
    // Find and append next nodes
    const std::vector<std::vector<Edge>>& edges = node->GetEdges();
    for (const auto& edge_list : edges) {
      for (const Edge& edge : edge_list) {
        GradNodeBase* next_node = edge.GetMutableGradNode().get();
        // Next node could be nullptr if it is leaf tensor with no
        // AccumulationNode attached
        // Or it could also originated from dispensable inputs
        if (!next_node) continue;

        // Update in_degree
        if (!node_in_degree_map.count(next_node))
          node_in_degree_map[next_node] = 0;
        node_in_degree_map[next_node]++;
        queue.push(next_node);
      }
    }
  }
395
  return node_in_degree_map;
396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
}

// Enforce GradNode has TensorWrappers as Input
void EnforceGradNodeHasInput(GradNodeBase* node) {
  VLOG(6) << "Running in EnforceGradNodeHasInput";
  PADDLE_ENFORCE_NE(
      node->IsTensorWrappersCleared(), true,
      paddle::platform::errors::Fatal(
          "The TensorWrappers of %s do not exist. This may be because:\n"
          "You calculate backward twice for the same subgraph without "
          "setting retain_graph=True. Please set retain_graph=True in the "
          "first backward/grad call.\n",
          node->name()));
}

411 412 413 414 415 416 417 418 419 420 421 422
void DuplicateCheck(const std::vector<paddle::experimental::Tensor>& inputs,
                    bool is_input) {
  std::unordered_set<AutogradMeta*> visisted_ins;
  std::string msg = is_input ? "inputs" : "outputs";
  for (auto in : inputs) {
    AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(in);
    PADDLE_ENFORCE_EQ(
        visisted_ins.count(auto_grad_meta), 0,
        paddle::platform::errors::AlreadyExists(
            "%s contain duplicate tensor %s, please check %s carefully.", msg,
            in.name(), msg));
    visisted_ins.insert(auto_grad_meta);
423 424 425
  }
}

426 427
GeneralGrad* GeneralGrad::general_grad_ = new GeneralGrad();

428 429 430 431 432 433 434
std::vector<paddle::experimental::Tensor> RunBackward(
    const std::vector<paddle::experimental::Tensor>& tensors,  // output
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
    bool retain_graph, bool create_graph = false,
    const std::vector<paddle::experimental::Tensor>& inputs = {},
    bool allow_unused = false,
    const std::vector<paddle::experimental::Tensor>& no_grad_vars = {}) {
435 436 437 438 439
  VLOG(6) << "Start Backward";
  // *Gradient Hook should happen at node-level
  // *Inplace version check should perform at node-level
  // *Cross-batch accumulation happens at forward pass

440 441
  // GeneralGrad
  bool is_general_grad = !inputs.empty();
442

443 444 445 446 447 448 449
  /* --- Initialization --- */
  // 1. Init queue with starting nodes
  // 2. Prepare initial input buffers
  std::queue<GradNodeBase*> queue;
  std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>>
      node_input_buffers_dict;
  for (size_t i = 0; i < tensors.size(); i++) {
450
    const paddle::experimental::Tensor& tensor = tensors[i];
451 452 453 454 455 456 457 458

    AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(tensor);
    // Get grad input info from target tensors
    auto input_info = auto_grad_meta->OutRankInfo();

    VLOG(2) << "Out Rank of Tensor is slot: " << input_info.first
            << ", rank: " << input_info.second;
    // Get target GradNodeBase from target tensors
459 460 461 462 463 464 465 466 467 468
    auto shared_grad_node = auto_grad_meta->GetMutableGradNode();

    if (shared_grad_node == nullptr || shared_grad_node.get() == nullptr ||
        auto_grad_meta->StopGradient()) {
      VLOG(3) << "Skip auto grad since there is no grad op for var or loss is "
                 "stop_gradient=True: "
              << tensor.name();
      continue;
    }

469
    // TODO(zhanlve): Copy and Modify GradNode if is_general_grad
470
    GradNodeBase* grad_node = shared_grad_node.get();
471 472 473

    // Prepare GradTensorHolder
    if (!node_input_buffers_dict.count(grad_node)) {
474 475
      VLOG(6) << "Create Value for grad input tensor " << i
              << " of grad node: " << grad_node->name();
476 477 478 479 480 481 482 483 484 485
      node_input_buffers_dict[grad_node] =
          std::make_unique<GradTensorHolder>(grad_node->InputMeta());
    }

    if (grad_tensors.size() > 0) {
      PADDLE_ENFORCE(
          grad_tensors.size() == tensors.size(),
          paddle::platform::errors::Fatal(
              "Detected size mismatch between tensors and grad_tensors"
              "grad_tensors should either have "
486
              "size = 0 or same size as tensors."));
487 488
      // Feed given tensor if it's provided
      VLOG(6) << "Fill grad input tensor " << i << "with give grad tensor";
489

490 491 492
      // Deep copy
      node_input_buffers_dict[grad_node]->CopyValueFromTensor(
          input_info.first, input_info.second, grad_tensors[i]);
493 494 495 496 497 498 499 500

    } else {
      VLOG(6) << "Fill grad input tensor " << i << " with 1.0";
      // Initialize tensor with 1.0
      // Forward Tensor "tensor" is passed to indicate tensortype, datatype and
      // dims
      // GradTensorHolder will initialize another tensor with same tensortype,
      // datatype and dims but filled with 1.0
501
      node_input_buffers_dict[grad_node]->CopyValueFromTensor(
502 503 504
          input_info.first, input_info.second, tensor, true /*fill_one=true*/);
    }

505
    // Prepare queue, potential startup_nodes
506
    queue.push(grad_node);
507 508 509
    if (is_general_grad) {
      GeneralGrad::Instance().GetPotentialStartupNodes()->emplace(grad_node);
    }
510 511 512 513 514 515 516
  }

  VLOG(6) << "Update In degree Map for backward";
  // 3. Compute in_degree for each node
  std::unordered_map<GradNodeBase*, int> node_in_degree_map =
      getInDegreeMap(queue);

517 518 519 520
  if (is_general_grad) {
    // Prepare several vital preprocess for GeneralGrad
    GeneralGrad::Instance().PreparedForGeneralGrad(inputs, no_grad_vars, &queue,
                                                   node_input_buffers_dict);
521 522
  }

523
  VLOG(6) << " startup_ops' size is :" << queue.size();
524

525 526 527
  /* --- Topological Visit --- */
  // 1. Pop queue
  // 2. Run node
528
  //    |- Check and capture target result
529 530 531 532
  //    |- node(grads)
  //    |- Prepare for next node
  // 3. Update queue
  VLOG(6) << "Run Backward";
533 534
  while (!queue.empty()) {
    GradNodeBase* node = queue.front();
535
    VLOG(6) << "Running GradNode:" << node->name();
536

537 538 539 540
    paddle::platform::RecordEvent node_record_event(
        std::string(typeid(*node).name()) + " grad_node",
        paddle::platform::TracerEventType::Operator, 1);

541 542 543 544 545 546
    if (queue.size() > 1 && node_in_degree_map[node] != 0) {
      queue.pop();
      continue;
    }
    queue.pop();

547 548 549 550
    // Run node: This is where Hook happens
    PADDLE_ENFORCE(
        node_input_buffers_dict.count(node),
        paddle::platform::errors::Fatal(
551
            "Unable to find next node in the GradTensorHolder \n"
552
            "Trying to run Node without configuring its GradTensorHolder."));
553 554 555

    std::unique_ptr<GradTensorHolder> node_input_buffer =
        std::move(node_input_buffers_dict[node]);
556

557 558 559 560
    // Set input target grad_var from node_input_buffer by inputmeta
    if (!inputs.empty() && is_general_grad) {
      GeneralGrad::Instance().SetResultForInputTargetVar(*node_input_buffer,
                                                         node);
561 562 563
    }

    // no_grad_vars
564 565 566 567 568 569 570 571 572 573
    if (!no_grad_vars.empty() && is_general_grad) {
      auto iter =
          GeneralGrad::Instance().GetNoGradVarNodesInputMetaMap()->find(node);
      if (iter !=
          GeneralGrad::Instance().GetNoGradVarNodesInputMetaMap()->end()) {
        VLOG(6) << "Change the input buffer[slot][rank] by Zeros";
        auto rank_info = (iter->second)->OutRankInfo();
        node_input_buffer->SetBufferSlotRankZeros(rank_info.first,
                                                  rank_info.second);
      }
574 575 576 577
    }

    VLOG(6) << "Running GradNode:" << node->name();

578
    // Check input
579 580
    EnforceGradNodeHasInput(node);

581
    VLOG(6) << "Run Backward Kernel with GradTensorHolder.";
582
    // Run Pre Backward Node and get outputs
583
    std::vector<std::vector<paddle::experimental::Tensor>> grad_output_tensors =
584 585 586 587 588 589 590 591 592
        (*node)(node_input_buffer->Buffers(), create_graph);

    // retain_grad or not
    if (!retain_graph) {
      VLOG(6)
          << "retain_graph is false, need to clear the TensorWrapper of nodes.";
      node->ClearTensorWrappers();
    }

593
    // TODO(jiabin): Should we erase it or find a more efficient way.
594

595 596 597 598 599 600 601
    node_input_buffers_dict.erase(node);

    // Prepare GradTensorHolder for next node
    const std::vector<std::vector<Edge>>& edges = node->GetEdges();
    PADDLE_ENFORCE(edges.size() == grad_output_tensors.size() || edges.empty(),
                   paddle::platform::errors::Fatal(
                       "Number of edges should be either empty ( for leaf node "
602 603 604
                       ") or the same as number of output grad tensors, but we "
                       "got edges size is: %d, grad_output size is: %d",
                       edges.size(), grad_output_tensors.size()));
605 606 607 608

    for (size_t i = 0; i < edges.size(); i++) {
      for (size_t j = 0; j < edges[i].size(); j++) {
        const Edge& edge = edges[i][j];
J
Jiabin Yang 已提交
609 610 611
        if (!edge.IsInitialized()) {
          continue;
        }
612 613 614 615
        auto edge_rank = edge.GetEdgeRankInfo();
        // Since we make edge has as same rank as bwd outputs, we indexing them
        // with
        // the same rank(i, j)
616
        auto next_node_shared = edge.GetMutableGradNode();
617

618 619 620
        // Next node could be nullptr if it is leaf tensor with no
        // AccumulationNode attached
        // Or it could also originated from dispensable inputs
621 622 623 624
        if (!next_node_shared || !next_node_shared.get() ||
            grad_output_tensors[i].empty()) {
          continue;
        }
625

626 627 628 629 630 631 632
        PADDLE_ENFORCE_LT(
            j, grad_output_tensors[i].size(),
            paddle::platform::errors::Fatal(
                "Rank of grad_output_tensors should be less than "
                "grad_output_tensors[i].size(), which is: %d. This error may "
                "indicate autoprune or autograd api error. ",
                grad_output_tensors.size()));
633 634
        paddle::experimental::Tensor& grad_output_tensor =
            grad_output_tensors[i][j];
635 636 637

        if ((!grad_output_tensor.defined() ||
             !grad_output_tensor.initialized())) {
638 639
          VLOG(6) << "We get grad_output_tensor with slot: " << i
                  << ", rank: " << j << " as uninitialized or undefined tensor";
640 641 642 643 644 645
        }
        VLOG(6) << "Get Edge and grad_output_tensor with slot: " << i
                << ", rank: " << j
                << " 's name is: " << grad_output_tensor.name();

        auto* next_node = next_node_shared.get();
646
        if (!node_input_buffers_dict.count(next_node)) {
647 648 649
          const auto& input_meta = next_node->InputMeta();
          auto grad_tensor_holder =
              std::make_unique<GradTensorHolder>(input_meta);
650 651
          VLOG(6) << "Construct GradTensorHolder for grad node: "
                  << next_node->name();
652
          node_input_buffers_dict[next_node] = std::move(grad_tensor_holder);
653 654 655 656 657 658 659 660
        }
        VLOG(6) << "Sum grad inputs for edge slot: " << edge_rank.first
                << ", rank: " << edge_rank.second;
        node_input_buffers_dict[next_node]->add(
            edge_rank.first, edge_rank.second, grad_output_tensor);

        // Update queue
        node_in_degree_map[next_node]--;
661

662 663 664 665
        PADDLE_ENFORCE(
            node_in_degree_map[next_node] >= 0,
            paddle::platform::errors::Fatal(
                "Detected in-degree value smaller than zero. For Node: %s"
666
                "Node's in-degree cannot be negative.",
667
                next_node->name()));
668

669 670 671 672 673 674 675 676 677 678
        if (is_general_grad) {
          bool is_potential_stop_node =
              GeneralGrad::Instance().GetPotentialStopNodes()->count(next_node);
          if (node_in_degree_map[next_node] == 0 && !is_potential_stop_node) {
            queue.emplace(std::move(next_node));
          }
        } else {
          if (node_in_degree_map[next_node] == 0) {
            queue.emplace(std::move(next_node));
          }
679 680 681 682
        }
      }
    }
  }
683

684 685
  if (!is_general_grad) return {};
  return GeneralGrad::Instance().GetResults(inputs, allow_unused, create_graph);
686 687
}

688
void Backward(
689
    const std::vector<paddle::experimental::Tensor>& tensors,  // outputs
690 691 692 693 694 695 696 697 698
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
    bool retain_graph) {
  VLOG(6) << "Run in Backward";
  paddle::platform::RecordEvent backward_record_event(
      "backward", paddle::platform::TracerEventType::Operator, 1);
  RunBackward(tensors, grad_tensors, retain_graph);
}

std::vector<paddle::experimental::Tensor> Grad(
699
    const std::vector<paddle::experimental::Tensor>& tensors,  // outputs
700 701 702 703 704
    const std::vector<paddle::experimental::Tensor>& inputs,
    const std::vector<paddle::experimental::Tensor>& grad_tensors,
    bool retain_graph, bool create_graph, bool only_inputs, bool allow_unused,
    const std::vector<paddle::experimental::Tensor>& no_grad_vars) {
  VLOG(6) << "Run in Grad";
705 706 707 708

  DuplicateCheck(inputs, true /* is_input */);
  DuplicateCheck(tensors, false /* is_input */);

709 710 711
  return RunBackward(tensors, grad_tensors, retain_graph, create_graph, inputs,
                     allow_unused, no_grad_vars);
}
712
}  // namespace egr