提交 eef09152 编写于 作者: X xixiaoyao

add saving infermodel with num steps

上级 40283138
......@@ -304,19 +304,6 @@ class Controller(object):
instances[i].task_reuse_scope = instances[j].name
break
# parse Reader and Paradigm for each instance
for inst in instances:
reader_name = inst.config['reader']
reader_mod = importlib.import_module(READER_DIR + '.' + reader_name)
Reader = getattr(reader_mod, 'Reader')
parad_name = inst.config['paradigm']
parad_mod = importlib.import_module(PARADIGM_DIR + '.' + parad_name)
Paradigm = getattr(parad_mod, 'TaskParadigm')
inst.Reader = Reader
inst.Paradigm = Paradigm
self.instances = instances
self.mrs = mrs
self.Backbone = Backbone
......@@ -605,6 +592,9 @@ class Controller(object):
global_step += 1
cur_task.cur_train_step += 1
if cur_task.save_infermodel_every_n_steps > 0 and cur_task.cur_train_step % cur_task.save_infermodel_every_n_steps == 0:
cur_task.save(suffix='-step'+str(cur_task.cur_train_step))
if global_step % main_conf.get('print_every_n_steps', 5) == 0:
loss = rt_outputs[cur_task.name+'/loss']
loss = np.mean(np.squeeze(loss)).tolist()
......@@ -635,8 +625,11 @@ class Controller(object):
assert isinstance(task_instance, str)
if isinstance(inference_model_dir, str):
assert os.path.exists(inference_model_dir), inference_model_dir+" not found."
if not self.has_init_pred and inference_model_dir is None:
raise ValueError('infer_model_path is required for prediction.')
# if not self.has_init_pred and inference_model_dir is None:
# raise ValueError('infer_model_path is required for prediction.')
if inference_model_dir is None:
assert 'save_path' in self.mtl_conf, "one of the `inference_model_dir` and 'save_path' should be set to load inference model."
inference_model_dir = os.path.join(self.mtl_conf['save_path'], task_instance, 'infer_model')
instance = None
for inst in self.instances:
......
......@@ -19,13 +19,34 @@ import os
import json
from paddle import fluid
def check_req_args(conf, name):
assert 'reader' in conf, name+': reader is required to build TaskInstance.'
assert 'paradigm' in conf, name+': paradigm is required to build TaskInstance.'
assert 'train_file' in conf or 'pred_file' in conf, name+': at least train_file or pred_file should be provided to build TaskInstance.'
class TaskInstance(object):
def __init__(self, name, id, config={}, verbose=True):
def __init__(self, name, id, config, verbose=True):
self._name = name
self._config = config
self._verbose = verbose
check_req_args(config)
# parse Reader and Paradigm
reader_name = config['reader']
reader_mod = importlib.import_module(READER_DIR + '.' + reader_name)
Reader = getattr(reader_mod, 'Reader')
parad_name = config['paradigm']
parad_mod = importlib.import_module(PARADIGM_DIR + '.' + parad_name)
Paradigm = getattr(parad_mod, 'TaskParadigm')
self._Reader = Reader
self._Paradigm = Paradigm
self._save_infermodel_path = os.path.join(self._config['save_path'], self._name, 'infer_model')
self._save_ckpt_path = os.path.join(self._config['save_path'], 'ckpt')
......@@ -116,23 +137,23 @@ class TaskInstance(object):
def Reader(self):
return self._Reader
@Reader.setter
def Reader(self, cls):
assert base_reader.__name__ == cls.__bases__[-1].__name__, \
"expect: {}, receive: {}.".format(base_reader.__name__, \
cls.__bases__[-1].__name__)
self._Reader = cls
# @Reader.setter
# def Reader(self, cls):
# assert base_reader.__name__ == cls.__bases__[-1].__name__, \
# "expect: {}, receive: {}.".format(base_reader.__name__, \
# cls.__bases__[-1].__name__)
# self._Reader = cls
@property
def Paradigm(self):
return self._Paradigm
@Paradigm.setter
def Paradigm(self, cls):
assert base_paradigm.__name__ == cls.__bases__[-1].__name__, \
"expect: {}, receive: {}.".format(base_paradigm.__name__, \
cls.__bases__[-1].__name__)
self._Paradigm = cls
# @Paradigm.setter
# def Paradigm(self, cls):
# assert base_paradigm.__name__ == cls.__bases__[-1].__name__, \
# "expect: {}, receive: {}.".format(base_paradigm.__name__, \
# cls.__bases__[-1].__name__)
# self._Paradigm = cls
@property
def config(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册