autograd_meta.h 5.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include "paddle/fluid/eager/api/utils/global_utils.h"
18 19 20 21 22 23 24
#include "paddle/fluid/eager/grad_node_info.h"
namespace egr {

using AbstractAutogradMeta = paddle::experimental::AbstractAutogradMeta;
/**
 *
 * AutogradMeta is what record the backward info for tensor. When we run
25 26 27 28
 * computation graph eagerly, we can not build a static paddle program like
 * static mode do, so we need a new method to record forward info to trace
 * backward when we finish all forward computation. This require our
 * AutogradMeta class record following main members
29 30 31 32 33 34 35 36 37
 *
 * 1. grad_op:
 * Grad_op indicate the grad operation of the forward op
 *
 * 2. grad:
 * Grad is the gradient of forward Tensor, which should be compute after
 * backward computation
 *
 * NOTE: grad should only be available when current tensor is a leaf tensor, and
38 39
 * for non-leaf tensor grad is only available while user set `retain_grad`
 * option as `true`.
40 41 42 43
 *
 * TODO(jiabin) : support hooks
 * 3. hooks:
 * Hooks are some computation logic which only attached with backward operation,
44
 * it registered by user and run before accumulator.
45
 *
46
 * 4. overrided_stop_gradient_
47
 * This member is used to finish some auto-prune related work, which indicate
48 49 50
 * user set stop_gradient should overrided the result indicated by framework.
 * All non-parameter tensor's stop_gradient properties should be true. We will
 * pass stop_gradient when we find one who need it.
51 52
 *
 * NOTE: AutogradMeta is inherited from AbstractAutogradMeta which is defined
53 54 55
 * in tensor's deps, we did this to avoid additional dependency on Autograd.
 * In eager execution, we will cast AbstractAutogradMeta as AutogradMeta to use
 * it.
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
 *
 * **/

// No other AutogradMeta class should be derivated from AbstractAutogradMeta.
// It's only used by
class AutogradMeta : public AbstractAutogradMeta {
 public:
  explicit AutogradMeta(const Edge& edge = Edge()) {
    out_slot_id_ = edge.GetEdgeRankInfo().first;
    out_rank_ = edge.GetEdgeRankInfo().second;
    grad_node_ = edge.GetMutableGradNode();
  }

  ~AutogradMeta() override = default;

71
  const paddle::experimental::Tensor& Grad() const {
72 73 74 75
    PADDLE_ENFORCE_NOT_NULL(
        grad_.get(),
        paddle::platform::errors::InvalidArgument(
            "Should Not get NULL from Grad pointer, since "
76
            "we should have default Tensor once we init AutoGradMeta. "
77 78 79 80 81
            "if you got this error may indicates framework error in "
            "PaddlePaddle"));
    return *(grad_.get());
  }

82
  paddle::experimental::Tensor* MutableGrad() { return grad_.get(); }
83

84
  std::weak_ptr<paddle::experimental::Tensor> WeakGrad() { return grad_; }
85 86 87 88 89 90 91 92

  void SetGradNode(const std::shared_ptr<GradNodeBase>& grad_node) {
    PADDLE_ENFORCE_NOT_NULL(
        grad_node.get(),
        paddle::platform::errors::InvalidArgument(
            "Should Not set NULL as GradNode pointer, since "
            "our default Edge and autogradMeta has nullptr for "
            "grad node. Set Nullptr will lead error."));
93

94 95 96 97 98 99 100 101 102
    grad_node_ = grad_node;
  }

  std::shared_ptr<GradNodeBase> GetMutableGradNode() const {
    return grad_node_;
  }

  GradNodeBase* GradNode() const { return grad_node_.get(); }

103 104
  void ResetGradNode() { grad_node_.reset(); }

105 106 107 108 109 110 111 112 113 114
  void SetSingleOutRankWithSlot(size_t slot_id, size_t rank) {
    out_slot_id_ = slot_id;
    out_rank_ = rank;
  }

  std::pair</* slot id */ size_t, /* rank in slot */ size_t> OutRankInfo()
      const {
    return std::make_pair(out_slot_id_, out_rank_);
  }

115
  bool IsInitialized() const { return grad_node_.get(); }
116 117 118 119 120 121 122 123 124 125 126 127 128 129

  // TODO(jiabin): This may cause error, since -1 still can indication true;
  bool StopGradient() const { return stop_gradient_ != 0; }

  int NumericStopGradient() const { return stop_gradient_; }

  void SetStopGradient(bool stop_gradient) {
    stop_gradient_ = static_cast<int>(stop_gradient);
  }

  bool Persistable() const { return persistable_; }

  void SetPersistable(bool persistable) { persistable_ = persistable; }

130
  bool RetainGrads() const { return retain_grads_; }
131 132 133

  void SetRetainGrads(bool value) { retain_grads_ = value; }

134 135
 private:
  // TODO(jiabin) :Should we use pointer instead of object?
136
  std::shared_ptr<paddle::experimental::Tensor> grad_{
137
      std::make_shared<paddle::experimental::Tensor>()};
138 139 140 141

  // GradNodeBase is base class of all grad op which is a
  // wrapper for grad op. This class will make grad op easy
  // to be traced.
142
  std::shared_ptr<GradNodeBase> grad_node_ = nullptr;
143 144 145

  /**
   * Why we need slot id here?
146
   * Because in paddle most of operators, inputs and outputs
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
   * are assemble in form of {"slot name", vector<tensor>}.
   * So its better for us to set a slot id to fit this format. **/
  size_t out_slot_id_;

  // output rank of forward op, this is a vital num, since
  // we are now trying to make our forward output is as same
  // sequence as backward input. In case of tracing backward
  // sequence we need to record output rank in slot here.
  size_t out_rank_;

  // TODO(jiabin) :Support hooks here and store it in AutogradMeta

  // Stop gradient flag to indicate should we compute backward
  int stop_gradient_{-1};

  bool persistable_{false};

164 165
  bool retain_grads_{false};

166 167 168 169
  // TODO(jiabin) :Support Quantum here and add cache mechanism as
  // VarCache defined in VarBase
};
}  // namespace egr