提交 ebba2b13 编写于 作者: Q qiaolongfei

update code with new cclient

上级 8941a385
...@@ -11,10 +11,9 @@ ...@@ -11,10 +11,9 @@
void sendGrads(paddle_pserver_client c) { void sendGrads(paddle_pserver_client c) {
unsigned char grad_a[2000] = {2}; unsigned char grad_a[2000] = {2};
unsigned char grad_b[3000] = {3}; unsigned char grad_b[3000] = {3};
paddle_gradient grads[2] = { paddle_gradient grad1 = {"param_a", PADDLE_ELEMENT_TYPE_FLOAT32, grad_a, 2000};
{"param_a", PADDLE_ELEMENT_TYPE_FLOAT32, grad_a, 2000}, paddle_gradient grad2 = {"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, grad_b, 3000};
{"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, grad_b, 3000}}; paddle_gradient* grads[2] = {&grad1, &grad2};
if (paddle_send_grads(c, grads, 2)) { if (paddle_send_grads(c, grads, 2)) {
fail(); fail();
} }
......
...@@ -30,30 +30,36 @@ void print_parameter(paddle_gradient* param) { ...@@ -30,30 +30,36 @@ void print_parameter(paddle_gradient* param) {
int main() { int main() {
char addr[] = "localhost:3000"; char addr[] = "localhost:3000";
client c = paddle_new_pserver_client(addr, 1); paddle_pserver_client c = paddle_new_pserver_client(addr, 1);
char* names[] = {"param_a", "param_b"}; char* names[] = {"param_a", "param_b"};
retry: retry:
printf("init parameter to pserver:\n");
if (paddle_begin_init_params(c)) {
paddle_parameter param;
real param_content1[] = {0.1, 0.2, 0.3}; real param_content1[] = {0.1, 0.2, 0.3};
param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32; real param_content2[] = {0.4, 0.5, 0.6};
param.name = names[0]; paddle_parameter** params =
param.content = (unsigned char*)param_content1; (paddle_parameter**)malloc(sizeof(paddle_parameter*) * 2);
param.content_len = 3 * sizeof(real); params[0] = (paddle_parameter*)malloc(sizeof(paddle_parameter));
if (paddle_init_param(c, param, NULL, 0) != 0) { params[0]->name = names[0];
params[0]->content = (unsigned char*)param_content1;
params[0]->content_len = 3 * sizeof(real);
params[0]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
params[1] = (paddle_parameter*)malloc(sizeof(paddle_parameter));
params[1]->name = names[1];
params[1]->content = (unsigned char*)param_content2;
params[1]->content_len = 3 * sizeof(real);
params[1]->element_type = PADDLE_ELEMENT_TYPE_INT32;
if (paddle_begin_init_params(c)) {
if (paddle_init_param(c, *params[0], NULL, 0) != 0) {
goto retry; goto retry;
} }
real param_content2[] = {0.4, 0.5, 0.6}; if (paddle_init_param(c, *params[1], NULL, 0) != 0) {
param.element_type = PADDLE_ELEMENT_TYPE_INT32;
param.name = names[1];
param.content = (unsigned char*)param_content2;
param.content_len = 3 * sizeof(real);
if (paddle_init_param(c, param, NULL, 0) != 0) {
goto retry; goto retry;
} }
if (paddle_finish_init_params(c) != 0) { if (paddle_finish_init_params(c) != 0) {
goto retry; goto retry;
} }
...@@ -61,13 +67,13 @@ retry: ...@@ -61,13 +67,13 @@ retry:
fail(); fail();
} }
printf("get initialized parameters from pserver:\n"); printf("get inited parameters from pserver:\n");
paddle_parameter* param_ptrs[2] = {NULL, NULL}; // get parameters again by reusing the allocated parameter buffers.
if (paddle_get_params(c, names, param_ptrs, 2) != 0) { if (paddle_get_params(c, params, 2) != 0) {
fail(); fail();
} }
print_parameter(param_ptrs[0]); print_parameter(params[0]);
print_parameter(param_ptrs[1]); print_parameter(params[1]);
printf("send gradient to pserver:\n"); printf("send gradient to pserver:\n");
real gradient_content1[] = {0.01, 0.02, 0.03}; real gradient_content1[] = {0.01, 0.02, 0.03};
...@@ -87,6 +93,7 @@ retry: ...@@ -87,6 +93,7 @@ retry:
grads[1]->content_len = 3 * sizeof(real); grads[1]->content_len = 3 * sizeof(real);
grads[1]->element_type = PADDLE_ELEMENT_TYPE_INT32; grads[1]->element_type = PADDLE_ELEMENT_TYPE_INT32;
printf("print gradient sent to pserver:\n");
print_parameter(grads[0]); print_parameter(grads[0]);
print_parameter(grads[1]); print_parameter(grads[1]);
...@@ -96,15 +103,11 @@ retry: ...@@ -96,15 +103,11 @@ retry:
printf("get updated parameters from pserver:\n"); printf("get updated parameters from pserver:\n");
// get parameters again by reusing the allocated parameter buffers. // get parameters again by reusing the allocated parameter buffers.
if (paddle_get_params(c, names, param_ptrs, 2) != 0) { if (paddle_get_params(c, params, 2) != 0) {
fail(); fail();
} }
print_parameter(params[0]);
print_parameter(param_ptrs[0]); print_parameter(params[1]);
print_parameter(param_ptrs[1]);
paddle_release_param(param_ptrs[0]);
paddle_release_param(param_ptrs[1]);
if (paddle_save_model(c, "/tmp/") != 0) { if (paddle_save_model(c, "/tmp/") != 0) {
fail(); fail();
......
...@@ -25,7 +25,6 @@ NewRemoteParameterUpdater::NewRemoteParameterUpdater( ...@@ -25,7 +25,6 @@ NewRemoteParameterUpdater::NewRemoteParameterUpdater(
: parameterClient_(-1), : parameterClient_(-1),
newParameters_(nullptr), newParameters_(nullptr),
newGradients_(nullptr), newGradients_(nullptr),
names_(nullptr),
pserverSpec_(pserverSpec) {} pserverSpec_(pserverSpec) {}
void NewRemoteParameterUpdater::init( void NewRemoteParameterUpdater::init(
...@@ -41,12 +40,6 @@ void NewRemoteParameterUpdater::init( ...@@ -41,12 +40,6 @@ void NewRemoteParameterUpdater::init(
parameterClient_ = paddle_new_pserver_client((char *)pserverSpec_.c_str(), parameterClient_ = paddle_new_pserver_client((char *)pserverSpec_.c_str(),
FLAGS_trainer_id == 0); FLAGS_trainer_id == 0);
// init names_ for get parameter through paddle_cclient
names_ = (char **)malloc(parameterSize() * sizeof(char *));
for (int i = 0; i < parameterSize(); ++i) {
names_[i] = (char *)parameters_[i]->getName().c_str();
}
// init new parameter and gradient. // init new parameter and gradient.
newParameters_ = initNewParameter(PARAMETER_VALUE); newParameters_ = initNewParameter(PARAMETER_VALUE);
newGradients_ = initNewParameter(PARAMETER_GRADIENT); newGradients_ = initNewParameter(PARAMETER_GRADIENT);
...@@ -68,7 +61,7 @@ void NewRemoteParameterUpdater::init( ...@@ -68,7 +61,7 @@ void NewRemoteParameterUpdater::init(
LOG(INFO) << "paddle_begin_init_params done"; LOG(INFO) << "paddle_begin_init_params done";
} else { } else {
paddle_get_params( paddle_get_params(
parameterClient_, names_, newParameters_, parameterSize()); parameterClient_, newParameters_, parameterSize());
} }
LOG(INFO) << "NewRemoteParameterUpdater initialized"; LOG(INFO) << "NewRemoteParameterUpdater initialized";
...@@ -80,7 +73,7 @@ void NewRemoteParameterUpdater::finishBatch(real cost) { ...@@ -80,7 +73,7 @@ void NewRemoteParameterUpdater::finishBatch(real cost) {
// send gradient to parameter server. // send gradient to parameter server.
paddle_send_grads(parameterClient_, newGradients_, parameterSize()); paddle_send_grads(parameterClient_, newGradients_, parameterSize());
// get the updated parameter from parameterClient. // get the updated parameter from parameterClient.
paddle_get_params(parameterClient_, names_, newParameters_, parameterSize()); paddle_get_params(parameterClient_, newParameters_, parameterSize());
// clear gradient after update parameter. // clear gradient after update parameter.
for (auto &para : parameters_) { for (auto &para : parameters_) {
......
...@@ -32,9 +32,6 @@ public: ...@@ -32,9 +32,6 @@ public:
NewRemoteParameterUpdater(const OptimizationConfig& config, NewRemoteParameterUpdater(const OptimizationConfig& config,
const std::string pserverSpec); const std::string pserverSpec);
~NewRemoteParameterUpdater() { ~NewRemoteParameterUpdater() {
if (names_ != nullptr) {
free(names_);
}
releaseNewParameter(newParameters_); releaseNewParameter(newParameters_);
releaseNewParameter(newGradients_); releaseNewParameter(newGradients_);
if (parameterClient_ >= 0) paddle_pserver_client_release(parameterClient_); if (parameterClient_ >= 0) paddle_pserver_client_release(parameterClient_);
...@@ -105,13 +102,11 @@ private: ...@@ -105,13 +102,11 @@ private:
protected: protected:
/// internal parameter client object for exchanging data with pserver /// internal parameter client object for exchanging data with pserver
client parameterClient_; paddle_pserver_client parameterClient_;
/// the parameters for new pserver client /// the parameters for new pserver client
paddle_parameter** newParameters_; paddle_parameter** newParameters_;
/// the gradinets for new pserver client /// the gradinets for new pserver client
paddle_parameter** newGradients_; paddle_parameter** newGradients_;
/// the names for new parameters.
char** names_;
/// the specification of parameter server "host1:port,host1:port" /// the specification of parameter server "host1:port,host1:port"
std::string pserverSpec_; std::string pserverSpec_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册