提交 8a99149a 编写于 作者: X xixiaoyao

add predict

上级 2efeb39b
......@@ -122,7 +122,16 @@ class Trainer(object):
pred_task_inputs = {'backbone': pred_bb_output_vars, 'reader': cur_inputs}
scope = self.name + '.'
with fluid.unique_name.guard(scope):
self._build_head(pred_task_inputs, phase='pred', scope=scope)
output_vars = self._build_head(pred_task_inputs, phase='pred', scope=scope)
if output_vars is not None:
self._pred_fetch_name_list, self._pred_fetch_var_list = zip(*output_vars.items())
else:
self._pred_fetch_name_list = []
self._pred_fetch_var_list = []
self._distribute_pred_prog = fluid.CompiledProgram(self._pred_prog).with_data_parallel()
return output_vars
def build_forward(self, backbone, pred_backbone=None, train_prog=None, train_init_prog=None, pred_prog=None, pred_init_prog=None):
......@@ -277,11 +286,35 @@ class Trainer(object):
return distribute_feeder_fn()
def random_init_params(self):
assert self._train_init_prog is not None, "train graph not foung! You should build_forward first before you random init parameters."
on_gpu = gpu_dev_count > 0
self._exe = helper.build_executor(on_gpu)
print('random init params...')
self._exe.run(self._train_init_prog)
def load_ckpt(self, model_path, phase='train'):
# load pretrain model (or ckpt)
assert self._exe is not None, "You need to random_init_params before load pretrain models."
if phase == 'train':
assert self._train_init_prog is not None, "train graph not found! You should build_forward first before load checkpoint."
saver.init_pretraining_params(
self._exe,
model_path,
main_program=self._train_init_prog)
elif phase == 'predict':
assert self._pred_init_prog is not None, "predict graph not found! You should build_predict_head first before load checkpoint."
saver.init_pretraining_params(
self._exe,
model_path,
main_program=self._pred_init_prog)
else:
raise NotImplementedError()
def load_predict_model(self, model_path):
raise NotImplementedError()
def load_pretrain(self, model_path):
# load pretrain model (or ckpt)
assert self._exe is not None, "You need to random_init_params before load pretrain models."
......@@ -401,17 +434,25 @@ class Trainer(object):
rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)}
return rt_outputs
def predict_one_batch(self, batch):
if gpu_dev_count > 1:
feed, mask = batch
rt_outputs = self.exe.run(self._distribute_train_prog, feed=feed, fetch_list=self._fetch_list)
while mask.pop() == False:
rt_outputs.pop()
else:
feed = self._feed_batch_process_fn(batch)
rt_outputs = self._exe.run(self._distribute_train_prog, feed=feed, fetch_list=self._fetch_list)
rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)}
def _build_head(self, net_inputs, phase, scope=""):
if phase == 'train':
output_vars = self._task_head.build(net_inputs, scope_name=scope)
if phase == 'pred':
output_vars = self._pred_head.build(net_inputs, scope_name=scope)
if output_vars is not None:
self._pred_fetch_name_list, self._pred_fetch_var_list = zip(*output_vars.items())
else:
self._pred_fetch_name_list = []
self._pred_fetch_var_list = []
return output_vars
def _postprocess(self, rt_outputs, phase):
......@@ -441,6 +482,7 @@ class Trainer(object):
writer.write(json.dumps(conf, indent=1))
print(self._name + ': predict model saved at ' + dirpath)
def _load(self, infer_model_path=None):
if infer_model_path is None:
infer_model_path = self._save_infermodel_path
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册