From 8a99149aeec285d1741932a0a4acfd84d8bcc3a2 Mon Sep 17 00:00:00 2001 From: xixiaoyao Date: Wed, 8 Jan 2020 18:16:50 +0800 Subject: [PATCH] add predict --- paddlepalm/trainer.py | 54 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 6 deletions(-) diff --git a/paddlepalm/trainer.py b/paddlepalm/trainer.py index 01be7e1..6d324d0 100644 --- a/paddlepalm/trainer.py +++ b/paddlepalm/trainer.py @@ -122,7 +122,16 @@ class Trainer(object): pred_task_inputs = {'backbone': pred_bb_output_vars, 'reader': cur_inputs} scope = self.name + '.' with fluid.unique_name.guard(scope): - self._build_head(pred_task_inputs, phase='pred', scope=scope) + output_vars = self._build_head(pred_task_inputs, phase='pred', scope=scope) + + if output_vars is not None: + self._pred_fetch_name_list, self._pred_fetch_var_list = zip(*output_vars.items()) + else: + self._pred_fetch_name_list = [] + self._pred_fetch_var_list = [] + + self._distribute_pred_prog = fluid.CompiledProgram(self._pred_prog).with_data_parallel() + return output_vars def build_forward(self, backbone, pred_backbone=None, train_prog=None, train_init_prog=None, pred_prog=None, pred_init_prog=None): @@ -277,11 +286,35 @@ class Trainer(object): return distribute_feeder_fn() def random_init_params(self): + assert self._train_init_prog is not None, "train graph not foung! You should build_forward first before you random init parameters." on_gpu = gpu_dev_count > 0 self._exe = helper.build_executor(on_gpu) print('random init params...') self._exe.run(self._train_init_prog) + def load_ckpt(self, model_path, phase='train'): + # load pretrain model (or ckpt) + assert self._exe is not None, "You need to random_init_params before load pretrain models." + + if phase == 'train': + assert self._train_init_prog is not None, "train graph not found! You should build_forward first before load checkpoint." + saver.init_pretraining_params( + self._exe, + model_path, + main_program=self._train_init_prog) + elif phase == 'predict': + assert self._pred_init_prog is not None, "predict graph not found! You should build_predict_head first before load checkpoint." + saver.init_pretraining_params( + self._exe, + model_path, + main_program=self._pred_init_prog) + else: + raise NotImplementedError() + + + def load_predict_model(self, model_path): + raise NotImplementedError() + def load_pretrain(self, model_path): # load pretrain model (or ckpt) assert self._exe is not None, "You need to random_init_params before load pretrain models." @@ -400,6 +433,19 @@ class Trainer(object): rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)} return rt_outputs + + def predict_one_batch(self, batch): + if gpu_dev_count > 1: + feed, mask = batch + rt_outputs = self.exe.run(self._distribute_train_prog, feed=feed, fetch_list=self._fetch_list) + while mask.pop() == False: + rt_outputs.pop() + else: + feed = self._feed_batch_process_fn(batch) + rt_outputs = self._exe.run(self._distribute_train_prog, feed=feed, fetch_list=self._fetch_list) + + rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)} + def _build_head(self, net_inputs, phase, scope=""): @@ -407,11 +453,6 @@ class Trainer(object): output_vars = self._task_head.build(net_inputs, scope_name=scope) if phase == 'pred': output_vars = self._pred_head.build(net_inputs, scope_name=scope) - if output_vars is not None: - self._pred_fetch_name_list, self._pred_fetch_var_list = zip(*output_vars.items()) - else: - self._pred_fetch_name_list = [] - self._pred_fetch_var_list = [] return output_vars def _postprocess(self, rt_outputs, phase): @@ -441,6 +482,7 @@ class Trainer(object): writer.write(json.dumps(conf, indent=1)) print(self._name + ': predict model saved at ' + dirpath) + def _load(self, infer_model_path=None): if infer_model_path is None: infer_model_path = self._save_infermodel_path -- GitLab