grad_node_info.cc 11.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/eager/grad_node_info.h"
16
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
17
#include "paddle/fluid/eager/autograd_meta.h"
18 19
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"
20 21 22 23 24 25 26 27

#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"

#include "glog/logging.h"

/**
28
 * Implementation of GradNodeBase, Edge and GradTensorHolder.
29 30 31 32
**/
namespace egr {

GradNodeBase::GradNodeBase(size_t bwd_in_slot_num, size_t bwd_out_slot_num) {
J
Jiabin Yang 已提交
33
  VLOG(6) << "Construct GradNodeBase";
34 35 36 37 38 39
  bwd_in_meta_.resize(bwd_in_slot_num);
  bwd_out_meta_.resize(bwd_out_slot_num);
  // adj_edges has the same num as backward outputs
  adj_edges_.resize(bwd_out_slot_num);
}

40 41 42 43 44 45 46 47 48 49 50
void GradNodeBase::AddEdges(std::vector<AutogradMeta*>* metas, size_t slot_id) {
  PADDLE_ENFORCE_LT(
      slot_id, adj_edges_.size(),
      paddle::platform::errors::InvalidArgument(
          "Given slot id is out of range of adj_edges outter size, "
          "adj_edges is designed to has the same size of grad "
          "inputs's slot num."));
  for (const auto& meta : *metas) {
    // adj_edges has as same rank as fwd inputs, and record it's output rank
    // from
    // its pre-ops
51
    if (meta && !meta->StopGradient()) {
52
      auto node = meta->GetMutableGradNode();
J
Jiabin Yang 已提交
53 54 55
      if (node && node.get()) {
        VLOG(6) << "Add Edges for slot: " << slot_id
                << " which is: " << meta->GetMutableGradNode()->name();
56 57 58
        adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
                                         meta->OutRankInfo());
      } else {
59
        meta->SetGradNode(std::make_shared<egr::GradNodeAccumulation>(meta));
J
Jiabin Yang 已提交
60 61
        VLOG(6) << "Add Edges for slot: " << slot_id
                << " which is: " << meta->GetMutableGradNode()->name();
62 63 64
        adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
                                         meta->OutRankInfo());
      }
65
    }
66 67 68
  }
}

69
void GradNodeBase::AddEdges(AutogradMeta* meta, size_t slot_id) {
70 71 72 73 74 75
  PADDLE_ENFORCE_LT(
      slot_id, adj_edges_.size(),
      paddle::platform::errors::InvalidArgument(
          "Given slot id is out of range of adj_edges outter size, "
          "adj_edges is designed to has the same size of grad "
          "inputs's slot num."));
76
  if (meta && !meta->StopGradient()) {
77
    auto node = meta->GetMutableGradNode();
J
Jiabin Yang 已提交
78
    if (node && node.get()) {
79 80
      VLOG(6) << "Add Edges for slot: " << slot_id << ", the Edge is from "
              << this->name() << " to " << meta->GetMutableGradNode()->name();
81 82 83
      adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
                                       meta->OutRankInfo());
    } else {
84
      meta->SetGradNode(std::make_shared<egr::GradNodeAccumulation>(meta));
85 86
      VLOG(6) << "Add Edges for slot: " << slot_id << ", the Edge is from "
              << this->name() << " to " << meta->GetMutableGradNode()->name();
87 88
      adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
                                       meta->OutRankInfo());
89
    }
90
  }
91 92 93 94 95 96 97 98 99 100
}

const std::vector<GradSlotMeta>& GradNodeBase::InputMeta() const {
  return bwd_in_meta_;
}

const std::vector<GradSlotMeta>& GradNodeBase::OutputMeta() const {
  return bwd_out_meta_;
}

101
void GradNodeBase::SetGradInMeta(std::vector<AutogradMeta*>* fwd_out,
102
                                 size_t slot_rank) {
103
  size_t slot_size = fwd_out->size();
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
  PADDLE_ENFORCE_LE(
      slot_rank, (bwd_in_meta_.size() - 1),
      paddle::platform::errors::InvalidArgument(
          "Slot Rank should less equal than bwd_in_meta_ size, since "
          "bwd_in_meta_ is designed to hold as same num as backward "
          "inputs."));
  auto& meta = bwd_in_meta_.at(slot_rank);
  PADDLE_ENFORCE_EQ(meta.IsInitialized(), false,
                    paddle::platform::errors::PreconditionNotMet(
                        "Bwd_in_meta should only be init once, addition "
                        "initialization for it is forbidden. If you got this "
                        "error, it indicates bugs in framework."));
  // Init stop gradient vector before use to avoid push back
  meta.Init(slot_size);
  for (size_t i = 0; i < slot_size; i++) {
119
    PADDLE_ENFORCE_NOT_NULL((*fwd_out)[i],
120 121 122 123
                            paddle::platform::errors::PreconditionNotMet(
                                "Bwd_in_meta should only be called while "
                                "autograd_meta is not null. If you got this "
                                "error, it indicates bugs in framework."));
124
    if ((*fwd_out)[i]->StopGradient()) {
125 126
      // Set Stop Gradient only when its true or non-initialized autograd_meta,
      // since all default value is false.
127
      meta.SetStopGradient(i, (*fwd_out)[i]->StopGradient());
128 129 130 131
    }
  }
}

132
void GradNodeBase::SetGradInMeta(AutogradMeta* fwd_out, size_t slot_rank) {
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
  PADDLE_ENFORCE_LE(
      slot_rank, (bwd_in_meta_.size() - 1),
      paddle::platform::errors::InvalidArgument(
          "Slot Rank should less equal than bwd_in_meta_ size, since "
          "bwd_in_meta_ is designed to hold as same num as backward "
          "inputs."));
  auto& meta = bwd_in_meta_.at(slot_rank);
  PADDLE_ENFORCE_EQ(meta.IsInitialized(), false,
                    paddle::platform::errors::PreconditionNotMet(
                        "Bwd_in_meta should only be init once, Additional "
                        "initialization for it is forbidden. If you got this "
                        "error, it indicates bugs in framework."));
  // Init stop gradient vector before use to avoid push back
  VLOG(7) << "Init bwd_in_meta_ with slot rank: " << slot_rank;
  meta.Init(1);
148
  meta.SetStopGradient(0, fwd_out->StopGradient());
149 150
}

151
void GradNodeBase::SetGradOutMeta(std::vector<AutogradMeta*>* fwd_in,
152
                                  size_t slot_rank) {
153
  size_t slot_size = fwd_in->size();
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
  PADDLE_ENFORCE_LE(
      slot_rank, (bwd_out_meta_.size() - 1),
      paddle::platform::errors::InvalidArgument(
          "Slot Rank should less equal than bwd_out_meta_ size, "
          "since bwd_out_meta_ is designed to hold as same num as "
          "backward outputs."));
  auto& meta = bwd_out_meta_.at(slot_rank);
  PADDLE_ENFORCE_EQ(meta.IsInitialized(), false,
                    paddle::platform::errors::PreconditionNotMet(
                        "Bwd_out_meta should only be init once. Additional "
                        "initialization for it is forbidden. If you got this "
                        "error, it indicates bugs in framework."));
  // Init stop gradient vector before use to avoid push back
  meta.Init(slot_size);
  for (size_t i = 0; i < slot_size; i++) {
169
    if (!(*fwd_in)[i]) {
170 171 172
      meta.SetStopGradient(i, true);
      continue;
    }
173
    if ((*fwd_in)[i]->StopGradient()) {
174 175
      // Set Stop Gradient only when its true or non-initialized autograd_meta,
      // since all default value is false.
176
      meta.SetStopGradient(i, (*fwd_in)[i]->StopGradient());
177 178 179 180
    }
  }
}

181
void GradNodeBase::SetGradOutMeta(AutogradMeta* fwd_in, size_t slot_rank) {
182 183 184 185 186 187 188 189 190 191 192 193 194 195
  PADDLE_ENFORCE_LE(
      (slot_rank + 1), bwd_out_meta_.size(),
      paddle::platform::errors::InvalidArgument(
          "Slot Rank should less equal than bwd_out_meta_ size, "
          "since bwd_out_meta_ is designed to hold as same num as "
          "backward outputs."));
  auto& meta = bwd_out_meta_.at(slot_rank);
  PADDLE_ENFORCE_EQ(meta.IsInitialized(), false,
                    paddle::platform::errors::PreconditionNotMet(
                        "Bwd_out_meta should only be init once. Additional "
                        "initialization for it is forbidden. If you got this "
                        "error, it indicates bugs in framework."));
  // Init stop gradient vector before use to avoid push back
  meta.Init(1);
196 197 198 199 200
  if (fwd_in) {
    meta.SetStopGradient(0, fwd_in->StopGradient());
  } else {
    meta.SetStopGradient(0, true);
  }
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
}

void GradNodeBase::SetDefaultGradInOutMeta() {
  PADDLE_ENFORCE((bwd_out_meta_.size() == 1) && (bwd_in_meta_.size() == 1),
                 paddle::platform::errors::PreconditionNotMet(
                     "We can only support 1 input and 1 output in default grad "
                     "meta setter, other size of inputs and outputs should "
                     "create with Setter and Getters"));
  // Default stop_gradient is false and slot id is 0, slot size is 1;
  bwd_out_meta_[0].Init(1);
  bwd_in_meta_[0].Init(1);
}

const std::vector<std::vector<Edge>>& GradNodeBase::GetEdges() const {
  return adj_edges_;
}

218 219 220 221 222
int64_t GradNodeBase::RegisterGradientHook(
    size_t slot_id, size_t rank, std::shared_ptr<egr::TensorHook>&& hook) {
  gradient_hooks_.emplace(next_hook_id_,
                          std::make_tuple(slot_id, rank, std::move(hook)));
  return next_hook_id_++;
223 224
}

225 226 227 228
std::vector<std::vector<paddle::experimental::Tensor>>
GradNodeBase::ApplyGradientHooks(
    const std::vector<std::vector<paddle::experimental::Tensor>>& tensors) {
  std::vector<std::vector<paddle::experimental::Tensor>> outs(tensors.size());
229 230 231 232 233
  for (auto& hook_pair : gradient_hooks_) {
    size_t slot_id = std::get<0>(hook_pair.second);
    size_t rank = std::get<1>(hook_pair.second);

    auto hook = std::get<2>(hook_pair.second);
234 235 236 237 238 239 240 241 242 243 244 245

    PADDLE_ENFORCE(slot_id < tensors.size(),
                   paddle::platform::errors::Fatal(
                       "Slot_id from registered hook should be smaller than "
                       "slot size of grad_tensors"));

    PADDLE_ENFORCE(rank < tensors[slot_id].size(),
                   paddle::platform::errors::Fatal(
                       "rank of slot %d from registered hook should be smaller "
                       "than rank size of grad_tensors",
                       slot_id));

246
    std::vector<paddle::experimental::Tensor>& slot_out = outs[slot_id];
247
    slot_out.resize(tensors[slot_id].size());
248
    paddle::experimental::Tensor& out = slot_out[rank];
249
    if (!out.defined() || !out.initialized()) {
250
      out = (*hook)(tensors[slot_id][rank]);
251
    } else {
252
      // If more than one hook is registered, the input to the next hook func
253
      // should be the output of the previous hook
254
      out = (*hook)(out);
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
    }
  }

  for (size_t i = 0; i < outs.size(); i++) {
    if (outs[i].empty() && (!tensors[i].empty())) {
      outs[i].resize(tensors[i].size());
    }
    // TODO(Jiabin): Optimize this if we only add hook slot by slot
    for (size_t j = 0; j < outs[i].size(); j++) {
      if (!outs[i][j].defined() || !outs[i][j].initialized()) {
        outs[i][j] = tensors[i][j];
      }
    }
  }

  return outs;
}

}  // namespace egr