From a5d3f512d657ad6c8a8b8f338f7b3a290ba70d45 Mon Sep 17 00:00:00 2001 From: xiexionghang Date: Mon, 8 Jun 2020 16:44:10 +0800 Subject: [PATCH] fix code style --- core/trainer.py | 6 +++--- core/utils/dataset_holder.py | 24 ++++++++++++------------ core/utils/envs.py | 13 ++++++------- core/utils/util.py | 10 ++++++---- 4 files changed, 27 insertions(+), 26 deletions(-) diff --git a/core/trainer.py b/core/trainer.py index 030ee869..ecb069a3 100755 --- a/core/trainer.py +++ b/core/trainer.py @@ -62,7 +62,8 @@ class Trainer(object): traceback.print_exc() print('Catch Exception:%s' % str(err)) sys.stdout.flush() - self._context['is_exit'] = self.handle_processor_exception(status, context, err) + self._context['is_exit'] = self.handle_processor_exception( + status, context, err) def other_status_processor(self, context): """ @@ -72,7 +73,7 @@ class Trainer(object): """ print('unknow context_status:%s, do nothing' % context['status']) time.sleep(60) - + def handle_processor_exception(self, status, context, exception): """ when exception throwed from processor, will call this func to handle it @@ -82,7 +83,6 @@ class Trainer(object): print('Exit app. catch exception in precoss status:%s, except:%s' \ % (context['status'], str(exception))) return True - def reload_train_context(self): """ diff --git a/core/utils/dataset_holder.py b/core/utils/dataset_holder.py index 9b8a3d9f..748ce447 100755 --- a/core/utils/dataset_holder.py +++ b/core/utils/dataset_holder.py @@ -66,7 +66,6 @@ class TimeSplitDatasetHolder(DatasetHolder): """ Dataset with time split dir. root_path/$DAY/$HOUR """ - def __init__(self, config): """ init data root_path, time_split_interval, data_path_format @@ -113,8 +112,8 @@ class TimeSplitDatasetHolder(DatasetHolder): True/False """ is_ready = True - data_time, windows_mins = self._format_data_time(daytime_str, - time_window_mins) + data_time, windows_mins = self._format_data_time( + daytime_str, time_window_mins) while time_window_mins > 0: file_path = self._path_generator.generate_path( 'donefile_path', {'time_format': data_time}) @@ -142,18 +141,19 @@ class TimeSplitDatasetHolder(DatasetHolder): list, data_shard[node_idx] """ data_file_list = [] - data_time, windows_mins = self._format_data_time(daytime_str, - time_window_mins) + data_time, windows_mins = self._format_data_time( + daytime_str, time_window_mins) while time_window_mins > 0: file_path = self._path_generator.generate_path( 'data_path', {'time_format': data_time}) sub_file_list = self._data_file_handler.ls(file_path) for sub_file in sub_file_list: sub_file_name = self._data_file_handler.get_file_name(sub_file) - if not sub_file_name.startswith(self._config[ - 'filename_prefix']): + if not sub_file_name.startswith( + self._config['filename_prefix']): continue - postfix= sub_file_name.split(self._config['filename_prefix'])[1] + postfix = sub_file_name.split( + self._config['filename_prefix'])[1] if postfix.isdigit(): if int(postfix) % node_num == node_idx: data_file_list.append(sub_file) @@ -167,8 +167,8 @@ class TimeSplitDatasetHolder(DatasetHolder): def _alloc_dataset(self, file_list): """ """ - dataset = fluid.DatasetFactory().create_dataset(self._config[ - 'dataset_type']) + dataset = fluid.DatasetFactory().create_dataset( + self._config['dataset_type']) dataset.set_batch_size(self._config['batch_size']) dataset.set_thread(self._config['load_thread']) dataset.set_hdfs_config(self._config['fs_name'], @@ -207,8 +207,8 @@ class TimeSplitDatasetHolder(DatasetHolder): params['node_num'], params['node_idx']) self._datasets[begin_time] = self._alloc_dataset(file_list) - self._datasets[begin_time].preload_into_memory(self._config[ - 'preload_thread']) + self._datasets[begin_time].preload_into_memory( + self._config['preload_thread']) return True return False diff --git a/core/utils/envs.py b/core/utils/envs.py index dffc069a..09691259 100755 --- a/core/utils/envs.py +++ b/core/utils/envs.py @@ -70,8 +70,8 @@ def set_global_envs(envs): nests = copy.deepcopy(namespace_nests) nests.append(k) fatten_env_namespace(nests, v) - elif (k == "dataset" or k == "phase" or - k == "runner") and isinstance(v, list): + elif (k == "dataset" or k == "phase" + or k == "runner") and isinstance(v, list): for i in v: if i.get("name") is None: raise ValueError("name must be in dataset list ", v) @@ -169,15 +169,14 @@ def pretty_print_envs(envs, header=None): def lazy_instance_by_package(package, class_name): try: - model_package = __import__(package, - globals(), locals(), package.split(".")) + model_package = __import__(package, globals(), locals(), + package.split(".")) instance = getattr(model_package, class_name) return instance except Exception, err: traceback.print_exc() print('Catch Exception:%s' % str(err)) return None - def lazy_instance_by_fliename(abs, class_name): @@ -186,8 +185,8 @@ def lazy_instance_by_fliename(abs, class_name): sys.path.append(dirname) package = os.path.splitext(os.path.basename(abs))[0] - model_package = __import__(package, - globals(), locals(), package.split(".")) + model_package = __import__(package, globals(), locals(), + package.split(".")) instance = getattr(model_package, class_name) return instance except Exception, err: diff --git a/core/utils/util.py b/core/utils/util.py index eb52fcdb..5b0ea22f 100755 --- a/core/utils/util.py +++ b/core/utils/util.py @@ -101,6 +101,7 @@ def make_datetime(date_str, fmt=None): return datetime.datetime.strptime(date_str, '%Y%m%d%H%M') return datetime.datetime.strptime(date_str, fmt) + def wroker_numric_opt(fleet, value, env, opt): """ numric count opt for workers @@ -116,6 +117,7 @@ def wroker_numric_opt(fleet, value, env, opt): fleet._role_maker.all_reduce_worker(local_value, global_value, opt) return global_value[0] + def worker_numric_sum(fleet, value, env="mpi"): """R """ @@ -139,6 +141,7 @@ def worker_numric_max(fleet, value, env="mpi"): """ return wroker_numric_opt(fleet, value, env, "max") + def print_log(log_str, params): """R """ @@ -153,6 +156,7 @@ def print_log(log_str, params): if 'stdout' in params: params['stdout'] += log_str + '\n' + def rank0_print(log_str, fleet): """R """ @@ -171,7 +175,6 @@ class CostPrinter(object): """ For count cost time && print cost log """ - def __init__(self, callback, callback_params): """R """ @@ -207,7 +210,6 @@ class PathGenerator(object): """ generate path with template & runtime variables """ - def __init__(self, config): """R """ @@ -228,8 +230,8 @@ class PathGenerator(object): """ if template_name in self._templates: if 'time_format' in param: - str = param['time_format'].strftime(self._templates[ - template_name]) + str = param['time_format'].strftime( + self._templates[template_name]) return str.format(**param) return self._templates[template_name].format(**param) else: -- GitLab