#include "parameter_optimizer.h" #include #include "optimizer_factory.h" namespace paddle { namespace optimizer { template ParameterOptimizer *ParameterOptimizer::create( const ::std::string &config_proto) { paddle::OptimizerConfig config; CHECK(config.ParseFromString(config_proto) == 0) << "error : optimizer config"; CHECK(config_valid(config) == 0) << "error : invalid optimizer config "; ParameterOptimizer *opt = nullptr; switch (config.optimizer_name()) { case "SGD": opt = new SGDOptimizer(config); break; case "Adagrad": opt = new AdagradOptimizer(config); break; case "Adadelta": opt = new AdadeltaOptimizer(config); break; case "Adam": opt = new AdamOptimizer(config); break; default: opt = new SGDOptimizer(config); } switch (config.lr_policy()) { case "ConstLr": opt.lr_policy = new ConstLr(config); break; } return opt; } template T *ParameterOptimizer::get_weight() const { return parameter.get().get_buffer(); } template char *ParameterOptimizer::get_config_proto() const { // set config dynamic value for save checkpoint config_.lr_policy().set_learning_rate( lr_policy->get_learning_rate(num_sample_passed)); config_.set_num_sample_passed(num_sample_passed); config_.set_iterations(iterations); return config_.SerializeAsString().c_str(); } template void ParameterOptimizer::set_weight(const Tensor *p) { parameter_ = p; } template bool ParameterOptimizer::config_valid(const ::std::string &config) const { // TODO(zhihong) : add more value checker, failed ASAP return true; } template class ParameterOptimzier; template class ParameterOptimzier; } // namespace optimizer } // namespace paddle