grad_node_info.h 11.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17 18
#include <memory>

19
#include "paddle/fluid/eager/eager_tensor.h"
20
#include "paddle/fluid/eager/hooks.h"
21
#include "paddle/phi/api/all.h"
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61

namespace egr {
/**
 * GradNodeBase is base class of all grad node, which is what should be used by
 * eager execution, we define most of backward autograd members here, and for
 * each Operator, they should hold their onw forward Inputs as TensorWrapper.
 *
 * The GradNodeBase will be held in autograd_meta, and it is also a member of
 * Edge, which indicates the edge of backward graph.
 *
 * TODO:(yangzhanlue) GradNodeBase will also in charge of get the correct input
 * from GradOpDescMaker to GradNodeBase.
 *
 * NOTE:GradNodeBase has a method named run, this method should be overrided by
 * the
 * specific derived class, it will prepare backward inputs and double backward's
 * depends. Then, it will call C++ API of backward kernel functions to finish
 * backward computation.
 *
 * NOTE:GradNodeBase holds its own inputs and Outputs
 *
 * Edge is defined to descripe depend of backward, an Edge is what linked
 * between two
 * node, it should contain a Node and rank of this Node (this is used to
 * indicate which
 * input of grad this edge belong).
 * */
class Edge;
class AutogradMeta;

/**
 * GradSlotMeta is used to Record Forward Tensor info to backward, since paddle
 * has lots of operators
 * whose backward logic is depends on if it has some specific inputs or outputs.
 * So, we need a meta info
 * to record it's needs.
 * **/
class GradSlotMeta {
 public:
  GradSlotMeta() = default;
62 63 64
  bool IsStopGradient() const { return stop_gradient_; }
  void SetStopGradient(bool stop_gradient = true) {
    stop_gradient_ = stop_gradient;
65 66
  }

67 68 69 70 71 72 73 74 75 76 77 78
  void SetTensorMeta(const phi::DenseTensorMeta& meta) {
    meta_ = std::make_shared<phi::DenseTensorMeta>(meta);
  }
  bool HasTensorMeta() const { return meta_ && meta_.get(); }
  const phi::DenseTensorMeta& GetTensorMeta() const {
    if (!HasTensorMeta()) {
      PADDLE_THROW(paddle::platform::errors::Fatal(
          "meta_ of GradSlotMeta has not been initialized yet."
          "You're expected to check Edge availability with HasTensorMeta()"
          "before calling GetTensorMeta() interface."));
    }
    return *meta_.get();
79 80
  }

81 82 83
  void SetPlace(const phi::Place& place) { place_ = place; }
  const phi::Place& GetPlace() const { return place_; }

84
 private:
85
  bool stop_gradient_{false};
86
  phi::Place place_;
87
  std::shared_ptr<phi::DenseTensorMeta> meta_ = nullptr;
88 89
};

90
class GradNodeBase : public std::enable_shared_from_this<GradNodeBase> {
91
 public:
J
Jiabin Yang 已提交
92
  GradNodeBase() { VLOG(6) << "Construct GradNodeBase"; }
93 94
  GradNodeBase(size_t bwd_in_slot_num, size_t bwd_out_slot_num);
  // TODO(jiabin): Should we have other constructor here?
J
Jiabin Yang 已提交
95
  virtual ~GradNodeBase() { VLOG(6) << "Destruct GradNodeBase"; }
96 97 98 99 100 101 102 103 104

  /**
   * operator() designed to contian the real backward execution logic, it should
   * be
   * overrided by derived class defined for each operator. It accepts a vector
   * of
   * Tensor which contains grads input of current operator
   *
   * Note: why we need backward inputs and outputs construct as vector of vector
105
   * of paddle::experimental::Tensor?
106 107 108 109
   * Since all of paddle op composite in form of {"Slot name ", vector<Var>},
   * so, vector of vector
   * is better choice to fit this format.
   * **/
110
  virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
111
      std::vector<std::vector<paddle::experimental::Tensor>>& grads,  // NOLINT
112
      bool create_graph = false) = 0;
113

114 115
  virtual void ClearTensorWrappers() = 0;

116 117 118 119 120
  /**
       * Self-Copy interface designed for use in DoubleGrad
       * **/
  virtual std::shared_ptr<GradNodeBase> Copy() const = 0;

121 122 123 124 125 126 127 128
  /**
   * AddEdges is designed to set input tensors' backward Node as current
   * node's Edges.
   * This method should be call in forward code and for double backward depends
   * computation.
   *
   * This one is called slot by slot
   * **/
129 130
  void AddEdges(std::vector<AutogradMeta*>* metas, size_t slot_id);
  void AddEdges(AutogradMeta* meta, size_t slot_id);
131

132 133 134 135
  // adj_edges were moved inside OutputMeta(), so no available direct access
  // from GradNodeBase.
  // To access Edges, get GradSlotMeta by calling OutputMeta(), then use
  // slot_meta.GetEdge()
136 137 138

  /**
   * Get Input Meta of current Grad node**/
139
  const std::vector<std::vector<GradSlotMeta>>& InputMeta() const;
140 141
  /**
   * Get Output Meta of current Grad node**/
142
  const std::vector<std::vector<GradSlotMeta>>& OutputMeta() const;
143 144 145 146
  /**
   * Set bwd ins and outs info with forward vars
   * **/

147 148 149 150
  void SetGradInMeta(const std::vector<paddle::experimental::Tensor>& fwd_out,
                     size_t slot_rank);
  void SetGradInMeta(const paddle::experimental::Tensor& fwd_out,
                     size_t slot_rank);
151

152 153 154 155
  void SetGradOutMeta(const std::vector<paddle::experimental::Tensor>& fwd_in,
                      size_t slot_rank);
  void SetGradOutMeta(const paddle::experimental::Tensor& fwd_in,
                      size_t slot_rank);
156 157 158 159 160 161
  /**
   * Default setters for Grad in/out meta this should be used for same special
   * Node which will not create by user
   * **/
  void SetDefaultGradInOutMeta();
  /**
162
   * Register GradientHook
163
   * **/
164 165 166 167 168 169 170 171 172 173 174 175 176
  int64_t RegisterGradientHook(size_t slot_id, size_t rank,
                               std::shared_ptr<egr::TensorHook>&& hook);

  /**
  * Remove GradientHook
  * **/
  bool RemoveGradientHook(const int64_t& hook_id) {
    auto remove_cnt = gradient_hooks_.erase(hook_id);
    if (remove_cnt == 0) {
      return false;
    }
    return true;
  }
177 178

  /**
179
   * Apply GradientHook
180
   * **/
181
  inline bool GradientHooksRegistered() { return !gradient_hooks_.empty(); }
182

183 184
  std::vector<std::vector<paddle::experimental::Tensor>> ApplyGradientHooks(
      const std::vector<std::vector<paddle::experimental::Tensor>>& tensors);
185

186 187 188 189 190 191 192
  /**
    * Handle Complex - Real Type Promotion
    * **/
  void HandleComplexGradToRealGrad(
      std::vector<std::vector<paddle::experimental::Tensor>>* out_grads);
  bool NeedComplexToRealConversion() { return need_complex_to_real_; }

193 194
  virtual std::string name() { return "GradNodeBase"; }

195 196 197
  /**
       * GetEdges is designed to get all edges of current node**/
  const std::vector<std::vector<Edge>>& GetEdges() const;
198 199 200 201 202 203 204 205 206 207
  std::vector<std::vector<Edge>>& GetMutableEdges();

  /**
       * The following interfaces are designed for no_need_buffer
       * **/
  bool IsTensorWrappersCleared() { return is_tensor_wrappers_cleared_; }

  void SetIsTensorWrappersCleared(bool is_tensor_wrappers_cleared) {
    is_tensor_wrappers_cleared_ = is_tensor_wrappers_cleared;
  }
208

209 210
 private:
  // TODO(zhanlve): Merge adj_edges_ into GradOutMeta
211 212 213 214 215 216 217
  // Edges recorded the backward related node info, which indicate all edges
  // linked
  // by this Grad Node.
  // Why we need vector<vector<Edge>>: Edges is as same rank as bwd output.
  std::vector<std::vector<Edge>> adj_edges_;

  // bwd_out_meta_ is used to record Grad output info for backward
218
  std::vector<std::vector<GradSlotMeta>> bwd_out_meta_;
219 220

  // bwd_in_meta_ used to record Grad input info for backward
221
  std::vector<std::vector<GradSlotMeta>> bwd_in_meta_;
222 223 224
  // Gradient Hooks
  // Customer may register a list of hooks which will be called in order during
  // backward
225 226 227 228 229
  // Each entry consists one pair of
  // <hook_id, <out_rank, std::shared_ptr<TensorHook>>>
  std::map<int64_t, std::tuple<
                        /* slot id */ size_t, /* rank */ size_t,
                        /* hook */ std::shared_ptr<TensorHook>>>
230
      gradient_hooks_;
231

232 233
  // We handle complex to real conversion only if any complex GradIn is involved
  bool need_complex_to_real_ = false;
234
  int64_t next_hook_id_{0};
235
  bool is_tensor_wrappers_cleared_ = false;
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
};

class Edge {
 public:
  // Default constructor for Edges in order to construct it for AutogradMeta
  Edge() : in_slot_id_(0), in_rank_(0), grad_node_(nullptr) {}

  // In real use cases we should create Edge from grad node and input rank which
  // indicate which edge it is.
  // Since we have slot design in operators we will have to locate an edge with
  // slot
  // and rank.
  Edge(const std::shared_ptr<GradNodeBase>& grad_node, size_t in_slot_id,
       size_t in_rank)
      : in_slot_id_(in_slot_id), in_rank_(in_rank), grad_node_(grad_node) {}

  Edge(const std::shared_ptr<GradNodeBase>& grad_node,
       const std::pair</* slot_id */ size_t, /* rank */ size_t>& rank_info)
      : in_slot_id_(rank_info.first),
        in_rank_(rank_info.second),
        grad_node_(grad_node) {}

  GradNodeBase* GetGradNode() const { return grad_node_.get(); }

  std::shared_ptr<GradNodeBase> GetMutableGradNode() const {
    return grad_node_;
  }

264 265 266 267 268
  void SetGradNode(const std::shared_ptr<GradNodeBase>& node) {
    VLOG(6) << "Reseting Edge's Grad Node";
    grad_node_ = node;
  }

269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
  std::pair<size_t, size_t> GetEdgeRankInfo() const {
    return std::make_pair(in_slot_id_, in_rank_);
  }

  void SetEdgeRankInfo(size_t slot_id, size_t in_rank) {
    in_slot_id_ = slot_id;
    in_rank_ = in_rank;
  }

  void SetEdgeRankInfo(
      const std::pair</* slot_id */ size_t, /* rank */ size_t>& edge_rank) {
    in_slot_id_ = edge_rank.first;
    in_rank_ = edge_rank.second;
  }

  // Currently we use grad_node_ to identify if a edge is initialized.
J
Jiabin Yang 已提交
285 286 287 288 289 290 291 292 293 294 295
  bool IsInitialized() const {
    if (!grad_node_) {
      return false;
    } else {
      if (!(grad_node_.get())) {
        return false;
      } else {
        return true;
      }
    }
  }
296 297 298 299

 private:
  size_t in_slot_id_;
  size_t in_rank_;
J
Jiabin Yang 已提交
300
  std::shared_ptr<GradNodeBase> grad_node_{nullptr};
301 302
};

303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
inline void CheckTensor(const paddle::experimental::Tensor& pre,
                        const paddle::experimental::Tensor& post) {
  if (!pre.initialized() && post.initialized()) {
    PADDLE_THROW(paddle::platform::errors::PermissionDenied(
        "The tensor in before and after hook are not consistent"));
  }
  if (pre.initialized() && post.initialized()) {
    VLOG(4) << paddle::framework::DataType2String(pre.dtype()) << " "
            << paddle::framework::DataType2String(post.dtype());
    PADDLE_ENFORCE_EQ(
        pre.dtype(), post.dtype(),
        paddle::platform::errors::PermissionDenied(
            "The dtype of tensor before(%s) and after(%s) hook are not "
            "consistent",
            paddle::framework::DataType2String(pre.dtype()),
            paddle::framework::DataType2String(post.dtype())));
    PADDLE_ENFORCE_EQ(
        pre.inner_place(), post.inner_place(),
        paddle::platform::errors::PermissionDenied(
            "The place of tensor before(%s) and after(%s) "
            "hook are not consistent",
            pre.inner_place().DebugString(), post.inner_place().DebugString()));
  }
}

328
}  // namespace egr