提交 e1acd73f 编写于 作者: D dongzhihong

"fix typo deleted part"

上级 6935dd7b
...@@ -17,7 +17,7 @@ public: ...@@ -17,7 +17,7 @@ public:
// constant learning rate policy // constant learning rate policy
class ConstLr final : public LrPolicy { class ConstLr final : public LrPolicy {
public: public:
ConstLr(double lr) : learning_rate(lr){}; ConstLr(double lr) : learning_rate_(lr){};
double LearningRate(const uint64_t num_sample_passed) { double LearningRate(const uint64_t num_sample_passed) {
return learning_rate_; return learning_rate_;
} }
...@@ -28,7 +28,7 @@ public: ...@@ -28,7 +28,7 @@ public:
*state_len = str.size(); *state_len = str.size();
return str.c_str(); return str.c_str();
} }
void DeserializeState(const std::string &state) { void DeserializeState(const std::string &str) {
LrPolicyState state; LrPolicyState state;
state.ParseFromString(str); state.ParseFromString(str);
learning_rate_ = state.learning_rate(); learning_rate_ = state.learning_rate();
......
...@@ -30,7 +30,11 @@ void SGDOptimizer::Update(const Tensor *gradient) { ...@@ -30,7 +30,11 @@ void SGDOptimizer::Update(const Tensor *gradient) {
const char *SGDOptimizer::SerializeState(int *state_len) { const char *SGDOptimizer::SerializeState(int *state_len) {
SGDOptimizerState state; SGDOptimizerState state;
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
state.set_lr_ TensorToProto(*parameter_, state.mutable_parameter()); std::string lr_str = this->lr_policy_->SerializeState(state_len);
LrPolicyState lr_state;
lr_state.ParseFromString(lr_str);
state.mutable_lr_state() = lr_state;
TensorToProto(*parameter_, state.mutable_parameter());
if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums()); if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums());
auto str = state.SerializeAsString(); auto str = state.SerializeAsString();
*state_len += str.size(); *state_len += str.size();
......
...@@ -95,7 +95,7 @@ message SGDOptimizerState { ...@@ -95,7 +95,7 @@ message SGDOptimizerState {
message AdadeltaOptimizerState { message AdadeltaOptimizerState {
// learning rate policy // learning rate policy
optional LrPolicyState lrstate = 101; optional LrPolicyState lr_state = 101;
optional double num_sample_passed = 104; optional double num_sample_passed = 104;
// state // state
optional TensorProto parameter = 1; optional TensorProto parameter = 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册