parameter_optimizer.h 1.0 KB
Newer Older
D
dzhwinter 已提交
1
#pragma once
2 3 4 5 6 7

#include <glog/logging.h>
#include <functional>
#include <string>
#include "OptimizerConfig.pb.h"
#include "lr_policy.h"
D
dzhwinter 已提交
8
#include "serialization.h"
D
dzhwinter 已提交
9
#include "tensor.h"
10 11 12 13 14 15 16 17 18 19

namespace paddle {
namespace optimizer {

class ParameterOptimizer {
public:
  /**
   * @brief  update hook for algorithm need to traverse parameter more than
   * once.
   */
D
dzhwinter 已提交
20 21
  ParameterOptimizer(Tensor *parameter, LrPolicy *lr)
      : parameter_(parameter), lr_policy_(lr), num_sample_passed_(0) {}
22 23
  virtual ~ParameterOptimizer() { delete parameter_; };

D
dzhwinter 已提交
24 25
  static ParameterOptimizer *Create(const std::string &config_proto,
                                    Tensor *parameter);
D
dzhwinter 已提交
26
  virtual void Update(const Tensor *gradient) = 0;
27
  virtual float *get_weight(int *param_size) const;
D
dzhwinter 已提交
28
  virtual const char *SerializeState(int *state_len) = 0;
D
dzhwinter 已提交
29
  virtual void DeserializeState(const std::string &state) = 0;
30

D
dzhwinter 已提交
31
protected:
D
dzhwinter 已提交
32
  Tensor *parameter_;
33
  // learning rate policy
D
dzhwinter 已提交
34 35
  LrPolicy *lr_policy_;
  uint64_t num_sample_passed_;
36 37 38 39
};

}  // namespace optimizer
}  // namespace paddle