提交 6398c15c 编写于 作者: D dzhwinter 提交者: GitHub

Merge pull request #2718 from dzhwinter/lr_state

"lr state serialization"
...@@ -27,22 +27,24 @@ void AdadeltaOptimizer::Update(const Tensor* gradient) { ...@@ -27,22 +27,24 @@ void AdadeltaOptimizer::Update(const Tensor* gradient) {
const char* AdadeltaOptimizer::SerializeState(int* state_len) { const char* AdadeltaOptimizer::SerializeState(int* state_len) {
AdadeltaOptimizerState state; AdadeltaOptimizerState state;
// TODO(zhihong) : add lr_policy serialization
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
std::string lr_str = this->lr_policy_->SerializeState(state_len);
state.mutable_lr_state()->ParseFromString(lr_str);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*accum_gradient_, state.mutable_accum_gradient()); TensorToProto(*accum_gradient_, state.mutable_accum_gradient());
TensorToProto(*accum_delta_, state.mutable_accum_delta()); TensorToProto(*accum_delta_, state.mutable_accum_delta());
TensorToProto(*update_delta_, state.mutable_update_delta()); TensorToProto(*update_delta_, state.mutable_update_delta());
auto str = state.SerializeAsString(); auto str = state.SerializeAsString();
*state_len = str.size(); *state_len += str.size();
return str.c_str(); return str.c_str();
} }
void AdadeltaOptimizer::DeserializeState(const std::string& str) { void AdadeltaOptimizer::DeserializeState(const std::string& str) {
AdadeltaOptimizerState state; AdadeltaOptimizerState state;
state.ParseFromString(str); state.ParseFromString(str);
// TODO(zhihong) : add lr_policy DeserializeState auto lr_state = state.lr_state();
this->lr_policy_->DeserializeState(lr_state.SerializeAsString());
num_sample_passed_ = state.num_sample_passed(); num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_); ProtoToTensor(state.parameter(), parameter_);
......
...@@ -19,20 +19,23 @@ void AdagradOptimizer::Update(const Tensor* gradient) { ...@@ -19,20 +19,23 @@ void AdagradOptimizer::Update(const Tensor* gradient) {
} }
const char* AdagradOptimizer::SerializeState(int* state_len) { const char* AdagradOptimizer::SerializeState(int* state_len) {
AdagradOptimizerState state; AdagradOptimizerState state;
// TODO(zhihong) : add lr_policy serialization
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
std::string lr_str = this->lr_policy_->SerializeState(state_len);
state.mutable_lr_state()->ParseFromString(lr_str);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*accum_gradient_, state.mutable_accum_gradient()); TensorToProto(*accum_gradient_, state.mutable_accum_gradient());
auto str = state.SerializeAsString(); auto str = state.SerializeAsString();
*state_len = str.size(); *state_len += str.size();
return str.c_str(); return str.c_str();
} }
void AdagradOptimizer::DeserializeState(const std::string& str) { void AdagradOptimizer::DeserializeState(const std::string& str) {
AdagradOptimizerState state; AdagradOptimizerState state;
state.ParseFromString(str); state.ParseFromString(str);
// TODO(zhihong) : add lr_policy DeserializeState auto lr_state = state.lr_state();
this->lr_policy_->DeserializeState(lr_state.SerializeAsString());
num_sample_passed_ = state.num_sample_passed(); num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_); ProtoToTensor(state.parameter(), parameter_);
ProtoToTensor(state.accum_gradient(), accum_gradient_); ProtoToTensor(state.accum_gradient(), accum_gradient_);
......
...@@ -24,20 +24,23 @@ void AdamOptimizer::Update(const Tensor *gradient) { ...@@ -24,20 +24,23 @@ void AdamOptimizer::Update(const Tensor *gradient) {
const char *AdamOptimizer::SerializeState(int *state_len) { const char *AdamOptimizer::SerializeState(int *state_len) {
AdamOptimizerState state; AdamOptimizerState state;
// TODO(zhihong) : add lr_policy serialization std::string lr_str = this->lr_policy_->SerializeState(state_len);
state.mutable_lr_state()->ParseFromString(lr_str);
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*momentums_, state.mutable_momentums()); TensorToProto(*momentums_, state.mutable_momentums());
TensorToProto(*velocitys_, state.mutable_velocitys()); TensorToProto(*velocitys_, state.mutable_velocitys());
auto str = state.SerializeAsString(); auto str = state.SerializeAsString();
*state_len = str.size(); *state_len += str.size();
return str.c_str(); return str.c_str();
} }
void AdamOptimizer::DeserializeState(const std::string &str) { void AdamOptimizer::DeserializeState(const std::string &str) {
AdamOptimizerState state; AdamOptimizerState state;
state.ParseFromString(str); state.ParseFromString(str);
// TODO(zhihong) : add lr_policy DeserializeState auto lr_state = state.lr_state();
this->lr_policy_->DeserializeState(lr_state.SerializeAsString());
num_sample_passed_ = state.num_sample_passed(); num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_); ProtoToTensor(state.parameter(), parameter_);
......
...@@ -17,36 +17,56 @@ public: ...@@ -17,36 +17,56 @@ public:
// constant learning rate policy // constant learning rate policy
class ConstLr final : public LrPolicy { class ConstLr final : public LrPolicy {
public: public:
ConstLr(double lr) : learning_rate(lr){}; ConstLr(double lr) : learning_rate_(lr){};
double LearningRate(const uint64_t num_sample_passed) { double LearningRate(const uint64_t num_sample_passed) {
return learning_rate; return learning_rate_;
}
const char *SerializeState(int *state_len) {
LrPolicyState state;
state.set_learning_rate(learning_rate_);
auto str = state.SerializeAsString();
*state_len = str.size();
return str.c_str();
}
void DeserializeState(const std::string &str) {
LrPolicyState state;
state.ParseFromString(str);
learning_rate_ = state.learning_rate();
} }
const char *SerializeState(int *state_len) { return nullptr; }
void DeserializeState(const std::string &state) {}
private: private:
double learning_rate; double learning_rate_;
}; };
class LinearLr final : public LrPolicy { class LinearLr final : public LrPolicy {
public: public:
LinearLr(double lr, double lr_decay_a, double lr_decay_b) LinearLr(double lr, double lr_decay_a, double lr_decay_b)
: learning_rate(lr), lr_decay_a(lr_decay_a), lr_decay_b(lr_decay_b) {} : learning_rate_(lr), lr_decay_a_(lr_decay_a), lr_decay_b_(lr_decay_b) {}
double LearningRate(const uint64_t num_sample_passed) { double LearningRate(const uint64_t num_sample_passed) {
return std::max(learning_rate - lr_decay_a * num_sample_passed, lr_decay_b); return std::max(learning_rate_ - lr_decay_a_ * num_sample_passed,
lr_decay_b_);
} }
const char *SerializeState(int *state_len) { const char *SerializeState(int *state_len) {
// TODO(zhihong) : add lr_policy serialization LrPolicyState state;
return nullptr; state.set_learning_rate(learning_rate_);
state.set_lr_decay_a(lr_decay_a_);
state.set_lr_decay_b(lr_decay_b_);
auto str = state.SerializeAsString();
*state_len = str.size();
return str.c_str();
} }
void DeserializeState(const std::string &state) { void DeserializeState(const std::string &str) {
// TODO(zhihong) : add lr_policy serialization LrPolicyState state;
state.ParseFromString(str);
learning_rate_ = state.learning_rate();
lr_decay_a_ = state.lr_decay_a();
lr_decay_b_ = state.lr_decay_b();
} }
private: private:
double learning_rate; double learning_rate_;
double lr_decay_a; double lr_decay_a_;
double lr_decay_b; double lr_decay_b_;
}; };
} // namespace optimizer } // namespace optimizer
......
...@@ -30,16 +30,20 @@ void SGDOptimizer::Update(const Tensor *gradient) { ...@@ -30,16 +30,20 @@ void SGDOptimizer::Update(const Tensor *gradient) {
const char *SGDOptimizer::SerializeState(int *state_len) { const char *SGDOptimizer::SerializeState(int *state_len) {
SGDOptimizerState state; SGDOptimizerState state;
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
std::string lr_str = this->lr_policy_->SerializeState(state_len);
state.mutable_lr_state()->ParseFromString(lr_str);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums()); if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums());
auto str = state.SerializeAsString(); auto str = state.SerializeAsString();
*state_len = str.size(); *state_len += str.size();
return str.c_str(); return str.c_str();
} }
void SGDOptimizer::DeserializeState(const std::string &str) { void SGDOptimizer::DeserializeState(const std::string &str) {
SGDOptimizerState state; SGDOptimizerState state;
state.ParseFromString(str); state.ParseFromString(str);
auto lr_state = state.lr_state();
this->lr_policy_->DeserializeState(lr_state.SerializeAsString());
num_sample_passed_ = state.num_sample_passed(); num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_); ProtoToTensor(state.parameter(), parameter_);
if (momentum_ != 0.0) ProtoToTensor(state.parameter(), momentums_); if (momentum_ != 0.0) ProtoToTensor(state.parameter(), momentums_);
......
...@@ -78,11 +78,15 @@ enum DataType { ...@@ -78,11 +78,15 @@ enum DataType {
repeated bytes content = 2; repeated bytes content = 2;
} }
message LrPolicyState {
// learninRate Policy
optional double learning_rate = 1 [default = 1.0];
optional double lr_decay_a = 2;
optional double lr_decay_b = 3;
}
message SGDOptimizerState { message SGDOptimizerState {
// learning rate policy optional LrPolicyState lr_state = 101;
optional double learning_rate = 101;
optional double lr_decay_a = 102;
optional double lr_decay_b = 103;
optional double num_sample_passed = 104; optional double num_sample_passed = 104;
// state // state
optional TensorProto parameter = 1; optional TensorProto parameter = 1;
...@@ -91,9 +95,7 @@ message SGDOptimizerState { ...@@ -91,9 +95,7 @@ message SGDOptimizerState {
message AdadeltaOptimizerState { message AdadeltaOptimizerState {
// learning rate policy // learning rate policy
optional double learning_rate = 101; optional LrPolicyState lr_state = 101;
optional double lr_decay_a = 102;
optional double lr_decay_b = 103;
optional double num_sample_passed = 104; optional double num_sample_passed = 104;
// state // state
optional TensorProto parameter = 1; optional TensorProto parameter = 1;
...@@ -102,11 +104,9 @@ message AdadeltaOptimizerState { ...@@ -102,11 +104,9 @@ message AdadeltaOptimizerState {
optional TensorProto update_delta = 4; optional TensorProto update_delta = 4;
} }
message AdagradOptimizerState { message AdagradOptimizerState {
// learning rate policy optional LrPolicyState lr_state = 101;
optional double learning_rate = 101;
optional double lr_decay_a = 102;
optional double lr_decay_b = 103;
optional double num_sample_passed = 104; optional double num_sample_passed = 104;
// state // state
optional TensorProto parameter = 1; optional TensorProto parameter = 1;
...@@ -114,10 +114,7 @@ message AdagradOptimizerState { ...@@ -114,10 +114,7 @@ message AdagradOptimizerState {
} }
message AdamOptimizerState { message AdamOptimizerState {
// learning rate policy optional LrPolicyState lr_state = 101;
optional double learning_rate = 101;
optional double lr_decay_a = 102;
optional double lr_decay_b = 103;
optional double num_sample_passed = 104; optional double num_sample_passed = 104;
// state // state
optional TensorProto parameter = 1; optional TensorProto parameter = 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册