提交 6f1c91da 编写于 作者: Q qiaolongfei

refine code

上级 99dc6064
......@@ -14,7 +14,6 @@ int main() {
client c = paddle_new_pserver_client(addr, 1);
retry:
if (paddle_begin_init_params(c)) {
paddle_parameter param;
char name_a[] = "param_a";
char name_b[] = "param_b";
......
......@@ -45,8 +45,8 @@ void NewRemoteParameterUpdater::init(
}
// init new parameter and gradient.
initNewParameter(newParameters_, PARAMETER_VALUE);
initNewParameter(newGradients_, PARAMETER_GRADIENT);
newParameters_ = initNewParameter(PARAMETER_VALUE);
newGradients_ = initNewParameter(PARAMETER_GRADIENT);
// init parameter, one trainer will get the opportunity to int parameter and
// send them to parameter server. Others will get the initialized parameter
......@@ -60,7 +60,7 @@ void NewRemoteParameterUpdater::init(
LOG(INFO) << "paddle_begin_init_params done";
} else {
paddle_get_params(
parameterClient_, names_, newParameters_, (int)parameters_.size());
parameterClient_, names_, newParameters_, parameterSize());
}
LOG(INFO) << "NewRemoteParameterUpdater initialized";
......
......@@ -32,9 +32,9 @@ public:
NewRemoteParameterUpdater(const OptimizationConfig& config,
const std::string pserverSpec);
~NewRemoteParameterUpdater() {
if (newGradients_) {
paddle_pserver_client_release(parameterClient_);
}
releaseNewParameter(newParameters_);
releaseNewParameter(newGradients_);
if (parameterClient_ >= 0) paddle_pserver_client_release(parameterClient_);
}
/**
......@@ -57,37 +57,49 @@ public:
virtual void startPass();
virtual bool finishPass();
int parameterSize() { return (int)parameters_.size(); }
protected:
/**
* work need to do after finishBatch
*/
virtual void updateImpl(Parameter* para);
private:
int parameterSize() {
return (int)parameters_.size();
}
/**
* init parameter of paddle pserver cclient.
* @param new_paras
* @param new_params
* @param type
*/
void initNewParameter(paddle_parameter**& new_paras, ParameterType type) {
new_paras =
paddle_parameter** initNewParameter(ParameterType type) {
paddle_parameter** new_params =
(paddle_parameter**)malloc(sizeof(paddle_parameter*) * parameterSize());
for (int i = 0; i < parameterSize(); ++i) {
new_paras[i] = (paddle_parameter*)malloc(sizeof(paddle_parameter));
memset(new_paras[i], 0, sizeof(paddle_parameter));
new_params[i] = (paddle_parameter*)malloc(sizeof(paddle_parameter));
memset(new_params[i], 0, sizeof(paddle_parameter));
}
for (int i = 0; i < parameterSize(); ++i) {
ParameterPtr para = parameters_[i];
new_paras[i]->content_len = 10;
new_paras[i]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
new_paras[i]->name = (char*)para->getName().c_str();
new_paras[i]->content =
(unsigned char*)(para->getBuf(type).get()->getData());
new_paras[i]->content_len = (int)para->getBuf(type).get()->getSize();
ParameterPtr param = parameters_[i];
new_params[i]->content_len = 10;
new_params[i]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
new_params[i]->name = (char*)param->getName().c_str();
new_params[i]->content =
(unsigned char*)(param->getBuf(type).get()->getData());
new_params[i]->content_len = (int)param->getBuf(type).get()->getSize();
}
return new_params;
}
protected:
/**
* work need to do after finishBatch
*/
virtual void updateImpl(Parameter* para);
void releaseNewParameter(paddle_parameter** newParams) {
if (newParams != NULL) {
for (int i = 0; i < parameterSize(); ++i) {
paddle_release_param(newParams[i]);
}
}
}
protected:
/// internal parameter client object for exchanging data with pserver
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册