提交 a5d3f512 编写于 作者: X xiexionghang

fix code style

上级 505d7e7c
...@@ -62,7 +62,8 @@ class Trainer(object): ...@@ -62,7 +62,8 @@ class Trainer(object):
traceback.print_exc() traceback.print_exc()
print('Catch Exception:%s' % str(err)) print('Catch Exception:%s' % str(err))
sys.stdout.flush() sys.stdout.flush()
self._context['is_exit'] = self.handle_processor_exception(status, context, err) self._context['is_exit'] = self.handle_processor_exception(
status, context, err)
def other_status_processor(self, context): def other_status_processor(self, context):
""" """
...@@ -72,7 +73,7 @@ class Trainer(object): ...@@ -72,7 +73,7 @@ class Trainer(object):
""" """
print('unknow context_status:%s, do nothing' % context['status']) print('unknow context_status:%s, do nothing' % context['status'])
time.sleep(60) time.sleep(60)
def handle_processor_exception(self, status, context, exception): def handle_processor_exception(self, status, context, exception):
""" """
when exception throwed from processor, will call this func to handle it when exception throwed from processor, will call this func to handle it
...@@ -82,7 +83,6 @@ class Trainer(object): ...@@ -82,7 +83,6 @@ class Trainer(object):
print('Exit app. catch exception in precoss status:%s, except:%s' \ print('Exit app. catch exception in precoss status:%s, except:%s' \
% (context['status'], str(exception))) % (context['status'], str(exception)))
return True return True
def reload_train_context(self): def reload_train_context(self):
""" """
......
...@@ -66,7 +66,6 @@ class TimeSplitDatasetHolder(DatasetHolder): ...@@ -66,7 +66,6 @@ class TimeSplitDatasetHolder(DatasetHolder):
""" """
Dataset with time split dir. root_path/$DAY/$HOUR Dataset with time split dir. root_path/$DAY/$HOUR
""" """
def __init__(self, config): def __init__(self, config):
""" """
init data root_path, time_split_interval, data_path_format init data root_path, time_split_interval, data_path_format
...@@ -113,8 +112,8 @@ class TimeSplitDatasetHolder(DatasetHolder): ...@@ -113,8 +112,8 @@ class TimeSplitDatasetHolder(DatasetHolder):
True/False True/False
""" """
is_ready = True is_ready = True
data_time, windows_mins = self._format_data_time(daytime_str, data_time, windows_mins = self._format_data_time(
time_window_mins) daytime_str, time_window_mins)
while time_window_mins > 0: while time_window_mins > 0:
file_path = self._path_generator.generate_path( file_path = self._path_generator.generate_path(
'donefile_path', {'time_format': data_time}) 'donefile_path', {'time_format': data_time})
...@@ -142,18 +141,19 @@ class TimeSplitDatasetHolder(DatasetHolder): ...@@ -142,18 +141,19 @@ class TimeSplitDatasetHolder(DatasetHolder):
list, data_shard[node_idx] list, data_shard[node_idx]
""" """
data_file_list = [] data_file_list = []
data_time, windows_mins = self._format_data_time(daytime_str, data_time, windows_mins = self._format_data_time(
time_window_mins) daytime_str, time_window_mins)
while time_window_mins > 0: while time_window_mins > 0:
file_path = self._path_generator.generate_path( file_path = self._path_generator.generate_path(
'data_path', {'time_format': data_time}) 'data_path', {'time_format': data_time})
sub_file_list = self._data_file_handler.ls(file_path) sub_file_list = self._data_file_handler.ls(file_path)
for sub_file in sub_file_list: for sub_file in sub_file_list:
sub_file_name = self._data_file_handler.get_file_name(sub_file) sub_file_name = self._data_file_handler.get_file_name(sub_file)
if not sub_file_name.startswith(self._config[ if not sub_file_name.startswith(
'filename_prefix']): self._config['filename_prefix']):
continue continue
postfix= sub_file_name.split(self._config['filename_prefix'])[1] postfix = sub_file_name.split(
self._config['filename_prefix'])[1]
if postfix.isdigit(): if postfix.isdigit():
if int(postfix) % node_num == node_idx: if int(postfix) % node_num == node_idx:
data_file_list.append(sub_file) data_file_list.append(sub_file)
...@@ -167,8 +167,8 @@ class TimeSplitDatasetHolder(DatasetHolder): ...@@ -167,8 +167,8 @@ class TimeSplitDatasetHolder(DatasetHolder):
def _alloc_dataset(self, file_list): def _alloc_dataset(self, file_list):
""" """ """ """
dataset = fluid.DatasetFactory().create_dataset(self._config[ dataset = fluid.DatasetFactory().create_dataset(
'dataset_type']) self._config['dataset_type'])
dataset.set_batch_size(self._config['batch_size']) dataset.set_batch_size(self._config['batch_size'])
dataset.set_thread(self._config['load_thread']) dataset.set_thread(self._config['load_thread'])
dataset.set_hdfs_config(self._config['fs_name'], dataset.set_hdfs_config(self._config['fs_name'],
...@@ -207,8 +207,8 @@ class TimeSplitDatasetHolder(DatasetHolder): ...@@ -207,8 +207,8 @@ class TimeSplitDatasetHolder(DatasetHolder):
params['node_num'], params['node_num'],
params['node_idx']) params['node_idx'])
self._datasets[begin_time] = self._alloc_dataset(file_list) self._datasets[begin_time] = self._alloc_dataset(file_list)
self._datasets[begin_time].preload_into_memory(self._config[ self._datasets[begin_time].preload_into_memory(
'preload_thread']) self._config['preload_thread'])
return True return True
return False return False
......
...@@ -70,8 +70,8 @@ def set_global_envs(envs): ...@@ -70,8 +70,8 @@ def set_global_envs(envs):
nests = copy.deepcopy(namespace_nests) nests = copy.deepcopy(namespace_nests)
nests.append(k) nests.append(k)
fatten_env_namespace(nests, v) fatten_env_namespace(nests, v)
elif (k == "dataset" or k == "phase" or elif (k == "dataset" or k == "phase"
k == "runner") and isinstance(v, list): or k == "runner") and isinstance(v, list):
for i in v: for i in v:
if i.get("name") is None: if i.get("name") is None:
raise ValueError("name must be in dataset list ", v) raise ValueError("name must be in dataset list ", v)
...@@ -169,15 +169,14 @@ def pretty_print_envs(envs, header=None): ...@@ -169,15 +169,14 @@ def pretty_print_envs(envs, header=None):
def lazy_instance_by_package(package, class_name): def lazy_instance_by_package(package, class_name):
try: try:
model_package = __import__(package, model_package = __import__(package, globals(), locals(),
globals(), locals(), package.split(".")) package.split("."))
instance = getattr(model_package, class_name) instance = getattr(model_package, class_name)
return instance return instance
except Exception, err: except Exception, err:
traceback.print_exc() traceback.print_exc()
print('Catch Exception:%s' % str(err)) print('Catch Exception:%s' % str(err))
return None return None
def lazy_instance_by_fliename(abs, class_name): def lazy_instance_by_fliename(abs, class_name):
...@@ -186,8 +185,8 @@ def lazy_instance_by_fliename(abs, class_name): ...@@ -186,8 +185,8 @@ def lazy_instance_by_fliename(abs, class_name):
sys.path.append(dirname) sys.path.append(dirname)
package = os.path.splitext(os.path.basename(abs))[0] package = os.path.splitext(os.path.basename(abs))[0]
model_package = __import__(package, model_package = __import__(package, globals(), locals(),
globals(), locals(), package.split(".")) package.split("."))
instance = getattr(model_package, class_name) instance = getattr(model_package, class_name)
return instance return instance
except Exception, err: except Exception, err:
......
...@@ -101,6 +101,7 @@ def make_datetime(date_str, fmt=None): ...@@ -101,6 +101,7 @@ def make_datetime(date_str, fmt=None):
return datetime.datetime.strptime(date_str, '%Y%m%d%H%M') return datetime.datetime.strptime(date_str, '%Y%m%d%H%M')
return datetime.datetime.strptime(date_str, fmt) return datetime.datetime.strptime(date_str, fmt)
def wroker_numric_opt(fleet, value, env, opt): def wroker_numric_opt(fleet, value, env, opt):
""" """
numric count opt for workers numric count opt for workers
...@@ -116,6 +117,7 @@ def wroker_numric_opt(fleet, value, env, opt): ...@@ -116,6 +117,7 @@ def wroker_numric_opt(fleet, value, env, opt):
fleet._role_maker.all_reduce_worker(local_value, global_value, opt) fleet._role_maker.all_reduce_worker(local_value, global_value, opt)
return global_value[0] return global_value[0]
def worker_numric_sum(fleet, value, env="mpi"): def worker_numric_sum(fleet, value, env="mpi"):
"""R """R
""" """
...@@ -139,6 +141,7 @@ def worker_numric_max(fleet, value, env="mpi"): ...@@ -139,6 +141,7 @@ def worker_numric_max(fleet, value, env="mpi"):
""" """
return wroker_numric_opt(fleet, value, env, "max") return wroker_numric_opt(fleet, value, env, "max")
def print_log(log_str, params): def print_log(log_str, params):
"""R """R
""" """
...@@ -153,6 +156,7 @@ def print_log(log_str, params): ...@@ -153,6 +156,7 @@ def print_log(log_str, params):
if 'stdout' in params: if 'stdout' in params:
params['stdout'] += log_str + '\n' params['stdout'] += log_str + '\n'
def rank0_print(log_str, fleet): def rank0_print(log_str, fleet):
"""R """R
""" """
...@@ -171,7 +175,6 @@ class CostPrinter(object): ...@@ -171,7 +175,6 @@ class CostPrinter(object):
""" """
For count cost time && print cost log For count cost time && print cost log
""" """
def __init__(self, callback, callback_params): def __init__(self, callback, callback_params):
"""R """R
""" """
...@@ -207,7 +210,6 @@ class PathGenerator(object): ...@@ -207,7 +210,6 @@ class PathGenerator(object):
""" """
generate path with template & runtime variables generate path with template & runtime variables
""" """
def __init__(self, config): def __init__(self, config):
"""R """R
""" """
...@@ -228,8 +230,8 @@ class PathGenerator(object): ...@@ -228,8 +230,8 @@ class PathGenerator(object):
""" """
if template_name in self._templates: if template_name in self._templates:
if 'time_format' in param: if 'time_format' in param:
str = param['time_format'].strftime(self._templates[ str = param['time_format'].strftime(
template_name]) self._templates[template_name])
return str.format(**param) return str.format(**param)
return self._templates[template_name].format(**param) return self._templates[template_name].format(**param)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册