From 6f1c91da992b9f7b230633c0ac56db184d4df5c2 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Thu, 8 Jun 2017 22:38:30 +0800 Subject: [PATCH] refine code --- go/pserver/cclient/test/main.c | 1 - paddle/trainer/NewRemoteParameterUpdater.cpp | 6 +- paddle/trainer/NewRemoteParameterUpdater.h | 68 ++++++++++++-------- 3 files changed, 43 insertions(+), 32 deletions(-) diff --git a/go/pserver/cclient/test/main.c b/go/pserver/cclient/test/main.c index 0ad890daa2f..b95abf96b1d 100644 --- a/go/pserver/cclient/test/main.c +++ b/go/pserver/cclient/test/main.c @@ -14,7 +14,6 @@ int main() { client c = paddle_new_pserver_client(addr, 1); retry: if (paddle_begin_init_params(c)) { - paddle_parameter param; char name_a[] = "param_a"; char name_b[] = "param_b"; diff --git a/paddle/trainer/NewRemoteParameterUpdater.cpp b/paddle/trainer/NewRemoteParameterUpdater.cpp index 9060052e113..13110adb450 100644 --- a/paddle/trainer/NewRemoteParameterUpdater.cpp +++ b/paddle/trainer/NewRemoteParameterUpdater.cpp @@ -45,8 +45,8 @@ void NewRemoteParameterUpdater::init( } // init new parameter and gradient. - initNewParameter(newParameters_, PARAMETER_VALUE); - initNewParameter(newGradients_, PARAMETER_GRADIENT); + newParameters_ = initNewParameter(PARAMETER_VALUE); + newGradients_ = initNewParameter(PARAMETER_GRADIENT); // init parameter, one trainer will get the opportunity to int parameter and // send them to parameter server. Others will get the initialized parameter @@ -60,7 +60,7 @@ void NewRemoteParameterUpdater::init( LOG(INFO) << "paddle_begin_init_params done"; } else { paddle_get_params( - parameterClient_, names_, newParameters_, (int)parameters_.size()); + parameterClient_, names_, newParameters_, parameterSize()); } LOG(INFO) << "NewRemoteParameterUpdater initialized"; diff --git a/paddle/trainer/NewRemoteParameterUpdater.h b/paddle/trainer/NewRemoteParameterUpdater.h index 33640bc8a38..5fd404dcf8c 100644 --- a/paddle/trainer/NewRemoteParameterUpdater.h +++ b/paddle/trainer/NewRemoteParameterUpdater.h @@ -32,9 +32,9 @@ public: NewRemoteParameterUpdater(const OptimizationConfig& config, const std::string pserverSpec); ~NewRemoteParameterUpdater() { - if (newGradients_) { - paddle_pserver_client_release(parameterClient_); - } + releaseNewParameter(newParameters_); + releaseNewParameter(newGradients_); + if (parameterClient_ >= 0) paddle_pserver_client_release(parameterClient_); } /** @@ -57,37 +57,49 @@ public: virtual void startPass(); virtual bool finishPass(); - int parameterSize() { return (int)parameters_.size(); } - +protected: /** - * init parameter of paddle pserver cclient. - * @param new_paras - * @param type + * work need to do after finishBatch */ - void initNewParameter(paddle_parameter**& new_paras, ParameterType type) { - new_paras = - (paddle_parameter**)malloc(sizeof(paddle_parameter*) * parameterSize()); - for (int i = 0; i < parameterSize(); ++i) { - new_paras[i] = (paddle_parameter*)malloc(sizeof(paddle_parameter)); - memset(new_paras[i], 0, sizeof(paddle_parameter)); + virtual void updateImpl(Parameter* para); + +private: + int parameterSize() { + return (int)parameters_.size(); } - for (int i = 0; i < parameterSize(); ++i) { - ParameterPtr para = parameters_[i]; - new_paras[i]->content_len = 10; - new_paras[i]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32; - new_paras[i]->name = (char*)para->getName().c_str(); - new_paras[i]->content = - (unsigned char*)(para->getBuf(type).get()->getData()); - new_paras[i]->content_len = (int)para->getBuf(type).get()->getSize(); + /** + * init parameter of paddle pserver cclient. + * @param new_params + * @param type + */ + paddle_parameter** initNewParameter(ParameterType type) { + paddle_parameter** new_params = + (paddle_parameter**)malloc(sizeof(paddle_parameter*) * parameterSize()); + for (int i = 0; i < parameterSize(); ++i) { + new_params[i] = (paddle_parameter*)malloc(sizeof(paddle_parameter)); + memset(new_params[i], 0, sizeof(paddle_parameter)); + } + + for (int i = 0; i < parameterSize(); ++i) { + ParameterPtr param = parameters_[i]; + new_params[i]->content_len = 10; + new_params[i]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32; + new_params[i]->name = (char*)param->getName().c_str(); + new_params[i]->content = + (unsigned char*)(param->getBuf(type).get()->getData()); + new_params[i]->content_len = (int)param->getBuf(type).get()->getSize(); + } + return new_params; } - } -protected: - /** - * work need to do after finishBatch - */ - virtual void updateImpl(Parameter* para); + void releaseNewParameter(paddle_parameter** newParams) { + if (newParams != NULL) { + for (int i = 0; i < parameterSize(); ++i) { + paddle_release_param(newParams[i]); + } + } + } protected: /// internal parameter client object for exchanging data with pserver -- GitLab