提交 0876fc14 编写于 作者: Q qiaolongfei

fix feed var

上级 4977d99b
...@@ -275,17 +275,6 @@ class Executor(object): ...@@ -275,17 +275,6 @@ class Executor(object):
outputs={'Out': [out]}, outputs={'Out': [out]},
attrs={'col': i}) attrs={'col': i})
for op in global_block.ops:
if op.desc.type() == 'feed':
feed_target_name = op.desc.output('Out')[0]
cur_feed = feed[feed_target_name]
if not isinstance(cur_feed, core.LoDTensor):
cur_feed = self.aslodtensor(cur_feed)
idx = op.desc.attr('col')
core.set_feed_variable(scope, cur_feed, feed_var_name, idx)
else:
break
if not has_fetch_operators(global_block, fetch_list, if not has_fetch_operators(global_block, fetch_list,
fetch_var_name): fetch_var_name):
for i, var in enumerate(fetch_list): for i, var in enumerate(fetch_list):
...@@ -297,6 +286,18 @@ class Executor(object): ...@@ -297,6 +286,18 @@ class Executor(object):
outputs={'Out': [fetch_var]}, outputs={'Out': [fetch_var]},
attrs={'col': i}) attrs={'col': i})
# feed var to framework
for op in program_cache.global_block().ops:
if op.desc.type() == 'feed':
feed_target_name = op.desc.output('Out')[0]
cur_feed = feed[feed_target_name]
if not isinstance(cur_feed, core.LoDTensor):
cur_feed = self.aslodtensor(cur_feed)
idx = op.desc.attr('col')
core.set_feed_variable(scope, cur_feed, feed_var_name, idx)
else:
break
self.executor.run(program_cache.desc, scope, 0, True, True) self.executor.run(program_cache.desc, scope, 0, True, True)
outs = [ outs = [
core.get_fetch_variable(scope, fetch_var_name, i) core.get_fetch_variable(scope, fetch_var_name, i)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册