gradient_accumulator.cc 33.8 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/gradient_accumulator.h"
16

J
Jiabin Yang 已提交
17 18 19
#include <algorithm>
#include <memory>
#include <utility>
20

21
#include "paddle/fluid/framework/convert_utils.h"
J
Jiabin Yang 已提交
22
#include "paddle/fluid/framework/lod_tensor.h"
23
#include "paddle/fluid/framework/selected_rows_utils.h"
J
Jiabin Yang 已提交
24
#include "paddle/fluid/imperative/layer.h"
25
#include "paddle/fluid/platform/bfloat16.h"
26
#include "paddle/fluid/platform/complex.h"
J
Jiabin Yang 已提交
27
#include "paddle/fluid/platform/device_context.h"
28
#include "paddle/fluid/platform/float16.h"
J
Jiabin Yang 已提交
29
#include "paddle/fluid/platform/profiler.h"
30 31
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
32
#include "paddle/phi/kernels/funcs/selected_rows_functor.h"
H
hong 已提交
33 34 35
#ifdef PADDLE_WITH_XPU
#include "xpu/refactor/math.h"
#endif
36
#ifdef PADDLE_WITH_ASCEND_CL
37
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
38
#endif
F
fwenguang 已提交
39 40 41
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#endif
42 43 44
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
#endif
45
#include "paddle/phi/kernels/elementwise_add_kernel.h"
J
Jiabin Yang 已提交
46 47 48 49

namespace paddle {
namespace imperative {

50 51
static void MoveOrCopyVar(framework::Variable* dst,
                          framework::Variable* src,
52 53
                          bool force_copy) {
  if (!force_copy) {
54
    VLOG(6) << "Just Move Variable when sum gradients within this graph";
55 56 57 58
    *dst = std::move(*src);
    return;
  }

59
  VLOG(6) << "Copy occurs when sum gradients within this graph";
60 61 62 63 64 65 66 67
  if (src->IsType<framework::LoDTensor>()) {
    auto& src_tensor = src->Get<framework::LoDTensor>();
    if (!dst->IsType<framework::LoDTensor>()) {
      dst->Clear();
    }
    auto* dst_tensor = dst->GetMutable<framework::LoDTensor>();
    framework::TensorCopy(src_tensor, src_tensor.place(), dst_tensor);
    dst_tensor->set_lod(src_tensor.lod());
68 69 70
  } else if (src->IsType<phi::SelectedRows>()) {
    auto& src_selected_rows = src->Get<phi::SelectedRows>();
    if (!dst->IsType<phi::SelectedRows>()) {
71 72
      dst->Clear();
    }
73
    auto* dst_selected_rows = dst->GetMutable<phi::SelectedRows>();
74 75 76 77 78 79 80
    framework::TensorCopy(src_selected_rows.value(),
                          src_selected_rows.value().place(),
                          dst_selected_rows->mutable_value());
    dst_selected_rows->set_rows(src_selected_rows.rows());
    dst_selected_rows->set_height(src_selected_rows.height());
  } else {
    PADDLE_THROW(platform::errors::PermissionDenied(
81
        "Only support LoDTensor and SelectedRows for sum gradient"));
82 83 84
  }
}

85 86 87
#ifdef PADDLE_WITH_XPU
template <typename T>
void XPUTensorAddFunctor(const platform::Place& place,
88 89
                         const phi::DenseTensor& src,
                         phi::DenseTensor* dst) {
90 91 92 93 94
  using XPUType = typename XPUTypeTrait<T>::Type;
  platform::XPUDeviceContext* ctx = dynamic_cast<platform::XPUDeviceContext*>(
      platform::DeviceContextPool::Instance().Get(place));
  const XPUType* x = reinterpret_cast<const XPUType*>(src.data<T>());
  XPUType* y = reinterpret_cast<XPUType*>(dst->mutable_data<T>(place));
95 96
  int r = xpu::add<XPUType>(
      ctx->x_context(), x, y, y, static_cast<int>(src.numel()));
97
  PADDLE_ENFORCE_EQ(
98 99 100 101
      r,
      XPU_SUCCESS,
      platform::errors::External(
          "XPU add kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r]));
102 103 104
}
#endif

105 106 107
template <typename TType>
TType* GetInnerMutableTensor(framework::Variable* dst) {
  auto* dst_tensor = dst->GetMutable<TType>();
108 109 110
  return dst_tensor;
}

111 112 113
template <typename TType>
TType* GetInnerMutableTensor(paddle::experimental::Tensor* dst) {
  auto* dst_tensor = static_cast<TType*>(dst->impl().get());
114 115 116
  return dst_tensor;
}

117 118 119
template <typename TType>
const TType& GetInnerTensor(const framework::Variable& src) {
  return src.Get<TType>();
120 121
}

122 123 124
template <typename TType>
TType& GetInnerTensor(const paddle::experimental::Tensor& src) {
  PADDLE_ENFORCE_EQ(
125 126
      src.initialized(),
      true,
127 128 129 130 131
      platform::errors::Fatal("We only add tensor with value if a tensor is "
                              "NOT INITILIZED, it should just move instead of "
                              "calling this method."));
  auto* src_tensor = static_cast<TType*>(src.impl().get());
  return *src_tensor;
132 133
}

134 135 136
template <typename TType>
TType* GetEmptyInnerTensor(paddle::experimental::Tensor* dst) {
  PADDLE_ENFORCE_EQ(
137 138
      dst->defined(),
      false,
139 140 141 142 143 144 145 146 147 148 149 150 151
      platform::errors::Fatal(
          "The underlying Tensor implementation should be nullptr"));
  dst->set_impl(std::make_shared<TType>());
  auto* dst_tensor = static_cast<TType*>(dst->impl().get());
  return dst_tensor;
}

template <typename TType>
TType* GetEmptyInnerTensor(paddle::imperative::VariableWrapper* dst) {
  auto* dst_tensor = dst->MutableVar()->GetMutable<TType>();
  return dst_tensor;
}

152 153
template <typename VarType>
void TensorAdd(const VarType& src, VarType* dst) {
154 155
  phi::DenseTensor* dst_tensor = GetInnerMutableTensor<phi::DenseTensor>(dst);
  const phi::DenseTensor& src_tensor = GetInnerTensor<phi::DenseTensor>(src);
J
Jiabin Yang 已提交
156 157 158 159 160 161 162 163 164

  auto numel = src_tensor.numel();

  // FIXME(minqiyang): loss_grad op will pass a zero grad of label
  // ugly fix for it
  if (numel == 0) {
    return;
  }

165
  PADDLE_ENFORCE_EQ(
166 167
      dst_tensor->numel(),
      numel,
168 169 170 171
      platform::errors::PreconditionNotMet(
          "The number of elements of source tensor and destination tensor "
          "should be equal, but got the number of elements of source tensor is "
          "%zu and the number of elements of destination tensor is %zu.",
172 173
          numel,
          dst_tensor->numel()));
J
Jiabin Yang 已提交
174

175
  auto data_type = framework::TransToProtoVarType(src_tensor.dtype());
J
Jiabin Yang 已提交
176 177
  auto place = src_tensor.place();

178 179
  PADDLE_ENFORCE_EQ(framework::TransToProtoVarType(dst_tensor->dtype()),
                    data_type,
180 181 182 183 184
                    platform::errors::PreconditionNotMet(
                        "The data type of source tensor and destination tensor "
                        "should be equal, Otherwise, the calculation results "
                        "will be incorrect."));

185 186 187 188
  // if src and dst are in different place, copy dst to src's place
  if (dst_tensor->place() != place) {
    paddle::framework::TensorCopySync(*dst_tensor, place, dst_tensor);
  }
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253

#define PADDLE_TENSOR_ADD(T, CONTEXT)                                          \
  if (data_type == framework::DataTypeTrait<T>::DataType()) {                  \
    auto cpu_ctx = static_cast<CONTEXT*>(                                      \
        platform::DeviceContextPool::Instance().Get(place));                   \
    phi::AddKernel<T, CONTEXT>(*cpu_ctx, src_tensor, *dst_tensor, dst_tensor); \
    return;                                                                    \
  }

  if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    PADDLE_TENSOR_ADD(float, phi::GPUContext);
    PADDLE_TENSOR_ADD(double, phi::GPUContext);
    PADDLE_TENSOR_ADD(phi::dtype::float16, phi::GPUContext);
    PADDLE_TENSOR_ADD(phi::dtype::bfloat16, phi::GPUContext);
    PADDLE_TENSOR_ADD(platform::complex<float>, phi::GPUContext);
    PADDLE_TENSOR_ADD(platform::complex<double>, phi::GPUContext);
#endif
  }

#define TENSOR_ADD_EIGEN(T)                                           \
  auto cpu_ctx = static_cast<phi::CPUContext*>(                       \
      platform::DeviceContextPool::Instance().Get(place));            \
  auto in = paddle::framework::EigenVector<T>::Flatten(src_tensor);   \
  auto out = paddle::framework::EigenVector<T>::Flatten(*dst_tensor); \
  auto& p = *(cpu_ctx->eigen_device());                               \
  out.device(p) = out + in;                                           \
  return;

  if (platform::is_cpu_place(place)) {
    PADDLE_TENSOR_ADD(float, phi::CPUContext);
    PADDLE_TENSOR_ADD(double, phi::CPUContext);
    PADDLE_TENSOR_ADD(platform::complex<float>, phi::CPUContext);
    PADDLE_TENSOR_ADD(platform::complex<double>, phi::CPUContext);
    if (data_type == framework::proto::VarType::BF16) {
      TENSOR_ADD_EIGEN(phi::dtype::bfloat16);
    }
    if (data_type == framework::proto::VarType::FP16) {
      TENSOR_ADD_EIGEN(phi::dtype::float16);
    }
  }

#define PADDLE_TENSOR_ADD_CUSTOM(T)                              \
  if (data_type == framework::DataTypeTrait<T>::DataType()) {    \
    platform::CustomDeviceContext* ctx =                         \
        static_cast<platform::CustomDeviceContext*>(             \
            platform::DeviceContextPool::Instance().Get(place)); \
    phi::stream::Stream stream(place, ctx->stream());            \
    auto device = phi::DeviceManager::GetDeviceWithPlace(place); \
    device->BlasAXPBY<T>(stream,                                 \
                         static_cast<size_t>(numel),             \
                         1.,                                     \
                         src_tensor.data<T>(),                   \
                         1.,                                     \
                         dst_tensor->mutable_data<T>(place));    \
    return;                                                      \
  }

  if (platform::is_custom_place(place)) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
    PADDLE_TENSOR_ADD_CUSTOM(float);
    PADDLE_TENSOR_ADD_CUSTOM(double);
    PADDLE_TENSOR_ADD_CUSTOM(platform::complex<float>);
    PADDLE_TENSOR_ADD_CUSTOM(platform::complex<double>);
#endif
J
Jiabin Yang 已提交
254 255
  }

256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
#ifdef PADDLE_WITH_ASCEND_CL
  if (platform::is_npu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::DeviceContext* ctx = pool.Get(place);
    auto dev_ctx = dynamic_cast<platform::NPUDeviceContext*>(ctx);
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      dst_tensor->mutable_data<float>(place);
    } else if (data_type == framework::DataTypeTrait<double>::DataType()) {
      dst_tensor->mutable_data<double>(place);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      dst_tensor->mutable_data<platform::float16>(place);
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
272 273
          framework::DataTypeToString(data_type),
          place));
274 275 276 277 278 279 280
    }
    const auto& runner = operators::NpuOpRunner(
        "Add", {*dst_tensor, src_tensor}, {*dst_tensor}, {});
    runner.Run(dev_ctx->stream());
    return;
  }
#endif
281

282 283 284 285 286 287 288 289 290 291 292
#ifdef PADDLE_WITH_XPU
  if (platform::is_xpu_place(place)) {
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      XPUTensorAddFunctor<float>(place, src_tensor, dst_tensor);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      XPUTensorAddFunctor<platform::float16>(place, src_tensor, dst_tensor);
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
293 294
          framework::DataTypeToString(data_type),
          place));
295 296 297 298 299
    }
    return;
  }
#endif

F
fwenguang 已提交
300 301 302 303 304 305 306 307 308 309 310 311 312 313
#ifdef PADDLE_WITH_MLU
  if (platform::is_mlu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::DeviceContext* ctx = pool.Get(place);
    auto dev_ctx = dynamic_cast<platform::MLUDeviceContext*>(ctx);
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      dst_tensor->mutable_data<float>(place);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      dst_tensor->mutable_data<platform::float16>(place);
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
314 315
          framework::DataTypeToString(data_type),
          place));
F
fwenguang 已提交
316 317 318 319 320
    }
    static const float alpha = 1.f;
    static const float beta = 1.f;
    operators::MLUCnnlTensorDesc src_tensor_desc(src_tensor);
    operators::MLUCnnlTensorDesc dst_tensor_desc(*dst_tensor);
321 322 323 324 325 326 327 328 329 330
    PADDLE_ENFORCE_MLU_SUCCESS(
        cnnlAssignAdd(dev_ctx->cnnl_handle(),
                      static_cast<const void*>(&alpha),
                      src_tensor_desc.get(),
                      operators::GetBasePtr(&src_tensor),
                      nullptr,
                      0,
                      static_cast<const void*>(&beta),
                      dst_tensor_desc.get(),
                      operators::GetBasePtr(dst_tensor)));
F
fwenguang 已提交
331 332 333 334
    return;
  }
#endif

335 336 337
  PADDLE_THROW(platform::errors::Unimplemented(
      "Gradient accumulation of data type (%s) on place (%s) is not "
      "supported in imperative mode",
338 339
      framework::DataTypeToString(data_type),
      place));
J
Jiabin Yang 已提交
340 341
}

342 343
template void TensorAdd<framework::Variable>(const framework::Variable& src,
                                             framework::Variable* dst);
344 345
template void TensorAdd<paddle::experimental::Tensor>(
    const paddle::experimental::Tensor& src, paddle::experimental::Tensor* dst);
346

347 348
template <typename VarType>
void SelectedRowsAddToTensor(const VarType& src, VarType* dst) {
349 350 351
  phi::DenseTensor* dst_tensor = GetInnerMutableTensor<phi::DenseTensor>(dst);
  const phi::SelectedRows& src_selected_rows =
      GetInnerTensor<phi::SelectedRows>(src);
352
  auto place = dst_tensor->place();
353 354
  auto data_type =
      framework::TransToProtoVarType(src_selected_rows.value().dtype());
355 356
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

357 358 359 360 361 362 363 364
#define PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(dev_ctx_type, cpp_type)       \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {     \
    paddle::platform::DeviceContext* dev_ctx = pool.Get(place);          \
    phi::funcs::SelectedRowsAddToTensor<dev_ctx_type, cpp_type> functor; \
    functor(*(dynamic_cast<dev_ctx_type*>(dev_ctx)),                     \
            src_selected_rows,                                           \
            dst_tensor);                                                 \
    return;                                                              \
365 366
  }

367
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
368
  if (paddle::platform::is_gpu_place(place)) {
L
Leo Chen 已提交
369 370
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::GPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::GPUContext, double);
371 372
  } else {
#endif
L
Leo Chen 已提交
373 374
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::CPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::CPUContext, double);
375
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
376 377 378 379 380 381 382 383 384 385
  }
#endif

#undef PADDLE_SELECTED_ROWS_ADD_TO_TENSOR

  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsAddToTensor",
      framework::DataTypeToString(data_type)));
}

386 387 388 389 390 391 392 393 394
template void SelectedRowsAddToTensor(const framework::Variable& src,
                                      framework::Variable* dst);
template void SelectedRowsAddToTensor(const paddle::experimental::Tensor& src,
                                      paddle::experimental::Tensor* dst);

template <typename VarType>
void SelectedRowsAddTensor(const VarType& src_selected_rows_var,
                           const VarType& src_tensor_var,
                           VarType* dst_tensor_var) {
395 396 397 398
  const phi::SelectedRows& src_selected_rows =
      GetInnerTensor<phi::SelectedRows>(src_selected_rows_var);
  const phi::DenseTensor& src_tensor =
      GetInnerTensor<phi::DenseTensor>(src_tensor_var);
399
  const auto& place = src_tensor.place();
400
  auto data_type = framework::TransToProtoVarType(src_tensor.dtype());
401 402
  auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);

403 404
  phi::DenseTensor* dst_tensor =
      GetInnerMutableTensor<phi::DenseTensor>(dst_tensor_var);
405
  dst_tensor->Resize(src_tensor.dims());
406 407
  dst_tensor->mutable_data(place, src_tensor.dtype());

408 409 410 411 412 413 414 415
#define PADDLE_SELECTED_ROWS_ADD_TENSOR(dev_ctx_type, cpp_type)        \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {   \
    phi::funcs::SelectedRowsAddTensor<dev_ctx_type, cpp_type> functor; \
    functor(*(dynamic_cast<dev_ctx_type*>(dev_ctx)),                   \
            src_selected_rows,                                         \
            src_tensor,                                                \
            dst_tensor);                                               \
    return;                                                            \
416 417
  }

418
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
419
  if (platform::is_gpu_place(place)) {
L
Leo Chen 已提交
420 421
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::GPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::GPUContext, double);
422 423
  } else {
#endif
L
Leo Chen 已提交
424 425
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::CPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::CPUContext, double);
426
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
427 428 429 430 431 432 433 434 435 436
  }
#endif

  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsAddToTensor",
      framework::DataTypeToString(data_type)));

#undef PADDLE_SELECTED_ROWS_ADD_TENSOR
}

437 438 439 440 441 442 443 444 445 446 447 448
template void SelectedRowsAddTensor(
    const framework::Variable& src_selected_rows_var,
    const framework::Variable& src_tensor_var,
    framework::Variable* dst_tensor_var);
template void SelectedRowsAddTensor(
    const paddle::experimental::Tensor& src_selected_rows_var,
    const paddle::experimental::Tensor& src_tensor_var,
    paddle::experimental::Tensor* dst_tensor_var);

// Note(chenweihang): when two selected rows need to be added,
//   adding one to another is not equal to merging two selected rows
//   to one then add it to a empty selected rows, the after is correct
449 450 451
template <typename ReturnVarType, typename VarType>
std::shared_ptr<ReturnVarType> SelectedRowsMerge(const VarType& src1,
                                                 const VarType& src2) {
452 453 454 455
  const phi::SelectedRows& src_selected_rows1 =
      GetInnerTensor<phi::SelectedRows>(src1);
  const phi::SelectedRows& src_selected_rows2 =
      GetInnerTensor<phi::SelectedRows>(src2);
456

457
  auto place = src_selected_rows1.value().place();
458 459
  auto data_type =
      framework::TransToProtoVarType(src_selected_rows1.value().dtype());
460 461
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

462
  std::vector<const phi::SelectedRows*> src_selected_rows;
463 464
  src_selected_rows.emplace_back(&src_selected_rows1);
  src_selected_rows.emplace_back(&src_selected_rows2);
465 466

  auto dst_var = std::make_shared<ReturnVarType>("Temp");
467 468
  phi::SelectedRows* dst_selected_rows =
      GetEmptyInnerTensor<phi::SelectedRows>(dst_var.get());
469

470 471 472 473 474 475 476 477
#define PADDLE_SELECTED_ROWS_ADD(dev_ctx_type, cpp_type)             \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) { \
    paddle::platform::DeviceContext* dev_ctx = pool.Get(place);      \
    phi::funcs::scatter::MergeAdd<dev_ctx_type, cpp_type> merge_add; \
    merge_add(*(dynamic_cast<dev_ctx_type*>(dev_ctx)),               \
              src_selected_rows,                                     \
              dst_selected_rows);                                    \
    return dst_var;                                                  \
478 479
  }

480
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
481
  if (paddle::platform::is_gpu_place(place)) {
L
Leo Chen 已提交
482 483
    PADDLE_SELECTED_ROWS_ADD(phi::GPUContext, float);
    PADDLE_SELECTED_ROWS_ADD(phi::GPUContext, double);
484 485
  } else {
#endif
L
Leo Chen 已提交
486 487
    PADDLE_SELECTED_ROWS_ADD(phi::CPUContext, float);
    PADDLE_SELECTED_ROWS_ADD(phi::CPUContext, double);
488
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
489 490 491 492 493 494 495 496 497
  }
#endif

#undef PADDLE_SELECTED_ROWS_ADD
  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsMerge",
      framework::DataTypeToString(data_type)));
}

498 499 500 501 502 503
template std::shared_ptr<paddle::experimental::Tensor> SelectedRowsMerge(
    const paddle::experimental::Tensor& src1,
    const paddle::experimental::Tensor& src2);
template std::shared_ptr<paddle::imperative::VariableWrapper> SelectedRowsMerge(
    const framework::Variable& src1, const framework::Variable& src2);

504
void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var,
505 506
                        VariableWrapper* dst_var,
                        bool unchange_input) {
507
  auto& src = var->Var();
508
  auto* dst = dst_var->MutableVar();
509 510
  if (dst->IsType<framework::LoDTensor>()) {
    if (src.IsType<framework::LoDTensor>()) {
511
      TensorAdd<framework::Variable>(src, dst);
512
    } else if (src.IsType<phi::SelectedRows>()) {
513 514 515 516 517 518 519 520
      SelectedRowsAddToTensor(src, dst);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unexpected branch, output variable type is %s",
          framework::ToTypeName(dst->Type())));
    }
  } else {
    if (src.IsType<framework::LoDTensor>()) {
521 522 523 524 525 526 527 528 529
      if (unchange_input) {
        framework::Variable new_dst;
        SelectedRowsAddTensor(*dst, src, &new_dst);
        *dst = std::move(new_dst);
      } else {
        auto* src_mutable = var->MutableVar();
        SelectedRowsAddToTensor(*dst, src_mutable);
        *dst = std::move(*(var->MutableVar()));
      }
530
    } else if (src.IsType<phi::SelectedRows>()) {
531
      auto temp = SelectedRowsMerge<VariableWrapper>(src, *dst);
532 533 534 535 536 537 538 539 540
      *dst = std::move(*(temp->MutableVar()));
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unexpected branch, output variable type is %s",
          framework::ToTypeName(dst->Type())));
    }
  }
}

541 542
static platform::Place GetPlaceOfVar(
    const std::shared_ptr<VariableWrapper>& var) {
543 544 545
  platform::Place place;
  if (var->Var().IsType<framework::LoDTensor>()) {
    place = var->Var().Get<framework::LoDTensor>().place();
546 547
  } else if (var->Var().IsType<phi::SelectedRows>()) {
    place = var->Var().Get<phi::SelectedRows>().place();
548 549 550 551 552 553 554
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "only support LoDTensor and SelectedRows in dygraph"));
  }
  return place;
}

555 556
void GradientAccumulator::AccumulateGrad() {
  /**
557 558
   * If the leaf gradient has been calculated done, the inner_var_
   * should be added to the var_.
559 560 561 562
   */
  if (!var_->IsLeafGrad() || !SumGradCompleted() || !HasInnerVar()) {
    return;
  }
563 564
  PADDLE_ENFORCE_EQ(HasInnerVar(),
                    true,
565 566 567
                    platform::errors::InvalidArgument(
                        "Leaf tensor should have inner var to store results of "
                        "this auto-grad"));
568 569
  PADDLE_ENFORCE_EQ(inner_var_->Var().IsInitialized(),
                    true,
570
                    platform::errors::InvalidArgument(
571
                        "Interior var of Leaf tensor should be initialized."));
572 573 574
  auto* src = inner_var_->MutableVar();
  auto* dst = var_->MutableVar();
  if (!var_->IsEmpty()) {
575 576 577
    VLOG(6) << "Leaf Var(" << var_->Name()
            << ")'s Gradient has been initizlized, will accumulate on "
               "previous gradient.";
578 579
    if (dst->IsType<framework::LoDTensor>()) {
      if (src->IsType<framework::LoDTensor>()) {
580
        TensorAdd<framework::Variable>(*src, dst);
581
      } else if (src->IsType<phi::SelectedRows>()) {
582 583
        SelectedRowsAddToTensor(*src, dst);
      }
584
    } else if (dst->IsType<phi::SelectedRows>()) {
585 586 587
      if (src->IsType<framework::LoDTensor>()) {
        SelectedRowsAddToTensor(*dst, src);
        *dst = std::move(*src);
588
      } else if (src->IsType<phi::SelectedRows>()) {
589
        auto temp = SelectedRowsMerge<VariableWrapper>(*src, *dst);
590 591 592 593 594 595 596
        *dst = std::move(*(temp->MutableVar()));
      }
    } else {
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Only support LoDTensor and SelectedRows for gradient var"));
    }
  } else {
597 598 599
    VLOG(6)
        << "Leaf Var(" << var_->Name()
        << ")'s Gradient has not been initialized, not accumulate. Just move";
600 601 602
    *(dst) = std::move(*src);
    var_->SetType(inner_var_->Type());
    var_->SetDataType(inner_var_->DataType());
603
    var_->SetIsEmpty(false);
604 605 606 607
  }
  inner_var_.reset();
}

608
void GradientAccumulator::CallGradientHooks() {
609 610
  PADDLE_ENFORCE_EQ(var_->IsLeafGrad(),
                    true,
611 612 613 614
                    platform::errors::Unavailable(
                        "Only leaf gradient Tensor can deal with by gradient "
                        "hook in gradient accumulator."));
  PADDLE_ENFORCE_EQ(
615 616
      SumGradCompleted(),
      true,
617 618 619
      platform::errors::PreconditionNotMet(
          "Only can call gradient hooks after sum gradient completed."));
  PADDLE_ENFORCE_EQ(
620 621
      HasInnerVar(),
      true,
622 623 624
      platform::errors::PreconditionNotMet(
          "Leaf Tensor's inner var is nullptr when call gradient hook."));
  PADDLE_ENFORCE_EQ(
625 626
      inner_var_->Var().IsInitialized(),
      true,
627 628 629
      platform::errors::PreconditionNotMet("Leaf Tensor's inner var "
                                           "is not initialized when "
                                           "call gradient hook."));
630 631
  if (var_->HasVariableWrapperHook()) {
    VLOG(3) << "Call " << var_->GetVariableWrapperHooks().size()
632 633 634 635
            << " hooks of leaf gradient accumulator's inner var `"
            << var_->Name() << "`.";
    auto tmp_var = inner_var_;
    VLOG(3) << "Input var " << var_->Name() << "'s hook size - "
636 637
            << var_->GetVariableWrapperHooks().size();
    for (const auto& hook_pair : var_->GetVariableWrapperHooks()) {
638
      tmp_var = (*hook_pair.second)(tmp_var);
L
Leo Chen 已提交
639
      CheckVar(inner_var_, tmp_var);
640 641 642 643 644 645 646
    }
    inner_var_ = tmp_var;
  }
}

void GradientAccumulator::CallReduceHooks() {
  PADDLE_ENFORCE_EQ(
647 648
      var_->IsLeafGrad(),
      true,
649 650
      platform::errors::Unavailable("Only leaf gradient Tensor can deal with "
                                    "by reduce hook in gradient accumulator."));
651 652
  PADDLE_ENFORCE_EQ(SumGradCompleted(),
                    true,
653 654 655
                    platform::errors::PreconditionNotMet(
                        "Only can call reduce hooks after the gradient "
                        "summation is completed in current batch."));
656 657
  PADDLE_ENFORCE_EQ(HasInnerVar(),
                    false,
658 659 660 661
                    platform::errors::PreconditionNotMet(
                        "Only can call reduce hooks after the "
                        "gradient accumulation is completed in "
                        "current batch or across batchs."));
662 663
  if (var_->HasVoidHook()) {
    for (const auto& hook : var_->GetVoidHooks()) {
664
      VLOG(3) << "call gradient accumulator backward hooks.";
665
      (*hook)();
666 667 668 669
    }
  }
}

670
void EagerGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
671 672
                                       size_t trace_id,
                                       bool unchange_input) {
673 674 675 676 677 678 679 680
  /**
   * If var has grad node, it indicates that this var would be an input
   * of a grad op. Therefore, it should not be changed.
   */
  if (var->HasGradNode()) {
    unchange_input = true;
  }

681
  auto* dst_var = Var();
682
  platform::Place place = GetPlaceOfVar(var);
683 684 685
  if (!dst_var->OverridedStopGradient()) {
    if (CurCnt() == 0) {
      MoveOrCopyVar(dst_var->MutableVar(), var->MutableVar(), unchange_input);
686
    } else {
687 688 689
      VLOG(6) << "Sum Gradient for: " << dst_var->Name()
              << " within this graph.";
      VariableWrapperAdd(var, dst_var, unchange_input);
690
    }
J
Jiabin Yang 已提交
691
  } else {
692 693 694
    if (!dst_var->Var().IsInitialized() ||
        !dst_var->Var().Get<framework::LoDTensor>().IsInitialized()) {
      VLOG(6) << "Set StopGradient Grad: " << dst_var->Name() << " as zero ";
695
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
696 697 698 699
      if (!dst_var->Var().IsInitialized()) {
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
        VLOG(6) << "Dims of " << dst_var->Name() << " is set as: "
700 701
                << var->Var().Get<framework::LoDTensor>().dims();
        tensor->Resize(var->Var().Get<framework::LoDTensor>().dims());
702
        tensor->mutable_data(place,
703
                             framework::TransToPhiDataType(var->DataType()));
704
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
705
      } else {
706 707
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
708
        tensor->mutable_data(place,
709
                             framework::TransToPhiDataType(var->DataType()));
710
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
711
      }
712
    }
J
Jiabin Yang 已提交
713
  }
714

715 716 717 718
  // Type may be changed after OP run, such as VarTypeInference
  // so synchronous VariableWrapper with Variable.
  if (dst_var->Var().IsType<framework::LoDTensor>()) {
    dst_var->SetType(framework::proto::VarType::LOD_TENSOR);
719
  } else if (dst_var->Var().IsType<phi::SelectedRows>()) {
720
    dst_var->SetType(framework::proto::VarType::SELECTED_ROWS);
721
  }
722

723
  // Increase curent count
724
  IncreaseCurCnt();
J
Jiabin Yang 已提交
725 726
}

727
void SortedGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
728 729
                                        size_t trace_id,
                                        bool unchange_input) {
730
  auto* dst_var = Var();
731
  platform::Place place = GetPlaceOfVar(var);
732
  if (!dst_var->OverridedStopGradient()) {
733
    if (ref_cnt_ == 1) {
734 735
      MoveOrCopyVar(dst_var->MutableVar(),
                    var->MutableVar(),
736
                    unchange_input || var->HasGradNode());
737 738 739 740 741
    } else {
      if (tmp_grad_vars_.empty()) {
        tmp_grad_vars_.reserve(ref_cnt_);
      }

742
      tmp_grad_vars_.emplace_back(std::move(var), trace_id, unchange_input);
743 744 745 746 747

      if (tmp_grad_vars_.size() != ref_cnt_) {
        return;
      }

748 749
      VLOG(6) << "Sum Gradient for: " << dst_var->Name()
              << " within this graph.";
750 751
      std::sort(tmp_grad_vars_.begin(),
                tmp_grad_vars_.end(),
752 753 754 755 756 757 758 759 760
                [](const SavedVarInfo& info1, const SavedVarInfo& info2) {
                  return info1.trace_id > info2.trace_id;
                });

      for (auto& var_info : tmp_grad_vars_) {
        if (var_info.var->HasGradNode()) {
          var_info.unchange_input = true;
        }
      }
761

762
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
763
      if (paddle::platform::is_gpu_place(place)) {
764
        // sum selected rows firstly
765
        for (auto& var_info : tmp_grad_vars_) {
766
          if (!var_info.var->Var().IsType<phi::SelectedRows>()) {
767
            continue;
768
          }
769

770
          if (CurCnt() == 0) {
771 772
            MoveOrCopyVar(dst_var->MutableVar(),
                          var_info.var->MutableVar(),
773 774
                          var_info.unchange_input);
          } else {
775
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
776
          }
777 778

          var_info.var = nullptr;
779 780
          // Increase count
          IncreaseCurCnt();
781 782 783 784 785 786 787 788
        }

        for (auto& var_info : tmp_grad_vars_) {
          if (!var_info.var) {
            continue;
          }

          PADDLE_ENFORCE_EQ(var_info.var->Var().IsType<framework::LoDTensor>(),
789 790 791
                            true,
                            platform::errors::PermissionDenied(
                                "Gradient var must be LoDTensor"));
792
          if (CurCnt() == 0) {
793 794
            MoveOrCopyVar(dst_var->MutableVar(),
                          var_info.var->MutableVar(),
795 796
                          var_info.unchange_input);
          } else {
797
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
798
          }
799 800

          var_info.var = nullptr;
801 802
          // Increase count
          IncreaseCurCnt();
803 804 805
        }
      } else {
#endif
806 807 808 809 810 811
        for (auto& var_info : tmp_grad_vars_) {
          if (!var_info.var) {
            continue;
          }
          PADDLE_ENFORCE_EQ(
              var_info.var->Var().IsType<framework::LoDTensor>() ||
812
                  var_info.var->Var().IsType<phi::SelectedRows>(),
813 814 815 816
              true,
              platform::errors::PermissionDenied("The type of Gradient "
                                                 "var must be LoDTensor "
                                                 "or SelectedRows"));
817
          if (CurCnt() == 0) {
818 819
            MoveOrCopyVar(dst_var->MutableVar(),
                          var_info.var->MutableVar(),
820 821 822 823 824 825 826
                          var_info.unchange_input);
          } else {
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
          }
          var_info.var = nullptr;
          // Increase count
          IncreaseCurCnt();
827
        }
828
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
829
      }
830
#endif
831
      tmp_grad_vars_.clear();
J
Jiabin Yang 已提交
832
    }
833
  } else {
834 835
    if (!dst_var->Var().IsInitialized() ||
        !dst_var->Var().Get<framework::LoDTensor>().IsInitialized()) {
836 837
      VLOG(6) << "Set StopGradient Grad: " << var->Name() << " as zero";
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
838 839 840 841
      if (!dst_var->Var().IsInitialized()) {
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
        VLOG(6) << "Dims of " << dst_var->Name() << " is set as: "
842 843
                << var->Var().Get<framework::LoDTensor>().dims();
        tensor->Resize(var->Var().Get<framework::LoDTensor>().dims());
844
        tensor->mutable_data(place,
845
                             framework::TransToPhiDataType(var->DataType()));
846
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
847
      } else {
848 849
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
850
        tensor->mutable_data(place,
851
                             framework::TransToPhiDataType(var->DataType()));
852
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
853
      }
J
Jiabin Yang 已提交
854
    }
855
    // looks like tmp_grad_vars will not have any member but just in case
J
Jiabin Yang 已提交
856 857
    tmp_grad_vars_.clear();
  }
858

859 860
  if (dst_var->Var().IsType<framework::LoDTensor>()) {
    dst_var->SetType(framework::proto::VarType::LOD_TENSOR);
861
  } else if (dst_var->Var().IsType<phi::SelectedRows>()) {
862
    dst_var->SetType(framework::proto::VarType::SELECTED_ROWS);
863
  }
J
Jiabin Yang 已提交
864 865 866 867
}

}  // namespace imperative
}  // namespace paddle