diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index c4f5dca26cc6a5e9fdd23ee27b594ced29a25c7a..f5ead40682c69ffd610a3fe64207cb077d848949 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -469,6 +469,7 @@ private: enum GradientMatchineCreateMode { CREATE_MODE_NORMAL = 0, + CREATE_MODE_SGD_SPARSE_CPU_TRAINING = 3, CREATE_MODE_TESTING = 4 }; diff --git a/python/paddle/v2/topology.py b/python/paddle/v2/topology.py index 737b6bf1e2eb60281d4d6e92667d9fe91e243704..86e7549e97201cb06af01d6e2c37f85375954262 100644 --- a/python/paddle/v2/topology.py +++ b/python/paddle/v2/topology.py @@ -73,6 +73,16 @@ class Topology(object): assert isinstance(self.__model_config__, ModelConfig) + def use_sparse_updater(self): + """ + check if any parameter require to use sparse_update + :return: + """ + for parameter in self.__model_config__.parameters: + if parameter.sparse_update or parameter.sparse_remote_update: + return True + return False + def proto(self): return self.__model_config__ diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index f5797a86c2b71502a7791453ff86c6a486c9f185..2dac95b63d550733c54ee5fb13d2c02272fb1af5 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -42,7 +42,7 @@ class SGD(object): :type extra_layers: paddle.v2.config_base.Layer """ - def __init__(self, cost, parameters, update_equation, extra_layers=None): + def __init__(self, cost, parameters, update_equation, extra_layers=None, is_local=True): if not isinstance(parameters, v2_parameters.Parameters): raise TypeError('parameters should be parameters') @@ -55,15 +55,21 @@ class SGD(object): self.__topology__ = topology self.__parameters__ = parameters self.__topology_in_proto__ = topology.proto() - - # In local mode, disable sparse_remote_update. - for param in self.__topology_in_proto__.parameters: - if param.sparse_remote_update: - param.sparse_remote_update = False - + self.__is_local__ = is_local + + self.__use_sparse_updater__ = self.__topology__.use_sparse_updater() + # # In local mode, disable sparse_remote_update. + if is_local: + self.__use_sparse_updater__ = False + for param in self.__topology_in_proto__.parameters: + if param.sparse_remote_update: + param.sparse_remote_update = False + + self.__gm_create_mode__ = api.CREATE_MODE_NORMAL if not \ + self.__use_sparse_updater__ else api.CREATE_MODE_SGD_SPARSE_CPU_TRAINING self.__data_types__ = topology.data_type() gm = api.GradientMachine.createFromConfigProto( - self.__topology_in_proto__, api.CREATE_MODE_NORMAL, + self.__topology_in_proto__, self.__gm_create_mode__, self.__optimizer__.enable_types()) assert isinstance(gm, api.GradientMachine) self.__gradient_machine__ = gm @@ -88,7 +94,10 @@ class SGD(object): event_handler = default_event_handler __check_train_args__(**locals()) - updater = self.__optimizer__.create_local_updater() + if self.__is_local__: + updater = self.__optimizer__.create_local_updater() + else: + updater = self.__optimizer__.create_remote_updater(num_passes) updater.init(self.__gradient_machine__) self.__gradient_machine__.start() @@ -108,6 +117,9 @@ class SGD(object): v2_event.BeginIteration( pass_id=pass_id, batch_id=batch_id)) pass_type = updater.startBatch(len(data_batch)) + if self.__use_sparse_updater__: + self.__gradient_machine__.prefetch(feeder(data_batch)) + updater.getParametersRemote() self.__gradient_machine__.forwardBackward( feeder(data_batch), out_args, pass_type) self.__gradient_machine__.eval(pass_evaluator)