parameter_optimizer.cc 2.9 KB
Newer Older
1
#include <glog/logging.h>
D
dzhwinter 已提交
2 3 4
#include "adadelta_optimizer.h"
#include "adagrad_optimizer.h"
#include "adam_optimizer.h"
5 6 7 8
#include "lr_policy.h"
#include "sgd_optimizer.h"

#include "parameter_optimizer.h"
9 10 11 12

namespace paddle {
namespace optimizer {

D
dzhwinter 已提交
13 14
ParameterOptimizer *ParameterOptimizer::Create(const std::string &config_proto,
                                               Tensor *parameter) {
15
  paddle::OptimizerConfig config;
D
dzhwinter 已提交
16
  CHECK(config.ParseFromString(config_proto) == true)
D
dzhwinter 已提交
17 18
      << "failed parse optimizer config";
  auto select_lr_policy = [=](const OptimizerConfig &config) -> LrPolicy * {
D
dzhwinter 已提交
19
    if (config.lr_policy() == OptimizerConfig::Const)
D
dzhwinter 已提交
20
      return new ConstLr(config.const_lr().learning_rate());
D
dzhwinter 已提交
21
    if (config.lr_policy() == OptimizerConfig::Linear)
D
dzhwinter 已提交
22 23 24
      return new LinearLr(config.linear_lr().learning_rate(),
                          config.linear_lr().lr_decay_a(),
                          config.linear_lr().lr_decay_b());
D
dzhwinter 已提交
25
    // default
26 27
    LOG(WARNING) << " have not select any LrPolicy. use ConstLr in default";
    return new ConstLr(0.1);
D
dzhwinter 已提交
28
  };
D
dongzhihong 已提交
29

D
dzhwinter 已提交
30
  LrPolicy *lr = select_lr_policy(config);
D
dongzhihong 已提交
31 32 33
  auto select_optimizer = [=](
      Tensor *parameter,
      const OptimizerConfig &config) -> ParameterOptimizer * {
D
dzhwinter 已提交
34
    if (config.optimizer() == OptimizerConfig::SGD) {
35
      LOG(INFO) << "creating SGD optimizer";
D
dzhwinter 已提交
36 37 38
      return new SGDOptimizer(parameter,
                              lr,
                              config.sgd().momentum(),
D
dzhwinter 已提交
39
                              config.sgd().decay(),
D
dzhwinter 已提交
40
                              config.sgd().nesterov());
D
dzhwinter 已提交
41
    }
D
dzhwinter 已提交
42
    if (config.optimizer() == OptimizerConfig::Adadelta) {
43
      LOG(INFO) << "creating Adadelta optimizer";
D
dzhwinter 已提交
44 45 46
      return new AdadeltaOptimizer(parameter,
                                   lr,
                                   config.adadelta().rho(),
47
                                   config.adadelta().epsilon(),
D
dzhwinter 已提交
48
                                   config.adadelta().decay());
D
dzhwinter 已提交
49
    }
D
dzhwinter 已提交
50
    if (config.optimizer() == OptimizerConfig::Adagrad) {
51
      LOG(INFO) << "creating Adagrad optimizer";
D
dzhwinter 已提交
52
      return new AdagradOptimizer(
D
dzhwinter 已提交
53
          parameter, lr, config.adagrad().epsilon(), config.adagrad().decay());
D
dzhwinter 已提交
54
    }
D
dzhwinter 已提交
55
    if (config.optimizer() == OptimizerConfig::Adam) {
56
      LOG(INFO) << "creating Adam optimizer";
D
dzhwinter 已提交
57 58 59
      return new AdamOptimizer(parameter,
                               lr,
                               config.adam().beta_1(),
60 61
                               config.adam().beta_2(),
                               config.adam().epsilon(),
D
dzhwinter 已提交
62
                               config.adam().decay());
D
dzhwinter 已提交
63
    }
D
dzhwinter 已提交
64
    // default
65 66
    LOG(WARNING)
        << "have not select any Optimizer. use SGDOptimizer in default";
D
dzhwinter 已提交
67
    return new SGDOptimizer(parameter, lr, 0.0, 0.0, false);
D
dzhwinter 已提交
68
  };
D
dzhwinter 已提交
69
  return select_optimizer(parameter, config);
70 71
}

72 73
float *ParameterOptimizer::get_weight(int *param_size) const {
  *param_size = (int)parameter_->size();
D
dzhwinter 已提交
74
  return parameter_->get_buffer();
75 76 77 78
}

}  // namespace optimizer
}  // namespace paddle