gradient_accumulator.cc 34.1 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/gradient_accumulator.h"
16

J
Jiabin Yang 已提交
17 18 19
#include <algorithm>
#include <memory>
#include <utility>
20

21
#include "paddle/fluid/framework/convert_utils.h"
J
Jiabin Yang 已提交
22
#include "paddle/fluid/framework/lod_tensor.h"
23
#include "paddle/fluid/framework/selected_rows_utils.h"
J
Jiabin Yang 已提交
24
#include "paddle/fluid/imperative/layer.h"
25
#include "paddle/fluid/operators/math/selected_rows_functor.h"
26
#include "paddle/fluid/platform/bfloat16.h"
27
#include "paddle/fluid/platform/complex.h"
J
Jiabin Yang 已提交
28
#include "paddle/fluid/platform/device_context.h"
29
#include "paddle/fluid/platform/float16.h"
J
Jiabin Yang 已提交
30
#include "paddle/fluid/platform/profiler.h"
31 32
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
H
hong 已提交
33 34 35
#ifdef PADDLE_WITH_XPU
#include "xpu/refactor/math.h"
#endif
36
#ifdef PADDLE_WITH_ASCEND_CL
37
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
38
#endif
F
fwenguang 已提交
39 40 41
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#endif
42 43 44
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
#endif
45
#include "paddle/phi/kernels/elementwise_add_kernel.h"
J
Jiabin Yang 已提交
46 47 48 49

namespace paddle {
namespace imperative {

50 51
static void MoveOrCopyVar(framework::Variable* dst,
                          framework::Variable* src,
52 53
                          bool force_copy) {
  if (!force_copy) {
54
    VLOG(6) << "Just Move Variable when sum gradients within this graph";
55 56 57 58
    *dst = std::move(*src);
    return;
  }

59
  VLOG(6) << "Copy occurs when sum gradients within this graph";
60 61 62 63 64 65 66 67
  if (src->IsType<framework::LoDTensor>()) {
    auto& src_tensor = src->Get<framework::LoDTensor>();
    if (!dst->IsType<framework::LoDTensor>()) {
      dst->Clear();
    }
    auto* dst_tensor = dst->GetMutable<framework::LoDTensor>();
    framework::TensorCopy(src_tensor, src_tensor.place(), dst_tensor);
    dst_tensor->set_lod(src_tensor.lod());
68 69 70
  } else if (src->IsType<phi::SelectedRows>()) {
    auto& src_selected_rows = src->Get<phi::SelectedRows>();
    if (!dst->IsType<phi::SelectedRows>()) {
71 72
      dst->Clear();
    }
73
    auto* dst_selected_rows = dst->GetMutable<phi::SelectedRows>();
74 75 76 77 78 79 80
    framework::TensorCopy(src_selected_rows.value(),
                          src_selected_rows.value().place(),
                          dst_selected_rows->mutable_value());
    dst_selected_rows->set_rows(src_selected_rows.rows());
    dst_selected_rows->set_height(src_selected_rows.height());
  } else {
    PADDLE_THROW(platform::errors::PermissionDenied(
81
        "Only support LoDTensor and SelectedRows for sum gradient"));
82 83 84
  }
}

85 86 87
#ifdef PADDLE_WITH_XPU
template <typename T>
void XPUTensorAddFunctor(const platform::Place& place,
88 89
                         const framework::Tensor& src,
                         framework::Tensor* dst) {
90 91 92 93 94
  using XPUType = typename XPUTypeTrait<T>::Type;
  platform::XPUDeviceContext* ctx = dynamic_cast<platform::XPUDeviceContext*>(
      platform::DeviceContextPool::Instance().Get(place));
  const XPUType* x = reinterpret_cast<const XPUType*>(src.data<T>());
  XPUType* y = reinterpret_cast<XPUType*>(dst->mutable_data<T>(place));
95 96
  int r = xpu::add<XPUType>(
      ctx->x_context(), x, y, y, static_cast<int>(src.numel()));
97
  PADDLE_ENFORCE_EQ(
98 99 100 101
      r,
      XPU_SUCCESS,
      platform::errors::External(
          "XPU add kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r]));
102 103 104
}
#endif

105 106 107
template <typename TType>
TType* GetInnerMutableTensor(framework::Variable* dst) {
  auto* dst_tensor = dst->GetMutable<TType>();
108 109 110
  return dst_tensor;
}

111 112 113
template <typename TType>
TType* GetInnerMutableTensor(paddle::experimental::Tensor* dst) {
  auto* dst_tensor = static_cast<TType*>(dst->impl().get());
114 115 116
  return dst_tensor;
}

117 118 119
template <typename TType>
const TType& GetInnerTensor(const framework::Variable& src) {
  return src.Get<TType>();
120 121
}

122 123 124
template <typename TType>
TType& GetInnerTensor(const paddle::experimental::Tensor& src) {
  PADDLE_ENFORCE_EQ(
125 126
      src.initialized(),
      true,
127 128 129 130 131
      platform::errors::Fatal("We only add tensor with value if a tensor is "
                              "NOT INITILIZED, it should just move instead of "
                              "calling this method."));
  auto* src_tensor = static_cast<TType*>(src.impl().get());
  return *src_tensor;
132 133
}

134 135 136
template <typename TType>
TType* GetEmptyInnerTensor(paddle::experimental::Tensor* dst) {
  PADDLE_ENFORCE_EQ(
137 138
      dst->defined(),
      false,
139 140 141 142 143 144 145 146 147 148 149 150 151
      platform::errors::Fatal(
          "The underlying Tensor implementation should be nullptr"));
  dst->set_impl(std::make_shared<TType>());
  auto* dst_tensor = static_cast<TType*>(dst->impl().get());
  return dst_tensor;
}

template <typename TType>
TType* GetEmptyInnerTensor(paddle::imperative::VariableWrapper* dst) {
  auto* dst_tensor = dst->MutableVar()->GetMutable<TType>();
  return dst_tensor;
}

152 153
template <typename VarType>
void TensorAdd(const VarType& src, VarType* dst) {
154 155
  phi::DenseTensor* dst_tensor = GetInnerMutableTensor<phi::DenseTensor>(dst);
  const phi::DenseTensor& src_tensor = GetInnerTensor<phi::DenseTensor>(src);
J
Jiabin Yang 已提交
156 157 158 159 160 161 162 163 164

  auto numel = src_tensor.numel();

  // FIXME(minqiyang): loss_grad op will pass a zero grad of label
  // ugly fix for it
  if (numel == 0) {
    return;
  }

165
  PADDLE_ENFORCE_EQ(
166 167
      dst_tensor->numel(),
      numel,
168 169 170 171
      platform::errors::PreconditionNotMet(
          "The number of elements of source tensor and destination tensor "
          "should be equal, but got the number of elements of source tensor is "
          "%zu and the number of elements of destination tensor is %zu.",
172 173
          numel,
          dst_tensor->numel()));
J
Jiabin Yang 已提交
174

175
  auto data_type = framework::TransToProtoVarType(src_tensor.dtype());
J
Jiabin Yang 已提交
176 177
  auto place = src_tensor.place();

178 179
  PADDLE_ENFORCE_EQ(framework::TransToProtoVarType(dst_tensor->dtype()),
                    data_type,
180 181 182 183 184
                    platform::errors::PreconditionNotMet(
                        "The data type of source tensor and destination tensor "
                        "should be equal, Otherwise, the calculation results "
                        "will be incorrect."));

185 186 187 188
  // if src and dst are in different place, copy dst to src's place
  if (dst_tensor->place() != place) {
    paddle::framework::TensorCopySync(*dst_tensor, place, dst_tensor);
  }
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253

#define PADDLE_TENSOR_ADD(T, CONTEXT)                                          \
  if (data_type == framework::DataTypeTrait<T>::DataType()) {                  \
    auto cpu_ctx = static_cast<CONTEXT*>(                                      \
        platform::DeviceContextPool::Instance().Get(place));                   \
    phi::AddKernel<T, CONTEXT>(*cpu_ctx, src_tensor, *dst_tensor, dst_tensor); \
    return;                                                                    \
  }

  if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    PADDLE_TENSOR_ADD(float, phi::GPUContext);
    PADDLE_TENSOR_ADD(double, phi::GPUContext);
    PADDLE_TENSOR_ADD(phi::dtype::float16, phi::GPUContext);
    PADDLE_TENSOR_ADD(phi::dtype::bfloat16, phi::GPUContext);
    PADDLE_TENSOR_ADD(platform::complex<float>, phi::GPUContext);
    PADDLE_TENSOR_ADD(platform::complex<double>, phi::GPUContext);
#endif
  }

#define TENSOR_ADD_EIGEN(T)                                           \
  auto cpu_ctx = static_cast<phi::CPUContext*>(                       \
      platform::DeviceContextPool::Instance().Get(place));            \
  auto in = paddle::framework::EigenVector<T>::Flatten(src_tensor);   \
  auto out = paddle::framework::EigenVector<T>::Flatten(*dst_tensor); \
  auto& p = *(cpu_ctx->eigen_device());                               \
  out.device(p) = out + in;                                           \
  return;

  if (platform::is_cpu_place(place)) {
    PADDLE_TENSOR_ADD(float, phi::CPUContext);
    PADDLE_TENSOR_ADD(double, phi::CPUContext);
    PADDLE_TENSOR_ADD(platform::complex<float>, phi::CPUContext);
    PADDLE_TENSOR_ADD(platform::complex<double>, phi::CPUContext);
    if (data_type == framework::proto::VarType::BF16) {
      TENSOR_ADD_EIGEN(phi::dtype::bfloat16);
    }
    if (data_type == framework::proto::VarType::FP16) {
      TENSOR_ADD_EIGEN(phi::dtype::float16);
    }
  }

#define PADDLE_TENSOR_ADD_CUSTOM(T)                              \
  if (data_type == framework::DataTypeTrait<T>::DataType()) {    \
    platform::CustomDeviceContext* ctx =                         \
        static_cast<platform::CustomDeviceContext*>(             \
            platform::DeviceContextPool::Instance().Get(place)); \
    phi::stream::Stream stream(place, ctx->stream());            \
    auto device = phi::DeviceManager::GetDeviceWithPlace(place); \
    device->BlasAXPBY<T>(stream,                                 \
                         static_cast<size_t>(numel),             \
                         1.,                                     \
                         src_tensor.data<T>(),                   \
                         1.,                                     \
                         dst_tensor->mutable_data<T>(place));    \
    return;                                                      \
  }

  if (platform::is_custom_place(place)) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
    PADDLE_TENSOR_ADD_CUSTOM(float);
    PADDLE_TENSOR_ADD_CUSTOM(double);
    PADDLE_TENSOR_ADD_CUSTOM(platform::complex<float>);
    PADDLE_TENSOR_ADD_CUSTOM(platform::complex<double>);
#endif
J
Jiabin Yang 已提交
254 255
  }

256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
#ifdef PADDLE_WITH_ASCEND_CL
  if (platform::is_npu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::DeviceContext* ctx = pool.Get(place);
    auto dev_ctx = dynamic_cast<platform::NPUDeviceContext*>(ctx);
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      dst_tensor->mutable_data<float>(place);
    } else if (data_type == framework::DataTypeTrait<double>::DataType()) {
      dst_tensor->mutable_data<double>(place);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      dst_tensor->mutable_data<platform::float16>(place);
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
272 273
          framework::DataTypeToString(data_type),
          place));
274 275 276 277 278 279 280
    }
    const auto& runner = operators::NpuOpRunner(
        "Add", {*dst_tensor, src_tensor}, {*dst_tensor}, {});
    runner.Run(dev_ctx->stream());
    return;
  }
#endif
281

282 283 284 285 286 287 288 289 290 291 292
#ifdef PADDLE_WITH_XPU
  if (platform::is_xpu_place(place)) {
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      XPUTensorAddFunctor<float>(place, src_tensor, dst_tensor);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      XPUTensorAddFunctor<platform::float16>(place, src_tensor, dst_tensor);
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
293 294
          framework::DataTypeToString(data_type),
          place));
295 296 297 298 299
    }
    return;
  }
#endif

F
fwenguang 已提交
300 301 302 303 304 305 306 307 308 309 310 311 312 313
#ifdef PADDLE_WITH_MLU
  if (platform::is_mlu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::DeviceContext* ctx = pool.Get(place);
    auto dev_ctx = dynamic_cast<platform::MLUDeviceContext*>(ctx);
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      dst_tensor->mutable_data<float>(place);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      dst_tensor->mutable_data<platform::float16>(place);
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
314 315
          framework::DataTypeToString(data_type),
          place));
F
fwenguang 已提交
316 317 318 319 320
    }
    static const float alpha = 1.f;
    static const float beta = 1.f;
    operators::MLUCnnlTensorDesc src_tensor_desc(src_tensor);
    operators::MLUCnnlTensorDesc dst_tensor_desc(*dst_tensor);
321 322 323 324 325 326 327 328 329 330
    PADDLE_ENFORCE_MLU_SUCCESS(
        cnnlAssignAdd(dev_ctx->cnnl_handle(),
                      static_cast<const void*>(&alpha),
                      src_tensor_desc.get(),
                      operators::GetBasePtr(&src_tensor),
                      nullptr,
                      0,
                      static_cast<const void*>(&beta),
                      dst_tensor_desc.get(),
                      operators::GetBasePtr(dst_tensor)));
F
fwenguang 已提交
331 332 333 334
    return;
  }
#endif

335 336 337
  PADDLE_THROW(platform::errors::Unimplemented(
      "Gradient accumulation of data type (%s) on place (%s) is not "
      "supported in imperative mode",
338 339
      framework::DataTypeToString(data_type),
      place));
J
Jiabin Yang 已提交
340 341
}

342 343
template void TensorAdd<framework::Variable>(const framework::Variable& src,
                                             framework::Variable* dst);
344 345
template void TensorAdd<paddle::experimental::Tensor>(
    const paddle::experimental::Tensor& src, paddle::experimental::Tensor* dst);
346

347 348
template <typename VarType>
void SelectedRowsAddToTensor(const VarType& src, VarType* dst) {
349 350 351
  phi::DenseTensor* dst_tensor = GetInnerMutableTensor<phi::DenseTensor>(dst);
  const phi::SelectedRows& src_selected_rows =
      GetInnerTensor<phi::SelectedRows>(src);
352
  auto place = dst_tensor->place();
353 354
  auto data_type =
      framework::TransToProtoVarType(src_selected_rows.value().dtype());
355 356 357 358 359 360 361
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

#define PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(dev_ctx_type, cpp_type)           \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {         \
    paddle::platform::DeviceContext* dev_ctx = pool.Get(place);              \
    paddle::operators::math::SelectedRowsAddToTensor<dev_ctx_type, cpp_type> \
        functor;                                                             \
362 363
    functor(*(dynamic_cast<dev_ctx_type*>(dev_ctx)),                         \
            src_selected_rows,                                               \
364 365 366 367
            dst_tensor);                                                     \
    return;                                                                  \
  }

368
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
369
  if (paddle::platform::is_gpu_place(place)) {
L
Leo Chen 已提交
370 371
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::GPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::GPUContext, double);
372 373
  } else {
#endif
L
Leo Chen 已提交
374 375
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::CPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::CPUContext, double);
376
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
377 378 379 380 381 382 383 384 385 386
  }
#endif

#undef PADDLE_SELECTED_ROWS_ADD_TO_TENSOR

  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsAddToTensor",
      framework::DataTypeToString(data_type)));
}

387 388 389 390 391 392 393 394 395
template void SelectedRowsAddToTensor(const framework::Variable& src,
                                      framework::Variable* dst);
template void SelectedRowsAddToTensor(const paddle::experimental::Tensor& src,
                                      paddle::experimental::Tensor* dst);

template <typename VarType>
void SelectedRowsAddTensor(const VarType& src_selected_rows_var,
                           const VarType& src_tensor_var,
                           VarType* dst_tensor_var) {
396 397 398 399
  const phi::SelectedRows& src_selected_rows =
      GetInnerTensor<phi::SelectedRows>(src_selected_rows_var);
  const phi::DenseTensor& src_tensor =
      GetInnerTensor<phi::DenseTensor>(src_tensor_var);
400
  const auto& place = src_tensor.place();
401
  auto data_type = framework::TransToProtoVarType(src_tensor.dtype());
402 403
  auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);

404 405
  phi::DenseTensor* dst_tensor =
      GetInnerMutableTensor<phi::DenseTensor>(dst_tensor_var);
406
  dst_tensor->Resize(src_tensor.dims());
407 408
  dst_tensor->mutable_data(place, src_tensor.dtype());

409 410 411 412
#define PADDLE_SELECTED_ROWS_ADD_TENSOR(dev_ctx_type, cpp_type)            \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {       \
    paddle::operators::math::SelectedRowsAddTensor<dev_ctx_type, cpp_type> \
        functor;                                                           \
413 414 415 416
    functor(*(dynamic_cast<dev_ctx_type*>(dev_ctx)),                       \
            src_selected_rows,                                             \
            src_tensor,                                                    \
            dst_tensor);                                                   \
417 418 419
    return;                                                                \
  }

420
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
421
  if (platform::is_gpu_place(place)) {
L
Leo Chen 已提交
422 423
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::GPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::GPUContext, double);
424 425
  } else {
#endif
L
Leo Chen 已提交
426 427
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::CPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::CPUContext, double);
428
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
429 430 431 432 433 434 435 436 437 438
  }
#endif

  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsAddToTensor",
      framework::DataTypeToString(data_type)));

#undef PADDLE_SELECTED_ROWS_ADD_TENSOR
}

439 440 441 442 443 444 445 446 447 448 449 450
template void SelectedRowsAddTensor(
    const framework::Variable& src_selected_rows_var,
    const framework::Variable& src_tensor_var,
    framework::Variable* dst_tensor_var);
template void SelectedRowsAddTensor(
    const paddle::experimental::Tensor& src_selected_rows_var,
    const paddle::experimental::Tensor& src_tensor_var,
    paddle::experimental::Tensor* dst_tensor_var);

// Note(chenweihang): when two selected rows need to be added,
//   adding one to another is not equal to merging two selected rows
//   to one then add it to a empty selected rows, the after is correct
451 452 453
template <typename ReturnVarType, typename VarType>
std::shared_ptr<ReturnVarType> SelectedRowsMerge(const VarType& src1,
                                                 const VarType& src2) {
454 455 456 457
  const phi::SelectedRows& src_selected_rows1 =
      GetInnerTensor<phi::SelectedRows>(src1);
  const phi::SelectedRows& src_selected_rows2 =
      GetInnerTensor<phi::SelectedRows>(src2);
458

459
  auto place = src_selected_rows1.value().place();
460 461
  auto data_type =
      framework::TransToProtoVarType(src_selected_rows1.value().dtype());
462 463
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

464
  std::vector<const phi::SelectedRows*> src_selected_rows;
465 466
  src_selected_rows.emplace_back(&src_selected_rows1);
  src_selected_rows.emplace_back(&src_selected_rows2);
467 468

  auto dst_var = std::make_shared<ReturnVarType>("Temp");
469 470
  phi::SelectedRows* dst_selected_rows =
      GetEmptyInnerTensor<phi::SelectedRows>(dst_var.get());
471

472 473 474 475 476 477 478 479 480
#define PADDLE_SELECTED_ROWS_ADD(dev_ctx_type, cpp_type)               \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {   \
    paddle::platform::DeviceContext* dev_ctx = pool.Get(place);        \
    paddle::operators::math::scatter::MergeAdd<dev_ctx_type, cpp_type> \
        merge_add;                                                     \
    merge_add(*(dynamic_cast<dev_ctx_type*>(dev_ctx)),                 \
              src_selected_rows,                                       \
              dst_selected_rows);                                      \
    return dst_var;                                                    \
481 482
  }

483
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
484
  if (paddle::platform::is_gpu_place(place)) {
L
Leo Chen 已提交
485 486
    PADDLE_SELECTED_ROWS_ADD(phi::GPUContext, float);
    PADDLE_SELECTED_ROWS_ADD(phi::GPUContext, double);
487 488
  } else {
#endif
L
Leo Chen 已提交
489 490
    PADDLE_SELECTED_ROWS_ADD(phi::CPUContext, float);
    PADDLE_SELECTED_ROWS_ADD(phi::CPUContext, double);
491
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
492 493 494 495 496 497 498 499 500
  }
#endif

#undef PADDLE_SELECTED_ROWS_ADD
  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsMerge",
      framework::DataTypeToString(data_type)));
}

501 502 503 504 505 506
template std::shared_ptr<paddle::experimental::Tensor> SelectedRowsMerge(
    const paddle::experimental::Tensor& src1,
    const paddle::experimental::Tensor& src2);
template std::shared_ptr<paddle::imperative::VariableWrapper> SelectedRowsMerge(
    const framework::Variable& src1, const framework::Variable& src2);

507
void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var,
508 509
                        VariableWrapper* dst_var,
                        bool unchange_input) {
510
  auto& src = var->Var();
511
  auto* dst = dst_var->MutableVar();
512 513
  if (dst->IsType<framework::LoDTensor>()) {
    if (src.IsType<framework::LoDTensor>()) {
514
      TensorAdd<framework::Variable>(src, dst);
515
    } else if (src.IsType<phi::SelectedRows>()) {
516 517 518 519 520 521 522 523
      SelectedRowsAddToTensor(src, dst);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unexpected branch, output variable type is %s",
          framework::ToTypeName(dst->Type())));
    }
  } else {
    if (src.IsType<framework::LoDTensor>()) {
524 525 526 527 528 529 530 531 532
      if (unchange_input) {
        framework::Variable new_dst;
        SelectedRowsAddTensor(*dst, src, &new_dst);
        *dst = std::move(new_dst);
      } else {
        auto* src_mutable = var->MutableVar();
        SelectedRowsAddToTensor(*dst, src_mutable);
        *dst = std::move(*(var->MutableVar()));
      }
533
    } else if (src.IsType<phi::SelectedRows>()) {
534
      auto temp = SelectedRowsMerge<VariableWrapper>(src, *dst);
535 536 537 538 539 540 541 542 543
      *dst = std::move(*(temp->MutableVar()));
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unexpected branch, output variable type is %s",
          framework::ToTypeName(dst->Type())));
    }
  }
}

544 545
static platform::Place GetPlaceOfVar(
    const std::shared_ptr<VariableWrapper>& var) {
546 547 548
  platform::Place place;
  if (var->Var().IsType<framework::LoDTensor>()) {
    place = var->Var().Get<framework::LoDTensor>().place();
549 550
  } else if (var->Var().IsType<phi::SelectedRows>()) {
    place = var->Var().Get<phi::SelectedRows>().place();
551 552 553 554 555 556 557
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "only support LoDTensor and SelectedRows in dygraph"));
  }
  return place;
}

558 559
void GradientAccumulator::AccumulateGrad() {
  /**
560 561
   * If the leaf gradient has been calculated done, the inner_var_
   * should be added to the var_.
562 563 564 565
   */
  if (!var_->IsLeafGrad() || !SumGradCompleted() || !HasInnerVar()) {
    return;
  }
566 567
  PADDLE_ENFORCE_EQ(HasInnerVar(),
                    true,
568 569 570
                    platform::errors::InvalidArgument(
                        "Leaf tensor should have inner var to store results of "
                        "this auto-grad"));
571 572
  PADDLE_ENFORCE_EQ(inner_var_->Var().IsInitialized(),
                    true,
573
                    platform::errors::InvalidArgument(
574
                        "Interior var of Leaf tensor should be initialized."));
575 576 577
  auto* src = inner_var_->MutableVar();
  auto* dst = var_->MutableVar();
  if (!var_->IsEmpty()) {
578 579 580
    VLOG(6) << "Leaf Var(" << var_->Name()
            << ")'s Gradient has been initizlized, will accumulate on "
               "previous gradient.";
581 582
    if (dst->IsType<framework::LoDTensor>()) {
      if (src->IsType<framework::LoDTensor>()) {
583
        TensorAdd<framework::Variable>(*src, dst);
584
      } else if (src->IsType<phi::SelectedRows>()) {
585 586
        SelectedRowsAddToTensor(*src, dst);
      }
587
    } else if (dst->IsType<phi::SelectedRows>()) {
588 589 590
      if (src->IsType<framework::LoDTensor>()) {
        SelectedRowsAddToTensor(*dst, src);
        *dst = std::move(*src);
591
      } else if (src->IsType<phi::SelectedRows>()) {
592
        auto temp = SelectedRowsMerge<VariableWrapper>(*src, *dst);
593 594 595 596 597 598 599
        *dst = std::move(*(temp->MutableVar()));
      }
    } else {
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Only support LoDTensor and SelectedRows for gradient var"));
    }
  } else {
600 601 602
    VLOG(6)
        << "Leaf Var(" << var_->Name()
        << ")'s Gradient has not been initialized, not accumulate. Just move";
603 604 605
    *(dst) = std::move(*src);
    var_->SetType(inner_var_->Type());
    var_->SetDataType(inner_var_->DataType());
606
    var_->SetIsEmpty(false);
607 608 609 610
  }
  inner_var_.reset();
}

611
void GradientAccumulator::CallGradientHooks() {
612 613
  PADDLE_ENFORCE_EQ(var_->IsLeafGrad(),
                    true,
614 615 616 617
                    platform::errors::Unavailable(
                        "Only leaf gradient Tensor can deal with by gradient "
                        "hook in gradient accumulator."));
  PADDLE_ENFORCE_EQ(
618 619
      SumGradCompleted(),
      true,
620 621 622
      platform::errors::PreconditionNotMet(
          "Only can call gradient hooks after sum gradient completed."));
  PADDLE_ENFORCE_EQ(
623 624
      HasInnerVar(),
      true,
625 626 627
      platform::errors::PreconditionNotMet(
          "Leaf Tensor's inner var is nullptr when call gradient hook."));
  PADDLE_ENFORCE_EQ(
628 629
      inner_var_->Var().IsInitialized(),
      true,
630 631 632
      platform::errors::PreconditionNotMet("Leaf Tensor's inner var "
                                           "is not initialized when "
                                           "call gradient hook."));
633 634
  if (var_->HasVariableWrapperHook()) {
    VLOG(3) << "Call " << var_->GetVariableWrapperHooks().size()
635 636 637 638
            << " hooks of leaf gradient accumulator's inner var `"
            << var_->Name() << "`.";
    auto tmp_var = inner_var_;
    VLOG(3) << "Input var " << var_->Name() << "'s hook size - "
639 640
            << var_->GetVariableWrapperHooks().size();
    for (const auto& hook_pair : var_->GetVariableWrapperHooks()) {
641
      tmp_var = (*hook_pair.second)(tmp_var);
L
Leo Chen 已提交
642
      CheckVar(inner_var_, tmp_var);
643 644 645 646 647 648 649
    }
    inner_var_ = tmp_var;
  }
}

void GradientAccumulator::CallReduceHooks() {
  PADDLE_ENFORCE_EQ(
650 651
      var_->IsLeafGrad(),
      true,
652 653
      platform::errors::Unavailable("Only leaf gradient Tensor can deal with "
                                    "by reduce hook in gradient accumulator."));
654 655
  PADDLE_ENFORCE_EQ(SumGradCompleted(),
                    true,
656 657 658
                    platform::errors::PreconditionNotMet(
                        "Only can call reduce hooks after the gradient "
                        "summation is completed in current batch."));
659 660
  PADDLE_ENFORCE_EQ(HasInnerVar(),
                    false,
661 662 663 664
                    platform::errors::PreconditionNotMet(
                        "Only can call reduce hooks after the "
                        "gradient accumulation is completed in "
                        "current batch or across batchs."));
665 666
  if (var_->HasVoidHook()) {
    for (const auto& hook : var_->GetVoidHooks()) {
667
      VLOG(3) << "call gradient accumulator backward hooks.";
668
      (*hook)();
669 670 671 672
    }
  }
}

673
void EagerGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
674 675
                                       size_t trace_id,
                                       bool unchange_input) {
676 677 678 679 680 681 682 683
  /**
   * If var has grad node, it indicates that this var would be an input
   * of a grad op. Therefore, it should not be changed.
   */
  if (var->HasGradNode()) {
    unchange_input = true;
  }

684
  auto* dst_var = Var();
685
  platform::Place place = GetPlaceOfVar(var);
686 687 688
  if (!dst_var->OverridedStopGradient()) {
    if (CurCnt() == 0) {
      MoveOrCopyVar(dst_var->MutableVar(), var->MutableVar(), unchange_input);
689
    } else {
690 691 692
      VLOG(6) << "Sum Gradient for: " << dst_var->Name()
              << " within this graph.";
      VariableWrapperAdd(var, dst_var, unchange_input);
693
    }
J
Jiabin Yang 已提交
694
  } else {
695 696 697
    if (!dst_var->Var().IsInitialized() ||
        !dst_var->Var().Get<framework::LoDTensor>().IsInitialized()) {
      VLOG(6) << "Set StopGradient Grad: " << dst_var->Name() << " as zero ";
698
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
699 700 701 702
      if (!dst_var->Var().IsInitialized()) {
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
        VLOG(6) << "Dims of " << dst_var->Name() << " is set as: "
703 704
                << var->Var().Get<framework::LoDTensor>().dims();
        tensor->Resize(var->Var().Get<framework::LoDTensor>().dims());
705
        tensor->mutable_data(place,
706
                             framework::TransToPhiDataType(var->DataType()));
707
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
708
      } else {
709 710
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
711
        tensor->mutable_data(place,
712
                             framework::TransToPhiDataType(var->DataType()));
713
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
714
      }
715
    }
J
Jiabin Yang 已提交
716
  }
717

718 719 720 721
  // Type may be changed after OP run, such as VarTypeInference
  // so synchronous VariableWrapper with Variable.
  if (dst_var->Var().IsType<framework::LoDTensor>()) {
    dst_var->SetType(framework::proto::VarType::LOD_TENSOR);
722
  } else if (dst_var->Var().IsType<phi::SelectedRows>()) {
723
    dst_var->SetType(framework::proto::VarType::SELECTED_ROWS);
724
  }
725

726
  // Increase curent count
727
  IncreaseCurCnt();
J
Jiabin Yang 已提交
728 729
}

730
void SortedGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
731 732
                                        size_t trace_id,
                                        bool unchange_input) {
733
  auto* dst_var = Var();
734
  platform::Place place = GetPlaceOfVar(var);
735
  if (!dst_var->OverridedStopGradient()) {
736
    if (ref_cnt_ == 1) {
737 738
      MoveOrCopyVar(dst_var->MutableVar(),
                    var->MutableVar(),
739
                    unchange_input || var->HasGradNode());
740 741 742 743 744
    } else {
      if (tmp_grad_vars_.empty()) {
        tmp_grad_vars_.reserve(ref_cnt_);
      }

745
      tmp_grad_vars_.emplace_back(std::move(var), trace_id, unchange_input);
746 747 748 749 750

      if (tmp_grad_vars_.size() != ref_cnt_) {
        return;
      }

751 752
      VLOG(6) << "Sum Gradient for: " << dst_var->Name()
              << " within this graph.";
753 754
      std::sort(tmp_grad_vars_.begin(),
                tmp_grad_vars_.end(),
755 756 757 758 759 760 761 762 763
                [](const SavedVarInfo& info1, const SavedVarInfo& info2) {
                  return info1.trace_id > info2.trace_id;
                });

      for (auto& var_info : tmp_grad_vars_) {
        if (var_info.var->HasGradNode()) {
          var_info.unchange_input = true;
        }
      }
764

765
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
766
      if (paddle::platform::is_gpu_place(place)) {
767
        // sum selected rows firstly
768
        for (auto& var_info : tmp_grad_vars_) {
769
          if (!var_info.var->Var().IsType<phi::SelectedRows>()) {
770
            continue;
771
          }
772

773
          if (CurCnt() == 0) {
774 775
            MoveOrCopyVar(dst_var->MutableVar(),
                          var_info.var->MutableVar(),
776 777
                          var_info.unchange_input);
          } else {
778
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
779
          }
780 781

          var_info.var = nullptr;
782 783
          // Increase count
          IncreaseCurCnt();
784 785 786 787 788 789 790 791
        }

        for (auto& var_info : tmp_grad_vars_) {
          if (!var_info.var) {
            continue;
          }

          PADDLE_ENFORCE_EQ(var_info.var->Var().IsType<framework::LoDTensor>(),
792 793 794
                            true,
                            platform::errors::PermissionDenied(
                                "Gradient var must be LoDTensor"));
795
          if (CurCnt() == 0) {
796 797
            MoveOrCopyVar(dst_var->MutableVar(),
                          var_info.var->MutableVar(),
798 799
                          var_info.unchange_input);
          } else {
800
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
801
          }
802 803

          var_info.var = nullptr;
804 805
          // Increase count
          IncreaseCurCnt();
806 807 808
        }
      } else {
#endif
809 810 811 812 813 814
        for (auto& var_info : tmp_grad_vars_) {
          if (!var_info.var) {
            continue;
          }
          PADDLE_ENFORCE_EQ(
              var_info.var->Var().IsType<framework::LoDTensor>() ||
815
                  var_info.var->Var().IsType<phi::SelectedRows>(),
816 817 818 819
              true,
              platform::errors::PermissionDenied("The type of Gradient "
                                                 "var must be LoDTensor "
                                                 "or SelectedRows"));
820
          if (CurCnt() == 0) {
821 822
            MoveOrCopyVar(dst_var->MutableVar(),
                          var_info.var->MutableVar(),
823 824 825 826 827 828 829
                          var_info.unchange_input);
          } else {
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
          }
          var_info.var = nullptr;
          // Increase count
          IncreaseCurCnt();
830
        }
831
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
832
      }
833
#endif
834
      tmp_grad_vars_.clear();
J
Jiabin Yang 已提交
835
    }
836
  } else {
837 838
    if (!dst_var->Var().IsInitialized() ||
        !dst_var->Var().Get<framework::LoDTensor>().IsInitialized()) {
839 840
      VLOG(6) << "Set StopGradient Grad: " << var->Name() << " as zero";
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
841 842 843 844
      if (!dst_var->Var().IsInitialized()) {
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
        VLOG(6) << "Dims of " << dst_var->Name() << " is set as: "
845 846
                << var->Var().Get<framework::LoDTensor>().dims();
        tensor->Resize(var->Var().Get<framework::LoDTensor>().dims());
847
        tensor->mutable_data(place,
848
                             framework::TransToPhiDataType(var->DataType()));
849
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
850
      } else {
851 852
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
853
        tensor->mutable_data(place,
854
                             framework::TransToPhiDataType(var->DataType()));
855
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
856
      }
J
Jiabin Yang 已提交
857
    }
858
    // looks like tmp_grad_vars will not have any member but just in case
J
Jiabin Yang 已提交
859 860
    tmp_grad_vars_.clear();
  }
861

862 863
  if (dst_var->Var().IsType<framework::LoDTensor>()) {
    dst_var->SetType(framework::proto::VarType::LOD_TENSOR);
864
  } else if (dst_var->Var().IsType<phi::SelectedRows>()) {
865
    dst_var->SetType(framework::proto::VarType::SELECTED_ROWS);
866
  }
J
Jiabin Yang 已提交
867 868 869 870
}

}  // namespace imperative
}  // namespace paddle