parameter_optimizer.cc 2.2 KB
Newer Older
1
#include <glog/logging.h>
D
dzhwinter 已提交
2 3 4
#include "adadelta_optimizer.h"
#include "adagrad_optimizer.h"
#include "adam_optimizer.h"
5 6 7 8
#include "lr_policy.h"
#include "sgd_optimizer.h"

#include "parameter_optimizer.h"
9 10 11 12

namespace paddle {
namespace optimizer {

D
dzhwinter 已提交
13
ParameterOptimizer *ParameterOptimizer::create(
14 15 16 17
    const ::std::string &config_proto) {
  paddle::OptimizerConfig config;
  CHECK(config.ParseFromString(config_proto) == 0)
      << "error : optimizer config";
18

D
dzhwinter 已提交
19 20
  auto select_lr_policy = [=](const OptimizerConfig &config) -> BaseLr * {
    std::string s(config.lr_policy());
D
dzhwinter 已提交
21
    if (s == "ConstLr") return new ConstLr(config.const_lr().learning_rate());
D
dzhwinter 已提交
22
    if (s == "LinearLr")
D
dzhwinter 已提交
23 24 25
      return new LinearLr(config.linear_lr().learning_rate(),
                          config.linear_lr().lr_decay_a(),
                          config.linear_lr().lr_decay_b());
D
dzhwinter 已提交
26
    // default
D
dzhwinter 已提交
27
    return nullptr;
D
dzhwinter 已提交
28 29 30 31 32 33 34 35 36 37 38
  };
  BaseLr *lr = select_lr_policy(config);
  auto select_optimizer =
      [=](const OptimizerConfig &config) -> ParameterOptimizer * {
    std::string s(config.optimizer_name());
    if (s == "SGD") {
      return new SGDOptimizer(config.sgd().momentum(),
                              config.sgd().decay(),
                              config.sgd().nesterov(),
                              lr);
    }
D
dzhwinter 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51 52
    if (s == "Adadelta") {
      return new AdagradOptimizer(
          config.adagrad().epsilon(), config.adagrad().decay(), lr);
    }
    if (s == "Adagrad") {
      return new AdagradOptimizer(
          config.adagrad().epsilon(), config.adagrad().decay(), lr);
    }
    if (s == "Adam") {
      return new AdadeltaOptimizer(config.adadelta().rho(),
                                   config.adadelta().epsilon(),
                                   config.adadelta().decay(),
                                   lr);
    }
D
dzhwinter 已提交
53 54 55 56 57 58 59
    // default
    return new SGDOptimizer(config.sgd().momentum(),
                            config.sgd().decay(),
                            config.sgd().nesterov(),
                            lr);
  };
  return select_optimizer(config);
60 61
}

D
dzhwinter 已提交
62 63
real *ParameterOptimizer::get_weight() const {
  return parameter_->get_buffer();
64 65
}

D
dzhwinter 已提交
66
void ParameterOptimizer::set_weight(Tensor *p) { parameter_ = p; }
67 68 69

}  // namespace optimizer
}  // namespace paddle