From bad503ff08e36f6af19b8e7203cf0ce3507bd80d Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Fri, 14 Apr 2017 23:33:29 +0800 Subject: [PATCH] support RemoteSparseUpdater --- paddle/api/PaddleAPI.h | 9 +++++---- paddle/api/ParameterUpdater.cpp | 15 ++++++++++++--- python/paddle/v2/optimizer.py | 4 ++-- python/paddle/v2/trainer.py | 3 ++- 4 files changed, 21 insertions(+), 10 deletions(-) diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index f5ead40682..c8800519bd 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -21,6 +21,7 @@ limitations under the License. */ #include #include "paddle/utils/Common.h" #include "paddle/utils/GlobalConstants.h" +#include "paddle/gserver/gradientmachines/GradientMachine.h" /// Import PaddlePaddle's enumeration into global namespace. using namespace paddle::enumeration_wrapper; // NOLINT @@ -468,9 +469,9 @@ private: }; enum GradientMatchineCreateMode { - CREATE_MODE_NORMAL = 0, - CREATE_MODE_SGD_SPARSE_CPU_TRAINING = 3, - CREATE_MODE_TESTING = 4 + CREATE_MODE_NORMAL = paddle::GradientMachine::kNormal, + CREATE_MODE_SGD_SPARSE_CPU_TRAINING = paddle::GradientMachine::kSgdSparseCpuTraining, + CREATE_MODE_TESTING = paddle::GradientMachine::kTesting }; struct ParameterConfigPrivate; @@ -818,7 +819,7 @@ private: public: static ParameterUpdater* createLocalUpdater(OptimizationConfig* config); static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config, - int passCount); + int passCount, bool userSparseUpdater); ~ParameterUpdater(); /** diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/api/ParameterUpdater.cpp index 75b0ae7cb6..e96ccc9285 100644 --- a/paddle/api/ParameterUpdater.cpp +++ b/paddle/api/ParameterUpdater.cpp @@ -29,10 +29,19 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater( } ParameterUpdater *ParameterUpdater::createRemoteUpdater( - OptimizationConfig *config, int passCount) { + OptimizationConfig *config, int passCount, bool userSparseUpdater) { auto updater = new ParameterUpdater(); - updater->m->updater.reset(new paddle::RemoteParameterUpdater( - config->m->getConfig(), passCount, nullptr)); + auto remoteUpdater = new paddle::RemoteParameterUpdater( + config->m->getConfig(), passCount, nullptr); + if (userSparseUpdater) { + std::unique_ptr remoteUpdaterPtr; + remoteUpdaterPtr.reset(remoteUpdater); + auto sparseRemoteUpdater = new paddle::SparseRemoteParameterUpdaterComposite( + config->m->getConfig(), passCount, false, std::move(remoteUpdaterPtr)); + updater->m->updater.reset(sparseRemoteUpdater); + } else { + updater->m->updater.reset(remoteUpdater); + } return updater; } diff --git a/python/paddle/v2/optimizer.py b/python/paddle/v2/optimizer.py index 1a01d95c20..6fefd7b2f2 100644 --- a/python/paddle/v2/optimizer.py +++ b/python/paddle/v2/optimizer.py @@ -41,9 +41,9 @@ class Optimizer(object): def create_local_updater(self): return swig_api.ParameterUpdater.createLocalUpdater(self.__opt_conf__) - def create_remote_updater(self, pass_num): + def create_remote_updater(self, pass_num, use_sparse_updater): return swig_api.ParameterUpdater.createRemoteUpdater(self.__opt_conf__, - pass_num) + pass_num, use_sparse_updater) class Momentum(Optimizer): diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 2dac95b63d..dc23eb5b0d 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -97,7 +97,8 @@ class SGD(object): if self.__is_local__: updater = self.__optimizer__.create_local_updater() else: - updater = self.__optimizer__.create_remote_updater(num_passes) + updater = self.__optimizer__.create_remote_updater(num_passes, + self.__use_sparse_updater__) updater.init(self.__gradient_machine__) self.__gradient_machine__.start() -- GitLab