提交 bc26df79 编写于 作者: D dzhwinter

"polish code style and update based review comment"

上级 5a1e678b
......@@ -46,7 +46,7 @@ protected:
};
// TODO(zhihong): design problem of dynamic datatype, need to fix it
typedef TensorT<real> Tensor;
typedef TensorT<float> Tensor;
} // namespace optimizer
} // namespace paddle
......@@ -8,7 +8,13 @@ namespace optimizer {
class AdadeltaOptimizer : public ParameterOptimizer {
public:
AdadeltaOptimizer(double rho, double epsilon, double decay, LrPolicy *lr)
: ParameterOptimizer(lr), rho_(rho), epsilon_(epsilon), decay_(decay) {}
: ParameterOptimizer(lr),
accum_gradient_(nullptr),
accum_delta_(nullptr),
update_delta_(nullptr),
rho_(rho),
epsilon_(epsilon),
decay_(decay) {}
~AdadeltaOptimizer() {
if (accum_gradient_) delete accum_gradient_;
if (accum_delta_) delete accum_delta_;
......@@ -16,13 +22,11 @@ public:
}
void Update(const Tensor *gradient);
void set_weight(Tensor *p);
real *get_weight() const;
private:
Tensor *accum_gradient_;
Tensor *accum_delta_;
Tensor *update_delta_;
double rho_;
double epsilon_;
double decay_;
......
......@@ -8,13 +8,15 @@ namespace optimizer {
class AdagradOptimizer : public ParameterOptimizer {
public:
AdagradOptimizer(double epsilon, double decay, LrPolicy *lr)
: ParameterOptimizer(lr), epsilon_(epsilon), decay_(decay) {}
: ParameterOptimizer(lr),
accum_gradient_(nullptr),
epsilon_(epsilon),
decay_(decay) {}
~AdagradOptimizer() {
if (accum_gradient_) delete accum_gradient_;
}
void Update(const Tensor *gradient);
void set_weight(Tensor *p);
real *get_weight() const;
private:
Tensor *accum_gradient_;
......
......@@ -10,6 +10,8 @@ public:
AdamOptimizer(
double beta_1, double beta_2, double epsilon, double decay, LrPolicy *lr)
: ParameterOptimizer(lr),
momentums_(nullptr),
velocitys_(nullptr),
beta_1_(beta_1),
beta_2_(beta_2),
epsilon_(epsilon),
......@@ -20,7 +22,6 @@ public:
}
void Update(const Tensor *gradient);
void set_weight(Tensor *p);
real *get_weight() const;
private:
Tensor *momentums_;
......
......@@ -20,7 +20,7 @@ public:
return learning_rate;
}
protected:
private:
double learning_rate;
};
......
......@@ -2,6 +2,7 @@
#include <string>
#include "parameter_optimizer.h"
using namespace paddle;
using namespace paddle::optimizer;
......@@ -26,6 +27,7 @@ MATCH_ENUM_TYPE(int32_t, PADDLE_ELEMENT_TYPE_INT32);
MATCH_ENUM_TYPE(uint32_t, PADDLE_ELEMENT_TYPE_UINT32);
MATCH_ENUM_TYPE(int64_t, PADDLE_ELEMENT_TYPE_INT64);
MATCH_ENUM_TYPE(uint64_t, PADDLE_ELEMENT_TYPE_UINT64);
// TODO(zhihong): only implement below type, need to fix
MATCH_ENUM_TYPE(float, PADDLE_ELEMENT_TYPE_FLOAT32);
MATCH_ENUM_TYPE(double, PADDLE_ELEMENT_TYPE_FLOAT64);
......@@ -35,15 +37,20 @@ struct paddle_optimizer {
paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto,
const int config_proto_len,
const char** state,
const int state_size) {
const paddle_element_type data_type,
void* param_buffer,
int num_bytes,
const char* state,
const int state_len) {
paddle_optimizer* optimizer = new paddle_optimizer;
std::string config(config_proto, config_proto + config_proto_len);
optimizer->impl = ParameterOptimizer::Create(config);
if (state != nullptr) {
std::string s(*state, *state + state_size);
std::string s(state, state + state_len);
optimizer->impl->DeSerializeState(s);
}
Tensor* param = new Tensor(reinterpret_cast<float*>(param_buffer), num_bytes);
optimizer->impl->set_weight(param);
return optimizer;
}
......@@ -57,28 +64,19 @@ int paddle_update_parameter(paddle_optimizer* o,
const void* grad_buffer,
int num_bytes) {
// TOOD(zhihong): datatype not work. need to add the runtime datatype
auto grad_type = reinterpret_cast<const real*>(grad_buffer);
Tensor* gradient = new Tensor(const_cast<real*>(grad_type), num_bytes);
auto grad_type = reinterpret_cast<const float*>(grad_buffer);
Tensor* gradient = new Tensor(const_cast<float*>(grad_type), num_bytes);
o->impl->Update(gradient);
return PADDLE_SUCCESS;
}
int paddle_optimizer_set_weights(paddle_optimizer* o,
const paddle_element_type data_type,
void* param_buffer,
int num_bytes) {
// TOOD(zhihong): datatype not work. need to add the runtime datatype
Tensor* param = new Tensor(reinterpret_cast<real*>(param_buffer), num_bytes);
o->impl->set_weight(param);
return PADDLE_SUCCESS;
int paddle_optimizer_get_weights(paddle_optimizer* o, void** param_buffer) {
int param_size = 0;
*param_buffer = (void*)o->impl->get_weight(&param_size);
return param_size;
}
void* paddle_optimizer_get_weights(paddle_optimizer* o) {
void* buffer = (void*)o->impl->get_weight();
return buffer;
}
int paddle_optimizer_get_state(paddle_optimizer* o, const char* state) {
state = o->impl->SerializeState();
return PADDLE_SUCCESS;
int paddle_optimizer_get_state(paddle_optimizer* o, const char** state) {
*state = o->impl->SerializeState();
return strlen(*state);
}
......@@ -3,19 +3,18 @@
#include <stdbool.h>
#include <stdint.h>
/*! \brief optimizer export C API. which will be used in
Case A, on Trainer (On ParameterServer Client) optimize gradient
Case B, on ParameterServer side optimize gradient
To simplify the configuration parsing. optimizer *do not* parse any config
e.g. learning rate should be calculated by the caller
/**
* @brief optimizer library in independent with other module
* which will be used in :
* Case A, the gradient optimized locally on the trainer.
*
* Case B, the gradient optimized on the parameter server.
*/
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief datatypes */
typedef enum {
PADDLE_ELEMENT_TYPE_INT32 = 0,
PADDLE_ELEMENT_TYPE_UINT32 = 1,
......@@ -25,7 +24,9 @@ typedef enum {
PADDLE_ELEMENT_TYPE_FLOAT64 = 5,
} paddle_element_type;
/*! \brief execute status code */
/**
* @brief execution status code
*/
const int32_t PADDLE_SUCCESS = 0;
const int32_t PADDLE_ERROR = -1;
......@@ -46,8 +47,11 @@ typedef struct paddle_optimizer paddle_optimizer;
*/
paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto,
const int config_proto_len,
const char** state,
const int state_size);
const paddle_element_type data_type,
void* param_buffer,
int num_bytes,
const char* state,
const int state_len);
/**
* @brief release optimizer
......@@ -72,23 +76,17 @@ int paddle_update_parameter(paddle_optimizer* o,
/**
* @brief optimizer instance
* @param data_type datatype of gradient
* @param param_buffer, initilized parameter buffer
* @param num_bytes, parameter size
* @return return exec status
* @return return content length
*/
int paddle_optimizer_set_weights(paddle_optimizer* o,
const paddle_element_type data_type,
void* param_buffer,
int num_bytes);
int paddle_optimizer_get_weights(paddle_optimizer* o, void** param_buffer);
/**
* @brief optimizer instance
* @return return content of parameter buffer in optimizer
* @brief optimzizer instance
* @param training state for receive SerializeState
* @return return state_buffer length
*/
void* paddle_optimizer_get_weights(paddle_optimizer* o);
int paddle_optimizer_get_state(paddle_optimizer* o, const char* state);
int paddle_optimizer_get_state(paddle_optimizer* o, const char** state);
#ifdef __cplusplus
}
......
......@@ -24,7 +24,8 @@ ParameterOptimizer *ParameterOptimizer::Create(
config.linear_lr().lr_decay_a(),
config.linear_lr().lr_decay_b());
// default
return nullptr;
LOG(WARNING) << " have not select any LrPolicy. use ConstLr in default";
return new ConstLr(0.1);
};
LrPolicy *lr = select_lr_policy(config);
auto select_optimizer =
......@@ -36,29 +37,32 @@ ParameterOptimizer *ParameterOptimizer::Create(
lr);
}
if (config.optimizer() == OptimizerConfig::Adadelta) {
return new AdagradOptimizer(
config.adagrad().epsilon(), config.adagrad().decay(), lr);
return new AdadeltaOptimizer(config.adadelta().rho(),
config.adadelta().epsilon(),
config.adadelta().decay(),
lr);
}
if (config.optimizer() == OptimizerConfig::Adagrad) {
return new AdagradOptimizer(
config.adagrad().epsilon(), config.adagrad().decay(), lr);
}
if (config.optimizer() == OptimizerConfig::Adam) {
return new AdadeltaOptimizer(config.adadelta().rho(),
config.adadelta().epsilon(),
config.adadelta().decay(),
return new AdamOptimizer(config.adam().beta_1(),
config.adam().beta_2(),
config.adam().epsilon(),
config.adam().decay(),
lr);
}
// default
return new SGDOptimizer(config.sgd().momentum(),
config.sgd().decay(),
config.sgd().nesterov(),
lr);
LOG(WARNING)
<< "have not select any Optimizer. use SGDOptimizer in default";
return new SGDOptimizer(0.0, 0.0, false, lr);
};
return select_optimizer(config);
}
real *ParameterOptimizer::get_weight() const {
float *ParameterOptimizer::get_weight(int *param_size) const {
*param_size = (int)parameter_->size();
return parameter_->get_buffer();
}
......
......@@ -25,11 +25,10 @@ public:
virtual const char *SerializeState();
virtual void DeSerializeState(const std::string &state);
virtual void Update(const Tensor *gradient) = 0;
virtual real *get_weight() const;
virtual float *get_weight(int *param_size) const;
virtual void set_weight(Tensor *parameter);
protected:
OptimizerConfig config_;
Tensor *parameter_;
// learning rate policy
......
......@@ -77,7 +77,8 @@ public:
opts[i]->set_weight(p);
}
for (size_t i = 0; i < opts.size(); ++i) {
real* newp = (real*)opts[i]->get_weight();
int s = 0;
float* newp = (float*)opts[i]->get_weight(&s);
for (size_t j = 0; j < size; ++j) {
EXPECT_EQ(newp[j], (*p)[j]);
}
......
......@@ -9,14 +9,18 @@ namespace optimizer {
class SGDOptimizer : public ParameterOptimizer {
public:
SGDOptimizer(double m, double d, bool n, LrPolicy* lr)
: ParameterOptimizer(lr), momentum_(m), decay_(d), nesterov_(n) {}
: ParameterOptimizer(lr),
momentums_(nullptr),
momentum_(m),
decay_(d),
nesterov_(n) {}
virtual ~SGDOptimizer() { delete momentums_; }
void Update(const Tensor* gradient);
const char* SerializeState();
void DeSerializeState(const std::string& state);
void set_weight(Tensor* p);
real* get_weight() const;
float* get_weight(int* param_size) const;
private:
Tensor* momentums_;
......
......@@ -16,7 +16,7 @@ void SGDOptimizer::set_weight(Tensor *p) {
void SGDOptimizer::Update(const Tensor *gradient) {
num_sample_passed_ += 1;
double learning_rate = lr_policy_->LearningRate(num_sample_passed_);
real velocity = 0.0;
float velocity = 0.0;
Tensor &param = *parameter_;
const Tensor &grad = *gradient;
Tensor &m = *momentums_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册