From da3e84a6d25fe75f63a624e4e523aba7a8c378c6 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sat, 10 Jun 2017 21:43:59 +0800 Subject: [PATCH] change trainer_id --- go/pserver/cclient/test/mnist_test.py | 5 +---- go/pserver/cclient/test/test_train.py | 2 +- paddle/trainer/NewRemoteParameterUpdater.cpp | 4 ++-- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/go/pserver/cclient/test/mnist_test.py b/go/pserver/cclient/test/mnist_test.py index c77af4913..c3a3af55e 100644 --- a/go/pserver/cclient/test/mnist_test.py +++ b/go/pserver/cclient/test/mnist_test.py @@ -56,7 +56,7 @@ def convolutional_neural_network(img): def main(): - paddle.init(use_gpu=False, trainer_count=1, trainer_id=1) + paddle.init(use_gpu=False, trainer_count=1) # define network topology images = paddle.layer.data( @@ -92,9 +92,6 @@ def main(): print "Pass %d, Batch %d, Cost %f, %s" % ( event.pass_id, event.batch_id, event.cost, event.metrics) - with gzip.open('params.tar.gz', 'w') as f: - parameters.to_tar(f) - elif isinstance(event, paddle.event.EndPass): result = trainer.test(reader=paddle.batch( paddle.dataset.mnist.test(), batch_size=128)) diff --git a/go/pserver/cclient/test/test_train.py b/go/pserver/cclient/test/test_train.py index ddd6371e0..3f8d5d793 100644 --- a/go/pserver/cclient/test/test_train.py +++ b/go/pserver/cclient/test/test_train.py @@ -4,7 +4,7 @@ import paddle.v2.dataset.uci_housing as uci_housing def main(): # init - paddle.init(use_gpu=False, trainer_count=1, trainer_id=1) + paddle.init(use_gpu=False, trainer_count=1) # network config x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13)) diff --git a/paddle/trainer/NewRemoteParameterUpdater.cpp b/paddle/trainer/NewRemoteParameterUpdater.cpp index 0f879dbde..d554e0975 100644 --- a/paddle/trainer/NewRemoteParameterUpdater.cpp +++ b/paddle/trainer/NewRemoteParameterUpdater.cpp @@ -39,8 +39,8 @@ void NewRemoteParameterUpdater::init( } // create parameter server client. - parameterClient_ = - paddle_new_pserver_client((char *)pserverSpec_.c_str(), FLAGS_trainer_id); + parameterClient_ = paddle_new_pserver_client((char *)pserverSpec_.c_str(), + FLAGS_trainer_id == 0); // init names_ for get parameter through paddle_cclient names_ = (char **)malloc(parameterSize() * sizeof(char *)); -- GitLab