parameter_optimizer.h 1.1 KB
Newer Older
D
dzhwinter 已提交
1
#pragma once
2 3 4 5 6 7

#include <glog/logging.h>
#include <functional>
#include <string>
#include "OptimizerConfig.pb.h"
#include "lr_policy.h"
D
dzhwinter 已提交
8
#include "serialization.h"
D
dzhwinter 已提交
9
#include "tensor.h"
10 11 12 13 14 15 16 17 18 19

namespace paddle {
namespace optimizer {

class ParameterOptimizer {
public:
  /**
   * @brief  update hook for algorithm need to traverse parameter more than
   * once.
   */
D
dzhwinter 已提交
20 21
  ParameterOptimizer(Tensor *parameter, LrPolicy *lr)
      : parameter_(parameter), lr_policy_(lr), num_sample_passed_(0) {}
D
dzhwinter 已提交
22 23 24 25
  virtual ~ParameterOptimizer() {
    delete parameter_;
    delete lr_policy_;
  }
26

D
dzhwinter 已提交
27 28
  static ParameterOptimizer *Create(const std::string &config_proto,
                                    Tensor *parameter);
D
dzhwinter 已提交
29
  virtual void Update(const Tensor *gradient) = 0;
30
  virtual float *get_weight(int *param_size) const;
D
dzhwinter 已提交
31
  virtual const char *SerializeState(int *state_len) = 0;
D
dzhwinter 已提交
32
  virtual void DeserializeState(const std::string &state) = 0;
33

D
dzhwinter 已提交
34
protected:
D
dzhwinter 已提交
35
  Tensor *parameter_;
36
  // learning rate policy
D
dzhwinter 已提交
37 38
  LrPolicy *lr_policy_;
  uint64_t num_sample_passed_;
39 40 41 42
};

}  // namespace optimizer
}  // namespace paddle