diff --git a/paddle/optimizer/adadelta_optimizer.cc b/paddle/optimizer/adadelta_optimizer.cc index 465ad5e0d2089121a0f11ab916afe0420cbcfab7..6eec5d846fa5ef6b25e7646200dad1d452dda806 100644 --- a/paddle/optimizer/adadelta_optimizer.cc +++ b/paddle/optimizer/adadelta_optimizer.cc @@ -27,22 +27,24 @@ void AdadeltaOptimizer::Update(const Tensor* gradient) { const char* AdadeltaOptimizer::SerializeState(int* state_len) { AdadeltaOptimizerState state; - // TODO(zhihong) : add lr_policy serialization state.set_num_sample_passed(num_sample_passed_); + std::string lr_str = this->lr_policy_->SerializeState(state_len); + state.mutable_lr_state()->ParseFromString(lr_str); TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*accum_gradient_, state.mutable_accum_gradient()); TensorToProto(*accum_delta_, state.mutable_accum_delta()); TensorToProto(*update_delta_, state.mutable_update_delta()); auto str = state.SerializeAsString(); - *state_len = str.size(); + *state_len += str.size(); return str.c_str(); } void AdadeltaOptimizer::DeserializeState(const std::string& str) { AdadeltaOptimizerState state; state.ParseFromString(str); - // TODO(zhihong) : add lr_policy DeserializeState + auto lr_state = state.lr_state(); + this->lr_policy_->DeserializeState(lr_state.SerializeAsString()); num_sample_passed_ = state.num_sample_passed(); ProtoToTensor(state.parameter(), parameter_); diff --git a/paddle/optimizer/adagrad_optimizer.cc b/paddle/optimizer/adagrad_optimizer.cc index bdaa7877d2bc58c17c51b977852d4b6fec511ed2..5b92610ac547ee11cedf2e49e4d7f1db4b2da646 100644 --- a/paddle/optimizer/adagrad_optimizer.cc +++ b/paddle/optimizer/adagrad_optimizer.cc @@ -19,20 +19,23 @@ void AdagradOptimizer::Update(const Tensor* gradient) { } const char* AdagradOptimizer::SerializeState(int* state_len) { AdagradOptimizerState state; - // TODO(zhihong) : add lr_policy serialization state.set_num_sample_passed(num_sample_passed_); + std::string lr_str = this->lr_policy_->SerializeState(state_len); + state.mutable_lr_state()->ParseFromString(lr_str); TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*accum_gradient_, state.mutable_accum_gradient()); auto str = state.SerializeAsString(); - *state_len = str.size(); + *state_len += str.size(); return str.c_str(); } void AdagradOptimizer::DeserializeState(const std::string& str) { AdagradOptimizerState state; state.ParseFromString(str); - // TODO(zhihong) : add lr_policy DeserializeState + auto lr_state = state.lr_state(); + this->lr_policy_->DeserializeState(lr_state.SerializeAsString()); + num_sample_passed_ = state.num_sample_passed(); ProtoToTensor(state.parameter(), parameter_); ProtoToTensor(state.accum_gradient(), accum_gradient_); diff --git a/paddle/optimizer/adam_optimizer.cc b/paddle/optimizer/adam_optimizer.cc index ceab7397d87349c64ca9e5d11990cb38068421be..1ebb6b1e0f7b4edcbac1b28319fd4de576f85f6a 100644 --- a/paddle/optimizer/adam_optimizer.cc +++ b/paddle/optimizer/adam_optimizer.cc @@ -24,20 +24,23 @@ void AdamOptimizer::Update(const Tensor *gradient) { const char *AdamOptimizer::SerializeState(int *state_len) { AdamOptimizerState state; - // TODO(zhihong) : add lr_policy serialization + std::string lr_str = this->lr_policy_->SerializeState(state_len); + state.mutable_lr_state()->ParseFromString(lr_str); state.set_num_sample_passed(num_sample_passed_); + TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*momentums_, state.mutable_momentums()); TensorToProto(*velocitys_, state.mutable_velocitys()); auto str = state.SerializeAsString(); - *state_len = str.size(); + *state_len += str.size(); return str.c_str(); } void AdamOptimizer::DeserializeState(const std::string &str) { AdamOptimizerState state; state.ParseFromString(str); - // TODO(zhihong) : add lr_policy DeserializeState + auto lr_state = state.lr_state(); + this->lr_policy_->DeserializeState(lr_state.SerializeAsString()); num_sample_passed_ = state.num_sample_passed(); ProtoToTensor(state.parameter(), parameter_); diff --git a/paddle/optimizer/sgd_optimizer.cc b/paddle/optimizer/sgd_optimizer.cc index 96570eab26b340eb68a4ed0a5794d2ba5533034e..15418faa840c19e776f293700ee886991754fb04 100644 --- a/paddle/optimizer/sgd_optimizer.cc +++ b/paddle/optimizer/sgd_optimizer.cc @@ -31,8 +31,6 @@ const char *SGDOptimizer::SerializeState(int *state_len) { SGDOptimizerState state; state.set_num_sample_passed(num_sample_passed_); std::string lr_str = this->lr_policy_->SerializeState(state_len); - LrPolicyState lr_state; - lr_state.ParseFromString(lr_str); state.mutable_lr_state()->ParseFromString(lr_str); TensorToProto(*parameter_, state.mutable_parameter()); if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums());