diff --git a/cmake/util.cmake b/cmake/util.cmake index 8c9143462227e7081142f6be250b1a45e4b6d51b..87ad9d91d8701c56255c1e7f224764998df634a7 100644 --- a/cmake/util.cmake +++ b/cmake/util.cmake @@ -84,6 +84,7 @@ function(link_paddle_exe TARGET_NAME) paddle_parameter paddle_proto paddle_cuda + paddle_optimizer ${EXTERNAL_LIBS} ${CMAKE_THREAD_LIBS_INIT} ${CMAKE_DL_LIBS} diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index 47ca1833967ee705d6558b1dad06a6335b30f03a..bf1b01309ec5d023af33f6d9d6da8a8d7295d8d1 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -8,6 +8,7 @@ add_subdirectory(gserver) add_subdirectory(pserver) add_subdirectory(trainer) add_subdirectory(scripts) +add_subdirectory(optimizer) add_subdirectory(strings) # Do not build go directory until go cmake is working smoothly. diff --git a/paddle/optimizer/CMakeLists.txt b/paddle/optimizer/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..791be95efa9e8063a2f4b9632d86db2f5bc751e6 --- /dev/null +++ b/paddle/optimizer/CMakeLists.txt @@ -0,0 +1,16 @@ +include_directories(${CMAKE_CURRENT_BINARY_DIR}) + +set(OPITMIZER_SRCS + adadelta_optimizer.cc + adagrad_optimizer.cc + adam_optimizer.cc + optimizer.cc + parameter_optimizer.cc + sgd_optimizer.cc + ) + +add_library(paddle_optimizer STATIC ${OPITMIZER_SRCS}) +add_dependencies(paddle_optimizer gen_proto_cpp) + +add_simple_unittest(serialization_test) +add_simple_unittest(parameter_optimizer_test) diff --git a/paddle/optimizer/adadelta_optimizer.cc b/paddle/optimizer/adadelta_optimizer.cc new file mode 100644 index 0000000000000000000000000000000000000000..465ad5e0d2089121a0f11ab916afe0420cbcfab7 --- /dev/null +++ b/paddle/optimizer/adadelta_optimizer.cc @@ -0,0 +1,55 @@ +#include "adadelta_optimizer.h" +#include +#include + +namespace paddle { +namespace optimizer { + +void AdadeltaOptimizer::Update(const Tensor* gradient) { + num_sample_passed_ += 1; + double learning_rate = lr_policy_->LearningRate(num_sample_passed_); + Tensor& param = *parameter_; + const Tensor& grad = *gradient; + Tensor& accum_g = *accum_gradient_; + Tensor& accum_d = *accum_delta_; + Tensor& update_d = *update_delta_; + for (size_t i = 0; i < param.size(); ++i) { + accum_g[i] = rho_ * accum_g[i] + (1.0 - rho_) * grad[i] * grad[i]; + + update_d[i] = std::sqrt(accum_d[i] + epsilon_) / + std::sqrt(accum_g[i] + epsilon_) * grad[i]; + + accum_d[i] = rho_ * accum_d[i] + (1.0 - rho_) * update_d[i] * update_d[i]; + + param[i] -= learning_rate * update_d[i] + learning_rate * decay_ * param[i]; + } +} + +const char* AdadeltaOptimizer::SerializeState(int* state_len) { + AdadeltaOptimizerState state; + // TODO(zhihong) : add lr_policy serialization + state.set_num_sample_passed(num_sample_passed_); + + TensorToProto(*parameter_, state.mutable_parameter()); + TensorToProto(*accum_gradient_, state.mutable_accum_gradient()); + TensorToProto(*accum_delta_, state.mutable_accum_delta()); + TensorToProto(*update_delta_, state.mutable_update_delta()); + auto str = state.SerializeAsString(); + *state_len = str.size(); + return str.c_str(); +} + +void AdadeltaOptimizer::DeserializeState(const std::string& str) { + AdadeltaOptimizerState state; + state.ParseFromString(str); + // TODO(zhihong) : add lr_policy DeserializeState + num_sample_passed_ = state.num_sample_passed(); + + ProtoToTensor(state.parameter(), parameter_); + ProtoToTensor(state.accum_gradient(), accum_gradient_); + ProtoToTensor(state.accum_delta(), accum_delta_); + ProtoToTensor(state.update_delta(), update_delta_); +} + +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/adadelta_optimizer.h b/paddle/optimizer/adadelta_optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..1d5eab097f57d049855dd171a1aa6f74c48ae0e7 --- /dev/null +++ b/paddle/optimizer/adadelta_optimizer.h @@ -0,0 +1,39 @@ +#pragma once + +#include "parameter_optimizer.h" + +namespace paddle { +namespace optimizer { + +class AdadeltaOptimizer : public ParameterOptimizer { +public: + AdadeltaOptimizer( + Tensor *parameter, LrPolicy *lr, double rho, double epsilon, double decay) + : ParameterOptimizer(parameter, lr), + accum_gradient_(new Tensor(parameter->size())), + accum_delta_(new Tensor(parameter->size())), + update_delta_(new Tensor(parameter->size())), + rho_(rho), + epsilon_(epsilon), + decay_(decay) {} + + ~AdadeltaOptimizer() { + if (accum_gradient_) delete accum_gradient_; + if (accum_delta_) delete accum_delta_; + if (update_delta_) delete update_delta_; + } + void Update(const Tensor *gradient); + const char *SerializeState(int *state_len); + void DeserializeState(const std::string &state); + +private: + Tensor *accum_gradient_; + Tensor *accum_delta_; + Tensor *update_delta_; + double rho_; + double epsilon_; + double decay_; +}; + +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/adagrad_optimizer.cc b/paddle/optimizer/adagrad_optimizer.cc new file mode 100644 index 0000000000000000000000000000000000000000..bdaa7877d2bc58c17c51b977852d4b6fec511ed2 --- /dev/null +++ b/paddle/optimizer/adagrad_optimizer.cc @@ -0,0 +1,42 @@ +#include + +#include "adagrad_optimizer.h" + +namespace paddle { +namespace optimizer { + +void AdagradOptimizer::Update(const Tensor* gradient) { + num_sample_passed_ += 1; + double learning_rate = lr_policy_->LearningRate(num_sample_passed_); + Tensor& param = *parameter_; + Tensor& accum_g = *accum_gradient_; + const Tensor& grad = *gradient; + for (size_t i = 0; i < param.size(); ++i) { + accum_g[i] += grad[i] * grad[i]; + param[i] += learning_rate * grad[i] / std::sqrt(accum_g[i] + epsilon_) + + learning_rate * decay_ * param[i]; + } +} +const char* AdagradOptimizer::SerializeState(int* state_len) { + AdagradOptimizerState state; + // TODO(zhihong) : add lr_policy serialization + state.set_num_sample_passed(num_sample_passed_); + + TensorToProto(*parameter_, state.mutable_parameter()); + TensorToProto(*accum_gradient_, state.mutable_accum_gradient()); + auto str = state.SerializeAsString(); + *state_len = str.size(); + return str.c_str(); +} + +void AdagradOptimizer::DeserializeState(const std::string& str) { + AdagradOptimizerState state; + state.ParseFromString(str); + // TODO(zhihong) : add lr_policy DeserializeState + num_sample_passed_ = state.num_sample_passed(); + ProtoToTensor(state.parameter(), parameter_); + ProtoToTensor(state.accum_gradient(), accum_gradient_); +} + +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/adagrad_optimizer.h b/paddle/optimizer/adagrad_optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..15d0a965ad0c6967e73b14b465168fa66eb8fba3 --- /dev/null +++ b/paddle/optimizer/adagrad_optimizer.h @@ -0,0 +1,32 @@ +#pragma once + +#include "parameter_optimizer.h" + +namespace paddle { +namespace optimizer { + +class AdagradOptimizer : public ParameterOptimizer { +public: + AdagradOptimizer(Tensor *parameter, + LrPolicy *lr, + double epsilon, + double decay) + : ParameterOptimizer(parameter, lr), + accum_gradient_(new Tensor(parameter->size())), + epsilon_(epsilon), + decay_(decay) {} + ~AdagradOptimizer() { + if (accum_gradient_) delete accum_gradient_; + } + void Update(const Tensor *gradient); + const char *SerializeState(int *state_len); + void DeserializeState(const std::string &state); + +private: + Tensor *accum_gradient_; + double epsilon_; + double decay_; +}; + +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/adam_optimizer.cc b/paddle/optimizer/adam_optimizer.cc new file mode 100644 index 0000000000000000000000000000000000000000..ceab7397d87349c64ca9e5d11990cb38068421be --- /dev/null +++ b/paddle/optimizer/adam_optimizer.cc @@ -0,0 +1,48 @@ +#include "adam_optimizer.h" +#include + +namespace paddle { +namespace optimizer { + +void AdamOptimizer::Update(const Tensor *gradient) { + num_sample_passed_ += 1; + double learning_rate = lr_policy_->LearningRate(num_sample_passed_); + double coef1 = 1.0 - std::pow(beta_1_, num_sample_passed_); + double coef2 = 1.0 - std::pow(beta_2_, num_sample_passed_); + learning_rate *= std::sqrt(coef2) / coef1; + Tensor ¶m = *parameter_; + const Tensor &grad = *gradient; + Tensor &m = *momentums_; + Tensor &v = *velocitys_; + for (size_t i = 0; i < param.size(); ++i) { + m[i] = beta_1_ * m[i] + (1.0 - beta_1_) * grad[i]; + v[i] = beta_2_ * v[i] + (1.0 - beta_2_) * grad[i] * grad[i]; + param[i] -= + learning_rate * (m[i] / std::sqrt(v[i] + epsilon_) + decay_ * param[i]); + } +} + +const char *AdamOptimizer::SerializeState(int *state_len) { + AdamOptimizerState state; + // TODO(zhihong) : add lr_policy serialization + state.set_num_sample_passed(num_sample_passed_); + TensorToProto(*parameter_, state.mutable_parameter()); + TensorToProto(*momentums_, state.mutable_momentums()); + TensorToProto(*velocitys_, state.mutable_velocitys()); + auto str = state.SerializeAsString(); + *state_len = str.size(); + return str.c_str(); +} + +void AdamOptimizer::DeserializeState(const std::string &str) { + AdamOptimizerState state; + state.ParseFromString(str); + // TODO(zhihong) : add lr_policy DeserializeState + num_sample_passed_ = state.num_sample_passed(); + + ProtoToTensor(state.parameter(), parameter_); + ProtoToTensor(state.momentums(), momentums_); + ProtoToTensor(state.velocitys(), velocitys_); +} +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/adam_optimizer.h b/paddle/optimizer/adam_optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..0ea4c8bb8470504282b4d6c12039791ce896e401 --- /dev/null +++ b/paddle/optimizer/adam_optimizer.h @@ -0,0 +1,41 @@ +#pragma once + +#include "parameter_optimizer.h" + +namespace paddle { +namespace optimizer { + +class AdamOptimizer : public ParameterOptimizer { +public: + AdamOptimizer(Tensor *parameter, + LrPolicy *lr, + double beta_1, + double beta_2, + double epsilon, + double decay) + : ParameterOptimizer(parameter, lr), + momentums_(new Tensor(parameter->size())), + velocitys_(new Tensor(parameter->size())), + beta_1_(beta_1), + beta_2_(beta_2), + epsilon_(epsilon), + decay_(decay) {} + ~AdamOptimizer() { + if (momentums_) delete momentums_; + if (velocitys_) delete velocitys_; + } + void Update(const Tensor *gradient); + const char *SerializeState(int *state_len); + void DeserializeState(const std::string &state); + +private: + Tensor *momentums_; + Tensor *velocitys_; + double beta_1_; + double beta_2_; + double epsilon_; + double decay_; +}; + +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/lr_policy.h b/paddle/optimizer/lr_policy.h new file mode 100644 index 0000000000000000000000000000000000000000..d8e33ad37ab4c019a36f63f34babe65cf8c8fb16 --- /dev/null +++ b/paddle/optimizer/lr_policy.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include "OptimizerConfig.pb.h" + +namespace paddle { +namespace optimizer { + +class LrPolicy { +public: + virtual ~LrPolicy() {} + virtual double LearningRate(const uint64_t num_sample_passed) = 0; + virtual const char *SerializeState(int *state_len) = 0; + virtual void DeserializeState(const std::string &state) = 0; +}; + +// constant learning rate policy +class ConstLr final : public LrPolicy { +public: + ConstLr(double lr) : learning_rate(lr){}; + double LearningRate(const uint64_t num_sample_passed) { + return learning_rate; + } + const char *SerializeState(int *state_len) { return nullptr; } + void DeserializeState(const std::string &state) {} + +private: + double learning_rate; +}; + +class LinearLr final : public LrPolicy { +public: + LinearLr(double lr, double lr_decay_a, double lr_decay_b) + : learning_rate(lr), lr_decay_a(lr_decay_a), lr_decay_b(lr_decay_b) {} + double LearningRate(const uint64_t num_sample_passed) { + return std::max(learning_rate - lr_decay_a * num_sample_passed, lr_decay_b); + } + const char *SerializeState(int *state_len) { + // TODO(zhihong) : add lr_policy serialization + return nullptr; + } + void DeserializeState(const std::string &state) { + // TODO(zhihong) : add lr_policy serialization + } + +private: + double learning_rate; + double lr_decay_a; + double lr_decay_b; +}; + +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/optimizer.cc b/paddle/optimizer/optimizer.cc new file mode 100644 index 0000000000000000000000000000000000000000..54662dc37891d3211950453b210db4b475837df4 --- /dev/null +++ b/paddle/optimizer/optimizer.cc @@ -0,0 +1,83 @@ +#include "optimizer.h" +#include + +#include "parameter_optimizer.h" + +using namespace paddle; +using namespace paddle::optimizer; + +template +struct EnumToType {}; + +template +struct TypeToEnum {}; + +#define MATCH_ENUM_TYPE(TYPE, ENUM) \ + template <> \ + struct TypeToEnum { \ + static paddle_element_type v() { return ENUM; }; \ + static constexpr TYPE value = ENUM; \ + }; \ + template <> \ + struct EnumToType { \ + typedef TYPE Type; \ + } + +MATCH_ENUM_TYPE(int32_t, PADDLE_ELEMENT_TYPE_INT32); +MATCH_ENUM_TYPE(uint32_t, PADDLE_ELEMENT_TYPE_UINT32); +MATCH_ENUM_TYPE(int64_t, PADDLE_ELEMENT_TYPE_INT64); +MATCH_ENUM_TYPE(uint64_t, PADDLE_ELEMENT_TYPE_UINT64); +// TODO(zhihong): only implement below type, need to fix +MATCH_ENUM_TYPE(float, PADDLE_ELEMENT_TYPE_FLOAT32); +MATCH_ENUM_TYPE(double, PADDLE_ELEMENT_TYPE_FLOAT64); + +struct paddle_optimizer { + paddle::optimizer::ParameterOptimizer* impl; +}; + +paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto, + const int config_proto_len, + const paddle_element_type data_type, + void* param_buffer, + int num_bytes, + const char* state, + const int state_len) { + paddle_optimizer* optimizer = new paddle_optimizer; + std::string config(config_proto, config_proto + config_proto_len); + Tensor* parameter = + new Tensor(reinterpret_cast(param_buffer), num_bytes); + optimizer->impl = ParameterOptimizer::Create(config, parameter); + if (state != nullptr) { + std::string s(state, state + state_len); + optimizer->impl->DeserializeState(s); + } + return optimizer; +} + +int paddle_release_optimizer(paddle_optimizer* o) { + if (o != nullptr) delete o->impl; + return PADDLE_SUCCESS; +} + +int paddle_update_parameter(paddle_optimizer* o, + const paddle_element_type data_type, + const void* grad_buffer, + int num_bytes) { + // TOOD(zhihong): datatype not work. need to add the runtime datatype + auto grad_type = reinterpret_cast(grad_buffer); + Tensor* gradient = new Tensor(const_cast(grad_type), num_bytes); + o->impl->Update(gradient); + return PADDLE_SUCCESS; +} + +int paddle_optimizer_get_weights(paddle_optimizer* o, void** param_buffer) { + int param_size = 0; + *param_buffer = (void*)o->impl->get_weight(¶m_size); + return param_size; +} + +int paddle_optimizer_get_state(paddle_optimizer* o, const char** state) { + int state_len = 0; + *state = o->impl->SerializeState(&state_len); + return state_len; +} diff --git a/paddle/optimizer/optimizer.h b/paddle/optimizer/optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..aabf7a458dd30092ed1e522c4d88c6cfe63fcce1 --- /dev/null +++ b/paddle/optimizer/optimizer.h @@ -0,0 +1,93 @@ +#pragma once + +#include +#include + +/** + * @brief optimizer library in independent with other module + * which will be used in : + * Case A, the gradient optimized locally on the trainer. + * + * Case B, the gradient optimized on the parameter server. + */ + +#ifdef __cplusplus +extern "C" { +#endif + +typedef enum { + PADDLE_ELEMENT_TYPE_INT32 = 0, + PADDLE_ELEMENT_TYPE_UINT32 = 1, + PADDLE_ELEMENT_TYPE_INT64 = 2, + PADDLE_ELEMENT_TYPE_UINT64 = 3, + PADDLE_ELEMENT_TYPE_FLOAT32 = 4, + PADDLE_ELEMENT_TYPE_FLOAT64 = 5, +} paddle_element_type; + +/** + * @brief execution status code + */ +const int32_t PADDLE_SUCCESS = 0; +const int32_t PADDLE_ERROR = -1; + +typedef struct paddle_optimizer paddle_optimizer; +/** + * this group interface called in order : + * 1. create optimizer with config + * 2. set weights + * 3. update_parameter + * 4. get_weights + * 5. release optimizer + */ + +/** + * @brief create optimizer with proto_config + * @param config_proto, optimizer protobuf, see OptimizerConfig.proto in detail + * @return return optimizer instance + */ +paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto, + const int config_proto_len, + const paddle_element_type data_type, + void* param_buffer, + int num_bytes, + const char* state, + const int state_len); + +/** + * @brief release optimizer + * @param optimizer + * @return return exec status + */ +int paddle_release_optimizer(paddle_optimizer* o); + +/** + * @brief optimizer instance + * @param datatype of gradient and parameter + * @param gradient, calculate by optimzizer caller. + * TODO(zhihong): just pass loss to reduce communicate overhead. + * Project Adam Ms'14 paper for detail + * @param num_bytes, gradient size + * @return return exec status + */ +int paddle_update_parameter(paddle_optimizer* o, + const paddle_element_type data_type, + const void* gradient, + int num_bytes); + +/** + * @brief optimizer for get parameter buffer + * @param param_buffer, initilized parameter buffer + * @return return content length + */ +int paddle_optimizer_get_weights(paddle_optimizer* o, void** param_buffer); + +/** + * @brief optimzizer for saving training state + * @param training state for receive SerializeState + * @return return state_buffer length + */ +int paddle_optimizer_get_state(paddle_optimizer* o, const char** state); + +#ifdef __cplusplus +} +#endif diff --git a/paddle/optimizer/parameter_optimizer.cc b/paddle/optimizer/parameter_optimizer.cc new file mode 100644 index 0000000000000000000000000000000000000000..f6218037925649e741d17f49af972ce2d50f8d3d --- /dev/null +++ b/paddle/optimizer/parameter_optimizer.cc @@ -0,0 +1,74 @@ +#include +#include "adadelta_optimizer.h" +#include "adagrad_optimizer.h" +#include "adam_optimizer.h" +#include "lr_policy.h" +#include "sgd_optimizer.h" + +#include "parameter_optimizer.h" + +namespace paddle { +namespace optimizer { + +ParameterOptimizer *ParameterOptimizer::Create(const std::string &config_proto, + Tensor *parameter) { + paddle::OptimizerConfig config; + CHECK(config.ParseFromString(config_proto) == true) + << "failed parse optimizer config"; + auto select_lr_policy = [=](const OptimizerConfig &config) -> LrPolicy * { + if (config.lr_policy() == OptimizerConfig::Const) + return new ConstLr(config.const_lr().learning_rate()); + if (config.lr_policy() == OptimizerConfig::Linear) + return new LinearLr(config.linear_lr().learning_rate(), + config.linear_lr().lr_decay_a(), + config.linear_lr().lr_decay_b()); + // default + LOG(WARNING) << " have not select any LrPolicy. use ConstLr in default"; + return new ConstLr(0.1); + }; + + LrPolicy *lr = select_lr_policy(config); + auto select_optimizer = [=]( + Tensor *parameter, + const OptimizerConfig &config) -> ParameterOptimizer * { + if (config.optimizer() == OptimizerConfig::SGD) { + return new SGDOptimizer(parameter, + lr, + config.sgd().momentum(), + config.sgd().decay(), + config.sgd().nesterov()); + } + if (config.optimizer() == OptimizerConfig::Adadelta) { + return new AdadeltaOptimizer(parameter, + lr, + config.adadelta().rho(), + config.adadelta().epsilon(), + config.adadelta().decay()); + } + if (config.optimizer() == OptimizerConfig::Adagrad) { + return new AdagradOptimizer( + parameter, lr, config.adagrad().epsilon(), config.adagrad().decay()); + } + if (config.optimizer() == OptimizerConfig::Adam) { + return new AdamOptimizer(parameter, + lr, + config.adam().beta_1(), + config.adam().beta_2(), + config.adam().epsilon(), + config.adam().decay()); + } + // default + LOG(WARNING) + << "have not select any Optimizer. use SGDOptimizer in default"; + return new SGDOptimizer(parameter, lr, 0.0, 0.0, false); + }; + return select_optimizer(parameter, config); +} + +float *ParameterOptimizer::get_weight(int *param_size) const { + *param_size = (int)parameter_->size(); + return parameter_->get_buffer(); +} + +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/parameter_optimizer.h b/paddle/optimizer/parameter_optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..d89c9abb791f947172078d4dce5b1c366852591b --- /dev/null +++ b/paddle/optimizer/parameter_optimizer.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include +#include +#include "OptimizerConfig.pb.h" +#include "lr_policy.h" +#include "serialization.h" +#include "tensor.h" + +namespace paddle { +namespace optimizer { + +class ParameterOptimizer { +public: + /** + * @brief update hook for algorithm need to traverse parameter more than + * once. + */ + ParameterOptimizer(Tensor *parameter, LrPolicy *lr) + : parameter_(parameter), lr_policy_(lr), num_sample_passed_(0) {} + virtual ~ParameterOptimizer() { + delete parameter_; + delete lr_policy_; + } + + static ParameterOptimizer *Create(const std::string &config_proto, + Tensor *parameter); + virtual void Update(const Tensor *gradient) = 0; + virtual float *get_weight(int *param_size) const; + virtual const char *SerializeState(int *state_len) = 0; + virtual void DeserializeState(const std::string &state) = 0; + +protected: + Tensor *parameter_; + // learning rate policy + LrPolicy *lr_policy_; + uint64_t num_sample_passed_; +}; + +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/parameter_optimizer_test.cpp b/paddle/optimizer/parameter_optimizer_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4e6254d9e4dab48279b4a880695959526d30d70c --- /dev/null +++ b/paddle/optimizer/parameter_optimizer_test.cpp @@ -0,0 +1,107 @@ +#include "parameter_optimizer.h" +#include +#include +#include +#include "gtest/gtest.h" +#include "lr_policy.h" + +using namespace paddle; +using namespace paddle::optimizer; + +Tensor* FillTensor(size_t size) { + Tensor* param = new Tensor(size); + Tensor& p = *param; + for (size_t i = 0; i < p.size(); ++i) { + p[i] = (float)rand() / (float)RAND_MAX; + } + return param; +} + +Tensor* FixedTensor(size_t size) { + Tensor* param = new Tensor(size); + Tensor& p = *param; + for (size_t i = 0; i < p.size(); ++i) { + p[i] = i; + } + return param; +} + +class OptimizerTest : public testing::Test { +public: + // init tensor shape + const size_t kSize = 5; + + virtual void SetUp() { + CreateSGD(); + CreateAdam(); + } + virtual void TearDown() {} + + void CreateSGD() { + Tensor* parameter = FixedTensor(kSize); + config_.set_optimizer(OptimizerConfig::SGD); + config_.mutable_sgd()->set_momentum(0.0); + config_.mutable_sgd()->set_decay(0.0); + config_.mutable_sgd()->set_nesterov(false); + config_.set_lr_policy(OptimizerConfig::Const); + config_.mutable_const_lr()->set_learning_rate(0.1); + std::string str = config_.SerializeAsString(); + ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter); + opts_.push_back(opt); + } + + void CreateAdam() { + Tensor* parameter = FixedTensor(kSize); + config_.set_optimizer(OptimizerConfig::Adam); + config_.mutable_adam()->set_beta_1(0.9); + config_.mutable_adam()->set_beta_2(0.1); + config_.mutable_adam()->set_epsilon(1e-3); + config_.mutable_adam()->set_decay(0.0); + config_.set_lr_policy(OptimizerConfig::Const); + config_.mutable_const_lr()->set_learning_rate(0.1); + std::string str = config_.SerializeAsString(); + ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter); + opts_.push_back(opt); + } + + void TestGetWeight() { + Tensor* p = FixedTensor(kSize); + for (size_t i = 0; i < opts_.size(); ++i) { + int s = 0; + float* newp = (float*)opts_[i]->get_weight(&s); + for (size_t j = 0; j < kSize; ++j) { + EXPECT_EQ(newp[j], (*p)[j]); + } + } + } + + void TestUpdate() { + Tensor* g = FixedTensor(kSize); + for (size_t i = 0; i < opts_.size(); ++i) { + opts_[i]->Update(g); + } + } + + void TestCheckPoint() { + for (size_t i = 0; i < opts_.size(); ++i) { + int state_len = 0; + std::string state = opts_[i]->SerializeState(&state_len); + opts_[i]->DeserializeState(state); + } + } + +private: + std::vector opts_; + OptimizerConfig config_; +}; + +TEST_F(OptimizerTest, TestGetWeight) { TestGetWeight(); } + +TEST_F(OptimizerTest, TestUpdate) { TestUpdate(); } + +TEST_F(OptimizerTest, TestCheckPoint) { TestCheckPoint(); } + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/paddle/optimizer/serialization.h b/paddle/optimizer/serialization.h new file mode 100644 index 0000000000000000000000000000000000000000..92fbf65cc6b98d7f92841bafe4ab77001ca03b7c --- /dev/null +++ b/paddle/optimizer/serialization.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include +#include +#include +#include "OptimizerConfig.pb.h" +#include "paddle/utils/Logging.h" +#include "tensor.h" + +namespace paddle { +namespace optimizer { + +static void TensorToProto(const Tensor& tensor, TensorProto* proto) { + proto->set_data_type(TensorProto::PADDLE_ELEMENT_TYPE_FLOAT32); + std::stringstream os; + for (size_t i = 0; i < tensor.size(); ++i) { + os << tensor[i]; + proto->add_content(os.str()); + os.str(std::string()); + } +} + +static void ProtoToTensor(const TensorProto& proto, Tensor* tensor) { + std::stringstream sin; + for (auto i = 0; i < proto.content_size(); ++i) { + sin << proto.content(i); + sin >> (*tensor)[i]; + sin.str(std::string()); + sin.clear(); + } +} + +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/serialization_test.cpp b/paddle/optimizer/serialization_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d2454140dc243b40ed8348578360b30894213838 --- /dev/null +++ b/paddle/optimizer/serialization_test.cpp @@ -0,0 +1,25 @@ +#include "serialization.h" +#include "gtest/gtest.h" + +using namespace paddle; +using namespace paddle::optimizer; + +TEST(TensorToProto, Case1) { + Tensor t(3), t1(3); + for (size_t i = 0; i < t.size(); ++i) { + t[i] = i; + t1[i] = 0; + } + + TensorProto proto; + TensorToProto(t, &proto); + ProtoToTensor(proto, &t1); + for (size_t i = 0; i < t1.size(); ++i) { + EXPECT_EQ(t1[i], t[i]); + } +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/paddle/optimizer/sgd_optimizer.cc b/paddle/optimizer/sgd_optimizer.cc new file mode 100644 index 0000000000000000000000000000000000000000..34e051003fa83f11b1f4a39c46856e0372836a1a --- /dev/null +++ b/paddle/optimizer/sgd_optimizer.cc @@ -0,0 +1,49 @@ +#include "sgd_optimizer.h" +#include "serialization.h" + +namespace paddle { +namespace optimizer { + +void SGDOptimizer::Update(const Tensor *gradient) { + num_sample_passed_ += 1; + double learning_rate = lr_policy_->LearningRate(num_sample_passed_); + float velocity = 0.0; + Tensor ¶m = *parameter_; + const Tensor &grad = *gradient; + Tensor &m = *momentums_; + for (size_t i = 0; i < param.size(); ++i) { + if (momentum_ == 0.0) { + velocity = -learning_rate * grad[i] - learning_rate * decay_ * param[i]; + } else { + m[i] = momentum_ * m[i] - learning_rate * grad[i] - + learning_rate * decay_ * param[i]; + velocity = m[i]; + } + if (nesterov_) { + param[i] += momentum_ * velocity - learning_rate * grad[i]; + } else { + param[i] += velocity; + } + } +} + +const char *SGDOptimizer::SerializeState(int *state_len) { + SGDOptimizerState state; + state.set_num_sample_passed(num_sample_passed_); + TensorToProto(*parameter_, state.mutable_parameter()); + if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums()); + auto str = state.SerializeAsString(); + *state_len = str.size(); + return str.c_str(); +} + +void SGDOptimizer::DeserializeState(const std::string &str) { + SGDOptimizerState state; + state.ParseFromString(str); + num_sample_passed_ = state.num_sample_passed(); + ProtoToTensor(state.parameter(), parameter_); + if (momentum_ != 0.0) ProtoToTensor(state.parameter(), momentums_); +} + +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/sgd_optimizer.h b/paddle/optimizer/sgd_optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..b74a902e1aa40a7831b36ab826d72372a3588bcf --- /dev/null +++ b/paddle/optimizer/sgd_optimizer.h @@ -0,0 +1,37 @@ +#pragma once + +#include "parameter_optimizer.h" + +namespace paddle { +namespace optimizer { + +class SGDOptimizer : public ParameterOptimizer { +public: + SGDOptimizer(Tensor* parameter, LrPolicy* lr, double m, double d, bool n) + : ParameterOptimizer(parameter, lr), + momentums_(nullptr), + momentum_(m), + decay_(d), + nesterov_(n) { + if (momentum_ != 0.0) { + size_t size = parameter->size(); + // TODO: fix it with align aware allocator bind to Tensor + momentums_ = new Tensor(size); + } + } + virtual ~SGDOptimizer() { + if (momentums_) delete momentums_; + } + void Update(const Tensor* gradient); + const char* SerializeState(int* state_len); + void DeserializeState(const std::string& state); + +private: + Tensor* momentums_; + double momentum_; + double decay_; + bool nesterov_; +}; + +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/tensor.h b/paddle/optimizer/tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..80a8c93081ea7758d3b5ba016a14d424954db913 --- /dev/null +++ b/paddle/optimizer/tensor.h @@ -0,0 +1,54 @@ +#pragma once +/** + * @brief tensor used by optimizer + */ + +#include +#include +#include "paddle/utils/Common.h" +#include "paddle/utils/Logging.h" + +namespace paddle { +namespace optimizer { + +template +class TensorT { +public: + TensorT(size_t size) : height_(1), width_(size) { + data_ptr_ = std::shared_ptr(new T[size], std::default_delete()); + data_ = data_ptr_.get(); + } + + TensorT(T* data, size_t size) + : height_(1), width_(size), data_ptr_(nullptr), data_(data) {} + + TensorT(T* data, size_t h, size_t w) + : height_(h), width_(w), data_ptr_(nullptr), data_(data) {} + + virtual ~TensorT() {} + + T* get_buffer() { return this->data_; } + + T& operator[](const size_t idx) { + CHECK(idx >= 0 && idx < this->width_) << "out of index range"; + return data_[idx]; + } + T& operator[](const size_t idx) const { + CHECK(idx >= 0 && idx < this->width_) << "out of index range"; + return data_[idx]; + } + // TODO: replace with tensorshape + size_t size() const { return this->width_ * this->height_; } + +protected: + size_t height_; + size_t width_; + std::shared_ptr data_ptr_; + T* data_; +}; + +// TODO(zhihong): design problem of dynamic datatype, need to fix it +typedef TensorT Tensor; + +} // namespace optimizer +} // namespace paddle diff --git a/proto/CMakeLists.txt b/proto/CMakeLists.txt index 62d5b9e38b21ee82d1e78c3bde5aa5df7e4a33ee..9b98dd3fde4d141a35d93c0981acb287831c3eaf 100644 --- a/proto/CMakeLists.txt +++ b/proto/CMakeLists.txt @@ -5,6 +5,7 @@ set(proto_filenames ParameterConfig.proto ParameterService.proto TrainerConfig.proto + OptimizerConfig.proto ParameterServerConfig.proto) set(PROTO_GEN) diff --git a/proto/OptimizerConfig.proto b/proto/OptimizerConfig.proto new file mode 100644 index 0000000000000000000000000000000000000000..c698d3c2ddbf58a41ac6ee960af83a257325d1f9 --- /dev/null +++ b/proto/OptimizerConfig.proto @@ -0,0 +1,154 @@ +syntax = "proto2"; + +option optimize_for = LITE_RUNTIME; + +package paddle; + +message SGDConfig { + // SGD + // momentum: float >= 0. Parameter updates momentum. + // decay: float >= 0. Learning rate decay over each update. + // nesterov: boolean. Whether to apply Nesterov momentum. + optional double momentum = 21 [default = 0.0]; + optional double decay = 23 [default = 0.0]; + optional bool nesterov =24 [default = false]; + +} + + +message AdadeltaConfig { + // Adadelta + // It is recommended to leave it at the default value. + // rho: float >= 0. + // epsilon: float >= 0. Fuzz factor. + // decay: float >= 0. Learning rate decay over each update. + + // reference : [Adadelta - an adaptive learning rate method](http://arxiv.org/abs/1212.5701) + optional double rho = 33 [default = 0.90]; + optional double epsilon = 31 [default = 1e-5]; + optional double decay = 32 [default = 0.0]; + +} + +message AdagradConfig { +// Adagrad +// epsilon: float >= 0. +// decay: float >= 0. Learning rate decay over each update. + +// reference : [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) + optional double epsilon = 41 [default = 1e-5]; + optional double decay = 42 [default = 0.0]; +} + +message AdamConfig { + // Adaj + // beta_1: float, 0 < beta < 1. Generally close to 1. + // beta_2: float, 0 < beta < 1. Generally close to 1. + // epsilon: float >= 0. Fuzz factor. + // decay: float >= 0. Learning rate decay over each update. + // reference : [Adam - A Method for Stochastic Optimization](http://arxiv.org/abs/1412.6980v8) + optional double beta_1 = 41; + optional double beta_2 = 42; + optional double epsilon = 43; + optional double decay = 44; +} + +message ConstLrConfig { + // learninRate Policy + optional double learning_rate = 1 [default = 1.0]; +} + +message LinearLrConfig { + // learninRate Policy + optional double learning_rate = 1 [default = 1.0]; + optional double lr_decay_a = 2; + optional double lr_decay_b = 3; +} + +message TensorProto { +enum DataType { + PADDLE_ELEMENT_TYPE_INT32 = 0; + PADDLE_ELEMENT_TYPE_UINT32 = 1; + PADDLE_ELEMENT_TYPE_INT64 = 2; + PADDLE_ELEMENT_TYPE_UINT64 = 3; + PADDLE_ELEMENT_TYPE_FLOAT32 = 4; + PADDLE_ELEMENT_TYPE_FLOAT64 = 5; +} + optional DataType data_type = 1; + repeated bytes content = 2; +} + +message SGDOptimizerState { + // learning rate policy + optional double learning_rate = 101; + optional double lr_decay_a = 102; + optional double lr_decay_b = 103; + optional double num_sample_passed = 104; + // state + optional TensorProto parameter = 1; + optional TensorProto momentums = 2; +} + +message AdadeltaOptimizerState { + // learning rate policy + optional double learning_rate = 101; + optional double lr_decay_a = 102; + optional double lr_decay_b = 103; + optional double num_sample_passed = 104; + // state + optional TensorProto parameter = 1; + optional TensorProto accum_gradient = 2; + optional TensorProto accum_delta = 3; + optional TensorProto update_delta = 4; +} + +message AdagradOptimizerState { + // learning rate policy + optional double learning_rate = 101; + optional double lr_decay_a = 102; + optional double lr_decay_b = 103; + optional double num_sample_passed = 104; + // state + optional TensorProto parameter = 1; + optional TensorProto accum_gradient = 2; +} + +message AdamOptimizerState { + // learning rate policy + optional double learning_rate = 101; + optional double lr_decay_a = 102; + optional double lr_decay_b = 103; + optional double num_sample_passed = 104; + // state + optional TensorProto parameter = 1; + optional TensorProto momentums = 2; + optional TensorProto velocitys = 3; +} + +message OptimizerConfig { + enum Optimizer { + SGD = 1; + Adadelta = 2; + Adagrad = 3; + Adam = 4; + } + optional Optimizer optimizer = 1; + optional SGDConfig sgd = 3; + optional AdadeltaConfig adadelta = 4; + optional AdagradConfig adagrad = 5; + optional AdamConfig adam = 6; + + enum LrPolicy { + Const = 0; + Linear = 1; + } + optional LrPolicy lr_policy = 11; + optional ConstLrConfig const_lr = 12; + optional LinearLrConfig linear_lr = 13; + + // common config of optimizer + // gradient clip when L2 exceeding value + optional double clip_norm = 101; + // gradient clip when L1 exceeding value + optional double clip_value = 102; +}