提交 6935dd7b 编写于 作者: D dongzhihong

"lr state serialization"

上级 f448edf1
...@@ -19,34 +19,54 @@ class ConstLr final : public LrPolicy { ...@@ -19,34 +19,54 @@ class ConstLr final : public LrPolicy {
public: public:
ConstLr(double lr) : learning_rate(lr){}; ConstLr(double lr) : learning_rate(lr){};
double LearningRate(const uint64_t num_sample_passed) { double LearningRate(const uint64_t num_sample_passed) {
return learning_rate; return learning_rate_;
}
const char *SerializeState(int *state_len) {
LrPolicyState state;
state.set_learning_rate(learning_rate_);
auto str = state.SerializeAsString();
*state_len = str.size();
return str.c_str();
}
void DeserializeState(const std::string &state) {
LrPolicyState state;
state.ParseFromString(str);
learning_rate_ = state.learning_rate();
} }
const char *SerializeState(int *state_len) { return nullptr; }
void DeserializeState(const std::string &state) {}
private: private:
double learning_rate; double learning_rate_;
}; };
class LinearLr final : public LrPolicy { class LinearLr final : public LrPolicy {
public: public:
LinearLr(double lr, double lr_decay_a, double lr_decay_b) LinearLr(double lr, double lr_decay_a, double lr_decay_b)
: learning_rate(lr), lr_decay_a(lr_decay_a), lr_decay_b(lr_decay_b) {} : learning_rate_(lr), lr_decay_a_(lr_decay_a), lr_decay_b_(lr_decay_b) {}
double LearningRate(const uint64_t num_sample_passed) { double LearningRate(const uint64_t num_sample_passed) {
return std::max(learning_rate - lr_decay_a * num_sample_passed, lr_decay_b); return std::max(learning_rate_ - lr_decay_a_ * num_sample_passed,
lr_decay_b_);
} }
const char *SerializeState(int *state_len) { const char *SerializeState(int *state_len) {
// TODO(zhihong) : add lr_policy serialization LrPolicyState state;
return nullptr; state.set_learning_rate(learning_rate_);
state.set_lr_decay_a(lr_decay_a_);
state.set_lr_decay_b(lr_decay_b_);
auto str = state.SerializeAsString();
*state_len = str.size();
return str.c_str();
} }
void DeserializeState(const std::string &state) { void DeserializeState(const std::string &str) {
// TODO(zhihong) : add lr_policy serialization LrPolicyState state;
state.ParseFromString(str);
learning_rate_ = state.learning_rate();
lr_decay_a_ = state.lr_decay_a();
lr_decay_b_ = state.lr_decay_b();
} }
private: private:
double learning_rate; double learning_rate_;
double lr_decay_a; double lr_decay_a_;
double lr_decay_b; double lr_decay_b_;
}; };
} // namespace optimizer } // namespace optimizer
......
...@@ -30,10 +30,10 @@ void SGDOptimizer::Update(const Tensor *gradient) { ...@@ -30,10 +30,10 @@ void SGDOptimizer::Update(const Tensor *gradient) {
const char *SGDOptimizer::SerializeState(int *state_len) { const char *SGDOptimizer::SerializeState(int *state_len) {
SGDOptimizerState state; SGDOptimizerState state;
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
TensorToProto(*parameter_, state.mutable_parameter()); state.set_lr_ TensorToProto(*parameter_, state.mutable_parameter());
if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums()); if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums());
auto str = state.SerializeAsString(); auto str = state.SerializeAsString();
*state_len = str.size(); *state_len += str.size();
return str.c_str(); return str.c_str();
} }
......
...@@ -78,11 +78,15 @@ enum DataType { ...@@ -78,11 +78,15 @@ enum DataType {
repeated bytes content = 2; repeated bytes content = 2;
} }
message LrPolicyState {
// learninRate Policy
optional double learning_rate = 1 [default = 1.0];
optional double lr_decay_a = 2;
optional double lr_decay_b = 3;
}
message SGDOptimizerState { message SGDOptimizerState {
// learning rate policy optional LrPolicyState lrstate = 101;
optional double learning_rate = 101;
optional double lr_decay_a = 102;
optional double lr_decay_b = 103;
optional double num_sample_passed = 104; optional double num_sample_passed = 104;
// state // state
optional TensorProto parameter = 1; optional TensorProto parameter = 1;
...@@ -91,9 +95,7 @@ message SGDOptimizerState { ...@@ -91,9 +95,7 @@ message SGDOptimizerState {
message AdadeltaOptimizerState { message AdadeltaOptimizerState {
// learning rate policy // learning rate policy
optional double learning_rate = 101; optional LrPolicyState lrstate = 101;
optional double lr_decay_a = 102;
optional double lr_decay_b = 103;
optional double num_sample_passed = 104; optional double num_sample_passed = 104;
// state // state
optional TensorProto parameter = 1; optional TensorProto parameter = 1;
...@@ -102,11 +104,9 @@ message AdadeltaOptimizerState { ...@@ -102,11 +104,9 @@ message AdadeltaOptimizerState {
optional TensorProto update_delta = 4; optional TensorProto update_delta = 4;
} }
message AdagradOptimizerState { message AdagradOptimizerState {
// learning rate policy optional LrPolicyState lrstate = 101;
optional double learning_rate = 101;
optional double lr_decay_a = 102;
optional double lr_decay_b = 103;
optional double num_sample_passed = 104; optional double num_sample_passed = 104;
// state // state
optional TensorProto parameter = 1; optional TensorProto parameter = 1;
...@@ -114,10 +114,7 @@ message AdagradOptimizerState { ...@@ -114,10 +114,7 @@ message AdagradOptimizerState {
} }
message AdamOptimizerState { message AdamOptimizerState {
// learning rate policy optional LrPolicyState lrstate = 101;
optional double learning_rate = 101;
optional double lr_decay_a = 102;
optional double lr_decay_b = 103;
optional double num_sample_passed = 104; optional double num_sample_passed = 104;
// state // state
optional TensorProto parameter = 1; optional TensorProto parameter = 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册