diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 30fc2a0886ba776f1ad42fed4f254c773323afdd..c1f964a8106d0887d401560cea7d9a184e10abf8 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -79,6 +79,10 @@ class SGD(object): self.__gradient_machine__ = gm self.__gradient_machine__.randParameters() parameters.append_gradient_machine(gm) + self.__parameter_updater__ = None + + def use_remote_sparse_updater(self): + return self.__use_sparse_updater__ and not self.__is_local__ def train(self, reader, num_passes=1, event_handler=None, feeding=None): """ @@ -103,6 +107,7 @@ class SGD(object): else: parameter_updater = self.__optimizer__.create_remote_updater( num_passes, self.__use_sparse_updater__) + self.__parameter_updater__ = parameter_updater parameter_updater.init(self.__gradient_machine__) self.__gradient_machine__.start() @@ -122,11 +127,12 @@ class SGD(object): v2_event.BeginIteration( pass_id=pass_id, batch_id=batch_id)) pass_type = parameter_updater.startBatch(len(data_batch)) - if self.__use_sparse_updater__ and not self.__is_local__: - self.__gradient_machine__.prefetch(feeder(data_batch)) + in_args = feeder(data_batch) + if self.use_remote_sparse_updater(): + self.__gradient_machine__.prefetch(in_args) parameter_updater.getParametersRemote() - self.__gradient_machine__.forwardBackward( - feeder(data_batch), out_args, pass_type) + self.__gradient_machine__.forwardBackward(in_args, out_args, + pass_type) self.__gradient_machine__.eval(pass_evaluator) self.__gradient_machine__.eval(batch_evaluator) for each_param in self.__gradient_machine__.getNonStaticParameters( @@ -157,8 +163,11 @@ class SGD(object): num_samples = 0.0 for data_batch in reader(): num_samples += len(data_batch) - self.__gradient_machine__.forward( - feeder(data_batch), out_args, api.PASS_TEST) + in_args = feeder(data_batch) + if self.use_remote_sparse_updater(): + self.__gradient_machine__.prefetch(in_args) + self.__parameter_updater__.getParametersRemote() + self.__gradient_machine__.forward(in_args, out_args, api.PASS_TEST) total_cost += out_args.sum() self.__gradient_machine__.eval(evaluator)