提交 3b1294ae 编写于 作者: D dzhwinter

"add checkpoint interface: set state, get state"

上级 fd8c5107
......@@ -34,10 +34,16 @@ struct paddle_optimizer {
};
paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto,
int config_proto_len) {
const int config_proto_len,
const char** state,
const int state_size) {
paddle_optimizer* optimizer = new paddle_optimizer;
std::string config(config_proto, config_proto + config_proto_len);
optimizer->impl = ParameterOptimizer::Create(config);
if (state != nullptr) {
std::string s(*state, *state + state_size);
optimizer->impl->DeSerializeState(s);
}
return optimizer;
}
......@@ -71,3 +77,8 @@ void* paddle_optimizer_get_weights(paddle_optimizer* o) {
void* buffer = (void*)o->impl->get_weight();
return buffer;
}
int paddle_optimizer_get_state(paddle_optimizer* o, const char* state) {
state = o->impl->SerializeState();
return PADDLE_SUCCESS;
}
......@@ -45,7 +45,9 @@ typedef struct paddle_optimizer paddle_optimizer;
* @return return optimizer instance
*/
paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto,
int config_proto_len);
const int config_proto_len,
const char** state,
const int state_size);
/**
* @brief release optimizer
......@@ -86,6 +88,8 @@ int paddle_optimizer_set_weights(paddle_optimizer* o,
*/
void* paddle_optimizer_get_weights(paddle_optimizer* o);
int paddle_optimizer_get_state(paddle_optimizer* o, const char* state);
#ifdef __cplusplus
}
#endif
......
......@@ -11,6 +11,8 @@
namespace paddle {
namespace optimizer {
const std::string kOptimizerVersion = "1.0";
class ParameterOptimizer {
public:
/**
......@@ -21,6 +23,8 @@ public:
virtual ~ParameterOptimizer() { delete parameter_; };
static ParameterOptimizer *Create(const std::string &config_proto);
virtual const char *SerializeState();
virtual void DeSerializeState(const std::string &state);
virtual void Update(const Tensor *gradient) = 0;
virtual real *get_weight() const;
virtual void set_weight(Tensor *parameter);
......
......@@ -10,7 +10,7 @@
using namespace paddle;
using namespace paddle::optimizer;
Tensor* fill_n_Tensor(size_t size) {
Tensor* FillTensor(size_t size) {
real* ptr = new real[size];
Tensor* param = new Tensor(ptr, size);
Tensor& p = *param;
......@@ -20,7 +20,7 @@ Tensor* fill_n_Tensor(size_t size) {
return param;
}
Tensor* fix_n_Tensor(size_t size) {
Tensor* FixedTensor(size_t size) {
real* ptr = new real[size];
Tensor* param = new Tensor(ptr, size);
Tensor& p = *param;
......@@ -36,12 +36,12 @@ public:
const size_t size = 5;
virtual void SetUp() {
create_sgd();
create_adam();
CreateSGD();
CreateAdam();
}
virtual void TearDown() {}
void create_sgd() {
void CreateSGD() {
config.set_optimizer(OptimizerConfig::SGD);
config.mutable_sgd()->set_momentum(0.0);
config.mutable_sgd()->set_decay(0.0);
......@@ -54,7 +54,7 @@ public:
opts.push_back(opt);
}
void create_adam() {
void CreateAdam() {
config.set_optimizer(OptimizerConfig::Adam);
config.mutable_adam()->set_beta_1(0.9);
config.mutable_adam()->set_beta_2(0.1);
......@@ -66,15 +66,15 @@ public:
ParameterOptimizer::Create(config.SerializeAsString());
opts.push_back(opt);
}
void test_set_weight() {
Tensor* p = fill_n_Tensor(size);
void TestSetWeight() {
Tensor* p = FillTensor(size);
for (size_t i = 0; i < opts.size(); ++i) {
opts[i]->set_weight(p);
}
}
void test_get_weight() {
Tensor* p = fix_n_Tensor(size);
void TestGetWeight() {
Tensor* p = FixedTensor(size);
for (size_t i = 0; i < opts.size(); ++i) {
opts[i]->set_weight(p);
}
......@@ -85,8 +85,8 @@ public:
}
}
}
void test_update() {
Tensor* g = fix_n_Tensor(size);
void TestUpdate() {
Tensor* g = FixedTensor(size);
for (size_t i = 0; i < opts.size(); ++i) {
opts[i]->Update(g);
}
......@@ -98,10 +98,10 @@ private:
};
TEST_F(OptimizerTest, test_set_get_weight) {
test_set_weight();
test_get_weight();
TestSetWeight();
TestGetWeight();
}
TEST_F(OptimizerTest, test_update) { test_update(); }
TEST_F(OptimizerTest, TestUpdate) { TestUpdate(); }
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
......
#ifndef PADDLE_OPTIMIZER_SERIALIZARION_H
#define PADDLE_OPTIMIZER_SERIALIZARION_H
#include <sstream>
#include <string>
#include "OptimizerConfig.pb.h"
#include "paddle/utils/Logging.h"
#include "tensor.h"
namespace paddle {
namespace optimizer {
static void TensorToProto(const Tensor& tensor, TensorProto* proto) {
proto->set_data_type(TensorProto::PADDLE_ELEMENT_TYPE_FLOAT32);
proto->set_size(tensor.size());
std::stringstream os;
for (size_t i = 0; i < tensor.size(); ++i) {
os << tensor[i];
proto->add_content(os.str());
os.clear();
}
}
static void ProtoToTensor(const TensorProto& proto, Tensor* tensor) {
CHECK(proto.size() == tensor->size()) << "unmatch shape of proto and tensor";
std::stringstream sin;
for (auto i = 0; i < proto.content_size(); ++i) {
sin << proto.content(i);
sin >> (*tensor)[i];
sin.clear();
}
}
} // namespace optimizer
} // namespace paddle
#endif
......@@ -12,6 +12,8 @@ public:
: ParameterOptimizer(lr), momentum_(m), decay_(d), nesterov_(n) {}
virtual ~SGDOptimizer() { delete momentums_; }
void Update(const Tensor* gradient);
const char* SerializeState();
void DeSerializeState(const std::string& state);
void set_weight(Tensor* p);
real* get_weight() const;
......
#include "serialization.h"
#include "sgd_optimizer.h"
namespace paddle {
......@@ -37,5 +38,31 @@ void SGDOptimizer::Update(const Tensor *gradient) {
}
}
const char *SGDOptimizer::SerializeState() {
OptimizerState state;
// version is a global const value
state.set_version(kOptimizerVersion);
TensorToProto(*parameter_, state.add_data());
TensorToProto(*momentums_, state.add_data());
// state.add_data(param_proto);
// state.add_data(momentum_proto);
state.add_hyperparam(momentum_);
return state.SerializeAsString().c_str();
}
void SGDOptimizer::DeSerializeState(const std::string &str) {
OptimizerState state;
state.ParseFromString(str);
CHECK(state.version() == kOptimizerVersion)
<< "error version of state"
<< "expected : " << kOptimizerVersion << "get : " << state.version();
ProtoToTensor(state.data(0), parameter_);
if (state.data_size() == 2) {
ProtoToTensor(state.data(1), momentums_);
momentum_ = state.hyperparam(0);
}
}
} // namespace optimizer
} // namespace paddle
......@@ -64,6 +64,26 @@ message LinearLr {
optional double lr_decay_b = 3;
}
message TensorProto {
enum DataType {
PADDLE_ELEMENT_TYPE_INT32 = 0;
PADDLE_ELEMENT_TYPE_UINT32 = 1;
PADDLE_ELEMENT_TYPE_INT64 = 2;
PADDLE_ELEMENT_TYPE_UINT64 = 3;
PADDLE_ELEMENT_TYPE_FLOAT32 = 4;
PADDLE_ELEMENT_TYPE_FLOAT64 = 5;
}
required DataType data_type = 1;
repeated bytes content = 2;
optional uint64 size = 3;
}
message OptimizerState {
// match old training state with format parser
required string version = 100;
repeated TensorProto data = 1;
repeated double hyperparam = 3;
}
message OptimizerConfig {
// common config of optimizer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册