提交 a3f3e830 编写于 作者: J jiangjiajun

fix scope program for export inference

上级 aaf19f7e
...@@ -356,23 +356,24 @@ class BaseAPI: ...@@ -356,23 +356,24 @@ class BaseAPI:
def export_inference_model(self, save_dir): def export_inference_model(self, save_dir):
test_input_names = [var.name for var in list(self.test_inputs.values())] test_input_names = [var.name for var in list(self.test_inputs.values())]
test_outputs = list(self.test_outputs.values()) test_outputs = list(self.test_outputs.values())
if self.__class__.__name__ == 'MaskRCNN': with fluid.scope_guard(self.scope):
from paddlex.utils.save import save_mask_inference_model if self.__class__.__name__ == 'MaskRCNN':
save_mask_inference_model( from paddlex.utils.save import save_mask_inference_model
dirname=save_dir, save_mask_inference_model(
executor=self.exe, dirname=save_dir,
params_filename='__params__', executor=self.exe,
feeded_var_names=test_input_names, params_filename='__params__',
target_vars=test_outputs, feeded_var_names=test_input_names,
main_program=self.test_prog) target_vars=test_outputs,
else: main_program=self.test_prog)
fluid.io.save_inference_model( else:
dirname=save_dir, fluid.io.save_inference_model(
executor=self.exe, dirname=save_dir,
params_filename='__params__', executor=self.exe,
feeded_var_names=test_input_names, params_filename='__params__',
target_vars=test_outputs, feeded_var_names=test_input_names,
main_program=self.test_prog) target_vars=test_outputs,
main_program=self.test_prog)
model_info = self.get_model_info() model_info = self.get_model_info()
model_info['status'] = 'Infer' model_info['status'] = 'Infer'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册