partial_grad_engine.cc 36.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/partial_grad_engine.h"

#include <algorithm>
#include <map>
#include <memory>
#include <queue>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/imperative/gradient_accumulator.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/op_base.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/string/string_helper.h"

namespace paddle {
namespace imperative {

/**
 * This function prunes the graph to get the ops between `output_targets`
 * and `input_target_grads`.
 *
 *
 * The inputs are:
 *
 *  - input_target_grads: the input target grads. It may be changed.
 *  - output_targets: the output target vars. It may be changed.
 *
 *
 * The outputs are:
 *
 *  - startup_op_ptr: startup ops of the pruned graph.
 *  - pending_ops_ptr: contains all the pending ops of each op in the graph.
 *  - op_deps_ptr: the preceding op number of each op in the graph.
 *  - related_grad_vars_ptr: all grad vars in the pruned graph.
 */
static void GetGraphInfoBetweenTargets(
    std::unordered_set<VariableWrapper *> *input_target_grads,
    std::unordered_set<VarBase *> *output_targets,
Z
Zeng Jinle 已提交
60 61 62
    std::unordered_set<OpBase *> *startup_ops_ptr,
    std::unordered_map<OpBase *, std::unordered_set<OpBase *>> *pending_ops_ptr,
    std::unordered_map<OpBase *, size_t> *op_deps_ptr,
63 64 65 66 67
    std::unordered_set<VariableWrapper *> *related_grad_vars_ptr,
    const std::unordered_set<VariableWrapper *> &no_grad_var_grad) {
  /**
   * Step 1. Find the candidate startup grad ops, prepared for following BFS.
   */
Z
Zeng Jinle 已提交
68
  std::queue<std::pair<OpBase *, GradOpNode *>> q;
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
  std::unordered_set<GradOpNode *> visited;
  for (auto iter = output_targets->begin(); iter != output_targets->end();) {
    auto *output_target = *iter;
    PADDLE_ENFORCE_NOT_NULL(
        output_target,
        platform::errors::NotFound("output_target must not be nullptr"));
    if (output_target->OverridedStopGradient() ||
        output_target->GradVarBase() == nullptr ||
        output_target->GradVarBase()->GradNode() == nullptr) {
      VLOG(10) << output_target->Name()
               << " is pruned because it stops gradient or has no grad var";
      iter = output_targets->erase(iter);
      continue;
    }

    auto &grad_node = output_target->GradVarBase()->GradNode();
    if (visited.count(grad_node.get()) == 0) {
      for (auto &op : *grad_node) {
        q.emplace(&op, grad_node.get());
      }
    }
    ++iter;
  }

  /**
   * Step 2. BFS the graph and find all grad ops which generate the
   * input_target_grads. Notice that not all candidate startup ops
   * would be connected with input_target_grads, that is to say,
   * not all input_target_grads would be found.
   */
  std::unordered_set<VariableWrapper *> found_input_target_grads;
Z
Zeng Jinle 已提交
100 101
  std::unordered_set<OpBase *> endpoint_ops;
  std::unordered_map<OpBase *, std::unordered_set<OpBase *>> preceding_ops;
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
  while (!q.empty()) {
    auto op_node_pair = q.front();
    q.pop();

    auto *op = op_node_pair.first;
    auto *node = op_node_pair.second;

    for (auto &output_pair : op->GetOutsMap()) {
      if (!output_pair.second.IsGrad()) {
        VLOG(10) << "WARNING: " << op->Type() << " outputs a forward var";
        continue;
      }

      for (auto &out_var : output_pair.second) {
        if (out_var && input_target_grads->count(out_var.get()) > 0) {
          VLOG(10) << "Found endpoint op " << op->Type() << " which generates "
                   << out_var->Name();
          found_input_target_grads.insert(out_var.get());
          endpoint_ops.emplace(op);
        }
      }
    }

    for (auto &pending_node : node->GradPendingNodes()) {
      if (visited.count(pending_node.get()) == 0) {
        for (auto &pending_op : *pending_node) {
          preceding_ops[&pending_op].insert(op);
          q.emplace(&pending_op, pending_node.get());
        }
      }
    }
  }

  /**
   * Step 3. Based on the found input_target_grads, BFS the graph in reverse
   * order. `target_vars` would record all grad vars in the graph, and
   * `startup_ops` would be the final startup ops of the graph.
   */
  *input_target_grads = found_input_target_grads;

  auto &pending_ops = *pending_ops_ptr;
  pending_ops.clear();

  auto &startup_ops = *startup_ops_ptr;
  startup_ops.clear();

  auto &op_deps = *op_deps_ptr;
  op_deps.clear();

  auto &target_vars = *related_grad_vars_ptr;
  target_vars = *input_target_grads;

Z
Zeng Jinle 已提交
154
  std::queue<std::pair<OpBase * /*op*/, OpBase * /*pending op*/>> op_queue;
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
  for (auto &endpoint_op : endpoint_ops) {
    op_queue.emplace(endpoint_op, nullptr);
  }

  while (!op_queue.empty()) {
    auto op_pair = op_queue.front();
    auto *op = op_pair.first;
    auto *pending_op = op_pair.second;

    op_queue.pop();

    bool is_valid = false;
    for (auto &output_pair : op->GetOutsMap()) {
      if (!output_pair.second.IsGrad()) {
        continue;
      }

      for (auto &out_var : output_pair.second) {
        if (out_var && target_vars.count(out_var.get()) > 0) {
          is_valid = true;
          break;
        }
      }

      if (is_valid) {
        break;
      }
    }

    if (!is_valid) {
      continue;
    }

    is_valid = false;
    for (auto &input_pair : op->GetInsMap()) {
      if (!input_pair.second.IsGrad()) {
        continue;
      }

      for (auto &in_var : input_pair.second) {
        if (in_var && no_grad_var_grad.count(in_var.get()) == 0) {
          target_vars.insert(in_var.get());
          is_valid = true;
        }
      }
    }

    if (!is_valid) {
      continue;
    }

    op_deps[op];
    if (pending_op) {
      VLOG(10) << "Pending op of " << op->Type() << " is "
               << pending_op->Type();
      pending_ops[op].insert(pending_op);
      ++op_deps[pending_op];
    } else {
      pending_ops[op];
    }

    auto iter = preceding_ops.find(op);
    if (iter != preceding_ops.end()) {
      for (auto &preceding_op : iter->second) {
        op_queue.emplace(preceding_op, op);
      }
    }
  }

  for (auto &pair : op_deps) {
    if (pair.second == 0) {
      auto *op = pair.first;
      VLOG(10) << "Found startup op " << op->Type();
      startup_ops.insert(op);
    }
  }

  /**
   * Step 4. Prune output_targets which is not the input of startup_ops
   */
  for (auto iter = output_targets->begin(); iter != output_targets->end();) {
    auto &grad_node = (*iter)->GradVarBase()->GradNode();
    bool is_valid = std::find_if(grad_node->begin(), grad_node->end(),
Z
Zeng Jinle 已提交
238
                                 [&](OpBase &op) {  // NOLINT
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517
                                   return startup_ops.count(&op) > 0;
                                 }) != grad_node->end();
    if (is_valid) {
      ++iter;
    } else {
      iter = output_targets->erase(iter);
    }
  }
}

// Get debug string of op types contained in `node`
static std::string GradOpTypes(const GradOpNode &node) {
  std::vector<std::string> node_types;
  for (auto &op : node) {
    node_types.emplace_back(op.Type());
  }
  return string::join_strings(node_types, ',');
}

// Get debug string of grad node of `var`'s gradient
static std::string GradOpTypes(const VarBase &var) {
  if (!var.GradVarBase() || !var.GradVarBase()->GradNode()) {
    return "";
  } else {
    return GradOpTypes(*(var.GradVarBase()->GradNode()));
  }
}

// Get pending op types of `node`
static std::string GradPendingOpTypes(const GradOpNode &node) {
  std::vector<std::string> node_types;
  for (auto &n : node.GradPendingNodes()) {
    node_types.emplace_back(GradOpTypes(*n));
  }
  return string::join_strings(node_types, ',');
}

static void FillConstantLike(const VariableWrapper &ref_var,
                             VariableWrapper *dst_var,
                             const platform::Place &place, float value) {
  auto &ref_tensor = ref_var.Var().Get<framework::LoDTensor>();
  auto *dst_tensor = dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
  auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
  dst_tensor->Resize(ref_tensor.dims());
  dst_tensor->mutable_data(place, ref_var.DataType());
  operators::math::set_constant(*dev_ctx, dst_tensor, value);
}

/**
 * A data structure for gradient accumulation
 */
class GradientAccumulationInfo {
 private:
  using PartialGradGradTraceIdPair =
      std::pair<std::weak_ptr<VariableWrapper> /*partial grad grad var*/,
                size_t /*trace_id*/>;

 public:
  explicit GradientAccumulationInfo(const std::shared_ptr<VariableWrapper> &var,
                                    bool sort_gradient, bool create_graph)
      : mapped_grad_var_(var.get()),
        sort_gradient_(sort_gradient),
        create_graph_(create_graph) {}

  void IncreaseTotalRefCnt() {
    ++total_ref_cnt_;

    // The gradient accumulator is needed only when total_ref_cnt_ > 1.
    // grad_var_ would be created only when total_ref_cnt_ > 1.
    if (total_ref_cnt_ > 1) {
      if (!grad_var_) {
        grad_var_ = std::make_shared<VarBase>(true, mapped_grad_var_->Name());
        grad_var_->SetOverridedStopGradient(false);
        if (sort_gradient_) {
          accumulator_.reset(
              new SortedGradientAccumulator(grad_var_->SharedVar().get()));
        } else {
          accumulator_.reset(
              new EagerGradientAccumulator(grad_var_->SharedVar().get()));
        }
        accumulator_->IncreaseRefCnt();
      }
      accumulator_->IncreaseRefCnt();
    }
  }

  size_t TotalRefCnt() { return total_ref_cnt_; }

  const std::shared_ptr<VarBase> &GradVarBase() const { return grad_var_; }

  std::shared_ptr<VariableWrapper> GradVar() const {
    return grad_var_ == nullptr ? nullptr : grad_var_->SharedVar();
  }

  VariableWrapper *MappedGradVar() { return mapped_grad_var_; }

  std::vector<std::shared_ptr<VariableWrapper>> SumGradient(
      std::shared_ptr<VariableWrapper> grad_var_partial, size_t trace_id,
      bool *is_finished, bool unchange_input = false) {
    PADDLE_ENFORCE_NOT_NULL(grad_var_partial,
                            platform::errors::PermissionDenied(
                                "Partial grad of %s would not be nullptr",
                                mapped_grad_var_->Name()));
    PADDLE_ENFORCE_GT(total_ref_cnt_, 1,
                      platform::errors::PermissionDenied(
                          "Gradient accumulation should not be called when "
                          "reference count is 1 or 0"));

    ++cur_ref_cnt_;
    PADDLE_ENFORCE_LE(cur_ref_cnt_, total_ref_cnt_,
                      platform::errors::PermissionDenied(
                          "Reference count overflows, this may be a bug"));

    *is_finished = (cur_ref_cnt_ == total_ref_cnt_);
    accumulator_->Add(grad_var_partial, trace_id, unchange_input);

    if (create_graph_) {
      VLOG(10) << "Store partial grad grad for double grad "
               << mapped_grad_var_->Name();
      partial_grad_grads_.emplace_back(grad_var_partial->GetWeakGradVar(),
                                       trace_id);
    }

    if (!(*is_finished) || !create_graph_) {
      return {};
    }

    if (sort_gradient_) {
      std::sort(partial_grad_grads_.begin(), partial_grad_grads_.end(),
                [](const PartialGradGradTraceIdPair &p1,
                   const PartialGradGradTraceIdPair &p2) {
                  return p1.second > p2.second;
                });
    }

    // Only when create_graph_ = True, the return value would be not empty
    std::vector<std::shared_ptr<VariableWrapper>> result;
    result.reserve(partial_grad_grads_.size());
    for (auto &pair : partial_grad_grads_) {
      if (auto var = pair.first.lock()) {
        result.emplace_back(var);
      }
    }
    return result;
  }

 private:
  std::shared_ptr<VarBase> grad_var_;
  VariableWrapper *mapped_grad_var_;
  std::unique_ptr<GradientAccumulator> accumulator_;
  std::vector<PartialGradGradTraceIdPair> partial_grad_grads_;
  size_t total_ref_cnt_{0};
  size_t cur_ref_cnt_{0};
  bool sort_gradient_;
  bool create_graph_;
};

class ReadyGradVarInfoMap {
 private:
  struct ReadyVarInfo {
    std::shared_ptr<VarBase> var;
    size_t cur_ref_cnt{0};
    size_t total_ref_cnt{0};
  };

 public:
  void IncreaseRefCnt(const VariableWrapper *var) {
    ++(vars_[var].total_ref_cnt);
  }

  std::shared_ptr<VarBase> Get(const VariableWrapper *var,
                               const platform::Place &place, bool *is_last) {
    auto iter = vars_.find(var);
    PADDLE_ENFORCE_EQ(
        iter != vars_.end(), true,
        platform::errors::NotFound("Variable %s not found, this may be a bug",
                                   var->Name()));
    auto &ready_var = iter->second;
    PADDLE_ENFORCE_LT(ready_var.cur_ref_cnt, ready_var.total_ref_cnt,
                      platform::errors::PermissionDenied(
                          "Reference count overflows for %s", var->Name()));

    if (ready_var.var == nullptr && ready_var.cur_ref_cnt == 0) {
      ready_var.var = std::make_shared<VarBase>(var->Name());
      VLOG(10) << "Fill zero for " << var->Name() << " because it is not ready";
      FillConstantLike(*var, ready_var.var->SharedVar().get(), place, 0.0f);
    } else {
      PADDLE_ENFORCE_NOT_NULL(
          ready_var.var,
          platform::errors::NotFound(
              "%s is not found when reference count does not decreases to 0"));
    }

    if (++ready_var.cur_ref_cnt == ready_var.total_ref_cnt) {
      *is_last = true;
      return std::move(ready_var.var);  // move to set ready_var.var to nullptr
    } else {
      *is_last = false;
      return ready_var.var;
    }
  }

  // Set a var as a ready var.
  // If the var is one of target vars, store it inside `target_vars_` as well.
  bool Set(const VariableWrapper *mapped_var,
           const std::shared_ptr<VarBase> &var) {
    PADDLE_ENFORCE_NOT_NULL(
        var,
        platform::errors::PermissionDenied(
            "Cannot set nullptr as ready grad var for %s", mapped_var->Name()));
    {
      auto target_iter = target_vars_.find(mapped_var);
      if (target_iter != target_vars_.end()) {
        PADDLE_ENFORCE_EQ(
            target_iter->second, nullptr,
            platform::errors::PermissionDenied("Cannot set target var %s twice",
                                               mapped_var->Name()));
        target_iter->second = var;
      }
    }

    auto iter = vars_.find(mapped_var);
    if (iter != vars_.end()) {  // This var is ready for next op's input
      auto &ready_var = iter->second;
      PADDLE_ENFORCE_EQ(
          ready_var.var, nullptr,
          platform::errors::PermissionDenied("Cannot set target var %s twice",
                                             mapped_var->Name()));
      PADDLE_ENFORCE_EQ(
          ready_var.cur_ref_cnt, 0,
          platform::errors::PermissionDenied(
              "Reference count must be 0 when ready var %s is set",
              mapped_var->Name()));
      ready_var.var = var;
      return true;
    } else {
      VLOG(10) << "Do not record " << mapped_var->Name()
               << " because it is not input of any following ops";
      return false;
    }
  }

  void Clear() {
    vars_.clear();
    target_vars_.clear();
  }

  // Mark a var as target var
  void SetTarget(const VariableWrapper *var) {
    PADDLE_ENFORCE_EQ(target_vars_[var], nullptr,
                      platform::errors::PermissionDenied(
                          "Target var would not be generated when marking"));
  }

  // Get target var
  const std::shared_ptr<VarBase> &GetTarget(const VariableWrapper *var) const {
    auto iter = target_vars_.find(var);
    PADDLE_ENFORCE_EQ(iter != target_vars_.end(), true,
                      platform::errors::NotFound("Target var %s does not exist",
                                                 var->Name()));
    PADDLE_ENFORCE_NOT_NULL(
        iter->second, platform::errors::PermissionDenied(
                          "Target var %s should not be nullptr", var->Name()));
    return iter->second;
  }

 private:
  std::unordered_map<const VariableWrapper *, ReadyVarInfo> vars_;
  std::unordered_map<const VariableWrapper *, std::shared_ptr<VarBase>>
      target_vars_;
};

class PartialGradTask {
 public:
  PartialGradTask(const std::vector<std::shared_ptr<VarBase>> &input_targets,
                  const std::vector<std::shared_ptr<VarBase>> &output_targets,
                  const std::vector<std::shared_ptr<VarBase>> &output_grads,
                  const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
                  const platform::Place &place,
Z
Zeng Jinle 已提交
518 519
                  const detail::BackwardStrategy &strategy, bool create_graph,
                  bool retain_graph, bool allow_unused, bool only_inputs);
520 521 522 523

  std::vector<std::shared_ptr<VarBase>> Run();

 private:
Z
Zeng Jinle 已提交
524
  void RunEachOp(OpBase *op);
525 526 527 528 529 530 531 532 533 534 535 536

  void PrepareInitialReadyVarsMap(const OpBase *op);

  void PrepareInitialGradientAccumulators(const OpBase *op);

  std::vector<std::shared_ptr<VarBase>> CreateResult();

  bool IsValidGradVar(const std::shared_ptr<VariableWrapper> &var) const {
    return var && no_grad_var_grad_.count(var.get()) == 0;
  }

 private:
Z
Zeng Jinle 已提交
537 538 539
  std::unordered_set<OpBase *> startup_ops_;
  std::unordered_map<OpBase *, std::unordered_set<OpBase *>> pending_ops_;
  std::unordered_map<OpBase *, size_t> op_deps_;
540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561

  ReadyGradVarInfoMap ready_grad_vars_;

  std::unordered_map<VariableWrapper *,
                     std::unique_ptr<GradientAccumulationInfo>>
      grad_accumulators_;

  std::vector<std::shared_ptr<GradOpNode>> double_grad_nodes_;

  std::vector<
      std::pair<GradientAccumulationInfo *, std::shared_ptr<VariableWrapper>>>
      grads_to_accumulate_;

  // Input targets that are reachable
  std::vector<std::shared_ptr<VarBase>> input_targets_;
  std::unordered_set<VariableWrapper *> input_target_grads_;

  std::unordered_set<VariableWrapper *> no_grad_var_grad_;
  std::vector<std::weak_ptr<VariableWrapper>> reset_stop_gradient_vars_;

  platform::Place place_;
  bool create_graph_;
Z
Zeng Jinle 已提交
562 563 564
  bool retain_graph_;
  bool allow_unused_;
  bool only_inputs_;
565 566 567 568 569 570 571 572 573
  detail::BackwardStrategy strategy_;
};

PartialGradTask::PartialGradTask(
    const std::vector<std::shared_ptr<VarBase>> &input_targets,
    const std::vector<std::shared_ptr<VarBase>> &output_targets,
    const std::vector<std::shared_ptr<VarBase>> &output_grads,
    const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
    const platform::Place &place, const detail::BackwardStrategy &strategy,
Z
Zeng Jinle 已提交
574
    bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs) {
575 576 577
  input_targets_ = input_targets;
  place_ = place;
  create_graph_ = create_graph;
Z
Zeng Jinle 已提交
578 579 580
  retain_graph_ = retain_graph;
  allow_unused_ = allow_unused;
  only_inputs_ = only_inputs;
581 582
  strategy_ = strategy;

Z
Zeng Jinle 已提交
583 584 585 586
  PADDLE_ENFORCE_EQ(only_inputs_, true,
                    platform::errors::Unimplemented(
                        "only_inputs=False is not supported yet"));

587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728
  for (auto &var : no_grad_vars) {
    if (var && var->GradVarBase()) {
      no_grad_var_grad_.insert(var->GradVarBase()->SharedVar().get());
    }
  }

  PADDLE_ENFORCE_EQ(
      input_targets.empty(), false,
      platform::errors::PermissionDenied("inputs can not be empty"));
  PADDLE_ENFORCE_EQ(
      output_targets.empty(), false,
      platform::errors::PermissionDenied("outputs can not be empty"));

  std::unordered_set<VarBase *> out_set;
  for (auto &output : output_targets) {
    PADDLE_ENFORCE_NOT_NULL(output,
                            platform::errors::PermissionDenied(
                                "Variable inside outputs should not be null"));
    PADDLE_ENFORCE_EQ(
        output->GradVarBase() && !output->OverridedStopGradient(), true,
        platform::errors::PermissionDenied(
            "Variable %s inside outputs has no gradient", output->Name()));
    PADDLE_ENFORCE_EQ(
        out_set.count(output.get()), 0,
        platform::errors::AlreadyExists("outputs contain duplicate variable %s",
                                        output->Name()));
    PADDLE_ENFORCE_EQ(IsValidGradVar(output->GradVarBase()->SharedVar()), true,
                      platform::errors::PermissionDenied(
                          "outputs contain var that is inside no_grad_set"));

    out_set.insert(output.get());
  }

  std::unordered_set<VarBase *> in_set;
  std::unordered_set<VariableWrapper *> one_grad_vars;
  for (auto &input : input_targets) {
    PADDLE_ENFORCE_NOT_NULL(input,
                            platform::errors::PermissionDenied(
                                "Variable inside inputs should not be null"));
    PADDLE_ENFORCE_EQ(
        input->GradVarBase() && !input->OverridedStopGradient(), true,
        platform::errors::PermissionDenied(
            "Variable %s inside inputs has no gradient", input->Name()));
    PADDLE_ENFORCE_EQ(
        in_set.count(input.get()), 0,
        platform::errors::AlreadyExists("inputs contain duplicate variable %s",
                                        input->Name()));
    in_set.insert(input.get());
    input_target_grads_.insert(input->GradVarBase()->SharedVar().get());

    PADDLE_ENFORCE_EQ(IsValidGradVar(input->GradVarBase()->SharedVar()), true,
                      platform::errors::PermissionDenied(
                          "inputs contain var that is inside no_grad_set"));

    // Record same vars between inputs and outputs
    if (out_set.count(input.get()) > 0) {
      one_grad_vars.insert(input->GradVarBase()->SharedVar().get());
    }
  }

  std::unordered_set<VariableWrapper *> related_grad_vars;
  GetGraphInfoBetweenTargets(&input_target_grads_, &out_set, &startup_ops_,
                             &pending_ops_, &op_deps_, &related_grad_vars,
                             no_grad_var_grad_);

  for (auto &op_pair : pending_ops_) {
    auto *op = op_pair.first;
    PrepareInitialReadyVarsMap(op);
    PrepareInitialGradientAccumulators(op);
  }

  for (auto &input_grad : input_target_grads_) {
    ready_grad_vars_.SetTarget(input_grad);
  }

  for (auto &one_grad : one_grad_vars) {
    VLOG(10) << "Add same in/out target " << one_grad->Name();
    input_target_grads_.insert(one_grad);
    ready_grad_vars_.SetTarget(one_grad);
  }

  VLOG(10) << "Valid op number " << pending_ops_.size();

  if (!output_grads.empty()) {
    PADDLE_ENFORCE_EQ(output_targets.size(), output_grads.size(),
                      platform::errors::InvalidArgument(
                          "grad_outputs number should be equal to outputs"));
  }

  for (size_t i = 0; i < output_targets.size(); ++i) {
    auto *mapped_out_grad_var =
        output_targets[i]->GradVarBase()->SharedVar().get();

    if (related_grad_vars.count(mapped_out_grad_var) == 0 &&
        one_grad_vars.count(mapped_out_grad_var) == 0) {
      VLOG(10) << mapped_out_grad_var->Name() << " should be None";
      continue;
    }

    std::shared_ptr<VariableWrapper> out_grad_var;
    bool unchange_input = false;
    if (output_grads.empty() || output_grads[i] == nullptr) {
      VLOG(10) << "Fill 1.0f for " << output_targets[i]->Name();
      out_grad_var = std::make_shared<VariableWrapper>(
          framework::GradVarName(output_targets[i]->Name()));
      FillConstantLike(*(output_targets[i]->SharedVar()), out_grad_var.get(),
                       place_, 1.0f);
    } else {
      VLOG(10) << "Use user provided grad var for "
               << output_targets[i]->Name();
      const auto &out_tensor =
          output_targets[i]->Var().Get<framework::LoDTensor>();
      const auto &grad_tensor =
          output_grads[i]->Var().Get<framework::LoDTensor>();
      PADDLE_ENFORCE_EQ(
          grad_tensor.dims(), out_tensor.dims(),
          platform::errors::InvalidArgument(
              "The %d-th grad_output's shape does not match the %d-th output",
              i, i));
      PADDLE_ENFORCE_EQ(grad_tensor.type(), out_tensor.type(),
                        platform::errors::InvalidArgument(
                            "The %d-th grad_output's data type does not "
                            "match the %d-th output",
                            i, i));
      out_grad_var = output_grads[i]->SharedVar();
      PADDLE_ENFORCE_EQ(IsValidGradVar(out_grad_var), true,
                        platform::errors::PermissionDenied(
                            "grad_outputs contain var inside no_grad_set"));

      if (out_grad_var->OverridedStopGradient()) {
        VLOG(10) << "Grad var " << out_grad_var->Name()
                 << " should reset stop gradient";
        reset_stop_gradient_vars_.emplace_back(out_grad_var);
      }

      unchange_input = true;
    }

    out_grad_var->SetOverridedStopGradient(false);
    auto grad_accumulator_iter = grad_accumulators_.find(mapped_out_grad_var);
    if (grad_accumulator_iter == grad_accumulators_.end()) {
      ready_grad_vars_.Set(mapped_out_grad_var,
729
                           std::make_shared<VarBase>(out_grad_var));
730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747
      VLOG(10) << "Fill 1.0f or user-provided gradient as ready var "
               << out_grad_var->Name();
    } else {
      auto &accumulator = grad_accumulator_iter->second;
      accumulator->IncreaseTotalRefCnt();
      bool is_finished = false;
      accumulator->SumGradient(out_grad_var, 0, &is_finished, unchange_input);
      PADDLE_ENFORCE_EQ(
          is_finished, false,
          platform::errors::Fatal("gradient accumulator should not finish"));
      VLOG(10) << "Add 1.0f or user-provided gradient to gradient accumulator"
               << out_grad_var->Name();
    }
  }
}

std::vector<std::shared_ptr<VarBase>> PartialGradTask::Run() {
  VLOG(10) << "Startup op number " << startup_ops_.size();
Z
Zeng Jinle 已提交
748
  std::queue<OpBase *> q;
749 750 751 752 753 754 755
  for (auto *op : startup_ops_) {
    q.push(op);
  }

  while (!q.empty()) {
    auto *op = q.front();
    q.pop();
Z
Zeng Jinle 已提交
756

757
    VLOG(10) << "Start to run " << op->Type();
Z
Zeng Jinle 已提交
758
    op->EnforceHasInOut();
759
    RunEachOp(op);
Z
Zeng Jinle 已提交
760 761 762
    if (!retain_graph_) {
      op->ClearBackwardTrace();
    }
763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787
    VLOG(10) << "End to run " << op->Type();

    auto iter = pending_ops_.find(op);
    if (iter == pending_ops_.end()) {
      VLOG(10) << "Finish running because " << op->Type()
               << " has no pending ops";
      continue;
    }

    for (auto &pending_op : iter->second) {
      auto dep_iter = op_deps_.find(pending_op);
      PADDLE_ENFORCE_EQ(
          dep_iter != op_deps_.end(), true,
          platform::errors::Fatal("Dependency number of %s does not exist",
                                  pending_op->Type()));
      if (--(dep_iter->second) == 0) {
        q.push(pending_op);
      }
    }
  }

  VLOG(10) << "Created " << double_grad_nodes_.size() << " double grad ops";
  return CreateResult();
}

Z
Zeng Jinle 已提交
788
void PartialGradTask::RunEachOp(OpBase *op) {
789 790 791 792 793 794 795 796 797
  // Prepare new inputs
  NameVarMap<VarBase> tmp_ins;
  for (auto &input_pair : op->GetInsMap()) {
    auto &new_inputs = tmp_ins[input_pair.first];
    new_inputs.reserve(input_pair.second.size());

    if (!input_pair.second.IsGrad()) {
      for (auto &fwd_var : input_pair.second) {
        if (fwd_var) {
798
          new_inputs.emplace_back(new VarBase(fwd_var));
799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827
          VLOG(10) << "Unpacked forward var " << fwd_var->Name()
                   << ", grad ops: " << GradOpTypes(*new_inputs.back());
        } else {
          new_inputs.emplace_back();
        }
      }
    } else {
      for (auto &grad_var : input_pair.second) {
        if (grad_var) {
          bool is_last;
          new_inputs.emplace_back(
              ready_grad_vars_.Get(grad_var.get(), op->place(), &is_last));
          VLOG(10) << "Got ready grad var " << grad_var->Name() << " "
                   << new_inputs.back().get();
        } else {
          new_inputs.emplace_back();
        }
      }
    }
  }

  // Prepare new outputs
  NameVarMap<VarBase> tmp_outs;
  for (auto &output_pair : op->GetOutsMap()) {
    auto &new_outputs = tmp_outs[output_pair.first];
    if (!output_pair.second.IsGrad()) {
      for (auto &fwd_var : output_pair.second) {
        // unpack forward var
        if (fwd_var) {
828
          new_outputs.emplace_back(new VarBase(fwd_var));
829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892
          VLOG(10) << "Unpacked forward var " << fwd_var->Name();
        } else {
          new_outputs.emplace_back();
        }
      }
    } else {
      for (auto &grad_var : output_pair.second) {
        if (IsValidGradVar(grad_var)) {
          VLOG(10) << "Creating output grad var " << grad_var->Name();
          auto new_grad_var_iter = grad_accumulators_.find(grad_var.get());
          PADDLE_ENFORCE_EQ(new_grad_var_iter != grad_accumulators_.end(), true,
                            platform::errors::Fatal(
                                "Cannot find gradient accumulator of %s %p",
                                grad_var->Name(), grad_var.get()));

          auto new_grad_var = std::make_shared<VarBase>(true, grad_var->Name());
          new_grad_var->SetOverridedStopGradient(false);
          if (new_grad_var_iter->second->TotalRefCnt() > 1) {
            grads_to_accumulate_.emplace_back(new_grad_var_iter->second.get(),
                                              new_grad_var->SharedVar());
          } else {
            PADDLE_ENFORCE_EQ(
                new_grad_var_iter->second->GradVar(), nullptr,
                platform::errors::AlreadyExists(
                    "When reference count is 1, the grad var should not be "
                    "created in gradient accumulator"));
            grad_accumulators_.erase(new_grad_var_iter);
            ready_grad_vars_.Set(grad_var.get(), new_grad_var);
          }
          VLOG(10) << "Created output grad var " << grad_var->Name();
          new_outputs.emplace_back(std::move(new_grad_var));
        } else {
          new_outputs.emplace_back();
        }
      }
    }
  }

  // Run op
  OpBase::Run(op->InnerOp(), tmp_ins, tmp_outs, op->Attrs(), op->place());

  if (create_graph_) {
    auto double_grad_node = CreateGradOpNode(op->InnerOp(), tmp_ins, tmp_outs,
                                             op->Attrs(), op->place());
    if (double_grad_node) {
      VLOG(10) << "Create " << double_grad_node->size()
               << " double grad op(s) for " << op->Type()
               << ", pending ops: " << GradPendingOpTypes(*double_grad_node);
      double_grad_nodes_.emplace_back(std::move(double_grad_node));
    }
  }

  VLOG(10) << "There are " << grads_to_accumulate_.size() << " to sum gradient";

  // Gradient accumulation and add assign op
  for (auto &pair : grads_to_accumulate_) {
    auto *accumulator_info = pair.first;
    auto &grad_var = pair.second;

    bool is_finished = false;
    VLOG(10) << "Start to sum " << accumulator_info->MappedGradVar()->Name();
    auto partial_grad_grads = accumulator_info->SumGradient(
        std::move(grad_var), op->id(), &is_finished);

893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921
    if (!partial_grad_grads.empty()) {
      auto sum_grad_var_grad =
          accumulator_info->GradVarBase()->MutableGradVarBase();
      sum_grad_var_grad->SetOverridedStopGradient(false);

      auto assign_node = std::make_shared<GradOpNode>();
      sum_grad_var_grad->SetGradNode(assign_node);

      VLOG(10) << "Add " << partial_grad_grads.size() << " assign op for "
               << sum_grad_var_grad->Name();

      for (auto &grad_grad : partial_grad_grads) {
        auto *assign_op = &(assign_node->emplace_back());
        assign_op->SetType("assign");  // Can use "scale" as static graph mode
        assign_op->SetInput("X", {sum_grad_var_grad->SharedVar()}, true);
        assign_op->SetOutput("Out", {grad_grad}, true);
        assign_op->CheckAttrs();
        assign_op->SetId(OpBase::GenerateUniqueId());
        assign_op->SetPlace(op->place());

        if (auto grad_pending_node = grad_grad->GetGradNode()) {
          assign_node->InsertGradPendingNode(std::move(grad_pending_node));
        }
      }
      VLOG(10) << "Pending ops of assign is "
               << GradPendingOpTypes(*assign_node);
      double_grad_nodes_.emplace_back(assign_node);
    }

922 923 924 925 926 927
    if (is_finished) {
      VLOG(10) << "Sum has finished for "
               << accumulator_info->MappedGradVar()->Name() << " "
               << accumulator_info->GradVarBase();
      ready_grad_vars_.Set(accumulator_info->MappedGradVar(),
                           accumulator_info->GradVarBase());
928
      grad_accumulators_.erase(accumulator_info->MappedGradVar());
929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974
    }
  }

  grads_to_accumulate_.clear();
}

void PartialGradTask::PrepareInitialReadyVarsMap(const OpBase *op) {
  for (auto &in_var_pair : op->GetInsMap()) {
    if (!in_var_pair.second.IsGrad()) {
      continue;
    }

    for (auto &var : in_var_pair.second) {
      if (var) {
        ready_grad_vars_.IncreaseRefCnt(var.get());
      }
    }
  }
}

void PartialGradTask::PrepareInitialGradientAccumulators(const OpBase *op) {
  for (auto &out_var_pair : op->GetOutsMap()) {
    if (!out_var_pair.second.IsGrad()) {
      continue;
    }

    for (auto &var : out_var_pair.second) {
      if (var == nullptr) {
        continue;
      }

      auto &accumulator = grad_accumulators_[var.get()];

      if (!accumulator) {
        accumulator.reset(new GradientAccumulationInfo(
            var, strategy_.sorted_sum_gradient_, create_graph_));
      }

      accumulator->IncreaseTotalRefCnt();
    }
  }
}

std::vector<std::shared_ptr<VarBase>> PartialGradTask::CreateResult() {
  std::vector<std::shared_ptr<VarBase>> result;
  result.reserve(input_targets_.size());
Z
Zeng Jinle 已提交
975 976
  for (size_t i = 0; i < input_targets_.size(); ++i) {
    auto &input_target = input_targets_[i];
977 978 979 980 981 982 983 984 985 986
    PADDLE_ENFORCE_NOT_NULL(
        input_target->GradVarBase(),
        platform::errors::InvalidArgument("input should have gradient"));
    auto *original_grad_var = input_target->GradVarBase()->SharedVar().get();
    auto iter = input_target_grads_.find(original_grad_var);
    if (iter != input_target_grads_.end()) {
      auto ready_var = ready_grad_vars_.GetTarget(original_grad_var);
      ready_var->SetOverridedStopGradient(!create_graph_);
      result.emplace_back(std::move(ready_var));
    } else {  // return None if it does not appear in the graph
Z
Zeng Jinle 已提交
987 988 989 990 991 992
      PADDLE_ENFORCE_EQ(allow_unused_, true,
                        platform::errors::InvalidArgument(
                            "The %d-th input does not appear in the backward "
                            "graph. Please check the input variable or set "
                            "allow_unused=True to get None result.",
                            i));
993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016
      result.emplace_back();
    }
  }

  for (auto &weak_var : reset_stop_gradient_vars_) {
    if (auto var = weak_var.lock()) {
      VLOG(10) << "Reset " << var->Name() << " stop gradient";
      var->SetOverridedStopGradient(!var->OverridedStopGradient());
    }
  }

  ready_grad_vars_.Clear();
  grad_accumulators_.clear();
  double_grad_nodes_.clear();
  reset_stop_gradient_vars_.clear();
  return result;
}

PartialGradEngine::PartialGradEngine(
    const std::vector<std::shared_ptr<VarBase>> &input_targets,
    const std::vector<std::shared_ptr<VarBase>> &output_targets,
    const std::vector<std::shared_ptr<VarBase>> &output_grads,
    const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
    const platform::Place &place, const detail::BackwardStrategy &strategy,
Z
Zeng Jinle 已提交
1017
    bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs)
1018 1019 1020 1021 1022
    : task_(new PartialGradTask(input_targets, output_targets, output_grads,
                                no_grad_vars, place, strategy, create_graph,
                                retain_graph, allow_unused, only_inputs)) {}

PartialGradEngine::~PartialGradEngine() { Clear(); }
1023 1024 1025 1026 1027 1028

std::vector<std::shared_ptr<VarBase>> PartialGradEngine::GetResult() const {
  return results_;
}

void PartialGradEngine::Clear() {
1029 1030 1031 1032
  if (task_) {
    delete task_;
    task_ = nullptr;
  }
1033 1034 1035
}

void PartialGradEngine::Execute() {
1036 1037
  PADDLE_ENFORCE_NOT_NULL(task_, platform::errors::PermissionDenied(
                                     "PartialGradEngine has been destructed"));
1038
  VLOG(10) << "Starts to execute PartialGradEngine";
1039
  results_ = task_->Run();
1040 1041 1042 1043 1044
  Clear();
}

}  // namespace imperative
}  // namespace paddle