提交 a6b24416 编写于 作者: Z Zhang, Guoming

Fix the invalid list operation on save_inference_model function.

test=develop
上级 6c71c1f8
...@@ -637,8 +637,8 @@ def save_inference_model(dirname, ...@@ -637,8 +637,8 @@ def save_inference_model(dirname,
if isinstance(target_vars, Variable): if isinstance(target_vars, Variable):
target_vars = [target_vars] target_vars = [target_vars]
elif export_for_deployment: elif export_for_deployment:
if not (bool(target_vars) and all( if not (bool(target_vars) and
isinstance(var, Variable) for var in target_vars)): all(isinstance(var, Variable) for var in target_vars)):
raise ValueError("'target_vars' should be a list of Variable.") raise ValueError("'target_vars' should be a list of Variable.")
if main_program is None: if main_program is None:
...@@ -667,10 +667,15 @@ def save_inference_model(dirname, ...@@ -667,10 +667,15 @@ def save_inference_model(dirname,
if export_for_deployment: if export_for_deployment:
main_program = main_program.clone() main_program = main_program.clone()
global_block = main_program.global_block() global_block = main_program.global_block()
need_to_remove_op_index = []
for i, op in enumerate(global_block.ops): for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False) op.desc.set_is_target(False)
if op.type == "feed" or op.type == "fetch": if op.type == "feed" or op.type == "fetch":
global_block._remove_op(i) need_to_remove_op_index.append(i)
for index in need_to_remove_op_index[::-1]:
global_block._remove_op(index)
main_program.desc.flush() main_program.desc.flush()
main_program = main_program._prune(targets=target_vars) main_program = main_program._prune(targets=target_vars)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册