提交 27ccfdc0 编写于 作者: W wangxiao1021

fix predict

上级 c809077c
......@@ -43,7 +43,7 @@ if __name__ == '__main__':
trainer.build_predict_forward(pred_ernie, cls_pred_head)
# step 6: load pretrained model
pred_model = trainer.load_pretrain(pre_params)
pred_model = trainer.load_predict_model(pre_params)
# step 7: fit prepared reader and data
trainer.fit_reader(predict_cls_reader, phase='predict')
......
......@@ -359,7 +359,10 @@ class Trainer(object):
# load data
self._check_phase(phase)
assert self._shape_and_dtypes is not None or self._pred_shape_and_dtypes is not None, "You need to build_forward or build_predict_head first to prepare input features."
if phase=='train':
assert self._shape_and_dtypes is not None, "You need to build_forward or build_predict_head first to prepare input features."
else:
assert self._pred_shape_and_dtypes is not None, "You need to build_forward or build_predict_head first to prepare input features."
# 这里不确定是否要向上取整,需确认
# tail = self._num_examples % batch_size > 0
......@@ -463,8 +466,22 @@ class Trainer(object):
else:
raise Exception("model not found. You should at least build_forward or build_predict_forward to load its checkpoint.")
def load_predict_model(self, model_path):
raise NotImplementedError()
def load_predict_model(self, model_path, convert=False):
"""
load pretrain models(backbone) for training.
Args:
model_path: the path of saved pretrained parameters.
"""
assert self._pred_prog is not None, "training graph not found. You should at least build_forward to load its pretrained parameters."
saver.init_pretraining_params(
self._exe,
model_path,
convert=convert,
main_program=self._pred_prog)
# raise NotImplementedError()
def load_pretrain(self, model_path, convert=False):
"""
......@@ -717,7 +734,7 @@ class Trainer(object):
if gpu_dev_count > 1:
feed, mask = batch
rt_outputs = self._exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._pred_fetch_list)
num_fakes = decode_fake(len(rt_outputs[0]), mask, self._batch_size)
num_fakes = decode_fake(len(rt_outputs[0]), mask, self._predict_batch_size)
for _ in range(num_fakes):
for item in rt_outputs:
item.pop()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册