From 3b1294ae48efd6118948f359eced8b4209dbe19a Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Tue, 6 Jun 2017 17:02:05 +0800 Subject: [PATCH] "add checkpoint interface: set state, get state" --- paddle/optimizer/optimizer.cc | 13 ++++++- paddle/optimizer/optimizer.h | 6 +++- paddle/optimizer/parameter_optimizer.h | 4 +++ paddle/optimizer/parameter_optimizer_test.cpp | 30 ++++++++-------- paddle/optimizer/serialization.h | 36 +++++++++++++++++++ paddle/optimizer/sgd_optimizer.h | 2 ++ paddle/optimizer/sgd_optmizer.cc | 27 ++++++++++++++ proto/OptimizerConfig.proto | 20 +++++++++++ 8 files changed, 121 insertions(+), 17 deletions(-) create mode 100644 paddle/optimizer/serialization.h diff --git a/paddle/optimizer/optimizer.cc b/paddle/optimizer/optimizer.cc index e9bcdcd8016..5076029494c 100644 --- a/paddle/optimizer/optimizer.cc +++ b/paddle/optimizer/optimizer.cc @@ -34,10 +34,16 @@ struct paddle_optimizer { }; paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto, - int config_proto_len) { + const int config_proto_len, + const char** state, + const int state_size) { paddle_optimizer* optimizer = new paddle_optimizer; std::string config(config_proto, config_proto + config_proto_len); optimizer->impl = ParameterOptimizer::Create(config); + if (state != nullptr) { + std::string s(*state, *state + state_size); + optimizer->impl->DeSerializeState(s); + } return optimizer; } @@ -71,3 +77,8 @@ void* paddle_optimizer_get_weights(paddle_optimizer* o) { void* buffer = (void*)o->impl->get_weight(); return buffer; } + +int paddle_optimizer_get_state(paddle_optimizer* o, const char* state) { + state = o->impl->SerializeState(); + return PADDLE_SUCCESS; +} diff --git a/paddle/optimizer/optimizer.h b/paddle/optimizer/optimizer.h index a2c2b13405b..c3328331fc3 100644 --- a/paddle/optimizer/optimizer.h +++ b/paddle/optimizer/optimizer.h @@ -45,7 +45,9 @@ typedef struct paddle_optimizer paddle_optimizer; * @return return optimizer instance */ paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto, - int config_proto_len); + const int config_proto_len, + const char** state, + const int state_size); /** * @brief release optimizer @@ -86,6 +88,8 @@ int paddle_optimizer_set_weights(paddle_optimizer* o, */ void* paddle_optimizer_get_weights(paddle_optimizer* o); +int paddle_optimizer_get_state(paddle_optimizer* o, const char* state); + #ifdef __cplusplus } #endif diff --git a/paddle/optimizer/parameter_optimizer.h b/paddle/optimizer/parameter_optimizer.h index 69e964069b4..e60c7778205 100644 --- a/paddle/optimizer/parameter_optimizer.h +++ b/paddle/optimizer/parameter_optimizer.h @@ -11,6 +11,8 @@ namespace paddle { namespace optimizer { +const std::string kOptimizerVersion = "1.0"; + class ParameterOptimizer { public: /** @@ -21,6 +23,8 @@ public: virtual ~ParameterOptimizer() { delete parameter_; }; static ParameterOptimizer *Create(const std::string &config_proto); + virtual const char *SerializeState(); + virtual void DeSerializeState(const std::string &state); virtual void Update(const Tensor *gradient) = 0; virtual real *get_weight() const; virtual void set_weight(Tensor *parameter); diff --git a/paddle/optimizer/parameter_optimizer_test.cpp b/paddle/optimizer/parameter_optimizer_test.cpp index cc791483431..2b3ad84ca95 100644 --- a/paddle/optimizer/parameter_optimizer_test.cpp +++ b/paddle/optimizer/parameter_optimizer_test.cpp @@ -10,7 +10,7 @@ using namespace paddle; using namespace paddle::optimizer; -Tensor* fill_n_Tensor(size_t size) { +Tensor* FillTensor(size_t size) { real* ptr = new real[size]; Tensor* param = new Tensor(ptr, size); Tensor& p = *param; @@ -20,7 +20,7 @@ Tensor* fill_n_Tensor(size_t size) { return param; } -Tensor* fix_n_Tensor(size_t size) { +Tensor* FixedTensor(size_t size) { real* ptr = new real[size]; Tensor* param = new Tensor(ptr, size); Tensor& p = *param; @@ -36,12 +36,12 @@ public: const size_t size = 5; virtual void SetUp() { - create_sgd(); - create_adam(); + CreateSGD(); + CreateAdam(); } virtual void TearDown() {} - void create_sgd() { + void CreateSGD() { config.set_optimizer(OptimizerConfig::SGD); config.mutable_sgd()->set_momentum(0.0); config.mutable_sgd()->set_decay(0.0); @@ -54,7 +54,7 @@ public: opts.push_back(opt); } - void create_adam() { + void CreateAdam() { config.set_optimizer(OptimizerConfig::Adam); config.mutable_adam()->set_beta_1(0.9); config.mutable_adam()->set_beta_2(0.1); @@ -66,15 +66,15 @@ public: ParameterOptimizer::Create(config.SerializeAsString()); opts.push_back(opt); } - void test_set_weight() { - Tensor* p = fill_n_Tensor(size); + void TestSetWeight() { + Tensor* p = FillTensor(size); for (size_t i = 0; i < opts.size(); ++i) { opts[i]->set_weight(p); } } - void test_get_weight() { - Tensor* p = fix_n_Tensor(size); + void TestGetWeight() { + Tensor* p = FixedTensor(size); for (size_t i = 0; i < opts.size(); ++i) { opts[i]->set_weight(p); } @@ -85,8 +85,8 @@ public: } } } - void test_update() { - Tensor* g = fix_n_Tensor(size); + void TestUpdate() { + Tensor* g = FixedTensor(size); for (size_t i = 0; i < opts.size(); ++i) { opts[i]->Update(g); } @@ -98,10 +98,10 @@ private: }; TEST_F(OptimizerTest, test_set_get_weight) { - test_set_weight(); - test_get_weight(); + TestSetWeight(); + TestGetWeight(); } -TEST_F(OptimizerTest, test_update) { test_update(); } +TEST_F(OptimizerTest, TestUpdate) { TestUpdate(); } int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); diff --git a/paddle/optimizer/serialization.h b/paddle/optimizer/serialization.h new file mode 100644 index 00000000000..edb90d0a164 --- /dev/null +++ b/paddle/optimizer/serialization.h @@ -0,0 +1,36 @@ +#ifndef PADDLE_OPTIMIZER_SERIALIZARION_H +#define PADDLE_OPTIMIZER_SERIALIZARION_H + +#include +#include +#include "OptimizerConfig.pb.h" +#include "paddle/utils/Logging.h" +#include "tensor.h" + +namespace paddle { +namespace optimizer { + +static void TensorToProto(const Tensor& tensor, TensorProto* proto) { + proto->set_data_type(TensorProto::PADDLE_ELEMENT_TYPE_FLOAT32); + proto->set_size(tensor.size()); + std::stringstream os; + for (size_t i = 0; i < tensor.size(); ++i) { + os << tensor[i]; + proto->add_content(os.str()); + os.clear(); + } +} + +static void ProtoToTensor(const TensorProto& proto, Tensor* tensor) { + CHECK(proto.size() == tensor->size()) << "unmatch shape of proto and tensor"; + std::stringstream sin; + for (auto i = 0; i < proto.content_size(); ++i) { + sin << proto.content(i); + sin >> (*tensor)[i]; + sin.clear(); + } +} + +} // namespace optimizer +} // namespace paddle +#endif diff --git a/paddle/optimizer/sgd_optimizer.h b/paddle/optimizer/sgd_optimizer.h index 1f6728d61e3..284d0a4d0c7 100644 --- a/paddle/optimizer/sgd_optimizer.h +++ b/paddle/optimizer/sgd_optimizer.h @@ -12,6 +12,8 @@ public: : ParameterOptimizer(lr), momentum_(m), decay_(d), nesterov_(n) {} virtual ~SGDOptimizer() { delete momentums_; } void Update(const Tensor* gradient); + const char* SerializeState(); + void DeSerializeState(const std::string& state); void set_weight(Tensor* p); real* get_weight() const; diff --git a/paddle/optimizer/sgd_optmizer.cc b/paddle/optimizer/sgd_optmizer.cc index c58ab5bbe2b..f4fa7756eab 100644 --- a/paddle/optimizer/sgd_optmizer.cc +++ b/paddle/optimizer/sgd_optmizer.cc @@ -1,3 +1,4 @@ +#include "serialization.h" #include "sgd_optimizer.h" namespace paddle { @@ -37,5 +38,31 @@ void SGDOptimizer::Update(const Tensor *gradient) { } } +const char *SGDOptimizer::SerializeState() { + OptimizerState state; + // version is a global const value + state.set_version(kOptimizerVersion); + TensorToProto(*parameter_, state.add_data()); + TensorToProto(*momentums_, state.add_data()); + // state.add_data(param_proto); + // state.add_data(momentum_proto); + state.add_hyperparam(momentum_); + return state.SerializeAsString().c_str(); +} + +void SGDOptimizer::DeSerializeState(const std::string &str) { + OptimizerState state; + state.ParseFromString(str); + CHECK(state.version() == kOptimizerVersion) + << "error version of state" + << "expected : " << kOptimizerVersion << "get : " << state.version(); + + ProtoToTensor(state.data(0), parameter_); + if (state.data_size() == 2) { + ProtoToTensor(state.data(1), momentums_); + momentum_ = state.hyperparam(0); + } +} + } // namespace optimizer } // namespace paddle diff --git a/proto/OptimizerConfig.proto b/proto/OptimizerConfig.proto index 5dd26373379..f492364a5aa 100644 --- a/proto/OptimizerConfig.proto +++ b/proto/OptimizerConfig.proto @@ -64,6 +64,26 @@ message LinearLr { optional double lr_decay_b = 3; } +message TensorProto { +enum DataType { + PADDLE_ELEMENT_TYPE_INT32 = 0; + PADDLE_ELEMENT_TYPE_UINT32 = 1; + PADDLE_ELEMENT_TYPE_INT64 = 2; + PADDLE_ELEMENT_TYPE_UINT64 = 3; + PADDLE_ELEMENT_TYPE_FLOAT32 = 4; + PADDLE_ELEMENT_TYPE_FLOAT64 = 5; +} + required DataType data_type = 1; + repeated bytes content = 2; + optional uint64 size = 3; +} + +message OptimizerState { + // match old training state with format parser + required string version = 100; + repeated TensorProto data = 1; + repeated double hyperparam = 3; +} message OptimizerConfig { // common config of optimizer -- GitLab