momentum_op.h 22.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
S
sidgoyal78 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
S
sneaxiy 已提交
16
#include <memory>
D
dzhwinter 已提交
17
#include <string>
Y
Yi Wang 已提交
18 19
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
20
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
D
dzhwinter 已提交
21 22
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
23
#include "paddle/fluid/platform/float16.h"
D
dzhwinter 已提交
24
#include "paddle/fluid/platform/for_range.h"
S
sidgoyal78 已提交
25 26 27 28

namespace paddle {
namespace operators {

D
dzhwinter 已提交
29 30 31 32 33
using framework::Tensor;
using framework::SelectedRows;
struct NoNesterov;
struct UseNesterov;

34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
namespace details {

template <typename T>
struct CPUDenseUpdater {
  template <typename G>
  void operator()(const Tensor& param, const Tensor& velocity, const T& mu,
                  const T& lr, const bool use_nesterov, G&& grad,
                  Tensor* param_out, Tensor* velocity_out) const {
    auto param_out_vec = framework::EigenVector<T>::Flatten(*param_out);
    auto velocity_out_vec = framework::EigenVector<T>::Flatten(*velocity_out);

    auto param_vec = framework::EigenVector<T>::Flatten(param);
    auto velocity_vec = framework::EigenVector<T>::Flatten(velocity);
    velocity_out_vec = velocity_vec * mu + grad;
    if (use_nesterov) {
      param_out_vec = param_vec - (grad + velocity_out_vec * mu) * lr;
    } else {
      param_out_vec = param_vec - lr * velocity_out_vec;
    }
  }
};

}  // namespace details

template <typename T>
using MultiPrecisionType = typename details::MPTypeTrait<T>::Type;

61 62 63 64 65 66
enum class RegularizationType {
  kNONE = 0,
  kL1DECAY = 1,  // do not need support right now
  kL2DECAY = 2,
};

67 68 69 70 71
class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override;
};

72 73 74 75 76 77
class MomentumOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
C
Chengmo 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
    PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true,
                      platform::errors::NotFound(
                          "Input(param) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"), true,
                      platform::errors::NotFound(
                          "Input(grad) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasInput("Velocity"), true,
                      platform::errors::NotFound(
                          "Input(velocity) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("LearningRate"), true,
        platform::errors::NotFound(
            "Input(LearningRate) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(
        ctx->GetInputsVarType("Param").front(),
        framework::proto::VarType::LOD_TENSOR,
        platform::errors::InvalidArgument(
            "The input var's type should be LoDTensor, but the received is %s",
            ctx->GetInputsVarType("Param").front()));

    PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"), true,
                      platform::errors::NotFound(
                          "Output(ParamOut) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasOutput("VelocityOut"), true,
        platform::errors::NotFound(
            "Output(VelocityOut) of Momentum should not be null."));
105

106 107
    auto lr_dims = ctx->GetInputDim("LearningRate");
    PADDLE_ENFORCE_NE(framework::product(lr_dims), 0,
C
Chengmo 已提交
108 109 110 111 112
                      platform::errors::InvalidArgument(
                          "Maybe the Input variable LearningRate has not "
                          "been initialized. You may need to confirm "
                          "if you put exe.run(startup_program) "
                          "after optimizer.minimize function."));
113
    PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
C
Chengmo 已提交
114 115 116 117
                      platform::errors::InvalidArgument(
                          "Learning_rate should be a scalar. But Received "
                          "LearningRate's dim [%s]",
                          framework::product(lr_dims)));
118

119 120 121 122 123
    auto param_dim = ctx->GetInputDim("Param");
    if (ctx->GetInputsVarType("Grad")[0] ==
        framework::proto::VarType::LOD_TENSOR) {
      PADDLE_ENFORCE_EQ(
          param_dim, ctx->GetInputDim("Grad"),
C
Chengmo 已提交
124 125 126 127
          platform::errors::InvalidArgument(
              "Param and Grad input of MomentumOp should have the same "
              "dimension. But received Param's dim [%s] and Grad's dim [%s].",
              param_dim, ctx->GetInputDim("Grad")));
128 129
      PADDLE_ENFORCE_EQ(
          param_dim, ctx->GetInputDim("Velocity"),
C
Chengmo 已提交
130 131 132 133
          platform::errors::InvalidArgument(
              "Param and Velocity of MomentumOp should have the same "
              "dimension. But received Param's dim [%s] and Velocity [%s].",
              param_dim, ctx->GetInputDim("Velocity")));
134 135 136 137 138
    }

    ctx->SetOutputDim("ParamOut", param_dim);
    ctx->SetOutputDim("VelocityOut", param_dim);
  }
S
sneaxiy 已提交
139

140 141
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
142 143
    auto input_data_type =
        OperatorWithKernel::IndicateVarDataType(ctx, "Param");
144 145 146 147
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
};

D
dzhwinter 已提交
148 149 150
template <typename T>
class CPUDenseMomentumFunctor {
 public:
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
  void operator()(const Tensor* param, const Tensor* grad,
                  const Tensor* velocity, const Tensor* learning_rate,
                  const T mu, const bool use_nesterov,
                  const RegularizationType regularization_flag,
                  const T regularization_coeff, Tensor* param_out,
                  Tensor* velocity_out) {
    auto grad_vec = framework::EigenVector<T>::Flatten(*grad);
    auto* lr = learning_rate->data<MultiPrecisionType<T>>();

    details::CPUDenseUpdater<T> updater;
    if (regularization_flag == RegularizationType::kL2DECAY) {
      auto param_vec = framework::EigenVector<T>::Flatten(*param);
      updater(*param, *velocity, mu, static_cast<T>(lr[0]), use_nesterov,
              param_vec * regularization_coeff + grad_vec, param_out,
              velocity_out);
D
dzhwinter 已提交
166
    } else {
167 168
      updater(*param, *velocity, mu, static_cast<T>(lr[0]), use_nesterov,
              grad_vec, param_out, velocity_out);
D
dzhwinter 已提交
169 170 171 172
    }
  }
};

173
template <typename T, typename MT, typename UpdateMethod>
D
dzhwinter 已提交
174 175 176 177 178
class DenseMomentumFunctor;

// NOTE(dzh) for performance.
// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two
// functor.
179 180
template <typename T, typename MT>
class DenseMomentumFunctor<T, MT, UseNesterov> {
D
dzhwinter 已提交
181
 private:
182 183
  const T* param_;
  const T* grad_;
184 185 186 187 188
  const MT* velocity_;
  const MultiPrecisionType<MT>* lr_;
  const MT* master_param_;
  const MT mu_;
  const MT rescale_grad_;
D
dzhwinter 已提交
189
  const int64_t num_;
190
  T* param_out_;
191 192 193 194
  MT* velocity_out_;
  MT* master_param_out_;
  const RegularizationType regularization_flag_;
  const MT regularization_coeff_;
D
dzhwinter 已提交
195 196

 public:
197 198 199 200 201 202 203
  DenseMomentumFunctor(const T* param, const T* grad, const MT* velocity,
                       const MultiPrecisionType<MT>* learning_rate,
                       const MT* master_param, const MT mu,
                       const MT rescale_grad, const int64_t num,
                       const RegularizationType regularization_flag,
                       const MT regularization_coeff, T* param_out,
                       MT* velocity_out, MT* master_param_out)
204 205 206
      : param_(param),
        grad_(grad),
        velocity_(velocity),
D
dzhwinter 已提交
207
        lr_(learning_rate),
208
        master_param_(master_param),
D
dzhwinter 已提交
209
        mu_(mu),
210
        rescale_grad_(rescale_grad),
D
dzhwinter 已提交
211
        num_(num),
212 213
        param_out_(param_out),
        velocity_out_(velocity_out),
214
        master_param_out_(master_param_out),
215 216
        regularization_flag_(regularization_flag),
        regularization_coeff_(regularization_coeff) {}
D
dzhwinter 已提交
217 218
  inline HOSTDEVICE void operator()(size_t i) const {
    // put memory access in register
219 220 221 222 223
    const MT param =
        master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
    MT grad = static_cast<MT>(grad_[i]) * rescale_grad_;
    const MT lr = static_cast<MT>(lr_[0]);
    const MT velocity = velocity_[i];
224 225 226 227 228

    grad = regularization_flag_ == RegularizationType::kL2DECAY
               ? grad + regularization_coeff_ * param
               : grad;

229 230
    MT velocity_out = velocity * mu_ + grad;
    MT param_out = param - (grad + velocity_out * mu_) * lr;
D
dzhwinter 已提交
231
    // write reigster to memory
232
    velocity_out_[i] = velocity_out;
233 234 235 236
    param_out_[i] = static_cast<T>(param_out);
    if (master_param_out_) {
      master_param_out_[i] = param_out;
    }
D
dzhwinter 已提交
237 238 239
  }
};

240 241
template <typename T, typename MT>
class DenseMomentumFunctor<T, MT, NoNesterov> {
D
dzhwinter 已提交
242
 private:
243 244
  const T* param_;
  const T* grad_;
245 246 247 248 249
  const MT* velocity_;
  const MultiPrecisionType<MT>* lr_;
  const MT* master_param_;
  const MT mu_;
  const MT rescale_grad_;
D
dzhwinter 已提交
250
  const int64_t num_;
251
  T* param_out_;
252 253 254 255
  MT* velocity_out_;
  MT* master_param_out_;
  const RegularizationType regularization_flag_;
  const MT regularization_coeff_;
D
dzhwinter 已提交
256 257

 public:
258 259 260 261 262 263 264
  DenseMomentumFunctor(const T* param, const T* grad, const MT* velocity,
                       const MultiPrecisionType<MT>* learning_rate,
                       const MT* master_param, const MT mu,
                       const MT rescale_grad, const int64_t num,
                       const RegularizationType regularization_flag,
                       const MT regularization_coeff, T* param_out,
                       MT* velocity_out, MT* master_param_out)
265 266 267
      : param_(param),
        grad_(grad),
        velocity_(velocity),
D
dzhwinter 已提交
268
        lr_(learning_rate),
269
        master_param_(master_param),
D
dzhwinter 已提交
270
        mu_(mu),
271
        rescale_grad_(rescale_grad),
D
dzhwinter 已提交
272
        num_(num),
273 274
        param_out_(param_out),
        velocity_out_(velocity_out),
275
        master_param_out_(master_param_out),
276 277
        regularization_flag_(regularization_flag),
        regularization_coeff_(regularization_coeff) {}
D
dzhwinter 已提交
278 279
  inline HOSTDEVICE void operator()(size_t i) const {
    // put memory access in register
280 281 282 283 284
    const MT param =
        master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
    MT grad = static_cast<MT>(grad_[i]) * rescale_grad_;
    const MT lr = static_cast<MT>(lr_[0]);
    const MT velocity = velocity_[i];
285 286 287 288 289

    grad = regularization_flag_ == RegularizationType::kL2DECAY
               ? grad + regularization_coeff_ * param
               : grad;

290 291
    MT velocity_out = velocity * mu_ + grad;
    MT param_out = param - lr * velocity_out;
D
dzhwinter 已提交
292
    // write reigster to memory
293
    velocity_out_[i] = velocity_out;
294 295 296 297
    param_out_[i] = static_cast<T>(param_out);
    if (master_param_out_) {
      master_param_out_[i] = param_out;
    }
D
dzhwinter 已提交
298 299 300
  }
};

301
template <typename T, typename MT, typename UpdateMethod>
D
dzhwinter 已提交
302 303
class SparseMomentumFunctor;

304 305
template <typename T, typename MT>
class SparseMomentumFunctor<T, MT, UseNesterov> {
D
dzhwinter 已提交
306
 private:
307 308
  const T* param_;
  const T* grad_;
309 310 311 312 313
  const MT* velocity_;
  const MultiPrecisionType<MT>* lr_;
  const MT* master_param_;
  const MT mu_;
  const MT rescale_grad_;
D
dzhwinter 已提交
314 315 316
  const int64_t* rows_;
  const int64_t row_numel_;
  const int64_t row_height_;
317
  T* param_out_;
318 319 320 321
  MT* velocity_out_;
  MT* master_param_out_;
  const RegularizationType regularization_flag_;
  const MT regularization_coeff_;
D
dzhwinter 已提交
322 323

 public:
324 325 326 327
  SparseMomentumFunctor(const T* param, const T* grad, const MT* velocity,
                        const MultiPrecisionType<MT>* lr,
                        const MT* master_param, const MT mu,
                        const MT rescale_grad, const int64_t* rows,
328
                        int64_t row_numel, int64_t row_height,
329 330 331
                        const RegularizationType regularization_flag,
                        const MT regularization_coeff, T* param_out,
                        MT* velocity_out, MT* master_param_out)
332 333 334
      : param_(param),
        grad_(grad),
        velocity_(velocity),
D
dzhwinter 已提交
335
        lr_(lr),
336
        master_param_(master_param),
D
dzhwinter 已提交
337
        mu_(mu),
338
        rescale_grad_(rescale_grad),
D
dzhwinter 已提交
339 340 341
        rows_(rows),
        row_numel_(row_numel),
        row_height_(row_height),
342 343
        param_out_(param_out),
        velocity_out_(velocity_out),
344
        master_param_out_(master_param_out),
345 346
        regularization_flag_(regularization_flag),
        regularization_coeff_(regularization_coeff) {}
D
dzhwinter 已提交
347 348 349 350

  inline HOSTDEVICE void operator()(size_t i) {
    auto row_idx =
        math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
351 352 353 354 355
    MT grad =
        row_idx >= 0
            ? static_cast<MT>(grad_[row_idx * row_numel_ + i % row_numel_]) *
                  rescale_grad_
            : static_cast<MT>(0);
D
dzhwinter 已提交
356
    // put memory access in register
357 358 359 360
    const MT param =
        master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
    const MT lr = static_cast<MT>(lr_[0]);
    const MT velocity = velocity_[i];
361 362 363 364 365

    grad = regularization_flag_ == RegularizationType::kL2DECAY
               ? grad + regularization_coeff_ * param
               : grad;

366 367
    MT velocity_out = velocity * mu_ + grad;
    MT param_out = param - (grad + velocity_out * mu_) * lr;
D
dzhwinter 已提交
368
    // write reigster to memory
369
    velocity_out_[i] = velocity_out;
370 371 372 373
    param_out_[i] = static_cast<T>(param_out);
    if (master_param_out_) {
      master_param_out_[i] = param_out;
    }
D
dzhwinter 已提交
374 375 376
  }
};

377 378
template <typename T, typename MT>
class SparseMomentumFunctor<T, MT, NoNesterov> {
D
dzhwinter 已提交
379
 private:
380 381
  const T* param_;
  const T* grad_;
382 383 384 385 386
  const MT* velocity_;
  const MultiPrecisionType<MT>* lr_;
  const MT* master_param_;
  const MT mu_;
  const MT rescale_grad_;
D
dzhwinter 已提交
387 388 389
  const int64_t* rows_;
  const int64_t row_numel_;
  const int64_t row_height_;
390
  T* param_out_;
391 392 393 394
  MT* velocity_out_;
  MT* master_param_out_;
  const RegularizationType regularization_flag_;
  const MT regularization_coeff_;
D
dzhwinter 已提交
395 396

 public:
397 398 399 400
  SparseMomentumFunctor(const T* param, const T* grad, const MT* velocity,
                        const MultiPrecisionType<MT>* lr,
                        const MT* master_param, const MT mu,
                        const MT rescale_grad, const int64_t* rows,
401
                        int64_t row_numel, int64_t row_height,
402 403 404
                        const RegularizationType regularization_flag,
                        const MT regularization_coeff, T* param_out,
                        MT* velocity_out, MT* master_param_out)
405 406 407
      : param_(param),
        grad_(grad),
        velocity_(velocity),
D
dzhwinter 已提交
408
        lr_(lr),
409
        master_param_(master_param),
D
dzhwinter 已提交
410
        mu_(mu),
411
        rescale_grad_(rescale_grad),
D
dzhwinter 已提交
412 413 414
        rows_(rows),
        row_numel_(row_numel),
        row_height_(row_height),
415 416
        param_out_(param_out),
        velocity_out_(velocity_out),
417
        master_param_out_(master_param_out),
418 419
        regularization_flag_(regularization_flag),
        regularization_coeff_(regularization_coeff) {}
D
dzhwinter 已提交
420 421 422 423

  inline HOSTDEVICE void operator()(size_t i) {
    auto row_idx =
        math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
424 425 426 427 428
    MT grad =
        row_idx >= 0
            ? static_cast<MT>(grad_[row_idx * row_numel_ + i % row_numel_]) *
                  rescale_grad_
            : static_cast<MT>(0);
D
dzhwinter 已提交
429
    // put memory access in register
430 431 432 433
    const MT param =
        master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
    const MT lr = static_cast<MT>(lr_[0]);
    const MT velocity = velocity_[i];
434 435 436 437 438

    grad = regularization_flag_ == RegularizationType::kL2DECAY
               ? grad + regularization_coeff_ * param
               : grad;

439 440
    MT velocity_out = velocity * mu_ + grad;
    MT param_out = param - velocity_out * lr;
D
dzhwinter 已提交
441
    // write reigster to memory
442
    velocity_out_[i] = velocity_out;
443 444 445 446
    param_out_[i] = static_cast<T>(param_out);
    if (master_param_out_) {
      master_param_out_[i] = param_out;
    }
D
dzhwinter 已提交
447 448 449 450
  }
};

template <typename DeviceContext, typename T>
S
sidgoyal78 已提交
451
class MomentumOpKernel : public framework::OpKernel<T> {
452 453
  using MPDType = MultiPrecisionType<T>;

S
sidgoyal78 已提交
454 455
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
456 457 458 459 460
    const bool multi_precision = ctx.Attr<bool>("multi_precision");
    if (multi_precision) {
      InnerCompute<MPDType>(ctx, multi_precision);
    } else {
      InnerCompute<T>(ctx, multi_precision);
461
    }
462
  }
463

464 465 466 467 468 469 470 471
 private:
  template <typename MT>
  void InnerCompute(const framework::ExecutionContext& ctx,
                    const bool multi_precision) const {
    std::string regularization_method =
        ctx.Attr<std::string>("regularization_method");
    MT regularization_coeff =
        static_cast<MT>(ctx.Attr<float>("regularization_coeff"));
472 473 474 475 476 477
    RegularizationType regularization_flag{
        RegularizationType::kNONE};  // disable regularization
    if (regularization_method == "l2_decay") {
      regularization_flag = RegularizationType::kL2DECAY;
    }

478 479
    MT mu = static_cast<MT>(ctx.Attr<float>("mu"));
    MT rescale_grad = static_cast<MT>(ctx.Attr<float>("rescale_grad"));
480
    bool use_nesterov = ctx.Attr<bool>("use_nesterov");
S
sidgoyal78 已提交
481

482 483 484
    auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
    auto param = ctx.Input<framework::Tensor>("Param");
    auto param_out = ctx.Output<framework::Tensor>("ParamOut");
485
    auto velocity = ctx.Input<framework::Tensor>("Velocity");
D
dzhwinter 已提交
486
    auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
487

488 489 490 491 492 493 494 495 496 497 498 499 500 501
    const framework::Tensor* master_param = nullptr;
    framework::Tensor* master_param_out = nullptr;
    if (multi_precision) {
      bool has_master =
          ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
      PADDLE_ENFORCE_EQ(has_master, true,
                        platform::errors::InvalidArgument(
                            "The Input(MasterParam) and Output(MasterParamOut) "
                            "should not be null when "
                            "the attr `multi_precision` is true"));
      master_param = ctx.Input<framework::Tensor>("MasterParam");
      master_param_out = ctx.Output<framework::Tensor>("MasterParamOut");
    }

D
dzhwinter 已提交
502
    param_out->mutable_data<T>(ctx.GetPlace());
503 504 505 506 507 508
    velocity_out->mutable_data<MT>(ctx.GetPlace());
    const MT* master_in_data =
        multi_precision ? master_param->data<MT>() : nullptr;
    MT* master_out_data =
        multi_precision ? master_param_out->mutable_data<MT>(ctx.GetPlace())
                        : nullptr;
D
dzhwinter 已提交
509

510 511 512
    auto* grad_var = ctx.InputVar("Grad");
    if (grad_var->IsType<framework::LoDTensor>()) {
      auto grad = ctx.Input<framework::Tensor>("Grad");
D
dzhwinter 已提交
513
      if (platform::is_cpu_place(ctx.GetPlace())) {
514 515 516 517
        CPUDenseMomentumFunctor<MT> functor;
        functor(param, grad, velocity, learning_rate, mu, use_nesterov,
                regularization_flag, regularization_coeff, param_out,
                velocity_out);
D
dzhwinter 已提交
518 519 520 521 522
      } else if (platform::is_gpu_place(ctx.GetPlace())) {
        platform::ForRange<DeviceContext> for_range(
            static_cast<const DeviceContext&>(ctx.device_context()),
            param->numel());
        if (use_nesterov) {
523 524 525 526 527 528
          DenseMomentumFunctor<T, MT, UseNesterov> functor(
              param->data<T>(), grad->data<T>(), velocity->data<MT>(),
              learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad,
              param->numel(), regularization_flag, regularization_coeff,
              param_out->mutable_data<T>(ctx.GetPlace()),
              velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
D
dzhwinter 已提交
529 530 531
          for_range(functor);

        } else {
532 533 534 535 536 537
          DenseMomentumFunctor<T, MT, NoNesterov> functor(
              param->data<T>(), grad->data<T>(), velocity->data<MT>(),
              learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad,
              param->numel(), regularization_flag, regularization_coeff,
              param_out->mutable_data<T>(ctx.GetPlace()),
              velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
D
dzhwinter 已提交
538 539
          for_range(functor);
        }
540
      }
D
dzhwinter 已提交
541

542 543 544
    } else if (grad_var->IsType<framework::SelectedRows>()) {
      // sparse update embedding with selectedrows
      auto grad = ctx.Input<framework::SelectedRows>("Grad");
S
sidgoyal78 已提交
545

546 547
      // sparse update maybe empty.
      if (grad->rows().size() == 0) {
M
minqiyang 已提交
548
        VLOG(3) << "Grad SelectedRows contains no data!";
549 550
        return;
      }
S
sneaxiy 已提交
551 552 553

      framework::SelectedRows tmp_merged_grad;
      framework::SelectedRows* merged_grad = &tmp_merged_grad;
D
dzhwinter 已提交
554 555 556 557
      math::scatter::MergeAdd<DeviceContext, T> merge_func;
      merge_func(ctx.template device_context<DeviceContext>(), *grad,
                 merged_grad);

S
sneaxiy 已提交
558
      const int64_t* rows = merged_grad->rows().Data(ctx.GetPlace());
D
dzhwinter 已提交
559 560 561 562 563
      int64_t row_numel =
          merged_grad->value().numel() / merged_grad->rows().size();
      platform::ForRange<DeviceContext> for_range(
          static_cast<const DeviceContext&>(ctx.device_context()),
          param->numel());
D
dzhwinter 已提交
564
      if (use_nesterov) {
565
        SparseMomentumFunctor<T, MT, UseNesterov> functor(
D
dzhwinter 已提交
566
            param->data<T>(), merged_grad->value().data<T>(),
567 568
            velocity->data<MT>(), learning_rate->data<MPDType>(),
            master_in_data, mu, rescale_grad, rows, row_numel,
D
dzhwinter 已提交
569
            static_cast<int64_t>(merged_grad->rows().size()),
570
            regularization_flag, regularization_coeff,
D
dzhwinter 已提交
571
            param_out->mutable_data<T>(ctx.GetPlace()),
572
            velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
D
dzhwinter 已提交
573 574 575
        for_range(functor);

      } else {
576
        SparseMomentumFunctor<T, MT, NoNesterov> functor(
D
dzhwinter 已提交
577
            param->data<T>(), merged_grad->value().data<T>(),
578 579
            velocity->data<MT>(), learning_rate->data<MPDType>(),
            master_in_data, mu, rescale_grad, rows, row_numel,
D
dzhwinter 已提交
580
            static_cast<int64_t>(merged_grad->rows().size()),
581
            regularization_flag, regularization_coeff,
D
dzhwinter 已提交
582
            param_out->mutable_data<T>(ctx.GetPlace()),
583
            velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
D
dzhwinter 已提交
584
        for_range(functor);
585
      }
K
kavyasrinet 已提交
586
    } else {
C
Chengmo 已提交
587 588 589 590 591 592
      PADDLE_ENFORCE_EQ(false, true,
                        platform::errors::PermissionDenied(
                            "Unsupported Variable Type of Grad "
                            "in MomentumOp. Excepted LodTensor "
                            "or SelectedRows, But received [%s]",
                            paddle::framework::ToTypeName(grad_var->Type())));
K
kavyasrinet 已提交
593
    }
S
sidgoyal78 已提交
594 595 596 597 598
  }
};

}  // namespace operators
}  // namespace paddle