parameter_optimizer.cc 3.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
D
dzhwinter 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
D
dzhwinter 已提交
14

15
#include <glog/logging.h>
D
dzhwinter 已提交
16 17 18
#include "adadelta_optimizer.h"
#include "adagrad_optimizer.h"
#include "adam_optimizer.h"
19 20 21 22
#include "lr_policy.h"
#include "sgd_optimizer.h"

#include "parameter_optimizer.h"
23 24 25 26

namespace paddle {
namespace optimizer {

D
dzhwinter 已提交
27 28
ParameterOptimizer *ParameterOptimizer::Create(const std::string &config_proto,
                                               Tensor *parameter) {
29
  paddle::OptimizerConfig config;
D
dzhwinter 已提交
30
  CHECK(config.ParseFromString(config_proto) == true)
D
dzhwinter 已提交
31 32
      << "failed parse optimizer config";
  auto select_lr_policy = [=](const OptimizerConfig &config) -> LrPolicy * {
D
dzhwinter 已提交
33
    if (config.lr_policy() == OptimizerConfig::Const)
D
dzhwinter 已提交
34
      return new ConstLr(config.const_lr().learning_rate());
D
dzhwinter 已提交
35
    if (config.lr_policy() == OptimizerConfig::Linear)
D
dzhwinter 已提交
36 37 38
      return new LinearLr(config.linear_lr().learning_rate(),
                          config.linear_lr().lr_decay_a(),
                          config.linear_lr().lr_decay_b());
D
dzhwinter 已提交
39
    // default
40 41
    LOG(WARNING) << " have not select any LrPolicy. use ConstLr in default";
    return new ConstLr(0.1);
D
dzhwinter 已提交
42
  };
D
dongzhihong 已提交
43

D
dzhwinter 已提交
44
  LrPolicy *lr = select_lr_policy(config);
D
dongzhihong 已提交
45 46 47
  auto select_optimizer = [=](
      Tensor *parameter,
      const OptimizerConfig &config) -> ParameterOptimizer * {
D
dzhwinter 已提交
48
    if (config.optimizer() == OptimizerConfig::SGD) {
49
      LOG(INFO) << "creating SGD optimizer";
D
dzhwinter 已提交
50 51 52
      return new SGDOptimizer(parameter,
                              lr,
                              config.sgd().momentum(),
D
dzhwinter 已提交
53
                              config.sgd().decay(),
D
dzhwinter 已提交
54
                              config.sgd().nesterov());
D
dzhwinter 已提交
55
    }
D
dzhwinter 已提交
56
    if (config.optimizer() == OptimizerConfig::Adadelta) {
57
      LOG(INFO) << "creating Adadelta optimizer";
D
dzhwinter 已提交
58 59 60
      return new AdadeltaOptimizer(parameter,
                                   lr,
                                   config.adadelta().rho(),
61
                                   config.adadelta().epsilon(),
D
dzhwinter 已提交
62
                                   config.adadelta().decay());
D
dzhwinter 已提交
63
    }
D
dzhwinter 已提交
64
    if (config.optimizer() == OptimizerConfig::Adagrad) {
65
      LOG(INFO) << "creating Adagrad optimizer";
D
dzhwinter 已提交
66
      return new AdagradOptimizer(
D
dzhwinter 已提交
67
          parameter, lr, config.adagrad().epsilon(), config.adagrad().decay());
D
dzhwinter 已提交
68
    }
D
dzhwinter 已提交
69
    if (config.optimizer() == OptimizerConfig::Adam) {
70
      LOG(INFO) << "creating Adam optimizer";
D
dzhwinter 已提交
71 72 73
      return new AdamOptimizer(parameter,
                               lr,
                               config.adam().beta_1(),
74 75
                               config.adam().beta_2(),
                               config.adam().epsilon(),
D
dzhwinter 已提交
76
                               config.adam().decay());
D
dzhwinter 已提交
77
    }
D
dzhwinter 已提交
78
    // default
79 80
    LOG(WARNING)
        << "have not select any Optimizer. use SGDOptimizer in default";
D
dzhwinter 已提交
81
    return new SGDOptimizer(parameter, lr, 0.0, 0.0, false);
D
dzhwinter 已提交
82
  };
D
dzhwinter 已提交
83
  return select_optimizer(parameter, config);
84 85
}

86 87
float *ParameterOptimizer::get_weight(int *param_size) const {
  *param_size = (int)parameter_->size();
D
dzhwinter 已提交
88
  return parameter_->get_buffer();
89 90 91 92
}

}  // namespace optimizer
}  // namespace paddle