From eef09152b36b4771235246f4c56bbca94fa76014 Mon Sep 17 00:00:00 2001 From: xixiaoyao Date: Sun, 24 Nov 2019 20:58:30 +0800 Subject: [PATCH] add saving infermodel with num steps --- paddlepalm/mtl_controller.py | 23 ++++++------------ paddlepalm/task_instance.py | 47 ++++++++++++++++++++++++++---------- 2 files changed, 42 insertions(+), 28 deletions(-) diff --git a/paddlepalm/mtl_controller.py b/paddlepalm/mtl_controller.py index aa94c34..03c50a8 100755 --- a/paddlepalm/mtl_controller.py +++ b/paddlepalm/mtl_controller.py @@ -304,19 +304,6 @@ class Controller(object): instances[i].task_reuse_scope = instances[j].name break - # parse Reader and Paradigm for each instance - for inst in instances: - reader_name = inst.config['reader'] - reader_mod = importlib.import_module(READER_DIR + '.' + reader_name) - Reader = getattr(reader_mod, 'Reader') - - parad_name = inst.config['paradigm'] - parad_mod = importlib.import_module(PARADIGM_DIR + '.' + parad_name) - Paradigm = getattr(parad_mod, 'TaskParadigm') - - inst.Reader = Reader - inst.Paradigm = Paradigm - self.instances = instances self.mrs = mrs self.Backbone = Backbone @@ -605,6 +592,9 @@ class Controller(object): global_step += 1 cur_task.cur_train_step += 1 + if cur_task.save_infermodel_every_n_steps > 0 and cur_task.cur_train_step % cur_task.save_infermodel_every_n_steps == 0: + cur_task.save(suffix='-step'+str(cur_task.cur_train_step)) + if global_step % main_conf.get('print_every_n_steps', 5) == 0: loss = rt_outputs[cur_task.name+'/loss'] loss = np.mean(np.squeeze(loss)).tolist() @@ -635,8 +625,11 @@ class Controller(object): assert isinstance(task_instance, str) if isinstance(inference_model_dir, str): assert os.path.exists(inference_model_dir), inference_model_dir+" not found." - if not self.has_init_pred and inference_model_dir is None: - raise ValueError('infer_model_path is required for prediction.') + # if not self.has_init_pred and inference_model_dir is None: + # raise ValueError('infer_model_path is required for prediction.') + if inference_model_dir is None: + assert 'save_path' in self.mtl_conf, "one of the `inference_model_dir` and 'save_path' should be set to load inference model." + inference_model_dir = os.path.join(self.mtl_conf['save_path'], task_instance, 'infer_model') instance = None for inst in self.instances: diff --git a/paddlepalm/task_instance.py b/paddlepalm/task_instance.py index a2b1b05..f16f985 100644 --- a/paddlepalm/task_instance.py +++ b/paddlepalm/task_instance.py @@ -19,13 +19,34 @@ import os import json from paddle import fluid + +def check_req_args(conf, name): + assert 'reader' in conf, name+': reader is required to build TaskInstance.' + assert 'paradigm' in conf, name+': paradigm is required to build TaskInstance.' + assert 'train_file' in conf or 'pred_file' in conf, name+': at least train_file or pred_file should be provided to build TaskInstance.' + + class TaskInstance(object): - def __init__(self, name, id, config={}, verbose=True): + def __init__(self, name, id, config, verbose=True): self._name = name self._config = config self._verbose = verbose + check_req_args(config) + + # parse Reader and Paradigm + reader_name = config['reader'] + reader_mod = importlib.import_module(READER_DIR + '.' + reader_name) + Reader = getattr(reader_mod, 'Reader') + + parad_name = config['paradigm'] + parad_mod = importlib.import_module(PARADIGM_DIR + '.' + parad_name) + Paradigm = getattr(parad_mod, 'TaskParadigm') + + self._Reader = Reader + self._Paradigm = Paradigm + self._save_infermodel_path = os.path.join(self._config['save_path'], self._name, 'infer_model') self._save_ckpt_path = os.path.join(self._config['save_path'], 'ckpt') @@ -116,23 +137,23 @@ class TaskInstance(object): def Reader(self): return self._Reader - @Reader.setter - def Reader(self, cls): - assert base_reader.__name__ == cls.__bases__[-1].__name__, \ - "expect: {}, receive: {}.".format(base_reader.__name__, \ - cls.__bases__[-1].__name__) - self._Reader = cls + # @Reader.setter + # def Reader(self, cls): + # assert base_reader.__name__ == cls.__bases__[-1].__name__, \ + # "expect: {}, receive: {}.".format(base_reader.__name__, \ + # cls.__bases__[-1].__name__) + # self._Reader = cls @property def Paradigm(self): return self._Paradigm - @Paradigm.setter - def Paradigm(self, cls): - assert base_paradigm.__name__ == cls.__bases__[-1].__name__, \ - "expect: {}, receive: {}.".format(base_paradigm.__name__, \ - cls.__bases__[-1].__name__) - self._Paradigm = cls + # @Paradigm.setter + # def Paradigm(self, cls): + # assert base_paradigm.__name__ == cls.__bases__[-1].__name__, \ + # "expect: {}, receive: {}.".format(base_paradigm.__name__, \ + # cls.__bases__[-1].__name__) + # self._Paradigm = cls @property def config(self): -- GitLab