diff --git a/paddle/optimizer/lr_policy.h b/paddle/optimizer/lr_policy.h index ab5101e2e860ddc8eaeccda0ae77e0d3fb92f186..036c376e10f465c2866a230caf9224f4af5478bc 100644 --- a/paddle/optimizer/lr_policy.h +++ b/paddle/optimizer/lr_policy.h @@ -17,7 +17,7 @@ public: // constant learning rate policy class ConstLr final : public LrPolicy { public: - ConstLr(double lr) : learning_rate(lr){}; + ConstLr(double lr) : learning_rate_(lr){}; double LearningRate(const uint64_t num_sample_passed) { return learning_rate_; } @@ -28,7 +28,7 @@ public: *state_len = str.size(); return str.c_str(); } - void DeserializeState(const std::string &state) { + void DeserializeState(const std::string &str) { LrPolicyState state; state.ParseFromString(str); learning_rate_ = state.learning_rate(); diff --git a/paddle/optimizer/sgd_optimizer.cc b/paddle/optimizer/sgd_optimizer.cc index 9e5477b2ff3b216d864969ba69eae2ce2bb6bdea..527e65144da514901c3968d126cc7ee0eaaebd9f 100644 --- a/paddle/optimizer/sgd_optimizer.cc +++ b/paddle/optimizer/sgd_optimizer.cc @@ -30,7 +30,11 @@ void SGDOptimizer::Update(const Tensor *gradient) { const char *SGDOptimizer::SerializeState(int *state_len) { SGDOptimizerState state; state.set_num_sample_passed(num_sample_passed_); - state.set_lr_ TensorToProto(*parameter_, state.mutable_parameter()); + std::string lr_str = this->lr_policy_->SerializeState(state_len); + LrPolicyState lr_state; + lr_state.ParseFromString(lr_str); + state.mutable_lr_state() = lr_state; + TensorToProto(*parameter_, state.mutable_parameter()); if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums()); auto str = state.SerializeAsString(); *state_len += str.size(); diff --git a/proto/OptimizerConfig.proto b/proto/OptimizerConfig.proto index 19ce289ea39fff0657ceb8eac04e9af20bcd0a03..290932898eb67308daebae2b03df7becd52e221f 100644 --- a/proto/OptimizerConfig.proto +++ b/proto/OptimizerConfig.proto @@ -95,7 +95,7 @@ message SGDOptimizerState { message AdadeltaOptimizerState { // learning rate policy - optional LrPolicyState lrstate = 101; + optional LrPolicyState lr_state = 101; optional double num_sample_passed = 104; // state optional TensorProto parameter = 1;