parameter_optimizer.h 1.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
D
dzhwinter 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
D
dzhwinter 已提交
14

D
dzhwinter 已提交
15
#pragma once
16 17 18 19 20 21

#include <glog/logging.h>
#include <functional>
#include <string>
#include "OptimizerConfig.pb.h"
#include "lr_policy.h"
D
dzhwinter 已提交
22
#include "serialization.h"
D
dzhwinter 已提交
23
#include "tensor.h"
24 25 26 27 28 29 30 31 32 33

namespace paddle {
namespace optimizer {

class ParameterOptimizer {
public:
  /**
   * @brief  update hook for algorithm need to traverse parameter more than
   * once.
   */
D
dzhwinter 已提交
34 35
  ParameterOptimizer(Tensor *parameter, LrPolicy *lr)
      : parameter_(parameter), lr_policy_(lr), num_sample_passed_(0) {}
D
dzhwinter 已提交
36 37 38 39
  virtual ~ParameterOptimizer() {
    delete parameter_;
    delete lr_policy_;
  }
40

D
dzhwinter 已提交
41 42
  static ParameterOptimizer *Create(const std::string &config_proto,
                                    Tensor *parameter);
D
dzhwinter 已提交
43
  virtual void Update(const Tensor *gradient) = 0;
44
  virtual float *get_weight(int *param_size) const;
45
  virtual std::string SerializeState() = 0;
D
dzhwinter 已提交
46
  virtual void DeserializeState(const std::string &state) = 0;
47

D
dzhwinter 已提交
48
protected:
D
dzhwinter 已提交
49
  Tensor *parameter_;
50
  // learning rate policy
D
dzhwinter 已提交
51 52
  LrPolicy *lr_policy_;
  uint64_t num_sample_passed_;
53 54 55 56
};

}  // namespace optimizer
}  // namespace paddle