diff --git a/paddlepalm/mtl_controller.py b/paddlepalm/mtl_controller.py index aa94c3496be9ba78abcb322706396850bc915a69..03c50a81df4b1e0f3cf279e42a0c0b2d14d8bb72 100755 --- a/paddlepalm/mtl_controller.py +++ b/paddlepalm/mtl_controller.py @@ -304,19 +304,6 @@ class Controller(object): instances[i].task_reuse_scope = instances[j].name break - # parse Reader and Paradigm for each instance - for inst in instances: - reader_name = inst.config['reader'] - reader_mod = importlib.import_module(READER_DIR + '.' + reader_name) - Reader = getattr(reader_mod, 'Reader') - - parad_name = inst.config['paradigm'] - parad_mod = importlib.import_module(PARADIGM_DIR + '.' + parad_name) - Paradigm = getattr(parad_mod, 'TaskParadigm') - - inst.Reader = Reader - inst.Paradigm = Paradigm - self.instances = instances self.mrs = mrs self.Backbone = Backbone @@ -605,6 +592,9 @@ class Controller(object): global_step += 1 cur_task.cur_train_step += 1 + if cur_task.save_infermodel_every_n_steps > 0 and cur_task.cur_train_step % cur_task.save_infermodel_every_n_steps == 0: + cur_task.save(suffix='-step'+str(cur_task.cur_train_step)) + if global_step % main_conf.get('print_every_n_steps', 5) == 0: loss = rt_outputs[cur_task.name+'/loss'] loss = np.mean(np.squeeze(loss)).tolist() @@ -635,8 +625,11 @@ class Controller(object): assert isinstance(task_instance, str) if isinstance(inference_model_dir, str): assert os.path.exists(inference_model_dir), inference_model_dir+" not found." - if not self.has_init_pred and inference_model_dir is None: - raise ValueError('infer_model_path is required for prediction.') + # if not self.has_init_pred and inference_model_dir is None: + # raise ValueError('infer_model_path is required for prediction.') + if inference_model_dir is None: + assert 'save_path' in self.mtl_conf, "one of the `inference_model_dir` and 'save_path' should be set to load inference model." + inference_model_dir = os.path.join(self.mtl_conf['save_path'], task_instance, 'infer_model') instance = None for inst in self.instances: diff --git a/paddlepalm/task_instance.py b/paddlepalm/task_instance.py index a2b1b05125d72570cf634816f28bd06e50801032..f16f98555b85c303511e1a03fc1084f27b291bc4 100644 --- a/paddlepalm/task_instance.py +++ b/paddlepalm/task_instance.py @@ -19,13 +19,34 @@ import os import json from paddle import fluid + +def check_req_args(conf, name): + assert 'reader' in conf, name+': reader is required to build TaskInstance.' + assert 'paradigm' in conf, name+': paradigm is required to build TaskInstance.' + assert 'train_file' in conf or 'pred_file' in conf, name+': at least train_file or pred_file should be provided to build TaskInstance.' + + class TaskInstance(object): - def __init__(self, name, id, config={}, verbose=True): + def __init__(self, name, id, config, verbose=True): self._name = name self._config = config self._verbose = verbose + check_req_args(config) + + # parse Reader and Paradigm + reader_name = config['reader'] + reader_mod = importlib.import_module(READER_DIR + '.' + reader_name) + Reader = getattr(reader_mod, 'Reader') + + parad_name = config['paradigm'] + parad_mod = importlib.import_module(PARADIGM_DIR + '.' + parad_name) + Paradigm = getattr(parad_mod, 'TaskParadigm') + + self._Reader = Reader + self._Paradigm = Paradigm + self._save_infermodel_path = os.path.join(self._config['save_path'], self._name, 'infer_model') self._save_ckpt_path = os.path.join(self._config['save_path'], 'ckpt') @@ -116,23 +137,23 @@ class TaskInstance(object): def Reader(self): return self._Reader - @Reader.setter - def Reader(self, cls): - assert base_reader.__name__ == cls.__bases__[-1].__name__, \ - "expect: {}, receive: {}.".format(base_reader.__name__, \ - cls.__bases__[-1].__name__) - self._Reader = cls + # @Reader.setter + # def Reader(self, cls): + # assert base_reader.__name__ == cls.__bases__[-1].__name__, \ + # "expect: {}, receive: {}.".format(base_reader.__name__, \ + # cls.__bases__[-1].__name__) + # self._Reader = cls @property def Paradigm(self): return self._Paradigm - @Paradigm.setter - def Paradigm(self, cls): - assert base_paradigm.__name__ == cls.__bases__[-1].__name__, \ - "expect: {}, receive: {}.".format(base_paradigm.__name__, \ - cls.__bases__[-1].__name__) - self._Paradigm = cls + # @Paradigm.setter + # def Paradigm(self, cls): + # assert base_paradigm.__name__ == cls.__bases__[-1].__name__, \ + # "expect: {}, receive: {}.".format(base_paradigm.__name__, \ + # cls.__bases__[-1].__name__) + # self._Paradigm = cls @property def config(self):