未验证 提交 c21afb28 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Merge pull request #27 from xixiaoyao/master

add save infermodel with train steps
...@@ -304,19 +304,6 @@ class Controller(object): ...@@ -304,19 +304,6 @@ class Controller(object):
instances[i].task_reuse_scope = instances[j].name instances[i].task_reuse_scope = instances[j].name
break break
# parse Reader and Paradigm for each instance
for inst in instances:
reader_name = inst.config['reader']
reader_mod = importlib.import_module(READER_DIR + '.' + reader_name)
Reader = getattr(reader_mod, 'Reader')
parad_name = inst.config['paradigm']
parad_mod = importlib.import_module(PARADIGM_DIR + '.' + parad_name)
Paradigm = getattr(parad_mod, 'TaskParadigm')
inst.Reader = Reader
inst.Paradigm = Paradigm
self.instances = instances self.instances = instances
self.mrs = mrs self.mrs = mrs
self.Backbone = Backbone self.Backbone = Backbone
...@@ -361,7 +348,6 @@ class Controller(object): ...@@ -361,7 +348,6 @@ class Controller(object):
task_attrs = [] task_attrs = []
pred_task_attrs = [] pred_task_attrs = []
for inst in instances: for inst in instances:
train_reader = inst.Reader(inst.config, phase='train') train_reader = inst.Reader(inst.config, phase='train')
inst.reader['train'] = train_reader inst.reader['train'] = train_reader
train_parad = inst.Paradigm(inst.config, phase='train', backbone_config=bb_conf) train_parad = inst.Paradigm(inst.config, phase='train', backbone_config=bb_conf)
...@@ -605,6 +591,9 @@ class Controller(object): ...@@ -605,6 +591,9 @@ class Controller(object):
global_step += 1 global_step += 1
cur_task.cur_train_step += 1 cur_task.cur_train_step += 1
if cur_task.save_infermodel_every_n_steps > 0 and cur_task.cur_train_step % cur_task.save_infermodel_every_n_steps == 0:
cur_task.save(suffix='.step'+str(cur_task.cur_train_step))
if global_step % main_conf.get('print_every_n_steps', 5) == 0: if global_step % main_conf.get('print_every_n_steps', 5) == 0:
loss = rt_outputs[cur_task.name+'/loss'] loss = rt_outputs[cur_task.name+'/loss']
loss = np.mean(np.squeeze(loss)).tolist() loss = np.mean(np.squeeze(loss)).tolist()
...@@ -635,8 +624,11 @@ class Controller(object): ...@@ -635,8 +624,11 @@ class Controller(object):
assert isinstance(task_instance, str) assert isinstance(task_instance, str)
if isinstance(inference_model_dir, str): if isinstance(inference_model_dir, str):
assert os.path.exists(inference_model_dir), inference_model_dir+" not found." assert os.path.exists(inference_model_dir), inference_model_dir+" not found."
if not self.has_init_pred and inference_model_dir is None: # if not self.has_init_pred and inference_model_dir is None:
raise ValueError('infer_model_path is required for prediction.') # raise ValueError('infer_model_path is required for prediction.')
if inference_model_dir is None:
assert 'save_path' in self.mtl_conf, "one of the `inference_model_dir` and 'save_path' should be set to load inference model."
inference_model_dir = os.path.join(self.mtl_conf['save_path'], task_instance, 'infer_model')
instance = None instance = None
for inst in self.instances: for inst in self.instances:
......
...@@ -18,16 +18,40 @@ from paddlepalm.interface import task_paradigm as base_paradigm ...@@ -18,16 +18,40 @@ from paddlepalm.interface import task_paradigm as base_paradigm
import os import os
import json import json
from paddle import fluid from paddle import fluid
import importlib
from paddlepalm.default_settings import *
def check_req_args(conf, name):
assert 'reader' in conf, name+': reader is required to build TaskInstance.'
assert 'paradigm' in conf, name+': paradigm is required to build TaskInstance.'
assert 'train_file' in conf or 'pred_file' in conf, name+': at least train_file or pred_file should be provided to build TaskInstance.'
class TaskInstance(object): class TaskInstance(object):
def __init__(self, name, id, config={}, verbose=True): def __init__(self, name, id, config, verbose=True):
self._name = name self._name = name
self._config = config self._config = config
self._verbose = verbose self._verbose = verbose
check_req_args(config, name)
# parse Reader and Paradigm
reader_name = config['reader']
reader_mod = importlib.import_module(READER_DIR + '.' + reader_name)
Reader = getattr(reader_mod, 'Reader')
parad_name = config['paradigm']
parad_mod = importlib.import_module(PARADIGM_DIR + '.' + parad_name)
Paradigm = getattr(parad_mod, 'TaskParadigm')
self._Reader = Reader
self._Paradigm = Paradigm
self._save_infermodel_path = os.path.join(self._config['save_path'], self._name, 'infer_model') self._save_infermodel_path = os.path.join(self._config['save_path'], self._name, 'infer_model')
self._save_ckpt_path = os.path.join(self._config['save_path'], 'ckpt') self._save_ckpt_path = os.path.join(self._config['save_path'], 'ckpt')
self._save_infermodel_every_n_steps = config.get('save_infermodel_every_n_steps', -1)
# following flags can be fetch from instance config file # following flags can be fetch from instance config file
self._is_target = config.get('is_target', True) self._is_target = config.get('is_target', True)
...@@ -56,9 +80,6 @@ class TaskInstance(object): ...@@ -56,9 +80,6 @@ class TaskInstance(object):
self._pred_fetch_name_list = [] self._pred_fetch_name_list = []
self._pred_fetch_var_list = [] self._pred_fetch_var_list = []
self._Reader = None
self._Paradigm = None
self._exe = fluid.Executor(fluid.CPUPlace()) self._exe = fluid.Executor(fluid.CPUPlace())
self._save_protocol = { self._save_protocol = {
...@@ -87,7 +108,9 @@ class TaskInstance(object): ...@@ -87,7 +108,9 @@ class TaskInstance(object):
dirpath = self._save_infermodel_path + suffix dirpath = self._save_infermodel_path + suffix
self._pred_input_varname_list = [str(i) for i in self._pred_input_varname_list] self._pred_input_varname_list = [str(i) for i in self._pred_input_varname_list]
fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, export_for_deployment = True) # fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, export_for_deployment = True)
prog = fluid.default_main_program().clone()
fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, prog)
conf = {} conf = {}
for k, strv in self._save_protocol.items(): for k, strv in self._save_protocol.items():
...@@ -116,23 +139,23 @@ class TaskInstance(object): ...@@ -116,23 +139,23 @@ class TaskInstance(object):
def Reader(self): def Reader(self):
return self._Reader return self._Reader
@Reader.setter # @Reader.setter
def Reader(self, cls): # def Reader(self, cls):
assert base_reader.__name__ == cls.__bases__[-1].__name__, \ # assert base_reader.__name__ == cls.__bases__[-1].__name__, \
"expect: {}, receive: {}.".format(base_reader.__name__, \ # "expect: {}, receive: {}.".format(base_reader.__name__, \
cls.__bases__[-1].__name__) # cls.__bases__[-1].__name__)
self._Reader = cls # self._Reader = cls
@property @property
def Paradigm(self): def Paradigm(self):
return self._Paradigm return self._Paradigm
@Paradigm.setter # @Paradigm.setter
def Paradigm(self, cls): # def Paradigm(self, cls):
assert base_paradigm.__name__ == cls.__bases__[-1].__name__, \ # assert base_paradigm.__name__ == cls.__bases__[-1].__name__, \
"expect: {}, receive: {}.".format(base_paradigm.__name__, \ # "expect: {}, receive: {}.".format(base_paradigm.__name__, \
cls.__bases__[-1].__name__) # cls.__bases__[-1].__name__)
self._Paradigm = cls # self._Paradigm = cls
@property @property
def config(self): def config(self):
...@@ -201,6 +224,10 @@ class TaskInstance(object): ...@@ -201,6 +224,10 @@ class TaskInstance(object):
if self._verbose: if self._verbose:
print('{}: mix_ratio is set to {}'.format(self._name, self._mix_ratio)) print('{}: mix_ratio is set to {}'.format(self._name, self._mix_ratio))
@property
def save_infermodel_every_n_steps(self):
return self._save_infermodel_every_n_steps
@property @property
def expected_train_steps(self): def expected_train_steps(self):
return self._expected_train_steps return self._expected_train_steps
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册