提交 6935dd7b 编写于 作者: D dongzhihong

"lr state serialization"

上级 f448edf1
......@@ -19,34 +19,54 @@ class ConstLr final : public LrPolicy {
public:
ConstLr(double lr) : learning_rate(lr){};
double LearningRate(const uint64_t num_sample_passed) {
return learning_rate;
return learning_rate_;
}
const char *SerializeState(int *state_len) {
LrPolicyState state;
state.set_learning_rate(learning_rate_);
auto str = state.SerializeAsString();
*state_len = str.size();
return str.c_str();
}
void DeserializeState(const std::string &state) {
LrPolicyState state;
state.ParseFromString(str);
learning_rate_ = state.learning_rate();
}
const char *SerializeState(int *state_len) { return nullptr; }
void DeserializeState(const std::string &state) {}
private:
double learning_rate;
double learning_rate_;
};
class LinearLr final : public LrPolicy {
public:
LinearLr(double lr, double lr_decay_a, double lr_decay_b)
: learning_rate(lr), lr_decay_a(lr_decay_a), lr_decay_b(lr_decay_b) {}
: learning_rate_(lr), lr_decay_a_(lr_decay_a), lr_decay_b_(lr_decay_b) {}
double LearningRate(const uint64_t num_sample_passed) {
return std::max(learning_rate - lr_decay_a * num_sample_passed, lr_decay_b);
return std::max(learning_rate_ - lr_decay_a_ * num_sample_passed,
lr_decay_b_);
}
const char *SerializeState(int *state_len) {
// TODO(zhihong) : add lr_policy serialization
return nullptr;
LrPolicyState state;
state.set_learning_rate(learning_rate_);
state.set_lr_decay_a(lr_decay_a_);
state.set_lr_decay_b(lr_decay_b_);
auto str = state.SerializeAsString();
*state_len = str.size();
return str.c_str();
}
void DeserializeState(const std::string &state) {
// TODO(zhihong) : add lr_policy serialization
void DeserializeState(const std::string &str) {
LrPolicyState state;
state.ParseFromString(str);
learning_rate_ = state.learning_rate();
lr_decay_a_ = state.lr_decay_a();
lr_decay_b_ = state.lr_decay_b();
}
private:
double learning_rate;
double lr_decay_a;
double lr_decay_b;
double learning_rate_;
double lr_decay_a_;
double lr_decay_b_;
};
} // namespace optimizer
......
......@@ -30,10 +30,10 @@ void SGDOptimizer::Update(const Tensor *gradient) {
const char *SGDOptimizer::SerializeState(int *state_len) {
SGDOptimizerState state;
state.set_num_sample_passed(num_sample_passed_);
TensorToProto(*parameter_, state.mutable_parameter());
state.set_lr_ TensorToProto(*parameter_, state.mutable_parameter());
if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums());
auto str = state.SerializeAsString();
*state_len = str.size();
*state_len += str.size();
return str.c_str();
}
......
......@@ -78,11 +78,15 @@ enum DataType {
repeated bytes content = 2;
}
message LrPolicyState {
// learninRate Policy
optional double learning_rate = 1 [default = 1.0];
optional double lr_decay_a = 2;
optional double lr_decay_b = 3;
}
message SGDOptimizerState {
// learning rate policy
optional double learning_rate = 101;
optional double lr_decay_a = 102;
optional double lr_decay_b = 103;
optional LrPolicyState lrstate = 101;
optional double num_sample_passed = 104;
// state
optional TensorProto parameter = 1;
......@@ -91,9 +95,7 @@ message SGDOptimizerState {
message AdadeltaOptimizerState {
// learning rate policy
optional double learning_rate = 101;
optional double lr_decay_a = 102;
optional double lr_decay_b = 103;
optional LrPolicyState lrstate = 101;
optional double num_sample_passed = 104;
// state
optional TensorProto parameter = 1;
......@@ -102,11 +104,9 @@ message AdadeltaOptimizerState {
optional TensorProto update_delta = 4;
}
message AdagradOptimizerState {
// learning rate policy
optional double learning_rate = 101;
optional double lr_decay_a = 102;
optional double lr_decay_b = 103;
optional LrPolicyState lrstate = 101;
optional double num_sample_passed = 104;
// state
optional TensorProto parameter = 1;
......@@ -114,10 +114,7 @@ message AdagradOptimizerState {
}
message AdamOptimizerState {
// learning rate policy
optional double learning_rate = 101;
optional double lr_decay_a = 102;
optional double lr_decay_b = 103;
optional LrPolicyState lrstate = 101;
optional double num_sample_passed = 104;
// state
optional TensorProto parameter = 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册