diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index f5ead40682c69ffd610a3fe64207cb077d848949..c8800519bd2fcd8d3f0f75d5f8f093985ec227ee 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -21,6 +21,7 @@ limitations under the License. */ #include #include "paddle/utils/Common.h" #include "paddle/utils/GlobalConstants.h" +#include "paddle/gserver/gradientmachines/GradientMachine.h" /// Import PaddlePaddle's enumeration into global namespace. using namespace paddle::enumeration_wrapper; // NOLINT @@ -468,9 +469,9 @@ private: }; enum GradientMatchineCreateMode { - CREATE_MODE_NORMAL = 0, - CREATE_MODE_SGD_SPARSE_CPU_TRAINING = 3, - CREATE_MODE_TESTING = 4 + CREATE_MODE_NORMAL = paddle::GradientMachine::kNormal, + CREATE_MODE_SGD_SPARSE_CPU_TRAINING = paddle::GradientMachine::kSgdSparseCpuTraining, + CREATE_MODE_TESTING = paddle::GradientMachine::kTesting }; struct ParameterConfigPrivate; @@ -818,7 +819,7 @@ private: public: static ParameterUpdater* createLocalUpdater(OptimizationConfig* config); static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config, - int passCount); + int passCount, bool userSparseUpdater); ~ParameterUpdater(); /** diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/api/ParameterUpdater.cpp index 75b0ae7cb6cc8c9ad0f8fe69963b7439a44bf55e..e96ccc928549d0f2e433cc6914b721f5e073b545 100644 --- a/paddle/api/ParameterUpdater.cpp +++ b/paddle/api/ParameterUpdater.cpp @@ -29,10 +29,19 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater( } ParameterUpdater *ParameterUpdater::createRemoteUpdater( - OptimizationConfig *config, int passCount) { + OptimizationConfig *config, int passCount, bool userSparseUpdater) { auto updater = new ParameterUpdater(); - updater->m->updater.reset(new paddle::RemoteParameterUpdater( - config->m->getConfig(), passCount, nullptr)); + auto remoteUpdater = new paddle::RemoteParameterUpdater( + config->m->getConfig(), passCount, nullptr); + if (userSparseUpdater) { + std::unique_ptr remoteUpdaterPtr; + remoteUpdaterPtr.reset(remoteUpdater); + auto sparseRemoteUpdater = new paddle::SparseRemoteParameterUpdaterComposite( + config->m->getConfig(), passCount, false, std::move(remoteUpdaterPtr)); + updater->m->updater.reset(sparseRemoteUpdater); + } else { + updater->m->updater.reset(remoteUpdater); + } return updater; } diff --git a/python/paddle/v2/optimizer.py b/python/paddle/v2/optimizer.py index 1a01d95c205c0626374e1814a170ce2d58f23a60..6fefd7b2f241342da85e0d08ea88c378f98a674b 100644 --- a/python/paddle/v2/optimizer.py +++ b/python/paddle/v2/optimizer.py @@ -41,9 +41,9 @@ class Optimizer(object): def create_local_updater(self): return swig_api.ParameterUpdater.createLocalUpdater(self.__opt_conf__) - def create_remote_updater(self, pass_num): + def create_remote_updater(self, pass_num, use_sparse_updater): return swig_api.ParameterUpdater.createRemoteUpdater(self.__opt_conf__, - pass_num) + pass_num, use_sparse_updater) class Momentum(Optimizer): diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 2dac95b63d550733c54ee5fb13d2c02272fb1af5..dc23eb5b0d74ab69b662747fb6e1a77f8f7dae46 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -97,7 +97,8 @@ class SGD(object): if self.__is_local__: updater = self.__optimizer__.create_local_updater() else: - updater = self.__optimizer__.create_remote_updater(num_passes) + updater = self.__optimizer__.create_remote_updater(num_passes, + self.__use_sparse_updater__) updater.init(self.__gradient_machine__) self.__gradient_machine__.start()