momentum_op.h 22.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
S
sidgoyal78 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
S
sneaxiy 已提交
16
#include <memory>
D
dzhwinter 已提交
17
#include <string>
Y
Yi Wang 已提交
18 19
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
D
dzhwinter 已提交
20 21
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
22
#include "paddle/fluid/platform/float16.h"
D
dzhwinter 已提交
23
#include "paddle/fluid/platform/for_range.h"
S
sidgoyal78 已提交
24 25 26 27

namespace paddle {
namespace operators {

D
dzhwinter 已提交
28 29 30 31 32
using framework::Tensor;
using framework::SelectedRows;
struct NoNesterov;
struct UseNesterov;

33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
namespace details {

template <typename T>
class MPTypeTrait {
 public:
  using Type = T;
};
template <>
class MPTypeTrait<platform::float16> {
 public:
  using Type = float;
};

template <typename T>
struct CPUDenseUpdater {
  template <typename G>
  void operator()(const Tensor& param, const Tensor& velocity, const T& mu,
                  const T& lr, const bool use_nesterov, G&& grad,
                  Tensor* param_out, Tensor* velocity_out) const {
    auto param_out_vec = framework::EigenVector<T>::Flatten(*param_out);
    auto velocity_out_vec = framework::EigenVector<T>::Flatten(*velocity_out);

    auto param_vec = framework::EigenVector<T>::Flatten(param);
    auto velocity_vec = framework::EigenVector<T>::Flatten(velocity);
    velocity_out_vec = velocity_vec * mu + grad;
    if (use_nesterov) {
      param_out_vec = param_vec - (grad + velocity_out_vec * mu) * lr;
    } else {
      param_out_vec = param_vec - lr * velocity_out_vec;
    }
  }
};

}  // namespace details

template <typename T>
using MultiPrecisionType = typename details::MPTypeTrait<T>::Type;

71 72 73 74 75 76
enum class RegularizationType {
  kNONE = 0,
  kL1DECAY = 1,  // do not need support right now
  kL2DECAY = 2,
};

77 78 79 80 81
class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override;
};

82 83 84 85 86 87
class MomentumOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
C
Chengmo 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
    PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true,
                      platform::errors::NotFound(
                          "Input(param) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"), true,
                      platform::errors::NotFound(
                          "Input(grad) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasInput("Velocity"), true,
                      platform::errors::NotFound(
                          "Input(velocity) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("LearningRate"), true,
        platform::errors::NotFound(
            "Input(LearningRate) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(
        ctx->GetInputsVarType("Param").front(),
        framework::proto::VarType::LOD_TENSOR,
        platform::errors::InvalidArgument(
            "The input var's type should be LoDTensor, but the received is %s",
            ctx->GetInputsVarType("Param").front()));

    PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"), true,
                      platform::errors::NotFound(
                          "Output(ParamOut) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasOutput("VelocityOut"), true,
        platform::errors::NotFound(
            "Output(VelocityOut) of Momentum should not be null."));
115

116 117
    auto lr_dims = ctx->GetInputDim("LearningRate");
    PADDLE_ENFORCE_NE(framework::product(lr_dims), 0,
C
Chengmo 已提交
118 119 120 121 122
                      platform::errors::InvalidArgument(
                          "Maybe the Input variable LearningRate has not "
                          "been initialized. You may need to confirm "
                          "if you put exe.run(startup_program) "
                          "after optimizer.minimize function."));
123
    PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
C
Chengmo 已提交
124 125 126 127
                      platform::errors::InvalidArgument(
                          "Learning_rate should be a scalar. But Received "
                          "LearningRate's dim [%s]",
                          framework::product(lr_dims)));
128

129 130 131 132 133
    auto param_dim = ctx->GetInputDim("Param");
    if (ctx->GetInputsVarType("Grad")[0] ==
        framework::proto::VarType::LOD_TENSOR) {
      PADDLE_ENFORCE_EQ(
          param_dim, ctx->GetInputDim("Grad"),
C
Chengmo 已提交
134 135 136 137
          platform::errors::InvalidArgument(
              "Param and Grad input of MomentumOp should have the same "
              "dimension. But received Param's dim [%s] and Grad's dim [%s].",
              param_dim, ctx->GetInputDim("Grad")));
138 139
      PADDLE_ENFORCE_EQ(
          param_dim, ctx->GetInputDim("Velocity"),
C
Chengmo 已提交
140 141 142 143
          platform::errors::InvalidArgument(
              "Param and Velocity of MomentumOp should have the same "
              "dimension. But received Param's dim [%s] and Velocity [%s].",
              param_dim, ctx->GetInputDim("Velocity")));
144 145 146 147 148
    }

    ctx->SetOutputDim("ParamOut", param_dim);
    ctx->SetOutputDim("VelocityOut", param_dim);
  }
S
sneaxiy 已提交
149

150 151
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
152 153
    auto input_data_type =
        OperatorWithKernel::IndicateVarDataType(ctx, "Param");
154 155 156 157
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
};

D
dzhwinter 已提交
158 159 160
template <typename T>
class CPUDenseMomentumFunctor {
 public:
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
  void operator()(const Tensor* param, const Tensor* grad,
                  const Tensor* velocity, const Tensor* learning_rate,
                  const T mu, const bool use_nesterov,
                  const RegularizationType regularization_flag,
                  const T regularization_coeff, Tensor* param_out,
                  Tensor* velocity_out) {
    auto grad_vec = framework::EigenVector<T>::Flatten(*grad);
    auto* lr = learning_rate->data<MultiPrecisionType<T>>();

    details::CPUDenseUpdater<T> updater;
    if (regularization_flag == RegularizationType::kL2DECAY) {
      auto param_vec = framework::EigenVector<T>::Flatten(*param);
      updater(*param, *velocity, mu, static_cast<T>(lr[0]), use_nesterov,
              param_vec * regularization_coeff + grad_vec, param_out,
              velocity_out);
D
dzhwinter 已提交
176
    } else {
177 178
      updater(*param, *velocity, mu, static_cast<T>(lr[0]), use_nesterov,
              grad_vec, param_out, velocity_out);
D
dzhwinter 已提交
179 180 181 182
    }
  }
};

183
template <typename T, typename MT, typename UpdateMethod>
D
dzhwinter 已提交
184 185 186 187 188
class DenseMomentumFunctor;

// NOTE(dzh) for performance.
// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two
// functor.
189 190
template <typename T, typename MT>
class DenseMomentumFunctor<T, MT, UseNesterov> {
D
dzhwinter 已提交
191
 private:
192 193
  const T* param_;
  const T* grad_;
194 195 196 197 198
  const MT* velocity_;
  const MultiPrecisionType<MT>* lr_;
  const MT* master_param_;
  const MT mu_;
  const MT rescale_grad_;
D
dzhwinter 已提交
199
  const int64_t num_;
200
  T* param_out_;
201 202 203 204
  MT* velocity_out_;
  MT* master_param_out_;
  const RegularizationType regularization_flag_;
  const MT regularization_coeff_;
D
dzhwinter 已提交
205 206

 public:
207 208 209 210 211 212 213
  DenseMomentumFunctor(const T* param, const T* grad, const MT* velocity,
                       const MultiPrecisionType<MT>* learning_rate,
                       const MT* master_param, const MT mu,
                       const MT rescale_grad, const int64_t num,
                       const RegularizationType regularization_flag,
                       const MT regularization_coeff, T* param_out,
                       MT* velocity_out, MT* master_param_out)
214 215 216
      : param_(param),
        grad_(grad),
        velocity_(velocity),
D
dzhwinter 已提交
217
        lr_(learning_rate),
218
        master_param_(master_param),
D
dzhwinter 已提交
219
        mu_(mu),
220
        rescale_grad_(rescale_grad),
D
dzhwinter 已提交
221
        num_(num),
222 223
        param_out_(param_out),
        velocity_out_(velocity_out),
224
        master_param_out_(master_param_out),
225 226
        regularization_flag_(regularization_flag),
        regularization_coeff_(regularization_coeff) {}
D
dzhwinter 已提交
227 228
  inline HOSTDEVICE void operator()(size_t i) const {
    // put memory access in register
229 230 231 232 233
    const MT param =
        master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
    MT grad = static_cast<MT>(grad_[i]) * rescale_grad_;
    const MT lr = static_cast<MT>(lr_[0]);
    const MT velocity = velocity_[i];
234 235 236 237 238

    grad = regularization_flag_ == RegularizationType::kL2DECAY
               ? grad + regularization_coeff_ * param
               : grad;

239 240
    MT velocity_out = velocity * mu_ + grad;
    MT param_out = param - (grad + velocity_out * mu_) * lr;
D
dzhwinter 已提交
241
    // write reigster to memory
242
    velocity_out_[i] = velocity_out;
243 244 245 246
    param_out_[i] = static_cast<T>(param_out);
    if (master_param_out_) {
      master_param_out_[i] = param_out;
    }
D
dzhwinter 已提交
247 248 249
  }
};

250 251
template <typename T, typename MT>
class DenseMomentumFunctor<T, MT, NoNesterov> {
D
dzhwinter 已提交
252
 private:
253 254
  const T* param_;
  const T* grad_;
255 256 257 258 259
  const MT* velocity_;
  const MultiPrecisionType<MT>* lr_;
  const MT* master_param_;
  const MT mu_;
  const MT rescale_grad_;
D
dzhwinter 已提交
260
  const int64_t num_;
261
  T* param_out_;
262 263 264 265
  MT* velocity_out_;
  MT* master_param_out_;
  const RegularizationType regularization_flag_;
  const MT regularization_coeff_;
D
dzhwinter 已提交
266 267

 public:
268 269 270 271 272 273 274
  DenseMomentumFunctor(const T* param, const T* grad, const MT* velocity,
                       const MultiPrecisionType<MT>* learning_rate,
                       const MT* master_param, const MT mu,
                       const MT rescale_grad, const int64_t num,
                       const RegularizationType regularization_flag,
                       const MT regularization_coeff, T* param_out,
                       MT* velocity_out, MT* master_param_out)
275 276 277
      : param_(param),
        grad_(grad),
        velocity_(velocity),
D
dzhwinter 已提交
278
        lr_(learning_rate),
279
        master_param_(master_param),
D
dzhwinter 已提交
280
        mu_(mu),
281
        rescale_grad_(rescale_grad),
D
dzhwinter 已提交
282
        num_(num),
283 284
        param_out_(param_out),
        velocity_out_(velocity_out),
285
        master_param_out_(master_param_out),
286 287
        regularization_flag_(regularization_flag),
        regularization_coeff_(regularization_coeff) {}
D
dzhwinter 已提交
288 289
  inline HOSTDEVICE void operator()(size_t i) const {
    // put memory access in register
290 291 292 293 294
    const MT param =
        master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
    MT grad = static_cast<MT>(grad_[i]) * rescale_grad_;
    const MT lr = static_cast<MT>(lr_[0]);
    const MT velocity = velocity_[i];
295 296 297 298 299

    grad = regularization_flag_ == RegularizationType::kL2DECAY
               ? grad + regularization_coeff_ * param
               : grad;

300 301
    MT velocity_out = velocity * mu_ + grad;
    MT param_out = param - lr * velocity_out;
D
dzhwinter 已提交
302
    // write reigster to memory
303
    velocity_out_[i] = velocity_out;
304 305 306 307
    param_out_[i] = static_cast<T>(param_out);
    if (master_param_out_) {
      master_param_out_[i] = param_out;
    }
D
dzhwinter 已提交
308 309 310
  }
};

311
template <typename T, typename MT, typename UpdateMethod>
D
dzhwinter 已提交
312 313
class SparseMomentumFunctor;

314 315
template <typename T, typename MT>
class SparseMomentumFunctor<T, MT, UseNesterov> {
D
dzhwinter 已提交
316
 private:
317 318
  const T* param_;
  const T* grad_;
319 320 321 322 323
  const MT* velocity_;
  const MultiPrecisionType<MT>* lr_;
  const MT* master_param_;
  const MT mu_;
  const MT rescale_grad_;
D
dzhwinter 已提交
324 325 326
  const int64_t* rows_;
  const int64_t row_numel_;
  const int64_t row_height_;
327
  T* param_out_;
328 329 330 331
  MT* velocity_out_;
  MT* master_param_out_;
  const RegularizationType regularization_flag_;
  const MT regularization_coeff_;
D
dzhwinter 已提交
332 333

 public:
334 335 336 337
  SparseMomentumFunctor(const T* param, const T* grad, const MT* velocity,
                        const MultiPrecisionType<MT>* lr,
                        const MT* master_param, const MT mu,
                        const MT rescale_grad, const int64_t* rows,
338
                        int64_t row_numel, int64_t row_height,
339 340 341
                        const RegularizationType regularization_flag,
                        const MT regularization_coeff, T* param_out,
                        MT* velocity_out, MT* master_param_out)
342 343 344
      : param_(param),
        grad_(grad),
        velocity_(velocity),
D
dzhwinter 已提交
345
        lr_(lr),
346
        master_param_(master_param),
D
dzhwinter 已提交
347
        mu_(mu),
348
        rescale_grad_(rescale_grad),
D
dzhwinter 已提交
349 350 351
        rows_(rows),
        row_numel_(row_numel),
        row_height_(row_height),
352 353
        param_out_(param_out),
        velocity_out_(velocity_out),
354
        master_param_out_(master_param_out),
355 356
        regularization_flag_(regularization_flag),
        regularization_coeff_(regularization_coeff) {}
D
dzhwinter 已提交
357 358 359 360

  inline HOSTDEVICE void operator()(size_t i) {
    auto row_idx =
        math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
361 362 363 364 365
    MT grad =
        row_idx >= 0
            ? static_cast<MT>(grad_[row_idx * row_numel_ + i % row_numel_]) *
                  rescale_grad_
            : static_cast<MT>(0);
D
dzhwinter 已提交
366
    // put memory access in register
367 368 369 370
    const MT param =
        master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
    const MT lr = static_cast<MT>(lr_[0]);
    const MT velocity = velocity_[i];
371 372 373 374 375

    grad = regularization_flag_ == RegularizationType::kL2DECAY
               ? grad + regularization_coeff_ * param
               : grad;

376 377
    MT velocity_out = velocity * mu_ + grad;
    MT param_out = param - (grad + velocity_out * mu_) * lr;
D
dzhwinter 已提交
378
    // write reigster to memory
379
    velocity_out_[i] = velocity_out;
380 381 382 383
    param_out_[i] = static_cast<T>(param_out);
    if (master_param_out_) {
      master_param_out_[i] = param_out;
    }
D
dzhwinter 已提交
384 385 386
  }
};

387 388
template <typename T, typename MT>
class SparseMomentumFunctor<T, MT, NoNesterov> {
D
dzhwinter 已提交
389
 private:
390 391
  const T* param_;
  const T* grad_;
392 393 394 395 396
  const MT* velocity_;
  const MultiPrecisionType<MT>* lr_;
  const MT* master_param_;
  const MT mu_;
  const MT rescale_grad_;
D
dzhwinter 已提交
397 398 399
  const int64_t* rows_;
  const int64_t row_numel_;
  const int64_t row_height_;
400
  T* param_out_;
401 402 403 404
  MT* velocity_out_;
  MT* master_param_out_;
  const RegularizationType regularization_flag_;
  const MT regularization_coeff_;
D
dzhwinter 已提交
405 406

 public:
407 408 409 410
  SparseMomentumFunctor(const T* param, const T* grad, const MT* velocity,
                        const MultiPrecisionType<MT>* lr,
                        const MT* master_param, const MT mu,
                        const MT rescale_grad, const int64_t* rows,
411
                        int64_t row_numel, int64_t row_height,
412 413 414
                        const RegularizationType regularization_flag,
                        const MT regularization_coeff, T* param_out,
                        MT* velocity_out, MT* master_param_out)
415 416 417
      : param_(param),
        grad_(grad),
        velocity_(velocity),
D
dzhwinter 已提交
418
        lr_(lr),
419
        master_param_(master_param),
D
dzhwinter 已提交
420
        mu_(mu),
421
        rescale_grad_(rescale_grad),
D
dzhwinter 已提交
422 423 424
        rows_(rows),
        row_numel_(row_numel),
        row_height_(row_height),
425 426
        param_out_(param_out),
        velocity_out_(velocity_out),
427
        master_param_out_(master_param_out),
428 429
        regularization_flag_(regularization_flag),
        regularization_coeff_(regularization_coeff) {}
D
dzhwinter 已提交
430 431 432 433

  inline HOSTDEVICE void operator()(size_t i) {
    auto row_idx =
        math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
434 435 436 437 438
    MT grad =
        row_idx >= 0
            ? static_cast<MT>(grad_[row_idx * row_numel_ + i % row_numel_]) *
                  rescale_grad_
            : static_cast<MT>(0);
D
dzhwinter 已提交
439
    // put memory access in register
440 441 442 443
    const MT param =
        master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
    const MT lr = static_cast<MT>(lr_[0]);
    const MT velocity = velocity_[i];
444 445 446 447 448

    grad = regularization_flag_ == RegularizationType::kL2DECAY
               ? grad + regularization_coeff_ * param
               : grad;

449 450
    MT velocity_out = velocity * mu_ + grad;
    MT param_out = param - velocity_out * lr;
D
dzhwinter 已提交
451
    // write reigster to memory
452
    velocity_out_[i] = velocity_out;
453 454 455 456
    param_out_[i] = static_cast<T>(param_out);
    if (master_param_out_) {
      master_param_out_[i] = param_out;
    }
D
dzhwinter 已提交
457 458 459 460
  }
};

template <typename DeviceContext, typename T>
S
sidgoyal78 已提交
461
class MomentumOpKernel : public framework::OpKernel<T> {
462 463
  using MPDType = MultiPrecisionType<T>;

S
sidgoyal78 已提交
464 465
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
466 467 468 469 470
    const bool multi_precision = ctx.Attr<bool>("multi_precision");
    if (multi_precision) {
      InnerCompute<MPDType>(ctx, multi_precision);
    } else {
      InnerCompute<T>(ctx, multi_precision);
471
    }
472
  }
473

474 475 476 477 478 479 480 481
 private:
  template <typename MT>
  void InnerCompute(const framework::ExecutionContext& ctx,
                    const bool multi_precision) const {
    std::string regularization_method =
        ctx.Attr<std::string>("regularization_method");
    MT regularization_coeff =
        static_cast<MT>(ctx.Attr<float>("regularization_coeff"));
482 483 484 485 486 487
    RegularizationType regularization_flag{
        RegularizationType::kNONE};  // disable regularization
    if (regularization_method == "l2_decay") {
      regularization_flag = RegularizationType::kL2DECAY;
    }

488 489
    MT mu = static_cast<MT>(ctx.Attr<float>("mu"));
    MT rescale_grad = static_cast<MT>(ctx.Attr<float>("rescale_grad"));
490
    bool use_nesterov = ctx.Attr<bool>("use_nesterov");
S
sidgoyal78 已提交
491

492 493 494
    auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
    auto param = ctx.Input<framework::Tensor>("Param");
    auto param_out = ctx.Output<framework::Tensor>("ParamOut");
495
    auto velocity = ctx.Input<framework::Tensor>("Velocity");
D
dzhwinter 已提交
496
    auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
497

498 499 500 501 502 503 504 505 506 507 508 509 510 511
    const framework::Tensor* master_param = nullptr;
    framework::Tensor* master_param_out = nullptr;
    if (multi_precision) {
      bool has_master =
          ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
      PADDLE_ENFORCE_EQ(has_master, true,
                        platform::errors::InvalidArgument(
                            "The Input(MasterParam) and Output(MasterParamOut) "
                            "should not be null when "
                            "the attr `multi_precision` is true"));
      master_param = ctx.Input<framework::Tensor>("MasterParam");
      master_param_out = ctx.Output<framework::Tensor>("MasterParamOut");
    }

D
dzhwinter 已提交
512
    param_out->mutable_data<T>(ctx.GetPlace());
513 514 515 516 517 518
    velocity_out->mutable_data<MT>(ctx.GetPlace());
    const MT* master_in_data =
        multi_precision ? master_param->data<MT>() : nullptr;
    MT* master_out_data =
        multi_precision ? master_param_out->mutable_data<MT>(ctx.GetPlace())
                        : nullptr;
D
dzhwinter 已提交
519

520 521 522
    auto* grad_var = ctx.InputVar("Grad");
    if (grad_var->IsType<framework::LoDTensor>()) {
      auto grad = ctx.Input<framework::Tensor>("Grad");
D
dzhwinter 已提交
523
      if (platform::is_cpu_place(ctx.GetPlace())) {
524 525 526 527
        CPUDenseMomentumFunctor<MT> functor;
        functor(param, grad, velocity, learning_rate, mu, use_nesterov,
                regularization_flag, regularization_coeff, param_out,
                velocity_out);
D
dzhwinter 已提交
528 529 530 531 532
      } else if (platform::is_gpu_place(ctx.GetPlace())) {
        platform::ForRange<DeviceContext> for_range(
            static_cast<const DeviceContext&>(ctx.device_context()),
            param->numel());
        if (use_nesterov) {
533 534 535 536 537 538
          DenseMomentumFunctor<T, MT, UseNesterov> functor(
              param->data<T>(), grad->data<T>(), velocity->data<MT>(),
              learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad,
              param->numel(), regularization_flag, regularization_coeff,
              param_out->mutable_data<T>(ctx.GetPlace()),
              velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
D
dzhwinter 已提交
539 540 541
          for_range(functor);

        } else {
542 543 544 545 546 547
          DenseMomentumFunctor<T, MT, NoNesterov> functor(
              param->data<T>(), grad->data<T>(), velocity->data<MT>(),
              learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad,
              param->numel(), regularization_flag, regularization_coeff,
              param_out->mutable_data<T>(ctx.GetPlace()),
              velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
D
dzhwinter 已提交
548 549
          for_range(functor);
        }
550
      }
D
dzhwinter 已提交
551

552 553 554
    } else if (grad_var->IsType<framework::SelectedRows>()) {
      // sparse update embedding with selectedrows
      auto grad = ctx.Input<framework::SelectedRows>("Grad");
S
sidgoyal78 已提交
555

556 557
      // sparse update maybe empty.
      if (grad->rows().size() == 0) {
M
minqiyang 已提交
558
        VLOG(3) << "Grad SelectedRows contains no data!";
559 560
        return;
      }
S
sneaxiy 已提交
561 562 563

      framework::SelectedRows tmp_merged_grad;
      framework::SelectedRows* merged_grad = &tmp_merged_grad;
D
dzhwinter 已提交
564 565 566 567
      math::scatter::MergeAdd<DeviceContext, T> merge_func;
      merge_func(ctx.template device_context<DeviceContext>(), *grad,
                 merged_grad);

S
sneaxiy 已提交
568
      const int64_t* rows = merged_grad->rows().Data(ctx.GetPlace());
D
dzhwinter 已提交
569 570 571 572 573
      int64_t row_numel =
          merged_grad->value().numel() / merged_grad->rows().size();
      platform::ForRange<DeviceContext> for_range(
          static_cast<const DeviceContext&>(ctx.device_context()),
          param->numel());
D
dzhwinter 已提交
574
      if (use_nesterov) {
575
        SparseMomentumFunctor<T, MT, UseNesterov> functor(
D
dzhwinter 已提交
576
            param->data<T>(), merged_grad->value().data<T>(),
577 578
            velocity->data<MT>(), learning_rate->data<MPDType>(),
            master_in_data, mu, rescale_grad, rows, row_numel,
D
dzhwinter 已提交
579
            static_cast<int64_t>(merged_grad->rows().size()),
580
            regularization_flag, regularization_coeff,
D
dzhwinter 已提交
581
            param_out->mutable_data<T>(ctx.GetPlace()),
582
            velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
D
dzhwinter 已提交
583 584 585
        for_range(functor);

      } else {
586
        SparseMomentumFunctor<T, MT, NoNesterov> functor(
D
dzhwinter 已提交
587
            param->data<T>(), merged_grad->value().data<T>(),
588 589
            velocity->data<MT>(), learning_rate->data<MPDType>(),
            master_in_data, mu, rescale_grad, rows, row_numel,
D
dzhwinter 已提交
590
            static_cast<int64_t>(merged_grad->rows().size()),
591
            regularization_flag, regularization_coeff,
D
dzhwinter 已提交
592
            param_out->mutable_data<T>(ctx.GetPlace()),
593
            velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
D
dzhwinter 已提交
594
        for_range(functor);
595
      }
K
kavyasrinet 已提交
596
    } else {
C
Chengmo 已提交
597 598 599 600 601 602
      PADDLE_ENFORCE_EQ(false, true,
                        platform::errors::PermissionDenied(
                            "Unsupported Variable Type of Grad "
                            "in MomentumOp. Excepted LodTensor "
                            "or SelectedRows, But received [%s]",
                            paddle::framework::ToTypeName(grad_var->Type())));
K
kavyasrinet 已提交
603
    }
S
sidgoyal78 已提交
604 605 606 607 608
  }
};

}  // namespace operators
}  // namespace paddle