parameter_optimizer.cc 2.6 KB
Newer Older
1
#include <glog/logging.h>
2 3 4 5 6 7 8
#include "adadelta_optimizer.h"
#include "adagrad_optimizer.h"
#include "adam_optimizer.h"
#include "lr_policy.h"
#include "sgd_optimizer.h"

#include "parameter_optimizer.h"
9 10 11 12 13 14 15 16 17 18 19

namespace paddle {
namespace optimizer {

template <class T>
ParameterOptimizer<T> *ParameterOptimizer<T>::create(
    const ::std::string &config_proto) {
  paddle::OptimizerConfig config;
  CHECK(config.ParseFromString(config_proto) == 0)
      << "error : optimizer config";
  CHECK(config_valid(config) == 0) << "error : invalid optimizer config ";
20 21 22 23 24 25 26

  BaseLr *lr = nullptr;
  switch (config.lr_policy()) {
    case "ConstLr":
      lr = new ConstLr(config.lr_config().learning_rate());
      break;
  }
27 28 29
  ParameterOptimizer<T> *opt = nullptr;
  switch (config.optimizer_name()) {
    case "SGD":
30 31 32 33
      opt = new SGDOptimizer<T>(config.sgd().momentum(),
                                config.sgd().decay(),
                                config.sgd().nesterov(),
                                lr);
34 35
      break;
    case "Adagrad":
36 37
      opt = new AdagradOptimizer<T>(
          config.adagrad().epsilon(), config.adagrad().decay(), lr);
38 39
      break;
    case "Adadelta":
40 41 42 43
      opt = new AdadeltaOptimizer<T>(config.adadelta().rho(),
                                     config.adadelta().epsilon(),
                                     config.adadelta().decay(),
                                     lr);
44 45
      break;
    case "Adam":
46 47 48 49 50
      opt = new AdamOptimizer<T>(config.adam().beta_1(),
                                 config.adam().beta_2(),
                                 config.adam().epsilon(),
                                 config.adam().decay(),
                                 lr);
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
      break;
  }

  return opt;
}

template <class T>
T *ParameterOptimizer<T>::get_weight() const {
  return parameter.get().get_buffer();
}

template <class T>
char *ParameterOptimizer<T>::get_config_proto() const {
  // set config dynamic value for save checkpoint
  config_.lr_policy().set_learning_rate(
      lr_policy->get_learning_rate(num_sample_passed));
  config_.set_num_sample_passed(num_sample_passed);
  config_.set_iterations(iterations);
  return config_.SerializeAsString().c_str();
}

template <class T>
void ParameterOptimizer<T>::set_weight(const Tensor<T> *p) {
  parameter_ = p;
}

template <class T>
bool ParameterOptimizer<T>::config_valid(const ::std::string &config) const {
  // TODO(zhihong) : add more value checker, failed ASAP
  return true;
}

template class ParameterOptimzier<float>;
template class ParameterOptimzier<double>;

}  // namespace optimizer
}  // namespace paddle