提交 085a13d2 编写于 作者: X xixiaoyao

fix pred

...@@ -40,6 +40,9 @@ class Trainer(object): ...@@ -40,6 +40,9 @@ class Trainer(object):
self._task_head = task_head self._task_head = task_head
self._pred_head = None self._pred_head = None
self._train_init = False
self._predict_init = False
# if save_predict_model: # if save_predict_model:
# self._save_predict_model = True # self._save_predict_model = True
# assert pred_head is not None, "pred_head is required to save predict model." # assert pred_head is not None, "pred_head is required to save predict model."
...@@ -220,7 +223,7 @@ class Trainer(object): ...@@ -220,7 +223,7 @@ class Trainer(object):
for _id, block in enumerate(self._train_prog.blocks): for _id, block in enumerate(self._train_prog.blocks):
for var in block.vars: for var in block.vars:
print("[debug] : %d, %s" % (_id, var)) print("[debug] : %d, %s" % (_id, var))
self._loss_var = loss_var
return loss_var return loss_var
def build_backward(self, optimizer, weight_decay=None, use_ema=False, ema_decay=0.9999): def build_backward(self, optimizer, weight_decay=None, use_ema=False, ema_decay=0.9999):
...@@ -296,30 +299,44 @@ class Trainer(object): ...@@ -296,30 +299,44 @@ class Trainer(object):
distribute_feeder_fn = iterator_fn distribute_feeder_fn = iterator_fn
return distribute_feeder_fn() return distribute_feeder_fn()
def random_init_params(self): def _init_exe_prog(self, for_train=True):
assert self._train_init_prog is not None, "train graph not foung! You should build_forward first before you random init parameters." assert self._train_init_prog is not None, "train graph not foung! You should build_forward first before you random init parameters."
self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=loss_var.name) self._train_init = True
self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name)
on_gpu = gpu_dev_count > 0 on_gpu = gpu_dev_count > 0
self._exe = helper.build_executor(on_gpu) self._exe = helper.build_executor(on_gpu)
if not for_train:
raise NotImplementedError()
def random_init_params(self):
if not self._train_init:
self._init_exe_prog()
print('random init params...') print('random init params...')
self._exe.run(self._train_init_prog) self._exe.run(self._train_init_prog)
def load_ckpt(self, model_path, phase='train'): def load_ckpt(self, model_path, phase='train'):
# load pretrain model (or ckpt) # load pretrain model (or ckpt)
assert self._exe is not None, "You need to random_init_params before load checkpoints." # assert self._exe is not None, "You need to random_init_params before load checkpoints."
if phase == 'train' and not self._train_init:
self._init_exe_prog()
if phase == 'predict' and not self._predict_init:
pass
if phase == 'train': if phase == 'train':
assert self._train_init_prog is not None, "train graph not found! You should build_forward first before load checkpoint." assert self._train_init_prog is not None, "train graph not found! You should build_forward first before load checkpoint."
saver.init_pretraining_params( saver.init_pretraining_params(
self._exe, self._exe,
model_path, model_path,
main_program=self._train_init_prog) main_program=self._train_init_prog,
strict=True)
elif phase == 'predict': elif phase == 'predict':
assert self._pred_init_prog is not None, "predict graph not found! You should build_predict_head first before load checkpoint." assert self._pred_init_prog is not None, "predict graph not found! You should build_predict_head first before load checkpoint."
saver.init_pretraining_params( saver.init_pretraining_params(
self._exe, self._exe,
model_path, model_path,
main_program=self._pred_init_prog) main_program=self._pred_init_prog,
strict=True)
else: else:
raise NotImplementedError() raise NotImplementedError()
...@@ -397,6 +414,11 @@ class Trainer(object): ...@@ -397,6 +414,11 @@ class Trainer(object):
task_rt_outputs = {k[len(self.name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self.name+'.')} task_rt_outputs = {k[len(self.name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self.name+'.')}
self._task_head.postprocess(task_rt_outputs) self._task_head.postprocess(task_rt_outputs)
# rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)}
task_rt_outputs = {k[len(self.name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self.name+'.')}
self._task_head.postprocess(task_rt_outputs)
self._cur_train_step += 1 self._cur_train_step += 1
self._cur_train_epoch = (self._cur_train_step-1) // self._steps_pur_epoch self._cur_train_epoch = (self._cur_train_step-1) // self._steps_pur_epoch
...@@ -578,11 +600,6 @@ class Trainer(object): ...@@ -578,11 +600,6 @@ class Trainer(object):
# self._cur_train_step = 1 # self._cur_train_step = 1
# if self._is_target and self._cur_train_step + self._cur_train_epoch * self._steps_pur_epoch >= self._expected_train_steps: # if self._is_target and self._cur_train_step + self._cur_train_epoch * self._steps_pur_epoch >= self._expected_train_steps:
# self._train_finish = True # self._train_finish = True
@property
def steps_pur_epoch(self):
return self._steps_pur_epoch
@steps_pur_epoch.setter @steps_pur_epoch.setter
def steps_pur_epoch(self, value): def steps_pur_epoch(self, value):
self._steps_pur_epoch = value self._steps_pur_epoch = value
......
...@@ -46,20 +46,24 @@ def init_checkpoint(exe, init_checkpoint_path, main_program, skip_list = []): ...@@ -46,20 +46,24 @@ def init_checkpoint(exe, init_checkpoint_path, main_program, skip_list = []):
def init_pretraining_params(exe, def init_pretraining_params(exe,
pretraining_params_path, pretraining_params_path,
convert,
main_program): main_program):
assert os.path.exists(pretraining_params_path assert os.path.exists(pretraining_params_path
), "[%s] cann't be found." % pretraining_params_path ), "[%s] cann't be found." % pretraining_params_path
if convert:
assert os.path.exists(os.path.join(pretraining_params_path, '__palmmodel__')), "__palmmodel__ not found."
assert os.path.exists(os.path.join(pretraining_params_path, '__palmmodel__')), "__palmmodel__ not found." with tarfile.open(os.path.join(pretraining_params_path, '__palmmodel__'), 'r') as f:
print("Loading pretraining parameters from {}...".format( f.extractall(os.path.join(pretraining_params_path, '.temp'))
pretraining_params_path))
log_path = os.path.join(pretraining_params_path, '__palmmodel__')
pretraining_params_path = os.path.join(pretraining_params_path, '.temp')
with tarfile.open(os.path.join(pretraining_params_path, '__palmmodel__'), 'r') as f: else:
f.extractall(os.path.join(pretraining_params_path, '.temp')) log_path = pretraining_params_path
log_path = os.path.join(pretraining_params_path, '__palmmodel__') print("Loading pretraining parameters from {}...".format(pretraining_params_path))
pretraining_params_path = os.path.join(pretraining_params_path, '.temp')
def existed_params(var): def existed_params(var):
if not isinstance(var, fluid.framework.Parameter): if not isinstance(var, fluid.framework.Parameter):
...@@ -73,8 +77,8 @@ def init_pretraining_params(exe, ...@@ -73,8 +77,8 @@ def init_pretraining_params(exe,
pretraining_params_path, pretraining_params_path,
main_program=main_program, main_program=main_program,
predicate=existed_params) predicate=existed_params)
if convert:
shutil.rmtree(pretraining_params_path) shutil.rmtree(pretraining_params_path)
print('') print('')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册