diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index a4d1131ab83502fdb7a41253b7e26dfa96838d4e..bb8fbeb50a4eaddc7277e0be25a4fe2d8795749c 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -27,8 +27,9 @@ from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_st from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import data_layer_not_check -from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.data_feeder import check_type +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.layers.utils import map_structure __all__ = ['ProgramTranslator', 'convert_function_with_cache'] @@ -403,10 +404,11 @@ class ProgramTranslator(object): if not program_cache.in_build_process: outputs = self._run(*args, **kwargs) with guard(): + outputs = map_structure(to_variable, outputs) if len(outputs) == 1: - outputs = to_variable(outputs[0]) + outputs = outputs[0] else: - outputs = tuple(to_variable(x) for x in outputs) + outputs = tuple(outputs) return outputs def get_func(self, dygraph_func):