lr_policy.h 1015 字节
Newer Older
1 2 3
#ifndef PADDLE_OPTIMIZER_LR_POLICY_H_
#define PADDLE_OPTIMIZER_LR_POLICY_H_

D
dzhwinter 已提交
4
#include <algorithm>
5
#include "OptimizerConfig.pb.h"
6 7 8 9

namespace paddle {
namespace optimizer {

D
dzhwinter 已提交
10
class LrPolicy {
11
public:
D
dzhwinter 已提交
12 13
  virtual ~LrPolicy() {}
  virtual double LearningRate(const uint64_t num_sample_passed) = 0;
14 15 16
};

// constant learning rate policy
D
dzhwinter 已提交
17
class ConstLr final : public LrPolicy {
18
public:
D
dzhwinter 已提交
19 20
  ConstLr(double lr) : learning_rate(lr){};
  double LearningRate(const uint64_t num_sample_passed) {
21 22
    return learning_rate;
  }
D
dzhwinter 已提交
23 24 25

protected:
  double learning_rate;
26 27
};

D
dzhwinter 已提交
28
class LinearLr final : public LrPolicy {
D
dzhwinter 已提交
29 30
public:
  LinearLr(double lr, double lr_decay_a, double lr_decay_b)
D
dzhwinter 已提交
31 32
      : learning_rate(lr), lr_decay_a(lr_decay_a), lr_decay_b(lr_decay_b) {}
  double LearningRate(const uint64_t num_sample_passed) {
D
dzhwinter 已提交
33 34 35 36
    return std::max(learning_rate - lr_decay_a * num_sample_passed, lr_decay_b);
  }

private:
D
dzhwinter 已提交
37
  double learning_rate;
D
dzhwinter 已提交
38 39 40 41
  double lr_decay_a;
  double lr_decay_b;
};

42 43 44 45
}  // namespace optimizer
}  // namespace paddle

#endif