parameter_optimizer.cc 2.3 KB
Newer Older
1
#include <glog/logging.h>
D
dzhwinter 已提交
2 3 4
#include "adadelta_optimizer.h"
#include "adagrad_optimizer.h"
#include "adam_optimizer.h"
5 6 7 8
#include "lr_policy.h"
#include "sgd_optimizer.h"

#include "parameter_optimizer.h"
9 10 11 12

namespace paddle {
namespace optimizer {

D
dzhwinter 已提交
13 14
ParameterOptimizer *ParameterOptimizer::Create(
    const std::string &config_proto) {
15 16
  paddle::OptimizerConfig config;
  CHECK(config.ParseFromString(config_proto) == 0)
D
dzhwinter 已提交
17
      << "failed parse optimizer config";
18

D
dzhwinter 已提交
19 20 21 22
  auto select_lr_policy = [=](const OptimizerConfig &config) -> LrPolicy * {
    if (config.lr_policy() == OptimizerConfig::ConstLr)
      return new ConstLr(config.const_lr().learning_rate());
    if (config.lr_policy() == OptimizerConfig::LinearLr)
D
dzhwinter 已提交
23 24 25
      return new LinearLr(config.linear_lr().learning_rate(),
                          config.linear_lr().lr_decay_a(),
                          config.linear_lr().lr_decay_b());
D
dzhwinter 已提交
26
    // default
D
dzhwinter 已提交
27
    return nullptr;
D
dzhwinter 已提交
28
  };
D
dzhwinter 已提交
29
  LrPolicy *lr = select_lr_policy(config);
D
dzhwinter 已提交
30 31
  auto select_optimizer =
      [=](const OptimizerConfig &config) -> ParameterOptimizer * {
D
dzhwinter 已提交
32
    if (config.optimizer() == OptimizerConfig::SGD) {
D
dzhwinter 已提交
33 34 35 36 37
      return new SGDOptimizer(config.sgd().momentum(),
                              config.sgd().decay(),
                              config.sgd().nesterov(),
                              lr);
    }
D
dzhwinter 已提交
38
    if (config.optimizer() == OptimizerConfig::Adadelta) {
D
dzhwinter 已提交
39 40 41
      return new AdagradOptimizer(
          config.adagrad().epsilon(), config.adagrad().decay(), lr);
    }
D
dzhwinter 已提交
42
    if (config.optimizer() == OptimizerConfig::Adagrad) {
D
dzhwinter 已提交
43 44 45
      return new AdagradOptimizer(
          config.adagrad().epsilon(), config.adagrad().decay(), lr);
    }
D
dzhwinter 已提交
46
    if (config.optimizer() == OptimizerConfig::Adam) {
D
dzhwinter 已提交
47 48 49 50 51
      return new AdadeltaOptimizer(config.adadelta().rho(),
                                   config.adadelta().epsilon(),
                                   config.adadelta().decay(),
                                   lr);
    }
D
dzhwinter 已提交
52 53 54 55 56 57 58
    // default
    return new SGDOptimizer(config.sgd().momentum(),
                            config.sgd().decay(),
                            config.sgd().nesterov(),
                            lr);
  };
  return select_optimizer(config);
59 60
}

D
dzhwinter 已提交
61 62
real *ParameterOptimizer::get_weight() const {
  return parameter_->get_buffer();
63 64
}

D
dzhwinter 已提交
65
void ParameterOptimizer::set_weight(Tensor *p) { parameter_ = p; }
66 67 68

}  // namespace optimizer
}  // namespace paddle