momentum_op.h 18.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
S
sidgoyal78 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
S
sneaxiy 已提交
16
#include <memory>
D
dzhwinter 已提交
17
#include <string>
Y
Yi Wang 已提交
18 19
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
D
dzhwinter 已提交
20 21 22
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/for_range.h"
S
sidgoyal78 已提交
23 24 25 26

namespace paddle {
namespace operators {

D
dzhwinter 已提交
27 28 29 30 31
using framework::Tensor;
using framework::SelectedRows;
struct NoNesterov;
struct UseNesterov;

32 33 34 35 36 37
enum class RegularizationType {
  kNONE = 0,
  kL1DECAY = 1,  // do not need support right now
  kL2DECAY = 2,
};

38 39 40 41 42
class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override;
};

43 44 45 46 47 48
class MomentumOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
C
Chengmo 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
    PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true,
                      platform::errors::NotFound(
                          "Input(param) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"), true,
                      platform::errors::NotFound(
                          "Input(grad) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasInput("Velocity"), true,
                      platform::errors::NotFound(
                          "Input(velocity) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("LearningRate"), true,
        platform::errors::NotFound(
            "Input(LearningRate) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(
        ctx->GetInputsVarType("Param").front(),
        framework::proto::VarType::LOD_TENSOR,
        platform::errors::InvalidArgument(
            "The input var's type should be LoDTensor, but the received is %s",
            ctx->GetInputsVarType("Param").front()));

    PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"), true,
                      platform::errors::NotFound(
                          "Output(ParamOut) of Momentum should not be null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasOutput("VelocityOut"), true,
        platform::errors::NotFound(
            "Output(VelocityOut) of Momentum should not be null."));
76

77 78
    auto lr_dims = ctx->GetInputDim("LearningRate");
    PADDLE_ENFORCE_NE(framework::product(lr_dims), 0,
C
Chengmo 已提交
79 80 81 82 83
                      platform::errors::InvalidArgument(
                          "Maybe the Input variable LearningRate has not "
                          "been initialized. You may need to confirm "
                          "if you put exe.run(startup_program) "
                          "after optimizer.minimize function."));
84
    PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
C
Chengmo 已提交
85 86 87 88
                      platform::errors::InvalidArgument(
                          "Learning_rate should be a scalar. But Received "
                          "LearningRate's dim [%s]",
                          framework::product(lr_dims)));
89

90 91 92 93 94
    auto param_dim = ctx->GetInputDim("Param");
    if (ctx->GetInputsVarType("Grad")[0] ==
        framework::proto::VarType::LOD_TENSOR) {
      PADDLE_ENFORCE_EQ(
          param_dim, ctx->GetInputDim("Grad"),
C
Chengmo 已提交
95 96 97 98
          platform::errors::InvalidArgument(
              "Param and Grad input of MomentumOp should have the same "
              "dimension. But received Param's dim [%s] and Grad's dim [%s].",
              param_dim, ctx->GetInputDim("Grad")));
99 100
      PADDLE_ENFORCE_EQ(
          param_dim, ctx->GetInputDim("Velocity"),
C
Chengmo 已提交
101 102 103 104
          platform::errors::InvalidArgument(
              "Param and Velocity of MomentumOp should have the same "
              "dimension. But received Param's dim [%s] and Velocity [%s].",
              param_dim, ctx->GetInputDim("Velocity")));
105 106 107 108 109
    }

    ctx->SetOutputDim("ParamOut", param_dim);
    ctx->SetOutputDim("VelocityOut", param_dim);
  }
S
sneaxiy 已提交
110

111 112
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
113 114
    auto input_data_type =
        OperatorWithKernel::IndicateVarDataType(ctx, "Param");
115 116 117 118
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
};

D
dzhwinter 已提交
119 120 121
template <typename T>
class CPUDenseMomentumFunctor {
 private:
122 123 124 125 126 127 128 129 130 131
  const Tensor* param_;
  const Tensor* grad_;
  const Tensor* velocity_;
  const Tensor* learning_rate_;
  const T mu_;
  const T use_nesterov_;
  RegularizationType regularization_flag_;
  const T regularization_coeff_;
  Tensor* param_out_;
  Tensor* velocity_out_;
D
dzhwinter 已提交
132 133 134 135 136

 public:
  CPUDenseMomentumFunctor(const Tensor* param, const Tensor* grad,
                          const Tensor* velocity, const Tensor* learning_rate,
                          const T mu, const bool use_nesterov,
137 138 139 140 141 142 143 144 145 146 147 148 149
                          RegularizationType regularization_flag,
                          const T regularization_coeff, Tensor* param_out,
                          Tensor* velocity_out)
      : param_(param),
        grad_(grad),
        velocity_(velocity),
        learning_rate_(learning_rate),
        mu_(mu),
        use_nesterov_(use_nesterov),
        regularization_flag_(regularization_flag),
        regularization_coeff_(regularization_coeff),
        param_out_(param_out),
        velocity_out_(velocity_out) {}
D
dzhwinter 已提交
150 151

  inline void operator()() {
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
    auto param_out = framework::EigenVector<T>::Flatten(*param_out_);
    auto velocity_out = framework::EigenVector<T>::Flatten(*velocity_out_);

    auto param = framework::EigenVector<T>::Flatten(*param_);
    auto velocity = framework::EigenVector<T>::Flatten(*velocity_);
    auto grad = framework::EigenVector<T>::Flatten(*grad_);
    auto* lr = learning_rate_->data<T>();

    if (regularization_flag_ == RegularizationType::kL2DECAY) {
      velocity_out = velocity * mu_ + param * regularization_coeff_ + grad;
      if (use_nesterov_) {
        param_out =
            param -
            (param * regularization_coeff_ + grad + velocity_out * mu_) * lr[0];
      } else {
        param_out = param - lr[0] * velocity_out;
      }
D
dzhwinter 已提交
169
    } else {
170 171 172 173 174 175
      velocity_out = velocity * mu_ + grad;
      if (use_nesterov_) {
        param_out = param - (grad + velocity_out * mu_) * lr[0];
      } else {
        param_out = param - lr[0] * velocity_out;
      }
D
dzhwinter 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188
    }
  }
};

template <typename T, typename UpdateMethod>
class DenseMomentumFunctor;

// NOTE(dzh) for performance.
// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two
// functor.
template <typename T>
class DenseMomentumFunctor<T, UseNesterov> {
 private:
189 190 191
  const T* param_;
  const T* grad_;
  const T* velocity_;
D
dzhwinter 已提交
192 193 194
  const T* lr_;
  const T mu_;
  const int64_t num_;
195 196 197 198
  T* param_out_;
  T* velocity_out_;
  RegularizationType regularization_flag_;
  const T regularization_coeff_;
D
dzhwinter 已提交
199 200

 public:
201
  DenseMomentumFunctor(const T* param, const T* grad, const T* velocity,
D
dzhwinter 已提交
202
                       const T* learning_rate, const T mu, const int64_t num,
203 204 205 206 207 208
                       RegularizationType regularization_flag,
                       const T regularization_coeff, T* param_out,
                       T* velocity_out)
      : param_(param),
        grad_(grad),
        velocity_(velocity),
D
dzhwinter 已提交
209 210 211
        lr_(learning_rate),
        mu_(mu),
        num_(num),
212 213 214 215 216
        param_out_(param_out),
        velocity_out_(velocity_out),
        regularization_flag_(regularization_flag),
        regularization_coeff_(regularization_coeff) {}

D
dzhwinter 已提交
217 218
  inline HOSTDEVICE void operator()(size_t i) const {
    // put memory access in register
219 220
    const T param = param_[i];
    T grad = grad_[i];
D
dzhwinter 已提交
221
    const T lr = lr_[0];
222 223 224 225 226 227 228 229
    const T velocity = velocity_[i];

    grad = regularization_flag_ == RegularizationType::kL2DECAY
               ? grad + regularization_coeff_ * param
               : grad;

    T velocity_out = velocity * mu_ + grad;
    T param_out = param - (grad + velocity_out * mu_) * lr;
D
dzhwinter 已提交
230
    // write reigster to memory
231 232
    velocity_out_[i] = velocity_out;
    param_out_[i] = param_out;
D
dzhwinter 已提交
233 234 235 236 237 238
  }
};

template <typename T>
class DenseMomentumFunctor<T, NoNesterov> {
 private:
239 240 241
  const T* param_;
  const T* grad_;
  const T* velocity_;
D
dzhwinter 已提交
242 243 244
  const T* lr_;
  const T mu_;
  const int64_t num_;
245 246 247 248
  T* param_out_;
  T* velocity_out_;
  RegularizationType regularization_flag_;
  const T regularization_coeff_;
D
dzhwinter 已提交
249 250

 public:
251
  DenseMomentumFunctor(const T* param, const T* grad, const T* velocity,
D
dzhwinter 已提交
252
                       const T* learning_rate, const T mu, const int64_t num,
253 254 255 256 257 258
                       RegularizationType regularization_flag,
                       const T regularization_coeff, T* param_out,
                       T* velocity_out)
      : param_(param),
        grad_(grad),
        velocity_(velocity),
D
dzhwinter 已提交
259 260 261
        lr_(learning_rate),
        mu_(mu),
        num_(num),
262 263 264 265 266
        param_out_(param_out),
        velocity_out_(velocity_out),
        regularization_flag_(regularization_flag),
        regularization_coeff_(regularization_coeff) {}

D
dzhwinter 已提交
267 268
  inline HOSTDEVICE void operator()(size_t i) const {
    // put memory access in register
269 270
    const T param = param_[i];
    T grad = grad_[i];
D
dzhwinter 已提交
271
    const T lr = lr_[0];
272 273 274 275 276 277 278 279
    const T velocity = velocity_[i];

    grad = regularization_flag_ == RegularizationType::kL2DECAY
               ? grad + regularization_coeff_ * param
               : grad;

    T velocity_out = velocity * mu_ + grad;
    T param_out = param - lr * velocity_out;
D
dzhwinter 已提交
280
    // write reigster to memory
281 282
    velocity_out_[i] = velocity_out;
    param_out_[i] = param_out;
D
dzhwinter 已提交
283 284 285 286 287 288
  }
};

template <typename T, typename UpdateMethod>
class SparseMomentumFunctor;

289
template <typename T>
D
dzhwinter 已提交
290 291
class SparseMomentumFunctor<T, UseNesterov> {
 private:
292 293 294
  const T* param_;
  const T* grad_;
  const T* velocity_;
D
dzhwinter 已提交
295 296 297 298 299
  const T* lr_;
  const T mu_;
  const int64_t* rows_;
  const int64_t row_numel_;
  const int64_t row_height_;
300 301 302 303
  T* param_out_;
  T* velocity_out_;
  RegularizationType regularization_flag_;
  const T regularization_coeff_;
D
dzhwinter 已提交
304 305

 public:
306 307 308 309 310 311 312 313 314
  SparseMomentumFunctor(const T* param, const T* grad, const T* velocity,
                        const T* lr, const T mu, const int64_t* rows,
                        int64_t row_numel, int64_t row_height,
                        RegularizationType regularization_flag,
                        const T regularization_coeff, T* param_out,
                        T* velocity_out)
      : param_(param),
        grad_(grad),
        velocity_(velocity),
D
dzhwinter 已提交
315 316 317 318 319
        lr_(lr),
        mu_(mu),
        rows_(rows),
        row_numel_(row_numel),
        row_height_(row_height),
320 321 322 323
        param_out_(param_out),
        velocity_out_(velocity_out),
        regularization_flag_(regularization_flag),
        regularization_coeff_(regularization_coeff) {}
D
dzhwinter 已提交
324 325 326 327

  inline HOSTDEVICE void operator()(size_t i) {
    auto row_idx =
        math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
328 329
    T grad = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_]
                          : static_cast<T>(0);
D
dzhwinter 已提交
330
    // put memory access in register
331
    const T param = param_[i];
D
dzhwinter 已提交
332
    const T lr = lr_[0];
333 334 335 336 337 338 339 340
    const T velocity = velocity_[i];

    grad = regularization_flag_ == RegularizationType::kL2DECAY
               ? grad + regularization_coeff_ * param
               : grad;

    T velocity_out = velocity * mu_ + grad;
    T param_out = param - (grad + velocity_out * mu_) * lr;
D
dzhwinter 已提交
341
    // write reigster to memory
342 343
    velocity_out_[i] = velocity_out;
    param_out_[i] = param_out;
D
dzhwinter 已提交
344 345 346 347 348 349
  }
};

template <typename T>
class SparseMomentumFunctor<T, NoNesterov> {
 private:
350 351 352
  const T* param_;
  const T* grad_;
  const T* velocity_;
D
dzhwinter 已提交
353 354 355 356 357
  const T* lr_;
  const T mu_;
  const int64_t* rows_;
  const int64_t row_numel_;
  const int64_t row_height_;
358 359 360 361
  T* param_out_;
  T* velocity_out_;
  RegularizationType regularization_flag_;
  const T regularization_coeff_;
D
dzhwinter 已提交
362 363

 public:
364 365 366 367 368 369 370 371 372
  SparseMomentumFunctor(const T* param, const T* grad, const T* velocity,
                        const T* lr, const T mu, const int64_t* rows,
                        int64_t row_numel, int64_t row_height,
                        RegularizationType regularization_flag,
                        const T regularization_coeff, T* param_out,
                        T* velocity_out)
      : param_(param),
        grad_(grad),
        velocity_(velocity),
D
dzhwinter 已提交
373 374 375 376 377
        lr_(lr),
        mu_(mu),
        rows_(rows),
        row_numel_(row_numel),
        row_height_(row_height),
378 379 380 381
        param_out_(param_out),
        velocity_out_(velocity_out),
        regularization_flag_(regularization_flag),
        regularization_coeff_(regularization_coeff) {}
D
dzhwinter 已提交
382 383 384 385

  inline HOSTDEVICE void operator()(size_t i) {
    auto row_idx =
        math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
386 387
    T grad = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_]
                          : static_cast<T>(0);
D
dzhwinter 已提交
388
    // put memory access in register
389
    const T param = param_[i];
D
dzhwinter 已提交
390
    const T lr = lr_[0];
391 392 393 394 395 396 397 398
    const T velocity = velocity_[i];

    grad = regularization_flag_ == RegularizationType::kL2DECAY
               ? grad + regularization_coeff_ * param
               : grad;

    T velocity_out = velocity * mu_ + grad;
    T param_out = param - velocity_out * lr;
D
dzhwinter 已提交
399
    // write reigster to memory
400 401
    velocity_out_[i] = velocity_out;
    param_out_[i] = param_out;
D
dzhwinter 已提交
402 403 404 405
  }
};

template <typename DeviceContext, typename T>
S
sidgoyal78 已提交
406 407 408
class MomentumOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426
    std::string regularization_method =
        ctx.Attr<std::string>("regularization_method");
    if (regularization_method != "" || !regularization_method.empty()) {
      PADDLE_ENFORCE_EQ("l2_decay", regularization_method,
                        platform::errors::InvalidArgument(
                            "if regularization_method is not null, "
                            "it should be l2_decay, but received %s",
                            regularization_method));
    }

    T regularization_coeff =
        static_cast<T>(ctx.Attr<float>("regularization_coeff"));
    RegularizationType regularization_flag{
        RegularizationType::kNONE};  // disable regularization
    if (regularization_method == "l2_decay") {
      regularization_flag = RegularizationType::kL2DECAY;
    }

427
    T mu = static_cast<T>(ctx.Attr<float>("mu"));
428
    bool use_nesterov = ctx.Attr<bool>("use_nesterov");
S
sidgoyal78 已提交
429

430 431 432
    auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
    auto param = ctx.Input<framework::Tensor>("Param");
    auto param_out = ctx.Output<framework::Tensor>("ParamOut");
D
dzhwinter 已提交
433 434
    auto* velocity = ctx.Input<framework::Tensor>("Velocity");
    auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
435

D
dzhwinter 已提交
436 437 438
    param_out->mutable_data<T>(ctx.GetPlace());
    velocity_out->mutable_data<T>(ctx.GetPlace());

439 440 441
    auto* grad_var = ctx.InputVar("Grad");
    if (grad_var->IsType<framework::LoDTensor>()) {
      auto grad = ctx.Input<framework::Tensor>("Grad");
D
dzhwinter 已提交
442
      if (platform::is_cpu_place(ctx.GetPlace())) {
443 444 445
        CPUDenseMomentumFunctor<T> functor(
            param, grad, velocity, learning_rate, mu, use_nesterov,
            regularization_flag, regularization_coeff, param_out, velocity_out);
D
dzhwinter 已提交
446 447 448 449 450 451 452 453
        functor();
      } else if (platform::is_gpu_place(ctx.GetPlace())) {
        platform::ForRange<DeviceContext> for_range(
            static_cast<const DeviceContext&>(ctx.device_context()),
            param->numel());
        if (use_nesterov) {
          DenseMomentumFunctor<T, UseNesterov> functor(
              param->data<T>(), grad->data<T>(), velocity->data<T>(),
454 455
              learning_rate->data<T>(), mu, param->numel(), regularization_flag,
              regularization_coeff, param_out->mutable_data<T>(ctx.GetPlace()),
D
dzhwinter 已提交
456 457 458 459 460 461
              velocity_out->mutable_data<T>(ctx.GetPlace()));
          for_range(functor);

        } else {
          DenseMomentumFunctor<T, NoNesterov> functor(
              param->data<T>(), grad->data<T>(), velocity->data<T>(),
462 463
              learning_rate->data<T>(), mu, param->numel(), regularization_flag,
              regularization_coeff, param_out->mutable_data<T>(ctx.GetPlace()),
D
dzhwinter 已提交
464 465 466
              velocity_out->mutable_data<T>(ctx.GetPlace()));
          for_range(functor);
        }
467
      }
D
dzhwinter 已提交
468

469 470 471
    } else if (grad_var->IsType<framework::SelectedRows>()) {
      // sparse update embedding with selectedrows
      auto grad = ctx.Input<framework::SelectedRows>("Grad");
S
sidgoyal78 已提交
472

473 474
      // sparse update maybe empty.
      if (grad->rows().size() == 0) {
M
minqiyang 已提交
475
        VLOG(3) << "Grad SelectedRows contains no data!";
476 477
        return;
      }
S
sneaxiy 已提交
478 479 480

      framework::SelectedRows tmp_merged_grad;
      framework::SelectedRows* merged_grad = &tmp_merged_grad;
D
dzhwinter 已提交
481 482 483 484
      math::scatter::MergeAdd<DeviceContext, T> merge_func;
      merge_func(ctx.template device_context<DeviceContext>(), *grad,
                 merged_grad);

S
sneaxiy 已提交
485
      const int64_t* rows = merged_grad->rows().Data(ctx.GetPlace());
D
dzhwinter 已提交
486 487 488 489 490
      int64_t row_numel =
          merged_grad->value().numel() / merged_grad->rows().size();
      platform::ForRange<DeviceContext> for_range(
          static_cast<const DeviceContext&>(ctx.device_context()),
          param->numel());
D
dzhwinter 已提交
491 492 493
      if (use_nesterov) {
        SparseMomentumFunctor<T, UseNesterov> functor(
            param->data<T>(), merged_grad->value().data<T>(),
D
dzhwinter 已提交
494
            velocity->data<T>(), learning_rate->data<T>(), mu, rows, row_numel,
D
dzhwinter 已提交
495
            static_cast<int64_t>(merged_grad->rows().size()),
496
            regularization_flag, regularization_coeff,
D
dzhwinter 已提交
497 498 499 500 501 502 503
            param_out->mutable_data<T>(ctx.GetPlace()),
            velocity_out->mutable_data<T>(ctx.GetPlace()));
        for_range(functor);

      } else {
        SparseMomentumFunctor<T, NoNesterov> functor(
            param->data<T>(), merged_grad->value().data<T>(),
D
dzhwinter 已提交
504
            velocity->data<T>(), learning_rate->data<T>(), mu, rows, row_numel,
D
dzhwinter 已提交
505
            static_cast<int64_t>(merged_grad->rows().size()),
506
            regularization_flag, regularization_coeff,
D
dzhwinter 已提交
507 508 509
            param_out->mutable_data<T>(ctx.GetPlace()),
            velocity_out->mutable_data<T>(ctx.GetPlace()));
        for_range(functor);
510
      }
K
kavyasrinet 已提交
511
    } else {
C
Chengmo 已提交
512 513 514 515 516 517
      PADDLE_ENFORCE_EQ(false, true,
                        platform::errors::PermissionDenied(
                            "Unsupported Variable Type of Grad "
                            "in MomentumOp. Excepted LodTensor "
                            "or SelectedRows, But received [%s]",
                            paddle::framework::ToTypeName(grad_var->Type())));
K
kavyasrinet 已提交
518
    }
S
sidgoyal78 已提交
519 520 521 522 523
  }
};

}  // namespace operators
}  // namespace paddle