parameter_optimizer.cc 2.3 KB
Newer Older
1
#include <glog/logging.h>
D
dzhwinter 已提交
2 3 4
// #include "adadelta_optimizer.h"
// #include "adagrad_optimizer.h"
// #include "adam_optimizer.h"
5 6 7 8
#include "lr_policy.h"
#include "sgd_optimizer.h"

#include "parameter_optimizer.h"
9 10 11 12

namespace paddle {
namespace optimizer {

D
dzhwinter 已提交
13
ParameterOptimizer *ParameterOptimizer::create(
14 15 16 17
    const ::std::string &config_proto) {
  paddle::OptimizerConfig config;
  CHECK(config.ParseFromString(config_proto) == 0)
      << "error : optimizer config";
18

D
dzhwinter 已提交
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
  auto select_lr_policy = [=](const OptimizerConfig &config) -> BaseLr * {
    std::string s(config.lr_policy());
    if (s == "ConstLr") return new ConstLr(config.lr_config().learning_rate());
    if (s == "LinearLr")
      return new LinearLr(config.lr_config().learning_rate(),
                          config.lr_config().lr_decay_a(),
                          config.lr_config().lr_decay_b());
    // default
    return new ConstLr(config.lr_config().learning_rate());
  };
  BaseLr *lr = select_lr_policy(config);
  auto select_optimizer =
      [=](const OptimizerConfig &config) -> ParameterOptimizer * {
    std::string s(config.optimizer_name());
    if (s == "SGD") {
      return new SGDOptimizer(config.sgd().momentum(),
                              config.sgd().decay(),
                              config.sgd().nesterov(),
                              lr);
    }
D
dzhwinter 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51 52
    // if (s == "Adadelta") {
    //   return new AdagradOptimizer(
    //       config.adagrad().epsilon(), config.adagrad().decay(), lr);
    // }
    // if (s == "Adagrad") {
    //   return new AdagradOptimizer(
    //       config.adagrad().epsilon(), config.adagrad().decay(), lr);
    // }
    // if (s == "Adam") {
    //   return new AdadeltaOptimizer(config.adadelta().rho(),
    //                                config.adadelta().epsilon(),
    //                                config.adadelta().decay(),
    //                                lr);
    // }
D
dzhwinter 已提交
53 54 55 56 57 58 59
    // default
    return new SGDOptimizer(config.sgd().momentum(),
                            config.sgd().decay(),
                            config.sgd().nesterov(),
                            lr);
  };
  return select_optimizer(config);
60 61
}

D
dzhwinter 已提交
62 63
real *ParameterOptimizer::get_weight() const {
  return parameter_->get_buffer();
64 65
}

D
dzhwinter 已提交
66
void ParameterOptimizer::set_weight(Tensor *p) { parameter_ = p; }
67 68 69

}  // namespace optimizer
}  // namespace paddle