未验证 提交 e270b2ad 编写于 作者: W wuzhihua 提交者: GitHub

Merge pull request #51 from xionghang/master

fix bugs in dataset and metrics
...@@ -66,7 +66,7 @@ class AUCMetric(Metric): ...@@ -66,7 +66,7 @@ class AUCMetric(Metric):
old_metric_shape = np.array(metric.shape) old_metric_shape = np.array(metric.shape)
metric = metric.reshape(-1) metric = metric.reshape(-1)
global_metric = np.copy(metric) * 0 global_metric = np.copy(metric) * 0
self.fleet._role_maker._node_type_comm.Allreduce(metric, global_metric) self.fleet._role_maker.all_reduce_worker(metric, global_metric)
global_metric = global_metric.reshape(old_metric_shape) global_metric = global_metric.reshape(old_metric_shape)
return global_metric[0] return global_metric[0]
......
...@@ -33,8 +33,13 @@ def create(config): ...@@ -33,8 +33,13 @@ def create(config):
model = None model = None
if config['mode'] == 'fluid': if config['mode'] == 'fluid':
if config['layer_file'].endswith(".py"):
model_class = envs.lazy_instance_by_fliename(config['layer_file'],
"Model")
model = model_class(config)
else:
model = YamlModel(config) model = YamlModel(config)
model.train_net() model.train()
return model return model
......
...@@ -17,6 +17,7 @@ import os ...@@ -17,6 +17,7 @@ import os
import time import time
import sys import sys
import yaml import yaml
import traceback
from paddle import fluid from paddle import fluid
...@@ -51,10 +52,18 @@ class Trainer(object): ...@@ -51,10 +52,18 @@ class Trainer(object):
Return: Return:
None : run a processor for this status None : run a processor for this status
""" """
if context['status'] in self._status_processor: status = context['status']
try:
if status in self._status_processor:
self._status_processor[context['status']](context) self._status_processor[context['status']](context)
else: else:
self.other_status_processor(context) self.other_status_processor(context)
except Exception, err:
traceback.print_exc()
print('Catch Exception:%s' % str(err))
sys.stdout.flush()
self._context['is_exit'] = self.handle_processor_exception(
status, context, err)
def other_status_processor(self, context): def other_status_processor(self, context):
""" """
...@@ -65,6 +74,16 @@ class Trainer(object): ...@@ -65,6 +74,16 @@ class Trainer(object):
print('unknow context_status:%s, do nothing' % context['status']) print('unknow context_status:%s, do nothing' % context['status'])
time.sleep(60) time.sleep(60)
def handle_processor_exception(self, status, context, exception):
"""
when exception throwed from processor, will call this func to handle it
Return:
bool exit_app or not
"""
print('Exit app. catch exception in precoss status:%s, except:%s' \
% (context['status'], str(exception)))
return True
def reload_train_context(self): def reload_train_context(self):
""" """
context maybe update timely, reload for update context maybe update timely, reload for update
......
...@@ -71,7 +71,7 @@ class TimeSplitDatasetHolder(DatasetHolder): ...@@ -71,7 +71,7 @@ class TimeSplitDatasetHolder(DatasetHolder):
""" """
init data root_path, time_split_interval, data_path_format init data root_path, time_split_interval, data_path_format
""" """
Dataset.__init__(self, config) DatasetHolder.__init__(self, config)
if 'data_donefile' not in config or config['data_donefile'] is None: if 'data_donefile' not in config or config['data_donefile'] is None:
config['data_donefile'] = config['data_path'] + "/to.hadoop.done" config['data_donefile'] = config['data_path'] + "/to.hadoop.done"
self._path_generator = util.PathGenerator({ self._path_generator = util.PathGenerator({
...@@ -153,6 +153,12 @@ class TimeSplitDatasetHolder(DatasetHolder): ...@@ -153,6 +153,12 @@ class TimeSplitDatasetHolder(DatasetHolder):
if not sub_file_name.startswith(self._config[ if not sub_file_name.startswith(self._config[
'filename_prefix']): 'filename_prefix']):
continue continue
postfix = sub_file_name.split(self._config['filename_prefix'])[
1]
if postfix.isdigit():
if int(postfix) % node_num == node_idx:
data_file_list.append(sub_file)
else:
if hash(sub_file_name) % node_num == node_idx: if hash(sub_file_name) % node_num == node_idx:
data_file_list.append(sub_file) data_file_list.append(sub_file)
time_window_mins = time_window_mins - self._split_interval time_window_mins = time_window_mins - self._split_interval
......
...@@ -18,6 +18,7 @@ import copy ...@@ -18,6 +18,7 @@ import copy
import os import os
import socket import socket
import sys import sys
import traceback
global_envs = {} global_envs = {}
...@@ -167,14 +168,19 @@ def pretty_print_envs(envs, header=None): ...@@ -167,14 +168,19 @@ def pretty_print_envs(envs, header=None):
def lazy_instance_by_package(package, class_name): def lazy_instance_by_package(package, class_name):
models = get_global_env("train.model.models") try:
model_package = __import__(package, model_package = __import__(package,
globals(), locals(), package.split(".")) globals(), locals(), package.split("."))
instance = getattr(model_package, class_name) instance = getattr(model_package, class_name)
return instance return instance
except Exception, err:
traceback.print_exc()
print('Catch Exception:%s' % str(err))
return None
def lazy_instance_by_fliename(abs, class_name): def lazy_instance_by_fliename(abs, class_name):
try:
dirname = os.path.dirname(abs) dirname = os.path.dirname(abs)
sys.path.append(dirname) sys.path.append(dirname)
package = os.path.splitext(os.path.basename(abs))[0] package = os.path.splitext(os.path.basename(abs))[0]
...@@ -183,6 +189,10 @@ def lazy_instance_by_fliename(abs, class_name): ...@@ -183,6 +189,10 @@ def lazy_instance_by_fliename(abs, class_name):
globals(), locals(), package.split(".")) globals(), locals(), package.split("."))
instance = getattr(model_package, class_name) instance = getattr(model_package, class_name)
return instance return instance
except Exception, err:
traceback.print_exc()
print('Catch Exception:%s' % str(err))
return None
def get_platform(): def get_platform():
......
...@@ -14,8 +14,9 @@ ...@@ -14,8 +14,9 @@
import datetime import datetime
import os import os
import sys
import time import time
import numpy as np
from paddle import fluid from paddle import fluid
from paddlerec.core.utils import fs as fs from paddlerec.core.utils import fs as fs
...@@ -101,10 +102,65 @@ def make_datetime(date_str, fmt=None): ...@@ -101,10 +102,65 @@ def make_datetime(date_str, fmt=None):
return datetime.datetime.strptime(date_str, fmt) return datetime.datetime.strptime(date_str, fmt)
def rank0_print(log_str): def wroker_numric_opt(fleet, value, env, opt):
"""
numric count opt for workers
Args:
value: value for count
env: mpi/gloo
opt: count operator, SUM/MAX/MIN/AVG
Return:
count result
"""
local_value = np.array([value])
global_value = np.copy(local_value) * 0
fleet._role_maker.all_reduce_worker(local_value, global_value, opt)
return global_value[0]
def worker_numric_sum(fleet, value, env="mpi"):
"""R
"""
return wroker_numric_opt(fleet, value, env, "sum")
def worker_numric_avg(fleet, value, env="mpi"):
"""R
"""
return worker_numric_sum(fleet, value, env) / fleet.worker_num()
def worker_numric_min(fleet, value, env="mpi"):
"""R """R
""" """
print_log(log_str, {'master': True}) return wroker_numric_opt(fleet, value, env, "min")
def worker_numric_max(fleet, value, env="mpi"):
"""R
"""
return wroker_numric_opt(fleet, value, env, "max")
def print_log(log_str, params):
"""R
"""
time_str = time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime())
log_str = time_str + " " + log_str
if 'master' in params and params['master']:
if 'index' in params and params['index'] == 0:
print(log_str)
else:
print(log_str)
sys.stdout.flush()
if 'stdout' in params:
params['stdout'] += log_str + '\n'
def rank0_print(log_str, fleet):
"""R
"""
print_log(log_str, {'master': True, 'index': fleet.worker_index()})
def print_cost(cost, params): def print_cost(cost, params):
...@@ -182,163 +238,3 @@ class PathGenerator(object): ...@@ -182,163 +238,3 @@ class PathGenerator(object):
return self._templates[template_name].format(**param) return self._templates[template_name].format(**param)
else: else:
return "" return ""
class TimeTrainPass(object):
"""
timely pass
define pass time_interval && start_time && end_time
"""
def __init__(self, global_config):
"""R
"""
self._config = global_config['epoch']
if '+' in self._config['days']:
day_str = self._config['days'].replace(' ', '')
day_fields = day_str.split('+')
self._begin_day = make_datetime(day_fields[0].strip())
if len(day_fields) == 1 or len(day_fields[1]) == 0:
# 100 years, meaning to continuous running
self._end_day = self._begin_day + datetime.timedelta(
days=36500)
else:
# example: 2020212+10
run_day = int(day_fields[1].strip())
self._end_day = self._begin_day + datetime.timedelta(
days=run_day)
else:
# example: {20191001..20191031}
days = os.popen("echo -n " + self._config['days']).read().split(
" ")
self._begin_day = make_datetime(days[0])
self._end_day = make_datetime(days[len(days) - 1])
self._checkpoint_interval = self._config['checkpoint_interval']
self._dump_inference_interval = self._config['dump_inference_interval']
self._interval_per_pass = self._config[
'train_time_interval'] # train N min data per pass
self._pass_id = 0
self._inference_pass_id = 0
self._pass_donefile_handler = None
if 'pass_donefile_name' in self._config:
self._train_pass_donefile = global_config[
'output_path'] + '/' + self._config['pass_donefile_name']
if fs.is_afs_path(self._train_pass_donefile):
self._pass_donefile_handler = fs.FileHandler(global_config[
'io']['afs'])
else:
self._pass_donefile_handler = fs.FileHandler(global_config[
'io']['local_fs'])
last_done = self._pass_donefile_handler.cat(
self._train_pass_donefile).strip().split('\n')[-1]
done_fileds = last_done.split('\t')
if len(done_fileds) > 4:
self._base_key = done_fileds[1]
self._checkpoint_model_path = done_fileds[2]
self._checkpoint_pass_id = int(done_fileds[3])
self._inference_pass_id = int(done_fileds[4])
self.init_pass_by_id(done_fileds[0], self._checkpoint_pass_id)
def max_pass_num_day(self):
"""R
"""
return 24 * 60 / self._interval_per_pass
def save_train_progress(self, day, pass_id, base_key, model_path,
is_checkpoint):
"""R
"""
if is_checkpoint:
self._checkpoint_pass_id = pass_id
self._checkpoint_model_path = model_path
done_content = "%s\t%s\t%s\t%s\t%d\n" % (
day, base_key, self._checkpoint_model_path,
self._checkpoint_pass_id, pass_id)
self._pass_donefile_handler.write(done_content,
self._train_pass_donefile, 'a')
pass
def init_pass_by_id(self, date_str, pass_id):
"""
init pass context with pass_id
Args:
date_str: example "20200110"
pass_id(int): pass_id of date
"""
date_time = make_datetime(date_str)
if pass_id < 1:
pass_id = 0
if (date_time - self._begin_day).total_seconds() > 0:
self._begin_day = date_time
self._pass_id = pass_id
mins = self._interval_per_pass * (pass_id - 1)
self._current_train_time = date_time + datetime.timedelta(minutes=mins)
def init_pass_by_time(self, datetime_str):
"""
init pass context with datetime
Args:
date_str: example "20200110000" -> "%Y%m%d%H%M"
"""
self._current_train_time = make_datetime(datetime_str)
minus = self._current_train_time.hour * 60 + self._current_train_time.minute
self._pass_id = minus / self._interval_per_pass + 1
def current_pass(self):
"""R
"""
return self._pass_id
def next(self):
"""R
"""
has_next = True
old_pass_id = self._pass_id
if self._pass_id < 1:
self.init_pass_by_time(self._begin_day.strftime("%Y%m%d%H%M"))
else:
next_time = self._current_train_time + datetime.timedelta(
minutes=self._interval_per_pass)
if (next_time - self._end_day).total_seconds() > 0:
has_next = False
else:
self.init_pass_by_time(next_time.strftime("%Y%m%d%H%M"))
if has_next and (self._inference_pass_id < self._pass_id or
self._pass_id < old_pass_id):
self._inference_pass_id = self._pass_id - 1
return has_next
def is_checkpoint_pass(self, pass_id):
"""R
"""
if pass_id < 1:
return True
if pass_id == self.max_pass_num_day():
return False
if pass_id % self._checkpoint_interval == 0:
return True
return False
def need_dump_inference(self, pass_id):
"""R
"""
return self._inference_pass_id < pass_id and pass_id % self._dump_inference_interval == 0
def date(self, delta_day=0):
"""
get train date
Args:
delta_day(int): n day afer current_train_date
Return:
date(current_train_time + delta_day)
"""
return (self._current_train_time + datetime.timedelta(days=delta_day)
).strftime("%Y%m%d")
def timestamp(self, delta_day=0):
"""R
"""
return (self._current_train_time + datetime.timedelta(days=delta_day)
).timestamp()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册