提交 f28b4d68 编写于 作者: H Helin Wang

Fix parameter server checkpoint serialization

上级 cdb5f292
...@@ -72,21 +72,34 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer ...@@ -72,21 +72,34 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
} }
o.config = c o.config = c
o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)), o.opt = C.paddle_create_optimizer(
C.paddle_element_type(p.ElementType), cbuffer, C.int(paramBufferSize), (*C.char)(cstate), C.int(len(s))) (*C.uchar)(&c[0]),
C.int(len(c)),
C.paddle_element_type(p.ElementType),
cbuffer,
C.int(paramBufferSize),
(*C.char)(cstate),
C.int(len(s)),
)
return o return o
} }
func (o *optimizer) GetWeights() []byte { func (o *optimizer) GetWeights() []byte {
var buffer unsafe.Pointer var buffer unsafe.Pointer
// we do not own the buffer, no need to free later.
bufferLen := C.paddle_optimizer_get_weights(o.opt, &buffer) bufferLen := C.paddle_optimizer_get_weights(o.opt, &buffer)
return cArrayToSlice(buffer, int(bufferLen)*C.sizeof_float) return cArrayToSlice(buffer, int(bufferLen)*C.sizeof_float)
} }
func (o *optimizer) GetStates() []byte { func (o *optimizer) GetStates() []byte {
var cbuffer *C.char var cbuffer *C.char
// we owns the state buffer, need to free later.
cbufferLen := C.paddle_optimizer_get_state(o.opt, &cbuffer) cbufferLen := C.paddle_optimizer_get_state(o.opt, &cbuffer)
return cArrayToSlice(unsafe.Pointer(cbuffer), int(cbufferLen)) buf := cArrayToSlice(unsafe.Pointer(cbuffer), int(cbufferLen))
cpy := make([]byte, len(buf))
copy(cpy, buf)
C.free(unsafe.Pointer(cbuffer))
return cpy
} }
func (o *optimizer) UpdateParameter(g Gradient) error { func (o *optimizer) UpdateParameter(g Gradient) error {
......
...@@ -15,8 +15,12 @@ ...@@ -15,8 +15,12 @@
package pserver package pserver
import ( import (
"encoding/binary"
"io/ioutil" "io/ioutil"
"math"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func TestOptimizerCreateRelease(t *testing.T) { func TestOptimizerCreateRelease(t *testing.T) {
...@@ -36,3 +40,39 @@ func TestOptimizerCreateRelease(t *testing.T) { ...@@ -36,3 +40,39 @@ func TestOptimizerCreateRelease(t *testing.T) {
o := newOptimizer(param, nil) o := newOptimizer(param, nil)
o.Cleanup() o.Cleanup()
} }
func float32Bytes(float float32) []byte {
bits := math.Float32bits(float)
bytes := make([]byte, 4)
binary.LittleEndian.PutUint32(bytes, bits)
return bytes
}
func TestOptimizerState(t *testing.T) {
p := Parameter{
Name: "a",
ElementType: Int32,
}
weights := float32Bytes(100)
p.Content = weights
config, err := ioutil.ReadFile("./client/c/test/testdata/optimizer.pb")
if err != nil {
t.Fatalf("read optimizer proto failed")
}
param := ParameterWithConfig{
Param: p,
Config: config,
}
o := newOptimizer(param, nil)
s := o.GetStates()
// clear param content and check if the state is restored.
param.Param.Content = float32Bytes(300)
o1 := newOptimizer(param, s)
s1 := o1.GetStates()
assert.Equal(t, s, s1)
assert.Equal(t, weights, o.GetWeights())
assert.Equal(t, weights, o1.GetWeights())
o.Cleanup()
o1.Cleanup()
}
...@@ -297,6 +297,13 @@ func (s *Service) checkpoint() (err error) { ...@@ -297,6 +297,13 @@ func (s *Service) checkpoint() (err error) {
return return
} }
if _, err = os.Stat(s.checkpointPath); os.IsNotExist(err) {
err = os.MkdirAll(s.checkpointPath, os.ModePerm)
if err != nil {
return
}
}
id := uuid.NewV4().String() id := uuid.NewV4().String()
p := path.Join(s.checkpointPath, id) p := path.Join(s.checkpointPath, id)
f, err := os.Create(p) f, err := os.Create(p)
......
...@@ -25,19 +25,17 @@ void AdadeltaOptimizer::Update(const Tensor* gradient) { ...@@ -25,19 +25,17 @@ void AdadeltaOptimizer::Update(const Tensor* gradient) {
} }
} }
const char* AdadeltaOptimizer::SerializeState(int* state_len) { std::string AdadeltaOptimizer::SerializeState() {
AdadeltaOptimizerState state; AdadeltaOptimizerState state;
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
std::string lr_str = this->lr_policy_->SerializeState(state_len); std::string lr_str = this->lr_policy_->SerializeState();
state.mutable_lr_state()->ParseFromString(lr_str); state.mutable_lr_state()->ParseFromString(lr_str);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*accum_gradient_, state.mutable_accum_gradient()); TensorToProto(*accum_gradient_, state.mutable_accum_gradient());
TensorToProto(*accum_delta_, state.mutable_accum_delta()); TensorToProto(*accum_delta_, state.mutable_accum_delta());
TensorToProto(*update_delta_, state.mutable_update_delta()); TensorToProto(*update_delta_, state.mutable_update_delta());
auto str = state.SerializeAsString(); return state.SerializeAsString();
*state_len += str.size();
return str.c_str();
} }
void AdadeltaOptimizer::DeserializeState(const std::string& str) { void AdadeltaOptimizer::DeserializeState(const std::string& str) {
......
...@@ -23,7 +23,7 @@ public: ...@@ -23,7 +23,7 @@ public:
if (update_delta_) delete update_delta_; if (update_delta_) delete update_delta_;
} }
void Update(const Tensor *gradient); void Update(const Tensor *gradient);
const char *SerializeState(int *state_len); std::string SerializeState();
void DeserializeState(const std::string &state); void DeserializeState(const std::string &state);
private: private:
......
...@@ -17,17 +17,15 @@ void AdagradOptimizer::Update(const Tensor* gradient) { ...@@ -17,17 +17,15 @@ void AdagradOptimizer::Update(const Tensor* gradient) {
learning_rate * decay_ * param[i]; learning_rate * decay_ * param[i];
} }
} }
const char* AdagradOptimizer::SerializeState(int* state_len) { std::string AdagradOptimizer::SerializeState() {
AdagradOptimizerState state; AdagradOptimizerState state;
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
std::string lr_str = this->lr_policy_->SerializeState(state_len); std::string lr_str = this->lr_policy_->SerializeState();
state.mutable_lr_state()->ParseFromString(lr_str); state.mutable_lr_state()->ParseFromString(lr_str);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*accum_gradient_, state.mutable_accum_gradient()); TensorToProto(*accum_gradient_, state.mutable_accum_gradient());
auto str = state.SerializeAsString(); return state.SerializeAsString();
*state_len += str.size();
return str.c_str();
} }
void AdagradOptimizer::DeserializeState(const std::string& str) { void AdagradOptimizer::DeserializeState(const std::string& str) {
......
...@@ -19,7 +19,7 @@ public: ...@@ -19,7 +19,7 @@ public:
if (accum_gradient_) delete accum_gradient_; if (accum_gradient_) delete accum_gradient_;
} }
void Update(const Tensor *gradient); void Update(const Tensor *gradient);
const char *SerializeState(int *state_len); std::string SerializeState();
void DeserializeState(const std::string &state); void DeserializeState(const std::string &state);
private: private:
......
...@@ -22,18 +22,16 @@ void AdamOptimizer::Update(const Tensor *gradient) { ...@@ -22,18 +22,16 @@ void AdamOptimizer::Update(const Tensor *gradient) {
} }
} }
const char *AdamOptimizer::SerializeState(int *state_len) { std::string AdamOptimizer::SerializeState() {
AdamOptimizerState state; AdamOptimizerState state;
std::string lr_str = this->lr_policy_->SerializeState(state_len); std::string lr_str = this->lr_policy_->SerializeState();
state.mutable_lr_state()->ParseFromString(lr_str); state.mutable_lr_state()->ParseFromString(lr_str);
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*momentums_, state.mutable_momentums()); TensorToProto(*momentums_, state.mutable_momentums());
TensorToProto(*velocitys_, state.mutable_velocitys()); TensorToProto(*velocitys_, state.mutable_velocitys());
auto str = state.SerializeAsString(); return state.SerializeAsString();
*state_len += str.size();
return str.c_str();
} }
void AdamOptimizer::DeserializeState(const std::string &str) { void AdamOptimizer::DeserializeState(const std::string &str) {
......
...@@ -25,7 +25,7 @@ public: ...@@ -25,7 +25,7 @@ public:
if (velocitys_) delete velocitys_; if (velocitys_) delete velocitys_;
} }
void Update(const Tensor *gradient); void Update(const Tensor *gradient);
const char *SerializeState(int *state_len); std::string SerializeState();
void DeserializeState(const std::string &state); void DeserializeState(const std::string &state);
private: private:
......
...@@ -10,7 +10,7 @@ class LrPolicy { ...@@ -10,7 +10,7 @@ class LrPolicy {
public: public:
virtual ~LrPolicy() {} virtual ~LrPolicy() {}
virtual double LearningRate(const uint64_t num_sample_passed) = 0; virtual double LearningRate(const uint64_t num_sample_passed) = 0;
virtual const char *SerializeState(int *state_len) = 0; virtual std::string SerializeState() = 0;
virtual void DeserializeState(const std::string &state) = 0; virtual void DeserializeState(const std::string &state) = 0;
}; };
...@@ -21,12 +21,10 @@ public: ...@@ -21,12 +21,10 @@ public:
double LearningRate(const uint64_t num_sample_passed) { double LearningRate(const uint64_t num_sample_passed) {
return learning_rate_; return learning_rate_;
} }
const char *SerializeState(int *state_len) { std::string SerializeState() {
LrPolicyState state; LrPolicyState state;
state.set_learning_rate(learning_rate_); state.set_learning_rate(learning_rate_);
auto str = state.SerializeAsString(); return state.SerializeAsString();
*state_len = str.size();
return str.c_str();
} }
void DeserializeState(const std::string &str) { void DeserializeState(const std::string &str) {
LrPolicyState state; LrPolicyState state;
...@@ -46,14 +44,12 @@ public: ...@@ -46,14 +44,12 @@ public:
return std::max(learning_rate_ - lr_decay_a_ * num_sample_passed, return std::max(learning_rate_ - lr_decay_a_ * num_sample_passed,
lr_decay_b_); lr_decay_b_);
} }
const char *SerializeState(int *state_len) { std::string SerializeState() {
LrPolicyState state; LrPolicyState state;
state.set_learning_rate(learning_rate_); state.set_learning_rate(learning_rate_);
state.set_lr_decay_a(lr_decay_a_); state.set_lr_decay_a(lr_decay_a_);
state.set_lr_decay_b(lr_decay_b_); state.set_lr_decay_b(lr_decay_b_);
auto str = state.SerializeAsString(); return state.SerializeAsString();
*state_len = str.size();
return str.c_str();
} }
void DeserializeState(const std::string &str) { void DeserializeState(const std::string &str) {
LrPolicyState state; LrPolicyState state;
......
#include "optimizer.h" #include "optimizer.h"
#include <glog/logging.h>
#include <cstdlib>
#include <cstring>
#include <string> #include <string>
#include "parameter_optimizer.h" #include "parameter_optimizer.h"
...@@ -78,7 +81,13 @@ int paddle_optimizer_get_weights(paddle_optimizer* o, void** param_buffer) { ...@@ -78,7 +81,13 @@ int paddle_optimizer_get_weights(paddle_optimizer* o, void** param_buffer) {
} }
int paddle_optimizer_get_state(paddle_optimizer* o, const char** state) { int paddle_optimizer_get_state(paddle_optimizer* o, const char** state) {
int state_len = 0; std::string s = o->impl->SerializeState();
*state = o->impl->SerializeState(&state_len); int state_len = s.size();
if (state_len > 0) {
*state = (char*)std::malloc(state_len);
std::memcpy((void*)*state, (const void*)s.c_str(), state_len);
}
return state_len; return state_len;
} }
...@@ -32,6 +32,7 @@ ParameterOptimizer *ParameterOptimizer::Create(const std::string &config_proto, ...@@ -32,6 +32,7 @@ ParameterOptimizer *ParameterOptimizer::Create(const std::string &config_proto,
Tensor *parameter, Tensor *parameter,
const OptimizerConfig &config) -> ParameterOptimizer * { const OptimizerConfig &config) -> ParameterOptimizer * {
if (config.optimizer() == OptimizerConfig::SGD) { if (config.optimizer() == OptimizerConfig::SGD) {
LOG(INFO) << "creating SGD optimizer";
return new SGDOptimizer(parameter, return new SGDOptimizer(parameter,
lr, lr,
config.sgd().momentum(), config.sgd().momentum(),
...@@ -39,6 +40,7 @@ ParameterOptimizer *ParameterOptimizer::Create(const std::string &config_proto, ...@@ -39,6 +40,7 @@ ParameterOptimizer *ParameterOptimizer::Create(const std::string &config_proto,
config.sgd().nesterov()); config.sgd().nesterov());
} }
if (config.optimizer() == OptimizerConfig::Adadelta) { if (config.optimizer() == OptimizerConfig::Adadelta) {
LOG(INFO) << "creating Adadelta optimizer";
return new AdadeltaOptimizer(parameter, return new AdadeltaOptimizer(parameter,
lr, lr,
config.adadelta().rho(), config.adadelta().rho(),
...@@ -46,10 +48,12 @@ ParameterOptimizer *ParameterOptimizer::Create(const std::string &config_proto, ...@@ -46,10 +48,12 @@ ParameterOptimizer *ParameterOptimizer::Create(const std::string &config_proto,
config.adadelta().decay()); config.adadelta().decay());
} }
if (config.optimizer() == OptimizerConfig::Adagrad) { if (config.optimizer() == OptimizerConfig::Adagrad) {
LOG(INFO) << "creating Adagrad optimizer";
return new AdagradOptimizer( return new AdagradOptimizer(
parameter, lr, config.adagrad().epsilon(), config.adagrad().decay()); parameter, lr, config.adagrad().epsilon(), config.adagrad().decay());
} }
if (config.optimizer() == OptimizerConfig::Adam) { if (config.optimizer() == OptimizerConfig::Adam) {
LOG(INFO) << "creating Adam optimizer";
return new AdamOptimizer(parameter, return new AdamOptimizer(parameter,
lr, lr,
config.adam().beta_1(), config.adam().beta_1(),
......
...@@ -28,7 +28,7 @@ public: ...@@ -28,7 +28,7 @@ public:
Tensor *parameter); Tensor *parameter);
virtual void Update(const Tensor *gradient) = 0; virtual void Update(const Tensor *gradient) = 0;
virtual float *get_weight(int *param_size) const; virtual float *get_weight(int *param_size) const;
virtual const char *SerializeState(int *state_len) = 0; virtual std::string SerializeState() = 0;
virtual void DeserializeState(const std::string &state) = 0; virtual void DeserializeState(const std::string &state) = 0;
protected: protected:
......
...@@ -85,6 +85,7 @@ public: ...@@ -85,6 +85,7 @@ public:
for (size_t i = 0; i < opts_.size(); ++i) { for (size_t i = 0; i < opts_.size(); ++i) {
int s = 0; int s = 0;
float* newp = (float*)opts_[i]->get_weight(&s); float* newp = (float*)opts_[i]->get_weight(&s);
EXPECT_EQ(s, kSize);
for (size_t j = 0; j < kSize; ++j) { for (size_t j = 0; j < kSize; ++j) {
EXPECT_EQ(newp[j], (*p)[j]); EXPECT_EQ(newp[j], (*p)[j]);
} }
...@@ -99,10 +100,20 @@ public: ...@@ -99,10 +100,20 @@ public:
} }
void TestCheckPoint() { void TestCheckPoint() {
paddle::optimizer::Tensor* p = FixedTensor(kSize);
for (size_t i = 0; i < opts_.size(); ++i) { for (size_t i = 0; i < opts_.size(); ++i) {
int state_len = 0; auto state = opts_[i]->SerializeState();
std::string state = opts_[i]->SerializeState(&state_len); opts_[i]->DeserializeState(state);
auto state1 = opts_[i]->SerializeState();
opts_[i]->DeserializeState(state); opts_[i]->DeserializeState(state);
EXPECT_EQ(state, state1);
int s = 0;
float* newp = (float*)opts_[i]->get_weight(&s);
EXPECT_EQ(s, kSize);
for (size_t j = 0; j < kSize; ++j) {
EXPECT_EQ(newp[j], (*p)[j]);
}
} }
} }
......
...@@ -21,7 +21,22 @@ TEST(TensorToProto, Case1) { ...@@ -21,7 +21,22 @@ TEST(TensorToProto, Case1) {
paddle::optimizer::Tensor t(3), t1(3); paddle::optimizer::Tensor t(3), t1(3);
for (size_t i = 0; i < t.size(); ++i) { for (size_t i = 0; i < t.size(); ++i) {
t[i] = i; t[i] = i;
t1[i] = 0; t1[i] = 10;
}
paddle::TensorProto proto;
paddle::optimizer::TensorToProto(t, &proto);
paddle::optimizer::ProtoToTensor(proto, &t1);
for (size_t i = 0; i < t1.size(); ++i) {
EXPECT_EQ(t1[i], t[i]);
}
}
TEST(TensorToProto, Case2) {
paddle::optimizer::Tensor t(1), t1(1);
for (size_t i = 0; i < t.size(); ++i) {
t[i] = i;
t1[i] = 10;
} }
paddle::TensorProto proto; paddle::TensorProto proto;
......
...@@ -27,16 +27,14 @@ void SGDOptimizer::Update(const Tensor *gradient) { ...@@ -27,16 +27,14 @@ void SGDOptimizer::Update(const Tensor *gradient) {
} }
} }
const char *SGDOptimizer::SerializeState(int *state_len) { std::string SGDOptimizer::SerializeState() {
SGDOptimizerState state; SGDOptimizerState state;
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
std::string lr_str = this->lr_policy_->SerializeState(state_len); std::string lr_str = this->lr_policy_->SerializeState();
state.mutable_lr_state()->ParseFromString(lr_str); state.mutable_lr_state()->ParseFromString(lr_str);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums()); if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums());
auto str = state.SerializeAsString(); return state.SerializeAsString();
*state_len += str.size();
return str.c_str();
} }
void SGDOptimizer::DeserializeState(const std::string &str) { void SGDOptimizer::DeserializeState(const std::string &str) {
......
...@@ -23,7 +23,7 @@ public: ...@@ -23,7 +23,7 @@ public:
if (momentums_) delete momentums_; if (momentums_) delete momentums_;
} }
void Update(const Tensor* gradient); void Update(const Tensor* gradient);
const char* SerializeState(int* state_len); std::string SerializeState();
void DeserializeState(const std::string& state); void DeserializeState(const std::string& state);
private: private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册