parameter_optimizer.cc 2.8 KB
Newer Older
1
#include <glog/logging.h>
D
dzhwinter 已提交
2 3 4
#include "adadelta_optimizer.h"
#include "adagrad_optimizer.h"
#include "adam_optimizer.h"
5 6 7 8
#include "lr_policy.h"
#include "sgd_optimizer.h"

#include "parameter_optimizer.h"
9 10 11 12

namespace paddle {
namespace optimizer {

D
dzhwinter 已提交
13 14
ParameterOptimizer *ParameterOptimizer::Create(const std::string &config_proto,
                                               Tensor *parameter) {
15 16
  paddle::OptimizerConfig config;
  CHECK(config.ParseFromString(config_proto) == 0)
D
dzhwinter 已提交
17
      << "failed parse optimizer config";
18

D
dzhwinter 已提交
19 20 21 22
  auto select_lr_policy = [=](const OptimizerConfig &config) -> LrPolicy * {
    if (config.lr_policy() == OptimizerConfig::ConstLr)
      return new ConstLr(config.const_lr().learning_rate());
    if (config.lr_policy() == OptimizerConfig::LinearLr)
D
dzhwinter 已提交
23 24 25
      return new LinearLr(config.linear_lr().learning_rate(),
                          config.linear_lr().lr_decay_a(),
                          config.linear_lr().lr_decay_b());
D
dzhwinter 已提交
26
    // default
27 28
    LOG(WARNING) << " have not select any LrPolicy. use ConstLr in default";
    return new ConstLr(0.1);
D
dzhwinter 已提交
29
  };
D
dzhwinter 已提交
30
  LrPolicy *lr = select_lr_policy(config);
D
dzhwinter 已提交
31
  auto select_optimizer =
D
dzhwinter 已提交
32 33
      [=](Tensor *parameter,
          const OptimizerConfig &config) -> ParameterOptimizer * {
D
dzhwinter 已提交
34
    if (config.optimizer() == OptimizerConfig::SGD) {
D
dzhwinter 已提交
35 36 37
      return new SGDOptimizer(parameter,
                              lr,
                              config.sgd().momentum(),
D
dzhwinter 已提交
38
                              config.sgd().decay(),
D
dzhwinter 已提交
39
                              config.sgd().nesterov());
D
dzhwinter 已提交
40
    }
D
dzhwinter 已提交
41
    if (config.optimizer() == OptimizerConfig::Adadelta) {
D
dzhwinter 已提交
42 43 44
      return new AdadeltaOptimizer(parameter,
                                   lr,
                                   config.adadelta().rho(),
45
                                   config.adadelta().epsilon(),
D
dzhwinter 已提交
46
                                   config.adadelta().decay());
D
dzhwinter 已提交
47
    }
D
dzhwinter 已提交
48
    if (config.optimizer() == OptimizerConfig::Adagrad) {
D
dzhwinter 已提交
49
      return new AdagradOptimizer(
D
dzhwinter 已提交
50
          parameter, lr, config.adagrad().epsilon(), config.adagrad().decay());
D
dzhwinter 已提交
51
    }
D
dzhwinter 已提交
52
    if (config.optimizer() == OptimizerConfig::Adam) {
D
dzhwinter 已提交
53 54 55
      return new AdamOptimizer(parameter,
                               lr,
                               config.adam().beta_1(),
56 57
                               config.adam().beta_2(),
                               config.adam().epsilon(),
D
dzhwinter 已提交
58
                               config.adam().decay());
D
dzhwinter 已提交
59
    }
D
dzhwinter 已提交
60
    // default
61 62
    LOG(WARNING)
        << "have not select any Optimizer. use SGDOptimizer in default";
D
dzhwinter 已提交
63
    return new SGDOptimizer(parameter, lr, 0.0, 0.0, false);
D
dzhwinter 已提交
64
  };
D
dzhwinter 已提交
65
  return select_optimizer(parameter, config);
66 67
}

68 69
float *ParameterOptimizer::get_weight(int *param_size) const {
  *param_size = (int)parameter_->size();
D
dzhwinter 已提交
70
  return parameter_->get_buffer();
71 72 73 74
}

}  // namespace optimizer
}  // namespace paddle