diff --git a/paddlepalm/mtl_controller.py b/paddlepalm/mtl_controller.py index 03c50a81df4b1e0f3cf279e42a0c0b2d14d8bb72..fccafd433e707dfd6d6c6f06dd973d035bfe6071 100755 --- a/paddlepalm/mtl_controller.py +++ b/paddlepalm/mtl_controller.py @@ -348,7 +348,6 @@ class Controller(object): task_attrs = [] pred_task_attrs = [] for inst in instances: - train_reader = inst.Reader(inst.config, phase='train') inst.reader['train'] = train_reader train_parad = inst.Paradigm(inst.config, phase='train', backbone_config=bb_conf) @@ -593,7 +592,7 @@ class Controller(object): cur_task.cur_train_step += 1 if cur_task.save_infermodel_every_n_steps > 0 and cur_task.cur_train_step % cur_task.save_infermodel_every_n_steps == 0: - cur_task.save(suffix='-step'+str(cur_task.cur_train_step)) + cur_task.save(suffix='.step'+str(cur_task.cur_train_step)) if global_step % main_conf.get('print_every_n_steps', 5) == 0: loss = rt_outputs[cur_task.name+'/loss'] diff --git a/paddlepalm/task_instance.py b/paddlepalm/task_instance.py index f16f98555b85c303511e1a03fc1084f27b291bc4..3f1131967d08c43ca1eda03f5e7368f689cc81a6 100644 --- a/paddlepalm/task_instance.py +++ b/paddlepalm/task_instance.py @@ -18,6 +18,8 @@ from paddlepalm.interface import task_paradigm as base_paradigm import os import json from paddle import fluid +import importlib +from paddlepalm.default_settings import * def check_req_args(conf, name): @@ -33,7 +35,7 @@ class TaskInstance(object): self._config = config self._verbose = verbose - check_req_args(config) + check_req_args(config, name) # parse Reader and Paradigm reader_name = config['reader'] @@ -49,6 +51,7 @@ class TaskInstance(object): self._save_infermodel_path = os.path.join(self._config['save_path'], self._name, 'infer_model') self._save_ckpt_path = os.path.join(self._config['save_path'], 'ckpt') + self._save_infermodel_every_n_steps = config.get('save_infermodel_every_n_steps', -1) # following flags can be fetch from instance config file self._is_target = config.get('is_target', True) @@ -77,9 +80,6 @@ class TaskInstance(object): self._pred_fetch_name_list = [] self._pred_fetch_var_list = [] - self._Reader = None - self._Paradigm = None - self._exe = fluid.Executor(fluid.CPUPlace()) self._save_protocol = { @@ -108,7 +108,9 @@ class TaskInstance(object): dirpath = self._save_infermodel_path + suffix self._pred_input_varname_list = [str(i) for i in self._pred_input_varname_list] - fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, export_for_deployment = True) + # fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, export_for_deployment = True) + prog = fluid.default_main_program().clone() + fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, prog) conf = {} for k, strv in self._save_protocol.items(): @@ -222,6 +224,10 @@ class TaskInstance(object): if self._verbose: print('{}: mix_ratio is set to {}'.format(self._name, self._mix_ratio)) + @property + def save_infermodel_every_n_steps(self): + return self._save_infermodel_every_n_steps + @property def expected_train_steps(self): return self._expected_train_steps