diff --git a/paddle/optimizer/adam_optimizer.cc b/paddle/optimizer/adam_optimizer.cc index bfe438ec936bf30008e1eee118ca02e2c7b30770..ceab7397d87349c64ca9e5d11990cb38068421be 100644 --- a/paddle/optimizer/adam_optimizer.cc +++ b/paddle/optimizer/adam_optimizer.cc @@ -26,7 +26,6 @@ const char *AdamOptimizer::SerializeState(int *state_len) { AdamOptimizerState state; // TODO(zhihong) : add lr_policy serialization state.set_num_sample_passed(num_sample_passed_); - TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*momentums_, state.mutable_momentums()); TensorToProto(*velocitys_, state.mutable_velocitys()); @@ -42,6 +41,7 @@ void AdamOptimizer::DeserializeState(const std::string &str) { num_sample_passed_ = state.num_sample_passed(); ProtoToTensor(state.parameter(), parameter_); + ProtoToTensor(state.momentums(), momentums_); ProtoToTensor(state.velocitys(), velocitys_); } } // namespace optimizer diff --git a/paddle/optimizer/parameter_optimizer_test.cpp b/paddle/optimizer/parameter_optimizer_test.cpp index f599b74d71c4b94ca1c04c0d94649a32baa52311..4e6254d9e4dab48279b4a880695959526d30d70c 100644 --- a/paddle/optimizer/parameter_optimizer_test.cpp +++ b/paddle/optimizer/parameter_optimizer_test.cpp @@ -45,11 +45,9 @@ public: config_.mutable_sgd()->set_nesterov(false); config_.set_lr_policy(OptimizerConfig::Const); config_.mutable_const_lr()->set_learning_rate(0.1); - std::string str = config_.SerializeAsString(); ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter); opts_.push_back(opt); - opts_table_[opts_.size()] = OptimizerConfig::SGD; } void CreateAdam() { @@ -64,7 +62,6 @@ public: std::string str = config_.SerializeAsString(); ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter); opts_.push_back(opt); - opts_table_[opts_.size()] = OptimizerConfig::Adam; } void TestGetWeight() { @@ -86,21 +83,15 @@ public: } void TestCheckPoint() { - std::map expected_state_len = { - {OptimizerConfig::SGD, kSize * sizeof(float) + sizeof(double)}, - {OptimizerConfig::Adam, kSize * 3 * sizeof(float) + sizeof(double)}, - }; for (size_t i = 0; i < opts_.size(); ++i) { int state_len = 0; std::string state = opts_[i]->SerializeState(&state_len); - EXPECT_EQ(state_len, expected_state_len[opts_table_[i + 1]]); opts_[i]->DeserializeState(state); } } private: std::vector opts_; - std::map opts_table_; OptimizerConfig config_; }; diff --git a/paddle/optimizer/sgd_optimizer.cc b/paddle/optimizer/sgd_optimizer.cc index 252f205bb07de7788b21994d7ae7dcc687b1f3c0..34e051003fa83f11b1f4a39c46856e0372836a1a 100644 --- a/paddle/optimizer/sgd_optimizer.cc +++ b/paddle/optimizer/sgd_optimizer.cc @@ -42,7 +42,7 @@ void SGDOptimizer::DeserializeState(const std::string &str) { state.ParseFromString(str); num_sample_passed_ = state.num_sample_passed(); ProtoToTensor(state.parameter(), parameter_); - ProtoToTensor(state.parameter(), momentums_); + if (momentum_ != 0.0) ProtoToTensor(state.parameter(), momentums_); } } // namespace optimizer diff --git a/proto/OptimizerConfig.proto b/proto/OptimizerConfig.proto index 56bda35be47729a57ce790c8e7ccef196d55a3e2..c698d3c2ddbf58a41ac6ee960af83a257325d1f9 100644 --- a/proto/OptimizerConfig.proto +++ b/proto/OptimizerConfig.proto @@ -55,12 +55,12 @@ message AdamConfig { message ConstLrConfig { // learninRate Policy - required double learning_rate = 1 [default = 1.0]; + optional double learning_rate = 1 [default = 1.0]; } message LinearLrConfig { // learninRate Policy - required double learning_rate = 1 [default = 1.0]; + optional double learning_rate = 1 [default = 1.0]; optional double lr_decay_a = 2; optional double lr_decay_b = 3; } @@ -74,7 +74,7 @@ enum DataType { PADDLE_ELEMENT_TYPE_FLOAT32 = 4; PADDLE_ELEMENT_TYPE_FLOAT64 = 5; } - required DataType data_type = 1; + optional DataType data_type = 1; repeated bytes content = 2; } @@ -132,7 +132,7 @@ message OptimizerConfig { Adagrad = 3; Adam = 4; } - required Optimizer optimizer = 1; + optional Optimizer optimizer = 1; optional SGDConfig sgd = 3; optional AdadeltaConfig adadelta = 4; optional AdagradConfig adagrad = 5; @@ -142,7 +142,7 @@ message OptimizerConfig { Const = 0; Linear = 1; } - required LrPolicy lr_policy = 11; + optional LrPolicy lr_policy = 11; optional ConstLrConfig const_lr = 12; optional LinearLrConfig linear_lr = 13;