提交 b72e8aa3 编写于 作者: D dzhwinter

"seperate serialization proto state"

上级 b7e68e06
...@@ -6,12 +6,13 @@ set(OPITMIZER_SRCS ...@@ -6,12 +6,13 @@ set(OPITMIZER_SRCS
adam_optimizer.cc adam_optimizer.cc
optimizer.cc optimizer.cc
parameter_optimizer.cc parameter_optimizer.cc
sgd_optmizer.cc sgd_optimizer.cc
) )
add_library(optimizer STATIC ${OPITMIZER_SRCS}) add_library(optimizer STATIC ${OPITMIZER_SRCS})
add_dependencies(optimizer gen_proto_cpp) add_dependencies(optimizer gen_proto_cpp)
add_simple_unittest(tensor_test) add_simple_unittest(tensor_test)
add_simple_unittest(serialization_test)
add_simple_unittest(parameter_optimizer_test) add_simple_unittest(parameter_optimizer_test)
add_dependencies(parameter_optimizer_test optimizer) add_dependencies(parameter_optimizer_test optimizer)
...@@ -17,16 +17,10 @@ public: ...@@ -17,16 +17,10 @@ public:
TensorT(size_t size) : height_(1), width_(size) { data_ = new T[size]; } TensorT(size_t size) : height_(1), width_(size) { data_ = new T[size]; }
TensorT(T* data, size_t size) : height_(1), width_(size), data_(data) {} TensorT(T* data, size_t size) : height_(1), width_(size), data_(data) {}
TensorT(T* data, size_t h, size_t w) : height_(h), width_(w), data_(data_) {} TensorT(T* data, size_t h, size_t w) : height_(h), width_(w), data_(data_) {}
TensorT(const TensorT& t)
: TensorT(1, t.size(), 0, t.get_buffer(), false, false) {}
~TensorT() { ~TensorT() {
if (data_) delete data_; if (data_) delete data_;
} }
TensorT& operator=(const TensorT& t) {
this->width_ = t.size();
this->data_ = t.get_buffer();
}
T* get_buffer() { return this->data_; } T* get_buffer() { return this->data_; }
T& operator[](const size_t idx) { T& operator[](const size_t idx) {
CHECK(idx >= 0 && idx < this->width_) << "out of index range"; CHECK(idx >= 0 && idx < this->width_) << "out of index range";
......
...@@ -26,7 +26,7 @@ void AdadeltaOptimizer::Update(const Tensor* gradient) { ...@@ -26,7 +26,7 @@ void AdadeltaOptimizer::Update(const Tensor* gradient) {
} }
const char* AdadeltaOptimizer::SerializeState(int* state_len) { const char* AdadeltaOptimizer::SerializeState(int* state_len) {
OptimizerState state; AdadeltaOptimizerState state;
state.set_learning_rate(lr_policy_->LearningRate(num_sample_passed_)); state.set_learning_rate(lr_policy_->LearningRate(num_sample_passed_));
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
...@@ -34,22 +34,14 @@ const char* AdadeltaOptimizer::SerializeState(int* state_len) { ...@@ -34,22 +34,14 @@ const char* AdadeltaOptimizer::SerializeState(int* state_len) {
TensorToProto(*accum_gradient_, state.mutable_accum_gradient()); TensorToProto(*accum_gradient_, state.mutable_accum_gradient());
TensorToProto(*accum_delta_, state.mutable_accum_delta()); TensorToProto(*accum_delta_, state.mutable_accum_delta());
TensorToProto(*update_delta_, state.mutable_update_delta()); TensorToProto(*update_delta_, state.mutable_update_delta());
state.set_nesterov(epsilon_);
state.set_momentum(rho_); *state_len =
state.set_decay(decay_); CalStateSize(parameter_, accum_gradient_, accum_delta_, update_delta_);
// can be used when memory alignment to system
*state_len += CalStateSize(parameter_,
accum_gradient_,
accum_delta_,
update_delta_,
rho_,
epsilon_,
decay_);
return state.SerializeAsString().c_str(); return state.SerializeAsString().c_str();
} }
void AdadeltaOptimizer::DeSerializeState(const std::string& str) { void AdadeltaOptimizer::DeserializeState(const std::string& str) {
OptimizerState state; AdadeltaOptimizerState state;
state.ParseFromString(str); state.ParseFromString(str);
lr_policy_->set(state.learning_rate()); lr_policy_->set(state.learning_rate());
num_sample_passed_ = state.num_sample_passed(); num_sample_passed_ = state.num_sample_passed();
...@@ -58,6 +50,7 @@ void AdadeltaOptimizer::DeSerializeState(const std::string& str) { ...@@ -58,6 +50,7 @@ void AdadeltaOptimizer::DeSerializeState(const std::string& str) {
ProtoToTensor(state.accum_gradient(), accum_gradient_); ProtoToTensor(state.accum_gradient(), accum_gradient_);
ProtoToTensor(state.accum_delta(), accum_delta_); ProtoToTensor(state.accum_delta(), accum_delta_);
ProtoToTensor(state.update_delta(), update_delta_); ProtoToTensor(state.update_delta(), update_delta_);
}
} // namespace optimizer } // namespace optimizer
} // namespace optimizer } // namespace paddle
...@@ -13,7 +13,7 @@ public: ...@@ -13,7 +13,7 @@ public:
rho_(rho), rho_(rho),
epsilon_(epsilon), epsilon_(epsilon),
decay_(decay) { decay_(decay) {
size_t size = p->size(); size_t size = parameter->size();
if (accum_gradient_) delete accum_gradient_; if (accum_gradient_) delete accum_gradient_;
accum_gradient_ = new Tensor(size); accum_gradient_ = new Tensor(size);
if (accum_delta_) delete accum_delta_; if (accum_delta_) delete accum_delta_;
...@@ -28,7 +28,7 @@ public: ...@@ -28,7 +28,7 @@ public:
} }
void Update(const Tensor *gradient); void Update(const Tensor *gradient);
const char *SerializeState(int *state_len); const char *SerializeState(int *state_len);
void DeSerializeState(const std::string &state); void DeserializeState(const std::string &state);
private: private:
Tensor *accum_gradient_; Tensor *accum_gradient_;
......
...@@ -17,8 +17,25 @@ void AdagradOptimizer::Update(const Tensor* gradient) { ...@@ -17,8 +17,25 @@ void AdagradOptimizer::Update(const Tensor* gradient) {
learning_rate * decay_ * param[i]; learning_rate * decay_ * param[i];
} }
} }
const char* SGDOptimizer::SerializeState(int* state_len) { NIMPL; } const char* AdagradOptimizer::SerializeState(int* state_len) {
AdagradOptimizerState state;
state.set_learning_rate(lr_policy_->LearningRate(num_sample_passed_));
state.set_num_sample_passed(num_sample_passed_);
TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*accum_gradient_, state.mutable_accum_gradient());
*state_len = CalStateSize(parameter_, accum_gradient_);
return state.SerializeAsString().c_str();
}
void AdagradOptimizer::DeserializeState(const std::string& str) {
AdagradOptimizerState state;
state.ParseFromString(str);
lr_policy_->set(state.learning_rate());
num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_);
ProtoToTensor(state.accum_gradient(), accum_gradient_);
}
void SGDOptimizer::DeSerializeState(const std::string& str) { NIMPL; }
// namespace optimizer
} // namespace optimizer } // namespace optimizer
} // namespace paddle
...@@ -12,7 +12,7 @@ public: ...@@ -12,7 +12,7 @@ public:
double epsilon, double epsilon,
double decay) double decay)
: ParameterOptimizer(parameter, lr), epsilon_(epsilon), decay_(decay) { : ParameterOptimizer(parameter, lr), epsilon_(epsilon), decay_(decay) {
size_t size = p->size(); size_t size = parameter->size();
if (accum_gradient_) delete accum_gradient_; if (accum_gradient_) delete accum_gradient_;
accum_gradient_ = new Tensor(size); accum_gradient_ = new Tensor(size);
} }
...@@ -21,7 +21,7 @@ public: ...@@ -21,7 +21,7 @@ public:
} }
void Update(const Tensor *gradient); void Update(const Tensor *gradient);
const char *SerializeState(int *state_len); const char *SerializeState(int *state_len);
void DeSerializeState(const std::string &state); void DeserializeState(const std::string &state);
private: private:
Tensor *accum_gradient_; Tensor *accum_gradient_;
......
...@@ -22,32 +22,26 @@ void AdamOptimizer::Update(const Tensor *gradient) { ...@@ -22,32 +22,26 @@ void AdamOptimizer::Update(const Tensor *gradient) {
} }
} }
const char *AdadeltaOptimizer::SerializeState(int *state_len) { const char *AdamOptimizer::SerializeState(int *state_len) {
OptimizerState state; AdamOptimizerState state;
state.set_learning_rate(lr_policy_->LearningRate(num_sample_passed_)); state.set_learning_rate(lr_policy_->LearningRate(num_sample_passed_));
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*velocitys_, state.mutable_momentums()); TensorToProto(*velocitys_, state.mutable_momentums());
state.set_beta_1(beta_1_); *state_len = CalStateSize(parameter_, momentums_, velocitys_);
state.set_beta_2(beta_2_);
state.set_decay(decay_);
*state_len += CalStateSize(
parameter_, momentums_, velocitys_, beta_1_, beta_2, epsilon_ decay_);
return state.SerializeAsString().c_str(); return state.SerializeAsString().c_str();
} }
void AdadeltaOptimizer::DeSerializeState(const std::string &str) { void AdamOptimizer::DeserializeState(const std::string &str) {
OptimizerState state; AdamOptimizerState state;
state.ParseFromString(str); state.ParseFromString(str);
lr_policy_->set(state.learning_rate()); lr_policy_->set(state.learning_rate());
num_sample_passed_ = state.num_sample_passed(); num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_); ProtoToTensor(state.parameter(), parameter_);
ProtoToTensor(state.velocitys(), velocitys__); ProtoToTensor(state.velocitys(), velocitys_);
beta_1_ = state.beta_1();
beta_2_ = state.beta_2();
} }
} // namespace optimizer } // namespace optimizer
} // namespace paddle } // namespace paddle
...@@ -8,7 +8,8 @@ namespace optimizer { ...@@ -8,7 +8,8 @@ namespace optimizer {
class AdamOptimizer : public ParameterOptimizer { class AdamOptimizer : public ParameterOptimizer {
public: public:
AdamOptimizer(Tensor *parameter, AdamOptimizer(Tensor *parameter,
LrPolicy *lr double beta_1, LrPolicy *lr,
double beta_1,
double beta_2, double beta_2,
double epsilon, double epsilon,
double decay) double decay)
...@@ -17,7 +18,7 @@ public: ...@@ -17,7 +18,7 @@ public:
beta_2_(beta_2), beta_2_(beta_2),
epsilon_(epsilon), epsilon_(epsilon),
decay_(decay) { decay_(decay) {
size_t size = p->size(); size_t size = parameter->size();
momentums_ = new Tensor(size); momentums_ = new Tensor(size);
velocitys_ = new Tensor(size); velocitys_ = new Tensor(size);
} }
...@@ -26,6 +27,8 @@ public: ...@@ -26,6 +27,8 @@ public:
if (velocitys_) delete velocitys_; if (velocitys_) delete velocitys_;
} }
void Update(const Tensor *gradient); void Update(const Tensor *gradient);
const char *SerializeState(int *state_len);
void DeserializeState(const std::string &state);
private: private:
Tensor *momentums_; Tensor *momentums_;
......
...@@ -49,7 +49,7 @@ paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto, ...@@ -49,7 +49,7 @@ paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto,
optimizer->impl = ParameterOptimizer::Create(config, parameter); optimizer->impl = ParameterOptimizer::Create(config, parameter);
if (state != nullptr) { if (state != nullptr) {
std::string s(state, state + state_len); std::string s(state, state + state_len);
optimizer->impl->DeSerializeState(s); optimizer->impl->DeserializeState(s);
} }
return optimizer; return optimizer;
} }
......
...@@ -75,14 +75,14 @@ int paddle_update_parameter(paddle_optimizer* o, ...@@ -75,14 +75,14 @@ int paddle_update_parameter(paddle_optimizer* o,
int num_bytes); int num_bytes);
/** /**
* @brief optimizer instance * @brief optimizer for get parameter buffer
* @param param_buffer, initilized parameter buffer * @param param_buffer, initilized parameter buffer
* @return return content length * @return return content length
*/ */
int paddle_optimizer_get_weights(paddle_optimizer* o, void** param_buffer); int paddle_optimizer_get_weights(paddle_optimizer* o, void** param_buffer);
/** /**
* @brief optimzizer instance * @brief optimzizer for saving training state
* @param training state for receive SerializeState * @param training state for receive SerializeState
* @return return state_buffer length * @return return state_buffer length
*/ */
......
...@@ -62,7 +62,7 @@ ParameterOptimizer *ParameterOptimizer::Create(const std::string &config_proto, ...@@ -62,7 +62,7 @@ ParameterOptimizer *ParameterOptimizer::Create(const std::string &config_proto,
<< "have not select any Optimizer. use SGDOptimizer in default"; << "have not select any Optimizer. use SGDOptimizer in default";
return new SGDOptimizer(parameter, lr, 0.0, 0.0, false); return new SGDOptimizer(parameter, lr, 0.0, 0.0, false);
}; };
return select_optimizer(config); return select_optimizer(parameter, config);
} }
float *ParameterOptimizer::get_weight(int *param_size) const { float *ParameterOptimizer::get_weight(int *param_size) const {
......
...@@ -8,10 +8,6 @@ ...@@ -8,10 +8,6 @@
#include "serialization.h" #include "serialization.h"
#include "tensor.h" #include "tensor.h"
// Not Implemen Yet, macr
// o
#define NIMPL crash(__PRETTY_FUNCTION__, " not implemented yet")
namespace paddle { namespace paddle {
namespace optimizer { namespace optimizer {
...@@ -30,7 +26,7 @@ public: ...@@ -30,7 +26,7 @@ public:
virtual void Update(const Tensor *gradient) = 0; virtual void Update(const Tensor *gradient) = 0;
virtual float *get_weight(int *param_size) const; virtual float *get_weight(int *param_size) const;
virtual const char *SerializeState(int *state_len) = 0; virtual const char *SerializeState(int *state_len) = 0;
virtual void DeSerializeState(const std::string &state) = 0; virtual void DeserializeState(const std::string &state) = 0;
protected: protected:
Tensor *parameter_; Tensor *parameter_;
......
#include "parameter_optimizer.h" #include "parameter_optimizer.h"
#include <cmath> #include <cmath>
#include <tuple> #include <map>
#include <vector> #include <vector>
#include "adadelta_optimizer.h" #include "adadelta_optimizer.h"
#include "adagrad_optimizer.h" #include "adagrad_optimizer.h"
#include "adam_optimizer.h" #include "adam_optimizer.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "sgd_optimizer.h" #include "sgd_optimizer.h"
using namespace paddle; using namespace paddle;
using namespace paddle::optimizer; using namespace paddle::optimizer;
Tensor* FillTensor(size_t size) { Tensor* FillTensor(size_t size) {
Tensor* param = new Tensor(size); Tensor* param = new Tensor(size);
Tensor& p = *param; Tensor& p = *param;
for (auto i = 0; i < p.size(); ++i) { for (size_t i = 0; i < p.size(); ++i) {
p[i] = (float)rand() / (float)RAND_MAX; p[i] = (float)rand() / (float)RAND_MAX;
} }
return param; return param;
...@@ -22,7 +23,7 @@ Tensor* FillTensor(size_t size) { ...@@ -22,7 +23,7 @@ Tensor* FillTensor(size_t size) {
Tensor* FixedTensor(size_t size) { Tensor* FixedTensor(size_t size) {
Tensor* param = new Tensor(size); Tensor* param = new Tensor(size);
Tensor& p = *param; Tensor& p = *param;
for (auto i = 0; i < p.size(); ++i) { for (size_t i = 0; i < p.size(); ++i) {
p[i] = i; p[i] = i;
} }
return param; return param;
...@@ -31,7 +32,7 @@ Tensor* FixedTensor(size_t size) { ...@@ -31,7 +32,7 @@ Tensor* FixedTensor(size_t size) {
class OptimizerTest : public testing::Test { class OptimizerTest : public testing::Test {
public: public:
// init tensor shape // init tensor shape
const size_t size = 5; const size_t kSize = 5;
virtual void SetUp() { virtual void SetUp() {
CreateSGD(); CreateSGD();
...@@ -40,68 +41,77 @@ public: ...@@ -40,68 +41,77 @@ public:
virtual void TearDown() {} virtual void TearDown() {}
void CreateSGD() { void CreateSGD() {
config.set_optimizer(OptimizerConfig::SGD); Tensor* parameter = FillTensor(kSize);
config.mutable_sgd()->set_momentum(0.0); config_.set_optimizer(OptimizerConfig::SGD);
config.mutable_sgd()->set_decay(0.0); config_.mutable_sgd()->set_momentum(0.0);
config.mutable_sgd()->set_nesterov(false); config_.mutable_sgd()->set_decay(0.0);
config.set_lr_policy(OptimizerConfig::ConstLr); config_.mutable_sgd()->set_nesterov(false);
config.mutable_const_lr()->set_learning_rate(0.1); config_.set_lr_policy(OptimizerConfig::ConstLr);
config_.mutable_const_lr()->set_learning_rate(0.1);
ParameterOptimizer* opt = ParameterOptimizer* opt =
ParameterOptimizer::Create(config.SerializeAsString()); ParameterOptimizer::Create(config_.SerializeAsString(), parameter);
opts.push_back(opt); opts_.push_back(opt);
opts_table_[opts_.size()] = OptimizerConfig::SGD;
} }
void CreateAdam() { void CreateAdam() {
config.set_optimizer(OptimizerConfig::Adam); Tensor* parameter = FixedTensor(kSize);
config.mutable_adam()->set_beta_1(0.9); config_.set_optimizer(OptimizerConfig::Adam);
config.mutable_adam()->set_beta_2(0.1); config_.mutable_adam()->set_beta_1(0.9);
config.mutable_adam()->set_epsilon(1e-3); config_.mutable_adam()->set_beta_2(0.1);
config.mutable_adam()->set_decay(0.0); config_.mutable_adam()->set_epsilon(1e-3);
config.set_lr_policy(OptimizerConfig::ConstLr); config_.mutable_adam()->set_decay(0.0);
config.mutable_const_lr()->set_learning_rate(0.1); config_.set_lr_policy(OptimizerConfig::ConstLr);
config_.mutable_const_lr()->set_learning_rate(0.1);
ParameterOptimizer* opt = ParameterOptimizer* opt =
ParameterOptimizer::Create(config.SerializeAsString()); ParameterOptimizer::Create(config_.SerializeAsString(), parameter);
opts.push_back(opt); opts_.push_back(opt);
} opts_table_[opts_.size()] = OptimizerConfig::Adam;
void TestSetWeight() {
Tensor* p = FillTensor(size);
for (size_t i = 0; i < opts.size(); ++i) {
opts[i]->set_weight(p);
}
} }
void TestGetWeight() { void TestGetWeight() {
Tensor* p = FixedTensor(size); Tensor* p = FixedTensor(kSize);
for (size_t i = 0; i < opts.size(); ++i) { for (size_t i = 0; i < opts_.size(); ++i) {
opts[i]->set_weight(p);
}
for (size_t i = 0; i < opts.size(); ++i) {
int s = 0; int s = 0;
float* newp = (float*)opts[i]->get_weight(&s); float* newp = (float*)opts_[i]->get_weight(&s);
for (size_t j = 0; j < size; ++j) { for (size_t j = 0; j < kSize; ++j) {
EXPECT_EQ(newp[j], (*p)[j]); EXPECT_EQ(newp[j], (*p)[j]);
} }
} }
} }
void TestUpdate() { void TestUpdate() {
Tensor* g = FixedTensor(size); Tensor* g = FixedTensor(kSize);
for (size_t i = 0; i < opts.size(); ++i) { for (size_t i = 0; i < opts_.size(); ++i) {
opts[i]->Update(g); opts_[i]->Update(g);
}
}
void TestCheckPoint() {
std::map<OptimizerConfig::Optimizer, int> expected_state_len = {
{OptimizerConfig::SGD, kSize}, {OptimizerConfig::Adam, kSize * 3},
};
for (size_t i = 0; i < opts_.size(); ++i) {
int state_len = 0;
std::string state = opts_[i]->SerializeState(&state_len);
EXPECT_EQ(state_len, expected_state_len[opts_table_[i]]);
opts_[i]->DeserializeState(state);
} }
} }
private: private:
std::vector<ParameterOptimizer*> opts; std::vector<ParameterOptimizer*> opts_;
OptimizerConfig config; std::map<int, OptimizerConfig::Optimizer> opts_table_;
OptimizerConfig config_;
}; };
TEST_F(OptimizerTest, test_set_get_weight) { TEST_F(OptimizerTest, TestGetWeight) { TestGetWeight(); }
TestSetWeight();
TestGetWeight();
}
TEST_F(OptimizerTest, TestUpdate) { TestUpdate(); } TEST_F(OptimizerTest, TestUpdate) { TestUpdate(); }
TEST_F(OptimizerTest, TestCheckPoint) { TestCheckPoint(); }
int main(int argc, char** argv) { int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS(); return RUN_ALL_TESTS();
......
...@@ -10,15 +10,16 @@ ...@@ -10,15 +10,16 @@
namespace paddle { namespace paddle {
namespace optimizer { namespace optimizer {
static unsigned CalStateSize(int* state_len) { return 0; } static unsigned CalStateSize() { return 0; }
template <typename HEAD, typename... TAIL> template <typename HEAD, typename... TAIL>
unsigned CalStateSize(const HEAD& head, const TAIL&... tail) { unsigned CalStateSize(const HEAD& head, const TAIL&... tail) {
if (std::is_fundamental<HEAD>::value) { return sizeof head + CalStateSize(tail...);
return sizeof head + CalStateSize(tail...); }
} else {
return sizeof(head[0]) * head->size() + CalStateSize(tail...); template <typename... TAIL>
} unsigned CalStateSize(const Tensor* head, const TAIL&... tail) {
return head->size() + CalStateSize(tail...);
} }
static void TensorToProto(const Tensor& tensor, TensorProto* proto) { static void TensorToProto(const Tensor& tensor, TensorProto* proto) {
...@@ -32,7 +33,6 @@ static void TensorToProto(const Tensor& tensor, TensorProto* proto) { ...@@ -32,7 +33,6 @@ static void TensorToProto(const Tensor& tensor, TensorProto* proto) {
} }
static void ProtoToTensor(const TensorProto& proto, Tensor* tensor) { static void ProtoToTensor(const TensorProto& proto, Tensor* tensor) {
CHECK(proto.size() == tensor->size()) << "unmatch shape of proto and tensor";
std::stringstream sin; std::stringstream sin;
for (auto i = 0; i < proto.content_size(); ++i) { for (auto i = 0; i < proto.content_size(); ++i) {
sin << proto.content(i); sin << proto.content(i);
......
#include "serialization.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "tensor.h"
using namespace paddle; using namespace paddle;
using namespace paddle::optimizer; using namespace paddle::optimizer;
TEST(Tensor, indexer) { TEST(TensorToProto, Case1) {
Tensor t(3); Tensor t(3), t1(3);
for (auto i = 0; i < t.size(); ++i) { for (size_t i = 0; i < t.size(); ++i) {
t[i] = i; t[i] = i;
t1[i] = 0;
}
TensorProto proto;
TensorToProto(t, &proto);
ProtoToTensor(proto, &t1);
for (size_t i = 0; i < t1.size(); ++i) {
EXPECT_EQ(t1[i], t[i]);
} }
ASSERT_EQ(t[2], 2);
ASSERT_EQ(t[1], 1);
} }
int main(int argc, char** argv) { int main(int argc, char** argv) {
......
#include "serialization.h"
#include "sgd_optimizer.h" #include "sgd_optimizer.h"
#include "serialization.h"
namespace paddle { namespace paddle {
namespace optimizer { namespace optimizer {
...@@ -28,29 +28,24 @@ void SGDOptimizer::Update(const Tensor *gradient) { ...@@ -28,29 +28,24 @@ void SGDOptimizer::Update(const Tensor *gradient) {
} }
const char *SGDOptimizer::SerializeState(int *state_len) { const char *SGDOptimizer::SerializeState(int *state_len) {
OptimizerState state; SGDOptimizerState state;
state.set_learning_rate(lr_policy_->LearningRate(num_sample_passed_)); state.set_learning_rate(lr_policy_->LearningRate(num_sample_passed_));
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*momentums_, state.mutable_momentums()); TensorToProto(*momentums_, state.mutable_momentums());
state.set_momentum(momentum_); *state_len = CalStateSize(parameter_, momentums_);
state.set_decay(decay_);
state.set_nesterov(nesterov_);
*state_len +=
CalStateSize(parameter_, momentums_, momentum_, decay_, nesterov_);
return state.SerializeAsString().c_str(); return state.SerializeAsString().c_str();
} }
void SGDOptimizer::DeSerializeState(const std::string &str) { void SGDOptimizer::DeserializeState(const std::string &str) {
OptimizerState state; SGDOptimizerState state;
state.ParseFromString(str); state.ParseFromString(str);
lr_policy_->set(state.learning_rate()); lr_policy_->set(state.learning_rate());
num_sample_passed_ = state.num_sample_passed(); num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_); ProtoToTensor(state.parameter(), parameter_);
ProtoToTensor(state.parameter(), momentums_); ProtoToTensor(state.parameter(), momentums_);
momentum_ = state.momentum();
} }
} // namespace optimizer } // namespace optimizer
......
...@@ -14,7 +14,7 @@ public: ...@@ -14,7 +14,7 @@ public:
decay_(d), decay_(d),
nesterov_(n) { nesterov_(n) {
if (momentum_ != 0.0) { if (momentum_ != 0.0) {
size_t size = p->size(); size_t size = parameter->size();
// TODO: fix it with align aware allocator bind to Tensor // TODO: fix it with align aware allocator bind to Tensor
if (momentums_) delete momentums_; if (momentums_) delete momentums_;
momentums_ = new Tensor(size); momentums_ = new Tensor(size);
...@@ -25,7 +25,7 @@ public: ...@@ -25,7 +25,7 @@ public:
} }
void Update(const Tensor* gradient); void Update(const Tensor* gradient);
const char* SerializeState(int* state_len); const char* SerializeState(int* state_len);
void DeSerializeState(const std::string& state); void DeserializeState(const std::string& state);
private: private:
Tensor* momentums_; Tensor* momentums_;
......
...@@ -78,36 +78,51 @@ enum DataType { ...@@ -78,36 +78,51 @@ enum DataType {
repeated bytes content = 2; repeated bytes content = 2;
} }
message OptimizerState { message SGDOptimizerState {
// learning rate policy
optional double learning_rate = 101; optional double learning_rate = 101;
optional double lr_decay_a = 102; optional double lr_decay_a = 102;
optional double lr_decay_b = 103; optional double lr_decay_b = 103;
optional double num_sample_passed = 104; optional double num_sample_passed = 104;
// momentum // state
optional TensorProto parameter = 105; optional TensorProto parameter = 1;
optional TensorProto momentums = 1; optional TensorProto momentums = 2;
}
// adadelta message AdadeltaOptimizerState {
// learning rate policy
optional double learning_rate = 101;
optional double lr_decay_a = 102;
optional double lr_decay_b = 103;
optional double num_sample_passed = 104;
// state
optional TensorProto parameter = 1;
optional TensorProto accum_gradient = 2; optional TensorProto accum_gradient = 2;
optional TensorProto accum_delta = 3; optional TensorProto accum_delta = 3;
optional TensorProto update_delta = 4; optional TensorProto update_delta = 4;
}
// adam message AdagradOptimizerState {
optional TensorProto velocitys = 5; // learning rate policy
optional double learning_rate = 101;
// momentum optional double lr_decay_a = 102;
optional double momentum = 6; optional double lr_decay_b = 103;
optional double decay = 7; optional double num_sample_passed = 104;
optional bool nesterov = 8; // state
optional TensorProto parameter = 1;
// adadelta optional TensorProto accum_gradient = 2;
optional double rho = 9; }
optional double epsilon = 10;
// adam
optional double beta_1 = 11;
optional double beta_2 = 12;
message AdamOptimizerState {
// learning rate policy
optional double learning_rate = 101;
optional double lr_decay_a = 102;
optional double lr_decay_b = 103;
optional double num_sample_passed = 104;
// state
optional TensorProto parameter = 1;
optional TensorProto momentums = 2;
optional TensorProto velocitys = 3;
} }
message OptimizerConfig { message OptimizerConfig {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册