From a6b24416e4522ca663f27594e0232732bc87d7f3 Mon Sep 17 00:00:00 2001 From: "Zhang, Guoming" Date: Tue, 27 Nov 2018 23:41:29 +0800 Subject: [PATCH] Fix the invalid list operation on save_inference_model function. test=develop --- python/paddle/fluid/io.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 26d7af87b..0782933c6 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -637,8 +637,8 @@ def save_inference_model(dirname, if isinstance(target_vars, Variable): target_vars = [target_vars] elif export_for_deployment: - if not (bool(target_vars) and all( - isinstance(var, Variable) for var in target_vars)): + if not (bool(target_vars) and + all(isinstance(var, Variable) for var in target_vars)): raise ValueError("'target_vars' should be a list of Variable.") if main_program is None: @@ -667,10 +667,15 @@ def save_inference_model(dirname, if export_for_deployment: main_program = main_program.clone() global_block = main_program.global_block() + need_to_remove_op_index = [] for i, op in enumerate(global_block.ops): op.desc.set_is_target(False) if op.type == "feed" or op.type == "fetch": - global_block._remove_op(i) + need_to_remove_op_index.append(i) + + for index in need_to_remove_op_index[::-1]: + global_block._remove_op(index) + main_program.desc.flush() main_program = main_program._prune(targets=target_vars) -- GitLab