adam_op.h 19.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
Y
Yang Yu 已提交
16
#include <math.h>  // for sqrt in CPU and CUDA
17
#include <Eigen/Dense>
S
sneaxiy 已提交
18
#include <unordered_map>
S
sneaxiy 已提交
19
#include <vector>
Y
Yi Wang 已提交
20
#include "paddle/fluid/framework/op_registry.h"
Q
Qiao Longfei 已提交
21
#include "paddle/fluid/framework/threadpool.h"
Y
Yi Wang 已提交
22
#include "paddle/fluid/operators/detail/safe_ref.h"
S
sneaxiy 已提交
23
#include "paddle/fluid/operators/math/algorithm.h"
Y
Yi Wang 已提交
24 25
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/for_range.h"
26 27 28 29

namespace paddle {
namespace operators {

T
wip  
typhoonzero 已提交
30 31
namespace scatter = paddle::operators::math::scatter;

32 33 34 35 36 37 38 39 40 41
static inline float GetAttrFromTensor(const framework::Tensor* tensor) {
  const float* tensor_data = tensor->data<float>();
  framework::Tensor cpu_tensor;
  if (platform::is_gpu_place(tensor->place())) {
    TensorCopySync(*tensor, platform::CPUPlace(), &cpu_tensor);
    tensor_data = cpu_tensor.data<float>();
  }
  return tensor_data[0];
}

Y
Yibing Liu 已提交
42 43 44 45 46 47 48 49 50
class AdamOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override;
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
};

51 52 53 54 55 56
struct GPUAdam;
struct CPUAdam;

template <typename T, typename Flavour>
struct AdamFunctor;

Y
Yang Yu 已提交
57
template <typename T>
58
struct AdamFunctor<T, GPUAdam> {
Y
Yang Yu 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71
  T beta1_;
  T beta2_;
  T epsilon_;

  const T* beta1_pow_;
  const T* beta2_pow_;
  const T* moment1_;
  T* moment1_out_;
  const T* moment2_;
  T* moment2_out_;
  const T* lr_;
  const T* grad_;
  const T* param_;
Y
Yang Yu 已提交
72
  T* param_out_;
Y
Yang Yu 已提交
73 74 75

  AdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
              const T* beta2_pow, const T* mom1, T* mom1_out, const T* mom2,
Y
Yang Yu 已提交
76 77
              T* mom2_out, const T* lr, const T* grad, const T* param,
              T* param_out)
Y
Yang Yu 已提交
78 79 80 81 82 83 84 85 86 87 88
      : beta1_(beta1),
        beta2_(beta2),
        epsilon_(epsilon),
        beta1_pow_(beta1_pow),
        beta2_pow_(beta2_pow),
        moment1_(mom1),
        moment1_out_(mom1_out),
        moment2_(mom2),
        moment2_out_(mom2_out),
        lr_(lr),
        grad_(grad),
Y
Yang Yu 已提交
89 90
        param_(param),
        param_out_(param_out) {}
Y
Yang Yu 已提交
91

Y
Yang Yu 已提交
92
  inline HOSTDEVICE void operator()(size_t i) const {
Y
Yang Yu 已提交
93 94 95 96 97 98 99
    // Merge all memory access together.
    T g = grad_[i];
    T mom1 = moment1_[i];
    T mom2 = moment2_[i];
    T lr = *lr_;
    T beta1_pow = *beta1_pow_;
    T beta2_pow = *beta2_pow_;
Y
Yang Yu 已提交
100
    T p = param_[i];
Y
Yang Yu 已提交
101 102

    // Calculation
Y
Yang Yu 已提交
103
    lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
104

Y
Yang Yu 已提交
105 106
    mom1 = beta1_ * mom1 + (1 - beta1_) * g;
    mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
Y
Yang Yu 已提交
107
    p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
Y
Yang Yu 已提交
108 109 110 111

    // Write back to global memory
    moment1_out_[i] = mom1;
    moment2_out_[i] = mom2;
Y
Yang Yu 已提交
112
    param_out_[i] = p;
Y
Yang Yu 已提交
113 114 115
  }
};

116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
template <typename T>
struct AdamFunctor<T, CPUAdam> {
  T beta1_;
  T beta2_;
  T epsilon_;

  const T* beta1_pow_;
  const T* beta2_pow_;
  const T* moment1_;
  T* moment1_out_;
  const T* moment2_;
  T* moment2_out_;
  const T* lr_;
  const T* grad_;
  const T* param_;
  T* param_out_;

  AdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
              const T* beta2_pow, const T* mom1, T* mom1_out, const T* mom2,
              T* mom2_out, const T* lr, const T* grad, const T* param,
              T* param_out)
      : beta1_(beta1),
        beta2_(beta2),
        epsilon_(epsilon),
        beta1_pow_(beta1_pow),
        beta2_pow_(beta2_pow),
        moment1_(mom1),
        moment1_out_(mom1_out),
        moment2_(mom2),
        moment2_out_(mom2_out),
        lr_(lr),
        grad_(grad),
        param_(param),
        param_out_(param_out) {}

  void operator()(size_t numel) const {
    Eigen::Map<const Eigen::Array<T, 1, Eigen::Dynamic>> g{
        grad_, static_cast<Eigen::Index>(numel)};
    Eigen::Map<const Eigen::Array<T, 1, Eigen::Dynamic>> mom1{
        moment1_, static_cast<Eigen::Index>(numel)};
    Eigen::Map<const Eigen::Array<T, 1, Eigen::Dynamic>> mom2{
        moment2_, static_cast<Eigen::Index>(numel)};
    Eigen::Map<const Eigen::Array<T, 1, Eigen::Dynamic>> param{
        param_, static_cast<Eigen::Index>(numel)};

    Eigen::Map<Eigen::Array<T, 1, Eigen::Dynamic>> param_out{
        param_out_, static_cast<Eigen::Index>(numel)};
    Eigen::Map<Eigen::Array<T, 1, Eigen::Dynamic>> moment1_out{
        moment1_out_, static_cast<Eigen::Index>(numel)};
    Eigen::Map<Eigen::Array<T, 1, Eigen::Dynamic>> moment2_out{
        moment2_out_, static_cast<Eigen::Index>(numel)};

    T lr = *lr_;
    T beta1_pow = *beta1_pow_;
    T beta2_pow = *beta2_pow_;

    // Calculation
    lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);

    moment1_out = beta1_ * mom1 + (1 - beta1_) * g;
    moment2_out = beta2_ * mom2 + (1 - beta2_) * g * g;
    param_out = param - lr * (moment1_out / (moment2_out.sqrt() + epsilon_));
  }
};

181 182 183
template <typename T, typename Flavour>
struct SparseAdamFunctor;

T
wip  
typhoonzero 已提交
184
template <typename T>
M
minqiyang 已提交
185
struct SparseAdamFunctor<T, GPUAdam> {
T
wip  
typhoonzero 已提交
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
  T beta1_;
  T beta2_;
  T epsilon_;

  const T* beta1_pow_;
  const T* beta2_pow_;
  const T* moment1_;
  T* moment1_out_;
  const T* moment2_;
  T* moment2_out_;
  const T* lr_;
  const T* grad_;
  const T* param_;
  T* param_out_;

  const int64_t* rows_;
  int64_t row_numel_;
S
sneaxiy 已提交
203
  int64_t row_count_;
Q
Qiao Longfei 已提交
204
  bool lazy_mode_;
T
wip  
typhoonzero 已提交
205 206 207 208 209

  SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
                    const T* beta2_pow, const T* mom1, T* mom1_out,
                    const T* mom2, T* mom2_out, const T* lr, const T* grad,
                    const T* param, T* param_out, const int64_t* rows,
Q
Qiao Longfei 已提交
210
                    int64_t row_numel, int64_t row_count, bool lazy_mode)
T
wip  
typhoonzero 已提交
211 212 213 214 215 216 217 218 219 220 221 222 223 224
      : beta1_(beta1),
        beta2_(beta2),
        epsilon_(epsilon),
        beta1_pow_(beta1_pow),
        beta2_pow_(beta2_pow),
        moment1_(mom1),
        moment1_out_(mom1_out),
        moment2_(mom2),
        moment2_out_(mom2_out),
        lr_(lr),
        grad_(grad),
        param_(param),
        param_out_(param_out),
        rows_(rows),
S
sneaxiy 已提交
225
        row_numel_(row_numel),
Q
Qiao Longfei 已提交
226
        row_count_(row_count),
Q
Qiao Longfei 已提交
227
        lazy_mode_(lazy_mode) {}
S
sneaxiy 已提交
228

Q
Qiao Longfei 已提交
229
  inline HOSTDEVICE void adam_update(size_t i, T g) const {
S
sneaxiy 已提交
230 231 232 233
    // The following code is the same as dense
    T mom1 = moment1_[i];
    T mom2 = moment2_[i];
    T lr = *lr_;
T
typhoonzero 已提交
234 235
    T beta1_pow = *beta1_pow_;
    T beta2_pow = *beta2_pow_;
S
sneaxiy 已提交
236 237 238 239 240 241 242 243 244 245 246 247 248
    T p = param_[i];

    // Calculation
    lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);

    mom1 = beta1_ * mom1 + (1 - beta1_) * g;
    mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
    p -= lr * (mom1 / (sqrt(mom2) + epsilon_));

    // Write back to global memory
    moment1_out_[i] = mom1;
    moment2_out_[i] = mom2;
    param_out_[i] = p;
T
wip  
typhoonzero 已提交
249
  }
Q
Qiao Longfei 已提交
250 251 252 253

  inline HOSTDEVICE void operator()(size_t i) const {
    auto row_idx =
        math::BinarySearch<int64_t>(rows_, row_count_, i / row_numel_);
Q
Qiao Longfei 已提交
254 255 256
    if (lazy_mode_ && row_idx < 0) {
      return;
    } else {
Q
Qiao Longfei 已提交
257 258 259
      T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
      adam_update(i, g);
    }
Q
Qiao Longfei 已提交
260
  }
T
wip  
typhoonzero 已提交
261 262
};

M
minqiyang 已提交
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
template <typename T>
struct SparseAdamFunctor<T, CPUAdam> {
  T beta1_;
  T beta2_;
  T epsilon_;

  const T* beta1_pow_;
  const T* beta2_pow_;
  const T* moment1_;
  T* moment1_out_;
  const T* moment2_;
  T* moment2_out_;
  const T* lr_;
  const T* grad_;
  const T* param_;
  T* param_out_;

  const int64_t* rows_;
  int64_t row_numel_;
  int64_t row_count_;

  SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
                    const T* beta2_pow, const T* mom1, T* mom1_out,
                    const T* mom2, T* mom2_out, const T* lr, const T* grad,
                    const T* param, T* param_out, const int64_t* rows,
288
                    int64_t row_numel, int64_t row_count, bool lazy_mode)
M
minqiyang 已提交
289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305
      : beta1_(beta1),
        beta2_(beta2),
        epsilon_(epsilon),
        beta1_pow_(beta1_pow),
        beta2_pow_(beta2_pow),
        moment1_(mom1),
        moment1_out_(mom1_out),
        moment2_(mom2),
        moment2_out_(mom2_out),
        lr_(lr),
        grad_(grad),
        param_(param),
        param_out_(param_out),
        rows_(rows),
        row_numel_(row_numel),
        row_count_(row_count) {}

306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
  inline HOSTDEVICE void adam_update(size_t i, T g) const {
    // The following code is the same as dense
    T mom1 = moment1_[i];
    T mom2 = moment2_[i];
    T lr = *lr_;
    T beta1_pow = *beta1_pow_;
    T beta2_pow = *beta2_pow_;
    T p = param_[i];

    // Calculation
    lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);

    mom1 = beta1_ * mom1 + (1 - beta1_) * g;
    mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
    p -= lr * (mom1 / (sqrt(mom2) + epsilon_));

    // Write back to global memory
    moment1_out_[i] = mom1;
    moment2_out_[i] = mom2;
    param_out_[i] = p;
  }

M
minqiyang 已提交
328 329 330 331 332 333
  inline void operator()(size_t numel) const {
    // lr could be reuse
    T lr = *lr_;
    T beta1_pow = *beta1_pow_;
    T beta2_pow = *beta2_pow_;
    lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
S
sneaxiy 已提交
334
    int64_t row_count = static_cast<int64_t>(numel / row_numel_);
M
minqiyang 已提交
335

S
sneaxiy 已提交
336
    for (int64_t i = 0, j = 0; i != row_count; ++i) {
M
minqiyang 已提交
337
      if (i == *(rows_ + j)) {
S
sneaxiy 已提交
338
        for (int64_t k = 0; k != row_numel_; ++k) {
M
Fix bug  
minqiyang 已提交
339
          T g = grad_[j * row_numel_ + k];
M
minqiyang 已提交
340
          adam_update(i * row_numel_ + k, g);
M
Fix bug  
minqiyang 已提交
341
        }
M
minqiyang 已提交
342 343
        ++j;
      } else {
S
sneaxiy 已提交
344
        for (int64_t k = 0; k != row_numel_; ++k) {
M
Fix bug  
minqiyang 已提交
345 346 347 348 349 350 351 352 353 354 355 356 357
          T mom1 = moment1_[i * row_numel_ + k];
          T mom2 = moment2_[i * row_numel_ + k];
          T p = param_[i * row_numel_ + k];

          mom1 = beta1_ * mom1;
          mom2 = beta2_ * mom2;

          p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
          // Write back to global memory
          moment1_out_[i * row_numel_ + k] = mom1;
          moment2_out_[i * row_numel_ + k] = mom2;
          param_out_[i * row_numel_ + k] = p;
        }
M
minqiyang 已提交
358 359 360 361 362
      }
    }
  }
};

Q
QI JUN 已提交
363
template <typename DeviceContext, typename T>
364 365 366
class AdamOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
C
chengduo 已提交
367 368 369 370
    const auto* param_var = ctx.InputVar("Param");
    PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
                   "The Var(%s)'s type should be LoDTensor, "
                   "but the received is %s",
S
sneaxiy 已提交
371 372
                   ctx.Inputs("Param").front(),
                   framework::ToTypeName(param_var->Type()));
C
chengduo 已提交
373

Y
Yang Yu 已提交
374 375
    using paddle::framework::LoDTensor;
    using paddle::operators::detail::Ref;
376

377 378
    int64_t min_row_size_to_use_multithread =
        ctx.Attr<int64_t>("min_row_size_to_use_multithread");
Q
Qiao Longfei 已提交
379
    bool lazy_mode = ctx.Attr<bool>("lazy_mode");
380
    T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
Y
Yang Yu 已提交
381
    auto& param = Ref(ctx.Input<LoDTensor>("Param"), "Must set Param");
T
wip  
typhoonzero 已提交
382 383
    // auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad");
    auto* grad_var = ctx.InputVar("Grad");
Y
Yang Yu 已提交
384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400
    auto& mom1 = Ref(ctx.Input<LoDTensor>("Moment1"), "Must set Moment1");
    auto& mom2 = Ref(ctx.Input<LoDTensor>("Moment2"), "Must set Moment2");
    auto& lr =
        Ref(ctx.Input<LoDTensor>("LearningRate"), "Must set LearningRate");

    auto& beta1_pow =
        Ref(ctx.Input<LoDTensor>("Beta1Pow"), "Must set Beta1Pow");
    auto& beta2_pow =
        Ref(ctx.Input<LoDTensor>("Beta2Pow"), "Must set Beta2Pow");

    auto& param_out =
        Ref(ctx.Output<LoDTensor>("ParamOut"), "Must set ParamOut");
    auto& mom1_out =
        Ref(ctx.Output<LoDTensor>("Moment1Out"), "Must set Moment1Out");
    auto& mom2_out =
        Ref(ctx.Output<LoDTensor>("Moment2Out"), "Must set Moment1Out");

401 402 403 404 405 406 407 408 409 410 411
    T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
    if (ctx.HasInput("Beta1Tensor")) {
      auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
      beta1 = static_cast<T>(GetAttrFromTensor(beta1_tensor));
    }
    T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
    if (ctx.HasInput("Beta2Tensor")) {
      auto* beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor");
      beta2 = static_cast<T>(GetAttrFromTensor(beta2_tensor));
    }

T
wip  
typhoonzero 已提交
412 413
    if (grad_var->IsType<framework::LoDTensor>()) {
      auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad");
414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441

      if (platform::is_cpu_place(ctx.GetPlace())) {
        AdamFunctor<T, CPUAdam> functor(
            beta1, beta2, epsilon, beta1_pow.template data<T>(),
            beta2_pow.template data<T>(), mom1.template data<T>(),
            mom1_out.template mutable_data<T>(ctx.GetPlace()),
            mom2.template data<T>(),
            mom2_out.template mutable_data<T>(ctx.GetPlace()),
            lr.template data<T>(), grad.template data<T>(),
            param.template data<T>(),
            param_out.template mutable_data<T>(ctx.GetPlace()));
        functor(param.numel());
      } else if (platform::is_gpu_place(ctx.GetPlace())) {
        AdamFunctor<T, GPUAdam> functor(
            beta1, beta2, epsilon, beta1_pow.template data<T>(),
            beta2_pow.template data<T>(), mom1.template data<T>(),
            mom1_out.template mutable_data<T>(ctx.GetPlace()),
            mom2.template data<T>(),
            mom2_out.template mutable_data<T>(ctx.GetPlace()),
            lr.template data<T>(), grad.template data<T>(),
            param.template data<T>(),
            param_out.template mutable_data<T>(ctx.GetPlace()));

        platform::ForRange<DeviceContext> for_range(
            static_cast<const DeviceContext&>(ctx.device_context()),
            param.numel());
        for_range(functor);
      }
T
wip  
typhoonzero 已提交
442 443 444
    } else if (grad_var->IsType<framework::SelectedRows>()) {
      auto& grad =
          Ref(ctx.Input<framework::SelectedRows>("Grad"), "Must set Grad");
445
      if (grad.rows().size() == 0) {
M
minqiyang 已提交
446
        VLOG(3) << "grad row size is 0!!";
447 448
        return;
      }
S
sneaxiy 已提交
449 450 451 452 453 454 455 456 457 458

      std::vector<int64_t> cpu_rows(grad.rows().begin(), grad.rows().end());
      bool is_strict_sorted = true;
      for (size_t i = 1; i < cpu_rows.size(); ++i) {
        if (cpu_rows[i - 1] >= cpu_rows[i]) {
          is_strict_sorted = false;
          break;
        }
      }

S
sneaxiy 已提交
459
      framework::SelectedRows tmp_grad_merge;
S
sneaxiy 已提交
460 461 462 463 464 465 466 467
      const framework::SelectedRows* grad_merge_ptr;
      if (is_strict_sorted) {
        grad_merge_ptr = &grad;
      } else {
        // merge duplicated rows if any.
        // The rows of grad_merge have been sorted inside MergeAdd functor
        scatter::MergeAdd<DeviceContext, T> merge_func;
        merge_func(ctx.template device_context<DeviceContext>(), grad,
S
sneaxiy 已提交
468 469
                   &tmp_grad_merge, true);
        grad_merge_ptr = &tmp_grad_merge;
S
sneaxiy 已提交
470 471 472
      }

      auto& grad_merge = *grad_merge_ptr;
T
wip  
typhoonzero 已提交
473
      auto& grad_tensor = grad_merge.value();
T
wip  
typhoonzero 已提交
474
      const T* grad_data = grad_tensor.template data<T>();
S
sneaxiy 已提交
475
      const int64_t* rows = grad_merge.rows().Data(ctx.GetPlace());
T
wip  
typhoonzero 已提交
476
      auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
T
wip  
typhoonzero 已提交
477

M
minqiyang 已提交
478 479
      if (platform::is_cpu_place(ctx.GetPlace())) {
        SparseAdamFunctor<T, CPUAdam> functor(
Q
Qiao Longfei 已提交
480 481 482 483 484 485 486 487
            beta1, beta2, epsilon, beta1_pow.template data<T>(),
            beta2_pow.template data<T>(), mom1.template data<T>(),
            mom1_out.template mutable_data<T>(ctx.GetPlace()),
            mom2.template data<T>(),
            mom2_out.template mutable_data<T>(ctx.GetPlace()),
            lr.template data<T>(), grad_data, param.template data<T>(),
            param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
            grad_merge.rows().size(), lazy_mode);
488 489 490 491 492 493 494 495 496 497
        if (lazy_mode) {
          VLOG(3) << "run cpu lazy mode";
          size_t row_count = grad_merge.rows().size();
          std::vector<int64_t> cpu_rows(grad_merge.rows());
          for (size_t row_index = 0; row_index < row_count; ++row_index) {
            for (size_t offset = 0; offset < row_numel; ++offset) {
              size_t i = cpu_rows[row_index] * row_numel + offset;
              functor.adam_update(i, grad_data[row_index * row_numel + offset]);
            }
          }
498 499
        }
#ifndef _WIN32
S
sneaxiy 已提交
500
        else if (FLAGS_inner_op_parallelism > 1 &&  // NOLINT
501 502
                 min_row_size_to_use_multithread > 0 &&
                 param.dims()[0] > min_row_size_to_use_multithread) {
503 504
          VLOG(3) << "use multi thread, inner_op_parallelism="
                  << FLAGS_inner_op_parallelism
505
                  << " min_row_size_to_use_multithread="
506
                  << min_row_size_to_use_multithread;
Q
Qiao Longfei 已提交
507
          if (FLAGS_inner_op_parallelism > 10) {
508 509
            VLOG(1) << "FLAGS_inner_op_parallelism "
                    << FLAGS_inner_op_parallelism << " is two large!";
Q
Qiao Longfei 已提交
510
          }
511 512 513
          auto& grad_rows = grad_merge.rows();
          std::unordered_map<size_t, int> row_id_to_grad_row_offset;
          size_t param_row_count = param.numel() / row_numel;
Q
Qiao Longfei 已提交
514
          if (param_row_count < 1000) {
515 516 517
            VLOG(1) << "param_row_count should be larger then 1000 to use "
                       "multi thread, currently "
                    << param_row_count;
Q
Qiao Longfei 已提交
518
          }
519 520
          for (size_t i = 0; i < grad_rows.size(); ++i) {
            row_id_to_grad_row_offset[grad_rows[i]] = i;
Q
Qiao Longfei 已提交
521
          }
522
          std::vector<std::future<void>> fs;
Q
Qiao Longfei 已提交
523
          int64_t line_in_each_thread =
Q
Qiao Longfei 已提交
524
              param_row_count / FLAGS_inner_op_parallelism + 1;
525 526 527
          for (int i = 0; i < FLAGS_inner_op_parallelism; ++i) {
            int64_t start = i * line_in_each_thread;
            int64_t end = (i + 1) * line_in_each_thread;
S
sneaxiy 已提交
528
            if (start >= static_cast<int64_t>(param_row_count)) {
Q
Qiao Longfei 已提交
529 530
              break;
            }
S
sneaxiy 已提交
531 532
            if (end > static_cast<int64_t>(param_row_count)) {
              end = static_cast<int64_t>(param_row_count);
Q
Qiao Longfei 已提交
533
            }
Q
Qiao Longfei 已提交
534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551
            fs.push_back(
                framework::Async([&functor, &row_id_to_grad_row_offset,
                                  &grad_data, row_numel, start, end]() {
                  for (int64_t row_id = start; row_id < end; ++row_id) {
                    auto iter = row_id_to_grad_row_offset.find(row_id);
                    if (iter != row_id_to_grad_row_offset.end()) {
                      for (size_t row_offset = 0U; row_offset < row_numel;
                           ++row_offset) {
                        functor.adam_update(
                            row_id * row_numel + row_offset,
                            grad_data[iter->second * row_numel + row_offset]);
                      }
                    } else {
                      for (size_t row_offset = 0U; row_offset < row_numel;
                           ++row_offset) {
                        functor.adam_update(row_id * row_numel + row_offset, 0);
                      }
                    }
Q
Qiao Longfei 已提交
552 553
                  }
                }));
Q
Qiao Longfei 已提交
554
          }
555
          for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
556
        }
S
sneaxiy 已提交
557 558
#endif          // !_WIN32
        else {  // NOLINT
559
          functor(param.numel());
Q
Qiao Longfei 已提交
560
        }
M
minqiyang 已提交
561 562 563 564 565 566 567 568 569
      } else if (platform::is_gpu_place(ctx.GetPlace())) {
        SparseAdamFunctor<T, GPUAdam> functor(
            beta1, beta2, epsilon, beta1_pow.template data<T>(),
            beta2_pow.template data<T>(), mom1.template data<T>(),
            mom1_out.template mutable_data<T>(ctx.GetPlace()),
            mom2.template data<T>(),
            mom2_out.template mutable_data<T>(ctx.GetPlace()),
            lr.template data<T>(), grad_data, param.template data<T>(),
            param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
570
            grad_merge.rows().size(), lazy_mode);
M
minqiyang 已提交
571 572

        // FIXME(minqiyang): remove BinarySearch in GPU later
Q
Qiao Longfei 已提交
573 574 575 576 577
        platform::ForRange<DeviceContext> for_range(
            static_cast<const DeviceContext&>(ctx.device_context()),
            param.numel());
        for_range(functor);
      }
T
wip  
typhoonzero 已提交
578 579 580
    } else {
      PADDLE_THROW("Variable type not supported by adam_op");
    }
581 582 583 584 585
  }
};

}  // namespace operators
}  // namespace paddle