parameter_optimizer.h 1011 字节
Newer Older
D
dzhwinter 已提交
1
#pragma once
2 3 4 5 6 7

#include <glog/logging.h>
#include <functional>
#include <string>
#include "OptimizerConfig.pb.h"
#include "lr_policy.h"
D
dzhwinter 已提交
8
#include "tensor.h"
9 10 11 12

namespace paddle {
namespace optimizer {

13 14
const std::string kOptimizerVersion = "1.0";

15 16 17 18 19 20
class ParameterOptimizer {
public:
  /**
   * @brief  update hook for algorithm need to traverse parameter more than
   * once.
   */
D
dzhwinter 已提交
21
  ParameterOptimizer(LrPolicy *lr) : lr_policy_(lr), num_sample_passed_(0) {}
22 23
  virtual ~ParameterOptimizer() { delete parameter_; };

D
dzhwinter 已提交
24
  static ParameterOptimizer *Create(const std::string &config_proto);
25 26
  virtual const char *SerializeState();
  virtual void DeSerializeState(const std::string &state);
D
dzhwinter 已提交
27
  virtual void Update(const Tensor *gradient) = 0;
D
dzhwinter 已提交
28 29
  virtual real *get_weight() const;
  virtual void set_weight(Tensor *parameter);
30

D
dzhwinter 已提交
31
protected:
32
  OptimizerConfig config_;
D
dzhwinter 已提交
33
  Tensor *parameter_;
34 35

  // learning rate policy
D
dzhwinter 已提交
36 37
  LrPolicy *lr_policy_;
  uint64_t num_sample_passed_;
38 39 40 41
};

}  // namespace optimizer
}  // namespace paddle