parameter_optimizer.cc 1.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
#include "parameter_optimizer.h"
#include <glog/logging.h>
#include "optimizer_factory.h"

namespace paddle {
namespace optimizer {

template <class T>
ParameterOptimizer<T> *ParameterOptimizer<T>::create(
    const ::std::string &config_proto) {
  paddle::OptimizerConfig config;
  CHECK(config.ParseFromString(config_proto) == 0)
      << "error : optimizer config";
  CHECK(config_valid(config) == 0) << "error : invalid optimizer config ";
  ParameterOptimizer<T> *opt = nullptr;
  switch (config.optimizer_name()) {
    case "SGD":
      opt = new SGDOptimizer<T>(config);
      break;
    case "Adagrad":
      opt = new AdagradOptimizer<T>(config);
      break;
    case "Adadelta":
      opt = new AdadeltaOptimizer<T>(config);
      break;
    case "Adam":
      opt = new AdamOptimizer<T>(config);
      break;
    default:
      opt = new SGDOptimizer<T>(config);
  }

  switch (config.lr_policy()) {
    case "ConstLr":
      opt.lr_policy = new ConstLr(config);
      break;
  }
  return opt;
}

template <class T>
T *ParameterOptimizer<T>::get_weight() const {
  return parameter.get().get_buffer();
}

template <class T>
char *ParameterOptimizer<T>::get_config_proto() const {
  // set config dynamic value for save checkpoint
  config_.lr_policy().set_learning_rate(
      lr_policy->get_learning_rate(num_sample_passed));
  config_.set_num_sample_passed(num_sample_passed);
  config_.set_iterations(iterations);
  return config_.SerializeAsString().c_str();
}

template <class T>
void ParameterOptimizer<T>::set_weight(const Tensor<T> *p) {
  parameter_ = p;
}

template <class T>
bool ParameterOptimizer<T>::config_valid(const ::std::string &config) const {
  // TODO(zhihong) : add more value checker, failed ASAP
  return true;
}

template class ParameterOptimzier<float>;
template class ParameterOptimzier<double>;

}  // namespace optimizer
}  // namespace paddle