diff --git a/paddle/api/Paddle.swig b/paddle/api/Paddle.swig index 3365927f9b59936244230bed439808fa7ead2c61..068ba286c07d8854a1a7c7042224a679b50b4957 100644 --- a/paddle/api/Paddle.swig +++ b/paddle/api/Paddle.swig @@ -178,6 +178,7 @@ namespace std { %newobject ParameterOptimizer::create; %newobject ParameterOptimizer::needSpecialTraversal; %newobject ParameterUpdater::createLocalUpdater; +%newobject ParameterUpdater::createRemoteUpdater; %feature("director") UpdateCallback; %feature("autodoc", 1); // To generate method stub, for code hint in ide diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index 09c891871a5ca8571216d211203fe8643fc3a63f..81c9eed0bccd5ad63f524cdb011fc73cd568f465 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -803,6 +803,8 @@ private: public: static ParameterUpdater* createLocalUpdater(OptimizationConfig* config); + static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config, + int passCount); ~ParameterUpdater(); /** diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/api/ParameterUpdater.cpp index 7cd8ed7e3907489a60f37090df6f51492def2612..75b0ae7cb6cc8c9ad0f8fe69963b7439a44bf55e 100644 --- a/paddle/api/ParameterUpdater.cpp +++ b/paddle/api/ParameterUpdater.cpp @@ -15,15 +15,25 @@ limitations under the License. */ #include "PaddleAPI.h" #include "PaddleAPIPrivate.h" +#include "paddle/trainer/RemoteParameterUpdater.h" #include "paddle/trainer/ThreadParameterUpdater.h" ParameterUpdater::ParameterUpdater() : m(new ParameterUpdaterPrivate()) {} ParameterUpdater *ParameterUpdater::createLocalUpdater( OptimizationConfig *config) { - auto param = new ParameterUpdater(); - param->m->updater.reset(new paddle::SgdThreadUpdater(config->m->getConfig())); - return param; + auto updater = new ParameterUpdater(); + updater->m->updater.reset( + new paddle::SgdThreadUpdater(config->m->getConfig())); + return updater; +} + +ParameterUpdater *ParameterUpdater::createRemoteUpdater( + OptimizationConfig *config, int passCount) { + auto updater = new ParameterUpdater(); + updater->m->updater.reset(new paddle::RemoteParameterUpdater( + config->m->getConfig(), passCount, nullptr)); + return updater; } ParameterUpdater::~ParameterUpdater() { delete m; } diff --git a/paddle/trainer/RemoteParameterUpdater.h b/paddle/trainer/RemoteParameterUpdater.h index 7794b209009a3429e810074b61e1d5bffa8b3a4e..5e82c944751629632ea8d16992bd8f4178a2fbd5 100644 --- a/paddle/trainer/RemoteParameterUpdater.h +++ b/paddle/trainer/RemoteParameterUpdater.h @@ -56,7 +56,7 @@ class RemoteParameterUpdater : public ParameterUpdater { public: RemoteParameterUpdater( const OptimizationConfig& config, - int expectedPpassCount, + int expectedPassCount, std::unique_ptr&& localUpdater = nullptr); ~RemoteParameterUpdater() { if (controllerThread_) { @@ -146,7 +146,7 @@ protected: BatchStatus batchStatus_; /// controller thread for sync-sgd std::unique_ptr controllerThread_; - /// passed alread finished + /// passed already finished int64_t passCount_; /// expected passes to finished int64_t expectedPassCount_;