提交 6f1c91da 编写于 作者: Q qiaolongfei

refine code

上级 99dc6064
...@@ -14,7 +14,6 @@ int main() { ...@@ -14,7 +14,6 @@ int main() {
client c = paddle_new_pserver_client(addr, 1); client c = paddle_new_pserver_client(addr, 1);
retry: retry:
if (paddle_begin_init_params(c)) { if (paddle_begin_init_params(c)) {
paddle_parameter param; paddle_parameter param;
char name_a[] = "param_a"; char name_a[] = "param_a";
char name_b[] = "param_b"; char name_b[] = "param_b";
......
...@@ -45,8 +45,8 @@ void NewRemoteParameterUpdater::init( ...@@ -45,8 +45,8 @@ void NewRemoteParameterUpdater::init(
} }
// init new parameter and gradient. // init new parameter and gradient.
initNewParameter(newParameters_, PARAMETER_VALUE); newParameters_ = initNewParameter(PARAMETER_VALUE);
initNewParameter(newGradients_, PARAMETER_GRADIENT); newGradients_ = initNewParameter(PARAMETER_GRADIENT);
// init parameter, one trainer will get the opportunity to int parameter and // init parameter, one trainer will get the opportunity to int parameter and
// send them to parameter server. Others will get the initialized parameter // send them to parameter server. Others will get the initialized parameter
...@@ -60,7 +60,7 @@ void NewRemoteParameterUpdater::init( ...@@ -60,7 +60,7 @@ void NewRemoteParameterUpdater::init(
LOG(INFO) << "paddle_begin_init_params done"; LOG(INFO) << "paddle_begin_init_params done";
} else { } else {
paddle_get_params( paddle_get_params(
parameterClient_, names_, newParameters_, (int)parameters_.size()); parameterClient_, names_, newParameters_, parameterSize());
} }
LOG(INFO) << "NewRemoteParameterUpdater initialized"; LOG(INFO) << "NewRemoteParameterUpdater initialized";
......
...@@ -32,9 +32,9 @@ public: ...@@ -32,9 +32,9 @@ public:
NewRemoteParameterUpdater(const OptimizationConfig& config, NewRemoteParameterUpdater(const OptimizationConfig& config,
const std::string pserverSpec); const std::string pserverSpec);
~NewRemoteParameterUpdater() { ~NewRemoteParameterUpdater() {
if (newGradients_) { releaseNewParameter(newParameters_);
paddle_pserver_client_release(parameterClient_); releaseNewParameter(newGradients_);
} if (parameterClient_ >= 0) paddle_pserver_client_release(parameterClient_);
} }
/** /**
...@@ -57,37 +57,49 @@ public: ...@@ -57,37 +57,49 @@ public:
virtual void startPass(); virtual void startPass();
virtual bool finishPass(); virtual bool finishPass();
int parameterSize() { return (int)parameters_.size(); } protected:
/** /**
* init parameter of paddle pserver cclient. * work need to do after finishBatch
* @param new_paras
* @param type
*/ */
void initNewParameter(paddle_parameter**& new_paras, ParameterType type) { virtual void updateImpl(Parameter* para);
new_paras =
(paddle_parameter**)malloc(sizeof(paddle_parameter*) * parameterSize()); private:
for (int i = 0; i < parameterSize(); ++i) { int parameterSize() {
new_paras[i] = (paddle_parameter*)malloc(sizeof(paddle_parameter)); return (int)parameters_.size();
memset(new_paras[i], 0, sizeof(paddle_parameter));
} }
for (int i = 0; i < parameterSize(); ++i) { /**
ParameterPtr para = parameters_[i]; * init parameter of paddle pserver cclient.
new_paras[i]->content_len = 10; * @param new_params
new_paras[i]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32; * @param type
new_paras[i]->name = (char*)para->getName().c_str(); */
new_paras[i]->content = paddle_parameter** initNewParameter(ParameterType type) {
(unsigned char*)(para->getBuf(type).get()->getData()); paddle_parameter** new_params =
new_paras[i]->content_len = (int)para->getBuf(type).get()->getSize(); (paddle_parameter**)malloc(sizeof(paddle_parameter*) * parameterSize());
for (int i = 0; i < parameterSize(); ++i) {
new_params[i] = (paddle_parameter*)malloc(sizeof(paddle_parameter));
memset(new_params[i], 0, sizeof(paddle_parameter));
}
for (int i = 0; i < parameterSize(); ++i) {
ParameterPtr param = parameters_[i];
new_params[i]->content_len = 10;
new_params[i]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
new_params[i]->name = (char*)param->getName().c_str();
new_params[i]->content =
(unsigned char*)(param->getBuf(type).get()->getData());
new_params[i]->content_len = (int)param->getBuf(type).get()->getSize();
}
return new_params;
} }
}
protected: void releaseNewParameter(paddle_parameter** newParams) {
/** if (newParams != NULL) {
* work need to do after finishBatch for (int i = 0; i < parameterSize(); ++i) {
*/ paddle_release_param(newParams[i]);
virtual void updateImpl(Parameter* para); }
}
}
protected: protected:
/// internal parameter client object for exchanging data with pserver /// internal parameter client object for exchanging data with pserver
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册