gradient_accumulator.cc 33.4 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/gradient_accumulator.h"
16

J
Jiabin Yang 已提交
17 18 19
#include <algorithm>
#include <memory>
#include <utility>
20

J
Jiabin Yang 已提交
21
#include "paddle/fluid/framework/lod_tensor.h"
22
#include "paddle/fluid/framework/selected_rows_utils.h"
J
Jiabin Yang 已提交
23 24 25
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
26
#include "paddle/fluid/operators/math/selected_rows_functor.h"
27
#include "paddle/fluid/platform/complex.h"
J
Jiabin Yang 已提交
28
#include "paddle/fluid/platform/device_context.h"
29
#include "paddle/fluid/platform/float16.h"
J
Jiabin Yang 已提交
30
#include "paddle/fluid/platform/profiler.h"
H
hong 已提交
31 32 33
#ifdef PADDLE_WITH_XPU
#include "xpu/refactor/math.h"
#endif
34
#ifdef PADDLE_WITH_ASCEND_CL
35
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
36
#endif
J
Jiabin Yang 已提交
37 38 39 40

namespace paddle {
namespace imperative {

41 42 43
static void MoveOrCopyVar(framework::Variable* dst, framework::Variable* src,
                          bool force_copy) {
  if (!force_copy) {
44
    VLOG(6) << "Just Move Variable when sum gradients within this graph";
45 46 47 48
    *dst = std::move(*src);
    return;
  }

49
  VLOG(6) << "Copy occurs when sum gradients within this graph";
50 51 52 53 54 55 56 57
  if (src->IsType<framework::LoDTensor>()) {
    auto& src_tensor = src->Get<framework::LoDTensor>();
    if (!dst->IsType<framework::LoDTensor>()) {
      dst->Clear();
    }
    auto* dst_tensor = dst->GetMutable<framework::LoDTensor>();
    framework::TensorCopy(src_tensor, src_tensor.place(), dst_tensor);
    dst_tensor->set_lod(src_tensor.lod());
58 59 60
  } else if (src->IsType<pten::SelectedRows>()) {
    auto& src_selected_rows = src->Get<pten::SelectedRows>();
    if (!dst->IsType<pten::SelectedRows>()) {
61 62
      dst->Clear();
    }
63
    auto* dst_selected_rows = dst->GetMutable<pten::SelectedRows>();
64 65 66 67 68 69 70
    framework::TensorCopy(src_selected_rows.value(),
                          src_selected_rows.value().place(),
                          dst_selected_rows->mutable_value());
    dst_selected_rows->set_rows(src_selected_rows.rows());
    dst_selected_rows->set_height(src_selected_rows.height());
  } else {
    PADDLE_THROW(platform::errors::PermissionDenied(
71
        "Only support LoDTensor and SelectedRows for sum gradient"));
72 73 74
  }
}

J
Jiabin Yang 已提交
75 76 77 78 79 80
template <typename T>
class TensorAddFunctor : public boost::static_visitor<> {
 public:
  TensorAddFunctor(int64_t numel, const T* x, T* y)
      : numel_(numel), x_(x), y_(y) {}

81
  void operator()(const platform::CPUPlace& place) const {
J
Jiabin Yang 已提交
82 83 84 85 86 87
    platform::CPUDeviceContext* ctx = dynamic_cast<platform::CPUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto blas = operators::math::GetBlas<platform::CPUDeviceContext, T>(*ctx);
    blas.AXPY(numel_, 1., x_, y_);
  }

H
hong 已提交
88
#ifdef PADDLE_WITH_XPU
89
  void operator()(const platform::XPUPlace& place) const {
90
    using XPUType = typename XPUTypeTrait<T>::Type;
H
hong 已提交
91 92
    platform::XPUDeviceContext* ctx = dynamic_cast<platform::XPUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
93 94 95 96 97 98 99 100
    int r = xpu::add<XPUType>(
        ctx->x_context(), reinterpret_cast<const XPUType*>(x_),
        reinterpret_cast<const XPUType*>(y_), reinterpret_cast<XPUType*>(y_),
        static_cast<int>(numel_));
    PADDLE_ENFORCE_EQ(
        r, XPU_SUCCESS,
        platform::errors::External("XPU add kernel return wrong value[%d %s]",
                                   r, XPUAPIErrorMsg[r]));
H
hong 已提交
101 102
  }
#else
103
  void operator()(const platform::XPUPlace& place) const {
104 105 106 107 108
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
H
hong 已提交
109
#endif
110

111
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
112
  void operator()(const platform::CUDAPlace& place) const {
J
Jiabin Yang 已提交
113 114 115 116 117 118 119
    platform::CUDADeviceContext* ctx =
        dynamic_cast<platform::CUDADeviceContext*>(
            platform::DeviceContextPool::Instance().Get(place));
    auto blas = operators::math::GetBlas<platform::CUDADeviceContext, T>(*ctx);
    blas.AXPY(numel_, 1., x_, y_);
  }
#else
120
  void operator()(const platform::CUDAPlace& place) const {
121
    PADDLE_THROW(platform::errors::PermissionDenied(
122 123 124 125 126 127
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
#endif

F
fwenguang 已提交
128
#ifdef PADDLE_WITH_MLU
129
  void operator()(const platform::MLUPlace& place) const {
F
fwenguang 已提交
130 131 132 133 134 135 136
    // TODO(fwg): SUPPORT it
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
#else
137
  void operator()(const platform::MLUPlace& place) const {
F
fwenguang 已提交
138 139 140 141 142 143 144
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
#endif

145
#ifdef PADDLE_WITH_ASCEND_CL
146
  void operator()(const platform::NPUPlace& place) const {
147 148 149 150 151 152 153
    // TODO(zhiqiu): SUPPORT it
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
#else
154
  void operator()(const platform::NPUPlace& place) const {
155
    PADDLE_THROW(platform::errors::PermissionDenied(
156 157 158
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
J
Jiabin Yang 已提交
159 160 161
  }
#endif

162
  void operator()(const platform::NPUPinnedPlace& place) const {
163 164 165 166 167
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
J
Jiabin Yang 已提交
168
  // there is NO blas in CUDAPinnedPlace
169
  void operator()(const platform::CUDAPinnedPlace& place) const {
170 171 172 173
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
J
Jiabin Yang 已提交
174
  }
J
jianghaicheng 已提交
175
  // there is NO support in IPUPlace
176
  void operator()(const platform::IPUPlace& place) const {
J
jianghaicheng 已提交
177 178 179 180 181
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
J
Jiabin Yang 已提交
182 183 184 185

 private:
  int64_t numel_;
  const T* x_;
186
  mutable T* y_;
J
Jiabin Yang 已提交
187 188
};

189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
#ifdef PADDLE_WITH_XPU
template <typename T>
void XPUTensorAddFunctor(const platform::Place& place,
                         const framework::Tensor& src, framework::Tensor* dst) {
  using XPUType = typename XPUTypeTrait<T>::Type;
  platform::XPUDeviceContext* ctx = dynamic_cast<platform::XPUDeviceContext*>(
      platform::DeviceContextPool::Instance().Get(place));
  const XPUType* x = reinterpret_cast<const XPUType*>(src.data<T>());
  XPUType* y = reinterpret_cast<XPUType*>(dst->mutable_data<T>(place));
  int r = xpu::add<XPUType>(ctx->x_context(), x, y, y,
                            static_cast<int>(src.numel()));
  PADDLE_ENFORCE_EQ(
      r, XPU_SUCCESS,
      platform::errors::External("XPU add kernel return wrong value[%d %s]", r,
                                 XPUAPIErrorMsg[r]));
}
#endif

207 208 209 210 211 212 213 214 215 216
template <typename DeviceContext, typename T>
void TensorAddImpl(const framework::Tensor& src, framework::Tensor* dst,
                   const platform::Place& place) {
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  paddle::platform::DeviceContext* ctx = pool.Get(place);
  auto dev_ctx = dynamic_cast<DeviceContext*>(ctx);
  operators::math::ElementwiseAddTo<DeviceContext, T> func;
  func(dev_ctx, src, dst);
}

217 218 219 220 221 222 223 224 225 226 227 228 229 230
std::shared_ptr<pten::DenseTensor> GetInnerDstTensor(egr::EagerTensor* dst) {
  std::shared_ptr<pten::DenseTensor> dst_tensor =
      std::dynamic_pointer_cast<pten::DenseTensor>(dst->impl());
  return dst_tensor;
}

std::shared_ptr<pten::DenseTensor> GetInnerSrcTensor(
    const egr::EagerTensor& src) {
  std::shared_ptr<pten::DenseTensor> dst_tensor =
      std::dynamic_pointer_cast<pten::DenseTensor>(src.impl());
  return dst_tensor;
}

std::shared_ptr<pten::DenseTensor> GetInnerDstTensor(framework::Variable* dst) {
J
Jiabin Yang 已提交
231
  auto* dst_tensor = dst->GetMutable<framework::LoDTensor>();
232 233 234 235 236
  return std::make_shared<pten::DenseTensor>(*dst_tensor);
}

std::shared_ptr<pten::DenseTensor> GetInnerSrcTensor(
    const framework::Variable& src) {
J
Jiabin Yang 已提交
237
  auto& src_tensor = src.Get<framework::LoDTensor>();
238 239 240 241 242 243 244 245 246 247
  return std::make_shared<pten::DenseTensor>(src_tensor);
}

template <typename VarType>
void TensorAdd(const VarType& src, VarType* dst) {
  std::shared_ptr<pten::DenseTensor> d_tensor = GetInnerDstTensor(dst);
  std::shared_ptr<pten::DenseTensor> s_tensor = GetInnerSrcTensor(src);

  auto* dst_tensor = d_tensor.get();
  auto& src_tensor = *s_tensor.get();
J
Jiabin Yang 已提交
248 249 250 251 252 253 254 255 256

  auto numel = src_tensor.numel();

  // FIXME(minqiyang): loss_grad op will pass a zero grad of label
  // ugly fix for it
  if (numel == 0) {
    return;
  }

257 258 259 260 261 262 263
  PADDLE_ENFORCE_EQ(
      dst_tensor->numel(), numel,
      platform::errors::PreconditionNotMet(
          "The number of elements of source tensor and destination tensor "
          "should be equal, but got the number of elements of source tensor is "
          "%zu and the number of elements of destination tensor is %zu.",
          numel, dst_tensor->numel()));
J
Jiabin Yang 已提交
264 265 266 267

  auto data_type = src_tensor.type();
  auto place = src_tensor.place();

268 269 270 271 272 273
  PADDLE_ENFORCE_EQ(dst_tensor->type(), data_type,
                    platform::errors::PreconditionNotMet(
                        "The data type of source tensor and destination tensor "
                        "should be equal, Otherwise, the calculation results "
                        "will be incorrect."));

274 275 276 277 278 279 280
#ifdef PADDLE_WITH_XPU
  // if src and dst are in different place, copy dst to src's place
  if (dst_tensor->place() != place) {
    paddle::framework::TensorCopySync(*dst_tensor, place, dst_tensor);
  }
#endif

281
#define PADDLE_TENSOR_ADD(cpp_type)                                  \
J
Jiabin Yang 已提交
282 283 284 285
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) { \
    TensorAddFunctor<cpp_type> func(                                 \
        numel, src_tensor.data<cpp_type>(),                          \
        dst_tensor->mutable_data<cpp_type>(place));                  \
286
    platform::VisitPlace(place, func);                               \
J
Jiabin Yang 已提交
287 288 289
    return;                                                          \
  }

290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
#ifdef PADDLE_WITH_ASCEND_CL
  if (platform::is_npu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::DeviceContext* ctx = pool.Get(place);
    auto dev_ctx = dynamic_cast<platform::NPUDeviceContext*>(ctx);
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      dst_tensor->mutable_data<float>(place);
    } else if (data_type == framework::DataTypeTrait<double>::DataType()) {
      dst_tensor->mutable_data<double>(place);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      dst_tensor->mutable_data<platform::float16>(place);
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
          framework::DataTypeToString(data_type), place));
    }
    const auto& runner = operators::NpuOpRunner(
        "Add", {*dst_tensor, src_tensor}, {*dst_tensor}, {});
    runner.Run(dev_ctx->stream());
    return;
  }
#endif
314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331

#ifdef PADDLE_WITH_XPU
  if (platform::is_xpu_place(place)) {
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      XPUTensorAddFunctor<float>(place, src_tensor, dst_tensor);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      XPUTensorAddFunctor<platform::float16>(place, src_tensor, dst_tensor);
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
          framework::DataTypeToString(data_type), place));
    }
    return;
  }
#endif

332
  PADDLE_TENSOR_ADD(float);
333

H
hong 已提交
334 335
#ifndef PADDLE_WITH_XPU
  // NOTE(phlrain): xpu only support float
336
  PADDLE_TENSOR_ADD(double);
337 338
  // NOTE(chenweihang): only support complex grad tensor accumulated,
  // support selected rows if needed in the future
339 340
  PADDLE_TENSOR_ADD(platform::complex<float>);
  PADDLE_TENSOR_ADD(platform::complex<double>);
H
hong 已提交
341
#endif
J
Jiabin Yang 已提交
342

343
#undef PADDLE_TENSOR_ADD
J
Jiabin Yang 已提交
344

345 346
  if (data_type == framework::proto::VarType::FP16) {
    if (platform::is_gpu_place(place)) {
347
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
      return TensorAddImpl<platform::CUDADeviceContext, platform::float16>(
          src_tensor, dst_tensor, place);
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
          framework::DataTypeToString(data_type), place));
#endif
    } else if (platform::is_cpu_place(place)) {
      return TensorAddImpl<platform::CPUDeviceContext, platform::float16>(
          src_tensor, dst_tensor, place);
    }
  }
  PADDLE_THROW(platform::errors::Unimplemented(
      "Gradient accumulation of data type (%s) on place (%s) is not "
      "supported in imperative mode",
      framework::DataTypeToString(data_type), place));
J
Jiabin Yang 已提交
365 366
}

367 368 369 370 371
template void TensorAdd<framework::Variable>(const framework::Variable& src,
                                             framework::Variable* dst);
template void TensorAdd<egr::EagerTensor>(const egr::EagerTensor& src,
                                          egr::EagerTensor* dst);

372 373 374
void SelectedRowsAddToTensor(const framework::Variable& src,
                             framework::Variable* dst) {
  auto* dst_tensor = dst->GetMutable<framework::LoDTensor>();
375
  auto& src_selected_rows = src.Get<pten::SelectedRows>();
376 377 378 379 380 381 382 383 384 385 386 387 388 389
  auto place = dst_tensor->place();
  auto data_type = src_selected_rows.value().type();
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

#define PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(dev_ctx_type, cpp_type)           \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {         \
    paddle::platform::DeviceContext* dev_ctx = pool.Get(place);              \
    paddle::operators::math::SelectedRowsAddToTensor<dev_ctx_type, cpp_type> \
        functor;                                                             \
    functor(*(dynamic_cast<dev_ctx_type*>(dev_ctx)), src_selected_rows,      \
            dst_tensor);                                                     \
    return;                                                                  \
  }

390
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
391 392 393 394 395 396 397
  if (paddle::platform::is_gpu_place(place)) {
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CUDADeviceContext, float);
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CUDADeviceContext, double);
  } else {
#endif
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CPUDeviceContext, float);
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CPUDeviceContext, double);
398
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
399 400 401 402 403 404 405 406 407 408
  }
#endif

#undef PADDLE_SELECTED_ROWS_ADD_TO_TENSOR

  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsAddToTensor",
      framework::DataTypeToString(data_type)));
}

409 410 411
void SelectedRowsAddTensor(const framework::Variable& src_selected_rows_var,
                           const framework::Variable& src_tensor_var,
                           framework::Variable* dst_tensor_var) {
412
  const auto& src_selected_rows =
413
      src_selected_rows_var.Get<pten::SelectedRows>();
414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431
  const auto& src_tensor = src_tensor_var.Get<framework::LoDTensor>();
  const auto& place = src_tensor.place();
  auto data_type = src_tensor.type();
  auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);

  auto* dst_tensor = dst_tensor_var->GetMutable<framework::LoDTensor>();
  dst_tensor->Resize(src_tensor.dims());
  dst_tensor->mutable_data(place, data_type);

#define PADDLE_SELECTED_ROWS_ADD_TENSOR(dev_ctx_type, cpp_type)            \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {       \
    paddle::operators::math::SelectedRowsAddTensor<dev_ctx_type, cpp_type> \
        functor;                                                           \
    functor(*(dynamic_cast<dev_ctx_type*>(dev_ctx)), src_selected_rows,    \
            src_tensor, dst_tensor);                                       \
    return;                                                                \
  }

432
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
433 434 435 436 437 438 439
  if (platform::is_gpu_place(place)) {
    PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CUDADeviceContext, float);
    PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CUDADeviceContext, double);
  } else {
#endif
    PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CPUDeviceContext, float);
    PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CPUDeviceContext, double);
440
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
441 442 443 444 445 446 447 448 449 450
  }
#endif

  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsAddToTensor",
      framework::DataTypeToString(data_type)));

#undef PADDLE_SELECTED_ROWS_ADD_TENSOR
}

451 452 453
// Note(chenweihang): when two selected rows need to be added,
//   adding one to another is not equal to merging two selected rows
//   to one then add it to a empty selected rows, the after is correct
454 455
std::shared_ptr<VariableWrapper> SelectedRowsMerge(
    const framework::Variable& src1, const framework::Variable& src2) {
456 457
  auto& src_selected_rows1 = src1.Get<pten::SelectedRows>();
  auto& src_selected_rows2 = src2.Get<pten::SelectedRows>();
458 459 460 461
  auto place = src_selected_rows1.value().place();
  auto data_type = src_selected_rows1.value().type();
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

462
  std::vector<const pten::SelectedRows*> src_selected_rows;
463 464
  src_selected_rows.emplace_back(&src_selected_rows1);
  src_selected_rows.emplace_back(&src_selected_rows2);
465
  auto dst_var = std::make_shared<VariableWrapper>("Temp");
466
  auto* dst_selected_rows =
467
      dst_var->MutableVar()->GetMutable<pten::SelectedRows>();
468 469 470 471 472 473 474 475 476 477 478

#define PADDLE_SELECTED_ROWS_ADD(dev_ctx_type, cpp_type)                  \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {      \
    paddle::platform::DeviceContext* dev_ctx = pool.Get(place);           \
    paddle::operators::math::scatter::MergeAdd<dev_ctx_type, cpp_type>    \
        merge_add;                                                        \
    merge_add(*(dynamic_cast<dev_ctx_type*>(dev_ctx)), src_selected_rows, \
              dst_selected_rows);                                         \
    return dst_var;                                                       \
  }

479
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
480 481 482 483 484 485 486
  if (paddle::platform::is_gpu_place(place)) {
    PADDLE_SELECTED_ROWS_ADD(platform::CUDADeviceContext, float);
    PADDLE_SELECTED_ROWS_ADD(platform::CUDADeviceContext, double);
  } else {
#endif
    PADDLE_SELECTED_ROWS_ADD(platform::CPUDeviceContext, float);
    PADDLE_SELECTED_ROWS_ADD(platform::CPUDeviceContext, double);
487
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
488 489 490 491 492 493 494 495 496 497
  }
#endif

#undef PADDLE_SELECTED_ROWS_ADD

  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsMerge",
      framework::DataTypeToString(data_type)));
}

498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525
void VariableAdd(const egr::EagerTensor& src_tensor,
                 egr::EagerTensor* dst_tensor) {
  auto& src = src_tensor.Var();
  auto* dst = dst_tensor->MutableVar();

  if (dst->IsType<paddle::framework::LoDTensor>()) {
    if (src.IsType<paddle::framework::LoDTensor>()) {
      paddle::imperative::TensorAdd<paddle::framework::Variable>(src, dst);
    } else if (src.IsType<pten::SelectedRows>()) {
      paddle::imperative::SelectedRowsAddToTensor(src, dst);
    } else {
      PADDLE_THROW(paddle::platform::errors::InvalidArgument(
          "Unexpected branch, output variable type is %s",
          paddle::framework::ToTypeName(dst->Type())));
    }
  } else {
    if (src.IsType<paddle::framework::LoDTensor>()) {
      paddle::framework::Variable new_dst;
      paddle::imperative::SelectedRowsAddTensor(*dst, src, &new_dst);
      *dst = std::move(new_dst);
    } else {
      PADDLE_THROW(paddle::platform::errors::InvalidArgument(
          "Unexpected branch, output variable type is %s",
          paddle::framework::ToTypeName(dst->Type())));
    }
  }
}

526
void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var,
527
                        VariableWrapper* dst_var, bool unchange_input) {
528
  auto& src = var->Var();
529
  auto* dst = dst_var->MutableVar();
530 531
  if (dst->IsType<framework::LoDTensor>()) {
    if (src.IsType<framework::LoDTensor>()) {
532
      TensorAdd<framework::Variable>(src, dst);
533
    } else if (src.IsType<pten::SelectedRows>()) {
534 535 536 537 538 539 540 541
      SelectedRowsAddToTensor(src, dst);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unexpected branch, output variable type is %s",
          framework::ToTypeName(dst->Type())));
    }
  } else {
    if (src.IsType<framework::LoDTensor>()) {
542 543 544 545 546 547 548 549 550
      if (unchange_input) {
        framework::Variable new_dst;
        SelectedRowsAddTensor(*dst, src, &new_dst);
        *dst = std::move(new_dst);
      } else {
        auto* src_mutable = var->MutableVar();
        SelectedRowsAddToTensor(*dst, src_mutable);
        *dst = std::move(*(var->MutableVar()));
      }
551
    } else if (src.IsType<pten::SelectedRows>()) {
552
      auto temp = SelectedRowsMerge(src, *dst);
553 554 555 556 557 558 559 560 561
      *dst = std::move(*(temp->MutableVar()));
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unexpected branch, output variable type is %s",
          framework::ToTypeName(dst->Type())));
    }
  }
}

562 563
static platform::Place GetPlaceOfVar(
    const std::shared_ptr<VariableWrapper>& var) {
564 565 566
  platform::Place place;
  if (var->Var().IsType<framework::LoDTensor>()) {
    place = var->Var().Get<framework::LoDTensor>().place();
567 568
  } else if (var->Var().IsType<pten::SelectedRows>()) {
    place = var->Var().Get<pten::SelectedRows>().place();
569 570 571 572 573 574 575
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "only support LoDTensor and SelectedRows in dygraph"));
  }
  return place;
}

576 577
void GradientAccumulator::AccumulateGrad() {
  /**
578 579
   * If the leaf gradient has been calculated done, the inner_var_
   * should be added to the var_.
580 581 582 583 584 585 586 587 588 589
   */
  if (!var_->IsLeafGrad() || !SumGradCompleted() || !HasInnerVar()) {
    return;
  }
  PADDLE_ENFORCE_EQ(HasInnerVar(), true,
                    platform::errors::InvalidArgument(
                        "Leaf tensor should have inner var to store results of "
                        "this auto-grad"));
  PADDLE_ENFORCE_EQ(inner_var_->Var().IsInitialized(), true,
                    platform::errors::InvalidArgument(
590
                        "Interior var of Leaf tensor should be initialized."));
591 592 593
  auto* src = inner_var_->MutableVar();
  auto* dst = var_->MutableVar();
  if (!var_->IsEmpty()) {
594 595 596
    VLOG(6) << "Leaf Var(" << var_->Name()
            << ")'s Gradient has been initizlized, will accumulate on "
               "previous gradient.";
597 598
    if (dst->IsType<framework::LoDTensor>()) {
      if (src->IsType<framework::LoDTensor>()) {
599
        TensorAdd<framework::Variable>(*src, dst);
600
      } else if (src->IsType<pten::SelectedRows>()) {
601 602
        SelectedRowsAddToTensor(*src, dst);
      }
603
    } else if (dst->IsType<pten::SelectedRows>()) {
604 605 606
      if (src->IsType<framework::LoDTensor>()) {
        SelectedRowsAddToTensor(*dst, src);
        *dst = std::move(*src);
607
      } else if (src->IsType<pten::SelectedRows>()) {
608 609 610 611 612 613 614 615
        auto temp = SelectedRowsMerge(*src, *dst);
        *dst = std::move(*(temp->MutableVar()));
      }
    } else {
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Only support LoDTensor and SelectedRows for gradient var"));
    }
  } else {
616 617 618
    VLOG(6)
        << "Leaf Var(" << var_->Name()
        << ")'s Gradient has not been initialized, not accumulate. Just move";
619 620 621
    *(dst) = std::move(*src);
    var_->SetType(inner_var_->Type());
    var_->SetDataType(inner_var_->DataType());
622
    var_->SetIsEmpty(false);
623 624 625 626
  }
  inner_var_.reset();
}

627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644
void GradientAccumulator::CallGradientHooks() {
  PADDLE_ENFORCE_EQ(var_->IsLeafGrad(), true,
                    platform::errors::Unavailable(
                        "Only leaf gradient Tensor can deal with by gradient "
                        "hook in gradient accumulator."));
  PADDLE_ENFORCE_EQ(
      SumGradCompleted(), true,
      platform::errors::PreconditionNotMet(
          "Only can call gradient hooks after sum gradient completed."));
  PADDLE_ENFORCE_EQ(
      HasInnerVar(), true,
      platform::errors::PreconditionNotMet(
          "Leaf Tensor's inner var is nullptr when call gradient hook."));
  PADDLE_ENFORCE_EQ(
      inner_var_->Var().IsInitialized(), true,
      platform::errors::PreconditionNotMet("Leaf Tensor's inner var "
                                           "is not initialized when "
                                           "call gradient hook."));
645 646
  if (var_->HasVariableWrapperHook()) {
    VLOG(3) << "Call " << var_->GetVariableWrapperHooks().size()
647 648 649 650
            << " hooks of leaf gradient accumulator's inner var `"
            << var_->Name() << "`.";
    auto tmp_var = inner_var_;
    VLOG(3) << "Input var " << var_->Name() << "'s hook size - "
651 652
            << var_->GetVariableWrapperHooks().size();
    for (const auto& hook_pair : var_->GetVariableWrapperHooks()) {
653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672
      tmp_var = (*hook_pair.second)(tmp_var);
    }
    inner_var_ = tmp_var;
  }
}

void GradientAccumulator::CallReduceHooks() {
  PADDLE_ENFORCE_EQ(
      var_->IsLeafGrad(), true,
      platform::errors::Unavailable("Only leaf gradient Tensor can deal with "
                                    "by reduce hook in gradient accumulator."));
  PADDLE_ENFORCE_EQ(SumGradCompleted(), true,
                    platform::errors::PreconditionNotMet(
                        "Only can call reduce hooks after the gradient "
                        "summation is completed in current batch."));
  PADDLE_ENFORCE_EQ(HasInnerVar(), false,
                    platform::errors::PreconditionNotMet(
                        "Only can call reduce hooks after the "
                        "gradient accumulation is completed in "
                        "current batch or across batchs."));
673 674
  if (var_->HasVoidHook()) {
    for (const auto& hook : var_->GetVoidHooks()) {
675
      VLOG(3) << "call gradient accumulator backward hooks.";
676
      (*hook)();
677 678 679 680
    }
  }
}

681 682
void EagerGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
                                       size_t trace_id, bool unchange_input) {
683 684 685 686 687 688 689 690
  /**
   * If var has grad node, it indicates that this var would be an input
   * of a grad op. Therefore, it should not be changed.
   */
  if (var->HasGradNode()) {
    unchange_input = true;
  }

691
  auto* dst_var = Var();
692
  platform::Place place = GetPlaceOfVar(var);
693 694 695
  if (!dst_var->OverridedStopGradient()) {
    if (CurCnt() == 0) {
      MoveOrCopyVar(dst_var->MutableVar(), var->MutableVar(), unchange_input);
696
    } else {
697 698 699
      VLOG(6) << "Sum Gradient for: " << dst_var->Name()
              << " within this graph.";
      VariableWrapperAdd(var, dst_var, unchange_input);
700
    }
J
Jiabin Yang 已提交
701
  } else {
702 703 704
    if (!dst_var->Var().IsInitialized() ||
        !dst_var->Var().Get<framework::LoDTensor>().IsInitialized()) {
      VLOG(6) << "Set StopGradient Grad: " << dst_var->Name() << " as zero ";
705
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
706 707 708 709
      if (!dst_var->Var().IsInitialized()) {
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
        VLOG(6) << "Dims of " << dst_var->Name() << " is set as: "
710 711 712 713 714
                << var->Var().Get<framework::LoDTensor>().dims();
        tensor->Resize(var->Var().Get<framework::LoDTensor>().dims());
        tensor->mutable_data(place, var->DataType());
        operators::math::set_constant(*dev_ctx, tensor, 0.0);
      } else {
715 716
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
717 718 719
        tensor->mutable_data(place, var->DataType());
        operators::math::set_constant(*dev_ctx, tensor, 0.0);
      }
720
    }
J
Jiabin Yang 已提交
721
  }
722

723 724 725 726
  // Type may be changed after OP run, such as VarTypeInference
  // so synchronous VariableWrapper with Variable.
  if (dst_var->Var().IsType<framework::LoDTensor>()) {
    dst_var->SetType(framework::proto::VarType::LOD_TENSOR);
727
  } else if (dst_var->Var().IsType<pten::SelectedRows>()) {
728
    dst_var->SetType(framework::proto::VarType::SELECTED_ROWS);
729
  }
730

731
  // Increase curent count
732
  IncreaseCurCnt();
J
Jiabin Yang 已提交
733 734
}

735 736 737
void SortedGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
                                        size_t trace_id, bool unchange_input) {
  auto* dst_var = Var();
738
  platform::Place place = GetPlaceOfVar(var);
739
  if (!dst_var->OverridedStopGradient()) {
740
    if (ref_cnt_ == 1) {
741
      MoveOrCopyVar(dst_var->MutableVar(), var->MutableVar(),
742
                    unchange_input || var->HasGradNode());
743 744 745 746 747
    } else {
      if (tmp_grad_vars_.empty()) {
        tmp_grad_vars_.reserve(ref_cnt_);
      }

748
      tmp_grad_vars_.emplace_back(std::move(var), trace_id, unchange_input);
749 750 751 752 753

      if (tmp_grad_vars_.size() != ref_cnt_) {
        return;
      }

754 755
      VLOG(6) << "Sum Gradient for: " << dst_var->Name()
              << " within this graph.";
756 757 758 759 760 761 762 763 764 765
      std::sort(tmp_grad_vars_.begin(), tmp_grad_vars_.end(),
                [](const SavedVarInfo& info1, const SavedVarInfo& info2) {
                  return info1.trace_id > info2.trace_id;
                });

      for (auto& var_info : tmp_grad_vars_) {
        if (var_info.var->HasGradNode()) {
          var_info.unchange_input = true;
        }
      }
766

767
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
768
      if (paddle::platform::is_gpu_place(place)) {
769
        // sum selected rows firstly
770
        for (auto& var_info : tmp_grad_vars_) {
771
          if (!var_info.var->Var().IsType<pten::SelectedRows>()) {
772
            continue;
773
          }
774

775 776
          if (CurCnt() == 0) {
            MoveOrCopyVar(dst_var->MutableVar(), var_info.var->MutableVar(),
777 778
                          var_info.unchange_input);
          } else {
779
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
780
          }
781 782

          var_info.var = nullptr;
783 784
          // Increase count
          IncreaseCurCnt();
785 786 787 788 789 790 791 792 793 794
        }

        for (auto& var_info : tmp_grad_vars_) {
          if (!var_info.var) {
            continue;
          }

          PADDLE_ENFORCE_EQ(var_info.var->Var().IsType<framework::LoDTensor>(),
                            true, platform::errors::PermissionDenied(
                                      "Gradient var must be LoDTensor"));
795 796
          if (CurCnt() == 0) {
            MoveOrCopyVar(dst_var->MutableVar(), var_info.var->MutableVar(),
797 798
                          var_info.unchange_input);
          } else {
799
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
800
          }
801 802

          var_info.var = nullptr;
803 804
          // Increase count
          IncreaseCurCnt();
805 806 807
        }
      } else {
#endif
808 809 810 811 812 813
        for (auto& var_info : tmp_grad_vars_) {
          if (!var_info.var) {
            continue;
          }
          PADDLE_ENFORCE_EQ(
              var_info.var->Var().IsType<framework::LoDTensor>() ||
814
                  var_info.var->Var().IsType<pten::SelectedRows>(),
815 816 817 818 819 820 821 822 823 824 825 826
              true, platform::errors::PermissionDenied("The type of Gradient "
                                                       "var must be LoDTensor "
                                                       "or SelectedRows"));
          if (CurCnt() == 0) {
            MoveOrCopyVar(dst_var->MutableVar(), var_info.var->MutableVar(),
                          var_info.unchange_input);
          } else {
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
          }
          var_info.var = nullptr;
          // Increase count
          IncreaseCurCnt();
827
        }
828
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
829
      }
830
#endif
831
      tmp_grad_vars_.clear();
J
Jiabin Yang 已提交
832
    }
833
  } else {
834 835
    if (!dst_var->Var().IsInitialized() ||
        !dst_var->Var().Get<framework::LoDTensor>().IsInitialized()) {
836 837
      VLOG(6) << "Set StopGradient Grad: " << var->Name() << " as zero";
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
838 839 840 841
      if (!dst_var->Var().IsInitialized()) {
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
        VLOG(6) << "Dims of " << dst_var->Name() << " is set as: "
842 843 844 845 846
                << var->Var().Get<framework::LoDTensor>().dims();
        tensor->Resize(var->Var().Get<framework::LoDTensor>().dims());
        tensor->mutable_data(place, var->DataType());
        operators::math::set_constant(*dev_ctx, tensor, 0.0);
      } else {
847 848
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
849 850 851
        tensor->mutable_data(place, var->DataType());
        operators::math::set_constant(*dev_ctx, tensor, 0.0);
      }
J
Jiabin Yang 已提交
852
    }
853
    // looks like tmp_grad_vars will not have any member but just in case
J
Jiabin Yang 已提交
854 855
    tmp_grad_vars_.clear();
  }
856

857 858
  if (dst_var->Var().IsType<framework::LoDTensor>()) {
    dst_var->SetType(framework::proto::VarType::LOD_TENSOR);
859
  } else if (dst_var->Var().IsType<pten::SelectedRows>()) {
860
    dst_var->SetType(framework::proto::VarType::SELECTED_ROWS);
861
  }
J
Jiabin Yang 已提交
862 863 864 865
}

}  // namespace imperative
}  // namespace paddle