提交 bad503ff 编写于 作者: Q qiaolongfei

support RemoteSparseUpdater

上级 6802b65c
......@@ -21,6 +21,7 @@ limitations under the License. */
#include <vector>
#include "paddle/utils/Common.h"
#include "paddle/utils/GlobalConstants.h"
#include "paddle/gserver/gradientmachines/GradientMachine.h"
/// Import PaddlePaddle's enumeration into global namespace.
using namespace paddle::enumeration_wrapper; // NOLINT
......@@ -468,9 +469,9 @@ private:
};
enum GradientMatchineCreateMode {
CREATE_MODE_NORMAL = 0,
CREATE_MODE_SGD_SPARSE_CPU_TRAINING = 3,
CREATE_MODE_TESTING = 4
CREATE_MODE_NORMAL = paddle::GradientMachine::kNormal,
CREATE_MODE_SGD_SPARSE_CPU_TRAINING = paddle::GradientMachine::kSgdSparseCpuTraining,
CREATE_MODE_TESTING = paddle::GradientMachine::kTesting
};
struct ParameterConfigPrivate;
......@@ -818,7 +819,7 @@ private:
public:
static ParameterUpdater* createLocalUpdater(OptimizationConfig* config);
static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config,
int passCount);
int passCount, bool userSparseUpdater);
~ParameterUpdater();
/**
......
......@@ -29,10 +29,19 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater(
}
ParameterUpdater *ParameterUpdater::createRemoteUpdater(
OptimizationConfig *config, int passCount) {
OptimizationConfig *config, int passCount, bool userSparseUpdater) {
auto updater = new ParameterUpdater();
updater->m->updater.reset(new paddle::RemoteParameterUpdater(
config->m->getConfig(), passCount, nullptr));
auto remoteUpdater = new paddle::RemoteParameterUpdater(
config->m->getConfig(), passCount, nullptr);
if (userSparseUpdater) {
std::unique_ptr<paddle::ParameterUpdater> remoteUpdaterPtr;
remoteUpdaterPtr.reset(remoteUpdater);
auto sparseRemoteUpdater = new paddle::SparseRemoteParameterUpdaterComposite(
config->m->getConfig(), passCount, false, std::move(remoteUpdaterPtr));
updater->m->updater.reset(sparseRemoteUpdater);
} else {
updater->m->updater.reset(remoteUpdater);
}
return updater;
}
......
......@@ -41,9 +41,9 @@ class Optimizer(object):
def create_local_updater(self):
return swig_api.ParameterUpdater.createLocalUpdater(self.__opt_conf__)
def create_remote_updater(self, pass_num):
def create_remote_updater(self, pass_num, use_sparse_updater):
return swig_api.ParameterUpdater.createRemoteUpdater(self.__opt_conf__,
pass_num)
pass_num, use_sparse_updater)
class Momentum(Optimizer):
......
......@@ -97,7 +97,8 @@ class SGD(object):
if self.__is_local__:
updater = self.__optimizer__.create_local_updater()
else:
updater = self.__optimizer__.create_remote_updater(num_passes)
updater = self.__optimizer__.create_remote_updater(num_passes,
self.__use_sparse_updater__)
updater.init(self.__gradient_machine__)
self.__gradient_machine__.start()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册