diff --git a/paddle/api/Paddle.swig b/paddle/api/Paddle.swig index 3365927f9b59936244230bed439808fa7ead2c61..068ba286c07d8854a1a7c7042224a679b50b4957 100644 --- a/paddle/api/Paddle.swig +++ b/paddle/api/Paddle.swig @@ -178,6 +178,7 @@ namespace std { %newobject ParameterOptimizer::create; %newobject ParameterOptimizer::needSpecialTraversal; %newobject ParameterUpdater::createLocalUpdater; +%newobject ParameterUpdater::createRemoteUpdater; %feature("director") UpdateCallback; %feature("autodoc", 1); // To generate method stub, for code hint in ide diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index 0a273f9f6f942a9ca3f4404dfc843eb88d650daa..f70a8ce26b165f7989b3182ed0a2feb320572efb 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -797,6 +797,8 @@ private: public: static ParameterUpdater* createLocalUpdater(OptimizationConfig* config); + static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config, + int passCount); ~ParameterUpdater(); /** diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/api/ParameterUpdater.cpp index 7cd8ed7e3907489a60f37090df6f51492def2612..e84bb63866094d004c03e4c8c0c4838c99c551e6 100644 --- a/paddle/api/ParameterUpdater.cpp +++ b/paddle/api/ParameterUpdater.cpp @@ -15,15 +15,27 @@ limitations under the License. */ #include "PaddleAPI.h" #include "PaddleAPIPrivate.h" +#include "paddle/trainer/RemoteParameterUpdater.h" #include "paddle/trainer/ThreadParameterUpdater.h" ParameterUpdater::ParameterUpdater() : m(new ParameterUpdaterPrivate()) {} ParameterUpdater *ParameterUpdater::createLocalUpdater( OptimizationConfig *config) { - auto param = new ParameterUpdater(); - param->m->updater.reset(new paddle::SgdThreadUpdater(config->m->getConfig())); - return param; + auto updater = new ParameterUpdater(); + updater->m->updater.reset( + new paddle::SgdThreadUpdater(config->m->getConfig())); + return updater; +} + +ParameterUpdater *ParameterUpdater::createRemoteUpdater( + OptimizationConfig *config, int passCount) { + auto updater = new ParameterUpdater(); + std::unique_ptr localUpdater; + localUpdater.reset(new paddle::SgdThreadUpdater(config->m->getConfig())); + updater->m->updater.reset(new paddle::ConcurrentRemoteParameterUpdater( + config->m->getConfig(), passCount, std::move(localUpdater))); + return updater; } ParameterUpdater::~ParameterUpdater() { delete m; } diff --git a/paddle/trainer/RemoteParameterUpdater.h b/paddle/trainer/RemoteParameterUpdater.h index 7794b209009a3429e810074b61e1d5bffa8b3a4e..5e82c944751629632ea8d16992bd8f4178a2fbd5 100644 --- a/paddle/trainer/RemoteParameterUpdater.h +++ b/paddle/trainer/RemoteParameterUpdater.h @@ -56,7 +56,7 @@ class RemoteParameterUpdater : public ParameterUpdater { public: RemoteParameterUpdater( const OptimizationConfig& config, - int expectedPpassCount, + int expectedPassCount, std::unique_ptr&& localUpdater = nullptr); ~RemoteParameterUpdater() { if (controllerThread_) { @@ -146,7 +146,7 @@ protected: BatchStatus batchStatus_; /// controller thread for sync-sgd std::unique_ptr controllerThread_; - /// passed alread finished + /// passed already finished int64_t passCount_; /// expected passes to finished int64_t expectedPassCount_;