parameter_optimizer.cc 2.6 KB
Newer Older
1
#include <glog/logging.h>
D
dzhwinter 已提交
2 3 4
#include "adadelta_optimizer.h"
#include "adagrad_optimizer.h"
#include "adam_optimizer.h"
5 6 7 8
#include "lr_policy.h"
#include "sgd_optimizer.h"

#include "parameter_optimizer.h"
9 10 11 12

namespace paddle {
namespace optimizer {

D
dzhwinter 已提交
13 14
ParameterOptimizer *ParameterOptimizer::Create(
    const std::string &config_proto) {
15 16
  paddle::OptimizerConfig config;
  CHECK(config.ParseFromString(config_proto) == 0)
D
dzhwinter 已提交
17
      << "failed parse optimizer config";
18

D
dzhwinter 已提交
19 20 21 22
  auto select_lr_policy = [=](const OptimizerConfig &config) -> LrPolicy * {
    if (config.lr_policy() == OptimizerConfig::ConstLr)
      return new ConstLr(config.const_lr().learning_rate());
    if (config.lr_policy() == OptimizerConfig::LinearLr)
D
dzhwinter 已提交
23 24 25
      return new LinearLr(config.linear_lr().learning_rate(),
                          config.linear_lr().lr_decay_a(),
                          config.linear_lr().lr_decay_b());
D
dzhwinter 已提交
26
    // default
27 28
    LOG(WARNING) << " have not select any LrPolicy. use ConstLr in default";
    return new ConstLr(0.1);
D
dzhwinter 已提交
29
  };
D
dzhwinter 已提交
30
  LrPolicy *lr = select_lr_policy(config);
D
dzhwinter 已提交
31 32
  auto select_optimizer =
      [=](const OptimizerConfig &config) -> ParameterOptimizer * {
D
dzhwinter 已提交
33
    if (config.optimizer() == OptimizerConfig::SGD) {
D
dzhwinter 已提交
34 35 36 37 38
      return new SGDOptimizer(config.sgd().momentum(),
                              config.sgd().decay(),
                              config.sgd().nesterov(),
                              lr);
    }
D
dzhwinter 已提交
39
    if (config.optimizer() == OptimizerConfig::Adadelta) {
40 41 42 43
      return new AdadeltaOptimizer(config.adadelta().rho(),
                                   config.adadelta().epsilon(),
                                   config.adadelta().decay(),
                                   lr);
D
dzhwinter 已提交
44
    }
D
dzhwinter 已提交
45
    if (config.optimizer() == OptimizerConfig::Adagrad) {
D
dzhwinter 已提交
46 47 48
      return new AdagradOptimizer(
          config.adagrad().epsilon(), config.adagrad().decay(), lr);
    }
D
dzhwinter 已提交
49
    if (config.optimizer() == OptimizerConfig::Adam) {
50 51 52 53 54
      return new AdamOptimizer(config.adam().beta_1(),
                               config.adam().beta_2(),
                               config.adam().epsilon(),
                               config.adam().decay(),
                               lr);
D
dzhwinter 已提交
55
    }
D
dzhwinter 已提交
56
    // default
57 58 59
    LOG(WARNING)
        << "have not select any Optimizer. use SGDOptimizer in default";
    return new SGDOptimizer(0.0, 0.0, false, lr);
D
dzhwinter 已提交
60 61
  };
  return select_optimizer(config);
62 63
}

64 65
float *ParameterOptimizer::get_weight(int *param_size) const {
  *param_size = (int)parameter_->size();
D
dzhwinter 已提交
66
  return parameter_->get_buffer();
67 68
}

D
dzhwinter 已提交
69
void ParameterOptimizer::set_weight(Tensor *p) { parameter_ = p; }
70 71 72

}  // namespace optimizer
}  // namespace paddle