lr_policy.h 588 字节
Newer Older
1 2 3
#ifndef PADDLE_OPTIMIZER_LR_POLICY_H_
#define PADDLE_OPTIMIZER_LR_POLICY_H_

4
#include "OptimizerConfig.pb.h"
5 6 7 8 9 10

namespace paddle {
namespace optimizer {

class BaseLr {
public:
11 12
  BaseLr(double lr) : learning_rate(lr) {}
  virtual ~BaseLr() {}
13 14
  virtual double get_learning_rate(const uint64_t num_sample_passed) = 0;

15
protected:
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
  double learning_rate;
};

// constant learning rate policy
class ConstLr final : public BaseLr {
public:
  double get_learning_rate(const uint64_t num_sample_passed) {
    return learning_rate;
  }
};

}  // namespace optimizer
}  // namespace paddle

#endif