utils.cc 10.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/eager/utils.h"
16
#include "paddle/fluid/eager/api/utils/global_utils.h"
17
#include "paddle/fluid/eager/api/utils/hook_utils.h"
18
#include "paddle/fluid/eager/tensor_wrapper.h"
19 20 21 22 23 24 25 26

#include "paddle/pten/api/all.h"
#include "paddle/pten/common/layout.h"
#include "paddle/pten/core/tensor_meta.h"

#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/variable.h"
27

28 29 30
PADDLE_DEFINE_EXPORTED_bool(retain_grad_for_all_tensor, true,
                            "retain grad for all tensor");

31
namespace egr {
32 33 34 35
/**
 * Implementation of Eager Utils.
**/

36
AutogradMeta* EagerUtils::autograd_meta(paddle::experimental::Tensor* target) {
37 38 39 40 41 42 43 44 45
  auto* p_autograd_meta = target->get_autograd_meta();
  if (!p_autograd_meta) {
    auto p_autograd_meta_ptr = std::make_shared<AutogradMeta>();
    p_autograd_meta = p_autograd_meta_ptr.get();
    target->set_autograd_meta(p_autograd_meta_ptr);
  }
  return static_cast<AutogradMeta*>(p_autograd_meta);
}

46 47
AutogradMeta* EagerUtils::unsafe_autograd_meta(
    const paddle::experimental::Tensor& target) {
48 49 50 51 52 53 54 55
  auto* p_autograd_meta = target.get_autograd_meta();
  PADDLE_ENFORCE(p_autograd_meta,
                 paddle::platform::errors::Fatal(
                     "Null autograd_meta gotten from unsafe_autograd_meta()"));
  return static_cast<AutogradMeta*>(p_autograd_meta);
}

std::vector<AutogradMeta*> EagerUtils::unsafe_autograd_meta(
56
    const std::vector<paddle::experimental::Tensor>& targets) {
57
  std::vector<AutogradMeta*> metas;
58
  metas.reserve(targets.size());
59
  for (const paddle::experimental::Tensor& t : targets) {
60
    metas.emplace_back(unsafe_autograd_meta(t));
61 62 63 64
  }
  return metas;
}

65
AutogradMeta* EagerUtils::nullable_autograd_meta(
66
    const paddle::experimental::Tensor& target) {
67 68 69 70 71 72
  auto* p_autograd_meta = target.get_autograd_meta();
  if (!p_autograd_meta) return nullptr;

  return static_cast<AutogradMeta*>(p_autograd_meta);
}

73
std::vector<AutogradMeta*> EagerUtils::nullable_autograd_meta(
74
    const std::vector<paddle::experimental::Tensor>& targets) {
75 76
  std::vector<AutogradMeta*> metas;
  metas.reserve(targets.size());
77
  for (const paddle::experimental::Tensor& t : targets) {
78 79 80 81 82
    metas.emplace_back(nullable_autograd_meta(t));
  }
  return metas;
}

83
std::vector<AutogradMeta*> EagerUtils::autograd_meta(
84
    std::vector<paddle::experimental::Tensor>* targets) {
85 86 87
  std::vector<AutogradMeta*> ret;
  ret.reserve(targets->size());

88
  // for autograd_meta we can tolerent it has nullptr.
89 90 91
  for (size_t i = 0; i < targets->size(); i++) {
    auto* p_autograd_meta = autograd_meta(&((*targets)[i]));
    ret.emplace_back(p_autograd_meta);
92 93 94 95 96
  }
  return ret;
}

std::pair<size_t, size_t> EagerUtils::OutRankInfo(
97
    const paddle::experimental::Tensor& target) {
98 99 100 101
  return unsafe_autograd_meta(target)->OutRankInfo();
}

std::shared_ptr<GradNodeBase> EagerUtils::grad_node(
102
    const paddle::experimental::Tensor& target) {
103 104 105 106 107 108
  auto* meta = nullable_autograd_meta(target);
  if (meta) {
    return meta->GetMutableGradNode();
  } else {
    return nullptr;
  }
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
}

void EagerUtils::SetHistory(std::vector<AutogradMeta*>* autograd_metas,
                            const std::shared_ptr<GradNodeBase>& grad_node) {
  for (const auto& autograd_meta : *autograd_metas) {
    autograd_meta->SetGradNode(grad_node);
  }
}

void EagerUtils::SetHistory(AutogradMeta* autograd_meta,
                            const std::shared_ptr<GradNodeBase>& grad_node) {
  autograd_meta->SetGradNode(grad_node);
}

void EagerUtils::SetOutRankWithSlot(std::vector<AutogradMeta*>* targets,
                                    size_t slot_id) {
  // Set OutRankInfo from 0 to size of targets
  for (size_t i = 0; i < targets->size(); i++) {
    (*targets)[i]->SetSingleOutRankWithSlot(slot_id, i);
  }
}
void EagerUtils::SetOutRankWithSlot(AutogradMeta* target, size_t slot_id) {
  target->SetSingleOutRankWithSlot(slot_id, 0);
}

134 135 136
std::shared_ptr<egr::EagerTensor> EagerUtils::TrySyncToVar(
    const paddle::experimental::Tensor& tensor) {
  return std::make_shared<egr::EagerTensor>(tensor);
137 138 139
}

std::vector<std::shared_ptr<egr::EagerTensor>> EagerUtils::TrySyncToVars(
140
    const paddle::experimental::Tensor& tensor) {
141 142 143 144
  return {TrySyncToVar(tensor)};
}

std::vector<std::shared_ptr<egr::EagerTensor>> EagerUtils::TrySyncToVars(
145 146 147 148 149 150 151
    paddle::experimental::Tensor* tensor) {
  PADDLE_ENFORCE_NOT_NULL(
      tensor,
      paddle::platform::errors::Fatal(
          "Should Not Pass Empty tensor pointer in, since only output can "
          "reach this, please check output value and make sure it's not null"));
  return {TrySyncToVar(*tensor)};
152 153
}

154
std::vector<std::shared_ptr<egr::EagerTensor>> EagerUtils::TrySyncToVars(
155
    const std::vector<paddle::experimental::Tensor*>& tensors) {
156 157 158 159
  std::vector<std::shared_ptr<EagerTensor>> res;
  size_t num = tensors.size();
  res.reserve(num);
  for (size_t i = 0; i < num; i++) {
160 161 162 163 164 165 166 167
    auto* tensor = tensors[i];
    PADDLE_ENFORCE_NOT_NULL(
        tensor, paddle::platform::errors::Fatal(
                    "Tensor is null and cannot be copied. "
                    "We are tring to TrySyncToVars tensor from its "
                    "shared_ptr, this error may indicate some outputs "
                    "are nullptr"));
    res.emplace_back(TrySyncToVar(*tensor));
168 169 170 171
  }
  return res;
}

172 173
std::vector<std::shared_ptr<egr::EagerTensor>> EagerUtils::TrySyncToVars(
    const std::vector<paddle::experimental::Tensor>& tensors) {
174 175 176 177
  std::vector<std::shared_ptr<EagerTensor>> res;
  size_t num = tensors.size();
  res.reserve(num);
  for (size_t i = 0; i < num; i++) {
178
    res.emplace_back(TrySyncToVar(tensors[i]));
179 180 181 182
  }
  return res;
}

183
std::vector<std::shared_ptr<EagerTensor>> EagerUtils::CreateVars(
184 185 186 187 188 189 190 191 192 193
    const size_t num) {
  std::vector<std::shared_ptr<EagerTensor>> res;
  res.reserve(num);
  for (size_t i = 0; i < num; i++) {
    res.emplace_back(
        new EagerTensor(egr::Controller::Instance().GenerateUniqueName()));
  }
  return res;
}

194
std::vector<paddle::experimental::Tensor> EagerUtils::GetOutputs(
195
    const std::vector<std::shared_ptr<EagerTensor>>& outs) {
196
  std::vector<paddle::experimental::Tensor> res;
197 198 199 200 201 202 203 204 205
  res.reserve(outs.size());
  for (const auto& out : outs) {
    PADDLE_ENFORCE_NOT_NULL(
        out.get(), paddle::platform::errors::Fatal(
                       "Eager Tensor %s is null and cannot be copied. "
                       "We are tring to Get Output tensor from its "
                       "shared_ptr, this error may indicate some outputs "
                       "are nullptr",
                       out->name()));
206
    res.emplace_back(out->GetTensorBase(), out->name());
207 208 209 210
  }
  return res;
}

211
paddle::experimental::Tensor EagerUtils::GetOutput(
212 213 214 215 216 217 218
    const std::shared_ptr<EagerTensor>& out) {
  PADDLE_ENFORCE_NOT_NULL(
      out.get(), paddle::platform::errors::Fatal(
                     "Eager Tensor %s is null and cannot be copied. We "
                     "are tring to Get Output tensor from its shared_ptr, "
                     "this error may indicate output is nullptr",
                     out->name()));
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
  return paddle::experimental::Tensor(out->GetTensorBase(), out->name());
}

void EagerUtils::OverwriteOutputs(const std::shared_ptr<EagerTensor>& out,
                                  paddle::experimental::Tensor* tensor) {
  PADDLE_ENFORCE_NOT_NULL(
      tensor, paddle::platform::errors::Fatal(
                  "Tensor is null and cannot be copied. "
                  "We are tring to OverwriteOutput from its "
                  "shared_ptr, this error may indicate some outputs "
                  "are nullptr"));
  tensor->set_impl(out->GetTensorBase());
}

void EagerUtils::OverwriteOutputs(
    const std::vector<std::shared_ptr<EagerTensor>>& outs,
    const std::vector<paddle::experimental::Tensor*>& tensors) {
  PADDLE_ENFORCE_EQ(
      outs.size(), tensors.size(),
      paddle::platform::errors::Fatal(
          "We are tring to OverwriteOutputs which passed in and it expected "
          "elements num of outs and origin outputs are equal, but we got outs "
          "size of: %d, and tensors passed in size is: %d",
          outs.size(), tensors.size()));
  for (size_t i = 0; i < outs.size(); i++) {
    OverwriteOutputs(outs[i], tensors[i]);
  }
246 247
}

248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
void EagerUtils::OverwriteOutputs(const paddle::experimental::Tensor& out,
                                  paddle::experimental::Tensor* tensor) {
  PADDLE_ENFORCE_NOT_NULL(
      tensor, paddle::platform::errors::Fatal(
                  "Tensor is null and cannot be copied. "
                  "We are tring to OverwriteOutput from its "
                  "shared_ptr, this error may indicate some outputs "
                  "are nullptr"));
  *tensor = out;
}
void EagerUtils::OverwriteOutputs(
    const std::vector<paddle::experimental::Tensor>& outs,
    const std::vector<paddle::experimental::Tensor*>& tensors) {
  for (size_t i = 0; i < outs.size(); i++) {
    PADDLE_ENFORCE_NOT_NULL(
        tensors[i], paddle::platform::errors::Fatal(
                        "Tensor is null and cannot be copied. "
                        "We are tring to OverwriteOutput from its "
                        "shared_ptr, this error may indicate some outputs "
                        "are nullptr"));
    *tensors[i] = outs[i];
  }
}

paddle::experimental::Tensor EagerUtils::RecoverTensorWrapper(
273 274 275 276
    TensorWrapper* tw, const std::shared_ptr<GradNodeBase>& grad_node) {
  return tw->recover(grad_node);
}

277
std::vector<paddle::experimental::Tensor> EagerUtils::RecoverTensorWrapper(
278 279
    std::vector<TensorWrapper>* tw,
    const std::shared_ptr<GradNodeBase>& grad_node) {
280
  std::vector<paddle::experimental::Tensor> ret;
281 282 283 284 285 286
  for (auto& t : *tw) {
    ret.emplace_back(t.recover(grad_node));
  }
  return ret;
}

287 288
void EagerUtils::CheckAndRetainGrad(
    const paddle::experimental::Tensor& tensor) {
289 290
  VLOG(6) << "Check RetainGradForTensor: " << tensor.name();
  if (FLAGS_retain_grad_for_all_tensor) {
291
    VLOG(6) << "RetainGradForTensor: " << tensor.name();
292 293 294 295 296
    egr::egr_utils_api::RetainGradForTensor(tensor);
  }
}

void EagerUtils::CheckAndRetainGrad(
297
    const std::vector<paddle::experimental::Tensor>& tensors) {
298 299
  if (FLAGS_retain_grad_for_all_tensor) {
    for (auto& tensor : tensors) {
300
      VLOG(6) << "RetainGradForTensor: " << tensor.name();
301 302 303 304 305
      egr::egr_utils_api::RetainGradForTensor(tensor);
    }
  }
}

306
}  // namespace egr