From cfff9467ac4df2731c3aadaf4060f93c49a885e8 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 17 Apr 2017 10:08:45 +0800 Subject: [PATCH] optimizer parameter_updater --- python/paddle/v2/optimizer.py | 8 ++++++++ python/paddle/v2/trainer.py | 23 ++++++++++------------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/python/paddle/v2/optimizer.py b/python/paddle/v2/optimizer.py index 7bac1ea3b..887b2567a 100644 --- a/python/paddle/v2/optimizer.py +++ b/python/paddle/v2/optimizer.py @@ -45,6 +45,14 @@ class Optimizer(object): return swig_api.ParameterUpdater.createRemoteUpdater( self.__opt_conf__, pass_num, use_sparse_updater) + def create_updater(self, is_local, num_passes, use_sparse_updater): + if is_local: + parameter_updater = self.create_local_updater() + else: + parameter_updater = self.create_remote_updater(num_passes, + use_sparse_updater) + return parameter_updater + class Momentum(Optimizer): def __init__(self, momentum=None, sparse=False, **kwargs): diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index c1f964a81..9caaeca2e 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -102,13 +102,9 @@ class SGD(object): event_handler = default_event_handler __check_train_args__(**locals()) - if self.__is_local__: - parameter_updater = self.__optimizer__.create_local_updater() - else: - parameter_updater = self.__optimizer__.create_remote_updater( - num_passes, self.__use_sparse_updater__) - self.__parameter_updater__ = parameter_updater - parameter_updater.init(self.__gradient_machine__) + self.__parameter_updater__ = self.__optimizer__.create_updater( + self.__is_local__, num_passes, self.__use_sparse_updater__) + self.__parameter_updater__.init(self.__gradient_machine__) self.__gradient_machine__.start() batch_evaluator = self.__gradient_machine__.makeEvaluator() @@ -120,27 +116,28 @@ class SGD(object): for pass_id in xrange(num_passes): event_handler(v2_event.BeginPass(pass_id)) pass_evaluator.start() - parameter_updater.startPass() + self.__parameter_updater__.startPass() for batch_id, data_batch in enumerate(reader()): batch_evaluator.start() event_handler( v2_event.BeginIteration( pass_id=pass_id, batch_id=batch_id)) - pass_type = parameter_updater.startBatch(len(data_batch)) + pass_type = self.__parameter_updater__.startBatch( + len(data_batch)) in_args = feeder(data_batch) if self.use_remote_sparse_updater(): self.__gradient_machine__.prefetch(in_args) - parameter_updater.getParametersRemote() + self.__parameter_updater__.getParametersRemote() self.__gradient_machine__.forwardBackward(in_args, out_args, pass_type) self.__gradient_machine__.eval(pass_evaluator) self.__gradient_machine__.eval(batch_evaluator) for each_param in self.__gradient_machine__.getNonStaticParameters( ): - parameter_updater.update(each_param) + self.__parameter_updater__.update(each_param) cost_sum = out_args.sum() cost = cost_sum / len(data_batch) - parameter_updater.finishBatch(cost) + self.__parameter_updater__.finishBatch(cost) batch_evaluator.finish() event_handler( v2_event.EndIteration( @@ -149,7 +146,7 @@ class SGD(object): cost=cost, evaluator=batch_evaluator)) - parameter_updater.finishPass() + self.__parameter_updater__.finishPass() pass_evaluator.finish() event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator)) self.__gradient_machine__.finish() -- GitLab