diff --git a/paddle/optimizer/sgd_optimizer.cc b/paddle/optimizer/sgd_optimizer.cc index 527e65144da514901c3968d126cc7ee0eaaebd9f..96570eab26b340eb68a4ed0a5794d2ba5533034e 100644 --- a/paddle/optimizer/sgd_optimizer.cc +++ b/paddle/optimizer/sgd_optimizer.cc @@ -33,7 +33,7 @@ const char *SGDOptimizer::SerializeState(int *state_len) { std::string lr_str = this->lr_policy_->SerializeState(state_len); LrPolicyState lr_state; lr_state.ParseFromString(lr_str); - state.mutable_lr_state() = lr_state; + state.mutable_lr_state()->ParseFromString(lr_str); TensorToProto(*parameter_, state.mutable_parameter()); if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums()); auto str = state.SerializeAsString(); @@ -44,6 +44,8 @@ const char *SGDOptimizer::SerializeState(int *state_len) { void SGDOptimizer::DeserializeState(const std::string &str) { SGDOptimizerState state; state.ParseFromString(str); + auto lr_state = state.lr_state(); + this->lr_policy_->DeserializeState(lr_state.SerializeAsString()); num_sample_passed_ = state.num_sample_passed(); ProtoToTensor(state.parameter(), parameter_); if (momentum_ != 0.0) ProtoToTensor(state.parameter(), momentums_); diff --git a/proto/OptimizerConfig.proto b/proto/OptimizerConfig.proto index 290932898eb67308daebae2b03df7becd52e221f..2a87e293f64d3398dea2641c3ff292eceec7e154 100644 --- a/proto/OptimizerConfig.proto +++ b/proto/OptimizerConfig.proto @@ -86,7 +86,7 @@ message LrPolicyState { } message SGDOptimizerState { - optional LrPolicyState lrstate = 101; + optional LrPolicyState lr_state = 101; optional double num_sample_passed = 104; // state optional TensorProto parameter = 1; @@ -106,7 +106,7 @@ message AdadeltaOptimizerState { message AdagradOptimizerState { - optional LrPolicyState lrstate = 101; + optional LrPolicyState lr_state = 101; optional double num_sample_passed = 104; // state optional TensorProto parameter = 1; @@ -114,7 +114,7 @@ message AdagradOptimizerState { } message AdamOptimizerState { - optional LrPolicyState lrstate = 101; + optional LrPolicyState lr_state = 101; optional double num_sample_passed = 104; // state optional TensorProto parameter = 1;