basic_engine.cc 23.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/basic_engine.h"

#include <algorithm>
#include <memory>
#include <queue>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
26

27
#include "paddle/fluid/framework/convert_utils.h"
28 29 30 31 32
#include "paddle/fluid/imperative/gradient_accumulator.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/op_base.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/platform/profiler.h"
33
#include "paddle/phi/kernels/funcs/math_function.h"
34

35 36
DECLARE_bool(sort_sum_gradient);

37 38 39
namespace paddle {
namespace imperative {

40 41 42 43
void BasicEngine::Init(
    const std::vector<std::shared_ptr<VarBase>>& tensors,
    const std::vector<std::shared_ptr<VarBase>>& grad_tensors,
    bool retain_graph) {
44
  retain_graph_ = retain_graph;
45

46 47 48 49 50 51 52
  PADDLE_ENFORCE_EQ(
      tensors.size(), grad_tensors.size(),
      platform::errors::Unavailable(
          "The size of tensors do not equal the size of grad_tensors,"
          "the size of tensors is %s, but the size of grad_tensors is %s.",
          tensors.size(), grad_tensors.size()));

C
chentianyu03 已提交
53 54 55 56
  PADDLE_ENFORCE_EQ(accumulators_.empty(), true,
                    platform::errors::AlreadyExists(
                        "Accumulators are not empty before preparing it for "
                        "backward network execution."));
57 58 59 60
  PADDLE_ENFORCE_EQ(accumulators_with_grad_node_.empty(), true,
                    platform::errors::AlreadyExists(
                        "Accumulators with grad_node as the key are not empty "
                        "before preparing it for backward network execution."));
C
chentianyu03 已提交
61

62 63 64 65 66
  for (size_t i = 0; i < tensors.size(); ++i) {
    auto var = tensors[i];
    auto grad_tensor = grad_tensors[i];

    auto init_node = var->GradVarBase()->GradNode();
C
chentianyu03 已提交
67

68 69 70 71 72 73 74 75 76 77 78 79 80 81
    PADDLE_ENFORCE_EQ(
        var->GradVarBase()->GraphIsFreed(), false,
        platform::errors::Unavailable(
            "%s trying to backward through the same graph a second "
            "time, but this graph have already been freed. Please "
            "specify Tensor.backward(retain_graph=True) when "
            "calling backward at the first time.",
            var->Name()));

    if (!retain_graph) {
      VLOG(5) << "Clear the auto-grad graph from grad var " << var->Name()
              << " because of retain_graph=False when calling backward";
      var->GradVarBase()->SetGraphIsFreed(true);
    }
82

83 84 85 86 87 88
    if (init_node == nullptr || var->OverridedStopGradient()) {
      VLOG(3) << "Skip auto grad since there is no grad op for var or loss is "
                 "stop_gradient=True: "
              << var->Name();
      continue;
    }
89

90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
    VLOG(3) << "Init node of backward";

    PADDLE_ENFORCE_EQ(
        var->HasGradVar(), true,
        platform::errors::NotFound("Tensor %s has no gradient", var->Name()));

    auto& fwd_var = var->Var().Get<framework::LoDTensor>();
    auto* grad_var =
        var->GradVarBase()->MutableVar()->GetMutable<framework::LoDTensor>();
    VLOG(6) << "init loss grad:" << var->GradVarBase()->Name()
            << " as stop_gradient false";
    var->GradVarBase()->InnerSetOverridedStopGradient(false);
    auto* dev_ctx =
        platform::DeviceContextPool::Instance().Get(fwd_var.place());
    if (grad_tensor == nullptr) {
      grad_var->Resize(fwd_var.dims());
      grad_var->mutable_data(fwd_var.place(), fwd_var.type());
107
      phi::funcs::set_constant(*dev_ctx, grad_var, 1.0);
108 109 110 111 112 113
    } else {
      paddle::framework::TensorCopy(
          grad_tensor->Var().Get<framework::LoDTensor>(), fwd_var.place(),
          *dev_ctx, grad_var);
    }

C
chentianyu03 已提交
114
    VariableWrapper* init_grad_var = var->GradVarBase()->SharedVar().get();
115 116 117
    auto& accumulator =
        accumulators_with_grad_node_[init_grad_var->GetGradNode()]
                                    [init_grad_var];
C
chentianyu03 已提交
118 119 120 121 122 123 124
    if (!accumulator) {
      if (FLAGS_sort_sum_gradient) {
        accumulator.reset(new SortedGradientAccumulator(init_grad_var));
      } else {
        accumulator.reset(new EagerGradientAccumulator(init_grad_var));
      }
    }
125 126
    accumulator->IncreaseRefCnt();
    accumulator->IncreaseCurCnt();
C
chentianyu03 已提交
127

128 129
    init_nodes_.push_back(init_node);
  }
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
}

void BasicEngine::CheckBackwardInputs(const OpBase& op) {
  for (auto& pair : op.GetInsMap()) {
    if (!pair.second.IsGrad()) {
      continue;
    }

    for (auto& var : pair.second) {
      if (!var) {
        continue;
      }

      auto* inner_var = var->MutableVar();
      framework::Tensor* tensor = nullptr;
      if (!inner_var->IsInitialized() ||
          inner_var->IsType<framework::LoDTensor>()) {
        tensor = inner_var->GetMutable<framework::LoDTensor>();
      }

      if (tensor && !tensor->IsInitialized()) {
        auto* dev_ctx = platform::DeviceContextPool::Instance().Get(op.place());
152 153 154 155
        // NOTE(zhiqiu): since grad variable is ungenerated, so the dtype is not
        // correct. var->DataType() returns the default dtype, which is float32.
        // Here, we use the type of the corresponding forward datatype.

156 157
        tensor->mutable_data(
            op.place(), framework::TransToPtenDataType(var->ForwardDataType()));
158 159 160
        VLOG(6) << "Set ungenerated Grad: " << var->Name()
                << " as zero with dtype "
                << framework::DataTypeToString(var->ForwardDataType());
161
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
162 163 164 165 166
      }
    }
  }
}

167 168 169
void BasicEngine::PrepareGradAccumulators(
    const OpBase& op,
    const std::vector<std::shared_ptr<GradOpNode>>& grad_pending_nodes) {
170 171 172 173 174 175 176 177
  for (const auto& pair : op.GetOutsMap()) {
    if (!pair.second.IsGrad()) {
      continue;
    }

    for (const auto& var : pair.second) {
      if (!var) continue;

178
      bool find_grad_node_of_var = false;
179
      if (grad_pending_nodes.size()) {
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
        // Because Inplace op overwrites the grad_node of the input grad_var. So
        // only the information of grad_pending_node can be used to find the
        // grad_node of grad_var.
        for (auto& grad_pending_node : grad_pending_nodes) {
          PADDLE_ENFORCE_NOT_NULL(
              grad_pending_node,
              platform::errors::NotFound("Grad pending node is nullptr."));
          for (auto& grad_pending_op : *grad_pending_node) {
            VLOG(6) << "Determine whether var (" << var->Name()
                    << ") is the input var of grad_pending_op ("
                    << grad_pending_op.Type() << ").";
            grad_pending_op.EnforceHasInOut();
            for (const auto& grad_pending_op_ins_pair :
                 grad_pending_op.GetInsMap()) {
              if (!grad_pending_op_ins_pair.second.IsGrad()) {
                continue;
              }
              for (const auto& pending_in_var :
                   grad_pending_op_ins_pair.second) {
                if (var == pending_in_var) {
                  VLOG(6) << "Var (" << var->Name()
                          << ") is the input var of grad_pending_op ("
                          << grad_pending_op.Type() << ").";
                  find_grad_node_of_var = true;
                  break;
                }
              }
              if (find_grad_node_of_var) {
                break;
              }
            }
          }
212

213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
          if (find_grad_node_of_var) {
            auto& accumulator =
                accumulators_with_grad_node_[grad_pending_node][var.get()];

            if (!accumulator) {
              if (FLAGS_sort_sum_gradient) {
                accumulator.reset(new SortedGradientAccumulator(var.get()));
              } else {
                accumulator.reset(new EagerGradientAccumulator(var.get()));
              }
            }

            accumulator->IncreaseRefCnt();

            VLOG(3) << "Prepare to acccumulate variable grad " << var->Name()
                    << "(" << var.get()
                    << ") that has grad node with reference count "
                    << accumulator->RefCnt();
            break;
          }
        }
234 235 236 237 238 239 240 241 242 243 244
        if (!find_grad_node_of_var) {
          // Special case: `set_value` is inplace op, and it can change
          // the var with `stop_gradient=True` to the var with
          // `stop_gradient=False `.
          // This inplace var has grad_node (the inplace op), but it
          // isn't the input of grad_pending_op.
          VLOG(6) << "No grad node corresponding to grad Tensor ("
                  << var->Name() << ") was found.";
        }
      }

245
      if (!grad_pending_nodes.size() || !find_grad_node_of_var) {
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
        auto& accumulator = accumulators_[var.get()];
        if (!accumulator) {
          if (FLAGS_sort_sum_gradient) {
            accumulator.reset(new SortedGradientAccumulator(var.get()));
          } else {
            accumulator.reset(new EagerGradientAccumulator(var.get()));
          }
        }

        accumulator->IncreaseRefCnt();

        VLOG(3) << "Prepare to acccumulate variable grad " << var->Name() << "("
                << var.get()
                << ") that don't have grad node  with reference count "
                << accumulator->RefCnt();
261
      }
262 263 264 265 266 267 268
    }
  }
}

void BasicEngine::PrepareDeps() {
  PADDLE_ENFORCE_EQ(
      node_deps_.empty(), true,
269 270
      platform::errors::AlreadyExists("Op deps are not empty before preparing "
                                      "it for backward network execution."));
271 272 273 274

  std::queue<GradOpNode*> q;
  std::unordered_set<GradOpNode*> visited;

275 276 277 278
  for (size_t i = 0; i < init_nodes_.size(); ++i) {
    q.push(init_nodes_[i].get());
    visited.insert(init_nodes_[i].get());
  }
279 280 281 282 283

  while (!q.empty()) {
    auto* cur_node = q.front();
    q.pop();

284 285
    const auto& grad_pending_nodes = cur_node->GradPendingNodes();

286
    for (auto& cur_op : *cur_node) {
Z
Zeng Jinle 已提交
287
      cur_op.EnforceHasInOut();
288
      PrepareGradAccumulators(cur_op, grad_pending_nodes);
289 290 291 292 293
    }

    for (auto& grad_pending_node : grad_pending_nodes) {
      PADDLE_ENFORCE_NOT_NULL(
          grad_pending_node,
294
          platform::errors::NotFound("Grad pending node is nullptr."));
295 296 297 298 299 300 301 302 303
      ++node_deps_[grad_pending_node.get()];
      if (visited.count(grad_pending_node.get()) == 0) {
        visited.insert(grad_pending_node.get());
        q.push(grad_pending_node.get());
      }
    }
  }
}

304 305 306 307 308 309
static std::shared_ptr<NameVarMap<VariableWrapper>> CallGradientHooks(
    const NameVarMap<VariableWrapper>& bwd_ins, const std::string& op_type) {
  std::shared_ptr<NameVarMap<VariableWrapper>> tmp_ins_ptr = nullptr;
  for (const auto& pair : bwd_ins) {
    for (size_t i = 0; i < pair.second.size(); ++i) {
      auto& var = pair.second[i];
310
      if (var->HasVariableWrapperHook()) {
311 312 313
        if (tmp_ins_ptr == nullptr) {
          tmp_ins_ptr = std::make_shared<NameVarMap<VariableWrapper>>(bwd_ins);
        }
314 315 316
        VLOG(3) << "Call " << var->GetVariableWrapperHooks().size()
                << " hooks of " << op_type << "'s input `" << pair.first
                << "`'s var `" << var->Name() << "`.";
317
        auto tmp_var = var;
318
        for (const auto& hook_pair : var->GetVariableWrapperHooks()) {
319 320 321 322 323 324 325 326 327
          tmp_var = (*hook_pair.second)(tmp_var);
        }
        (*tmp_ins_ptr)[pair.first][i] = tmp_var;
      }
    }
  }
  return tmp_ins_ptr;
}

328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
static bool IsInputCanInplace(const std::shared_ptr<VariableWrapper>& var) {
  auto* inner_var = var->MutableVar();
  if (inner_var->IsInitialized() && inner_var->IsType<framework::LoDTensor>()) {
    auto tensor = inner_var->GetMutable<framework::LoDTensor>();
    if (tensor->IsInitialized()) {
      return true;
    }
  }
  return false;
}

static void PerformBackwardInplace(const std::string& op_type,
                                   const NameVarMap<VariableWrapper>& ins,
                                   NameVarMap<VariableWrapper>* outs) {
  auto& infer_inplace =
      paddle::framework::OpInfoMap::Instance().Get(op_type).infer_inplace_;

  if (infer_inplace) {
    auto in_to_outs = infer_inplace(true);
    for (auto& pair : in_to_outs) {
      framework::LoDTensor *in_tensor = nullptr, *out_tensor = nullptr;
      for (auto& p : ins) {
        if (p.first == pair.first) {
          // has at least one var
          if (p.second.size() > 0 && p.second[0]) {
            auto& in_var = p.second[0];
            VLOG(10) << p.first << " use_count: " << in_var.use_count();
            // the refcount of var to be inplaced should be 1
            if (in_var.use_count() == 1) {
              if (IsInputCanInplace(in_var)) {
                in_tensor =
                    in_var->MutableVar()->GetMutable<framework::LoDTensor>();
              }
            }
          }
        }
      }
      if (!in_tensor) {
        continue;
      }
      for (auto& p : *outs) {
        if (p.first == pair.second) {
          if (p.second.size() > 0 && p.second[0]) {
            auto& out_var = p.second[0];
            if (out_var->Type() == framework::proto::VarType::LOD_TENSOR) {
              out_tensor =
                  out_var->MutableVar()->GetMutable<framework::LoDTensor>();
            }
          }
        }
      }
      if (!out_tensor) {
        continue;
      }
      out_tensor->ShareBufferWith(*in_tensor);
      out_tensor->Resize(in_tensor->dims());
      VLOG(4) << "Inplace performed in op " << op_type << ": " << pair.second
              << " -> " << pair.first;
    }
  }
}

390
void BasicEngine::Execute() {
391
  if (init_nodes_.empty()) {
392 393 394 395 396 397
    return;
  }

  PrepareDeps();
  // Start execute Computation graph
  std::queue<std::shared_ptr<GradOpNode>> q;
398
  for (size_t i = 0; i < init_nodes_.size(); ++i) {
C
chentianyu03 已提交
399 400 401
    if (node_deps_[init_nodes_[i].get()] == 0) {
      q.push(std::move(init_nodes_[i]));
    }
402
  }
403 404 405 406 407 408 409

  size_t op_num = 0;

  while (!q.empty()) {
    auto shared_cur_node = std::move(q.front());
    q.pop();

410 411
    auto& inplace_grad_name_map = shared_cur_node->InplaceGradNameMap();

412
    for (auto& cur_op : *shared_cur_node) {
413 414
      platform::RecordEvent op_type_record_event(cur_op.Type());

415 416 417 418 419
      ++op_num;

      // CheckBackWardInput
      CheckBackwardInputs(cur_op);

420
      // Step 1: Run Backward OP
421 422 423
      auto& bwd_ins = cur_op.GetInsMap();
      auto& bwd_outs = cur_op.GetOutsMap();

424 425 426 427 428 429 430
      /**
       * [ Why need temporary outputs here? ]
       *
       * - construct the temp output map, avoid to disrupt graph
       * - replace the element in the map by temp var, because a
       *   var may be coresponding to several grad var in one op
       */
431
      NameVarMap<VariableWrapper> tmp_outs(bwd_outs);
432

433 434 435 436 437 438 439 440 441 442
      for (auto& pair : tmp_outs) {
        if (!pair.second.IsGrad()) {
          continue;
        }

        for (auto& var : pair.second) {
          if (!var) {
            continue;
          }

443
          const auto& grad_pending_nodes = shared_cur_node->GradPendingNodes();
444 445 446
          std::unordered_map<VariableWrapper*,
                             std::unique_ptr<GradientAccumulator>>::iterator
              iter;
447
          bool flag_find_grad = false;
448
          if (grad_pending_nodes.size()) {
449 450
            VLOG(10) << "Find gradient of var (" << var->Name()
                     << ") with grad_node.";
451
            for (auto& grad_pending_node : grad_pending_nodes) {
452 453 454 455 456 457 458 459 460 461
              const auto& iter_grad_node =
                  accumulators_with_grad_node_.find(grad_pending_node);
              if (iter_grad_node != accumulators_with_grad_node_.end()) {
                iter = iter_grad_node->second.find(var.get());
                if (iter != iter_grad_node->second.end()) {
                  flag_find_grad = true;
                  break;
                }
              }
            }
462
            if (!flag_find_grad) {
463 464
              VLOG(6) << "Cannot find gradient of variable " << var->Name()
                      << " in accumulators_with_grad_node_";
465 466
            }
          }
467
          if (!grad_pending_nodes.size() || !flag_find_grad) {
468 469 470
            VLOG(10) << "Find gradient of var (" << var->Name()
                     << ") with no grad_node.";
            iter = accumulators_.find(var.get());
471
            PADDLE_ENFORCE_EQ(
472
                iter != accumulators_.end(), true,
473 474 475
                platform::errors::NotFound(
                    "Cannot find gradient of variable %s", var->Name()));
          }
476

477 478
          // leaf_accumulators_ : hooks and accumulate-grad for leaf tensor,
          // it should be orderly and not reapeated.
479
          if (var->IsLeafGrad()) {
480 481 482 483
            if (std::find(leaf_accumulators_.begin(), leaf_accumulators_.end(),
                          iter->second.get()) == leaf_accumulators_.end()) {
              leaf_accumulators_.push_back(iter->second.get());
            }
484 485 486 487

            if (iter->second->HasInnerVar()) {
              var = iter->second->InnerVar();
            }
488 489
          }

490 491 492
          if (var->OverridedStopGradient() || iter->second->RefCnt() > 1) {
            auto tmp_var = std::make_shared<VariableWrapper>(var->Name());
            tmp_var->SetType(var->Type());
493
            tmp_var->SetForwardDataType(var->ForwardDataType());
494 495 496 497
            var = tmp_var;
            need_accu_var_list_.emplace_back(iter->second.get(), var);
            VLOG(10) << "create temporary var of " << var->Name()
                     << " for sum gradient within this graph!";
498
          } else if (!inplace_grad_name_map.empty() &&
499 500
                     inplace_grad_name_map.count(pair.first) &&
                     bwd_ins.count(inplace_grad_name_map.at(pair.first))) {
501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517
            // When calculate Inplace grad op, create a new output var.
            // If a tmp var has been created, there is no need to create it
            // again.
            for (auto& in_var :
                 bwd_ins.at(inplace_grad_name_map.at(pair.first))) {
              if (in_var == var) {
                auto tmp_var = std::make_shared<VariableWrapper>(var->Name());
                tmp_var->SetType(var->Type());
                tmp_var->SetForwardDataType(var->ForwardDataType());
                inplace_output_grad_var_list_.emplace_back(var, tmp_var);
                var = tmp_var;
                VLOG(10) << "Inplace grad op does not use the Inplace "
                            "strategy, a temporary output var ("
                         << var->Name() << ") will be created.";
                break;
              }
            }
518
          }
519 520 521
        }
      }

522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546
      VLOG(4) << "Check whether there is any inplace operation affecting "
                 "gradient calculation.";
      for (auto& pair : bwd_ins) {
        for (auto& var_wrapper : pair.second) {
          auto wrapper_version_snapshot = var_wrapper->InplaceVersionSnapshot();
          auto tensor_version =
              var_wrapper->MutableVar()->CurrentInplaceVersion();
          PADDLE_ENFORCE_EQ(
              tensor_version, wrapper_version_snapshot,
              platform::errors::PermissionDenied(
                  "Tensor '%s' used in gradient computation in grad op '%s' "
                  "has been "
                  "modified by an inplace operation. "
                  "Its version is %s but the expected version is %s. "
                  "Please fix your code to void calling an inplace operator "
                  "after using the Tensor which will used in gradient "
                  "computation.",
                  var_wrapper->Name(), cur_op.Type(), tensor_version,
                  wrapper_version_snapshot));

          VLOG(6) << " The version of Tensor '" << var_wrapper->Name()
                  << "' is [ " << wrapper_version_snapshot << " ]";
        }
      }

547 548 549 550 551 552 553 554 555 556 557 558 559
      /**
       * [ Why need temporary inputs here? ]
       *
       * - Hook execution should not change original input tensor.
       *   User can register hook for Tensor's gradient, It is expected
       *   that the hook only affects the gradient of the backward
       *   propagation, and does not affect the gradient value input
       *   as the hook.
       * - use `tmp_ins_ptr`, only copy bwd_ins when the var in bwd_ins
       *   hold hooks
       */
      auto tmp_ins_ptr = CallGradientHooks(bwd_ins, cur_op.Type());

560 561 562 563
      if (!tmp_ins_ptr) {
        PerformBackwardInplace(cur_op.Type(), bwd_ins, &tmp_outs);
      }

564 565
      {
        VLOG(3) << "Start to execute grad op " << cur_op.Type();
566 567 568
        try {
          if (tmp_ins_ptr == nullptr) {
            OpBase::Run(cur_op.InnerOp(), bwd_ins, tmp_outs, cur_op.Attrs(),
569
                        cur_op.DefaultAttrsMap(), cur_op.place());
570 571
          } else {
            OpBase::Run(cur_op.InnerOp(), *tmp_ins_ptr, tmp_outs,
572 573
                        cur_op.Attrs(), cur_op.DefaultAttrsMap(),
                        cur_op.place());
574 575 576 577 578 579 580
          }
        } catch (platform::EnforceNotMet& exception) {
          Clear();
          throw std::move(exception);
        } catch (std::exception& ex) {
          Clear();
          PADDLE_THROW(platform::errors::External("%s", ex.what()));
581
        }
582 583
      }

584 585 586 587 588 589 590
      // Function Post Hook
      if (cur_op.HasVoidFunctionPostHook()) {
        for (const auto& hook : cur_op.GetVoidFunctionPostHooks()) {
          (*hook)();
        }
      }

591 592 593 594
      for (auto& pair : inplace_output_grad_var_list_) {
        *pair.first = std::move(*pair.second);
      }

595 596 597 598 599 600 601 602 603 604
      // Step 2: Sum Gradient of This graph
      for (auto& pair : need_accu_var_list_) {
        pair.first->SumGrad(std::move(pair.second), cur_op.id());
      }

      // Step 3: Call Hooks && Sum Gradient with Pre-Graph && Call BackwardHooks
      for (auto* accumulator : leaf_accumulators_) {
        if (!accumulator->SumGradCompleted()) {
          continue;
        }
605 606
        // 1. Call Hooks for `inner_var_`
        accumulator->CallGradientHooks();
607

608
        // 2. Sum Gradient `inner_var_` to `var_` of Current or Previous Graph
609 610
        accumulator->AccumulateGrad();

611 612
        // 3. Call backward Hooks for `var_`
        accumulator->CallReduceHooks();
613 614
      }

615
      need_accu_var_list_.clear();
616
      inplace_output_grad_var_list_.clear();
617
      leaf_accumulators_.clear();
618

619
      if (!retain_graph_) {
620
        VLOG(3) << "Remove op after op " << cur_op.Type() << " runs";
621 622
        cur_op.ClearBackwardTrace();
      }
623 624 625 626
    }

    // Step 3: Collect ready ops
    for (auto& grad_pending_node : shared_cur_node->GradPendingNodes()) {
627 628 629
      PADDLE_ENFORCE_NOT_NULL(
          grad_pending_node,
          platform::errors::NotFound("Grad pending node is nullptr."));
630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645
      auto iter = node_deps_.find(grad_pending_node.get());
      if (iter == node_deps_.end()) {
        continue;
      }

      if (--(iter->second) == 0) {
        q.push(grad_pending_node);
      }
    }
  }
  Clear();

  VLOG(1) << "Backward op number: " << op_num;
}

void BasicEngine::Clear() {
646
  init_nodes_.clear();
647 648
  node_deps_.clear();
  accumulators_.clear();
649
  accumulators_with_grad_node_.clear();
650
  need_accu_var_list_.clear();
651
  leaf_accumulators_.clear();
652 653 654 655
}

}  // namespace imperative
}  // namespace paddle