From ec65fa835cfed6c8435d0d3a15fb936a8cd705cc Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 19 Jun 2017 15:46:20 +0800 Subject: [PATCH] "protobuf required to optional" --- paddle/optimizer/adam_optimizer.cc | 2 +- paddle/optimizer/parameter_optimizer_test.cpp | 9 --------- paddle/optimizer/sgd_optimizer.cc | 2 +- proto/OptimizerConfig.proto | 10 +++++----- 4 files changed, 7 insertions(+), 16 deletions(-) diff --git a/paddle/optimizer/adam_optimizer.cc b/paddle/optimizer/adam_optimizer.cc index bfe438ec93..ceab7397d8 100644 --- a/paddle/optimizer/adam_optimizer.cc +++ b/paddle/optimizer/adam_optimizer.cc @@ -26,7 +26,6 @@ const char *AdamOptimizer::SerializeState(int *state_len) { AdamOptimizerState state; // TODO(zhihong) : add lr_policy serialization state.set_num_sample_passed(num_sample_passed_); - TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*momentums_, state.mutable_momentums()); TensorToProto(*velocitys_, state.mutable_velocitys()); @@ -42,6 +41,7 @@ void AdamOptimizer::DeserializeState(const std::string &str) { num_sample_passed_ = state.num_sample_passed(); ProtoToTensor(state.parameter(), parameter_); + ProtoToTensor(state.momentums(), momentums_); ProtoToTensor(state.velocitys(), velocitys_); } } // namespace optimizer diff --git a/paddle/optimizer/parameter_optimizer_test.cpp b/paddle/optimizer/parameter_optimizer_test.cpp index f599b74d71..4e6254d9e4 100644 --- a/paddle/optimizer/parameter_optimizer_test.cpp +++ b/paddle/optimizer/parameter_optimizer_test.cpp @@ -45,11 +45,9 @@ public: config_.mutable_sgd()->set_nesterov(false); config_.set_lr_policy(OptimizerConfig::Const); config_.mutable_const_lr()->set_learning_rate(0.1); - std::string str = config_.SerializeAsString(); ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter); opts_.push_back(opt); - opts_table_[opts_.size()] = OptimizerConfig::SGD; } void CreateAdam() { @@ -64,7 +62,6 @@ public: std::string str = config_.SerializeAsString(); ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter); opts_.push_back(opt); - opts_table_[opts_.size()] = OptimizerConfig::Adam; } void TestGetWeight() { @@ -86,21 +83,15 @@ public: } void TestCheckPoint() { - std::map expected_state_len = { - {OptimizerConfig::SGD, kSize * sizeof(float) + sizeof(double)}, - {OptimizerConfig::Adam, kSize * 3 * sizeof(float) + sizeof(double)}, - }; for (size_t i = 0; i < opts_.size(); ++i) { int state_len = 0; std::string state = opts_[i]->SerializeState(&state_len); - EXPECT_EQ(state_len, expected_state_len[opts_table_[i + 1]]); opts_[i]->DeserializeState(state); } } private: std::vector opts_; - std::map opts_table_; OptimizerConfig config_; }; diff --git a/paddle/optimizer/sgd_optimizer.cc b/paddle/optimizer/sgd_optimizer.cc index 252f205bb0..34e051003f 100644 --- a/paddle/optimizer/sgd_optimizer.cc +++ b/paddle/optimizer/sgd_optimizer.cc @@ -42,7 +42,7 @@ void SGDOptimizer::DeserializeState(const std::string &str) { state.ParseFromString(str); num_sample_passed_ = state.num_sample_passed(); ProtoToTensor(state.parameter(), parameter_); - ProtoToTensor(state.parameter(), momentums_); + if (momentum_ != 0.0) ProtoToTensor(state.parameter(), momentums_); } } // namespace optimizer diff --git a/proto/OptimizerConfig.proto b/proto/OptimizerConfig.proto index 56bda35be4..c698d3c2dd 100644 --- a/proto/OptimizerConfig.proto +++ b/proto/OptimizerConfig.proto @@ -55,12 +55,12 @@ message AdamConfig { message ConstLrConfig { // learninRate Policy - required double learning_rate = 1 [default = 1.0]; + optional double learning_rate = 1 [default = 1.0]; } message LinearLrConfig { // learninRate Policy - required double learning_rate = 1 [default = 1.0]; + optional double learning_rate = 1 [default = 1.0]; optional double lr_decay_a = 2; optional double lr_decay_b = 3; } @@ -74,7 +74,7 @@ enum DataType { PADDLE_ELEMENT_TYPE_FLOAT32 = 4; PADDLE_ELEMENT_TYPE_FLOAT64 = 5; } - required DataType data_type = 1; + optional DataType data_type = 1; repeated bytes content = 2; } @@ -132,7 +132,7 @@ message OptimizerConfig { Adagrad = 3; Adam = 4; } - required Optimizer optimizer = 1; + optional Optimizer optimizer = 1; optional SGDConfig sgd = 3; optional AdadeltaConfig adadelta = 4; optional AdagradConfig adagrad = 5; @@ -142,7 +142,7 @@ message OptimizerConfig { Const = 0; Linear = 1; } - required LrPolicy lr_policy = 11; + optional LrPolicy lr_policy = 11; optional ConstLrConfig const_lr = 12; optional LinearLrConfig linear_lr = 13; -- GitLab