gradient_accumulator.cc 31.9 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/gradient_accumulator.h"
16

J
Jiabin Yang 已提交
17 18 19
#include <algorithm>
#include <memory>
#include <utility>
20

21
#include "paddle/fluid/framework/convert_utils.h"
J
Jiabin Yang 已提交
22
#include "paddle/fluid/framework/lod_tensor.h"
23
#include "paddle/fluid/framework/selected_rows_utils.h"
J
Jiabin Yang 已提交
24
#include "paddle/fluid/imperative/layer.h"
25
#include "paddle/fluid/platform/bfloat16.h"
26
#include "paddle/fluid/platform/complex.h"
J
Jiabin Yang 已提交
27
#include "paddle/fluid/platform/device_context.h"
28
#include "paddle/fluid/platform/float16.h"
J
Jiabin Yang 已提交
29
#include "paddle/fluid/platform/profiler.h"
30 31
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
32
#include "paddle/phi/kernels/funcs/selected_rows_functor.h"
H
hong 已提交
33
#ifdef PADDLE_WITH_XPU
34
#include "paddle/phi/backends/xpu/enforce_xpu.h"
H
hong 已提交
35 36
#include "xpu/refactor/math.h"
#endif
37 38 39
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
#endif
40
#include "paddle/phi/kernels/elementwise_add_kernel.h"
J
Jiabin Yang 已提交
41 42 43 44

namespace paddle {
namespace imperative {

45 46
static void MoveOrCopyVar(framework::Variable* dst,
                          framework::Variable* src,
47 48
                          bool force_copy) {
  if (!force_copy) {
49
    VLOG(6) << "Just Move Variable when sum gradients within this graph";
50 51 52 53
    *dst = std::move(*src);
    return;
  }

54
  VLOG(6) << "Copy occurs when sum gradients within this graph";
55 56 57
  if (src->IsType<phi::DenseTensor>()) {
    auto& src_tensor = src->Get<phi::DenseTensor>();
    if (!dst->IsType<phi::DenseTensor>()) {
58 59
      dst->Clear();
    }
60
    auto* dst_tensor = dst->GetMutable<phi::DenseTensor>();
61 62
    framework::TensorCopy(src_tensor, src_tensor.place(), dst_tensor);
    dst_tensor->set_lod(src_tensor.lod());
63 64 65
  } else if (src->IsType<phi::SelectedRows>()) {
    auto& src_selected_rows = src->Get<phi::SelectedRows>();
    if (!dst->IsType<phi::SelectedRows>()) {
66 67
      dst->Clear();
    }
68
    auto* dst_selected_rows = dst->GetMutable<phi::SelectedRows>();
69 70 71 72 73 74 75
    framework::TensorCopy(src_selected_rows.value(),
                          src_selected_rows.value().place(),
                          dst_selected_rows->mutable_value());
    dst_selected_rows->set_rows(src_selected_rows.rows());
    dst_selected_rows->set_height(src_selected_rows.height());
  } else {
    PADDLE_THROW(platform::errors::PermissionDenied(
76
        "Only support LoDTensor and SelectedRows for sum gradient"));
77 78 79
  }
}

80 81 82
#ifdef PADDLE_WITH_XPU
template <typename T>
void XPUTensorAddFunctor(const platform::Place& place,
83 84
                         const phi::DenseTensor& src,
                         phi::DenseTensor* dst) {
85 86 87 88 89
  using XPUType = typename XPUTypeTrait<T>::Type;
  platform::XPUDeviceContext* ctx = dynamic_cast<platform::XPUDeviceContext*>(
      platform::DeviceContextPool::Instance().Get(place));
  const XPUType* x = reinterpret_cast<const XPUType*>(src.data<T>());
  XPUType* y = reinterpret_cast<XPUType*>(dst->mutable_data<T>(place));
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
  int r = -1;
  int numel = static_cast<int>(src.numel());
  if (std::is_same<T, double>::value) {
    xpu::ctx_guard RAII_GUARD(ctx->x_context());
    float* x_cast_to_fp32 = RAII_GUARD.alloc<float>(numel);
    PADDLE_ENFORCE_XDNN_NOT_NULL(x_cast_to_fp32);
    float* y_cast_to_fp32 = RAII_GUARD.alloc<float>(numel);
    PADDLE_ENFORCE_XDNN_NOT_NULL(y_cast_to_fp32);
    r = xpu::cast<XPUType, float>(ctx->x_context(), x, x_cast_to_fp32, numel);
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
    r = xpu::cast<XPUType, float>(ctx->x_context(), y, y_cast_to_fp32, numel);
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
    r = xpu::add<float>(ctx->x_context(),
                        x_cast_to_fp32,
                        y_cast_to_fp32,
                        y_cast_to_fp32,
                        numel);
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "add");
    r = xpu::cast<float, XPUType>(ctx->x_context(), y_cast_to_fp32, y, numel);
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
  } else {
    r = xpu::add<XPUType>(ctx->x_context(), x, y, y, numel);
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "add");
  }
114 115 116
}
#endif

117 118 119
template <typename TType>
TType* GetInnerMutableTensor(framework::Variable* dst) {
  auto* dst_tensor = dst->GetMutable<TType>();
120 121 122
  return dst_tensor;
}

123
template <typename TType>
124
TType* GetInnerMutableTensor(paddle::Tensor* dst) {
125
  auto* dst_tensor = static_cast<TType*>(dst->impl().get());
126 127 128
  return dst_tensor;
}

129 130 131
template <typename TType>
const TType& GetInnerTensor(const framework::Variable& src) {
  return src.Get<TType>();
132 133
}

134
template <typename TType>
135
TType& GetInnerTensor(const paddle::Tensor& src) {
136
  PADDLE_ENFORCE_EQ(
137 138
      src.initialized(),
      true,
139 140 141 142 143
      platform::errors::Fatal("We only add tensor with value if a tensor is "
                              "NOT INITILIZED, it should just move instead of "
                              "calling this method."));
  auto* src_tensor = static_cast<TType*>(src.impl().get());
  return *src_tensor;
144 145
}

146
template <typename TType>
147
TType* GetEmptyInnerTensor(paddle::Tensor* dst) {
148
  PADDLE_ENFORCE_EQ(
149 150
      dst->defined(),
      false,
151 152 153 154 155 156 157 158 159 160 161 162 163
      platform::errors::Fatal(
          "The underlying Tensor implementation should be nullptr"));
  dst->set_impl(std::make_shared<TType>());
  auto* dst_tensor = static_cast<TType*>(dst->impl().get());
  return dst_tensor;
}

template <typename TType>
TType* GetEmptyInnerTensor(paddle::imperative::VariableWrapper* dst) {
  auto* dst_tensor = dst->MutableVar()->GetMutable<TType>();
  return dst_tensor;
}

164 165
template <typename VarType>
void TensorAdd(const VarType& src, VarType* dst) {
166 167
  phi::DenseTensor* dst_tensor = GetInnerMutableTensor<phi::DenseTensor>(dst);
  const phi::DenseTensor& src_tensor = GetInnerTensor<phi::DenseTensor>(src);
J
Jiabin Yang 已提交
168 169 170 171 172 173 174 175 176

  auto numel = src_tensor.numel();

  // FIXME(minqiyang): loss_grad op will pass a zero grad of label
  // ugly fix for it
  if (numel == 0) {
    return;
  }

177
  PADDLE_ENFORCE_EQ(
178 179
      dst_tensor->numel(),
      numel,
180 181 182 183
      platform::errors::PreconditionNotMet(
          "The number of elements of source tensor and destination tensor "
          "should be equal, but got the number of elements of source tensor is "
          "%zu and the number of elements of destination tensor is %zu.",
184 185
          numel,
          dst_tensor->numel()));
J
Jiabin Yang 已提交
186

187
  auto data_type = framework::TransToProtoVarType(src_tensor.dtype());
J
Jiabin Yang 已提交
188 189
  auto place = src_tensor.place();

190 191 192 193
  // if src and dst are in different place, copy dst to src's place
  if (dst_tensor->place() != place) {
    paddle::framework::TensorCopySync(*dst_tensor, place, dst_tensor);
  }
194

195 196 197
  // AddKernel already support inputs of different dtype. For AMP master_grad,
  // the dtype of source tensor and destination tensor will be diferent. So the
  // check requiring input dtypes to be the same have been removed.
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
#define PADDLE_TENSOR_ADD(T, CONTEXT)                                          \
  if (data_type == framework::DataTypeTrait<T>::DataType()) {                  \
    auto cpu_ctx = static_cast<CONTEXT*>(                                      \
        platform::DeviceContextPool::Instance().Get(place));                   \
    phi::AddKernel<T, CONTEXT>(*cpu_ctx, src_tensor, *dst_tensor, dst_tensor); \
    return;                                                                    \
  }

  if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    PADDLE_TENSOR_ADD(float, phi::GPUContext);
    PADDLE_TENSOR_ADD(double, phi::GPUContext);
    PADDLE_TENSOR_ADD(phi::dtype::float16, phi::GPUContext);
    PADDLE_TENSOR_ADD(phi::dtype::bfloat16, phi::GPUContext);
    PADDLE_TENSOR_ADD(platform::complex<float>, phi::GPUContext);
    PADDLE_TENSOR_ADD(platform::complex<double>, phi::GPUContext);
#endif
  }

217 218 219 220 221 222 223
#define TENSOR_ADD_EIGEN(T)                                \
  auto cpu_ctx = static_cast<phi::CPUContext*>(            \
      platform::DeviceContextPool::Instance().Get(place)); \
  auto in = phi::EigenVector<T>::Flatten(src_tensor);      \
  auto out = phi::EigenVector<T>::Flatten(*dst_tensor);    \
  auto& p = *(cpu_ctx->eigen_device());                    \
  out.device(p) = out + in;                                \
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
  return;

  if (platform::is_cpu_place(place)) {
    PADDLE_TENSOR_ADD(float, phi::CPUContext);
    PADDLE_TENSOR_ADD(double, phi::CPUContext);
    PADDLE_TENSOR_ADD(platform::complex<float>, phi::CPUContext);
    PADDLE_TENSOR_ADD(platform::complex<double>, phi::CPUContext);
    if (data_type == framework::proto::VarType::BF16) {
      TENSOR_ADD_EIGEN(phi::dtype::bfloat16);
    }
    if (data_type == framework::proto::VarType::FP16) {
      TENSOR_ADD_EIGEN(phi::dtype::float16);
    }
  }

#define PADDLE_TENSOR_ADD_CUSTOM(T)                              \
  if (data_type == framework::DataTypeTrait<T>::DataType()) {    \
    platform::CustomDeviceContext* ctx =                         \
        static_cast<platform::CustomDeviceContext*>(             \
            platform::DeviceContextPool::Instance().Get(place)); \
    phi::stream::Stream stream(place, ctx->stream());            \
    auto device = phi::DeviceManager::GetDeviceWithPlace(place); \
    device->BlasAXPBY<T>(stream,                                 \
                         static_cast<size_t>(numel),             \
                         1.,                                     \
                         src_tensor.data<T>(),                   \
                         1.,                                     \
                         dst_tensor->mutable_data<T>(place));    \
    return;                                                      \
  }

  if (platform::is_custom_place(place)) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
    PADDLE_TENSOR_ADD_CUSTOM(float);
    PADDLE_TENSOR_ADD_CUSTOM(double);
    PADDLE_TENSOR_ADD_CUSTOM(platform::complex<float>);
    PADDLE_TENSOR_ADD_CUSTOM(platform::complex<double>);
#endif
J
Jiabin Yang 已提交
262 263
  }

264 265 266 267 268 269 270
#ifdef PADDLE_WITH_XPU
  if (platform::is_xpu_place(place)) {
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      XPUTensorAddFunctor<float>(place, src_tensor, dst_tensor);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      XPUTensorAddFunctor<platform::float16>(place, src_tensor, dst_tensor);
271 272
    } else if (data_type == framework::DataTypeTrait<double>::DataType()) {
      XPUTensorAddFunctor<double>(place, src_tensor, dst_tensor);
273 274 275 276
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
277 278
          framework::DataTypeToString(data_type),
          place));
279 280 281 282 283
    }
    return;
  }
#endif

284 285 286
  PADDLE_THROW(platform::errors::Unimplemented(
      "Gradient accumulation of data type (%s) on place (%s) is not "
      "supported in imperative mode",
287 288
      framework::DataTypeToString(data_type),
      place));
J
Jiabin Yang 已提交
289 290
}

291 292
template void TensorAdd<framework::Variable>(const framework::Variable& src,
                                             framework::Variable* dst);
293 294
template void TensorAdd<paddle::Tensor>(const paddle::Tensor& src,
                                        paddle::Tensor* dst);
295

296 297
template <typename VarType>
void SelectedRowsAddToTensor(const VarType& src, VarType* dst) {
298 299 300
  phi::DenseTensor* dst_tensor = GetInnerMutableTensor<phi::DenseTensor>(dst);
  const phi::SelectedRows& src_selected_rows =
      GetInnerTensor<phi::SelectedRows>(src);
301
  auto place = dst_tensor->place();
302 303
  auto data_type =
      framework::TransToProtoVarType(src_selected_rows.value().dtype());
304 305
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

306 307 308 309 310 311 312 313
#define PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(dev_ctx_type, cpp_type)       \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {     \
    paddle::platform::DeviceContext* dev_ctx = pool.Get(place);          \
    phi::funcs::SelectedRowsAddToTensor<dev_ctx_type, cpp_type> functor; \
    functor(*(dynamic_cast<dev_ctx_type*>(dev_ctx)),                     \
            src_selected_rows,                                           \
            dst_tensor);                                                 \
    return;                                                              \
314 315
  }

316
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
317
  if (paddle::platform::is_gpu_place(place)) {
L
Leo Chen 已提交
318 319
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::GPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::GPUContext, double);
320 321
  } else {
#endif
L
Leo Chen 已提交
322 323
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::CPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(phi::CPUContext, double);
324
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
325 326 327 328 329 330 331 332 333 334
  }
#endif

#undef PADDLE_SELECTED_ROWS_ADD_TO_TENSOR

  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsAddToTensor",
      framework::DataTypeToString(data_type)));
}

335 336
template void SelectedRowsAddToTensor(const framework::Variable& src,
                                      framework::Variable* dst);
337 338
template void SelectedRowsAddToTensor(const paddle::Tensor& src,
                                      paddle::Tensor* dst);
339 340 341 342 343

template <typename VarType>
void SelectedRowsAddTensor(const VarType& src_selected_rows_var,
                           const VarType& src_tensor_var,
                           VarType* dst_tensor_var) {
344 345 346 347
  const phi::SelectedRows& src_selected_rows =
      GetInnerTensor<phi::SelectedRows>(src_selected_rows_var);
  const phi::DenseTensor& src_tensor =
      GetInnerTensor<phi::DenseTensor>(src_tensor_var);
348
  const auto& place = src_tensor.place();
349
  auto data_type = framework::TransToProtoVarType(src_tensor.dtype());
350 351
  auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);

352 353
  phi::DenseTensor* dst_tensor =
      GetInnerMutableTensor<phi::DenseTensor>(dst_tensor_var);
354
  dst_tensor->Resize(src_tensor.dims());
355 356
  dst_tensor->mutable_data(place, src_tensor.dtype());

357 358 359 360 361 362 363 364
#define PADDLE_SELECTED_ROWS_ADD_TENSOR(dev_ctx_type, cpp_type)        \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {   \
    phi::funcs::SelectedRowsAddTensor<dev_ctx_type, cpp_type> functor; \
    functor(*(dynamic_cast<dev_ctx_type*>(dev_ctx)),                   \
            src_selected_rows,                                         \
            src_tensor,                                                \
            dst_tensor);                                               \
    return;                                                            \
365 366
  }

367
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
368
  if (platform::is_gpu_place(place)) {
L
Leo Chen 已提交
369 370
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::GPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::GPUContext, double);
371 372
  } else {
#endif
L
Leo Chen 已提交
373 374
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::CPUContext, float);
    PADDLE_SELECTED_ROWS_ADD_TENSOR(phi::CPUContext, double);
375
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
376 377 378 379 380 381 382 383 384 385
  }
#endif

  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsAddToTensor",
      framework::DataTypeToString(data_type)));

#undef PADDLE_SELECTED_ROWS_ADD_TENSOR
}

386 387 388 389
template void SelectedRowsAddTensor(
    const framework::Variable& src_selected_rows_var,
    const framework::Variable& src_tensor_var,
    framework::Variable* dst_tensor_var);
390 391 392
template void SelectedRowsAddTensor(const paddle::Tensor& src_selected_rows_var,
                                    const paddle::Tensor& src_tensor_var,
                                    paddle::Tensor* dst_tensor_var);
393 394 395 396

// Note(chenweihang): when two selected rows need to be added,
//   adding one to another is not equal to merging two selected rows
//   to one then add it to a empty selected rows, the after is correct
397 398 399
template <typename ReturnVarType, typename VarType>
std::shared_ptr<ReturnVarType> SelectedRowsMerge(const VarType& src1,
                                                 const VarType& src2) {
400 401 402 403
  const phi::SelectedRows& src_selected_rows1 =
      GetInnerTensor<phi::SelectedRows>(src1);
  const phi::SelectedRows& src_selected_rows2 =
      GetInnerTensor<phi::SelectedRows>(src2);
404

405
  auto place = src_selected_rows1.value().place();
406 407
  auto data_type =
      framework::TransToProtoVarType(src_selected_rows1.value().dtype());
408 409
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

410
  std::vector<const phi::SelectedRows*> src_selected_rows;
411 412
  src_selected_rows.emplace_back(&src_selected_rows1);
  src_selected_rows.emplace_back(&src_selected_rows2);
413 414

  auto dst_var = std::make_shared<ReturnVarType>("Temp");
415 416
  phi::SelectedRows* dst_selected_rows =
      GetEmptyInnerTensor<phi::SelectedRows>(dst_var.get());
417

418 419 420 421 422 423 424 425
#define PADDLE_SELECTED_ROWS_ADD(dev_ctx_type, cpp_type)             \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) { \
    paddle::platform::DeviceContext* dev_ctx = pool.Get(place);      \
    phi::funcs::scatter::MergeAdd<dev_ctx_type, cpp_type> merge_add; \
    merge_add(*(dynamic_cast<dev_ctx_type*>(dev_ctx)),               \
              src_selected_rows,                                     \
              dst_selected_rows);                                    \
    return dst_var;                                                  \
426 427
  }

428
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
429
  if (paddle::platform::is_gpu_place(place)) {
L
Leo Chen 已提交
430 431
    PADDLE_SELECTED_ROWS_ADD(phi::GPUContext, float);
    PADDLE_SELECTED_ROWS_ADD(phi::GPUContext, double);
432 433
  } else {
#endif
434 435 436 437 438 439 440 441 442 443
#if defined(PADDLE_WITH_XPU)
    if (paddle::platform::is_xpu_place(place)) {
      PADDLE_SELECTED_ROWS_ADD(phi::XPUContext, float);
    } else {
#endif
      PADDLE_SELECTED_ROWS_ADD(phi::CPUContext, float);
      PADDLE_SELECTED_ROWS_ADD(phi::CPUContext, double);
#if defined(PADDLE_WITH_XPU)
    }
#endif
444
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
445 446 447 448 449 450 451 452 453
  }
#endif

#undef PADDLE_SELECTED_ROWS_ADD
  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsMerge",
      framework::DataTypeToString(data_type)));
}

454 455
template std::shared_ptr<paddle::Tensor> SelectedRowsMerge(
    const paddle::Tensor& src1, const paddle::Tensor& src2);
456 457 458
template std::shared_ptr<paddle::imperative::VariableWrapper> SelectedRowsMerge(
    const framework::Variable& src1, const framework::Variable& src2);

459
void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var,
460 461
                        VariableWrapper* dst_var,
                        bool unchange_input) {
462
  auto& src = var->Var();
463
  auto* dst = dst_var->MutableVar();
464 465
  if (dst->IsType<phi::DenseTensor>()) {
    if (src.IsType<phi::DenseTensor>()) {
466
      TensorAdd<framework::Variable>(src, dst);
467
    } else if (src.IsType<phi::SelectedRows>()) {
468 469 470 471 472 473 474
      SelectedRowsAddToTensor(src, dst);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unexpected branch, output variable type is %s",
          framework::ToTypeName(dst->Type())));
    }
  } else {
475
    if (src.IsType<phi::DenseTensor>()) {
476 477 478 479 480 481 482 483 484
      if (unchange_input) {
        framework::Variable new_dst;
        SelectedRowsAddTensor(*dst, src, &new_dst);
        *dst = std::move(new_dst);
      } else {
        auto* src_mutable = var->MutableVar();
        SelectedRowsAddToTensor(*dst, src_mutable);
        *dst = std::move(*(var->MutableVar()));
      }
485
    } else if (src.IsType<phi::SelectedRows>()) {
486
      auto temp = SelectedRowsMerge<VariableWrapper>(src, *dst);
487 488 489 490 491 492 493 494 495
      *dst = std::move(*(temp->MutableVar()));
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unexpected branch, output variable type is %s",
          framework::ToTypeName(dst->Type())));
    }
  }
}

496 497
static platform::Place GetPlaceOfVar(
    const std::shared_ptr<VariableWrapper>& var) {
498
  platform::Place place;
499 500
  if (var->Var().IsType<phi::DenseTensor>()) {
    place = var->Var().Get<phi::DenseTensor>().place();
501 502
  } else if (var->Var().IsType<phi::SelectedRows>()) {
    place = var->Var().Get<phi::SelectedRows>().place();
503 504 505 506 507 508 509
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "only support LoDTensor and SelectedRows in dygraph"));
  }
  return place;
}

510 511
void GradientAccumulator::AccumulateGrad() {
  /**
512 513
   * If the leaf gradient has been calculated done, the inner_var_
   * should be added to the var_.
514 515 516 517
   */
  if (!var_->IsLeafGrad() || !SumGradCompleted() || !HasInnerVar()) {
    return;
  }
518 519
  PADDLE_ENFORCE_EQ(HasInnerVar(),
                    true,
520 521 522
                    platform::errors::InvalidArgument(
                        "Leaf tensor should have inner var to store results of "
                        "this auto-grad"));
523 524
  PADDLE_ENFORCE_EQ(inner_var_->Var().IsInitialized(),
                    true,
525
                    platform::errors::InvalidArgument(
526
                        "Interior var of Leaf tensor should be initialized."));
527 528 529
  auto* src = inner_var_->MutableVar();
  auto* dst = var_->MutableVar();
  if (!var_->IsEmpty()) {
530 531 532
    VLOG(6) << "Leaf Var(" << var_->Name()
            << ")'s Gradient has been initizlized, will accumulate on "
               "previous gradient.";
533 534
    if (dst->IsType<phi::DenseTensor>()) {
      if (src->IsType<phi::DenseTensor>()) {
535
        TensorAdd<framework::Variable>(*src, dst);
536
      } else if (src->IsType<phi::SelectedRows>()) {
537 538
        SelectedRowsAddToTensor(*src, dst);
      }
539
    } else if (dst->IsType<phi::SelectedRows>()) {
540
      if (src->IsType<phi::DenseTensor>()) {
541 542
        SelectedRowsAddToTensor(*dst, src);
        *dst = std::move(*src);
543
      } else if (src->IsType<phi::SelectedRows>()) {
544
        auto temp = SelectedRowsMerge<VariableWrapper>(*src, *dst);
545 546 547 548 549 550 551
        *dst = std::move(*(temp->MutableVar()));
      }
    } else {
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Only support LoDTensor and SelectedRows for gradient var"));
    }
  } else {
552 553 554
    VLOG(6)
        << "Leaf Var(" << var_->Name()
        << ")'s Gradient has not been initialized, not accumulate. Just move";
555 556 557
    *(dst) = std::move(*src);
    var_->SetType(inner_var_->Type());
    var_->SetDataType(inner_var_->DataType());
558
    var_->SetIsEmpty(false);
559 560 561 562
  }
  inner_var_.reset();
}

563
void GradientAccumulator::CallGradientHooks() {
564 565
  PADDLE_ENFORCE_EQ(var_->IsLeafGrad(),
                    true,
566 567 568 569
                    platform::errors::Unavailable(
                        "Only leaf gradient Tensor can deal with by gradient "
                        "hook in gradient accumulator."));
  PADDLE_ENFORCE_EQ(
570 571
      SumGradCompleted(),
      true,
572 573
      platform::errors::PreconditionNotMet(
          "Only can call gradient hooks after sum gradient completed."));
574 575 576 577 578
  PADDLE_ENFORCE_EQ(HasInnerVar(),
                    true,
                    platform::errors::PreconditionNotMet(
                        "Leaf Tensor's inner var is nullptr when "
                        "call gradient hook."));
579
  PADDLE_ENFORCE_EQ(
580 581
      inner_var_->Var().IsInitialized(),
      true,
582 583 584
      platform::errors::PreconditionNotMet("Leaf Tensor's inner var "
                                           "is not initialized when "
                                           "call gradient hook."));
585 586
  if (var_->HasVariableWrapperHook()) {
    VLOG(3) << "Call " << var_->GetVariableWrapperHooks().size()
587 588 589 590
            << " hooks of leaf gradient accumulator's inner var `"
            << var_->Name() << "`.";
    auto tmp_var = inner_var_;
    VLOG(3) << "Input var " << var_->Name() << "'s hook size - "
591 592
            << var_->GetVariableWrapperHooks().size();
    for (const auto& hook_pair : var_->GetVariableWrapperHooks()) {
593
      tmp_var = (*hook_pair.second)(tmp_var);
L
Leo Chen 已提交
594
      CheckVar(inner_var_, tmp_var);
595 596 597 598 599 600 601
    }
    inner_var_ = tmp_var;
  }
}

void GradientAccumulator::CallReduceHooks() {
  PADDLE_ENFORCE_EQ(
602 603
      var_->IsLeafGrad(),
      true,
604 605
      platform::errors::Unavailable("Only leaf gradient Tensor can deal with "
                                    "by reduce hook in gradient accumulator."));
606 607
  PADDLE_ENFORCE_EQ(SumGradCompleted(),
                    true,
608 609 610
                    platform::errors::PreconditionNotMet(
                        "Only can call reduce hooks after the gradient "
                        "summation is completed in current batch."));
611 612
  PADDLE_ENFORCE_EQ(HasInnerVar(),
                    false,
613 614 615 616
                    platform::errors::PreconditionNotMet(
                        "Only can call reduce hooks after the "
                        "gradient accumulation is completed in "
                        "current batch or across batchs."));
617 618
  if (var_->HasVoidHook()) {
    for (const auto& hook : var_->GetVoidHooks()) {
619
      VLOG(3) << "call gradient accumulator backward hooks.";
620
      (*hook)();
621 622 623 624
    }
  }
}

625
void EagerGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
626 627
                                       size_t trace_id,
                                       bool unchange_input) {
628 629 630 631 632 633 634 635
  /**
   * If var has grad node, it indicates that this var would be an input
   * of a grad op. Therefore, it should not be changed.
   */
  if (var->HasGradNode()) {
    unchange_input = true;
  }

636
  auto* dst_var = Var();
637
  platform::Place place = GetPlaceOfVar(var);
638 639 640
  if (!dst_var->OverridedStopGradient()) {
    if (CurCnt() == 0) {
      MoveOrCopyVar(dst_var->MutableVar(), var->MutableVar(), unchange_input);
641
    } else {
642 643 644
      VLOG(6) << "Sum Gradient for: " << dst_var->Name()
              << " within this graph.";
      VariableWrapperAdd(var, dst_var, unchange_input);
645
    }
J
Jiabin Yang 已提交
646
  } else {
647
    if (!dst_var->Var().IsInitialized() ||
648
        !dst_var->Var().Get<phi::DenseTensor>().IsInitialized()) {
649
      VLOG(6) << "Set StopGradient Grad: " << dst_var->Name() << " as zero ";
650
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
651
      if (!dst_var->Var().IsInitialized()) {
652 653 654 655
        auto* tensor = dst_var->MutableVar()->GetMutable<phi::DenseTensor>();
        VLOG(6) << "Dims of " << dst_var->Name()
                << " is set as: " << var->Var().Get<phi::DenseTensor>().dims();
        tensor->Resize(var->Var().Get<phi::DenseTensor>().dims());
656
        tensor->mutable_data(place,
657
                             framework::TransToPhiDataType(var->DataType()));
658
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
659
      } else {
660
        auto* tensor = dst_var->MutableVar()->GetMutable<phi::DenseTensor>();
661
        tensor->mutable_data(place,
662
                             framework::TransToPhiDataType(var->DataType()));
663
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
664
      }
665
    }
J
Jiabin Yang 已提交
666
  }
667

668 669
  // Type may be changed after OP run, such as VarTypeInference
  // so synchronous VariableWrapper with Variable.
670
  if (dst_var->Var().IsType<phi::DenseTensor>()) {
671
    dst_var->SetType(framework::proto::VarType::LOD_TENSOR);
672
  } else if (dst_var->Var().IsType<phi::SelectedRows>()) {
673
    dst_var->SetType(framework::proto::VarType::SELECTED_ROWS);
674
  }
675

676
  // Increase curent count
677
  IncreaseCurCnt();
J
Jiabin Yang 已提交
678 679
}

680
void SortedGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
681 682
                                        size_t trace_id,
                                        bool unchange_input) {
683
  auto* dst_var = Var();
684
  platform::Place place = GetPlaceOfVar(var);
685
  if (!dst_var->OverridedStopGradient()) {
686
    if (ref_cnt_ == 1) {
687 688
      MoveOrCopyVar(dst_var->MutableVar(),
                    var->MutableVar(),
689
                    unchange_input || var->HasGradNode());
690 691 692 693 694
    } else {
      if (tmp_grad_vars_.empty()) {
        tmp_grad_vars_.reserve(ref_cnt_);
      }

695
      tmp_grad_vars_.emplace_back(std::move(var), trace_id, unchange_input);
696 697 698 699 700

      if (tmp_grad_vars_.size() != ref_cnt_) {
        return;
      }

701 702
      VLOG(6) << "Sum Gradient for: " << dst_var->Name()
              << " within this graph.";
703 704
      std::sort(tmp_grad_vars_.begin(),
                tmp_grad_vars_.end(),
705 706 707 708 709 710 711 712 713
                [](const SavedVarInfo& info1, const SavedVarInfo& info2) {
                  return info1.trace_id > info2.trace_id;
                });

      for (auto& var_info : tmp_grad_vars_) {
        if (var_info.var->HasGradNode()) {
          var_info.unchange_input = true;
        }
      }
714

715
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
716
      if (paddle::platform::is_gpu_place(place)) {
717
        // sum selected rows firstly
718
        for (auto& var_info : tmp_grad_vars_) {
719
          if (!var_info.var->Var().IsType<phi::SelectedRows>()) {
720
            continue;
721
          }
722

723
          if (CurCnt() == 0) {
724 725
            MoveOrCopyVar(dst_var->MutableVar(),
                          var_info.var->MutableVar(),
726 727
                          var_info.unchange_input);
          } else {
728
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
729
          }
730 731

          var_info.var = nullptr;
732 733
          // Increase count
          IncreaseCurCnt();
734 735 736 737 738 739 740
        }

        for (auto& var_info : tmp_grad_vars_) {
          if (!var_info.var) {
            continue;
          }

741
          PADDLE_ENFORCE_EQ(var_info.var->Var().IsType<phi::DenseTensor>(),
742 743 744
                            true,
                            platform::errors::PermissionDenied(
                                "Gradient var must be LoDTensor"));
745
          if (CurCnt() == 0) {
746 747
            MoveOrCopyVar(dst_var->MutableVar(),
                          var_info.var->MutableVar(),
748 749
                          var_info.unchange_input);
          } else {
750
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
751
          }
752 753

          var_info.var = nullptr;
754 755
          // Increase count
          IncreaseCurCnt();
756 757 758
        }
      } else {
#endif
759 760 761 762 763
        for (auto& var_info : tmp_grad_vars_) {
          if (!var_info.var) {
            continue;
          }
          PADDLE_ENFORCE_EQ(
764
              var_info.var->Var().IsType<phi::DenseTensor>() ||
765
                  var_info.var->Var().IsType<phi::SelectedRows>(),
766 767 768 769
              true,
              platform::errors::PermissionDenied("The type of Gradient "
                                                 "var must be LoDTensor "
                                                 "or SelectedRows"));
770
          if (CurCnt() == 0) {
771 772
            MoveOrCopyVar(dst_var->MutableVar(),
                          var_info.var->MutableVar(),
773 774 775 776 777 778 779
                          var_info.unchange_input);
          } else {
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
          }
          var_info.var = nullptr;
          // Increase count
          IncreaseCurCnt();
780
        }
781
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
782
      }
783
#endif
784
      tmp_grad_vars_.clear();
J
Jiabin Yang 已提交
785
    }
786
  } else {
787
    if (!dst_var->Var().IsInitialized() ||
788
        !dst_var->Var().Get<phi::DenseTensor>().IsInitialized()) {
789 790
      VLOG(6) << "Set StopGradient Grad: " << var->Name() << " as zero";
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
791
      if (!dst_var->Var().IsInitialized()) {
792 793 794 795
        auto* tensor = dst_var->MutableVar()->GetMutable<phi::DenseTensor>();
        VLOG(6) << "Dims of " << dst_var->Name()
                << " is set as: " << var->Var().Get<phi::DenseTensor>().dims();
        tensor->Resize(var->Var().Get<phi::DenseTensor>().dims());
796
        tensor->mutable_data(place,
797
                             framework::TransToPhiDataType(var->DataType()));
798
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
799
      } else {
800
        auto* tensor = dst_var->MutableVar()->GetMutable<phi::DenseTensor>();
801
        tensor->mutable_data(place,
802
                             framework::TransToPhiDataType(var->DataType()));
803
        phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
804
      }
J
Jiabin Yang 已提交
805
    }
806
    // looks like tmp_grad_vars will not have any member but just in case
J
Jiabin Yang 已提交
807 808
    tmp_grad_vars_.clear();
  }
809

810
  if (dst_var->Var().IsType<phi::DenseTensor>()) {
811
    dst_var->SetType(framework::proto::VarType::LOD_TENSOR);
812
  } else if (dst_var->Var().IsType<phi::SelectedRows>()) {
813
    dst_var->SetType(framework::proto::VarType::SELECTED_ROWS);
814
  }
J
Jiabin Yang 已提交
815 816 817
}
}  // namespace imperative
}  // namespace paddle