grad_node_info.cc 10.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/eager/grad_node_info.h"
16
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
17
#include "paddle/fluid/eager/autograd_meta.h"
18 19
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38

#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"

#include "glog/logging.h"

/**
 * Implementation of GradNodeBase, Edge and InputBuffer.
**/
namespace egr {

GradNodeBase::GradNodeBase(size_t bwd_in_slot_num, size_t bwd_out_slot_num) {
  bwd_in_meta_.resize(bwd_in_slot_num);
  bwd_out_meta_.resize(bwd_out_slot_num);
  // adj_edges has the same num as backward outputs
  adj_edges_.resize(bwd_out_slot_num);
}

39 40 41 42 43 44 45 46 47 48 49
void GradNodeBase::AddEdges(std::vector<AutogradMeta*>* metas, size_t slot_id) {
  PADDLE_ENFORCE_LT(
      slot_id, adj_edges_.size(),
      paddle::platform::errors::InvalidArgument(
          "Given slot id is out of range of adj_edges outter size, "
          "adj_edges is designed to has the same size of grad "
          "inputs's slot num."));
  for (const auto& meta : *metas) {
    // adj_edges has as same rank as fwd inputs, and record it's output rank
    // from
    // its pre-ops
50
    if (meta && !meta->StopGradient()) {
51 52 53 54 55
      auto node = meta->GetMutableGradNode();
      if (node) {
        adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
                                         meta->OutRankInfo());
      } else {
56
        meta->SetGradNode(std::make_shared<egr::GradNodeAccumulation>());
57 58 59
        adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
                                         meta->OutRankInfo());
      }
60
    }
61 62 63
  }
}

64
void GradNodeBase::AddEdges(AutogradMeta* meta, size_t slot_id) {
65 66 67 68 69 70
  PADDLE_ENFORCE_LT(
      slot_id, adj_edges_.size(),
      paddle::platform::errors::InvalidArgument(
          "Given slot id is out of range of adj_edges outter size, "
          "adj_edges is designed to has the same size of grad "
          "inputs's slot num."));
71 72
  if (meta && !meta->StopGradient()) {
    VLOG(6) << "Add Edges for slot: " << slot_id;
73 74 75 76 77
    auto node = meta->GetMutableGradNode();
    if (node) {
      adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
                                       meta->OutRankInfo());
    } else {
78 79 80
      meta->SetGradNode(std::make_shared<egr::GradNodeAccumulation>());
      adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
                                       meta->OutRankInfo());
81
    }
82
  }
83 84 85 86 87 88 89 90 91 92
}

const std::vector<GradSlotMeta>& GradNodeBase::InputMeta() const {
  return bwd_in_meta_;
}

const std::vector<GradSlotMeta>& GradNodeBase::OutputMeta() const {
  return bwd_out_meta_;
}

93
void GradNodeBase::SetGradInMeta(std::vector<AutogradMeta*>* fwd_out,
94
                                 size_t slot_rank) {
95
  size_t slot_size = fwd_out->size();
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
  PADDLE_ENFORCE_LE(
      slot_rank, (bwd_in_meta_.size() - 1),
      paddle::platform::errors::InvalidArgument(
          "Slot Rank should less equal than bwd_in_meta_ size, since "
          "bwd_in_meta_ is designed to hold as same num as backward "
          "inputs."));
  auto& meta = bwd_in_meta_.at(slot_rank);
  PADDLE_ENFORCE_EQ(meta.IsInitialized(), false,
                    paddle::platform::errors::PreconditionNotMet(
                        "Bwd_in_meta should only be init once, addition "
                        "initialization for it is forbidden. If you got this "
                        "error, it indicates bugs in framework."));
  // Init stop gradient vector before use to avoid push back
  meta.Init(slot_size);
  for (size_t i = 0; i < slot_size; i++) {
111
    PADDLE_ENFORCE_NOT_NULL((*fwd_out)[i],
112 113 114 115
                            paddle::platform::errors::PreconditionNotMet(
                                "Bwd_in_meta should only be called while "
                                "autograd_meta is not null. If you got this "
                                "error, it indicates bugs in framework."));
116
    if ((*fwd_out)[i]->StopGradient()) {
117 118
      // Set Stop Gradient only when its true or non-initialized autograd_meta,
      // since all default value is false.
119
      meta.SetStopGradient(i, (*fwd_out)[i]->StopGradient());
120 121 122 123
    }
  }
}

124
void GradNodeBase::SetGradInMeta(AutogradMeta* fwd_out, size_t slot_rank) {
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
  PADDLE_ENFORCE_LE(
      slot_rank, (bwd_in_meta_.size() - 1),
      paddle::platform::errors::InvalidArgument(
          "Slot Rank should less equal than bwd_in_meta_ size, since "
          "bwd_in_meta_ is designed to hold as same num as backward "
          "inputs."));
  auto& meta = bwd_in_meta_.at(slot_rank);
  PADDLE_ENFORCE_EQ(meta.IsInitialized(), false,
                    paddle::platform::errors::PreconditionNotMet(
                        "Bwd_in_meta should only be init once, Additional "
                        "initialization for it is forbidden. If you got this "
                        "error, it indicates bugs in framework."));
  // Init stop gradient vector before use to avoid push back
  VLOG(7) << "Init bwd_in_meta_ with slot rank: " << slot_rank;
  meta.Init(1);
140
  meta.SetStopGradient(0, fwd_out->StopGradient());
141 142
}

143
void GradNodeBase::SetGradOutMeta(std::vector<AutogradMeta*>* fwd_in,
144
                                  size_t slot_rank) {
145
  size_t slot_size = fwd_in->size();
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
  PADDLE_ENFORCE_LE(
      slot_rank, (bwd_out_meta_.size() - 1),
      paddle::platform::errors::InvalidArgument(
          "Slot Rank should less equal than bwd_out_meta_ size, "
          "since bwd_out_meta_ is designed to hold as same num as "
          "backward outputs."));
  auto& meta = bwd_out_meta_.at(slot_rank);
  PADDLE_ENFORCE_EQ(meta.IsInitialized(), false,
                    paddle::platform::errors::PreconditionNotMet(
                        "Bwd_out_meta should only be init once. Additional "
                        "initialization for it is forbidden. If you got this "
                        "error, it indicates bugs in framework."));
  // Init stop gradient vector before use to avoid push back
  meta.Init(slot_size);
  for (size_t i = 0; i < slot_size; i++) {
161
    if (!(*fwd_in)[i]) {
162 163 164
      meta.SetStopGradient(i, true);
      continue;
    }
165
    if ((*fwd_in)[i]->StopGradient()) {
166 167
      // Set Stop Gradient only when its true or non-initialized autograd_meta,
      // since all default value is false.
168
      meta.SetStopGradient(i, (*fwd_in)[i]->StopGradient());
169 170 171 172
    }
  }
}

173
void GradNodeBase::SetGradOutMeta(AutogradMeta* fwd_in, size_t slot_rank) {
174 175 176 177 178 179 180 181 182 183 184 185 186 187
  PADDLE_ENFORCE_LE(
      (slot_rank + 1), bwd_out_meta_.size(),
      paddle::platform::errors::InvalidArgument(
          "Slot Rank should less equal than bwd_out_meta_ size, "
          "since bwd_out_meta_ is designed to hold as same num as "
          "backward outputs."));
  auto& meta = bwd_out_meta_.at(slot_rank);
  PADDLE_ENFORCE_EQ(meta.IsInitialized(), false,
                    paddle::platform::errors::PreconditionNotMet(
                        "Bwd_out_meta should only be init once. Additional "
                        "initialization for it is forbidden. If you got this "
                        "error, it indicates bugs in framework."));
  // Init stop gradient vector before use to avoid push back
  meta.Init(1);
188 189 190 191 192
  if (fwd_in) {
    meta.SetStopGradient(0, fwd_in->StopGradient());
  } else {
    meta.SetStopGradient(0, true);
  }
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
}

void GradNodeBase::SetDefaultGradInOutMeta() {
  PADDLE_ENFORCE((bwd_out_meta_.size() == 1) && (bwd_in_meta_.size() == 1),
                 paddle::platform::errors::PreconditionNotMet(
                     "We can only support 1 input and 1 output in default grad "
                     "meta setter, other size of inputs and outputs should "
                     "create with Setter and Getters"));
  // Default stop_gradient is false and slot id is 0, slot size is 1;
  bwd_out_meta_[0].Init(1);
  bwd_in_meta_[0].Init(1);
}

const std::vector<std::vector<Edge>>& GradNodeBase::GetEdges() const {
  return adj_edges_;
}

void GradNodeBase::RegisterGradientHook(
    size_t slot_id, size_t rank,
212 213
    const std::function<paddle::experimental::Tensor(
        const paddle::experimental::Tensor&)>& hook) {
214 215 216
  gradient_hooks_.emplace_back(std::make_tuple(slot_id, rank, hook));
}

217 218 219 220
std::vector<std::vector<paddle::experimental::Tensor>>
GradNodeBase::ApplyGradientHooks(
    const std::vector<std::vector<paddle::experimental::Tensor>>& tensors) {
  std::vector<std::vector<paddle::experimental::Tensor>> outs(tensors.size());
221 222 223
  for (auto& tuple : gradient_hooks_) {
    size_t slot_id = std::get<0>(tuple);
    size_t rank = std::get<1>(tuple);
224 225
    std::function<paddle::experimental::Tensor(
        const paddle::experimental::Tensor&)>& hook = std::get<2>(tuple);
226 227 228 229 230 231 232 233 234 235 236 237

    PADDLE_ENFORCE(slot_id < tensors.size(),
                   paddle::platform::errors::Fatal(
                       "Slot_id from registered hook should be smaller than "
                       "slot size of grad_tensors"));

    PADDLE_ENFORCE(rank < tensors[slot_id].size(),
                   paddle::platform::errors::Fatal(
                       "rank of slot %d from registered hook should be smaller "
                       "than rank size of grad_tensors",
                       slot_id));

238
    std::vector<paddle::experimental::Tensor>& slot_out = outs[slot_id];
239
    slot_out.resize(tensors[slot_id].size());
240
    paddle::experimental::Tensor& out = slot_out[rank];
241
    if (!out.defined() || !out.initialized()) {
242
      VLOG(8) << "Run Hook for tensor: " << tensors[slot_id][rank].name();
243 244
      out = hook(tensors[slot_id][rank]);
    } else {
245 246
      // If more than one hook is registered, the input to the next hook func
      // should be the output of the previous hook
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
      out = hook(out);
    }
  }

  for (size_t i = 0; i < outs.size(); i++) {
    if (outs[i].empty() && (!tensors[i].empty())) {
      outs[i].resize(tensors[i].size());
    }
    // TODO(Jiabin): Optimize this if we only add hook slot by slot
    for (size_t j = 0; j < outs[i].size(); j++) {
      if (!outs[i][j].defined() || !outs[i][j].initialized()) {
        outs[i][j] = tensors[i][j];
      }
    }
  }

  return outs;
}

}  // namespace egr