From dec65aca7ddcbeac1ba54608bc487dc93d2d28f3 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Tue, 4 Jul 2017 01:24:27 +0800 Subject: [PATCH] "fix parameter accumulate size" --- paddle/optimizer/adadelta_optimizer.cc | 8 +++++--- paddle/optimizer/adagrad_optimizer.cc | 9 ++++++--- paddle/optimizer/adam_optimizer.cc | 9 ++++++--- paddle/optimizer/sgd_optimizer.cc | 2 -- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/paddle/optimizer/adadelta_optimizer.cc b/paddle/optimizer/adadelta_optimizer.cc index 465ad5e0d..6eec5d846 100644 --- a/paddle/optimizer/adadelta_optimizer.cc +++ b/paddle/optimizer/adadelta_optimizer.cc @@ -27,22 +27,24 @@ void AdadeltaOptimizer::Update(const Tensor* gradient) { const char* AdadeltaOptimizer::SerializeState(int* state_len) { AdadeltaOptimizerState state; - // TODO(zhihong) : add lr_policy serialization state.set_num_sample_passed(num_sample_passed_); + std::string lr_str = this->lr_policy_->SerializeState(state_len); + state.mutable_lr_state()->ParseFromString(lr_str); TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*accum_gradient_, state.mutable_accum_gradient()); TensorToProto(*accum_delta_, state.mutable_accum_delta()); TensorToProto(*update_delta_, state.mutable_update_delta()); auto str = state.SerializeAsString(); - *state_len = str.size(); + *state_len += str.size(); return str.c_str(); } void AdadeltaOptimizer::DeserializeState(const std::string& str) { AdadeltaOptimizerState state; state.ParseFromString(str); - // TODO(zhihong) : add lr_policy DeserializeState + auto lr_state = state.lr_state(); + this->lr_policy_->DeserializeState(lr_state.SerializeAsString()); num_sample_passed_ = state.num_sample_passed(); ProtoToTensor(state.parameter(), parameter_); diff --git a/paddle/optimizer/adagrad_optimizer.cc b/paddle/optimizer/adagrad_optimizer.cc index bdaa7877d..5b92610ac 100644 --- a/paddle/optimizer/adagrad_optimizer.cc +++ b/paddle/optimizer/adagrad_optimizer.cc @@ -19,20 +19,23 @@ void AdagradOptimizer::Update(const Tensor* gradient) { } const char* AdagradOptimizer::SerializeState(int* state_len) { AdagradOptimizerState state; - // TODO(zhihong) : add lr_policy serialization state.set_num_sample_passed(num_sample_passed_); + std::string lr_str = this->lr_policy_->SerializeState(state_len); + state.mutable_lr_state()->ParseFromString(lr_str); TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*accum_gradient_, state.mutable_accum_gradient()); auto str = state.SerializeAsString(); - *state_len = str.size(); + *state_len += str.size(); return str.c_str(); } void AdagradOptimizer::DeserializeState(const std::string& str) { AdagradOptimizerState state; state.ParseFromString(str); - // TODO(zhihong) : add lr_policy DeserializeState + auto lr_state = state.lr_state(); + this->lr_policy_->DeserializeState(lr_state.SerializeAsString()); + num_sample_passed_ = state.num_sample_passed(); ProtoToTensor(state.parameter(), parameter_); ProtoToTensor(state.accum_gradient(), accum_gradient_); diff --git a/paddle/optimizer/adam_optimizer.cc b/paddle/optimizer/adam_optimizer.cc index ceab7397d..1ebb6b1e0 100644 --- a/paddle/optimizer/adam_optimizer.cc +++ b/paddle/optimizer/adam_optimizer.cc @@ -24,20 +24,23 @@ void AdamOptimizer::Update(const Tensor *gradient) { const char *AdamOptimizer::SerializeState(int *state_len) { AdamOptimizerState state; - // TODO(zhihong) : add lr_policy serialization + std::string lr_str = this->lr_policy_->SerializeState(state_len); + state.mutable_lr_state()->ParseFromString(lr_str); state.set_num_sample_passed(num_sample_passed_); + TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*momentums_, state.mutable_momentums()); TensorToProto(*velocitys_, state.mutable_velocitys()); auto str = state.SerializeAsString(); - *state_len = str.size(); + *state_len += str.size(); return str.c_str(); } void AdamOptimizer::DeserializeState(const std::string &str) { AdamOptimizerState state; state.ParseFromString(str); - // TODO(zhihong) : add lr_policy DeserializeState + auto lr_state = state.lr_state(); + this->lr_policy_->DeserializeState(lr_state.SerializeAsString()); num_sample_passed_ = state.num_sample_passed(); ProtoToTensor(state.parameter(), parameter_); diff --git a/paddle/optimizer/sgd_optimizer.cc b/paddle/optimizer/sgd_optimizer.cc index 96570eab2..15418faa8 100644 --- a/paddle/optimizer/sgd_optimizer.cc +++ b/paddle/optimizer/sgd_optimizer.cc @@ -31,8 +31,6 @@ const char *SGDOptimizer::SerializeState(int *state_len) { SGDOptimizerState state; state.set_num_sample_passed(num_sample_passed_); std::string lr_str = this->lr_policy_->SerializeState(state_len); - LrPolicyState lr_state; - lr_state.ParseFromString(lr_str); state.mutable_lr_state()->ParseFromString(lr_str); TensorToProto(*parameter_, state.mutable_parameter()); if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums()); -- GitLab