diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index 725328ce4d29b5e39910d7ad9c78fd2f07fc0a9e..be6be556a73843c28c8883bb2eb5770ffa6774de 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -859,6 +859,13 @@ public: */ void update(Parameter* param); + /** + * @breif only get required sparse rows by default. + * @param fullSize: get full matrix parameter if *fullSize* set + * @param apply: get PARAMETER_APPLY on pserver if *apply* set + */ + void getParametersRemote(bool fullSize = false, bool apply = false); + /** * @brief restore the average parameter. * @note It is only used in AverageOptimizer. Restore will get the current diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/api/ParameterUpdater.cpp index 708379ded5b74cc37fb1ad5ec98950ac247ce36e..ce2ac33d44970e9802d25ade356b688ca4d62f6e 100644 --- a/paddle/api/ParameterUpdater.cpp +++ b/paddle/api/ParameterUpdater.cpp @@ -72,6 +72,10 @@ void ParameterUpdater::update(Parameter *param) { m->updater->update(paddleParam); } +void ParameterUpdater::getParametersRemote(bool fullSize, bool apply) { + m->updater->getParametersRemote(fullSize, apply); +} + void ParameterUpdater::restore() { m->updater->restore(); } void ParameterUpdater::apply() { m->updater->apply(); } diff --git a/python/paddle/v2/topology.py b/python/paddle/v2/topology.py index 86e7549e97201cb06af01d6e2c37f85375954262..ff28c85c53dc8255b6ad5e3975b07f72a9a64e4b 100644 --- a/python/paddle/v2/topology.py +++ b/python/paddle/v2/topology.py @@ -78,10 +78,12 @@ class Topology(object): check if any parameter require to use sparse_update :return: """ + use_sparse = False for parameter in self.__model_config__.parameters: if parameter.sparse_update or parameter.sparse_remote_update: - return True - return False + use_sparse = True + break + return use_sparse def proto(self): return self.__model_config__ diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 80f243b4137d3dce4fdc6026542c9dc8b8a8e765..30fc2a0886ba776f1ad42fed4f254c773323afdd 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -65,7 +65,6 @@ class SGD(object): self.__use_sparse_updater__ = self.__topology__.use_sparse_updater() # # In local mode, disable sparse_remote_update. if is_local: - self.__use_sparse_updater__ = False for param in self.__topology_in_proto__.parameters: if param.sparse_remote_update: param.sparse_remote_update = False @@ -100,11 +99,11 @@ class SGD(object): __check_train_args__(**locals()) if self.__is_local__: - updater = self.__optimizer__.create_local_updater() + parameter_updater = self.__optimizer__.create_local_updater() else: - updater = self.__optimizer__.create_remote_updater( + parameter_updater = self.__optimizer__.create_remote_updater( num_passes, self.__use_sparse_updater__) - updater.init(self.__gradient_machine__) + parameter_updater.init(self.__gradient_machine__) self.__gradient_machine__.start() batch_evaluator = self.__gradient_machine__.makeEvaluator() @@ -116,26 +115,26 @@ class SGD(object): for pass_id in xrange(num_passes): event_handler(v2_event.BeginPass(pass_id)) pass_evaluator.start() - updater.startPass() + parameter_updater.startPass() for batch_id, data_batch in enumerate(reader()): batch_evaluator.start() event_handler( v2_event.BeginIteration( pass_id=pass_id, batch_id=batch_id)) - pass_type = updater.startBatch(len(data_batch)) - if self.__use_sparse_updater__: + pass_type = parameter_updater.startBatch(len(data_batch)) + if self.__use_sparse_updater__ and not self.__is_local__: self.__gradient_machine__.prefetch(feeder(data_batch)) - updater.getParametersRemote() + parameter_updater.getParametersRemote() self.__gradient_machine__.forwardBackward( feeder(data_batch), out_args, pass_type) self.__gradient_machine__.eval(pass_evaluator) self.__gradient_machine__.eval(batch_evaluator) for each_param in self.__gradient_machine__.getNonStaticParameters( ): - updater.update(each_param) + parameter_updater.update(each_param) cost_sum = out_args.sum() cost = cost_sum / len(data_batch) - updater.finishBatch(cost) + parameter_updater.finishBatch(cost) batch_evaluator.finish() event_handler( v2_event.EndIteration( @@ -144,7 +143,7 @@ class SGD(object): cost=cost, evaluator=batch_evaluator)) - updater.finishPass() + parameter_updater.finishPass() pass_evaluator.finish() event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator)) self.__gradient_machine__.finish()