lamb_op.h 29.2 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <math.h>  // for sqrt in CPU and CUDA
17

Y
Yibing Liu 已提交
18 19
#include <Eigen/Dense>
#include <vector>
20

Y
Yibing Liu 已提交
21
#include "paddle/fluid/framework/op_registry.h"
S
sneaxiy 已提交
22
#include "paddle/fluid/memory/buffer.h"
23
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
Y
Yibing Liu 已提交
24
#include "paddle/fluid/operators/math/selected_rows_functor.h"
25
#include "paddle/fluid/operators/tensor_to_string.h"
Y
Yibing Liu 已提交
26
#include "paddle/fluid/platform/for_range.h"
27 28
#include "paddle/phi/kernels/funcs/algorithm.h"
#include "paddle/phi/kernels/funcs/eigen/extensions.h"
29
#include "paddle/phi/kernels/funcs/squared_l2_norm.h"
Y
Yibing Liu 已提交
30 31 32 33 34 35

namespace paddle {
namespace operators {

namespace scatter = paddle::operators::math::scatter;

36
template <typename T, bool IsMultiPrecision>
37
struct LambMomentREGUpdateFunctor {
38 39 40
  using MT = typename std::conditional<IsMultiPrecision,
                                       typename details::MPTypeTrait<T>::Type,
                                       T>::type;
41 42 43 44 45 46 47 48 49 50 51 52 53 54

  MT weight_decay_;
  MT beta1_;
  MT beta2_;
  MT epsilon_;

  MT beta1_pow_;
  MT* beta1_pow_out_;
  MT beta2_pow_;
  MT* beta2_pow_out_;
  const MT* moment1_;
  MT* moment1_out_;
  const MT* moment2_;
  MT* moment2_out_;
55
  const T* grad_;
56 57 58 59
  const MT* param_;
  MT* trust_ratio_div_;
  const bool* skip_update_;

60 61 62 63 64 65 66 67 68 69 70 71 72 73
  LambMomentREGUpdateFunctor(MT weight_decay,
                             MT beta1,
                             MT beta2,
                             MT epsilon,
                             MT beta1_pow,
                             MT beta2_pow,
                             const MT* mom1,
                             MT* mom1_out,
                             const MT* mom2,
                             MT* mom2_out,
                             const T* grad,
                             const MT* param,
                             MT* trust_ratio_div,
                             const bool* skip_update)
74 75 76 77 78 79 80 81 82 83 84 85
      : weight_decay_(weight_decay),
        beta1_(beta1),
        beta2_(beta2),
        epsilon_(epsilon),
        beta1_pow_(beta1_pow),
        beta2_pow_(beta2_pow),
        moment1_(mom1),
        moment1_out_(mom1_out),
        moment2_(mom2),
        moment2_out_(mom2_out),
        grad_(grad),
        param_(param),
86 87
        trust_ratio_div_(trust_ratio_div),
        skip_update_(skip_update) {}
88 89

  inline HOSTDEVICE void operator()(size_t i) const {
90
    if (skip_update_ && *skip_update_) return;
91

92 93 94 95 96 97 98 99 100
    MT g = static_cast<MT>(grad_[i]);
    MT mom1 = moment1_[i];
    MT mom2 = moment2_[i];
    MT beta1_pow = beta1_pow_;
    MT beta2_pow = beta2_pow_;
    MT p = param_[i];

    mom1 = beta1_ * mom1 + (static_cast<MT>(1) - beta1_) * g;
    mom2 = beta2_ * mom2 + (static_cast<MT>(1) - beta2_) * g * g;
101 102 103 104

    moment1_out_[i] = mom1;
    moment2_out_[i] = mom2;

105 106
    MT mom1_unbiased = mom1 / (static_cast<MT>(1) - beta1_pow);
    MT mom2_unbiased = mom2 / (static_cast<MT>(1) - beta2_pow);
107
    trust_ratio_div_[i] =
108 109
        mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
        weight_decay_ * p;
110 111 112
  }
};

113
template <typename T, bool IsMultiPrecision>
114
struct LambMomentMENUpdateFunctor {
115 116 117
  using MT = typename std::conditional<IsMultiPrecision,
                                       typename details::MPTypeTrait<T>::Type,
                                       T>::type;
118 119 120 121 122 123 124 125 126 127 128 129

  MT weight_decay_;
  MT beta1_;
  MT beta2_;
  MT epsilon_;

  const MT* beta1_pow_;
  const MT* beta2_pow_;
  const MT* moment1_;
  MT* moment1_out_;
  const MT* moment2_;
  MT* moment2_out_;
Y
Yibing Liu 已提交
130
  const T* grad_;
131 132 133 134
  const MT* param_;
  MT* trust_ratio_div_;
  const bool* skip_update_;

135 136 137 138 139 140 141 142 143 144 145 146 147 148
  LambMomentMENUpdateFunctor(MT weight_decay,
                             MT beta1,
                             MT beta2,
                             MT epsilon,
                             const MT* beta1_pow,
                             const MT* beta2_pow,
                             const MT* mom1,
                             MT* mom1_out,
                             const MT* mom2,
                             MT* mom2_out,
                             const T* grad,
                             const MT* param,
                             MT* trust_ratio_div,
                             const bool* skip_update)
Y
Yibing Liu 已提交
149 150 151 152 153 154 155 156 157 158 159 160
      : weight_decay_(weight_decay),
        beta1_(beta1),
        beta2_(beta2),
        epsilon_(epsilon),
        beta1_pow_(beta1_pow),
        beta2_pow_(beta2_pow),
        moment1_(mom1),
        moment1_out_(mom1_out),
        moment2_(mom2),
        moment2_out_(mom2_out),
        grad_(grad),
        param_(param),
161 162
        trust_ratio_div_(trust_ratio_div),
        skip_update_(skip_update) {}
Y
Yibing Liu 已提交
163 164

  inline HOSTDEVICE void operator()(size_t i) const {
165 166 167 168 169 170 171
    if (skip_update_ && *skip_update_) return;
    MT g = static_cast<MT>(grad_[i]);
    MT mom1 = moment1_[i];
    MT mom2 = moment2_[i];
    MT beta1_pow = *beta1_pow_;
    MT beta2_pow = *beta2_pow_;
    MT p = param_[i];
Y
Yibing Liu 已提交
172

173 174
    mom1 = beta1_ * mom1 + (static_cast<MT>(1) - beta1_) * g;
    mom2 = beta2_ * mom2 + (static_cast<MT>(1) - beta2_) * g * g;
Y
Yibing Liu 已提交
175 176 177

    moment1_out_[i] = mom1;
    moment2_out_[i] = mom2;
178

179 180
    MT mom1_unbiased = mom1 / (static_cast<MT>(1) - beta1_pow);
    MT mom2_unbiased = mom2 / (static_cast<MT>(1) - beta2_pow);
181
    trust_ratio_div_[i] =
182 183
        mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
        weight_decay_ * p;
Y
Yibing Liu 已提交
184 185 186 187
  }
};

template <typename T>
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
struct SparseLambMomentREGUpdateFunctor {
  T weight_decay_;
  T beta1_;
  T beta2_;
  T epsilon_;

  T beta1_pow_;
  T beta2_pow_;
  const T* moment1_;
  T* moment1_out_;
  const T* moment2_;
  T* moment2_out_;
  const T* grad_;
  const T* param_;
  T* trust_ratio_div_;

  const int64_t* rows_;
  int64_t row_numel_;
  int64_t row_count_;

208 209
  const bool* skip_update_;

210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
  SparseLambMomentREGUpdateFunctor(T weight_decay,
                                   T beta1,
                                   T beta2,
                                   T epsilon,
                                   T beta1_pow,
                                   T beta2_pow,
                                   const T* mom1,
                                   T* mom1_out,
                                   const T* mom2,
                                   T* mom2_out,
                                   const T* grad,
                                   const T* param,
                                   T* trust_ratio_div,
                                   const int64_t* rows,
                                   int64_t row_numel,
                                   int64_t row_count,
226
                                   const bool* skip_update)
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
      : weight_decay_(weight_decay),
        beta1_(beta1),
        beta2_(beta2),
        epsilon_(epsilon),
        beta1_pow_(beta1_pow),
        beta2_pow_(beta2_pow),
        moment1_(mom1),
        moment1_out_(mom1_out),
        moment2_(mom2),
        moment2_out_(mom2_out),
        grad_(grad),
        param_(param),
        trust_ratio_div_(trust_ratio_div),
        rows_(rows),
        row_numel_(row_numel),
242 243
        row_count_(row_count),
        skip_update_(skip_update) {}
244 245 246 247 248 249 250 251 252

  inline HOSTDEVICE void update(size_t i, T g) const {
    // The following code is same as dense
    T mom1 = moment1_[i];
    T mom2 = moment2_[i];
    T beta1_pow = beta1_pow_;
    T beta2_pow = beta2_pow_;
    T p = param_[i];

253 254
    mom1 = beta1_ * mom1 + (static_cast<T>(1) - beta1_) * g;
    mom2 = beta2_ * mom2 + (static_cast<T>(1) - beta2_) * g * g;
255 256 257 258

    moment1_out_[i] = mom1;
    moment2_out_[i] = mom2;

259 260
    T mom1_unbiased = mom1 / (static_cast<T>(1) - beta1_pow);
    T mom2_unbiased = mom2 / (static_cast<T>(1) - beta2_pow);
261
    trust_ratio_div_[i] =
262 263
        mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
        weight_decay_ * p;
264 265 266
  }

  inline HOSTDEVICE void operator()(size_t i) const {
267
    if (skip_update_ && *skip_update_) return;
268
    auto row_idx =
269
        phi::funcs::BinarySearch<int64_t>(rows_, row_count_, i / row_numel_);
270 271
    T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_]
                       : static_cast<T>(0);
272 273 274 275 276 277
    update(i, g);
  }
};

template <typename T>
struct SparseLambMomentMENUpdateFunctor {
Y
Yibing Liu 已提交
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296
  T weight_decay_;
  T beta1_;
  T beta2_;
  T epsilon_;

  const T* beta1_pow_;
  const T* beta2_pow_;
  const T* moment1_;
  T* moment1_out_;
  const T* moment2_;
  T* moment2_out_;
  const T* grad_;
  const T* param_;
  T* trust_ratio_div_;

  const int64_t* rows_;
  int64_t row_numel_;
  int64_t row_count_;

297 298
  const bool* skip_update_;

299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314
  SparseLambMomentMENUpdateFunctor(T weight_decay,
                                   T beta1,
                                   T beta2,
                                   T epsilon,
                                   const T* beta1_pow,
                                   const T* beta2_pow,
                                   const T* mom1,
                                   T* mom1_out,
                                   const T* mom2,
                                   T* mom2_out,
                                   const T* grad,
                                   const T* param,
                                   T* trust_ratio_div,
                                   const int64_t* rows,
                                   int64_t row_numel,
                                   int64_t row_count,
315
                                   const bool* skip_update)
Y
Yibing Liu 已提交
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330
      : weight_decay_(weight_decay),
        beta1_(beta1),
        beta2_(beta2),
        epsilon_(epsilon),
        beta1_pow_(beta1_pow),
        beta2_pow_(beta2_pow),
        moment1_(mom1),
        moment1_out_(mom1_out),
        moment2_(mom2),
        moment2_out_(mom2_out),
        grad_(grad),
        param_(param),
        trust_ratio_div_(trust_ratio_div),
        rows_(rows),
        row_numel_(row_numel),
331 332
        row_count_(row_count),
        skip_update_(skip_update) {}
Y
Yibing Liu 已提交
333 334 335 336 337

  inline HOSTDEVICE void update(size_t i, T g) const {
    // The following code is same as dense
    T mom1 = moment1_[i];
    T mom2 = moment2_[i];
338 339
    T beta1_pow = *beta1_pow_;
    T beta2_pow = *beta2_pow_;
Y
Yibing Liu 已提交
340 341
    T p = param_[i];

342 343
    mom1 = beta1_ * mom1 + (static_cast<T>(1) - beta1_) * g;
    mom2 = beta2_ * mom2 + (static_cast<T>(1) - beta2_) * g * g;
Y
Yibing Liu 已提交
344 345 346

    moment1_out_[i] = mom1;
    moment2_out_[i] = mom2;
347

348 349
    T mom1_unbiased = mom1 / (static_cast<T>(1) - beta1_pow);
    T mom2_unbiased = mom2 / (static_cast<T>(1) - beta2_pow);
350
    trust_ratio_div_[i] =
351 352
        mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
        weight_decay_ * p;
Y
Yibing Liu 已提交
353 354 355
  }

  inline HOSTDEVICE void operator()(size_t i) const {
356
    if (skip_update_ && *skip_update_) return;
Y
Yibing Liu 已提交
357
    auto row_idx =
358
        phi::funcs::BinarySearch<int64_t>(rows_, row_count_, i / row_numel_);
359 360
    T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_]
                       : static_cast<T>(0);
Y
Yibing Liu 已提交
361 362 363 364
    update(i, g);
  }
};

365 366
template <typename MT, bool NeedUpdateBetaPow /*=true*/>
struct LambBetaPowUpdateFunctor {
367 368 369 370 371 372
  void SetBetaPows(const MT* beta1pow,
                   const MT* beta2pow,
                   MT* beta1pow_out,
                   MT* beta2pow_out,
                   MT beta1,
                   MT beta2) {
373 374 375 376 377 378 379
    beta1pow_ = beta1pow;
    beta2pow_ = beta2pow;
    beta1pow_out_ = beta1pow_out;
    beta2pow_out_ = beta2pow_out;
    beta1_ = beta1;
    beta2_ = beta2;
  }
380

381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398
  HOSTDEVICE void UpdateBetaPow(size_t i) const {
    if (i == 0) {
      beta1pow_out_[0] = beta1pow_[0] * beta1_;
      beta2pow_out_[0] = beta2pow_[0] * beta2_;
    }
  }

 private:
  const MT* beta1pow_;
  const MT* beta2pow_;
  MT* beta1pow_out_;
  MT* beta2pow_out_;
  MT beta1_;
  MT beta2_;
};

template <typename MT>
struct LambBetaPowUpdateFunctor<MT, /*NeedUpdateBetaPow=*/false> {
399 400 401 402 403 404
  void SetBetaPows(const MT* beta1pow,
                   const MT* beta2pow,
                   MT* beta1pow_out,
                   MT* beta2pow_out,
                   MT beta1,
                   MT beta2) {}
405 406 407 408 409 410
  HOSTDEVICE void UpdateBetaPow(size_t) const {}
};

template <typename T, typename MT, bool IsMultiPrecision, bool UpdateBetaPow>
struct LambParamUpateFunctor
    : public LambBetaPowUpdateFunctor<MT, UpdateBetaPow> {
411
  const MT* lr_;
Y
Yibing Liu 已提交
412
  const T* param_;
413 414 415 416
  const MT* master_param_;
  const MT* param_norm_;
  const MT* trust_ratio_div_;
  const MT* trust_ratio_div_norm_;
Y
Yibing Liu 已提交
417
  T* param_out_;
418 419 420
  MT* master_param_out_;

  const bool* skip_update_;
Y
Yibing Liu 已提交
421

422 423 424 425 426 427 428 429 430
  LambParamUpateFunctor(const MT* lr,
                        const T* param,
                        const MT* master_param,
                        const MT* param_norm,
                        const MT* trust_ratio_div,
                        const MT* trust_ratio_div_norm,
                        T* param_out,
                        MT* master_param_out,
                        const bool* skip_update)
Y
Yibing Liu 已提交
431 432
      : lr_(lr),
        param_(param),
433
        master_param_(master_param),
Y
Yibing Liu 已提交
434 435 436
        param_norm_(param_norm),
        trust_ratio_div_(trust_ratio_div),
        trust_ratio_div_norm_(trust_ratio_div_norm),
437 438 439
        param_out_(param_out),
        master_param_out_(master_param_out),
        skip_update_(skip_update) {}
Y
Yibing Liu 已提交
440 441

  inline HOSTDEVICE void operator()(size_t i) const {
442 443
    if (skip_update_ && *skip_update_) return;
    MT lr = *lr_;
S
sneaxiy 已提交
444 445
    MT pn = Eigen::numext::sqrt(*param_norm_);
    MT tn = Eigen::numext::sqrt(*trust_ratio_div_norm_);
446 447 448 449

    MT r = (pn > static_cast<MT>(0) && tn > static_cast<MT>(0))
               ? pn / tn
               : static_cast<MT>(1);
Y
Yibing Liu 已提交
450
    lr *= r;
451 452 453 454 455 456
    MT p = IsMultiPrecision ? master_param_[i] : static_cast<MT>(param_[i]);
    MT param_out = p - lr * trust_ratio_div_[i];
    param_out_[i] = static_cast<T>(param_out);
    if (IsMultiPrecision) {
      master_param_out_[i] = param_out;
    }
457
    this->UpdateBetaPow(i);
Y
Yibing Liu 已提交
458 459 460 461 462 463 464
  }
};

template <typename DeviceContext, typename T>
class LambOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
465 466 467 468 469 470 471 472 473 474 475 476 477 478 479
    using MT = typename details::MPTypeTrait<T>::Type;
    bool multi_precision = ctx.Attr<bool>("multi_precision");
    if (multi_precision) {
      ComputeImpl<MT, true>(ctx);
    } else {
      ComputeImpl<T, false>(ctx);
    }
  }

 private:
  template <typename MT, bool IsMultiPrecision>
  void ComputeImpl(const framework::ExecutionContext& ctx) const {
    if (!IsMultiPrecision) {
      constexpr auto kIsSameType = std::is_same<T, MT>::value;
      PADDLE_ENFORCE_EQ(
480 481
          kIsSameType,
          true,
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506
          platform::errors::InvalidArgument(
              "When multi_precision=False, T and MT must be the same type."));
    }
    const auto* skip_update = ctx.Input<framework::LoDTensor>("SkipUpdate");
    const bool* skip_update_flag = skip_update && skip_update->IsInitialized()
                                       ? skip_update->data<bool>()
                                       : nullptr;
    if (skip_update_flag && platform::is_cpu_place(skip_update->place()) &&
        (*skip_update_flag)) {
      return;
    }

    auto weight_decay = static_cast<MT>(ctx.Attr<float>("weight_decay"));
    auto beta1 = static_cast<MT>(ctx.Attr<float>("beta1"));
    auto beta2 = static_cast<MT>(ctx.Attr<float>("beta2"));
    auto epsilon = static_cast<MT>(ctx.Attr<float>("epsilon"));
    const auto& param = GET_DATA_SAFELY(
        ctx.Input<framework::LoDTensor>("Param"), "Input", "Param", "Lamb");
    const auto* grad_var = ctx.InputVar("Grad");
    const auto& mom1 = GET_DATA_SAFELY(
        ctx.Input<framework::LoDTensor>("Moment1"), "Input", "Moment1", "Lamb");
    const auto& mom2 = GET_DATA_SAFELY(
        ctx.Input<framework::LoDTensor>("Moment2"), "Input", "Moment2", "Lamb");
    const auto& lr =
        GET_DATA_SAFELY(ctx.Input<framework::LoDTensor>("LearningRate"),
507 508 509
                        "Input",
                        "LearningRate",
                        "Lamb");
510 511

    const auto& beta1_pow =
512 513 514 515
        GET_DATA_SAFELY(ctx.Input<framework::LoDTensor>("Beta1Pow"),
                        "Input",
                        "Beta1Pow",
                        "Lamb");
516
    const auto& beta2_pow =
517 518 519 520
        GET_DATA_SAFELY(ctx.Input<framework::LoDTensor>("Beta2Pow"),
                        "Input",
                        "Beta2Pow",
                        "Lamb");
521 522

    auto& param_out =
523 524 525 526
        GET_DATA_SAFELY(ctx.Output<framework::LoDTensor>("ParamOut"),
                        "Output",
                        "ParamOut",
                        "Lamb");
527 528
    auto& mom1_out =
        GET_DATA_SAFELY(ctx.Output<framework::LoDTensor>("Moment1Out"),
529 530 531
                        "Output",
                        "Moment1Out",
                        "Lamb");
532 533
    auto& mom2_out =
        GET_DATA_SAFELY(ctx.Output<framework::LoDTensor>("Moment2Out"),
534 535 536
                        "Output",
                        "Moment2Out",
                        "Lamb");
537 538
    auto& beta1_pow_out =
        GET_DATA_SAFELY(ctx.Output<framework::LoDTensor>("Beta1PowOut"),
539 540 541
                        "Output",
                        "Beta1PowOut",
                        "Lamb");
542 543
    auto& beta2_pow_out =
        GET_DATA_SAFELY(ctx.Output<framework::LoDTensor>("Beta2PowOut"),
544 545 546
                        "Output",
                        "Beta2PowOut",
                        "Lamb");
547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563
    const auto* master_param =
        IsMultiPrecision ? ctx.Input<framework::LoDTensor>("MasterParam")
                         : nullptr;
    auto* master_param_out =
        IsMultiPrecision ? ctx.Output<framework::LoDTensor>("MasterParamOut")
                         : nullptr;

    if (IsMultiPrecision) {
      PADDLE_ENFORCE_NOT_NULL(master_param,
                              platform::errors::InvalidArgument(
                                  "Input(MasterParam) must be provided when "
                                  "multi_precision=True."));
      PADDLE_ENFORCE_NOT_NULL(master_param_out,
                              platform::errors::InvalidArgument(
                                  "Output(MasterParamOut) must be provided "
                                  "when multi_precision=True."));
    }
Y
Yibing Liu 已提交
564 565

    auto& dev_ctx = ctx.template device_context<DeviceContext>();
S
sneaxiy 已提交
566 567
    auto numel = param.numel();
    platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
568 569
    auto trust_ratio_div =
        ctx.AllocateTmpTensor<MT, DeviceContext>(param.dims(), dev_ctx);
S
sneaxiy 已提交
570
    auto* trust_ratio_div_ptr = trust_ratio_div.template data<MT>();
571

572
    const void* param_ptr = param.data();
573
    const void* master_param_ptr =
574
        master_param ? master_param->data() : nullptr;
575 576 577 578 579
    void* param_out_ptr = param_out.template mutable_data<T>(ctx.GetPlace());
    void* master_param_out_ptr =
        master_param_out
            ? master_param_out->template mutable_data<MT>(ctx.GetPlace())
            : nullptr;
Y
Yibing Liu 已提交
580 581

    // Update moments
582 583 584 585 586
    bool should_update_beta_pow_later = false;
    const MT *beta1_pow_ptr = nullptr, *beta2_pow_ptr = nullptr;
    MT *beta1_pow_out_ptr = nullptr, *beta2_pow_out_ptr = nullptr;
    VLOG(10) << "Beta1Pow place: " << beta1_pow.place()
             << " , Beta2Pow place: " << beta2_pow.place();
Y
Yibing Liu 已提交
587
    if (grad_var->IsType<framework::LoDTensor>()) {
588
      auto& grad = grad_var->Get<framework::LoDTensor>();
589 590 591
      if (platform::is_gpu_place(ctx.GetPlace()) &&
          beta1_pow.place() == platform::CPUPlace() &&
          beta2_pow.place() == platform::CPUPlace()) {
592
        LambMomentREGUpdateFunctor<T, IsMultiPrecision> moment_update_functor(
593 594 595 596 597 598 599
            weight_decay,
            beta1,
            beta2,
            epsilon,
            *beta1_pow.template data<MT>(),
            *beta2_pow.template data<MT>(),
            mom1.template data<MT>(),
600 601 602 603 604 605
            mom1_out.template mutable_data<MT>(ctx.GetPlace()),
            mom2.template data<MT>(),
            mom2_out.template mutable_data<MT>(ctx.GetPlace()),
            grad.template data<T>(),
            static_cast<const MT*>(IsMultiPrecision ? master_param_ptr
                                                    : param_ptr),
606 607
            trust_ratio_div_ptr,
            skip_update_flag);
608
        for_range(moment_update_functor);
609 610 611 612
        beta1_pow_out.template mutable_data<MT>(platform::CPUPlace())[0] =
            beta1 * beta1_pow.template data<MT>()[0];
        beta2_pow_out.template mutable_data<MT>(platform::CPUPlace())[0] =
            beta2 * beta2_pow.template data<MT>()[0];
613
      } else {
614 615 616 617 618 619 620
        beta1_pow_ptr = beta1_pow.template data<MT>();
        beta2_pow_ptr = beta2_pow.template data<MT>();
        beta1_pow_out_ptr =
            beta1_pow_out.template mutable_data<MT>(ctx.GetPlace());
        beta2_pow_out_ptr =
            beta2_pow_out.template mutable_data<MT>(ctx.GetPlace());
        should_update_beta_pow_later = true;
621
        LambMomentMENUpdateFunctor<T, IsMultiPrecision> moment_update_functor(
622 623 624 625
            weight_decay,
            beta1,
            beta2,
            epsilon,
626
            static_cast<const MT*>(beta1_pow_ptr),
627 628
            static_cast<const MT*>(beta2_pow_ptr),
            mom1.template data<MT>(),
629 630 631 632 633 634
            mom1_out.template mutable_data<MT>(ctx.GetPlace()),
            mom2.template data<MT>(),
            mom2_out.template mutable_data<MT>(ctx.GetPlace()),
            grad.template data<T>(),
            static_cast<const MT*>(IsMultiPrecision ? master_param_ptr
                                                    : param_ptr),
635 636
            trust_ratio_div_ptr,
            skip_update_flag);
637 638
        for_range(moment_update_functor);
      }
639
    } else if (grad_var->IsType<phi::SelectedRows>()) {
640 641
      PADDLE_ENFORCE_EQ(IsMultiPrecision,
                        false,
642 643
                        platform::errors::Unimplemented(
                            "SelectedRows gradient is not supported when "
644 645
                            "multi_precision=True."));
      constexpr bool kIsSameType = std::is_same<T, MT>::value;
646 647
      PADDLE_ENFORCE_EQ(kIsSameType,
                        true,
648 649 650
                        platform::errors::Unimplemented(
                            "SelectedRows gradient is not supported when "
                            "multi_precision=True."));
651 652
      auto& grad = GET_DATA_SAFELY(
          ctx.Input<phi::SelectedRows>("Grad"), "Input", "Grad", "Lamb");
Y
Yibing Liu 已提交
653 654 655 656 657 658 659 660 661 662 663 664 665 666
      if (grad.rows().size() == 0) {
        VLOG(3) << "grad row size is 0!!";
        return;
      }

      std::vector<int64_t> cpu_rows(grad.rows().begin(), grad.rows().end());
      bool is_strict_sorted = true;
      for (size_t i = 1; i < cpu_rows.size(); ++i) {
        if (cpu_rows[i - 1] >= cpu_rows[i]) {
          is_strict_sorted = false;
          break;
        }
      }

667 668
      phi::SelectedRows tmp_grad_merge;
      const phi::SelectedRows* grad_merge_ptr;
Y
Yibing Liu 已提交
669 670 671 672 673 674 675 676 677 678 679 680 681
      if (is_strict_sorted) {
        grad_merge_ptr = &grad;
      } else {
        // merge duplicated rows if any.
        // The rows of grad_merge have been sorted inside MergeAdd functor
        scatter::MergeAdd<DeviceContext, T> merge_func;
        merge_func(dev_ctx, grad, &tmp_grad_merge, true);
        grad_merge_ptr = &tmp_grad_merge;
      }

      auto& grad_merge = *grad_merge_ptr;
      auto& grad_tensor = grad_merge.value();
      const T* grad_data = grad_tensor.template data<T>();
682 683 684 685
      auto* grad_merge_rows = &grad_merge.rows();
      paddle::framework::MixVector<int64_t> mixv_grad_merge_rows(
          grad_merge_rows);
      const int64_t* rows = mixv_grad_merge_rows.Data(ctx.GetPlace());
Y
Yibing Liu 已提交
686
      auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
687 688 689 690
      if (platform::is_gpu_place(ctx.GetPlace()) &&
          beta1_pow.place() == platform::CPUPlace() &&
          beta2_pow.place() == platform::CPUPlace()) {
        SparseLambMomentREGUpdateFunctor<T> moment_update_functor(
691 692 693 694 695 696
            static_cast<T>(weight_decay),
            static_cast<T>(beta1),
            static_cast<T>(beta2),
            static_cast<T>(epsilon),
            *beta1_pow.template data<T>(),
            *beta2_pow.template data<T>(),
697
            mom1.template data<T>(),
698 699
            mom1_out.template mutable_data<T>(ctx.GetPlace()),
            mom2.template data<T>(),
700 701 702 703 704 705 706 707
            mom2_out.template mutable_data<T>(ctx.GetPlace()),
            grad_data,
            param.template data<T>(),
            trust_ratio_div.template data<T>(),
            rows,
            row_numel,
            grad_merge.rows().size(),
            skip_update_flag);
708 709
        for_range(moment_update_functor);
        beta1_pow_out.template mutable_data<T>(platform::CPUPlace())[0] =
710
            static_cast<T>(beta1) * beta1_pow.template data<T>()[0];
711
        beta2_pow_out.template mutable_data<T>(platform::CPUPlace())[0] =
712
            static_cast<T>(beta2) * beta2_pow.template data<T>()[0];
713
      } else {
714 715 716 717 718 719 720
        beta1_pow_ptr = beta1_pow.template data<MT>();
        beta2_pow_ptr = beta2_pow.template data<MT>();
        beta1_pow_out_ptr =
            beta1_pow_out.template mutable_data<MT>(ctx.GetPlace());
        beta2_pow_out_ptr =
            beta2_pow_out.template mutable_data<MT>(ctx.GetPlace());
        should_update_beta_pow_later = true;
721
        SparseLambMomentMENUpdateFunctor<T> moment_update_functor(
722 723 724 725
            static_cast<T>(weight_decay),
            static_cast<T>(beta1),
            static_cast<T>(beta2),
            static_cast<T>(epsilon),
726
            reinterpret_cast<const T*>(beta1_pow_ptr),
727 728
            reinterpret_cast<const T*>(beta2_pow_ptr),
            mom1.template data<T>(),
729 730
            mom1_out.template mutable_data<T>(ctx.GetPlace()),
            mom2.template data<T>(),
731 732 733 734 735 736 737 738
            mom2_out.template mutable_data<T>(ctx.GetPlace()),
            grad_data,
            param.template data<T>(),
            trust_ratio_div.template data<T>(),
            rows,
            row_numel,
            grad_merge.rows().size(),
            skip_update_flag);
739 740
        for_range(moment_update_functor);
      }
Y
Yibing Liu 已提交
741
    } else {
742 743 744
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Variable type not supported by lamb_op. Expect LoDTensor or "
          "SelectedRows, but got %s",
745
          framework::ToTypeName(grad_var->Type())));
Y
Yibing Liu 已提交
746 747 748
    }

    // Update parameter
749
    auto p_norm_t = ctx.AllocateTmpTensor<MT, DeviceContext>({1}, dev_ctx);
S
sneaxiy 已提交
750 751
    auto* p_norm_ptr = p_norm_t.template data<MT>();

752 753
    auto trust_ratio_div_norm_t =
        ctx.AllocateTmpTensor<MT, DeviceContext>({1}, dev_ctx);
S
sneaxiy 已提交
754
    auto* trust_ratio_div_norm_ptr = trust_ratio_div_norm_t.template data<MT>();
Y
Yibing Liu 已提交
755

756 757
    // TODO(zengjinle): remove the following Eigen operations when
    // *skip_update == true.
S
sneaxiy 已提交
758
    memory::Buffer buffer(dev_ctx.GetPlace());
759 760 761 762 763 764 765 766
    phi::funcs::SquaredL2Norm(
        dev_ctx,
        reinterpret_cast<const MT*>(IsMultiPrecision ? master_param_ptr
                                                     : param_ptr),
        p_norm_ptr,
        numel,
        &buffer);
    phi::funcs::SquaredL2Norm(
767
        dev_ctx, trust_ratio_div_ptr, trust_ratio_div_norm_ptr, numel, &buffer);
768

769 770 771 772 773 774 775 776 777 778
    if (VLOG_IS_ON(1)) {
      const auto& name = ctx.GetOp().Input("Param");
      auto pn = ToVector(p_norm_ptr, 1, dev_ctx.GetPlace());
      auto tn = ToVector(trust_ratio_div_norm_ptr, 1, dev_ctx.GetPlace());
      auto dtype =
          framework::DataTypeToString(framework::DataTypeTrait<T>::DataType());
      VLOG(1) << "Param " << dtype << " " << name << " pn = " << pn[0]
              << " , tn = " << tn[0];
    }

779 780 781
#define CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(__should_update_beta_pow)         \
  do {                                                                       \
    LambParamUpateFunctor<T, MT, IsMultiPrecision, __should_update_beta_pow> \
782 783 784 785 786 787 788 789 790
        param_update_functor(lr.template data<MT>(),                         \
                             static_cast<const T*>(param_ptr),               \
                             static_cast<const MT*>(master_param_ptr),       \
                             p_norm_ptr,                                     \
                             trust_ratio_div_ptr,                            \
                             trust_ratio_div_norm_ptr,                       \
                             static_cast<T*>(param_out_ptr),                 \
                             static_cast<MT*>(master_param_out_ptr),         \
                             skip_update_flag);                              \
791
    if (__should_update_beta_pow) {                                          \
792 793 794 795 796 797
      param_update_functor.SetBetaPows(beta1_pow_ptr,                        \
                                       beta2_pow_ptr,                        \
                                       beta1_pow_out_ptr,                    \
                                       beta2_pow_out_ptr,                    \
                                       beta1,                                \
                                       beta2);                               \
798 799 800 801 802 803 804 805 806 807 808
    }                                                                        \
    for_range(param_update_functor);                                         \
  } while (0)

    if (should_update_beta_pow_later) {
      CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(true);
    } else {
      CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(false);
    }

#undef CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC
Y
Yibing Liu 已提交
809 810 811 812 813
  }
};

}  // namespace operators
}  // namespace paddle