提交 cf86ca04 编写于 作者: Q qiaolongfei

refine code

上级 cfff9467
......@@ -34,8 +34,7 @@ ParameterUpdater *ParameterUpdater::createRemoteUpdater(
auto remoteUpdater = new paddle::RemoteParameterUpdater(
config->m->getConfig(), passCount, nullptr);
if (useSparseUpdater) {
std::unique_ptr<paddle::ParameterUpdater> remoteUpdaterPtr;
remoteUpdaterPtr.reset(remoteUpdater);
std::unique_ptr<paddle::ParameterUpdater> remoteUpdaterPtr(remoteUpdater);
auto sparseRemoteUpdater =
new paddle::SparseRemoteParameterUpdaterComposite(
config->m->getConfig(),
......
......@@ -38,19 +38,34 @@ class Optimizer(object):
assert isinstance(tmp, swig_api.ParameterOptimizer)
return tmp.getParameterTypes()
def create_local_updater(self):
def __create_local_updater__(self):
return swig_api.ParameterUpdater.createLocalUpdater(self.__opt_conf__)
def create_remote_updater(self, pass_num, use_sparse_updater):
def __create_remote_updater__(self, pass_num, use_sparse_updater):
return swig_api.ParameterUpdater.createRemoteUpdater(
self.__opt_conf__, pass_num, use_sparse_updater)
def create_updater(self, is_local, num_passes, use_sparse_updater):
"""
create proper parameter_updater by configuration.
:param is_local: create local or remote parameter updater
:param num_passes: remote parameter updater will use this to config
parameter server.
:param use_sparse_updater: when use remote updater, if some parameter is
sparse, updater should do some extra thing:
.. code-block:: python
if use_sparse_remote_updater:
gradient_machine.prefetch(in_args)
parameter_updater.getParametersRemote()
:return: parameter_updater
"""
if is_local:
parameter_updater = self.create_local_updater()
parameter_updater = self.__create_local_updater__()
else:
parameter_updater = self.create_remote_updater(num_passes,
use_sparse_updater)
parameter_updater = self.__create_remote_updater__(
num_passes, use_sparse_updater)
return parameter_updater
......
......@@ -78,12 +78,24 @@ class SGD(object):
assert isinstance(gm, api.GradientMachine)
self.__gradient_machine__ = gm
self.__gradient_machine__.randParameters()
parameters.append_gradient_machine(gm)
self.__parameters__.append_gradient_machine(gm)
self.__parameter_updater__ = None
def use_remote_sparse_updater(self):
def __use_remote_sparse_updater__(self):
return self.__use_sparse_updater__ and not self.__is_local__
def __prepare_parameter__(self, in_args):
"""
prepare parameter before forward backward.
1. When use remote sparse updater, parameters should be got
from ps according to input arguments.
:param in_args: input arguments of this batch.
:return:
"""
if self.__use_remote_sparse_updater__():
self.__gradient_machine__.prefetch(in_args)
self.__parameter_updater__.getParametersRemote()
def train(self, reader, num_passes=1, event_handler=None, feeding=None):
"""
Training method. Will train num_passes of input data.
......@@ -125,9 +137,7 @@ class SGD(object):
pass_type = self.__parameter_updater__.startBatch(
len(data_batch))
in_args = feeder(data_batch)
if self.use_remote_sparse_updater():
self.__gradient_machine__.prefetch(in_args)
self.__parameter_updater__.getParametersRemote()
self.__prepare_parameter__(in_args)
self.__gradient_machine__.forwardBackward(in_args, out_args,
pass_type)
self.__gradient_machine__.eval(pass_evaluator)
......@@ -161,9 +171,7 @@ class SGD(object):
for data_batch in reader():
num_samples += len(data_batch)
in_args = feeder(data_batch)
if self.use_remote_sparse_updater():
self.__gradient_machine__.prefetch(in_args)
self.__parameter_updater__.getParametersRemote()
self.__prepare_parameter__(in_args)
self.__gradient_machine__.forward(in_args, out_args, api.PASS_TEST)
total_cost += out_args.sum()
self.__gradient_machine__.eval(evaluator)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册