提交 e1acd73f 编写于 作者: D dongzhihong

"fix typo deleted part"

上级 6935dd7b
......@@ -17,7 +17,7 @@ public:
// constant learning rate policy
class ConstLr final : public LrPolicy {
public:
ConstLr(double lr) : learning_rate(lr){};
ConstLr(double lr) : learning_rate_(lr){};
double LearningRate(const uint64_t num_sample_passed) {
return learning_rate_;
}
......@@ -28,7 +28,7 @@ public:
*state_len = str.size();
return str.c_str();
}
void DeserializeState(const std::string &state) {
void DeserializeState(const std::string &str) {
LrPolicyState state;
state.ParseFromString(str);
learning_rate_ = state.learning_rate();
......
......@@ -30,7 +30,11 @@ void SGDOptimizer::Update(const Tensor *gradient) {
const char *SGDOptimizer::SerializeState(int *state_len) {
SGDOptimizerState state;
state.set_num_sample_passed(num_sample_passed_);
state.set_lr_ TensorToProto(*parameter_, state.mutable_parameter());
std::string lr_str = this->lr_policy_->SerializeState(state_len);
LrPolicyState lr_state;
lr_state.ParseFromString(lr_str);
state.mutable_lr_state() = lr_state;
TensorToProto(*parameter_, state.mutable_parameter());
if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums());
auto str = state.SerializeAsString();
*state_len += str.size();
......
......@@ -95,7 +95,7 @@ message SGDOptimizerState {
message AdadeltaOptimizerState {
// learning rate policy
optional LrPolicyState lrstate = 101;
optional LrPolicyState lr_state = 101;
optional double num_sample_passed = 104;
// state
optional TensorProto parameter = 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册