diff --git a/mobile/tools/python/fluidtools/run.py b/mobile/tools/python/fluidtools/run.py index 6fa5842009121f31aedb409dc6d4ef85e3ddcd79..4773222536a53f28f433b2538d617233d05ab685 100644 --- a/mobile/tools/python/fluidtools/run.py +++ b/mobile/tools/python/fluidtools/run.py @@ -26,6 +26,7 @@ quantification = False quantification_fold = 1000 architecture = "arm-v7a" # architecture = "arm-v8a" +correct_persistable = False np.set_printoptions(linewidth=150) @@ -69,6 +70,18 @@ exe.run(fluid.default_startup_program()) # 加载模型 def load_model(model_path): prog, feeds, fetches = fluid.io.load_inference_model(dirname=model_path, executor=exe, model_filename="model", params_filename="params") + global correct_persistable + if correct_persistable: + ops = prog.current_block().ops + vars = prog.current_block().vars + for op in ops: + for var_name in op.output_arg_names: + if var_name == "fetch": + continue + var = vars[var_name] + if var.persistable: + pp_red("has found non-persistable output var : {}".format(var_name)) + var.persistable = False return (prog, feeds, fetches) prog, feeds, fetches = load_model(model_path)