From e1f475ad1e2fdfbc1628ab063fdaaa037468c44e Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Wed, 17 Jan 2018 17:15:42 -0800 Subject: [PATCH] refine code --- python/paddle/v2/fluid/io.py | 50 +++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/python/paddle/v2/fluid/io.py b/python/paddle/v2/fluid/io.py index ad4837c3f7..1b17afbe3f 100644 --- a/python/paddle/v2/fluid/io.py +++ b/python/paddle/v2/fluid/io.py @@ -192,6 +192,33 @@ def get_inference_program(target_vars, main_program=None): return inference_program +def prepend_feed_ops(inference_program, feeded_var_names): + global_block = inference_program.global_block() + feed_var = global_block.create_var( + name='feed', type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True) + + for i, name in enumerate(feeded_var_names): + out = global_block.var(name) + global_block.prepend_op( + type='feed', + inputs={'X': [feed_var]}, + outputs={'Out': [out]}, + attrs={'col': i}) + + +def append_fetch_ops(inference_program, fetch_var_names): + global_block = inference_program.global_block() + fetch_var = global_block.create_var( + name='fetch', type=core.VarDesc.VarType.FETCH_LIST, persistable=True) + + for i, name in enumerate(fetch_var_names): + global_block.append_op( + type='fetch', + inputs={'X': [name]}, + outputs={'Out': [fetch_var]}, + attrs={'col': i}) + + def save_inference_model(dirname, feeded_var_names, target_vars, @@ -244,27 +271,8 @@ def save_inference_model(dirname, # Save only programDesc of inference_program in binary format # in another file: __model__.dat - global_block = inference_program.global_block() - feed_var = global_block.create_var( - name='feed', type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True) - - for i, name in enumerate(feeded_var_names): - out = global_block.var(name) - global_block.prepend_op( - type='feed', - inputs={'X': [feed_var]}, - outputs={'Out': [out]}, - attrs={'col': i}) - - fetch_var = global_block.create_var( - name='fetch', type=core.VarDesc.VarType.FETCH_LIST, persistable=True) - - for i, name in enumerate(fetch_var_names): - global_block.append_op( - type='fetch', - inputs={'X': [name]}, - outputs={'Out': [fetch_var]}, - attrs={'col': i}) + prepend_feed_ops(inference_program, feeded_var_names) + append_fetch_ops(inference_program, fetch_var_names) with open(model_file_name + ".dat", "wb") as fp: fp.write(inference_program.desc.serialize_to_string()) -- GitLab