提交 cfff9467 编写于 作者: Q qiaolongfei

optimizer parameter_updater

上级 6295f2d6
......@@ -45,6 +45,14 @@ class Optimizer(object):
return swig_api.ParameterUpdater.createRemoteUpdater(
self.__opt_conf__, pass_num, use_sparse_updater)
def create_updater(self, is_local, num_passes, use_sparse_updater):
if is_local:
parameter_updater = self.create_local_updater()
else:
parameter_updater = self.create_remote_updater(num_passes,
use_sparse_updater)
return parameter_updater
class Momentum(Optimizer):
def __init__(self, momentum=None, sparse=False, **kwargs):
......
......@@ -102,13 +102,9 @@ class SGD(object):
event_handler = default_event_handler
__check_train_args__(**locals())
if self.__is_local__:
parameter_updater = self.__optimizer__.create_local_updater()
else:
parameter_updater = self.__optimizer__.create_remote_updater(
num_passes, self.__use_sparse_updater__)
self.__parameter_updater__ = parameter_updater
parameter_updater.init(self.__gradient_machine__)
self.__parameter_updater__ = self.__optimizer__.create_updater(
self.__is_local__, num_passes, self.__use_sparse_updater__)
self.__parameter_updater__.init(self.__gradient_machine__)
self.__gradient_machine__.start()
batch_evaluator = self.__gradient_machine__.makeEvaluator()
......@@ -120,27 +116,28 @@ class SGD(object):
for pass_id in xrange(num_passes):
event_handler(v2_event.BeginPass(pass_id))
pass_evaluator.start()
parameter_updater.startPass()
self.__parameter_updater__.startPass()
for batch_id, data_batch in enumerate(reader()):
batch_evaluator.start()
event_handler(
v2_event.BeginIteration(
pass_id=pass_id, batch_id=batch_id))
pass_type = parameter_updater.startBatch(len(data_batch))
pass_type = self.__parameter_updater__.startBatch(
len(data_batch))
in_args = feeder(data_batch)
if self.use_remote_sparse_updater():
self.__gradient_machine__.prefetch(in_args)
parameter_updater.getParametersRemote()
self.__parameter_updater__.getParametersRemote()
self.__gradient_machine__.forwardBackward(in_args, out_args,
pass_type)
self.__gradient_machine__.eval(pass_evaluator)
self.__gradient_machine__.eval(batch_evaluator)
for each_param in self.__gradient_machine__.getNonStaticParameters(
):
parameter_updater.update(each_param)
self.__parameter_updater__.update(each_param)
cost_sum = out_args.sum()
cost = cost_sum / len(data_batch)
parameter_updater.finishBatch(cost)
self.__parameter_updater__.finishBatch(cost)
batch_evaluator.finish()
event_handler(
v2_event.EndIteration(
......@@ -149,7 +146,7 @@ class SGD(object):
cost=cost,
evaluator=batch_evaluator))
parameter_updater.finishPass()
self.__parameter_updater__.finishPass()
pass_evaluator.finish()
event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator))
self.__gradient_machine__.finish()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册