提交 ec65fa83 编写于 作者: D dzhwinter

"protobuf required to optional"

上级 65d9e33b
......@@ -26,7 +26,6 @@ const char *AdamOptimizer::SerializeState(int *state_len) {
AdamOptimizerState state;
// TODO(zhihong) : add lr_policy serialization
state.set_num_sample_passed(num_sample_passed_);
TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*momentums_, state.mutable_momentums());
TensorToProto(*velocitys_, state.mutable_velocitys());
......@@ -42,6 +41,7 @@ void AdamOptimizer::DeserializeState(const std::string &str) {
num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_);
ProtoToTensor(state.momentums(), momentums_);
ProtoToTensor(state.velocitys(), velocitys_);
}
} // namespace optimizer
......
......@@ -45,11 +45,9 @@ public:
config_.mutable_sgd()->set_nesterov(false);
config_.set_lr_policy(OptimizerConfig::Const);
config_.mutable_const_lr()->set_learning_rate(0.1);
std::string str = config_.SerializeAsString();
ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter);
opts_.push_back(opt);
opts_table_[opts_.size()] = OptimizerConfig::SGD;
}
void CreateAdam() {
......@@ -64,7 +62,6 @@ public:
std::string str = config_.SerializeAsString();
ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter);
opts_.push_back(opt);
opts_table_[opts_.size()] = OptimizerConfig::Adam;
}
void TestGetWeight() {
......@@ -86,21 +83,15 @@ public:
}
void TestCheckPoint() {
std::map<OptimizerConfig::Optimizer, int> expected_state_len = {
{OptimizerConfig::SGD, kSize * sizeof(float) + sizeof(double)},
{OptimizerConfig::Adam, kSize * 3 * sizeof(float) + sizeof(double)},
};
for (size_t i = 0; i < opts_.size(); ++i) {
int state_len = 0;
std::string state = opts_[i]->SerializeState(&state_len);
EXPECT_EQ(state_len, expected_state_len[opts_table_[i + 1]]);
opts_[i]->DeserializeState(state);
}
}
private:
std::vector<ParameterOptimizer*> opts_;
std::map<int, OptimizerConfig::Optimizer> opts_table_;
OptimizerConfig config_;
};
......
......@@ -42,7 +42,7 @@ void SGDOptimizer::DeserializeState(const std::string &str) {
state.ParseFromString(str);
num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_);
ProtoToTensor(state.parameter(), momentums_);
if (momentum_ != 0.0) ProtoToTensor(state.parameter(), momentums_);
}
} // namespace optimizer
......
......@@ -55,12 +55,12 @@ message AdamConfig {
message ConstLrConfig {
// learninRate Policy
required double learning_rate = 1 [default = 1.0];
optional double learning_rate = 1 [default = 1.0];
}
message LinearLrConfig {
// learninRate Policy
required double learning_rate = 1 [default = 1.0];
optional double learning_rate = 1 [default = 1.0];
optional double lr_decay_a = 2;
optional double lr_decay_b = 3;
}
......@@ -74,7 +74,7 @@ enum DataType {
PADDLE_ELEMENT_TYPE_FLOAT32 = 4;
PADDLE_ELEMENT_TYPE_FLOAT64 = 5;
}
required DataType data_type = 1;
optional DataType data_type = 1;
repeated bytes content = 2;
}
......@@ -132,7 +132,7 @@ message OptimizerConfig {
Adagrad = 3;
Adam = 4;
}
required Optimizer optimizer = 1;
optional Optimizer optimizer = 1;
optional SGDConfig sgd = 3;
optional AdadeltaConfig adadelta = 4;
optional AdagradConfig adagrad = 5;
......@@ -142,7 +142,7 @@ message OptimizerConfig {
Const = 0;
Linear = 1;
}
required LrPolicy lr_policy = 11;
optional LrPolicy lr_policy = 11;
optional ConstLrConfig const_lr = 12;
optional LinearLrConfig linear_lr = 13;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册