parameter_optimizer.h 937 字节
Newer Older
1 2 3 4 5 6 7 8
#ifndef PADDLE_PARAMETER_OPTIMIZER_H_
#define PADDLE_PARAMETER_OPTIMIZER_H_

#include <glog/logging.h>
#include <functional>
#include <string>
#include "OptimizerConfig.pb.h"
#include "lr_policy.h"
D
dzhwinter 已提交
9
#include "tensor.h"
10 11 12 13 14 15 16 17 18 19

namespace paddle {
namespace optimizer {

class ParameterOptimizer {
public:
  /**
   * @brief  update hook for algorithm need to traverse parameter more than
   * once.
   */
D
dzhwinter 已提交
20
  ParameterOptimizer(LrPolicy *lr) : lr_policy_(lr), num_sample_passed_(0) {}
21 22
  virtual ~ParameterOptimizer() { delete parameter_; };

D
dzhwinter 已提交
23 24
  static ParameterOptimizer *Create(const std::string &config_proto);
  virtual void Update(const Tensor *gradient) = 0;
D
dzhwinter 已提交
25 26
  virtual real *get_weight() const;
  virtual void set_weight(Tensor *parameter);
27

D
dzhwinter 已提交
28
protected:
29
  OptimizerConfig config_;
D
dzhwinter 已提交
30
  Tensor *parameter_;
31 32

  // learning rate policy
D
dzhwinter 已提交
33 34
  LrPolicy *lr_policy_;
  uint64_t num_sample_passed_;
35 36 37 38 39 40
};

}  // namespace optimizer
}  // namespace paddle

#endif