提交 ec65fa83 编写于 作者: D dzhwinter

"protobuf required to optional"

上级 65d9e33b
...@@ -26,7 +26,6 @@ const char *AdamOptimizer::SerializeState(int *state_len) { ...@@ -26,7 +26,6 @@ const char *AdamOptimizer::SerializeState(int *state_len) {
AdamOptimizerState state; AdamOptimizerState state;
// TODO(zhihong) : add lr_policy serialization // TODO(zhihong) : add lr_policy serialization
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*momentums_, state.mutable_momentums()); TensorToProto(*momentums_, state.mutable_momentums());
TensorToProto(*velocitys_, state.mutable_velocitys()); TensorToProto(*velocitys_, state.mutable_velocitys());
...@@ -42,6 +41,7 @@ void AdamOptimizer::DeserializeState(const std::string &str) { ...@@ -42,6 +41,7 @@ void AdamOptimizer::DeserializeState(const std::string &str) {
num_sample_passed_ = state.num_sample_passed(); num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_); ProtoToTensor(state.parameter(), parameter_);
ProtoToTensor(state.momentums(), momentums_);
ProtoToTensor(state.velocitys(), velocitys_); ProtoToTensor(state.velocitys(), velocitys_);
} }
} // namespace optimizer } // namespace optimizer
......
...@@ -45,11 +45,9 @@ public: ...@@ -45,11 +45,9 @@ public:
config_.mutable_sgd()->set_nesterov(false); config_.mutable_sgd()->set_nesterov(false);
config_.set_lr_policy(OptimizerConfig::Const); config_.set_lr_policy(OptimizerConfig::Const);
config_.mutable_const_lr()->set_learning_rate(0.1); config_.mutable_const_lr()->set_learning_rate(0.1);
std::string str = config_.SerializeAsString(); std::string str = config_.SerializeAsString();
ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter); ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter);
opts_.push_back(opt); opts_.push_back(opt);
opts_table_[opts_.size()] = OptimizerConfig::SGD;
} }
void CreateAdam() { void CreateAdam() {
...@@ -64,7 +62,6 @@ public: ...@@ -64,7 +62,6 @@ public:
std::string str = config_.SerializeAsString(); std::string str = config_.SerializeAsString();
ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter); ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter);
opts_.push_back(opt); opts_.push_back(opt);
opts_table_[opts_.size()] = OptimizerConfig::Adam;
} }
void TestGetWeight() { void TestGetWeight() {
...@@ -86,21 +83,15 @@ public: ...@@ -86,21 +83,15 @@ public:
} }
void TestCheckPoint() { void TestCheckPoint() {
std::map<OptimizerConfig::Optimizer, int> expected_state_len = {
{OptimizerConfig::SGD, kSize * sizeof(float) + sizeof(double)},
{OptimizerConfig::Adam, kSize * 3 * sizeof(float) + sizeof(double)},
};
for (size_t i = 0; i < opts_.size(); ++i) { for (size_t i = 0; i < opts_.size(); ++i) {
int state_len = 0; int state_len = 0;
std::string state = opts_[i]->SerializeState(&state_len); std::string state = opts_[i]->SerializeState(&state_len);
EXPECT_EQ(state_len, expected_state_len[opts_table_[i + 1]]);
opts_[i]->DeserializeState(state); opts_[i]->DeserializeState(state);
} }
} }
private: private:
std::vector<ParameterOptimizer*> opts_; std::vector<ParameterOptimizer*> opts_;
std::map<int, OptimizerConfig::Optimizer> opts_table_;
OptimizerConfig config_; OptimizerConfig config_;
}; };
......
...@@ -42,7 +42,7 @@ void SGDOptimizer::DeserializeState(const std::string &str) { ...@@ -42,7 +42,7 @@ void SGDOptimizer::DeserializeState(const std::string &str) {
state.ParseFromString(str); state.ParseFromString(str);
num_sample_passed_ = state.num_sample_passed(); num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_); ProtoToTensor(state.parameter(), parameter_);
ProtoToTensor(state.parameter(), momentums_); if (momentum_ != 0.0) ProtoToTensor(state.parameter(), momentums_);
} }
} // namespace optimizer } // namespace optimizer
......
...@@ -55,12 +55,12 @@ message AdamConfig { ...@@ -55,12 +55,12 @@ message AdamConfig {
message ConstLrConfig { message ConstLrConfig {
// learninRate Policy // learninRate Policy
required double learning_rate = 1 [default = 1.0]; optional double learning_rate = 1 [default = 1.0];
} }
message LinearLrConfig { message LinearLrConfig {
// learninRate Policy // learninRate Policy
required double learning_rate = 1 [default = 1.0]; optional double learning_rate = 1 [default = 1.0];
optional double lr_decay_a = 2; optional double lr_decay_a = 2;
optional double lr_decay_b = 3; optional double lr_decay_b = 3;
} }
...@@ -74,7 +74,7 @@ enum DataType { ...@@ -74,7 +74,7 @@ enum DataType {
PADDLE_ELEMENT_TYPE_FLOAT32 = 4; PADDLE_ELEMENT_TYPE_FLOAT32 = 4;
PADDLE_ELEMENT_TYPE_FLOAT64 = 5; PADDLE_ELEMENT_TYPE_FLOAT64 = 5;
} }
required DataType data_type = 1; optional DataType data_type = 1;
repeated bytes content = 2; repeated bytes content = 2;
} }
...@@ -132,7 +132,7 @@ message OptimizerConfig { ...@@ -132,7 +132,7 @@ message OptimizerConfig {
Adagrad = 3; Adagrad = 3;
Adam = 4; Adam = 4;
} }
required Optimizer optimizer = 1; optional Optimizer optimizer = 1;
optional SGDConfig sgd = 3; optional SGDConfig sgd = 3;
optional AdadeltaConfig adadelta = 4; optional AdadeltaConfig adadelta = 4;
optional AdagradConfig adagrad = 5; optional AdagradConfig adagrad = 5;
...@@ -142,7 +142,7 @@ message OptimizerConfig { ...@@ -142,7 +142,7 @@ message OptimizerConfig {
Const = 0; Const = 0;
Linear = 1; Linear = 1;
} }
required LrPolicy lr_policy = 11; optional LrPolicy lr_policy = 11;
optional ConstLrConfig const_lr = 12; optional ConstLrConfig const_lr = 12;
optional LinearLrConfig linear_lr = 13; optional LinearLrConfig linear_lr = 13;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册