sgd_optimizer.h 917 字节
Newer Older
D
dzhwinter 已提交
1
#pragma once
2 3 4 5 6 7

#include "parameter_optimizer.h"

namespace paddle {
namespace optimizer {

D
dzhwinter 已提交
8
class SGDOptimizer : public ParameterOptimizer {
9
public:
D
dzhwinter 已提交
10 11
  SGDOptimizer(Tensor* parameter, LrPolicy* lr, double m, double d, bool n)
      : ParameterOptimizer(parameter, lr),
12 13 14
        momentums_(nullptr),
        momentum_(m),
        decay_(d),
D
dzhwinter 已提交
15 16 17 18 19 20 21 22 23 24 25
        nesterov_(n) {
    if (momentum_ != 0.0) {
      size_t size = p->size();
      // TODO: fix it with align aware allocator bind to Tensor
      if (momentums_) delete momentums_;
      momentums_ = new Tensor(size);
    }
  }
  virtual ~SGDOptimizer() {
    if (momentums_) delete momentums_;
  }
D
dzhwinter 已提交
26
  void Update(const Tensor* gradient);
D
dzhwinter 已提交
27
  const char* SerializeState(int* state_len);
28
  void DeSerializeState(const std::string& state);
D
dzhwinter 已提交
29

30
private:
D
dzhwinter 已提交
31
  Tensor* momentums_;
D
dzhwinter 已提交
32 33 34
  double momentum_;
  double decay_;
  bool nesterov_;
35 36 37 38
};

}  // namespace optimizer
}  // namespace paddle