“0145cbe100921844e8f54988f90084fb02367afe”上不存在“tests/git@gitcode.net:taosdata/tdengine.git”
parameter_optimizer.cc 2.8 KB
Newer Older
1
#include <glog/logging.h>
D
dzhwinter 已提交
2 3 4
#include "adadelta_optimizer.h"
#include "adagrad_optimizer.h"
#include "adam_optimizer.h"
5 6 7 8
#include "lr_policy.h"
#include "sgd_optimizer.h"

#include "parameter_optimizer.h"
9 10 11 12

namespace paddle {
namespace optimizer {

D
dzhwinter 已提交
13 14
ParameterOptimizer *ParameterOptimizer::Create(const std::string &config_proto,
                                               Tensor *parameter) {
15
  paddle::OptimizerConfig config;
D
dzhwinter 已提交
16
  CHECK(config.ParseFromString(config_proto) == true)
D
dzhwinter 已提交
17 18
      << "failed parse optimizer config";
  auto select_lr_policy = [=](const OptimizerConfig &config) -> LrPolicy * {
D
dzhwinter 已提交
19
    if (config.lr_policy() == OptimizerConfig::Const)
D
dzhwinter 已提交
20
      return new ConstLr(config.const_lr().learning_rate());
D
dzhwinter 已提交
21
    if (config.lr_policy() == OptimizerConfig::Linear)
D
dzhwinter 已提交
22 23 24
      return new LinearLr(config.linear_lr().learning_rate(),
                          config.linear_lr().lr_decay_a(),
                          config.linear_lr().lr_decay_b());
D
dzhwinter 已提交
25
    // default
26 27
    LOG(WARNING) << " have not select any LrPolicy. use ConstLr in default";
    return new ConstLr(0.1);
D
dzhwinter 已提交
28
  };
D
dzhwinter 已提交
29
  LrPolicy *lr = select_lr_policy(config);
D
dzhwinter 已提交
30
  auto select_optimizer =
D
dzhwinter 已提交
31 32
      [=](Tensor *parameter,
          const OptimizerConfig &config) -> ParameterOptimizer * {
D
dzhwinter 已提交
33
    if (config.optimizer() == OptimizerConfig::SGD) {
D
dzhwinter 已提交
34 35 36
      return new SGDOptimizer(parameter,
                              lr,
                              config.sgd().momentum(),
D
dzhwinter 已提交
37
                              config.sgd().decay(),
D
dzhwinter 已提交
38
                              config.sgd().nesterov());
D
dzhwinter 已提交
39
    }
D
dzhwinter 已提交
40
    if (config.optimizer() == OptimizerConfig::Adadelta) {
D
dzhwinter 已提交
41 42 43
      return new AdadeltaOptimizer(parameter,
                                   lr,
                                   config.adadelta().rho(),
44
                                   config.adadelta().epsilon(),
D
dzhwinter 已提交
45
                                   config.adadelta().decay());
D
dzhwinter 已提交
46
    }
D
dzhwinter 已提交
47
    if (config.optimizer() == OptimizerConfig::Adagrad) {
D
dzhwinter 已提交
48
      return new AdagradOptimizer(
D
dzhwinter 已提交
49
          parameter, lr, config.adagrad().epsilon(), config.adagrad().decay());
D
dzhwinter 已提交
50
    }
D
dzhwinter 已提交
51
    if (config.optimizer() == OptimizerConfig::Adam) {
D
dzhwinter 已提交
52 53 54
      return new AdamOptimizer(parameter,
                               lr,
                               config.adam().beta_1(),
55 56
                               config.adam().beta_2(),
                               config.adam().epsilon(),
D
dzhwinter 已提交
57
                               config.adam().decay());
D
dzhwinter 已提交
58
    }
D
dzhwinter 已提交
59
    // default
60 61
    LOG(WARNING)
        << "have not select any Optimizer. use SGDOptimizer in default";
D
dzhwinter 已提交
62
    return new SGDOptimizer(parameter, lr, 0.0, 0.0, false);
D
dzhwinter 已提交
63
  };
D
dzhwinter 已提交
64
  return select_optimizer(parameter, config);
65 66
}

67 68
float *ParameterOptimizer::get_weight(int *param_size) const {
  *param_size = (int)parameter_->size();
D
dzhwinter 已提交
69
  return parameter_->get_buffer();
70 71 72 73
}

}  // namespace optimizer
}  // namespace paddle