autograd_meta.h 6.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include "paddle/fluid/eager/api/utils/global_utils.h"
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
#include "paddle/fluid/eager/grad_node_info.h"
namespace egr {

using AbstractAutogradMeta = paddle::experimental::AbstractAutogradMeta;
/**
 *
 * AutogradMeta is what record the backward info for tensor. When we run
 * computation
 * graph eagerly, we can not build a static paddle program like static mode do,
 * so we
 * need a new method to record forward info to trace backward when we finish all
 * forward
 * computation. This require our AutogradMeta class record following main
 * members
 *
 * 1. grad_op:
 * Grad_op indicate the grad operation of the forward op
 *
 * 2. grad:
 * Grad is the gradient of forward Tensor, which should be compute after
 * backward computation
 *
 * NOTE: grad should only be available when current tensor is a leaf tensor, and
 * for non-leaf
 * tensor grad is only available while user set `retain_grad` option as `true`.
 *
 * TODO(jiabin) : support hooks
 * 3. hooks:
 * Hooks are some computation logic which only attached with backward operation,
 * it registered
 * by user and run before accumulator.
 *
 * 4.overrided_stop_gradient_
 * This member is used to finish some auto-prune related work, which indicate
 * user set stop_gradient
 * should overrided the result indicated by framework. All non-parameter
 * tensor's stop_gradient
 * properties should be true. We will pass stop_gradient when we find one who
 * need it.
 *
 * NOTE: AutogradMeta is inherited from AbstractAutogradMeta which is defined
 * in tensor's deps,
 * we did this to avoid additional dependency on Autograd. In eager execution,
 * we will cast
 * AbstractAutogradMeta as AutogradMeta to use it.
 *
 * **/

// No other AutogradMeta class should be derivated from AbstractAutogradMeta.
// It's only used by
class AutogradMeta : public AbstractAutogradMeta {
 public:
  explicit AutogradMeta(const Edge& edge = Edge()) {
    out_slot_id_ = edge.GetEdgeRankInfo().first;
    out_rank_ = edge.GetEdgeRankInfo().second;
    grad_node_ = edge.GetMutableGradNode();
  }

  ~AutogradMeta() override = default;

78
  const paddle::experimental::Tensor& Grad() const {
79 80 81 82
    PADDLE_ENFORCE_NOT_NULL(
        grad_.get(),
        paddle::platform::errors::InvalidArgument(
            "Should Not get NULL from Grad pointer, since "
83
            "we should have default Tensor once we init AutoGradMeta. "
84 85 86 87 88
            "if you got this error may indicates framework error in "
            "PaddlePaddle"));
    return *(grad_.get());
  }

89
  paddle::experimental::Tensor* MutableGrad() { return grad_.get(); }
90

91
  std::weak_ptr<paddle::experimental::Tensor> WeakGrad() { return grad_; }
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129

  void SetGradNode(const std::shared_ptr<GradNodeBase>& grad_node) {
    PADDLE_ENFORCE_NOT_NULL(
        grad_node.get(),
        paddle::platform::errors::InvalidArgument(
            "Should Not set NULL as GradNode pointer, since "
            "our default Edge and autogradMeta has nullptr for "
            "grad node. Set Nullptr will lead error."));
    grad_node_ = grad_node;
  }

  std::shared_ptr<GradNodeBase> GetMutableGradNode() const {
    return grad_node_;
  }

  GradNodeBase* GradNode() const { return grad_node_.get(); }

  void SetSingleOutRankWithSlot(size_t slot_id, size_t rank) {
    out_slot_id_ = slot_id;
    out_rank_ = rank;
  }

  std::pair</* slot id */ size_t, /* rank in slot */ size_t> OutRankInfo()
      const {
    return std::make_pair(out_slot_id_, out_rank_);
  }

  bool IsInitialized() { return grad_node_.get(); }

  // TODO(jiabin): This may cause error, since -1 still can indication true;
  bool StopGradient() const { return stop_gradient_ != 0; }

  int NumericStopGradient() const { return stop_gradient_; }

  void SetStopGradient(bool stop_gradient) {
    stop_gradient_ = static_cast<int>(stop_gradient);
  }

130 131 132 133 134 135
  void WeakSetStopGradient(bool stop_gradient) {
    if (stop_gradient_ == -1) {
      stop_gradient_ = static_cast<int>(stop_gradient);
    }
  }

136 137 138 139
  bool Persistable() const { return persistable_; }

  void SetPersistable(bool persistable) { persistable_ = persistable; }

140 141 142 143
  bool RetainGrads() { return retain_grads_; }

  void SetRetainGrads(bool value) { retain_grads_ = value; }

144 145
 private:
  // TODO(jiabin) :Should we use pointer instead of object?
146 147 148
  std::shared_ptr<paddle::experimental::Tensor> grad_{
      std::make_shared<paddle::experimental::Tensor>(
          egr::Controller::Instance().GenerateUniqueName("@grad"))};
149 150 151 152

  // GradNodeBase is base class of all grad op which is a
  // wrapper for grad op. This class will make grad op easy
  // to be traced.
153
  std::shared_ptr<GradNodeBase> grad_node_ = nullptr;
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174

  /**
   * Why we need slot id here?
   * Because in paddle most of our operators inputs and outputs
   * are assemble in form of {"slot name", vector<tensor>}.
   * So its better for us to set a slot id to fit this format. **/
  size_t out_slot_id_;

  // output rank of forward op, this is a vital num, since
  // we are now trying to make our forward output is as same
  // sequence as backward input. In case of tracing backward
  // sequence we need to record output rank in slot here.
  size_t out_rank_;

  // TODO(jiabin) :Support hooks here and store it in AutogradMeta

  // Stop gradient flag to indicate should we compute backward
  int stop_gradient_{-1};

  bool persistable_{false};

175 176
  bool retain_grads_{false};

177 178 179 180
  // TODO(jiabin) :Support Quantum here and add cache mechanism as
  // VarCache defined in VarBase
};
}  // namespace egr