sgd_optimizer.h 870 字节
Newer Older
D
dzhwinter 已提交
1
#pragma once
2 3 4 5 6 7

#include "parameter_optimizer.h"

namespace paddle {
namespace optimizer {

D
dzhwinter 已提交
8
class SGDOptimizer : public ParameterOptimizer {
9
public:
D
dzhwinter 已提交
10 11
  SGDOptimizer(Tensor* parameter, LrPolicy* lr, double m, double d, bool n)
      : ParameterOptimizer(parameter, lr),
12 13 14
        momentums_(nullptr),
        momentum_(m),
        decay_(d),
D
dzhwinter 已提交
15 16
        nesterov_(n) {
    if (momentum_ != 0.0) {
D
dzhwinter 已提交
17
      size_t size = parameter->size();
D
dzhwinter 已提交
18 19 20 21 22 23 24
      // TODO: fix it with align aware allocator bind to Tensor
      momentums_ = new Tensor(size);
    }
  }
  virtual ~SGDOptimizer() {
    if (momentums_) delete momentums_;
  }
D
dzhwinter 已提交
25
  void Update(const Tensor* gradient);
26
  std::string SerializeState();
D
dzhwinter 已提交
27
  void DeserializeState(const std::string& state);
D
dzhwinter 已提交
28

29
private:
D
dzhwinter 已提交
30
  Tensor* momentums_;
D
dzhwinter 已提交
31 32 33
  double momentum_;
  double decay_;
  bool nesterov_;
34 35 36 37
};

}  // namespace optimizer
}  // namespace paddle