提交 8a8d8acb 编写于 作者: X xiexionghang

fix bugs and add interface

上级 cdc0f7c9
......@@ -66,7 +66,7 @@ class AUCMetric(Metric):
old_metric_shape = np.array(metric.shape)
metric = metric.reshape(-1)
global_metric = np.copy(metric) * 0
self.fleet._role_maker._node_type_comm.Allreduce(metric, global_metric)
self.fleet._role_maker.all_reduce_worker(metric, global_metric)
global_metric = global_metric.reshape(old_metric_shape)
return global_metric[0]
......
......@@ -22,7 +22,7 @@ from paddlerec.core.model import Model
from paddlerec.core.utils import table
def create(config):
def create_model(config):
"""
Create a model instance by config
Args:
......@@ -33,8 +33,12 @@ def create(config):
model = None
if config['mode'] == 'fluid':
model = YamlModel(config)
model.train_net()
if config['layer_file'].endswith(".py"):
model_class = envs.lazy_instance_by_fliename(config['layer_file'], "Model")
model = model_class(config)
else:
model = YamlModel(config)
model.train()
return model
......
......@@ -17,6 +17,7 @@ import os
import time
import sys
import yaml
import traceback
from paddle import fluid
......@@ -51,10 +52,17 @@ class Trainer(object):
Return:
None : run a processor for this status
"""
if context['status'] in self._status_processor:
self._status_processor[context['status']](context)
else:
self.other_status_processor(context)
status = context['status']
try:
if status in self._status_processor:
self._status_processor[context['status']](context)
else:
self.other_status_processor(context)
except Exception, err:
traceback.print_exc()
print('Catch Exception:%s' % str(err))
sys.stdout.flush()
self._context['is_exit'] = self.handle_processor_exception(status, context, err)
def other_status_processor(self, context):
"""
......@@ -64,6 +72,17 @@ class Trainer(object):
"""
print('unknow context_status:%s, do nothing' % context['status'])
time.sleep(60)
def handle_processor_exception(self, status, context, exception):
"""
when exception throwed from processor, will call this func to handle it
Return:
bool exit_app or not
"""
print('Exit app. catch exception in precoss status:%s, except:%s' \
% (context['status'], str(exception)))
return True
def reload_train_context(self):
"""
......
......@@ -71,7 +71,7 @@ class TimeSplitDatasetHolder(DatasetHolder):
"""
init data root_path, time_split_interval, data_path_format
"""
Dataset.__init__(self, config)
DatasetHolder.__init__(self, config)
if 'data_donefile' not in config or config['data_donefile'] is None:
config['data_donefile'] = config['data_path'] + "/to.hadoop.done"
self._path_generator = util.PathGenerator({
......@@ -153,7 +153,11 @@ class TimeSplitDatasetHolder(DatasetHolder):
if not sub_file_name.startswith(self._config[
'filename_prefix']):
continue
if hash(sub_file_name) % node_num == node_idx:
postfix= sub_file_name.split(self._config['filename_prefix'])[1]
if postfix.isdigit():
if int(postfix) % node_num == node_idx:
data_file_list.append(sub_file)
elif hash(sub_file_name) % node_num == node_idx:
data_file_list.append(sub_file)
time_window_mins = time_window_mins - self._split_interval
data_time = data_time + datetime.timedelta(
......
......@@ -167,22 +167,28 @@ def pretty_print_envs(envs, header=None):
def lazy_instance_by_package(package, class_name):
models = get_global_env("train.model.models")
model_package = __import__(package,
try:
model_package = __import__(package,
globals(), locals(), package.split("."))
instance = getattr(model_package, class_name)
return instance
instance = getattr(model_package, class_name)
return instance
except Exception,err:
return None
def lazy_instance_by_fliename(abs, class_name):
dirname = os.path.dirname(abs)
sys.path.append(dirname)
package = os.path.splitext(os.path.basename(abs))[0]
try:
dirname = os.path.dirname(abs)
sys.path.append(dirname)
package = os.path.splitext(os.path.basename(abs))[0]
model_package = __import__(package,
model_package = __import__(package,
globals(), locals(), package.split("."))
instance = getattr(model_package, class_name)
return instance
instance = getattr(model_package, class_name)
return instance
except Exception,err:
return None
def get_platform():
......
......@@ -14,8 +14,9 @@
import datetime
import os
import sys
import time
import numpy as np
from paddle import fluid
from paddlerec.core.utils import fs as fs
......@@ -100,11 +101,62 @@ def make_datetime(date_str, fmt=None):
return datetime.datetime.strptime(date_str, '%Y%m%d%H%M')
return datetime.datetime.strptime(date_str, fmt)
def wroker_numric_opt(fleet, value, env, opt):
"""
numric count opt for workers
Args:
value: value for count
env: mpi/gloo
opt: count operator, SUM/MAX/MIN/AVG
Return:
count result
"""
local_value = np.array([value])
global_value = np.copy(local_value) * 0
fleet._role_maker.all_reduce_worker(local_value, global_value, opt)
return global_value[0]
def worker_numric_sum(fleet, value, env="mpi"):
"""R
"""
return wroker_numric_opt(fleet, value, env, "sum")
def worker_numric_avg(fleet, value, env="mpi"):
"""R
"""
return worker_numric_sum(fleet, value, env) / fleet.worker_num()
def rank0_print(log_str):
def worker_numric_min(fleet, value, env="mpi"):
"""R
"""
print_log(log_str, {'master': True})
return wroker_numric_opt(fleet, value, env, "min")
def worker_numric_max(fleet, value, env="mpi"):
"""R
"""
return wroker_numric_opt(fleet, value, env, "max")
def print_log(log_str, params):
"""R
"""
time_str = time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime())
log_str = time_str + " " + log_str
if 'master' in params and params['master']:
if 'index' in params and params['index'] == 0:
print(log_str)
else:
print(log_str)
sys.stdout.flush()
if 'stdout' in params:
params['stdout'] += log_str + '\n'
def rank0_print(log_str, fleet):
"""R
"""
print_log(log_str, {'master': True, 'index': fleet.worker_index()})
def print_cost(cost, params):
......@@ -182,163 +234,3 @@ class PathGenerator(object):
return self._templates[template_name].format(**param)
else:
return ""
class TimeTrainPass(object):
"""
timely pass
define pass time_interval && start_time && end_time
"""
def __init__(self, global_config):
"""R
"""
self._config = global_config['epoch']
if '+' in self._config['days']:
day_str = self._config['days'].replace(' ', '')
day_fields = day_str.split('+')
self._begin_day = make_datetime(day_fields[0].strip())
if len(day_fields) == 1 or len(day_fields[1]) == 0:
# 100 years, meaning to continuous running
self._end_day = self._begin_day + datetime.timedelta(
days=36500)
else:
# example: 2020212+10
run_day = int(day_fields[1].strip())
self._end_day = self._begin_day + datetime.timedelta(
days=run_day)
else:
# example: {20191001..20191031}
days = os.popen("echo -n " + self._config['days']).read().split(
" ")
self._begin_day = make_datetime(days[0])
self._end_day = make_datetime(days[len(days) - 1])
self._checkpoint_interval = self._config['checkpoint_interval']
self._dump_inference_interval = self._config['dump_inference_interval']
self._interval_per_pass = self._config[
'train_time_interval'] # train N min data per pass
self._pass_id = 0
self._inference_pass_id = 0
self._pass_donefile_handler = None
if 'pass_donefile_name' in self._config:
self._train_pass_donefile = global_config[
'output_path'] + '/' + self._config['pass_donefile_name']
if fs.is_afs_path(self._train_pass_donefile):
self._pass_donefile_handler = fs.FileHandler(global_config[
'io']['afs'])
else:
self._pass_donefile_handler = fs.FileHandler(global_config[
'io']['local_fs'])
last_done = self._pass_donefile_handler.cat(
self._train_pass_donefile).strip().split('\n')[-1]
done_fileds = last_done.split('\t')
if len(done_fileds) > 4:
self._base_key = done_fileds[1]
self._checkpoint_model_path = done_fileds[2]
self._checkpoint_pass_id = int(done_fileds[3])
self._inference_pass_id = int(done_fileds[4])
self.init_pass_by_id(done_fileds[0], self._checkpoint_pass_id)
def max_pass_num_day(self):
"""R
"""
return 24 * 60 / self._interval_per_pass
def save_train_progress(self, day, pass_id, base_key, model_path,
is_checkpoint):
"""R
"""
if is_checkpoint:
self._checkpoint_pass_id = pass_id
self._checkpoint_model_path = model_path
done_content = "%s\t%s\t%s\t%s\t%d\n" % (
day, base_key, self._checkpoint_model_path,
self._checkpoint_pass_id, pass_id)
self._pass_donefile_handler.write(done_content,
self._train_pass_donefile, 'a')
pass
def init_pass_by_id(self, date_str, pass_id):
"""
init pass context with pass_id
Args:
date_str: example "20200110"
pass_id(int): pass_id of date
"""
date_time = make_datetime(date_str)
if pass_id < 1:
pass_id = 0
if (date_time - self._begin_day).total_seconds() > 0:
self._begin_day = date_time
self._pass_id = pass_id
mins = self._interval_per_pass * (pass_id - 1)
self._current_train_time = date_time + datetime.timedelta(minutes=mins)
def init_pass_by_time(self, datetime_str):
"""
init pass context with datetime
Args:
date_str: example "20200110000" -> "%Y%m%d%H%M"
"""
self._current_train_time = make_datetime(datetime_str)
minus = self._current_train_time.hour * 60 + self._current_train_time.minute
self._pass_id = minus / self._interval_per_pass + 1
def current_pass(self):
"""R
"""
return self._pass_id
def next(self):
"""R
"""
has_next = True
old_pass_id = self._pass_id
if self._pass_id < 1:
self.init_pass_by_time(self._begin_day.strftime("%Y%m%d%H%M"))
else:
next_time = self._current_train_time + datetime.timedelta(
minutes=self._interval_per_pass)
if (next_time - self._end_day).total_seconds() > 0:
has_next = False
else:
self.init_pass_by_time(next_time.strftime("%Y%m%d%H%M"))
if has_next and (self._inference_pass_id < self._pass_id or
self._pass_id < old_pass_id):
self._inference_pass_id = self._pass_id - 1
return has_next
def is_checkpoint_pass(self, pass_id):
"""R
"""
if pass_id < 1:
return True
if pass_id == self.max_pass_num_day():
return False
if pass_id % self._checkpoint_interval == 0:
return True
return False
def need_dump_inference(self, pass_id):
"""R
"""
return self._inference_pass_id < pass_id and pass_id % self._dump_inference_interval == 0
def date(self, delta_day=0):
"""
get train date
Args:
delta_day(int): n day afer current_train_date
Return:
date(current_train_time + delta_day)
"""
return (self._current_train_time + datetime.timedelta(days=delta_day)
).strftime("%Y%m%d")
def timestamp(self, delta_day=0):
"""R
"""
return (self._current_train_time + datetime.timedelta(days=delta_day)
).timestamp()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册