From e1acd73fab4e17db5700feba09339f09d7152406 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Tue, 4 Jul 2017 01:13:30 +0800 Subject: [PATCH] "fix typo deleted part" --- paddle/optimizer/lr_policy.h | 4 ++-- paddle/optimizer/sgd_optimizer.cc | 6 +++++- proto/OptimizerConfig.proto | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/paddle/optimizer/lr_policy.h b/paddle/optimizer/lr_policy.h index ab5101e2e86..036c376e10f 100644 --- a/paddle/optimizer/lr_policy.h +++ b/paddle/optimizer/lr_policy.h @@ -17,7 +17,7 @@ public: // constant learning rate policy class ConstLr final : public LrPolicy { public: - ConstLr(double lr) : learning_rate(lr){}; + ConstLr(double lr) : learning_rate_(lr){}; double LearningRate(const uint64_t num_sample_passed) { return learning_rate_; } @@ -28,7 +28,7 @@ public: *state_len = str.size(); return str.c_str(); } - void DeserializeState(const std::string &state) { + void DeserializeState(const std::string &str) { LrPolicyState state; state.ParseFromString(str); learning_rate_ = state.learning_rate(); diff --git a/paddle/optimizer/sgd_optimizer.cc b/paddle/optimizer/sgd_optimizer.cc index 9e5477b2ff3..527e65144da 100644 --- a/paddle/optimizer/sgd_optimizer.cc +++ b/paddle/optimizer/sgd_optimizer.cc @@ -30,7 +30,11 @@ void SGDOptimizer::Update(const Tensor *gradient) { const char *SGDOptimizer::SerializeState(int *state_len) { SGDOptimizerState state; state.set_num_sample_passed(num_sample_passed_); - state.set_lr_ TensorToProto(*parameter_, state.mutable_parameter()); + std::string lr_str = this->lr_policy_->SerializeState(state_len); + LrPolicyState lr_state; + lr_state.ParseFromString(lr_str); + state.mutable_lr_state() = lr_state; + TensorToProto(*parameter_, state.mutable_parameter()); if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums()); auto str = state.SerializeAsString(); *state_len += str.size(); diff --git a/proto/OptimizerConfig.proto b/proto/OptimizerConfig.proto index 19ce289ea39..290932898eb 100644 --- a/proto/OptimizerConfig.proto +++ b/proto/OptimizerConfig.proto @@ -95,7 +95,7 @@ message SGDOptimizerState { message AdadeltaOptimizerState { // learning rate policy - optional LrPolicyState lrstate = 101; + optional LrPolicyState lr_state = 101; optional double num_sample_passed = 104; // state optional TensorProto parameter = 1; -- GitLab