From 6935dd7bc96e101ec65de39be1d2d8f4f79f1af3 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Tue, 4 Jul 2017 00:16:45 +0800 Subject: [PATCH] "lr state serialization" --- paddle/optimizer/lr_policy.h | 46 ++++++++++++++++++++++--------- paddle/optimizer/sgd_optimizer.cc | 4 +-- proto/OptimizerConfig.proto | 27 ++++++++---------- 3 files changed, 47 insertions(+), 30 deletions(-) diff --git a/paddle/optimizer/lr_policy.h b/paddle/optimizer/lr_policy.h index d8e33ad37a..ab5101e2e8 100644 --- a/paddle/optimizer/lr_policy.h +++ b/paddle/optimizer/lr_policy.h @@ -19,34 +19,54 @@ class ConstLr final : public LrPolicy { public: ConstLr(double lr) : learning_rate(lr){}; double LearningRate(const uint64_t num_sample_passed) { - return learning_rate; + return learning_rate_; + } + const char *SerializeState(int *state_len) { + LrPolicyState state; + state.set_learning_rate(learning_rate_); + auto str = state.SerializeAsString(); + *state_len = str.size(); + return str.c_str(); + } + void DeserializeState(const std::string &state) { + LrPolicyState state; + state.ParseFromString(str); + learning_rate_ = state.learning_rate(); } - const char *SerializeState(int *state_len) { return nullptr; } - void DeserializeState(const std::string &state) {} private: - double learning_rate; + double learning_rate_; }; class LinearLr final : public LrPolicy { public: LinearLr(double lr, double lr_decay_a, double lr_decay_b) - : learning_rate(lr), lr_decay_a(lr_decay_a), lr_decay_b(lr_decay_b) {} + : learning_rate_(lr), lr_decay_a_(lr_decay_a), lr_decay_b_(lr_decay_b) {} double LearningRate(const uint64_t num_sample_passed) { - return std::max(learning_rate - lr_decay_a * num_sample_passed, lr_decay_b); + return std::max(learning_rate_ - lr_decay_a_ * num_sample_passed, + lr_decay_b_); } const char *SerializeState(int *state_len) { - // TODO(zhihong) : add lr_policy serialization - return nullptr; + LrPolicyState state; + state.set_learning_rate(learning_rate_); + state.set_lr_decay_a(lr_decay_a_); + state.set_lr_decay_b(lr_decay_b_); + auto str = state.SerializeAsString(); + *state_len = str.size(); + return str.c_str(); } - void DeserializeState(const std::string &state) { - // TODO(zhihong) : add lr_policy serialization + void DeserializeState(const std::string &str) { + LrPolicyState state; + state.ParseFromString(str); + learning_rate_ = state.learning_rate(); + lr_decay_a_ = state.lr_decay_a(); + lr_decay_b_ = state.lr_decay_b(); } private: - double learning_rate; - double lr_decay_a; - double lr_decay_b; + double learning_rate_; + double lr_decay_a_; + double lr_decay_b_; }; } // namespace optimizer diff --git a/paddle/optimizer/sgd_optimizer.cc b/paddle/optimizer/sgd_optimizer.cc index 34e051003f..9e5477b2ff 100644 --- a/paddle/optimizer/sgd_optimizer.cc +++ b/paddle/optimizer/sgd_optimizer.cc @@ -30,10 +30,10 @@ void SGDOptimizer::Update(const Tensor *gradient) { const char *SGDOptimizer::SerializeState(int *state_len) { SGDOptimizerState state; state.set_num_sample_passed(num_sample_passed_); - TensorToProto(*parameter_, state.mutable_parameter()); + state.set_lr_ TensorToProto(*parameter_, state.mutable_parameter()); if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums()); auto str = state.SerializeAsString(); - *state_len = str.size(); + *state_len += str.size(); return str.c_str(); } diff --git a/proto/OptimizerConfig.proto b/proto/OptimizerConfig.proto index c698d3c2dd..19ce289ea3 100644 --- a/proto/OptimizerConfig.proto +++ b/proto/OptimizerConfig.proto @@ -78,11 +78,15 @@ enum DataType { repeated bytes content = 2; } +message LrPolicyState { + // learninRate Policy + optional double learning_rate = 1 [default = 1.0]; + optional double lr_decay_a = 2; + optional double lr_decay_b = 3; +} + message SGDOptimizerState { - // learning rate policy - optional double learning_rate = 101; - optional double lr_decay_a = 102; - optional double lr_decay_b = 103; + optional LrPolicyState lrstate = 101; optional double num_sample_passed = 104; // state optional TensorProto parameter = 1; @@ -91,9 +95,7 @@ message SGDOptimizerState { message AdadeltaOptimizerState { // learning rate policy - optional double learning_rate = 101; - optional double lr_decay_a = 102; - optional double lr_decay_b = 103; + optional LrPolicyState lrstate = 101; optional double num_sample_passed = 104; // state optional TensorProto parameter = 1; @@ -102,11 +104,9 @@ message AdadeltaOptimizerState { optional TensorProto update_delta = 4; } + message AdagradOptimizerState { - // learning rate policy - optional double learning_rate = 101; - optional double lr_decay_a = 102; - optional double lr_decay_b = 103; + optional LrPolicyState lrstate = 101; optional double num_sample_passed = 104; // state optional TensorProto parameter = 1; @@ -114,10 +114,7 @@ message AdagradOptimizerState { } message AdamOptimizerState { - // learning rate policy - optional double learning_rate = 101; - optional double lr_decay_a = 102; - optional double lr_decay_b = 103; + optional LrPolicyState lrstate = 101; optional double num_sample_passed = 104; // state optional TensorProto parameter = 1; -- GitLab