提交 7bdc468f 编写于 作者: W wuzewu

Add save_inference_model api

上级 b50fa1ea
......@@ -559,6 +559,11 @@ class BaseTask(object):
return [metric.name for metric in self.metrics] + [self.loss.name]
return [output.name for output in self.outputs]
@property
def fetch_var_list(self):
vars = self.main_program.global_block().vars
return [vars[varname] for varname in self.fetch_list]
@property
def tb_writer(self):
if not os.path.exists(self.config.checkpoint_dir):
......@@ -740,6 +745,20 @@ class BaseTask(object):
fluid.io.save_params(
self.exe, dirname=dirname, main_program=self.main_program)
def save_inference_model(self,
dirname,
model_filename=None,
params_filename=None):
with self.phase_guard("predict"):
fluid.io.save_inference_model(
dirname=dirname,
executor=self.exe,
feeded_var_names=self.feed_list,
target_vars=self.fetch_var_list,
main_program=self.main_program,
model_filename=model_filename,
params_filename=params_filename)
def finetune_and_eval(self):
return self.finetune(do_eval=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册