gradient_accumulator.cc 34.8 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/gradient_accumulator.h"
16

J
Jiabin Yang 已提交
17 18 19
#include <algorithm>
#include <memory>
#include <utility>
20

21
#include "paddle/fluid/framework/convert_utils.h"
J
Jiabin Yang 已提交
22
#include "paddle/fluid/framework/lod_tensor.h"
23
#include "paddle/fluid/framework/selected_rows_utils.h"
J
Jiabin Yang 已提交
24
#include "paddle/fluid/imperative/layer.h"
25
#include "paddle/fluid/platform/bfloat16.h"
26
#include "paddle/fluid/platform/complex.h"
J
Jiabin Yang 已提交
27
#include "paddle/fluid/platform/device_context.h"
28
#include "paddle/fluid/platform/float16.h"
J
Jiabin Yang 已提交
29
#include "paddle/fluid/platform/profiler.h"
30 31
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
32
#include "paddle/phi/kernels/funcs/selected_rows_functor.h"
H
hong 已提交
33
#ifdef PADDLE_WITH_XPU
34
#include "paddle/phi/backends/xpu/enforce_xpu.h"
H
hong 已提交
35 36
#include "xpu/refactor/math.h"
#endif
37
#ifdef PADDLE_WITH_ASCEND_CL
38
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
39
#endif
F
fwenguang 已提交
40 41 42
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#endif
43 44 45
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
#endif
46
#include "paddle/phi/kernels/elementwise_add_kernel.h"
J
Jiabin Yang 已提交
47 48 49 50

namespace paddle {
namespace imperative {

51 52
static void MoveOrCopyVar(framework::Variable* dst,
                          framework::Variable* src,
53 54
                          bool force_copy) {
  if (!force_copy) {
55
    VLOG(6) << "Just Move Variable when sum gradients within this graph";
56 57 58 59
    *dst = std::move(*src);
    return;
  }

60
  VLOG(6) << "Copy occurs when sum gradients within this graph";
61 62 63
  if (src->IsType<phi::DenseTensor>()) {
    auto& src_tensor = src->Get<phi::DenseTensor>();
    if (!dst->IsType<phi::DenseTensor>()) {
64 65
      dst->Clear();
    }
66
    auto* dst_tensor = dst->GetMutable<phi::DenseTensor>();
67 68
    framework::TensorCopy(src_tensor, src_tensor.place(), dst_tensor);
    dst_tensor->set_lod(src_tensor.lod());
69 70 71
  } else if (src->IsType<phi::SelectedRows>()) {
    auto& src_selected_rows = src->Get<phi::SelectedRows>();
    if (!dst->IsType<phi::SelectedRows>()) {
72 73
      dst->Clear();
    }
74
    auto* dst_selected_rows = dst->GetMutable<phi::SelectedRows>();
75 76 77 78 79 80 81
    framework::TensorCopy(src_selected_rows.value(),
                          src_selected_rows.value().place(),
                          dst_selected_rows->mutable_value());
    dst_selected_rows->set_rows(src_selected_rows.rows());
    dst_selected_rows->set_height(src_selected_rows.height());
  } else {
    PADDLE_THROW(platform::errors::PermissionDenied(
82
        "Only support LoDTensor and SelectedRows for sum gradient"));
83 84 85
  }
}

86 87 88
#ifdef PADDLE_WITH_XPU
template <typename T>
void XPUTensorAddFunctor(const platform::Place& place,
89 90
                         const phi::DenseTensor& src,
                         phi::DenseTensor* dst) {
91 92 93 94 95
  using XPUType = typename XPUTypeTrait<T>::Type;
  platform::XPUDeviceContext* ctx = dynamic_cast<platform::XPUDeviceContext*>(
      platform::DeviceContextPool::Instance().Get(place));
  const XPUType* x = reinterpret_cast<const XPUType*>(src.data<T>());
  XPUType* y = reinterpret_cast<XPUType*>(dst->mutable_data<T>(place));
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
  int r = -1;
  int numel = static_cast<int>(src.numel());
  if (std::is_same<T, double>::value) {
    xpu::ctx_guard RAII_GUARD(ctx->x_context());
    float* x_cast_to_fp32 = RAII_GUARD.alloc<float>(numel);
    PADDLE_ENFORCE_XDNN_NOT_NULL(x_cast_to_fp32);
    float* y_cast_to_fp32 = RAII_GUARD.alloc<float>(numel);
    PADDLE_ENFORCE_XDNN_NOT_NULL(y_cast_to_fp32);
    r = xpu::cast<XPUType, float>(ctx->x_context(), x, x_cast_to_fp32, numel);
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
    r = xpu::cast<XPUType, float>(ctx->x_context(), y, y_cast_to_fp32, numel);
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
    r = xpu::add<float>(ctx->x_context(),
                        x_cast_to_fp32,
                        y_cast_to_fp32,
                        y_cast_to_fp32,
                        numel);
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "add");
    r = xpu::cast<float, XPUType>(ctx->x_context(), y_cast_to_fp32, y, numel);
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
  } else {
    r = xpu::add<XPUType>(ctx->x_context(), x, y, y, numel);
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "add");
  }
120 121 122
}
#endif

123 124 125
template <typename TType>
TType* GetInnerMutableTensor(framework::Variable* dst) {
  auto* dst_tensor = dst->GetMutable<TType>();
126 127 128
  return dst_tensor;
}

129 130 131
template <typename TType>
TType* GetInnerMutableTensor(paddle::experimental::Tensor* dst) {
  auto* dst_tensor = static_cast<TType*>(dst->impl().get());
132 133 134
  return dst_tensor;
}

135 136 137
template <typename TType>
const TType& GetInnerTensor(const framework::Variable& src) {
  return src.Get<TType>();
138 139
}

140 141 142
template <typename TType>
TType& GetInnerTensor(const paddle::experimental::Tensor& src) {
  PADDLE_ENFORCE_EQ(
143 144
      src.initialized(),
      true,
145 146 147 148 149
      platform::errors::Fatal("We only add tensor with value if a tensor is "
                              "NOT INITILIZED, it should just move instead of "
                              "calling this method."));
  auto* src_tensor = static_cast<TType*>(src.impl().get());
  return *src_tensor;
150 151
}

152 153 154
template <typename TType>
TType* GetEmptyInnerTensor(paddle::experimental::Tensor* dst) {
  PADDLE_ENFORCE_EQ(
155 156
      dst->defined(),
      false,
157 158 159 160 161 162 163 164 165 166 167 168 169
      platform::errors::Fatal(
          "The underlying Tensor implementation should be nullptr"));
  dst->set_impl(std::make_shared<TType>());
  auto* dst_tensor = static_cast<TType*>(dst->impl().get());
  return dst_tensor;
}

template <typename TType>
TType* GetEmptyInnerTensor(paddle::imperative::VariableWrapper* dst) {
  auto* dst_tensor = dst->MutableVar()->GetMutable<TType>();
  return dst_tensor;
}

170 171
template <typename VarType>
void TensorAdd(const VarType& src, VarType* dst) {
172 173
  phi::DenseTensor* dst_tensor = GetInnerMutableTensor<phi::DenseTensor>(dst);
  const phi::DenseTensor& src_tensor = GetInnerTensor<phi::DenseTensor>(src);
J
Jiabin Yang 已提交
174 175 176 177 178 179 180 181 182

  auto numel = src_tensor.numel();

  // FIXME(minqiyang): loss_grad op will pass a zero grad of label
  // ugly fix for it
  if (numel == 0) {
    return;
  }

183
  PADDLE_ENFORCE_EQ(
184 185
      dst_tensor->numel(),
      numel,
186 187 188 189
      platform::errors::PreconditionNotMet(
          "The number of elements of source tensor and destination tensor "
          "should be equal, but got the number of elements of source tensor is "
          "%zu and the number of elements of destination tensor is %zu.",
190 191
          numel,
          dst_tensor->numel()));
J
Jiabin Yang 已提交
192

193
  auto data_type = framework::TransToProtoVarType(src_tensor.dtype());
J
Jiabin Yang 已提交
194 195
  auto place = src_tensor.place();

196 197
  PADDLE_ENFORCE_EQ(framework::TransToProtoVarType(dst_tensor->dtype()),
                    data_type,
198 199 200 201 202
                    platform::errors::PreconditionNotMet(
                        "The data type of source tensor and destination tensor "
                        "should be equal, Otherwise, the calculation results "
                        "will be incorrect."));

203 204 205 206
  // if src and dst are in different place, copy dst to src's place
  if (dst_tensor->place() != place) {
    paddle::framework::TensorCopySync(*dst_tensor, place, dst_tensor);
  }
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226

#define PADDLE_TENSOR_ADD(T, CONTEXT)                                          \
  if (data_type == framework::DataTypeTrait<T>::DataType()) {                  \
    auto cpu_ctx = static_cast<CONTEXT*>(                                      \
        platform::DeviceContextPool::Instance().Get(place));                   \
    phi::AddKernel<T, CONTEXT>(*cpu_ctx, src_tensor, *dst_tensor, dst_tensor); \
    return;                                                                    \
  }

  if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    PADDLE_TENSOR_ADD(float, phi::GPUContext);
    PADDLE_TENSOR_ADD(double, phi::GPUContext);
    PADDLE_TENSOR_ADD(phi::dtype::float16, phi::GPUContext);
    PADDLE_TENSOR_ADD(phi::dtype::bfloat16, phi::GPUContext);
    PADDLE_TENSOR_ADD(platform::complex<float>, phi::GPUContext);
    PADDLE_TENSOR_ADD(platform::complex<double>, phi::GPUContext);
#endif
  }

227 228 229 230 231 232 233
#define TENSOR_ADD_EIGEN(T)                                \
  auto cpu_ctx = static_cast<phi::CPUContext*>(            \
      platform::DeviceContextPool::Instance().Get(place)); \
  auto in = phi::EigenVector<T>::Flatten(src_tensor);      \
  auto out = phi::EigenVector<T>::Flatten(*dst_tensor);    \
  auto& p = *(cpu_ctx->eigen_device());                    \
  out.device(p) = out + in;                                \
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
  return;

  if (platform::is_cpu_place(place)) {
    PADDLE_TENSOR_ADD(float, phi::CPUContext);
    PADDLE_TENSOR_ADD(double, phi::CPUContext);
    PADDLE_TENSOR_ADD(platform::complex<float>, phi::CPUContext);
    PADDLE_TENSOR_ADD(platform::complex<double>, phi::CPUContext);
    if (data_type == framework::proto::VarType::BF16) {
      TENSOR_ADD_EIGEN(phi::dtype::bfloat16);
    }
    if (data_type == framework::proto::VarType::FP16) {
      TENSOR_ADD_EIGEN(phi::dtype::float16);
    }
  }

#define PADDLE_TENSOR_ADD_CUSTOM(T)                              \
  if (data_type == framework::DataTypeTrait<T>::DataType()) {    \
    platform::CustomDeviceContext* ctx =                         \
        static_cast<platform::CustomDeviceContext*>(             \
            platform::DeviceContextPool::Instance().Get(place)); \
    phi::stream::Stream stream(place, ctx->stream());            \
    auto device = phi::DeviceManager::GetDeviceWithPlace(place); \
    device->BlasAXPBY<T>(stream,                                 \
                         static_cast<size_t>(numel),             \
                         1.,                                     \
                         src_tensor.data<T>(),                   \
                         1.,                                     \
                         dst_tensor->mutable_data<T>(place));    \
    return;                                                      \
  }

  if (platform::is_custom_place(place)) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
    PADDLE_TENSOR_ADD_CUSTOM(float);
    PADDLE_TENSOR_ADD_CUSTOM(double);
    PADDLE_TENSOR_ADD_CUSTOM(platform::complex<float>);
    PADDLE_TENSOR_ADD_CUSTOM(platform::complex<double>);
#endif
J
Jiabin Yang 已提交
272 273
  }

274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
#ifdef PADDLE_WITH_ASCEND_CL
  if (platform::is_npu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::DeviceContext* ctx = pool.Get(place);
    auto dev_ctx = dynamic_cast<platform::NPUDeviceContext*>(ctx);
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      dst_tensor->mutable_data<float>(place);
    } else if (data_type == framework::DataTypeTrait<double>::DataType()) {
      dst_tensor->mutable_data<double>(place);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      dst_tensor->mutable_data<platform::float16>(place);
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
290 291
          framework::DataTypeToString(data_type),
          place));
292 293 294 295 296 297 298
    }
    const auto& runner = operators::NpuOpRunner(
        "Add", {*dst_tensor, src_tensor}, {*dst_tensor}, {});
    runner.Run(dev_ctx->stream());
    return;
  }
#endif
299

300 301 302 303 304 305 306
#ifdef PADDLE_WITH_XPU
  if (platform::is_xpu_place(place)) {
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      XPUTensorAddFunctor<float>(place, src_tensor, dst_tensor);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      XPUTensorAddFunctor<platform::float16>(place, src_tensor, dst_tensor);
307 308
    } else if (data_type == framework::DataTypeTrait<double>::DataType()) {
      XPUTensorAddFunctor<double>(place, src_tensor, dst_tensor);
309 310 311 312
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
313 314
          framework::DataTypeToString(data_type),
          place));
315 316 317 318 319
    }
    return;
  }
#endif

F
fwenguang 已提交
320 321 322 323 324 325 326 327 328 329 330 331 332 333
#ifdef PADDLE_WITH_MLU
  if (platform::is_mlu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::DeviceContext* ctx = pool.Get(place);
    auto dev_ctx = dynamic_cast<platform::MLUDeviceContext*>(ctx);
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      dst_tensor->mutable_data<float>(place);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      dst_tensor->mutable_data<platform::float16>(place);
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
334 335
          framework::DataTypeToString(data_type),
          place));
F
fwenguang 已提交
336 337 338 339 340
    }
    static const float alpha = 1.f;
    static const float beta = 1.f;
    operators::MLUCnnlTensorDesc src_tensor_desc(src_tensor);
    operators::MLUCnnlTensorDesc dst_tensor_desc(*dst_tensor);
341 342 343 344 345 346 347 348 349 350
    PADDLE_ENFORCE_MLU_SUCCESS(
        cnnlAssignAdd(dev_ctx->cnnl_handle(),
                      static_cast<const void*>(&alpha),
                      src_tensor_desc.get(),
                      operators::GetBasePtr(&src_tensor),
                      nullptr,
                      0,
                      static_cast<const void*>(&beta),
                      dst_tensor_desc.get(),
                      operators::GetBasePtr(dst_tensor)));
F
fwenguang 已提交
351 352 353 354
    return;
  }
#endif

355 356 357
  PADDLE_THROW(platform::errors::Unimplemented(
      "Gradient accumulation of data type (%s) on place (%s) is not "
      "supported in imperative mode",
358 359
      framework::DataTypeToString(data_type),
      place));
J
Jiabin Yang 已提交
360 361
}

362 363
template void TensorAdd<framework::Variable>(const framework::Variable& src,
                                             framework::Variable* dst);
364 365
template void TensorAdd<paddle::experimental::Tensor>(
    const paddle::experimental::Tensor& src, paddle::experimental::Tensor* dst);
366

367 368
template <typename VarType>
void SelectedRowsAddToTensor(const VarType& src, VarType* dst) {
369 370 371
  phi::DenseTensor* dst_tensor = GetInnerMutableTensor<phi::DenseTensor>(dst);
  const phi::SelectedRows& src_selected_rows =
      GetInnerTensor<phi::SelectedRows>(src);
372
  auto place = dst_tensor->place();
373 374
  auto data_type =
      framework::TransToProtoVarType(src_selected_rows.value().dtype());
375 376
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

377 378 379 380 381 382 383 384
#define PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(dev_ctx_type, cpp_type)       \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {     \
    paddle::platform::DeviceContext* dev_ctx = pool.Get(place);          \
    phi::funcs::SelectedRowsAddToTensor<dev_ctx_type, cpp_type> functor; \
    functor(*(dynamic_cast<dev_ctx_type*>(dev_ctx)),                     \
            src_selected_rows,                                           \
            dst_tensor);                                                 \
    return;                                                              \
385 386
  }

387
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
388
  if (paddle::platform::is_gpu_place(place)) {
L
Leo Chen 已提交
389 390
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::GPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::GPUContext, double);
391 392
  } else {
#endif
L
Leo Chen 已提交
393 394
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::CPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::CPUContext, double);
395
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
396 397 398 399 400 401 402 403 404 405
  }
#endif

#undef PADDLE_SELECTED_ROWS_ADD_TO_TENSOR

  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsAddToTensor",
      framework::DataTypeToString(data_type)));
}

406 407 408 409 410 411 412 413 414
template void SelectedRowsAddToTensor(const framework::Variable& src,
                                      framework::Variable* dst);
template void SelectedRowsAddToTensor(const paddle::experimental::Tensor& src,
                                      paddle::experimental::Tensor* dst);

template <typename VarType>
void SelectedRowsAddTensor(const VarType& src_selected_rows_var,
                           const VarType& src_tensor_var,
                           VarType* dst_tensor_var) {
415 416 417 418
  const phi::SelectedRows& src_selected_rows =
      GetInnerTensor<phi::SelectedRows>(src_selected_rows_var);
  const phi::DenseTensor& src_tensor =
      GetInnerTensor<phi::DenseTensor>(src_tensor_var);
419
  const auto& place = src_tensor.place();
420
  auto data_type = framework::TransToProtoVarType(src_tensor.dtype());
421 422
  auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);

423 424
  phi::DenseTensor* dst_tensor =
      GetInnerMutableTensor<phi::DenseTensor>(dst_tensor_var);
425
  dst_tensor->Resize(src_tensor.dims());
426 427
  dst_tensor->mutable_data(place, src_tensor.dtype());

428 429 430 431 432 433 434 435
#define PADDLE_SELECTED_ROWS_ADD_TENSOR(dev_ctx_type, cpp_type)        \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {   \
    phi::funcs::SelectedRowsAddTensor<dev_ctx_type, cpp_type> functor; \
    functor(*(dynamic_cast<dev_ctx_type*>(dev_ctx)),                   \
            src_selected_rows,                                         \
            src_tensor,                                                \
            dst_tensor);                                               \
    return;                                                            \
436 437
  }

438
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
439
  if (platform::is_gpu_place(place)) {
L
Leo Chen 已提交
440 441
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::GPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::GPUContext, double);
442 443
  } else {
#endif
L
Leo Chen 已提交
444 445
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::CPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::CPUContext, double);
446
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
447 448 449 450 451 452 453 454 455 456
  }
#endif

  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsAddToTensor",
      framework::DataTypeToString(data_type)));

#undef PADDLE_SELECTED_ROWS_ADD_TENSOR
}

457 458 459 460 461 462 463 464 465 466 467 468
template void SelectedRowsAddTensor(
    const framework::Variable& src_selected_rows_var,
    const framework::Variable& src_tensor_var,
    framework::Variable* dst_tensor_var);
template void SelectedRowsAddTensor(
    const paddle::experimental::Tensor& src_selected_rows_var,
    const paddle::experimental::Tensor& src_tensor_var,
    paddle::experimental::Tensor* dst_tensor_var);

// Note(chenweihang): when two selected rows need to be added,
//   adding one to another is not equal to merging two selected rows
//   to one then add it to a empty selected rows, the after is correct
469 470 471
template <typename ReturnVarType, typename VarType>
std::shared_ptr<ReturnVarType> SelectedRowsMerge(const VarType& src1,
                                                 const VarType& src2) {
472 473 474 475
  const phi::SelectedRows& src_selected_rows1 =
      GetInnerTensor<phi::SelectedRows>(src1);
  const phi::SelectedRows& src_selected_rows2 =
      GetInnerTensor<phi::SelectedRows>(src2);
476

477
  auto place = src_selected_rows1.value().place();
478 479
  auto data_type =
      framework::TransToProtoVarType(src_selected_rows1.value().dtype());
480 481
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

482
  std::vector<const phi::SelectedRows*> src_selected_rows;
483 484
  src_selected_rows.emplace_back(&src_selected_rows1);
  src_selected_rows.emplace_back(&src_selected_rows2);
485 486

  auto dst_var = std::make_shared<ReturnVarType>("Temp");
487 488
  phi::SelectedRows* dst_selected_rows =
      GetEmptyInnerTensor<phi::SelectedRows>(dst_var.get());
489

490 491 492 493 494 495 496 497
#define PADDLE_SELECTED_ROWS_ADD(dev_ctx_type, cpp_type)             \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) { \
    paddle::platform::DeviceContext* dev_ctx = pool.Get(place);      \
    phi::funcs::scatter::MergeAdd<dev_ctx_type, cpp_type> merge_add; \
    merge_add(*(dynamic_cast<dev_ctx_type*>(dev_ctx)),               \
              src_selected_rows,                                     \
              dst_selected_rows);                                    \
    return dst_var;                                                  \
498 499
  }

500
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
501
  if (paddle::platform::is_gpu_place(place)) {
L
Leo Chen 已提交
502 503
    PADDLE_SELECTED_ROWS_ADD(phi::GPUContext, float);
    PADDLE_SELECTED_ROWS_ADD(phi::GPUContext, double);
504 505
  } else {
#endif
506 507 508 509 510 511 512 513 514 515
#if defined(PADDLE_WITH_XPU)
    if (paddle::platform::is_xpu_place(place)) {
      PADDLE_SELECTED_ROWS_ADD(phi::XPUContext, float);
    } else {
#endif
      PADDLE_SELECTED_ROWS_ADD(phi::CPUContext, float);
      PADDLE_SELECTED_ROWS_ADD(phi::CPUContext, double);
#if defined(PADDLE_WITH_XPU)
    }
#endif
516
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
517 518 519 520 521 522 523 524 525
  }
#endif

#undef PADDLE_SELECTED_ROWS_ADD
  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsMerge",
      framework::DataTypeToString(data_type)));
}

526 527 528 529 530 531
template std::shared_ptr<paddle::experimental::Tensor> SelectedRowsMerge(
    const paddle::experimental::Tensor& src1,
    const paddle::experimental::Tensor& src2);
template std::shared_ptr<paddle::imperative::VariableWrapper> SelectedRowsMerge(
    const framework::Variable& src1, const framework::Variable& src2);

532
void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var,
533 534
                        VariableWrapper* dst_var,
                        bool unchange_input) {
535
  auto& src = var->Var();
536
  auto* dst = dst_var->MutableVar();
537 538
  if (dst->IsType<phi::DenseTensor>()) {
    if (src.IsType<phi::DenseTensor>()) {
539
      TensorAdd<framework::Variable>(src, dst);
540
    } else if (src.IsType<phi::SelectedRows>()) {
541 542 543 544 545 546 547
      SelectedRowsAddToTensor(src, dst);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unexpected branch, output variable type is %s",
          framework::ToTypeName(dst->Type())));
    }
  } else {
548
    if (src.IsType<phi::DenseTensor>()) {
549 550 551 552 553 554 555 556 557
      if (unchange_input) {
        framework::Variable new_dst;
        SelectedRowsAddTensor(*dst, src, &new_dst);
        *dst = std::move(new_dst);
      } else {
        auto* src_mutable = var->MutableVar();
        SelectedRowsAddToTensor(*dst, src_mutable);
        *dst = std::move(*(var->MutableVar()));
      }
558
    } else if (src.IsType<phi::SelectedRows>()) {
559
      auto temp = SelectedRowsMerge<VariableWrapper>(src, *dst);
560 561 562 563 564 565 566 567 568
      *dst = std::move(*(temp->MutableVar()));
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unexpected branch, output variable type is %s",
          framework::ToTypeName(dst->Type())));
    }
  }
}

569 570
static platform::Place GetPlaceOfVar(
    const std::shared_ptr<VariableWrapper>& var) {
571
  platform::Place place;
572 573
  if (var->Var().IsType<phi::DenseTensor>()) {
    place = var->Var().Get<phi::DenseTensor>().place();
574 575
  } else if (var->Var().IsType<phi::SelectedRows>()) {
    place = var->Var().Get<phi::SelectedRows>().place();
576 577 578 579 580 581 582
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "only support LoDTensor and SelectedRows in dygraph"));
  }
  return place;
}

583 584
void GradientAccumulator::AccumulateGrad() {
  /**
585 586
   * If the leaf gradient has been calculated done, the inner_var_
   * should be added to the var_.
587 588 589 590
   */
  if (!var_->IsLeafGrad() || !SumGradCompleted() || !HasInnerVar()) {
    return;
  }
591 592
  PADDLE_ENFORCE_EQ(HasInnerVar(),
                    true,
593 594 595
                    platform::errors::InvalidArgument(
                        "Leaf tensor should have inner var to store results of "
                        "this auto-grad"));
596 597
  PADDLE_ENFORCE_EQ(inner_var_->Var().IsInitialized(),
                    true,
598
                    platform::errors::InvalidArgument(
599
                        "Interior var of Leaf tensor should be initialized."));
600 601 602
  auto* src = inner_var_->MutableVar();
  auto* dst = var_->MutableVar();
  if (!var_->IsEmpty()) {
603 604 605
    VLOG(6) << "Leaf Var(" << var_->Name()
            << ")'s Gradient has been initizlized, will accumulate on "
               "previous gradient.";
606 607
    if (dst->IsType<phi::DenseTensor>()) {
      if (src->IsType<phi::DenseTensor>()) {
608
        TensorAdd<framework::Variable>(*src, dst);
609
      } else if (src->IsType<phi::SelectedRows>()) {
610 611
        SelectedRowsAddToTensor(*src, dst);
      }
612
    } else if (dst->IsType<phi::SelectedRows>()) {
613
      if (src->IsType<phi::DenseTensor>()) {
614 615
        SelectedRowsAddToTensor(*dst, src);
        *dst = std::move(*src);
616
      } else if (src->IsType<phi::SelectedRows>()) {
617
        auto temp = SelectedRowsMerge<VariableWrapper>(*src, *dst);
618 619 620 621 622 623 624
        *dst = std::move(*(temp->MutableVar()));
      }
    } else {
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Only support LoDTensor and SelectedRows for gradient var"));
    }
  } else {
625 626 627
    VLOG(6)
        << "Leaf Var(" << var_->Name()
        << ")'s Gradient has not been initialized, not accumulate. Just move";
628 629 630
    *(dst) = std::move(*src);
    var_->SetType(inner_var_->Type());
    var_->SetDataType(inner_var_->DataType());
631
    var_->SetIsEmpty(false);
632 633 634 635
  }
  inner_var_.reset();
}

636
void GradientAccumulator::CallGradientHooks() {
637 638
  PADDLE_ENFORCE_EQ(var_->IsLeafGrad(),
                    true,
639 640 641 642
                    platform::errors::Unavailable(
                        "Only leaf gradient Tensor can deal with by gradient "
                        "hook in gradient accumulator."));
  PADDLE_ENFORCE_EQ(
643 644
      SumGradCompleted(),
      true,
645 646 647
      platform::errors::PreconditionNotMet(
          "Only can call gradient hooks after sum gradient completed."));
  PADDLE_ENFORCE_EQ(
648 649
      HasInnerVar(),
      true,
650 651 652
      platform::errors::PreconditionNotMet(
          "Leaf Tensor's inner var is nullptr when call gradient hook."));
  PADDLE_ENFORCE_EQ(
653 654
      inner_var_->Var().IsInitialized(),
      true,
655 656 657
      platform::errors::PreconditionNotMet("Leaf Tensor's inner var "
                                           "is not initialized when "
                                           "call gradient hook."));
658 659
  if (var_->HasVariableWrapperHook()) {
    VLOG(3) << "Call " << var_->GetVariableWrapperHooks().size()
660 661 662 663
            << " hooks of leaf gradient accumulator's inner var `"
            << var_->Name() << "`.";
    auto tmp_var = inner_var_;
    VLOG(3) << "Input var " << var_->Name() << "'s hook size - "
664 665
            << var_->GetVariableWrapperHooks().size();
    for (const auto& hook_pair : var_->GetVariableWrapperHooks()) {
666
      tmp_var = (*hook_pair.second)(tmp_var);
L
Leo Chen 已提交
667
      CheckVar(inner_var_, tmp_var);
668 669 670 671 672 673 674
    }
    inner_var_ = tmp_var;
  }
}

void GradientAccumulator::CallReduceHooks() {
  PADDLE_ENFORCE_EQ(
675 676
      var_->IsLeafGrad(),
      true,
677 678
      platform::errors::Unavailable("Only leaf gradient Tensor can deal with "
                                    "by reduce hook in gradient accumulator."));
679 680
  PADDLE_ENFORCE_EQ(SumGradCompleted(),
                    true,
681 682 683
                    platform::errors::PreconditionNotMet(
                        "Only can call reduce hooks after the gradient "
                        "summation is completed in current batch."));
684 685
  PADDLE_ENFORCE_EQ(HasInnerVar(),
                    false,
686 687 688 689
                    platform::errors::PreconditionNotMet(
                        "Only can call reduce hooks after the "
                        "gradient accumulation is completed in "
                        "current batch or across batchs."));
690 691
  if (var_->HasVoidHook()) {
    for (const auto& hook : var_->GetVoidHooks()) {
692
      VLOG(3) << "call gradient accumulator backward hooks.";
693
      (*hook)();
694 695 696 697
    }
  }
}

698
void EagerGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
699 700
                                       size_t trace_id,
                                       bool unchange_input) {
701 702 703 704 705 706 707 708
  /**
   * If var has grad node, it indicates that this var would be an input
   * of a grad op. Therefore, it should not be changed.
   */
  if (var->HasGradNode()) {
    unchange_input = true;
  }

709
  auto* dst_var = Var();
710
  platform::Place place = GetPlaceOfVar(var);
711 712 713
  if (!dst_var->OverridedStopGradient()) {
    if (CurCnt() == 0) {
      MoveOrCopyVar(dst_var->MutableVar(), var->MutableVar(), unchange_input);
714
    } else {
715 716 717
      VLOG(6) << "Sum Gradient for: " << dst_var->Name()
              << " within this graph.";
      VariableWrapperAdd(var, dst_var, unchange_input);
718
    }
J
Jiabin Yang 已提交
719
  } else {
720
    if (!dst_var->Var().IsInitialized() ||
721
        !dst_var->Var().Get<phi::DenseTensor>().IsInitialized()) {
722
      VLOG(6) << "Set StopGradient Grad: " << dst_var->Name() << " as zero ";
723
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
724
      if (!dst_var->Var().IsInitialized()) {
725 726 727 728
        auto* tensor = dst_var->MutableVar()->GetMutable<phi::DenseTensor>();
        VLOG(6) << "Dims of " << dst_var->Name()
                << " is set as: " << var->Var().Get<phi::DenseTensor>().dims();
        tensor->Resize(var->Var().Get<phi::DenseTensor>().dims());
729
        tensor->mutable_data(place,
730
                             framework::TransToPhiDataType(var->DataType()));
731
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
732
      } else {
733
        auto* tensor = dst_var->MutableVar()->GetMutable<phi::DenseTensor>();
734
        tensor->mutable_data(place,
735
                             framework::TransToPhiDataType(var->DataType()));
736
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
737
      }
738
    }
J
Jiabin Yang 已提交
739
  }
740

741 742
  // Type may be changed after OP run, such as VarTypeInference
  // so synchronous VariableWrapper with Variable.
743
  if (dst_var->Var().IsType<phi::DenseTensor>()) {
744
    dst_var->SetType(framework::proto::VarType::LOD_TENSOR);
745
  } else if (dst_var->Var().IsType<phi::SelectedRows>()) {
746
    dst_var->SetType(framework::proto::VarType::SELECTED_ROWS);
747
  }
748

749
  // Increase curent count
750
  IncreaseCurCnt();
J
Jiabin Yang 已提交
751 752
}

753
void SortedGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
754 755
                                        size_t trace_id,
                                        bool unchange_input) {
756
  auto* dst_var = Var();
757
  platform::Place place = GetPlaceOfVar(var);
758
  if (!dst_var->OverridedStopGradient()) {
759
    if (ref_cnt_ == 1) {
760 761
      MoveOrCopyVar(dst_var->MutableVar(),
                    var->MutableVar(),
762
                    unchange_input || var->HasGradNode());
763 764 765 766 767
    } else {
      if (tmp_grad_vars_.empty()) {
        tmp_grad_vars_.reserve(ref_cnt_);
      }

768
      tmp_grad_vars_.emplace_back(std::move(var), trace_id, unchange_input);
769 770 771 772 773

      if (tmp_grad_vars_.size() != ref_cnt_) {
        return;
      }

774 775
      VLOG(6) << "Sum Gradient for: " << dst_var->Name()
              << " within this graph.";
776 777
      std::sort(tmp_grad_vars_.begin(),
                tmp_grad_vars_.end(),
778 779 780 781 782 783 784 785 786
                [](const SavedVarInfo& info1, const SavedVarInfo& info2) {
                  return info1.trace_id > info2.trace_id;
                });

      for (auto& var_info : tmp_grad_vars_) {
        if (var_info.var->HasGradNode()) {
          var_info.unchange_input = true;
        }
      }
787

788
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
789
      if (paddle::platform::is_gpu_place(place)) {
790
        // sum selected rows firstly
791
        for (auto& var_info : tmp_grad_vars_) {
792
          if (!var_info.var->Var().IsType<phi::SelectedRows>()) {
793
            continue;
794
          }
795

796
          if (CurCnt() == 0) {
797 798
            MoveOrCopyVar(dst_var->MutableVar(),
                          var_info.var->MutableVar(),
799 800
                          var_info.unchange_input);
          } else {
801
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
802
          }
803 804

          var_info.var = nullptr;
805 806
          // Increase count
          IncreaseCurCnt();
807 808 809 810 811 812 813
        }

        for (auto& var_info : tmp_grad_vars_) {
          if (!var_info.var) {
            continue;
          }

814
          PADDLE_ENFORCE_EQ(var_info.var->Var().IsType<phi::DenseTensor>(),
815 816 817
                            true,
                            platform::errors::PermissionDenied(
                                "Gradient var must be LoDTensor"));
818
          if (CurCnt() == 0) {
819 820
            MoveOrCopyVar(dst_var->MutableVar(),
                          var_info.var->MutableVar(),
821 822
                          var_info.unchange_input);
          } else {
823
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
824
          }
825 826

          var_info.var = nullptr;
827 828
          // Increase count
          IncreaseCurCnt();
829 830 831
        }
      } else {
#endif
832 833 834 835 836
        for (auto& var_info : tmp_grad_vars_) {
          if (!var_info.var) {
            continue;
          }
          PADDLE_ENFORCE_EQ(
837
              var_info.var->Var().IsType<phi::DenseTensor>() ||
838
                  var_info.var->Var().IsType<phi::SelectedRows>(),
839 840 841 842
              true,
              platform::errors::PermissionDenied("The type of Gradient "
                                                 "var must be LoDTensor "
                                                 "or SelectedRows"));
843
          if (CurCnt() == 0) {
844 845
            MoveOrCopyVar(dst_var->MutableVar(),
                          var_info.var->MutableVar(),
846 847 848 849 850 851 852
                          var_info.unchange_input);
          } else {
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
          }
          var_info.var = nullptr;
          // Increase count
          IncreaseCurCnt();
853
        }
854
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
855
      }
856
#endif
857
      tmp_grad_vars_.clear();
J
Jiabin Yang 已提交
858
    }
859
  } else {
860
    if (!dst_var->Var().IsInitialized() ||
861
        !dst_var->Var().Get<phi::DenseTensor>().IsInitialized()) {
862 863
      VLOG(6) << "Set StopGradient Grad: " << var->Name() << " as zero";
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
864
      if (!dst_var->Var().IsInitialized()) {
865 866 867 868
        auto* tensor = dst_var->MutableVar()->GetMutable<phi::DenseTensor>();
        VLOG(6) << "Dims of " << dst_var->Name()
                << " is set as: " << var->Var().Get<phi::DenseTensor>().dims();
        tensor->Resize(var->Var().Get<phi::DenseTensor>().dims());
869
        tensor->mutable_data(place,
870
                             framework::TransToPhiDataType(var->DataType()));
871
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
872
      } else {
873
        auto* tensor = dst_var->MutableVar()->GetMutable<phi::DenseTensor>();
874
        tensor->mutable_data(place,
875
                             framework::TransToPhiDataType(var->DataType()));
876
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
877
      }
J
Jiabin Yang 已提交
878
    }
879
    // looks like tmp_grad_vars will not have any member but just in case
J
Jiabin Yang 已提交
880 881
    tmp_grad_vars_.clear();
  }
882

883
  if (dst_var->Var().IsType<phi::DenseTensor>()) {
884
    dst_var->SetType(framework::proto::VarType::LOD_TENSOR);
885
  } else if (dst_var->Var().IsType<phi::SelectedRows>()) {
886
    dst_var->SetType(framework::proto::VarType::SELECTED_ROWS);
887
  }
J
Jiabin Yang 已提交
888 889 890
}
}  // namespace imperative
}  // namespace paddle