lr_policy.h 1.0 KB
Newer Older
1 2 3
#ifndef PADDLE_OPTIMIZER_LR_POLICY_H_
#define PADDLE_OPTIMIZER_LR_POLICY_H_

D
dzhwinter 已提交
4
#include <algorithm>
5
#include "OptimizerConfig.pb.h"
6 7 8 9 10 11

namespace paddle {
namespace optimizer {

class BaseLr {
public:
12 13
  BaseLr(double lr) : learning_rate(lr) {}
  virtual ~BaseLr() {}
14 15
  virtual double get_learning_rate(const uint64_t num_sample_passed) = 0;

16
protected:
17 18 19 20 21 22
  double learning_rate;
};

// constant learning rate policy
class ConstLr final : public BaseLr {
public:
D
dzhwinter 已提交
23
  ConstLr(double lr) : BaseLr(lr){};
24 25 26 27 28
  double get_learning_rate(const uint64_t num_sample_passed) {
    return learning_rate;
  }
};

D
dzhwinter 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41
class LinearLr final : public BaseLr {
public:
  LinearLr(double lr, double lr_decay_a, double lr_decay_b)
      : BaseLr(lr), lr_decay_a(lr_decay_a), lr_decay_b(lr_decay_b) {}
  double get_learning_rate(const uint64_t num_sample_passed) {
    return std::max(learning_rate - lr_decay_a * num_sample_passed, lr_decay_b);
  }

private:
  double lr_decay_a;
  double lr_decay_b;
};

42 43 44 45
}  // namespace optimizer
}  // namespace paddle

#endif