From 8a335b50bec5e94077a312552c8294b5e4425abe Mon Sep 17 00:00:00 2001 From: dongdaxiang Date: Tue, 29 Jan 2019 20:28:54 +0800 Subject: [PATCH] add downpour device_worker pb configuration --- paddle/fluid/framework/trainer_desc.proto | 1 - python/paddle/fluid/async_executor.py | 37 +++++++++++++++++++++++ python/paddle/fluid/trainer_desc.py | 33 ++++++++++++++++++-- 3 files changed, 68 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index 54b698cd530..a3054b61b07 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -59,7 +59,6 @@ message TableParameter { optional int64 table_id = 1; repeated string dense_value_name = 2; repeated string dense_grad_name = 3; - repeated int32 dense_table_size = 4; repeated int32 push_dense_wait_times = 5; // sparse table only repeated string sparse_key_name = 6; diff --git a/python/paddle/fluid/async_executor.py b/python/paddle/fluid/async_executor.py index 25f95ffbb0a..7068f51331b 100644 --- a/python/paddle/fluid/async_executor.py +++ b/python/paddle/fluid/async_executor.py @@ -24,6 +24,7 @@ from paddle.fluid.proto import data_feed_pb2 from google.protobuf import text_format from . import io from .data_feed_desc import DataFeedDesc +from .trainer_desc import TrainerDesc, MultiTrainer, DistMultiTrainer from .distributed import ps_instance from .contrib.utils import hdfs_utils as hdfs @@ -89,6 +90,38 @@ class AsyncExecutor(object): self.executor = core.AsyncExecutor(scope, p) self.instance = None + def run(self, program, data_feed, filelist, thread_num, fetch, debug=False): + if program is None: + program = default_main_program() + program_desc = program.desc + + if data_feed is None: + raise ValueError('ValueError: data_feed should be provided') + + if filelist is None: + raise ValueError('ValueError: filelist should be provided') + + if isinstance(filelist, str): + filelist = [filelist] + + if not isinstance(thread_num, int): + raise TypeError('TypeError: thread_num should be a positive number') + + is_local = self.instance == None + trainer = None + if is_local: + trainer = MultiTrainer(data_feed=data_feed, worker="Hogwild") + else: + trainer = DistMultiTrainer( + data_feed, worker="Downpour", fleet_desc=self.dist_desc) + + # define a trainer and a device_worker here + trainer.set_thread(thread_num) + trainer.set_filelist(filelist) + trainer.set_data_feed(data_feed) + self.executor.run_from_files(program_desc, trainer._desc(), debug) + + ''' def run(self, program, data_feed, @@ -160,6 +193,7 @@ class AsyncExecutor(object): self.executor.run_from_files(program_desc, data_feed.desc(), filelist, thread_num, fetch_var_names, mode, debug) + ''' def download_data(self, afs_path, @@ -250,6 +284,7 @@ class AsyncExecutor(object): raise ValueError( 'instance is None, please run config_distributed_nodes init instance' ) + self.init_desc = init_desc self.executor.init_server(dist_desc, self.instance._rankid) ip = self.executor.start_server() self.instance.set_ip(ip) @@ -270,6 +305,8 @@ class AsyncExecutor(object): raise ValueError( 'instance is None, please run config_distributed_nodes init instance' ) + + self.dist_desc = dist_desc place = core.CPUPlace() executor = Executor(place) executor.run(startup_program) diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index 77ee951dbda..85bfb0a4ee3 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -13,6 +13,7 @@ # limitations under the License. from paddle.fluid.proto import trainer_desc_pb2 +import ps_pb2 as pslib from google.protobuf import text_format __all__ = ['TrainerDesc', 'MultiTrainer', 'DistMultiTrainer'] @@ -42,7 +43,7 @@ class TrainerDesc(object): class MultiTrainer(TrainerDesc): - def __init__(self, worker="Hogwild"): + def __init__(self, dataset=None, worker="Hogwild"): super(MultiTrainer, self).__init__() if worker == "Hogwild": self.proto_desc.device_worker_name = worker + "Worker" @@ -53,11 +54,39 @@ class MultiTrainer(TrainerDesc): class DistMultiTrainer(TrainerDesc): - def __init__(self, worker='Downpour'): + def __init__(self, dataset=None, worker='Downpour', fleet_desc=None): super(DistMultiTrainer, self).__init__() if worker == "Downpour": self.proto_desc.device_worker_name = worker + "Worker" self.proto_desc.class_name = "DistMultiTrainer" + self.proto_desc.data_feed.CopyFrom(dataset) + downpour = self.proto_desc.downpour_param.add() + # sparse table should specify: + sparse_table = downpour.sparse_table.add() + sparse_table.table_id = \ + fleet_desc.trainer_param.sparse_table.table_id + sparse_table.sparse_key_name.CopyFrom(fleet_desc.trainer_param() + .sparse_table().slot_key()) + sparse_table.sparse_value_name.CopyFrom(fleet_desc.trainer_param( + ).sparse_table().slot_value()) + sparse_table.sparse_grad_name.CopyFrom(fleet_desc.trainer_param( + ).sparse_table().slot_gradient()) + sparse_table.emb_dim = fleet_desc.server_param.downpour_server_param.downpour_table_param.accessor.fea_dim - 2 + sparse_table.fea_dim = downpour.emb_dim + 2 + sparse_table.label_var_name = "click" + + # dense table should specify: + dense_table = downpour.dense_table.add() + dense_table.table_id = \ + fleet_desc.trainer_param.dense_table.table_id + # dense_value_name + dense_table.dense_value_name.CopyFrom(fleet_desc.trainer_param( + ).dense_table().dense_variable_name) + # dense_grad_name + dense_table.dense_grad_name.CopyFrom(fleet_desc.trainer_param( + ).dense_table().dense_gradient_name) + downpour.skipped_ops.extend(fleet_desc.trainer_param.skip_op) + print(str(self.proto_desc)) else: raise ValueError('ValueError: DeviceWorker %s ' 'is not supported in DistMultiTrainer' % worker) -- GitLab