momentum_op.h 22.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
S
sidgoyal78 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
S
sneaxiy 已提交
16
#include <memory>
D
dzhwinter 已提交
17
#include <string>
Y
Yi Wang 已提交
18 19
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
20
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
D
dzhwinter 已提交
21 22
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
23
#include "paddle/fluid/platform/float16.h"
D
dzhwinter 已提交
24
#include "paddle/fluid/platform/for_range.h"
S
sidgoyal78 已提交
25 26 27 28

namespace paddle {
namespace operators {

D
dzhwinter 已提交
29 30 31 32 33
using framework::Tensor;
using framework::SelectedRows;
struct NoNesterov;
struct UseNesterov;

34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
namespace details {

template <typename T>
struct CPUDenseUpdater {
  template <typename G>
  void operator()(const Tensor& param, const Tensor& velocity, const T& mu,
                  const T& lr, const bool use_nesterov, G&& grad,
                  Tensor* param_out, Tensor* velocity_out) const {
    auto param_out_vec = framework::EigenVector<T>::Flatten(*param_out);
    auto velocity_out_vec = framework::EigenVector<T>::Flatten(*velocity_out);

    auto param_vec = framework::EigenVector<T>::Flatten(param);
    auto velocity_vec = framework::EigenVector<T>::Flatten(velocity);
    velocity_out_vec = velocity_vec * mu + grad;
    if (use_nesterov) {
      param_out_vec = param_vec - (grad + velocity_out_vec * mu) * lr;
    } else {
      param_out_vec = param_vec - lr * velocity_out_vec;
    }
  }
};

}  // namespace details

template <typename T>
using MultiPrecisionType = typename details::MPTypeTrait<T>::Type;

61 62 63 64 65 66
enum class RegularizationType {
  kNONE = 0,
  kL1DECAY = 1,  // do not need support right now
  kL2DECAY = 2,
};

67 68 69 70 71
class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override;
};

72 73 74 75 76 77
class MomentumOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
C
Chengmo 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
    PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true,
                      platform::errors::NotFound(
                          "Input(param) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"), true,
                      platform::errors::NotFound(
                          "Input(grad) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasInput("Velocity"), true,
                      platform::errors::NotFound(
                          "Input(velocity) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("LearningRate"), true,
        platform::errors::NotFound(
            "Input(LearningRate) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(
        ctx->GetInputsVarType("Param").front(),
        framework::proto::VarType::LOD_TENSOR,
        platform::errors::InvalidArgument(
            "The input var's type should be LoDTensor, but the received is %s",
            ctx->GetInputsVarType("Param").front()));

    PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"), true,
                      platform::errors::NotFound(
                          "Output(ParamOut) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasOutput("VelocityOut"), true,
        platform::errors::NotFound(
            "Output(VelocityOut) of Momentum should not be null."));
105

106 107
    auto lr_dims = ctx->GetInputDim("LearningRate");
    PADDLE_ENFORCE_NE(framework::product(lr_dims), 0,
C
Chengmo 已提交
108 109 110 111 112
                      platform::errors::InvalidArgument(
                          "Maybe the Input variable LearningRate has not "
                          "been initialized. You may need to confirm "
                          "if you put exe.run(startup_program) "
                          "after optimizer.minimize function."));
113
    PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
C
Chengmo 已提交
114 115 116 117
                      platform::errors::InvalidArgument(
                          "Learning_rate should be a scalar. But Received "
                          "LearningRate's dim [%s]",
                          framework::product(lr_dims)));
118

119 120 121 122 123
    auto param_dim = ctx->GetInputDim("Param");
    if (ctx->GetInputsVarType("Grad")[0] ==
        framework::proto::VarType::LOD_TENSOR) {
      PADDLE_ENFORCE_EQ(
          param_dim, ctx->GetInputDim("Grad"),
C
Chengmo 已提交
124 125 126 127
          platform::errors::InvalidArgument(
              "Param and Grad input of MomentumOp should have the same "
              "dimension. But received Param's dim [%s] and Grad's dim [%s].",
              param_dim, ctx->GetInputDim("Grad")));
128 129
      PADDLE_ENFORCE_EQ(
          param_dim, ctx->GetInputDim("Velocity"),
C
Chengmo 已提交
130 131 132 133
          platform::errors::InvalidArgument(
              "Param and Velocity of MomentumOp should have the same "
              "dimension. But received Param's dim [%s] and Velocity [%s].",
              param_dim, ctx->GetInputDim("Velocity")));
134 135 136 137
    }

    ctx->SetOutputDim("ParamOut", param_dim);
    ctx->SetOutputDim("VelocityOut", param_dim);
138 139 140
    if (ctx->HasOutput("MasterParamOut")) {
      ctx->SetOutputDim("MasterParamOut", param_dim);
    }
141
  }
S
sneaxiy 已提交
142

143 144
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
145 146
    auto input_data_type =
        OperatorWithKernel::IndicateVarDataType(ctx, "Param");
147 148 149 150
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
};

D
dzhwinter 已提交
151 152 153
template <typename T>
class CPUDenseMomentumFunctor {
 public:
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
  void operator()(const Tensor* param, const Tensor* grad,
                  const Tensor* velocity, const Tensor* learning_rate,
                  const T mu, const bool use_nesterov,
                  const RegularizationType regularization_flag,
                  const T regularization_coeff, Tensor* param_out,
                  Tensor* velocity_out) {
    auto grad_vec = framework::EigenVector<T>::Flatten(*grad);
    auto* lr = learning_rate->data<MultiPrecisionType<T>>();

    details::CPUDenseUpdater<T> updater;
    if (regularization_flag == RegularizationType::kL2DECAY) {
      auto param_vec = framework::EigenVector<T>::Flatten(*param);
      updater(*param, *velocity, mu, static_cast<T>(lr[0]), use_nesterov,
              param_vec * regularization_coeff + grad_vec, param_out,
              velocity_out);
D
dzhwinter 已提交
169
    } else {
170 171
      updater(*param, *velocity, mu, static_cast<T>(lr[0]), use_nesterov,
              grad_vec, param_out, velocity_out);
D
dzhwinter 已提交
172 173 174 175
    }
  }
};

176
template <typename T, typename MT, typename UpdateMethod>
D
dzhwinter 已提交
177 178 179 180 181
class DenseMomentumFunctor;

// NOTE(dzh) for performance.
// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two
// functor.
182 183
template <typename T, typename MT>
class DenseMomentumFunctor<T, MT, UseNesterov> {
D
dzhwinter 已提交
184
 private:
185 186
  const T* param_;
  const T* grad_;
187 188 189 190 191
  const MT* velocity_;
  const MultiPrecisionType<MT>* lr_;
  const MT* master_param_;
  const MT mu_;
  const MT rescale_grad_;
D
dzhwinter 已提交
192
  const int64_t num_;
193
  T* param_out_;
194 195 196 197
  MT* velocity_out_;
  MT* master_param_out_;
  const RegularizationType regularization_flag_;
  const MT regularization_coeff_;
D
dzhwinter 已提交
198 199

 public:
200 201 202 203 204 205 206
  DenseMomentumFunctor(const T* param, const T* grad, const MT* velocity,
                       const MultiPrecisionType<MT>* learning_rate,
                       const MT* master_param, const MT mu,
                       const MT rescale_grad, const int64_t num,
                       const RegularizationType regularization_flag,
                       const MT regularization_coeff, T* param_out,
                       MT* velocity_out, MT* master_param_out)
207 208 209
      : param_(param),
        grad_(grad),
        velocity_(velocity),
D
dzhwinter 已提交
210
        lr_(learning_rate),
211
        master_param_(master_param),
D
dzhwinter 已提交
212
        mu_(mu),
213
        rescale_grad_(rescale_grad),
D
dzhwinter 已提交
214
        num_(num),
215 216
        param_out_(param_out),
        velocity_out_(velocity_out),
217
        master_param_out_(master_param_out),
218 219
        regularization_flag_(regularization_flag),
        regularization_coeff_(regularization_coeff) {}
D
dzhwinter 已提交
220 221
  inline HOSTDEVICE void operator()(size_t i) const {
    // put memory access in register
222 223 224 225 226
    const MT param =
        master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
    MT grad = static_cast<MT>(grad_[i]) * rescale_grad_;
    const MT lr = static_cast<MT>(lr_[0]);
    const MT velocity = velocity_[i];
227 228 229 230 231

    grad = regularization_flag_ == RegularizationType::kL2DECAY
               ? grad + regularization_coeff_ * param
               : grad;

232 233
    MT velocity_out = velocity * mu_ + grad;
    MT param_out = param - (grad + velocity_out * mu_) * lr;
D
dzhwinter 已提交
234
    // write reigster to memory
235
    velocity_out_[i] = velocity_out;
236 237 238 239
    param_out_[i] = static_cast<T>(param_out);
    if (master_param_out_) {
      master_param_out_[i] = param_out;
    }
D
dzhwinter 已提交
240 241 242
  }
};

243 244
template <typename T, typename MT>
class DenseMomentumFunctor<T, MT, NoNesterov> {
D
dzhwinter 已提交
245
 private:
246 247
  const T* param_;
  const T* grad_;
248 249 250 251 252
  const MT* velocity_;
  const MultiPrecisionType<MT>* lr_;
  const MT* master_param_;
  const MT mu_;
  const MT rescale_grad_;
D
dzhwinter 已提交
253
  const int64_t num_;
254
  T* param_out_;
255 256 257 258
  MT* velocity_out_;
  MT* master_param_out_;
  const RegularizationType regularization_flag_;
  const MT regularization_coeff_;
D
dzhwinter 已提交
259 260

 public:
261 262 263 264 265 266 267
  DenseMomentumFunctor(const T* param, const T* grad, const MT* velocity,
                       const MultiPrecisionType<MT>* learning_rate,
                       const MT* master_param, const MT mu,
                       const MT rescale_grad, const int64_t num,
                       const RegularizationType regularization_flag,
                       const MT regularization_coeff, T* param_out,
                       MT* velocity_out, MT* master_param_out)
268 269 270
      : param_(param),
        grad_(grad),
        velocity_(velocity),
D
dzhwinter 已提交
271
        lr_(learning_rate),
272
        master_param_(master_param),
D
dzhwinter 已提交
273
        mu_(mu),
274
        rescale_grad_(rescale_grad),
D
dzhwinter 已提交
275
        num_(num),
276 277
        param_out_(param_out),
        velocity_out_(velocity_out),
278
        master_param_out_(master_param_out),
279 280
        regularization_flag_(regularization_flag),
        regularization_coeff_(regularization_coeff) {}
D
dzhwinter 已提交
281 282
  inline HOSTDEVICE void operator()(size_t i) const {
    // put memory access in register
283 284 285 286 287
    const MT param =
        master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
    MT grad = static_cast<MT>(grad_[i]) * rescale_grad_;
    const MT lr = static_cast<MT>(lr_[0]);
    const MT velocity = velocity_[i];
288 289 290 291 292

    grad = regularization_flag_ == RegularizationType::kL2DECAY
               ? grad + regularization_coeff_ * param
               : grad;

293 294
    MT velocity_out = velocity * mu_ + grad;
    MT param_out = param - lr * velocity_out;
D
dzhwinter 已提交
295
    // write reigster to memory
296
    velocity_out_[i] = velocity_out;
297 298 299 300
    param_out_[i] = static_cast<T>(param_out);
    if (master_param_out_) {
      master_param_out_[i] = param_out;
    }
D
dzhwinter 已提交
301 302 303
  }
};

304
template <typename T, typename MT, typename UpdateMethod>
D
dzhwinter 已提交
305 306
class SparseMomentumFunctor;

307 308
template <typename T, typename MT>
class SparseMomentumFunctor<T, MT, UseNesterov> {
D
dzhwinter 已提交
309
 private:
310 311
  const T* param_;
  const T* grad_;
312 313 314 315 316
  const MT* velocity_;
  const MultiPrecisionType<MT>* lr_;
  const MT* master_param_;
  const MT mu_;
  const MT rescale_grad_;
D
dzhwinter 已提交
317 318 319
  const int64_t* rows_;
  const int64_t row_numel_;
  const int64_t row_height_;
320
  T* param_out_;
321 322 323 324
  MT* velocity_out_;
  MT* master_param_out_;
  const RegularizationType regularization_flag_;
  const MT regularization_coeff_;
D
dzhwinter 已提交
325 326

 public:
327 328 329 330
  SparseMomentumFunctor(const T* param, const T* grad, const MT* velocity,
                        const MultiPrecisionType<MT>* lr,
                        const MT* master_param, const MT mu,
                        const MT rescale_grad, const int64_t* rows,
331
                        int64_t row_numel, int64_t row_height,
332 333 334
                        const RegularizationType regularization_flag,
                        const MT regularization_coeff, T* param_out,
                        MT* velocity_out, MT* master_param_out)
335 336 337
      : param_(param),
        grad_(grad),
        velocity_(velocity),
D
dzhwinter 已提交
338
        lr_(lr),
339
        master_param_(master_param),
D
dzhwinter 已提交
340
        mu_(mu),
341
        rescale_grad_(rescale_grad),
D
dzhwinter 已提交
342 343 344
        rows_(rows),
        row_numel_(row_numel),
        row_height_(row_height),
345 346
        param_out_(param_out),
        velocity_out_(velocity_out),
347
        master_param_out_(master_param_out),
348 349
        regularization_flag_(regularization_flag),
        regularization_coeff_(regularization_coeff) {}
D
dzhwinter 已提交
350 351 352 353

  inline HOSTDEVICE void operator()(size_t i) {
    auto row_idx =
        math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
354 355 356 357 358
    MT grad =
        row_idx >= 0
            ? static_cast<MT>(grad_[row_idx * row_numel_ + i % row_numel_]) *
                  rescale_grad_
            : static_cast<MT>(0);
D
dzhwinter 已提交
359
    // put memory access in register
360 361 362 363
    const MT param =
        master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
    const MT lr = static_cast<MT>(lr_[0]);
    const MT velocity = velocity_[i];
364 365 366 367 368

    grad = regularization_flag_ == RegularizationType::kL2DECAY
               ? grad + regularization_coeff_ * param
               : grad;

369 370
    MT velocity_out = velocity * mu_ + grad;
    MT param_out = param - (grad + velocity_out * mu_) * lr;
D
dzhwinter 已提交
371
    // write reigster to memory
372
    velocity_out_[i] = velocity_out;
373 374 375 376
    param_out_[i] = static_cast<T>(param_out);
    if (master_param_out_) {
      master_param_out_[i] = param_out;
    }
D
dzhwinter 已提交
377 378 379
  }
};

380 381
template <typename T, typename MT>
class SparseMomentumFunctor<T, MT, NoNesterov> {
D
dzhwinter 已提交
382
 private:
383 384
  const T* param_;
  const T* grad_;
385 386 387 388 389
  const MT* velocity_;
  const MultiPrecisionType<MT>* lr_;
  const MT* master_param_;
  const MT mu_;
  const MT rescale_grad_;
D
dzhwinter 已提交
390 391 392
  const int64_t* rows_;
  const int64_t row_numel_;
  const int64_t row_height_;
393
  T* param_out_;
394 395 396 397
  MT* velocity_out_;
  MT* master_param_out_;
  const RegularizationType regularization_flag_;
  const MT regularization_coeff_;
D
dzhwinter 已提交
398 399

 public:
400 401 402 403
  SparseMomentumFunctor(const T* param, const T* grad, const MT* velocity,
                        const MultiPrecisionType<MT>* lr,
                        const MT* master_param, const MT mu,
                        const MT rescale_grad, const int64_t* rows,
404
                        int64_t row_numel, int64_t row_height,
405 406 407
                        const RegularizationType regularization_flag,
                        const MT regularization_coeff, T* param_out,
                        MT* velocity_out, MT* master_param_out)
408 409 410
      : param_(param),
        grad_(grad),
        velocity_(velocity),
D
dzhwinter 已提交
411
        lr_(lr),
412
        master_param_(master_param),
D
dzhwinter 已提交
413
        mu_(mu),
414
        rescale_grad_(rescale_grad),
D
dzhwinter 已提交
415 416 417
        rows_(rows),
        row_numel_(row_numel),
        row_height_(row_height),
418 419
        param_out_(param_out),
        velocity_out_(velocity_out),
420
        master_param_out_(master_param_out),
421 422
        regularization_flag_(regularization_flag),
        regularization_coeff_(regularization_coeff) {}
D
dzhwinter 已提交
423 424 425 426

  inline HOSTDEVICE void operator()(size_t i) {
    auto row_idx =
        math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
427 428 429 430 431
    MT grad =
        row_idx >= 0
            ? static_cast<MT>(grad_[row_idx * row_numel_ + i % row_numel_]) *
                  rescale_grad_
            : static_cast<MT>(0);
D
dzhwinter 已提交
432
    // put memory access in register
433 434 435 436
    const MT param =
        master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
    const MT lr = static_cast<MT>(lr_[0]);
    const MT velocity = velocity_[i];
437 438 439 440 441

    grad = regularization_flag_ == RegularizationType::kL2DECAY
               ? grad + regularization_coeff_ * param
               : grad;

442 443
    MT velocity_out = velocity * mu_ + grad;
    MT param_out = param - velocity_out * lr;
D
dzhwinter 已提交
444
    // write reigster to memory
445
    velocity_out_[i] = velocity_out;
446 447 448 449
    param_out_[i] = static_cast<T>(param_out);
    if (master_param_out_) {
      master_param_out_[i] = param_out;
    }
D
dzhwinter 已提交
450 451 452 453
  }
};

template <typename DeviceContext, typename T>
S
sidgoyal78 已提交
454
class MomentumOpKernel : public framework::OpKernel<T> {
455 456
  using MPDType = MultiPrecisionType<T>;

S
sidgoyal78 已提交
457 458
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
459 460 461 462 463
    const bool multi_precision = ctx.Attr<bool>("multi_precision");
    if (multi_precision) {
      InnerCompute<MPDType>(ctx, multi_precision);
    } else {
      InnerCompute<T>(ctx, multi_precision);
464
    }
465
  }
466

467 468 469 470 471 472 473 474
 private:
  template <typename MT>
  void InnerCompute(const framework::ExecutionContext& ctx,
                    const bool multi_precision) const {
    std::string regularization_method =
        ctx.Attr<std::string>("regularization_method");
    MT regularization_coeff =
        static_cast<MT>(ctx.Attr<float>("regularization_coeff"));
475 476 477 478 479 480
    RegularizationType regularization_flag{
        RegularizationType::kNONE};  // disable regularization
    if (regularization_method == "l2_decay") {
      regularization_flag = RegularizationType::kL2DECAY;
    }

481 482
    MT mu = static_cast<MT>(ctx.Attr<float>("mu"));
    MT rescale_grad = static_cast<MT>(ctx.Attr<float>("rescale_grad"));
483
    bool use_nesterov = ctx.Attr<bool>("use_nesterov");
S
sidgoyal78 已提交
484

485 486 487
    auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
    auto param = ctx.Input<framework::Tensor>("Param");
    auto param_out = ctx.Output<framework::Tensor>("ParamOut");
488
    auto velocity = ctx.Input<framework::Tensor>("Velocity");
D
dzhwinter 已提交
489
    auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
490

491 492 493 494 495 496 497 498 499 500 501 502 503 504
    const framework::Tensor* master_param = nullptr;
    framework::Tensor* master_param_out = nullptr;
    if (multi_precision) {
      bool has_master =
          ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
      PADDLE_ENFORCE_EQ(has_master, true,
                        platform::errors::InvalidArgument(
                            "The Input(MasterParam) and Output(MasterParamOut) "
                            "should not be null when "
                            "the attr `multi_precision` is true"));
      master_param = ctx.Input<framework::Tensor>("MasterParam");
      master_param_out = ctx.Output<framework::Tensor>("MasterParamOut");
    }

D
dzhwinter 已提交
505
    param_out->mutable_data<T>(ctx.GetPlace());
506 507 508 509 510 511
    velocity_out->mutable_data<MT>(ctx.GetPlace());
    const MT* master_in_data =
        multi_precision ? master_param->data<MT>() : nullptr;
    MT* master_out_data =
        multi_precision ? master_param_out->mutable_data<MT>(ctx.GetPlace())
                        : nullptr;
D
dzhwinter 已提交
512

513 514 515
    auto* grad_var = ctx.InputVar("Grad");
    if (grad_var->IsType<framework::LoDTensor>()) {
      auto grad = ctx.Input<framework::Tensor>("Grad");
D
dzhwinter 已提交
516
      if (platform::is_cpu_place(ctx.GetPlace())) {
517 518 519 520
        CPUDenseMomentumFunctor<MT> functor;
        functor(param, grad, velocity, learning_rate, mu, use_nesterov,
                regularization_flag, regularization_coeff, param_out,
                velocity_out);
D
dzhwinter 已提交
521 522 523 524 525
      } else if (platform::is_gpu_place(ctx.GetPlace())) {
        platform::ForRange<DeviceContext> for_range(
            static_cast<const DeviceContext&>(ctx.device_context()),
            param->numel());
        if (use_nesterov) {
526 527 528 529 530 531
          DenseMomentumFunctor<T, MT, UseNesterov> functor(
              param->data<T>(), grad->data<T>(), velocity->data<MT>(),
              learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad,
              param->numel(), regularization_flag, regularization_coeff,
              param_out->mutable_data<T>(ctx.GetPlace()),
              velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
D
dzhwinter 已提交
532 533 534
          for_range(functor);

        } else {
535 536 537 538 539 540
          DenseMomentumFunctor<T, MT, NoNesterov> functor(
              param->data<T>(), grad->data<T>(), velocity->data<MT>(),
              learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad,
              param->numel(), regularization_flag, regularization_coeff,
              param_out->mutable_data<T>(ctx.GetPlace()),
              velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
D
dzhwinter 已提交
541 542
          for_range(functor);
        }
543
      }
D
dzhwinter 已提交
544

545 546 547
    } else if (grad_var->IsType<framework::SelectedRows>()) {
      // sparse update embedding with selectedrows
      auto grad = ctx.Input<framework::SelectedRows>("Grad");
S
sidgoyal78 已提交
548

549 550
      // sparse update maybe empty.
      if (grad->rows().size() == 0) {
M
minqiyang 已提交
551
        VLOG(3) << "Grad SelectedRows contains no data!";
552 553
        return;
      }
S
sneaxiy 已提交
554 555 556

      framework::SelectedRows tmp_merged_grad;
      framework::SelectedRows* merged_grad = &tmp_merged_grad;
D
dzhwinter 已提交
557 558 559 560
      math::scatter::MergeAdd<DeviceContext, T> merge_func;
      merge_func(ctx.template device_context<DeviceContext>(), *grad,
                 merged_grad);

S
sneaxiy 已提交
561
      const int64_t* rows = merged_grad->rows().Data(ctx.GetPlace());
D
dzhwinter 已提交
562 563 564 565 566
      int64_t row_numel =
          merged_grad->value().numel() / merged_grad->rows().size();
      platform::ForRange<DeviceContext> for_range(
          static_cast<const DeviceContext&>(ctx.device_context()),
          param->numel());
D
dzhwinter 已提交
567
      if (use_nesterov) {
568
        SparseMomentumFunctor<T, MT, UseNesterov> functor(
D
dzhwinter 已提交
569
            param->data<T>(), merged_grad->value().data<T>(),
570 571
            velocity->data<MT>(), learning_rate->data<MPDType>(),
            master_in_data, mu, rescale_grad, rows, row_numel,
D
dzhwinter 已提交
572
            static_cast<int64_t>(merged_grad->rows().size()),
573
            regularization_flag, regularization_coeff,
D
dzhwinter 已提交
574
            param_out->mutable_data<T>(ctx.GetPlace()),
575
            velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
D
dzhwinter 已提交
576 577 578
        for_range(functor);

      } else {
579
        SparseMomentumFunctor<T, MT, NoNesterov> functor(
D
dzhwinter 已提交
580
            param->data<T>(), merged_grad->value().data<T>(),
581 582
            velocity->data<MT>(), learning_rate->data<MPDType>(),
            master_in_data, mu, rescale_grad, rows, row_numel,
D
dzhwinter 已提交
583
            static_cast<int64_t>(merged_grad->rows().size()),
584
            regularization_flag, regularization_coeff,
D
dzhwinter 已提交
585
            param_out->mutable_data<T>(ctx.GetPlace()),
586
            velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
D
dzhwinter 已提交
587
        for_range(functor);
588
      }
K
kavyasrinet 已提交
589
    } else {
C
Chengmo 已提交
590 591 592 593 594 595
      PADDLE_ENFORCE_EQ(false, true,
                        platform::errors::PermissionDenied(
                            "Unsupported Variable Type of Grad "
                            "in MomentumOp. Excepted LodTensor "
                            "or SelectedRows, But received [%s]",
                            paddle::framework::ToTypeName(grad_var->Type())));
K
kavyasrinet 已提交
596
    }
S
sidgoyal78 已提交
597 598 599 600 601
  }
};

}  // namespace operators
}  // namespace paddle