gradient_accumulator.cc 32.1 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/gradient_accumulator.h"
16

J
Jiabin Yang 已提交
17 18 19
#include <algorithm>
#include <memory>
#include <utility>
20

21
#include "paddle/fluid/framework/convert_utils.h"
J
Jiabin Yang 已提交
22
#include "paddle/fluid/framework/lod_tensor.h"
23
#include "paddle/fluid/framework/selected_rows_utils.h"
J
Jiabin Yang 已提交
24
#include "paddle/fluid/imperative/layer.h"
25
#include "paddle/fluid/platform/bfloat16.h"
26
#include "paddle/fluid/platform/complex.h"
J
Jiabin Yang 已提交
27
#include "paddle/fluid/platform/device_context.h"
28
#include "paddle/fluid/platform/float16.h"
J
Jiabin Yang 已提交
29
#include "paddle/fluid/platform/profiler.h"
30 31
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
32
#include "paddle/phi/kernels/funcs/selected_rows_functor.h"
H
hong 已提交
33
#ifdef PADDLE_WITH_XPU
34
#include "paddle/phi/backends/xpu/enforce_xpu.h"
H
hong 已提交
35 36
#include "xpu/refactor/math.h"
#endif
37 38 39
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
#endif
40
#include "paddle/phi/kernels/elementwise_add_kernel.h"
J
Jiabin Yang 已提交
41 42 43 44

namespace paddle {
namespace imperative {

45 46
static void MoveOrCopyVar(framework::Variable* dst,
                          framework::Variable* src,
47 48
                          bool force_copy) {
  if (!force_copy) {
49
    VLOG(6) << "Just Move Variable when sum gradients within this graph";
50 51 52 53
    *dst = std::move(*src);
    return;
  }

54
  VLOG(6) << "Copy occurs when sum gradients within this graph";
55 56 57
  if (src->IsType<phi::DenseTensor>()) {
    auto& src_tensor = src->Get<phi::DenseTensor>();
    if (!dst->IsType<phi::DenseTensor>()) {
58 59
      dst->Clear();
    }
60
    auto* dst_tensor = dst->GetMutable<phi::DenseTensor>();
61 62
    framework::TensorCopy(src_tensor, src_tensor.place(), dst_tensor);
    dst_tensor->set_lod(src_tensor.lod());
63 64 65
  } else if (src->IsType<phi::SelectedRows>()) {
    auto& src_selected_rows = src->Get<phi::SelectedRows>();
    if (!dst->IsType<phi::SelectedRows>()) {
66 67
      dst->Clear();
    }
68
    auto* dst_selected_rows = dst->GetMutable<phi::SelectedRows>();
69 70 71 72 73 74 75
    framework::TensorCopy(src_selected_rows.value(),
                          src_selected_rows.value().place(),
                          dst_selected_rows->mutable_value());
    dst_selected_rows->set_rows(src_selected_rows.rows());
    dst_selected_rows->set_height(src_selected_rows.height());
  } else {
    PADDLE_THROW(platform::errors::PermissionDenied(
76
        "Only support LoDTensor and SelectedRows for sum gradient"));
77 78 79
  }
}

80 81 82
#ifdef PADDLE_WITH_XPU
template <typename T>
void XPUTensorAddFunctor(const platform::Place& place,
83 84
                         const phi::DenseTensor& src,
                         phi::DenseTensor* dst) {
85 86 87 88 89
  using XPUType = typename XPUTypeTrait<T>::Type;
  platform::XPUDeviceContext* ctx = dynamic_cast<platform::XPUDeviceContext*>(
      platform::DeviceContextPool::Instance().Get(place));
  const XPUType* x = reinterpret_cast<const XPUType*>(src.data<T>());
  XPUType* y = reinterpret_cast<XPUType*>(dst->mutable_data<T>(place));
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
  int r = -1;
  int numel = static_cast<int>(src.numel());
  if (std::is_same<T, double>::value) {
    xpu::ctx_guard RAII_GUARD(ctx->x_context());
    float* x_cast_to_fp32 = RAII_GUARD.alloc<float>(numel);
    PADDLE_ENFORCE_XDNN_NOT_NULL(x_cast_to_fp32);
    float* y_cast_to_fp32 = RAII_GUARD.alloc<float>(numel);
    PADDLE_ENFORCE_XDNN_NOT_NULL(y_cast_to_fp32);
    r = xpu::cast<XPUType, float>(ctx->x_context(), x, x_cast_to_fp32, numel);
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
    r = xpu::cast<XPUType, float>(ctx->x_context(), y, y_cast_to_fp32, numel);
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
    r = xpu::add<float>(ctx->x_context(),
                        x_cast_to_fp32,
                        y_cast_to_fp32,
                        y_cast_to_fp32,
                        numel);
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "add");
    r = xpu::cast<float, XPUType>(ctx->x_context(), y_cast_to_fp32, y, numel);
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
  } else {
    r = xpu::add<XPUType>(ctx->x_context(), x, y, y, numel);
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "add");
  }
114 115 116
}
#endif

117 118 119
template <typename TType>
TType* GetInnerMutableTensor(framework::Variable* dst) {
  auto* dst_tensor = dst->GetMutable<TType>();
120 121 122
  return dst_tensor;
}

123
template <typename TType>
124
TType* GetInnerMutableTensor(paddle::Tensor* dst) {
125
  auto* dst_tensor = static_cast<TType*>(dst->impl().get());
126 127 128
  return dst_tensor;
}

129 130 131
template <typename TType>
const TType& GetInnerTensor(const framework::Variable& src) {
  return src.Get<TType>();
132 133
}

134
template <typename TType>
135
TType& GetInnerTensor(const paddle::Tensor& src) {
136
  PADDLE_ENFORCE_EQ(
137 138
      src.initialized(),
      true,
139 140 141 142 143
      platform::errors::Fatal("We only add tensor with value if a tensor is "
                              "NOT INITILIZED, it should just move instead of "
                              "calling this method."));
  auto* src_tensor = static_cast<TType*>(src.impl().get());
  return *src_tensor;
144 145
}

146
template <typename TType>
147
TType* GetEmptyInnerTensor(paddle::Tensor* dst) {
148
  PADDLE_ENFORCE_EQ(
149 150
      dst->defined(),
      false,
151 152 153 154 155 156 157 158 159 160 161 162 163
      platform::errors::Fatal(
          "The underlying Tensor implementation should be nullptr"));
  dst->set_impl(std::make_shared<TType>());
  auto* dst_tensor = static_cast<TType*>(dst->impl().get());
  return dst_tensor;
}

template <typename TType>
TType* GetEmptyInnerTensor(paddle::imperative::VariableWrapper* dst) {
  auto* dst_tensor = dst->MutableVar()->GetMutable<TType>();
  return dst_tensor;
}

164 165
template <typename VarType>
void TensorAdd(const VarType& src, VarType* dst) {
166 167
  phi::DenseTensor* dst_tensor = GetInnerMutableTensor<phi::DenseTensor>(dst);
  const phi::DenseTensor& src_tensor = GetInnerTensor<phi::DenseTensor>(src);
J
Jiabin Yang 已提交
168 169 170 171 172 173 174 175 176

  auto numel = src_tensor.numel();

  // FIXME(minqiyang): loss_grad op will pass a zero grad of label
  // ugly fix for it
  if (numel == 0) {
    return;
  }

177
  PADDLE_ENFORCE_EQ(
178 179
      dst_tensor->numel(),
      numel,
180 181 182 183
      platform::errors::PreconditionNotMet(
          "The number of elements of source tensor and destination tensor "
          "should be equal, but got the number of elements of source tensor is "
          "%zu and the number of elements of destination tensor is %zu.",
184 185
          numel,
          dst_tensor->numel()));
J
Jiabin Yang 已提交
186

187
  auto data_type = framework::TransToProtoVarType(src_tensor.dtype());
J
Jiabin Yang 已提交
188 189
  auto place = src_tensor.place();

190 191
  PADDLE_ENFORCE_EQ(framework::TransToProtoVarType(dst_tensor->dtype()),
                    data_type,
192 193 194 195 196
                    platform::errors::PreconditionNotMet(
                        "The data type of source tensor and destination tensor "
                        "should be equal, Otherwise, the calculation results "
                        "will be incorrect."));

197 198 199 200
  // if src and dst are in different place, copy dst to src's place
  if (dst_tensor->place() != place) {
    paddle::framework::TensorCopySync(*dst_tensor, place, dst_tensor);
  }
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220

#define PADDLE_TENSOR_ADD(T, CONTEXT)                                          \
  if (data_type == framework::DataTypeTrait<T>::DataType()) {                  \
    auto cpu_ctx = static_cast<CONTEXT*>(                                      \
        platform::DeviceContextPool::Instance().Get(place));                   \
    phi::AddKernel<T, CONTEXT>(*cpu_ctx, src_tensor, *dst_tensor, dst_tensor); \
    return;                                                                    \
  }

  if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    PADDLE_TENSOR_ADD(float, phi::GPUContext);
    PADDLE_TENSOR_ADD(double, phi::GPUContext);
    PADDLE_TENSOR_ADD(phi::dtype::float16, phi::GPUContext);
    PADDLE_TENSOR_ADD(phi::dtype::bfloat16, phi::GPUContext);
    PADDLE_TENSOR_ADD(platform::complex<float>, phi::GPUContext);
    PADDLE_TENSOR_ADD(platform::complex<double>, phi::GPUContext);
#endif
  }

221 222 223 224 225 226 227
#define TENSOR_ADD_EIGEN(T)                                \
  auto cpu_ctx = static_cast<phi::CPUContext*>(            \
      platform::DeviceContextPool::Instance().Get(place)); \
  auto in = phi::EigenVector<T>::Flatten(src_tensor);      \
  auto out = phi::EigenVector<T>::Flatten(*dst_tensor);    \
  auto& p = *(cpu_ctx->eigen_device());                    \
  out.device(p) = out + in;                                \
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
  return;

  if (platform::is_cpu_place(place)) {
    PADDLE_TENSOR_ADD(float, phi::CPUContext);
    PADDLE_TENSOR_ADD(double, phi::CPUContext);
    PADDLE_TENSOR_ADD(platform::complex<float>, phi::CPUContext);
    PADDLE_TENSOR_ADD(platform::complex<double>, phi::CPUContext);
    if (data_type == framework::proto::VarType::BF16) {
      TENSOR_ADD_EIGEN(phi::dtype::bfloat16);
    }
    if (data_type == framework::proto::VarType::FP16) {
      TENSOR_ADD_EIGEN(phi::dtype::float16);
    }
  }

#define PADDLE_TENSOR_ADD_CUSTOM(T)                              \
  if (data_type == framework::DataTypeTrait<T>::DataType()) {    \
    platform::CustomDeviceContext* ctx =                         \
        static_cast<platform::CustomDeviceContext*>(             \
            platform::DeviceContextPool::Instance().Get(place)); \
    phi::stream::Stream stream(place, ctx->stream());            \
    auto device = phi::DeviceManager::GetDeviceWithPlace(place); \
    device->BlasAXPBY<T>(stream,                                 \
                         static_cast<size_t>(numel),             \
                         1.,                                     \
                         src_tensor.data<T>(),                   \
                         1.,                                     \
                         dst_tensor->mutable_data<T>(place));    \
    return;                                                      \
  }

  if (platform::is_custom_place(place)) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
    PADDLE_TENSOR_ADD_CUSTOM(float);
    PADDLE_TENSOR_ADD_CUSTOM(double);
    PADDLE_TENSOR_ADD_CUSTOM(platform::complex<float>);
    PADDLE_TENSOR_ADD_CUSTOM(platform::complex<double>);
#endif
J
Jiabin Yang 已提交
266 267
  }

268 269 270 271 272 273 274
#ifdef PADDLE_WITH_XPU
  if (platform::is_xpu_place(place)) {
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      XPUTensorAddFunctor<float>(place, src_tensor, dst_tensor);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      XPUTensorAddFunctor<platform::float16>(place, src_tensor, dst_tensor);
275 276
    } else if (data_type == framework::DataTypeTrait<double>::DataType()) {
      XPUTensorAddFunctor<double>(place, src_tensor, dst_tensor);
277 278 279 280
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
281 282
          framework::DataTypeToString(data_type),
          place));
283 284 285 286 287
    }
    return;
  }
#endif

288 289 290
  PADDLE_THROW(platform::errors::Unimplemented(
      "Gradient accumulation of data type (%s) on place (%s) is not "
      "supported in imperative mode",
291 292
      framework::DataTypeToString(data_type),
      place));
J
Jiabin Yang 已提交
293 294
}

295 296
template void TensorAdd<framework::Variable>(const framework::Variable& src,
                                             framework::Variable* dst);
297 298
template void TensorAdd<paddle::Tensor>(const paddle::Tensor& src,
                                        paddle::Tensor* dst);
299

300 301
template <typename VarType>
void SelectedRowsAddToTensor(const VarType& src, VarType* dst) {
302 303 304
  phi::DenseTensor* dst_tensor = GetInnerMutableTensor<phi::DenseTensor>(dst);
  const phi::SelectedRows& src_selected_rows =
      GetInnerTensor<phi::SelectedRows>(src);
305
  auto place = dst_tensor->place();
306 307
  auto data_type =
      framework::TransToProtoVarType(src_selected_rows.value().dtype());
308 309
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

310 311 312 313 314 315 316 317
#define PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(dev_ctx_type, cpp_type)       \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {     \
    paddle::platform::DeviceContext* dev_ctx = pool.Get(place);          \
    phi::funcs::SelectedRowsAddToTensor<dev_ctx_type, cpp_type> functor; \
    functor(*(dynamic_cast<dev_ctx_type*>(dev_ctx)),                     \
            src_selected_rows,                                           \
            dst_tensor);                                                 \
    return;                                                              \
318 319
  }

320
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
321
  if (paddle::platform::is_gpu_place(place)) {
L
Leo Chen 已提交
322 323
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::GPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::GPUContext, double);
324 325
  } else {
#endif
L
Leo Chen 已提交
326 327
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::CPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::CPUContext, double);
328
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
329 330 331 332 333 334 335 336 337 338
  }
#endif

#undef PADDLE_SELECTED_ROWS_ADD_TO_TENSOR

  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsAddToTensor",
      framework::DataTypeToString(data_type)));
}

339 340
template void SelectedRowsAddToTensor(const framework::Variable& src,
                                      framework::Variable* dst);
341 342
template void SelectedRowsAddToTensor(const paddle::Tensor& src,
                                      paddle::Tensor* dst);
343 344 345 346 347

template <typename VarType>
void SelectedRowsAddTensor(const VarType& src_selected_rows_var,
                           const VarType& src_tensor_var,
                           VarType* dst_tensor_var) {
348 349 350 351
  const phi::SelectedRows& src_selected_rows =
      GetInnerTensor<phi::SelectedRows>(src_selected_rows_var);
  const phi::DenseTensor& src_tensor =
      GetInnerTensor<phi::DenseTensor>(src_tensor_var);
352
  const auto& place = src_tensor.place();
353
  auto data_type = framework::TransToProtoVarType(src_tensor.dtype());
354 355
  auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);

356 357
  phi::DenseTensor* dst_tensor =
      GetInnerMutableTensor<phi::DenseTensor>(dst_tensor_var);
358
  dst_tensor->Resize(src_tensor.dims());
359 360
  dst_tensor->mutable_data(place, src_tensor.dtype());

361 362 363 364 365 366 367 368
#define PADDLE_SELECTED_ROWS_ADD_TENSOR(dev_ctx_type, cpp_type)        \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {   \
    phi::funcs::SelectedRowsAddTensor<dev_ctx_type, cpp_type> functor; \
    functor(*(dynamic_cast<dev_ctx_type*>(dev_ctx)),                   \
            src_selected_rows,                                         \
            src_tensor,                                                \
            dst_tensor);                                               \
    return;                                                            \
369 370
  }

371
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
372
  if (platform::is_gpu_place(place)) {
L
Leo Chen 已提交
373 374
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::GPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::GPUContext, double);
375 376
  } else {
#endif
L
Leo Chen 已提交
377 378
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::CPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::CPUContext, double);
379
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
380 381 382 383 384 385 386 387 388 389
  }
#endif

  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsAddToTensor",
      framework::DataTypeToString(data_type)));

#undef PADDLE_SELECTED_ROWS_ADD_TENSOR
}

390 391 392 393
template void SelectedRowsAddTensor(
    const framework::Variable& src_selected_rows_var,
    const framework::Variable& src_tensor_var,
    framework::Variable* dst_tensor_var);
394 395 396
template void SelectedRowsAddTensor(const paddle::Tensor& src_selected_rows_var,
                                    const paddle::Tensor& src_tensor_var,
                                    paddle::Tensor* dst_tensor_var);
397 398 399 400

// Note(chenweihang): when two selected rows need to be added,
//   adding one to another is not equal to merging two selected rows
//   to one then add it to a empty selected rows, the after is correct
401 402 403
template <typename ReturnVarType, typename VarType>
std::shared_ptr<ReturnVarType> SelectedRowsMerge(const VarType& src1,
                                                 const VarType& src2) {
404 405 406 407
  const phi::SelectedRows& src_selected_rows1 =
      GetInnerTensor<phi::SelectedRows>(src1);
  const phi::SelectedRows& src_selected_rows2 =
      GetInnerTensor<phi::SelectedRows>(src2);
408

409
  auto place = src_selected_rows1.value().place();
410 411
  auto data_type =
      framework::TransToProtoVarType(src_selected_rows1.value().dtype());
412 413
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

414
  std::vector<const phi::SelectedRows*> src_selected_rows;
415 416
  src_selected_rows.emplace_back(&src_selected_rows1);
  src_selected_rows.emplace_back(&src_selected_rows2);
417 418

  auto dst_var = std::make_shared<ReturnVarType>("Temp");
419 420
  phi::SelectedRows* dst_selected_rows =
      GetEmptyInnerTensor<phi::SelectedRows>(dst_var.get());
421

422 423 424 425 426 427 428 429
#define PADDLE_SELECTED_ROWS_ADD(dev_ctx_type, cpp_type)             \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) { \
    paddle::platform::DeviceContext* dev_ctx = pool.Get(place);      \
    phi::funcs::scatter::MergeAdd<dev_ctx_type, cpp_type> merge_add; \
    merge_add(*(dynamic_cast<dev_ctx_type*>(dev_ctx)),               \
              src_selected_rows,                                     \
              dst_selected_rows);                                    \
    return dst_var;                                                  \
430 431
  }

432
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
433
  if (paddle::platform::is_gpu_place(place)) {
L
Leo Chen 已提交
434 435
    PADDLE_SELECTED_ROWS_ADD(phi::GPUContext, float);
    PADDLE_SELECTED_ROWS_ADD(phi::GPUContext, double);
436 437
  } else {
#endif
438 439 440 441 442 443 444 445 446 447
#if defined(PADDLE_WITH_XPU)
    if (paddle::platform::is_xpu_place(place)) {
      PADDLE_SELECTED_ROWS_ADD(phi::XPUContext, float);
    } else {
#endif
      PADDLE_SELECTED_ROWS_ADD(phi::CPUContext, float);
      PADDLE_SELECTED_ROWS_ADD(phi::CPUContext, double);
#if defined(PADDLE_WITH_XPU)
    }
#endif
448
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
449 450 451 452 453 454 455 456 457
  }
#endif

#undef PADDLE_SELECTED_ROWS_ADD
  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsMerge",
      framework::DataTypeToString(data_type)));
}

458 459
template std::shared_ptr<paddle::Tensor> SelectedRowsMerge(
    const paddle::Tensor& src1, const paddle::Tensor& src2);
460 461 462
template std::shared_ptr<paddle::imperative::VariableWrapper> SelectedRowsMerge(
    const framework::Variable& src1, const framework::Variable& src2);

463
void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var,
464 465
                        VariableWrapper* dst_var,
                        bool unchange_input) {
466
  auto& src = var->Var();
467
  auto* dst = dst_var->MutableVar();
468 469
  if (dst->IsType<phi::DenseTensor>()) {
    if (src.IsType<phi::DenseTensor>()) {
470
      TensorAdd<framework::Variable>(src, dst);
471
    } else if (src.IsType<phi::SelectedRows>()) {
472 473 474 475 476 477 478
      SelectedRowsAddToTensor(src, dst);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unexpected branch, output variable type is %s",
          framework::ToTypeName(dst->Type())));
    }
  } else {
479
    if (src.IsType<phi::DenseTensor>()) {
480 481 482 483 484 485 486 487 488
      if (unchange_input) {
        framework::Variable new_dst;
        SelectedRowsAddTensor(*dst, src, &new_dst);
        *dst = std::move(new_dst);
      } else {
        auto* src_mutable = var->MutableVar();
        SelectedRowsAddToTensor(*dst, src_mutable);
        *dst = std::move(*(var->MutableVar()));
      }
489
    } else if (src.IsType<phi::SelectedRows>()) {
490
      auto temp = SelectedRowsMerge<VariableWrapper>(src, *dst);
491 492 493 494 495 496 497 498 499
      *dst = std::move(*(temp->MutableVar()));
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unexpected branch, output variable type is %s",
          framework::ToTypeName(dst->Type())));
    }
  }
}

500 501
static platform::Place GetPlaceOfVar(
    const std::shared_ptr<VariableWrapper>& var) {
502
  platform::Place place;
503 504
  if (var->Var().IsType<phi::DenseTensor>()) {
    place = var->Var().Get<phi::DenseTensor>().place();
505 506
  } else if (var->Var().IsType<phi::SelectedRows>()) {
    place = var->Var().Get<phi::SelectedRows>().place();
507 508 509 510 511 512 513
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "only support LoDTensor and SelectedRows in dygraph"));
  }
  return place;
}

514 515
void GradientAccumulator::AccumulateGrad() {
  /**
516 517
   * If the leaf gradient has been calculated done, the inner_var_
   * should be added to the var_.
518 519 520 521
   */
  if (!var_->IsLeafGrad() || !SumGradCompleted() || !HasInnerVar()) {
    return;
  }
522 523
  PADDLE_ENFORCE_EQ(HasInnerVar(),
                    true,
524 525 526
                    platform::errors::InvalidArgument(
                        "Leaf tensor should have inner var to store results of "
                        "this auto-grad"));
527 528
  PADDLE_ENFORCE_EQ(inner_var_->Var().IsInitialized(),
                    true,
529
                    platform::errors::InvalidArgument(
530
                        "Interior var of Leaf tensor should be initialized."));
531 532 533
  auto* src = inner_var_->MutableVar();
  auto* dst = var_->MutableVar();
  if (!var_->IsEmpty()) {
534 535 536
    VLOG(6) << "Leaf Var(" << var_->Name()
            << ")'s Gradient has been initizlized, will accumulate on "
               "previous gradient.";
537 538
    if (dst->IsType<phi::DenseTensor>()) {
      if (src->IsType<phi::DenseTensor>()) {
539
        TensorAdd<framework::Variable>(*src, dst);
540
      } else if (src->IsType<phi::SelectedRows>()) {
541 542
        SelectedRowsAddToTensor(*src, dst);
      }
543
    } else if (dst->IsType<phi::SelectedRows>()) {
544
      if (src->IsType<phi::DenseTensor>()) {
545 546
        SelectedRowsAddToTensor(*dst, src);
        *dst = std::move(*src);
547
      } else if (src->IsType<phi::SelectedRows>()) {
548
        auto temp = SelectedRowsMerge<VariableWrapper>(*src, *dst);
549 550 551 552 553 554 555
        *dst = std::move(*(temp->MutableVar()));
      }
    } else {
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Only support LoDTensor and SelectedRows for gradient var"));
    }
  } else {
556 557 558
    VLOG(6)
        << "Leaf Var(" << var_->Name()
        << ")'s Gradient has not been initialized, not accumulate. Just move";
559 560 561
    *(dst) = std::move(*src);
    var_->SetType(inner_var_->Type());
    var_->SetDataType(inner_var_->DataType());
562
    var_->SetIsEmpty(false);
563 564 565 566
  }
  inner_var_.reset();
}

567
void GradientAccumulator::CallGradientHooks() {
568 569
  PADDLE_ENFORCE_EQ(var_->IsLeafGrad(),
                    true,
570 571 572 573
                    platform::errors::Unavailable(
                        "Only leaf gradient Tensor can deal with by gradient "
                        "hook in gradient accumulator."));
  PADDLE_ENFORCE_EQ(
574 575
      SumGradCompleted(),
      true,
576 577
      platform::errors::PreconditionNotMet(
          "Only can call gradient hooks after sum gradient completed."));
578 579 580 581 582
  PADDLE_ENFORCE_EQ(HasInnerVar(),
                    true,
                    platform::errors::PreconditionNotMet(
                        "Leaf Tensor's inner var is nullptr when "
                        "call gradient hook."));
583
  PADDLE_ENFORCE_EQ(
584 585
      inner_var_->Var().IsInitialized(),
      true,
586 587 588
      platform::errors::PreconditionNotMet("Leaf Tensor's inner var "
                                           "is not initialized when "
                                           "call gradient hook."));
589 590
  if (var_->HasVariableWrapperHook()) {
    VLOG(3) << "Call " << var_->GetVariableWrapperHooks().size()
591 592 593 594
            << " hooks of leaf gradient accumulator's inner var `"
            << var_->Name() << "`.";
    auto tmp_var = inner_var_;
    VLOG(3) << "Input var " << var_->Name() << "'s hook size - "
595 596
            << var_->GetVariableWrapperHooks().size();
    for (const auto& hook_pair : var_->GetVariableWrapperHooks()) {
597
      tmp_var = (*hook_pair.second)(tmp_var);
L
Leo Chen 已提交
598
      CheckVar(inner_var_, tmp_var);
599 600 601 602 603 604 605
    }
    inner_var_ = tmp_var;
  }
}

void GradientAccumulator::CallReduceHooks() {
  PADDLE_ENFORCE_EQ(
606 607
      var_->IsLeafGrad(),
      true,
608 609
      platform::errors::Unavailable("Only leaf gradient Tensor can deal with "
                                    "by reduce hook in gradient accumulator."));
610 611
  PADDLE_ENFORCE_EQ(SumGradCompleted(),
                    true,
612 613 614
                    platform::errors::PreconditionNotMet(
                        "Only can call reduce hooks after the gradient "
                        "summation is completed in current batch."));
615 616
  PADDLE_ENFORCE_EQ(HasInnerVar(),
                    false,
617 618 619 620
                    platform::errors::PreconditionNotMet(
                        "Only can call reduce hooks after the "
                        "gradient accumulation is completed in "
                        "current batch or across batchs."));
621 622
  if (var_->HasVoidHook()) {
    for (const auto& hook : var_->GetVoidHooks()) {
623
      VLOG(3) << "call gradient accumulator backward hooks.";
624
      (*hook)();
625 626 627 628
    }
  }
}

629
void EagerGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
630 631
                                       size_t trace_id,
                                       bool unchange_input) {
632 633 634 635 636 637 638 639
  /**
   * If var has grad node, it indicates that this var would be an input
   * of a grad op. Therefore, it should not be changed.
   */
  if (var->HasGradNode()) {
    unchange_input = true;
  }

640
  auto* dst_var = Var();
641
  platform::Place place = GetPlaceOfVar(var);
642 643 644
  if (!dst_var->OverridedStopGradient()) {
    if (CurCnt() == 0) {
      MoveOrCopyVar(dst_var->MutableVar(), var->MutableVar(), unchange_input);
645
    } else {
646 647 648
      VLOG(6) << "Sum Gradient for: " << dst_var->Name()
              << " within this graph.";
      VariableWrapperAdd(var, dst_var, unchange_input);
649
    }
J
Jiabin Yang 已提交
650
  } else {
651
    if (!dst_var->Var().IsInitialized() ||
652
        !dst_var->Var().Get<phi::DenseTensor>().IsInitialized()) {
653
      VLOG(6) << "Set StopGradient Grad: " << dst_var->Name() << " as zero ";
654
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
655
      if (!dst_var->Var().IsInitialized()) {
656 657 658 659
        auto* tensor = dst_var->MutableVar()->GetMutable<phi::DenseTensor>();
        VLOG(6) << "Dims of " << dst_var->Name()
                << " is set as: " << var->Var().Get<phi::DenseTensor>().dims();
        tensor->Resize(var->Var().Get<phi::DenseTensor>().dims());
660
        tensor->mutable_data(place,
661
                             framework::TransToPhiDataType(var->DataType()));
662
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
663
      } else {
664
        auto* tensor = dst_var->MutableVar()->GetMutable<phi::DenseTensor>();
665
        tensor->mutable_data(place,
666
                             framework::TransToPhiDataType(var->DataType()));
667
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
668
      }
669
    }
J
Jiabin Yang 已提交
670
  }
671

672 673
  // Type may be changed after OP run, such as VarTypeInference
  // so synchronous VariableWrapper with Variable.
674
  if (dst_var->Var().IsType<phi::DenseTensor>()) {
675
    dst_var->SetType(framework::proto::VarType::LOD_TENSOR);
676
  } else if (dst_var->Var().IsType<phi::SelectedRows>()) {
677
    dst_var->SetType(framework::proto::VarType::SELECTED_ROWS);
678
  }
679

680
  // Increase curent count
681
  IncreaseCurCnt();
J
Jiabin Yang 已提交
682 683
}

684
void SortedGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
685 686
                                        size_t trace_id,
                                        bool unchange_input) {
687
  auto* dst_var = Var();
688
  platform::Place place = GetPlaceOfVar(var);
689
  if (!dst_var->OverridedStopGradient()) {
690
    if (ref_cnt_ == 1) {
691 692
      MoveOrCopyVar(dst_var->MutableVar(),
                    var->MutableVar(),
693
                    unchange_input || var->HasGradNode());
694 695 696 697 698
    } else {
      if (tmp_grad_vars_.empty()) {
        tmp_grad_vars_.reserve(ref_cnt_);
      }

699
      tmp_grad_vars_.emplace_back(std::move(var), trace_id, unchange_input);
700 701 702 703 704

      if (tmp_grad_vars_.size() != ref_cnt_) {
        return;
      }

705 706
      VLOG(6) << "Sum Gradient for: " << dst_var->Name()
              << " within this graph.";
707 708
      std::sort(tmp_grad_vars_.begin(),
                tmp_grad_vars_.end(),
709 710 711 712 713 714 715 716 717
                [](const SavedVarInfo& info1, const SavedVarInfo& info2) {
                  return info1.trace_id > info2.trace_id;
                });

      for (auto& var_info : tmp_grad_vars_) {
        if (var_info.var->HasGradNode()) {
          var_info.unchange_input = true;
        }
      }
718

719
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
720
      if (paddle::platform::is_gpu_place(place)) {
721
        // sum selected rows firstly
722
        for (auto& var_info : tmp_grad_vars_) {
723
          if (!var_info.var->Var().IsType<phi::SelectedRows>()) {
724
            continue;
725
          }
726

727
          if (CurCnt() == 0) {
728 729
            MoveOrCopyVar(dst_var->MutableVar(),
                          var_info.var->MutableVar(),
730 731
                          var_info.unchange_input);
          } else {
732
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
733
          }
734 735

          var_info.var = nullptr;
736 737
          // Increase count
          IncreaseCurCnt();
738 739 740 741 742 743 744
        }

        for (auto& var_info : tmp_grad_vars_) {
          if (!var_info.var) {
            continue;
          }

745
          PADDLE_ENFORCE_EQ(var_info.var->Var().IsType<phi::DenseTensor>(),
746 747 748
                            true,
                            platform::errors::PermissionDenied(
                                "Gradient var must be LoDTensor"));
749
          if (CurCnt() == 0) {
750 751
            MoveOrCopyVar(dst_var->MutableVar(),
                          var_info.var->MutableVar(),
752 753
                          var_info.unchange_input);
          } else {
754
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
755
          }
756 757

          var_info.var = nullptr;
758 759
          // Increase count
          IncreaseCurCnt();
760 761 762
        }
      } else {
#endif
763 764 765 766 767
        for (auto& var_info : tmp_grad_vars_) {
          if (!var_info.var) {
            continue;
          }
          PADDLE_ENFORCE_EQ(
768
              var_info.var->Var().IsType<phi::DenseTensor>() ||
769
                  var_info.var->Var().IsType<phi::SelectedRows>(),
770 771 772 773
              true,
              platform::errors::PermissionDenied("The type of Gradient "
                                                 "var must be LoDTensor "
                                                 "or SelectedRows"));
774
          if (CurCnt() == 0) {
775 776
            MoveOrCopyVar(dst_var->MutableVar(),
                          var_info.var->MutableVar(),
777 778 779 780 781 782 783
                          var_info.unchange_input);
          } else {
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
          }
          var_info.var = nullptr;
          // Increase count
          IncreaseCurCnt();
784
        }
785
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
786
      }
787
#endif
788
      tmp_grad_vars_.clear();
J
Jiabin Yang 已提交
789
    }
790
  } else {
791
    if (!dst_var->Var().IsInitialized() ||
792
        !dst_var->Var().Get<phi::DenseTensor>().IsInitialized()) {
793 794
      VLOG(6) << "Set StopGradient Grad: " << var->Name() << " as zero";
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
795
      if (!dst_var->Var().IsInitialized()) {
796 797 798 799
        auto* tensor = dst_var->MutableVar()->GetMutable<phi::DenseTensor>();
        VLOG(6) << "Dims of " << dst_var->Name()
                << " is set as: " << var->Var().Get<phi::DenseTensor>().dims();
        tensor->Resize(var->Var().Get<phi::DenseTensor>().dims());
800
        tensor->mutable_data(place,
801
                             framework::TransToPhiDataType(var->DataType()));
802
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
803
      } else {
804
        auto* tensor = dst_var->MutableVar()->GetMutable<phi::DenseTensor>();
805
        tensor->mutable_data(place,
806
                             framework::TransToPhiDataType(var->DataType()));
807
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
808
      }
J
Jiabin Yang 已提交
809
    }
810
    // looks like tmp_grad_vars will not have any member but just in case
J
Jiabin Yang 已提交
811 812
    tmp_grad_vars_.clear();
  }
813

814
  if (dst_var->Var().IsType<phi::DenseTensor>()) {
815
    dst_var->SetType(framework::proto::VarType::LOD_TENSOR);
816
  } else if (dst_var->Var().IsType<phi::SelectedRows>()) {
817
    dst_var->SetType(framework::proto::VarType::SELECTED_ROWS);
818
  }
J
Jiabin Yang 已提交
819 820 821
}
}  // namespace imperative
}  // namespace paddle