未验证 提交 e8869a90 编写于 作者: L liym27 提交者: GitHub

Fix bug in ProgramTranslator.get_output, convert all items into VarBase in nested list. (#24267)

上级 381492fc
...@@ -27,8 +27,9 @@ from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_st ...@@ -27,8 +27,9 @@ from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_st
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import data_layer_not_check from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import data_layer_not_check
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.data_feeder import check_type from paddle.fluid.data_feeder import check_type
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.layers.utils import map_structure
__all__ = ['ProgramTranslator', 'convert_function_with_cache'] __all__ = ['ProgramTranslator', 'convert_function_with_cache']
...@@ -403,10 +404,11 @@ class ProgramTranslator(object): ...@@ -403,10 +404,11 @@ class ProgramTranslator(object):
if not program_cache.in_build_process: if not program_cache.in_build_process:
outputs = self._run(*args, **kwargs) outputs = self._run(*args, **kwargs)
with guard(): with guard():
outputs = map_structure(to_variable, outputs)
if len(outputs) == 1: if len(outputs) == 1:
outputs = to_variable(outputs[0]) outputs = outputs[0]
else: else:
outputs = tuple(to_variable(x) for x in outputs) outputs = tuple(outputs)
return outputs return outputs
def get_func(self, dygraph_func): def get_func(self, dygraph_func):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册