diff --git a/paddle/optimizer/CMakeLists.txt b/paddle/optimizer/CMakeLists.txt index 746f4a69f84534cdc98e9285eddf559fea786239..e93ba102945425197348587d9b9ef6f8af2ffe04 100644 --- a/paddle/optimizer/CMakeLists.txt +++ b/paddle/optimizer/CMakeLists.txt @@ -6,12 +6,13 @@ set(OPITMIZER_SRCS adam_optimizer.cc optimizer.cc parameter_optimizer.cc - sgd_optmizer.cc + sgd_optimizer.cc ) add_library(optimizer STATIC ${OPITMIZER_SRCS}) add_dependencies(optimizer gen_proto_cpp) add_simple_unittest(tensor_test) +add_simple_unittest(serialization_test) add_simple_unittest(parameter_optimizer_test) add_dependencies(parameter_optimizer_test optimizer) diff --git a/paddle/optimizer/Tensor.h b/paddle/optimizer/Tensor.h index a00a59bc6af4d714a8250f0267017d1bbfc19108..b8f212e81ff7935ef14ff3896c328feb789e7bf6 100644 --- a/paddle/optimizer/Tensor.h +++ b/paddle/optimizer/Tensor.h @@ -17,16 +17,10 @@ public: TensorT(size_t size) : height_(1), width_(size) { data_ = new T[size]; } TensorT(T* data, size_t size) : height_(1), width_(size), data_(data) {} TensorT(T* data, size_t h, size_t w) : height_(h), width_(w), data_(data_) {} - TensorT(const TensorT& t) - : TensorT(1, t.size(), 0, t.get_buffer(), false, false) {} ~TensorT() { if (data_) delete data_; } - TensorT& operator=(const TensorT& t) { - this->width_ = t.size(); - this->data_ = t.get_buffer(); - } T* get_buffer() { return this->data_; } T& operator[](const size_t idx) { CHECK(idx >= 0 && idx < this->width_) << "out of index range"; diff --git a/paddle/optimizer/Tensor_test.cpp b/paddle/optimizer/Tensor_test.cpp deleted file mode 100644 index cdf733093235871ca2fdda36232a8dcdd9e22ba3..0000000000000000000000000000000000000000 --- a/paddle/optimizer/Tensor_test.cpp +++ /dev/null @@ -1,19 +0,0 @@ -#include "gtest/gtest.h" -#include "tensor.h" - -using namespace paddle; -using namespace paddle::optimizer; - -TEST(Tensor, indexer) { - Tensor t(3); - for (auto i = 0; i < t.size(); ++i) { - t[i] = i; - } - ASSERT_EQ(t[2], 2); - ASSERT_EQ(t[1], 1); -} - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/paddle/optimizer/adadelta_optimizer.cc b/paddle/optimizer/adadelta_optimizer.cc index a6b079ce256d83a3cab49802bfb83ab5020b17ad..d1c6571d9b47301c447cc005c2bd1b1931fbf4ee 100644 --- a/paddle/optimizer/adadelta_optimizer.cc +++ b/paddle/optimizer/adadelta_optimizer.cc @@ -26,7 +26,7 @@ void AdadeltaOptimizer::Update(const Tensor* gradient) { } const char* AdadeltaOptimizer::SerializeState(int* state_len) { - OptimizerState state; + AdadeltaOptimizerState state; state.set_learning_rate(lr_policy_->LearningRate(num_sample_passed_)); state.set_num_sample_passed(num_sample_passed_); @@ -34,22 +34,14 @@ const char* AdadeltaOptimizer::SerializeState(int* state_len) { TensorToProto(*accum_gradient_, state.mutable_accum_gradient()); TensorToProto(*accum_delta_, state.mutable_accum_delta()); TensorToProto(*update_delta_, state.mutable_update_delta()); - state.set_nesterov(epsilon_); - state.set_momentum(rho_); - state.set_decay(decay_); - // can be used when memory alignment to system - *state_len += CalStateSize(parameter_, - accum_gradient_, - accum_delta_, - update_delta_, - rho_, - epsilon_, - decay_); + + *state_len = + CalStateSize(parameter_, accum_gradient_, accum_delta_, update_delta_); return state.SerializeAsString().c_str(); } -void AdadeltaOptimizer::DeSerializeState(const std::string& str) { - OptimizerState state; +void AdadeltaOptimizer::DeserializeState(const std::string& str) { + AdadeltaOptimizerState state; state.ParseFromString(str); lr_policy_->set(state.learning_rate()); num_sample_passed_ = state.num_sample_passed(); @@ -58,6 +50,7 @@ void AdadeltaOptimizer::DeSerializeState(const std::string& str) { ProtoToTensor(state.accum_gradient(), accum_gradient_); ProtoToTensor(state.accum_delta(), accum_delta_); ProtoToTensor(state.update_delta(), update_delta_); +} } // namespace optimizer -} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/adadelta_optimizer.h b/paddle/optimizer/adadelta_optimizer.h index e0f544a90e5aceaff326c4e76d872481d6cb6907..58a26ebb7a73afa89992cbf3964767486f610690 100644 --- a/paddle/optimizer/adadelta_optimizer.h +++ b/paddle/optimizer/adadelta_optimizer.h @@ -13,7 +13,7 @@ public: rho_(rho), epsilon_(epsilon), decay_(decay) { - size_t size = p->size(); + size_t size = parameter->size(); if (accum_gradient_) delete accum_gradient_; accum_gradient_ = new Tensor(size); if (accum_delta_) delete accum_delta_; @@ -28,7 +28,7 @@ public: } void Update(const Tensor *gradient); const char *SerializeState(int *state_len); - void DeSerializeState(const std::string &state); + void DeserializeState(const std::string &state); private: Tensor *accum_gradient_; diff --git a/paddle/optimizer/adagrad_optimizer.cc b/paddle/optimizer/adagrad_optimizer.cc index 6a17cf0ed06a1c50dc81414312f6b071b2f42f18..ebc4d9e83ae8527253dd15fb4987c2148acf807b 100644 --- a/paddle/optimizer/adagrad_optimizer.cc +++ b/paddle/optimizer/adagrad_optimizer.cc @@ -17,8 +17,25 @@ void AdagradOptimizer::Update(const Tensor* gradient) { learning_rate * decay_ * param[i]; } } -const char* SGDOptimizer::SerializeState(int* state_len) { NIMPL; } +const char* AdagradOptimizer::SerializeState(int* state_len) { + AdagradOptimizerState state; + state.set_learning_rate(lr_policy_->LearningRate(num_sample_passed_)); + state.set_num_sample_passed(num_sample_passed_); + + TensorToProto(*parameter_, state.mutable_parameter()); + TensorToProto(*accum_gradient_, state.mutable_accum_gradient()); + *state_len = CalStateSize(parameter_, accum_gradient_); + return state.SerializeAsString().c_str(); +} + +void AdagradOptimizer::DeserializeState(const std::string& str) { + AdagradOptimizerState state; + state.ParseFromString(str); + lr_policy_->set(state.learning_rate()); + num_sample_passed_ = state.num_sample_passed(); + ProtoToTensor(state.parameter(), parameter_); + ProtoToTensor(state.accum_gradient(), accum_gradient_); +} -void SGDOptimizer::DeSerializeState(const std::string& str) { NIMPL; } -// namespace optimizer } // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/adagrad_optimizer.h b/paddle/optimizer/adagrad_optimizer.h index ebc0fe2acc6fdcba22bd9c08925f098b156bcccb..90fc1dd4ac900090d67ff41aa4131184eb3fb604 100644 --- a/paddle/optimizer/adagrad_optimizer.h +++ b/paddle/optimizer/adagrad_optimizer.h @@ -12,7 +12,7 @@ public: double epsilon, double decay) : ParameterOptimizer(parameter, lr), epsilon_(epsilon), decay_(decay) { - size_t size = p->size(); + size_t size = parameter->size(); if (accum_gradient_) delete accum_gradient_; accum_gradient_ = new Tensor(size); } @@ -21,7 +21,7 @@ public: } void Update(const Tensor *gradient); const char *SerializeState(int *state_len); - void DeSerializeState(const std::string &state); + void DeserializeState(const std::string &state); private: Tensor *accum_gradient_; diff --git a/paddle/optimizer/adam_optimizer.cc b/paddle/optimizer/adam_optimizer.cc index 974039cf6dcb3d65d6aa96f7c2148e78d03af7bb..53b3350d68f67dea31555b1952aeb9b47286cce4 100644 --- a/paddle/optimizer/adam_optimizer.cc +++ b/paddle/optimizer/adam_optimizer.cc @@ -22,32 +22,26 @@ void AdamOptimizer::Update(const Tensor *gradient) { } } -const char *AdadeltaOptimizer::SerializeState(int *state_len) { - OptimizerState state; +const char *AdamOptimizer::SerializeState(int *state_len) { + AdamOptimizerState state; state.set_learning_rate(lr_policy_->LearningRate(num_sample_passed_)); state.set_num_sample_passed(num_sample_passed_); TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*velocitys_, state.mutable_momentums()); - state.set_beta_1(beta_1_); - state.set_beta_2(beta_2_); - state.set_decay(decay_); - *state_len += CalStateSize( - parameter_, momentums_, velocitys_, beta_1_, beta_2, epsilon_ decay_); + *state_len = CalStateSize(parameter_, momentums_, velocitys_); return state.SerializeAsString().c_str(); } -void AdadeltaOptimizer::DeSerializeState(const std::string &str) { - OptimizerState state; +void AdamOptimizer::DeserializeState(const std::string &str) { + AdamOptimizerState state; state.ParseFromString(str); lr_policy_->set(state.learning_rate()); num_sample_passed_ = state.num_sample_passed(); ProtoToTensor(state.parameter(), parameter_); - ProtoToTensor(state.velocitys(), velocitys__); - beta_1_ = state.beta_1(); - beta_2_ = state.beta_2(); + ProtoToTensor(state.velocitys(), velocitys_); } } // namespace optimizer } // namespace paddle diff --git a/paddle/optimizer/adam_optimizer.h b/paddle/optimizer/adam_optimizer.h index b8be2ca2227ea3dbee93e267affbd1fc5faddcd9..04bc01154fb6b274b142db99b37328f6b0cb4d4a 100644 --- a/paddle/optimizer/adam_optimizer.h +++ b/paddle/optimizer/adam_optimizer.h @@ -8,7 +8,8 @@ namespace optimizer { class AdamOptimizer : public ParameterOptimizer { public: AdamOptimizer(Tensor *parameter, - LrPolicy *lr double beta_1, + LrPolicy *lr, + double beta_1, double beta_2, double epsilon, double decay) @@ -17,7 +18,7 @@ public: beta_2_(beta_2), epsilon_(epsilon), decay_(decay) { - size_t size = p->size(); + size_t size = parameter->size(); momentums_ = new Tensor(size); velocitys_ = new Tensor(size); } @@ -26,6 +27,8 @@ public: if (velocitys_) delete velocitys_; } void Update(const Tensor *gradient); + const char *SerializeState(int *state_len); + void DeserializeState(const std::string &state); private: Tensor *momentums_; diff --git a/paddle/optimizer/optimizer.cc b/paddle/optimizer/optimizer.cc index c06c0737b25d7d4979bd483a7b64130893646169..54662dc37891d3211950453b210db4b475837df4 100644 --- a/paddle/optimizer/optimizer.cc +++ b/paddle/optimizer/optimizer.cc @@ -49,7 +49,7 @@ paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto, optimizer->impl = ParameterOptimizer::Create(config, parameter); if (state != nullptr) { std::string s(state, state + state_len); - optimizer->impl->DeSerializeState(s); + optimizer->impl->DeserializeState(s); } return optimizer; } diff --git a/paddle/optimizer/optimizer.h b/paddle/optimizer/optimizer.h index a5f468b06b54c8d483ca7c81bac5ea54d4c5d6d4..aabf7a458dd30092ed1e522c4d88c6cfe63fcce1 100644 --- a/paddle/optimizer/optimizer.h +++ b/paddle/optimizer/optimizer.h @@ -75,14 +75,14 @@ int paddle_update_parameter(paddle_optimizer* o, int num_bytes); /** - * @brief optimizer instance + * @brief optimizer for get parameter buffer * @param param_buffer, initilized parameter buffer * @return return content length */ int paddle_optimizer_get_weights(paddle_optimizer* o, void** param_buffer); /** - * @brief optimzizer instance + * @brief optimzizer for saving training state * @param training state for receive SerializeState * @return return state_buffer length */ diff --git a/paddle/optimizer/parameter_optimizer.cc b/paddle/optimizer/parameter_optimizer.cc index ae3e97bba8d10916908a71d94cf7e905e56c2653..38df3b75d77ea7ecb243e26bb387cdcc5cd441ca 100644 --- a/paddle/optimizer/parameter_optimizer.cc +++ b/paddle/optimizer/parameter_optimizer.cc @@ -62,7 +62,7 @@ ParameterOptimizer *ParameterOptimizer::Create(const std::string &config_proto, << "have not select any Optimizer. use SGDOptimizer in default"; return new SGDOptimizer(parameter, lr, 0.0, 0.0, false); }; - return select_optimizer(config); + return select_optimizer(parameter, config); } float *ParameterOptimizer::get_weight(int *param_size) const { diff --git a/paddle/optimizer/parameter_optimizer.h b/paddle/optimizer/parameter_optimizer.h index 1abd659d4844bf6a0e99975664d7147da6ac5b14..658b22406d68bc1733e33d7d14de710f2eaeb1d9 100644 --- a/paddle/optimizer/parameter_optimizer.h +++ b/paddle/optimizer/parameter_optimizer.h @@ -8,10 +8,6 @@ #include "serialization.h" #include "tensor.h" -// Not Implemen Yet, macr -// o -#define NIMPL crash(__PRETTY_FUNCTION__, " not implemented yet") - namespace paddle { namespace optimizer { @@ -30,7 +26,7 @@ public: virtual void Update(const Tensor *gradient) = 0; virtual float *get_weight(int *param_size) const; virtual const char *SerializeState(int *state_len) = 0; - virtual void DeSerializeState(const std::string &state) = 0; + virtual void DeserializeState(const std::string &state) = 0; protected: Tensor *parameter_; diff --git a/paddle/optimizer/parameter_optimizer_test.cpp b/paddle/optimizer/parameter_optimizer_test.cpp index d39d50a1abf4651650252ef55440959d71a1c5ac..8d3bfa9cf4b75ffbc2474b7dc0c040c9fa06adbe 100644 --- a/paddle/optimizer/parameter_optimizer_test.cpp +++ b/paddle/optimizer/parameter_optimizer_test.cpp @@ -1,19 +1,20 @@ #include "parameter_optimizer.h" #include -#include +#include #include #include "adadelta_optimizer.h" #include "adagrad_optimizer.h" #include "adam_optimizer.h" #include "gtest/gtest.h" #include "sgd_optimizer.h" + using namespace paddle; using namespace paddle::optimizer; Tensor* FillTensor(size_t size) { Tensor* param = new Tensor(size); Tensor& p = *param; - for (auto i = 0; i < p.size(); ++i) { + for (size_t i = 0; i < p.size(); ++i) { p[i] = (float)rand() / (float)RAND_MAX; } return param; @@ -22,7 +23,7 @@ Tensor* FillTensor(size_t size) { Tensor* FixedTensor(size_t size) { Tensor* param = new Tensor(size); Tensor& p = *param; - for (auto i = 0; i < p.size(); ++i) { + for (size_t i = 0; i < p.size(); ++i) { p[i] = i; } return param; @@ -31,7 +32,7 @@ Tensor* FixedTensor(size_t size) { class OptimizerTest : public testing::Test { public: // init tensor shape - const size_t size = 5; + const size_t kSize = 5; virtual void SetUp() { CreateSGD(); @@ -40,68 +41,77 @@ public: virtual void TearDown() {} void CreateSGD() { - config.set_optimizer(OptimizerConfig::SGD); - config.mutable_sgd()->set_momentum(0.0); - config.mutable_sgd()->set_decay(0.0); - config.mutable_sgd()->set_nesterov(false); - config.set_lr_policy(OptimizerConfig::ConstLr); - config.mutable_const_lr()->set_learning_rate(0.1); + Tensor* parameter = FillTensor(kSize); + config_.set_optimizer(OptimizerConfig::SGD); + config_.mutable_sgd()->set_momentum(0.0); + config_.mutable_sgd()->set_decay(0.0); + config_.mutable_sgd()->set_nesterov(false); + config_.set_lr_policy(OptimizerConfig::ConstLr); + config_.mutable_const_lr()->set_learning_rate(0.1); ParameterOptimizer* opt = - ParameterOptimizer::Create(config.SerializeAsString()); - opts.push_back(opt); + ParameterOptimizer::Create(config_.SerializeAsString(), parameter); + opts_.push_back(opt); + opts_table_[opts_.size()] = OptimizerConfig::SGD; } void CreateAdam() { - config.set_optimizer(OptimizerConfig::Adam); - config.mutable_adam()->set_beta_1(0.9); - config.mutable_adam()->set_beta_2(0.1); - config.mutable_adam()->set_epsilon(1e-3); - config.mutable_adam()->set_decay(0.0); - config.set_lr_policy(OptimizerConfig::ConstLr); - config.mutable_const_lr()->set_learning_rate(0.1); + Tensor* parameter = FixedTensor(kSize); + config_.set_optimizer(OptimizerConfig::Adam); + config_.mutable_adam()->set_beta_1(0.9); + config_.mutable_adam()->set_beta_2(0.1); + config_.mutable_adam()->set_epsilon(1e-3); + config_.mutable_adam()->set_decay(0.0); + config_.set_lr_policy(OptimizerConfig::ConstLr); + config_.mutable_const_lr()->set_learning_rate(0.1); ParameterOptimizer* opt = - ParameterOptimizer::Create(config.SerializeAsString()); - opts.push_back(opt); - } - void TestSetWeight() { - Tensor* p = FillTensor(size); - for (size_t i = 0; i < opts.size(); ++i) { - opts[i]->set_weight(p); - } + ParameterOptimizer::Create(config_.SerializeAsString(), parameter); + opts_.push_back(opt); + opts_table_[opts_.size()] = OptimizerConfig::Adam; } void TestGetWeight() { - Tensor* p = FixedTensor(size); - for (size_t i = 0; i < opts.size(); ++i) { - opts[i]->set_weight(p); - } - for (size_t i = 0; i < opts.size(); ++i) { + Tensor* p = FixedTensor(kSize); + for (size_t i = 0; i < opts_.size(); ++i) { int s = 0; - float* newp = (float*)opts[i]->get_weight(&s); - for (size_t j = 0; j < size; ++j) { + float* newp = (float*)opts_[i]->get_weight(&s); + for (size_t j = 0; j < kSize; ++j) { EXPECT_EQ(newp[j], (*p)[j]); } } } + void TestUpdate() { - Tensor* g = FixedTensor(size); - for (size_t i = 0; i < opts.size(); ++i) { - opts[i]->Update(g); + Tensor* g = FixedTensor(kSize); + for (size_t i = 0; i < opts_.size(); ++i) { + opts_[i]->Update(g); + } + } + + void TestCheckPoint() { + std::map expected_state_len = { + {OptimizerConfig::SGD, kSize}, {OptimizerConfig::Adam, kSize * 3}, + }; + for (size_t i = 0; i < opts_.size(); ++i) { + int state_len = 0; + std::string state = opts_[i]->SerializeState(&state_len); + EXPECT_EQ(state_len, expected_state_len[opts_table_[i]]); + opts_[i]->DeserializeState(state); } } private: - std::vector opts; - OptimizerConfig config; + std::vector opts_; + std::map opts_table_; + OptimizerConfig config_; }; -TEST_F(OptimizerTest, test_set_get_weight) { - TestSetWeight(); - TestGetWeight(); -} +TEST_F(OptimizerTest, TestGetWeight) { TestGetWeight(); } + TEST_F(OptimizerTest, TestUpdate) { TestUpdate(); } +TEST_F(OptimizerTest, TestCheckPoint) { TestCheckPoint(); } + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); diff --git a/paddle/optimizer/serialization.h b/paddle/optimizer/serialization.h index 60bb7e2765156e29f12fce14baef8aae0d2bb762..07874502a5015235cdb6a00d1032c0178e10b1c7 100644 --- a/paddle/optimizer/serialization.h +++ b/paddle/optimizer/serialization.h @@ -10,15 +10,16 @@ namespace paddle { namespace optimizer { -static unsigned CalStateSize(int* state_len) { return 0; } +static unsigned CalStateSize() { return 0; } template unsigned CalStateSize(const HEAD& head, const TAIL&... tail) { - if (std::is_fundamental::value) { - return sizeof head + CalStateSize(tail...); - } else { - return sizeof(head[0]) * head->size() + CalStateSize(tail...); - } + return sizeof head + CalStateSize(tail...); +} + +template +unsigned CalStateSize(const Tensor* head, const TAIL&... tail) { + return head->size() + CalStateSize(tail...); } static void TensorToProto(const Tensor& tensor, TensorProto* proto) { @@ -32,7 +33,6 @@ static void TensorToProto(const Tensor& tensor, TensorProto* proto) { } static void ProtoToTensor(const TensorProto& proto, Tensor* tensor) { - CHECK(proto.size() == tensor->size()) << "unmatch shape of proto and tensor"; std::stringstream sin; for (auto i = 0; i < proto.content_size(); ++i) { sin << proto.content(i); diff --git a/paddle/optimizer/serialization_test.cpp b/paddle/optimizer/serialization_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..98fbdf5a5e2cb703440b4178eee6a324a9e9b3cf --- /dev/null +++ b/paddle/optimizer/serialization_test.cpp @@ -0,0 +1,24 @@ +#include "serialization.h" +#include "gtest/gtest.h" + +using namespace paddle; +using namespace paddle::optimizer; + +TEST(TensorToProto, Case1) { + Tensor t(3), t1(3); + for (size_t i = 0; i < t.size(); ++i) { + t[i] = i; + t1[i] = 0; + } + TensorProto proto; + TensorToProto(t, &proto); + ProtoToTensor(proto, &t1); + for (size_t i = 0; i < t1.size(); ++i) { + EXPECT_EQ(t1[i], t[i]); + } +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/paddle/optimizer/sgd_optmizer.cc b/paddle/optimizer/sgd_optimizer.cc similarity index 81% rename from paddle/optimizer/sgd_optmizer.cc rename to paddle/optimizer/sgd_optimizer.cc index b2c6b7a1acf2d6456565b3fd2c06caa9875c67c1..8b4ea777d2dab7346ca1148816756c91423d3bc3 100644 --- a/paddle/optimizer/sgd_optmizer.cc +++ b/paddle/optimizer/sgd_optimizer.cc @@ -1,5 +1,5 @@ -#include "serialization.h" #include "sgd_optimizer.h" +#include "serialization.h" namespace paddle { namespace optimizer { @@ -28,29 +28,24 @@ void SGDOptimizer::Update(const Tensor *gradient) { } const char *SGDOptimizer::SerializeState(int *state_len) { - OptimizerState state; + SGDOptimizerState state; state.set_learning_rate(lr_policy_->LearningRate(num_sample_passed_)); state.set_num_sample_passed(num_sample_passed_); TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*momentums_, state.mutable_momentums()); - state.set_momentum(momentum_); - state.set_decay(decay_); - state.set_nesterov(nesterov_); - *state_len += - CalStateSize(parameter_, momentums_, momentum_, decay_, nesterov_); + *state_len = CalStateSize(parameter_, momentums_); return state.SerializeAsString().c_str(); } -void SGDOptimizer::DeSerializeState(const std::string &str) { - OptimizerState state; +void SGDOptimizer::DeserializeState(const std::string &str) { + SGDOptimizerState state; state.ParseFromString(str); lr_policy_->set(state.learning_rate()); num_sample_passed_ = state.num_sample_passed(); ProtoToTensor(state.parameter(), parameter_); ProtoToTensor(state.parameter(), momentums_); - momentum_ = state.momentum(); } } // namespace optimizer diff --git a/paddle/optimizer/sgd_optimizer.h b/paddle/optimizer/sgd_optimizer.h index d0ac375d2b8f13383a4281db81e4d832dbbf6b8e..1d4ea46f1a4aadfb6d9853f3c6d1c6bcf2ee1f48 100644 --- a/paddle/optimizer/sgd_optimizer.h +++ b/paddle/optimizer/sgd_optimizer.h @@ -14,7 +14,7 @@ public: decay_(d), nesterov_(n) { if (momentum_ != 0.0) { - size_t size = p->size(); + size_t size = parameter->size(); // TODO: fix it with align aware allocator bind to Tensor if (momentums_) delete momentums_; momentums_ = new Tensor(size); @@ -25,7 +25,7 @@ public: } void Update(const Tensor* gradient); const char* SerializeState(int* state_len); - void DeSerializeState(const std::string& state); + void DeserializeState(const std::string& state); private: Tensor* momentums_; diff --git a/proto/OptimizerConfig.proto b/proto/OptimizerConfig.proto index 1ccba6d207612162a76756b82af196de2e3fbfd6..aab2fdad693f1e2ec5fdf9ae0f1843e6f196aa7f 100644 --- a/proto/OptimizerConfig.proto +++ b/proto/OptimizerConfig.proto @@ -78,36 +78,51 @@ enum DataType { repeated bytes content = 2; } -message OptimizerState { +message SGDOptimizerState { + // learning rate policy optional double learning_rate = 101; optional double lr_decay_a = 102; optional double lr_decay_b = 103; optional double num_sample_passed = 104; - // momentum - optional TensorProto parameter = 105; - optional TensorProto momentums = 1; + // state + optional TensorProto parameter = 1; + optional TensorProto momentums = 2; +} - // adadelta +message AdadeltaOptimizerState { + // learning rate policy + optional double learning_rate = 101; + optional double lr_decay_a = 102; + optional double lr_decay_b = 103; + optional double num_sample_passed = 104; + // state + optional TensorProto parameter = 1; optional TensorProto accum_gradient = 2; optional TensorProto accum_delta = 3; optional TensorProto update_delta = 4; +} - // adam - optional TensorProto velocitys = 5; - - // momentum - optional double momentum = 6; - optional double decay = 7; - optional bool nesterov = 8; - - // adadelta - optional double rho = 9; - optional double epsilon = 10; - - // adam - optional double beta_1 = 11; - optional double beta_2 = 12; +message AdagradOptimizerState { + // learning rate policy + optional double learning_rate = 101; + optional double lr_decay_a = 102; + optional double lr_decay_b = 103; + optional double num_sample_passed = 104; + // state + optional TensorProto parameter = 1; + optional TensorProto accum_gradient = 2; +} +message AdamOptimizerState { + // learning rate policy + optional double learning_rate = 101; + optional double lr_decay_a = 102; + optional double lr_decay_b = 103; + optional double num_sample_passed = 104; + // state + optional TensorProto parameter = 1; + optional TensorProto momentums = 2; + optional TensorProto velocitys = 3; } message OptimizerConfig {