adam_optimizer.h 1.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
D
dzhwinter 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
D
dzhwinter 已提交
14

D
dzhwinter 已提交
15
#pragma once
16 17 18 19 20 21

#include "parameter_optimizer.h"

namespace paddle {
namespace optimizer {

D
dzhwinter 已提交
22
class AdamOptimizer : public ParameterOptimizer {
23
public:
D
dzhwinter 已提交
24
  AdamOptimizer(Tensor *parameter,
D
dzhwinter 已提交
25 26
                LrPolicy *lr,
                double beta_1,
D
dzhwinter 已提交
27 28 29 30
                double beta_2,
                double epsilon,
                double decay)
      : ParameterOptimizer(parameter, lr),
D
dzhwinter 已提交
31 32
        momentums_(new Tensor(parameter->size())),
        velocitys_(new Tensor(parameter->size())),
D
dzhwinter 已提交
33 34 35
        beta_1_(beta_1),
        beta_2_(beta_2),
        epsilon_(epsilon),
D
dzhwinter 已提交
36
        decay_(decay) {}
37 38 39 40
  ~AdamOptimizer() {
    if (momentums_) delete momentums_;
    if (velocitys_) delete velocitys_;
  }
D
dzhwinter 已提交
41
  void Update(const Tensor *gradient);
42
  std::string SerializeState();
D
dzhwinter 已提交
43
  void DeserializeState(const std::string &state);
44 45

private:
D
dzhwinter 已提交
46 47
  Tensor *momentums_;
  Tensor *velocitys_;
D
dzhwinter 已提交
48 49 50 51
  double beta_1_;
  double beta_2_;
  double epsilon_;
  double decay_;
52 53 54 55
};

}  // namespace optimizer
}  // namespace paddle