提交 7edabe74 编写于 作者: D dongzhihong

"polish name convention"

上级 e1acd73f
......@@ -33,7 +33,7 @@ const char *SGDOptimizer::SerializeState(int *state_len) {
std::string lr_str = this->lr_policy_->SerializeState(state_len);
LrPolicyState lr_state;
lr_state.ParseFromString(lr_str);
state.mutable_lr_state() = lr_state;
state.mutable_lr_state()->ParseFromString(lr_str);
TensorToProto(*parameter_, state.mutable_parameter());
if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums());
auto str = state.SerializeAsString();
......@@ -44,6 +44,8 @@ const char *SGDOptimizer::SerializeState(int *state_len) {
void SGDOptimizer::DeserializeState(const std::string &str) {
SGDOptimizerState state;
state.ParseFromString(str);
auto lr_state = state.lr_state();
this->lr_policy_->DeserializeState(lr_state.SerializeAsString());
num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_);
if (momentum_ != 0.0) ProtoToTensor(state.parameter(), momentums_);
......
......@@ -86,7 +86,7 @@ message LrPolicyState {
}
message SGDOptimizerState {
optional LrPolicyState lrstate = 101;
optional LrPolicyState lr_state = 101;
optional double num_sample_passed = 104;
// state
optional TensorProto parameter = 1;
......@@ -106,7 +106,7 @@ message AdadeltaOptimizerState {
message AdagradOptimizerState {
optional LrPolicyState lrstate = 101;
optional LrPolicyState lr_state = 101;
optional double num_sample_passed = 104;
// state
optional TensorProto parameter = 1;
......@@ -114,7 +114,7 @@ message AdagradOptimizerState {
}
message AdamOptimizerState {
optional LrPolicyState lrstate = 101;
optional LrPolicyState lr_state = 101;
optional double num_sample_passed = 104;
// state
optional TensorProto parameter = 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册