提交 ed3e5717 编写于 作者: K Kexin Zhao

fix bug

上级 dc168ed0
...@@ -89,7 +89,7 @@ void InferenceEngine::LoadInferenceModel( ...@@ -89,7 +89,7 @@ void InferenceEngine::LoadInferenceModel(
} }
bool InferenceEngine::IsParameter(const framework::VarDesc* var) { bool InferenceEngine::IsParameter(const framework::VarDesc* var) {
if (var->Persistable()) { if (var->Persistable() && var->Name() != "feed" && var->Name() != "fetch") {
// There are many unreachable variables in the program // There are many unreachable variables in the program
for (size_t i = 0; i < program_->Size(); ++i) { for (size_t i = 0; i < program_->Size(); ++i) {
const framework::BlockDesc& block = program_->Block(i); const framework::BlockDesc& block = program_->Block(i);
......
...@@ -15,6 +15,7 @@ import os ...@@ -15,6 +15,7 @@ import os
import cPickle as pickle import cPickle as pickle
from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable
from . import core
__all__ = [ __all__ = [
'save_vars', 'save_vars',
...@@ -244,10 +245,10 @@ def save_inference_model(dirname, ...@@ -244,10 +245,10 @@ def save_inference_model(dirname,
# Save only programDesc of inference_program in binary format # Save only programDesc of inference_program in binary format
# in another file: __model__.dat # in another file: __model__.dat
global_block = inference_program.global_block() global_block = inference_program.global_block()
feed_var = global_blok.create_var( feed_var = global_block.create_var(
name='feed', type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True) name='feed', type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True)
for i, name in enumerated(feeded_var_names): for i, name in enumerate(feeded_var_names):
out = global_block.var(name) out = global_block.var(name)
global_block.prepend_op( global_block.prepend_op(
type='feed', type='feed',
...@@ -258,10 +259,10 @@ def save_inference_model(dirname, ...@@ -258,10 +259,10 @@ def save_inference_model(dirname,
fetch_var = global_block.create_var( fetch_var = global_block.create_var(
name='fetch', type=core.VarDesc.VarType.FETCH_LIST, persistable=True) name='fetch', type=core.VarDesc.VarType.FETCH_LIST, persistable=True)
for i, name in enumerated(fetch_var_names): for i, name in enumerate(fetch_var_names):
global_block.append_op( global_block.append_op(
type='fetch', type='fetch',
inputs={'X': [var]}, inputs={'X': [name]},
outputs={'Out': [fetch_var]}, outputs={'Out': [fetch_var]},
attrs={'col': i}) attrs={'col': i})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册