adam_op.h 23.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
Y
Yang Yu 已提交
16
#include <math.h>  // for sqrt in CPU and CUDA
17
#include <Eigen/Dense>
18
#include <string>
S
sneaxiy 已提交
19
#include <unordered_map>
S
sneaxiy 已提交
20
#include <vector>
Y
Yi Wang 已提交
21
#include "paddle/fluid/framework/op_registry.h"
Q
Qiao Longfei 已提交
22
#include "paddle/fluid/framework/threadpool.h"
23
#include "paddle/fluid/operators/jit/kernels.h"
S
sneaxiy 已提交
24
#include "paddle/fluid/operators/math/algorithm.h"
Y
Yi Wang 已提交
25 26
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/for_range.h"
27
#include "paddle/fluid/platform/profiler.h"
28 29 30 31

namespace paddle {
namespace operators {

T
wip  
typhoonzero 已提交
32 33
namespace scatter = paddle::operators::math::scatter;

34 35 36 37
static inline float GetAttrFromTensor(const framework::Tensor* tensor) {
  const float* tensor_data = tensor->data<float>();
  framework::Tensor cpu_tensor;
  if (platform::is_gpu_place(tensor->place())) {
38 39
    paddle::framework::TensorCopySync(*tensor, platform::CPUPlace(),
                                      &cpu_tensor);
40 41
    tensor_data = cpu_tensor.data<float>();
  }
Y
yinhaofeng 已提交
42
  if (platform::is_xpu_place(tensor->place())) {
43 44
    paddle::framework::TensorCopySync(*tensor, platform::CPUPlace(),
                                      &cpu_tensor);
Y
yinhaofeng 已提交
45 46
    tensor_data = cpu_tensor.data<float>();
  }
47 48 49
  return tensor_data[0];
}

Y
Yibing Liu 已提交
50 51 52 53 54 55 56
class AdamOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override;
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
57 58 59
  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const framework::Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override;
Y
Yibing Liu 已提交
60 61
};

62 63 64 65
struct GPUAdam;
struct CPUAdam;

template <typename T, typename Flavour>
A
Aurelius84 已提交
66
class AdamFunctor;
67

A
Aurelius84 已提交
68 69 70
template <typename T>
class AdamFunctor<T, GPUAdam> {
 private:
Y
Yang Yu 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83
  T beta1_;
  T beta2_;
  T epsilon_;

  const T* beta1_pow_;
  const T* beta2_pow_;
  const T* moment1_;
  T* moment1_out_;
  const T* moment2_;
  T* moment2_out_;
  const T* lr_;
  const T* grad_;
  const T* param_;
Y
Yang Yu 已提交
84
  T* param_out_;
Y
Yang Yu 已提交
85

A
Aurelius84 已提交
86
 public:
Y
Yang Yu 已提交
87 88
  AdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
              const T* beta2_pow, const T* mom1, T* mom1_out, const T* mom2,
Y
Yang Yu 已提交
89 90
              T* mom2_out, const T* lr, const T* grad, const T* param,
              T* param_out)
Y
Yang Yu 已提交
91 92 93 94 95 96 97 98 99 100 101
      : beta1_(beta1),
        beta2_(beta2),
        epsilon_(epsilon),
        beta1_pow_(beta1_pow),
        beta2_pow_(beta2_pow),
        moment1_(mom1),
        moment1_out_(mom1_out),
        moment2_(mom2),
        moment2_out_(mom2_out),
        lr_(lr),
        grad_(grad),
Y
Yang Yu 已提交
102 103
        param_(param),
        param_out_(param_out) {}
Y
Yang Yu 已提交
104

Y
Yang Yu 已提交
105
  inline HOSTDEVICE void operator()(size_t i) const {
Y
Yang Yu 已提交
106 107 108 109 110 111 112
    // Merge all memory access together.
    T g = grad_[i];
    T mom1 = moment1_[i];
    T mom2 = moment2_[i];
    T lr = *lr_;
    T beta1_pow = *beta1_pow_;
    T beta2_pow = *beta2_pow_;
Y
Yang Yu 已提交
113
    T p = param_[i];
Y
Yang Yu 已提交
114 115

    // Calculation
Y
Yang Yu 已提交
116
    lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
117

Y
Yang Yu 已提交
118 119
    mom1 = beta1_ * mom1 + (1 - beta1_) * g;
    mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
M
MRXLT 已提交
120
    p -= lr * (mom1 / (sqrt(mom2) + epsilon_ * sqrt(1 - beta2_pow)));
Y
Yang Yu 已提交
121 122 123 124

    // Write back to global memory
    moment1_out_[i] = mom1;
    moment2_out_[i] = mom2;
Y
Yang Yu 已提交
125
    param_out_[i] = p;
Y
Yang Yu 已提交
126 127 128
  }
};

129
template <typename T>
A
Aurelius84 已提交
130 131
class AdamFunctor<T, CPUAdam> {
 private:
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
  T beta1_;
  T beta2_;
  T epsilon_;

  const T* beta1_pow_;
  const T* beta2_pow_;
  const T* moment1_;
  T* moment1_out_;
  const T* moment2_;
  T* moment2_out_;
  const T* lr_;
  const T* grad_;
  const T* param_;
  T* param_out_;

A
Aurelius84 已提交
147
 public:
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
  AdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
              const T* beta2_pow, const T* mom1, T* mom1_out, const T* mom2,
              T* mom2_out, const T* lr, const T* grad, const T* param,
              T* param_out)
      : beta1_(beta1),
        beta2_(beta2),
        epsilon_(epsilon),
        beta1_pow_(beta1_pow),
        beta2_pow_(beta2_pow),
        moment1_(mom1),
        moment1_out_(mom1_out),
        moment2_(mom2),
        moment2_out_(mom2_out),
        lr_(lr),
        grad_(grad),
        param_(param),
        param_out_(param_out) {}

  void operator()(size_t numel) const {
    Eigen::Map<const Eigen::Array<T, 1, Eigen::Dynamic>> g{
        grad_, static_cast<Eigen::Index>(numel)};
    Eigen::Map<const Eigen::Array<T, 1, Eigen::Dynamic>> mom1{
        moment1_, static_cast<Eigen::Index>(numel)};
    Eigen::Map<const Eigen::Array<T, 1, Eigen::Dynamic>> mom2{
        moment2_, static_cast<Eigen::Index>(numel)};
    Eigen::Map<const Eigen::Array<T, 1, Eigen::Dynamic>> param{
        param_, static_cast<Eigen::Index>(numel)};

    Eigen::Map<Eigen::Array<T, 1, Eigen::Dynamic>> param_out{
        param_out_, static_cast<Eigen::Index>(numel)};
    Eigen::Map<Eigen::Array<T, 1, Eigen::Dynamic>> moment1_out{
        moment1_out_, static_cast<Eigen::Index>(numel)};
    Eigen::Map<Eigen::Array<T, 1, Eigen::Dynamic>> moment2_out{
        moment2_out_, static_cast<Eigen::Index>(numel)};

    T lr = *lr_;
    T beta1_pow = *beta1_pow_;
    T beta2_pow = *beta2_pow_;

    // Calculation
    lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);

    moment1_out = beta1_ * mom1 + (1 - beta1_) * g;
    moment2_out = beta2_ * mom2 + (1 - beta2_) * g * g;
M
MRXLT 已提交
192 193 194
    param_out = param -
                lr * (moment1_out /
                      (moment2_out.sqrt() + epsilon_ * sqrt(1 - beta2_pow)));
195 196 197
  }
};

198
template <typename T, typename Flavour, typename MT = T>
A
Aurelius84 已提交
199
class SparseAdamFunctor;
200

201 202
template <typename T, typename MT>
class SparseAdamFunctor<T, GPUAdam, MT> {
A
Aurelius84 已提交
203
 private:
204 205 206 207 208 209 210 211 212 213 214
  MT beta1_;
  MT beta2_;
  MT epsilon_;

  const MT* beta1_pow_;
  const MT* beta2_pow_;
  const MT* moment1_;
  MT* moment1_out_;
  const MT* moment2_;
  MT* moment2_out_;
  const MT* lr_;
T
wip  
typhoonzero 已提交
215 216 217
  const T* grad_;
  const T* param_;
  T* param_out_;
218 219
  const MT* master_param_;
  MT* master_param_out_;
T
wip  
typhoonzero 已提交
220 221 222

  const int64_t* rows_;
  int64_t row_numel_;
S
sneaxiy 已提交
223
  int64_t row_count_;
Q
Qiao Longfei 已提交
224
  bool lazy_mode_;
T
wip  
typhoonzero 已提交
225

A
Aurelius84 已提交
226
 public:
227 228 229 230 231
  SparseAdamFunctor(MT beta1, MT beta2, MT epsilon, const MT* beta1_pow,
                    const MT* beta2_pow, const MT* mom1, MT* mom1_out,
                    const MT* mom2, MT* mom2_out, const MT* lr, const T* grad,
                    const T* param, T* param_out, const MT* master_param,
                    MT* master_param_out, const int64_t* rows,
Q
Qiao Longfei 已提交
232
                    int64_t row_numel, int64_t row_count, bool lazy_mode)
T
wip  
typhoonzero 已提交
233 234 235 236 237 238 239 240 241 242 243 244 245
      : beta1_(beta1),
        beta2_(beta2),
        epsilon_(epsilon),
        beta1_pow_(beta1_pow),
        beta2_pow_(beta2_pow),
        moment1_(mom1),
        moment1_out_(mom1_out),
        moment2_(mom2),
        moment2_out_(mom2_out),
        lr_(lr),
        grad_(grad),
        param_(param),
        param_out_(param_out),
246 247
        master_param_(master_param),
        master_param_out_(master_param_out),
T
wip  
typhoonzero 已提交
248
        rows_(rows),
S
sneaxiy 已提交
249
        row_numel_(row_numel),
Q
Qiao Longfei 已提交
250
        row_count_(row_count),
Q
Qiao Longfei 已提交
251
        lazy_mode_(lazy_mode) {}
S
sneaxiy 已提交
252

253
  inline HOSTDEVICE void adam_update(size_t i, MT g) const {
S
sneaxiy 已提交
254
    // The following code is the same as dense
255 256 257 258 259 260
    MT mom1 = moment1_[i];
    MT mom2 = moment2_[i];
    MT lr = *lr_;
    MT beta1_pow = *beta1_pow_;
    MT beta2_pow = *beta2_pow_;
    MT p = master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
S
sneaxiy 已提交
261 262

    // Calculation
263 264
    lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
          (static_cast<MT>(1.0) - beta1_pow);
S
sneaxiy 已提交
265

266 267 268 269
    mom1 = beta1_ * mom1 + (static_cast<MT>(1.0) - beta1_) * g;
    mom2 = beta2_ * mom2 + (static_cast<MT>(1.0) - beta2_) * g * g;
    p -= lr * (mom1 / (sqrt(mom2) +
                       epsilon_ * sqrt(static_cast<MT>(1.0) - beta2_pow)));
S
sneaxiy 已提交
270 271 272 273

    // Write back to global memory
    moment1_out_[i] = mom1;
    moment2_out_[i] = mom2;
274 275 276 277
    param_out_[i] = static_cast<T>(p);
    if (master_param_out_) {
      master_param_out_[i] = p;
    }
T
wip  
typhoonzero 已提交
278
  }
Q
Qiao Longfei 已提交
279 280 281 282

  inline HOSTDEVICE void operator()(size_t i) const {
    auto row_idx =
        math::BinarySearch<int64_t>(rows_, row_count_, i / row_numel_);
Q
Qiao Longfei 已提交
283 284 285
    if (lazy_mode_ && row_idx < 0) {
      return;
    } else {
286 287 288
      MT g = row_idx >= 0
                 ? static_cast<MT>(grad_[row_idx * row_numel_ + i % row_numel_])
                 : static_cast<MT>(0);
Q
Qiao Longfei 已提交
289 290
      adam_update(i, g);
    }
Q
Qiao Longfei 已提交
291
  }
T
wip  
typhoonzero 已提交
292 293
};

M
minqiyang 已提交
294
template <typename T>
295
class SparseAdamFunctor<T, CPUAdam, T> {
A
Aurelius84 已提交
296
 private:
M
minqiyang 已提交
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
  T beta1_;
  T beta2_;
  T epsilon_;

  const T* beta1_pow_;
  const T* beta2_pow_;
  const T* moment1_;
  T* moment1_out_;
  const T* moment2_;
  T* moment2_out_;
  const T* lr_;
  const T* grad_;
  const T* param_;
  T* param_out_;

  const int64_t* rows_;
  int64_t row_numel_;
  int64_t row_count_;

A
Aurelius84 已提交
316
 public:
M
minqiyang 已提交
317 318 319 320
  SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
                    const T* beta2_pow, const T* mom1, T* mom1_out,
                    const T* mom2, T* mom2_out, const T* lr, const T* grad,
                    const T* param, T* param_out, const int64_t* rows,
321
                    int64_t row_numel, int64_t row_count, bool lazy_mode)
M
minqiyang 已提交
322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
      : beta1_(beta1),
        beta2_(beta2),
        epsilon_(epsilon),
        beta1_pow_(beta1_pow),
        beta2_pow_(beta2_pow),
        moment1_(mom1),
        moment1_out_(mom1_out),
        moment2_(mom2),
        moment2_out_(mom2_out),
        lr_(lr),
        grad_(grad),
        param_(param),
        param_out_(param_out),
        rows_(rows),
        row_numel_(row_numel),
        row_count_(row_count) {}

339 340 341 342 343 344 345 346 347 348 349 350 351 352
  inline HOSTDEVICE void adam_update(size_t i, T g) const {
    // The following code is the same as dense
    T mom1 = moment1_[i];
    T mom2 = moment2_[i];
    T lr = *lr_;
    T beta1_pow = *beta1_pow_;
    T beta2_pow = *beta2_pow_;
    T p = param_[i];

    // Calculation
    lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);

    mom1 = beta1_ * mom1 + (1 - beta1_) * g;
    mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
M
MRXLT 已提交
353
    p -= lr * (mom1 / (sqrt(mom2) + epsilon_ * sqrt(1 - beta2_pow)));
354 355 356 357 358 359 360

    // Write back to global memory
    moment1_out_[i] = mom1;
    moment2_out_[i] = mom2;
    param_out_[i] = p;
  }

M
minqiyang 已提交
361 362 363 364 365 366
  inline void operator()(size_t numel) const {
    // lr could be reuse
    T lr = *lr_;
    T beta1_pow = *beta1_pow_;
    T beta2_pow = *beta2_pow_;
    lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
S
sneaxiy 已提交
367
    int64_t row_count = static_cast<int64_t>(numel / row_numel_);
M
minqiyang 已提交
368

S
sneaxiy 已提交
369
    for (int64_t i = 0, j = 0; i != row_count; ++i) {
M
minqiyang 已提交
370
      if (i == *(rows_ + j)) {
S
sneaxiy 已提交
371
        for (int64_t k = 0; k != row_numel_; ++k) {
M
Fix bug  
minqiyang 已提交
372
          T g = grad_[j * row_numel_ + k];
M
minqiyang 已提交
373
          adam_update(i * row_numel_ + k, g);
M
Fix bug  
minqiyang 已提交
374
        }
M
minqiyang 已提交
375 376
        ++j;
      } else {
S
sneaxiy 已提交
377
        for (int64_t k = 0; k != row_numel_; ++k) {
M
Fix bug  
minqiyang 已提交
378 379 380 381 382 383 384 385 386 387 388 389 390
          T mom1 = moment1_[i * row_numel_ + k];
          T mom2 = moment2_[i * row_numel_ + k];
          T p = param_[i * row_numel_ + k];

          mom1 = beta1_ * mom1;
          mom2 = beta2_ * mom2;

          p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
          // Write back to global memory
          moment1_out_[i * row_numel_ + k] = mom1;
          moment2_out_[i * row_numel_ + k] = mom2;
          param_out_[i * row_numel_ + k] = p;
        }
M
minqiyang 已提交
391 392 393 394 395
      }
    }
  }
};

Q
QI JUN 已提交
396
template <typename DeviceContext, typename T>
397 398 399
class AdamOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
C
chengduo 已提交
400
    const auto* param_var = ctx.InputVar("Param");
401 402 403 404 405 406
    PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
                      platform::errors::InvalidArgument(
                          "The Var(%s)'s type should be LoDTensor, "
                          "but the received is %s",
                          ctx.InputNames("Param").front(),
                          framework::ToTypeName(param_var->Type())));
C
chengduo 已提交
407

Y
Yang Yu 已提交
408
    using paddle::framework::LoDTensor;
409

410 411
    int64_t min_row_size_to_use_multithread =
        ctx.Attr<int64_t>("min_row_size_to_use_multithread");
Q
Qiao Longfei 已提交
412
    bool lazy_mode = ctx.Attr<bool>("lazy_mode");
413 414
    bool use_global_beta_pow = ctx.Attr<bool>("use_global_beta_pow");
    VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
415

416
    auto* param = ctx.Input<LoDTensor>("Param");
T
wip  
typhoonzero 已提交
417
    auto* grad_var = ctx.InputVar("Grad");
418 419 420 421 422 423 424 425 426 427 428
    auto* mom1 = ctx.Input<LoDTensor>("Moment1");
    auto* mom2 = ctx.Input<LoDTensor>("Moment2");
    auto* lr = ctx.Input<LoDTensor>("LearningRate");
    auto* beta1_pow = ctx.Input<LoDTensor>("Beta1Pow");
    auto* beta2_pow = ctx.Input<LoDTensor>("Beta2Pow");

    auto* param_out = ctx.Output<LoDTensor>("ParamOut");
    auto* mom1_out = ctx.Output<LoDTensor>("Moment1Out");
    auto* mom2_out = ctx.Output<LoDTensor>("Moment2Out");
    auto* beta1_pow_out = ctx.Output<LoDTensor>("Beta1PowOut");
    auto* beta2_pow_out = ctx.Output<LoDTensor>("Beta2PowOut");
Y
Yang Yu 已提交
429

430 431 432 433 434 435 436 437
    bool skip_update = false;
    if (ctx.HasInput("SkipUpdate")) {
      auto* skip_update_tensor = ctx.Input<framework::Tensor>("SkipUpdate");
      PADDLE_ENFORCE_EQ(skip_update_tensor->numel(), 1,
                        platform::errors::InvalidArgument(
                            "Input(SkipUpdate) size must be 1, but get %d",
                            skip_update_tensor->numel()));
      std::vector<bool> skip_update_vec;
438 439
      paddle::framework::TensorToVector(*skip_update_tensor,
                                        ctx.device_context(), &skip_update_vec);
440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465
      skip_update = skip_update_vec[0];
    }
    // skip_update=true, just copy input to output, and TensorCopy will call
    // mutable_data
    if (skip_update) {
      VLOG(4) << "Adam skip update";
      framework::TensorCopy(
          *param, ctx.GetPlace(),
          ctx.template device_context<platform::DeviceContext>(), param_out);
      framework::TensorCopy(
          *mom1, ctx.GetPlace(),
          ctx.template device_context<platform::DeviceContext>(), mom1_out);
      framework::TensorCopy(
          *mom2, ctx.GetPlace(),
          ctx.template device_context<platform::DeviceContext>(), mom2_out);
      framework::TensorCopy(
          *beta1_pow, ctx.GetPlace(),
          ctx.template device_context<platform::DeviceContext>(),
          beta1_pow_out);
      framework::TensorCopy(
          *beta2_pow, ctx.GetPlace(),
          ctx.template device_context<platform::DeviceContext>(),
          beta2_pow_out);
      return;
    }

466 467 468
    T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
    if (ctx.HasInput("Beta1Tensor")) {
      auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
W
wangchaochaohu 已提交
469 470 471 472
      PADDLE_ENFORCE_EQ(beta1_tensor->numel(), 1,
                        platform::errors::InvalidArgument(
                            "Input(Beta1Tensor) size must be 1, but get %d",
                            beta1_tensor->numel()));
473 474 475 476 477
      beta1 = static_cast<T>(GetAttrFromTensor(beta1_tensor));
    }
    T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
    if (ctx.HasInput("Beta2Tensor")) {
      auto* beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor");
W
wangchaochaohu 已提交
478 479 480 481
      PADDLE_ENFORCE_EQ(beta2_tensor->numel(), 1,
                        platform::errors::InvalidArgument(
                            "Input(Beta2Tensor) size must be 1, but get %d",
                            beta2_tensor->numel()));
482 483
      beta2 = static_cast<T>(GetAttrFromTensor(beta2_tensor));
    }
484 485 486 487 488 489 490 491 492
    T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
    if (ctx.HasInput("EpsilonTensor")) {
      auto* epsilon_tensor = ctx.Input<framework::Tensor>("EpsilonTensor");
      PADDLE_ENFORCE_EQ(epsilon_tensor->numel(), 1,
                        platform::errors::InvalidArgument(
                            "Input(EpsilonTensor) size must be 1, but get %d",
                            epsilon_tensor->numel()));
      epsilon = static_cast<T>(GetAttrFromTensor(epsilon_tensor));
    }
493

494 495 496 497 498 499 500 501 502 503 504 505 506 507 508
    VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel()
            << "beta2_pow.numel() : " << beta2_pow->numel();
    VLOG(3) << "param.numel(): " << param->numel();

    PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), 1,
                      platform::errors::InvalidArgument(
                          "beta1 pow output size should be 1, but received "
                          "value is:%d.",
                          beta1_pow_out->numel()));

    PADDLE_ENFORCE_EQ(beta2_pow_out->numel(), 1,
                      platform::errors::InvalidArgument(
                          "beta2 pow output size should be 1, but received "
                          "value is:%d.",
                          beta2_pow_out->numel()));
509

T
wip  
typhoonzero 已提交
510
    if (grad_var->IsType<framework::LoDTensor>()) {
511 512
      T beta1_p = beta1_pow->data<T>()[0];
      T beta2_p = beta2_pow->data<T>()[0];
513

514 515 516 517 518 519
      if (!use_global_beta_pow) {
        beta1_pow_out->mutable_data<T>(ctx.GetPlace())[0] =
            beta1 * beta1_pow->data<T>()[0];
        beta2_pow_out->mutable_data<T>(ctx.GetPlace())[0] =
            beta2 * beta2_pow->data<T>()[0];
      }
520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562

      auto* grad = ctx.Input<LoDTensor>("Grad");

      T* param_out_ptr = param_out->mutable_data<T>(ctx.GetPlace());
      T* mom1_out_ptr = mom1_out->mutable_data<T>(ctx.GetPlace());
      T* mom2_out_ptr = mom2_out->mutable_data<T>(ctx.GetPlace());

      T learning_rate = lr->data<T>()[0] * (sqrt(1 - beta2_p) / (1 - beta1_p));
      T eps = epsilon * sqrt(1 - beta2_p);

      jit::adam_attr_t attr(beta1, beta2);
      int64_t numel = param->numel();

      const T* param_ptr = param->data<T>();
      const T* mom1_ptr = mom1->data<T>();
      const T* mom2_ptr = mom2->data<T>();
      const T* grad_ptr = grad->data<T>();

      auto adam =
          jit::KernelFuncs<jit::AdamTuple<T>, platform::CPUPlace>::Cache().At(
              attr);

      static constexpr int64_t chunk_size = 512;

#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
      for (int64_t i = 0; i < numel / chunk_size; ++i) {
        const int64_t offset = i * chunk_size;
        adam(beta1, beta2, -learning_rate, eps, chunk_size, grad_ptr + offset,
             mom1_ptr + offset, mom2_ptr + offset, param_ptr + offset,
             mom1_out_ptr + offset, mom2_out_ptr + offset,
             param_out_ptr + offset);
      }

      if (numel % chunk_size != 0) {
        const int64_t offset = (numel / chunk_size) * chunk_size;
        const int64_t tail_numel = numel % chunk_size;
        adam(beta1, beta2, -learning_rate, eps, tail_numel, grad_ptr + offset,
             mom1_ptr + offset, mom2_ptr + offset, param_ptr + offset,
             mom1_out_ptr + offset, mom2_out_ptr + offset,
             param_out_ptr + offset);
      }
563 564
    } else if (grad_var->IsType<pten::SelectedRows>()) {
      auto* grad = ctx.Input<pten::SelectedRows>("Grad");
565
      if (grad->rows().size() == 0) {
M
minqiyang 已提交
566
        VLOG(3) << "grad row size is 0!!";
567 568
        return;
      }
S
sneaxiy 已提交
569

570
      std::vector<int64_t> cpu_rows(grad->rows().begin(), grad->rows().end());
S
sneaxiy 已提交
571 572 573 574 575 576 577 578
      bool is_strict_sorted = true;
      for (size_t i = 1; i < cpu_rows.size(); ++i) {
        if (cpu_rows[i - 1] >= cpu_rows[i]) {
          is_strict_sorted = false;
          break;
        }
      }

579 580
      pten::SelectedRows tmp_grad_merge;
      const pten::SelectedRows* grad_merge_ptr;
S
sneaxiy 已提交
581
      if (is_strict_sorted) {
582
        grad_merge_ptr = grad;
S
sneaxiy 已提交
583 584 585 586
      } else {
        // merge duplicated rows if any.
        // The rows of grad_merge have been sorted inside MergeAdd functor
        scatter::MergeAdd<DeviceContext, T> merge_func;
587
        merge_func(ctx.template device_context<DeviceContext>(), *grad,
S
sneaxiy 已提交
588 589
                   &tmp_grad_merge, true);
        grad_merge_ptr = &tmp_grad_merge;
S
sneaxiy 已提交
590 591 592
      }

      auto& grad_merge = *grad_merge_ptr;
T
wip  
typhoonzero 已提交
593
      auto& grad_tensor = grad_merge.value();
T
wip  
typhoonzero 已提交
594
      const T* grad_data = grad_tensor.template data<T>();
S
sneaxiy 已提交
595
      const int64_t* rows = grad_merge.rows().Data(ctx.GetPlace());
T
wip  
typhoonzero 已提交
596
      auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
T
wip  
typhoonzero 已提交
597

598 599 600 601 602 603 604 605
      SparseAdamFunctor<T, CPUAdam> functor(
          beta1, beta2, epsilon, beta1_pow->data<T>(), beta2_pow->data<T>(),
          mom1->data<T>(), mom1_out->mutable_data<T>(ctx.GetPlace()),
          mom2->data<T>(), mom2_out->mutable_data<T>(ctx.GetPlace()),
          lr->data<T>(), grad_data, param->data<T>(),
          param_out->mutable_data<T>(ctx.GetPlace()), rows, row_numel,
          grad_merge.rows().size(), lazy_mode);
      // update beta1 and beta2
606 607 608 609 610 611
      if (!use_global_beta_pow) {
        beta1_pow_out->mutable_data<T>(ctx.GetPlace())[0] =
            beta1 * beta1_pow->data<T>()[0];
        beta2_pow_out->mutable_data<T>(ctx.GetPlace())[0] =
            beta2 * beta2_pow->data<T>()[0];
      }
612 613 614 615 616 617 618 619
      if (lazy_mode) {
        VLOG(3) << "run cpu lazy mode";
        size_t row_count = grad_merge.rows().size();
        std::vector<int64_t> cpu_rows(grad_merge.rows());
        for (size_t row_index = 0; row_index < row_count; ++row_index) {
          for (size_t offset = 0; offset < row_numel; ++offset) {
            size_t i = cpu_rows[row_index] * row_numel + offset;
            functor.adam_update(i, grad_data[row_index * row_numel + offset]);
620
          }
621
        }
622
      }
623
#ifndef _WIN32
624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653
      else if (FLAGS_inner_op_parallelism > 1 &&  // NOLINT
               min_row_size_to_use_multithread > 0 &&
               param->dims()[0] > min_row_size_to_use_multithread) {
        VLOG(3) << "use multi thread, inner_op_parallelism="
                << FLAGS_inner_op_parallelism
                << " min_row_size_to_use_multithread="
                << min_row_size_to_use_multithread;
        if (FLAGS_inner_op_parallelism > 10) {
          VLOG(1) << "FLAGS_inner_op_parallelism " << FLAGS_inner_op_parallelism
                  << " is two large!";
        }
        auto& grad_rows = grad_merge.rows();
        std::unordered_map<size_t, int> row_id_to_grad_row_offset;
        size_t param_row_count = param->numel() / row_numel;
        if (param_row_count < 1000) {
          VLOG(1) << "param_row_count should be larger then 1000 to use "
                     "multi thread, currently "
                  << param_row_count;
        }
        for (size_t i = 0; i < grad_rows.size(); ++i) {
          row_id_to_grad_row_offset[grad_rows[i]] = i;
        }
        std::vector<std::future<void>> fs;
        int64_t line_in_each_thread =
            param_row_count / FLAGS_inner_op_parallelism + 1;
        for (int i = 0; i < FLAGS_inner_op_parallelism; ++i) {
          int64_t start = i * line_in_each_thread;
          int64_t end = (i + 1) * line_in_each_thread;
          if (start >= static_cast<int64_t>(param_row_count)) {
            break;
Q
Qiao Longfei 已提交
654
          }
655 656
          if (end > static_cast<int64_t>(param_row_count)) {
            end = static_cast<int64_t>(param_row_count);
Q
Qiao Longfei 已提交
657
          }
658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674
          fs.push_back(framework::Async([&functor, &row_id_to_grad_row_offset,
                                         &grad_data, row_numel, start, end]() {
            for (int64_t row_id = start; row_id < end; ++row_id) {
              auto iter = row_id_to_grad_row_offset.find(row_id);
              if (iter != row_id_to_grad_row_offset.end()) {
                for (size_t row_offset = 0U; row_offset < row_numel;
                     ++row_offset) {
                  functor.adam_update(
                      row_id * row_numel + row_offset,
                      grad_data[iter->second * row_numel + row_offset]);
                }
              } else {
                for (size_t row_offset = 0U; row_offset < row_numel;
                     ++row_offset) {
                  functor.adam_update(row_id * row_numel + row_offset, 0);
                }
              }
Q
Qiao Longfei 已提交
675
            }
676
          }));
677
        }
678 679 680 681 682
        for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
      }
#endif        // !_WIN32
      else {  // NOLINT
        functor(param->numel());
Q
Qiao Longfei 已提交
683
      }
T
wip  
typhoonzero 已提交
684
    } else {
685 686
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Variable type not supported by adam_op"));
T
wip  
typhoonzero 已提交
687
    }
688 689 690 691 692
  }
};

}  // namespace operators
}  // namespace paddle