From 32b28c6429451b4ea7b91e002b139919491ea3a3 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 27 Dec 2016 11:50:06 +0800 Subject: [PATCH] add remote updater in api and swig --- paddle/api/Paddle.swig | 1 + paddle/api/PaddleAPI.h | 2 ++ paddle/api/ParameterUpdater.cpp | 18 +++++++++++++++--- paddle/trainer/RemoteParameterUpdater.h | 4 ++-- 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/paddle/api/Paddle.swig b/paddle/api/Paddle.swig index 3365927f9b..068ba286c0 100644 --- a/paddle/api/Paddle.swig +++ b/paddle/api/Paddle.swig @@ -178,6 +178,7 @@ namespace std { %newobject ParameterOptimizer::create; %newobject ParameterOptimizer::needSpecialTraversal; %newobject ParameterUpdater::createLocalUpdater; +%newobject ParameterUpdater::createRemoteUpdater; %feature("director") UpdateCallback; %feature("autodoc", 1); // To generate method stub, for code hint in ide diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index 0a273f9f6f..f70a8ce26b 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -797,6 +797,8 @@ private: public: static ParameterUpdater* createLocalUpdater(OptimizationConfig* config); + static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config, + int passCount); ~ParameterUpdater(); /** diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/api/ParameterUpdater.cpp index 7cd8ed7e39..e84bb63866 100644 --- a/paddle/api/ParameterUpdater.cpp +++ b/paddle/api/ParameterUpdater.cpp @@ -15,15 +15,27 @@ limitations under the License. */ #include "PaddleAPI.h" #include "PaddleAPIPrivate.h" +#include "paddle/trainer/RemoteParameterUpdater.h" #include "paddle/trainer/ThreadParameterUpdater.h" ParameterUpdater::ParameterUpdater() : m(new ParameterUpdaterPrivate()) {} ParameterUpdater *ParameterUpdater::createLocalUpdater( OptimizationConfig *config) { - auto param = new ParameterUpdater(); - param->m->updater.reset(new paddle::SgdThreadUpdater(config->m->getConfig())); - return param; + auto updater = new ParameterUpdater(); + updater->m->updater.reset( + new paddle::SgdThreadUpdater(config->m->getConfig())); + return updater; +} + +ParameterUpdater *ParameterUpdater::createRemoteUpdater( + OptimizationConfig *config, int passCount) { + auto updater = new ParameterUpdater(); + std::unique_ptr localUpdater; + localUpdater.reset(new paddle::SgdThreadUpdater(config->m->getConfig())); + updater->m->updater.reset(new paddle::ConcurrentRemoteParameterUpdater( + config->m->getConfig(), passCount, std::move(localUpdater))); + return updater; } ParameterUpdater::~ParameterUpdater() { delete m; } diff --git a/paddle/trainer/RemoteParameterUpdater.h b/paddle/trainer/RemoteParameterUpdater.h index 7794b20900..5e82c94475 100644 --- a/paddle/trainer/RemoteParameterUpdater.h +++ b/paddle/trainer/RemoteParameterUpdater.h @@ -56,7 +56,7 @@ class RemoteParameterUpdater : public ParameterUpdater { public: RemoteParameterUpdater( const OptimizationConfig& config, - int expectedPpassCount, + int expectedPassCount, std::unique_ptr&& localUpdater = nullptr); ~RemoteParameterUpdater() { if (controllerThread_) { @@ -146,7 +146,7 @@ protected: BatchStatus batchStatus_; /// controller thread for sync-sgd std::unique_ptr controllerThread_; - /// passed alread finished + /// passed already finished int64_t passCount_; /// expected passes to finished int64_t expectedPassCount_; -- GitLab