From 57ac412b98990ac1d946ad32de30b07a15d0a18f Mon Sep 17 00:00:00 2001 From: heqiaozhi Date: Tue, 11 Dec 2018 17:48:25 +0800 Subject: [PATCH] download data --- python/paddle/fluid/async_executor.py | 22 ++++++++++++++++++- python/paddle/fluid/contrib/utils/__init__.py | 4 ++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/async_executor.py b/python/paddle/fluid/async_executor.py index 787a6a6b9..cce7ec5cc 100644 --- a/python/paddle/fluid/async_executor.py +++ b/python/paddle/fluid/async_executor.py @@ -25,6 +25,7 @@ from google.protobuf import text_format from . import io from .data_feed_desc import DataFeedDesc from .distributed import ps_instance +from .contrib.utils import hdfs_utils as hdfs __all__ = ['AsyncExecutor'] @@ -152,6 +153,22 @@ class AsyncExecutor(object): data_feed.desc(), filelist, thread_num, fetch_var_names, debug) + def download_data(self, afs_path, local_path, fs_default_name, ugi, process_num=12): + hadoop_home = "$HADOOP_HOME" + + configs = { + "fs.default.name": fs_default_name, + "hadoop.job.ugi": ugi + } + + client = hdfs.HDFSClient(hadoop_home, configs) + downloads = hdfs.multi_download( + client, + afs_path, + local_path, + self.instance.get_worker_index(), + self.instance.get_node_cnt() / 2, + multi_processes=process_num) def config_distributed_nodes(self, dist_opt): @@ -179,10 +196,11 @@ class AsyncExecutor(object): self.executor.gather_servers(ips, self.instance.get_node_cnt()) self.instance.barrier_all() #wait all worker start self.instance.barrier_all() #wait init model + self.instance.barrier_all() #wait for download_data self.instance.barrier_all() #wait worker do all things self.instance.barrier_all() #sync - def init_worker(self, dist_desc): + def init_worker(self, dist_desc, afs_path, local_path, fs_default_name, ugi): self.instance.barrier_all() #wait all server start ips = self.instance.gather_ips() self.executor.init_worker(dist_desc, ips, self.instance.get_node_cnt(), self.instance._rankid) @@ -190,6 +208,8 @@ class AsyncExecutor(object): if self.instance.is_first_worker(): self.executor.init_model() self.instance.barrier_all() #wait init model + self.download_data(afs_path, local_path, fs_default_name, ugi, process_num=12) + self.instance.barrier_all() #wait for download_data def init_model(self): self.executor.init_model() diff --git a/python/paddle/fluid/contrib/utils/__init__.py b/python/paddle/fluid/contrib/utils/__init__.py index 6e479bdc2..2fe9f702f 100644 --- a/python/paddle/fluid/contrib/utils/__init__.py +++ b/python/paddle/fluid/contrib/utils/__init__.py @@ -13,8 +13,8 @@ # limitations under the License. from __future__ import print_function -from . import lookup_table_utils -from .lookup_table_utils import * +#from . import lookup_table_utils +#from .lookup_table_utils import * from . import hdfs_utils from .hdfs_utils import * -- GitLab