gradient_accumulator.cc 37.2 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/gradient_accumulator.h"
16

J
Jiabin Yang 已提交
17 18 19
#include <algorithm>
#include <memory>
#include <utility>
20

21
#include "paddle/fluid/framework/convert_utils.h"
J
Jiabin Yang 已提交
22
#include "paddle/fluid/framework/lod_tensor.h"
23
#include "paddle/fluid/framework/selected_rows_utils.h"
J
Jiabin Yang 已提交
24
#include "paddle/fluid/imperative/layer.h"
25
#include "paddle/fluid/operators/math/selected_rows_functor.h"
26
#include "paddle/fluid/platform/bfloat16.h"
27
#include "paddle/fluid/platform/complex.h"
J
Jiabin Yang 已提交
28
#include "paddle/fluid/platform/device_context.h"
29
#include "paddle/fluid/platform/float16.h"
J
Jiabin Yang 已提交
30
#include "paddle/fluid/platform/profiler.h"
31 32
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
H
hong 已提交
33 34 35
#ifdef PADDLE_WITH_XPU
#include "xpu/refactor/math.h"
#endif
36
#ifdef PADDLE_WITH_ASCEND_CL
37
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
38
#endif
F
fwenguang 已提交
39 40 41
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#endif
J
Jiabin Yang 已提交
42 43 44 45

namespace paddle {
namespace imperative {

46 47 48
static void MoveOrCopyVar(framework::Variable* dst, framework::Variable* src,
                          bool force_copy) {
  if (!force_copy) {
49
    VLOG(6) << "Just Move Variable when sum gradients within this graph";
50 51 52 53
    *dst = std::move(*src);
    return;
  }

54
  VLOG(6) << "Copy occurs when sum gradients within this graph";
55 56 57 58 59 60 61 62
  if (src->IsType<framework::LoDTensor>()) {
    auto& src_tensor = src->Get<framework::LoDTensor>();
    if (!dst->IsType<framework::LoDTensor>()) {
      dst->Clear();
    }
    auto* dst_tensor = dst->GetMutable<framework::LoDTensor>();
    framework::TensorCopy(src_tensor, src_tensor.place(), dst_tensor);
    dst_tensor->set_lod(src_tensor.lod());
63 64 65
  } else if (src->IsType<phi::SelectedRows>()) {
    auto& src_selected_rows = src->Get<phi::SelectedRows>();
    if (!dst->IsType<phi::SelectedRows>()) {
66 67
      dst->Clear();
    }
68
    auto* dst_selected_rows = dst->GetMutable<phi::SelectedRows>();
69 70 71 72 73 74 75
    framework::TensorCopy(src_selected_rows.value(),
                          src_selected_rows.value().place(),
                          dst_selected_rows->mutable_value());
    dst_selected_rows->set_rows(src_selected_rows.rows());
    dst_selected_rows->set_height(src_selected_rows.height());
  } else {
    PADDLE_THROW(platform::errors::PermissionDenied(
76
        "Only support LoDTensor and SelectedRows for sum gradient"));
77 78 79
  }
}

J
Jiabin Yang 已提交
80 81 82 83 84 85
template <typename T>
class TensorAddFunctor : public boost::static_visitor<> {
 public:
  TensorAddFunctor(int64_t numel, const T* x, T* y)
      : numel_(numel), x_(x), y_(y) {}

86
  void operator()(const platform::CPUPlace& place) const {
J
Jiabin Yang 已提交
87 88
    platform::CPUDeviceContext* ctx = dynamic_cast<platform::CPUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
89
    auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, T>(*ctx);
J
Jiabin Yang 已提交
90 91 92
    blas.AXPY(numel_, 1., x_, y_);
  }

H
hong 已提交
93
#ifdef PADDLE_WITH_XPU
94
  void operator()(const platform::XPUPlace& place) const {
95
    using XPUType = typename XPUTypeTrait<T>::Type;
H
hong 已提交
96 97
    platform::XPUDeviceContext* ctx = dynamic_cast<platform::XPUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
98 99 100 101 102 103 104 105
    int r = xpu::add<XPUType>(
        ctx->x_context(), reinterpret_cast<const XPUType*>(x_),
        reinterpret_cast<const XPUType*>(y_), reinterpret_cast<XPUType*>(y_),
        static_cast<int>(numel_));
    PADDLE_ENFORCE_EQ(
        r, XPU_SUCCESS,
        platform::errors::External("XPU add kernel return wrong value[%d %s]",
                                   r, XPUAPIErrorMsg[r]));
H
hong 已提交
106 107
  }
#else
108
  void operator()(const platform::XPUPlace& place) const {
109 110 111 112 113
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
H
hong 已提交
114
#endif
115

116
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
117
  void operator()(const platform::CUDAPlace& place) const {
J
Jiabin Yang 已提交
118 119 120
    platform::CUDADeviceContext* ctx =
        dynamic_cast<platform::CUDADeviceContext*>(
            platform::DeviceContextPool::Instance().Get(place));
121
    auto blas = phi::funcs::GetBlas<platform::CUDADeviceContext, T>(*ctx);
J
Jiabin Yang 已提交
122 123 124
    blas.AXPY(numel_, 1., x_, y_);
  }
#else
125
  void operator()(const platform::CUDAPlace& place) const {
126
    PADDLE_THROW(platform::errors::PermissionDenied(
127 128 129 130 131 132
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
#endif

F
fwenguang 已提交
133
#ifdef PADDLE_WITH_MLU
134
  void operator()(const platform::MLUPlace& place) const {
F
fwenguang 已提交
135 136 137 138 139 140 141
    // TODO(fwg): SUPPORT it
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
#else
142
  void operator()(const platform::MLUPlace& place) const {
F
fwenguang 已提交
143 144 145 146 147 148 149
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
#endif

150
#ifdef PADDLE_WITH_ASCEND_CL
151
  void operator()(const platform::NPUPlace& place) const {
152 153 154 155 156 157 158
    // TODO(zhiqiu): SUPPORT it
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
#else
159
  void operator()(const platform::NPUPlace& place) const {
160
    PADDLE_THROW(platform::errors::PermissionDenied(
161 162 163
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
J
Jiabin Yang 已提交
164 165 166
  }
#endif

167
  void operator()(const platform::NPUPinnedPlace& place) const {
168 169 170 171 172
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
J
Jiabin Yang 已提交
173
  // there is NO blas in CUDAPinnedPlace
174
  void operator()(const platform::CUDAPinnedPlace& place) const {
175 176 177 178
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
J
Jiabin Yang 已提交
179
  }
J
jianghaicheng 已提交
180
  // there is NO support in IPUPlace
181
  void operator()(const platform::IPUPlace& place) const {
J
jianghaicheng 已提交
182 183 184 185 186
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
187 188 189 190 191 192
  void operator()(const platform::CustomPlace& place) const {
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
J
Jiabin Yang 已提交
193 194 195 196

 private:
  int64_t numel_;
  const T* x_;
197
  mutable T* y_;
J
Jiabin Yang 已提交
198 199
};

200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
#ifdef PADDLE_WITH_XPU
template <typename T>
void XPUTensorAddFunctor(const platform::Place& place,
                         const framework::Tensor& src, framework::Tensor* dst) {
  using XPUType = typename XPUTypeTrait<T>::Type;
  platform::XPUDeviceContext* ctx = dynamic_cast<platform::XPUDeviceContext*>(
      platform::DeviceContextPool::Instance().Get(place));
  const XPUType* x = reinterpret_cast<const XPUType*>(src.data<T>());
  XPUType* y = reinterpret_cast<XPUType*>(dst->mutable_data<T>(place));
  int r = xpu::add<XPUType>(ctx->x_context(), x, y, y,
                            static_cast<int>(src.numel()));
  PADDLE_ENFORCE_EQ(
      r, XPU_SUCCESS,
      platform::errors::External("XPU add kernel return wrong value[%d %s]", r,
                                 XPUAPIErrorMsg[r]));
}
#endif

218 219 220 221 222 223
template <typename DeviceContext, typename T>
void TensorAddImpl(const framework::Tensor& src, framework::Tensor* dst,
                   const platform::Place& place) {
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  paddle::platform::DeviceContext* ctx = pool.Get(place);
  auto dev_ctx = dynamic_cast<DeviceContext*>(ctx);
224
  phi::funcs::ElementwiseAddTo<DeviceContext, T> func;
225 226 227
  func(dev_ctx, src, dst);
}

228 229 230
template <typename TType>
TType* GetInnerMutableTensor(framework::Variable* dst) {
  auto* dst_tensor = dst->GetMutable<TType>();
231 232 233
  return dst_tensor;
}

234 235 236
template <typename TType>
TType* GetInnerMutableTensor(paddle::experimental::Tensor* dst) {
  auto* dst_tensor = static_cast<TType*>(dst->impl().get());
237 238 239
  return dst_tensor;
}

240 241 242
template <typename TType>
const TType& GetInnerTensor(const framework::Variable& src) {
  return src.Get<TType>();
243 244
}

245 246 247 248 249 250 251 252 253
template <typename TType>
TType& GetInnerTensor(const paddle::experimental::Tensor& src) {
  PADDLE_ENFORCE_EQ(
      src.initialized(), true,
      platform::errors::Fatal("We only add tensor with value if a tensor is "
                              "NOT INITILIZED, it should just move instead of "
                              "calling this method."));
  auto* src_tensor = static_cast<TType*>(src.impl().get());
  return *src_tensor;
254 255
}

256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
template <typename TType>
TType* GetEmptyInnerTensor(paddle::experimental::Tensor* dst) {
  PADDLE_ENFORCE_EQ(
      dst->defined(), false,
      platform::errors::Fatal(
          "The underlying Tensor implementation should be nullptr"));
  dst->set_impl(std::make_shared<TType>());
  auto* dst_tensor = static_cast<TType*>(dst->impl().get());
  return dst_tensor;
}

template <typename TType>
TType* GetEmptyInnerTensor(paddle::imperative::VariableWrapper* dst) {
  auto* dst_tensor = dst->MutableVar()->GetMutable<TType>();
  return dst_tensor;
}

273 274
template <typename VarType>
void TensorAdd(const VarType& src, VarType* dst) {
275 276
  phi::DenseTensor* dst_tensor = GetInnerMutableTensor<phi::DenseTensor>(dst);
  const phi::DenseTensor& src_tensor = GetInnerTensor<phi::DenseTensor>(src);
J
Jiabin Yang 已提交
277 278 279 280 281 282 283 284 285

  auto numel = src_tensor.numel();

  // FIXME(minqiyang): loss_grad op will pass a zero grad of label
  // ugly fix for it
  if (numel == 0) {
    return;
  }

286 287 288 289 290 291 292
  PADDLE_ENFORCE_EQ(
      dst_tensor->numel(), numel,
      platform::errors::PreconditionNotMet(
          "The number of elements of source tensor and destination tensor "
          "should be equal, but got the number of elements of source tensor is "
          "%zu and the number of elements of destination tensor is %zu.",
          numel, dst_tensor->numel()));
J
Jiabin Yang 已提交
293

294
  auto data_type = framework::TransToProtoVarType(src_tensor.dtype());
J
Jiabin Yang 已提交
295 296
  auto place = src_tensor.place();

297 298
  PADDLE_ENFORCE_EQ(framework::TransToProtoVarType(dst_tensor->dtype()),
                    data_type,
299 300 301 302 303
                    platform::errors::PreconditionNotMet(
                        "The data type of source tensor and destination tensor "
                        "should be equal, Otherwise, the calculation results "
                        "will be incorrect."));

304 305 306 307
  // if src and dst are in different place, copy dst to src's place
  if (dst_tensor->place() != place) {
    paddle::framework::TensorCopySync(*dst_tensor, place, dst_tensor);
  }
308
#define PADDLE_TENSOR_ADD(cpp_type)                                  \
J
Jiabin Yang 已提交
309 310 311 312
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) { \
    TensorAddFunctor<cpp_type> func(                                 \
        numel, src_tensor.data<cpp_type>(),                          \
        dst_tensor->mutable_data<cpp_type>(place));                  \
313
    platform::VisitPlace(place, func);                               \
J
Jiabin Yang 已提交
314 315 316
    return;                                                          \
  }

317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
#ifdef PADDLE_WITH_ASCEND_CL
  if (platform::is_npu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::DeviceContext* ctx = pool.Get(place);
    auto dev_ctx = dynamic_cast<platform::NPUDeviceContext*>(ctx);
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      dst_tensor->mutable_data<float>(place);
    } else if (data_type == framework::DataTypeTrait<double>::DataType()) {
      dst_tensor->mutable_data<double>(place);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      dst_tensor->mutable_data<platform::float16>(place);
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
          framework::DataTypeToString(data_type), place));
    }
    const auto& runner = operators::NpuOpRunner(
        "Add", {*dst_tensor, src_tensor}, {*dst_tensor}, {});
    runner.Run(dev_ctx->stream());
    return;
  }
#endif
341 342 343 344 345 346 347 348
#ifdef PADDLE_WITH_CUSTOM_DEVICE
  if (platform::is_custom_place(place)) {
    PADDLE_THROW(platform::errors::Unimplemented(
        "Gradient accumulation of data type (%s) on place (%s) is not "
        "supported in imperative mode",
        framework::DataTypeToString(data_type), place));
  }
#endif
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
#ifdef PADDLE_WITH_XPU
  if (platform::is_xpu_place(place)) {
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      XPUTensorAddFunctor<float>(place, src_tensor, dst_tensor);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      XPUTensorAddFunctor<platform::float16>(place, src_tensor, dst_tensor);
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
          framework::DataTypeToString(data_type), place));
    }
    return;
  }
#endif

F
fwenguang 已提交
366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386
#ifdef PADDLE_WITH_MLU
  if (platform::is_mlu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::DeviceContext* ctx = pool.Get(place);
    auto dev_ctx = dynamic_cast<platform::MLUDeviceContext*>(ctx);
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      dst_tensor->mutable_data<float>(place);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      dst_tensor->mutable_data<platform::float16>(place);
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
          framework::DataTypeToString(data_type), place));
    }
    static const float alpha = 1.f;
    static const float beta = 1.f;
    operators::MLUCnnlTensorDesc src_tensor_desc(src_tensor);
    operators::MLUCnnlTensorDesc dst_tensor_desc(*dst_tensor);
    PADDLE_ENFORCE_MLU_SUCCESS(cnnlAssignAdd(
387
        dev_ctx->cnnl_handle(), static_cast<const void*>(&alpha),
F
fwenguang 已提交
388
        src_tensor_desc.get(), operators::GetBasePtr(&src_tensor), nullptr, 0,
389
        static_cast<const void*>(&beta), dst_tensor_desc.get(),
F
fwenguang 已提交
390 391 392 393 394
        operators::GetBasePtr(dst_tensor)));
    return;
  }
#endif

395
  PADDLE_TENSOR_ADD(float);
396

H
hong 已提交
397 398
#ifndef PADDLE_WITH_XPU
  // NOTE(phlrain): xpu only support float
399
  PADDLE_TENSOR_ADD(double);
400 401
  // NOTE(chenweihang): only support complex grad tensor accumulated,
  // support selected rows if needed in the future
402 403
  PADDLE_TENSOR_ADD(platform::complex<float>);
  PADDLE_TENSOR_ADD(platform::complex<double>);
H
hong 已提交
404
#endif
J
Jiabin Yang 已提交
405

406
#undef PADDLE_TENSOR_ADD
J
Jiabin Yang 已提交
407

408 409
  if (data_type == framework::proto::VarType::FP16) {
    if (platform::is_gpu_place(place)) {
410
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
411 412 413 414 415 416 417 418 419 420 421 422 423
      return TensorAddImpl<platform::CUDADeviceContext, platform::float16>(
          src_tensor, dst_tensor, place);
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
          framework::DataTypeToString(data_type), place));
#endif
    } else if (platform::is_cpu_place(place)) {
      return TensorAddImpl<platform::CPUDeviceContext, platform::float16>(
          src_tensor, dst_tensor, place);
    }
  }
424 425
  if (data_type == framework::proto::VarType::BF16) {
    if (platform::is_gpu_place(place)) {
426
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
427 428 429 430 431 432 433 434 435 436 437 438 439
      return TensorAddImpl<platform::CUDADeviceContext, platform::bfloat16>(
          src_tensor, dst_tensor, place);
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
          framework::DataTypeToString(data_type), place));
#endif
    } else if (platform::is_cpu_place(place)) {
      return TensorAddImpl<platform::CPUDeviceContext, platform::bfloat16>(
          src_tensor, dst_tensor, place);
    }
  }
440 441 442 443
  PADDLE_THROW(platform::errors::Unimplemented(
      "Gradient accumulation of data type (%s) on place (%s) is not "
      "supported in imperative mode",
      framework::DataTypeToString(data_type), place));
J
Jiabin Yang 已提交
444 445
}

446 447
template void TensorAdd<framework::Variable>(const framework::Variable& src,
                                             framework::Variable* dst);
448 449
template void TensorAdd<paddle::experimental::Tensor>(
    const paddle::experimental::Tensor& src, paddle::experimental::Tensor* dst);
450

451 452
template <typename VarType>
void SelectedRowsAddToTensor(const VarType& src, VarType* dst) {
453 454 455
  phi::DenseTensor* dst_tensor = GetInnerMutableTensor<phi::DenseTensor>(dst);
  const phi::SelectedRows& src_selected_rows =
      GetInnerTensor<phi::SelectedRows>(src);
456
  auto place = dst_tensor->place();
457 458
  auto data_type =
      framework::TransToProtoVarType(src_selected_rows.value().dtype());
459 460 461 462 463 464 465 466 467 468 469 470
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

#define PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(dev_ctx_type, cpp_type)           \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {         \
    paddle::platform::DeviceContext* dev_ctx = pool.Get(place);              \
    paddle::operators::math::SelectedRowsAddToTensor<dev_ctx_type, cpp_type> \
        functor;                                                             \
    functor(*(dynamic_cast<dev_ctx_type*>(dev_ctx)), src_selected_rows,      \
            dst_tensor);                                                     \
    return;                                                                  \
  }

471
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
472 473 474 475 476 477 478
  if (paddle::platform::is_gpu_place(place)) {
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CUDADeviceContext, float);
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CUDADeviceContext, double);
  } else {
#endif
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CPUDeviceContext, float);
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CPUDeviceContext, double);
479
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
480 481 482 483 484 485 486 487 488 489
  }
#endif

#undef PADDLE_SELECTED_ROWS_ADD_TO_TENSOR

  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsAddToTensor",
      framework::DataTypeToString(data_type)));
}

490 491 492 493 494 495 496 497 498
template void SelectedRowsAddToTensor(const framework::Variable& src,
                                      framework::Variable* dst);
template void SelectedRowsAddToTensor(const paddle::experimental::Tensor& src,
                                      paddle::experimental::Tensor* dst);

template <typename VarType>
void SelectedRowsAddTensor(const VarType& src_selected_rows_var,
                           const VarType& src_tensor_var,
                           VarType* dst_tensor_var) {
499 500 501 502
  const phi::SelectedRows& src_selected_rows =
      GetInnerTensor<phi::SelectedRows>(src_selected_rows_var);
  const phi::DenseTensor& src_tensor =
      GetInnerTensor<phi::DenseTensor>(src_tensor_var);
503
  const auto& place = src_tensor.place();
504
  auto data_type = framework::TransToProtoVarType(src_tensor.dtype());
505 506
  auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);

507 508
  phi::DenseTensor* dst_tensor =
      GetInnerMutableTensor<phi::DenseTensor>(dst_tensor_var);
509
  dst_tensor->Resize(src_tensor.dims());
510 511
  dst_tensor->mutable_data(place, src_tensor.dtype());

512 513 514 515 516 517 518 519 520
#define PADDLE_SELECTED_ROWS_ADD_TENSOR(dev_ctx_type, cpp_type)            \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {       \
    paddle::operators::math::SelectedRowsAddTensor<dev_ctx_type, cpp_type> \
        functor;                                                           \
    functor(*(dynamic_cast<dev_ctx_type*>(dev_ctx)), src_selected_rows,    \
            src_tensor, dst_tensor);                                       \
    return;                                                                \
  }

521
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
522 523 524 525 526 527 528
  if (platform::is_gpu_place(place)) {
    PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CUDADeviceContext, float);
    PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CUDADeviceContext, double);
  } else {
#endif
    PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CPUDeviceContext, float);
    PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CPUDeviceContext, double);
529
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
530 531 532 533 534 535 536 537 538 539
  }
#endif

  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsAddToTensor",
      framework::DataTypeToString(data_type)));

#undef PADDLE_SELECTED_ROWS_ADD_TENSOR
}

540 541 542 543 544 545 546 547 548 549 550 551
template void SelectedRowsAddTensor(
    const framework::Variable& src_selected_rows_var,
    const framework::Variable& src_tensor_var,
    framework::Variable* dst_tensor_var);
template void SelectedRowsAddTensor(
    const paddle::experimental::Tensor& src_selected_rows_var,
    const paddle::experimental::Tensor& src_tensor_var,
    paddle::experimental::Tensor* dst_tensor_var);

// Note(chenweihang): when two selected rows need to be added,
//   adding one to another is not equal to merging two selected rows
//   to one then add it to a empty selected rows, the after is correct
552 553 554
template <typename ReturnVarType, typename VarType>
std::shared_ptr<ReturnVarType> SelectedRowsMerge(const VarType& src1,
                                                 const VarType& src2) {
555 556 557 558
  const phi::SelectedRows& src_selected_rows1 =
      GetInnerTensor<phi::SelectedRows>(src1);
  const phi::SelectedRows& src_selected_rows2 =
      GetInnerTensor<phi::SelectedRows>(src2);
559

560
  auto place = src_selected_rows1.value().place();
561 562
  auto data_type =
      framework::TransToProtoVarType(src_selected_rows1.value().dtype());
563 564
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

565
  std::vector<const phi::SelectedRows*> src_selected_rows;
566 567
  src_selected_rows.emplace_back(&src_selected_rows1);
  src_selected_rows.emplace_back(&src_selected_rows2);
568 569

  auto dst_var = std::make_shared<ReturnVarType>("Temp");
570 571
  phi::SelectedRows* dst_selected_rows =
      GetEmptyInnerTensor<phi::SelectedRows>(dst_var.get());
572 573 574 575 576 577 578 579 580 581 582

#define PADDLE_SELECTED_ROWS_ADD(dev_ctx_type, cpp_type)                  \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {      \
    paddle::platform::DeviceContext* dev_ctx = pool.Get(place);           \
    paddle::operators::math::scatter::MergeAdd<dev_ctx_type, cpp_type>    \
        merge_add;                                                        \
    merge_add(*(dynamic_cast<dev_ctx_type*>(dev_ctx)), src_selected_rows, \
              dst_selected_rows);                                         \
    return dst_var;                                                       \
  }

583
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
584 585 586 587 588 589 590
  if (paddle::platform::is_gpu_place(place)) {
    PADDLE_SELECTED_ROWS_ADD(platform::CUDADeviceContext, float);
    PADDLE_SELECTED_ROWS_ADD(platform::CUDADeviceContext, double);
  } else {
#endif
    PADDLE_SELECTED_ROWS_ADD(platform::CPUDeviceContext, float);
    PADDLE_SELECTED_ROWS_ADD(platform::CPUDeviceContext, double);
591
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
592 593 594 595 596 597 598 599 600
  }
#endif

#undef PADDLE_SELECTED_ROWS_ADD
  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsMerge",
      framework::DataTypeToString(data_type)));
}

601 602 603 604 605 606
template std::shared_ptr<paddle::experimental::Tensor> SelectedRowsMerge(
    const paddle::experimental::Tensor& src1,
    const paddle::experimental::Tensor& src2);
template std::shared_ptr<paddle::imperative::VariableWrapper> SelectedRowsMerge(
    const framework::Variable& src1, const framework::Variable& src2);

607
void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var,
608
                        VariableWrapper* dst_var, bool unchange_input) {
609
  auto& src = var->Var();
610
  auto* dst = dst_var->MutableVar();
611 612
  if (dst->IsType<framework::LoDTensor>()) {
    if (src.IsType<framework::LoDTensor>()) {
613
      TensorAdd<framework::Variable>(src, dst);
614
    } else if (src.IsType<phi::SelectedRows>()) {
615 616 617 618 619 620 621 622
      SelectedRowsAddToTensor(src, dst);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unexpected branch, output variable type is %s",
          framework::ToTypeName(dst->Type())));
    }
  } else {
    if (src.IsType<framework::LoDTensor>()) {
623 624 625 626 627 628 629 630 631
      if (unchange_input) {
        framework::Variable new_dst;
        SelectedRowsAddTensor(*dst, src, &new_dst);
        *dst = std::move(new_dst);
      } else {
        auto* src_mutable = var->MutableVar();
        SelectedRowsAddToTensor(*dst, src_mutable);
        *dst = std::move(*(var->MutableVar()));
      }
632
    } else if (src.IsType<phi::SelectedRows>()) {
633
      auto temp = SelectedRowsMerge<VariableWrapper>(src, *dst);
634 635 636 637 638 639 640 641 642
      *dst = std::move(*(temp->MutableVar()));
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unexpected branch, output variable type is %s",
          framework::ToTypeName(dst->Type())));
    }
  }
}

643 644
static platform::Place GetPlaceOfVar(
    const std::shared_ptr<VariableWrapper>& var) {
645 646 647
  platform::Place place;
  if (var->Var().IsType<framework::LoDTensor>()) {
    place = var->Var().Get<framework::LoDTensor>().place();
648 649
  } else if (var->Var().IsType<phi::SelectedRows>()) {
    place = var->Var().Get<phi::SelectedRows>().place();
650 651 652 653 654 655 656
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "only support LoDTensor and SelectedRows in dygraph"));
  }
  return place;
}

657 658
void GradientAccumulator::AccumulateGrad() {
  /**
659 660
   * If the leaf gradient has been calculated done, the inner_var_
   * should be added to the var_.
661 662 663 664 665 666 667 668 669 670
   */
  if (!var_->IsLeafGrad() || !SumGradCompleted() || !HasInnerVar()) {
    return;
  }
  PADDLE_ENFORCE_EQ(HasInnerVar(), true,
                    platform::errors::InvalidArgument(
                        "Leaf tensor should have inner var to store results of "
                        "this auto-grad"));
  PADDLE_ENFORCE_EQ(inner_var_->Var().IsInitialized(), true,
                    platform::errors::InvalidArgument(
671
                        "Interior var of Leaf tensor should be initialized."));
672 673 674
  auto* src = inner_var_->MutableVar();
  auto* dst = var_->MutableVar();
  if (!var_->IsEmpty()) {
675 676 677
    VLOG(6) << "Leaf Var(" << var_->Name()
            << ")'s Gradient has been initizlized, will accumulate on "
               "previous gradient.";
678 679
    if (dst->IsType<framework::LoDTensor>()) {
      if (src->IsType<framework::LoDTensor>()) {
680
        TensorAdd<framework::Variable>(*src, dst);
681
      } else if (src->IsType<phi::SelectedRows>()) {
682 683
        SelectedRowsAddToTensor(*src, dst);
      }
684
    } else if (dst->IsType<phi::SelectedRows>()) {
685 686 687
      if (src->IsType<framework::LoDTensor>()) {
        SelectedRowsAddToTensor(*dst, src);
        *dst = std::move(*src);
688
      } else if (src->IsType<phi::SelectedRows>()) {
689
        auto temp = SelectedRowsMerge<VariableWrapper>(*src, *dst);
690 691 692 693 694 695 696
        *dst = std::move(*(temp->MutableVar()));
      }
    } else {
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Only support LoDTensor and SelectedRows for gradient var"));
    }
  } else {
697 698 699
    VLOG(6)
        << "Leaf Var(" << var_->Name()
        << ")'s Gradient has not been initialized, not accumulate. Just move";
700 701 702
    *(dst) = std::move(*src);
    var_->SetType(inner_var_->Type());
    var_->SetDataType(inner_var_->DataType());
703
    var_->SetIsEmpty(false);
704 705 706 707
  }
  inner_var_.reset();
}

708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725
void GradientAccumulator::CallGradientHooks() {
  PADDLE_ENFORCE_EQ(var_->IsLeafGrad(), true,
                    platform::errors::Unavailable(
                        "Only leaf gradient Tensor can deal with by gradient "
                        "hook in gradient accumulator."));
  PADDLE_ENFORCE_EQ(
      SumGradCompleted(), true,
      platform::errors::PreconditionNotMet(
          "Only can call gradient hooks after sum gradient completed."));
  PADDLE_ENFORCE_EQ(
      HasInnerVar(), true,
      platform::errors::PreconditionNotMet(
          "Leaf Tensor's inner var is nullptr when call gradient hook."));
  PADDLE_ENFORCE_EQ(
      inner_var_->Var().IsInitialized(), true,
      platform::errors::PreconditionNotMet("Leaf Tensor's inner var "
                                           "is not initialized when "
                                           "call gradient hook."));
726 727
  if (var_->HasVariableWrapperHook()) {
    VLOG(3) << "Call " << var_->GetVariableWrapperHooks().size()
728 729 730 731
            << " hooks of leaf gradient accumulator's inner var `"
            << var_->Name() << "`.";
    auto tmp_var = inner_var_;
    VLOG(3) << "Input var " << var_->Name() << "'s hook size - "
732 733
            << var_->GetVariableWrapperHooks().size();
    for (const auto& hook_pair : var_->GetVariableWrapperHooks()) {
734
      tmp_var = (*hook_pair.second)(tmp_var);
L
Leo Chen 已提交
735
      CheckVar(inner_var_, tmp_var);
736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754
    }
    inner_var_ = tmp_var;
  }
}

void GradientAccumulator::CallReduceHooks() {
  PADDLE_ENFORCE_EQ(
      var_->IsLeafGrad(), true,
      platform::errors::Unavailable("Only leaf gradient Tensor can deal with "
                                    "by reduce hook in gradient accumulator."));
  PADDLE_ENFORCE_EQ(SumGradCompleted(), true,
                    platform::errors::PreconditionNotMet(
                        "Only can call reduce hooks after the gradient "
                        "summation is completed in current batch."));
  PADDLE_ENFORCE_EQ(HasInnerVar(), false,
                    platform::errors::PreconditionNotMet(
                        "Only can call reduce hooks after the "
                        "gradient accumulation is completed in "
                        "current batch or across batchs."));
755 756
  if (var_->HasVoidHook()) {
    for (const auto& hook : var_->GetVoidHooks()) {
757
      VLOG(3) << "call gradient accumulator backward hooks.";
758
      (*hook)();
759 760 761 762
    }
  }
}

763 764
void EagerGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
                                       size_t trace_id, bool unchange_input) {
765 766 767 768 769 770 771 772
  /**
   * If var has grad node, it indicates that this var would be an input
   * of a grad op. Therefore, it should not be changed.
   */
  if (var->HasGradNode()) {
    unchange_input = true;
  }

773
  auto* dst_var = Var();
774
  platform::Place place = GetPlaceOfVar(var);
775 776 777
  if (!dst_var->OverridedStopGradient()) {
    if (CurCnt() == 0) {
      MoveOrCopyVar(dst_var->MutableVar(), var->MutableVar(), unchange_input);
778
    } else {
779 780 781
      VLOG(6) << "Sum Gradient for: " << dst_var->Name()
              << " within this graph.";
      VariableWrapperAdd(var, dst_var, unchange_input);
782
    }
J
Jiabin Yang 已提交
783
  } else {
784 785 786
    if (!dst_var->Var().IsInitialized() ||
        !dst_var->Var().Get<framework::LoDTensor>().IsInitialized()) {
      VLOG(6) << "Set StopGradient Grad: " << dst_var->Name() << " as zero ";
787
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
788 789 790 791
      if (!dst_var->Var().IsInitialized()) {
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
        VLOG(6) << "Dims of " << dst_var->Name() << " is set as: "
792 793
                << var->Var().Get<framework::LoDTensor>().dims();
        tensor->Resize(var->Var().Get<framework::LoDTensor>().dims());
794
        tensor->mutable_data(place,
795
                             framework::TransToPhiDataType(var->DataType()));
796
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
797
      } else {
798 799
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
800
        tensor->mutable_data(place,
801
                             framework::TransToPhiDataType(var->DataType()));
802
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
803
      }
804
    }
J
Jiabin Yang 已提交
805
  }
806

807 808 809 810
  // Type may be changed after OP run, such as VarTypeInference
  // so synchronous VariableWrapper with Variable.
  if (dst_var->Var().IsType<framework::LoDTensor>()) {
    dst_var->SetType(framework::proto::VarType::LOD_TENSOR);
811
  } else if (dst_var->Var().IsType<phi::SelectedRows>()) {
812
    dst_var->SetType(framework::proto::VarType::SELECTED_ROWS);
813
  }
814

815
  // Increase curent count
816
  IncreaseCurCnt();
J
Jiabin Yang 已提交
817 818
}

819 820 821
void SortedGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
                                        size_t trace_id, bool unchange_input) {
  auto* dst_var = Var();
822
  platform::Place place = GetPlaceOfVar(var);
823
  if (!dst_var->OverridedStopGradient()) {
824
    if (ref_cnt_ == 1) {
825
      MoveOrCopyVar(dst_var->MutableVar(), var->MutableVar(),
826
                    unchange_input || var->HasGradNode());
827 828 829 830 831
    } else {
      if (tmp_grad_vars_.empty()) {
        tmp_grad_vars_.reserve(ref_cnt_);
      }

832
      tmp_grad_vars_.emplace_back(std::move(var), trace_id, unchange_input);
833 834 835 836 837

      if (tmp_grad_vars_.size() != ref_cnt_) {
        return;
      }

838 839
      VLOG(6) << "Sum Gradient for: " << dst_var->Name()
              << " within this graph.";
840 841 842 843 844 845 846 847 848 849
      std::sort(tmp_grad_vars_.begin(), tmp_grad_vars_.end(),
                [](const SavedVarInfo& info1, const SavedVarInfo& info2) {
                  return info1.trace_id > info2.trace_id;
                });

      for (auto& var_info : tmp_grad_vars_) {
        if (var_info.var->HasGradNode()) {
          var_info.unchange_input = true;
        }
      }
850

851
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
852
      if (paddle::platform::is_gpu_place(place)) {
853
        // sum selected rows firstly
854
        for (auto& var_info : tmp_grad_vars_) {
855
          if (!var_info.var->Var().IsType<phi::SelectedRows>()) {
856
            continue;
857
          }
858

859 860
          if (CurCnt() == 0) {
            MoveOrCopyVar(dst_var->MutableVar(), var_info.var->MutableVar(),
861 862
                          var_info.unchange_input);
          } else {
863
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
864
          }
865 866

          var_info.var = nullptr;
867 868
          // Increase count
          IncreaseCurCnt();
869 870 871 872 873 874 875 876 877 878
        }

        for (auto& var_info : tmp_grad_vars_) {
          if (!var_info.var) {
            continue;
          }

          PADDLE_ENFORCE_EQ(var_info.var->Var().IsType<framework::LoDTensor>(),
                            true, platform::errors::PermissionDenied(
                                      "Gradient var must be LoDTensor"));
879 880
          if (CurCnt() == 0) {
            MoveOrCopyVar(dst_var->MutableVar(), var_info.var->MutableVar(),
881 882
                          var_info.unchange_input);
          } else {
883
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
884
          }
885 886

          var_info.var = nullptr;
887 888
          // Increase count
          IncreaseCurCnt();
889 890 891
        }
      } else {
#endif
892 893 894 895 896 897
        for (auto& var_info : tmp_grad_vars_) {
          if (!var_info.var) {
            continue;
          }
          PADDLE_ENFORCE_EQ(
              var_info.var->Var().IsType<framework::LoDTensor>() ||
898
                  var_info.var->Var().IsType<phi::SelectedRows>(),
899 900 901 902 903 904 905 906 907 908 909 910
              true, platform::errors::PermissionDenied("The type of Gradient "
                                                       "var must be LoDTensor "
                                                       "or SelectedRows"));
          if (CurCnt() == 0) {
            MoveOrCopyVar(dst_var->MutableVar(), var_info.var->MutableVar(),
                          var_info.unchange_input);
          } else {
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
          }
          var_info.var = nullptr;
          // Increase count
          IncreaseCurCnt();
911
        }
912
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
913
      }
914
#endif
915
      tmp_grad_vars_.clear();
J
Jiabin Yang 已提交
916
    }
917
  } else {
918 919
    if (!dst_var->Var().IsInitialized() ||
        !dst_var->Var().Get<framework::LoDTensor>().IsInitialized()) {
920 921
      VLOG(6) << "Set StopGradient Grad: " << var->Name() << " as zero";
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
922 923 924 925
      if (!dst_var->Var().IsInitialized()) {
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
        VLOG(6) << "Dims of " << dst_var->Name() << " is set as: "
926 927
                << var->Var().Get<framework::LoDTensor>().dims();
        tensor->Resize(var->Var().Get<framework::LoDTensor>().dims());
928
        tensor->mutable_data(place,
929
                             framework::TransToPhiDataType(var->DataType()));
930
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
931
      } else {
932 933
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
934
        tensor->mutable_data(place,
935
                             framework::TransToPhiDataType(var->DataType()));
936
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
937
      }
J
Jiabin Yang 已提交
938
    }
939
    // looks like tmp_grad_vars will not have any member but just in case
J
Jiabin Yang 已提交
940 941
    tmp_grad_vars_.clear();
  }
942

943 944
  if (dst_var->Var().IsType<framework::LoDTensor>()) {
    dst_var->SetType(framework::proto::VarType::LOD_TENSOR);
945
  } else if (dst_var->Var().IsType<phi::SelectedRows>()) {
946
    dst_var->SetType(framework::proto::VarType::SELECTED_ROWS);
947
  }
J
Jiabin Yang 已提交
948 949 950 951
}

}  // namespace imperative
}  // namespace paddle