提交 cf86ca04 编写于 作者: Q qiaolongfei

refine code

上级 cfff9467
...@@ -34,8 +34,7 @@ ParameterUpdater *ParameterUpdater::createRemoteUpdater( ...@@ -34,8 +34,7 @@ ParameterUpdater *ParameterUpdater::createRemoteUpdater(
auto remoteUpdater = new paddle::RemoteParameterUpdater( auto remoteUpdater = new paddle::RemoteParameterUpdater(
config->m->getConfig(), passCount, nullptr); config->m->getConfig(), passCount, nullptr);
if (useSparseUpdater) { if (useSparseUpdater) {
std::unique_ptr<paddle::ParameterUpdater> remoteUpdaterPtr; std::unique_ptr<paddle::ParameterUpdater> remoteUpdaterPtr(remoteUpdater);
remoteUpdaterPtr.reset(remoteUpdater);
auto sparseRemoteUpdater = auto sparseRemoteUpdater =
new paddle::SparseRemoteParameterUpdaterComposite( new paddle::SparseRemoteParameterUpdaterComposite(
config->m->getConfig(), config->m->getConfig(),
......
...@@ -38,19 +38,34 @@ class Optimizer(object): ...@@ -38,19 +38,34 @@ class Optimizer(object):
assert isinstance(tmp, swig_api.ParameterOptimizer) assert isinstance(tmp, swig_api.ParameterOptimizer)
return tmp.getParameterTypes() return tmp.getParameterTypes()
def create_local_updater(self): def __create_local_updater__(self):
return swig_api.ParameterUpdater.createLocalUpdater(self.__opt_conf__) return swig_api.ParameterUpdater.createLocalUpdater(self.__opt_conf__)
def create_remote_updater(self, pass_num, use_sparse_updater): def __create_remote_updater__(self, pass_num, use_sparse_updater):
return swig_api.ParameterUpdater.createRemoteUpdater( return swig_api.ParameterUpdater.createRemoteUpdater(
self.__opt_conf__, pass_num, use_sparse_updater) self.__opt_conf__, pass_num, use_sparse_updater)
def create_updater(self, is_local, num_passes, use_sparse_updater): def create_updater(self, is_local, num_passes, use_sparse_updater):
"""
create proper parameter_updater by configuration.
:param is_local: create local or remote parameter updater
:param num_passes: remote parameter updater will use this to config
parameter server.
:param use_sparse_updater: when use remote updater, if some parameter is
sparse, updater should do some extra thing:
.. code-block:: python
if use_sparse_remote_updater:
gradient_machine.prefetch(in_args)
parameter_updater.getParametersRemote()
:return: parameter_updater
"""
if is_local: if is_local:
parameter_updater = self.create_local_updater() parameter_updater = self.__create_local_updater__()
else: else:
parameter_updater = self.create_remote_updater(num_passes, parameter_updater = self.__create_remote_updater__(
use_sparse_updater) num_passes, use_sparse_updater)
return parameter_updater return parameter_updater
......
...@@ -78,12 +78,24 @@ class SGD(object): ...@@ -78,12 +78,24 @@ class SGD(object):
assert isinstance(gm, api.GradientMachine) assert isinstance(gm, api.GradientMachine)
self.__gradient_machine__ = gm self.__gradient_machine__ = gm
self.__gradient_machine__.randParameters() self.__gradient_machine__.randParameters()
parameters.append_gradient_machine(gm) self.__parameters__.append_gradient_machine(gm)
self.__parameter_updater__ = None self.__parameter_updater__ = None
def use_remote_sparse_updater(self): def __use_remote_sparse_updater__(self):
return self.__use_sparse_updater__ and not self.__is_local__ return self.__use_sparse_updater__ and not self.__is_local__
def __prepare_parameter__(self, in_args):
"""
prepare parameter before forward backward.
1. When use remote sparse updater, parameters should be got
from ps according to input arguments.
:param in_args: input arguments of this batch.
:return:
"""
if self.__use_remote_sparse_updater__():
self.__gradient_machine__.prefetch(in_args)
self.__parameter_updater__.getParametersRemote()
def train(self, reader, num_passes=1, event_handler=None, feeding=None): def train(self, reader, num_passes=1, event_handler=None, feeding=None):
""" """
Training method. Will train num_passes of input data. Training method. Will train num_passes of input data.
...@@ -125,9 +137,7 @@ class SGD(object): ...@@ -125,9 +137,7 @@ class SGD(object):
pass_type = self.__parameter_updater__.startBatch( pass_type = self.__parameter_updater__.startBatch(
len(data_batch)) len(data_batch))
in_args = feeder(data_batch) in_args = feeder(data_batch)
if self.use_remote_sparse_updater(): self.__prepare_parameter__(in_args)
self.__gradient_machine__.prefetch(in_args)
self.__parameter_updater__.getParametersRemote()
self.__gradient_machine__.forwardBackward(in_args, out_args, self.__gradient_machine__.forwardBackward(in_args, out_args,
pass_type) pass_type)
self.__gradient_machine__.eval(pass_evaluator) self.__gradient_machine__.eval(pass_evaluator)
...@@ -161,9 +171,7 @@ class SGD(object): ...@@ -161,9 +171,7 @@ class SGD(object):
for data_batch in reader(): for data_batch in reader():
num_samples += len(data_batch) num_samples += len(data_batch)
in_args = feeder(data_batch) in_args = feeder(data_batch)
if self.use_remote_sparse_updater(): self.__prepare_parameter__(in_args)
self.__gradient_machine__.prefetch(in_args)
self.__parameter_updater__.getParametersRemote()
self.__gradient_machine__.forward(in_args, out_args, api.PASS_TEST) self.__gradient_machine__.forward(in_args, out_args, api.PASS_TEST)
total_cost += out_args.sum() total_cost += out_args.sum()
self.__gradient_machine__.eval(evaluator) self.__gradient_machine__.eval(evaluator)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册