gradient_accumulator.cc 38.1 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/gradient_accumulator.h"
16

J
Jiabin Yang 已提交
17 18 19
#include <algorithm>
#include <memory>
#include <utility>
20

21
#include "paddle/fluid/framework/convert_utils.h"
J
Jiabin Yang 已提交
22
#include "paddle/fluid/framework/lod_tensor.h"
23
#include "paddle/fluid/framework/selected_rows_utils.h"
J
Jiabin Yang 已提交
24
#include "paddle/fluid/imperative/layer.h"
25
#include "paddle/fluid/operators/math/selected_rows_functor.h"
26
#include "paddle/fluid/platform/bfloat16.h"
27
#include "paddle/fluid/platform/complex.h"
J
Jiabin Yang 已提交
28
#include "paddle/fluid/platform/device_context.h"
29
#include "paddle/fluid/platform/float16.h"
J
Jiabin Yang 已提交
30
#include "paddle/fluid/platform/profiler.h"
31 32
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
H
hong 已提交
33 34 35
#ifdef PADDLE_WITH_XPU
#include "xpu/refactor/math.h"
#endif
36
#ifdef PADDLE_WITH_ASCEND_CL
37
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
38
#endif
F
fwenguang 已提交
39 40 41
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#endif
42 43 44
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
#endif
J
Jiabin Yang 已提交
45 46 47 48

namespace paddle {
namespace imperative {

49 50
static void MoveOrCopyVar(framework::Variable* dst,
                          framework::Variable* src,
51 52
                          bool force_copy) {
  if (!force_copy) {
53
    VLOG(6) << "Just Move Variable when sum gradients within this graph";
54 55 56 57
    *dst = std::move(*src);
    return;
  }

58
  VLOG(6) << "Copy occurs when sum gradients within this graph";
59 60 61 62 63 64 65 66
  if (src->IsType<framework::LoDTensor>()) {
    auto& src_tensor = src->Get<framework::LoDTensor>();
    if (!dst->IsType<framework::LoDTensor>()) {
      dst->Clear();
    }
    auto* dst_tensor = dst->GetMutable<framework::LoDTensor>();
    framework::TensorCopy(src_tensor, src_tensor.place(), dst_tensor);
    dst_tensor->set_lod(src_tensor.lod());
67 68 69
  } else if (src->IsType<phi::SelectedRows>()) {
    auto& src_selected_rows = src->Get<phi::SelectedRows>();
    if (!dst->IsType<phi::SelectedRows>()) {
70 71
      dst->Clear();
    }
72
    auto* dst_selected_rows = dst->GetMutable<phi::SelectedRows>();
73 74 75 76 77 78 79
    framework::TensorCopy(src_selected_rows.value(),
                          src_selected_rows.value().place(),
                          dst_selected_rows->mutable_value());
    dst_selected_rows->set_rows(src_selected_rows.rows());
    dst_selected_rows->set_height(src_selected_rows.height());
  } else {
    PADDLE_THROW(platform::errors::PermissionDenied(
80
        "Only support LoDTensor and SelectedRows for sum gradient"));
81 82 83
  }
}

J
Jiabin Yang 已提交
84
template <typename T>
85 86
class TensorAddFunctor
    : public std::unary_function<const platform::Place&, void> {
J
Jiabin Yang 已提交
87 88 89 90
 public:
  TensorAddFunctor(int64_t numel, const T* x, T* y)
      : numel_(numel), x_(x), y_(y) {}

91
  void operator()(const platform::CPUPlace& place) const {
L
Leo Chen 已提交
92
    phi::CPUContext* ctx = dynamic_cast<phi::CPUContext*>(
J
Jiabin Yang 已提交
93
        platform::DeviceContextPool::Instance().Get(place));
L
Leo Chen 已提交
94
    auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(*ctx);
J
Jiabin Yang 已提交
95 96 97
    blas.AXPY(numel_, 1., x_, y_);
  }

H
hong 已提交
98
#ifdef PADDLE_WITH_XPU
99
  void operator()(const platform::XPUPlace& place) const {
100
    using XPUType = typename XPUTypeTrait<T>::Type;
H
hong 已提交
101 102
    platform::XPUDeviceContext* ctx = dynamic_cast<platform::XPUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
103 104 105 106 107
    int r = xpu::add<XPUType>(ctx->x_context(),
                              reinterpret_cast<const XPUType*>(x_),
                              reinterpret_cast<const XPUType*>(y_),
                              reinterpret_cast<XPUType*>(y_),
                              static_cast<int>(numel_));
108
    PADDLE_ENFORCE_EQ(
109 110 111 112
        r,
        XPU_SUCCESS,
        platform::errors::External(
            "XPU add kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r]));
H
hong 已提交
113 114
  }
#else
115
  void operator()(const platform::XPUPlace& place) const {
116 117 118 119 120
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
H
hong 已提交
121
#endif
122

123
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
124
  void operator()(const platform::CUDAPlace& place) const {
L
Leo Chen 已提交
125 126 127
    phi::GPUContext* ctx = dynamic_cast<phi::GPUContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(*ctx);
J
Jiabin Yang 已提交
128 129 130
    blas.AXPY(numel_, 1., x_, y_);
  }
#else
131
  void operator()(const platform::CUDAPlace& place) const {
132
    PADDLE_THROW(platform::errors::PermissionDenied(
133 134 135 136 137 138
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
#endif

F
fwenguang 已提交
139
#ifdef PADDLE_WITH_MLU
140
  void operator()(const platform::MLUPlace& place) const {
F
fwenguang 已提交
141 142 143 144 145 146 147
    // TODO(fwg): SUPPORT it
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
#else
148
  void operator()(const platform::MLUPlace& place) const {
F
fwenguang 已提交
149 150 151 152 153 154 155
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
#endif

156
#ifdef PADDLE_WITH_ASCEND_CL
157
  void operator()(const platform::NPUPlace& place) const {
158 159 160 161 162 163 164
    // TODO(zhiqiu): SUPPORT it
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
#else
165
  void operator()(const platform::NPUPlace& place) const {
166
    PADDLE_THROW(platform::errors::PermissionDenied(
167 168 169
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
J
Jiabin Yang 已提交
170 171 172
  }
#endif

173
  void operator()(const platform::NPUPinnedPlace& place) const {
174 175 176 177 178
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
J
Jiabin Yang 已提交
179
  // there is NO blas in CUDAPinnedPlace
180
  void operator()(const platform::CUDAPinnedPlace& place) const {
181 182 183 184
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
J
Jiabin Yang 已提交
185
  }
J
jianghaicheng 已提交
186
  // there is NO support in IPUPlace
187
  void operator()(const platform::IPUPlace& place) const {
J
jianghaicheng 已提交
188 189 190 191 192
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
193
  void operator()(const platform::CustomPlace& place) const {
194 195 196 197 198 199 200 201
#ifdef PADDLE_WITH_CUSTOM_DEVICE
    platform::CustomDeviceContext* ctx =
        dynamic_cast<platform::CustomDeviceContext*>(
            platform::DeviceContextPool::Instance().Get(place));
    phi::stream::Stream stream(place, ctx->stream());
    auto device = phi::DeviceManager::GetDeviceWithPlace(place);
    device->BlasAXPBY<T>(stream, static_cast<size_t>(numel_), 1., x_, 1., y_);
#else
202 203 204 205
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
206
#endif
207
  }
J
Jiabin Yang 已提交
208 209 210 211

 private:
  int64_t numel_;
  const T* x_;
212
  mutable T* y_;
J
Jiabin Yang 已提交
213 214
};

215 216 217
#ifdef PADDLE_WITH_XPU
template <typename T>
void XPUTensorAddFunctor(const platform::Place& place,
218 219
                         const framework::Tensor& src,
                         framework::Tensor* dst) {
220 221 222 223 224
  using XPUType = typename XPUTypeTrait<T>::Type;
  platform::XPUDeviceContext* ctx = dynamic_cast<platform::XPUDeviceContext*>(
      platform::DeviceContextPool::Instance().Get(place));
  const XPUType* x = reinterpret_cast<const XPUType*>(src.data<T>());
  XPUType* y = reinterpret_cast<XPUType*>(dst->mutable_data<T>(place));
225 226
  int r = xpu::add<XPUType>(
      ctx->x_context(), x, y, y, static_cast<int>(src.numel()));
227
  PADDLE_ENFORCE_EQ(
228 229 230 231
      r,
      XPU_SUCCESS,
      platform::errors::External(
          "XPU add kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r]));
232 233 234
}
#endif

235
template <typename DeviceContext, typename T>
236 237
void TensorAddImpl(const framework::Tensor& src,
                   framework::Tensor* dst,
238 239 240 241
                   const platform::Place& place) {
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  paddle::platform::DeviceContext* ctx = pool.Get(place);
  auto dev_ctx = dynamic_cast<DeviceContext*>(ctx);
242
  phi::funcs::ElementwiseAddTo<DeviceContext, T> func;
243 244 245
  func(dev_ctx, src, dst);
}

246 247 248
template <typename TType>
TType* GetInnerMutableTensor(framework::Variable* dst) {
  auto* dst_tensor = dst->GetMutable<TType>();
249 250 251
  return dst_tensor;
}

252 253 254
template <typename TType>
TType* GetInnerMutableTensor(paddle::experimental::Tensor* dst) {
  auto* dst_tensor = static_cast<TType*>(dst->impl().get());
255 256 257
  return dst_tensor;
}

258 259 260
template <typename TType>
const TType& GetInnerTensor(const framework::Variable& src) {
  return src.Get<TType>();
261 262
}

263 264 265
template <typename TType>
TType& GetInnerTensor(const paddle::experimental::Tensor& src) {
  PADDLE_ENFORCE_EQ(
266 267
      src.initialized(),
      true,
268 269 270 271 272
      platform::errors::Fatal("We only add tensor with value if a tensor is "
                              "NOT INITILIZED, it should just move instead of "
                              "calling this method."));
  auto* src_tensor = static_cast<TType*>(src.impl().get());
  return *src_tensor;
273 274
}

275 276 277
template <typename TType>
TType* GetEmptyInnerTensor(paddle::experimental::Tensor* dst) {
  PADDLE_ENFORCE_EQ(
278 279
      dst->defined(),
      false,
280 281 282 283 284 285 286 287 288 289 290 291 292
      platform::errors::Fatal(
          "The underlying Tensor implementation should be nullptr"));
  dst->set_impl(std::make_shared<TType>());
  auto* dst_tensor = static_cast<TType*>(dst->impl().get());
  return dst_tensor;
}

template <typename TType>
TType* GetEmptyInnerTensor(paddle::imperative::VariableWrapper* dst) {
  auto* dst_tensor = dst->MutableVar()->GetMutable<TType>();
  return dst_tensor;
}

293 294
template <typename VarType>
void TensorAdd(const VarType& src, VarType* dst) {
295 296
  phi::DenseTensor* dst_tensor = GetInnerMutableTensor<phi::DenseTensor>(dst);
  const phi::DenseTensor& src_tensor = GetInnerTensor<phi::DenseTensor>(src);
J
Jiabin Yang 已提交
297 298 299 300 301 302 303 304 305

  auto numel = src_tensor.numel();

  // FIXME(minqiyang): loss_grad op will pass a zero grad of label
  // ugly fix for it
  if (numel == 0) {
    return;
  }

306
  PADDLE_ENFORCE_EQ(
307 308
      dst_tensor->numel(),
      numel,
309 310 311 312
      platform::errors::PreconditionNotMet(
          "The number of elements of source tensor and destination tensor "
          "should be equal, but got the number of elements of source tensor is "
          "%zu and the number of elements of destination tensor is %zu.",
313 314
          numel,
          dst_tensor->numel()));
J
Jiabin Yang 已提交
315

316
  auto data_type = framework::TransToProtoVarType(src_tensor.dtype());
J
Jiabin Yang 已提交
317 318
  auto place = src_tensor.place();

319 320
  PADDLE_ENFORCE_EQ(framework::TransToProtoVarType(dst_tensor->dtype()),
                    data_type,
321 322 323 324 325
                    platform::errors::PreconditionNotMet(
                        "The data type of source tensor and destination tensor "
                        "should be equal, Otherwise, the calculation results "
                        "will be incorrect."));

326 327 328 329
  // if src and dst are in different place, copy dst to src's place
  if (dst_tensor->place() != place) {
    paddle::framework::TensorCopySync(*dst_tensor, place, dst_tensor);
  }
330
#define PADDLE_TENSOR_ADD(cpp_type)                                  \
J
Jiabin Yang 已提交
331 332
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) { \
    TensorAddFunctor<cpp_type> func(                                 \
333 334
        numel,                                                       \
        src_tensor.data<cpp_type>(),                                 \
J
Jiabin Yang 已提交
335
        dst_tensor->mutable_data<cpp_type>(place));                  \
336
    platform::VisitPlace(place, func);                               \
J
Jiabin Yang 已提交
337 338 339
    return;                                                          \
  }

340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355
#ifdef PADDLE_WITH_ASCEND_CL
  if (platform::is_npu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::DeviceContext* ctx = pool.Get(place);
    auto dev_ctx = dynamic_cast<platform::NPUDeviceContext*>(ctx);
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      dst_tensor->mutable_data<float>(place);
    } else if (data_type == framework::DataTypeTrait<double>::DataType()) {
      dst_tensor->mutable_data<double>(place);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      dst_tensor->mutable_data<platform::float16>(place);
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
356 357
          framework::DataTypeToString(data_type),
          place));
358 359 360 361 362 363 364
    }
    const auto& runner = operators::NpuOpRunner(
        "Add", {*dst_tensor, src_tensor}, {*dst_tensor}, {});
    runner.Run(dev_ctx->stream());
    return;
  }
#endif
365

366 367 368 369 370 371 372 373 374 375 376
#ifdef PADDLE_WITH_XPU
  if (platform::is_xpu_place(place)) {
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      XPUTensorAddFunctor<float>(place, src_tensor, dst_tensor);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      XPUTensorAddFunctor<platform::float16>(place, src_tensor, dst_tensor);
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
377 378
          framework::DataTypeToString(data_type),
          place));
379 380 381 382 383
    }
    return;
  }
#endif

F
fwenguang 已提交
384 385 386 387 388 389 390 391 392 393 394 395 396 397
#ifdef PADDLE_WITH_MLU
  if (platform::is_mlu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::DeviceContext* ctx = pool.Get(place);
    auto dev_ctx = dynamic_cast<platform::MLUDeviceContext*>(ctx);
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      dst_tensor->mutable_data<float>(place);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      dst_tensor->mutable_data<platform::float16>(place);
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
398 399
          framework::DataTypeToString(data_type),
          place));
F
fwenguang 已提交
400 401 402 403 404
    }
    static const float alpha = 1.f;
    static const float beta = 1.f;
    operators::MLUCnnlTensorDesc src_tensor_desc(src_tensor);
    operators::MLUCnnlTensorDesc dst_tensor_desc(*dst_tensor);
405 406 407 408 409 410 411 412 413 414
    PADDLE_ENFORCE_MLU_SUCCESS(
        cnnlAssignAdd(dev_ctx->cnnl_handle(),
                      static_cast<const void*>(&alpha),
                      src_tensor_desc.get(),
                      operators::GetBasePtr(&src_tensor),
                      nullptr,
                      0,
                      static_cast<const void*>(&beta),
                      dst_tensor_desc.get(),
                      operators::GetBasePtr(dst_tensor)));
F
fwenguang 已提交
415 416 417 418
    return;
  }
#endif

419
  PADDLE_TENSOR_ADD(float);
420

H
hong 已提交
421 422
#ifndef PADDLE_WITH_XPU
  // NOTE(phlrain): xpu only support float
423
  PADDLE_TENSOR_ADD(double);
424 425
  // NOTE(chenweihang): only support complex grad tensor accumulated,
  // support selected rows if needed in the future
426 427
  PADDLE_TENSOR_ADD(platform::complex<float>);
  PADDLE_TENSOR_ADD(platform::complex<double>);
H
hong 已提交
428
#endif
J
Jiabin Yang 已提交
429

430
#undef PADDLE_TENSOR_ADD
J
Jiabin Yang 已提交
431

432 433
  if (data_type == framework::proto::VarType::FP16) {
    if (platform::is_gpu_place(place)) {
434
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
L
Leo Chen 已提交
435
      return TensorAddImpl<phi::GPUContext, platform::float16>(
436 437 438 439 440
          src_tensor, dst_tensor, place);
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
441 442
          framework::DataTypeToString(data_type),
          place));
443 444
#endif
    } else if (platform::is_cpu_place(place)) {
L
Leo Chen 已提交
445
      return TensorAddImpl<phi::CPUContext, platform::float16>(
446 447 448
          src_tensor, dst_tensor, place);
    }
  }
449 450
  if (data_type == framework::proto::VarType::BF16) {
    if (platform::is_gpu_place(place)) {
451
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
L
Leo Chen 已提交
452
      return TensorAddImpl<phi::GPUContext, platform::bfloat16>(
453 454 455 456 457
          src_tensor, dst_tensor, place);
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
458 459
          framework::DataTypeToString(data_type),
          place));
460 461
#endif
    } else if (platform::is_cpu_place(place)) {
L
Leo Chen 已提交
462
      return TensorAddImpl<phi::CPUContext, platform::bfloat16>(
463 464 465
          src_tensor, dst_tensor, place);
    }
  }
466 467 468
  PADDLE_THROW(platform::errors::Unimplemented(
      "Gradient accumulation of data type (%s) on place (%s) is not "
      "supported in imperative mode",
469 470
      framework::DataTypeToString(data_type),
      place));
J
Jiabin Yang 已提交
471 472
}

473 474
template void TensorAdd<framework::Variable>(const framework::Variable& src,
                                             framework::Variable* dst);
475 476
template void TensorAdd<paddle::experimental::Tensor>(
    const paddle::experimental::Tensor& src, paddle::experimental::Tensor* dst);
477

478 479
template <typename VarType>
void SelectedRowsAddToTensor(const VarType& src, VarType* dst) {
480 481 482
  phi::DenseTensor* dst_tensor = GetInnerMutableTensor<phi::DenseTensor>(dst);
  const phi::SelectedRows& src_selected_rows =
      GetInnerTensor<phi::SelectedRows>(src);
483
  auto place = dst_tensor->place();
484 485
  auto data_type =
      framework::TransToProtoVarType(src_selected_rows.value().dtype());
486 487 488 489 490 491 492
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

#define PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(dev_ctx_type, cpp_type)           \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {         \
    paddle::platform::DeviceContext* dev_ctx = pool.Get(place);              \
    paddle::operators::math::SelectedRowsAddToTensor<dev_ctx_type, cpp_type> \
        functor;                                                             \
493 494
    functor(*(dynamic_cast<dev_ctx_type*>(dev_ctx)),                         \
            src_selected_rows,                                               \
495 496 497 498
            dst_tensor);                                                     \
    return;                                                                  \
  }

499
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
500
  if (paddle::platform::is_gpu_place(place)) {
L
Leo Chen 已提交
501 502
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::GPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::GPUContext, double);
503 504
  } else {
#endif
L
Leo Chen 已提交
505 506
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::CPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::CPUContext, double);
507
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
508 509 510 511 512 513 514 515 516 517
  }
#endif

#undef PADDLE_SELECTED_ROWS_ADD_TO_TENSOR

  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsAddToTensor",
      framework::DataTypeToString(data_type)));
}

518 519 520 521 522 523 524 525 526
template void SelectedRowsAddToTensor(const framework::Variable& src,
                                      framework::Variable* dst);
template void SelectedRowsAddToTensor(const paddle::experimental::Tensor& src,
                                      paddle::experimental::Tensor* dst);

template <typename VarType>
void SelectedRowsAddTensor(const VarType& src_selected_rows_var,
                           const VarType& src_tensor_var,
                           VarType* dst_tensor_var) {
527 528 529 530
  const phi::SelectedRows& src_selected_rows =
      GetInnerTensor<phi::SelectedRows>(src_selected_rows_var);
  const phi::DenseTensor& src_tensor =
      GetInnerTensor<phi::DenseTensor>(src_tensor_var);
531
  const auto& place = src_tensor.place();
532
  auto data_type = framework::TransToProtoVarType(src_tensor.dtype());
533 534
  auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);

535 536
  phi::DenseTensor* dst_tensor =
      GetInnerMutableTensor<phi::DenseTensor>(dst_tensor_var);
537
  dst_tensor->Resize(src_tensor.dims());
538 539
  dst_tensor->mutable_data(place, src_tensor.dtype());

540 541 542 543
#define PADDLE_SELECTED_ROWS_ADD_TENSOR(dev_ctx_type, cpp_type)            \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {       \
    paddle::operators::math::SelectedRowsAddTensor<dev_ctx_type, cpp_type> \
        functor;                                                           \
544 545 546 547
    functor(*(dynamic_cast<dev_ctx_type*>(dev_ctx)),                       \
            src_selected_rows,                                             \
            src_tensor,                                                    \
            dst_tensor);                                                   \
548 549 550
    return;                                                                \
  }

551
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
552
  if (platform::is_gpu_place(place)) {
L
Leo Chen 已提交
553 554
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::GPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::GPUContext, double);
555 556
  } else {
#endif
L
Leo Chen 已提交
557 558
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::CPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::CPUContext, double);
559
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
560 561 562 563 564 565 566 567 568 569
  }
#endif

  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsAddToTensor",
      framework::DataTypeToString(data_type)));

#undef PADDLE_SELECTED_ROWS_ADD_TENSOR
}

570 571 572 573 574 575 576 577 578 579 580 581
template void SelectedRowsAddTensor(
    const framework::Variable& src_selected_rows_var,
    const framework::Variable& src_tensor_var,
    framework::Variable* dst_tensor_var);
template void SelectedRowsAddTensor(
    const paddle::experimental::Tensor& src_selected_rows_var,
    const paddle::experimental::Tensor& src_tensor_var,
    paddle::experimental::Tensor* dst_tensor_var);

// Note(chenweihang): when two selected rows need to be added,
//   adding one to another is not equal to merging two selected rows
//   to one then add it to a empty selected rows, the after is correct
582 583 584
template <typename ReturnVarType, typename VarType>
std::shared_ptr<ReturnVarType> SelectedRowsMerge(const VarType& src1,
                                                 const VarType& src2) {
585 586 587 588
  const phi::SelectedRows& src_selected_rows1 =
      GetInnerTensor<phi::SelectedRows>(src1);
  const phi::SelectedRows& src_selected_rows2 =
      GetInnerTensor<phi::SelectedRows>(src2);
589

590
  auto place = src_selected_rows1.value().place();
591 592
  auto data_type =
      framework::TransToProtoVarType(src_selected_rows1.value().dtype());
593 594
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

595
  std::vector<const phi::SelectedRows*> src_selected_rows;
596 597
  src_selected_rows.emplace_back(&src_selected_rows1);
  src_selected_rows.emplace_back(&src_selected_rows2);
598 599

  auto dst_var = std::make_shared<ReturnVarType>("Temp");
600 601
  phi::SelectedRows* dst_selected_rows =
      GetEmptyInnerTensor<phi::SelectedRows>(dst_var.get());
602

603 604 605 606 607 608 609 610 611
#define PADDLE_SELECTED_ROWS_ADD(dev_ctx_type, cpp_type)               \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {   \
    paddle::platform::DeviceContext* dev_ctx = pool.Get(place);        \
    paddle::operators::math::scatter::MergeAdd<dev_ctx_type, cpp_type> \
        merge_add;                                                     \
    merge_add(*(dynamic_cast<dev_ctx_type*>(dev_ctx)),                 \
              src_selected_rows,                                       \
              dst_selected_rows);                                      \
    return dst_var;                                                    \
612 613
  }

614
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
615
  if (paddle::platform::is_gpu_place(place)) {
L
Leo Chen 已提交
616 617
    PADDLE_SELECTED_ROWS_ADD(phi::GPUContext, float);
    PADDLE_SELECTED_ROWS_ADD(phi::GPUContext, double);
618 619
  } else {
#endif
L
Leo Chen 已提交
620 621
    PADDLE_SELECTED_ROWS_ADD(phi::CPUContext, float);
    PADDLE_SELECTED_ROWS_ADD(phi::CPUContext, double);
622
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
623 624 625 626 627 628 629 630 631
  }
#endif

#undef PADDLE_SELECTED_ROWS_ADD
  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsMerge",
      framework::DataTypeToString(data_type)));
}

632 633 634 635 636 637
template std::shared_ptr<paddle::experimental::Tensor> SelectedRowsMerge(
    const paddle::experimental::Tensor& src1,
    const paddle::experimental::Tensor& src2);
template std::shared_ptr<paddle::imperative::VariableWrapper> SelectedRowsMerge(
    const framework::Variable& src1, const framework::Variable& src2);

638
void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var,
639 640
                        VariableWrapper* dst_var,
                        bool unchange_input) {
641
  auto& src = var->Var();
642
  auto* dst = dst_var->MutableVar();
643 644
  if (dst->IsType<framework::LoDTensor>()) {
    if (src.IsType<framework::LoDTensor>()) {
645
      TensorAdd<framework::Variable>(src, dst);
646
    } else if (src.IsType<phi::SelectedRows>()) {
647 648 649 650 651 652 653 654
      SelectedRowsAddToTensor(src, dst);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unexpected branch, output variable type is %s",
          framework::ToTypeName(dst->Type())));
    }
  } else {
    if (src.IsType<framework::LoDTensor>()) {
655 656 657 658 659 660 661 662 663
      if (unchange_input) {
        framework::Variable new_dst;
        SelectedRowsAddTensor(*dst, src, &new_dst);
        *dst = std::move(new_dst);
      } else {
        auto* src_mutable = var->MutableVar();
        SelectedRowsAddToTensor(*dst, src_mutable);
        *dst = std::move(*(var->MutableVar()));
      }
664
    } else if (src.IsType<phi::SelectedRows>()) {
665
      auto temp = SelectedRowsMerge<VariableWrapper>(src, *dst);
666 667 668 669 670 671 672 673 674
      *dst = std::move(*(temp->MutableVar()));
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unexpected branch, output variable type is %s",
          framework::ToTypeName(dst->Type())));
    }
  }
}

675 676
static platform::Place GetPlaceOfVar(
    const std::shared_ptr<VariableWrapper>& var) {
677 678 679
  platform::Place place;
  if (var->Var().IsType<framework::LoDTensor>()) {
    place = var->Var().Get<framework::LoDTensor>().place();
680 681
  } else if (var->Var().IsType<phi::SelectedRows>()) {
    place = var->Var().Get<phi::SelectedRows>().place();
682 683 684 685 686 687 688
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "only support LoDTensor and SelectedRows in dygraph"));
  }
  return place;
}

689 690
void GradientAccumulator::AccumulateGrad() {
  /**
691 692
   * If the leaf gradient has been calculated done, the inner_var_
   * should be added to the var_.
693 694 695 696
   */
  if (!var_->IsLeafGrad() || !SumGradCompleted() || !HasInnerVar()) {
    return;
  }
697 698
  PADDLE_ENFORCE_EQ(HasInnerVar(),
                    true,
699 700 701
                    platform::errors::InvalidArgument(
                        "Leaf tensor should have inner var to store results of "
                        "this auto-grad"));
702 703
  PADDLE_ENFORCE_EQ(inner_var_->Var().IsInitialized(),
                    true,
704
                    platform::errors::InvalidArgument(
705
                        "Interior var of Leaf tensor should be initialized."));
706 707 708
  auto* src = inner_var_->MutableVar();
  auto* dst = var_->MutableVar();
  if (!var_->IsEmpty()) {
709 710 711
    VLOG(6) << "Leaf Var(" << var_->Name()
            << ")'s Gradient has been initizlized, will accumulate on "
               "previous gradient.";
712 713
    if (dst->IsType<framework::LoDTensor>()) {
      if (src->IsType<framework::LoDTensor>()) {
714
        TensorAdd<framework::Variable>(*src, dst);
715
      } else if (src->IsType<phi::SelectedRows>()) {
716 717
        SelectedRowsAddToTensor(*src, dst);
      }
718
    } else if (dst->IsType<phi::SelectedRows>()) {
719 720 721
      if (src->IsType<framework::LoDTensor>()) {
        SelectedRowsAddToTensor(*dst, src);
        *dst = std::move(*src);
722
      } else if (src->IsType<phi::SelectedRows>()) {
723
        auto temp = SelectedRowsMerge<VariableWrapper>(*src, *dst);
724 725 726 727 728 729 730
        *dst = std::move(*(temp->MutableVar()));
      }
    } else {
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Only support LoDTensor and SelectedRows for gradient var"));
    }
  } else {
731 732 733
    VLOG(6)
        << "Leaf Var(" << var_->Name()
        << ")'s Gradient has not been initialized, not accumulate. Just move";
734 735 736
    *(dst) = std::move(*src);
    var_->SetType(inner_var_->Type());
    var_->SetDataType(inner_var_->DataType());
737
    var_->SetIsEmpty(false);
738 739 740 741
  }
  inner_var_.reset();
}

742
void GradientAccumulator::CallGradientHooks() {
743 744
  PADDLE_ENFORCE_EQ(var_->IsLeafGrad(),
                    true,
745 746 747 748
                    platform::errors::Unavailable(
                        "Only leaf gradient Tensor can deal with by gradient "
                        "hook in gradient accumulator."));
  PADDLE_ENFORCE_EQ(
749 750
      SumGradCompleted(),
      true,
751 752 753
      platform::errors::PreconditionNotMet(
          "Only can call gradient hooks after sum gradient completed."));
  PADDLE_ENFORCE_EQ(
754 755
      HasInnerVar(),
      true,
756 757 758
      platform::errors::PreconditionNotMet(
          "Leaf Tensor's inner var is nullptr when call gradient hook."));
  PADDLE_ENFORCE_EQ(
759 760
      inner_var_->Var().IsInitialized(),
      true,
761 762 763
      platform::errors::PreconditionNotMet("Leaf Tensor's inner var "
                                           "is not initialized when "
                                           "call gradient hook."));
764 765
  if (var_->HasVariableWrapperHook()) {
    VLOG(3) << "Call " << var_->GetVariableWrapperHooks().size()
766 767 768 769
            << " hooks of leaf gradient accumulator's inner var `"
            << var_->Name() << "`.";
    auto tmp_var = inner_var_;
    VLOG(3) << "Input var " << var_->Name() << "'s hook size - "
770 771
            << var_->GetVariableWrapperHooks().size();
    for (const auto& hook_pair : var_->GetVariableWrapperHooks()) {
772
      tmp_var = (*hook_pair.second)(tmp_var);
L
Leo Chen 已提交
773
      CheckVar(inner_var_, tmp_var);
774 775 776 777 778 779 780
    }
    inner_var_ = tmp_var;
  }
}

void GradientAccumulator::CallReduceHooks() {
  PADDLE_ENFORCE_EQ(
781 782
      var_->IsLeafGrad(),
      true,
783 784
      platform::errors::Unavailable("Only leaf gradient Tensor can deal with "
                                    "by reduce hook in gradient accumulator."));
785 786
  PADDLE_ENFORCE_EQ(SumGradCompleted(),
                    true,
787 788 789
                    platform::errors::PreconditionNotMet(
                        "Only can call reduce hooks after the gradient "
                        "summation is completed in current batch."));
790 791
  PADDLE_ENFORCE_EQ(HasInnerVar(),
                    false,
792 793 794 795
                    platform::errors::PreconditionNotMet(
                        "Only can call reduce hooks after the "
                        "gradient accumulation is completed in "
                        "current batch or across batchs."));
796 797
  if (var_->HasVoidHook()) {
    for (const auto& hook : var_->GetVoidHooks()) {
798
      VLOG(3) << "call gradient accumulator backward hooks.";
799
      (*hook)();
800 801 802 803
    }
  }
}

804
void EagerGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
805 806
                                       size_t trace_id,
                                       bool unchange_input) {
807 808 809 810 811 812 813 814
  /**
   * If var has grad node, it indicates that this var would be an input
   * of a grad op. Therefore, it should not be changed.
   */
  if (var->HasGradNode()) {
    unchange_input = true;
  }

815
  auto* dst_var = Var();
816
  platform::Place place = GetPlaceOfVar(var);
817 818 819
  if (!dst_var->OverridedStopGradient()) {
    if (CurCnt() == 0) {
      MoveOrCopyVar(dst_var->MutableVar(), var->MutableVar(), unchange_input);
820
    } else {
821 822 823
      VLOG(6) << "Sum Gradient for: " << dst_var->Name()
              << " within this graph.";
      VariableWrapperAdd(var, dst_var, unchange_input);
824
    }
J
Jiabin Yang 已提交
825
  } else {
826 827 828
    if (!dst_var->Var().IsInitialized() ||
        !dst_var->Var().Get<framework::LoDTensor>().IsInitialized()) {
      VLOG(6) << "Set StopGradient Grad: " << dst_var->Name() << " as zero ";
829
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
830 831 832 833
      if (!dst_var->Var().IsInitialized()) {
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
        VLOG(6) << "Dims of " << dst_var->Name() << " is set as: "
834 835
                << var->Var().Get<framework::LoDTensor>().dims();
        tensor->Resize(var->Var().Get<framework::LoDTensor>().dims());
836
        tensor->mutable_data(place,
837
                             framework::TransToPhiDataType(var->DataType()));
838
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
839
      } else {
840 841
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
842
        tensor->mutable_data(place,
843
                             framework::TransToPhiDataType(var->DataType()));
844
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
845
      }
846
    }
J
Jiabin Yang 已提交
847
  }
848

849 850 851 852
  // Type may be changed after OP run, such as VarTypeInference
  // so synchronous VariableWrapper with Variable.
  if (dst_var->Var().IsType<framework::LoDTensor>()) {
    dst_var->SetType(framework::proto::VarType::LOD_TENSOR);
853
  } else if (dst_var->Var().IsType<phi::SelectedRows>()) {
854
    dst_var->SetType(framework::proto::VarType::SELECTED_ROWS);
855
  }
856

857
  // Increase curent count
858
  IncreaseCurCnt();
J
Jiabin Yang 已提交
859 860
}

861
void SortedGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
862 863
                                        size_t trace_id,
                                        bool unchange_input) {
864
  auto* dst_var = Var();
865
  platform::Place place = GetPlaceOfVar(var);
866
  if (!dst_var->OverridedStopGradient()) {
867
    if (ref_cnt_ == 1) {
868 869
      MoveOrCopyVar(dst_var->MutableVar(),
                    var->MutableVar(),
870
                    unchange_input || var->HasGradNode());
871 872 873 874 875
    } else {
      if (tmp_grad_vars_.empty()) {
        tmp_grad_vars_.reserve(ref_cnt_);
      }

876
      tmp_grad_vars_.emplace_back(std::move(var), trace_id, unchange_input);
877 878 879 880 881

      if (tmp_grad_vars_.size() != ref_cnt_) {
        return;
      }

882 883
      VLOG(6) << "Sum Gradient for: " << dst_var->Name()
              << " within this graph.";
884 885
      std::sort(tmp_grad_vars_.begin(),
                tmp_grad_vars_.end(),
886 887 888 889 890 891 892 893 894
                [](const SavedVarInfo& info1, const SavedVarInfo& info2) {
                  return info1.trace_id > info2.trace_id;
                });

      for (auto& var_info : tmp_grad_vars_) {
        if (var_info.var->HasGradNode()) {
          var_info.unchange_input = true;
        }
      }
895

896
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
897
      if (paddle::platform::is_gpu_place(place)) {
898
        // sum selected rows firstly
899
        for (auto& var_info : tmp_grad_vars_) {
900
          if (!var_info.var->Var().IsType<phi::SelectedRows>()) {
901
            continue;
902
          }
903

904
          if (CurCnt() == 0) {
905 906
            MoveOrCopyVar(dst_var->MutableVar(),
                          var_info.var->MutableVar(),
907 908
                          var_info.unchange_input);
          } else {
909
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
910
          }
911 912

          var_info.var = nullptr;
913 914
          // Increase count
          IncreaseCurCnt();
915 916 917 918 919 920 921 922
        }

        for (auto& var_info : tmp_grad_vars_) {
          if (!var_info.var) {
            continue;
          }

          PADDLE_ENFORCE_EQ(var_info.var->Var().IsType<framework::LoDTensor>(),
923 924 925
                            true,
                            platform::errors::PermissionDenied(
                                "Gradient var must be LoDTensor"));
926
          if (CurCnt() == 0) {
927 928
            MoveOrCopyVar(dst_var->MutableVar(),
                          var_info.var->MutableVar(),
929 930
                          var_info.unchange_input);
          } else {
931
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
932
          }
933 934

          var_info.var = nullptr;
935 936
          // Increase count
          IncreaseCurCnt();
937 938 939
        }
      } else {
#endif
940 941 942 943 944 945
        for (auto& var_info : tmp_grad_vars_) {
          if (!var_info.var) {
            continue;
          }
          PADDLE_ENFORCE_EQ(
              var_info.var->Var().IsType<framework::LoDTensor>() ||
946
                  var_info.var->Var().IsType<phi::SelectedRows>(),
947 948 949 950
              true,
              platform::errors::PermissionDenied("The type of Gradient "
                                                 "var must be LoDTensor "
                                                 "or SelectedRows"));
951
          if (CurCnt() == 0) {
952 953
            MoveOrCopyVar(dst_var->MutableVar(),
                          var_info.var->MutableVar(),
954 955 956 957 958 959 960
                          var_info.unchange_input);
          } else {
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
          }
          var_info.var = nullptr;
          // Increase count
          IncreaseCurCnt();
961
        }
962
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
963
      }
964
#endif
965
      tmp_grad_vars_.clear();
J
Jiabin Yang 已提交
966
    }
967
  } else {
968 969
    if (!dst_var->Var().IsInitialized() ||
        !dst_var->Var().Get<framework::LoDTensor>().IsInitialized()) {
970 971
      VLOG(6) << "Set StopGradient Grad: " << var->Name() << " as zero";
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
972 973 974 975
      if (!dst_var->Var().IsInitialized()) {
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
        VLOG(6) << "Dims of " << dst_var->Name() << " is set as: "
976 977
                << var->Var().Get<framework::LoDTensor>().dims();
        tensor->Resize(var->Var().Get<framework::LoDTensor>().dims());
978
        tensor->mutable_data(place,
979
                             framework::TransToPhiDataType(var->DataType()));
980
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
981
      } else {
982 983
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
984
        tensor->mutable_data(place,
985
                             framework::TransToPhiDataType(var->DataType()));
986
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
987
      }
J
Jiabin Yang 已提交
988
    }
989
    // looks like tmp_grad_vars will not have any member but just in case
J
Jiabin Yang 已提交
990 991
    tmp_grad_vars_.clear();
  }
992

993 994
  if (dst_var->Var().IsType<framework::LoDTensor>()) {
    dst_var->SetType(framework::proto::VarType::LOD_TENSOR);
995
  } else if (dst_var->Var().IsType<phi::SelectedRows>()) {
996
    dst_var->SetType(framework::proto::VarType::SELECTED_ROWS);
997
  }
J
Jiabin Yang 已提交
998 999 1000 1001
}

}  // namespace imperative
}  // namespace paddle