提交 3b1294ae 编写于 作者: D dzhwinter

"add checkpoint interface: set state, get state"

上级 fd8c5107
...@@ -34,10 +34,16 @@ struct paddle_optimizer { ...@@ -34,10 +34,16 @@ struct paddle_optimizer {
}; };
paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto, paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto,
int config_proto_len) { const int config_proto_len,
const char** state,
const int state_size) {
paddle_optimizer* optimizer = new paddle_optimizer; paddle_optimizer* optimizer = new paddle_optimizer;
std::string config(config_proto, config_proto + config_proto_len); std::string config(config_proto, config_proto + config_proto_len);
optimizer->impl = ParameterOptimizer::Create(config); optimizer->impl = ParameterOptimizer::Create(config);
if (state != nullptr) {
std::string s(*state, *state + state_size);
optimizer->impl->DeSerializeState(s);
}
return optimizer; return optimizer;
} }
...@@ -71,3 +77,8 @@ void* paddle_optimizer_get_weights(paddle_optimizer* o) { ...@@ -71,3 +77,8 @@ void* paddle_optimizer_get_weights(paddle_optimizer* o) {
void* buffer = (void*)o->impl->get_weight(); void* buffer = (void*)o->impl->get_weight();
return buffer; return buffer;
} }
int paddle_optimizer_get_state(paddle_optimizer* o, const char* state) {
state = o->impl->SerializeState();
return PADDLE_SUCCESS;
}
...@@ -45,7 +45,9 @@ typedef struct paddle_optimizer paddle_optimizer; ...@@ -45,7 +45,9 @@ typedef struct paddle_optimizer paddle_optimizer;
* @return return optimizer instance * @return return optimizer instance
*/ */
paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto, paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto,
int config_proto_len); const int config_proto_len,
const char** state,
const int state_size);
/** /**
* @brief release optimizer * @brief release optimizer
...@@ -86,6 +88,8 @@ int paddle_optimizer_set_weights(paddle_optimizer* o, ...@@ -86,6 +88,8 @@ int paddle_optimizer_set_weights(paddle_optimizer* o,
*/ */
void* paddle_optimizer_get_weights(paddle_optimizer* o); void* paddle_optimizer_get_weights(paddle_optimizer* o);
int paddle_optimizer_get_state(paddle_optimizer* o, const char* state);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
......
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
namespace paddle { namespace paddle {
namespace optimizer { namespace optimizer {
const std::string kOptimizerVersion = "1.0";
class ParameterOptimizer { class ParameterOptimizer {
public: public:
/** /**
...@@ -21,6 +23,8 @@ public: ...@@ -21,6 +23,8 @@ public:
virtual ~ParameterOptimizer() { delete parameter_; }; virtual ~ParameterOptimizer() { delete parameter_; };
static ParameterOptimizer *Create(const std::string &config_proto); static ParameterOptimizer *Create(const std::string &config_proto);
virtual const char *SerializeState();
virtual void DeSerializeState(const std::string &state);
virtual void Update(const Tensor *gradient) = 0; virtual void Update(const Tensor *gradient) = 0;
virtual real *get_weight() const; virtual real *get_weight() const;
virtual void set_weight(Tensor *parameter); virtual void set_weight(Tensor *parameter);
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
using namespace paddle; using namespace paddle;
using namespace paddle::optimizer; using namespace paddle::optimizer;
Tensor* fill_n_Tensor(size_t size) { Tensor* FillTensor(size_t size) {
real* ptr = new real[size]; real* ptr = new real[size];
Tensor* param = new Tensor(ptr, size); Tensor* param = new Tensor(ptr, size);
Tensor& p = *param; Tensor& p = *param;
...@@ -20,7 +20,7 @@ Tensor* fill_n_Tensor(size_t size) { ...@@ -20,7 +20,7 @@ Tensor* fill_n_Tensor(size_t size) {
return param; return param;
} }
Tensor* fix_n_Tensor(size_t size) { Tensor* FixedTensor(size_t size) {
real* ptr = new real[size]; real* ptr = new real[size];
Tensor* param = new Tensor(ptr, size); Tensor* param = new Tensor(ptr, size);
Tensor& p = *param; Tensor& p = *param;
...@@ -36,12 +36,12 @@ public: ...@@ -36,12 +36,12 @@ public:
const size_t size = 5; const size_t size = 5;
virtual void SetUp() { virtual void SetUp() {
create_sgd(); CreateSGD();
create_adam(); CreateAdam();
} }
virtual void TearDown() {} virtual void TearDown() {}
void create_sgd() { void CreateSGD() {
config.set_optimizer(OptimizerConfig::SGD); config.set_optimizer(OptimizerConfig::SGD);
config.mutable_sgd()->set_momentum(0.0); config.mutable_sgd()->set_momentum(0.0);
config.mutable_sgd()->set_decay(0.0); config.mutable_sgd()->set_decay(0.0);
...@@ -54,7 +54,7 @@ public: ...@@ -54,7 +54,7 @@ public:
opts.push_back(opt); opts.push_back(opt);
} }
void create_adam() { void CreateAdam() {
config.set_optimizer(OptimizerConfig::Adam); config.set_optimizer(OptimizerConfig::Adam);
config.mutable_adam()->set_beta_1(0.9); config.mutable_adam()->set_beta_1(0.9);
config.mutable_adam()->set_beta_2(0.1); config.mutable_adam()->set_beta_2(0.1);
...@@ -66,15 +66,15 @@ public: ...@@ -66,15 +66,15 @@ public:
ParameterOptimizer::Create(config.SerializeAsString()); ParameterOptimizer::Create(config.SerializeAsString());
opts.push_back(opt); opts.push_back(opt);
} }
void test_set_weight() { void TestSetWeight() {
Tensor* p = fill_n_Tensor(size); Tensor* p = FillTensor(size);
for (size_t i = 0; i < opts.size(); ++i) { for (size_t i = 0; i < opts.size(); ++i) {
opts[i]->set_weight(p); opts[i]->set_weight(p);
} }
} }
void test_get_weight() { void TestGetWeight() {
Tensor* p = fix_n_Tensor(size); Tensor* p = FixedTensor(size);
for (size_t i = 0; i < opts.size(); ++i) { for (size_t i = 0; i < opts.size(); ++i) {
opts[i]->set_weight(p); opts[i]->set_weight(p);
} }
...@@ -85,8 +85,8 @@ public: ...@@ -85,8 +85,8 @@ public:
} }
} }
} }
void test_update() { void TestUpdate() {
Tensor* g = fix_n_Tensor(size); Tensor* g = FixedTensor(size);
for (size_t i = 0; i < opts.size(); ++i) { for (size_t i = 0; i < opts.size(); ++i) {
opts[i]->Update(g); opts[i]->Update(g);
} }
...@@ -98,10 +98,10 @@ private: ...@@ -98,10 +98,10 @@ private:
}; };
TEST_F(OptimizerTest, test_set_get_weight) { TEST_F(OptimizerTest, test_set_get_weight) {
test_set_weight(); TestSetWeight();
test_get_weight(); TestGetWeight();
} }
TEST_F(OptimizerTest, test_update) { test_update(); } TEST_F(OptimizerTest, TestUpdate) { TestUpdate(); }
int main(int argc, char** argv) { int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
......
#ifndef PADDLE_OPTIMIZER_SERIALIZARION_H
#define PADDLE_OPTIMIZER_SERIALIZARION_H
#include <sstream>
#include <string>
#include "OptimizerConfig.pb.h"
#include "paddle/utils/Logging.h"
#include "tensor.h"
namespace paddle {
namespace optimizer {
static void TensorToProto(const Tensor& tensor, TensorProto* proto) {
proto->set_data_type(TensorProto::PADDLE_ELEMENT_TYPE_FLOAT32);
proto->set_size(tensor.size());
std::stringstream os;
for (size_t i = 0; i < tensor.size(); ++i) {
os << tensor[i];
proto->add_content(os.str());
os.clear();
}
}
static void ProtoToTensor(const TensorProto& proto, Tensor* tensor) {
CHECK(proto.size() == tensor->size()) << "unmatch shape of proto and tensor";
std::stringstream sin;
for (auto i = 0; i < proto.content_size(); ++i) {
sin << proto.content(i);
sin >> (*tensor)[i];
sin.clear();
}
}
} // namespace optimizer
} // namespace paddle
#endif
...@@ -12,6 +12,8 @@ public: ...@@ -12,6 +12,8 @@ public:
: ParameterOptimizer(lr), momentum_(m), decay_(d), nesterov_(n) {} : ParameterOptimizer(lr), momentum_(m), decay_(d), nesterov_(n) {}
virtual ~SGDOptimizer() { delete momentums_; } virtual ~SGDOptimizer() { delete momentums_; }
void Update(const Tensor* gradient); void Update(const Tensor* gradient);
const char* SerializeState();
void DeSerializeState(const std::string& state);
void set_weight(Tensor* p); void set_weight(Tensor* p);
real* get_weight() const; real* get_weight() const;
......
#include "serialization.h"
#include "sgd_optimizer.h" #include "sgd_optimizer.h"
namespace paddle { namespace paddle {
...@@ -37,5 +38,31 @@ void SGDOptimizer::Update(const Tensor *gradient) { ...@@ -37,5 +38,31 @@ void SGDOptimizer::Update(const Tensor *gradient) {
} }
} }
const char *SGDOptimizer::SerializeState() {
OptimizerState state;
// version is a global const value
state.set_version(kOptimizerVersion);
TensorToProto(*parameter_, state.add_data());
TensorToProto(*momentums_, state.add_data());
// state.add_data(param_proto);
// state.add_data(momentum_proto);
state.add_hyperparam(momentum_);
return state.SerializeAsString().c_str();
}
void SGDOptimizer::DeSerializeState(const std::string &str) {
OptimizerState state;
state.ParseFromString(str);
CHECK(state.version() == kOptimizerVersion)
<< "error version of state"
<< "expected : " << kOptimizerVersion << "get : " << state.version();
ProtoToTensor(state.data(0), parameter_);
if (state.data_size() == 2) {
ProtoToTensor(state.data(1), momentums_);
momentum_ = state.hyperparam(0);
}
}
} // namespace optimizer } // namespace optimizer
} // namespace paddle } // namespace paddle
...@@ -64,6 +64,26 @@ message LinearLr { ...@@ -64,6 +64,26 @@ message LinearLr {
optional double lr_decay_b = 3; optional double lr_decay_b = 3;
} }
message TensorProto {
enum DataType {
PADDLE_ELEMENT_TYPE_INT32 = 0;
PADDLE_ELEMENT_TYPE_UINT32 = 1;
PADDLE_ELEMENT_TYPE_INT64 = 2;
PADDLE_ELEMENT_TYPE_UINT64 = 3;
PADDLE_ELEMENT_TYPE_FLOAT32 = 4;
PADDLE_ELEMENT_TYPE_FLOAT64 = 5;
}
required DataType data_type = 1;
repeated bytes content = 2;
optional uint64 size = 3;
}
message OptimizerState {
// match old training state with format parser
required string version = 100;
repeated TensorProto data = 1;
repeated double hyperparam = 3;
}
message OptimizerConfig { message OptimizerConfig {
// common config of optimizer // common config of optimizer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册