提交 854ee964 编写于 作者: D dongdaxiang

add doc string for async_executor.py

上级 e52bb816
...@@ -89,8 +89,14 @@ class AsyncExecutor(object): ...@@ -89,8 +89,14 @@ class AsyncExecutor(object):
self.executor = core.AsyncExecutor(scope, p) self.executor = core.AsyncExecutor(scope, p)
self.instance = None self.instance = None
def run(self,
def run(self, program, data_feed, filelist, thread_num, fetch, mode="", debug=False): program,
data_feed,
filelist,
thread_num,
fetch,
mode="",
debug=False):
""" """
Run program by this AsyncExecutor. Training dataset will be in filelist. Run program by this AsyncExecutor. Training dataset will be in filelist.
Users can also inspect certain variables by naming them in parameter Users can also inspect certain variables by naming them in parameter
...@@ -110,6 +116,7 @@ class AsyncExecutor(object): ...@@ -110,6 +116,7 @@ class AsyncExecutor(object):
thread_num(int): number of concurrent training threads. See thread_num(int): number of concurrent training threads. See
:code:`Note` for how to set this properly :code:`Note` for how to set this properly
fetch(str|list): the var name or a list of var names to inspect fetch(str|list): the var name or a list of var names to inspect
mode(str): run mode of this interface
debug(bool): When set to True, fetch vars will be printed to debug(bool): When set to True, fetch vars will be printed to
standard output after each minibatch standard output after each minibatch
...@@ -154,83 +161,148 @@ class AsyncExecutor(object): ...@@ -154,83 +161,148 @@ class AsyncExecutor(object):
data_feed.desc(), filelist, thread_num, data_feed.desc(), filelist, thread_num,
fetch_var_names, mode, debug) fetch_var_names, mode, debug)
def download_data(self, afs_path, local_path, fs_default_name, ugi, file_cnt, hadoop_home="$HADOOP_HOME", process_num=12): def download_data(self,
afs_path,
local_path,
fs_default_name,
ugi,
file_cnt,
hadoop_home="$HADOOP_HOME",
process_num=12):
"""
download_data is a default download method for distributed training
a user download data without this method
Example:
>>> exe = fluid.AsyncExecutor()
>>> exe.download_data("/xxx/xxx/xx/",
>>> "./data", "afs://
>>> xxx.xxx.xxx.xxx:9901", "xxx,yyy")
Args:
afs_path(str): afs_path defined by users
local_path(str): download data path
fs_default_name(str): file system server address
ugi(str): hadoop ugi
file_cn(int): a user can specify file number for debugging
hadoop_home(str): hadoop home path
process_num(int): download process num
"""
if self.instance is None: if self.instance is None:
raise ValueError('instance is None, please run config_distributed_nodes init instance') raise ValueError('instance is None, please run'
'config_distributed_nodes init instance')
configs = {
"fs.default.name": fs_default_name, configs = {"fs.default.name": fs_default_name, "hadoop.job.ugi": ugi}
"hadoop.job.ugi": ugi
}
client = hdfs.HDFSClient(hadoop_home, configs) client = hdfs.HDFSClient(hadoop_home, configs)
downloads = hdfs.multi_download( downloads = hdfs.multi_download(
client, client,
afs_path, afs_path,
local_path, local_path,
self.instance.get_worker_index(), self.instance.get_worker_index(),
self.instance.get_node_cnt() / 2, self.instance.get_node_cnt() / 2,
file_cnt, file_cnt,
multi_processes=process_num) multi_processes=process_num)
#self.instance.barrier_all() #wait for download_data #TODO only barriere worker self.instance.barrier_worker() #wait for download_data
self.instance.barrier_worker() #wait for download_data #TODO only barriere worker
def config_distributed_nodes(self):
self.instance = ps_instance.PaddlePSInstance(1, 2)
return self.instance
# get total rank
# get rank index
# get iplists
# get hadoop info
pass
def get_instance(self): def get_instance(self):
"""
get current node's instance so that user can do operations
in distributed setting
"""
if self.instance is None: if self.instance is None:
raise ValueError('instance is None, please run config_distributed_nodes init instance') raise ValueError(
'instance is None, please run config_distributed_nodes init instance'
)
return self.instance
def config_distributed_nodes(self):
"""
if a user needs to run distributed async executor
he or she needs to do a global configuration so that
information of current process can be obtained
"""
self.instance = ps_instance.PaddlePSInstance(1, 2)
return self.instance return self.instance
def stop(self): def stop(self):
"""
at the end of process, users should call stop to servers
and barrier all workers
"""
if self.instance is None: if self.instance is None:
raise ValueError('instance is None, please run config_distributed_nodes init instance') raise ValueError(
self.instance.barrier_worker() #worker do all things 'instance is None, please run config_distributed_nodes init instance'
)
self.instance.barrier_worker() #worker do all things
if self.instance.is_first_worker(): if self.instance.is_first_worker():
self.executor.stop_server() self.executor.stop_server()
self.instance.barrier_worker() #sync self.instance.barrier_worker() #sync
def init_server(self, dist_desc): def init_server(self, dist_desc):
"""
initialize server of current node if current process is a server
Args:
dist_desc(str): a protobuf string that describes
how to init a worker and a server
"""
if self.instance is None: if self.instance is None:
raise ValueError('instance is None, please run config_distributed_nodes init instance') raise ValueError(
'instance is None, please run config_distributed_nodes init instance'
)
self.executor.init_server(dist_desc, self.instance._rankid) self.executor.init_server(dist_desc, self.instance._rankid)
ip = self.executor.start_server() ip = self.executor.start_server()
self.instance.set_ip(ip) self.instance.set_ip(ip)
self.instance.barrier_all() #wait all server start self.instance.barrier_all() #wait all server start
ips = self.instance.gather_ips() ips = self.instance.gather_ips()
self.executor.gather_servers(ips, self.instance.get_node_cnt()) self.executor.gather_servers(ips, self.instance.get_node_cnt())
self.instance.barrier_all() #wait all worker start self.instance.barrier_all() #wait all worker start
def init_worker(self, dist_desc, startup_program): def init_worker(self, dist_desc, startup_program):
"""
initialize worker of current node if current process is a worker
Args:
dist_desc(str): a protobuf string that describes
how to init a worker and a server
startup_program(fluid.Program): startup program of current process
"""
if self.instance is None: if self.instance is None:
raise ValueError('instance is None, please run config_distributed_nodes init instance') raise ValueError(
'instance is None, please run config_distributed_nodes init instance'
)
place = core.CPUPlace() place = core.CPUPlace()
executor = Executor(place) executor = Executor(place)
executor.run(startup_program) executor.run(startup_program)
self.instance.barrier_all() #wait all server start self.instance.barrier_all() #wait all server start
ips = self.instance.gather_ips() ips = self.instance.gather_ips()
self.executor.init_worker(dist_desc, ips, self.instance.get_node_cnt(), self.instance._rankid) self.executor.init_worker(dist_desc, ips,
self.instance.barrier_all() #wait all worker start self.instance.get_node_cnt(),
self.instance._rankid)
self.instance.barrier_all() #wait all worker start
if self.instance.is_first_worker(): if self.instance.is_first_worker():
self.executor.init_model() self.executor.init_model()
self.instance.barrier_worker() #wait init model self.instance.barrier_worker() #wait init model
def init_model(self): def init_model(self):
"""
init_model command that can be invoked from one of the worker
model parameters are initialized in servers
"""
if self.instance is None: if self.instance is None:
raise ValueError('instance is None, please run config_distributed_nodes init instance') raise ValueError(
'instance is None, please run config_distributed_nodes init instance'
)
self.executor.init_model() self.executor.init_model()
def save_model(self, save_path): def save_model(self, save_path):
"""
save_model command that can be invoked from one of the worker
model parameters are saved in servers and upload to save_path of file system
Args:
save_path(str): path to file system
"""
if self.instance is None: if self.instance is None:
raise ValueError('instance is None, please run config_distributed_nodes init instance') raise ValueError(
'instance is None, please run config_distributed_nodes init instance'
)
self.executor.save_model(save_path) self.executor.save_model(save_path)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册