提交 e1f475ad 编写于 作者: K Kexin Zhao

refine code

上级 ed3e5717
...@@ -192,6 +192,33 @@ def get_inference_program(target_vars, main_program=None): ...@@ -192,6 +192,33 @@ def get_inference_program(target_vars, main_program=None):
return inference_program return inference_program
def prepend_feed_ops(inference_program, feeded_var_names):
global_block = inference_program.global_block()
feed_var = global_block.create_var(
name='feed', type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True)
for i, name in enumerate(feeded_var_names):
out = global_block.var(name)
global_block.prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
def append_fetch_ops(inference_program, fetch_var_names):
global_block = inference_program.global_block()
fetch_var = global_block.create_var(
name='fetch', type=core.VarDesc.VarType.FETCH_LIST, persistable=True)
for i, name in enumerate(fetch_var_names):
global_block.append_op(
type='fetch',
inputs={'X': [name]},
outputs={'Out': [fetch_var]},
attrs={'col': i})
def save_inference_model(dirname, def save_inference_model(dirname,
feeded_var_names, feeded_var_names,
target_vars, target_vars,
...@@ -244,27 +271,8 @@ def save_inference_model(dirname, ...@@ -244,27 +271,8 @@ def save_inference_model(dirname,
# Save only programDesc of inference_program in binary format # Save only programDesc of inference_program in binary format
# in another file: __model__.dat # in another file: __model__.dat
global_block = inference_program.global_block() prepend_feed_ops(inference_program, feeded_var_names)
feed_var = global_block.create_var( append_fetch_ops(inference_program, fetch_var_names)
name='feed', type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True)
for i, name in enumerate(feeded_var_names):
out = global_block.var(name)
global_block.prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
fetch_var = global_block.create_var(
name='fetch', type=core.VarDesc.VarType.FETCH_LIST, persistable=True)
for i, name in enumerate(fetch_var_names):
global_block.append_op(
type='fetch',
inputs={'X': [name]},
outputs={'Out': [fetch_var]},
attrs={'col': i})
with open(model_file_name + ".dat", "wb") as fp: with open(model_file_name + ".dat", "wb") as fp:
fp.write(inference_program.desc.serialize_to_string()) fp.write(inference_program.desc.serialize_to_string())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册