utils.cc 12.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/eager/utils.h"
16
#include "paddle/fluid/eager/api/utils/global_utils.h"
17
#include "paddle/fluid/eager/api/utils/hook_utils.h"
18
#include "paddle/fluid/eager/tensor_wrapper.h"
19

20 21 22
#include "paddle/phi/api/all.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/tensor_meta.h"
23

24
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
25 26 27
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/variable.h"
28

29 30 31
PADDLE_DEFINE_EXPORTED_bool(retain_grad_for_all_tensor, true,
                            "retain grad for all tensor");

32
namespace egr {
33 34 35 36
/**
 * Implementation of Eager Utils.
**/

37
AutogradMeta* EagerUtils::autograd_meta(paddle::experimental::Tensor* target) {
38 39 40 41 42 43 44 45 46
  auto* p_autograd_meta = target->get_autograd_meta();
  if (!p_autograd_meta) {
    auto p_autograd_meta_ptr = std::make_shared<AutogradMeta>();
    p_autograd_meta = p_autograd_meta_ptr.get();
    target->set_autograd_meta(p_autograd_meta_ptr);
  }
  return static_cast<AutogradMeta*>(p_autograd_meta);
}

47 48
AutogradMeta* EagerUtils::unsafe_autograd_meta(
    const paddle::experimental::Tensor& target) {
49 50 51 52 53 54 55 56
  auto* p_autograd_meta = target.get_autograd_meta();
  PADDLE_ENFORCE(p_autograd_meta,
                 paddle::platform::errors::Fatal(
                     "Null autograd_meta gotten from unsafe_autograd_meta()"));
  return static_cast<AutogradMeta*>(p_autograd_meta);
}

std::vector<AutogradMeta*> EagerUtils::unsafe_autograd_meta(
57
    const std::vector<paddle::experimental::Tensor>& targets) {
58
  std::vector<AutogradMeta*> metas;
59
  metas.reserve(targets.size());
60
  for (const paddle::experimental::Tensor& t : targets) {
61
    metas.emplace_back(unsafe_autograd_meta(t));
62 63 64 65
  }
  return metas;
}

66
AutogradMeta* EagerUtils::nullable_autograd_meta(
67
    const paddle::experimental::Tensor& target) {
68 69 70 71 72 73
  auto* p_autograd_meta = target.get_autograd_meta();
  if (!p_autograd_meta) return nullptr;

  return static_cast<AutogradMeta*>(p_autograd_meta);
}

74
std::vector<AutogradMeta*> EagerUtils::nullable_autograd_meta(
75
    const std::vector<paddle::experimental::Tensor>& targets) {
76 77
  std::vector<AutogradMeta*> metas;
  metas.reserve(targets.size());
78
  for (const paddle::experimental::Tensor& t : targets) {
79 80 81 82 83
    metas.emplace_back(nullable_autograd_meta(t));
  }
  return metas;
}

84
std::vector<AutogradMeta*> EagerUtils::autograd_meta(
85
    std::vector<paddle::experimental::Tensor>* targets) {
86 87 88
  std::vector<AutogradMeta*> ret;
  ret.reserve(targets->size());

89
  // for autograd_meta we can tolerent it has nullptr.
90 91 92
  for (size_t i = 0; i < targets->size(); i++) {
    auto* p_autograd_meta = autograd_meta(&((*targets)[i]));
    ret.emplace_back(p_autograd_meta);
93 94 95 96 97
  }
  return ret;
}

std::pair<size_t, size_t> EagerUtils::OutRankInfo(
98
    const paddle::experimental::Tensor& target) {
99 100 101 102
  return unsafe_autograd_meta(target)->OutRankInfo();
}

std::shared_ptr<GradNodeBase> EagerUtils::grad_node(
103
    const paddle::experimental::Tensor& target) {
104 105 106 107 108 109
  auto* meta = nullable_autograd_meta(target);
  if (meta) {
    return meta->GetMutableGradNode();
  } else {
    return nullptr;
  }
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
}

void EagerUtils::SetHistory(std::vector<AutogradMeta*>* autograd_metas,
                            const std::shared_ptr<GradNodeBase>& grad_node) {
  for (const auto& autograd_meta : *autograd_metas) {
    autograd_meta->SetGradNode(grad_node);
  }
}

void EagerUtils::SetHistory(AutogradMeta* autograd_meta,
                            const std::shared_ptr<GradNodeBase>& grad_node) {
  autograd_meta->SetGradNode(grad_node);
}

void EagerUtils::SetOutRankWithSlot(std::vector<AutogradMeta*>* targets,
                                    size_t slot_id) {
  // Set OutRankInfo from 0 to size of targets
  for (size_t i = 0; i < targets->size(); i++) {
    (*targets)[i]->SetSingleOutRankWithSlot(slot_id, i);
  }
}
void EagerUtils::SetOutRankWithSlot(AutogradMeta* target, size_t slot_id) {
  target->SetSingleOutRankWithSlot(slot_id, 0);
}

135
std::shared_ptr<egr::EagerVariable> EagerUtils::TrySyncToVar(
136
    const paddle::experimental::Tensor& tensor) {
137
  return std::make_shared<egr::EagerVariable>(tensor);
138 139
}

140
std::vector<std::shared_ptr<egr::EagerVariable>> EagerUtils::TrySyncToVars(
141
    const paddle::experimental::Tensor& tensor) {
142 143 144
  return {TrySyncToVar(tensor)};
}

145
std::vector<std::shared_ptr<egr::EagerVariable>> EagerUtils::TrySyncToVars(
146 147 148 149 150 151 152
    paddle::experimental::Tensor* tensor) {
  PADDLE_ENFORCE_NOT_NULL(
      tensor,
      paddle::platform::errors::Fatal(
          "Should Not Pass Empty tensor pointer in, since only output can "
          "reach this, please check output value and make sure it's not null"));
  return {TrySyncToVar(*tensor)};
153 154
}

155
std::vector<std::shared_ptr<egr::EagerVariable>> EagerUtils::TrySyncToVars(
156
    const std::vector<paddle::experimental::Tensor*>& tensors) {
157
  std::vector<std::shared_ptr<EagerVariable>> res;
158 159 160
  size_t num = tensors.size();
  res.reserve(num);
  for (size_t i = 0; i < num; i++) {
161 162 163 164 165 166 167 168
    auto* tensor = tensors[i];
    PADDLE_ENFORCE_NOT_NULL(
        tensor, paddle::platform::errors::Fatal(
                    "Tensor is null and cannot be copied. "
                    "We are tring to TrySyncToVars tensor from its "
                    "shared_ptr, this error may indicate some outputs "
                    "are nullptr"));
    res.emplace_back(TrySyncToVar(*tensor));
169 170 171 172
  }
  return res;
}

173
std::vector<std::shared_ptr<egr::EagerVariable>> EagerUtils::TrySyncToVars(
174
    const std::vector<paddle::experimental::Tensor>& tensors) {
175
  std::vector<std::shared_ptr<EagerVariable>> res;
176 177 178
  size_t num = tensors.size();
  res.reserve(num);
  for (size_t i = 0; i < num; i++) {
179
    res.emplace_back(TrySyncToVar(tensors[i]));
180 181 182 183
  }
  return res;
}

184
std::vector<std::shared_ptr<EagerVariable>> EagerUtils::CreateVars(
185
    const size_t num) {
186
  std::vector<std::shared_ptr<EagerVariable>> res;
187 188 189
  res.reserve(num);
  for (size_t i = 0; i < num; i++) {
    res.emplace_back(
190
        new EagerVariable(egr::Controller::Instance().GenerateUniqueName()));
191 192 193 194
  }
  return res;
}

195
std::vector<paddle::experimental::Tensor> EagerUtils::GetOutputs(
196
    const std::vector<std::shared_ptr<EagerVariable>>& outs) {
197
  std::vector<paddle::experimental::Tensor> res;
198 199 200 201 202 203 204 205 206
  res.reserve(outs.size());
  for (const auto& out : outs) {
    PADDLE_ENFORCE_NOT_NULL(
        out.get(), paddle::platform::errors::Fatal(
                       "Eager Tensor %s is null and cannot be copied. "
                       "We are tring to Get Output tensor from its "
                       "shared_ptr, this error may indicate some outputs "
                       "are nullptr",
                       out->name()));
207
    res.emplace_back(out->GetTensorBase(), out->name());
208 209 210 211
  }
  return res;
}

212
paddle::experimental::Tensor EagerUtils::GetOutput(
213
    const std::shared_ptr<EagerVariable>& out) {
214 215 216 217 218 219
  PADDLE_ENFORCE_NOT_NULL(
      out.get(), paddle::platform::errors::Fatal(
                     "Eager Tensor %s is null and cannot be copied. We "
                     "are tring to Get Output tensor from its shared_ptr, "
                     "this error may indicate output is nullptr",
                     out->name()));
220 221 222
  return paddle::experimental::Tensor(out->GetTensorBase(), out->name());
}

223
void EagerUtils::OverwriteOutputs(const std::shared_ptr<EagerVariable>& out,
224 225 226 227 228 229 230 231 232 233 234
                                  paddle::experimental::Tensor* tensor) {
  PADDLE_ENFORCE_NOT_NULL(
      tensor, paddle::platform::errors::Fatal(
                  "Tensor is null and cannot be copied. "
                  "We are tring to OverwriteOutput from its "
                  "shared_ptr, this error may indicate some outputs "
                  "are nullptr"));
  tensor->set_impl(out->GetTensorBase());
}

void EagerUtils::OverwriteOutputs(
235
    const std::vector<std::shared_ptr<EagerVariable>>& outs,
236 237 238 239 240 241 242 243 244 245 246
    const std::vector<paddle::experimental::Tensor*>& tensors) {
  PADDLE_ENFORCE_EQ(
      outs.size(), tensors.size(),
      paddle::platform::errors::Fatal(
          "We are tring to OverwriteOutputs which passed in and it expected "
          "elements num of outs and origin outputs are equal, but we got outs "
          "size of: %d, and tensors passed in size is: %d",
          outs.size(), tensors.size()));
  for (size_t i = 0; i < outs.size(); i++) {
    OverwriteOutputs(outs[i], tensors[i]);
  }
247 248
}

249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
void EagerUtils::OverwriteOutputs(const paddle::experimental::Tensor& out,
                                  paddle::experimental::Tensor* tensor) {
  PADDLE_ENFORCE_NOT_NULL(
      tensor, paddle::platform::errors::Fatal(
                  "Tensor is null and cannot be copied. "
                  "We are tring to OverwriteOutput from its "
                  "shared_ptr, this error may indicate some outputs "
                  "are nullptr"));
  *tensor = out;
}
void EagerUtils::OverwriteOutputs(
    const std::vector<paddle::experimental::Tensor>& outs,
    const std::vector<paddle::experimental::Tensor*>& tensors) {
  for (size_t i = 0; i < outs.size(); i++) {
    PADDLE_ENFORCE_NOT_NULL(
        tensors[i], paddle::platform::errors::Fatal(
                        "Tensor is null and cannot be copied. "
                        "We are tring to OverwriteOutput from its "
                        "shared_ptr, this error may indicate some outputs "
                        "are nullptr"));
    *tensors[i] = outs[i];
  }
}

paddle::experimental::Tensor EagerUtils::RecoverTensorWrapper(
274 275 276 277
    TensorWrapper* tw, const std::shared_ptr<GradNodeBase>& grad_node) {
  return tw->recover(grad_node);
}

278
std::vector<paddle::experimental::Tensor> EagerUtils::RecoverTensorWrapper(
279 280
    std::vector<TensorWrapper>* tw,
    const std::shared_ptr<GradNodeBase>& grad_node) {
281
  std::vector<paddle::experimental::Tensor> ret;
282 283 284 285 286 287
  for (auto& t : *tw) {
    ret.emplace_back(t.recover(grad_node));
  }
  return ret;
}

288 289
void EagerUtils::CheckAndRetainGrad(
    const paddle::experimental::Tensor& tensor) {
290 291
  VLOG(6) << "Check RetainGradForTensor: " << tensor.name();
  if (FLAGS_retain_grad_for_all_tensor) {
292
    VLOG(6) << "RetainGradForTensor: " << tensor.name();
293 294 295 296 297
    egr::egr_utils_api::RetainGradForTensor(tensor);
  }
}

void EagerUtils::CheckAndRetainGrad(
298
    const std::vector<paddle::experimental::Tensor>& tensors) {
299 300
  if (FLAGS_retain_grad_for_all_tensor) {
    for (auto& tensor : tensors) {
301
      VLOG(6) << "RetainGradForTensor: " << tensor.name();
302 303 304 305 306
      egr::egr_utils_api::RetainGradForTensor(tensor);
    }
  }
}

307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343
std::shared_ptr<egr::GradNodeBase> EagerUtils::GetGradAccumulationNode(
    const paddle::experimental::Tensor& tensor) {
  auto* autograd_ptr = nullable_autograd_meta(tensor);
  if (!autograd_ptr) {
    return nullptr;
  }
  auto node_ptr = autograd_ptr->GetMutableGradNode();
  if (node_ptr && node_ptr.get()) {
    if (!autograd_ptr->StopGradient()) {
      auto accumulation_ptr =
          std::dynamic_pointer_cast<GradNodeAccumulation>(node_ptr);
      if (accumulation_ptr) {
        return accumulation_ptr;
      } else {
        // Current GradNode is not a egr::GradNodeAccumulation
        PADDLE_THROW(paddle::platform::errors::Fatal(
            "GetGradAccumulationNode should only be called on leaf tensor, but "
            "target tensor: %s has GradNode which is not a "
            "GradNodeAccumulation, and this should not happend unless target "
            "tensor is modified by some ops and calling set history for it.",
            tensor.name()));
      }
    } else {
      // Current Tensor does not have grad since it's stop_gradient is true;
      return nullptr;
    }
  } else {
    if (!autograd_ptr->StopGradient()) {
      VLOG(6) << "Add GradNodeAccumulation for tensor: " << tensor.name();
      autograd_ptr->SetGradNode(std::make_shared<egr::GradNodeAccumulation>());
      return autograd_ptr->GetMutableGradNode();
    } else {
      return nullptr;
    }
  }
}

344
}  // namespace egr