From cf86ca04b4682c0f1ecf24324ed3dcc7769cea63 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 19 Apr 2017 10:23:26 +0800 Subject: [PATCH] refine code --- paddle/api/ParameterUpdater.cpp | 3 +-- python/paddle/v2/optimizer.py | 25 ++++++++++++++++++++----- python/paddle/v2/trainer.py | 24 ++++++++++++++++-------- 3 files changed, 37 insertions(+), 15 deletions(-) diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/api/ParameterUpdater.cpp index 9dfd12ccbe7..79921ea6e78 100644 --- a/paddle/api/ParameterUpdater.cpp +++ b/paddle/api/ParameterUpdater.cpp @@ -34,8 +34,7 @@ ParameterUpdater *ParameterUpdater::createRemoteUpdater( auto remoteUpdater = new paddle::RemoteParameterUpdater( config->m->getConfig(), passCount, nullptr); if (useSparseUpdater) { - std::unique_ptr remoteUpdaterPtr; - remoteUpdaterPtr.reset(remoteUpdater); + std::unique_ptr remoteUpdaterPtr(remoteUpdater); auto sparseRemoteUpdater = new paddle::SparseRemoteParameterUpdaterComposite( config->m->getConfig(), diff --git a/python/paddle/v2/optimizer.py b/python/paddle/v2/optimizer.py index 887b2567a14..17c56a2b993 100644 --- a/python/paddle/v2/optimizer.py +++ b/python/paddle/v2/optimizer.py @@ -38,19 +38,34 @@ class Optimizer(object): assert isinstance(tmp, swig_api.ParameterOptimizer) return tmp.getParameterTypes() - def create_local_updater(self): + def __create_local_updater__(self): return swig_api.ParameterUpdater.createLocalUpdater(self.__opt_conf__) - def create_remote_updater(self, pass_num, use_sparse_updater): + def __create_remote_updater__(self, pass_num, use_sparse_updater): return swig_api.ParameterUpdater.createRemoteUpdater( self.__opt_conf__, pass_num, use_sparse_updater) def create_updater(self, is_local, num_passes, use_sparse_updater): + """ + create proper parameter_updater by configuration. + :param is_local: create local or remote parameter updater + :param num_passes: remote parameter updater will use this to config + parameter server. + :param use_sparse_updater: when use remote updater, if some parameter is + sparse, updater should do some extra thing: + + .. code-block:: python + + if use_sparse_remote_updater: + gradient_machine.prefetch(in_args) + parameter_updater.getParametersRemote() + :return: parameter_updater + """ if is_local: - parameter_updater = self.create_local_updater() + parameter_updater = self.__create_local_updater__() else: - parameter_updater = self.create_remote_updater(num_passes, - use_sparse_updater) + parameter_updater = self.__create_remote_updater__( + num_passes, use_sparse_updater) return parameter_updater diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 9caaeca2efe..552c6690a60 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -78,12 +78,24 @@ class SGD(object): assert isinstance(gm, api.GradientMachine) self.__gradient_machine__ = gm self.__gradient_machine__.randParameters() - parameters.append_gradient_machine(gm) + self.__parameters__.append_gradient_machine(gm) self.__parameter_updater__ = None - def use_remote_sparse_updater(self): + def __use_remote_sparse_updater__(self): return self.__use_sparse_updater__ and not self.__is_local__ + def __prepare_parameter__(self, in_args): + """ + prepare parameter before forward backward. + 1. When use remote sparse updater, parameters should be got + from ps according to input arguments. + :param in_args: input arguments of this batch. + :return: + """ + if self.__use_remote_sparse_updater__(): + self.__gradient_machine__.prefetch(in_args) + self.__parameter_updater__.getParametersRemote() + def train(self, reader, num_passes=1, event_handler=None, feeding=None): """ Training method. Will train num_passes of input data. @@ -125,9 +137,7 @@ class SGD(object): pass_type = self.__parameter_updater__.startBatch( len(data_batch)) in_args = feeder(data_batch) - if self.use_remote_sparse_updater(): - self.__gradient_machine__.prefetch(in_args) - self.__parameter_updater__.getParametersRemote() + self.__prepare_parameter__(in_args) self.__gradient_machine__.forwardBackward(in_args, out_args, pass_type) self.__gradient_machine__.eval(pass_evaluator) @@ -161,9 +171,7 @@ class SGD(object): for data_batch in reader(): num_samples += len(data_batch) in_args = feeder(data_batch) - if self.use_remote_sparse_updater(): - self.__gradient_machine__.prefetch(in_args) - self.__parameter_updater__.getParametersRemote() + self.__prepare_parameter__(in_args) self.__gradient_machine__.forward(in_args, out_args, api.PASS_TEST) total_cost += out_args.sum() self.__gradient_machine__.eval(evaluator) -- GitLab