gradient_accumulator.cc 33.9 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/gradient_accumulator.h"
16

J
Jiabin Yang 已提交
17 18 19
#include <algorithm>
#include <memory>
#include <utility>
20

21
#include "paddle/fluid/framework/convert_utils.h"
J
Jiabin Yang 已提交
22
#include "paddle/fluid/framework/lod_tensor.h"
23
#include "paddle/fluid/framework/selected_rows_utils.h"
J
Jiabin Yang 已提交
24
#include "paddle/fluid/imperative/layer.h"
25
#include "paddle/fluid/platform/bfloat16.h"
26
#include "paddle/fluid/platform/complex.h"
J
Jiabin Yang 已提交
27
#include "paddle/fluid/platform/device_context.h"
28
#include "paddle/fluid/platform/float16.h"
J
Jiabin Yang 已提交
29
#include "paddle/fluid/platform/profiler.h"
30 31
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
32
#include "paddle/phi/kernels/funcs/selected_rows_functor.h"
H
hong 已提交
33 34 35
#ifdef PADDLE_WITH_XPU
#include "xpu/refactor/math.h"
#endif
36
#ifdef PADDLE_WITH_ASCEND_CL
37
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
38
#endif
F
fwenguang 已提交
39 40 41
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#endif
42 43 44
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
#endif
45
#include "paddle/phi/kernels/elementwise_add_kernel.h"
J
Jiabin Yang 已提交
46 47 48 49

namespace paddle {
namespace imperative {

50 51
static void MoveOrCopyVar(framework::Variable* dst,
                          framework::Variable* src,
52 53
                          bool force_copy) {
  if (!force_copy) {
54
    VLOG(6) << "Just Move Variable when sum gradients within this graph";
55 56 57 58
    *dst = std::move(*src);
    return;
  }

59
  VLOG(6) << "Copy occurs when sum gradients within this graph";
60 61 62
  if (src->IsType<phi::DenseTensor>()) {
    auto& src_tensor = src->Get<phi::DenseTensor>();
    if (!dst->IsType<phi::DenseTensor>()) {
63 64
      dst->Clear();
    }
65
    auto* dst_tensor = dst->GetMutable<phi::DenseTensor>();
66 67
    framework::TensorCopy(src_tensor, src_tensor.place(), dst_tensor);
    dst_tensor->set_lod(src_tensor.lod());
68 69 70
  } else if (src->IsType<phi::SelectedRows>()) {
    auto& src_selected_rows = src->Get<phi::SelectedRows>();
    if (!dst->IsType<phi::SelectedRows>()) {
71 72
      dst->Clear();
    }
73
    auto* dst_selected_rows = dst->GetMutable<phi::SelectedRows>();
74 75 76 77 78 79 80
    framework::TensorCopy(src_selected_rows.value(),
                          src_selected_rows.value().place(),
                          dst_selected_rows->mutable_value());
    dst_selected_rows->set_rows(src_selected_rows.rows());
    dst_selected_rows->set_height(src_selected_rows.height());
  } else {
    PADDLE_THROW(platform::errors::PermissionDenied(
81
        "Only support LoDTensor and SelectedRows for sum gradient"));
82 83 84
  }
}

85 86 87
#ifdef PADDLE_WITH_XPU
template <typename T>
void XPUTensorAddFunctor(const platform::Place& place,
88 89
                         const phi::DenseTensor& src,
                         phi::DenseTensor* dst) {
90 91 92 93 94
  using XPUType = typename XPUTypeTrait<T>::Type;
  platform::XPUDeviceContext* ctx = dynamic_cast<platform::XPUDeviceContext*>(
      platform::DeviceContextPool::Instance().Get(place));
  const XPUType* x = reinterpret_cast<const XPUType*>(src.data<T>());
  XPUType* y = reinterpret_cast<XPUType*>(dst->mutable_data<T>(place));
95 96
  int r = xpu::add<XPUType>(
      ctx->x_context(), x, y, y, static_cast<int>(src.numel()));
97
  PADDLE_ENFORCE_EQ(
98 99 100 101
      r,
      XPU_SUCCESS,
      platform::errors::External(
          "XPU add kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r]));
102 103 104
}
#endif

105 106 107
template <typename TType>
TType* GetInnerMutableTensor(framework::Variable* dst) {
  auto* dst_tensor = dst->GetMutable<TType>();
108 109 110
  return dst_tensor;
}

111 112 113
template <typename TType>
TType* GetInnerMutableTensor(paddle::experimental::Tensor* dst) {
  auto* dst_tensor = static_cast<TType*>(dst->impl().get());
114 115 116
  return dst_tensor;
}

117 118 119
template <typename TType>
const TType& GetInnerTensor(const framework::Variable& src) {
  return src.Get<TType>();
120 121
}

122 123 124
template <typename TType>
TType& GetInnerTensor(const paddle::experimental::Tensor& src) {
  PADDLE_ENFORCE_EQ(
125 126
      src.initialized(),
      true,
127 128 129 130 131
      platform::errors::Fatal("We only add tensor with value if a tensor is "
                              "NOT INITILIZED, it should just move instead of "
                              "calling this method."));
  auto* src_tensor = static_cast<TType*>(src.impl().get());
  return *src_tensor;
132 133
}

134 135 136
template <typename TType>
TType* GetEmptyInnerTensor(paddle::experimental::Tensor* dst) {
  PADDLE_ENFORCE_EQ(
137 138
      dst->defined(),
      false,
139 140 141 142 143 144 145 146 147 148 149 150 151
      platform::errors::Fatal(
          "The underlying Tensor implementation should be nullptr"));
  dst->set_impl(std::make_shared<TType>());
  auto* dst_tensor = static_cast<TType*>(dst->impl().get());
  return dst_tensor;
}

template <typename TType>
TType* GetEmptyInnerTensor(paddle::imperative::VariableWrapper* dst) {
  auto* dst_tensor = dst->MutableVar()->GetMutable<TType>();
  return dst_tensor;
}

152 153
template <typename VarType>
void TensorAdd(const VarType& src, VarType* dst) {
154 155
  phi::DenseTensor* dst_tensor = GetInnerMutableTensor<phi::DenseTensor>(dst);
  const phi::DenseTensor& src_tensor = GetInnerTensor<phi::DenseTensor>(src);
J
Jiabin Yang 已提交
156 157 158 159 160 161 162 163 164

  auto numel = src_tensor.numel();

  // FIXME(minqiyang): loss_grad op will pass a zero grad of label
  // ugly fix for it
  if (numel == 0) {
    return;
  }

165
  PADDLE_ENFORCE_EQ(
166 167
      dst_tensor->numel(),
      numel,
168 169 170 171
      platform::errors::PreconditionNotMet(
          "The number of elements of source tensor and destination tensor "
          "should be equal, but got the number of elements of source tensor is "
          "%zu and the number of elements of destination tensor is %zu.",
172 173
          numel,
          dst_tensor->numel()));
J
Jiabin Yang 已提交
174

175
  auto data_type = framework::TransToProtoVarType(src_tensor.dtype());
J
Jiabin Yang 已提交
176 177
  auto place = src_tensor.place();

178 179
  PADDLE_ENFORCE_EQ(framework::TransToProtoVarType(dst_tensor->dtype()),
                    data_type,
180 181 182 183 184
                    platform::errors::PreconditionNotMet(
                        "The data type of source tensor and destination tensor "
                        "should be equal, Otherwise, the calculation results "
                        "will be incorrect."));

185 186 187 188
  // if src and dst are in different place, copy dst to src's place
  if (dst_tensor->place() != place) {
    paddle::framework::TensorCopySync(*dst_tensor, place, dst_tensor);
  }
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253

#define PADDLE_TENSOR_ADD(T, CONTEXT)                                          \
  if (data_type == framework::DataTypeTrait<T>::DataType()) {                  \
    auto cpu_ctx = static_cast<CONTEXT*>(                                      \
        platform::DeviceContextPool::Instance().Get(place));                   \
    phi::AddKernel<T, CONTEXT>(*cpu_ctx, src_tensor, *dst_tensor, dst_tensor); \
    return;                                                                    \
  }

  if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    PADDLE_TENSOR_ADD(float, phi::GPUContext);
    PADDLE_TENSOR_ADD(double, phi::GPUContext);
    PADDLE_TENSOR_ADD(phi::dtype::float16, phi::GPUContext);
    PADDLE_TENSOR_ADD(phi::dtype::bfloat16, phi::GPUContext);
    PADDLE_TENSOR_ADD(platform::complex<float>, phi::GPUContext);
    PADDLE_TENSOR_ADD(platform::complex<double>, phi::GPUContext);
#endif
  }

#define TENSOR_ADD_EIGEN(T)                                           \
  auto cpu_ctx = static_cast<phi::CPUContext*>(                       \
      platform::DeviceContextPool::Instance().Get(place));            \
  auto in = paddle::framework::EigenVector<T>::Flatten(src_tensor);   \
  auto out = paddle::framework::EigenVector<T>::Flatten(*dst_tensor); \
  auto& p = *(cpu_ctx->eigen_device());                               \
  out.device(p) = out + in;                                           \
  return;

  if (platform::is_cpu_place(place)) {
    PADDLE_TENSOR_ADD(float, phi::CPUContext);
    PADDLE_TENSOR_ADD(double, phi::CPUContext);
    PADDLE_TENSOR_ADD(platform::complex<float>, phi::CPUContext);
    PADDLE_TENSOR_ADD(platform::complex<double>, phi::CPUContext);
    if (data_type == framework::proto::VarType::BF16) {
      TENSOR_ADD_EIGEN(phi::dtype::bfloat16);
    }
    if (data_type == framework::proto::VarType::FP16) {
      TENSOR_ADD_EIGEN(phi::dtype::float16);
    }
  }

#define PADDLE_TENSOR_ADD_CUSTOM(T)                              \
  if (data_type == framework::DataTypeTrait<T>::DataType()) {    \
    platform::CustomDeviceContext* ctx =                         \
        static_cast<platform::CustomDeviceContext*>(             \
            platform::DeviceContextPool::Instance().Get(place)); \
    phi::stream::Stream stream(place, ctx->stream());            \
    auto device = phi::DeviceManager::GetDeviceWithPlace(place); \
    device->BlasAXPBY<T>(stream,                                 \
                         static_cast<size_t>(numel),             \
                         1.,                                     \
                         src_tensor.data<T>(),                   \
                         1.,                                     \
                         dst_tensor->mutable_data<T>(place));    \
    return;                                                      \
  }

  if (platform::is_custom_place(place)) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
    PADDLE_TENSOR_ADD_CUSTOM(float);
    PADDLE_TENSOR_ADD_CUSTOM(double);
    PADDLE_TENSOR_ADD_CUSTOM(platform::complex<float>);
    PADDLE_TENSOR_ADD_CUSTOM(platform::complex<double>);
#endif
J
Jiabin Yang 已提交
254 255
  }

256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
#ifdef PADDLE_WITH_ASCEND_CL
  if (platform::is_npu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::DeviceContext* ctx = pool.Get(place);
    auto dev_ctx = dynamic_cast<platform::NPUDeviceContext*>(ctx);
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      dst_tensor->mutable_data<float>(place);
    } else if (data_type == framework::DataTypeTrait<double>::DataType()) {
      dst_tensor->mutable_data<double>(place);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      dst_tensor->mutable_data<platform::float16>(place);
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
272 273
          framework::DataTypeToString(data_type),
          place));
274 275 276 277 278 279 280
    }
    const auto& runner = operators::NpuOpRunner(
        "Add", {*dst_tensor, src_tensor}, {*dst_tensor}, {});
    runner.Run(dev_ctx->stream());
    return;
  }
#endif
281

282 283 284 285 286 287 288 289 290 291 292
#ifdef PADDLE_WITH_XPU
  if (platform::is_xpu_place(place)) {
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      XPUTensorAddFunctor<float>(place, src_tensor, dst_tensor);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      XPUTensorAddFunctor<platform::float16>(place, src_tensor, dst_tensor);
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
293 294
          framework::DataTypeToString(data_type),
          place));
295 296 297 298 299
    }
    return;
  }
#endif

F
fwenguang 已提交
300 301 302 303 304 305 306 307 308 309 310 311 312 313
#ifdef PADDLE_WITH_MLU
  if (platform::is_mlu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::DeviceContext* ctx = pool.Get(place);
    auto dev_ctx = dynamic_cast<platform::MLUDeviceContext*>(ctx);
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      dst_tensor->mutable_data<float>(place);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      dst_tensor->mutable_data<platform::float16>(place);
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
314 315
          framework::DataTypeToString(data_type),
          place));
F
fwenguang 已提交
316 317 318 319 320
    }
    static const float alpha = 1.f;
    static const float beta = 1.f;
    operators::MLUCnnlTensorDesc src_tensor_desc(src_tensor);
    operators::MLUCnnlTensorDesc dst_tensor_desc(*dst_tensor);
321 322 323 324 325 326 327 328 329 330
    PADDLE_ENFORCE_MLU_SUCCESS(
        cnnlAssignAdd(dev_ctx->cnnl_handle(),
                      static_cast<const void*>(&alpha),
                      src_tensor_desc.get(),
                      operators::GetBasePtr(&src_tensor),
                      nullptr,
                      0,
                      static_cast<const void*>(&beta),
                      dst_tensor_desc.get(),
                      operators::GetBasePtr(dst_tensor)));
F
fwenguang 已提交
331 332 333 334
    return;
  }
#endif

335 336 337
  PADDLE_THROW(platform::errors::Unimplemented(
      "Gradient accumulation of data type (%s) on place (%s) is not "
      "supported in imperative mode",
338 339
      framework::DataTypeToString(data_type),
      place));
J
Jiabin Yang 已提交
340 341
}

342 343
template void TensorAdd<framework::Variable>(const framework::Variable& src,
                                             framework::Variable* dst);
344 345
template void TensorAdd<paddle::experimental::Tensor>(
    const paddle::experimental::Tensor& src, paddle::experimental::Tensor* dst);
346

347 348
template <typename VarType>
void SelectedRowsAddToTensor(const VarType& src, VarType* dst) {
349 350 351
  phi::DenseTensor* dst_tensor = GetInnerMutableTensor<phi::DenseTensor>(dst);
  const phi::SelectedRows& src_selected_rows =
      GetInnerTensor<phi::SelectedRows>(src);
352
  auto place = dst_tensor->place();
353 354
  auto data_type =
      framework::TransToProtoVarType(src_selected_rows.value().dtype());
355 356
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

357 358 359 360 361 362 363 364
#define PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(dev_ctx_type, cpp_type)       \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {     \
    paddle::platform::DeviceContext* dev_ctx = pool.Get(place);          \
    phi::funcs::SelectedRowsAddToTensor<dev_ctx_type, cpp_type> functor; \
    functor(*(dynamic_cast<dev_ctx_type*>(dev_ctx)),                     \
            src_selected_rows,                                           \
            dst_tensor);                                                 \
    return;                                                              \
365 366
  }

367
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
368
  if (paddle::platform::is_gpu_place(place)) {
L
Leo Chen 已提交
369 370
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::GPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::GPUContext, double);
371 372
  } else {
#endif
L
Leo Chen 已提交
373 374
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::CPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::CPUContext, double);
375
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
376 377 378 379 380 381 382 383 384 385
  }
#endif

#undef PADDLE_SELECTED_ROWS_ADD_TO_TENSOR

  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsAddToTensor",
      framework::DataTypeToString(data_type)));
}

386 387 388 389 390 391 392 393 394
template void SelectedRowsAddToTensor(const framework::Variable& src,
                                      framework::Variable* dst);
template void SelectedRowsAddToTensor(const paddle::experimental::Tensor& src,
                                      paddle::experimental::Tensor* dst);

template <typename VarType>
void SelectedRowsAddTensor(const VarType& src_selected_rows_var,
                           const VarType& src_tensor_var,
                           VarType* dst_tensor_var) {
395 396 397 398
  const phi::SelectedRows& src_selected_rows =
      GetInnerTensor<phi::SelectedRows>(src_selected_rows_var);
  const phi::DenseTensor& src_tensor =
      GetInnerTensor<phi::DenseTensor>(src_tensor_var);
399
  const auto& place = src_tensor.place();
400
  auto data_type = framework::TransToProtoVarType(src_tensor.dtype());
401 402
  auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);

403 404
  phi::DenseTensor* dst_tensor =
      GetInnerMutableTensor<phi::DenseTensor>(dst_tensor_var);
405
  dst_tensor->Resize(src_tensor.dims());
406 407
  dst_tensor->mutable_data(place, src_tensor.dtype());

408 409 410 411 412 413 414 415
#define PADDLE_SELECTED_ROWS_ADD_TENSOR(dev_ctx_type, cpp_type)        \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {   \
    phi::funcs::SelectedRowsAddTensor<dev_ctx_type, cpp_type> functor; \
    functor(*(dynamic_cast<dev_ctx_type*>(dev_ctx)),                   \
            src_selected_rows,                                         \
            src_tensor,                                                \
            dst_tensor);                                               \
    return;                                                            \
416 417
  }

418
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
419
  if (platform::is_gpu_place(place)) {
L
Leo Chen 已提交
420 421
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::GPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::GPUContext, double);
422 423
  } else {
#endif
L
Leo Chen 已提交
424 425
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::CPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::CPUContext, double);
426
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
427 428 429 430 431 432 433 434 435 436
  }
#endif

  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsAddToTensor",
      framework::DataTypeToString(data_type)));

#undef PADDLE_SELECTED_ROWS_ADD_TENSOR
}

437 438 439 440 441 442 443 444 445 446 447 448
template void SelectedRowsAddTensor(
    const framework::Variable& src_selected_rows_var,
    const framework::Variable& src_tensor_var,
    framework::Variable* dst_tensor_var);
template void SelectedRowsAddTensor(
    const paddle::experimental::Tensor& src_selected_rows_var,
    const paddle::experimental::Tensor& src_tensor_var,
    paddle::experimental::Tensor* dst_tensor_var);

// Note(chenweihang): when two selected rows need to be added,
//   adding one to another is not equal to merging two selected rows
//   to one then add it to a empty selected rows, the after is correct
449 450 451
template <typename ReturnVarType, typename VarType>
std::shared_ptr<ReturnVarType> SelectedRowsMerge(const VarType& src1,
                                                 const VarType& src2) {
452 453 454 455
  const phi::SelectedRows& src_selected_rows1 =
      GetInnerTensor<phi::SelectedRows>(src1);
  const phi::SelectedRows& src_selected_rows2 =
      GetInnerTensor<phi::SelectedRows>(src2);
456

457
  auto place = src_selected_rows1.value().place();
458 459
  auto data_type =
      framework::TransToProtoVarType(src_selected_rows1.value().dtype());
460 461
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

462
  std::vector<const phi::SelectedRows*> src_selected_rows;
463 464
  src_selected_rows.emplace_back(&src_selected_rows1);
  src_selected_rows.emplace_back(&src_selected_rows2);
465 466

  auto dst_var = std::make_shared<ReturnVarType>("Temp");
467 468
  phi::SelectedRows* dst_selected_rows =
      GetEmptyInnerTensor<phi::SelectedRows>(dst_var.get());
469

470 471 472 473 474 475 476 477
#define PADDLE_SELECTED_ROWS_ADD(dev_ctx_type, cpp_type)             \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) { \
    paddle::platform::DeviceContext* dev_ctx = pool.Get(place);      \
    phi::funcs::scatter::MergeAdd<dev_ctx_type, cpp_type> merge_add; \
    merge_add(*(dynamic_cast<dev_ctx_type*>(dev_ctx)),               \
              src_selected_rows,                                     \
              dst_selected_rows);                                    \
    return dst_var;                                                  \
478 479
  }

480
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
481
  if (paddle::platform::is_gpu_place(place)) {
L
Leo Chen 已提交
482 483
    PADDLE_SELECTED_ROWS_ADD(phi::GPUContext, float);
    PADDLE_SELECTED_ROWS_ADD(phi::GPUContext, double);
484 485
  } else {
#endif
486 487 488 489 490 491 492 493 494 495
#if defined(PADDLE_WITH_XPU)
    if (paddle::platform::is_xpu_place(place)) {
      PADDLE_SELECTED_ROWS_ADD(phi::XPUContext, float);
    } else {
#endif
      PADDLE_SELECTED_ROWS_ADD(phi::CPUContext, float);
      PADDLE_SELECTED_ROWS_ADD(phi::CPUContext, double);
#if defined(PADDLE_WITH_XPU)
    }
#endif
496
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
497 498 499 500 501 502 503 504 505
  }
#endif

#undef PADDLE_SELECTED_ROWS_ADD
  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsMerge",
      framework::DataTypeToString(data_type)));
}

506 507 508 509 510 511
template std::shared_ptr<paddle::experimental::Tensor> SelectedRowsMerge(
    const paddle::experimental::Tensor& src1,
    const paddle::experimental::Tensor& src2);
template std::shared_ptr<paddle::imperative::VariableWrapper> SelectedRowsMerge(
    const framework::Variable& src1, const framework::Variable& src2);

512
void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var,
513 514
                        VariableWrapper* dst_var,
                        bool unchange_input) {
515
  auto& src = var->Var();
516
  auto* dst = dst_var->MutableVar();
517 518
  if (dst->IsType<phi::DenseTensor>()) {
    if (src.IsType<phi::DenseTensor>()) {
519
      TensorAdd<framework::Variable>(src, dst);
520
    } else if (src.IsType<phi::SelectedRows>()) {
521 522 523 524 525 526 527
      SelectedRowsAddToTensor(src, dst);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unexpected branch, output variable type is %s",
          framework::ToTypeName(dst->Type())));
    }
  } else {
528
    if (src.IsType<phi::DenseTensor>()) {
529 530 531 532 533 534 535 536 537
      if (unchange_input) {
        framework::Variable new_dst;
        SelectedRowsAddTensor(*dst, src, &new_dst);
        *dst = std::move(new_dst);
      } else {
        auto* src_mutable = var->MutableVar();
        SelectedRowsAddToTensor(*dst, src_mutable);
        *dst = std::move(*(var->MutableVar()));
      }
538
    } else if (src.IsType<phi::SelectedRows>()) {
539
      auto temp = SelectedRowsMerge<VariableWrapper>(src, *dst);
540 541 542 543 544 545 546 547 548
      *dst = std::move(*(temp->MutableVar()));
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unexpected branch, output variable type is %s",
          framework::ToTypeName(dst->Type())));
    }
  }
}

549 550
static platform::Place GetPlaceOfVar(
    const std::shared_ptr<VariableWrapper>& var) {
551
  platform::Place place;
552 553
  if (var->Var().IsType<phi::DenseTensor>()) {
    place = var->Var().Get<phi::DenseTensor>().place();
554 555
  } else if (var->Var().IsType<phi::SelectedRows>()) {
    place = var->Var().Get<phi::SelectedRows>().place();
556 557 558 559 560 561 562
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "only support LoDTensor and SelectedRows in dygraph"));
  }
  return place;
}

563 564
void GradientAccumulator::AccumulateGrad() {
  /**
565 566
   * If the leaf gradient has been calculated done, the inner_var_
   * should be added to the var_.
567 568 569 570
   */
  if (!var_->IsLeafGrad() || !SumGradCompleted() || !HasInnerVar()) {
    return;
  }
571 572
  PADDLE_ENFORCE_EQ(HasInnerVar(),
                    true,
573 574 575
                    platform::errors::InvalidArgument(
                        "Leaf tensor should have inner var to store results of "
                        "this auto-grad"));
576 577
  PADDLE_ENFORCE_EQ(inner_var_->Var().IsInitialized(),
                    true,
578
                    platform::errors::InvalidArgument(
579
                        "Interior var of Leaf tensor should be initialized."));
580 581 582
  auto* src = inner_var_->MutableVar();
  auto* dst = var_->MutableVar();
  if (!var_->IsEmpty()) {
583 584 585
    VLOG(6) << "Leaf Var(" << var_->Name()
            << ")'s Gradient has been initizlized, will accumulate on "
               "previous gradient.";
586 587
    if (dst->IsType<phi::DenseTensor>()) {
      if (src->IsType<phi::DenseTensor>()) {
588
        TensorAdd<framework::Variable>(*src, dst);
589
      } else if (src->IsType<phi::SelectedRows>()) {
590 591
        SelectedRowsAddToTensor(*src, dst);
      }
592
    } else if (dst->IsType<phi::SelectedRows>()) {
593
      if (src->IsType<phi::DenseTensor>()) {
594 595
        SelectedRowsAddToTensor(*dst, src);
        *dst = std::move(*src);
596
      } else if (src->IsType<phi::SelectedRows>()) {
597
        auto temp = SelectedRowsMerge<VariableWrapper>(*src, *dst);
598 599 600 601 602 603 604
        *dst = std::move(*(temp->MutableVar()));
      }
    } else {
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Only support LoDTensor and SelectedRows for gradient var"));
    }
  } else {
605 606 607
    VLOG(6)
        << "Leaf Var(" << var_->Name()
        << ")'s Gradient has not been initialized, not accumulate. Just move";
608 609 610
    *(dst) = std::move(*src);
    var_->SetType(inner_var_->Type());
    var_->SetDataType(inner_var_->DataType());
611
    var_->SetIsEmpty(false);
612 613 614 615
  }
  inner_var_.reset();
}

616
void GradientAccumulator::CallGradientHooks() {
617 618
  PADDLE_ENFORCE_EQ(var_->IsLeafGrad(),
                    true,
619 620 621 622
                    platform::errors::Unavailable(
                        "Only leaf gradient Tensor can deal with by gradient "
                        "hook in gradient accumulator."));
  PADDLE_ENFORCE_EQ(
623 624
      SumGradCompleted(),
      true,
625 626 627
      platform::errors::PreconditionNotMet(
          "Only can call gradient hooks after sum gradient completed."));
  PADDLE_ENFORCE_EQ(
628 629
      HasInnerVar(),
      true,
630 631 632
      platform::errors::PreconditionNotMet(
          "Leaf Tensor's inner var is nullptr when call gradient hook."));
  PADDLE_ENFORCE_EQ(
633 634
      inner_var_->Var().IsInitialized(),
      true,
635 636 637
      platform::errors::PreconditionNotMet("Leaf Tensor's inner var "
                                           "is not initialized when "
                                           "call gradient hook."));
638 639
  if (var_->HasVariableWrapperHook()) {
    VLOG(3) << "Call " << var_->GetVariableWrapperHooks().size()
640 641 642 643
            << " hooks of leaf gradient accumulator's inner var `"
            << var_->Name() << "`.";
    auto tmp_var = inner_var_;
    VLOG(3) << "Input var " << var_->Name() << "'s hook size - "
644 645
            << var_->GetVariableWrapperHooks().size();
    for (const auto& hook_pair : var_->GetVariableWrapperHooks()) {
646
      tmp_var = (*hook_pair.second)(tmp_var);
L
Leo Chen 已提交
647
      CheckVar(inner_var_, tmp_var);
648 649 650 651 652 653 654
    }
    inner_var_ = tmp_var;
  }
}

void GradientAccumulator::CallReduceHooks() {
  PADDLE_ENFORCE_EQ(
655 656
      var_->IsLeafGrad(),
      true,
657 658
      platform::errors::Unavailable("Only leaf gradient Tensor can deal with "
                                    "by reduce hook in gradient accumulator."));
659 660
  PADDLE_ENFORCE_EQ(SumGradCompleted(),
                    true,
661 662 663
                    platform::errors::PreconditionNotMet(
                        "Only can call reduce hooks after the gradient "
                        "summation is completed in current batch."));
664 665
  PADDLE_ENFORCE_EQ(HasInnerVar(),
                    false,
666 667 668 669
                    platform::errors::PreconditionNotMet(
                        "Only can call reduce hooks after the "
                        "gradient accumulation is completed in "
                        "current batch or across batchs."));
670 671
  if (var_->HasVoidHook()) {
    for (const auto& hook : var_->GetVoidHooks()) {
672
      VLOG(3) << "call gradient accumulator backward hooks.";
673
      (*hook)();
674 675 676 677
    }
  }
}

678
void EagerGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
679 680
                                       size_t trace_id,
                                       bool unchange_input) {
681 682 683 684 685 686 687 688
  /**
   * If var has grad node, it indicates that this var would be an input
   * of a grad op. Therefore, it should not be changed.
   */
  if (var->HasGradNode()) {
    unchange_input = true;
  }

689
  auto* dst_var = Var();
690
  platform::Place place = GetPlaceOfVar(var);
691 692 693
  if (!dst_var->OverridedStopGradient()) {
    if (CurCnt() == 0) {
      MoveOrCopyVar(dst_var->MutableVar(), var->MutableVar(), unchange_input);
694
    } else {
695 696 697
      VLOG(6) << "Sum Gradient for: " << dst_var->Name()
              << " within this graph.";
      VariableWrapperAdd(var, dst_var, unchange_input);
698
    }
J
Jiabin Yang 已提交
699
  } else {
700
    if (!dst_var->Var().IsInitialized() ||
701
        !dst_var->Var().Get<phi::DenseTensor>().IsInitialized()) {
702
      VLOG(6) << "Set StopGradient Grad: " << dst_var->Name() << " as zero ";
703
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
704
      if (!dst_var->Var().IsInitialized()) {
705 706 707 708
        auto* tensor = dst_var->MutableVar()->GetMutable<phi::DenseTensor>();
        VLOG(6) << "Dims of " << dst_var->Name()
                << " is set as: " << var->Var().Get<phi::DenseTensor>().dims();
        tensor->Resize(var->Var().Get<phi::DenseTensor>().dims());
709
        tensor->mutable_data(place,
710
                             framework::TransToPhiDataType(var->DataType()));
711
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
712
      } else {
713
        auto* tensor = dst_var->MutableVar()->GetMutable<phi::DenseTensor>();
714
        tensor->mutable_data(place,
715
                             framework::TransToPhiDataType(var->DataType()));
716
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
717
      }
718
    }
J
Jiabin Yang 已提交
719
  }
720

721 722
  // Type may be changed after OP run, such as VarTypeInference
  // so synchronous VariableWrapper with Variable.
723
  if (dst_var->Var().IsType<phi::DenseTensor>()) {
724
    dst_var->SetType(framework::proto::VarType::LOD_TENSOR);
725
  } else if (dst_var->Var().IsType<phi::SelectedRows>()) {
726
    dst_var->SetType(framework::proto::VarType::SELECTED_ROWS);
727
  }
728

729
  // Increase curent count
730
  IncreaseCurCnt();
J
Jiabin Yang 已提交
731 732
}

733
void SortedGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
734 735
                                        size_t trace_id,
                                        bool unchange_input) {
736
  auto* dst_var = Var();
737
  platform::Place place = GetPlaceOfVar(var);
738
  if (!dst_var->OverridedStopGradient()) {
739
    if (ref_cnt_ == 1) {
740 741
      MoveOrCopyVar(dst_var->MutableVar(),
                    var->MutableVar(),
742
                    unchange_input || var->HasGradNode());
743 744 745 746 747
    } else {
      if (tmp_grad_vars_.empty()) {
        tmp_grad_vars_.reserve(ref_cnt_);
      }

748
      tmp_grad_vars_.emplace_back(std::move(var), trace_id, unchange_input);
749 750 751 752 753

      if (tmp_grad_vars_.size() != ref_cnt_) {
        return;
      }

754 755
      VLOG(6) << "Sum Gradient for: " << dst_var->Name()
              << " within this graph.";
756 757
      std::sort(tmp_grad_vars_.begin(),
                tmp_grad_vars_.end(),
758 759 760 761 762 763 764 765 766
                [](const SavedVarInfo& info1, const SavedVarInfo& info2) {
                  return info1.trace_id > info2.trace_id;
                });

      for (auto& var_info : tmp_grad_vars_) {
        if (var_info.var->HasGradNode()) {
          var_info.unchange_input = true;
        }
      }
767

768
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
769
      if (paddle::platform::is_gpu_place(place)) {
770
        // sum selected rows firstly
771
        for (auto& var_info : tmp_grad_vars_) {
772
          if (!var_info.var->Var().IsType<phi::SelectedRows>()) {
773
            continue;
774
          }
775

776
          if (CurCnt() == 0) {
777 778
            MoveOrCopyVar(dst_var->MutableVar(),
                          var_info.var->MutableVar(),
779 780
                          var_info.unchange_input);
          } else {
781
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
782
          }
783 784

          var_info.var = nullptr;
785 786
          // Increase count
          IncreaseCurCnt();
787 788 789 790 791 792 793
        }

        for (auto& var_info : tmp_grad_vars_) {
          if (!var_info.var) {
            continue;
          }

794
          PADDLE_ENFORCE_EQ(var_info.var->Var().IsType<phi::DenseTensor>(),
795 796 797
                            true,
                            platform::errors::PermissionDenied(
                                "Gradient var must be LoDTensor"));
798
          if (CurCnt() == 0) {
799 800
            MoveOrCopyVar(dst_var->MutableVar(),
                          var_info.var->MutableVar(),
801 802
                          var_info.unchange_input);
          } else {
803
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
804
          }
805 806

          var_info.var = nullptr;
807 808
          // Increase count
          IncreaseCurCnt();
809 810 811
        }
      } else {
#endif
812 813 814 815 816
        for (auto& var_info : tmp_grad_vars_) {
          if (!var_info.var) {
            continue;
          }
          PADDLE_ENFORCE_EQ(
817
              var_info.var->Var().IsType<phi::DenseTensor>() ||
818
                  var_info.var->Var().IsType<phi::SelectedRows>(),
819 820 821 822
              true,
              platform::errors::PermissionDenied("The type of Gradient "
                                                 "var must be LoDTensor "
                                                 "or SelectedRows"));
823
          if (CurCnt() == 0) {
824 825
            MoveOrCopyVar(dst_var->MutableVar(),
                          var_info.var->MutableVar(),
826 827 828 829 830 831 832
                          var_info.unchange_input);
          } else {
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
          }
          var_info.var = nullptr;
          // Increase count
          IncreaseCurCnt();
833
        }
834
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
835
      }
836
#endif
837
      tmp_grad_vars_.clear();
J
Jiabin Yang 已提交
838
    }
839
  } else {
840
    if (!dst_var->Var().IsInitialized() ||
841
        !dst_var->Var().Get<phi::DenseTensor>().IsInitialized()) {
842 843
      VLOG(6) << "Set StopGradient Grad: " << var->Name() << " as zero";
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
844
      if (!dst_var->Var().IsInitialized()) {
845 846 847 848
        auto* tensor = dst_var->MutableVar()->GetMutable<phi::DenseTensor>();
        VLOG(6) << "Dims of " << dst_var->Name()
                << " is set as: " << var->Var().Get<phi::DenseTensor>().dims();
        tensor->Resize(var->Var().Get<phi::DenseTensor>().dims());
849
        tensor->mutable_data(place,
850
                             framework::TransToPhiDataType(var->DataType()));
851
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
852
      } else {
853
        auto* tensor = dst_var->MutableVar()->GetMutable<phi::DenseTensor>();
854
        tensor->mutable_data(place,
855
                             framework::TransToPhiDataType(var->DataType()));
856
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
857
      }
J
Jiabin Yang 已提交
858
    }
859
    // looks like tmp_grad_vars will not have any member but just in case
J
Jiabin Yang 已提交
860 861
    tmp_grad_vars_.clear();
  }
862

863
  if (dst_var->Var().IsType<phi::DenseTensor>()) {
864
    dst_var->SetType(framework::proto::VarType::LOD_TENSOR);
865
  } else if (dst_var->Var().IsType<phi::SelectedRows>()) {
866
    dst_var->SetType(framework::proto::VarType::SELECTED_ROWS);
867
  }
J
Jiabin Yang 已提交
868 869 870
}
}  // namespace imperative
}  // namespace paddle