gradient_accumulator.cc 31.1 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/gradient_accumulator.h"
16

J
Jiabin Yang 已提交
17 18 19
#include <algorithm>
#include <memory>
#include <utility>
20

J
Jiabin Yang 已提交
21
#include "paddle/fluid/framework/lod_tensor.h"
22
#include "paddle/fluid/framework/selected_rows_utils.h"
J
Jiabin Yang 已提交
23 24 25
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
26
#include "paddle/fluid/operators/math/selected_rows_functor.h"
27
#include "paddle/fluid/platform/complex.h"
J
Jiabin Yang 已提交
28
#include "paddle/fluid/platform/device_context.h"
29
#include "paddle/fluid/platform/float16.h"
J
Jiabin Yang 已提交
30
#include "paddle/fluid/platform/profiler.h"
H
hong 已提交
31 32 33
#ifdef PADDLE_WITH_XPU
#include "xpu/refactor/math.h"
#endif
34
#ifdef PADDLE_WITH_ASCEND_CL
35
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
36
#endif
J
Jiabin Yang 已提交
37 38 39 40

namespace paddle {
namespace imperative {

41 42 43
static void MoveOrCopyVar(framework::Variable* dst, framework::Variable* src,
                          bool force_copy) {
  if (!force_copy) {
44
    VLOG(6) << "Just Move Variable when sum gradients within this graph";
45 46 47 48
    *dst = std::move(*src);
    return;
  }

49
  VLOG(6) << "Copy occurs when sum gradients within this graph";
50 51 52 53 54 55 56 57
  if (src->IsType<framework::LoDTensor>()) {
    auto& src_tensor = src->Get<framework::LoDTensor>();
    if (!dst->IsType<framework::LoDTensor>()) {
      dst->Clear();
    }
    auto* dst_tensor = dst->GetMutable<framework::LoDTensor>();
    framework::TensorCopy(src_tensor, src_tensor.place(), dst_tensor);
    dst_tensor->set_lod(src_tensor.lod());
58 59 60
  } else if (src->IsType<pten::SelectedRows>()) {
    auto& src_selected_rows = src->Get<pten::SelectedRows>();
    if (!dst->IsType<pten::SelectedRows>()) {
61 62
      dst->Clear();
    }
63
    auto* dst_selected_rows = dst->GetMutable<pten::SelectedRows>();
64 65 66 67 68 69 70
    framework::TensorCopy(src_selected_rows.value(),
                          src_selected_rows.value().place(),
                          dst_selected_rows->mutable_value());
    dst_selected_rows->set_rows(src_selected_rows.rows());
    dst_selected_rows->set_height(src_selected_rows.height());
  } else {
    PADDLE_THROW(platform::errors::PermissionDenied(
71
        "Only support LoDTensor and SelectedRows for sum gradient"));
72 73 74
  }
}

J
Jiabin Yang 已提交
75 76 77 78 79 80
template <typename T>
class TensorAddFunctor : public boost::static_visitor<> {
 public:
  TensorAddFunctor(int64_t numel, const T* x, T* y)
      : numel_(numel), x_(x), y_(y) {}

81
  void operator()(const platform::CPUPlace& place) const {
J
Jiabin Yang 已提交
82 83 84 85 86 87
    platform::CPUDeviceContext* ctx = dynamic_cast<platform::CPUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto blas = operators::math::GetBlas<platform::CPUDeviceContext, T>(*ctx);
    blas.AXPY(numel_, 1., x_, y_);
  }

H
hong 已提交
88
#ifdef PADDLE_WITH_XPU
89
  void operator()(const platform::XPUPlace& place) const {
90
    using XPUType = typename XPUTypeTrait<T>::Type;
H
hong 已提交
91 92
    platform::XPUDeviceContext* ctx = dynamic_cast<platform::XPUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
93 94 95 96 97 98 99 100
    int r = xpu::add<XPUType>(
        ctx->x_context(), reinterpret_cast<const XPUType*>(x_),
        reinterpret_cast<const XPUType*>(y_), reinterpret_cast<XPUType*>(y_),
        static_cast<int>(numel_));
    PADDLE_ENFORCE_EQ(
        r, XPU_SUCCESS,
        platform::errors::External("XPU add kernel return wrong value[%d %s]",
                                   r, XPUAPIErrorMsg[r]));
H
hong 已提交
101 102
  }
#else
103
  void operator()(const platform::XPUPlace& place) const {
104 105 106 107 108
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
H
hong 已提交
109
#endif
110

111
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
112
  void operator()(const platform::CUDAPlace& place) const {
J
Jiabin Yang 已提交
113 114 115 116 117 118 119
    platform::CUDADeviceContext* ctx =
        dynamic_cast<platform::CUDADeviceContext*>(
            platform::DeviceContextPool::Instance().Get(place));
    auto blas = operators::math::GetBlas<platform::CUDADeviceContext, T>(*ctx);
    blas.AXPY(numel_, 1., x_, y_);
  }
#else
120
  void operator()(const platform::CUDAPlace& place) const {
121
    PADDLE_THROW(platform::errors::PermissionDenied(
122 123 124 125 126 127
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
#endif

F
fwenguang 已提交
128
#ifdef PADDLE_WITH_MLU
129
  void operator()(const platform::MLUPlace& place) const {
F
fwenguang 已提交
130 131 132 133 134 135 136
    // TODO(fwg): SUPPORT it
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
#else
137
  void operator()(const platform::MLUPlace& place) const {
F
fwenguang 已提交
138 139 140 141 142 143 144
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
#endif

145
#ifdef PADDLE_WITH_ASCEND_CL
146
  void operator()(const platform::NPUPlace& place) const {
147 148 149 150 151 152 153
    // TODO(zhiqiu): SUPPORT it
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
#else
154
  void operator()(const platform::NPUPlace& place) const {
155
    PADDLE_THROW(platform::errors::PermissionDenied(
156 157 158
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
J
Jiabin Yang 已提交
159 160 161
  }
#endif

162
  void operator()(const platform::NPUPinnedPlace& place) const {
163 164 165 166 167
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
J
Jiabin Yang 已提交
168
  // there is NO blas in CUDAPinnedPlace
169
  void operator()(const platform::CUDAPinnedPlace& place) const {
170 171 172 173
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
J
Jiabin Yang 已提交
174
  }
J
jianghaicheng 已提交
175
  // there is NO support in IPUPlace
176
  void operator()(const platform::IPUPlace& place) const {
J
jianghaicheng 已提交
177 178 179 180 181
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Gradient accumulation on place (%s) "
        "is not supported in imperative mode",
        place));
  }
J
Jiabin Yang 已提交
182 183 184 185

 private:
  int64_t numel_;
  const T* x_;
186
  mutable T* y_;
J
Jiabin Yang 已提交
187 188
};

189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
#ifdef PADDLE_WITH_XPU
template <typename T>
void XPUTensorAddFunctor(const platform::Place& place,
                         const framework::Tensor& src, framework::Tensor* dst) {
  using XPUType = typename XPUTypeTrait<T>::Type;
  platform::XPUDeviceContext* ctx = dynamic_cast<platform::XPUDeviceContext*>(
      platform::DeviceContextPool::Instance().Get(place));
  const XPUType* x = reinterpret_cast<const XPUType*>(src.data<T>());
  XPUType* y = reinterpret_cast<XPUType*>(dst->mutable_data<T>(place));
  int r = xpu::add<XPUType>(ctx->x_context(), x, y, y,
                            static_cast<int>(src.numel()));
  PADDLE_ENFORCE_EQ(
      r, XPU_SUCCESS,
      platform::errors::External("XPU add kernel return wrong value[%d %s]", r,
                                 XPUAPIErrorMsg[r]));
}
#endif

207 208 209 210 211 212 213 214 215 216
template <typename DeviceContext, typename T>
void TensorAddImpl(const framework::Tensor& src, framework::Tensor* dst,
                   const platform::Place& place) {
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  paddle::platform::DeviceContext* ctx = pool.Get(place);
  auto dev_ctx = dynamic_cast<DeviceContext*>(ctx);
  operators::math::ElementwiseAddTo<DeviceContext, T> func;
  func(dev_ctx, src, dst);
}

J
Jiabin Yang 已提交
217 218 219 220 221 222 223 224 225 226 227 228
void TensorAdd(const framework::Variable& src, framework::Variable* dst) {
  auto* dst_tensor = dst->GetMutable<framework::LoDTensor>();
  auto& src_tensor = src.Get<framework::LoDTensor>();

  auto numel = src_tensor.numel();

  // FIXME(minqiyang): loss_grad op will pass a zero grad of label
  // ugly fix for it
  if (numel == 0) {
    return;
  }

229 230 231 232 233 234 235
  PADDLE_ENFORCE_EQ(
      dst_tensor->numel(), numel,
      platform::errors::PreconditionNotMet(
          "The number of elements of source tensor and destination tensor "
          "should be equal, but got the number of elements of source tensor is "
          "%zu and the number of elements of destination tensor is %zu.",
          numel, dst_tensor->numel()));
J
Jiabin Yang 已提交
236 237 238 239

  auto data_type = src_tensor.type();
  auto place = src_tensor.place();

240 241 242 243 244 245
  PADDLE_ENFORCE_EQ(dst_tensor->type(), data_type,
                    platform::errors::PreconditionNotMet(
                        "The data type of source tensor and destination tensor "
                        "should be equal, Otherwise, the calculation results "
                        "will be incorrect."));

246 247 248 249 250 251 252
#ifdef PADDLE_WITH_XPU
  // if src and dst are in different place, copy dst to src's place
  if (dst_tensor->place() != place) {
    paddle::framework::TensorCopySync(*dst_tensor, place, dst_tensor);
  }
#endif

253
#define PADDLE_TENSOR_ADD(cpp_type)                                  \
J
Jiabin Yang 已提交
254 255 256 257
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) { \
    TensorAddFunctor<cpp_type> func(                                 \
        numel, src_tensor.data<cpp_type>(),                          \
        dst_tensor->mutable_data<cpp_type>(place));                  \
258
    platform::VisitPlace(place, func);                               \
J
Jiabin Yang 已提交
259 260 261
    return;                                                          \
  }

262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
#ifdef PADDLE_WITH_ASCEND_CL
  if (platform::is_npu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::DeviceContext* ctx = pool.Get(place);
    auto dev_ctx = dynamic_cast<platform::NPUDeviceContext*>(ctx);
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      dst_tensor->mutable_data<float>(place);
    } else if (data_type == framework::DataTypeTrait<double>::DataType()) {
      dst_tensor->mutable_data<double>(place);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      dst_tensor->mutable_data<platform::float16>(place);
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
          framework::DataTypeToString(data_type), place));
    }
    const auto& runner = operators::NpuOpRunner(
        "Add", {*dst_tensor, src_tensor}, {*dst_tensor}, {});
    runner.Run(dev_ctx->stream());
    return;
  }
#endif
286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303

#ifdef PADDLE_WITH_XPU
  if (platform::is_xpu_place(place)) {
    if (data_type == framework::DataTypeTrait<float>::DataType()) {
      XPUTensorAddFunctor<float>(place, src_tensor, dst_tensor);
    } else if (data_type ==
               framework::DataTypeTrait<platform::float16>::DataType()) {
      XPUTensorAddFunctor<platform::float16>(place, src_tensor, dst_tensor);
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
          framework::DataTypeToString(data_type), place));
    }
    return;
  }
#endif

304
  PADDLE_TENSOR_ADD(float);
305

H
hong 已提交
306 307
#ifndef PADDLE_WITH_XPU
  // NOTE(phlrain): xpu only support float
308
  PADDLE_TENSOR_ADD(double);
309 310
  // NOTE(chenweihang): only support complex grad tensor accumulated,
  // support selected rows if needed in the future
311 312
  PADDLE_TENSOR_ADD(platform::complex<float>);
  PADDLE_TENSOR_ADD(platform::complex<double>);
H
hong 已提交
313
#endif
J
Jiabin Yang 已提交
314

315
#undef PADDLE_TENSOR_ADD
J
Jiabin Yang 已提交
316

317 318
  if (data_type == framework::proto::VarType::FP16) {
    if (platform::is_gpu_place(place)) {
319
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336
      return TensorAddImpl<platform::CUDADeviceContext, platform::float16>(
          src_tensor, dst_tensor, place);
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "Gradient accumulation of data type (%s) on place (%s) is not "
          "supported in imperative mode",
          framework::DataTypeToString(data_type), place));
#endif
    } else if (platform::is_cpu_place(place)) {
      return TensorAddImpl<platform::CPUDeviceContext, platform::float16>(
          src_tensor, dst_tensor, place);
    }
  }
  PADDLE_THROW(platform::errors::Unimplemented(
      "Gradient accumulation of data type (%s) on place (%s) is not "
      "supported in imperative mode",
      framework::DataTypeToString(data_type), place));
J
Jiabin Yang 已提交
337 338
}

339 340 341
void SelectedRowsAddToTensor(const framework::Variable& src,
                             framework::Variable* dst) {
  auto* dst_tensor = dst->GetMutable<framework::LoDTensor>();
342
  auto& src_selected_rows = src.Get<pten::SelectedRows>();
343 344 345 346 347 348 349 350 351 352 353 354 355 356
  auto place = dst_tensor->place();
  auto data_type = src_selected_rows.value().type();
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

#define PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(dev_ctx_type, cpp_type)           \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {         \
    paddle::platform::DeviceContext* dev_ctx = pool.Get(place);              \
    paddle::operators::math::SelectedRowsAddToTensor<dev_ctx_type, cpp_type> \
        functor;                                                             \
    functor(*(dynamic_cast<dev_ctx_type*>(dev_ctx)), src_selected_rows,      \
            dst_tensor);                                                     \
    return;                                                                  \
  }

357
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
358 359 360 361 362 363 364
  if (paddle::platform::is_gpu_place(place)) {
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CUDADeviceContext, float);
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CUDADeviceContext, double);
  } else {
#endif
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CPUDeviceContext, float);
    PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CPUDeviceContext, double);
365
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
366 367 368 369 370 371 372 373 374 375
  }
#endif

#undef PADDLE_SELECTED_ROWS_ADD_TO_TENSOR

  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsAddToTensor",
      framework::DataTypeToString(data_type)));
}

376 377 378 379 380
static void SelectedRowsAddTensor(
    const framework::Variable& src_selected_rows_var,
    const framework::Variable& src_tensor_var,
    framework::Variable* dst_tensor_var) {
  const auto& src_selected_rows =
381
      src_selected_rows_var.Get<pten::SelectedRows>();
382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
  const auto& src_tensor = src_tensor_var.Get<framework::LoDTensor>();
  const auto& place = src_tensor.place();
  auto data_type = src_tensor.type();
  auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);

  auto* dst_tensor = dst_tensor_var->GetMutable<framework::LoDTensor>();
  dst_tensor->Resize(src_tensor.dims());
  dst_tensor->mutable_data(place, data_type);

#define PADDLE_SELECTED_ROWS_ADD_TENSOR(dev_ctx_type, cpp_type)            \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {       \
    paddle::operators::math::SelectedRowsAddTensor<dev_ctx_type, cpp_type> \
        functor;                                                           \
    functor(*(dynamic_cast<dev_ctx_type*>(dev_ctx)), src_selected_rows,    \
            src_tensor, dst_tensor);                                       \
    return;                                                                \
  }

400
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
401 402 403 404 405 406 407
  if (platform::is_gpu_place(place)) {
    PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CUDADeviceContext, float);
    PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CUDADeviceContext, double);
  } else {
#endif
    PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CPUDeviceContext, float);
    PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CPUDeviceContext, double);
408
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
409 410 411 412 413 414 415 416 417 418
  }
#endif

  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsAddToTensor",
      framework::DataTypeToString(data_type)));

#undef PADDLE_SELECTED_ROWS_ADD_TENSOR
}

419 420 421
// Note(chenweihang): when two selected rows need to be added,
//   adding one to another is not equal to merging two selected rows
//   to one then add it to a empty selected rows, the after is correct
422 423
std::shared_ptr<VariableWrapper> SelectedRowsMerge(
    const framework::Variable& src1, const framework::Variable& src2) {
424 425
  auto& src_selected_rows1 = src1.Get<pten::SelectedRows>();
  auto& src_selected_rows2 = src2.Get<pten::SelectedRows>();
426 427 428 429
  auto place = src_selected_rows1.value().place();
  auto data_type = src_selected_rows1.value().type();
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

430
  std::vector<const pten::SelectedRows*> src_selected_rows;
431 432
  src_selected_rows.emplace_back(&src_selected_rows1);
  src_selected_rows.emplace_back(&src_selected_rows2);
433
  auto dst_var = std::make_shared<VariableWrapper>("Temp");
434
  auto* dst_selected_rows =
435
      dst_var->MutableVar()->GetMutable<pten::SelectedRows>();
436 437 438 439 440 441 442 443 444 445 446

#define PADDLE_SELECTED_ROWS_ADD(dev_ctx_type, cpp_type)                  \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) {      \
    paddle::platform::DeviceContext* dev_ctx = pool.Get(place);           \
    paddle::operators::math::scatter::MergeAdd<dev_ctx_type, cpp_type>    \
        merge_add;                                                        \
    merge_add(*(dynamic_cast<dev_ctx_type*>(dev_ctx)), src_selected_rows, \
              dst_selected_rows);                                         \
    return dst_var;                                                       \
  }

447
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
448 449 450 451 452 453 454
  if (paddle::platform::is_gpu_place(place)) {
    PADDLE_SELECTED_ROWS_ADD(platform::CUDADeviceContext, float);
    PADDLE_SELECTED_ROWS_ADD(platform::CUDADeviceContext, double);
  } else {
#endif
    PADDLE_SELECTED_ROWS_ADD(platform::CPUDeviceContext, float);
    PADDLE_SELECTED_ROWS_ADD(platform::CPUDeviceContext, double);
455
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
456 457 458 459 460 461 462 463 464 465
  }
#endif

#undef PADDLE_SELECTED_ROWS_ADD

  PADDLE_THROW(platform::errors::InvalidArgument(
      "Not supported data type %s for SelectedRowsMerge",
      framework::DataTypeToString(data_type)));
}

466
void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var,
467
                        VariableWrapper* dst_var, bool unchange_input) {
468
  auto& src = var->Var();
469
  auto* dst = dst_var->MutableVar();
470 471 472
  if (dst->IsType<framework::LoDTensor>()) {
    if (src.IsType<framework::LoDTensor>()) {
      TensorAdd(src, dst);
473
    } else if (src.IsType<pten::SelectedRows>()) {
474 475 476 477 478 479 480 481
      SelectedRowsAddToTensor(src, dst);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unexpected branch, output variable type is %s",
          framework::ToTypeName(dst->Type())));
    }
  } else {
    if (src.IsType<framework::LoDTensor>()) {
482 483 484 485 486 487 488 489 490
      if (unchange_input) {
        framework::Variable new_dst;
        SelectedRowsAddTensor(*dst, src, &new_dst);
        *dst = std::move(new_dst);
      } else {
        auto* src_mutable = var->MutableVar();
        SelectedRowsAddToTensor(*dst, src_mutable);
        *dst = std::move(*(var->MutableVar()));
      }
491
    } else if (src.IsType<pten::SelectedRows>()) {
492
      auto temp = SelectedRowsMerge(src, *dst);
493 494 495 496 497 498 499 500 501
      *dst = std::move(*(temp->MutableVar()));
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unexpected branch, output variable type is %s",
          framework::ToTypeName(dst->Type())));
    }
  }
}

502 503
static platform::Place GetPlaceOfVar(
    const std::shared_ptr<VariableWrapper>& var) {
504 505 506
  platform::Place place;
  if (var->Var().IsType<framework::LoDTensor>()) {
    place = var->Var().Get<framework::LoDTensor>().place();
507 508
  } else if (var->Var().IsType<pten::SelectedRows>()) {
    place = var->Var().Get<pten::SelectedRows>().place();
509 510 511 512 513 514 515
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "only support LoDTensor and SelectedRows in dygraph"));
  }
  return place;
}

516 517
void GradientAccumulator::AccumulateGrad() {
  /**
518 519
   * If the leaf gradient has been calculated done, the inner_var_
   * should be added to the var_.
520 521 522 523 524 525 526 527 528 529
   */
  if (!var_->IsLeafGrad() || !SumGradCompleted() || !HasInnerVar()) {
    return;
  }
  PADDLE_ENFORCE_EQ(HasInnerVar(), true,
                    platform::errors::InvalidArgument(
                        "Leaf tensor should have inner var to store results of "
                        "this auto-grad"));
  PADDLE_ENFORCE_EQ(inner_var_->Var().IsInitialized(), true,
                    platform::errors::InvalidArgument(
530
                        "Interior var of Leaf tensor should be initialized."));
531 532 533
  auto* src = inner_var_->MutableVar();
  auto* dst = var_->MutableVar();
  if (!var_->IsEmpty()) {
534 535 536
    VLOG(6) << "Leaf Var(" << var_->Name()
            << ")'s Gradient has been initizlized, will accumulate on "
               "previous gradient.";
537 538 539
    if (dst->IsType<framework::LoDTensor>()) {
      if (src->IsType<framework::LoDTensor>()) {
        TensorAdd(*src, dst);
540
      } else if (src->IsType<pten::SelectedRows>()) {
541 542
        SelectedRowsAddToTensor(*src, dst);
      }
543
    } else if (dst->IsType<pten::SelectedRows>()) {
544 545 546
      if (src->IsType<framework::LoDTensor>()) {
        SelectedRowsAddToTensor(*dst, src);
        *dst = std::move(*src);
547
      } else if (src->IsType<pten::SelectedRows>()) {
548 549 550 551 552 553 554 555
        auto temp = SelectedRowsMerge(*src, *dst);
        *dst = std::move(*(temp->MutableVar()));
      }
    } else {
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Only support LoDTensor and SelectedRows for gradient var"));
    }
  } else {
556 557 558
    VLOG(6)
        << "Leaf Var(" << var_->Name()
        << ")'s Gradient has not been initialized, not accumulate. Just move";
559 560 561
    *(dst) = std::move(*src);
    var_->SetType(inner_var_->Type());
    var_->SetDataType(inner_var_->DataType());
562
    var_->SetIsEmpty(false);
563 564 565 566
  }
  inner_var_.reset();
}

567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584
void GradientAccumulator::CallGradientHooks() {
  PADDLE_ENFORCE_EQ(var_->IsLeafGrad(), true,
                    platform::errors::Unavailable(
                        "Only leaf gradient Tensor can deal with by gradient "
                        "hook in gradient accumulator."));
  PADDLE_ENFORCE_EQ(
      SumGradCompleted(), true,
      platform::errors::PreconditionNotMet(
          "Only can call gradient hooks after sum gradient completed."));
  PADDLE_ENFORCE_EQ(
      HasInnerVar(), true,
      platform::errors::PreconditionNotMet(
          "Leaf Tensor's inner var is nullptr when call gradient hook."));
  PADDLE_ENFORCE_EQ(
      inner_var_->Var().IsInitialized(), true,
      platform::errors::PreconditionNotMet("Leaf Tensor's inner var "
                                           "is not initialized when "
                                           "call gradient hook."));
585 586
  if (var_->HasVariableWrapperHook()) {
    VLOG(3) << "Call " << var_->GetVariableWrapperHooks().size()
587 588 589 590
            << " hooks of leaf gradient accumulator's inner var `"
            << var_->Name() << "`.";
    auto tmp_var = inner_var_;
    VLOG(3) << "Input var " << var_->Name() << "'s hook size - "
591 592
            << var_->GetVariableWrapperHooks().size();
    for (const auto& hook_pair : var_->GetVariableWrapperHooks()) {
593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612
      tmp_var = (*hook_pair.second)(tmp_var);
    }
    inner_var_ = tmp_var;
  }
}

void GradientAccumulator::CallReduceHooks() {
  PADDLE_ENFORCE_EQ(
      var_->IsLeafGrad(), true,
      platform::errors::Unavailable("Only leaf gradient Tensor can deal with "
                                    "by reduce hook in gradient accumulator."));
  PADDLE_ENFORCE_EQ(SumGradCompleted(), true,
                    platform::errors::PreconditionNotMet(
                        "Only can call reduce hooks after the gradient "
                        "summation is completed in current batch."));
  PADDLE_ENFORCE_EQ(HasInnerVar(), false,
                    platform::errors::PreconditionNotMet(
                        "Only can call reduce hooks after the "
                        "gradient accumulation is completed in "
                        "current batch or across batchs."));
613 614
  if (var_->HasVoidHook()) {
    for (const auto& hook : var_->GetVoidHooks()) {
615
      VLOG(3) << "call gradient accumulator backward hooks.";
616
      (*hook)();
617 618 619 620
    }
  }
}

621 622
void EagerGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
                                       size_t trace_id, bool unchange_input) {
623 624 625 626 627 628 629 630
  /**
   * If var has grad node, it indicates that this var would be an input
   * of a grad op. Therefore, it should not be changed.
   */
  if (var->HasGradNode()) {
    unchange_input = true;
  }

631
  auto* dst_var = Var();
632
  platform::Place place = GetPlaceOfVar(var);
633 634 635
  if (!dst_var->OverridedStopGradient()) {
    if (CurCnt() == 0) {
      MoveOrCopyVar(dst_var->MutableVar(), var->MutableVar(), unchange_input);
636
    } else {
637 638 639
      VLOG(6) << "Sum Gradient for: " << dst_var->Name()
              << " within this graph.";
      VariableWrapperAdd(var, dst_var, unchange_input);
640
    }
J
Jiabin Yang 已提交
641
  } else {
642 643 644
    if (!dst_var->Var().IsInitialized() ||
        !dst_var->Var().Get<framework::LoDTensor>().IsInitialized()) {
      VLOG(6) << "Set StopGradient Grad: " << dst_var->Name() << " as zero ";
645
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
646 647 648 649
      if (!dst_var->Var().IsInitialized()) {
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
        VLOG(6) << "Dims of " << dst_var->Name() << " is set as: "
650 651 652 653 654
                << var->Var().Get<framework::LoDTensor>().dims();
        tensor->Resize(var->Var().Get<framework::LoDTensor>().dims());
        tensor->mutable_data(place, var->DataType());
        operators::math::set_constant(*dev_ctx, tensor, 0.0);
      } else {
655 656
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
657 658 659
        tensor->mutable_data(place, var->DataType());
        operators::math::set_constant(*dev_ctx, tensor, 0.0);
      }
660
    }
J
Jiabin Yang 已提交
661
  }
662

663 664 665 666
  // Type may be changed after OP run, such as VarTypeInference
  // so synchronous VariableWrapper with Variable.
  if (dst_var->Var().IsType<framework::LoDTensor>()) {
    dst_var->SetType(framework::proto::VarType::LOD_TENSOR);
667
  } else if (dst_var->Var().IsType<pten::SelectedRows>()) {
668
    dst_var->SetType(framework::proto::VarType::SELECTED_ROWS);
669
  }
670

671
  // Increase curent count
672
  IncreaseCurCnt();
J
Jiabin Yang 已提交
673 674
}

675 676 677
void SortedGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
                                        size_t trace_id, bool unchange_input) {
  auto* dst_var = Var();
678
  platform::Place place = GetPlaceOfVar(var);
679
  if (!dst_var->OverridedStopGradient()) {
680
    if (ref_cnt_ == 1) {
681
      MoveOrCopyVar(dst_var->MutableVar(), var->MutableVar(),
682
                    unchange_input || var->HasGradNode());
683 684 685 686 687
    } else {
      if (tmp_grad_vars_.empty()) {
        tmp_grad_vars_.reserve(ref_cnt_);
      }

688
      tmp_grad_vars_.emplace_back(std::move(var), trace_id, unchange_input);
689 690 691 692 693

      if (tmp_grad_vars_.size() != ref_cnt_) {
        return;
      }

694 695
      VLOG(6) << "Sum Gradient for: " << dst_var->Name()
              << " within this graph.";
696 697 698 699 700 701 702 703 704 705
      std::sort(tmp_grad_vars_.begin(), tmp_grad_vars_.end(),
                [](const SavedVarInfo& info1, const SavedVarInfo& info2) {
                  return info1.trace_id > info2.trace_id;
                });

      for (auto& var_info : tmp_grad_vars_) {
        if (var_info.var->HasGradNode()) {
          var_info.unchange_input = true;
        }
      }
706

707
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
708
      if (paddle::platform::is_gpu_place(place)) {
709
        // sum selected rows firstly
710
        for (auto& var_info : tmp_grad_vars_) {
711
          if (!var_info.var->Var().IsType<pten::SelectedRows>()) {
712
            continue;
713
          }
714

715 716
          if (CurCnt() == 0) {
            MoveOrCopyVar(dst_var->MutableVar(), var_info.var->MutableVar(),
717 718
                          var_info.unchange_input);
          } else {
719
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
720
          }
721 722

          var_info.var = nullptr;
723 724
          // Increase count
          IncreaseCurCnt();
725 726 727 728 729 730 731 732 733 734
        }

        for (auto& var_info : tmp_grad_vars_) {
          if (!var_info.var) {
            continue;
          }

          PADDLE_ENFORCE_EQ(var_info.var->Var().IsType<framework::LoDTensor>(),
                            true, platform::errors::PermissionDenied(
                                      "Gradient var must be LoDTensor"));
735 736
          if (CurCnt() == 0) {
            MoveOrCopyVar(dst_var->MutableVar(), var_info.var->MutableVar(),
737 738
                          var_info.unchange_input);
          } else {
739
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
740
          }
741 742

          var_info.var = nullptr;
743 744
          // Increase count
          IncreaseCurCnt();
745 746 747
        }
      } else {
#endif
748 749 750 751 752 753
        for (auto& var_info : tmp_grad_vars_) {
          if (!var_info.var) {
            continue;
          }
          PADDLE_ENFORCE_EQ(
              var_info.var->Var().IsType<framework::LoDTensor>() ||
754
                  var_info.var->Var().IsType<pten::SelectedRows>(),
755 756 757 758 759 760 761 762 763 764 765 766
              true, platform::errors::PermissionDenied("The type of Gradient "
                                                       "var must be LoDTensor "
                                                       "or SelectedRows"));
          if (CurCnt() == 0) {
            MoveOrCopyVar(dst_var->MutableVar(), var_info.var->MutableVar(),
                          var_info.unchange_input);
          } else {
            VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input);
          }
          var_info.var = nullptr;
          // Increase count
          IncreaseCurCnt();
767
        }
768
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
769
      }
770
#endif
771
      tmp_grad_vars_.clear();
J
Jiabin Yang 已提交
772
    }
773
  } else {
774 775
    if (!dst_var->Var().IsInitialized() ||
        !dst_var->Var().Get<framework::LoDTensor>().IsInitialized()) {
776 777
      VLOG(6) << "Set StopGradient Grad: " << var->Name() << " as zero";
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
778 779 780 781
      if (!dst_var->Var().IsInitialized()) {
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
        VLOG(6) << "Dims of " << dst_var->Name() << " is set as: "
782 783 784 785 786
                << var->Var().Get<framework::LoDTensor>().dims();
        tensor->Resize(var->Var().Get<framework::LoDTensor>().dims());
        tensor->mutable_data(place, var->DataType());
        operators::math::set_constant(*dev_ctx, tensor, 0.0);
      } else {
787 788
        auto* tensor =
            dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
789 790 791
        tensor->mutable_data(place, var->DataType());
        operators::math::set_constant(*dev_ctx, tensor, 0.0);
      }
J
Jiabin Yang 已提交
792
    }
793
    // looks like tmp_grad_vars will not have any member but just in case
J
Jiabin Yang 已提交
794 795
    tmp_grad_vars_.clear();
  }
796

797 798
  if (dst_var->Var().IsType<framework::LoDTensor>()) {
    dst_var->SetType(framework::proto::VarType::LOD_TENSOR);
799
  } else if (dst_var->Var().IsType<pten::SelectedRows>()) {
800
    dst_var->SetType(framework::proto::VarType::SELECTED_ROWS);
801
  }
J
Jiabin Yang 已提交
802 803 804 805
}

}  // namespace imperative
}  // namespace paddle