utils.cc 12.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/eager/utils.h"
16
#include "paddle/fluid/eager/api/utils/global_utils.h"
17
#include "paddle/fluid/eager/api/utils/hook_utils.h"
18
#include "paddle/fluid/eager/tensor_wrapper.h"
19

20 21 22
#include "paddle/phi/api/all.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/tensor_meta.h"
23

24
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
25 26 27
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/variable.h"
28

29 30 31
PADDLE_DEFINE_EXPORTED_bool(retain_grad_for_all_tensor, true,
                            "retain grad for all tensor");

32
namespace egr {
33 34 35 36
/**
 * Implementation of Eager Utils.
**/

37
AutogradMeta* EagerUtils::autograd_meta(paddle::experimental::Tensor* target) {
38 39 40 41 42 43 44 45 46
  auto* p_autograd_meta = target->get_autograd_meta();
  if (!p_autograd_meta) {
    auto p_autograd_meta_ptr = std::make_shared<AutogradMeta>();
    p_autograd_meta = p_autograd_meta_ptr.get();
    target->set_autograd_meta(p_autograd_meta_ptr);
  }
  return static_cast<AutogradMeta*>(p_autograd_meta);
}

47 48
AutogradMeta* EagerUtils::unsafe_autograd_meta(
    const paddle::experimental::Tensor& target) {
49 50 51 52 53 54 55 56
  auto* p_autograd_meta = target.get_autograd_meta();
  PADDLE_ENFORCE(p_autograd_meta,
                 paddle::platform::errors::Fatal(
                     "Null autograd_meta gotten from unsafe_autograd_meta()"));
  return static_cast<AutogradMeta*>(p_autograd_meta);
}

std::vector<AutogradMeta*> EagerUtils::unsafe_autograd_meta(
57
    const std::vector<paddle::experimental::Tensor>& targets) {
58
  std::vector<AutogradMeta*> metas;
59
  metas.reserve(targets.size());
60
  for (const paddle::experimental::Tensor& t : targets) {
61
    metas.emplace_back(unsafe_autograd_meta(t));
62 63 64 65
  }
  return metas;
}

66
AutogradMeta* EagerUtils::nullable_autograd_meta(
67
    const paddle::experimental::Tensor& target) {
68 69 70 71 72 73
  auto* p_autograd_meta = target.get_autograd_meta();
  if (!p_autograd_meta) return nullptr;

  return static_cast<AutogradMeta*>(p_autograd_meta);
}

74
std::vector<AutogradMeta*> EagerUtils::nullable_autograd_meta(
75
    const std::vector<paddle::experimental::Tensor>& targets) {
76 77
  std::vector<AutogradMeta*> metas;
  metas.reserve(targets.size());
78
  for (const paddle::experimental::Tensor& t : targets) {
79 80 81 82 83
    metas.emplace_back(nullable_autograd_meta(t));
  }
  return metas;
}

84
std::vector<AutogradMeta*> EagerUtils::autograd_meta(
85
    std::vector<paddle::experimental::Tensor>* targets) {
86 87 88
  std::vector<AutogradMeta*> ret;
  ret.reserve(targets->size());

89
  // for autograd_meta we can tolerent it has nullptr.
90 91 92
  for (size_t i = 0; i < targets->size(); i++) {
    auto* p_autograd_meta = autograd_meta(&((*targets)[i]));
    ret.emplace_back(p_autograd_meta);
93 94 95 96 97
  }
  return ret;
}

std::pair<size_t, size_t> EagerUtils::OutRankInfo(
98
    const paddle::experimental::Tensor& target) {
99 100 101 102
  return unsafe_autograd_meta(target)->OutRankInfo();
}

std::shared_ptr<GradNodeBase> EagerUtils::grad_node(
103
    const paddle::experimental::Tensor& target) {
104 105 106 107 108 109
  auto* meta = nullable_autograd_meta(target);
  if (meta) {
    return meta->GetMutableGradNode();
  } else {
    return nullptr;
  }
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
}

void EagerUtils::SetHistory(std::vector<AutogradMeta*>* autograd_metas,
                            const std::shared_ptr<GradNodeBase>& grad_node) {
  for (const auto& autograd_meta : *autograd_metas) {
    autograd_meta->SetGradNode(grad_node);
  }
}

void EagerUtils::SetHistory(AutogradMeta* autograd_meta,
                            const std::shared_ptr<GradNodeBase>& grad_node) {
  autograd_meta->SetGradNode(grad_node);
}

void EagerUtils::SetOutRankWithSlot(std::vector<AutogradMeta*>* targets,
                                    size_t slot_id) {
  // Set OutRankInfo from 0 to size of targets
  for (size_t i = 0; i < targets->size(); i++) {
    (*targets)[i]->SetSingleOutRankWithSlot(slot_id, i);
  }
}
void EagerUtils::SetOutRankWithSlot(AutogradMeta* target, size_t slot_id) {
  target->SetSingleOutRankWithSlot(slot_id, 0);
}

135
std::shared_ptr<egr::EagerVariable> EagerUtils::TrySyncToVar(
136
    const paddle::experimental::Tensor& tensor) {
137
  return std::make_shared<egr::EagerVariable>(tensor);
138 139
}

140
std::vector<std::shared_ptr<egr::EagerVariable>> EagerUtils::TrySyncToVars(
141
    const paddle::experimental::Tensor& tensor) {
142 143 144
  return {TrySyncToVar(tensor)};
}

145
std::vector<std::shared_ptr<egr::EagerVariable>> EagerUtils::TrySyncToVars(
146 147 148 149 150 151 152
    paddle::experimental::Tensor* tensor) {
  PADDLE_ENFORCE_NOT_NULL(
      tensor,
      paddle::platform::errors::Fatal(
          "Should Not Pass Empty tensor pointer in, since only output can "
          "reach this, please check output value and make sure it's not null"));
  return {TrySyncToVar(*tensor)};
153 154
}

155
std::vector<std::shared_ptr<egr::EagerVariable>> EagerUtils::TrySyncToVars(
156
    const std::vector<paddle::experimental::Tensor*>& tensors) {
157
  std::vector<std::shared_ptr<EagerVariable>> res;
158 159 160
  size_t num = tensors.size();
  res.reserve(num);
  for (size_t i = 0; i < num; i++) {
161 162 163 164 165 166 167 168
    auto* tensor = tensors[i];
    PADDLE_ENFORCE_NOT_NULL(
        tensor, paddle::platform::errors::Fatal(
                    "Tensor is null and cannot be copied. "
                    "We are tring to TrySyncToVars tensor from its "
                    "shared_ptr, this error may indicate some outputs "
                    "are nullptr"));
    res.emplace_back(TrySyncToVar(*tensor));
169 170 171 172
  }
  return res;
}

173
std::vector<std::shared_ptr<egr::EagerVariable>> EagerUtils::TrySyncToVars(
174
    const std::vector<paddle::experimental::Tensor>& tensors) {
175
  std::vector<std::shared_ptr<EagerVariable>> res;
176 177 178
  size_t num = tensors.size();
  res.reserve(num);
  for (size_t i = 0; i < num; i++) {
179
    res.emplace_back(TrySyncToVar(tensors[i]));
180 181 182 183
  }
  return res;
}

184
std::vector<std::shared_ptr<EagerVariable>> EagerUtils::CreateVars(
185
    const size_t num) {
186
  std::vector<std::shared_ptr<EagerVariable>> res;
187 188 189
  res.reserve(num);
  for (size_t i = 0; i < num; i++) {
    res.emplace_back(
190
        new EagerVariable(egr::Controller::Instance().GenerateUniqueName()));
191 192 193 194
  }
  return res;
}

195
std::vector<paddle::experimental::Tensor> EagerUtils::GetOutputs(
196
    const std::vector<std::shared_ptr<EagerVariable>>& outs) {
197
  std::vector<paddle::experimental::Tensor> res;
198 199 200 201 202 203 204 205 206
  res.reserve(outs.size());
  for (const auto& out : outs) {
    PADDLE_ENFORCE_NOT_NULL(
        out.get(), paddle::platform::errors::Fatal(
                       "Eager Tensor %s is null and cannot be copied. "
                       "We are tring to Get Output tensor from its "
                       "shared_ptr, this error may indicate some outputs "
                       "are nullptr",
                       out->name()));
207
    res.emplace_back(out->GetTensorBase(), out->name());
208 209 210 211
  }
  return res;
}

212
paddle::experimental::Tensor EagerUtils::GetOutput(
213
    const std::shared_ptr<EagerVariable>& out) {
214 215 216 217 218 219
  PADDLE_ENFORCE_NOT_NULL(
      out.get(), paddle::platform::errors::Fatal(
                     "Eager Tensor %s is null and cannot be copied. We "
                     "are tring to Get Output tensor from its shared_ptr, "
                     "this error may indicate output is nullptr",
                     out->name()));
220 221 222
  return paddle::experimental::Tensor(out->GetTensorBase(), out->name());
}

223 224
void EagerUtils::GetOutput(const std::shared_ptr<EagerVariable>& out,
                           paddle::experimental::Tensor* out_var) {
225
  PADDLE_ENFORCE_NOT_NULL(
226 227 228 229 230 231
      out_var, paddle::platform::errors::Fatal(
                   "Tensor is null and cannot be copied. "
                   "We are tring to OverwriteOutput from its "
                   "shared_ptr, this error may indicate some outputs "
                   "are nullptr"));
  out_var->set_impl(out->GetTensorBase());
232 233
}

234
void EagerUtils::GetOutputs(
235
    const std::vector<std::shared_ptr<EagerVariable>>& outs,
236
    std::vector<paddle::experimental::Tensor>* result) {
237
  for (size_t i = 0; i < outs.size(); i++) {
238
    result->emplace_back(outs[i]->GetTensorBase());
239
  }
240 241
}

242 243 244
void EagerUtils::GetOutputs(
    const std::vector<std::shared_ptr<EagerVariable>>& outs,
    const std::vector<paddle::experimental::Tensor*>& out_var) {
245 246
  for (size_t i = 0; i < outs.size(); i++) {
    PADDLE_ENFORCE_NOT_NULL(
247
        out_var[i], paddle::platform::errors::Fatal(
248 249 250 251
                        "Tensor is null and cannot be copied. "
                        "We are tring to OverwriteOutput from its "
                        "shared_ptr, this error may indicate some outputs "
                        "are nullptr"));
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
    out_var[i]->set_impl(outs[i]->GetTensorBase());
  }
}

void EagerUtils::GetOutputs(const std::shared_ptr<EagerVariable>& out,
                            std::vector<paddle::experimental::Tensor>* result) {
  result->emplace_back(out->GetTensorBase());
}

void EagerUtils::GetOutputs(
    const std::shared_ptr<EagerVariable>& out,
    const std::vector<paddle::experimental::Tensor*>& out_var) {
  PADDLE_ENFORCE_NOT_NULL(
      out_var[0], paddle::platform::errors::Fatal(
                      "Tensor is null and cannot be copied. "
                      "We are tring to OverwriteOutput from its "
                      "shared_ptr, this error may indicate some outputs "
                      "are nullptr"));
  out_var[0]->set_impl(out->GetTensorBase());
}

void EagerUtils::Output2Result(
    const std::vector<paddle::experimental::Tensor*>& out_var,
    std::vector<paddle::experimental::Tensor>* result) {
  result->reserve(out_var.size());
  for (size_t i = 0; i < out_var.size(); i++) {
    result->emplace_back(*out_var[i]);
279 280 281 282
  }
}

paddle::experimental::Tensor EagerUtils::RecoverTensorWrapper(
283 284 285 286
    TensorWrapper* tw, const std::shared_ptr<GradNodeBase>& grad_node) {
  return tw->recover(grad_node);
}

287
std::vector<paddle::experimental::Tensor> EagerUtils::RecoverTensorWrapper(
288 289
    std::vector<TensorWrapper>* tw,
    const std::shared_ptr<GradNodeBase>& grad_node) {
290
  std::vector<paddle::experimental::Tensor> ret;
291 292 293 294 295 296
  for (auto& t : *tw) {
    ret.emplace_back(t.recover(grad_node));
  }
  return ret;
}

297 298
void EagerUtils::CheckAndRetainGrad(
    const paddle::experimental::Tensor& tensor) {
299 300
  VLOG(6) << "Check RetainGradForTensor: " << tensor.name();
  if (FLAGS_retain_grad_for_all_tensor) {
301
    VLOG(6) << "RetainGradForTensor: " << tensor.name();
302 303 304 305 306
    egr::egr_utils_api::RetainGradForTensor(tensor);
  }
}

void EagerUtils::CheckAndRetainGrad(
307
    const std::vector<paddle::experimental::Tensor>& tensors) {
308 309
  if (FLAGS_retain_grad_for_all_tensor) {
    for (auto& tensor : tensors) {
310
      VLOG(6) << "RetainGradForTensor: " << tensor.name();
311 312 313 314 315
      egr::egr_utils_api::RetainGradForTensor(tensor);
    }
  }
}

316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352
std::shared_ptr<egr::GradNodeBase> EagerUtils::GetGradAccumulationNode(
    const paddle::experimental::Tensor& tensor) {
  auto* autograd_ptr = nullable_autograd_meta(tensor);
  if (!autograd_ptr) {
    return nullptr;
  }
  auto node_ptr = autograd_ptr->GetMutableGradNode();
  if (node_ptr && node_ptr.get()) {
    if (!autograd_ptr->StopGradient()) {
      auto accumulation_ptr =
          std::dynamic_pointer_cast<GradNodeAccumulation>(node_ptr);
      if (accumulation_ptr) {
        return accumulation_ptr;
      } else {
        // Current GradNode is not a egr::GradNodeAccumulation
        PADDLE_THROW(paddle::platform::errors::Fatal(
            "GetGradAccumulationNode should only be called on leaf tensor, but "
            "target tensor: %s has GradNode which is not a "
            "GradNodeAccumulation, and this should not happend unless target "
            "tensor is modified by some ops and calling set history for it.",
            tensor.name()));
      }
    } else {
      // Current Tensor does not have grad since it's stop_gradient is true;
      return nullptr;
    }
  } else {
    if (!autograd_ptr->StopGradient()) {
      VLOG(6) << "Add GradNodeAccumulation for tensor: " << tensor.name();
      autograd_ptr->SetGradNode(std::make_shared<egr::GradNodeAccumulation>());
      return autograd_ptr->GetMutableGradNode();
    } else {
      return nullptr;
    }
  }
}

353
}  // namespace egr