basic_engine.h 2.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include <unordered_map>
19
#include <unordered_set>
20 21 22 23 24 25 26 27 28 29 30 31 32
#include <utility>
#include <vector>
#include "paddle/fluid/imperative/engine.h"
#include "paddle/fluid/imperative/gradient_accumulator.h"

namespace paddle {
namespace imperative {

class VarBase;
class OpBase;

class BasicEngine : public Engine {
 public:
33
  void Init(VarBase* var, bool retain_graph = false);
34 35 36 37 38 39 40 41

  void Execute() override;

 private:
  void PrepareDeps();

  void CheckBackwardInputs(const OpBase& op);

42 43 44
  void PrepareGradAccumulators(
      const OpBase& op,
      const std::vector<std::shared_ptr<GradOpNode>>& grad_pending_nodes);
45 46 47 48 49 50

  void Clear();

 private:
  std::shared_ptr<GradOpNode> init_node_;
  std::unordered_map<GradOpNode*, size_t> node_deps_;
51 52 53 54 55 56 57 58 59 60 61
  // The input and output of Inplace op are the same. If only `var` is used
  // as the key, then the input and output of inplace op must be gradient
  // accumulated. Therefore, add the `grad_node` as the key to prevent the
  // problem of gradient accumulation in inplace op.
  std::unordered_map<std::shared_ptr<GradOpNode>,
                     std::unordered_map<VariableWrapper*,
                                        std::unique_ptr<GradientAccumulator>>>
      accumulators_with_grad_node_;
  // Leaf var doesn't have grad_node, and leaf var with `stop_gradient=False`
  // can't use Inplace strategy. If a var doesn't have grad_node, only use
  // `var` as the key.
62 63
  std::unordered_map<VariableWrapper*, std::unique_ptr<GradientAccumulator>>
      accumulators_;
64 65 66 67 68
  // The output grad var of Inplace grad op. Because Inplace grad op does not
  // use the Inplace strategy, a new output grad var needs to be created.
  std::vector<std::pair<std::shared_ptr<VariableWrapper>,
                        std::shared_ptr<VariableWrapper>>>
      inplace_output_grad_var_list_;
69 70
  std::vector<std::pair<GradientAccumulator*, std::shared_ptr<VariableWrapper>>>
      need_accu_var_list_;
71 72 73
  // leaf_accumulators_ is only for leaf tensor(hooks/accumulate grad)
  std::unordered_set<GradientAccumulator*> leaf_accumulators_;

74
  bool retain_graph_;
75 76 77 78
};

}  // namespace imperative
}  // namespace paddle