parameter_optimizer.h 1.1 KB
Newer Older
D
dzhwinter 已提交
1
#pragma once
2 3 4 5 6 7

#include <glog/logging.h>
#include <functional>
#include <string>
#include "OptimizerConfig.pb.h"
#include "lr_policy.h"
D
dzhwinter 已提交
8
#include "serialization.h"
D
dzhwinter 已提交
9
#include "tensor.h"
10

D
dzhwinter 已提交
11 12 13 14
// Not Implemen Yet, macr
// o
#define NIMPL crash(__PRETTY_FUNCTION__, " not implemented yet")

15 16 17 18 19 20 21 22 23
namespace paddle {
namespace optimizer {

class ParameterOptimizer {
public:
  /**
   * @brief  update hook for algorithm need to traverse parameter more than
   * once.
   */
D
dzhwinter 已提交
24 25
  ParameterOptimizer(Tensor *parameter, LrPolicy *lr)
      : parameter_(parameter), lr_policy_(lr), num_sample_passed_(0) {}
26 27
  virtual ~ParameterOptimizer() { delete parameter_; };

D
dzhwinter 已提交
28 29
  static ParameterOptimizer *Create(const std::string &config_proto,
                                    Tensor *parameter);
D
dzhwinter 已提交
30
  virtual void Update(const Tensor *gradient) = 0;
31
  virtual float *get_weight(int *param_size) const;
D
dzhwinter 已提交
32 33
  virtual const char *SerializeState(int *state_len) = 0;
  virtual void DeSerializeState(const std::string &state) = 0;
34

D
dzhwinter 已提交
35
protected:
D
dzhwinter 已提交
36
  Tensor *parameter_;
37
  // learning rate policy
D
dzhwinter 已提交
38 39
  LrPolicy *lr_policy_;
  uint64_t num_sample_passed_;
40 41 42 43
};

}  // namespace optimizer
}  // namespace paddle