提交 57ac412b 编写于 作者: H heqiaozhi

download data

上级 72968400
...@@ -25,6 +25,7 @@ from google.protobuf import text_format ...@@ -25,6 +25,7 @@ from google.protobuf import text_format
from . import io from . import io
from .data_feed_desc import DataFeedDesc from .data_feed_desc import DataFeedDesc
from .distributed import ps_instance from .distributed import ps_instance
from .contrib.utils import hdfs_utils as hdfs
__all__ = ['AsyncExecutor'] __all__ = ['AsyncExecutor']
...@@ -152,6 +153,22 @@ class AsyncExecutor(object): ...@@ -152,6 +153,22 @@ class AsyncExecutor(object):
data_feed.desc(), filelist, thread_num, data_feed.desc(), filelist, thread_num,
fetch_var_names, debug) fetch_var_names, debug)
def download_data(self, afs_path, local_path, fs_default_name, ugi, process_num=12):
hadoop_home = "$HADOOP_HOME"
configs = {
"fs.default.name": fs_default_name,
"hadoop.job.ugi": ugi
}
client = hdfs.HDFSClient(hadoop_home, configs)
downloads = hdfs.multi_download(
client,
afs_path,
local_path,
self.instance.get_worker_index(),
self.instance.get_node_cnt() / 2,
multi_processes=process_num)
def config_distributed_nodes(self, dist_opt): def config_distributed_nodes(self, dist_opt):
...@@ -179,10 +196,11 @@ class AsyncExecutor(object): ...@@ -179,10 +196,11 @@ class AsyncExecutor(object):
self.executor.gather_servers(ips, self.instance.get_node_cnt()) self.executor.gather_servers(ips, self.instance.get_node_cnt())
self.instance.barrier_all() #wait all worker start self.instance.barrier_all() #wait all worker start
self.instance.barrier_all() #wait init model self.instance.barrier_all() #wait init model
self.instance.barrier_all() #wait for download_data
self.instance.barrier_all() #wait worker do all things self.instance.barrier_all() #wait worker do all things
self.instance.barrier_all() #sync self.instance.barrier_all() #sync
def init_worker(self, dist_desc): def init_worker(self, dist_desc, afs_path, local_path, fs_default_name, ugi):
self.instance.barrier_all() #wait all server start self.instance.barrier_all() #wait all server start
ips = self.instance.gather_ips() ips = self.instance.gather_ips()
self.executor.init_worker(dist_desc, ips, self.instance.get_node_cnt(), self.instance._rankid) self.executor.init_worker(dist_desc, ips, self.instance.get_node_cnt(), self.instance._rankid)
...@@ -190,6 +208,8 @@ class AsyncExecutor(object): ...@@ -190,6 +208,8 @@ class AsyncExecutor(object):
if self.instance.is_first_worker(): if self.instance.is_first_worker():
self.executor.init_model() self.executor.init_model()
self.instance.barrier_all() #wait init model self.instance.barrier_all() #wait init model
self.download_data(afs_path, local_path, fs_default_name, ugi, process_num=12)
self.instance.barrier_all() #wait for download_data
def init_model(self): def init_model(self):
self.executor.init_model() self.executor.init_model()
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
from . import lookup_table_utils #from . import lookup_table_utils
from .lookup_table_utils import * #from .lookup_table_utils import *
from . import hdfs_utils from . import hdfs_utils
from .hdfs_utils import * from .hdfs_utils import *
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册