提交 bad503ff 编写于 作者: Q qiaolongfei

support RemoteSparseUpdater

上级 6802b65c
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/utils/Common.h" #include "paddle/utils/Common.h"
#include "paddle/utils/GlobalConstants.h" #include "paddle/utils/GlobalConstants.h"
#include "paddle/gserver/gradientmachines/GradientMachine.h"
/// Import PaddlePaddle's enumeration into global namespace. /// Import PaddlePaddle's enumeration into global namespace.
using namespace paddle::enumeration_wrapper; // NOLINT using namespace paddle::enumeration_wrapper; // NOLINT
...@@ -468,9 +469,9 @@ private: ...@@ -468,9 +469,9 @@ private:
}; };
enum GradientMatchineCreateMode { enum GradientMatchineCreateMode {
CREATE_MODE_NORMAL = 0, CREATE_MODE_NORMAL = paddle::GradientMachine::kNormal,
CREATE_MODE_SGD_SPARSE_CPU_TRAINING = 3, CREATE_MODE_SGD_SPARSE_CPU_TRAINING = paddle::GradientMachine::kSgdSparseCpuTraining,
CREATE_MODE_TESTING = 4 CREATE_MODE_TESTING = paddle::GradientMachine::kTesting
}; };
struct ParameterConfigPrivate; struct ParameterConfigPrivate;
...@@ -818,7 +819,7 @@ private: ...@@ -818,7 +819,7 @@ private:
public: public:
static ParameterUpdater* createLocalUpdater(OptimizationConfig* config); static ParameterUpdater* createLocalUpdater(OptimizationConfig* config);
static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config, static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config,
int passCount); int passCount, bool userSparseUpdater);
~ParameterUpdater(); ~ParameterUpdater();
/** /**
......
...@@ -29,10 +29,19 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater( ...@@ -29,10 +29,19 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater(
} }
ParameterUpdater *ParameterUpdater::createRemoteUpdater( ParameterUpdater *ParameterUpdater::createRemoteUpdater(
OptimizationConfig *config, int passCount) { OptimizationConfig *config, int passCount, bool userSparseUpdater) {
auto updater = new ParameterUpdater(); auto updater = new ParameterUpdater();
updater->m->updater.reset(new paddle::RemoteParameterUpdater( auto remoteUpdater = new paddle::RemoteParameterUpdater(
config->m->getConfig(), passCount, nullptr)); config->m->getConfig(), passCount, nullptr);
if (userSparseUpdater) {
std::unique_ptr<paddle::ParameterUpdater> remoteUpdaterPtr;
remoteUpdaterPtr.reset(remoteUpdater);
auto sparseRemoteUpdater = new paddle::SparseRemoteParameterUpdaterComposite(
config->m->getConfig(), passCount, false, std::move(remoteUpdaterPtr));
updater->m->updater.reset(sparseRemoteUpdater);
} else {
updater->m->updater.reset(remoteUpdater);
}
return updater; return updater;
} }
......
...@@ -41,9 +41,9 @@ class Optimizer(object): ...@@ -41,9 +41,9 @@ class Optimizer(object):
def create_local_updater(self): def create_local_updater(self):
return swig_api.ParameterUpdater.createLocalUpdater(self.__opt_conf__) return swig_api.ParameterUpdater.createLocalUpdater(self.__opt_conf__)
def create_remote_updater(self, pass_num): def create_remote_updater(self, pass_num, use_sparse_updater):
return swig_api.ParameterUpdater.createRemoteUpdater(self.__opt_conf__, return swig_api.ParameterUpdater.createRemoteUpdater(self.__opt_conf__,
pass_num) pass_num, use_sparse_updater)
class Momentum(Optimizer): class Momentum(Optimizer):
......
...@@ -97,7 +97,8 @@ class SGD(object): ...@@ -97,7 +97,8 @@ class SGD(object):
if self.__is_local__: if self.__is_local__:
updater = self.__optimizer__.create_local_updater() updater = self.__optimizer__.create_local_updater()
else: else:
updater = self.__optimizer__.create_remote_updater(num_passes) updater = self.__optimizer__.create_remote_updater(num_passes,
self.__use_sparse_updater__)
updater.init(self.__gradient_machine__) updater.init(self.__gradient_machine__)
self.__gradient_machine__.start() self.__gradient_machine__.start()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册