parameter_optimizer.h 1000 字节
Newer Older
D
dzhwinter 已提交
1
#pragma once
2 3 4 5 6 7

#include <glog/logging.h>
#include <functional>
#include <string>
#include "OptimizerConfig.pb.h"
#include "lr_policy.h"
D
dzhwinter 已提交
8
#include "tensor.h"
9 10 11 12

namespace paddle {
namespace optimizer {

13 14
const std::string kOptimizerVersion = "1.0";

15 16 17 18 19 20
class ParameterOptimizer {
public:
  /**
   * @brief  update hook for algorithm need to traverse parameter more than
   * once.
   */
D
dzhwinter 已提交
21
  ParameterOptimizer(LrPolicy *lr) : lr_policy_(lr), num_sample_passed_(0) {}
22 23
  virtual ~ParameterOptimizer() { delete parameter_; };

D
dzhwinter 已提交
24
  static ParameterOptimizer *Create(const std::string &config_proto);
25 26
  virtual const char *SerializeState();
  virtual void DeSerializeState(const std::string &state);
D
dzhwinter 已提交
27
  virtual void Update(const Tensor *gradient) = 0;
28
  virtual float *get_weight(int *param_size) const;
D
dzhwinter 已提交
29
  virtual void set_weight(Tensor *parameter);
30

D
dzhwinter 已提交
31
protected:
D
dzhwinter 已提交
32
  Tensor *parameter_;
33 34

  // learning rate policy
D
dzhwinter 已提交
35 36
  LrPolicy *lr_policy_;
  uint64_t num_sample_passed_;
37 38 39 40
};

}  // namespace optimizer
}  // namespace paddle