提交 1500c8e6 编写于 作者: H heqiaozhi

is instance is None

上级 10ed9e0a
...@@ -87,6 +87,7 @@ class AsyncExecutor(object): ...@@ -87,6 +87,7 @@ class AsyncExecutor(object):
scope = global_scope() scope = global_scope()
self.executor = core.AsyncExecutor(scope, p) self.executor = core.AsyncExecutor(scope, p)
self.instance = None
def run(self, program, data_feed, filelist, thread_num, fetch, mode="", debug=False): def run(self, program, data_feed, filelist, thread_num, fetch, mode="", debug=False):
""" """
...@@ -154,6 +155,9 @@ class AsyncExecutor(object): ...@@ -154,6 +155,9 @@ class AsyncExecutor(object):
def download_data(self, afs_path, local_path, fs_default_name, ugi, process_num=12): def download_data(self, afs_path, local_path, fs_default_name, ugi, process_num=12):
#hadoop_home = "$HADOOP_HOME" #hadoop_home = "$HADOOP_HOME"
if self.instance is None:
raise ValueError('instance is None, please run config_distributed_nodes init instance')
hadoop_home = "~/tools/hadoop-xingtian/hadoop/" hadoop_home = "~/tools/hadoop-xingtian/hadoop/"
configs = { configs = {
...@@ -182,15 +186,21 @@ class AsyncExecutor(object): ...@@ -182,15 +186,21 @@ class AsyncExecutor(object):
pass pass
def get_instance(self): def get_instance(self):
if self.instance is None:
raise ValueError('instance is None, please run config_distributed_nodes init instance')
return self.instance return self.instance
def stop_server(self): def stop_server(self):
if self.instance is None:
raise ValueError('instance is None, please run config_distributed_nodes init instance')
self.instance.barrier_all() #worker do all things self.instance.barrier_all() #worker do all things
if self.instance.is_first_worker(): if self.instance.is_first_worker():
self.executor.stop_server() self.executor.stop_server()
self.instance.barrier_all() #sync self.instance.barrier_all() #sync
def init_server(self, dist_desc): def init_server(self, dist_desc):
if self.instance is None:
raise ValueError('instance is None, please run config_distributed_nodes init instance')
self.executor.init_server(dist_desc, self.instance._rankid) self.executor.init_server(dist_desc, self.instance._rankid)
ip = self.executor.start_server() ip = self.executor.start_server()
self.instance.set_ip(ip) self.instance.set_ip(ip)
...@@ -204,6 +214,8 @@ class AsyncExecutor(object): ...@@ -204,6 +214,8 @@ class AsyncExecutor(object):
self.instance.barrier_all() #sync self.instance.barrier_all() #sync
def init_worker(self, dist_desc, startup_program): def init_worker(self, dist_desc, startup_program):
if self.instance is None:
raise ValueError('instance is None, please run config_distributed_nodes init instance')
place = core.CPUPlace() place = core.CPUPlace()
executor = Executor(place) executor = Executor(place)
executor.run(startup_program) executor.run(startup_program)
...@@ -217,8 +229,12 @@ class AsyncExecutor(object): ...@@ -217,8 +229,12 @@ class AsyncExecutor(object):
self.instance.barrier_all() #wait init model self.instance.barrier_all() #wait init model
def init_model(self): def init_model(self):
if self.instance is None:
raise ValueError('instance is None, please run config_distributed_nodes init instance')
self.executor.init_model() self.executor.init_model()
def save_model(self, save_path): def save_model(self, save_path):
if self.instance is None:
raise ValueError('instance is None, please run config_distributed_nodes init instance')
self.executor.save_model(save_path) self.executor.save_model(save_path)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册