parameter_optimizer.h 1.1 KB
Newer Older
1 2 3 4 5 6 7 8
#ifndef PADDLE_PARAMETER_OPTIMIZER_H_
#define PADDLE_PARAMETER_OPTIMIZER_H_

#include <glog/logging.h>
#include <functional>
#include <string>
#include "OptimizerConfig.pb.h"
#include "lr_policy.h"
D
dzhwinter 已提交
9
#include "tensor.h"
10 11 12 13

namespace paddle {
namespace optimizer {

14 15
const std::string kOptimizerVersion = "1.0";

16 17 18 19 20 21
class ParameterOptimizer {
public:
  /**
   * @brief  update hook for algorithm need to traverse parameter more than
   * once.
   */
D
dzhwinter 已提交
22
  ParameterOptimizer(LrPolicy *lr) : lr_policy_(lr), num_sample_passed_(0) {}
23 24
  virtual ~ParameterOptimizer() { delete parameter_; };

D
dzhwinter 已提交
25
  static ParameterOptimizer *Create(const std::string &config_proto);
26 27
  virtual const char *SerializeState();
  virtual void DeSerializeState(const std::string &state);
D
dzhwinter 已提交
28
  virtual void Update(const Tensor *gradient) = 0;
D
dzhwinter 已提交
29 30
  virtual real *get_weight() const;
  virtual void set_weight(Tensor *parameter);
31

D
dzhwinter 已提交
32
protected:
33
  OptimizerConfig config_;
D
dzhwinter 已提交
34
  Tensor *parameter_;
35 36

  // learning rate policy
D
dzhwinter 已提交
37 38
  LrPolicy *lr_policy_;
  uint64_t num_sample_passed_;
39 40 41 42 43 44
};

}  // namespace optimizer
}  // namespace paddle

#endif