提交 1814fc29 编写于 作者: D dzhwinter

"fix lr_policy serialization"

上级 b72e8aa3
...@@ -12,7 +12,6 @@ set(OPITMIZER_SRCS ...@@ -12,7 +12,6 @@ set(OPITMIZER_SRCS
add_library(optimizer STATIC ${OPITMIZER_SRCS}) add_library(optimizer STATIC ${OPITMIZER_SRCS})
add_dependencies(optimizer gen_proto_cpp) add_dependencies(optimizer gen_proto_cpp)
add_simple_unittest(tensor_test)
add_simple_unittest(serialization_test) add_simple_unittest(serialization_test)
add_simple_unittest(parameter_optimizer_test) add_simple_unittest(parameter_optimizer_test)
add_dependencies(parameter_optimizer_test optimizer) add_dependencies(parameter_optimizer_test optimizer)
...@@ -27,23 +27,22 @@ void AdadeltaOptimizer::Update(const Tensor* gradient) { ...@@ -27,23 +27,22 @@ void AdadeltaOptimizer::Update(const Tensor* gradient) {
const char* AdadeltaOptimizer::SerializeState(int* state_len) { const char* AdadeltaOptimizer::SerializeState(int* state_len) {
AdadeltaOptimizerState state; AdadeltaOptimizerState state;
state.set_learning_rate(lr_policy_->LearningRate(num_sample_passed_)); // TODO(zhihong) : add lr_policy serialization
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*accum_gradient_, state.mutable_accum_gradient()); TensorToProto(*accum_gradient_, state.mutable_accum_gradient());
TensorToProto(*accum_delta_, state.mutable_accum_delta()); TensorToProto(*accum_delta_, state.mutable_accum_delta());
TensorToProto(*update_delta_, state.mutable_update_delta()); TensorToProto(*update_delta_, state.mutable_update_delta());
auto str = state.SerializeAsString();
*state_len = *state_len = str.size();
CalStateSize(parameter_, accum_gradient_, accum_delta_, update_delta_); return str.c_str();
return state.SerializeAsString().c_str();
} }
void AdadeltaOptimizer::DeserializeState(const std::string& str) { void AdadeltaOptimizer::DeserializeState(const std::string& str) {
AdadeltaOptimizerState state; AdadeltaOptimizerState state;
state.ParseFromString(str); state.ParseFromString(str);
lr_policy_->set(state.learning_rate()); // TODO(zhihong) : add lr_policy DeserializeState
num_sample_passed_ = state.num_sample_passed(); num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_); ProtoToTensor(state.parameter(), parameter_);
......
...@@ -10,17 +10,13 @@ public: ...@@ -10,17 +10,13 @@ public:
AdadeltaOptimizer( AdadeltaOptimizer(
Tensor *parameter, LrPolicy *lr, double rho, double epsilon, double decay) Tensor *parameter, LrPolicy *lr, double rho, double epsilon, double decay)
: ParameterOptimizer(parameter, lr), : ParameterOptimizer(parameter, lr),
accum_gradient_(new Tensor(parameter->size())),
accum_delta_(new Tensor(parameter->size())),
update_delta_(new Tensor(parameter->size())),
rho_(rho), rho_(rho),
epsilon_(epsilon), epsilon_(epsilon),
decay_(decay) { decay_(decay) {}
size_t size = parameter->size();
if (accum_gradient_) delete accum_gradient_;
accum_gradient_ = new Tensor(size);
if (accum_delta_) delete accum_delta_;
accum_delta_ = new Tensor(size);
if (update_delta_) delete update_delta_;
update_delta_ = new Tensor(size);
}
~AdadeltaOptimizer() { ~AdadeltaOptimizer() {
if (accum_gradient_) delete accum_gradient_; if (accum_gradient_) delete accum_gradient_;
if (accum_delta_) delete accum_delta_; if (accum_delta_) delete accum_delta_;
......
...@@ -19,19 +19,20 @@ void AdagradOptimizer::Update(const Tensor* gradient) { ...@@ -19,19 +19,20 @@ void AdagradOptimizer::Update(const Tensor* gradient) {
} }
const char* AdagradOptimizer::SerializeState(int* state_len) { const char* AdagradOptimizer::SerializeState(int* state_len) {
AdagradOptimizerState state; AdagradOptimizerState state;
state.set_learning_rate(lr_policy_->LearningRate(num_sample_passed_)); // TODO(zhihong) : add lr_policy serialization
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*accum_gradient_, state.mutable_accum_gradient()); TensorToProto(*accum_gradient_, state.mutable_accum_gradient());
*state_len = CalStateSize(parameter_, accum_gradient_); auto str = state.SerializeAsString();
return state.SerializeAsString().c_str(); *state_len = str.size();
return str.c_str();
} }
void AdagradOptimizer::DeserializeState(const std::string& str) { void AdagradOptimizer::DeserializeState(const std::string& str) {
AdagradOptimizerState state; AdagradOptimizerState state;
state.ParseFromString(str); state.ParseFromString(str);
lr_policy_->set(state.learning_rate()); // TODO(zhihong) : add lr_policy DeserializeState
num_sample_passed_ = state.num_sample_passed(); num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_); ProtoToTensor(state.parameter(), parameter_);
ProtoToTensor(state.accum_gradient(), accum_gradient_); ProtoToTensor(state.accum_gradient(), accum_gradient_);
......
...@@ -11,11 +11,10 @@ public: ...@@ -11,11 +11,10 @@ public:
LrPolicy *lr, LrPolicy *lr,
double epsilon, double epsilon,
double decay) double decay)
: ParameterOptimizer(parameter, lr), epsilon_(epsilon), decay_(decay) { : ParameterOptimizer(parameter, lr),
size_t size = parameter->size(); accum_gradient_(new Tensor(parameter->size())),
if (accum_gradient_) delete accum_gradient_; epsilon_(epsilon),
accum_gradient_ = new Tensor(size); decay_(decay) {}
}
~AdagradOptimizer() { ~AdagradOptimizer() {
if (accum_gradient_) delete accum_gradient_; if (accum_gradient_) delete accum_gradient_;
} }
......
...@@ -24,20 +24,20 @@ void AdamOptimizer::Update(const Tensor *gradient) { ...@@ -24,20 +24,20 @@ void AdamOptimizer::Update(const Tensor *gradient) {
const char *AdamOptimizer::SerializeState(int *state_len) { const char *AdamOptimizer::SerializeState(int *state_len) {
AdamOptimizerState state; AdamOptimizerState state;
state.set_learning_rate(lr_policy_->LearningRate(num_sample_passed_)); // TODO(zhihong) : add lr_policy serialization
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*velocitys_, state.mutable_momentums()); TensorToProto(*velocitys_, state.mutable_momentums());
auto str = state.SerializeAsString();
*state_len = CalStateSize(parameter_, momentums_, velocitys_); *state_len = str.size();
return state.SerializeAsString().c_str(); return str.c_str();
} }
void AdamOptimizer::DeserializeState(const std::string &str) { void AdamOptimizer::DeserializeState(const std::string &str) {
AdamOptimizerState state; AdamOptimizerState state;
state.ParseFromString(str); state.ParseFromString(str);
lr_policy_->set(state.learning_rate()); // TODO(zhihong) : add lr_policy DeserializeState
num_sample_passed_ = state.num_sample_passed(); num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_); ProtoToTensor(state.parameter(), parameter_);
......
...@@ -14,14 +14,12 @@ public: ...@@ -14,14 +14,12 @@ public:
double epsilon, double epsilon,
double decay) double decay)
: ParameterOptimizer(parameter, lr), : ParameterOptimizer(parameter, lr),
momentums_(new Tensor(parameter->size())),
velocitys_(new Tensor(parameter->size())),
beta_1_(beta_1), beta_1_(beta_1),
beta_2_(beta_2), beta_2_(beta_2),
epsilon_(epsilon), epsilon_(epsilon),
decay_(decay) { decay_(decay) {}
size_t size = parameter->size();
momentums_ = new Tensor(size);
velocitys_ = new Tensor(size);
}
~AdamOptimizer() { ~AdamOptimizer() {
if (momentums_) delete momentums_; if (momentums_) delete momentums_;
if (velocitys_) delete velocitys_; if (velocitys_) delete velocitys_;
......
...@@ -10,7 +10,8 @@ class LrPolicy { ...@@ -10,7 +10,8 @@ class LrPolicy {
public: public:
virtual ~LrPolicy() {} virtual ~LrPolicy() {}
virtual double LearningRate(const uint64_t num_sample_passed) = 0; virtual double LearningRate(const uint64_t num_sample_passed) = 0;
virtual void set(double current_learning_rate) = 0; virtual const char *SerializeState(int *state_len) = 0;
virtual void DeserializeState(const std::string &state) = 0;
}; };
// constant learning rate policy // constant learning rate policy
...@@ -20,9 +21,8 @@ public: ...@@ -20,9 +21,8 @@ public:
double LearningRate(const uint64_t num_sample_passed) { double LearningRate(const uint64_t num_sample_passed) {
return learning_rate; return learning_rate;
} }
void set(double current_learning_rate) { const char *SerializeState(int *state_len);
learning_rate = current_learning_rate; void DeserializeState(const std::string &state);
}
private: private:
double learning_rate; double learning_rate;
...@@ -35,9 +35,8 @@ public: ...@@ -35,9 +35,8 @@ public:
double LearningRate(const uint64_t num_sample_passed) { double LearningRate(const uint64_t num_sample_passed) {
return std::max(learning_rate - lr_decay_a * num_sample_passed, lr_decay_b); return std::max(learning_rate - lr_decay_a * num_sample_passed, lr_decay_b);
} }
void set(double current_learning_rate) { const char *SerializeState(int *state_len);
learning_rate = current_learning_rate; void DeserializeState(const std::string &state);
}
private: private:
double learning_rate; double learning_rate;
......
...@@ -19,7 +19,10 @@ public: ...@@ -19,7 +19,10 @@ public:
*/ */
ParameterOptimizer(Tensor *parameter, LrPolicy *lr) ParameterOptimizer(Tensor *parameter, LrPolicy *lr)
: parameter_(parameter), lr_policy_(lr), num_sample_passed_(0) {} : parameter_(parameter), lr_policy_(lr), num_sample_passed_(0) {}
virtual ~ParameterOptimizer() { delete parameter_; }; virtual ~ParameterOptimizer() {
delete parameter_;
delete lr_policy_;
}
static ParameterOptimizer *Create(const std::string &config_proto, static ParameterOptimizer *Create(const std::string &config_proto,
Tensor *parameter); Tensor *parameter);
......
...@@ -10,18 +10,6 @@ ...@@ -10,18 +10,6 @@
namespace paddle { namespace paddle {
namespace optimizer { namespace optimizer {
static unsigned CalStateSize() { return 0; }
template <typename HEAD, typename... TAIL>
unsigned CalStateSize(const HEAD& head, const TAIL&... tail) {
return sizeof head + CalStateSize(tail...);
}
template <typename... TAIL>
unsigned CalStateSize(const Tensor* head, const TAIL&... tail) {
return head->size() + CalStateSize(tail...);
}
static void TensorToProto(const Tensor& tensor, TensorProto* proto) { static void TensorToProto(const Tensor& tensor, TensorProto* proto) {
proto->set_data_type(TensorProto::PADDLE_ELEMENT_TYPE_FLOAT32); proto->set_data_type(TensorProto::PADDLE_ELEMENT_TYPE_FLOAT32);
std::stringstream os; std::stringstream os;
......
...@@ -29,19 +29,20 @@ void SGDOptimizer::Update(const Tensor *gradient) { ...@@ -29,19 +29,20 @@ void SGDOptimizer::Update(const Tensor *gradient) {
const char *SGDOptimizer::SerializeState(int *state_len) { const char *SGDOptimizer::SerializeState(int *state_len) {
SGDOptimizerState state; SGDOptimizerState state;
state.set_learning_rate(lr_policy_->LearningRate(num_sample_passed_)); // TODO(zhihong) : add lr_policy serialization
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*momentums_, state.mutable_momentums()); TensorToProto(*momentums_, state.mutable_momentums());
*state_len = CalStateSize(parameter_, momentums_); auto str = state.SerializeAsString();
return state.SerializeAsString().c_str(); *state_len = str.size();
return str.c_str();
} }
void SGDOptimizer::DeserializeState(const std::string &str) { void SGDOptimizer::DeserializeState(const std::string &str) {
SGDOptimizerState state; SGDOptimizerState state;
state.ParseFromString(str); state.ParseFromString(str);
lr_policy_->set(state.learning_rate()); // TODO(zhihong) : add lr_policy DeserializeState
num_sample_passed_ = state.num_sample_passed(); num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_); ProtoToTensor(state.parameter(), parameter_);
......
...@@ -16,7 +16,6 @@ public: ...@@ -16,7 +16,6 @@ public:
if (momentum_ != 0.0) { if (momentum_ != 0.0) {
size_t size = parameter->size(); size_t size = parameter->size();
// TODO: fix it with align aware allocator bind to Tensor // TODO: fix it with align aware allocator bind to Tensor
if (momentums_) delete momentums_;
momentums_ = new Tensor(size); momentums_ = new Tensor(size);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部