未验证 提交 c21afb28 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Merge pull request #27 from xixiaoyao/master

add save infermodel with train steps
......@@ -304,19 +304,6 @@ class Controller(object):
instances[i].task_reuse_scope = instances[j].name
break
# parse Reader and Paradigm for each instance
for inst in instances:
reader_name = inst.config['reader']
reader_mod = importlib.import_module(READER_DIR + '.' + reader_name)
Reader = getattr(reader_mod, 'Reader')
parad_name = inst.config['paradigm']
parad_mod = importlib.import_module(PARADIGM_DIR + '.' + parad_name)
Paradigm = getattr(parad_mod, 'TaskParadigm')
inst.Reader = Reader
inst.Paradigm = Paradigm
self.instances = instances
self.mrs = mrs
self.Backbone = Backbone
......@@ -361,7 +348,6 @@ class Controller(object):
task_attrs = []
pred_task_attrs = []
for inst in instances:
train_reader = inst.Reader(inst.config, phase='train')
inst.reader['train'] = train_reader
train_parad = inst.Paradigm(inst.config, phase='train', backbone_config=bb_conf)
......@@ -605,6 +591,9 @@ class Controller(object):
global_step += 1
cur_task.cur_train_step += 1
if cur_task.save_infermodel_every_n_steps > 0 and cur_task.cur_train_step % cur_task.save_infermodel_every_n_steps == 0:
cur_task.save(suffix='.step'+str(cur_task.cur_train_step))
if global_step % main_conf.get('print_every_n_steps', 5) == 0:
loss = rt_outputs[cur_task.name+'/loss']
loss = np.mean(np.squeeze(loss)).tolist()
......@@ -635,8 +624,11 @@ class Controller(object):
assert isinstance(task_instance, str)
if isinstance(inference_model_dir, str):
assert os.path.exists(inference_model_dir), inference_model_dir+" not found."
if not self.has_init_pred and inference_model_dir is None:
raise ValueError('infer_model_path is required for prediction.')
# if not self.has_init_pred and inference_model_dir is None:
# raise ValueError('infer_model_path is required for prediction.')
if inference_model_dir is None:
assert 'save_path' in self.mtl_conf, "one of the `inference_model_dir` and 'save_path' should be set to load inference model."
inference_model_dir = os.path.join(self.mtl_conf['save_path'], task_instance, 'infer_model')
instance = None
for inst in self.instances:
......
......@@ -18,16 +18,40 @@ from paddlepalm.interface import task_paradigm as base_paradigm
import os
import json
from paddle import fluid
import importlib
from paddlepalm.default_settings import *
def check_req_args(conf, name):
assert 'reader' in conf, name+': reader is required to build TaskInstance.'
assert 'paradigm' in conf, name+': paradigm is required to build TaskInstance.'
assert 'train_file' in conf or 'pred_file' in conf, name+': at least train_file or pred_file should be provided to build TaskInstance.'
class TaskInstance(object):
def __init__(self, name, id, config={}, verbose=True):
def __init__(self, name, id, config, verbose=True):
self._name = name
self._config = config
self._verbose = verbose
check_req_args(config, name)
# parse Reader and Paradigm
reader_name = config['reader']
reader_mod = importlib.import_module(READER_DIR + '.' + reader_name)
Reader = getattr(reader_mod, 'Reader')
parad_name = config['paradigm']
parad_mod = importlib.import_module(PARADIGM_DIR + '.' + parad_name)
Paradigm = getattr(parad_mod, 'TaskParadigm')
self._Reader = Reader
self._Paradigm = Paradigm
self._save_infermodel_path = os.path.join(self._config['save_path'], self._name, 'infer_model')
self._save_ckpt_path = os.path.join(self._config['save_path'], 'ckpt')
self._save_infermodel_every_n_steps = config.get('save_infermodel_every_n_steps', -1)
# following flags can be fetch from instance config file
self._is_target = config.get('is_target', True)
......@@ -56,9 +80,6 @@ class TaskInstance(object):
self._pred_fetch_name_list = []
self._pred_fetch_var_list = []
self._Reader = None
self._Paradigm = None
self._exe = fluid.Executor(fluid.CPUPlace())
self._save_protocol = {
......@@ -87,7 +108,9 @@ class TaskInstance(object):
dirpath = self._save_infermodel_path + suffix
self._pred_input_varname_list = [str(i) for i in self._pred_input_varname_list]
fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, export_for_deployment = True)
# fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, export_for_deployment = True)
prog = fluid.default_main_program().clone()
fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, prog)
conf = {}
for k, strv in self._save_protocol.items():
......@@ -116,23 +139,23 @@ class TaskInstance(object):
def Reader(self):
return self._Reader
@Reader.setter
def Reader(self, cls):
assert base_reader.__name__ == cls.__bases__[-1].__name__, \
"expect: {}, receive: {}.".format(base_reader.__name__, \
cls.__bases__[-1].__name__)
self._Reader = cls
# @Reader.setter
# def Reader(self, cls):
# assert base_reader.__name__ == cls.__bases__[-1].__name__, \
# "expect: {}, receive: {}.".format(base_reader.__name__, \
# cls.__bases__[-1].__name__)
# self._Reader = cls
@property
def Paradigm(self):
return self._Paradigm
@Paradigm.setter
def Paradigm(self, cls):
assert base_paradigm.__name__ == cls.__bases__[-1].__name__, \
"expect: {}, receive: {}.".format(base_paradigm.__name__, \
cls.__bases__[-1].__name__)
self._Paradigm = cls
# @Paradigm.setter
# def Paradigm(self, cls):
# assert base_paradigm.__name__ == cls.__bases__[-1].__name__, \
# "expect: {}, receive: {}.".format(base_paradigm.__name__, \
# cls.__bases__[-1].__name__)
# self._Paradigm = cls
@property
def config(self):
......@@ -201,6 +224,10 @@ class TaskInstance(object):
if self._verbose:
print('{}: mix_ratio is set to {}'.format(self._name, self._mix_ratio))
@property
def save_infermodel_every_n_steps(self):
return self._save_infermodel_every_n_steps
@property
def expected_train_steps(self):
return self._expected_train_steps
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册