提交 854ee964 编写于 作者: D dongdaxiang

add doc string for async_executor.py

上级 e52bb816
......@@ -89,8 +89,14 @@ class AsyncExecutor(object):
self.executor = core.AsyncExecutor(scope, p)
self.instance = None
def run(self, program, data_feed, filelist, thread_num, fetch, mode="", debug=False):
def run(self,
program,
data_feed,
filelist,
thread_num,
fetch,
mode="",
debug=False):
"""
Run program by this AsyncExecutor. Training dataset will be in filelist.
Users can also inspect certain variables by naming them in parameter
......@@ -110,6 +116,7 @@ class AsyncExecutor(object):
thread_num(int): number of concurrent training threads. See
:code:`Note` for how to set this properly
fetch(str|list): the var name or a list of var names to inspect
mode(str): run mode of this interface
debug(bool): When set to True, fetch vars will be printed to
standard output after each minibatch
......@@ -154,14 +161,37 @@ class AsyncExecutor(object):
data_feed.desc(), filelist, thread_num,
fetch_var_names, mode, debug)
def download_data(self, afs_path, local_path, fs_default_name, ugi, file_cnt, hadoop_home="$HADOOP_HOME", process_num=12):
def download_data(self,
afs_path,
local_path,
fs_default_name,
ugi,
file_cnt,
hadoop_home="$HADOOP_HOME",
process_num=12):
"""
download_data is a default download method for distributed training
a user download data without this method
Example:
>>> exe = fluid.AsyncExecutor()
>>> exe.download_data("/xxx/xxx/xx/",
>>> "./data", "afs://
>>> xxx.xxx.xxx.xxx:9901", "xxx,yyy")
Args:
afs_path(str): afs_path defined by users
local_path(str): download data path
fs_default_name(str): file system server address
ugi(str): hadoop ugi
file_cn(int): a user can specify file number for debugging
hadoop_home(str): hadoop home path
process_num(int): download process num
"""
if self.instance is None:
raise ValueError('instance is None, please run config_distributed_nodes init instance')
raise ValueError('instance is None, please run'
'config_distributed_nodes init instance')
configs = {
"fs.default.name": fs_default_name,
"hadoop.job.ugi": ugi
}
configs = {"fs.default.name": fs_default_name, "hadoop.job.ugi": ugi}
client = hdfs.HDFSClient(hadoop_home, configs)
downloads = hdfs.multi_download(
......@@ -172,35 +202,53 @@ class AsyncExecutor(object):
self.instance.get_node_cnt() / 2,
file_cnt,
multi_processes=process_num)
#self.instance.barrier_all() #wait for download_data #TODO only barriere worker
self.instance.barrier_worker() #wait for download_data #TODO only barriere worker
def config_distributed_nodes(self):
self.instance = ps_instance.PaddlePSInstance(1, 2)
return self.instance
# get total rank
# get rank index
# get iplists
# get hadoop info
pass
self.instance.barrier_worker() #wait for download_data
def get_instance(self):
"""
get current node's instance so that user can do operations
in distributed setting
"""
if self.instance is None:
raise ValueError('instance is None, please run config_distributed_nodes init instance')
raise ValueError(
'instance is None, please run config_distributed_nodes init instance'
)
return self.instance
def config_distributed_nodes(self):
"""
if a user needs to run distributed async executor
he or she needs to do a global configuration so that
information of current process can be obtained
"""
self.instance = ps_instance.PaddlePSInstance(1, 2)
return self.instance
def stop(self):
"""
at the end of process, users should call stop to servers
and barrier all workers
"""
if self.instance is None:
raise ValueError('instance is None, please run config_distributed_nodes init instance')
raise ValueError(
'instance is None, please run config_distributed_nodes init instance'
)
self.instance.barrier_worker() #worker do all things
if self.instance.is_first_worker():
self.executor.stop_server()
self.instance.barrier_worker() #sync
def init_server(self, dist_desc):
"""
initialize server of current node if current process is a server
Args:
dist_desc(str): a protobuf string that describes
how to init a worker and a server
"""
if self.instance is None:
raise ValueError('instance is None, please run config_distributed_nodes init instance')
raise ValueError(
'instance is None, please run config_distributed_nodes init instance'
)
self.executor.init_server(dist_desc, self.instance._rankid)
ip = self.executor.start_server()
self.instance.set_ip(ip)
......@@ -210,27 +258,51 @@ class AsyncExecutor(object):
self.instance.barrier_all() #wait all worker start
def init_worker(self, dist_desc, startup_program):
"""
initialize worker of current node if current process is a worker
Args:
dist_desc(str): a protobuf string that describes
how to init a worker and a server
startup_program(fluid.Program): startup program of current process
"""
if self.instance is None:
raise ValueError('instance is None, please run config_distributed_nodes init instance')
raise ValueError(
'instance is None, please run config_distributed_nodes init instance'
)
place = core.CPUPlace()
executor = Executor(place)
executor.run(startup_program)
self.instance.barrier_all() #wait all server start
ips = self.instance.gather_ips()
self.executor.init_worker(dist_desc, ips, self.instance.get_node_cnt(), self.instance._rankid)
self.executor.init_worker(dist_desc, ips,
self.instance.get_node_cnt(),
self.instance._rankid)
self.instance.barrier_all() #wait all worker start
if self.instance.is_first_worker():
self.executor.init_model()
self.instance.barrier_worker() #wait init model
def init_model(self):
"""
init_model command that can be invoked from one of the worker
model parameters are initialized in servers
"""
if self.instance is None:
raise ValueError('instance is None, please run config_distributed_nodes init instance')
raise ValueError(
'instance is None, please run config_distributed_nodes init instance'
)
self.executor.init_model()
def save_model(self, save_path):
"""
save_model command that can be invoked from one of the worker
model parameters are saved in servers and upload to save_path of file system
Args:
save_path(str): path to file system
"""
if self.instance is None:
raise ValueError('instance is None, please run config_distributed_nodes init instance')
raise ValueError(
'instance is None, please run config_distributed_nodes init instance'
)
self.executor.save_model(save_path)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册