提交 65d9e33b 编写于 作者: D dzhwinter

"modify config name"

上级 df5bc787
...@@ -31,4 +31,3 @@ add_simple_unittest(test_FPException) ...@@ -31,4 +31,3 @@ add_simple_unittest(test_FPException)
add_simple_unittest(test_GpuProfiler) add_simple_unittest(test_GpuProfiler)
add_simple_unittest(test_BaseMatrix) add_simple_unittest(test_BaseMatrix)
add_simple_unittest(test_Matrix) add_simple_unittest(test_Matrix)
add_simple_unittest(test_Matrix2)
...@@ -28,7 +28,8 @@ const char *AdamOptimizer::SerializeState(int *state_len) { ...@@ -28,7 +28,8 @@ const char *AdamOptimizer::SerializeState(int *state_len) {
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*velocitys_, state.mutable_momentums()); TensorToProto(*momentums_, state.mutable_momentums());
TensorToProto(*velocitys_, state.mutable_velocitys());
auto str = state.SerializeAsString(); auto str = state.SerializeAsString();
*state_len = str.size(); *state_len = str.size();
return str.c_str(); return str.c_str();
......
...@@ -21,8 +21,8 @@ public: ...@@ -21,8 +21,8 @@ public:
double LearningRate(const uint64_t num_sample_passed) { double LearningRate(const uint64_t num_sample_passed) {
return learning_rate; return learning_rate;
} }
const char *SerializeState(int *state_len); const char *SerializeState(int *state_len) { return nullptr; }
void DeserializeState(const std::string &state); void DeserializeState(const std::string &state) {}
private: private:
double learning_rate; double learning_rate;
...@@ -35,8 +35,13 @@ public: ...@@ -35,8 +35,13 @@ public:
double LearningRate(const uint64_t num_sample_passed) { double LearningRate(const uint64_t num_sample_passed) {
return std::max(learning_rate - lr_decay_a * num_sample_passed, lr_decay_b); return std::max(learning_rate - lr_decay_a * num_sample_passed, lr_decay_b);
} }
const char *SerializeState(int *state_len); const char *SerializeState(int *state_len) {
void DeserializeState(const std::string &state); // TODO(zhihong) : add lr_policy serialization
return nullptr;
}
void DeserializeState(const std::string &state) {
// TODO(zhihong) : add lr_policy serialization
}
private: private:
double learning_rate; double learning_rate;
......
...@@ -13,13 +13,13 @@ namespace optimizer { ...@@ -13,13 +13,13 @@ namespace optimizer {
ParameterOptimizer *ParameterOptimizer::Create(const std::string &config_proto, ParameterOptimizer *ParameterOptimizer::Create(const std::string &config_proto,
Tensor *parameter) { Tensor *parameter) {
paddle::OptimizerConfig config; paddle::OptimizerConfig config;
CHECK(config.ParseFromString(config_proto) == 0) CHECK(config.ParseFromString(config_proto) == true)
<< "failed parse optimizer config"; << "failed parse optimizer config";
auto select_lr_policy = [=](const OptimizerConfig &config) -> LrPolicy * { auto select_lr_policy = [=](const OptimizerConfig &config) -> LrPolicy * {
if (config.lr_policy() == OptimizerConfig::ConstLr) if (config.lr_policy() == OptimizerConfig::Const)
return new ConstLr(config.const_lr().learning_rate()); return new ConstLr(config.const_lr().learning_rate());
if (config.lr_policy() == OptimizerConfig::LinearLr) if (config.lr_policy() == OptimizerConfig::Linear)
return new LinearLr(config.linear_lr().learning_rate(), return new LinearLr(config.linear_lr().learning_rate(),
config.linear_lr().lr_decay_a(), config.linear_lr().lr_decay_a(),
config.linear_lr().lr_decay_b()); config.linear_lr().lr_decay_b());
......
...@@ -2,11 +2,8 @@ ...@@ -2,11 +2,8 @@
#include <cmath> #include <cmath>
#include <map> #include <map>
#include <vector> #include <vector>
#include "adadelta_optimizer.h"
#include "adagrad_optimizer.h"
#include "adam_optimizer.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "sgd_optimizer.h" #include "lr_policy.h"
using namespace paddle; using namespace paddle;
using namespace paddle::optimizer; using namespace paddle::optimizer;
...@@ -41,12 +38,12 @@ public: ...@@ -41,12 +38,12 @@ public:
virtual void TearDown() {} virtual void TearDown() {}
void CreateSGD() { void CreateSGD() {
Tensor* parameter = FillTensor(kSize); Tensor* parameter = FixedTensor(kSize);
config_.set_optimizer(OptimizerConfig::SGD); config_.set_optimizer(OptimizerConfig::SGD);
config_.mutable_sgd()->set_momentum(0.0); config_.mutable_sgd()->set_momentum(0.0);
config_.mutable_sgd()->set_decay(0.0); config_.mutable_sgd()->set_decay(0.0);
config_.mutable_sgd()->set_nesterov(false); config_.mutable_sgd()->set_nesterov(false);
config_.set_lr_policy(OptimizerConfig::ConstLr); config_.set_lr_policy(OptimizerConfig::Const);
config_.mutable_const_lr()->set_learning_rate(0.1); config_.mutable_const_lr()->set_learning_rate(0.1);
std::string str = config_.SerializeAsString(); std::string str = config_.SerializeAsString();
...@@ -62,7 +59,7 @@ public: ...@@ -62,7 +59,7 @@ public:
config_.mutable_adam()->set_beta_2(0.1); config_.mutable_adam()->set_beta_2(0.1);
config_.mutable_adam()->set_epsilon(1e-3); config_.mutable_adam()->set_epsilon(1e-3);
config_.mutable_adam()->set_decay(0.0); config_.mutable_adam()->set_decay(0.0);
config_.set_lr_policy(OptimizerConfig::ConstLr); config_.set_lr_policy(OptimizerConfig::Const);
config_.mutable_const_lr()->set_learning_rate(0.1); config_.mutable_const_lr()->set_learning_rate(0.1);
std::string str = config_.SerializeAsString(); std::string str = config_.SerializeAsString();
ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter); ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter);
...@@ -90,12 +87,13 @@ public: ...@@ -90,12 +87,13 @@ public:
void TestCheckPoint() { void TestCheckPoint() {
std::map<OptimizerConfig::Optimizer, int> expected_state_len = { std::map<OptimizerConfig::Optimizer, int> expected_state_len = {
{OptimizerConfig::SGD, kSize}, {OptimizerConfig::Adam, kSize * 3}, {OptimizerConfig::SGD, kSize * sizeof(float) + sizeof(double)},
{OptimizerConfig::Adam, kSize * 3 * sizeof(float) + sizeof(double)},
}; };
for (size_t i = 0; i < opts_.size(); ++i) { for (size_t i = 0; i < opts_.size(); ++i) {
int state_len = 0; int state_len = 0;
std::string state = opts_[i]->SerializeState(&state_len); std::string state = opts_[i]->SerializeState(&state_len);
EXPECT_EQ(state_len, expected_state_len[opts_table_[i]]); EXPECT_EQ(state_len, expected_state_len[opts_table_[i + 1]]);
opts_[i]->DeserializeState(state); opts_[i]->DeserializeState(state);
} }
} }
......
...@@ -29,11 +29,9 @@ void SGDOptimizer::Update(const Tensor *gradient) { ...@@ -29,11 +29,9 @@ void SGDOptimizer::Update(const Tensor *gradient) {
const char *SGDOptimizer::SerializeState(int *state_len) { const char *SGDOptimizer::SerializeState(int *state_len) {
SGDOptimizerState state; SGDOptimizerState state;
// TODO(zhihong) : add lr_policy serialization
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*momentums_, state.mutable_momentums()); if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums());
auto str = state.SerializeAsString(); auto str = state.SerializeAsString();
*state_len = str.size(); *state_len = str.size();
return str.c_str(); return str.c_str();
...@@ -42,9 +40,7 @@ const char *SGDOptimizer::SerializeState(int *state_len) { ...@@ -42,9 +40,7 @@ const char *SGDOptimizer::SerializeState(int *state_len) {
void SGDOptimizer::DeserializeState(const std::string &str) { void SGDOptimizer::DeserializeState(const std::string &str) {
SGDOptimizerState state; SGDOptimizerState state;
state.ParseFromString(str); state.ParseFromString(str);
// TODO(zhihong) : add lr_policy DeserializeState
num_sample_passed_ = state.num_sample_passed(); num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_); ProtoToTensor(state.parameter(), parameter_);
ProtoToTensor(state.parameter(), momentums_); ProtoToTensor(state.parameter(), momentums_);
} }
......
...@@ -53,12 +53,12 @@ message AdamConfig { ...@@ -53,12 +53,12 @@ message AdamConfig {
optional double decay = 44; optional double decay = 44;
} }
message ConstLr { message ConstLrConfig {
// learninRate Policy // learninRate Policy
required double learning_rate = 1 [default = 1.0]; required double learning_rate = 1 [default = 1.0];
} }
message LinearLr { message LinearLrConfig {
// learninRate Policy // learninRate Policy
required double learning_rate = 1 [default = 1.0]; required double learning_rate = 1 [default = 1.0];
optional double lr_decay_a = 2; optional double lr_decay_a = 2;
...@@ -139,12 +139,12 @@ message OptimizerConfig { ...@@ -139,12 +139,12 @@ message OptimizerConfig {
optional AdamConfig adam = 6; optional AdamConfig adam = 6;
enum LrPolicy { enum LrPolicy {
ConstLr = 0; Const = 0;
LinearLr = 1; Linear = 1;
} }
required LrPolicy lr_policy = 11; required LrPolicy lr_policy = 11;
optional ConstLr const_lr = 12; optional ConstLrConfig const_lr = 12;
optional LinearLr linear_lr = 13; optional LinearLrConfig linear_lr = 13;
// common config of optimizer // common config of optimizer
// gradient clip when L2 exceeding value // gradient clip when L2 exceeding value
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册