momentum_op.h 22.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
S
sidgoyal78 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
S
sneaxiy 已提交
16
#include <memory>
D
dzhwinter 已提交
17
#include <string>
Y
Yi Wang 已提交
18 19
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
20
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
D
dzhwinter 已提交
21 22
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
23
#include "paddle/fluid/platform/float16.h"
D
dzhwinter 已提交
24
#include "paddle/fluid/platform/for_range.h"
S
sidgoyal78 已提交
25 26 27 28

namespace paddle {
namespace operators {

D
dzhwinter 已提交
29 30 31 32 33
using framework::Tensor;
using framework::SelectedRows;
struct NoNesterov;
struct UseNesterov;

34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
namespace details {

template <typename T>
struct CPUDenseUpdater {
  template <typename G>
  void operator()(const Tensor& param, const Tensor& velocity, const T& mu,
                  const T& lr, const bool use_nesterov, G&& grad,
                  Tensor* param_out, Tensor* velocity_out) const {
    auto param_out_vec = framework::EigenVector<T>::Flatten(*param_out);
    auto velocity_out_vec = framework::EigenVector<T>::Flatten(*velocity_out);

    auto param_vec = framework::EigenVector<T>::Flatten(param);
    auto velocity_vec = framework::EigenVector<T>::Flatten(velocity);
    velocity_out_vec = velocity_vec * mu + grad;
    if (use_nesterov) {
      param_out_vec = param_vec - (grad + velocity_out_vec * mu) * lr;
    } else {
      param_out_vec = param_vec - lr * velocity_out_vec;
    }
  }
};

}  // namespace details

template <typename T>
using MultiPrecisionType = typename details::MPTypeTrait<T>::Type;

61 62 63 64 65 66
enum class RegularizationType {
  kNONE = 0,
  kL1DECAY = 1,  // do not need support right now
  kL2DECAY = 2,
};

67 68 69 70 71
class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override;
};

72 73 74 75 76 77
class MomentumOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
C
Chengmo 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
    PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true,
                      platform::errors::NotFound(
                          "Input(param) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"), true,
                      platform::errors::NotFound(
                          "Input(grad) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasInput("Velocity"), true,
                      platform::errors::NotFound(
                          "Input(velocity) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("LearningRate"), true,
        platform::errors::NotFound(
            "Input(LearningRate) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(
        ctx->GetInputsVarType("Param").front(),
        framework::proto::VarType::LOD_TENSOR,
        platform::errors::InvalidArgument(
            "The input var's type should be LoDTensor, but the received is %s",
            ctx->GetInputsVarType("Param").front()));

    PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"), true,
                      platform::errors::NotFound(
                          "Output(ParamOut) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasOutput("VelocityOut"), true,
        platform::errors::NotFound(
            "Output(VelocityOut) of Momentum should not be null."));
105

106 107
    auto lr_dims = ctx->GetInputDim("LearningRate");
    PADDLE_ENFORCE_NE(framework::product(lr_dims), 0,
C
Chengmo 已提交
108 109 110 111 112
                      platform::errors::InvalidArgument(
                          "Maybe the Input variable LearningRate has not "
                          "been initialized. You may need to confirm "
                          "if you put exe.run(startup_program) "
                          "after optimizer.minimize function."));
113
    PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
C
Chengmo 已提交
114 115 116 117
                      platform::errors::InvalidArgument(
                          "Learning_rate should be a scalar. But Received "
                          "LearningRate's dim [%s]",
                          framework::product(lr_dims)));
118

119 120 121 122 123
    auto param_dim = ctx->GetInputDim("Param");
    if (ctx->GetInputsVarType("Grad")[0] ==
        framework::proto::VarType::LOD_TENSOR) {
      PADDLE_ENFORCE_EQ(
          param_dim, ctx->GetInputDim("Grad"),
C
Chengmo 已提交
124 125 126 127
          platform::errors::InvalidArgument(
              "Param and Grad input of MomentumOp should have the same "
              "dimension. But received Param's dim [%s] and Grad's dim [%s].",
              param_dim, ctx->GetInputDim("Grad")));
128 129
      PADDLE_ENFORCE_EQ(
          param_dim, ctx->GetInputDim("Velocity"),
C
Chengmo 已提交
130 131 132 133
          platform::errors::InvalidArgument(
              "Param and Velocity of MomentumOp should have the same "
              "dimension. But received Param's dim [%s] and Velocity [%s].",
              param_dim, ctx->GetInputDim("Velocity")));
134 135 136 137
    }

    ctx->SetOutputDim("ParamOut", param_dim);
    ctx->SetOutputDim("VelocityOut", param_dim);
138 139 140
    if (ctx->HasOutput("MasterParamOut")) {
      ctx->SetOutputDim("MasterParamOut", param_dim);
    }
141
  }
S
sneaxiy 已提交
142

143 144
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
145 146
    auto input_data_type =
        OperatorWithKernel::IndicateVarDataType(ctx, "Param");
147 148 149 150
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
};

D
dzhwinter 已提交
151 152 153
template <typename T>
class CPUDenseMomentumFunctor {
 public:
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
  void operator()(const Tensor* param, const Tensor* grad,
                  const Tensor* velocity, const Tensor* learning_rate,
                  const T mu, const bool use_nesterov,
                  const RegularizationType regularization_flag,
                  const T regularization_coeff, Tensor* param_out,
                  Tensor* velocity_out) {
    auto grad_vec = framework::EigenVector<T>::Flatten(*grad);
    auto* lr = learning_rate->data<MultiPrecisionType<T>>();

    details::CPUDenseUpdater<T> updater;
    if (regularization_flag == RegularizationType::kL2DECAY) {
      auto param_vec = framework::EigenVector<T>::Flatten(*param);
      updater(*param, *velocity, mu, static_cast<T>(lr[0]), use_nesterov,
              param_vec * regularization_coeff + grad_vec, param_out,
              velocity_out);
D
dzhwinter 已提交
169
    } else {
170 171
      updater(*param, *velocity, mu, static_cast<T>(lr[0]), use_nesterov,
              grad_vec, param_out, velocity_out);
D
dzhwinter 已提交
172 173 174 175
    }
  }
};

Z
Zeng Jinle 已提交
176 177
template <typename T, typename MT, RegularizationType kRegType,
          typename UpdateMethod>
D
dzhwinter 已提交
178 179 180 181 182
class DenseMomentumFunctor;

// NOTE(dzh) for performance.
// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two
// functor.
Z
Zeng Jinle 已提交
183 184
template <typename T, typename MT, RegularizationType kRegType>
class DenseMomentumFunctor<T, MT, kRegType, UseNesterov> {
D
dzhwinter 已提交
185
 private:
186 187
  const T* param_;
  const T* grad_;
188 189 190 191 192
  const MT* velocity_;
  const MultiPrecisionType<MT>* lr_;
  const MT* master_param_;
  const MT mu_;
  const MT rescale_grad_;
D
dzhwinter 已提交
193
  const int64_t num_;
194
  T* param_out_;
195 196 197
  MT* velocity_out_;
  MT* master_param_out_;
  const MT regularization_coeff_;
D
dzhwinter 已提交
198 199

 public:
200 201 202 203 204 205
  DenseMomentumFunctor(const T* param, const T* grad, const MT* velocity,
                       const MultiPrecisionType<MT>* learning_rate,
                       const MT* master_param, const MT mu,
                       const MT rescale_grad, const int64_t num,
                       const MT regularization_coeff, T* param_out,
                       MT* velocity_out, MT* master_param_out)
206 207 208
      : param_(param),
        grad_(grad),
        velocity_(velocity),
D
dzhwinter 已提交
209
        lr_(learning_rate),
210
        master_param_(master_param),
D
dzhwinter 已提交
211
        mu_(mu),
212
        rescale_grad_(rescale_grad),
D
dzhwinter 已提交
213
        num_(num),
214 215
        param_out_(param_out),
        velocity_out_(velocity_out),
216
        master_param_out_(master_param_out),
217
        regularization_coeff_(regularization_coeff) {}
D
dzhwinter 已提交
218 219
  inline HOSTDEVICE void operator()(size_t i) const {
    // put memory access in register
220 221 222 223 224
    const MT param =
        master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
    MT grad = static_cast<MT>(grad_[i]) * rescale_grad_;
    const MT lr = static_cast<MT>(lr_[0]);
    const MT velocity = velocity_[i];
225

Z
Zeng Jinle 已提交
226 227 228
    if (kRegType == RegularizationType::kL2DECAY) {
      grad += regularization_coeff_ * param;
    }
229

230 231
    MT velocity_out = velocity * mu_ + grad;
    MT param_out = param - (grad + velocity_out * mu_) * lr;
D
dzhwinter 已提交
232
    // write reigster to memory
233
    velocity_out_[i] = velocity_out;
234 235 236 237
    param_out_[i] = static_cast<T>(param_out);
    if (master_param_out_) {
      master_param_out_[i] = param_out;
    }
D
dzhwinter 已提交
238 239 240
  }
};

Z
Zeng Jinle 已提交
241 242
template <typename T, typename MT, RegularizationType kRegType>
class DenseMomentumFunctor<T, MT, kRegType, NoNesterov> {
D
dzhwinter 已提交
243
 private:
244 245
  const T* param_;
  const T* grad_;
246 247 248 249 250
  const MT* velocity_;
  const MultiPrecisionType<MT>* lr_;
  const MT* master_param_;
  const MT mu_;
  const MT rescale_grad_;
D
dzhwinter 已提交
251
  const int64_t num_;
252
  T* param_out_;
253 254 255
  MT* velocity_out_;
  MT* master_param_out_;
  const MT regularization_coeff_;
D
dzhwinter 已提交
256 257

 public:
258 259 260 261 262 263
  DenseMomentumFunctor(const T* param, const T* grad, const MT* velocity,
                       const MultiPrecisionType<MT>* learning_rate,
                       const MT* master_param, const MT mu,
                       const MT rescale_grad, const int64_t num,
                       const MT regularization_coeff, T* param_out,
                       MT* velocity_out, MT* master_param_out)
264 265 266
      : param_(param),
        grad_(grad),
        velocity_(velocity),
D
dzhwinter 已提交
267
        lr_(learning_rate),
268
        master_param_(master_param),
D
dzhwinter 已提交
269
        mu_(mu),
270
        rescale_grad_(rescale_grad),
D
dzhwinter 已提交
271
        num_(num),
272 273
        param_out_(param_out),
        velocity_out_(velocity_out),
274
        master_param_out_(master_param_out),
275
        regularization_coeff_(regularization_coeff) {}
D
dzhwinter 已提交
276 277
  inline HOSTDEVICE void operator()(size_t i) const {
    // put memory access in register
278 279 280 281 282
    const MT param =
        master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
    MT grad = static_cast<MT>(grad_[i]) * rescale_grad_;
    const MT lr = static_cast<MT>(lr_[0]);
    const MT velocity = velocity_[i];
283

Z
Zeng Jinle 已提交
284 285 286
    if (kRegType == RegularizationType::kL2DECAY) {
      grad += regularization_coeff_ * param;
    }
287

288 289
    MT velocity_out = velocity * mu_ + grad;
    MT param_out = param - lr * velocity_out;
D
dzhwinter 已提交
290
    // write reigster to memory
291
    velocity_out_[i] = velocity_out;
292 293 294 295
    param_out_[i] = static_cast<T>(param_out);
    if (master_param_out_) {
      master_param_out_[i] = param_out;
    }
D
dzhwinter 已提交
296 297 298
  }
};

299
template <typename T, typename MT, typename UpdateMethod>
D
dzhwinter 已提交
300 301
class SparseMomentumFunctor;

302 303
template <typename T, typename MT>
class SparseMomentumFunctor<T, MT, UseNesterov> {
D
dzhwinter 已提交
304
 private:
305 306
  const T* param_;
  const T* grad_;
307 308 309 310 311
  const MT* velocity_;
  const MultiPrecisionType<MT>* lr_;
  const MT* master_param_;
  const MT mu_;
  const MT rescale_grad_;
D
dzhwinter 已提交
312 313 314
  const int64_t* rows_;
  const int64_t row_numel_;
  const int64_t row_height_;
315
  T* param_out_;
316 317 318 319
  MT* velocity_out_;
  MT* master_param_out_;
  const RegularizationType regularization_flag_;
  const MT regularization_coeff_;
D
dzhwinter 已提交
320 321

 public:
322 323 324 325
  SparseMomentumFunctor(const T* param, const T* grad, const MT* velocity,
                        const MultiPrecisionType<MT>* lr,
                        const MT* master_param, const MT mu,
                        const MT rescale_grad, const int64_t* rows,
326
                        int64_t row_numel, int64_t row_height,
327 328 329
                        const RegularizationType regularization_flag,
                        const MT regularization_coeff, T* param_out,
                        MT* velocity_out, MT* master_param_out)
330 331 332
      : param_(param),
        grad_(grad),
        velocity_(velocity),
D
dzhwinter 已提交
333
        lr_(lr),
334
        master_param_(master_param),
D
dzhwinter 已提交
335
        mu_(mu),
336
        rescale_grad_(rescale_grad),
D
dzhwinter 已提交
337 338 339
        rows_(rows),
        row_numel_(row_numel),
        row_height_(row_height),
340 341
        param_out_(param_out),
        velocity_out_(velocity_out),
342
        master_param_out_(master_param_out),
343 344
        regularization_flag_(regularization_flag),
        regularization_coeff_(regularization_coeff) {}
D
dzhwinter 已提交
345 346 347 348

  inline HOSTDEVICE void operator()(size_t i) {
    auto row_idx =
        math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
349 350 351 352 353
    MT grad =
        row_idx >= 0
            ? static_cast<MT>(grad_[row_idx * row_numel_ + i % row_numel_]) *
                  rescale_grad_
            : static_cast<MT>(0);
D
dzhwinter 已提交
354
    // put memory access in register
355 356 357 358
    const MT param =
        master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
    const MT lr = static_cast<MT>(lr_[0]);
    const MT velocity = velocity_[i];
359 360 361 362 363

    grad = regularization_flag_ == RegularizationType::kL2DECAY
               ? grad + regularization_coeff_ * param
               : grad;

364 365
    MT velocity_out = velocity * mu_ + grad;
    MT param_out = param - (grad + velocity_out * mu_) * lr;
D
dzhwinter 已提交
366
    // write reigster to memory
367
    velocity_out_[i] = velocity_out;
368 369 370 371
    param_out_[i] = static_cast<T>(param_out);
    if (master_param_out_) {
      master_param_out_[i] = param_out;
    }
D
dzhwinter 已提交
372 373 374
  }
};

375 376
template <typename T, typename MT>
class SparseMomentumFunctor<T, MT, NoNesterov> {
D
dzhwinter 已提交
377
 private:
378 379
  const T* param_;
  const T* grad_;
380 381 382 383 384
  const MT* velocity_;
  const MultiPrecisionType<MT>* lr_;
  const MT* master_param_;
  const MT mu_;
  const MT rescale_grad_;
D
dzhwinter 已提交
385 386 387
  const int64_t* rows_;
  const int64_t row_numel_;
  const int64_t row_height_;
388
  T* param_out_;
389 390 391 392
  MT* velocity_out_;
  MT* master_param_out_;
  const RegularizationType regularization_flag_;
  const MT regularization_coeff_;
D
dzhwinter 已提交
393 394

 public:
395 396 397 398
  SparseMomentumFunctor(const T* param, const T* grad, const MT* velocity,
                        const MultiPrecisionType<MT>* lr,
                        const MT* master_param, const MT mu,
                        const MT rescale_grad, const int64_t* rows,
399
                        int64_t row_numel, int64_t row_height,
400 401 402
                        const RegularizationType regularization_flag,
                        const MT regularization_coeff, T* param_out,
                        MT* velocity_out, MT* master_param_out)
403 404 405
      : param_(param),
        grad_(grad),
        velocity_(velocity),
D
dzhwinter 已提交
406
        lr_(lr),
407
        master_param_(master_param),
D
dzhwinter 已提交
408
        mu_(mu),
409
        rescale_grad_(rescale_grad),
D
dzhwinter 已提交
410 411 412
        rows_(rows),
        row_numel_(row_numel),
        row_height_(row_height),
413 414
        param_out_(param_out),
        velocity_out_(velocity_out),
415
        master_param_out_(master_param_out),
416 417
        regularization_flag_(regularization_flag),
        regularization_coeff_(regularization_coeff) {}
D
dzhwinter 已提交
418 419 420 421

  inline HOSTDEVICE void operator()(size_t i) {
    auto row_idx =
        math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
422 423 424 425 426
    MT grad =
        row_idx >= 0
            ? static_cast<MT>(grad_[row_idx * row_numel_ + i % row_numel_]) *
                  rescale_grad_
            : static_cast<MT>(0);
D
dzhwinter 已提交
427
    // put memory access in register
428 429 430 431
    const MT param =
        master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
    const MT lr = static_cast<MT>(lr_[0]);
    const MT velocity = velocity_[i];
432 433 434 435 436

    grad = regularization_flag_ == RegularizationType::kL2DECAY
               ? grad + regularization_coeff_ * param
               : grad;

437 438
    MT velocity_out = velocity * mu_ + grad;
    MT param_out = param - velocity_out * lr;
D
dzhwinter 已提交
439
    // write reigster to memory
440
    velocity_out_[i] = velocity_out;
441 442 443 444
    param_out_[i] = static_cast<T>(param_out);
    if (master_param_out_) {
      master_param_out_[i] = param_out;
    }
D
dzhwinter 已提交
445 446 447 448
  }
};

template <typename DeviceContext, typename T>
S
sidgoyal78 已提交
449
class MomentumOpKernel : public framework::OpKernel<T> {
450 451
  using MPDType = MultiPrecisionType<T>;

S
sidgoyal78 已提交
452 453
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
454 455 456 457 458
    const bool multi_precision = ctx.Attr<bool>("multi_precision");
    if (multi_precision) {
      InnerCompute<MPDType>(ctx, multi_precision);
    } else {
      InnerCompute<T>(ctx, multi_precision);
459
    }
460
  }
461

462 463 464 465 466 467 468 469
 private:
  template <typename MT>
  void InnerCompute(const framework::ExecutionContext& ctx,
                    const bool multi_precision) const {
    std::string regularization_method =
        ctx.Attr<std::string>("regularization_method");
    MT regularization_coeff =
        static_cast<MT>(ctx.Attr<float>("regularization_coeff"));
470 471 472 473 474 475
    RegularizationType regularization_flag{
        RegularizationType::kNONE};  // disable regularization
    if (regularization_method == "l2_decay") {
      regularization_flag = RegularizationType::kL2DECAY;
    }

476 477
    MT mu = static_cast<MT>(ctx.Attr<float>("mu"));
    MT rescale_grad = static_cast<MT>(ctx.Attr<float>("rescale_grad"));
478
    bool use_nesterov = ctx.Attr<bool>("use_nesterov");
S
sidgoyal78 已提交
479

480 481 482
    auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
    auto param = ctx.Input<framework::Tensor>("Param");
    auto param_out = ctx.Output<framework::Tensor>("ParamOut");
483
    auto velocity = ctx.Input<framework::Tensor>("Velocity");
D
dzhwinter 已提交
484
    auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
485

486 487 488 489 490 491 492 493 494 495 496 497 498 499
    const framework::Tensor* master_param = nullptr;
    framework::Tensor* master_param_out = nullptr;
    if (multi_precision) {
      bool has_master =
          ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
      PADDLE_ENFORCE_EQ(has_master, true,
                        platform::errors::InvalidArgument(
                            "The Input(MasterParam) and Output(MasterParamOut) "
                            "should not be null when "
                            "the attr `multi_precision` is true"));
      master_param = ctx.Input<framework::Tensor>("MasterParam");
      master_param_out = ctx.Output<framework::Tensor>("MasterParamOut");
    }

D
dzhwinter 已提交
500
    param_out->mutable_data<T>(ctx.GetPlace());
501 502 503 504 505 506
    velocity_out->mutable_data<MT>(ctx.GetPlace());
    const MT* master_in_data =
        multi_precision ? master_param->data<MT>() : nullptr;
    MT* master_out_data =
        multi_precision ? master_param_out->mutable_data<MT>(ctx.GetPlace())
                        : nullptr;
D
dzhwinter 已提交
507

508 509 510
    auto* grad_var = ctx.InputVar("Grad");
    if (grad_var->IsType<framework::LoDTensor>()) {
      auto grad = ctx.Input<framework::Tensor>("Grad");
D
dzhwinter 已提交
511
      if (platform::is_cpu_place(ctx.GetPlace())) {
512 513 514 515
        CPUDenseMomentumFunctor<MT> functor;
        functor(param, grad, velocity, learning_rate, mu, use_nesterov,
                regularization_flag, regularization_coeff, param_out,
                velocity_out);
D
dzhwinter 已提交
516 517 518 519
      } else if (platform::is_gpu_place(ctx.GetPlace())) {
        platform::ForRange<DeviceContext> for_range(
            static_cast<const DeviceContext&>(ctx.device_context()),
            param->numel());
Z
Zeng Jinle 已提交
520 521 522 523 524 525 526 527
#define PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(__nesterov, __reg_type)     \
  DenseMomentumFunctor<T, MT, __reg_type, __nesterov> functor(          \
      param->data<T>(), grad->data<T>(), velocity->data<MT>(),          \
      learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad, \
      param->numel(), regularization_coeff,                             \
      param_out->mutable_data<T>(ctx.GetPlace()),                       \
      velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data); \
  for_range(functor);
D
dzhwinter 已提交
528

Z
Zeng Jinle 已提交
529 530 531 532 533 534 535 536
        if (use_nesterov) {
          if (regularization_flag == RegularizationType::kL2DECAY) {
            PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(UseNesterov,
                                                RegularizationType::kL2DECAY);
          } else {
            PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(UseNesterov,
                                                RegularizationType::kNONE);
          }
D
dzhwinter 已提交
537
        } else {
Z
Zeng Jinle 已提交
538 539 540 541 542 543 544
          if (regularization_flag == RegularizationType::kL2DECAY) {
            PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(NoNesterov,
                                                RegularizationType::kL2DECAY);
          } else {
            PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(NoNesterov,
                                                RegularizationType::kNONE);
          }
D
dzhwinter 已提交
545
        }
546
      }
D
dzhwinter 已提交
547

548 549 550
    } else if (grad_var->IsType<framework::SelectedRows>()) {
      // sparse update embedding with selectedrows
      auto grad = ctx.Input<framework::SelectedRows>("Grad");
S
sidgoyal78 已提交
551

552 553
      // sparse update maybe empty.
      if (grad->rows().size() == 0) {
M
minqiyang 已提交
554
        VLOG(3) << "Grad SelectedRows contains no data!";
555 556
        return;
      }
S
sneaxiy 已提交
557 558 559

      framework::SelectedRows tmp_merged_grad;
      framework::SelectedRows* merged_grad = &tmp_merged_grad;
D
dzhwinter 已提交
560 561 562 563
      math::scatter::MergeAdd<DeviceContext, T> merge_func;
      merge_func(ctx.template device_context<DeviceContext>(), *grad,
                 merged_grad);

S
sneaxiy 已提交
564
      const int64_t* rows = merged_grad->rows().Data(ctx.GetPlace());
D
dzhwinter 已提交
565 566 567 568 569
      int64_t row_numel =
          merged_grad->value().numel() / merged_grad->rows().size();
      platform::ForRange<DeviceContext> for_range(
          static_cast<const DeviceContext&>(ctx.device_context()),
          param->numel());
D
dzhwinter 已提交
570
      if (use_nesterov) {
571
        SparseMomentumFunctor<T, MT, UseNesterov> functor(
D
dzhwinter 已提交
572
            param->data<T>(), merged_grad->value().data<T>(),
573 574
            velocity->data<MT>(), learning_rate->data<MPDType>(),
            master_in_data, mu, rescale_grad, rows, row_numel,
D
dzhwinter 已提交
575
            static_cast<int64_t>(merged_grad->rows().size()),
576
            regularization_flag, regularization_coeff,
D
dzhwinter 已提交
577
            param_out->mutable_data<T>(ctx.GetPlace()),
578
            velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
D
dzhwinter 已提交
579 580 581
        for_range(functor);

      } else {
582
        SparseMomentumFunctor<T, MT, NoNesterov> functor(
D
dzhwinter 已提交
583
            param->data<T>(), merged_grad->value().data<T>(),
584 585
            velocity->data<MT>(), learning_rate->data<MPDType>(),
            master_in_data, mu, rescale_grad, rows, row_numel,
D
dzhwinter 已提交
586
            static_cast<int64_t>(merged_grad->rows().size()),
587
            regularization_flag, regularization_coeff,
D
dzhwinter 已提交
588
            param_out->mutable_data<T>(ctx.GetPlace()),
589
            velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
D
dzhwinter 已提交
590
        for_range(functor);
591
      }
K
kavyasrinet 已提交
592
    } else {
C
Chengmo 已提交
593 594 595 596 597 598
      PADDLE_ENFORCE_EQ(false, true,
                        platform::errors::PermissionDenied(
                            "Unsupported Variable Type of Grad "
                            "in MomentumOp. Excepted LodTensor "
                            "or SelectedRows, But received [%s]",
                            paddle::framework::ToTypeName(grad_var->Type())));
K
kavyasrinet 已提交
599
    }
S
sidgoyal78 已提交
600 601 602 603 604
  }
};

}  // namespace operators
}  // namespace paddle