提交 eef09152 编写于 作者: X xixiaoyao

add saving infermodel with num steps

上级 40283138
...@@ -304,19 +304,6 @@ class Controller(object): ...@@ -304,19 +304,6 @@ class Controller(object):
instances[i].task_reuse_scope = instances[j].name instances[i].task_reuse_scope = instances[j].name
break break
# parse Reader and Paradigm for each instance
for inst in instances:
reader_name = inst.config['reader']
reader_mod = importlib.import_module(READER_DIR + '.' + reader_name)
Reader = getattr(reader_mod, 'Reader')
parad_name = inst.config['paradigm']
parad_mod = importlib.import_module(PARADIGM_DIR + '.' + parad_name)
Paradigm = getattr(parad_mod, 'TaskParadigm')
inst.Reader = Reader
inst.Paradigm = Paradigm
self.instances = instances self.instances = instances
self.mrs = mrs self.mrs = mrs
self.Backbone = Backbone self.Backbone = Backbone
...@@ -605,6 +592,9 @@ class Controller(object): ...@@ -605,6 +592,9 @@ class Controller(object):
global_step += 1 global_step += 1
cur_task.cur_train_step += 1 cur_task.cur_train_step += 1
if cur_task.save_infermodel_every_n_steps > 0 and cur_task.cur_train_step % cur_task.save_infermodel_every_n_steps == 0:
cur_task.save(suffix='-step'+str(cur_task.cur_train_step))
if global_step % main_conf.get('print_every_n_steps', 5) == 0: if global_step % main_conf.get('print_every_n_steps', 5) == 0:
loss = rt_outputs[cur_task.name+'/loss'] loss = rt_outputs[cur_task.name+'/loss']
loss = np.mean(np.squeeze(loss)).tolist() loss = np.mean(np.squeeze(loss)).tolist()
...@@ -635,8 +625,11 @@ class Controller(object): ...@@ -635,8 +625,11 @@ class Controller(object):
assert isinstance(task_instance, str) assert isinstance(task_instance, str)
if isinstance(inference_model_dir, str): if isinstance(inference_model_dir, str):
assert os.path.exists(inference_model_dir), inference_model_dir+" not found." assert os.path.exists(inference_model_dir), inference_model_dir+" not found."
if not self.has_init_pred and inference_model_dir is None: # if not self.has_init_pred and inference_model_dir is None:
raise ValueError('infer_model_path is required for prediction.') # raise ValueError('infer_model_path is required for prediction.')
if inference_model_dir is None:
assert 'save_path' in self.mtl_conf, "one of the `inference_model_dir` and 'save_path' should be set to load inference model."
inference_model_dir = os.path.join(self.mtl_conf['save_path'], task_instance, 'infer_model')
instance = None instance = None
for inst in self.instances: for inst in self.instances:
......
...@@ -19,13 +19,34 @@ import os ...@@ -19,13 +19,34 @@ import os
import json import json
from paddle import fluid from paddle import fluid
def check_req_args(conf, name):
assert 'reader' in conf, name+': reader is required to build TaskInstance.'
assert 'paradigm' in conf, name+': paradigm is required to build TaskInstance.'
assert 'train_file' in conf or 'pred_file' in conf, name+': at least train_file or pred_file should be provided to build TaskInstance.'
class TaskInstance(object): class TaskInstance(object):
def __init__(self, name, id, config={}, verbose=True): def __init__(self, name, id, config, verbose=True):
self._name = name self._name = name
self._config = config self._config = config
self._verbose = verbose self._verbose = verbose
check_req_args(config)
# parse Reader and Paradigm
reader_name = config['reader']
reader_mod = importlib.import_module(READER_DIR + '.' + reader_name)
Reader = getattr(reader_mod, 'Reader')
parad_name = config['paradigm']
parad_mod = importlib.import_module(PARADIGM_DIR + '.' + parad_name)
Paradigm = getattr(parad_mod, 'TaskParadigm')
self._Reader = Reader
self._Paradigm = Paradigm
self._save_infermodel_path = os.path.join(self._config['save_path'], self._name, 'infer_model') self._save_infermodel_path = os.path.join(self._config['save_path'], self._name, 'infer_model')
self._save_ckpt_path = os.path.join(self._config['save_path'], 'ckpt') self._save_ckpt_path = os.path.join(self._config['save_path'], 'ckpt')
...@@ -116,23 +137,23 @@ class TaskInstance(object): ...@@ -116,23 +137,23 @@ class TaskInstance(object):
def Reader(self): def Reader(self):
return self._Reader return self._Reader
@Reader.setter # @Reader.setter
def Reader(self, cls): # def Reader(self, cls):
assert base_reader.__name__ == cls.__bases__[-1].__name__, \ # assert base_reader.__name__ == cls.__bases__[-1].__name__, \
"expect: {}, receive: {}.".format(base_reader.__name__, \ # "expect: {}, receive: {}.".format(base_reader.__name__, \
cls.__bases__[-1].__name__) # cls.__bases__[-1].__name__)
self._Reader = cls # self._Reader = cls
@property @property
def Paradigm(self): def Paradigm(self):
return self._Paradigm return self._Paradigm
@Paradigm.setter # @Paradigm.setter
def Paradigm(self, cls): # def Paradigm(self, cls):
assert base_paradigm.__name__ == cls.__bases__[-1].__name__, \ # assert base_paradigm.__name__ == cls.__bases__[-1].__name__, \
"expect: {}, receive: {}.".format(base_paradigm.__name__, \ # "expect: {}, receive: {}.".format(base_paradigm.__name__, \
cls.__bases__[-1].__name__) # cls.__bases__[-1].__name__)
self._Paradigm = cls # self._Paradigm = cls
@property @property
def config(self): def config(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册