提交 1814fc29 编写于 作者: D dzhwinter

"fix lr_policy serialization"

上级 b72e8aa3
......@@ -12,7 +12,6 @@ set(OPITMIZER_SRCS
add_library(optimizer STATIC ${OPITMIZER_SRCS})
add_dependencies(optimizer gen_proto_cpp)
add_simple_unittest(tensor_test)
add_simple_unittest(serialization_test)
add_simple_unittest(parameter_optimizer_test)
add_dependencies(parameter_optimizer_test optimizer)
......@@ -27,23 +27,22 @@ void AdadeltaOptimizer::Update(const Tensor* gradient) {
const char* AdadeltaOptimizer::SerializeState(int* state_len) {
AdadeltaOptimizerState state;
state.set_learning_rate(lr_policy_->LearningRate(num_sample_passed_));
// TODO(zhihong) : add lr_policy serialization
state.set_num_sample_passed(num_sample_passed_);
TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*accum_gradient_, state.mutable_accum_gradient());
TensorToProto(*accum_delta_, state.mutable_accum_delta());
TensorToProto(*update_delta_, state.mutable_update_delta());
*state_len =
CalStateSize(parameter_, accum_gradient_, accum_delta_, update_delta_);
return state.SerializeAsString().c_str();
auto str = state.SerializeAsString();
*state_len = str.size();
return str.c_str();
}
void AdadeltaOptimizer::DeserializeState(const std::string& str) {
AdadeltaOptimizerState state;
state.ParseFromString(str);
lr_policy_->set(state.learning_rate());
// TODO(zhihong) : add lr_policy DeserializeState
num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_);
......
......@@ -10,17 +10,13 @@ public:
AdadeltaOptimizer(
Tensor *parameter, LrPolicy *lr, double rho, double epsilon, double decay)
: ParameterOptimizer(parameter, lr),
accum_gradient_(new Tensor(parameter->size())),
accum_delta_(new Tensor(parameter->size())),
update_delta_(new Tensor(parameter->size())),
rho_(rho),
epsilon_(epsilon),
decay_(decay) {
size_t size = parameter->size();
if (accum_gradient_) delete accum_gradient_;
accum_gradient_ = new Tensor(size);
if (accum_delta_) delete accum_delta_;
accum_delta_ = new Tensor(size);
if (update_delta_) delete update_delta_;
update_delta_ = new Tensor(size);
}
decay_(decay) {}
~AdadeltaOptimizer() {
if (accum_gradient_) delete accum_gradient_;
if (accum_delta_) delete accum_delta_;
......
......@@ -19,19 +19,20 @@ void AdagradOptimizer::Update(const Tensor* gradient) {
}
const char* AdagradOptimizer::SerializeState(int* state_len) {
AdagradOptimizerState state;
state.set_learning_rate(lr_policy_->LearningRate(num_sample_passed_));
// TODO(zhihong) : add lr_policy serialization
state.set_num_sample_passed(num_sample_passed_);
TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*accum_gradient_, state.mutable_accum_gradient());
*state_len = CalStateSize(parameter_, accum_gradient_);
return state.SerializeAsString().c_str();
auto str = state.SerializeAsString();
*state_len = str.size();
return str.c_str();
}
void AdagradOptimizer::DeserializeState(const std::string& str) {
AdagradOptimizerState state;
state.ParseFromString(str);
lr_policy_->set(state.learning_rate());
// TODO(zhihong) : add lr_policy DeserializeState
num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_);
ProtoToTensor(state.accum_gradient(), accum_gradient_);
......
......@@ -11,11 +11,10 @@ public:
LrPolicy *lr,
double epsilon,
double decay)
: ParameterOptimizer(parameter, lr), epsilon_(epsilon), decay_(decay) {
size_t size = parameter->size();
if (accum_gradient_) delete accum_gradient_;
accum_gradient_ = new Tensor(size);
}
: ParameterOptimizer(parameter, lr),
accum_gradient_(new Tensor(parameter->size())),
epsilon_(epsilon),
decay_(decay) {}
~AdagradOptimizer() {
if (accum_gradient_) delete accum_gradient_;
}
......
......@@ -24,20 +24,20 @@ void AdamOptimizer::Update(const Tensor *gradient) {
const char *AdamOptimizer::SerializeState(int *state_len) {
AdamOptimizerState state;
state.set_learning_rate(lr_policy_->LearningRate(num_sample_passed_));
// TODO(zhihong) : add lr_policy serialization
state.set_num_sample_passed(num_sample_passed_);
TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*velocitys_, state.mutable_momentums());
*state_len = CalStateSize(parameter_, momentums_, velocitys_);
return state.SerializeAsString().c_str();
auto str = state.SerializeAsString();
*state_len = str.size();
return str.c_str();
}
void AdamOptimizer::DeserializeState(const std::string &str) {
AdamOptimizerState state;
state.ParseFromString(str);
lr_policy_->set(state.learning_rate());
// TODO(zhihong) : add lr_policy DeserializeState
num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_);
......
......@@ -14,14 +14,12 @@ public:
double epsilon,
double decay)
: ParameterOptimizer(parameter, lr),
momentums_(new Tensor(parameter->size())),
velocitys_(new Tensor(parameter->size())),
beta_1_(beta_1),
beta_2_(beta_2),
epsilon_(epsilon),
decay_(decay) {
size_t size = parameter->size();
momentums_ = new Tensor(size);
velocitys_ = new Tensor(size);
}
decay_(decay) {}
~AdamOptimizer() {
if (momentums_) delete momentums_;
if (velocitys_) delete velocitys_;
......
......@@ -10,7 +10,8 @@ class LrPolicy {
public:
virtual ~LrPolicy() {}
virtual double LearningRate(const uint64_t num_sample_passed) = 0;
virtual void set(double current_learning_rate) = 0;
virtual const char *SerializeState(int *state_len) = 0;
virtual void DeserializeState(const std::string &state) = 0;
};
// constant learning rate policy
......@@ -20,9 +21,8 @@ public:
double LearningRate(const uint64_t num_sample_passed) {
return learning_rate;
}
void set(double current_learning_rate) {
learning_rate = current_learning_rate;
}
const char *SerializeState(int *state_len);
void DeserializeState(const std::string &state);
private:
double learning_rate;
......@@ -35,9 +35,8 @@ public:
double LearningRate(const uint64_t num_sample_passed) {
return std::max(learning_rate - lr_decay_a * num_sample_passed, lr_decay_b);
}
void set(double current_learning_rate) {
learning_rate = current_learning_rate;
}
const char *SerializeState(int *state_len);
void DeserializeState(const std::string &state);
private:
double learning_rate;
......
......@@ -19,7 +19,10 @@ public:
*/
ParameterOptimizer(Tensor *parameter, LrPolicy *lr)
: parameter_(parameter), lr_policy_(lr), num_sample_passed_(0) {}
virtual ~ParameterOptimizer() { delete parameter_; };
virtual ~ParameterOptimizer() {
delete parameter_;
delete lr_policy_;
}
static ParameterOptimizer *Create(const std::string &config_proto,
Tensor *parameter);
......
......@@ -10,18 +10,6 @@
namespace paddle {
namespace optimizer {
static unsigned CalStateSize() { return 0; }
template <typename HEAD, typename... TAIL>
unsigned CalStateSize(const HEAD& head, const TAIL&... tail) {
return sizeof head + CalStateSize(tail...);
}
template <typename... TAIL>
unsigned CalStateSize(const Tensor* head, const TAIL&... tail) {
return head->size() + CalStateSize(tail...);
}
static void TensorToProto(const Tensor& tensor, TensorProto* proto) {
proto->set_data_type(TensorProto::PADDLE_ELEMENT_TYPE_FLOAT32);
std::stringstream os;
......
......@@ -29,19 +29,20 @@ void SGDOptimizer::Update(const Tensor *gradient) {
const char *SGDOptimizer::SerializeState(int *state_len) {
SGDOptimizerState state;
state.set_learning_rate(lr_policy_->LearningRate(num_sample_passed_));
// TODO(zhihong) : add lr_policy serialization
state.set_num_sample_passed(num_sample_passed_);
TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*momentums_, state.mutable_momentums());
*state_len = CalStateSize(parameter_, momentums_);
return state.SerializeAsString().c_str();
auto str = state.SerializeAsString();
*state_len = str.size();
return str.c_str();
}
void SGDOptimizer::DeserializeState(const std::string &str) {
SGDOptimizerState state;
state.ParseFromString(str);
lr_policy_->set(state.learning_rate());
// TODO(zhihong) : add lr_policy DeserializeState
num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_);
......
......@@ -16,7 +16,6 @@ public:
if (momentum_ != 0.0) {
size_t size = parameter->size();
// TODO: fix it with align aware allocator bind to Tensor
if (momentums_) delete momentums_;
momentums_ = new Tensor(size);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册