提交 7edabe74 编写于 作者: D dongzhihong

"polish name convention"

上级 e1acd73f
...@@ -33,7 +33,7 @@ const char *SGDOptimizer::SerializeState(int *state_len) { ...@@ -33,7 +33,7 @@ const char *SGDOptimizer::SerializeState(int *state_len) {
std::string lr_str = this->lr_policy_->SerializeState(state_len); std::string lr_str = this->lr_policy_->SerializeState(state_len);
LrPolicyState lr_state; LrPolicyState lr_state;
lr_state.ParseFromString(lr_str); lr_state.ParseFromString(lr_str);
state.mutable_lr_state() = lr_state; state.mutable_lr_state()->ParseFromString(lr_str);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums()); if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums());
auto str = state.SerializeAsString(); auto str = state.SerializeAsString();
...@@ -44,6 +44,8 @@ const char *SGDOptimizer::SerializeState(int *state_len) { ...@@ -44,6 +44,8 @@ const char *SGDOptimizer::SerializeState(int *state_len) {
void SGDOptimizer::DeserializeState(const std::string &str) { void SGDOptimizer::DeserializeState(const std::string &str) {
SGDOptimizerState state; SGDOptimizerState state;
state.ParseFromString(str); state.ParseFromString(str);
auto lr_state = state.lr_state();
this->lr_policy_->DeserializeState(lr_state.SerializeAsString());
num_sample_passed_ = state.num_sample_passed(); num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_); ProtoToTensor(state.parameter(), parameter_);
if (momentum_ != 0.0) ProtoToTensor(state.parameter(), momentums_); if (momentum_ != 0.0) ProtoToTensor(state.parameter(), momentums_);
......
...@@ -86,7 +86,7 @@ message LrPolicyState { ...@@ -86,7 +86,7 @@ message LrPolicyState {
} }
message SGDOptimizerState { message SGDOptimizerState {
optional LrPolicyState lrstate = 101; optional LrPolicyState lr_state = 101;
optional double num_sample_passed = 104; optional double num_sample_passed = 104;
// state // state
optional TensorProto parameter = 1; optional TensorProto parameter = 1;
...@@ -106,7 +106,7 @@ message AdadeltaOptimizerState { ...@@ -106,7 +106,7 @@ message AdadeltaOptimizerState {
message AdagradOptimizerState { message AdagradOptimizerState {
optional LrPolicyState lrstate = 101; optional LrPolicyState lr_state = 101;
optional double num_sample_passed = 104; optional double num_sample_passed = 104;
// state // state
optional TensorProto parameter = 1; optional TensorProto parameter = 1;
...@@ -114,7 +114,7 @@ message AdagradOptimizerState { ...@@ -114,7 +114,7 @@ message AdagradOptimizerState {
} }
message AdamOptimizerState { message AdamOptimizerState {
optional LrPolicyState lrstate = 101; optional LrPolicyState lr_state = 101;
optional double num_sample_passed = 104; optional double num_sample_passed = 104;
// state // state
optional TensorProto parameter = 1; optional TensorProto parameter = 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册