From c62b8d6dc8e4722ea4f158941bee8e02128a23f9 Mon Sep 17 00:00:00 2001 From: xionghang <674952820@qq.com> Date: Mon, 8 Jun 2020 20:00:25 +0800 Subject: [PATCH] revert it --- core/metrics/auc_metrics.py | 2 +- core/modules/modul/build.py | 9 +- core/trainer.py | 27 ++++- core/utils/dataset_holder.py | 12 +- core/utils/envs.py | 36 ++++-- core/utils/util.py | 222 ++++++++++------------------------- 6 files changed, 122 insertions(+), 186 deletions(-) diff --git a/core/metrics/auc_metrics.py b/core/metrics/auc_metrics.py index 085c8499..431411f3 100755 --- a/core/metrics/auc_metrics.py +++ b/core/metrics/auc_metrics.py @@ -66,7 +66,7 @@ class AUCMetric(Metric): old_metric_shape = np.array(metric.shape) metric = metric.reshape(-1) global_metric = np.copy(metric) * 0 - self.fleet._role_maker._node_type_comm.Allreduce(metric, global_metric) + self.fleet._role_maker.all_reduce_worker(metric, global_metric) global_metric = global_metric.reshape(old_metric_shape) return global_metric[0] diff --git a/core/modules/modul/build.py b/core/modules/modul/build.py index dae77717..681f9ed8 100755 --- a/core/modules/modul/build.py +++ b/core/modules/modul/build.py @@ -33,8 +33,13 @@ def create(config): model = None if config['mode'] == 'fluid': - model = YamlModel(config) - model.train_net() + if config['layer_file'].endswith(".py"): + model_class = envs.lazy_instance_by_fliename(config['layer_file'], + "Model") + model = model_class(config) + else: + model = YamlModel(config) + model.train() return model diff --git a/core/trainer.py b/core/trainer.py index 46b77b75..ecb069a3 100755 --- a/core/trainer.py +++ b/core/trainer.py @@ -17,6 +17,7 @@ import os import time import sys import yaml +import traceback from paddle import fluid @@ -51,10 +52,18 @@ class Trainer(object): Return: None : run a processor for this status """ - if context['status'] in self._status_processor: - self._status_processor[context['status']](context) - else: - self.other_status_processor(context) + status = context['status'] + try: + if status in self._status_processor: + self._status_processor[context['status']](context) + else: + self.other_status_processor(context) + except Exception, err: + traceback.print_exc() + print('Catch Exception:%s' % str(err)) + sys.stdout.flush() + self._context['is_exit'] = self.handle_processor_exception( + status, context, err) def other_status_processor(self, context): """ @@ -65,6 +74,16 @@ class Trainer(object): print('unknow context_status:%s, do nothing' % context['status']) time.sleep(60) + def handle_processor_exception(self, status, context, exception): + """ + when exception throwed from processor, will call this func to handle it + Return: + bool exit_app or not + """ + print('Exit app. catch exception in precoss status:%s, except:%s' \ + % (context['status'], str(exception))) + return True + def reload_train_context(self): """ context maybe update timely, reload for update diff --git a/core/utils/dataset_holder.py b/core/utils/dataset_holder.py index a75d52b6..70355a54 100755 --- a/core/utils/dataset_holder.py +++ b/core/utils/dataset_holder.py @@ -71,7 +71,7 @@ class TimeSplitDatasetHolder(DatasetHolder): """ init data root_path, time_split_interval, data_path_format """ - Dataset.__init__(self, config) + DatasetHolder.__init__(self, config) if 'data_donefile' not in config or config['data_donefile'] is None: config['data_donefile'] = config['data_path'] + "/to.hadoop.done" self._path_generator = util.PathGenerator({ @@ -153,8 +153,14 @@ class TimeSplitDatasetHolder(DatasetHolder): if not sub_file_name.startswith(self._config[ 'filename_prefix']): continue - if hash(sub_file_name) % node_num == node_idx: - data_file_list.append(sub_file) + postfix = sub_file_name.split(self._config['filename_prefix'])[ + 1] + if postfix.isdigit(): + if int(postfix) % node_num == node_idx: + data_file_list.append(sub_file) + else: + if hash(sub_file_name) % node_num == node_idx: + data_file_list.append(sub_file) time_window_mins = time_window_mins - self._split_interval data_time = data_time + datetime.timedelta( minutes=self._split_interval) diff --git a/core/utils/envs.py b/core/utils/envs.py index f768e14a..15036f67 100755 --- a/core/utils/envs.py +++ b/core/utils/envs.py @@ -18,6 +18,7 @@ import copy import os import socket import sys +import traceback global_envs = {} @@ -167,22 +168,31 @@ def pretty_print_envs(envs, header=None): def lazy_instance_by_package(package, class_name): - models = get_global_env("train.model.models") - model_package = __import__(package, - globals(), locals(), package.split(".")) - instance = getattr(model_package, class_name) - return instance + try: + model_package = __import__(package, + globals(), locals(), package.split(".")) + instance = getattr(model_package, class_name) + return instance + except Exception, err: + traceback.print_exc() + print('Catch Exception:%s' % str(err)) + return None def lazy_instance_by_fliename(abs, class_name): - dirname = os.path.dirname(abs) - sys.path.append(dirname) - package = os.path.splitext(os.path.basename(abs))[0] - - model_package = __import__(package, - globals(), locals(), package.split(".")) - instance = getattr(model_package, class_name) - return instance + try: + dirname = os.path.dirname(abs) + sys.path.append(dirname) + package = os.path.splitext(os.path.basename(abs))[0] + + model_package = __import__(package, + globals(), locals(), package.split(".")) + instance = getattr(model_package, class_name) + return instance + except Exception, err: + traceback.print_exc() + print('Catch Exception:%s' % str(err)) + return None def get_platform(): diff --git a/core/utils/util.py b/core/utils/util.py index 34f26c6d..381d35ca 100755 --- a/core/utils/util.py +++ b/core/utils/util.py @@ -14,8 +14,9 @@ import datetime import os +import sys import time - +import numpy as np from paddle import fluid from paddlerec.core.utils import fs as fs @@ -101,10 +102,65 @@ def make_datetime(date_str, fmt=None): return datetime.datetime.strptime(date_str, fmt) -def rank0_print(log_str): +def wroker_numric_opt(fleet, value, env, opt): + """ + numric count opt for workers + Args: + value: value for count + env: mpi/gloo + opt: count operator, SUM/MAX/MIN/AVG + Return: + count result + """ + local_value = np.array([value]) + global_value = np.copy(local_value) * 0 + fleet._role_maker.all_reduce_worker(local_value, global_value, opt) + return global_value[0] + + +def worker_numric_sum(fleet, value, env="mpi"): """R """ - print_log(log_str, {'master': True}) + return wroker_numric_opt(fleet, value, env, "sum") + + +def worker_numric_avg(fleet, value, env="mpi"): + """R + """ + return worker_numric_sum(fleet, value, env) / fleet.worker_num() + + +def worker_numric_min(fleet, value, env="mpi"): + """R + """ + return wroker_numric_opt(fleet, value, env, "min") + + +def worker_numric_max(fleet, value, env="mpi"): + """R + """ + return wroker_numric_opt(fleet, value, env, "max") + + +def print_log(log_str, params): + """R + """ + time_str = time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()) + log_str = time_str + " " + log_str + if 'master' in params and params['master']: + if 'index' in params and params['index'] == 0: + print(log_str) + else: + print(log_str) + sys.stdout.flush() + if 'stdout' in params: + params['stdout'] += log_str + '\n' + + +def rank0_print(log_str, fleet): + """R + """ + print_log(log_str, {'master': True, 'index': fleet.worker_index()}) def print_cost(cost, params): @@ -182,163 +238,3 @@ class PathGenerator(object): return self._templates[template_name].format(**param) else: return "" - - -class TimeTrainPass(object): - """ - timely pass - define pass time_interval && start_time && end_time - """ - - def __init__(self, global_config): - """R - """ - self._config = global_config['epoch'] - if '+' in self._config['days']: - day_str = self._config['days'].replace(' ', '') - day_fields = day_str.split('+') - self._begin_day = make_datetime(day_fields[0].strip()) - if len(day_fields) == 1 or len(day_fields[1]) == 0: - # 100 years, meaning to continuous running - self._end_day = self._begin_day + datetime.timedelta( - days=36500) - else: - # example: 2020212+10 - run_day = int(day_fields[1].strip()) - self._end_day = self._begin_day + datetime.timedelta( - days=run_day) - else: - # example: {20191001..20191031} - days = os.popen("echo -n " + self._config['days']).read().split( - " ") - self._begin_day = make_datetime(days[0]) - self._end_day = make_datetime(days[len(days) - 1]) - self._checkpoint_interval = self._config['checkpoint_interval'] - self._dump_inference_interval = self._config['dump_inference_interval'] - self._interval_per_pass = self._config[ - 'train_time_interval'] # train N min data per pass - - self._pass_id = 0 - self._inference_pass_id = 0 - self._pass_donefile_handler = None - if 'pass_donefile_name' in self._config: - self._train_pass_donefile = global_config[ - 'output_path'] + '/' + self._config['pass_donefile_name'] - if fs.is_afs_path(self._train_pass_donefile): - self._pass_donefile_handler = fs.FileHandler(global_config[ - 'io']['afs']) - else: - self._pass_donefile_handler = fs.FileHandler(global_config[ - 'io']['local_fs']) - - last_done = self._pass_donefile_handler.cat( - self._train_pass_donefile).strip().split('\n')[-1] - done_fileds = last_done.split('\t') - if len(done_fileds) > 4: - self._base_key = done_fileds[1] - self._checkpoint_model_path = done_fileds[2] - self._checkpoint_pass_id = int(done_fileds[3]) - self._inference_pass_id = int(done_fileds[4]) - self.init_pass_by_id(done_fileds[0], self._checkpoint_pass_id) - - def max_pass_num_day(self): - """R - """ - return 24 * 60 / self._interval_per_pass - - def save_train_progress(self, day, pass_id, base_key, model_path, - is_checkpoint): - """R - """ - if is_checkpoint: - self._checkpoint_pass_id = pass_id - self._checkpoint_model_path = model_path - done_content = "%s\t%s\t%s\t%s\t%d\n" % ( - day, base_key, self._checkpoint_model_path, - self._checkpoint_pass_id, pass_id) - self._pass_donefile_handler.write(done_content, - self._train_pass_donefile, 'a') - pass - - def init_pass_by_id(self, date_str, pass_id): - """ - init pass context with pass_id - Args: - date_str: example "20200110" - pass_id(int): pass_id of date - """ - date_time = make_datetime(date_str) - if pass_id < 1: - pass_id = 0 - if (date_time - self._begin_day).total_seconds() > 0: - self._begin_day = date_time - self._pass_id = pass_id - mins = self._interval_per_pass * (pass_id - 1) - self._current_train_time = date_time + datetime.timedelta(minutes=mins) - - def init_pass_by_time(self, datetime_str): - """ - init pass context with datetime - Args: - date_str: example "20200110000" -> "%Y%m%d%H%M" - """ - self._current_train_time = make_datetime(datetime_str) - minus = self._current_train_time.hour * 60 + self._current_train_time.minute - self._pass_id = minus / self._interval_per_pass + 1 - - def current_pass(self): - """R - """ - return self._pass_id - - def next(self): - """R - """ - has_next = True - old_pass_id = self._pass_id - if self._pass_id < 1: - self.init_pass_by_time(self._begin_day.strftime("%Y%m%d%H%M")) - else: - next_time = self._current_train_time + datetime.timedelta( - minutes=self._interval_per_pass) - if (next_time - self._end_day).total_seconds() > 0: - has_next = False - else: - self.init_pass_by_time(next_time.strftime("%Y%m%d%H%M")) - if has_next and (self._inference_pass_id < self._pass_id or - self._pass_id < old_pass_id): - self._inference_pass_id = self._pass_id - 1 - return has_next - - def is_checkpoint_pass(self, pass_id): - """R - """ - if pass_id < 1: - return True - if pass_id == self.max_pass_num_day(): - return False - if pass_id % self._checkpoint_interval == 0: - return True - return False - - def need_dump_inference(self, pass_id): - """R - """ - return self._inference_pass_id < pass_id and pass_id % self._dump_inference_interval == 0 - - def date(self, delta_day=0): - """ - get train date - Args: - delta_day(int): n day afer current_train_date - Return: - date(current_train_time + delta_day) - """ - return (self._current_train_time + datetime.timedelta(days=delta_day) - ).strftime("%Y%m%d") - - def timestamp(self, delta_day=0): - """R - """ - return (self._current_train_time + datetime.timedelta(days=delta_day) - ).timestamp() -- GitLab