From 27ccfdc02dd10dcfd675480a2f9e6e9cba2c6db8 Mon Sep 17 00:00:00 2001 From: wangxiao1021 Date: Mon, 3 Feb 2020 16:58:01 +0800 Subject: [PATCH] fix predict --- examples/predict/run.py | 2 +- paddlepalm/trainer.py | 25 +++++++++++++++++++++---- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/examples/predict/run.py b/examples/predict/run.py index 6e00c26..811aaf4 100644 --- a/examples/predict/run.py +++ b/examples/predict/run.py @@ -43,7 +43,7 @@ if __name__ == '__main__': trainer.build_predict_forward(pred_ernie, cls_pred_head) # step 6: load pretrained model - pred_model = trainer.load_pretrain(pre_params) + pred_model = trainer.load_predict_model(pre_params) # step 7: fit prepared reader and data trainer.fit_reader(predict_cls_reader, phase='predict') diff --git a/paddlepalm/trainer.py b/paddlepalm/trainer.py index 5bba571..cd78c2d 100644 --- a/paddlepalm/trainer.py +++ b/paddlepalm/trainer.py @@ -359,7 +359,10 @@ class Trainer(object): # load data self._check_phase(phase) - assert self._shape_and_dtypes is not None or self._pred_shape_and_dtypes is not None, "You need to build_forward or build_predict_head first to prepare input features." + if phase=='train': + assert self._shape_and_dtypes is not None, "You need to build_forward or build_predict_head first to prepare input features." + else: + assert self._pred_shape_and_dtypes is not None, "You need to build_forward or build_predict_head first to prepare input features." # 这里不确定是否要向上取整,需确认 # tail = self._num_examples % batch_size > 0 @@ -463,8 +466,22 @@ class Trainer(object): else: raise Exception("model not found. You should at least build_forward or build_predict_forward to load its checkpoint.") - def load_predict_model(self, model_path): - raise NotImplementedError() + def load_predict_model(self, model_path, convert=False): + """ + load pretrain models(backbone) for training. + + Args: + model_path: the path of saved pretrained parameters. + """ + + assert self._pred_prog is not None, "training graph not found. You should at least build_forward to load its pretrained parameters." + + saver.init_pretraining_params( + self._exe, + model_path, + convert=convert, + main_program=self._pred_prog) + # raise NotImplementedError() def load_pretrain(self, model_path, convert=False): """ @@ -717,7 +734,7 @@ class Trainer(object): if gpu_dev_count > 1: feed, mask = batch rt_outputs = self._exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._pred_fetch_list) - num_fakes = decode_fake(len(rt_outputs[0]), mask, self._batch_size) + num_fakes = decode_fake(len(rt_outputs[0]), mask, self._predict_batch_size) for _ in range(num_fakes): for item in rt_outputs: item.pop() -- GitLab